fusion_bench.modelpool¶
Base Class¶
BaseModelPool
¶
Bases: HydraConfigMixin, BaseYAMLSerializable
A class for managing and interacting with a pool of models along with their associated datasets or other specifications. For example, a model pool may contain multiple models, each with its own training, validation, and testing datasets. As for the specifications, a vision model pool may contain image preprocessor, and a language model pool may contain a tokenizer.
Attributes:
-
_models(DictConfig) –Configuration for all models in the pool.
-
_train_datasets(Optional[DictConfig]) –Configuration for training datasets.
-
_val_datasets(Optional[DictConfig]) –Configuration for validation datasets.
-
_test_datasets(Optional[DictConfig]) –Configuration for testing datasets.
-
_usage_(Optional[str]) –Optional usage information.
-
_version_(Optional[str]) –Optional version information.
Source code in fusion_bench/modelpool/base_pool.py
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 | |
all_model_names
property
¶
Get the names of all models in the pool, including special models.
Returns:
-
List[str]–List[str]: A list of all model names.
has_pretrained
property
¶
Check if the model pool contains a pretrained model.
Returns:
-
bool(bool) –True if a pretrained model is available, False otherwise.
has_test_dataset
property
¶
Check if the model pool contains testing datasets.
Returns:
-
bool(bool) –True if testing datasets are available, False otherwise.
has_train_dataset
property
¶
Check if the model pool contains training datasets.
Returns:
-
bool(bool) –True if training datasets are available, False otherwise.
has_val_dataset
property
¶
Check if the model pool contains validation datasets.
Returns:
-
bool(bool) –True if validation datasets are available, False otherwise.
model_names
property
¶
Get the names of regular models, excluding special models.
Returns:
-
List[str]–List[str]: A list of regular model names.
test_dataset_names
property
¶
Get the names of testing datasets.
Returns:
-
List[str]–List[str]: A list of testing dataset names.
train_dataset_names
property
¶
Get the names of training datasets.
Returns:
-
List[str]–List[str]: A list of training dataset names.
val_dataset_names
property
¶
Get the names of validation datasets.
Returns:
-
List[str]–List[str]: A list of validation dataset names.
__contains__(model_name)
¶
Check if a model with the given name exists in the model pool.
Examples:
>>> modelpool = BaseModelPool(models={"modelA": ..., "modelB": ...})
>>> "modelA" in modelpool
True
>>> "modelC" in modelpool
False
Parameters:
-
model_name(str) –The name of the model to check.
Returns:
-
bool(bool) –True if the model exists, False otherwise.
Source code in fusion_bench/modelpool/base_pool.py
get_model_config(model_name, return_copy=True)
¶
Get the configuration for the specified model.
Parameters:
-
model_name(str) –The name of the model.
Returns:
-
Union[DictConfig, str, Any]–Union[DictConfig, str, Any]: The configuration for the specified model, which may be a DictConfig, string path, or other type.
Raises:
-
ValidationError–If model_name is invalid.
-
KeyError–If model_name is not found in the pool.
Source code in fusion_bench/modelpool/base_pool.py
get_model_path(model_name)
¶
Get the path for the specified model.
Parameters:
-
model_name(str) –The name of the model.
Returns:
-
str(str) –The path for the specified model.
Raises:
-
ValidationError–If model_name is invalid.
-
KeyError–If model_name is not found in the pool.
-
ValueError–If model configuration is not a string path.
Source code in fusion_bench/modelpool/base_pool.py
is_special_model(model_name)
staticmethod
¶
Determine if a model is special based on its name.
Parameters:
-
model_name(str) –The name of the model.
Returns:
-
bool(bool) –True if the model name indicates a special model, False otherwise.
Source code in fusion_bench/modelpool/base_pool.py
load_model(model_name_or_config, *args, **kwargs)
¶
Load a model from the pool based on the provided configuration.
Parameters:
-
model_name_or_config(Union[str, DictConfig]) –The model name or configuration. - If str: should be a key in self._models - If DictConfig: should be a configuration dict for instantiation
-
*args–Additional positional arguments passed to model instantiation.
-
**kwargs–Additional keyword arguments passed to model instantiation.
Returns:
-
Module–nn.Module: The instantiated or retrieved model.
Source code in fusion_bench/modelpool/base_pool.py
load_pretrained_or_first_model(*args, **kwargs)
¶
Load the pretrained model if available, otherwise load the first available model.
Returns:
-
–
nn.Module: The loaded model.
Source code in fusion_bench/modelpool/base_pool.py
load_test_dataset(dataset_name, *args, **kwargs)
¶
Load the testing dataset for the specified model.
Parameters:
-
dataset_name(str) –The name of the model.
Returns:
-
Dataset(Dataset) –The instantiated testing dataset.
Source code in fusion_bench/modelpool/base_pool.py
load_train_dataset(dataset_name, *args, **kwargs)
¶
Load the training dataset for the specified model.
Parameters:
-
dataset_name(str) –The name of the model.
Returns:
-
Dataset(Dataset) –The instantiated training dataset.
Source code in fusion_bench/modelpool/base_pool.py
load_val_dataset(dataset_name, *args, **kwargs)
¶
Load the validation dataset for the specified model.
Parameters:
-
dataset_name(str) –The name of the model.
Returns:
-
Dataset(Dataset) –The instantiated validation dataset.
Source code in fusion_bench/modelpool/base_pool.py
save_model(model, path, *args, **kwargs)
¶
Save the state dictionary of the model to the specified path.
Parameters:
-
model(Module) –The model whose state dictionary is to be saved.
-
path(str) –The path where the state dictionary will be saved.
Source code in fusion_bench/modelpool/base_pool.py
Vision Model Pool¶
NYUv2 Tasks (ResNet)¶
NYUv2ModelPool
¶
Bases: ModelPool
Source code in fusion_bench/modelpool/nyuv2_modelpool.py
CLIP Vision Encoder¶
CLIPVisionModelPool
¶
Bases: BaseModelPool
A model pool for managing Hugging Face's CLIP Vision models.
This class extends the base ModelPool class and overrides its methods to handle
the specifics of the CLIP Vision models provided by the Hugging Face Transformers library.
Source code in fusion_bench/modelpool/clip_vision/modelpool.py
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 | |
load_model(model_name_or_config, *args, **kwargs)
¶
Load a CLIPVisionModel from the model pool with support for various configuration formats.
This method provides flexible model loading capabilities, handling different types of model configurations including string paths, pre-instantiated models, and complex configurations.
Supported configuration formats: 1. String model paths (e.g., Hugging Face model IDs) 2. Pre-instantiated nn.Module objects 3. DictConfig objects for complex configurations
Example configuration:
models:
# Simple string paths to Hugging Face models
cifar10: tanganke/clip-vit-base-patch32_cifar10
sun397: tanganke/clip-vit-base-patch32_sun397
stanford-cars: tanganke/clip-vit-base-patch32_stanford-cars
# Complex configuration with additional parameters
custom_model:
_target_: transformers.CLIPVisionModel.from_pretrained
pretrained_model_name_or_path: openai/clip-vit-base-patch32
torch_dtype: float16
Parameters:
-
model_name_or_config(Union[str, DictConfig]) –Either a model name from the pool or a configuration dictionary for instantiating the model.
-
*args–Additional positional arguments passed to model loading/instantiation.
-
**kwargs–Additional keyword arguments passed to model loading/instantiation.
Returns:
-
CLIPVisionModel(CLIPVisionModel) –The loaded CLIPVisionModel instance.
Source code in fusion_bench/modelpool/clip_vision/modelpool.py
save_model(model, path)
¶
Save a CLIP Vision model to the given path.
Parameters:
-
model(CLIPVisionModel) –The model to save.
-
path(str) –The path to save the model to.
Source code in fusion_bench/modelpool/clip_vision/modelpool.py
OpenCLIP Vision Encoder¶
OpenCLIPVisionModelPool
¶
Bases: BaseModelPool
A model pool for managing OpenCLIP Vision models (models from task vector paper).
Source code in fusion_bench/modelpool/openclip_vision/modelpool.py
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 | |
load_classification_head(model_name_or_config, *args, **kwargs)
¶
The model config can be:
- A string, which is the path to the model checkpoint in pickle format. Load directly using
torch.load. - Default, load the model using
instantiatefrom hydra.
Source code in fusion_bench/modelpool/openclip_vision/modelpool.py
load_model(model_name_or_config, *args, **kwargs)
¶
The model config can be:
- A string, which is the path to the model checkpoint in pickle format. Load directly using
torch.load. - {"model_name": str, "pickle_path": str}, load the model from the binary file (pickle format). This will first construct the model using
ImageEncoder(model_name), and then load the state dict from model located in the pickle file. - {"model_name": str, "state_dict_path": str}, load the model from the state dict file. This will first construct the model using
ImageEncoder(model_name), and then load the state dict from the file. - Default, load the model using
instantiatefrom hydra.
Source code in fusion_bench/modelpool/openclip_vision/modelpool.py
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 | |
ResNet for Image Classification¶
ResNet Model Pool for Image Classification.
This module provides a flexible model pool implementation for ResNet models used in image classification tasks. It supports both torchvision and transformers implementations of ResNet architectures with configurable preprocessing, loading, and saving capabilities.
Example Usage
Create a pool with a torchvision ResNet model:
>>> # Torchvision ResNet pool
>>> pool = ResNetForImageClassificationPool(
... type="torchvision",
... models={"resnet18_cifar10": {"model_name": "resnet18", "dataset_name": "cifar10"}}
... )
>>> model = pool.load_model("resnet18_cifar10")
>>> processor = pool.load_processor(stage="train")
Create a pool with a transformers ResNet model:
ResNetForImageClassificationPool
¶
Bases: BaseModelPool
Model pool for ResNet-based image classification models.
This class provides a unified interface for managing ResNet models from different sources (torchvision and transformers) with automatic preprocessing, loading, and saving capabilities. It supports multiple ResNet architectures and can automatically adapt models to different datasets by adjusting the number of output classes.
The pool supports two main types: - "torchvision": Uses torchvision's ResNet implementations with standard ImageNet preprocessing - "transformers": Uses Hugging Face transformers' ResNetForImageClassification with auto processors
Parameters:
-
type(str) –Model source type, must be either "torchvision" or "transformers".
-
**kwargs–Additional arguments passed to the base BaseModelPool class.
Attributes:
-
type(str) –The model source type specified during initialization.
Raises:
-
AssertionError–If type is not "torchvision" or "transformers".
Example
Create a pool with a torchvision ResNet model:
>>> # Torchvision-based pool
>>> pool = ResNetForImageClassificationPool(
... type="torchvision",
... models={
... "resnet18_cifar10": {
... "model_name": "resnet18",
... "weights": "DEFAULT",
... "dataset_name": "cifar10"
... }
... }
... )
Create a pool with a transformers ResNet model:
```python
>>> # Transformers-based pool
>>> pool = ResNetForImageClassificationPool(
... type="transformers",
... models={
... "resnet_model": {
... "config_path": "microsoft/resnet-50",
... "pretrained": True,
... "dataset_name": "imagenet"
... }
... }
... )
Source code in fusion_bench/modelpool/resnet_for_image_classification.py
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 | |
load_model(model_name_or_config, *args, **kwargs)
¶
Load a ResNet model based on the provided configuration or model name.
This method supports flexible model loading from different sources and configurations: - Direct model names (e.g., "resnet18", "resnet50") for standard architectures - Model pool keys that map to configurations - Dictionary/DictConfig objects with detailed model specifications - Hugging Face model identifiers for transformers models
For torchvision models, supports: - Standard ResNet architectures: resnet18, resnet34, resnet50, resnet101, resnet152 - Custom configurations with model_name, weights, and num_classes specifications - Automatic dataset adaptation with class number inference
For transformers models: - Loading from Hugging Face Hub or local paths - Pretrained or randomly initialized models - Automatic logits extraction by overriding forward method - Dataset-specific label mapping configuration
Parameters:
-
model_name_or_config(Union[str, DictConfig]) –Model specification that can be: - A string model name (e.g., "resnet18") for standard architectures - A model pool key referencing a stored configuration - A dict/DictConfig with model parameters like: * For torchvision: {"model_name": "resnet18", "weights": "DEFAULT", "num_classes": 10} * For transformers: {"config_path": "microsoft/resnet-50", "pretrained": True, "dataset_name": "cifar10"}
-
*args–Additional positional arguments (unused).
-
**kwargs–Additional keyword arguments (unused).
Returns:
-
–
Union[TorchVisionResNet, ResNetForImageClassification]: The loaded ResNet model
-
–
configured for the specified task. For transformers models, the forward method
-
–
is modified to return logits directly instead of the full model output.
Raises:
-
ValueError–If model_name_or_config type is invalid or if model type is unknown.
-
AssertionError–If num_classes from dataset doesn't match explicit num_classes specification.
Example
>>> # Load standard torchvision model
>>> model = pool.load_model("resnet18")
>>> # Load with custom configuration
>>> config = {"model_name": "resnet50", "weights": "DEFAULT", "dataset_name": "cifar10"}
>>> model = pool.load_model(config)
>>> # Load transformers model
>>> config = {"config_path": "microsoft/resnet-50", "pretrained": True}
>>> model = pool.load_model(config)
Source code in fusion_bench/modelpool/resnet_for_image_classification.py
300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 | |
load_processor(stage='test', *args, **kwargs)
¶
Load the appropriate image processor/transform for the specified training stage.
Creates stage-specific image preprocessing pipelines optimized for the model type:
For torchvision models: - Train stage: Includes data augmentation (random resize crop, horizontal flip) - Val/test stages: Standard preprocessing (resize, center crop) without augmentation - All stages: Apply ImageNet normalization (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
For transformers models: - Uses AutoImageProcessor from the pretrained model configuration - Automatically handles model-specific preprocessing requirements
Parameters:
-
stage(Literal['train', 'val', 'test'], default:'test') –The training stage determining preprocessing type. - "train": Applies data augmentation for training - "val"/"test": Uses standard preprocessing for evaluation
-
*args–Additional positional arguments (unused).
-
**kwargs–Additional keyword arguments (unused).
Returns:
-
–
Union[transforms.Compose, AutoImageProcessor]: The image processor/transform pipeline
-
–
appropriate for the specified stage and model type.
Raises:
-
ValueError–If no valid config_path can be found for transformers models.
Example
Source code in fusion_bench/modelpool/resnet_for_image_classification.py
219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 | |
save_model(model, path, algorithm_config=None, description=None, base_model=None, *args, **kwargs)
¶
Save a ResNet model to the specified path using the appropriate format.
This method handles model saving based on the model pool type: - For torchvision models: Saves only the state_dict using torch.save() - For transformers models: Saves the complete model and processor using save_pretrained()
The saving format ensures compatibility with the corresponding loading mechanisms and preserves all necessary components for model restoration.
Parameters:
-
model–The ResNet model to save. Should be compatible with the pool's model type.
-
path(str) –Destination path for saving the model. For torchvision models, this should be a file path (e.g., "model.pth"). For transformers models, this should be a directory path where model files will be stored.
-
*args–Additional positional arguments (unused).
-
**kwargs–Additional keyword arguments (unused).
Raises:
-
ValueError–If the model type is unknown or unsupported.
Note
For transformers models, both the model weights and the associated image processor are saved to ensure complete reproducibility of the preprocessing pipeline.
Example
Source code in fusion_bench/modelpool/resnet_for_image_classification.py
load_torchvision_resnet(model_name, weights, num_classes)
¶
Load a ResNet model from torchvision with optional custom classifier head.
This function creates a ResNet model using torchvision's model zoo and optionally replaces the final classification layer to match the required number of classes.
Parameters:
-
model_name(str) –Name of the ResNet model to load (e.g., 'resnet18', 'resnet50'). Must be a valid torchvision model name.
-
weights(Optional[str]) –Pretrained weights to load. Can be 'DEFAULT', 'IMAGENET1K_V1', or None for random initialization. See torchvision documentation for available options.
-
num_classes(Optional[int]) –Number of output classes. If provided, replaces the final fully connected layer. If None, keeps the original classifier (typically 1000 classes).
Returns:
-
TorchVisionResNet(ResNet) –The loaded ResNet model with appropriate classifier head.
Raises:
-
AttributeError–If model_name is not a valid torchvision model.
Example
Source code in fusion_bench/modelpool/resnet_for_image_classification.py
load_transformers_resnet(config_path, pretrained, dataset_name)
¶
Load a ResNet model from transformers with optional dataset-specific adaptation.
This function creates a ResNet model using the transformers library and optionally adapts it for a specific dataset by updating the classifier head and label mappings.
Parameters:
-
config_path(str) –Path or identifier for the model configuration. Can be a local path or a Hugging Face model identifier (e.g., 'microsoft/resnet-50').
-
pretrained(bool) –Whether to load pretrained weights. If True, loads from the specified config_path. If False, initializes with random weights using the config.
-
dataset_name(Optional[str]) –Name of the target dataset for adaptation. If provided, updates the model's classifier and label mappings to match the dataset's classes. If None, keeps the original model configuration.
Returns:
-
ResNetForImageClassification–The loaded and optionally adapted ResNet model.
Example
Source code in fusion_bench/modelpool/resnet_for_image_classification.py
ConvNeXt for Image Classification¶
Hugging Face ConvNeXt image classification model pool.
This module provides a BaseModelPool implementation that loads and saves
ConvNeXt models for image classification via transformers. It optionally
reconfigures the classification head to match a dataset's class names and
overrides forward to return logits only for simpler downstream usage.
See also: fusion_bench.modelpool.resnet_for_image_classification for a
parallel implementation for ResNet-based classifiers.
ConvNextForImageClassificationPool
¶
Bases: BaseModelPool
Model pool for ConvNeXt image classification models (HF Transformers).
Responsibilities:
- Load an AutoImageProcessor compatible with the configured ConvNeXt model.
- Load ConvNeXt models either from a pretrained checkpoint or from config.
- Optionally adapt the classifier head to match dataset classnames.
- Override forward to return logits for consistent interfaces within
FusionBench.
See fusion_bench.modelpool.resnet_for_image_classification for a closely
related ResNet-based pool with analogous behavior.
Source code in fusion_bench/modelpool/convnext_for_image_classification.py
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 | |
load_model(model_name_or_config, *args, **kwargs)
¶
Load a ConvNeXt model described by a name, path, or DictConfig.
Accepts either a string (pretrained identifier or local path) or a
config mapping with keys: config_path, optional pretrained (bool),
and optional dataset_name to resize the classifier.
Returns:
-
–
A model whose
forwardis wrapped to return only logits to align -
–
with FusionBench expectations.
Source code in fusion_bench/modelpool/convnext_for_image_classification.py
save_model(model, path, algorithm_config=None, description=None, base_model=None, *args, **kwargs)
¶
Save the model, processor, and an optional model card to disk.
Artifacts written to path:
- The ConvNeXt model via model.save_pretrained.
- The paired image processor via AutoImageProcessor.save_pretrained.
- If algorithm_config is provided and on rank-zero, a README model card
documenting the FusionBench configuration.
Source code in fusion_bench/modelpool/convnext_for_image_classification.py
load_transformers_convnext(config_path, pretrained, dataset_name)
¶
Create a ConvNeXt image classification model from a config or checkpoint.
Parameters:
-
config_path(str) –A model identifier or local path understood by
transformers.AutoConfig/AutoModel(e.g., "facebook/convnext-base-224"). -
pretrained(bool) –If True, load weights via
from_pretrained; otherwise, build the model from config only. -
dataset_name(Optional[str]) –Optional dataset key used by FusionBench to derive class names via
get_classnames. When provided, the model's id/label maps are updated and the classifier head is resized accordingly.
Returns:
-
ConvNextForImageClassification–A
transformers.ConvNextForImageClassificationinstance. Ifdataset_nameis set, the classifier head is adapted to the number of classes. The model'sconfig.id2labelandconfig.label2idare also populated.
Notes
The overall structure mirrors the ResNet implementation in
fusion_bench.modelpool.resnet_for_image_classification.
Source code in fusion_bench/modelpool/convnext_for_image_classification.py
DINOv2 for Image Classification¶
Hugging Face DINOv2 image classification model pool.
This module provides a BaseModelPool implementation that loads and saves
DINOv2 models for image classification via transformers. It optionally
reconfigures the classification head to match a dataset's class names and
overrides forward to return logits only for simpler downstream usage.
See also: fusion_bench.modelpool.convnext_for_image_classification for a
parallel implementation for ConvNeXt-based classifiers.
Dinov2ForImageClassificationPool
¶
Bases: BaseModelPool
Model pool for DINOv2 image classification models (HF Transformers).
Source code in fusion_bench/modelpool/dinov2_for_image_classification.py
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 | |
load_model(model_name_or_config, *args, **kwargs)
¶
Load a DINOv2 model described by a name, path, or DictConfig.
Accepts either a string (pretrained identifier or local path) or a
config mapping with keys: config_path, optional pretrained (bool),
and optional dataset_name to resize the classifier.
Returns:
-
–
A model whose
forwardis wrapped to return only logits to align -
–
with FusionBench expectations.
Source code in fusion_bench/modelpool/dinov2_for_image_classification.py
load_processor(*args, **kwargs)
¶
Load the paired image processor for this model pool.
Uses the configured model's identifier or config path to retrieve the
appropriate transformers.AutoImageProcessor instance. If a pretrained
model entry exists in the pool configuration, it is preferred to derive
the processor to ensure tokenization/normalization parity.
Source code in fusion_bench/modelpool/dinov2_for_image_classification.py
save_model(model, path, algorithm_config=None, description=None, base_model=None, *args, **kwargs)
¶
Save the model, processor, and an optional model card to disk.
Artifacts written to path:
- The DINOv2 model via model.save_pretrained.
- The paired image processor via AutoImageProcessor.save_pretrained.
- If algorithm_config is provided and on rank-zero, a README model card
documenting the FusionBench configuration.
Source code in fusion_bench/modelpool/dinov2_for_image_classification.py
load_transformers_dinov2(config_path, pretrained, dataset_name)
¶
Create a DINOv2 image classification model from a config or checkpoint.
Parameters:
-
config_path(str) –A model identifier or local path understood by
transformers.AutoConfig/AutoModel(e.g., "facebook/dinov2-base"). -
pretrained(bool) –If True, load weights via
from_pretrained; otherwise, build the model from config only. -
dataset_name(Optional[str]) –Optional dataset key used by FusionBench to derive class names via
get_classnames. When provided, the model's id/label maps are updated and the classifier head is resized accordingly.
Returns:
-
Dinov2ForImageClassification–A
transformers.Dinov2ForImageClassificationinstance. Ifdataset_nameis set, the classifier head is adapted to the number of classes. The model'sconfig.id2labelandconfig.label2idare also populated.
Notes
The overall structure mirrors the ConvNeXt implementation in
fusion_bench.modelpool.convnext_for_image_classification.
Source code in fusion_bench/modelpool/dinov2_for_image_classification.py
NLP Model Pool¶
GPT-2¶
HuggingFaceGPT2ClassificationPool = GPT2ForSequenceClassificationPool
module-attribute
¶
GPT2ForSequenceClassificationPool
¶
Bases: BaseModelPool
Source code in fusion_bench/modelpool/huggingface_gpt2_classification.py
Seq2Seq Language Models (Flan-T5)¶
Seq2SeqLMPool
¶
Bases: BaseModelPool
A model pool specialized for sequence-to-sequence language models.
This model pool provides management and loading capabilities for sequence-to-sequence (seq2seq) language models such as T5, BART, and mT5. It extends the base model pool functionality with seq2seq-specific features including tokenizer management and model configuration handling.
Seq2seq models are particularly useful for tasks that require generating output sequences from input sequences, such as translation, summarization, question answering, and text generation. This pool streamlines the process of loading and configuring multiple seq2seq models for fusion and ensemble scenarios.
Key Features
- Specialized loading for AutoModelForSeq2SeqLM models
- Integrated tokenizer management
- Support for model-specific keyword arguments
- Automatic dtype parsing and configuration
- Compatible with PEFT (Parameter-Efficient Fine-Tuning) adapters
Attributes:
-
_tokenizer–Configuration for the tokenizer associated with the models
-
_model_kwargs–Default keyword arguments applied to all model loading operations
Example
pool = Seq2SeqLMPool(
models={
"t5_base": "t5-base",
"t5_large": "t5-large",
"custom_model": "/path/to/local/model"
},
tokenizer={"_target_": "transformers.T5Tokenizer",
"pretrained_model_name_or_path": "t5-base"},
model_kwargs={"torch_dtype": "float16", "device_map": "auto"}
)
model = pool.load_model("t5_base")
tokenizer = pool.load_tokenizer()
Source code in fusion_bench/modelpool/seq2seq_lm/modelpool.py
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 | |
__init__(models, *, tokenizer, model_kwargs=None, **kwargs)
¶
Initialize the sequence-to-sequence language model pool.
Sets up the model pool with configurations for models, tokenizer, and default model loading parameters. Automatically processes model kwargs to handle special configurations like torch_dtype parsing.
Parameters:
-
models(DictConfig) –Configuration dictionary specifying the seq2seq models to manage. Keys are model names, values can be model paths/names or detailed configs.
-
tokenizer(Optional[DictConfig]) –Configuration for the tokenizer to use with the models. Can be a simple path/name or detailed configuration with target.
-
model_kwargs(Optional[DictConfig], default:None) –Default keyword arguments applied to all model loading operations. Common options include torch_dtype, device_map, etc. The torch_dtype field is automatically parsed from string to dtype.
-
**kwargs–Additional arguments passed to the parent BaseModelPool.
Example
Source code in fusion_bench/modelpool/seq2seq_lm/modelpool.py
load_model(model_name_or_config, *args, **kwargs)
¶
Load a sequence-to-sequence language model from the pool.
Loads a seq2seq model using the parent class loading mechanism while automatically applying the pool's default model kwargs. The method merges the pool's model_kwargs with any additional kwargs provided, giving priority to the explicitly provided kwargs.
Parameters:
-
model_name_or_config(str | DictConfig) –Either a string model name from the pool configuration or a DictConfig containing model loading parameters.
-
*args–Additional positional arguments passed to the parent load_model method.
-
**kwargs–Additional keyword arguments that override the pool's default model_kwargs. Common options include device, torch_dtype, etc.
Returns:
-
AutoModelForSeq2SeqLM–The loaded sequence-to-sequence language model.
Source code in fusion_bench/modelpool/seq2seq_lm/modelpool.py
load_tokenizer(*args, **kwargs)
¶
Load the tokenizer associated with the sequence-to-sequence models.
Loads a tokenizer based on the tokenizer configuration provided during pool initialization. The tokenizer should be compatible with the seq2seq models in the pool and is typically used for preprocessing input text and postprocessing generated output.
Parameters:
-
*args–Additional positional arguments passed to the tokenizer constructor.
-
**kwargs–Additional keyword arguments passed to the tokenizer constructor.
Returns:
-
PreTrainedTokenizer–The loaded tokenizer instance compatible with the seq2seq models in this pool.
Raises:
-
AssertionError–If no tokenizer configuration is provided.
Source code in fusion_bench/modelpool/seq2seq_lm/modelpool.py
SequenceClassificationModelPool
¶
Bases: BaseModelPool
Source code in fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py
save_model(model, path, push_to_hub=False, model_dtype=None, save_tokenizer=False, tokenizer_kwargs=None, **kwargs)
¶
Save the model to the specified path.
Parameters:
-
model(PreTrainedModel) –The model to be saved.
-
path(str) –The path where the model will be saved.
-
push_to_hub(bool, default:False) –Whether to push the model to the Hugging Face Hub. Defaults to False.
-
save_tokenizer(bool, default:False) –Whether to save the tokenizer along with the model. Defaults to False.
-
**kwargs–Additional keyword arguments passed to the
save_pretrainedmethod.
Source code in fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py
PeftModelForSeq2SeqLMPool
¶
Bases: ModelPool
Source code in fusion_bench/modelpool/PeftModelForSeq2SeqLM.py
load_model(model_config)
¶
Load a model based on the provided configuration.
The configuration options of model_config are:
- name: The name of the model. If it is "pretrained", a pretrained Seq2Seq language model is returned.
- path: The path where the model is stored.
- is_trainable: A boolean indicating whether the model parameters should be trainable. Default is
True. - merge_and_unload: A boolean indicating whether to merge and unload the PEFT model after loading. Default is
True.
Parameters:
-
model_config(str | DictConfig) –The configuration for the model. This can be either a string (name of the model) or a DictConfig object containing the model configuration.
Returns:
-
model–The loaded model. If the model name is "pretrained", it returns a pretrained Seq2Seq language model. Otherwise, it returns a PEFT model.
Source code in fusion_bench/modelpool/PeftModelForSeq2SeqLM.py
Causal Language Models (Llama, Mistral, Qwen...)¶
CausalLMPool
¶
Bases: BaseModelPool
A model pool for managing and loading causal language models.
This class provides a unified interface for loading and managing multiple causal language models, typically used in model fusion and ensemble scenarios. It supports both eager and lazy loading strategies, and handles model configuration through YAML configs or direct instantiation.
The pool can manage models from Hugging Face Hub, local paths, or custom configurations. It also provides tokenizer management and model saving capabilities with optional Hugging Face Hub integration.
Parameters:
-
models–Dictionary or configuration specifying the models to be managed. Can contain model names mapped to paths or detailed configurations.
-
tokenizer(Optional[DictConfig | str]) –Tokenizer configuration, either a string path/name or a DictConfig with detailed tokenizer settings.
-
model_kwargs(Optional[DictConfig], default:None) –Additional keyword arguments passed to model loading. Common options include torch_dtype, device_map, etc.
-
enable_lazy_loading(bool, default:False) –Whether to use lazy loading for models. When True, models are loaded as LazyStateDict objects instead of actual models, which can save memory for large model collections.
-
**kwargs–Additional arguments passed to the parent BaseModelPool.
Example
Source code in fusion_bench/modelpool/causal_lm/causal_lm.py
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 | |
get_model_kwargs()
¶
Get processed model keyword arguments for model loading.
Converts the stored model_kwargs from DictConfig to a regular dictionary
and processes special arguments like torch_dtype for proper model loading.
Returns:
-
dict–Processed keyword arguments ready to be passed to model loading functions. The torch_dtype field, if present, is converted from string to the appropriate torch dtype object.
Source code in fusion_bench/modelpool/causal_lm/causal_lm.py
get_model_path(model_name)
¶
Extract the model path from the model configuration.
Parameters:
-
model_name(str) –The name of the model as defined in the models configuration.
Returns:
-
str–The path or identifier for the model. For string configurations, returns the string directly. For dict configurations, extracts the 'pretrained_model_name_or_path' field.
Raises:
-
RuntimeError–If the model configuration is invalid or the model name is not found in the configuration.
Source code in fusion_bench/modelpool/causal_lm/causal_lm.py
load_model(model_name_or_config, *args, **kwargs)
¶
Load a causal language model from the model pool.
This method supports multiple loading strategies: 1. Loading by model name from the configured model pool 2. Loading from a direct configuration dictionary 3. Lazy loading using LazyStateDict for memory efficiency
The method automatically handles different model configuration formats and applies the appropriate loading strategy based on the enable_lazy_loading flag.
Parameters:
-
model_name_or_config(str | DictConfig) –Either a string model name that exists in the model pool configuration, or a DictConfig/dict containing the model configuration directly.
-
*args–Additional positional arguments passed to the model constructor.
-
**kwargs–Additional keyword arguments passed to the model constructor. These will be merged with the pool's model_kwargs.
Returns:
-
Union[PreTrainedModel, LazyStateDict]–Union[PreTrainedModel, LazyStateDict]: The loaded model. Returns a PreTrainedModel for normal loading or a LazyStateDict for lazy loading.
Raises:
-
RuntimeError–If the model configuration is invalid.
-
KeyError–If the model name is not found in the model pool.
Example YAML configurations
Simple string configuration:
Detailed configuration:
Source code in fusion_bench/modelpool/causal_lm/causal_lm.py
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 | |
load_tokenizer(*args, **kwargs)
¶
Load the tokenizer associated with this model pool.
Loads a tokenizer based on the tokenizer configuration provided during pool initialization. Supports both simple string paths and detailed configuration dictionaries.
Parameters:
-
*args–Additional positional arguments passed to the tokenizer constructor.
-
**kwargs–Additional keyword arguments passed to the tokenizer constructor.
Returns:
-
PreTrainedTokenizer(PreTrainedTokenizer) –The loaded tokenizer instance.
Raises:
-
AssertionError–If no tokenizer is defined in the configuration.
Example YAML configurations
Simple string configuration:
Detailed configuration:
Source code in fusion_bench/modelpool/causal_lm/causal_lm.py
save_model(model, path, push_to_hub=False, model_dtype=None, save_tokenizer=False, tokenizer_kwargs=None, tokenizer=None, algorithm_config=None, description=None, base_model_in_modelcard=True, **kwargs)
¶
Save a model to the specified path with optional tokenizer and Hub upload.
This method provides comprehensive model saving capabilities including optional tokenizer saving, dtype conversion, model card creation, and Hugging Face Hub upload. The model is saved in the standard Hugging Face format.
Parameters:
-
model(PreTrainedModel) –The PreTrainedModel instance to be saved.
-
path(str) –The local path where the model will be saved. Supports tilde expansion for home directory paths.
-
push_to_hub(bool, default:False) –Whether to push the saved model to the Hugging Face Hub. Requires proper authentication and repository permissions.
-
model_dtype(Optional[str], default:None) –Optional string specifying the target dtype for the model before saving (e.g., "float16", "bfloat16"). The model will be converted to this dtype before saving.
-
save_tokenizer(bool, default:False) –Whether to save the tokenizer alongside the model. If True, the tokenizer will be loaded using the pool's tokenizer configuration and saved to the same path.
-
tokenizer_kwargs–Additional keyword arguments for tokenizer loading when save_tokenizer is True.
-
tokenizer(Optional[PreTrainedTokenizer], default:None) –Optional pre-loaded tokenizer instance. If provided, this tokenizer will be saved regardless of the save_tokenizer flag.
-
algorithm_config(Optional[DictConfig], default:None) –Optional DictConfig containing algorithm configuration. If provided, a model card will be created with algorithm details.
-
description(Optional[str], default:None) –Optional description for the model card. If not provided and algorithm_config is given, a default description will be generated.
-
**kwargs–Additional keyword arguments passed to the model's save_pretrained method.
Example
>>> pool = CausalLMPool(models=..., tokenizer=...)
>>> model = pool.load_model("my_model")
>>> pool.save_model(
... model,
... "/path/to/save",
... save_tokenizer=True,
... model_dtype="float16",
... push_to_hub=True,
... algorithm_config=algorithm_config,
... description="Custom merged model"
... )
Source code in fusion_bench/modelpool/causal_lm/causal_lm.py
266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 | |
CausalLMBackbonePool
¶
Bases: CausalLMPool
A specialized model pool that loads only the transformer backbone layers.
This class extends CausalLMPool to provide access to just the transformer layers (backbone) of causal language models, excluding the language modeling head and embeddings. This is useful for model fusion scenarios where only the core transformer layers are needed.
The class automatically extracts the model.layers component from loaded
AutoModelForCausalLM instances, providing direct access to the transformer
blocks. Lazy loading is not supported for this pool type.
Note
This pool automatically disables lazy loading as it needs to access the internal structure of the model to extract the backbone layers.
Example
Source code in fusion_bench/modelpool/causal_lm/causal_lm.py
load_model(model_name_or_config, *args, **kwargs)
¶
Load only the transformer backbone layers from a causal language model.
This method loads a complete causal language model and then extracts only the transformer layers (backbone), discarding the embedding layers and language modeling head. This is useful for model fusion scenarios where only the core transformer computation is needed.
Parameters:
-
model_name_or_config(str | DictConfig) –Either a string model name from the pool configuration or a DictConfig with model loading parameters.
-
*args–Additional positional arguments passed to the parent load_model method.
-
**kwargs–Additional keyword arguments passed to the parent load_model method.
Returns:
-
Module(Module) –The transformer layers (typically a nn.ModuleList) containing the core transformer blocks without embeddings or output heads.
Note
Lazy loading is automatically disabled for this method as it needs to access the internal model structure to extract the layers.
Source code in fusion_bench/modelpool/causal_lm/causal_lm.py
load_peft_causal_lm(base_model_path, peft_model_path, torch_dtype='bfloat16', is_trainable=True, merge_and_unload=False)
¶
Load a causal language model with PEFT (Parameter-Efficient Fine-Tuning) adapters.
This function loads a base causal language model and applies PEFT adapters (such as LoRA, AdaLoRA, or other parameter-efficient fine-tuning methods) to create a fine-tuned model. It supports both keeping the adapters separate or merging them into the base model.
Parameters:
-
base_model_path(str) –Path or identifier for the base causal language model. Can be a Hugging Face model name or local path.
-
peft_model_path(str) –Path to the PEFT adapter configuration and weights. This should contain the adapter_config.json and adapter weights.
-
torch_dtype(str, default:'bfloat16') –The torch data type to use for the model. Common options include "float16", "bfloat16", "float32". Defaults to "bfloat16".
-
is_trainable(bool, default:True) –Whether the loaded PEFT model should be trainable. Set to False for inference-only usage to save memory.
-
merge_and_unload(bool, default:False) –Whether to merge the PEFT adapters into the base model and unload the adapter weights. When True, returns a standard PreTrainedModel instead of a PeftModel.
Returns:
-
–
Union[PeftModel, PreTrainedModel]: The loaded model with PEFT adapters. Returns a PeftModel if merge_and_unload is False, or a PreTrainedModel if the adapters are merged and unloaded.
Example
>>> # Load model with adapters for training
>>> model = load_peft_causal_lm(
... "microsoft/DialoGPT-medium",
... "/path/to/lora/adapters",
... is_trainable=True
... )
>>> # Load and merge adapters for inference
>>> merged_model = load_peft_causal_lm(
... "microsoft/DialoGPT-medium",
... "/path/to/lora/adapters",
... merge_and_unload=True,
... is_trainable=False
... )