Skip to content

ResNet Models for Image Classification

This page documents the ResNet model pool used in FusionBench for supervised image classification. The pool supports both torchvision and Hugging Face Transformers implementations and provides a unified interface for loading models, processors, and saving artifacts.

Overview

The ResNetForImageClassificationPool offers:

  • Two backends
    • torchvision: classic ResNet backbones (resnet18/34/50/101/152) with standard ImageNet preprocessing
    • transformers: ResNetForImageClassification models with AutoImageProcessor
  • Dataset-aware adaptation
    • When a dataset_name is provided, the pool resizes the classifier head and sets id2label/label2id mappings via the dataset class names
  • Processor management
    • Torchvision: returns stage-specific transforms (train/val/test)
    • Transformers: loads a compatible AutoImageProcessor
  • Clean logits API
    • For Transformers models, forward is wrapped to return logits only for consistent evaluation interfaces

Quick start (Transformers backend)

Minimal Python usage with a single pretrained model adapted to a dataset (e.g., CIFAR-10):

from fusion_bench.modelpool import ResNetForImageClassificationPool

pool = ResNetForImageClassificationPool(
        type="transformers",
        models={
                "_pretrained_": {
                        "config_path": "microsoft/resnet-50",
                        "pretrained": True,
                        "dataset_name": "cifar10",
                }
        },
)

model = pool.load_model("_pretrained_")
processor = pool.load_processor()  # AutoImageProcessor

Torchvision backend

When using torchvision models, the pool constructs appropriate transforms per stage and can optionally resize the classifier to match your dataset:

from fusion_bench.modelpool import ResNetForImageClassificationPool

pool = ResNetForImageClassificationPool(
        type="torchvision",
        models={
                "resnet18_cifar10": {
                        "model_name": "resnet18",
                        "weights": "DEFAULT",  # or None
                        "dataset_name": "cifar10",  # classifier resized to 10 classes
                }
        },
)

train_tf = pool.load_processor(stage="train")
val_tf = pool.load_processor(stage="test")
model = pool.load_model("resnet18_cifar10")

Low-level helpers also exist if you want to create models directly:

Ready-to-use configs

You can use the following model pool configs out-of-the-box. They set up a single pretrained model adapted to a specific dataset. Include them with the modelpool= flag when running fusion_bench.

config/modelpool/ResNetForImageClassification/transformers/resnet18_cifar10.yaml
defaults:
  - /dataset/image_classification/train@train_datasets:
      - cifar10
  - /dataset/image_classification/test@val_datasets:
      - cifar10
  - _self_
_target_: fusion_bench.modelpool.ResNetForImageClassificationPool
_recursive_: False
type: transformers
models:
  _pretrained_:
    config_path: microsoft/resnet-18
    pretrained: true
    dataset_name: cifar10
config/modelpool/ResNetForImageClassification/transformers/resnet50_cifar10.yaml
defaults:
  - /dataset/image_classification/train@train_datasets:
      - cifar10
  - /dataset/image_classification/test@val_datasets:
      - cifar10
  - _self_
_target_: fusion_bench.modelpool.ResNetForImageClassificationPool
_recursive_: False
type: transformers
models:
  _pretrained_:
    config_path: microsoft/resnet-50
    pretrained: true
    dataset_name: cifar10

These configs follow the same structure used across other model pools in this directory (see CLIP-ViT or GPT-2 pages for reference) and are suitable starting points for evaluation and merging workflows.

Saving models

  • Torchvision: state_dict is saved via torch.save to the given file path
  • Transformers: save_pretrained() is used for both model and processor; a README model card is written on rank-zero when algorithm_config is supplied

See fusion_bench.modelpool.resnet_for_image_classification.ResNetForImageClassificationPool.save_model.

Implementation Details