Creating a Custom ModelPool¶
A ModelPool is the component in FusionBench responsible for managing a collection of models and their associated datasets. It handles model loading, instantiation, and serialization. This guide walks you through creating a custom ModelPool from scratch.
Understanding the BaseModelPool¶
The base class fusion_bench.modelpool.base_pool.BaseModelPool provides the foundation for all model pools. Key features it provides out of the box:
- Model registry: A dictionary (
_models) mapping model names to configurations or instances. - Dataset management: Optional
train_datasets,val_datasets, andtest_datasetsdictionaries. - Special model names: Support for
_pretrained_(base/pretrained model) and_merged_(merged model) keys. - YAML serialization: Inherits from
BaseYAMLSerializablefor config-to-object conversion. - Hydra integration: Inherits from
HydraConfigMixinfor seamless Hydra configuration handling.
Key Methods in BaseModelPool¶
| Method | Purpose |
|---|---|
load_model(model_name_or_config) |
Load a model by name (from _models) or from a direct DictConfig |
save_model(model, path) |
Save model state dict to path via torch.save |
add_model(model_name, model_or_config) |
Add a model to the pool at runtime |
get_model_config(model_name) |
Get the raw config for a model |
models() |
Generator yielding all regular (non-special) models |
named_models() |
Generator yielding (name, model) tuples |
load_pretrained_model() |
Load the model registered under _pretrained_ |
Step 1: Inherit from BaseModelPool¶
Create a new Python file in fusion_bench/modelpool/. Your class must inherit from BaseModelPool:
from typing import Any, Dict, Optional, Union
from omegaconf import DictConfig
from torch import nn
from fusion_bench import BaseModelPool
class MyCustomModelPool(BaseModelPool):
"""A custom model pool for my specific model type."""
def __init__(
self,
models: DictConfig,
*,
some_custom_param: Optional[str] = None,
**kwargs,
):
super().__init__(models, **kwargs)
self.some_custom_param = some_custom_param
Step 2: Implement load_model¶
The load_model method is the heart of any ModelPool. It takes a model name (string) or a DictConfig and returns an instantiated nn.Module.
The base implementation already handles three cases:
- String name in
_modelswith adict/DictConfigconfig -> callsinstantiate()with_target_. - Pre-instantiated
nn.Modulein_models-> returns it directly. - Direct DictConfig passed as argument -> calls
instantiate().
Override this method when you need custom loading logic (e.g., calling from_pretrained, resolving platform-specific paths, or applying post-processing):
from typing_extensions import override
@override
def load_model(
self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
) -> nn.Module:
"""Load a model from the pool with custom logic."""
if isinstance(model_name_or_config, str) and model_name_or_config in self._models:
model_name = model_name_or_config
model_config = self._models[model_name]
if isinstance(model_config, str):
# String path - use it as a HuggingFace model ID or local path
model = MyModelClass.from_pretrained(model_config, *args, **kwargs)
return model
if isinstance(model_config, nn.Module):
# Already instantiated
return model_config
# For dict/DictConfig, delegate to parent
return super().load_model(model_name_or_config, *args, **kwargs)
return super().load_model(model_name_or_config, *args, **kwargs)
Real Example: CLIPVisionModelPool¶
The CLIPVisionModelPool in fusion_bench/modelpool/clip_vision/modelpool.py demonstrates a production-ready override:
@override
def load_model(self, model_name_or_config, *args, **kwargs) -> CLIPVisionModel:
if isinstance(model_name_or_config, str) and model_name_or_config in self._models:
match self._models[model_name_or_config]:
case str() as model_path:
# Resolve path (supports HuggingFace and ModelScope)
repo_path = resolve_repo_path(model_path, repo_type="model",
platform=self._platform)
return CLIPVisionModel.from_pretrained(repo_path, *args, **kwargs)
case nn.Module() as model:
return model
case _:
return super().load_model(model_name_or_config, *args, **kwargs)
return super().load_model(model_name_or_config, *args, **kwargs)
Step 3: Optionally Override save_model¶
The default save_model uses torch.save(model.state_dict(), path). Override it when your models require a different serialization format:
@override
def save_model(self, model: nn.Module, path: str, *args, **kwargs):
"""Save the model using HuggingFace format."""
model.save_pretrained(path)
Step 4: Register _config_mapping (if adding new attributes)¶
If your ModelPool has custom attributes that should be serialized to/from YAML, add them to _config_mapping:
class MyCustomModelPool(BaseModelPool):
_config_mapping = BaseModelPool._config_mapping | {
"_custom_processor": "processor",
"_platform": "platform",
}
def __init__(self, models: DictConfig, *,
custom_processor: Optional[DictConfig] = None,
platform: str = "hf",
**kwargs):
super().__init__(models, **kwargs)
self._custom_processor = custom_processor
self._platform = platform
Alternatively, use the @auto_register_config decorator to auto-register all __init__ parameters:
from fusion_bench.mixins import auto_register_config
@auto_register_config
class MyCustomModelPool(BaseModelPool):
def __init__(self, models, custom_processor: Optional[DictConfig] = None,
platform: str = "hf", **kwargs):
super().__init__(models=models, **kwargs)
self.custom_processor = custom_processor
self.platform = platform
Step 5: Create the YAML Configuration¶
Every ModelPool needs a matching YAML config file. Place it under config/modelpool/your_pool_name/:
# config/modelpool/MyCustomModelPool/example.yaml
_target_: fusion_bench.modelpool.MyCustomModelPool
_recursive_: False
models:
_pretrained_: myorg/my-base-model
task_a: myorg/my-model-finetuned-task-a
task_b: myorg/my-model-finetuned-task-b
task_c: myorg/my-model-finetuned-task-c
# Optional: train/val/test datasets
train_datasets: null
val_datasets: null
test_datasets: null
# Custom parameters
custom_processor:
_target_: transformers.AutoProcessor.from_pretrained
pretrained_model_name_or_path: myorg/my-base-model
platform: hf
Key points:
_target_: Must point to the fully qualified Python class path._recursive_: False: Prevents recursive instantiation of nested configs.models: Dictionary mapping names to model IDs or DictConfig objects. Always include_pretrained_when using methods like Task Arithmetic.- Special names: Model names starting and ending with underscores (e.g.,
_pretrained_,_merged_) are treated as special models.
Complete Working Example¶
Here is a complete, minimal custom ModelPool for a hypothetical BERT-like classifier:
# fusion_bench/modelpool/bert_classifier_pool.py
import logging
from typing import Optional, Union
from typing_extensions import override
from omegaconf import DictConfig
from torch import nn
from transformers import AutoModelForSequenceClassification
from fusion_bench import BaseModelPool
log = logging.getLogger(__name__)
class BertClassifierModelPool(BaseModelPool):
"""Model pool for BERT-based sequence classification models."""
_config_mapping = BaseModelPool._config_mapping | {
"_num_labels": "num_labels",
}
def __init__(
self,
models: DictConfig,
*,
num_labels: Optional[int] = None,
**kwargs,
):
super().__init__(models, **kwargs)
self._num_labels = num_labels
@override
def load_model(
self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
) -> nn.Module:
if isinstance(model_name_or_config, str) and model_name_or_config in self._models:
model_config = self._models[model_name_or_config]
if isinstance(model_config, str):
log.info(f"Loading BERT classifier from: {model_config}")
kwargs.setdefault("num_labels", self._num_labels)
return AutoModelForSequenceClassification.from_pretrained(
model_config, *args, **kwargs
)
if isinstance(model_config, nn.Module):
return model_config
return super().load_model(model_name_or_config, *args, **kwargs)
@override
def save_model(self, model: nn.Module, path: str, *args, **kwargs):
"""Save using HuggingFace format."""
model.save_pretrained(path)
With config file:
# config/modelpool/BertClassifierPool/glue_tasks.yaml
_target_: fusion_bench.modelpool.bert_classifier_pool.BertClassifierModelPool
_recursive_: False
models:
_pretrained_: bert-base-uncased
sst2: user/bert-sst2-finetuned
mnli: user/bert-mnli-finetuned
num_labels: 2
Usage:
Best Practices¶
- Always call
super().__init__(): The base class handles model validation, special name checks, and Hydra integration. - Use
@overridedecorator: Fromtyping_extensions, this marks overridden methods and catches errors at runtime. - Handle all three config types: In
load_model, handlestr,nn.Module, andDictConfigcases. Delegate unexpected types tosuper().load_model(). - Log model loading: Use
rank_zero_onlylogging to avoid duplicate logs in distributed settings. - Support
*args, **kwargs: Always forward extra arguments so algorithms can pass device, dtype, or other parameters. - Validate model names: The base class validates names during
__init__, but you can add custom validation inload_model.
Next Steps¶
- See
fusion_bench/modelpool/clip_vision/modelpool.pyandfusion_bench/modelpool/resnet_for_image_classification.pyfor production examples. - Read the Custom TaskPool guide to create a matching evaluation component.