Skip to content

Introduction to Model Pool Module

A modelpool is a collection of models that are utilized in the process of model fusion. In the context of straightforward model fusion techniques, like averaging, only models with the same architecture are used. While for more complex methods, such as AdaMerging 1, each model is paired with a unique set of unlabeled test data. This data is used during the test-time adaptation phase.

Yaml Configuration

A modelpool is specified by a yaml configuration file, which often contains the following fields:

  • type: The name of the modelpool.
  • models: A list of models, each model is dict with the following fields:
    • name: The name of the model. There are some special names that are reserved for specific purposes, such as _pretrained_ for the pretrained model.
    • path: The path to the model file.
    • type: The type of the model. If this field is not specified, the type is inferred from the model_type.

For more complex model fusion techniques that requires data, the modelpool configuration file may also contain the following fields:

  • dataset_type: The type of the dataset used for training the models in the modelpool.
  • datasets: A list of datasets, each dataset is dict with the following fields:
    • name: The name of the dataset, which is used to pair the dataset with the corresponding model. The name of the dataset should match the name of the model.
    • path: The path to the dataset file.
    • type: The type of the dataset. If this field is not specified, the type is inferred from the dataset_type.

We provide a list of modelpools that contain models trained on different datasets and with different architectures. Each modelpool is described in a separate document.

Basic Usage

The model is not loaded by default when you initialize a modelpool, you can load a model from a modelpool by calling the load_model method:

model = modelpool.load_model('model_name')

References

load_modelpool_from_config(modelpool_config)

Loads a model pool based on the provided configuration.

The function checks the 'type' attribute of the configuration and returns an instance of the corresponding model pool. If the 'type' attribute is not found or does not match any known model pool types, a ValueError is raised.

Parameters:

  • modelpool_config (DictConfig) –

    The configuration for the model pool. Must contain a 'type' attribute that specifies the type of the model pool.

Returns:

  • An instance of the specified model pool.

Raises:

  • ValueError

    If 'type' attribute is not found in the configuration or does not match any known model pool types.

Source code in fusion_bench/modelpool/__init__.py
def load_modelpool_from_config(modelpool_config: DictConfig):
    """
    Loads a model pool based on the provided configuration.

    The function checks the 'type' attribute of the configuration and returns an instance of the corresponding model pool.
    If the 'type' attribute is not found or does not match any known model pool types, a ValueError is raised.

    Args:
        modelpool_config (DictConfig): The configuration for the model pool. Must contain a 'type' attribute that specifies the type of the model pool.

    Returns:
        An instance of the specified model pool.

    Raises:
        ValueError: If 'type' attribute is not found in the configuration or does not match any known model pool types.
    """
    return ModelPoolFactory.create_modelpool(modelpool_config)

ModelPool

Bases: ABC

This is the base class for all modelpools.

Source code in fusion_bench/modelpool/base_pool.py
class ModelPool(ABC):
    """
    This is the base class for all modelpools.
    """

    _model_names = None

    def __init__(self, modelpool_config: Optional[DictConfig] = None):
        super().__init__()
        self.config = modelpool_config

        # check for duplicate model names
        if self.config is not None and self.config.get("models", None) is not None:
            model_names = [model["name"] for model in self.config["models"]]
            assert len(model_names) == len(
                set(model_names)
            ), "Duplicate model names found in model pool"
            self._model_names = model_names

    def __len__(self):
        return len(self.model_names)

    @property
    def model_names(self) -> List[str]:
        """
        This property returns a list of model names from the configuration, excluding any names that start or end with an underscore.
        To obtain all model names, including those starting or ending with an underscore, use the `_model_names` attribute.

        Returns:
            list: A list of model names.
        """
        names = [
            name for name in self._model_names if name[0] != "_" and name[-1] != "_"
        ]
        return names

    @property
    def has_pretrained(self):
        """
        Check if the pretrained model is available in the model pool.
        """
        for model_config in self.config["models"]:
            if model_config.get("name", None) == "_pretrained_":
                return True
        return False

    def get_model_config(self, model_name: str):
        """
        Retrieves the configuration for a specific model from the model pool.

        Args:
            model_name (str): The name of the model for which to retrieve the configuration.

        Returns:
            dict: The configuration dictionary for the specified model.

        Raises:
            ValueError: If the specified model is not found in the model pool.
        """
        for model in self.config["models"]:
            if model["name"] == model_name:
                return model
        raise ValueError(f"Model {model_name} not found in model pool")

    def load_model(self, model_config: Union[str, DictConfig]) -> nn.Module:
        """
        The models are load lazily, so this method should be implemented to load the model from the model pool.

        Load the model from the model pool.

        Args:
            model_config (Union[str, DictConfig]): The configuration dictionary for the model to load.

        Returns:
            Any: The loaded model.
        """
        raise NotImplementedError

    def load_pretrained_or_first_model(self, *args, **kwargs):
        """
        Load the pretrained model if available, otherwise load the first model in the list.

        This method checks if a pretrained model is available. If it is, it loads the pretrained model.
        If not, it loads the first model from the list of model names.

        Returns:
            nn.Module: The loaded model.
        """
        if self.has_pretrained:
            model = self.load_model("_pretrained_", *args, **kwargs)
        else:
            model = self.load_model(self.model_names[0], *args, **kwargs)
        return model

    def save_model(self, model: nn.Module, path: str):
        """
        Save the state dictionary of the model to the specified path.

        Args:
            model (nn.Module): The model whose state dictionary is to be saved.
            path (str): The path where the state dictionary will be saved.
        """
        with timeit_context(f"Saving the state dict of model to {path}"):
            torch.save(model.state_dict(), path)

    def models(self):
        for model_name in self.model_names:
            yield self.load_model(model_name)

    def named_models(self):
        for model_name in self.model_names:
            yield model_name, self.load_model(model_name)

    def get_train_dataset(self, model_name: str):
        """
        Get the training dataset for the model.
        """
        raise NotImplementedError

    def get_test_dataset(self, model_name: str):
        """
        Get the testing dataset for the model.
        """
        raise NotImplementedError

    def setup_taskpool(self, taskpool):
        """
        Setup the taskpool before evaluation.
        Such as setting the fabric, processor, tokenizer, etc.
        """
        pass

    def to_modellist(self) -> List[nn.Module]:
        """
        Convert the model pool to a list of models.
        """
        return [self.load_model(m) for m in self.model_names]

    def to_modeldict(self) -> Dict[str, nn.Module]:
        """
        Convert the model pool to a dictionary of models.
        """
        return {m: self.load_model(m) for m in self.model_names}
has_pretrained property

Check if the pretrained model is available in the model pool.

model_names: List[str] property

This property returns a list of model names from the configuration, excluding any names that start or end with an underscore. To obtain all model names, including those starting or ending with an underscore, use the _model_names attribute.

Returns:

  • list ( List[str] ) –

    A list of model names.

get_model_config(model_name)

Retrieves the configuration for a specific model from the model pool.

Parameters:

  • model_name (str) –

    The name of the model for which to retrieve the configuration.

Returns:

  • dict

    The configuration dictionary for the specified model.

Raises:

  • ValueError

    If the specified model is not found in the model pool.

Source code in fusion_bench/modelpool/base_pool.py
def get_model_config(self, model_name: str):
    """
    Retrieves the configuration for a specific model from the model pool.

    Args:
        model_name (str): The name of the model for which to retrieve the configuration.

    Returns:
        dict: The configuration dictionary for the specified model.

    Raises:
        ValueError: If the specified model is not found in the model pool.
    """
    for model in self.config["models"]:
        if model["name"] == model_name:
            return model
    raise ValueError(f"Model {model_name} not found in model pool")
get_test_dataset(model_name)

Get the testing dataset for the model.

Source code in fusion_bench/modelpool/base_pool.py
def get_test_dataset(self, model_name: str):
    """
    Get the testing dataset for the model.
    """
    raise NotImplementedError
get_train_dataset(model_name)

Get the training dataset for the model.

Source code in fusion_bench/modelpool/base_pool.py
def get_train_dataset(self, model_name: str):
    """
    Get the training dataset for the model.
    """
    raise NotImplementedError
load_model(model_config)

The models are load lazily, so this method should be implemented to load the model from the model pool.

Load the model from the model pool.

Parameters:

  • model_config (Union[str, DictConfig]) –

    The configuration dictionary for the model to load.

Returns:

  • Any ( Module ) –

    The loaded model.

Source code in fusion_bench/modelpool/base_pool.py
def load_model(self, model_config: Union[str, DictConfig]) -> nn.Module:
    """
    The models are load lazily, so this method should be implemented to load the model from the model pool.

    Load the model from the model pool.

    Args:
        model_config (Union[str, DictConfig]): The configuration dictionary for the model to load.

    Returns:
        Any: The loaded model.
    """
    raise NotImplementedError
load_pretrained_or_first_model(*args, **kwargs)

Load the pretrained model if available, otherwise load the first model in the list.

This method checks if a pretrained model is available. If it is, it loads the pretrained model. If not, it loads the first model from the list of model names.

Returns:

  • nn.Module: The loaded model.

Source code in fusion_bench/modelpool/base_pool.py
def load_pretrained_or_first_model(self, *args, **kwargs):
    """
    Load the pretrained model if available, otherwise load the first model in the list.

    This method checks if a pretrained model is available. If it is, it loads the pretrained model.
    If not, it loads the first model from the list of model names.

    Returns:
        nn.Module: The loaded model.
    """
    if self.has_pretrained:
        model = self.load_model("_pretrained_", *args, **kwargs)
    else:
        model = self.load_model(self.model_names[0], *args, **kwargs)
    return model
save_model(model, path)

Save the state dictionary of the model to the specified path.

Parameters:

  • model (Module) –

    The model whose state dictionary is to be saved.

  • path (str) –

    The path where the state dictionary will be saved.

Source code in fusion_bench/modelpool/base_pool.py
def save_model(self, model: nn.Module, path: str):
    """
    Save the state dictionary of the model to the specified path.

    Args:
        model (nn.Module): The model whose state dictionary is to be saved.
        path (str): The path where the state dictionary will be saved.
    """
    with timeit_context(f"Saving the state dict of model to {path}"):
        torch.save(model.state_dict(), path)
setup_taskpool(taskpool)

Setup the taskpool before evaluation. Such as setting the fabric, processor, tokenizer, etc.

Source code in fusion_bench/modelpool/base_pool.py
def setup_taskpool(self, taskpool):
    """
    Setup the taskpool before evaluation.
    Such as setting the fabric, processor, tokenizer, etc.
    """
    pass
to_modeldict()

Convert the model pool to a dictionary of models.

Source code in fusion_bench/modelpool/base_pool.py
def to_modeldict(self) -> Dict[str, nn.Module]:
    """
    Convert the model pool to a dictionary of models.
    """
    return {m: self.load_model(m) for m in self.model_names}
to_modellist()

Convert the model pool to a list of models.

Source code in fusion_bench/modelpool/base_pool.py
def to_modellist(self) -> List[nn.Module]:
    """
    Convert the model pool to a list of models.
    """
    return [self.load_model(m) for m in self.model_names]

  1. AdaMerging: Adaptive Model Merging for Multi-Task Learning. http://arxiv.org/abs/2310.02575