Skip to content

ConvNeXt Models for Image Classification

This page documents the ConvNeXt image classification model pool in FusionBench. It wraps Hugging Face Transformers ConvNeXt models with convenient loading, processor management, dataset-aware head adaptation, and save utilities.

Implementation: ConvNextForImageClassificationPool

alt text

Quick start

Minimal Python usage with a single pretrained ConvNeXt model (e.g., base-224):

from fusion_bench.modelpool import ConvNextForImageClassificationPool

pool = ConvNextForImageClassificationPool(
        models={
                "_pretrained_": {
                        "config_path": "facebook/convnext-base-224",
                        "pretrained": True,
                        # set to a known dataset key (e.g., "cifar10") to resize classifier
                        # and populate id2label/label2id mappings
                        "dataset_name": None,
                }
        }
)

model = pool.load_model("_pretrained_")
processor = pool.load_processor()  # AutoImageProcessor

Low-level construction is available via helpers:

Ready-to-use config

Use the provided Hydra config to set up a pretrained ConvNeXt-base model:

config/modelpool/ConvNextForImageClassification/convnext-base-224.yaml
_target_: fusion_bench.modelpool.ConvNextForImageClassificationPool
_recursive_: False
models:
  _pretrained_:
    config_path: facebook/convnext-base-224
    pretrained: true
    dataset_name: null
train_datasets: null
val_datasets: null
test_datasets: null

Tip: set dataset_name to a supported dataset key (e.g., cifar10, svhn, gtsrb, …) to auto-resize the classifier and label mappings.

The 8-Task Benchmark

This pool includes a preset configuration for an 8-task image classification benchmark. The tasks are: SUN397, Stanford Cars, RESISC45, EuroSAT, SVHN, GTSRB, MNIST, and DTD.

config/modelpool/ConvNextForImageClassification/convnext-base-224_8-tasks.yaml
_target_: fusion_bench.modelpool.ConvNextForImageClassificationPool
_recursive_: False
models:
  _pretrained_: facebook/convnext-base-224
  sun397: tanganke/convnext-base-224_sun397_sgd_batch-size-64_lr-0.01_steps-4000
  stanford-cars: tanganke/convnext-base-224_stanford-cars_sgd_batch-size-64_lr-0.01_steps-4000
  resisc45: tanganke/convnext-base-224_resisc45_sgd_batch-size-64_lr-0.01_steps-4000
  eurosat: tanganke/convnext-base-224_eurosat_sgd_batch-size-64_lr-0.01_steps-4000
  svhn: tanganke/convnext-base-224_svhn_sgd_batch-size-64_lr-0.01_steps-4000
  gtsrb: tanganke/convnext-base-224_gtsrb_sgd_batch-size-64_lr-0.01_steps-4000
  mnist: tanganke/convnext-base-224_mnist_sgd_batch-size-64_lr-0.01_steps-4000
  dtd: tanganke/convnext-base-224_dtd_sgd_batch-size-64_lr-0.01_steps-4000
train_datasets: null
val_datasets: null
test_datasets: null

These models are fine-tuned from facebook/convnext-base-224.

When merging these models (e.g., using simple_average), we typically want to fuse the shared backbone while potentially keeping task-specific heads separate or handling them specially. The ConvNeXt implementation sets:

model._fusion_bench_target_modules = ["convnext"]

This attribute informs partial fusion algorithms to target the convnext backbone module. For standard parameter merging, this ensures we focus on the shared feature extractor.

examples/simple_average/convnext-base-224.py
"""
Example of merging ConvNeXt models using simple averaging.
"""

import lightning as L

from fusion_bench.method import SimpleAverageAlgorithm
from fusion_bench.modelpool import ConvNextForImageClassificationPool, BaseModelPool
from fusion_bench.models.wrappers.switch import SwitchModule, set_active_option
from fusion_bench.taskpool.image_classification import ImageClassificationTaskPool
from fusion_bench.utils import initialize_hydra_config, instantiate

fabric = L.Fabric(accelerator="auto", devices=1)
fabric.launch()

config = initialize_hydra_config(
    config_name="fabric_model_fusion",
    overrides=[
        "method=simple_average",
        "modelpool=ConvNextForImageClassification/convnext-base-224_8-tasks",
        "taskpool=ImageClassificationTaskPool/convnext-base-224_8-tasks.yaml",
    ],
)

algorithm: SimpleAverageAlgorithm = instantiate(config.method)
modelpool: ConvNextForImageClassificationPool = instantiate(config.modelpool)
taskpool: ImageClassificationTaskPool = instantiate(config.taskpool)
taskpool.fabric = fabric

models = {
    model_name: modelpool.load_model(model_name) for model_name in modelpool.model_names
}

# Wrap classification heads in a SwitchModule
heads = {model_name: m.classifier for model_name, m in models.items()}
head = SwitchModule(heads)

merged_model = algorithm.run(modelpool=BaseModelPool(models))
merged_model.classifier = head
report = taskpool.evaluate(merged_model)
print(report)

Implementation Details