ResNet Models for Image Classification¶
This page documents the ResNet model pool used in FusionBench for supervised image classification. The pool supports both torchvision and Hugging Face Transformers implementations and provides a unified interface for loading models, processors, and saving artifacts.
Overview¶
The ResNetForImageClassificationPool offers:
- Two backends
torchvision: classic ResNet backbones (resnet18/34/50/101/152) with standard ImageNet preprocessingtransformers:ResNetForImageClassificationmodels withAutoImageProcessor
- Dataset-aware adaptation
- When a
dataset_nameis provided, the pool resizes the classifier head and setsid2label/label2idmappings via the dataset class names
- When a
- Processor management
- Torchvision: returns stage-specific transforms (train/val/test)
- Transformers: loads a compatible
AutoImageProcessor
- Clean logits API
- For Transformers models,
forwardis wrapped to return logits only for consistent evaluation interfaces
- For Transformers models,
Quick start (Transformers backend)¶
Minimal Python usage with a single pretrained model adapted to a dataset (e.g., CIFAR-10):
from fusion_bench.modelpool import ResNetForImageClassificationPool
pool = ResNetForImageClassificationPool(
type="transformers",
models={
"_pretrained_": {
"config_path": "microsoft/resnet-50",
"pretrained": True,
"dataset_name": "cifar10",
}
},
)
model = pool.load_model("_pretrained_")
processor = pool.load_processor() # AutoImageProcessor
Torchvision backend¶
When using torchvision models, the pool constructs appropriate transforms per stage and can optionally resize the classifier to match your dataset:
from fusion_bench.modelpool import ResNetForImageClassificationPool
pool = ResNetForImageClassificationPool(
type="torchvision",
models={
"resnet18_cifar10": {
"model_name": "resnet18",
"weights": "DEFAULT", # or None
"dataset_name": "cifar10", # classifier resized to 10 classes
}
},
)
train_tf = pool.load_processor(stage="train")
val_tf = pool.load_processor(stage="test")
model = pool.load_model("resnet18_cifar10")
Low-level helpers also exist if you want to create models directly:
- fusion_bench.modelpool.resnet_for_image_classification.load_torchvision_resnet
- fusion_bench.modelpool.resnet_for_image_classification.load_transformers_resnet
Ready-to-use configs¶
You can use the following model pool configs out-of-the-box. They set up a single pretrained model adapted to a specific dataset. Include them with the modelpool= flag when running fusion_bench.
defaults:
- /dataset/image_classification/train@train_datasets:
- cifar10
- /dataset/image_classification/test@val_datasets:
- cifar10
- _self_
_target_: fusion_bench.modelpool.ResNetForImageClassificationPool
_recursive_: False
type: transformers
models:
_pretrained_:
config_path: microsoft/resnet-18
pretrained: true
dataset_name: cifar10
defaults:
- /dataset/image_classification/train@train_datasets:
- cifar10
- /dataset/image_classification/test@val_datasets:
- cifar10
- _self_
_target_: fusion_bench.modelpool.ResNetForImageClassificationPool
_recursive_: False
type: transformers
models:
_pretrained_:
config_path: microsoft/resnet-50
pretrained: true
dataset_name: cifar10
These configs follow the same structure used across other model pools in this directory (see CLIP-ViT or GPT-2 pages for reference) and are suitable starting points for evaluation and merging workflows.
Saving models¶
- Torchvision:
state_dictis saved viatorch.saveto the given file path - Transformers:
save_pretrained()is used for both model and processor; a README model card is written on rank-zero whenalgorithm_configis supplied
See fusion_bench.modelpool.resnet_for_image_classification.ResNetForImageClassificationPool.save_model.
Implementation Details¶
- Pool class: fusion_bench.modelpool.resnet_for_image_classification.ResNetForImageClassificationPool
- Torchvision loader: fusion_bench.modelpool.resnet_for_image_classification.load_torchvision_resnet
- Transformers loader: fusion_bench.modelpool.resnet_for_image_classification.load_transformers_resnet
- Task pool (evaluation): fusion_bench.taskpool.resnet_for_image_classification.ResNetForImageClassificationTaskPool