ConvNeXt Models for Image Classification¶
This page documents the ConvNeXt image classification model pool in FusionBench. It wraps Hugging Face Transformers ConvNeXt models with convenient loading, processor management, dataset-aware head adaptation, and save utilities.
Implementation: ConvNextForImageClassificationPool

Quick start¶
Minimal Python usage with a single pretrained ConvNeXt model (e.g., base-224):
from fusion_bench.modelpool import ConvNextForImageClassificationPool
pool = ConvNextForImageClassificationPool(
models={
"_pretrained_": {
"config_path": "facebook/convnext-base-224",
"pretrained": True,
# set to a known dataset key (e.g., "cifar10") to resize classifier
# and populate id2label/label2id mappings
"dataset_name": None,
}
}
)
model = pool.load_model("_pretrained_")
processor = pool.load_processor() # AutoImageProcessor
Low-level construction is available via helpers:
Ready-to-use config¶
Use the provided Hydra config to set up a pretrained ConvNeXt-base model:
_target_: fusion_bench.modelpool.ConvNextForImageClassificationPool
_recursive_: False
models:
_pretrained_:
config_path: facebook/convnext-base-224
pretrained: true
dataset_name: null
train_datasets: null
val_datasets: null
test_datasets: null
Tip: set dataset_name to a supported dataset key (e.g., cifar10, svhn, gtsrb, …) to auto-resize the classifier and label mappings.
The 8-Task Benchmark¶
This pool includes a preset configuration for an 8-task image classification benchmark. The tasks are: SUN397, Stanford Cars, RESISC45, EuroSAT, SVHN, GTSRB, MNIST, and DTD.
_target_: fusion_bench.modelpool.ConvNextForImageClassificationPool
_recursive_: False
models:
_pretrained_: facebook/convnext-base-224
sun397: tanganke/convnext-base-224_sun397_sgd_batch-size-64_lr-0.01_steps-4000
stanford-cars: tanganke/convnext-base-224_stanford-cars_sgd_batch-size-64_lr-0.01_steps-4000
resisc45: tanganke/convnext-base-224_resisc45_sgd_batch-size-64_lr-0.01_steps-4000
eurosat: tanganke/convnext-base-224_eurosat_sgd_batch-size-64_lr-0.01_steps-4000
svhn: tanganke/convnext-base-224_svhn_sgd_batch-size-64_lr-0.01_steps-4000
gtsrb: tanganke/convnext-base-224_gtsrb_sgd_batch-size-64_lr-0.01_steps-4000
mnist: tanganke/convnext-base-224_mnist_sgd_batch-size-64_lr-0.01_steps-4000
dtd: tanganke/convnext-base-224_dtd_sgd_batch-size-64_lr-0.01_steps-4000
train_datasets: null
val_datasets: null
test_datasets: null
These models are fine-tuned from facebook/convnext-base-224.
When merging these models (e.g., using simple_average), we typically want to fuse the shared backbone while potentially keeping task-specific heads separate or handling them specially. The ConvNeXt implementation sets:
This attribute informs partial fusion algorithms to target the convnext backbone module.
For standard parameter merging, this ensures we focus on the shared feature extractor.
"""
Example of merging ConvNeXt models using simple averaging.
"""
import lightning as L
from fusion_bench.method import SimpleAverageAlgorithm
from fusion_bench.modelpool import ConvNextForImageClassificationPool, BaseModelPool
from fusion_bench.models.wrappers.switch import SwitchModule, set_active_option
from fusion_bench.taskpool.image_classification import ImageClassificationTaskPool
from fusion_bench.utils import initialize_hydra_config, instantiate
fabric = L.Fabric(accelerator="auto", devices=1)
fabric.launch()
config = initialize_hydra_config(
config_name="fabric_model_fusion",
overrides=[
"method=simple_average",
"modelpool=ConvNextForImageClassification/convnext-base-224_8-tasks",
"taskpool=ImageClassificationTaskPool/convnext-base-224_8-tasks.yaml",
],
)
algorithm: SimpleAverageAlgorithm = instantiate(config.method)
modelpool: ConvNextForImageClassificationPool = instantiate(config.modelpool)
taskpool: ImageClassificationTaskPool = instantiate(config.taskpool)
taskpool.fabric = fabric
models = {
model_name: modelpool.load_model(model_name) for model_name in modelpool.model_names
}
# Wrap classification heads in a SwitchModule
heads = {model_name: m.classifier for model_name, m in models.items()}
head = SwitchModule(heads)
merged_model = algorithm.run(modelpool=BaseModelPool(models))
merged_model.classifier = head
report = taskpool.evaluate(merged_model)
print(report)