Skip to content

fusion_bench.modelpool

Base Class

BaseModelPool

Bases: BaseYAMLSerializableModel, HydraConfigMixin

A class for managing and interacting with a pool of models along with their associated datasets or other specifications. For example, a model pool may contain multiple models, each with its own training, validation, and testing datasets. As for the specifications, a vision model pool may contain image preprocessor, and a language model pool may contain a tokenizer.

Attributes:

  • _models (DictConfig) –

    Configuration for all models in the pool.

  • _train_datasets (Optional[DictConfig]) –

    Configuration for training datasets.

  • _val_datasets (Optional[DictConfig]) –

    Configuration for validation datasets.

  • _test_datasets (Optional[DictConfig]) –

    Configuration for testing datasets.

  • _usage_ (Optional[str]) –

    Optional usage information.

  • _version_ (Optional[str]) –

    Optional version information.

Source code in fusion_bench/modelpool/base_pool.py
class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
    """
    A class for managing and interacting with a pool of models along with their associated datasets or other specifications. For example, a model pool may contain multiple models, each with its own training, validation, and testing datasets. As for the specifications, a vision model pool may contain image preprocessor, and a language model pool may contain a tokenizer.

    Attributes:
        _models (DictConfig): Configuration for all models in the pool.
        _train_datasets (Optional[DictConfig]): Configuration for training datasets.
        _val_datasets (Optional[DictConfig]): Configuration for validation datasets.
        _test_datasets (Optional[DictConfig]): Configuration for testing datasets.
        _usage_ (Optional[str]): Optional usage information.
        _version_ (Optional[str]): Optional version information.
    """

    _program = None
    _config_key = "modelpool"
    _models: Union[DictConfig, Dict[str, nn.Module]]
    _config_mapping = BaseYAMLSerializableModel._config_mapping | {
        "_models": "models",
        "_train_datasets": "train_datasets",
        "_val_datasets": "val_datasets",
        "_test_datasets": "test_datasets",
    }

    def __init__(
        self,
        models: Union[DictConfig, Dict[str, nn.Module], List[nn.Module]],
        *,
        train_datasets: Optional[DictConfig] = None,
        val_datasets: Optional[DictConfig] = None,
        test_datasets: Optional[DictConfig] = None,
        **kwargs,
    ):
        if isinstance(models, List):
            models = {str(model_idx): model for model_idx, model in enumerate(models)}
        self._models = models
        self._train_datasets = train_datasets
        self._val_datasets = val_datasets
        self._test_datasets = test_datasets
        super().__init__(**kwargs)

    @property
    def has_pretrained(self) -> bool:
        """
        Check if the model pool contains a pretrained model.

        Returns:
            bool: True if a pretrained model is available, False otherwise.
        """
        return "_pretrained_" in self._models

    @property
    def all_model_names(self) -> List[str]:
        """
        Get the names of all models in the pool, including special models.

        Returns:
            List[str]: A list of all model names.
        """
        return [name for name in self._models]

    @property
    def model_names(self) -> List[str]:
        """
        Get the names of regular models, excluding special models.

        Returns:
            List[str]: A list of regular model names.
        """
        return [name for name in self._models if not self.is_special_model(name)]

    @property
    def train_dataset_names(self) -> List[str]:
        """
        Get the names of training datasets.

        Returns:
            List[str]: A list of training dataset names.
        """
        return (
            list(self._train_datasets.keys())
            if self._train_datasets is not None
            else []
        )

    @property
    def val_dataset_names(self) -> List[str]:
        """
        Get the names of validation datasets.

        Returns:
            List[str]: A list of validation dataset names.
        """
        return list(self._val_datasets.keys()) if self._val_datasets is not None else []

    @property
    def test_dataset_names(self) -> List[str]:
        """
        Get the names of testing datasets.

        Returns:
            List[str]: A list of testing dataset names.
        """
        return (
            list(self._test_datasets.keys()) if self._test_datasets is not None else []
        )

    def __len__(self):
        return len(self.model_names)

    @staticmethod
    def is_special_model(model_name: str) -> bool:
        """
        Determine if a model is special based on its name.

        Args:
            model_name (str): The name of the model.

        Returns:
            bool: True if the model name indicates a special model, False otherwise.
        """
        return model_name.startswith("_") and model_name.endswith("_")

    def get_model_config(self, model_name: str, return_copy: bool = True) -> DictConfig:
        """
        Get the configuration for the specified model.

        Args:
            model_name (str): The name of the model.

        Returns:
            DictConfig: The configuration for the specified model.
        """
        model_config = self._models[model_name]
        if return_copy:
            model_config = deepcopy(model_config)
        return model_config

    def load_model(
        self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
    ) -> nn.Module:
        """
        Load a model from the pool based on the provided configuration.

        Args:
            model_name_or_config (Union[str, DictConfig]): The model name or configuration.

        Returns:
            nn.Module: The instantiated model.
        """
        log.debug(f"Loading model: {model_name_or_config}", stacklevel=2)
        if isinstance(self._models, DictConfig):
            model_config = (
                self._models[model_name_or_config]
                if isinstance(model_name_or_config, str)
                else model_name_or_config
            )
            model = instantiate(model_config, *args, **kwargs)
        elif isinstance(self._models, Dict) and isinstance(model_name_or_config, str):
            model = self._models[model_name_or_config]
        else:
            raise ValueError(
                "The model pool configuration is not in the expected format."
                f"We expected a DictConfig or Dict, but got {type(self._models)}."
            )
        return model

    def load_pretrained_model(self, *args, **kwargs):
        assert (
            self.has_pretrained
        ), "No pretrained model available. Check `_pretrained_` is in the `models` key."
        model = self.load_model("_pretrained_", *args, **kwargs)
        return model

    def load_pretrained_or_first_model(self, *args, **kwargs):
        """
        Load the pretrained model if available, otherwise load the first available model.

        Returns:
            nn.Module: The loaded model.
        """
        if self.has_pretrained:
            model = self.load_model("_pretrained_", *args, **kwargs)
        else:
            model = self.load_model(self.model_names[0], *args, **kwargs)
        return model

    def models(self):
        for model_name in self.model_names:
            yield self.load_model(model_name)

    def named_models(self):
        for model_name in self.model_names:
            yield model_name, self.load_model(model_name)

    def load_train_dataset(self, dataset_name: str, *args, **kwargs) -> Dataset:
        """
        Load the training dataset for the specified model.

        Args:
            dataset_name (str): The name of the model.

        Returns:
            Dataset: The instantiated training dataset.
        """
        return instantiate(self._train_datasets[dataset_name], *args, **kwargs)

    def train_datasets(self):
        for dataset_name in self.train_dataset_names:
            yield self.load_train_dataset(dataset_name)

    def load_val_dataset(self, dataset_name: str, *args, **kwargs) -> Dataset:
        """
        Load the validation dataset for the specified model.

        Args:
            dataset_name (str): The name of the model.

        Returns:
            Dataset: The instantiated validation dataset.
        """
        return instantiate(self._val_datasets[dataset_name], *args, **kwargs)

    def val_datasets(self):
        for dataset_name in self.val_dataset_names:
            yield self.load_val_dataset(dataset_name)

    def load_test_dataset(self, dataset_name: str, *args, **kwargs) -> Dataset:
        """
        Load the testing dataset for the specified model.

        Args:
            dataset_name (str): The name of the model.

        Returns:
            Dataset: The instantiated testing dataset.
        """
        return instantiate(self._test_datasets[dataset_name], *args, **kwargs)

    def test_datasets(self):
        for dataset_name in self.test_dataset_names:
            yield self.load_test_dataset(dataset_name)

    def save_model(self, model: nn.Module, path: str):
        """
        Save the state dictionary of the model to the specified path.

        Args:
            model (nn.Module): The model whose state dictionary is to be saved.
            path (str): The path where the state dictionary will be saved.
        """
        with timeit_context(f"Saving the state dict of model to {path}"):
            torch.save(model.state_dict(), path)

all_model_names property

Get the names of all models in the pool, including special models.

Returns:

  • List[str]

    List[str]: A list of all model names.

has_pretrained property

Check if the model pool contains a pretrained model.

Returns:

  • bool ( bool ) –

    True if a pretrained model is available, False otherwise.

model_names property

Get the names of regular models, excluding special models.

Returns:

  • List[str]

    List[str]: A list of regular model names.

test_dataset_names property

Get the names of testing datasets.

Returns:

  • List[str]

    List[str]: A list of testing dataset names.

train_dataset_names property

Get the names of training datasets.

Returns:

  • List[str]

    List[str]: A list of training dataset names.

val_dataset_names property

Get the names of validation datasets.

Returns:

  • List[str]

    List[str]: A list of validation dataset names.

get_model_config(model_name, return_copy=True)

Get the configuration for the specified model.

Parameters:

  • model_name (str) –

    The name of the model.

Returns:

  • DictConfig ( DictConfig ) –

    The configuration for the specified model.

Source code in fusion_bench/modelpool/base_pool.py
def get_model_config(self, model_name: str, return_copy: bool = True) -> DictConfig:
    """
    Get the configuration for the specified model.

    Args:
        model_name (str): The name of the model.

    Returns:
        DictConfig: The configuration for the specified model.
    """
    model_config = self._models[model_name]
    if return_copy:
        model_config = deepcopy(model_config)
    return model_config

is_special_model(model_name) staticmethod

Determine if a model is special based on its name.

Parameters:

  • model_name (str) –

    The name of the model.

Returns:

  • bool ( bool ) –

    True if the model name indicates a special model, False otherwise.

Source code in fusion_bench/modelpool/base_pool.py
@staticmethod
def is_special_model(model_name: str) -> bool:
    """
    Determine if a model is special based on its name.

    Args:
        model_name (str): The name of the model.

    Returns:
        bool: True if the model name indicates a special model, False otherwise.
    """
    return model_name.startswith("_") and model_name.endswith("_")

load_model(model_name_or_config, *args, **kwargs)

Load a model from the pool based on the provided configuration.

Parameters:

  • model_name_or_config (Union[str, DictConfig]) –

    The model name or configuration.

Returns:

  • Module

    nn.Module: The instantiated model.

Source code in fusion_bench/modelpool/base_pool.py
def load_model(
    self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
) -> nn.Module:
    """
    Load a model from the pool based on the provided configuration.

    Args:
        model_name_or_config (Union[str, DictConfig]): The model name or configuration.

    Returns:
        nn.Module: The instantiated model.
    """
    log.debug(f"Loading model: {model_name_or_config}", stacklevel=2)
    if isinstance(self._models, DictConfig):
        model_config = (
            self._models[model_name_or_config]
            if isinstance(model_name_or_config, str)
            else model_name_or_config
        )
        model = instantiate(model_config, *args, **kwargs)
    elif isinstance(self._models, Dict) and isinstance(model_name_or_config, str):
        model = self._models[model_name_or_config]
    else:
        raise ValueError(
            "The model pool configuration is not in the expected format."
            f"We expected a DictConfig or Dict, but got {type(self._models)}."
        )
    return model

load_pretrained_or_first_model(*args, **kwargs)

Load the pretrained model if available, otherwise load the first available model.

Returns:

  • nn.Module: The loaded model.

Source code in fusion_bench/modelpool/base_pool.py
def load_pretrained_or_first_model(self, *args, **kwargs):
    """
    Load the pretrained model if available, otherwise load the first available model.

    Returns:
        nn.Module: The loaded model.
    """
    if self.has_pretrained:
        model = self.load_model("_pretrained_", *args, **kwargs)
    else:
        model = self.load_model(self.model_names[0], *args, **kwargs)
    return model

load_test_dataset(dataset_name, *args, **kwargs)

Load the testing dataset for the specified model.

Parameters:

  • dataset_name (str) –

    The name of the model.

Returns:

  • Dataset ( Dataset ) –

    The instantiated testing dataset.

Source code in fusion_bench/modelpool/base_pool.py
def load_test_dataset(self, dataset_name: str, *args, **kwargs) -> Dataset:
    """
    Load the testing dataset for the specified model.

    Args:
        dataset_name (str): The name of the model.

    Returns:
        Dataset: The instantiated testing dataset.
    """
    return instantiate(self._test_datasets[dataset_name], *args, **kwargs)

load_train_dataset(dataset_name, *args, **kwargs)

Load the training dataset for the specified model.

Parameters:

  • dataset_name (str) –

    The name of the model.

Returns:

  • Dataset ( Dataset ) –

    The instantiated training dataset.

Source code in fusion_bench/modelpool/base_pool.py
def load_train_dataset(self, dataset_name: str, *args, **kwargs) -> Dataset:
    """
    Load the training dataset for the specified model.

    Args:
        dataset_name (str): The name of the model.

    Returns:
        Dataset: The instantiated training dataset.
    """
    return instantiate(self._train_datasets[dataset_name], *args, **kwargs)

load_val_dataset(dataset_name, *args, **kwargs)

Load the validation dataset for the specified model.

Parameters:

  • dataset_name (str) –

    The name of the model.

Returns:

  • Dataset ( Dataset ) –

    The instantiated validation dataset.

Source code in fusion_bench/modelpool/base_pool.py
def load_val_dataset(self, dataset_name: str, *args, **kwargs) -> Dataset:
    """
    Load the validation dataset for the specified model.

    Args:
        dataset_name (str): The name of the model.

    Returns:
        Dataset: The instantiated validation dataset.
    """
    return instantiate(self._val_datasets[dataset_name], *args, **kwargs)

save_model(model, path)

Save the state dictionary of the model to the specified path.

Parameters:

  • model (Module) –

    The model whose state dictionary is to be saved.

  • path (str) –

    The path where the state dictionary will be saved.

Source code in fusion_bench/modelpool/base_pool.py
def save_model(self, model: nn.Module, path: str):
    """
    Save the state dictionary of the model to the specified path.

    Args:
        model (nn.Module): The model whose state dictionary is to be saved.
        path (str): The path where the state dictionary will be saved.
    """
    with timeit_context(f"Saving the state dict of model to {path}"):
        torch.save(model.state_dict(), path)

Vision Model Pool

NYUv2 Tasks (ResNet)

NYUv2ModelPool

Bases: ModelPool

Source code in fusion_bench/modelpool/nyuv2_modelpool.py
class NYUv2ModelPool(ModelPool):
    def load_model(
        self, model_config: str | DictConfig, encoder_only: bool = True
    ) -> ResnetDilated | NYUv2Model:
        if isinstance(model_config, str):
            model_config = self.get_model_config(model_config)

        encoder = resnet_dilated(model_config.encoder)
        decoders = nn.ModuleDict(
            {
                task: DeepLabHead(2048, NYUv2.num_out_channels[task])
                for task in model_config.decoders
            }
        )
        model = NYUv2Model(encoder=encoder, decoders=decoders)
        if model_config.get("ckpt_path", None) is not None:
            ckpt = torch.load(model_config.ckpt_path, map_location="cpu")
            if "state_dict" in ckpt:
                ckpt = ckpt["state_dict"]
            model.load_state_dict(ckpt, strict=False)

        if encoder_only:
            return model.encoder
        else:
            return model

CLIP Vision Encoder

CLIPVisionModelPool

Bases: BaseModelPool

A model pool for managing Hugging Face's CLIP Vision models.

This class extends the base ModelPool class and overrides its methods to handle the specifics of the CLIP Vision models provided by the Hugging Face Transformers library.

Source code in fusion_bench/modelpool/clip_vision/modelpool.py
class CLIPVisionModelPool(BaseModelPool):
    """
    A model pool for managing Hugging Face's CLIP Vision models.

    This class extends the base `ModelPool` class and overrides its methods to handle
    the specifics of the CLIP Vision models provided by the Hugging Face Transformers library.
    """

    _config_mapping = BaseModelPool._config_mapping | {
        "_processor": "processor",
        "_platform": "hf",
    }

    def __init__(
        self,
        models: DictConfig,
        *,
        processor: Optional[DictConfig] = None,
        platform: Literal["hf", "huggingface", "modelscope"] = "hf",
        **kwargs,
    ):
        super().__init__(models, **kwargs)
        self._processor = processor
        self._platform = platform

    def load_processor(self, *args, **kwargs) -> CLIPProcessor:
        assert self._processor is not None, "Processor is not defined in the config"
        if isinstance(self._processor, str):
            if rank_zero_only.rank == 0:
                log.info(f"Loading `transformers.CLIPProcessor`: {self._processor}")
            repo_path = resolve_repo_path(
                repo_id=self._processor, repo_type="model", platform=self._platform
            )
            processor = CLIPProcessor.from_pretrained(repo_path, *args, **kwargs)
        else:
            processor = instantiate(self._processor, *args, **kwargs)
        return processor

    def load_clip_model(self, model_name: str, *args, **kwargs) -> CLIPModel:
        model_config = self._models[model_name]

        if isinstance(model_config, str):
            if rank_zero_only.rank == 0:
                log.info(f"Loading `transformers.CLIPModel`: {model_config}")
            repo_path = resolve_repo_path(
                repo_id=model_config, repo_type="model", platform=self._platform
            )
            clip_model = CLIPModel.from_pretrained(repo_path, *args, **kwargs)
            return clip_model
        else:
            assert isinstance(
                model_config, DictConfig
            ), "Model config must be a DictConfig"
            model_config = deepcopy(model_config)
            with open_dict(model_config):
                model_config._target_ = "transformers.CLIPModel.from_pretrained"
            clip_model = instantiate(model_config, *args, **kwargs)
            return clip_model

    @override
    def save_model(self, model: CLIPVisionModel, path: str):
        """
        Save a CLIP Vision model to the given path.

        Args:
            model (CLIPVisionModel): The model to save.
            path (str): The path to save the model to.
        """
        with timeit_context(f'Saving clip vision model to "{path}"'):
            model.save_pretrained(path)

    def load_model(
        self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
    ) -> CLIPVisionModel:
        """
        This method is used to load a CLIPVisionModel from the model pool.

        Example configuration could be:

        ```yaml
        models:
            cifar10: tanganke/clip-vit-base-patch32_cifar10
            sun397: tanganke/clip-vit-base-patch32_sun397
            stanford-cars: tanganke/clip-vit-base-patch32_stanford-cars
        ```

        Args:
            model_name_or_config (Union[str, DictConfig]): The name of the model or the model configuration.

        Returns:
            CLIPVisionModel: The loaded CLIPVisionModel.
        """
        if (
            isinstance(model_name_or_config, str)
            and model_name_or_config in self._models
        ):
            model = self._models[model_name_or_config]
            if isinstance(model, str):
                if rank_zero_only.rank == 0:
                    log.info(f"Loading `transformers.CLIPVisionModel`: {model}")
                repo_path = resolve_repo_path(
                    model, repo_type="model", platform=self._platform
                )
                return CLIPVisionModel.from_pretrained(repo_path, *args, **kwargs)
            if isinstance(model, nn.Module):
                if rank_zero_only.rank == 0:
                    log.info(f"Returning existing model: {model}")
                return model
        else:
            # If the model is not a string, we use the default load_model method
            return super().load_model(model_name_or_config, *args, **kwargs)

    def load_train_dataset(self, dataset_name: str, *args, **kwargs):
        dataset_config = self._train_datasets[dataset_name]
        if isinstance(dataset_config, str):
            if rank_zero_only.rank == 0:
                log.info(
                    f"Loading train dataset using `datasets.load_dataset`: {dataset_config}"
                )
            dataset = self._load_dataset(dataset_config, split="train")
        else:
            dataset = super().load_train_dataset(dataset_name, *args, **kwargs)
        return dataset

    def load_val_dataset(self, dataset_name: str, *args, **kwargs):
        dataset_config = self._val_datasets[dataset_name]
        if isinstance(dataset_config, str):
            if rank_zero_only.rank == 0:
                log.info(
                    f"Loading validation dataset using `datasets.load_dataset`: {dataset_config}"
                )
            dataset = self._load_dataset(dataset_config, split="validation")
        else:
            dataset = super().load_val_dataset(dataset_name, *args, **kwargs)
        return dataset

    def load_test_dataset(self, dataset_name: str, *args, **kwargs):
        dataset_config = self._test_datasets[dataset_name]
        if isinstance(dataset_config, str):
            if rank_zero_only.rank == 0:
                log.info(
                    f"Loading test dataset using `datasets.load_dataset`: {dataset_config}"
                )
            dataset = self._load_dataset(dataset_config, split="test")
        else:
            dataset = super().load_test_dataset(dataset_name, *args, **kwargs)
        return dataset

    def _load_dataset(self, name: str, split: str):
        """
        Load a dataset by its name and split.

        Args:
            dataset_name (str): The name of the dataset.
            split (str): The split of the dataset to load (e.g., "train", "validation", "test").

        Returns:
            Dataset: The loaded dataset.
        """
        datset_dir = resolve_repo_path(
            name, repo_type="dataset", platform=self._platform
        )
        dataset = load_dataset(datset_dir, split=split)
        return dataset
load_model(model_name_or_config, *args, **kwargs)

This method is used to load a CLIPVisionModel from the model pool.

Example configuration could be:

models:
    cifar10: tanganke/clip-vit-base-patch32_cifar10
    sun397: tanganke/clip-vit-base-patch32_sun397
    stanford-cars: tanganke/clip-vit-base-patch32_stanford-cars

Parameters:

  • model_name_or_config (Union[str, DictConfig]) –

    The name of the model or the model configuration.

Returns:

  • CLIPVisionModel ( CLIPVisionModel ) –

    The loaded CLIPVisionModel.

Source code in fusion_bench/modelpool/clip_vision/modelpool.py
def load_model(
    self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
) -> CLIPVisionModel:
    """
    This method is used to load a CLIPVisionModel from the model pool.

    Example configuration could be:

    ```yaml
    models:
        cifar10: tanganke/clip-vit-base-patch32_cifar10
        sun397: tanganke/clip-vit-base-patch32_sun397
        stanford-cars: tanganke/clip-vit-base-patch32_stanford-cars
    ```

    Args:
        model_name_or_config (Union[str, DictConfig]): The name of the model or the model configuration.

    Returns:
        CLIPVisionModel: The loaded CLIPVisionModel.
    """
    if (
        isinstance(model_name_or_config, str)
        and model_name_or_config in self._models
    ):
        model = self._models[model_name_or_config]
        if isinstance(model, str):
            if rank_zero_only.rank == 0:
                log.info(f"Loading `transformers.CLIPVisionModel`: {model}")
            repo_path = resolve_repo_path(
                model, repo_type="model", platform=self._platform
            )
            return CLIPVisionModel.from_pretrained(repo_path, *args, **kwargs)
        if isinstance(model, nn.Module):
            if rank_zero_only.rank == 0:
                log.info(f"Returning existing model: {model}")
            return model
    else:
        # If the model is not a string, we use the default load_model method
        return super().load_model(model_name_or_config, *args, **kwargs)
save_model(model, path)

Save a CLIP Vision model to the given path.

Parameters:

  • model (CLIPVisionModel) –

    The model to save.

  • path (str) –

    The path to save the model to.

Source code in fusion_bench/modelpool/clip_vision/modelpool.py
@override
def save_model(self, model: CLIPVisionModel, path: str):
    """
    Save a CLIP Vision model to the given path.

    Args:
        model (CLIPVisionModel): The model to save.
        path (str): The path to save the model to.
    """
    with timeit_context(f'Saving clip vision model to "{path}"'):
        model.save_pretrained(path)

OpenCLIP Vision Encoder

OpenCLIPVisionModelPool

Bases: BaseModelPool

A model pool for managing OpenCLIP Vision models (models from task vector paper).

Source code in fusion_bench/modelpool/openclip_vision/modelpool.py
class OpenCLIPVisionModelPool(BaseModelPool):
    """
    A model pool for managing OpenCLIP Vision models (models from task vector paper).
    """

    _train_processor = None
    _test_processor = None

    def __init__(
        self,
        models: DictConfig,
        classification_heads: Optional[DictConfig] = None,
        **kwargs,
    ):
        super().__init__(models, **kwargs)
        self._classification_heads = classification_heads

    @property
    def train_processor(self):
        if self._train_processor is None:
            encoder: ImageEncoder = self.load_pretrained_or_first_model()
            self._train_processor = encoder.train_preprocess
            if self._test_processor is None:
                self._test_processor = encoder.val_preprocess
        return self._train_processor

    @property
    def test_processor(self):
        if self._test_processor is None:
            encoder: ImageEncoder = self.load_pretrained_or_first_model()
            if self._train_processor is None:
                self._train_processor = encoder.train_preprocess
            self._test_processor = encoder.val_preprocess
        return self._test_processor

    def load_model(
        self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
    ) -> ImageEncoder:
        R"""
        The model config can be:

        - A string, which is the path to the model checkpoint in pickle format. Load directly using `torch.load`.
        - {"model_name": str, "pickle_path": str}, load the model from the binary file (pickle format). This will first construct the model using `ImageEncoder(model_name)`, and then load the state dict from model located in the pickle file.
        - {"model_name": str, "state_dict_path": str}, load the model from the state dict file. This will first construct the model using `ImageEncoder(model_name)`, and then load the state dict from the file.
        - Default, load the model using `instantiate` from hydra.
        """
        if (
            isinstance(model_name_or_config, str)
            and model_name_or_config in self._models
        ):
            model_config = self._models[model_name_or_config]
        else:
            model_config = model_name_or_config
        if isinstance(model_config, DictConfig):
            model_config = OmegaConf.to_container(model_config, resolve=True)

        if isinstance(model_config, str):
            # the model config is a string, which is the path to the model checkpoint in pickle format
            # load the model using `torch.load`
            # this is the original usage in the task arithmetic codebase
            _check_and_redirect_open_clip_modeling()
            log.info(f"loading ImageEncoder from {model_config}")
            weights_only = kwargs["weights_only"] if "weights_only" in kwargs else False
            try:
                encoder = torch.load(
                    model_config, weights_only=weights_only, *args, **kwargs
                )
            except RuntimeError as e:
                encoder = pickle.load(open(model_config, "rb"))
        elif is_expr_match({"model_name": str, "pickle_path": str}, model_config):
            # the model config is a dictionary with the following keys:
            # - model_name: str, the name of the model
            # - pickle_path: str, the path to the binary file (pickle format)
            # load the model from the binary file (pickle format)
            # this is useful when you use a newer version of torchvision
            _check_and_redirect_open_clip_modeling()
            log.info(
                f"loading ImageEncoder of {model_config['model_name']} from {model_config['pickle_path']}"
            )
            weights_only = kwargs["weights_only"] if "weights_only" in kwargs else False
            try:
                encoder = torch.load(
                    model_config["pickle_path"],
                    weights_only=weights_only,
                    *args,
                    **kwargs,
                )
            except RuntimeError as e:
                encoder = pickle.load(open(model_config["pickle_path"], "rb"))
            _encoder = ImageEncoder(model_config["model_name"])
            _encoder.load_state_dict(encoder.state_dict())
            encoder = _encoder
        elif is_expr_match({"model_name": str, "state_dict_path": str}, model_config):
            # the model config is a dictionary with the following keys:
            # - model_name: str, the name of the model
            # - state_dict_path: str, the path to the state dict file
            # load the model from the state dict file
            log.info(
                f"loading ImageEncoder of {model_config['model_name']} from {model_config['state_dict_path']}"
            )
            encoder = ImageEncoder(model_config["model_name"])
            encoder.load_state_dict(
                torch.load(
                    model_config["state_dict_path"], weights_only=True, *args, **kwargs
                )
            )
        elif isinstance(model_config, nn.Module):
            # the model config is an existing model
            log.info(f"Returning existing model: {model_config}")
            encoder = model_config
        else:
            encoder = super().load_model(model_name_or_config, *args, **kwargs)
        encoder = cast(ImageEncoder, encoder)

        # setup the train and test processors
        if self._train_processor is None and hasattr(encoder, "train_preprocess"):
            self._train_processor = encoder.train_preprocess
        if self._test_processor is None and hasattr(encoder, "val_preprocess"):
            self._test_processor = encoder.val_preprocess

        return encoder

    def load_classification_head(
        self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
    ) -> ClassificationHead:
        R"""
        The model config can be:

        - A string, which is the path to the model checkpoint in pickle format. Load directly using `torch.load`.
        - Default, load the model using `instantiate` from hydra.
        """
        if (
            isinstance(model_name_or_config, str)
            and model_name_or_config in self._classification_heads
        ):
            model_config = self._classification_heads[model_name_or_config]
        else:
            model_config = model_name_or_config

        head = load_classifier_head(model_config, *args, **kwargs)
        return head

    def load_train_dataset(self, dataset_name: str, *args, **kwargs):
        dataset_config = self._train_datasets[dataset_name]
        if isinstance(dataset_config, str):
            log.info(
                f"Loading train dataset using `datasets.load_dataset`: {dataset_config}"
            )
            dataset = load_dataset(dataset_config, split="train")
        else:
            dataset = super().load_train_dataset(dataset_name, *args, **kwargs)
        return dataset

    def load_val_dataset(self, dataset_name: str, *args, **kwargs):
        dataset_config = self._val_datasets[dataset_name]
        if isinstance(dataset_config, str):
            log.info(
                f"Loading validation dataset using `datasets.load_dataset`: {dataset_config}"
            )
            dataset = load_dataset(dataset_config, split="validation")
        else:
            dataset = super().load_val_dataset(dataset_name, *args, **kwargs)
        return dataset

    def load_test_dataset(self, dataset_name: str, *args, **kwargs):
        dataset_config = self._test_datasets[dataset_name]
        if isinstance(dataset_config, str):
            log.info(
                f"Loading test dataset using `datasets.load_dataset`: {dataset_config}"
            )
            dataset = load_dataset(dataset_config, split="test")
        else:
            dataset = super().load_test_dataset(dataset_name, *args, **kwargs)
        return dataset
load_classification_head(model_name_or_config, *args, **kwargs)

The model config can be:

  • A string, which is the path to the model checkpoint in pickle format. Load directly using torch.load.
  • Default, load the model using instantiate from hydra.
Source code in fusion_bench/modelpool/openclip_vision/modelpool.py
def load_classification_head(
    self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
) -> ClassificationHead:
    R"""
    The model config can be:

    - A string, which is the path to the model checkpoint in pickle format. Load directly using `torch.load`.
    - Default, load the model using `instantiate` from hydra.
    """
    if (
        isinstance(model_name_or_config, str)
        and model_name_or_config in self._classification_heads
    ):
        model_config = self._classification_heads[model_name_or_config]
    else:
        model_config = model_name_or_config

    head = load_classifier_head(model_config, *args, **kwargs)
    return head
load_model(model_name_or_config, *args, **kwargs)

The model config can be:

  • A string, which is the path to the model checkpoint in pickle format. Load directly using torch.load.
  • {"model_name": str, "pickle_path": str}, load the model from the binary file (pickle format). This will first construct the model using ImageEncoder(model_name), and then load the state dict from model located in the pickle file.
  • {"model_name": str, "state_dict_path": str}, load the model from the state dict file. This will first construct the model using ImageEncoder(model_name), and then load the state dict from the file.
  • Default, load the model using instantiate from hydra.
Source code in fusion_bench/modelpool/openclip_vision/modelpool.py
def load_model(
    self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
) -> ImageEncoder:
    R"""
    The model config can be:

    - A string, which is the path to the model checkpoint in pickle format. Load directly using `torch.load`.
    - {"model_name": str, "pickle_path": str}, load the model from the binary file (pickle format). This will first construct the model using `ImageEncoder(model_name)`, and then load the state dict from model located in the pickle file.
    - {"model_name": str, "state_dict_path": str}, load the model from the state dict file. This will first construct the model using `ImageEncoder(model_name)`, and then load the state dict from the file.
    - Default, load the model using `instantiate` from hydra.
    """
    if (
        isinstance(model_name_or_config, str)
        and model_name_or_config in self._models
    ):
        model_config = self._models[model_name_or_config]
    else:
        model_config = model_name_or_config
    if isinstance(model_config, DictConfig):
        model_config = OmegaConf.to_container(model_config, resolve=True)

    if isinstance(model_config, str):
        # the model config is a string, which is the path to the model checkpoint in pickle format
        # load the model using `torch.load`
        # this is the original usage in the task arithmetic codebase
        _check_and_redirect_open_clip_modeling()
        log.info(f"loading ImageEncoder from {model_config}")
        weights_only = kwargs["weights_only"] if "weights_only" in kwargs else False
        try:
            encoder = torch.load(
                model_config, weights_only=weights_only, *args, **kwargs
            )
        except RuntimeError as e:
            encoder = pickle.load(open(model_config, "rb"))
    elif is_expr_match({"model_name": str, "pickle_path": str}, model_config):
        # the model config is a dictionary with the following keys:
        # - model_name: str, the name of the model
        # - pickle_path: str, the path to the binary file (pickle format)
        # load the model from the binary file (pickle format)
        # this is useful when you use a newer version of torchvision
        _check_and_redirect_open_clip_modeling()
        log.info(
            f"loading ImageEncoder of {model_config['model_name']} from {model_config['pickle_path']}"
        )
        weights_only = kwargs["weights_only"] if "weights_only" in kwargs else False
        try:
            encoder = torch.load(
                model_config["pickle_path"],
                weights_only=weights_only,
                *args,
                **kwargs,
            )
        except RuntimeError as e:
            encoder = pickle.load(open(model_config["pickle_path"], "rb"))
        _encoder = ImageEncoder(model_config["model_name"])
        _encoder.load_state_dict(encoder.state_dict())
        encoder = _encoder
    elif is_expr_match({"model_name": str, "state_dict_path": str}, model_config):
        # the model config is a dictionary with the following keys:
        # - model_name: str, the name of the model
        # - state_dict_path: str, the path to the state dict file
        # load the model from the state dict file
        log.info(
            f"loading ImageEncoder of {model_config['model_name']} from {model_config['state_dict_path']}"
        )
        encoder = ImageEncoder(model_config["model_name"])
        encoder.load_state_dict(
            torch.load(
                model_config["state_dict_path"], weights_only=True, *args, **kwargs
            )
        )
    elif isinstance(model_config, nn.Module):
        # the model config is an existing model
        log.info(f"Returning existing model: {model_config}")
        encoder = model_config
    else:
        encoder = super().load_model(model_name_or_config, *args, **kwargs)
    encoder = cast(ImageEncoder, encoder)

    # setup the train and test processors
    if self._train_processor is None and hasattr(encoder, "train_preprocess"):
        self._train_processor = encoder.train_preprocess
    if self._test_processor is None and hasattr(encoder, "val_preprocess"):
        self._test_processor = encoder.val_preprocess

    return encoder

NLP Model Pool

GPT-2

HuggingFaceGPT2ClassificationPool = GPT2ForSequenceClassificationPool module-attribute

GPT2ForSequenceClassificationPool

Bases: BaseModelPool

Source code in fusion_bench/modelpool/huggingface_gpt2_classification.py
class GPT2ForSequenceClassificationPool(BaseModelPool):
    _config_mapping = BaseModelPool._config_mapping | {"_tokenizer": "tokenizer"}

    def __init__(self, tokenizer: DictConfig, **kwargs):
        self._tokenizer = tokenizer
        super().__init__(**kwargs)
        self.setup()

    def setup(self):
        global tokenizer
        self.tokenizer = tokenizer = instantiate(self._tokenizer)

    def load_classifier(
        self, model_config: str | DictConfig
    ) -> GPT2ForSequenceClassification:
        if isinstance(model_config, str):
            model_config = self.get_model_config(model_config, return_copy=True)
        model_config._target_ = (
            "transformers.GPT2ForSequenceClassification.from_pretrained"
        )
        model = instantiate(model_config)
        return model

Seq2Seq Language Models (Flan-T5)

Seq2SeqLMPool

Bases: BaseModelPool

Source code in fusion_bench/modelpool/seq2seq_lm/modelpool.py
class Seq2SeqLMPool(BaseModelPool):
    _config_mapping = BaseModelPool._config_mapping | {
        "_tokenizer": "tokenizer",
        "_model_kwargs": "model_kwargs",
    }

    def __init__(
        self,
        models: DictConfig,
        *,
        tokenizer: Optional[DictConfig],
        model_kwargs: Optional[DictConfig] = None,
        **kwargs,
    ):
        super().__init__(models, **kwargs)
        self._tokenizer = tokenizer
        self._model_kwargs = model_kwargs
        if self._model_kwargs is None:
            self._model_kwargs = DictConfig({})
        with flag_override(self._model_kwargs, "allow_objects", True):
            if hasattr(self._model_kwargs, "torch_dtype"):
                self._model_kwargs.torch_dtype = parse_dtype(
                    self._model_kwargs.torch_dtype
                )

    def load_model(self, model_name_or_config: str | DictConfig, *args, **kwargs):
        model_kwargs = deepcopy(self._model_kwargs)
        model_kwargs.update(kwargs)
        return super().load_model(model_name_or_config, *args, **model_kwargs)

    def load_tokenizer(self, *args, **kwargs):
        assert self._tokenizer is not None, "Tokenizer is not defined in the config"
        tokenizer = isinstance(self._tokenizer, *args, **kwargs)
        return tokenizer

SequenceClassificationModelPool

Bases: BaseModelPool

Source code in fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py
class SequenceClassificationModelPool(BaseModelPool):

    def __init__(
        self,
        models,
        *,
        tokenizer: Optional[DictConfig],
        model_kwargs: Optional[DictConfig] = None,
        **kwargs,
    ):
        super().__init__(models, **kwargs)
        # process `model_kwargs`
        self._tokenizer = tokenizer
        self._model_kwargs = model_kwargs
        if self._model_kwargs is None:
            self._model_kwargs = DictConfig({})
        with flag_override(self._model_kwargs, "allow_objects", True):
            if hasattr(self._model_kwargs, "torch_dtype"):
                self._model_kwargs.torch_dtype = parse_dtype(
                    self._model_kwargs.torch_dtype
                )

    @override
    def load_model(
        self,
        model_name_or_config: str | DictConfig,
        *args,
        **kwargs,
    ) -> Union[PreTrainedModel, "LlamaForSequenceClassification"]:
        model_kwargs = deepcopy(self._model_kwargs)
        model_kwargs.update(kwargs)
        if isinstance(model_name_or_config, str):
            log.info(f"Loading model: {model_name_or_config}", stacklevel=2)
        return super().load_model(model_name_or_config, *args, **model_kwargs)

    def load_tokenizer(self, *args, **kwargs) -> PreTrainedTokenizer:
        assert self._tokenizer is not None, "Tokenizer is not defined in the config"
        log.info("Loading tokenizer.", stacklevel=2)
        tokenizer = instantiate(self._tokenizer, *args, **kwargs)
        return tokenizer

    @override
    def save_model(
        self,
        model: PreTrainedModel,
        path: str,
        push_to_hub: bool = False,
        model_dtype: Optional[str] = None,
        save_tokenizer: bool = False,
        tokenizer_kwargs=None,
        **kwargs,
    ):
        """
        Save the model to the specified path.

        Args:
            model (PreTrainedModel): The model to be saved.
            path (str): The path where the model will be saved.
            push_to_hub (bool, optional): Whether to push the model to the Hugging Face Hub. Defaults to False.
            save_tokenizer (bool, optional): Whether to save the tokenizer along with the model. Defaults to False.
            **kwargs: Additional keyword arguments passed to the `save_pretrained` method.
        """
        path = os.path.expanduser(path)
        if save_tokenizer:
            if tokenizer_kwargs is None:
                tokenizer_kwargs = {}
            # load the tokenizer
            tokenizer = self.load_tokenizer(**tokenizer_kwargs)
            tokenizer.save_pretrained(
                path,
                push_to_hub=push_to_hub,
            )
        if model_dtype is not None:
            model.to(dtype=parse_dtype(model_dtype))
        model.save_pretrained(
            path,
            push_to_hub=push_to_hub,
            **kwargs,
        )
save_model(model, path, push_to_hub=False, model_dtype=None, save_tokenizer=False, tokenizer_kwargs=None, **kwargs)

Save the model to the specified path.

Parameters:

  • model (PreTrainedModel) –

    The model to be saved.

  • path (str) –

    The path where the model will be saved.

  • push_to_hub (bool, default: False ) –

    Whether to push the model to the Hugging Face Hub. Defaults to False.

  • save_tokenizer (bool, default: False ) –

    Whether to save the tokenizer along with the model. Defaults to False.

  • **kwargs

    Additional keyword arguments passed to the save_pretrained method.

Source code in fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py
@override
def save_model(
    self,
    model: PreTrainedModel,
    path: str,
    push_to_hub: bool = False,
    model_dtype: Optional[str] = None,
    save_tokenizer: bool = False,
    tokenizer_kwargs=None,
    **kwargs,
):
    """
    Save the model to the specified path.

    Args:
        model (PreTrainedModel): The model to be saved.
        path (str): The path where the model will be saved.
        push_to_hub (bool, optional): Whether to push the model to the Hugging Face Hub. Defaults to False.
        save_tokenizer (bool, optional): Whether to save the tokenizer along with the model. Defaults to False.
        **kwargs: Additional keyword arguments passed to the `save_pretrained` method.
    """
    path = os.path.expanduser(path)
    if save_tokenizer:
        if tokenizer_kwargs is None:
            tokenizer_kwargs = {}
        # load the tokenizer
        tokenizer = self.load_tokenizer(**tokenizer_kwargs)
        tokenizer.save_pretrained(
            path,
            push_to_hub=push_to_hub,
        )
    if model_dtype is not None:
        model.to(dtype=parse_dtype(model_dtype))
    model.save_pretrained(
        path,
        push_to_hub=push_to_hub,
        **kwargs,
    )

PeftModelForSeq2SeqLMPool

Bases: ModelPool

Source code in fusion_bench/modelpool/PeftModelForSeq2SeqLM.py
class PeftModelForSeq2SeqLMPool(ModelPool):
    def load_model(self, model_config: str | DictConfig):
        """
        Load a model based on the provided configuration.

        The configuration options of `model_config` are:

        - name: The name of the model. If it is "_pretrained_", a pretrained Seq2Seq language model is returned.
        - path: The path where the model is stored.
        - is_trainable: A boolean indicating whether the model parameters should be trainable. Default is `True`.
        - merge_and_unload: A boolean indicating whether to merge and unload the PEFT model after loading. Default is `True`.


        Args:
            model_config (str | DictConfig): The configuration for the model. This can be either a string (name of the model) or a DictConfig object containing the model configuration.


        Returns:
            model: The loaded model. If the model name is "_pretrained_", it returns a pretrained Seq2Seq language model. Otherwise, it returns a PEFT model.
        """
        if isinstance(model_config, str):
            model_config = self.get_model_config(model_config)
        with timeit_context(f"Loading model {model_config['name']}"):
            if model_config["name"] == "_pretrained_":
                model = AutoModelForSeq2SeqLM.from_pretrained(model_config["path"])
                return model
            else:
                model = self.load_model("_pretrained_")
                peft_model = PeftModel.from_pretrained(
                    model,
                    model_config["path"],
                    is_trainable=model_config.get("is_trainable", True),
                )
                if model_config.get("merge_and_unload", True):
                    return peft_model.merge_and_unload()
                else:
                    return peft_model
load_model(model_config)

Load a model based on the provided configuration.

The configuration options of model_config are:

  • name: The name of the model. If it is "pretrained", a pretrained Seq2Seq language model is returned.
  • path: The path where the model is stored.
  • is_trainable: A boolean indicating whether the model parameters should be trainable. Default is True.
  • merge_and_unload: A boolean indicating whether to merge and unload the PEFT model after loading. Default is True.

Parameters:

  • model_config (str | DictConfig) –

    The configuration for the model. This can be either a string (name of the model) or a DictConfig object containing the model configuration.

Returns:

  • model

    The loaded model. If the model name is "pretrained", it returns a pretrained Seq2Seq language model. Otherwise, it returns a PEFT model.

Source code in fusion_bench/modelpool/PeftModelForSeq2SeqLM.py
def load_model(self, model_config: str | DictConfig):
    """
    Load a model based on the provided configuration.

    The configuration options of `model_config` are:

    - name: The name of the model. If it is "_pretrained_", a pretrained Seq2Seq language model is returned.
    - path: The path where the model is stored.
    - is_trainable: A boolean indicating whether the model parameters should be trainable. Default is `True`.
    - merge_and_unload: A boolean indicating whether to merge and unload the PEFT model after loading. Default is `True`.


    Args:
        model_config (str | DictConfig): The configuration for the model. This can be either a string (name of the model) or a DictConfig object containing the model configuration.


    Returns:
        model: The loaded model. If the model name is "_pretrained_", it returns a pretrained Seq2Seq language model. Otherwise, it returns a PEFT model.
    """
    if isinstance(model_config, str):
        model_config = self.get_model_config(model_config)
    with timeit_context(f"Loading model {model_config['name']}"):
        if model_config["name"] == "_pretrained_":
            model = AutoModelForSeq2SeqLM.from_pretrained(model_config["path"])
            return model
        else:
            model = self.load_model("_pretrained_")
            peft_model = PeftModel.from_pretrained(
                model,
                model_config["path"],
                is_trainable=model_config.get("is_trainable", True),
            )
            if model_config.get("merge_and_unload", True):
                return peft_model.merge_and_unload()
            else:
                return peft_model

Causal Language Models (Llama, Mistral, Qwen...)

CausalLMPool

Bases: BaseModelPool

Source code in fusion_bench/modelpool/causal_lm/causal_lm.py
class CausalLMPool(BaseModelPool):
    _config_mapping = BaseModelPool._config_mapping | {
        "_tokenizer": "tokenizer",
        "_model_kwargs": "model_kwargs",
        "load_lazy": "load_lazy",
    }

    def __init__(
        self,
        models,
        *,
        tokenizer: Optional[DictConfig],
        model_kwargs: Optional[DictConfig] = None,
        load_lazy: bool = False,
        **kwargs,
    ):
        super().__init__(models, **kwargs)
        # process `model_kwargs`
        self._tokenizer = tokenizer
        self._model_kwargs = model_kwargs
        if self._model_kwargs is None:
            self._model_kwargs = DictConfig({})
        with flag_override(self._model_kwargs, "allow_objects", True):
            if hasattr(self._model_kwargs, "torch_dtype"):
                self._model_kwargs.torch_dtype = parse_dtype(
                    self._model_kwargs.torch_dtype
                )
        self.load_lazy = load_lazy

    @override
    def load_model(
        self,
        model_name_or_config: str | DictConfig,
        *args,
        **kwargs,
    ) -> PreTrainedModel:
        """
        Example of YAML config:

        ```yaml
        models:
          _pretrained_: path_to_pretrained_model # if a plain string, it will be passed to AutoModelForCausalLM.from_pretrained
          model_a: path_to_model_a
          model_b: path_to_model_b
        ```

        or equivalently,

        ```yaml
        models:
          _pretrained_:
            _target_: transformers.AutoModelForCausalLM # any callable that returns a model
            pretrained_model_name_or_path: path_to_pretrained_model
          model_a:
            _target_: transformers.AutoModelForCausalLM
            pretrained_model_name_or_path: path_to_model_a
          model_b:
            _target_: transformers.AutoModelForCausalLM
            pretrained_model_name_or_path: path_to_model_b
        ```
        """
        model_kwargs = deepcopy(self._model_kwargs)
        model_kwargs.update(kwargs)

        if isinstance(model_name_or_config, str):
            # If model_name_or_config is a string, it is the name or the path of the model
            log.info(f"Loading model: {model_name_or_config}", stacklevel=2)
            if model_name_or_config in self._models.keys():
                model_config = self._models[model_name_or_config]
                if isinstance(model_config, str):
                    # model_config is a string
                    if not self.load_lazy:
                        model = AutoModelForCausalLM.from_pretrained(
                            model_config,
                            *args,
                            **model_kwargs,
                        )
                    else:
                        # model_config is a string, but we want to use LazyStateDict
                        model = LazyStateDict(
                            checkpoint=model_config,
                            meta_module_class=AutoModelForCausalLM,
                            *args,
                            **model_kwargs,
                        )
                    return model
        elif isinstance(model_name_or_config, (DictConfig, Dict)):
            model_config = model_name_or_config

        if not self.load_lazy:
            model = instantiate(model_config, *args, **model_kwargs)
        else:
            meta_module_class = model_config.pop("_target_")
            checkpoint = model_config.pop("pretrained_model_name_or_path")
            model = LazyStateDict(
                checkpoint=checkpoint,
                meta_module_class=meta_module_class,
                *args,
                **model_kwargs,
            )
        return model

    def load_tokenizer(self, *args, **kwargs) -> PreTrainedTokenizer:
        """
        Example of YAML config:

        ```yaml
        tokenizer: google/gemma-2-2b-it # if a plain string, it will be passed to AutoTokenizer.from_pretrained
        ```

        or equivalently,

        ```yaml
        tokenizer:
          _target_: transformers.AutoTokenizer # any callable that returns a tokenizer
          pretrained_model_name_or_path: google/gemma-2-2b-it
        ```

        Returns:
            PreTrainedTokenizer: The tokenizer.
        """
        assert self._tokenizer is not None, "Tokenizer is not defined in the config"
        log.info("Loading tokenizer.", stacklevel=2)
        if isinstance(self._tokenizer, str):
            tokenizer = AutoTokenizer.from_pretrained(self._tokenizer, *args, **kwargs)
        else:
            tokenizer = instantiate(self._tokenizer, *args, **kwargs)
        return tokenizer

    @override
    def save_model(
        self,
        model: PreTrainedModel,
        path: str,
        push_to_hub: bool = False,
        model_dtype: Optional[str] = None,
        save_tokenizer: bool = False,
        tokenizer_kwargs=None,
        tokenizer: Optional[PreTrainedTokenizer] = None,
        **kwargs,
    ):
        """
        Save the model to the specified path.

        Args:
            model (PreTrainedModel): The model to be saved.
            path (str): The path where the model will be saved.
            push_to_hub (bool, optional): Whether to push the model to the Hugging Face Hub. Defaults to False.
            save_tokenizer (bool, optional): Whether to save the tokenizer along with the model. Defaults to False.
            **kwargs: Additional keyword arguments passed to the `save_pretrained` method.
        """
        path = os.path.expanduser(path)
        # NOTE: if tokenizer is provided, it will be saved regardless of `save_tokenizer`
        if save_tokenizer or tokenizer is not None:
            if tokenizer is None:
                if tokenizer_kwargs is None:
                    tokenizer_kwargs = {}
                # load the tokenizer
                tokenizer = self.load_tokenizer(**tokenizer_kwargs)
            tokenizer.save_pretrained(
                path,
                push_to_hub=push_to_hub,
            )
        if model_dtype is not None:
            model.to(dtype=parse_dtype(model_dtype))
        model.save_pretrained(
            path,
            push_to_hub=push_to_hub,
            **kwargs,
        )
load_model(model_name_or_config, *args, **kwargs)

Example of YAML config:

models:
  _pretrained_: path_to_pretrained_model # if a plain string, it will be passed to AutoModelForCausalLM.from_pretrained
  model_a: path_to_model_a
  model_b: path_to_model_b

or equivalently,

models:
  _pretrained_:
    _target_: transformers.AutoModelForCausalLM # any callable that returns a model
    pretrained_model_name_or_path: path_to_pretrained_model
  model_a:
    _target_: transformers.AutoModelForCausalLM
    pretrained_model_name_or_path: path_to_model_a
  model_b:
    _target_: transformers.AutoModelForCausalLM
    pretrained_model_name_or_path: path_to_model_b
Source code in fusion_bench/modelpool/causal_lm/causal_lm.py
@override
def load_model(
    self,
    model_name_or_config: str | DictConfig,
    *args,
    **kwargs,
) -> PreTrainedModel:
    """
    Example of YAML config:

    ```yaml
    models:
      _pretrained_: path_to_pretrained_model # if a plain string, it will be passed to AutoModelForCausalLM.from_pretrained
      model_a: path_to_model_a
      model_b: path_to_model_b
    ```

    or equivalently,

    ```yaml
    models:
      _pretrained_:
        _target_: transformers.AutoModelForCausalLM # any callable that returns a model
        pretrained_model_name_or_path: path_to_pretrained_model
      model_a:
        _target_: transformers.AutoModelForCausalLM
        pretrained_model_name_or_path: path_to_model_a
      model_b:
        _target_: transformers.AutoModelForCausalLM
        pretrained_model_name_or_path: path_to_model_b
    ```
    """
    model_kwargs = deepcopy(self._model_kwargs)
    model_kwargs.update(kwargs)

    if isinstance(model_name_or_config, str):
        # If model_name_or_config is a string, it is the name or the path of the model
        log.info(f"Loading model: {model_name_or_config}", stacklevel=2)
        if model_name_or_config in self._models.keys():
            model_config = self._models[model_name_or_config]
            if isinstance(model_config, str):
                # model_config is a string
                if not self.load_lazy:
                    model = AutoModelForCausalLM.from_pretrained(
                        model_config,
                        *args,
                        **model_kwargs,
                    )
                else:
                    # model_config is a string, but we want to use LazyStateDict
                    model = LazyStateDict(
                        checkpoint=model_config,
                        meta_module_class=AutoModelForCausalLM,
                        *args,
                        **model_kwargs,
                    )
                return model
    elif isinstance(model_name_or_config, (DictConfig, Dict)):
        model_config = model_name_or_config

    if not self.load_lazy:
        model = instantiate(model_config, *args, **model_kwargs)
    else:
        meta_module_class = model_config.pop("_target_")
        checkpoint = model_config.pop("pretrained_model_name_or_path")
        model = LazyStateDict(
            checkpoint=checkpoint,
            meta_module_class=meta_module_class,
            *args,
            **model_kwargs,
        )
    return model
load_tokenizer(*args, **kwargs)

Example of YAML config:

tokenizer: google/gemma-2-2b-it # if a plain string, it will be passed to AutoTokenizer.from_pretrained

or equivalently,

tokenizer:
  _target_: transformers.AutoTokenizer # any callable that returns a tokenizer
  pretrained_model_name_or_path: google/gemma-2-2b-it

Returns:

  • PreTrainedTokenizer ( PreTrainedTokenizer ) –

    The tokenizer.

Source code in fusion_bench/modelpool/causal_lm/causal_lm.py
def load_tokenizer(self, *args, **kwargs) -> PreTrainedTokenizer:
    """
    Example of YAML config:

    ```yaml
    tokenizer: google/gemma-2-2b-it # if a plain string, it will be passed to AutoTokenizer.from_pretrained
    ```

    or equivalently,

    ```yaml
    tokenizer:
      _target_: transformers.AutoTokenizer # any callable that returns a tokenizer
      pretrained_model_name_or_path: google/gemma-2-2b-it
    ```

    Returns:
        PreTrainedTokenizer: The tokenizer.
    """
    assert self._tokenizer is not None, "Tokenizer is not defined in the config"
    log.info("Loading tokenizer.", stacklevel=2)
    if isinstance(self._tokenizer, str):
        tokenizer = AutoTokenizer.from_pretrained(self._tokenizer, *args, **kwargs)
    else:
        tokenizer = instantiate(self._tokenizer, *args, **kwargs)
    return tokenizer
save_model(model, path, push_to_hub=False, model_dtype=None, save_tokenizer=False, tokenizer_kwargs=None, tokenizer=None, **kwargs)

Save the model to the specified path.

Parameters:

  • model (PreTrainedModel) –

    The model to be saved.

  • path (str) –

    The path where the model will be saved.

  • push_to_hub (bool, default: False ) –

    Whether to push the model to the Hugging Face Hub. Defaults to False.

  • save_tokenizer (bool, default: False ) –

    Whether to save the tokenizer along with the model. Defaults to False.

  • **kwargs

    Additional keyword arguments passed to the save_pretrained method.

Source code in fusion_bench/modelpool/causal_lm/causal_lm.py
@override
def save_model(
    self,
    model: PreTrainedModel,
    path: str,
    push_to_hub: bool = False,
    model_dtype: Optional[str] = None,
    save_tokenizer: bool = False,
    tokenizer_kwargs=None,
    tokenizer: Optional[PreTrainedTokenizer] = None,
    **kwargs,
):
    """
    Save the model to the specified path.

    Args:
        model (PreTrainedModel): The model to be saved.
        path (str): The path where the model will be saved.
        push_to_hub (bool, optional): Whether to push the model to the Hugging Face Hub. Defaults to False.
        save_tokenizer (bool, optional): Whether to save the tokenizer along with the model. Defaults to False.
        **kwargs: Additional keyword arguments passed to the `save_pretrained` method.
    """
    path = os.path.expanduser(path)
    # NOTE: if tokenizer is provided, it will be saved regardless of `save_tokenizer`
    if save_tokenizer or tokenizer is not None:
        if tokenizer is None:
            if tokenizer_kwargs is None:
                tokenizer_kwargs = {}
            # load the tokenizer
            tokenizer = self.load_tokenizer(**tokenizer_kwargs)
        tokenizer.save_pretrained(
            path,
            push_to_hub=push_to_hub,
        )
    if model_dtype is not None:
        model.to(dtype=parse_dtype(model_dtype))
    model.save_pretrained(
        path,
        push_to_hub=push_to_hub,
        **kwargs,
    )

CausalLMBackbonePool

Bases: CausalLMPool

Source code in fusion_bench/modelpool/causal_lm/causal_lm.py
class CausalLMBackbonePool(CausalLMPool):
    def load_model(
        self, model_name_or_config: str | DictConfig, *args, **kwargs
    ) -> Module:
        if self.load_lazy:
            log.warning(
                "CausalLMBackbonePool does not support lazy loading. "
                "Falling back to normal loading."
            )
            self.load_lazy = False
        model: AutoModelForCausalLM = super().load_model(
            model_name_or_config, *args, **kwargs
        )
        return model.model.layers

Others

Transformers AutoModel

AutoModelPool

Bases: ModelPool

Source code in fusion_bench/modelpool/huggingface_automodel.py
class AutoModelPool(ModelPool):
    def load_model(self, model_config: str | DictConfig) -> Module:
        if isinstance(model_config, str):
            model_config = self.get_model_config(model_config)
        else:
            model_config = model_config

        model = AutoModel.from_pretrained(model_config.path)
        return model