Skip to content

fusion_bench.modelpool

Base Class

BaseModelPool

Bases: HydraConfigMixin, BaseYAMLSerializable

A class for managing and interacting with a pool of models along with their associated datasets or other specifications. For example, a model pool may contain multiple models, each with its own training, validation, and testing datasets. As for the specifications, a vision model pool may contain image preprocessor, and a language model pool may contain a tokenizer.

Attributes:

  • _models (DictConfig) –

    Configuration for all models in the pool.

  • _train_datasets (Optional[DictConfig]) –

    Configuration for training datasets.

  • _val_datasets (Optional[DictConfig]) –

    Configuration for validation datasets.

  • _test_datasets (Optional[DictConfig]) –

    Configuration for testing datasets.

  • _usage_ (Optional[str]) –

    Optional usage information.

  • _version_ (Optional[str]) –

    Optional version information.

Source code in fusion_bench/modelpool/base_pool.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
class BaseModelPool(
    HydraConfigMixin,
    BaseYAMLSerializable,
):
    """
    A class for managing and interacting with a pool of models along with their associated datasets or other specifications. For example, a model pool may contain multiple models, each with its own training, validation, and testing datasets. As for the specifications, a vision model pool may contain image preprocessor, and a language model pool may contain a tokenizer.

    Attributes:
        _models (DictConfig): Configuration for all models in the pool.
        _train_datasets (Optional[DictConfig]): Configuration for training datasets.
        _val_datasets (Optional[DictConfig]): Configuration for validation datasets.
        _test_datasets (Optional[DictConfig]): Configuration for testing datasets.
        _usage_ (Optional[str]): Optional usage information.
        _version_ (Optional[str]): Optional version information.
    """

    _program = None
    _config_key = "modelpool"
    _models: Union[DictConfig, Dict[str, nn.Module]]
    _config_mapping = BaseYAMLSerializable._config_mapping | {
        "_models": "models",
        "_train_datasets": "train_datasets",
        "_val_datasets": "val_datasets",
        "_test_datasets": "test_datasets",
    }

    def __init__(
        self,
        models: Union[DictConfig, Dict[str, nn.Module], List[nn.Module]],
        *,
        train_datasets: Optional[DictConfig] = None,
        val_datasets: Optional[DictConfig] = None,
        test_datasets: Optional[DictConfig] = None,
        **kwargs,
    ):
        if isinstance(models, List):
            models = {str(model_idx): model for model_idx, model in enumerate(models)}

        if isinstance(models, dict):
            try:  # try to convert to DictConfig
                models = OmegaConf.create(models)
            except UnsupportedValueType:
                pass

        if not models:
            log.warning("Initialized BaseModelPool with empty models dictionary.")
        else:
            # Validate model names
            for model_name in models.keys():
                try:
                    validate_model_name(model_name, allow_special=True)
                except ValidationError as e:
                    log.warning(f"Invalid model name '{model_name}': {e}")

        self._models = models
        self._train_datasets = train_datasets
        self._val_datasets = val_datasets
        self._test_datasets = test_datasets
        super().__init__(**kwargs)

    @property
    def has_pretrained(self) -> bool:
        """
        Check if the model pool contains a pretrained model.

        Returns:
            bool: True if a pretrained model is available, False otherwise.
        """
        return "_pretrained_" in self._models

    @property
    def all_model_names(self) -> List[str]:
        """
        Get the names of all models in the pool, including special models.

        Returns:
            List[str]: A list of all model names.
        """
        return [name for name in self._models]

    @property
    def model_names(self) -> List[str]:
        """
        Get the names of regular models, excluding special models.

        Returns:
            List[str]: A list of regular model names.
        """
        return [name for name in self._models if not self.is_special_model(name)]

    @property
    def train_dataset_names(self) -> List[str]:
        """
        Get the names of training datasets.

        Returns:
            List[str]: A list of training dataset names.
        """
        return (
            list(self._train_datasets.keys())
            if self._train_datasets is not None
            else []
        )

    @property
    def val_dataset_names(self) -> List[str]:
        """
        Get the names of validation datasets.

        Returns:
            List[str]: A list of validation dataset names.
        """
        return list(self._val_datasets.keys()) if self._val_datasets is not None else []

    @property
    def test_dataset_names(self) -> List[str]:
        """
        Get the names of testing datasets.

        Returns:
            List[str]: A list of testing dataset names.
        """
        return (
            list(self._test_datasets.keys()) if self._test_datasets is not None else []
        )

    def __len__(self):
        return len(self.model_names)

    @staticmethod
    def is_special_model(model_name: str) -> bool:
        """
        Determine if a model is special based on its name.

        Args:
            model_name (str): The name of the model.

        Returns:
            bool: True if the model name indicates a special model, False otherwise.
        """
        return model_name.startswith("_") and model_name.endswith("_")

    def get_model_config(
        self, model_name: str, return_copy: bool = True
    ) -> Union[DictConfig, str, Any]:
        """
        Get the configuration for the specified model.

        Args:
            model_name (str): The name of the model.

        Returns:
            Union[DictConfig, str, Any]: The configuration for the specified model, which may be a DictConfig, string path, or other type.

        Raises:
            ValidationError: If model_name is invalid.
            KeyError: If model_name is not found in the pool.
        """
        # Validate model name
        validate_model_name(model_name, allow_special=True)

        # raise friendly error if model not found in the pool
        if model_name not in self._models:
            available_models = list(self._models.keys())
            raise KeyError(
                f"Model '{model_name}' not found in model pool. "
                f"Available models: {available_models}"
            )

        model_config = self._models[model_name]
        if isinstance(model_config, nn.Module):
            log.warning(
                f"Model configuration for '{model_name}' is a pre-instantiated model. "
                "Returning the model instance instead of configuration."
            )

        if return_copy:
            if isinstance(model_config, nn.Module):
                # raise performance warning
                log.warning(
                    f"Furthermore, returning a copy of the pre-instantiated model '{model_name}' may be inefficient."
                )
            model_config = deepcopy(model_config)
        return model_config

    def get_model_path(self, model_name: str) -> str:
        """
        Get the path for the specified model.

        Args:
            model_name (str): The name of the model.

        Returns:
            str: The path for the specified model.

        Raises:
            ValidationError: If model_name is invalid.
            KeyError: If model_name is not found in the pool.
            ValueError: If model configuration is not a string path.
        """
        # Validate model name
        validate_model_name(model_name, allow_special=True)

        if model_name not in self._models:
            available_models = list(self._models.keys())
            raise KeyError(
                f"Model '{model_name}' not found in model pool. "
                f"Available models: {available_models}"
            )

        if isinstance(self._models[model_name], str):
            return self._models[model_name]
        else:
            raise ValueError(
                f"Model configuration for '{model_name}' is not a string path. "
                "Try to override this method in derived modelpool class."
            )

    def load_model(
        self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
    ) -> nn.Module:
        """
        Load a model from the pool based on the provided configuration.

        Args:
            model_name_or_config (Union[str, DictConfig]): The model name or configuration.
                - If str: should be a key in self._models
                - If DictConfig: should be a configuration dict for instantiation
            *args: Additional positional arguments passed to model instantiation.
            **kwargs: Additional keyword arguments passed to model instantiation.

        Returns:
            nn.Module: The instantiated or retrieved model.
        """
        log.debug(f"Loading model: {model_name_or_config}", stacklevel=2)

        if isinstance(model_name_or_config, str):
            model_name = model_name_or_config
            # Handle string model names - lookup in the model pool
            if model_name not in self._models:
                raise KeyError(
                    f"Model '{model_name}' not found in model pool. "
                    f"Available models: {list(self._models.keys())}"
                )
            model_config = self._models[model_name]

            # Handle different types of model configurations
            match model_config:
                case dict() | DictConfig() as config:
                    # Configuration that needs instantiation
                    log.debug(f"Instantiating model '{model_name}' from configuration")
                    return instantiate(config, *args, **kwargs)

                case nn.Module() as model:
                    # Pre-instantiated model - return directly
                    log.debug(
                        f"Returning pre-instantiated model '{model_name}' of type {type(model)}"
                    )
                    return model

                case _:
                    # Unsupported model configuration type
                    raise ValueError(
                        f"Unsupported model configuration type for '{model_name}': {type(model_config)}. "
                        f"Expected nn.Module, dict, or DictConfig."
                    )

        elif isinstance(model_name_or_config, (dict, DictConfig)):
            # Direct configuration - instantiate directly
            log.debug("Instantiating model from direct DictConfig")
            model_config = model_name_or_config
            return instantiate(model_config, *args, **kwargs)

        else:
            # Unsupported input type
            raise TypeError(
                f"Unsupported input type: {type(model_name_or_config)}. "
                f"Expected str or DictConfig."
            )

    def load_pretrained_model(self, *args, **kwargs):
        assert (
            self.has_pretrained
        ), "No pretrained model available. Check `_pretrained_` is in the `models` key."
        model = self.load_model("_pretrained_", *args, **kwargs)
        return model

    def load_pretrained_or_first_model(self, *args, **kwargs):
        """
        Load the pretrained model if available, otherwise load the first available model.

        Returns:
            nn.Module: The loaded model.
        """
        if self.has_pretrained:
            model = self.load_model("_pretrained_", *args, **kwargs)
        else:
            model = self.load_model(self.model_names[0], *args, **kwargs)
        return model

    def models(self) -> Generator[nn.Module, None, None]:
        for model_name in self.model_names:
            yield self.load_model(model_name)

    def named_models(self) -> Generator[Tuple[str, nn.Module], None, None]:
        for model_name in self.model_names:
            yield model_name, self.load_model(model_name)

    @property
    def has_train_dataset(self) -> bool:
        """
        Check if the model pool contains training datasets.

        Returns:
            bool: True if training datasets are available, False otherwise.
        """
        return self._train_datasets is not None and len(self._train_datasets) > 0

    @property
    def has_val_dataset(self) -> bool:
        """
        Check if the model pool contains validation datasets.

        Returns:
            bool: True if validation datasets are available, False otherwise.
        """
        return self._val_datasets is not None and len(self._val_datasets) > 0

    @property
    def has_test_dataset(self) -> bool:
        """
        Check if the model pool contains testing datasets.

        Returns:
            bool: True if testing datasets are available, False otherwise.
        """
        return self._test_datasets is not None and len(self._test_datasets) > 0

    def load_train_dataset(self, dataset_name: str, *args, **kwargs) -> Dataset:
        """
        Load the training dataset for the specified model.

        Args:
            dataset_name (str): The name of the model.

        Returns:
            Dataset: The instantiated training dataset.
        """
        return instantiate(self._train_datasets[dataset_name], *args, **kwargs)

    def train_datasets(self):
        for dataset_name in self.train_dataset_names:
            yield self.load_train_dataset(dataset_name)

    def load_val_dataset(self, dataset_name: str, *args, **kwargs) -> Dataset:
        """
        Load the validation dataset for the specified model.

        Args:
            dataset_name (str): The name of the model.

        Returns:
            Dataset: The instantiated validation dataset.
        """
        return instantiate(self._val_datasets[dataset_name], *args, **kwargs)

    def val_datasets(self):
        for dataset_name in self.val_dataset_names:
            yield self.load_val_dataset(dataset_name)

    def load_test_dataset(self, dataset_name: str, *args, **kwargs) -> Dataset:
        """
        Load the testing dataset for the specified model.

        Args:
            dataset_name (str): The name of the model.

        Returns:
            Dataset: The instantiated testing dataset.
        """
        return instantiate(self._test_datasets[dataset_name], *args, **kwargs)

    def test_datasets(self):
        for dataset_name in self.test_dataset_names:
            yield self.load_test_dataset(dataset_name)

    def save_model(self, model: nn.Module, path: str, *args, **kwargs):
        """
        Save the state dictionary of the model to the specified path.

        Args:
            model (nn.Module): The model whose state dictionary is to be saved.
            path (str): The path where the state dictionary will be saved.
        """
        with timeit_context(f"Saving the state dict of model to {path}"):
            torch.save(model.state_dict(), path)

    def __contains__(self, model_name: str) -> bool:
        """
        Check if a model with the given name exists in the model pool.

        Examples:
            >>> modelpool = BaseModelPool(models={"modelA": ..., "modelB": ...})
            >>> "modelA" in modelpool
            True
            >>> "modelC" in modelpool
            False

        Args:
            model_name (str): The name of the model to check.

        Returns:
            bool: True if the model exists, False otherwise.
        """
        if self._models is None:
            raise RuntimeError("Model pool is not initialized")
        validate_model_name(model_name, allow_special=True)
        return model_name in self._models

all_model_names property

Get the names of all models in the pool, including special models.

Returns:

  • List[str]

    List[str]: A list of all model names.

has_pretrained property

Check if the model pool contains a pretrained model.

Returns:

  • bool ( bool ) –

    True if a pretrained model is available, False otherwise.

has_test_dataset property

Check if the model pool contains testing datasets.

Returns:

  • bool ( bool ) –

    True if testing datasets are available, False otherwise.

has_train_dataset property

Check if the model pool contains training datasets.

Returns:

  • bool ( bool ) –

    True if training datasets are available, False otherwise.

has_val_dataset property

Check if the model pool contains validation datasets.

Returns:

  • bool ( bool ) –

    True if validation datasets are available, False otherwise.

model_names property

Get the names of regular models, excluding special models.

Returns:

  • List[str]

    List[str]: A list of regular model names.

test_dataset_names property

Get the names of testing datasets.

Returns:

  • List[str]

    List[str]: A list of testing dataset names.

train_dataset_names property

Get the names of training datasets.

Returns:

  • List[str]

    List[str]: A list of training dataset names.

val_dataset_names property

Get the names of validation datasets.

Returns:

  • List[str]

    List[str]: A list of validation dataset names.

__contains__(model_name)

Check if a model with the given name exists in the model pool.

Examples:

>>> modelpool = BaseModelPool(models={"modelA": ..., "modelB": ...})
>>> "modelA" in modelpool
True
>>> "modelC" in modelpool
False

Parameters:

  • model_name (str) –

    The name of the model to check.

Returns:

  • bool ( bool ) –

    True if the model exists, False otherwise.

Source code in fusion_bench/modelpool/base_pool.py
def __contains__(self, model_name: str) -> bool:
    """
    Check if a model with the given name exists in the model pool.

    Examples:
        >>> modelpool = BaseModelPool(models={"modelA": ..., "modelB": ...})
        >>> "modelA" in modelpool
        True
        >>> "modelC" in modelpool
        False

    Args:
        model_name (str): The name of the model to check.

    Returns:
        bool: True if the model exists, False otherwise.
    """
    if self._models is None:
        raise RuntimeError("Model pool is not initialized")
    validate_model_name(model_name, allow_special=True)
    return model_name in self._models

get_model_config(model_name, return_copy=True)

Get the configuration for the specified model.

Parameters:

  • model_name (str) –

    The name of the model.

Returns:

  • Union[DictConfig, str, Any]

    Union[DictConfig, str, Any]: The configuration for the specified model, which may be a DictConfig, string path, or other type.

Raises:

  • ValidationError

    If model_name is invalid.

  • KeyError

    If model_name is not found in the pool.

Source code in fusion_bench/modelpool/base_pool.py
def get_model_config(
    self, model_name: str, return_copy: bool = True
) -> Union[DictConfig, str, Any]:
    """
    Get the configuration for the specified model.

    Args:
        model_name (str): The name of the model.

    Returns:
        Union[DictConfig, str, Any]: The configuration for the specified model, which may be a DictConfig, string path, or other type.

    Raises:
        ValidationError: If model_name is invalid.
        KeyError: If model_name is not found in the pool.
    """
    # Validate model name
    validate_model_name(model_name, allow_special=True)

    # raise friendly error if model not found in the pool
    if model_name not in self._models:
        available_models = list(self._models.keys())
        raise KeyError(
            f"Model '{model_name}' not found in model pool. "
            f"Available models: {available_models}"
        )

    model_config = self._models[model_name]
    if isinstance(model_config, nn.Module):
        log.warning(
            f"Model configuration for '{model_name}' is a pre-instantiated model. "
            "Returning the model instance instead of configuration."
        )

    if return_copy:
        if isinstance(model_config, nn.Module):
            # raise performance warning
            log.warning(
                f"Furthermore, returning a copy of the pre-instantiated model '{model_name}' may be inefficient."
            )
        model_config = deepcopy(model_config)
    return model_config

get_model_path(model_name)

Get the path for the specified model.

Parameters:

  • model_name (str) –

    The name of the model.

Returns:

  • str ( str ) –

    The path for the specified model.

Raises:

  • ValidationError

    If model_name is invalid.

  • KeyError

    If model_name is not found in the pool.

  • ValueError

    If model configuration is not a string path.

Source code in fusion_bench/modelpool/base_pool.py
def get_model_path(self, model_name: str) -> str:
    """
    Get the path for the specified model.

    Args:
        model_name (str): The name of the model.

    Returns:
        str: The path for the specified model.

    Raises:
        ValidationError: If model_name is invalid.
        KeyError: If model_name is not found in the pool.
        ValueError: If model configuration is not a string path.
    """
    # Validate model name
    validate_model_name(model_name, allow_special=True)

    if model_name not in self._models:
        available_models = list(self._models.keys())
        raise KeyError(
            f"Model '{model_name}' not found in model pool. "
            f"Available models: {available_models}"
        )

    if isinstance(self._models[model_name], str):
        return self._models[model_name]
    else:
        raise ValueError(
            f"Model configuration for '{model_name}' is not a string path. "
            "Try to override this method in derived modelpool class."
        )

is_special_model(model_name) staticmethod

Determine if a model is special based on its name.

Parameters:

  • model_name (str) –

    The name of the model.

Returns:

  • bool ( bool ) –

    True if the model name indicates a special model, False otherwise.

Source code in fusion_bench/modelpool/base_pool.py
@staticmethod
def is_special_model(model_name: str) -> bool:
    """
    Determine if a model is special based on its name.

    Args:
        model_name (str): The name of the model.

    Returns:
        bool: True if the model name indicates a special model, False otherwise.
    """
    return model_name.startswith("_") and model_name.endswith("_")

load_model(model_name_or_config, *args, **kwargs)

Load a model from the pool based on the provided configuration.

Parameters:

  • model_name_or_config (Union[str, DictConfig]) –

    The model name or configuration. - If str: should be a key in self._models - If DictConfig: should be a configuration dict for instantiation

  • *args

    Additional positional arguments passed to model instantiation.

  • **kwargs

    Additional keyword arguments passed to model instantiation.

Returns:

  • Module

    nn.Module: The instantiated or retrieved model.

Source code in fusion_bench/modelpool/base_pool.py
def load_model(
    self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
) -> nn.Module:
    """
    Load a model from the pool based on the provided configuration.

    Args:
        model_name_or_config (Union[str, DictConfig]): The model name or configuration.
            - If str: should be a key in self._models
            - If DictConfig: should be a configuration dict for instantiation
        *args: Additional positional arguments passed to model instantiation.
        **kwargs: Additional keyword arguments passed to model instantiation.

    Returns:
        nn.Module: The instantiated or retrieved model.
    """
    log.debug(f"Loading model: {model_name_or_config}", stacklevel=2)

    if isinstance(model_name_or_config, str):
        model_name = model_name_or_config
        # Handle string model names - lookup in the model pool
        if model_name not in self._models:
            raise KeyError(
                f"Model '{model_name}' not found in model pool. "
                f"Available models: {list(self._models.keys())}"
            )
        model_config = self._models[model_name]

        # Handle different types of model configurations
        match model_config:
            case dict() | DictConfig() as config:
                # Configuration that needs instantiation
                log.debug(f"Instantiating model '{model_name}' from configuration")
                return instantiate(config, *args, **kwargs)

            case nn.Module() as model:
                # Pre-instantiated model - return directly
                log.debug(
                    f"Returning pre-instantiated model '{model_name}' of type {type(model)}"
                )
                return model

            case _:
                # Unsupported model configuration type
                raise ValueError(
                    f"Unsupported model configuration type for '{model_name}': {type(model_config)}. "
                    f"Expected nn.Module, dict, or DictConfig."
                )

    elif isinstance(model_name_or_config, (dict, DictConfig)):
        # Direct configuration - instantiate directly
        log.debug("Instantiating model from direct DictConfig")
        model_config = model_name_or_config
        return instantiate(model_config, *args, **kwargs)

    else:
        # Unsupported input type
        raise TypeError(
            f"Unsupported input type: {type(model_name_or_config)}. "
            f"Expected str or DictConfig."
        )

load_pretrained_or_first_model(*args, **kwargs)

Load the pretrained model if available, otherwise load the first available model.

Returns:

  • nn.Module: The loaded model.

Source code in fusion_bench/modelpool/base_pool.py
def load_pretrained_or_first_model(self, *args, **kwargs):
    """
    Load the pretrained model if available, otherwise load the first available model.

    Returns:
        nn.Module: The loaded model.
    """
    if self.has_pretrained:
        model = self.load_model("_pretrained_", *args, **kwargs)
    else:
        model = self.load_model(self.model_names[0], *args, **kwargs)
    return model

load_test_dataset(dataset_name, *args, **kwargs)

Load the testing dataset for the specified model.

Parameters:

  • dataset_name (str) –

    The name of the model.

Returns:

  • Dataset ( Dataset ) –

    The instantiated testing dataset.

Source code in fusion_bench/modelpool/base_pool.py
def load_test_dataset(self, dataset_name: str, *args, **kwargs) -> Dataset:
    """
    Load the testing dataset for the specified model.

    Args:
        dataset_name (str): The name of the model.

    Returns:
        Dataset: The instantiated testing dataset.
    """
    return instantiate(self._test_datasets[dataset_name], *args, **kwargs)

load_train_dataset(dataset_name, *args, **kwargs)

Load the training dataset for the specified model.

Parameters:

  • dataset_name (str) –

    The name of the model.

Returns:

  • Dataset ( Dataset ) –

    The instantiated training dataset.

Source code in fusion_bench/modelpool/base_pool.py
def load_train_dataset(self, dataset_name: str, *args, **kwargs) -> Dataset:
    """
    Load the training dataset for the specified model.

    Args:
        dataset_name (str): The name of the model.

    Returns:
        Dataset: The instantiated training dataset.
    """
    return instantiate(self._train_datasets[dataset_name], *args, **kwargs)

load_val_dataset(dataset_name, *args, **kwargs)

Load the validation dataset for the specified model.

Parameters:

  • dataset_name (str) –

    The name of the model.

Returns:

  • Dataset ( Dataset ) –

    The instantiated validation dataset.

Source code in fusion_bench/modelpool/base_pool.py
def load_val_dataset(self, dataset_name: str, *args, **kwargs) -> Dataset:
    """
    Load the validation dataset for the specified model.

    Args:
        dataset_name (str): The name of the model.

    Returns:
        Dataset: The instantiated validation dataset.
    """
    return instantiate(self._val_datasets[dataset_name], *args, **kwargs)

save_model(model, path, *args, **kwargs)

Save the state dictionary of the model to the specified path.

Parameters:

  • model (Module) –

    The model whose state dictionary is to be saved.

  • path (str) –

    The path where the state dictionary will be saved.

Source code in fusion_bench/modelpool/base_pool.py
def save_model(self, model: nn.Module, path: str, *args, **kwargs):
    """
    Save the state dictionary of the model to the specified path.

    Args:
        model (nn.Module): The model whose state dictionary is to be saved.
        path (str): The path where the state dictionary will be saved.
    """
    with timeit_context(f"Saving the state dict of model to {path}"):
        torch.save(model.state_dict(), path)

Vision Model Pool

NYUv2 Tasks (ResNet)

NYUv2ModelPool

Bases: ModelPool

Source code in fusion_bench/modelpool/nyuv2_modelpool.py
class NYUv2ModelPool(ModelPool):
    def load_model(
        self, model_config: str | DictConfig, encoder_only: bool = True
    ) -> ResnetDilated | NYUv2Model:
        if isinstance(model_config, str):
            model_config = self.get_model_config(model_config)

        encoder = resnet_dilated(model_config.encoder)
        decoders = nn.ModuleDict(
            {
                task: DeepLabHead(2048, NYUv2.num_out_channels[task])
                for task in model_config.decoders
            }
        )
        model = NYUv2Model(encoder=encoder, decoders=decoders)
        if model_config.get("ckpt_path", None) is not None:
            ckpt = torch.load(model_config.ckpt_path, map_location="cpu")
            if "state_dict" in ckpt:
                ckpt = ckpt["state_dict"]
            model.load_state_dict(ckpt, strict=False)

        if encoder_only:
            return model.encoder
        else:
            return model

CLIP Vision Encoder

CLIPVisionModelPool

Bases: BaseModelPool

A model pool for managing Hugging Face's CLIP Vision models.

This class extends the base ModelPool class and overrides its methods to handle the specifics of the CLIP Vision models provided by the Hugging Face Transformers library.

Source code in fusion_bench/modelpool/clip_vision/modelpool.py
class CLIPVisionModelPool(BaseModelPool):
    """
    A model pool for managing Hugging Face's CLIP Vision models.

    This class extends the base `ModelPool` class and overrides its methods to handle
    the specifics of the CLIP Vision models provided by the Hugging Face Transformers library.
    """

    _config_mapping = BaseModelPool._config_mapping | {
        "_processor": "processor",
        "_platform": "hf",
    }

    def __init__(
        self,
        models: DictConfig,
        *,
        processor: Optional[DictConfig] = None,
        platform: Literal["hf", "huggingface", "modelscope"] = "hf",
        **kwargs,
    ):
        super().__init__(models, **kwargs)
        self._processor = processor
        self._platform = platform

    def load_processor(self, *args, **kwargs) -> CLIPProcessor:
        assert self._processor is not None, "Processor is not defined in the config"
        if isinstance(self._processor, str):
            if rank_zero_only.rank == 0:
                log.info(f"Loading `transformers.CLIPProcessor`: {self._processor}")
            repo_path = resolve_repo_path(
                repo_id=self._processor, repo_type="model", platform=self._platform
            )
            processor = CLIPProcessor.from_pretrained(repo_path, *args, **kwargs)
        else:
            processor = instantiate(self._processor, *args, **kwargs)
        return processor

    def load_clip_model(self, model_name: str, *args, **kwargs) -> CLIPModel:
        model_config = self._models[model_name]

        if isinstance(model_config, str):
            if rank_zero_only.rank == 0:
                log.info(f"Loading `transformers.CLIPModel`: {model_config}")
            repo_path = resolve_repo_path(
                repo_id=model_config, repo_type="model", platform=self._platform
            )
            clip_model = CLIPModel.from_pretrained(repo_path, *args, **kwargs)
            return clip_model
        else:
            assert isinstance(
                model_config, DictConfig
            ), "Model config must be a DictConfig"
            model_config = deepcopy(model_config)
            with open_dict(model_config):
                model_config._target_ = "transformers.CLIPModel.from_pretrained"
            clip_model = instantiate(model_config, *args, **kwargs)
            return clip_model

    @override
    def save_model(self, model: CLIPVisionModel, path: str):
        """
        Save a CLIP Vision model to the given path.

        Args:
            model (CLIPVisionModel): The model to save.
            path (str): The path to save the model to.
        """
        with timeit_context(f'Saving clip vision model to "{path}"'):
            model.save_pretrained(path)

    def load_model(
        self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
    ) -> CLIPVisionModel:
        """
        Load a CLIPVisionModel from the model pool with support for various configuration formats.

        This method provides flexible model loading capabilities, handling different types of model
        configurations including string paths, pre-instantiated models, and complex configurations.

        Supported configuration formats:
        1. String model paths (e.g., Hugging Face model IDs)
        2. Pre-instantiated nn.Module objects
        3. DictConfig objects for complex configurations

        Example configuration:
        ```yaml
        models:
            # Simple string paths to Hugging Face models
            cifar10: tanganke/clip-vit-base-patch32_cifar10
            sun397: tanganke/clip-vit-base-patch32_sun397
            stanford-cars: tanganke/clip-vit-base-patch32_stanford-cars

            # Complex configuration with additional parameters
            custom_model:
                _target_: transformers.CLIPVisionModel.from_pretrained
                pretrained_model_name_or_path: openai/clip-vit-base-patch32
                torch_dtype: float16
        ```

        Args:
            model_name_or_config (Union[str, DictConfig]): Either a model name from the pool
                or a configuration dictionary for instantiating the model.
            *args: Additional positional arguments passed to model loading/instantiation.
            **kwargs: Additional keyword arguments passed to model loading/instantiation.

        Returns:
            CLIPVisionModel: The loaded CLIPVisionModel instance.
        """
        # Check if we have a string model name that exists in our model pool
        if (
            isinstance(model_name_or_config, str)
            and model_name_or_config in self._models
        ):
            model_name = model_name_or_config

            # handle different model configuration types
            match self._models[model_name_or_config]:
                case str() as model_path:
                    # Handle string model paths (e.g., Hugging Face model IDs)
                    if rank_zero_only.rank == 0:
                        log.info(
                            f"Loading model `{model_name}` of type `transformers.CLIPVisionModel` from {model_path}"
                        )
                    # Resolve the repository path (supports both HuggingFace and ModelScope)
                    repo_path = resolve_repo_path(
                        model_path, repo_type="model", platform=self._platform
                    )
                    # Load and return the CLIPVisionModel from the resolved path
                    return CLIPVisionModel.from_pretrained(repo_path, *args, **kwargs)

                case nn.Module() as model:
                    # Handle pre-instantiated model objects
                    if rank_zero_only.rank == 0:
                        log.info(
                            f"Returning existing model `{model_name}` of type {type(model)}"
                        )
                    return model

                case _:
                    # Handle other configuration types (e.g., DictConfig) via parent class
                    # This fallback prevents returning None when the model config doesn't
                    # match the expected string or nn.Module patterns
                    return super().load_model(model_name_or_config, *args, **kwargs)

        # If model_name_or_config is not a string in our pool, delegate to parent class
        # This handles cases where model_name_or_config is a DictConfig directly
        return super().load_model(model_name_or_config, *args, **kwargs)

    def load_train_dataset(self, dataset_name: str, *args, **kwargs):
        dataset_config = self._train_datasets[dataset_name]
        if isinstance(dataset_config, str):
            if rank_zero_only.rank == 0:
                log.info(
                    f"Loading train dataset using `datasets.load_dataset`: {dataset_config}"
                )
            dataset = self._load_dataset(dataset_config, split="train")
        else:
            dataset = super().load_train_dataset(dataset_name, *args, **kwargs)
        return dataset

    def load_val_dataset(self, dataset_name: str, *args, **kwargs):
        dataset_config = self._val_datasets[dataset_name]
        if isinstance(dataset_config, str):
            if rank_zero_only.rank == 0:
                log.info(
                    f"Loading validation dataset using `datasets.load_dataset`: {dataset_config}"
                )
            dataset = self._load_dataset(dataset_config, split="validation")
        else:
            dataset = super().load_val_dataset(dataset_name, *args, **kwargs)
        return dataset

    def load_test_dataset(self, dataset_name: str, *args, **kwargs):
        dataset_config = self._test_datasets[dataset_name]
        if isinstance(dataset_config, str):
            if rank_zero_only.rank == 0:
                log.info(
                    f"Loading test dataset using `datasets.load_dataset`: {dataset_config}"
                )
            dataset = self._load_dataset(dataset_config, split="test")
        else:
            dataset = super().load_test_dataset(dataset_name, *args, **kwargs)
        return dataset

    def _load_dataset(self, name: str, split: str):
        """
        Load a dataset by its name and split.

        Args:
            dataset_name (str): The name of the dataset.
            split (str): The split of the dataset to load (e.g., "train", "validation", "test").

        Returns:
            Dataset: The loaded dataset.
        """
        datset_dir = resolve_repo_path(
            name, repo_type="dataset", platform=self._platform
        )
        dataset = load_dataset(datset_dir, split=split)
        return dataset
load_model(model_name_or_config, *args, **kwargs)

Load a CLIPVisionModel from the model pool with support for various configuration formats.

This method provides flexible model loading capabilities, handling different types of model configurations including string paths, pre-instantiated models, and complex configurations.

Supported configuration formats: 1. String model paths (e.g., Hugging Face model IDs) 2. Pre-instantiated nn.Module objects 3. DictConfig objects for complex configurations

Example configuration:

models:
    # Simple string paths to Hugging Face models
    cifar10: tanganke/clip-vit-base-patch32_cifar10
    sun397: tanganke/clip-vit-base-patch32_sun397
    stanford-cars: tanganke/clip-vit-base-patch32_stanford-cars

    # Complex configuration with additional parameters
    custom_model:
        _target_: transformers.CLIPVisionModel.from_pretrained
        pretrained_model_name_or_path: openai/clip-vit-base-patch32
        torch_dtype: float16

Parameters:

  • model_name_or_config (Union[str, DictConfig]) –

    Either a model name from the pool or a configuration dictionary for instantiating the model.

  • *args

    Additional positional arguments passed to model loading/instantiation.

  • **kwargs

    Additional keyword arguments passed to model loading/instantiation.

Returns:

  • CLIPVisionModel ( CLIPVisionModel ) –

    The loaded CLIPVisionModel instance.

Source code in fusion_bench/modelpool/clip_vision/modelpool.py
def load_model(
    self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
) -> CLIPVisionModel:
    """
    Load a CLIPVisionModel from the model pool with support for various configuration formats.

    This method provides flexible model loading capabilities, handling different types of model
    configurations including string paths, pre-instantiated models, and complex configurations.

    Supported configuration formats:
    1. String model paths (e.g., Hugging Face model IDs)
    2. Pre-instantiated nn.Module objects
    3. DictConfig objects for complex configurations

    Example configuration:
    ```yaml
    models:
        # Simple string paths to Hugging Face models
        cifar10: tanganke/clip-vit-base-patch32_cifar10
        sun397: tanganke/clip-vit-base-patch32_sun397
        stanford-cars: tanganke/clip-vit-base-patch32_stanford-cars

        # Complex configuration with additional parameters
        custom_model:
            _target_: transformers.CLIPVisionModel.from_pretrained
            pretrained_model_name_or_path: openai/clip-vit-base-patch32
            torch_dtype: float16
    ```

    Args:
        model_name_or_config (Union[str, DictConfig]): Either a model name from the pool
            or a configuration dictionary for instantiating the model.
        *args: Additional positional arguments passed to model loading/instantiation.
        **kwargs: Additional keyword arguments passed to model loading/instantiation.

    Returns:
        CLIPVisionModel: The loaded CLIPVisionModel instance.
    """
    # Check if we have a string model name that exists in our model pool
    if (
        isinstance(model_name_or_config, str)
        and model_name_or_config in self._models
    ):
        model_name = model_name_or_config

        # handle different model configuration types
        match self._models[model_name_or_config]:
            case str() as model_path:
                # Handle string model paths (e.g., Hugging Face model IDs)
                if rank_zero_only.rank == 0:
                    log.info(
                        f"Loading model `{model_name}` of type `transformers.CLIPVisionModel` from {model_path}"
                    )
                # Resolve the repository path (supports both HuggingFace and ModelScope)
                repo_path = resolve_repo_path(
                    model_path, repo_type="model", platform=self._platform
                )
                # Load and return the CLIPVisionModel from the resolved path
                return CLIPVisionModel.from_pretrained(repo_path, *args, **kwargs)

            case nn.Module() as model:
                # Handle pre-instantiated model objects
                if rank_zero_only.rank == 0:
                    log.info(
                        f"Returning existing model `{model_name}` of type {type(model)}"
                    )
                return model

            case _:
                # Handle other configuration types (e.g., DictConfig) via parent class
                # This fallback prevents returning None when the model config doesn't
                # match the expected string or nn.Module patterns
                return super().load_model(model_name_or_config, *args, **kwargs)

    # If model_name_or_config is not a string in our pool, delegate to parent class
    # This handles cases where model_name_or_config is a DictConfig directly
    return super().load_model(model_name_or_config, *args, **kwargs)
save_model(model, path)

Save a CLIP Vision model to the given path.

Parameters:

  • model (CLIPVisionModel) –

    The model to save.

  • path (str) –

    The path to save the model to.

Source code in fusion_bench/modelpool/clip_vision/modelpool.py
@override
def save_model(self, model: CLIPVisionModel, path: str):
    """
    Save a CLIP Vision model to the given path.

    Args:
        model (CLIPVisionModel): The model to save.
        path (str): The path to save the model to.
    """
    with timeit_context(f'Saving clip vision model to "{path}"'):
        model.save_pretrained(path)

OpenCLIP Vision Encoder

OpenCLIPVisionModelPool

Bases: BaseModelPool

A model pool for managing OpenCLIP Vision models (models from task vector paper).

Source code in fusion_bench/modelpool/openclip_vision/modelpool.py
class OpenCLIPVisionModelPool(BaseModelPool):
    """
    A model pool for managing OpenCLIP Vision models (models from task vector paper).
    """

    _train_processor = None
    _test_processor = None

    def __init__(
        self,
        models: DictConfig,
        classification_heads: Optional[DictConfig] = None,
        **kwargs,
    ):
        super().__init__(models, **kwargs)
        self._classification_heads = classification_heads

    @property
    def train_processor(self):
        if self._train_processor is None:
            encoder: ImageEncoder = self.load_pretrained_or_first_model()
            self._train_processor = encoder.train_preprocess
            if self._test_processor is None:
                self._test_processor = encoder.val_preprocess
        return self._train_processor

    @property
    def test_processor(self):
        if self._test_processor is None:
            encoder: ImageEncoder = self.load_pretrained_or_first_model()
            if self._train_processor is None:
                self._train_processor = encoder.train_preprocess
            self._test_processor = encoder.val_preprocess
        return self._test_processor

    def load_model(
        self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
    ) -> ImageEncoder:
        R"""
        The model config can be:

        - A string, which is the path to the model checkpoint in pickle format. Load directly using `torch.load`.
        - {"model_name": str, "pickle_path": str}, load the model from the binary file (pickle format). This will first construct the model using `ImageEncoder(model_name)`, and then load the state dict from model located in the pickle file.
        - {"model_name": str, "state_dict_path": str}, load the model from the state dict file. This will first construct the model using `ImageEncoder(model_name)`, and then load the state dict from the file.
        - Default, load the model using `instantiate` from hydra.
        """
        if (
            isinstance(model_name_or_config, str)
            and model_name_or_config in self._models
        ):
            model_config = self._models[model_name_or_config]
        else:
            model_config = model_name_or_config
        if isinstance(model_config, DictConfig):
            model_config = OmegaConf.to_container(model_config, resolve=True)

        if isinstance(model_config, str):
            # the model config is a string, which is the path to the model checkpoint in pickle format
            # load the model using `torch.load`
            # this is the original usage in the task arithmetic codebase
            _check_and_redirect_open_clip_modeling()
            log.info(f"loading ImageEncoder from {model_config}")
            weights_only = kwargs["weights_only"] if "weights_only" in kwargs else False
            try:
                encoder = torch.load(
                    model_config, weights_only=weights_only, *args, **kwargs
                )
            except RuntimeError as e:
                encoder = pickle.load(open(model_config, "rb"))
        elif is_expr_match({"model_name": str, "pickle_path": str}, model_config):
            # the model config is a dictionary with the following keys:
            # - model_name: str, the name of the model
            # - pickle_path: str, the path to the binary file (pickle format)
            # load the model from the binary file (pickle format)
            # this is useful when you use a newer version of torchvision
            _check_and_redirect_open_clip_modeling()
            log.info(
                f"loading ImageEncoder of {model_config['model_name']} from {model_config['pickle_path']}"
            )
            weights_only = kwargs["weights_only"] if "weights_only" in kwargs else False
            try:
                encoder = torch.load(
                    model_config["pickle_path"],
                    weights_only=weights_only,
                    *args,
                    **kwargs,
                )
            except RuntimeError as e:
                encoder = pickle.load(open(model_config["pickle_path"], "rb"))
            _encoder = ImageEncoder(model_config["model_name"])
            _encoder.load_state_dict(encoder.state_dict())
            encoder = _encoder
        elif is_expr_match({"model_name": str, "state_dict_path": str}, model_config):
            # the model config is a dictionary with the following keys:
            # - model_name: str, the name of the model
            # - state_dict_path: str, the path to the state dict file
            # load the model from the state dict file
            log.info(
                f"loading ImageEncoder of {model_config['model_name']} from {model_config['state_dict_path']}"
            )
            encoder = ImageEncoder(model_config["model_name"])
            encoder.load_state_dict(
                torch.load(
                    model_config["state_dict_path"], weights_only=True, *args, **kwargs
                )
            )
        elif isinstance(model_config, nn.Module):
            # the model config is an existing model
            log.info(f"Returning existing model: {model_config}")
            encoder = model_config
        else:
            encoder = super().load_model(model_name_or_config, *args, **kwargs)
        encoder = cast(ImageEncoder, encoder)

        # setup the train and test processors
        if self._train_processor is None and hasattr(encoder, "train_preprocess"):
            self._train_processor = encoder.train_preprocess
        if self._test_processor is None and hasattr(encoder, "val_preprocess"):
            self._test_processor = encoder.val_preprocess

        return encoder

    def load_classification_head(
        self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
    ) -> ClassificationHead:
        R"""
        The model config can be:

        - A string, which is the path to the model checkpoint in pickle format. Load directly using `torch.load`.
        - Default, load the model using `instantiate` from hydra.
        """
        if (
            isinstance(model_name_or_config, str)
            and model_name_or_config in self._classification_heads
        ):
            model_config = self._classification_heads[model_name_or_config]
        else:
            model_config = model_name_or_config

        head = load_classifier_head(model_config, *args, **kwargs)
        return head

    def load_train_dataset(self, dataset_name: str, *args, **kwargs):
        dataset_config = self._train_datasets[dataset_name]
        if isinstance(dataset_config, str):
            log.info(
                f"Loading train dataset using `datasets.load_dataset`: {dataset_config}"
            )
            dataset = load_dataset(dataset_config, split="train")
        else:
            dataset = super().load_train_dataset(dataset_name, *args, **kwargs)
        return dataset

    def load_val_dataset(self, dataset_name: str, *args, **kwargs):
        dataset_config = self._val_datasets[dataset_name]
        if isinstance(dataset_config, str):
            log.info(
                f"Loading validation dataset using `datasets.load_dataset`: {dataset_config}"
            )
            dataset = load_dataset(dataset_config, split="validation")
        else:
            dataset = super().load_val_dataset(dataset_name, *args, **kwargs)
        return dataset

    def load_test_dataset(self, dataset_name: str, *args, **kwargs):
        dataset_config = self._test_datasets[dataset_name]
        if isinstance(dataset_config, str):
            log.info(
                f"Loading test dataset using `datasets.load_dataset`: {dataset_config}"
            )
            dataset = load_dataset(dataset_config, split="test")
        else:
            dataset = super().load_test_dataset(dataset_name, *args, **kwargs)
        return dataset
load_classification_head(model_name_or_config, *args, **kwargs)

The model config can be:

  • A string, which is the path to the model checkpoint in pickle format. Load directly using torch.load.
  • Default, load the model using instantiate from hydra.
Source code in fusion_bench/modelpool/openclip_vision/modelpool.py
def load_classification_head(
    self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
) -> ClassificationHead:
    R"""
    The model config can be:

    - A string, which is the path to the model checkpoint in pickle format. Load directly using `torch.load`.
    - Default, load the model using `instantiate` from hydra.
    """
    if (
        isinstance(model_name_or_config, str)
        and model_name_or_config in self._classification_heads
    ):
        model_config = self._classification_heads[model_name_or_config]
    else:
        model_config = model_name_or_config

    head = load_classifier_head(model_config, *args, **kwargs)
    return head
load_model(model_name_or_config, *args, **kwargs)

The model config can be:

  • A string, which is the path to the model checkpoint in pickle format. Load directly using torch.load.
  • {"model_name": str, "pickle_path": str}, load the model from the binary file (pickle format). This will first construct the model using ImageEncoder(model_name), and then load the state dict from model located in the pickle file.
  • {"model_name": str, "state_dict_path": str}, load the model from the state dict file. This will first construct the model using ImageEncoder(model_name), and then load the state dict from the file.
  • Default, load the model using instantiate from hydra.
Source code in fusion_bench/modelpool/openclip_vision/modelpool.py
def load_model(
    self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
) -> ImageEncoder:
    R"""
    The model config can be:

    - A string, which is the path to the model checkpoint in pickle format. Load directly using `torch.load`.
    - {"model_name": str, "pickle_path": str}, load the model from the binary file (pickle format). This will first construct the model using `ImageEncoder(model_name)`, and then load the state dict from model located in the pickle file.
    - {"model_name": str, "state_dict_path": str}, load the model from the state dict file. This will first construct the model using `ImageEncoder(model_name)`, and then load the state dict from the file.
    - Default, load the model using `instantiate` from hydra.
    """
    if (
        isinstance(model_name_or_config, str)
        and model_name_or_config in self._models
    ):
        model_config = self._models[model_name_or_config]
    else:
        model_config = model_name_or_config
    if isinstance(model_config, DictConfig):
        model_config = OmegaConf.to_container(model_config, resolve=True)

    if isinstance(model_config, str):
        # the model config is a string, which is the path to the model checkpoint in pickle format
        # load the model using `torch.load`
        # this is the original usage in the task arithmetic codebase
        _check_and_redirect_open_clip_modeling()
        log.info(f"loading ImageEncoder from {model_config}")
        weights_only = kwargs["weights_only"] if "weights_only" in kwargs else False
        try:
            encoder = torch.load(
                model_config, weights_only=weights_only, *args, **kwargs
            )
        except RuntimeError as e:
            encoder = pickle.load(open(model_config, "rb"))
    elif is_expr_match({"model_name": str, "pickle_path": str}, model_config):
        # the model config is a dictionary with the following keys:
        # - model_name: str, the name of the model
        # - pickle_path: str, the path to the binary file (pickle format)
        # load the model from the binary file (pickle format)
        # this is useful when you use a newer version of torchvision
        _check_and_redirect_open_clip_modeling()
        log.info(
            f"loading ImageEncoder of {model_config['model_name']} from {model_config['pickle_path']}"
        )
        weights_only = kwargs["weights_only"] if "weights_only" in kwargs else False
        try:
            encoder = torch.load(
                model_config["pickle_path"],
                weights_only=weights_only,
                *args,
                **kwargs,
            )
        except RuntimeError as e:
            encoder = pickle.load(open(model_config["pickle_path"], "rb"))
        _encoder = ImageEncoder(model_config["model_name"])
        _encoder.load_state_dict(encoder.state_dict())
        encoder = _encoder
    elif is_expr_match({"model_name": str, "state_dict_path": str}, model_config):
        # the model config is a dictionary with the following keys:
        # - model_name: str, the name of the model
        # - state_dict_path: str, the path to the state dict file
        # load the model from the state dict file
        log.info(
            f"loading ImageEncoder of {model_config['model_name']} from {model_config['state_dict_path']}"
        )
        encoder = ImageEncoder(model_config["model_name"])
        encoder.load_state_dict(
            torch.load(
                model_config["state_dict_path"], weights_only=True, *args, **kwargs
            )
        )
    elif isinstance(model_config, nn.Module):
        # the model config is an existing model
        log.info(f"Returning existing model: {model_config}")
        encoder = model_config
    else:
        encoder = super().load_model(model_name_or_config, *args, **kwargs)
    encoder = cast(ImageEncoder, encoder)

    # setup the train and test processors
    if self._train_processor is None and hasattr(encoder, "train_preprocess"):
        self._train_processor = encoder.train_preprocess
    if self._test_processor is None and hasattr(encoder, "val_preprocess"):
        self._test_processor = encoder.val_preprocess

    return encoder

ResNet for Image Classification

ResNet Model Pool for Image Classification.

This module provides a flexible model pool implementation for ResNet models used in image classification tasks. It supports both torchvision and transformers implementations of ResNet architectures with configurable preprocessing, loading, and saving capabilities.

Example Usage

Create a pool with a torchvision ResNet model:

>>> # Torchvision ResNet pool
>>> pool = ResNetForImageClassificationPool(
...     type="torchvision",
...     models={"resnet18_cifar10": {"model_name": "resnet18", "dataset_name": "cifar10"}}
... )
>>> model = pool.load_model("resnet18_cifar10")
>>> processor = pool.load_processor(stage="train")

Create a pool with a transformers ResNet model:

>>> # Transformers ResNet pool
>>> pool = ResNetForImageClassificationPool(
...     type="transformers",
...     models={"resnet_model": {"config_path": "microsoft/resnet-50", "pretrained": True}}
... )

ResNetForImageClassificationPool

Bases: BaseModelPool

Model pool for ResNet-based image classification models.

This class provides a unified interface for managing ResNet models from different sources (torchvision and transformers) with automatic preprocessing, loading, and saving capabilities. It supports multiple ResNet architectures and can automatically adapt models to different datasets by adjusting the number of output classes.

The pool supports two main types: - "torchvision": Uses torchvision's ResNet implementations with standard ImageNet preprocessing - "transformers": Uses Hugging Face transformers' ResNetForImageClassification with auto processors

Parameters:

  • type (str) –

    Model source type, must be either "torchvision" or "transformers".

  • **kwargs

    Additional arguments passed to the base BaseModelPool class.

Attributes:

  • type (str) –

    The model source type specified during initialization.

Raises:

  • AssertionError

    If type is not "torchvision" or "transformers".

Example

Create a pool with a torchvision ResNet model:

>>> # Torchvision-based pool
>>> pool = ResNetForImageClassificationPool(
...     type="torchvision",
...     models={
...         "resnet18_cifar10": {
...             "model_name": "resnet18",
...             "weights": "DEFAULT",
...             "dataset_name": "cifar10"
...         }
...     }
... )
Create a pool with a transformers ResNet model:

```python
>>> # Transformers-based pool
>>> pool = ResNetForImageClassificationPool(
...     type="transformers",
...     models={
...         "resnet_model": {
...             "config_path": "microsoft/resnet-50",
...             "pretrained": True,
...             "dataset_name": "imagenet"
...         }
...     }
... )

Source code in fusion_bench/modelpool/resnet_for_image_classification.py
@auto_register_config
class ResNetForImageClassificationPool(BaseModelPool):
    """Model pool for ResNet-based image classification models.

    This class provides a unified interface for managing ResNet models from different sources
    (torchvision and transformers) with automatic preprocessing, loading, and saving capabilities.
    It supports multiple ResNet architectures and can automatically adapt models to different
    datasets by adjusting the number of output classes.

    The pool supports two main types:
    - "torchvision": Uses torchvision's ResNet implementations with standard ImageNet preprocessing
    - "transformers": Uses Hugging Face transformers' ResNetForImageClassification with auto processors

    Args:
        type (str): Model source type, must be either "torchvision" or "transformers".
        **kwargs: Additional arguments passed to the base BaseModelPool class.

    Attributes:
        type (str): The model source type specified during initialization.

    Raises:
        AssertionError: If type is not "torchvision" or "transformers".

    Example:
        Create a pool with a torchvision ResNet model:

        ```python
        >>> # Torchvision-based pool
        >>> pool = ResNetForImageClassificationPool(
        ...     type="torchvision",
        ...     models={
        ...         "resnet18_cifar10": {
        ...             "model_name": "resnet18",
        ...             "weights": "DEFAULT",
        ...             "dataset_name": "cifar10"
        ...         }
        ...     }
        ... )
        ```
        ```

        Create a pool with a transformers ResNet model:

        ```python
        >>> # Transformers-based pool
        >>> pool = ResNetForImageClassificationPool(
        ...     type="transformers",
        ...     models={
        ...         "resnet_model": {
        ...             "config_path": "microsoft/resnet-50",
        ...             "pretrained": True,
        ...             "dataset_name": "imagenet"
        ...         }
        ...     }
        ... )
        ```
    """

    def __init__(self, models, type: str, **kwargs):
        super().__init__(models=models, **kwargs)
        assert type in [
            "torchvision",
            "transformers",
        ], "type must be either 'torchvision' or 'transformers'"

    def load_processor(
        self, stage: Literal["train", "val", "test"] = "test", *args, **kwargs
    ):
        """Load the appropriate image processor/transform for the specified training stage.

        Creates stage-specific image preprocessing pipelines optimized for the model type:

        For torchvision models:
        - Train stage: Includes data augmentation (random resize crop, horizontal flip)
        - Val/test stages: Standard preprocessing (resize, center crop) without augmentation
        - All stages: Apply ImageNet normalization (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

        For transformers models:
        - Uses AutoImageProcessor from the pretrained model configuration
        - Automatically handles model-specific preprocessing requirements

        Args:
            stage (Literal["train", "val", "test"]): The training stage determining preprocessing type.
                - "train": Applies data augmentation for training
                - "val"/"test": Uses standard preprocessing for evaluation
            *args: Additional positional arguments (unused).
            **kwargs: Additional keyword arguments (unused).

        Returns:
            Union[transforms.Compose, AutoImageProcessor]: The image processor/transform pipeline
            appropriate for the specified stage and model type.

        Raises:
            ValueError: If no valid config_path can be found for transformers models.

        Example:
            ```python
            >>> # Get training transforms for torchvision model
            >>> train_transform = pool.load_processor(stage="train")
            >>> # Get evaluation processor for transformers model
            >>> eval_processor = pool.load_processor(stage="test")
            ```
        """
        if self.type == "torchvision":
            from torchvision import transforms

            to_tensor = transforms.ToTensor()
            normalize = transforms.Normalize(
                mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
            )
            if stage == "train":
                train_transform = transforms.Compose(
                    [
                        transforms.RandomResizedCrop(224),
                        transforms.RandomHorizontalFlip(),
                        to_tensor,
                        normalize,
                    ]
                )
                return train_transform
            else:
                val_transform = transforms.Compose(
                    [
                        transforms.Resize(256),
                        transforms.CenterCrop(224),
                        to_tensor,
                        normalize,
                    ]
                )
                return val_transform

        elif self.type == "transformers":
            from transformers import AutoImageProcessor

            if self.has_pretrained:
                config_path = self._models["_pretrained_"].config_path
            else:
                for model_cfg in self._models.values():
                    if isinstance(model_cfg, str):
                        config_path = model_cfg
                        break
                    if "config_path" in model_cfg:
                        config_path = model_cfg["config_path"]
                        break
            return AutoImageProcessor.from_pretrained(config_path)

    @override
    def load_model(self, model_name_or_config: Union[str, DictConfig], *args, **kwargs):
        """Load a ResNet model based on the provided configuration or model name.

        This method supports flexible model loading from different sources and configurations:
        - Direct model names (e.g., "resnet18", "resnet50") for standard architectures
        - Model pool keys that map to configurations
        - Dictionary/DictConfig objects with detailed model specifications
        - Hugging Face model identifiers for transformers models

        For torchvision models, supports:
        - Standard ResNet architectures: resnet18, resnet34, resnet50, resnet101, resnet152
        - Custom configurations with model_name, weights, and num_classes specifications
        - Automatic dataset adaptation with class number inference

        For transformers models:
        - Loading from Hugging Face Hub or local paths
        - Pretrained or randomly initialized models
        - Automatic logits extraction by overriding forward method
        - Dataset-specific label mapping configuration

        Args:
            model_name_or_config (Union[str, DictConfig]): Model specification that can be:
                - A string model name (e.g., "resnet18") for standard architectures
                - A model pool key referencing a stored configuration
                - A dict/DictConfig with model parameters like:
                  * For torchvision: {"model_name": "resnet18", "weights": "DEFAULT", "num_classes": 10}
                  * For transformers: {"config_path": "microsoft/resnet-50", "pretrained": True, "dataset_name": "cifar10"}
            *args: Additional positional arguments (unused).
            **kwargs: Additional keyword arguments (unused).

        Returns:
            Union[TorchVisionResNet, ResNetForImageClassification]: The loaded ResNet model
            configured for the specified task. For transformers models, the forward method
            is modified to return logits directly instead of the full model output.

        Raises:
            ValueError: If model_name_or_config type is invalid or if model type is unknown.
            AssertionError: If num_classes from dataset doesn't match explicit num_classes specification.

        Example:
            ```python
            >>> # Load standard torchvision model
            >>> model = pool.load_model("resnet18")

            >>> # Load with custom configuration
            >>> config = {"model_name": "resnet50", "weights": "DEFAULT", "dataset_name": "cifar10"}
            >>> model = pool.load_model(config)

            >>> # Load transformers model
            >>> config = {"config_path": "microsoft/resnet-50", "pretrained": True}
            >>> model = pool.load_model(config)
            ```
        """
        log.debug(f"Loading model: {model_name_or_config}", stacklevel=2)
        if (
            isinstance(model_name_or_config, str)
            and model_name_or_config in self._models
        ):
            model_name_or_config = self._models[model_name_or_config]

        if self.type == "torchvision":
            from torchvision.models import (
                resnet18,
                resnet34,
                resnet50,
                resnet101,
                resnet152,
            )

            match model_name_or_config:
                case "resnet18":
                    model = resnet18()
                case "resnet34":
                    model = resnet34()
                case "resnet50":
                    model = resnet50()
                case "resnet101":
                    model = resnet101()
                case "resnet152":
                    model = resnet152()
                case dict() | DictConfig() as model_config:
                    if "dataset_name" in model_config:
                        num_classes = get_num_classes(model_config["dataset_name"])
                        if "num_classes" in model_config:
                            assert (
                                num_classes == model_config["num_classes"]
                            ), f"num_classes mismatch: {num_classes} vs {model_config['num_classes']}"
                    elif "num_classes" in model_config:
                        num_classes = model_config["num_classes"]
                    else:
                        num_classes = None
                    model = load_torchvision_resnet(
                        model_name=model_config["model_name"],
                        weights=model_config.get("weights", None),
                        num_classes=num_classes,
                    )
                case _:
                    raise ValueError(
                        f"Invalid model_name_or_config type: {type(model_name_or_config)}"
                    )
        elif self.type == "transformers":
            match model_name_or_config:
                case str() as model_path:
                    from transformers import AutoModelForImageClassification

                    model = AutoModelForImageClassification.from_pretrained(model_path)
                case dict() | DictConfig() as model_config:

                    model = load_transformers_resnet(
                        config_path=model_config["config_path"],
                        pretrained=model_config.get("pretrained", True),
                        dataset_name=model_config.get("dataset_name", None),
                    )
                case _:
                    raise ValueError(
                        f"Invalid model_name_or_config type: {type(model_name_or_config)}"
                    )

            # override forward to return logits only
            original_forward = model.forward
            model.forward = lambda pixel_values, **kwargs: original_forward(
                pixel_values=pixel_values, **kwargs
            ).logits
            model.original_forward = original_forward
        else:
            raise ValueError(f"Unknown model type: {self.type}")
        return model

    @override
    def save_model(
        self,
        model,
        path,
        algorithm_config: Optional[DictConfig] = None,
        description: Optional[str] = None,
        base_model: Optional[str] = None,
        *args,
        **kwargs,
    ):
        """Save a ResNet model to the specified path using the appropriate format.

        This method handles model saving based on the model pool type:
        - For torchvision models: Saves only the state_dict using torch.save()
        - For transformers models: Saves the complete model and processor using save_pretrained()

        The saving format ensures compatibility with the corresponding loading mechanisms
        and preserves all necessary components for model restoration.

        Args:
            model: The ResNet model to save. Should be compatible with the pool's model type.
            path (str): Destination path for saving the model. For torchvision models, this
                should be a file path (e.g., "model.pth"). For transformers models, this
                should be a directory path where model files will be stored.
            *args: Additional positional arguments (unused).
            **kwargs: Additional keyword arguments (unused).

        Raises:
            ValueError: If the model type is unknown or unsupported.

        Note:
            For transformers models, both the model weights and the associated image processor
            are saved to ensure complete reproducibility of the preprocessing pipeline.

        Example:
            ```python
            >>> # Save torchvision model
            >>> pool.save_model(model, "checkpoints/resnet18_cifar10.pth")

            >>> # Save transformers model (saves to directory)
            >>> pool.save_model(model, "checkpoints/resnet50_model/")
            ```
        """
        if self.type == "torchvision":
            os.makedirs(os.path.dirname(path), exist_ok=True)
            torch.save(model.state_dict(), path)
        elif self.type == "transformers":
            model.save_pretrained(path)
            self.load_processor().save_pretrained(path)

            if algorithm_config is not None and rank_zero_only.rank == 0:
                from fusion_bench.models.hf_utils import create_default_model_card

                model_card_str = create_default_model_card(
                    base_model=base_model,
                    algorithm_config=algorithm_config,
                    description=description,
                    modelpool_config=self.config,
                )
                with open(os.path.join(path, "README.md"), "w") as f:
                    f.write(model_card_str)
        else:
            raise ValueError(f"Unknown model type: {self.type}")
load_model(model_name_or_config, *args, **kwargs)

Load a ResNet model based on the provided configuration or model name.

This method supports flexible model loading from different sources and configurations: - Direct model names (e.g., "resnet18", "resnet50") for standard architectures - Model pool keys that map to configurations - Dictionary/DictConfig objects with detailed model specifications - Hugging Face model identifiers for transformers models

For torchvision models, supports: - Standard ResNet architectures: resnet18, resnet34, resnet50, resnet101, resnet152 - Custom configurations with model_name, weights, and num_classes specifications - Automatic dataset adaptation with class number inference

For transformers models: - Loading from Hugging Face Hub or local paths - Pretrained or randomly initialized models - Automatic logits extraction by overriding forward method - Dataset-specific label mapping configuration

Parameters:

  • model_name_or_config (Union[str, DictConfig]) –

    Model specification that can be: - A string model name (e.g., "resnet18") for standard architectures - A model pool key referencing a stored configuration - A dict/DictConfig with model parameters like: * For torchvision: {"model_name": "resnet18", "weights": "DEFAULT", "num_classes": 10} * For transformers: {"config_path": "microsoft/resnet-50", "pretrained": True, "dataset_name": "cifar10"}

  • *args

    Additional positional arguments (unused).

  • **kwargs

    Additional keyword arguments (unused).

Returns:

  • Union[TorchVisionResNet, ResNetForImageClassification]: The loaded ResNet model

  • configured for the specified task. For transformers models, the forward method

  • is modified to return logits directly instead of the full model output.

Raises:

  • ValueError

    If model_name_or_config type is invalid or if model type is unknown.

  • AssertionError

    If num_classes from dataset doesn't match explicit num_classes specification.

Example
>>> # Load standard torchvision model
>>> model = pool.load_model("resnet18")

>>> # Load with custom configuration
>>> config = {"model_name": "resnet50", "weights": "DEFAULT", "dataset_name": "cifar10"}
>>> model = pool.load_model(config)

>>> # Load transformers model
>>> config = {"config_path": "microsoft/resnet-50", "pretrained": True}
>>> model = pool.load_model(config)
Source code in fusion_bench/modelpool/resnet_for_image_classification.py
@override
def load_model(self, model_name_or_config: Union[str, DictConfig], *args, **kwargs):
    """Load a ResNet model based on the provided configuration or model name.

    This method supports flexible model loading from different sources and configurations:
    - Direct model names (e.g., "resnet18", "resnet50") for standard architectures
    - Model pool keys that map to configurations
    - Dictionary/DictConfig objects with detailed model specifications
    - Hugging Face model identifiers for transformers models

    For torchvision models, supports:
    - Standard ResNet architectures: resnet18, resnet34, resnet50, resnet101, resnet152
    - Custom configurations with model_name, weights, and num_classes specifications
    - Automatic dataset adaptation with class number inference

    For transformers models:
    - Loading from Hugging Face Hub or local paths
    - Pretrained or randomly initialized models
    - Automatic logits extraction by overriding forward method
    - Dataset-specific label mapping configuration

    Args:
        model_name_or_config (Union[str, DictConfig]): Model specification that can be:
            - A string model name (e.g., "resnet18") for standard architectures
            - A model pool key referencing a stored configuration
            - A dict/DictConfig with model parameters like:
              * For torchvision: {"model_name": "resnet18", "weights": "DEFAULT", "num_classes": 10}
              * For transformers: {"config_path": "microsoft/resnet-50", "pretrained": True, "dataset_name": "cifar10"}
        *args: Additional positional arguments (unused).
        **kwargs: Additional keyword arguments (unused).

    Returns:
        Union[TorchVisionResNet, ResNetForImageClassification]: The loaded ResNet model
        configured for the specified task. For transformers models, the forward method
        is modified to return logits directly instead of the full model output.

    Raises:
        ValueError: If model_name_or_config type is invalid or if model type is unknown.
        AssertionError: If num_classes from dataset doesn't match explicit num_classes specification.

    Example:
        ```python
        >>> # Load standard torchvision model
        >>> model = pool.load_model("resnet18")

        >>> # Load with custom configuration
        >>> config = {"model_name": "resnet50", "weights": "DEFAULT", "dataset_name": "cifar10"}
        >>> model = pool.load_model(config)

        >>> # Load transformers model
        >>> config = {"config_path": "microsoft/resnet-50", "pretrained": True}
        >>> model = pool.load_model(config)
        ```
    """
    log.debug(f"Loading model: {model_name_or_config}", stacklevel=2)
    if (
        isinstance(model_name_or_config, str)
        and model_name_or_config in self._models
    ):
        model_name_or_config = self._models[model_name_or_config]

    if self.type == "torchvision":
        from torchvision.models import (
            resnet18,
            resnet34,
            resnet50,
            resnet101,
            resnet152,
        )

        match model_name_or_config:
            case "resnet18":
                model = resnet18()
            case "resnet34":
                model = resnet34()
            case "resnet50":
                model = resnet50()
            case "resnet101":
                model = resnet101()
            case "resnet152":
                model = resnet152()
            case dict() | DictConfig() as model_config:
                if "dataset_name" in model_config:
                    num_classes = get_num_classes(model_config["dataset_name"])
                    if "num_classes" in model_config:
                        assert (
                            num_classes == model_config["num_classes"]
                        ), f"num_classes mismatch: {num_classes} vs {model_config['num_classes']}"
                elif "num_classes" in model_config:
                    num_classes = model_config["num_classes"]
                else:
                    num_classes = None
                model = load_torchvision_resnet(
                    model_name=model_config["model_name"],
                    weights=model_config.get("weights", None),
                    num_classes=num_classes,
                )
            case _:
                raise ValueError(
                    f"Invalid model_name_or_config type: {type(model_name_or_config)}"
                )
    elif self.type == "transformers":
        match model_name_or_config:
            case str() as model_path:
                from transformers import AutoModelForImageClassification

                model = AutoModelForImageClassification.from_pretrained(model_path)
            case dict() | DictConfig() as model_config:

                model = load_transformers_resnet(
                    config_path=model_config["config_path"],
                    pretrained=model_config.get("pretrained", True),
                    dataset_name=model_config.get("dataset_name", None),
                )
            case _:
                raise ValueError(
                    f"Invalid model_name_or_config type: {type(model_name_or_config)}"
                )

        # override forward to return logits only
        original_forward = model.forward
        model.forward = lambda pixel_values, **kwargs: original_forward(
            pixel_values=pixel_values, **kwargs
        ).logits
        model.original_forward = original_forward
    else:
        raise ValueError(f"Unknown model type: {self.type}")
    return model
load_processor(stage='test', *args, **kwargs)

Load the appropriate image processor/transform for the specified training stage.

Creates stage-specific image preprocessing pipelines optimized for the model type:

For torchvision models: - Train stage: Includes data augmentation (random resize crop, horizontal flip) - Val/test stages: Standard preprocessing (resize, center crop) without augmentation - All stages: Apply ImageNet normalization (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

For transformers models: - Uses AutoImageProcessor from the pretrained model configuration - Automatically handles model-specific preprocessing requirements

Parameters:

  • stage (Literal['train', 'val', 'test'], default: 'test' ) –

    The training stage determining preprocessing type. - "train": Applies data augmentation for training - "val"/"test": Uses standard preprocessing for evaluation

  • *args

    Additional positional arguments (unused).

  • **kwargs

    Additional keyword arguments (unused).

Returns:

  • Union[transforms.Compose, AutoImageProcessor]: The image processor/transform pipeline

  • appropriate for the specified stage and model type.

Raises:

  • ValueError

    If no valid config_path can be found for transformers models.

Example
>>> # Get training transforms for torchvision model
>>> train_transform = pool.load_processor(stage="train")
>>> # Get evaluation processor for transformers model
>>> eval_processor = pool.load_processor(stage="test")
Source code in fusion_bench/modelpool/resnet_for_image_classification.py
def load_processor(
    self, stage: Literal["train", "val", "test"] = "test", *args, **kwargs
):
    """Load the appropriate image processor/transform for the specified training stage.

    Creates stage-specific image preprocessing pipelines optimized for the model type:

    For torchvision models:
    - Train stage: Includes data augmentation (random resize crop, horizontal flip)
    - Val/test stages: Standard preprocessing (resize, center crop) without augmentation
    - All stages: Apply ImageNet normalization (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    For transformers models:
    - Uses AutoImageProcessor from the pretrained model configuration
    - Automatically handles model-specific preprocessing requirements

    Args:
        stage (Literal["train", "val", "test"]): The training stage determining preprocessing type.
            - "train": Applies data augmentation for training
            - "val"/"test": Uses standard preprocessing for evaluation
        *args: Additional positional arguments (unused).
        **kwargs: Additional keyword arguments (unused).

    Returns:
        Union[transforms.Compose, AutoImageProcessor]: The image processor/transform pipeline
        appropriate for the specified stage and model type.

    Raises:
        ValueError: If no valid config_path can be found for transformers models.

    Example:
        ```python
        >>> # Get training transforms for torchvision model
        >>> train_transform = pool.load_processor(stage="train")
        >>> # Get evaluation processor for transformers model
        >>> eval_processor = pool.load_processor(stage="test")
        ```
    """
    if self.type == "torchvision":
        from torchvision import transforms

        to_tensor = transforms.ToTensor()
        normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )
        if stage == "train":
            train_transform = transforms.Compose(
                [
                    transforms.RandomResizedCrop(224),
                    transforms.RandomHorizontalFlip(),
                    to_tensor,
                    normalize,
                ]
            )
            return train_transform
        else:
            val_transform = transforms.Compose(
                [
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    to_tensor,
                    normalize,
                ]
            )
            return val_transform

    elif self.type == "transformers":
        from transformers import AutoImageProcessor

        if self.has_pretrained:
            config_path = self._models["_pretrained_"].config_path
        else:
            for model_cfg in self._models.values():
                if isinstance(model_cfg, str):
                    config_path = model_cfg
                    break
                if "config_path" in model_cfg:
                    config_path = model_cfg["config_path"]
                    break
        return AutoImageProcessor.from_pretrained(config_path)
save_model(model, path, algorithm_config=None, description=None, base_model=None, *args, **kwargs)

Save a ResNet model to the specified path using the appropriate format.

This method handles model saving based on the model pool type: - For torchvision models: Saves only the state_dict using torch.save() - For transformers models: Saves the complete model and processor using save_pretrained()

The saving format ensures compatibility with the corresponding loading mechanisms and preserves all necessary components for model restoration.

Parameters:

  • model

    The ResNet model to save. Should be compatible with the pool's model type.

  • path (str) –

    Destination path for saving the model. For torchvision models, this should be a file path (e.g., "model.pth"). For transformers models, this should be a directory path where model files will be stored.

  • *args

    Additional positional arguments (unused).

  • **kwargs

    Additional keyword arguments (unused).

Raises:

  • ValueError

    If the model type is unknown or unsupported.

Note

For transformers models, both the model weights and the associated image processor are saved to ensure complete reproducibility of the preprocessing pipeline.

Example
>>> # Save torchvision model
>>> pool.save_model(model, "checkpoints/resnet18_cifar10.pth")

>>> # Save transformers model (saves to directory)
>>> pool.save_model(model, "checkpoints/resnet50_model/")
Source code in fusion_bench/modelpool/resnet_for_image_classification.py
@override
def save_model(
    self,
    model,
    path,
    algorithm_config: Optional[DictConfig] = None,
    description: Optional[str] = None,
    base_model: Optional[str] = None,
    *args,
    **kwargs,
):
    """Save a ResNet model to the specified path using the appropriate format.

    This method handles model saving based on the model pool type:
    - For torchvision models: Saves only the state_dict using torch.save()
    - For transformers models: Saves the complete model and processor using save_pretrained()

    The saving format ensures compatibility with the corresponding loading mechanisms
    and preserves all necessary components for model restoration.

    Args:
        model: The ResNet model to save. Should be compatible with the pool's model type.
        path (str): Destination path for saving the model. For torchvision models, this
            should be a file path (e.g., "model.pth"). For transformers models, this
            should be a directory path where model files will be stored.
        *args: Additional positional arguments (unused).
        **kwargs: Additional keyword arguments (unused).

    Raises:
        ValueError: If the model type is unknown or unsupported.

    Note:
        For transformers models, both the model weights and the associated image processor
        are saved to ensure complete reproducibility of the preprocessing pipeline.

    Example:
        ```python
        >>> # Save torchvision model
        >>> pool.save_model(model, "checkpoints/resnet18_cifar10.pth")

        >>> # Save transformers model (saves to directory)
        >>> pool.save_model(model, "checkpoints/resnet50_model/")
        ```
    """
    if self.type == "torchvision":
        os.makedirs(os.path.dirname(path), exist_ok=True)
        torch.save(model.state_dict(), path)
    elif self.type == "transformers":
        model.save_pretrained(path)
        self.load_processor().save_pretrained(path)

        if algorithm_config is not None and rank_zero_only.rank == 0:
            from fusion_bench.models.hf_utils import create_default_model_card

            model_card_str = create_default_model_card(
                base_model=base_model,
                algorithm_config=algorithm_config,
                description=description,
                modelpool_config=self.config,
            )
            with open(os.path.join(path, "README.md"), "w") as f:
                f.write(model_card_str)
    else:
        raise ValueError(f"Unknown model type: {self.type}")

load_torchvision_resnet(model_name, weights, num_classes)

Load a ResNet model from torchvision with optional custom classifier head.

This function creates a ResNet model using torchvision's model zoo and optionally replaces the final classification layer to match the required number of classes.

Parameters:

  • model_name (str) –

    Name of the ResNet model to load (e.g., 'resnet18', 'resnet50'). Must be a valid torchvision model name.

  • weights (Optional[str]) –

    Pretrained weights to load. Can be 'DEFAULT', 'IMAGENET1K_V1', or None for random initialization. See torchvision documentation for available options.

  • num_classes (Optional[int]) –

    Number of output classes. If provided, replaces the final fully connected layer. If None, keeps the original classifier (typically 1000 classes).

Returns:

  • TorchVisionResNet ( ResNet ) –

    The loaded ResNet model with appropriate classifier head.

Raises:

  • AttributeError

    If model_name is not a valid torchvision model.

Example
>>> model = load_torchvision_resnet("resnet18", "DEFAULT", 10)  # CIFAR-10
>>> model = load_torchvision_resnet("resnet50", None, 100)     # Random init, 100 classes
Source code in fusion_bench/modelpool/resnet_for_image_classification.py
def load_torchvision_resnet(
    model_name: str, weights: Optional[str], num_classes: Optional[int]
) -> "TorchVisionResNet":
    """Load a ResNet model from torchvision with optional custom classifier head.

    This function creates a ResNet model using torchvision's model zoo and optionally
    replaces the final classification layer to match the required number of classes.

    Args:
        model_name (str): Name of the ResNet model to load (e.g., 'resnet18', 'resnet50').
            Must be a valid torchvision model name.
        weights (Optional[str]): Pretrained weights to load. Can be 'DEFAULT', 'IMAGENET1K_V1',
            or None for random initialization. See torchvision documentation for available options.
        num_classes (Optional[int]): Number of output classes. If provided, replaces the final
            fully connected layer. If None, keeps the original classifier (typically 1000 classes).

    Returns:
        TorchVisionResNet: The loaded ResNet model with appropriate classifier head.

    Raises:
        AttributeError: If model_name is not a valid torchvision model.

    Example:
        ```python
        >>> model = load_torchvision_resnet("resnet18", "DEFAULT", 10)  # CIFAR-10
        >>> model = load_torchvision_resnet("resnet50", None, 100)     # Random init, 100 classes
        ```
    """
    import torchvision.models

    model_fn = getattr(torchvision.models, model_name)
    model: "TorchVisionResNet" = model_fn(weights=weights)

    if num_classes is not None:
        model.fc = nn.Linear(model.fc.in_features, num_classes)

    return model

load_transformers_resnet(config_path, pretrained, dataset_name)

Load a ResNet model from transformers with optional dataset-specific adaptation.

This function creates a ResNet model using the transformers library and optionally adapts it for a specific dataset by updating the classifier head and label mappings.

Parameters:

  • config_path (str) –

    Path or identifier for the model configuration. Can be a local path or a Hugging Face model identifier (e.g., 'microsoft/resnet-50').

  • pretrained (bool) –

    Whether to load pretrained weights. If True, loads from the specified config_path. If False, initializes with random weights using the config.

  • dataset_name (Optional[str]) –

    Name of the target dataset for adaptation. If provided, updates the model's classifier and label mappings to match the dataset's classes. If None, keeps the original model configuration.

Returns:

  • ResNetForImageClassification

    The loaded and optionally adapted ResNet model.

Example
>>> # Load pretrained model adapted for CIFAR-10
>>> model = load_transformers_resnet("microsoft/resnet-50", True, "cifar10")
>>> # Load random initialized model with default classes
>>> model = load_transformers_resnet("microsoft/resnet-50", False, None)
Source code in fusion_bench/modelpool/resnet_for_image_classification.py
def load_transformers_resnet(
    config_path: str, pretrained: bool, dataset_name: Optional[str]
):
    """Load a ResNet model from transformers with optional dataset-specific adaptation.

    This function creates a ResNet model using the transformers library and optionally
    adapts it for a specific dataset by updating the classifier head and label mappings.

    Args:
        config_path (str): Path or identifier for the model configuration. Can be a local path
            or a Hugging Face model identifier (e.g., 'microsoft/resnet-50').
        pretrained (bool): Whether to load pretrained weights. If True, loads from the
            specified config_path. If False, initializes with random weights using the config.
        dataset_name (Optional[str]): Name of the target dataset for adaptation. If provided,
            updates the model's classifier and label mappings to match the dataset's classes.
            If None, keeps the original model configuration.

    Returns:
        ResNetForImageClassification: The loaded and optionally adapted ResNet model.

    Example:
        ```python
        >>> # Load pretrained model adapted for CIFAR-10
        >>> model = load_transformers_resnet("microsoft/resnet-50", True, "cifar10")
        >>> # Load random initialized model with default classes
        >>> model = load_transformers_resnet("microsoft/resnet-50", False, None)
        ```
    """
    from transformers import AutoConfig, ResNetForImageClassification

    if pretrained:
        model = ResNetForImageClassification.from_pretrained(config_path)
    else:
        config = AutoConfig.from_pretrained(config_path)
        model = ResNetForImageClassification(config)

    if dataset_name is None:
        return model

    classnames = get_classnames(dataset_name)
    id2label = {i: c for i, c in enumerate(classnames)}
    label2id = {c: i for i, c in enumerate(classnames)}
    model.config.id2label = id2label
    model.config.label2id = label2id
    model.num_labels = model.config.num_labels

    model.classifier[1] = (
        nn.Linear(
            model.classifier[1].in_features,
            len(classnames),
        )
        if model.config.num_labels > 0
        else nn.Identity()
    )
    return model

ConvNeXt for Image Classification

Hugging Face ConvNeXt image classification model pool.

This module provides a BaseModelPool implementation that loads and saves ConvNeXt models for image classification via transformers. It optionally reconfigures the classification head to match a dataset's class names and overrides forward to return logits only for simpler downstream usage.

See also: fusion_bench.modelpool.resnet_for_image_classification for a parallel implementation for ResNet-based classifiers.

ConvNextForImageClassificationPool

Bases: BaseModelPool

Model pool for ConvNeXt image classification models (HF Transformers).

Responsibilities: - Load an AutoImageProcessor compatible with the configured ConvNeXt model. - Load ConvNeXt models either from a pretrained checkpoint or from config. - Optionally adapt the classifier head to match dataset classnames. - Override forward to return logits for consistent interfaces within FusionBench.

See fusion_bench.modelpool.resnet_for_image_classification for a closely related ResNet-based pool with analogous behavior.

Source code in fusion_bench/modelpool/convnext_for_image_classification.py
@auto_register_config
class ConvNextForImageClassificationPool(BaseModelPool):
    """Model pool for ConvNeXt image classification models (HF Transformers).

    Responsibilities:
    - Load an `AutoImageProcessor` compatible with the configured ConvNeXt model.
    - Load ConvNeXt models either from a pretrained checkpoint or from config.
    - Optionally adapt the classifier head to match dataset classnames.
    - Override `forward` to return logits for consistent interfaces within
      FusionBench.

    See `fusion_bench.modelpool.resnet_for_image_classification` for a closely
    related ResNet-based pool with analogous behavior.
    """

    def load_processor(self, *args, **kwargs):
        from transformers import AutoImageProcessor

        if self.has_pretrained:
            config_path = self._models["_pretrained_"].config_path
        else:
            for model_cfg in self._models.values():
                if isinstance(model_cfg, str):
                    config_path = model_cfg
                    break
                if "config_path" in model_cfg:
                    config_path = model_cfg["config_path"]
                    break
        return AutoImageProcessor.from_pretrained(config_path)

    @override
    def load_model(self, model_name_or_config: Union[str, DictConfig], *args, **kwargs):
        """Load a ConvNeXt model described by a name, path, or DictConfig.

        Accepts either a string (pretrained identifier or local path) or a
        config mapping with keys: `config_path`, optional `pretrained` (bool),
        and optional `dataset_name` to resize the classifier.

        Returns:
            A model whose `forward` is wrapped to return only logits to align
            with FusionBench expectations.
        """
        log.debug(f"Loading model: {model_name_or_config}", stacklevel=2)
        if (
            isinstance(model_name_or_config, str)
            and model_name_or_config in self._models
        ):
            model_name_or_config = self._models[model_name_or_config]

        match model_name_or_config:
            case str() as model_path:
                from transformers import AutoModelForImageClassification

                model = AutoModelForImageClassification.from_pretrained(model_path)
            case dict() | DictConfig() as model_config:
                model = load_transformers_convnext(
                    model_config["config_path"],
                    pretrained=model_config.get("pretrained", True),
                    dataset_name=model_config.get("dataset_name", None),
                )
            case _:
                raise ValueError(
                    f"Unsupported model_name_or_config type: {type(model_name_or_config)}"
                )

        # override forward to return logits only
        original_forward = model.forward
        model.forward = lambda pixel_values, **kwargs: original_forward(
            pixel_values=pixel_values, **kwargs
        ).logits
        model.original_forward = original_forward

        return model

    @override
    def save_model(
        self,
        model,
        path,
        algorithm_config: Optional[DictConfig] = None,
        description: Optional[str] = None,
        base_model: Optional[str] = None,
        *args,
        **kwargs,
    ):
        """Save the model, processor, and an optional model card to disk.

        Artifacts written to `path`:
        - The ConvNeXt model via `model.save_pretrained`.
        - The paired image processor via `AutoImageProcessor.save_pretrained`.
        - If `algorithm_config` is provided and on rank-zero, a README model card
          documenting the FusionBench configuration.
        """
        model.save_pretrained(path)
        self.load_processor().save_pretrained(path)

        if algorithm_config is not None and rank_zero_only.rank == 0:
            from fusion_bench.models.hf_utils import create_default_model_card

            model_card_str = create_default_model_card(
                algorithm_config=algorithm_config,
                description=description,
                modelpool_config=self.config,
                base_model=base_model,
            )
            with open(os.path.join(path, "README.md"), "w") as f:
                f.write(model_card_str)
load_model(model_name_or_config, *args, **kwargs)

Load a ConvNeXt model described by a name, path, or DictConfig.

Accepts either a string (pretrained identifier or local path) or a config mapping with keys: config_path, optional pretrained (bool), and optional dataset_name to resize the classifier.

Returns:

  • A model whose forward is wrapped to return only logits to align

  • with FusionBench expectations.

Source code in fusion_bench/modelpool/convnext_for_image_classification.py
@override
def load_model(self, model_name_or_config: Union[str, DictConfig], *args, **kwargs):
    """Load a ConvNeXt model described by a name, path, or DictConfig.

    Accepts either a string (pretrained identifier or local path) or a
    config mapping with keys: `config_path`, optional `pretrained` (bool),
    and optional `dataset_name` to resize the classifier.

    Returns:
        A model whose `forward` is wrapped to return only logits to align
        with FusionBench expectations.
    """
    log.debug(f"Loading model: {model_name_or_config}", stacklevel=2)
    if (
        isinstance(model_name_or_config, str)
        and model_name_or_config in self._models
    ):
        model_name_or_config = self._models[model_name_or_config]

    match model_name_or_config:
        case str() as model_path:
            from transformers import AutoModelForImageClassification

            model = AutoModelForImageClassification.from_pretrained(model_path)
        case dict() | DictConfig() as model_config:
            model = load_transformers_convnext(
                model_config["config_path"],
                pretrained=model_config.get("pretrained", True),
                dataset_name=model_config.get("dataset_name", None),
            )
        case _:
            raise ValueError(
                f"Unsupported model_name_or_config type: {type(model_name_or_config)}"
            )

    # override forward to return logits only
    original_forward = model.forward
    model.forward = lambda pixel_values, **kwargs: original_forward(
        pixel_values=pixel_values, **kwargs
    ).logits
    model.original_forward = original_forward

    return model
save_model(model, path, algorithm_config=None, description=None, base_model=None, *args, **kwargs)

Save the model, processor, and an optional model card to disk.

Artifacts written to path: - The ConvNeXt model via model.save_pretrained. - The paired image processor via AutoImageProcessor.save_pretrained. - If algorithm_config is provided and on rank-zero, a README model card documenting the FusionBench configuration.

Source code in fusion_bench/modelpool/convnext_for_image_classification.py
@override
def save_model(
    self,
    model,
    path,
    algorithm_config: Optional[DictConfig] = None,
    description: Optional[str] = None,
    base_model: Optional[str] = None,
    *args,
    **kwargs,
):
    """Save the model, processor, and an optional model card to disk.

    Artifacts written to `path`:
    - The ConvNeXt model via `model.save_pretrained`.
    - The paired image processor via `AutoImageProcessor.save_pretrained`.
    - If `algorithm_config` is provided and on rank-zero, a README model card
      documenting the FusionBench configuration.
    """
    model.save_pretrained(path)
    self.load_processor().save_pretrained(path)

    if algorithm_config is not None and rank_zero_only.rank == 0:
        from fusion_bench.models.hf_utils import create_default_model_card

        model_card_str = create_default_model_card(
            algorithm_config=algorithm_config,
            description=description,
            modelpool_config=self.config,
            base_model=base_model,
        )
        with open(os.path.join(path, "README.md"), "w") as f:
            f.write(model_card_str)

load_transformers_convnext(config_path, pretrained, dataset_name)

Create a ConvNeXt image classification model from a config or checkpoint.

Parameters:

  • config_path (str) –

    A model identifier or local path understood by transformers.AutoConfig/AutoModel (e.g., "facebook/convnext-base-224").

  • pretrained (bool) –

    If True, load weights via from_pretrained; otherwise, build the model from config only.

  • dataset_name (Optional[str]) –

    Optional dataset key used by FusionBench to derive class names via get_classnames. When provided, the model's id/label maps are updated and the classifier head is resized accordingly.

Returns:

  • ConvNextForImageClassification

    A transformers.ConvNextForImageClassification instance. If dataset_name is set, the classifier head is adapted to the number of classes. The model's config.id2label and config.label2id are also populated.

Notes

The overall structure mirrors the ResNet implementation in fusion_bench.modelpool.resnet_for_image_classification.

Source code in fusion_bench/modelpool/convnext_for_image_classification.py
def load_transformers_convnext(
    config_path: str, pretrained: bool, dataset_name: Optional[str]
):
    """Create a ConvNeXt image classification model from a config or checkpoint.

    Args:
        config_path: A model identifier or local path understood by
            `transformers.AutoConfig/AutoModel` (e.g., "facebook/convnext-base-224").
        pretrained: If True, load weights via `from_pretrained`; otherwise, build
            the model from config only.
        dataset_name: Optional dataset key used by FusionBench to derive class
            names via `get_classnames`. When provided, the model's id/label maps
            are updated and the classifier head is resized accordingly.

    Returns:
        ConvNextForImageClassification: A `transformers.ConvNextForImageClassification` instance. If
            `dataset_name` is set, the classifier head is adapted to the number of
            classes. The model's `config.id2label` and `config.label2id` are also
            populated.

    Notes:
        The overall structure mirrors the ResNet implementation in
        `fusion_bench.modelpool.resnet_for_image_classification`.
    """
    from transformers import AutoConfig, ConvNextForImageClassification

    if pretrained:
        model = ConvNextForImageClassification.from_pretrained(config_path)
    else:
        config = AutoConfig.from_pretrained(config_path)
        model = ConvNextForImageClassification(config)

    if dataset_name is None:
        return model

    classnames = get_classnames(dataset_name)
    id2label = {i: c for i, c in enumerate(classnames)}
    label2id = {c: i for i, c in enumerate(classnames)}
    model.config.id2label = id2label
    model.config.label2id = label2id
    model.num_labels = model.config.num_labels

    model.classifier = (
        nn.Linear(
            model.classifier.in_features,
            len(classnames),
            device=model.classifier.weight.device,
            dtype=model.classifier.weight.dtype,
        )
        if model.config.num_labels > 0
        else nn.Identity()
    )
    return model

DINOv2 for Image Classification

Hugging Face DINOv2 image classification model pool.

This module provides a BaseModelPool implementation that loads and saves DINOv2 models for image classification via transformers. It optionally reconfigures the classification head to match a dataset's class names and overrides forward to return logits only for simpler downstream usage.

See also: fusion_bench.modelpool.convnext_for_image_classification for a parallel implementation for ConvNeXt-based classifiers.

Dinov2ForImageClassificationPool

Bases: BaseModelPool

Model pool for DINOv2 image classification models (HF Transformers).

Source code in fusion_bench/modelpool/dinov2_for_image_classification.py
@auto_register_config
class Dinov2ForImageClassificationPool(BaseModelPool):
    """Model pool for DINOv2 image classification models (HF Transformers)."""

    def load_processor(self, *args, **kwargs):
        """Load the paired image processor for this model pool.

        Uses the configured model's identifier or config path to retrieve the
        appropriate `transformers.AutoImageProcessor` instance. If a pretrained
        model entry exists in the pool configuration, it is preferred to derive
        the processor to ensure tokenization/normalization parity.
        """
        from transformers import AutoImageProcessor

        if self.has_pretrained:
            config_path = self._models["_pretrained_"].config_path
        else:
            for model_cfg in self._models.values():
                if isinstance(model_cfg, str):
                    config_path = model_cfg
                    break
                if "config_path" in model_cfg:
                    config_path = model_cfg["config_path"]
                    break
        return AutoImageProcessor.from_pretrained(config_path)

    @override
    def load_model(self, model_name_or_config: Union[str, DictConfig], *args, **kwargs):
        """Load a DINOv2 model described by a name, path, or DictConfig.

        Accepts either a string (pretrained identifier or local path) or a
        config mapping with keys: `config_path`, optional `pretrained` (bool),
        and optional `dataset_name` to resize the classifier.

        Returns:
            A model whose `forward` is wrapped to return only logits to align
            with FusionBench expectations.
        """
        log.debug(f"Loading model: {model_name_or_config}", stacklevel=2)
        if (
            isinstance(model_name_or_config, str)
            and model_name_or_config in self._models
        ):
            model_name_or_config = self._models[model_name_or_config]

        match model_name_or_config:
            case str() as model_path:
                from transformers import AutoModelForImageClassification

                model = AutoModelForImageClassification.from_pretrained(model_path)
            case dict() | DictConfig() as model_config:
                model = load_transformers_dinov2(
                    model_config["config_path"],
                    pretrained=model_config.get("pretrained", True),
                    dataset_name=model_config.get("dataset_name", None),
                )
            case _:
                raise ValueError(
                    f"Unsupported model_name_or_config type: {type(model_name_or_config)}"
                )

        # Override forward to return logits only, to unify the interface across
        # FusionBench model pools and simplify downstream usage.
        original_forward = model.forward
        model.forward = lambda pixel_values, **kwargs: original_forward(
            pixel_values=pixel_values, **kwargs
        ).logits
        model.original_forward = original_forward

        return model

    @override
    def save_model(
        self,
        model,
        path,
        algorithm_config: Optional[DictConfig] = None,
        description: Optional[str] = None,
        base_model: Optional[str] = None,
        *args,
        **kwargs,
    ):
        """Save the model, processor, and an optional model card to disk.

        Artifacts written to `path`:
        - The DINOv2 model via `model.save_pretrained`.
        - The paired image processor via `AutoImageProcessor.save_pretrained`.
        - If `algorithm_config` is provided and on rank-zero, a README model card
          documenting the FusionBench configuration.
        """
        model.save_pretrained(path)
        self.load_processor().save_pretrained(path)

        if algorithm_config is not None and rank_zero_only.rank == 0:
            from fusion_bench.models.hf_utils import create_default_model_card

            model_card_str = create_default_model_card(
                algorithm_config=algorithm_config,
                description=description,
                modelpool_config=self.config,
                base_model=base_model,
            )
            with open(os.path.join(path, "README.md"), "w") as f:
                f.write(model_card_str)
load_model(model_name_or_config, *args, **kwargs)

Load a DINOv2 model described by a name, path, or DictConfig.

Accepts either a string (pretrained identifier or local path) or a config mapping with keys: config_path, optional pretrained (bool), and optional dataset_name to resize the classifier.

Returns:

  • A model whose forward is wrapped to return only logits to align

  • with FusionBench expectations.

Source code in fusion_bench/modelpool/dinov2_for_image_classification.py
@override
def load_model(self, model_name_or_config: Union[str, DictConfig], *args, **kwargs):
    """Load a DINOv2 model described by a name, path, or DictConfig.

    Accepts either a string (pretrained identifier or local path) or a
    config mapping with keys: `config_path`, optional `pretrained` (bool),
    and optional `dataset_name` to resize the classifier.

    Returns:
        A model whose `forward` is wrapped to return only logits to align
        with FusionBench expectations.
    """
    log.debug(f"Loading model: {model_name_or_config}", stacklevel=2)
    if (
        isinstance(model_name_or_config, str)
        and model_name_or_config in self._models
    ):
        model_name_or_config = self._models[model_name_or_config]

    match model_name_or_config:
        case str() as model_path:
            from transformers import AutoModelForImageClassification

            model = AutoModelForImageClassification.from_pretrained(model_path)
        case dict() | DictConfig() as model_config:
            model = load_transformers_dinov2(
                model_config["config_path"],
                pretrained=model_config.get("pretrained", True),
                dataset_name=model_config.get("dataset_name", None),
            )
        case _:
            raise ValueError(
                f"Unsupported model_name_or_config type: {type(model_name_or_config)}"
            )

    # Override forward to return logits only, to unify the interface across
    # FusionBench model pools and simplify downstream usage.
    original_forward = model.forward
    model.forward = lambda pixel_values, **kwargs: original_forward(
        pixel_values=pixel_values, **kwargs
    ).logits
    model.original_forward = original_forward

    return model
load_processor(*args, **kwargs)

Load the paired image processor for this model pool.

Uses the configured model's identifier or config path to retrieve the appropriate transformers.AutoImageProcessor instance. If a pretrained model entry exists in the pool configuration, it is preferred to derive the processor to ensure tokenization/normalization parity.

Source code in fusion_bench/modelpool/dinov2_for_image_classification.py
def load_processor(self, *args, **kwargs):
    """Load the paired image processor for this model pool.

    Uses the configured model's identifier or config path to retrieve the
    appropriate `transformers.AutoImageProcessor` instance. If a pretrained
    model entry exists in the pool configuration, it is preferred to derive
    the processor to ensure tokenization/normalization parity.
    """
    from transformers import AutoImageProcessor

    if self.has_pretrained:
        config_path = self._models["_pretrained_"].config_path
    else:
        for model_cfg in self._models.values():
            if isinstance(model_cfg, str):
                config_path = model_cfg
                break
            if "config_path" in model_cfg:
                config_path = model_cfg["config_path"]
                break
    return AutoImageProcessor.from_pretrained(config_path)
save_model(model, path, algorithm_config=None, description=None, base_model=None, *args, **kwargs)

Save the model, processor, and an optional model card to disk.

Artifacts written to path: - The DINOv2 model via model.save_pretrained. - The paired image processor via AutoImageProcessor.save_pretrained. - If algorithm_config is provided and on rank-zero, a README model card documenting the FusionBench configuration.

Source code in fusion_bench/modelpool/dinov2_for_image_classification.py
@override
def save_model(
    self,
    model,
    path,
    algorithm_config: Optional[DictConfig] = None,
    description: Optional[str] = None,
    base_model: Optional[str] = None,
    *args,
    **kwargs,
):
    """Save the model, processor, and an optional model card to disk.

    Artifacts written to `path`:
    - The DINOv2 model via `model.save_pretrained`.
    - The paired image processor via `AutoImageProcessor.save_pretrained`.
    - If `algorithm_config` is provided and on rank-zero, a README model card
      documenting the FusionBench configuration.
    """
    model.save_pretrained(path)
    self.load_processor().save_pretrained(path)

    if algorithm_config is not None and rank_zero_only.rank == 0:
        from fusion_bench.models.hf_utils import create_default_model_card

        model_card_str = create_default_model_card(
            algorithm_config=algorithm_config,
            description=description,
            modelpool_config=self.config,
            base_model=base_model,
        )
        with open(os.path.join(path, "README.md"), "w") as f:
            f.write(model_card_str)

load_transformers_dinov2(config_path, pretrained, dataset_name)

Create a DINOv2 image classification model from a config or checkpoint.

Parameters:

  • config_path (str) –

    A model identifier or local path understood by transformers.AutoConfig/AutoModel (e.g., "facebook/dinov2-base").

  • pretrained (bool) –

    If True, load weights via from_pretrained; otherwise, build the model from config only.

  • dataset_name (Optional[str]) –

    Optional dataset key used by FusionBench to derive class names via get_classnames. When provided, the model's id/label maps are updated and the classifier head is resized accordingly.

Returns:

  • Dinov2ForImageClassification

    A transformers.Dinov2ForImageClassification instance. If dataset_name is set, the classifier head is adapted to the number of classes. The model's config.id2label and config.label2id are also populated.

Notes

The overall structure mirrors the ConvNeXt implementation in fusion_bench.modelpool.convnext_for_image_classification.

Source code in fusion_bench/modelpool/dinov2_for_image_classification.py
def load_transformers_dinov2(
    config_path: str, pretrained: bool, dataset_name: Optional[str]
):
    """Create a DINOv2 image classification model from a config or checkpoint.

    Args:
        config_path: A model identifier or local path understood by
            `transformers.AutoConfig/AutoModel` (e.g., "facebook/dinov2-base").
        pretrained: If True, load weights via `from_pretrained`; otherwise, build
            the model from config only.
        dataset_name: Optional dataset key used by FusionBench to derive class
            names via `get_classnames`. When provided, the model's id/label maps
            are updated and the classifier head is resized accordingly.

    Returns:
        Dinov2ForImageClassification: A `transformers.Dinov2ForImageClassification` instance. If
            `dataset_name` is set, the classifier head is adapted to the number of
            classes. The model's `config.id2label` and `config.label2id` are also
            populated.

    Notes:
        The overall structure mirrors the ConvNeXt implementation in
        `fusion_bench.modelpool.convnext_for_image_classification`.
    """
    from transformers import AutoConfig, Dinov2ForImageClassification

    if pretrained:
        model = Dinov2ForImageClassification.from_pretrained(config_path)
    else:
        config = AutoConfig.from_pretrained(config_path)
        model = Dinov2ForImageClassification(config)

    if dataset_name is None:
        return model

    classnames = get_classnames(dataset_name)
    id2label = {i: c for i, c in enumerate(classnames)}
    label2id = {c: i for i, c in enumerate(classnames)}
    model.config.id2label = id2label
    model.config.label2id = label2id
    model.num_labels = model.config.num_labels

    # If the model is configured with a positive number of labels, resize the
    # classifier to match the dataset classes; otherwise leave it as identity.
    model.classifier = (
        nn.Linear(
            model.classifier.in_features,
            len(classnames),
            device=model.classifier.weight.device,
            dtype=model.classifier.weight.dtype,
        )
        if model.config.num_labels > 0
        else nn.Identity()
    )
    return model

NLP Model Pool

GPT-2

HuggingFaceGPT2ClassificationPool = GPT2ForSequenceClassificationPool module-attribute

GPT2ForSequenceClassificationPool

Bases: BaseModelPool

Source code in fusion_bench/modelpool/huggingface_gpt2_classification.py
class GPT2ForSequenceClassificationPool(BaseModelPool):
    _config_mapping = BaseModelPool._config_mapping | {"_tokenizer": "tokenizer"}

    def __init__(self, tokenizer: DictConfig, **kwargs):
        self._tokenizer = tokenizer
        super().__init__(**kwargs)
        self.setup()

    def setup(self):
        global tokenizer
        self.tokenizer = tokenizer = instantiate(self._tokenizer)

    def load_classifier(
        self, model_config: str | DictConfig
    ) -> GPT2ForSequenceClassification:
        if isinstance(model_config, str):
            model_config = self.get_model_config(model_config, return_copy=True)
        model_config._target_ = (
            "transformers.GPT2ForSequenceClassification.from_pretrained"
        )
        model = instantiate(model_config)
        return model

Seq2Seq Language Models (Flan-T5)

Seq2SeqLMPool

Bases: BaseModelPool

A model pool specialized for sequence-to-sequence language models.

This model pool provides management and loading capabilities for sequence-to-sequence (seq2seq) language models such as T5, BART, and mT5. It extends the base model pool functionality with seq2seq-specific features including tokenizer management and model configuration handling.

Seq2seq models are particularly useful for tasks that require generating output sequences from input sequences, such as translation, summarization, question answering, and text generation. This pool streamlines the process of loading and configuring multiple seq2seq models for fusion and ensemble scenarios.

Key Features
  • Specialized loading for AutoModelForSeq2SeqLM models
  • Integrated tokenizer management
  • Support for model-specific keyword arguments
  • Automatic dtype parsing and configuration
  • Compatible with PEFT (Parameter-Efficient Fine-Tuning) adapters

Attributes:

  • _tokenizer

    Configuration for the tokenizer associated with the models

  • _model_kwargs

    Default keyword arguments applied to all model loading operations

Example
pool = Seq2SeqLMPool(
    models={
        "t5_base": "t5-base",
        "t5_large": "t5-large",
        "custom_model": "/path/to/local/model"
    },
    tokenizer={"_target_": "transformers.T5Tokenizer",
              "pretrained_model_name_or_path": "t5-base"},
    model_kwargs={"torch_dtype": "float16", "device_map": "auto"}
)
model = pool.load_model("t5_base")
tokenizer = pool.load_tokenizer()
Source code in fusion_bench/modelpool/seq2seq_lm/modelpool.py
class Seq2SeqLMPool(BaseModelPool):
    """A model pool specialized for sequence-to-sequence language models.

    This model pool provides management and loading capabilities for sequence-to-sequence
    (seq2seq) language models such as T5, BART, and mT5. It extends the base model pool
    functionality with seq2seq-specific features including tokenizer management and
    model configuration handling.

    Seq2seq models are particularly useful for tasks that require generating output
    sequences from input sequences, such as translation, summarization, question
    answering, and text generation. This pool streamlines the process of loading
    and configuring multiple seq2seq models for fusion and ensemble scenarios.

    Key Features:
        - Specialized loading for AutoModelForSeq2SeqLM models
        - Integrated tokenizer management
        - Support for model-specific keyword arguments
        - Automatic dtype parsing and configuration
        - Compatible with PEFT (Parameter-Efficient Fine-Tuning) adapters

    Attributes:
        _tokenizer: Configuration for the tokenizer associated with the models
        _model_kwargs: Default keyword arguments applied to all model loading operations

    Example:
        ```python
        pool = Seq2SeqLMPool(
            models={
                "t5_base": "t5-base",
                "t5_large": "t5-large",
                "custom_model": "/path/to/local/model"
            },
            tokenizer={"_target_": "transformers.T5Tokenizer",
                      "pretrained_model_name_or_path": "t5-base"},
            model_kwargs={"torch_dtype": "float16", "device_map": "auto"}
        )
        model = pool.load_model("t5_base")
        tokenizer = pool.load_tokenizer()
        ```
    """

    _config_mapping = BaseModelPool._config_mapping | {
        "_tokenizer": "tokenizer",
        "_model_kwargs": "model_kwargs",
    }

    def __init__(
        self,
        models: DictConfig,
        *,
        tokenizer: Optional[DictConfig],
        model_kwargs: Optional[DictConfig] = None,
        **kwargs,
    ):
        """Initialize the sequence-to-sequence language model pool.

        Sets up the model pool with configurations for models, tokenizer, and
        default model loading parameters. Automatically processes model kwargs
        to handle special configurations like torch_dtype parsing.

        Args:
            models: Configuration dictionary specifying the seq2seq models to manage.
                Keys are model names, values can be model paths/names or detailed configs.
            tokenizer: Configuration for the tokenizer to use with the models.
                Can be a simple path/name or detailed configuration with _target_.
            model_kwargs: Default keyword arguments applied to all model loading
                operations. Common options include torch_dtype, device_map, etc.
                The torch_dtype field is automatically parsed from string to dtype.
            **kwargs: Additional arguments passed to the parent BaseModelPool.

        Example:
            ```python
            pool = Seq2SeqLMPool(
                models={
                    "base": "t5-base",
                    "large": {"_target_": "transformers.AutoModelForSeq2SeqLM",
                             "pretrained_model_name_or_path": "t5-large"}
                },
                tokenizer="t5-base",
                model_kwargs={"torch_dtype": "bfloat16"}
            )
            ```
        """
        super().__init__(models, **kwargs)
        self._tokenizer = tokenizer
        self._model_kwargs = model_kwargs
        if self._model_kwargs is None:
            self._model_kwargs = DictConfig({})
        with flag_override(self._model_kwargs, "allow_objects", True):
            if hasattr(self._model_kwargs, "torch_dtype"):
                self._model_kwargs.torch_dtype = parse_dtype(
                    self._model_kwargs.torch_dtype
                )

    def load_model(self, model_name_or_config: str | DictConfig, *args, **kwargs):
        """Load a sequence-to-sequence language model from the pool.

        Loads a seq2seq model using the parent class loading mechanism while
        automatically applying the pool's default model kwargs. The method
        merges the pool's model_kwargs with any additional kwargs provided,
        giving priority to the explicitly provided kwargs.

        Args:
            model_name_or_config: Either a string model name from the pool
                configuration or a DictConfig containing model loading parameters.
            *args: Additional positional arguments passed to the parent load_model method.
            **kwargs: Additional keyword arguments that override the pool's default
                model_kwargs. Common options include device, torch_dtype, etc.

        Returns:
            AutoModelForSeq2SeqLM: The loaded sequence-to-sequence language model.
        """
        model_kwargs = deepcopy(self._model_kwargs)
        model_kwargs.update(kwargs)
        return super().load_model(model_name_or_config, *args, **model_kwargs)

    def load_tokenizer(self, *args, **kwargs):
        """Load the tokenizer associated with the sequence-to-sequence models.

        Loads a tokenizer based on the tokenizer configuration provided during
        pool initialization. The tokenizer should be compatible with the seq2seq
        models in the pool and is typically used for preprocessing input text
        and postprocessing generated output.

        Args:
            *args: Additional positional arguments passed to the tokenizer constructor.
            **kwargs: Additional keyword arguments passed to the tokenizer constructor.

        Returns:
            PreTrainedTokenizer: The loaded tokenizer instance compatible with
                the seq2seq models in this pool.

        Raises:
            AssertionError: If no tokenizer configuration is provided.
        """
        assert self._tokenizer is not None, "Tokenizer is not defined in the config"
        tokenizer = isinstance(self._tokenizer, *args, **kwargs)
        return tokenizer
__init__(models, *, tokenizer, model_kwargs=None, **kwargs)

Initialize the sequence-to-sequence language model pool.

Sets up the model pool with configurations for models, tokenizer, and default model loading parameters. Automatically processes model kwargs to handle special configurations like torch_dtype parsing.

Parameters:

  • models (DictConfig) –

    Configuration dictionary specifying the seq2seq models to manage. Keys are model names, values can be model paths/names or detailed configs.

  • tokenizer (Optional[DictConfig]) –

    Configuration for the tokenizer to use with the models. Can be a simple path/name or detailed configuration with target.

  • model_kwargs (Optional[DictConfig], default: None ) –

    Default keyword arguments applied to all model loading operations. Common options include torch_dtype, device_map, etc. The torch_dtype field is automatically parsed from string to dtype.

  • **kwargs

    Additional arguments passed to the parent BaseModelPool.

Example
pool = Seq2SeqLMPool(
    models={
        "base": "t5-base",
        "large": {"_target_": "transformers.AutoModelForSeq2SeqLM",
                 "pretrained_model_name_or_path": "t5-large"}
    },
    tokenizer="t5-base",
    model_kwargs={"torch_dtype": "bfloat16"}
)
Source code in fusion_bench/modelpool/seq2seq_lm/modelpool.py
def __init__(
    self,
    models: DictConfig,
    *,
    tokenizer: Optional[DictConfig],
    model_kwargs: Optional[DictConfig] = None,
    **kwargs,
):
    """Initialize the sequence-to-sequence language model pool.

    Sets up the model pool with configurations for models, tokenizer, and
    default model loading parameters. Automatically processes model kwargs
    to handle special configurations like torch_dtype parsing.

    Args:
        models: Configuration dictionary specifying the seq2seq models to manage.
            Keys are model names, values can be model paths/names or detailed configs.
        tokenizer: Configuration for the tokenizer to use with the models.
            Can be a simple path/name or detailed configuration with _target_.
        model_kwargs: Default keyword arguments applied to all model loading
            operations. Common options include torch_dtype, device_map, etc.
            The torch_dtype field is automatically parsed from string to dtype.
        **kwargs: Additional arguments passed to the parent BaseModelPool.

    Example:
        ```python
        pool = Seq2SeqLMPool(
            models={
                "base": "t5-base",
                "large": {"_target_": "transformers.AutoModelForSeq2SeqLM",
                         "pretrained_model_name_or_path": "t5-large"}
            },
            tokenizer="t5-base",
            model_kwargs={"torch_dtype": "bfloat16"}
        )
        ```
    """
    super().__init__(models, **kwargs)
    self._tokenizer = tokenizer
    self._model_kwargs = model_kwargs
    if self._model_kwargs is None:
        self._model_kwargs = DictConfig({})
    with flag_override(self._model_kwargs, "allow_objects", True):
        if hasattr(self._model_kwargs, "torch_dtype"):
            self._model_kwargs.torch_dtype = parse_dtype(
                self._model_kwargs.torch_dtype
            )
load_model(model_name_or_config, *args, **kwargs)

Load a sequence-to-sequence language model from the pool.

Loads a seq2seq model using the parent class loading mechanism while automatically applying the pool's default model kwargs. The method merges the pool's model_kwargs with any additional kwargs provided, giving priority to the explicitly provided kwargs.

Parameters:

  • model_name_or_config (str | DictConfig) –

    Either a string model name from the pool configuration or a DictConfig containing model loading parameters.

  • *args

    Additional positional arguments passed to the parent load_model method.

  • **kwargs

    Additional keyword arguments that override the pool's default model_kwargs. Common options include device, torch_dtype, etc.

Returns:

  • AutoModelForSeq2SeqLM

    The loaded sequence-to-sequence language model.

Source code in fusion_bench/modelpool/seq2seq_lm/modelpool.py
def load_model(self, model_name_or_config: str | DictConfig, *args, **kwargs):
    """Load a sequence-to-sequence language model from the pool.

    Loads a seq2seq model using the parent class loading mechanism while
    automatically applying the pool's default model kwargs. The method
    merges the pool's model_kwargs with any additional kwargs provided,
    giving priority to the explicitly provided kwargs.

    Args:
        model_name_or_config: Either a string model name from the pool
            configuration or a DictConfig containing model loading parameters.
        *args: Additional positional arguments passed to the parent load_model method.
        **kwargs: Additional keyword arguments that override the pool's default
            model_kwargs. Common options include device, torch_dtype, etc.

    Returns:
        AutoModelForSeq2SeqLM: The loaded sequence-to-sequence language model.
    """
    model_kwargs = deepcopy(self._model_kwargs)
    model_kwargs.update(kwargs)
    return super().load_model(model_name_or_config, *args, **model_kwargs)
load_tokenizer(*args, **kwargs)

Load the tokenizer associated with the sequence-to-sequence models.

Loads a tokenizer based on the tokenizer configuration provided during pool initialization. The tokenizer should be compatible with the seq2seq models in the pool and is typically used for preprocessing input text and postprocessing generated output.

Parameters:

  • *args

    Additional positional arguments passed to the tokenizer constructor.

  • **kwargs

    Additional keyword arguments passed to the tokenizer constructor.

Returns:

  • PreTrainedTokenizer

    The loaded tokenizer instance compatible with the seq2seq models in this pool.

Raises:

  • AssertionError

    If no tokenizer configuration is provided.

Source code in fusion_bench/modelpool/seq2seq_lm/modelpool.py
def load_tokenizer(self, *args, **kwargs):
    """Load the tokenizer associated with the sequence-to-sequence models.

    Loads a tokenizer based on the tokenizer configuration provided during
    pool initialization. The tokenizer should be compatible with the seq2seq
    models in the pool and is typically used for preprocessing input text
    and postprocessing generated output.

    Args:
        *args: Additional positional arguments passed to the tokenizer constructor.
        **kwargs: Additional keyword arguments passed to the tokenizer constructor.

    Returns:
        PreTrainedTokenizer: The loaded tokenizer instance compatible with
            the seq2seq models in this pool.

    Raises:
        AssertionError: If no tokenizer configuration is provided.
    """
    assert self._tokenizer is not None, "Tokenizer is not defined in the config"
    tokenizer = isinstance(self._tokenizer, *args, **kwargs)
    return tokenizer

SequenceClassificationModelPool

Bases: BaseModelPool

Source code in fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py
class SequenceClassificationModelPool(BaseModelPool):

    def __init__(
        self,
        models,
        *,
        tokenizer: Optional[DictConfig],
        model_kwargs: Optional[DictConfig] = None,
        **kwargs,
    ):
        super().__init__(models, **kwargs)
        # process `model_kwargs`
        self._tokenizer = tokenizer
        self._model_kwargs = model_kwargs
        if self._model_kwargs is None:
            self._model_kwargs = DictConfig({})
        with flag_override(self._model_kwargs, "allow_objects", True):
            if hasattr(self._model_kwargs, "torch_dtype"):
                self._model_kwargs.torch_dtype = parse_dtype(
                    self._model_kwargs.torch_dtype
                )

    @override
    def load_model(
        self,
        model_name_or_config: str | DictConfig,
        *args,
        **kwargs,
    ) -> Union[PreTrainedModel, "LlamaForSequenceClassification"]:
        model_kwargs = deepcopy(self._model_kwargs)
        model_kwargs.update(kwargs)
        if isinstance(model_name_or_config, str):
            log.info(f"Loading model: {model_name_or_config}", stacklevel=2)
        return super().load_model(model_name_or_config, *args, **model_kwargs)

    def load_tokenizer(self, *args, **kwargs) -> PreTrainedTokenizer:
        assert self._tokenizer is not None, "Tokenizer is not defined in the config"
        log.info("Loading tokenizer.", stacklevel=2)
        tokenizer = instantiate(self._tokenizer, *args, **kwargs)
        return tokenizer

    @override
    def save_model(
        self,
        model: PreTrainedModel,
        path: str,
        push_to_hub: bool = False,
        model_dtype: Optional[str] = None,
        save_tokenizer: bool = False,
        tokenizer_kwargs=None,
        **kwargs,
    ):
        """
        Save the model to the specified path.

        Args:
            model (PreTrainedModel): The model to be saved.
            path (str): The path where the model will be saved.
            push_to_hub (bool, optional): Whether to push the model to the Hugging Face Hub. Defaults to False.
            save_tokenizer (bool, optional): Whether to save the tokenizer along with the model. Defaults to False.
            **kwargs: Additional keyword arguments passed to the `save_pretrained` method.
        """
        path = os.path.expanduser(path)
        if save_tokenizer:
            if tokenizer_kwargs is None:
                tokenizer_kwargs = {}
            # load the tokenizer
            tokenizer = self.load_tokenizer(**tokenizer_kwargs)
            tokenizer.save_pretrained(
                path,
                push_to_hub=push_to_hub,
            )
        if model_dtype is not None:
            model.to(dtype=parse_dtype(model_dtype))
        model.save_pretrained(
            path,
            push_to_hub=push_to_hub,
            **kwargs,
        )
save_model(model, path, push_to_hub=False, model_dtype=None, save_tokenizer=False, tokenizer_kwargs=None, **kwargs)

Save the model to the specified path.

Parameters:

  • model (PreTrainedModel) –

    The model to be saved.

  • path (str) –

    The path where the model will be saved.

  • push_to_hub (bool, default: False ) –

    Whether to push the model to the Hugging Face Hub. Defaults to False.

  • save_tokenizer (bool, default: False ) –

    Whether to save the tokenizer along with the model. Defaults to False.

  • **kwargs

    Additional keyword arguments passed to the save_pretrained method.

Source code in fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py
@override
def save_model(
    self,
    model: PreTrainedModel,
    path: str,
    push_to_hub: bool = False,
    model_dtype: Optional[str] = None,
    save_tokenizer: bool = False,
    tokenizer_kwargs=None,
    **kwargs,
):
    """
    Save the model to the specified path.

    Args:
        model (PreTrainedModel): The model to be saved.
        path (str): The path where the model will be saved.
        push_to_hub (bool, optional): Whether to push the model to the Hugging Face Hub. Defaults to False.
        save_tokenizer (bool, optional): Whether to save the tokenizer along with the model. Defaults to False.
        **kwargs: Additional keyword arguments passed to the `save_pretrained` method.
    """
    path = os.path.expanduser(path)
    if save_tokenizer:
        if tokenizer_kwargs is None:
            tokenizer_kwargs = {}
        # load the tokenizer
        tokenizer = self.load_tokenizer(**tokenizer_kwargs)
        tokenizer.save_pretrained(
            path,
            push_to_hub=push_to_hub,
        )
    if model_dtype is not None:
        model.to(dtype=parse_dtype(model_dtype))
    model.save_pretrained(
        path,
        push_to_hub=push_to_hub,
        **kwargs,
    )

PeftModelForSeq2SeqLMPool

Bases: ModelPool

Source code in fusion_bench/modelpool/PeftModelForSeq2SeqLM.py
class PeftModelForSeq2SeqLMPool(ModelPool):
    def load_model(self, model_config: str | DictConfig):
        """
        Load a model based on the provided configuration.

        The configuration options of `model_config` are:

        - name: The name of the model. If it is "_pretrained_", a pretrained Seq2Seq language model is returned.
        - path: The path where the model is stored.
        - is_trainable: A boolean indicating whether the model parameters should be trainable. Default is `True`.
        - merge_and_unload: A boolean indicating whether to merge and unload the PEFT model after loading. Default is `True`.


        Args:
            model_config (str | DictConfig): The configuration for the model. This can be either a string (name of the model) or a DictConfig object containing the model configuration.


        Returns:
            model: The loaded model. If the model name is "_pretrained_", it returns a pretrained Seq2Seq language model. Otherwise, it returns a PEFT model.
        """
        if isinstance(model_config, str):
            model_config = self.get_model_config(model_config)
        with timeit_context(f"Loading model {model_config['name']}"):
            if model_config["name"] == "_pretrained_":
                model = AutoModelForSeq2SeqLM.from_pretrained(model_config["path"])
                return model
            else:
                model = self.load_model("_pretrained_")
                peft_model = PeftModel.from_pretrained(
                    model,
                    model_config["path"],
                    is_trainable=model_config.get("is_trainable", True),
                )
                if model_config.get("merge_and_unload", True):
                    return peft_model.merge_and_unload()
                else:
                    return peft_model
load_model(model_config)

Load a model based on the provided configuration.

The configuration options of model_config are:

  • name: The name of the model. If it is "pretrained", a pretrained Seq2Seq language model is returned.
  • path: The path where the model is stored.
  • is_trainable: A boolean indicating whether the model parameters should be trainable. Default is True.
  • merge_and_unload: A boolean indicating whether to merge and unload the PEFT model after loading. Default is True.

Parameters:

  • model_config (str | DictConfig) –

    The configuration for the model. This can be either a string (name of the model) or a DictConfig object containing the model configuration.

Returns:

  • model

    The loaded model. If the model name is "pretrained", it returns a pretrained Seq2Seq language model. Otherwise, it returns a PEFT model.

Source code in fusion_bench/modelpool/PeftModelForSeq2SeqLM.py
def load_model(self, model_config: str | DictConfig):
    """
    Load a model based on the provided configuration.

    The configuration options of `model_config` are:

    - name: The name of the model. If it is "_pretrained_", a pretrained Seq2Seq language model is returned.
    - path: The path where the model is stored.
    - is_trainable: A boolean indicating whether the model parameters should be trainable. Default is `True`.
    - merge_and_unload: A boolean indicating whether to merge and unload the PEFT model after loading. Default is `True`.


    Args:
        model_config (str | DictConfig): The configuration for the model. This can be either a string (name of the model) or a DictConfig object containing the model configuration.


    Returns:
        model: The loaded model. If the model name is "_pretrained_", it returns a pretrained Seq2Seq language model. Otherwise, it returns a PEFT model.
    """
    if isinstance(model_config, str):
        model_config = self.get_model_config(model_config)
    with timeit_context(f"Loading model {model_config['name']}"):
        if model_config["name"] == "_pretrained_":
            model = AutoModelForSeq2SeqLM.from_pretrained(model_config["path"])
            return model
        else:
            model = self.load_model("_pretrained_")
            peft_model = PeftModel.from_pretrained(
                model,
                model_config["path"],
                is_trainable=model_config.get("is_trainable", True),
            )
            if model_config.get("merge_and_unload", True):
                return peft_model.merge_and_unload()
            else:
                return peft_model

Causal Language Models (Llama, Mistral, Qwen...)

CausalLMPool

Bases: BaseModelPool

A model pool for managing and loading causal language models.

This class provides a unified interface for loading and managing multiple causal language models, typically used in model fusion and ensemble scenarios. It supports both eager and lazy loading strategies, and handles model configuration through YAML configs or direct instantiation.

The pool can manage models from Hugging Face Hub, local paths, or custom configurations. It also provides tokenizer management and model saving capabilities with optional Hugging Face Hub integration.

Parameters:

  • models

    Dictionary or configuration specifying the models to be managed. Can contain model names mapped to paths or detailed configurations.

  • tokenizer (Optional[DictConfig | str]) –

    Tokenizer configuration, either a string path/name or a DictConfig with detailed tokenizer settings.

  • model_kwargs (Optional[DictConfig], default: None ) –

    Additional keyword arguments passed to model loading. Common options include torch_dtype, device_map, etc.

  • enable_lazy_loading (bool, default: False ) –

    Whether to use lazy loading for models. When True, models are loaded as LazyStateDict objects instead of actual models, which can save memory for large model collections.

  • **kwargs

    Additional arguments passed to the parent BaseModelPool.

Example
>>> pool = CausalLMPool(
...     models={
...         "model_a": "microsoft/DialoGPT-medium",
...         "model_b": "/path/to/local/model"
...     },
...     tokenizer="microsoft/DialoGPT-medium",
...     model_kwargs={"torch_dtype": "bfloat16"}
... )
>>> model = pool.load_model("model_a")
>>> tokenizer = pool.load_tokenizer()
Source code in fusion_bench/modelpool/causal_lm/causal_lm.py
@auto_register_config
class CausalLMPool(BaseModelPool):
    """A model pool for managing and loading causal language models.

    This class provides a unified interface for loading and managing multiple
    causal language models, typically used in model fusion and ensemble scenarios.
    It supports both eager and lazy loading strategies, and handles model
    configuration through YAML configs or direct instantiation.

    The pool can manage models from Hugging Face Hub, local paths, or custom
    configurations. It also provides tokenizer management and model saving
    capabilities with optional Hugging Face Hub integration.

    Args:
        models: Dictionary or configuration specifying the models to be managed.
            Can contain model names mapped to paths or detailed configurations.
        tokenizer: Tokenizer configuration, either a string path/name or
            a DictConfig with detailed tokenizer settings.
        model_kwargs: Additional keyword arguments passed to model loading.
            Common options include torch_dtype, device_map, etc.
        enable_lazy_loading: Whether to use lazy loading for models. When True,
            models are loaded as LazyStateDict objects instead of actual models,
            which can save memory for large model collections.
        **kwargs: Additional arguments passed to the parent BaseModelPool.

    Example:
        ```python
        >>> pool = CausalLMPool(
        ...     models={
        ...         "model_a": "microsoft/DialoGPT-medium",
        ...         "model_b": "/path/to/local/model"
        ...     },
        ...     tokenizer="microsoft/DialoGPT-medium",
        ...     model_kwargs={"torch_dtype": "bfloat16"}
        ... )
        >>> model = pool.load_model("model_a")
        >>> tokenizer = pool.load_tokenizer()
        ```
    """

    def __init__(
        self,
        models,
        *,
        tokenizer: Optional[DictConfig | str],
        model_kwargs: Optional[DictConfig] = None,
        enable_lazy_loading: bool = False,
        **kwargs,
    ):
        super().__init__(models, **kwargs)
        if model_kwargs is None:
            self.model_kwargs = DictConfig({})

    def get_model_path(self, model_name: str):
        """Extract the model path from the model configuration.

        Args:
            model_name: The name of the model as defined in the models configuration.

        Returns:
            str: The path or identifier for the model. For string configurations,
                returns the string directly. For dict configurations, extracts
                the 'pretrained_model_name_or_path' field.

        Raises:
            RuntimeError: If the model configuration is invalid or the model
                name is not found in the configuration.
        """
        model_name_or_config = self._models[model_name]
        if isinstance(model_name_or_config, str):
            return model_name_or_config
        elif isinstance(model_name_or_config, (DictConfig, dict)):
            return model_name_or_config.get("pretrained_model_name_or_path")
        else:
            raise RuntimeError("Invalid model configuration")

    def get_model_kwargs(self):
        """Get processed model keyword arguments for model loading.

        Converts the stored `model_kwargs` from DictConfig to a regular dictionary
        and processes special arguments like torch_dtype for proper model loading.

        Returns:
            dict: Processed keyword arguments ready to be passed to model
                loading functions. The torch_dtype field, if present, is
                converted from string to the appropriate torch dtype object.
        """
        model_kwargs = (
            OmegaConf.to_container(self.model_kwargs, resolve=True)
            if isinstance(self.model_kwargs, DictConfig)
            else self.model_kwargs
        )
        if "torch_dtype" in model_kwargs:
            model_kwargs["torch_dtype"] = parse_dtype(model_kwargs["torch_dtype"])
        return model_kwargs

    @override
    def load_model(
        self,
        model_name_or_config: str | DictConfig,
        *args,
        **kwargs,
    ) -> Union[PreTrainedModel, LazyStateDict]:
        """Load a causal language model from the model pool.

        This method supports multiple loading strategies:
        1. Loading by model name from the configured model pool
        2. Loading from a direct configuration dictionary
        3. Lazy loading using LazyStateDict for memory efficiency

        The method automatically handles different model configuration formats
        and applies the appropriate loading strategy based on the enable_lazy_loading flag.

        Args:
            model_name_or_config: Either a string model name that exists in the
                model pool configuration, or a DictConfig/dict containing the
                model configuration directly.
            *args: Additional positional arguments passed to the model constructor.
            **kwargs: Additional keyword arguments passed to the model constructor.
                These will be merged with the pool's model_kwargs.

        Returns:
            Union[PreTrainedModel, LazyStateDict]: The loaded model. Returns a
                PreTrainedModel for normal loading or a LazyStateDict for lazy loading.

        Raises:
            RuntimeError: If the model configuration is invalid.
            KeyError: If the model name is not found in the model pool.

        Example YAML configurations:
            Simple string configuration:
            ```yaml
            models:
              _pretrained_: path_to_pretrained_model
              model_a: path_to_model_a
              model_b: path_to_model_b
            ```

            Detailed configuration:
            ```yaml
            models:
              _pretrained_:
                _target_: transformers.AutoModelForCausalLM
                pretrained_model_name_or_path: path_to_pretrained_model
              model_a:
                _target_: transformers.AutoModelForCausalLM
                pretrained_model_name_or_path: path_to_model_a
            ```
        """
        model_kwargs = self.get_model_kwargs()
        model_kwargs.update(kwargs)

        if isinstance(model_name_or_config, str):
            # If model_name_or_config is a string, it is the name or the path of the model
            log.info(f"Loading model: {model_name_or_config}", stacklevel=2)
            if model_name_or_config in self._models.keys():
                model_config = self._models[model_name_or_config]
                if isinstance(model_config, str):
                    # model_config is a string
                    if not self.enable_lazy_loading:
                        model = AutoModelForCausalLM.from_pretrained(
                            model_config,
                            *args,
                            **model_kwargs,
                        )
                    else:
                        # model_config is a string, but we want to use LazyStateDict
                        model = LazyStateDict(
                            checkpoint=model_config,
                            meta_module_class=AutoModelForCausalLM,
                            *args,
                            **model_kwargs,
                        )
                    return model
        elif isinstance(model_name_or_config, (DictConfig, Dict)):
            model_config = model_name_or_config

        if not self.enable_lazy_loading:
            model = instantiate(model_config, *args, **model_kwargs)
        else:
            meta_module_class = model_config.pop("_target_")
            checkpoint = model_config.pop("pretrained_model_name_or_path")
            model = LazyStateDict(
                checkpoint=checkpoint,
                meta_module_class=meta_module_class,
                *args,
                **model_kwargs,
            )
        return model

    def load_tokenizer(self, *args, **kwargs) -> PreTrainedTokenizer:
        """Load the tokenizer associated with this model pool.

        Loads a tokenizer based on the tokenizer configuration provided during
        pool initialization. Supports both simple string paths and detailed
        configuration dictionaries.

        Args:
            *args: Additional positional arguments passed to the tokenizer constructor.
            **kwargs: Additional keyword arguments passed to the tokenizer constructor.

        Returns:
            PreTrainedTokenizer: The loaded tokenizer instance.

        Raises:
            AssertionError: If no tokenizer is defined in the configuration.

        Example YAML configurations:
            Simple string configuration:
            ```yaml
            tokenizer: google/gemma-2-2b-it
            ```

            Detailed configuration:
            ```yaml
            tokenizer:
              _target_: transformers.AutoTokenizer
              pretrained_model_name_or_path: google/gemma-2-2b-it
              use_fast: true
              padding_side: left
            ```
        """
        assert self.tokenizer is not None, "Tokenizer is not defined in the config"
        log.info("Loading tokenizer.", stacklevel=2)
        if isinstance(self.tokenizer, str):
            tokenizer = AutoTokenizer.from_pretrained(self.tokenizer, *args, **kwargs)
        else:
            tokenizer = instantiate(self.tokenizer, *args, **kwargs)
        return tokenizer

    @override
    def save_model(
        self,
        model: PreTrainedModel,
        path: str,
        push_to_hub: bool = False,
        model_dtype: Optional[str] = None,
        save_tokenizer: bool = False,
        tokenizer_kwargs=None,
        tokenizer: Optional[PreTrainedTokenizer] = None,
        algorithm_config: Optional[DictConfig] = None,
        description: Optional[str] = None,
        base_model_in_modelcard: bool = True,
        **kwargs,
    ):
        """Save a model to the specified path with optional tokenizer and Hub upload.

        This method provides comprehensive model saving capabilities including
        optional tokenizer saving, dtype conversion, model card creation, and
        Hugging Face Hub upload. The model is saved in the standard Hugging Face format.

        Args:
            model: The PreTrainedModel instance to be saved.
            path: The local path where the model will be saved. Supports tilde
                expansion for home directory paths.
            push_to_hub: Whether to push the saved model to the Hugging Face Hub.
                Requires proper authentication and repository permissions.
            model_dtype: Optional string specifying the target dtype for the model
                before saving (e.g., "float16", "bfloat16"). The model will be
                converted to this dtype before saving.
            save_tokenizer: Whether to save the tokenizer alongside the model.
                If True, the tokenizer will be loaded using the pool's tokenizer
                configuration and saved to the same path.
            tokenizer_kwargs: Additional keyword arguments for tokenizer loading
                when save_tokenizer is True.
            tokenizer: Optional pre-loaded tokenizer instance. If provided, this
                tokenizer will be saved regardless of the save_tokenizer flag.
            algorithm_config: Optional DictConfig containing algorithm configuration.
                If provided, a model card will be created with algorithm details.
            description: Optional description for the model card. If not provided
                and algorithm_config is given, a default description will be generated.
            **kwargs: Additional keyword arguments passed to the model's
                save_pretrained method.

        Example:
            ```python
            >>> pool = CausalLMPool(models=..., tokenizer=...)
            >>> model = pool.load_model("my_model")
            >>> pool.save_model(
            ...     model,
            ...     "/path/to/save",
            ...     save_tokenizer=True,
            ...     model_dtype="float16",
            ...     push_to_hub=True,
            ...     algorithm_config=algorithm_config,
            ...     description="Custom merged model"
            ... )
            ```
        """
        path = os.path.expanduser(path)
        # NOTE: if tokenizer is provided, it will be saved regardless of `save_tokenizer`
        if save_tokenizer or tokenizer is not None:
            if tokenizer is None:
                if tokenizer_kwargs is None:
                    tokenizer_kwargs = {}
                # load the tokenizer
                tokenizer = self.load_tokenizer(**tokenizer_kwargs)
            tokenizer.save_pretrained(
                path,
                push_to_hub=push_to_hub,
            )
        if model_dtype is not None:
            model.to(dtype=parse_dtype(model_dtype))
        model.save_pretrained(
            path,
            push_to_hub=push_to_hub,
            **kwargs,
        )

        # Create and save model card if algorithm_config is provided
        if algorithm_config is not None and rank_zero_only.rank == 0:
            if description is None:
                description = "Model created using FusionBench."
            model_card_str = create_default_model_card(
                base_model=(
                    self.get_model_path("_pretrained_")
                    if base_model_in_modelcard and self.has_pretrained
                    else None
                ),
                models=[self.get_model_path(m) for m in self.model_names],
                description=description,
                algorithm_config=algorithm_config,
                modelpool_config=self.config,
            )
            with open(os.path.join(path, "README.md"), "w") as f:
                f.write(model_card_str)
get_model_kwargs()

Get processed model keyword arguments for model loading.

Converts the stored model_kwargs from DictConfig to a regular dictionary and processes special arguments like torch_dtype for proper model loading.

Returns:

  • dict

    Processed keyword arguments ready to be passed to model loading functions. The torch_dtype field, if present, is converted from string to the appropriate torch dtype object.

Source code in fusion_bench/modelpool/causal_lm/causal_lm.py
def get_model_kwargs(self):
    """Get processed model keyword arguments for model loading.

    Converts the stored `model_kwargs` from DictConfig to a regular dictionary
    and processes special arguments like torch_dtype for proper model loading.

    Returns:
        dict: Processed keyword arguments ready to be passed to model
            loading functions. The torch_dtype field, if present, is
            converted from string to the appropriate torch dtype object.
    """
    model_kwargs = (
        OmegaConf.to_container(self.model_kwargs, resolve=True)
        if isinstance(self.model_kwargs, DictConfig)
        else self.model_kwargs
    )
    if "torch_dtype" in model_kwargs:
        model_kwargs["torch_dtype"] = parse_dtype(model_kwargs["torch_dtype"])
    return model_kwargs
get_model_path(model_name)

Extract the model path from the model configuration.

Parameters:

  • model_name (str) –

    The name of the model as defined in the models configuration.

Returns:

  • str

    The path or identifier for the model. For string configurations, returns the string directly. For dict configurations, extracts the 'pretrained_model_name_or_path' field.

Raises:

  • RuntimeError

    If the model configuration is invalid or the model name is not found in the configuration.

Source code in fusion_bench/modelpool/causal_lm/causal_lm.py
def get_model_path(self, model_name: str):
    """Extract the model path from the model configuration.

    Args:
        model_name: The name of the model as defined in the models configuration.

    Returns:
        str: The path or identifier for the model. For string configurations,
            returns the string directly. For dict configurations, extracts
            the 'pretrained_model_name_or_path' field.

    Raises:
        RuntimeError: If the model configuration is invalid or the model
            name is not found in the configuration.
    """
    model_name_or_config = self._models[model_name]
    if isinstance(model_name_or_config, str):
        return model_name_or_config
    elif isinstance(model_name_or_config, (DictConfig, dict)):
        return model_name_or_config.get("pretrained_model_name_or_path")
    else:
        raise RuntimeError("Invalid model configuration")
load_model(model_name_or_config, *args, **kwargs)

Load a causal language model from the model pool.

This method supports multiple loading strategies: 1. Loading by model name from the configured model pool 2. Loading from a direct configuration dictionary 3. Lazy loading using LazyStateDict for memory efficiency

The method automatically handles different model configuration formats and applies the appropriate loading strategy based on the enable_lazy_loading flag.

Parameters:

  • model_name_or_config (str | DictConfig) –

    Either a string model name that exists in the model pool configuration, or a DictConfig/dict containing the model configuration directly.

  • *args

    Additional positional arguments passed to the model constructor.

  • **kwargs

    Additional keyword arguments passed to the model constructor. These will be merged with the pool's model_kwargs.

Returns:

  • Union[PreTrainedModel, LazyStateDict]

    Union[PreTrainedModel, LazyStateDict]: The loaded model. Returns a PreTrainedModel for normal loading or a LazyStateDict for lazy loading.

Raises:

  • RuntimeError

    If the model configuration is invalid.

  • KeyError

    If the model name is not found in the model pool.

Example YAML configurations

Simple string configuration:

models:
  _pretrained_: path_to_pretrained_model
  model_a: path_to_model_a
  model_b: path_to_model_b

Detailed configuration:

models:
  _pretrained_:
    _target_: transformers.AutoModelForCausalLM
    pretrained_model_name_or_path: path_to_pretrained_model
  model_a:
    _target_: transformers.AutoModelForCausalLM
    pretrained_model_name_or_path: path_to_model_a

Source code in fusion_bench/modelpool/causal_lm/causal_lm.py
@override
def load_model(
    self,
    model_name_or_config: str | DictConfig,
    *args,
    **kwargs,
) -> Union[PreTrainedModel, LazyStateDict]:
    """Load a causal language model from the model pool.

    This method supports multiple loading strategies:
    1. Loading by model name from the configured model pool
    2. Loading from a direct configuration dictionary
    3. Lazy loading using LazyStateDict for memory efficiency

    The method automatically handles different model configuration formats
    and applies the appropriate loading strategy based on the enable_lazy_loading flag.

    Args:
        model_name_or_config: Either a string model name that exists in the
            model pool configuration, or a DictConfig/dict containing the
            model configuration directly.
        *args: Additional positional arguments passed to the model constructor.
        **kwargs: Additional keyword arguments passed to the model constructor.
            These will be merged with the pool's model_kwargs.

    Returns:
        Union[PreTrainedModel, LazyStateDict]: The loaded model. Returns a
            PreTrainedModel for normal loading or a LazyStateDict for lazy loading.

    Raises:
        RuntimeError: If the model configuration is invalid.
        KeyError: If the model name is not found in the model pool.

    Example YAML configurations:
        Simple string configuration:
        ```yaml
        models:
          _pretrained_: path_to_pretrained_model
          model_a: path_to_model_a
          model_b: path_to_model_b
        ```

        Detailed configuration:
        ```yaml
        models:
          _pretrained_:
            _target_: transformers.AutoModelForCausalLM
            pretrained_model_name_or_path: path_to_pretrained_model
          model_a:
            _target_: transformers.AutoModelForCausalLM
            pretrained_model_name_or_path: path_to_model_a
        ```
    """
    model_kwargs = self.get_model_kwargs()
    model_kwargs.update(kwargs)

    if isinstance(model_name_or_config, str):
        # If model_name_or_config is a string, it is the name or the path of the model
        log.info(f"Loading model: {model_name_or_config}", stacklevel=2)
        if model_name_or_config in self._models.keys():
            model_config = self._models[model_name_or_config]
            if isinstance(model_config, str):
                # model_config is a string
                if not self.enable_lazy_loading:
                    model = AutoModelForCausalLM.from_pretrained(
                        model_config,
                        *args,
                        **model_kwargs,
                    )
                else:
                    # model_config is a string, but we want to use LazyStateDict
                    model = LazyStateDict(
                        checkpoint=model_config,
                        meta_module_class=AutoModelForCausalLM,
                        *args,
                        **model_kwargs,
                    )
                return model
    elif isinstance(model_name_or_config, (DictConfig, Dict)):
        model_config = model_name_or_config

    if not self.enable_lazy_loading:
        model = instantiate(model_config, *args, **model_kwargs)
    else:
        meta_module_class = model_config.pop("_target_")
        checkpoint = model_config.pop("pretrained_model_name_or_path")
        model = LazyStateDict(
            checkpoint=checkpoint,
            meta_module_class=meta_module_class,
            *args,
            **model_kwargs,
        )
    return model
load_tokenizer(*args, **kwargs)

Load the tokenizer associated with this model pool.

Loads a tokenizer based on the tokenizer configuration provided during pool initialization. Supports both simple string paths and detailed configuration dictionaries.

Parameters:

  • *args

    Additional positional arguments passed to the tokenizer constructor.

  • **kwargs

    Additional keyword arguments passed to the tokenizer constructor.

Returns:

  • PreTrainedTokenizer ( PreTrainedTokenizer ) –

    The loaded tokenizer instance.

Raises:

  • AssertionError

    If no tokenizer is defined in the configuration.

Example YAML configurations

Simple string configuration:

tokenizer: google/gemma-2-2b-it

Detailed configuration:

tokenizer:
  _target_: transformers.AutoTokenizer
  pretrained_model_name_or_path: google/gemma-2-2b-it
  use_fast: true
  padding_side: left

Source code in fusion_bench/modelpool/causal_lm/causal_lm.py
def load_tokenizer(self, *args, **kwargs) -> PreTrainedTokenizer:
    """Load the tokenizer associated with this model pool.

    Loads a tokenizer based on the tokenizer configuration provided during
    pool initialization. Supports both simple string paths and detailed
    configuration dictionaries.

    Args:
        *args: Additional positional arguments passed to the tokenizer constructor.
        **kwargs: Additional keyword arguments passed to the tokenizer constructor.

    Returns:
        PreTrainedTokenizer: The loaded tokenizer instance.

    Raises:
        AssertionError: If no tokenizer is defined in the configuration.

    Example YAML configurations:
        Simple string configuration:
        ```yaml
        tokenizer: google/gemma-2-2b-it
        ```

        Detailed configuration:
        ```yaml
        tokenizer:
          _target_: transformers.AutoTokenizer
          pretrained_model_name_or_path: google/gemma-2-2b-it
          use_fast: true
          padding_side: left
        ```
    """
    assert self.tokenizer is not None, "Tokenizer is not defined in the config"
    log.info("Loading tokenizer.", stacklevel=2)
    if isinstance(self.tokenizer, str):
        tokenizer = AutoTokenizer.from_pretrained(self.tokenizer, *args, **kwargs)
    else:
        tokenizer = instantiate(self.tokenizer, *args, **kwargs)
    return tokenizer
save_model(model, path, push_to_hub=False, model_dtype=None, save_tokenizer=False, tokenizer_kwargs=None, tokenizer=None, algorithm_config=None, description=None, base_model_in_modelcard=True, **kwargs)

Save a model to the specified path with optional tokenizer and Hub upload.

This method provides comprehensive model saving capabilities including optional tokenizer saving, dtype conversion, model card creation, and Hugging Face Hub upload. The model is saved in the standard Hugging Face format.

Parameters:

  • model (PreTrainedModel) –

    The PreTrainedModel instance to be saved.

  • path (str) –

    The local path where the model will be saved. Supports tilde expansion for home directory paths.

  • push_to_hub (bool, default: False ) –

    Whether to push the saved model to the Hugging Face Hub. Requires proper authentication and repository permissions.

  • model_dtype (Optional[str], default: None ) –

    Optional string specifying the target dtype for the model before saving (e.g., "float16", "bfloat16"). The model will be converted to this dtype before saving.

  • save_tokenizer (bool, default: False ) –

    Whether to save the tokenizer alongside the model. If True, the tokenizer will be loaded using the pool's tokenizer configuration and saved to the same path.

  • tokenizer_kwargs

    Additional keyword arguments for tokenizer loading when save_tokenizer is True.

  • tokenizer (Optional[PreTrainedTokenizer], default: None ) –

    Optional pre-loaded tokenizer instance. If provided, this tokenizer will be saved regardless of the save_tokenizer flag.

  • algorithm_config (Optional[DictConfig], default: None ) –

    Optional DictConfig containing algorithm configuration. If provided, a model card will be created with algorithm details.

  • description (Optional[str], default: None ) –

    Optional description for the model card. If not provided and algorithm_config is given, a default description will be generated.

  • **kwargs

    Additional keyword arguments passed to the model's save_pretrained method.

Example
>>> pool = CausalLMPool(models=..., tokenizer=...)
>>> model = pool.load_model("my_model")
>>> pool.save_model(
...     model,
...     "/path/to/save",
...     save_tokenizer=True,
...     model_dtype="float16",
...     push_to_hub=True,
...     algorithm_config=algorithm_config,
...     description="Custom merged model"
... )
Source code in fusion_bench/modelpool/causal_lm/causal_lm.py
@override
def save_model(
    self,
    model: PreTrainedModel,
    path: str,
    push_to_hub: bool = False,
    model_dtype: Optional[str] = None,
    save_tokenizer: bool = False,
    tokenizer_kwargs=None,
    tokenizer: Optional[PreTrainedTokenizer] = None,
    algorithm_config: Optional[DictConfig] = None,
    description: Optional[str] = None,
    base_model_in_modelcard: bool = True,
    **kwargs,
):
    """Save a model to the specified path with optional tokenizer and Hub upload.

    This method provides comprehensive model saving capabilities including
    optional tokenizer saving, dtype conversion, model card creation, and
    Hugging Face Hub upload. The model is saved in the standard Hugging Face format.

    Args:
        model: The PreTrainedModel instance to be saved.
        path: The local path where the model will be saved. Supports tilde
            expansion for home directory paths.
        push_to_hub: Whether to push the saved model to the Hugging Face Hub.
            Requires proper authentication and repository permissions.
        model_dtype: Optional string specifying the target dtype for the model
            before saving (e.g., "float16", "bfloat16"). The model will be
            converted to this dtype before saving.
        save_tokenizer: Whether to save the tokenizer alongside the model.
            If True, the tokenizer will be loaded using the pool's tokenizer
            configuration and saved to the same path.
        tokenizer_kwargs: Additional keyword arguments for tokenizer loading
            when save_tokenizer is True.
        tokenizer: Optional pre-loaded tokenizer instance. If provided, this
            tokenizer will be saved regardless of the save_tokenizer flag.
        algorithm_config: Optional DictConfig containing algorithm configuration.
            If provided, a model card will be created with algorithm details.
        description: Optional description for the model card. If not provided
            and algorithm_config is given, a default description will be generated.
        **kwargs: Additional keyword arguments passed to the model's
            save_pretrained method.

    Example:
        ```python
        >>> pool = CausalLMPool(models=..., tokenizer=...)
        >>> model = pool.load_model("my_model")
        >>> pool.save_model(
        ...     model,
        ...     "/path/to/save",
        ...     save_tokenizer=True,
        ...     model_dtype="float16",
        ...     push_to_hub=True,
        ...     algorithm_config=algorithm_config,
        ...     description="Custom merged model"
        ... )
        ```
    """
    path = os.path.expanduser(path)
    # NOTE: if tokenizer is provided, it will be saved regardless of `save_tokenizer`
    if save_tokenizer or tokenizer is not None:
        if tokenizer is None:
            if tokenizer_kwargs is None:
                tokenizer_kwargs = {}
            # load the tokenizer
            tokenizer = self.load_tokenizer(**tokenizer_kwargs)
        tokenizer.save_pretrained(
            path,
            push_to_hub=push_to_hub,
        )
    if model_dtype is not None:
        model.to(dtype=parse_dtype(model_dtype))
    model.save_pretrained(
        path,
        push_to_hub=push_to_hub,
        **kwargs,
    )

    # Create and save model card if algorithm_config is provided
    if algorithm_config is not None and rank_zero_only.rank == 0:
        if description is None:
            description = "Model created using FusionBench."
        model_card_str = create_default_model_card(
            base_model=(
                self.get_model_path("_pretrained_")
                if base_model_in_modelcard and self.has_pretrained
                else None
            ),
            models=[self.get_model_path(m) for m in self.model_names],
            description=description,
            algorithm_config=algorithm_config,
            modelpool_config=self.config,
        )
        with open(os.path.join(path, "README.md"), "w") as f:
            f.write(model_card_str)

CausalLMBackbonePool

Bases: CausalLMPool

A specialized model pool that loads only the transformer backbone layers.

This class extends CausalLMPool to provide access to just the transformer layers (backbone) of causal language models, excluding the language modeling head and embeddings. This is useful for model fusion scenarios where only the core transformer layers are needed.

The class automatically extracts the model.layers component from loaded AutoModelForCausalLM instances, providing direct access to the transformer blocks. Lazy loading is not supported for this pool type.

Note

This pool automatically disables lazy loading as it needs to access the internal structure of the model to extract the backbone layers.

Example
>>> backbone_pool = CausalLMBackbonePool(
...     models={"model_a": "microsoft/DialoGPT-medium"},
...     tokenizer="microsoft/DialoGPT-medium"
... )
>>> layers = backbone_pool.load_model("model_a")  # Returns nn.ModuleList of transformer layers
Source code in fusion_bench/modelpool/causal_lm/causal_lm.py
class CausalLMBackbonePool(CausalLMPool):
    """A specialized model pool that loads only the transformer backbone layers.

    This class extends CausalLMPool to provide access to just the transformer
    layers (backbone) of causal language models, excluding the language modeling
    head and embeddings. This is useful for model fusion scenarios where only
    the core transformer layers are needed.

    The class automatically extracts the `model.layers` component from loaded
    AutoModelForCausalLM instances, providing direct access to the transformer
    blocks. Lazy loading is not supported for this pool type.

    Note:
        This pool automatically disables lazy loading as it needs to access
        the internal structure of the model to extract the backbone layers.

    Example:
        ```python
        >>> backbone_pool = CausalLMBackbonePool(
        ...     models={"model_a": "microsoft/DialoGPT-medium"},
        ...     tokenizer="microsoft/DialoGPT-medium"
        ... )
        >>> layers = backbone_pool.load_model("model_a")  # Returns nn.ModuleList of transformer layers
        ```
    """

    def load_model(
        self, model_name_or_config: str | DictConfig, *args, **kwargs
    ) -> Module:
        """Load only the transformer backbone layers from a causal language model.

        This method loads a complete causal language model and then extracts
        only the transformer layers (backbone), discarding the embedding layers
        and language modeling head. This is useful for model fusion scenarios
        where only the core transformer computation is needed.

        Args:
            model_name_or_config: Either a string model name from the pool
                configuration or a DictConfig with model loading parameters.
            *args: Additional positional arguments passed to the parent load_model method.
            **kwargs: Additional keyword arguments passed to the parent load_model method.

        Returns:
            Module: The transformer layers (typically a nn.ModuleList) containing
                the core transformer blocks without embeddings or output heads.

        Note:
            Lazy loading is automatically disabled for this method as it needs
            to access the internal model structure to extract the layers.
        """
        if self.enable_lazy_loading:
            log.warning(
                "CausalLMBackbonePool does not support lazy loading. "
                "Falling back to normal loading."
            )
            self.enable_lazy_loading = False
        model: AutoModelForCausalLM = super().load_model(
            model_name_or_config, *args, **kwargs
        )
        return model.model.layers
load_model(model_name_or_config, *args, **kwargs)

Load only the transformer backbone layers from a causal language model.

This method loads a complete causal language model and then extracts only the transformer layers (backbone), discarding the embedding layers and language modeling head. This is useful for model fusion scenarios where only the core transformer computation is needed.

Parameters:

  • model_name_or_config (str | DictConfig) –

    Either a string model name from the pool configuration or a DictConfig with model loading parameters.

  • *args

    Additional positional arguments passed to the parent load_model method.

  • **kwargs

    Additional keyword arguments passed to the parent load_model method.

Returns:

  • Module ( Module ) –

    The transformer layers (typically a nn.ModuleList) containing the core transformer blocks without embeddings or output heads.

Note

Lazy loading is automatically disabled for this method as it needs to access the internal model structure to extract the layers.

Source code in fusion_bench/modelpool/causal_lm/causal_lm.py
def load_model(
    self, model_name_or_config: str | DictConfig, *args, **kwargs
) -> Module:
    """Load only the transformer backbone layers from a causal language model.

    This method loads a complete causal language model and then extracts
    only the transformer layers (backbone), discarding the embedding layers
    and language modeling head. This is useful for model fusion scenarios
    where only the core transformer computation is needed.

    Args:
        model_name_or_config: Either a string model name from the pool
            configuration or a DictConfig with model loading parameters.
        *args: Additional positional arguments passed to the parent load_model method.
        **kwargs: Additional keyword arguments passed to the parent load_model method.

    Returns:
        Module: The transformer layers (typically a nn.ModuleList) containing
            the core transformer blocks without embeddings or output heads.

    Note:
        Lazy loading is automatically disabled for this method as it needs
        to access the internal model structure to extract the layers.
    """
    if self.enable_lazy_loading:
        log.warning(
            "CausalLMBackbonePool does not support lazy loading. "
            "Falling back to normal loading."
        )
        self.enable_lazy_loading = False
    model: AutoModelForCausalLM = super().load_model(
        model_name_or_config, *args, **kwargs
    )
    return model.model.layers

load_peft_causal_lm(base_model_path, peft_model_path, torch_dtype='bfloat16', is_trainable=True, merge_and_unload=False)

Load a causal language model with PEFT (Parameter-Efficient Fine-Tuning) adapters.

This function loads a base causal language model and applies PEFT adapters (such as LoRA, AdaLoRA, or other parameter-efficient fine-tuning methods) to create a fine-tuned model. It supports both keeping the adapters separate or merging them into the base model.

Parameters:

  • base_model_path (str) –

    Path or identifier for the base causal language model. Can be a Hugging Face model name or local path.

  • peft_model_path (str) –

    Path to the PEFT adapter configuration and weights. This should contain the adapter_config.json and adapter weights.

  • torch_dtype (str, default: 'bfloat16' ) –

    The torch data type to use for the model. Common options include "float16", "bfloat16", "float32". Defaults to "bfloat16".

  • is_trainable (bool, default: True ) –

    Whether the loaded PEFT model should be trainable. Set to False for inference-only usage to save memory.

  • merge_and_unload (bool, default: False ) –

    Whether to merge the PEFT adapters into the base model and unload the adapter weights. When True, returns a standard PreTrainedModel instead of a PeftModel.

Returns:

  • Union[PeftModel, PreTrainedModel]: The loaded model with PEFT adapters. Returns a PeftModel if merge_and_unload is False, or a PreTrainedModel if the adapters are merged and unloaded.

Example
>>> # Load model with adapters for training
>>> model = load_peft_causal_lm(
...     "microsoft/DialoGPT-medium",
...     "/path/to/lora/adapters",
...     is_trainable=True
... )

>>> # Load and merge adapters for inference
>>> merged_model = load_peft_causal_lm(
...     "microsoft/DialoGPT-medium",
...     "/path/to/lora/adapters",
...     merge_and_unload=True,
...     is_trainable=False
... )
Source code in fusion_bench/modelpool/causal_lm/causal_lm.py
def load_peft_causal_lm(
    base_model_path: str,
    peft_model_path: str,
    torch_dtype: str = "bfloat16",
    is_trainable: bool = True,
    merge_and_unload: bool = False,
):
    """Load a causal language model with PEFT (Parameter-Efficient Fine-Tuning) adapters.

    This function loads a base causal language model and applies PEFT adapters
    (such as LoRA, AdaLoRA, or other parameter-efficient fine-tuning methods)
    to create a fine-tuned model. It supports both keeping the adapters separate
    or merging them into the base model.

    Args:
        base_model_path: Path or identifier for the base causal language model.
            Can be a Hugging Face model name or local path.
        peft_model_path: Path to the PEFT adapter configuration and weights.
            This should contain the adapter_config.json and adapter weights.
        torch_dtype: The torch data type to use for the model. Common options
            include "float16", "bfloat16", "float32". Defaults to "bfloat16".
        is_trainable: Whether the loaded PEFT model should be trainable.
            Set to False for inference-only usage to save memory.
        merge_and_unload: Whether to merge the PEFT adapters into the base model
            and unload the adapter weights. When True, returns a standard
            PreTrainedModel instead of a PeftModel.

    Returns:
        Union[PeftModel, PreTrainedModel]: The loaded model with PEFT adapters.
            Returns a PeftModel if merge_and_unload is False, or a PreTrainedModel
            if the adapters are merged and unloaded.

    Example:
        ```python
        >>> # Load model with adapters for training
        >>> model = load_peft_causal_lm(
        ...     "microsoft/DialoGPT-medium",
        ...     "/path/to/lora/adapters",
        ...     is_trainable=True
        ... )

        >>> # Load and merge adapters for inference
        >>> merged_model = load_peft_causal_lm(
        ...     "microsoft/DialoGPT-medium",
        ...     "/path/to/lora/adapters",
        ...     merge_and_unload=True,
        ...     is_trainable=False
        ... )
        ```
    """
    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_path, torch_dtype=torch_dtype
    )
    model = peft.PeftModel.from_pretrained(
        base_model,
        peft_model_path,
        is_trainable=is_trainable,
    )
    if merge_and_unload:
        model = model.merge_and_unload()
    return model

Others

Transformers AutoModel

AutoModelPool

Bases: ModelPool

Source code in fusion_bench/modelpool/huggingface_automodel.py
class AutoModelPool(ModelPool):
    def load_model(self, model_config: str | DictConfig) -> Module:
        if isinstance(model_config, str):
            model_config = self.get_model_config(model_config)
        else:
            model_config = model_config

        model = AutoModel.from_pretrained(model_config.path)
        return model