Skip to content

fusion_bench.compat

Method

ModelFusionAlgorithm

Bases: ABC

Abstract base class for model fusion algorithms (for v0.1.x versions, deprecated). For implementing new method, use fusion_bench.method.BaseModelFusionAlgorithm instead.

This class provides a template for implementing model fusion algorithms. Subclasses must implement the run method to define the fusion logic.

Attributes:

  • config (DictConfig) –

    Configuration for the algorithm.

Source code in fusion_bench/compat/method/base_algorithm.py
class ModelFusionAlgorithm(ABC):
    """
    Abstract base class for model fusion algorithms (for v0.1.x versions, deprecated).
    For implementing new method, use `fusion_bench.method.BaseModelFusionAlgorithm` instead.

    This class provides a template for implementing model fusion algorithms.
    Subclasses must implement the `run` method to define the fusion logic.

    Attributes:
        config (DictConfig): Configuration for the algorithm.
    """

    _program: "BaseHydraProgram" = None
    """A reference to the program that is running the algorithm."""

    def __init__(self, algorithm_config: Optional[DictConfig] = None):
        """
        Initialize the model fusion algorithm with the given configuration.

        Args:
            algorithm_config (Optional[DictConfig]): Configuration for the algorithm. Defaults to an empty configuration if not provided.
                Get access to the configuration using `self.config`.
        """
        if algorithm_config is None:
            algorithm_config = DictConfig({})
        self.config = algorithm_config

    def on_run_start(self):
        """
        Hook method called at the start of the run.
        Can be overridden by subclasses to perform initialization tasks.
        """
        pass

    def on_run_end(self):
        """
        Hook method called at the end of the run.
        Can be overridden by subclasses to perform cleanup tasks.
        """
        pass

    @abstractmethod
    def run(self, modelpool: "BaseModelPool") -> Any:
        """
        Fuse the models in the given model pool.

        This method must be implemented by subclasses to define the fusion logic.
        `modelpool` is an object responsible for managing the models to be fused and optional datasets to be used for fusion.

        Args:
            modelpool: The pool of models to fuse.

        Returns:
            The fused model.

        Examples:
            >>> algorithm = SimpleAverageAlgorithm()
            >>> modelpool = ModelPool()
            >>> merged_model = algorithm.fuse(modelpool)
        """
        pass

__init__(algorithm_config=None)

Initialize the model fusion algorithm with the given configuration.

Parameters:

  • algorithm_config (Optional[DictConfig], default: None ) –

    Configuration for the algorithm. Defaults to an empty configuration if not provided. Get access to the configuration using self.config.

Source code in fusion_bench/compat/method/base_algorithm.py
def __init__(self, algorithm_config: Optional[DictConfig] = None):
    """
    Initialize the model fusion algorithm with the given configuration.

    Args:
        algorithm_config (Optional[DictConfig]): Configuration for the algorithm. Defaults to an empty configuration if not provided.
            Get access to the configuration using `self.config`.
    """
    if algorithm_config is None:
        algorithm_config = DictConfig({})
    self.config = algorithm_config

on_run_end()

Hook method called at the end of the run. Can be overridden by subclasses to perform cleanup tasks.

Source code in fusion_bench/compat/method/base_algorithm.py
def on_run_end(self):
    """
    Hook method called at the end of the run.
    Can be overridden by subclasses to perform cleanup tasks.
    """
    pass

on_run_start()

Hook method called at the start of the run. Can be overridden by subclasses to perform initialization tasks.

Source code in fusion_bench/compat/method/base_algorithm.py
def on_run_start(self):
    """
    Hook method called at the start of the run.
    Can be overridden by subclasses to perform initialization tasks.
    """
    pass

run(modelpool) abstractmethod

Fuse the models in the given model pool.

This method must be implemented by subclasses to define the fusion logic. modelpool is an object responsible for managing the models to be fused and optional datasets to be used for fusion.

Parameters:

Returns:

  • Any

    The fused model.

Examples:

>>> algorithm = SimpleAverageAlgorithm()
>>> modelpool = ModelPool()
>>> merged_model = algorithm.fuse(modelpool)
Source code in fusion_bench/compat/method/base_algorithm.py
@abstractmethod
def run(self, modelpool: "BaseModelPool") -> Any:
    """
    Fuse the models in the given model pool.

    This method must be implemented by subclasses to define the fusion logic.
    `modelpool` is an object responsible for managing the models to be fused and optional datasets to be used for fusion.

    Args:
        modelpool: The pool of models to fuse.

    Returns:
        The fused model.

    Examples:
        >>> algorithm = SimpleAverageAlgorithm()
        >>> modelpool = ModelPool()
        >>> merged_model = algorithm.fuse(modelpool)
    """
    pass

AlgorithmFactory

Factory class to create and manage different model fusion algorithms.

This class provides methods to create algorithms based on a given configuration, register new algorithms, and list available algorithms.

Source code in fusion_bench/compat/method/__init__.py
class AlgorithmFactory:
    """
    Factory class to create and manage different model fusion algorithms.

    This class provides methods to create algorithms based on a given configuration,
    register new algorithms, and list available algorithms.
    """

    _aglorithms = {
        # single task learning (fine-tuning)
        "clip_finetune": ".classification.clip_finetune.ImageClassificationFineTuningForCLIP",
        # analysis
        # model merging methods
        "clip_task_wise_adamerging": ".adamerging.clip_task_wise_adamerging.CLIPTaskWiseAdaMergingAlgorithm",
        "clip_layer_wise_adamerging": ".adamerging.clip_layer_wise_adamerging.CLIPLayerWiseAdaMergingAlgorithm",
        "clip_layer_wise_adamerging_doge_ta": ".doge_ta.clip_layer_wise_adamerging.CLIPLayerWiseAdaMergingAlgorithm",
        "singular_projection_merging": "fusion_bench.method.smile_upscaling.singular_projection_merging.SingularProjectionMergingAlgorithm",
        "clip_layer_wise_adamerging_surgery": ".surgery.clip_layer_wise_adamerging_surgery.CLIPLayerWiseAdaMergingSurgeryAlgorithm",
        "clip_task_wise_gossip": ".gossip.clip_task_wise_gossip.CLIPTaskWiseGossipAlgorithm",
        "clip_layer_wise_gossip": ".gossip.clip_layer_wise_gossip.CLIPLayerWiseGossipAlgorithm",
        # plug-and-play model merging methods
        "clip_concrete_task_arithmetic": ".concrete_subspace.clip_concrete_task_arithmetic.ConcreteTaskArithmeticAlgorithmForCLIP",
        "clip_concrete_task_wise_adamerging": ".concrete_subspace.clip_concrete_adamerging.ConcreteTaskWiseAdaMergingForCLIP",
        "clip_concrete_layer_wise_adamerging": ".concrete_subspace.clip_concrete_adamerging.ConcreteLayerWiseAdaMergingForCLIP",
        "clip_post_defense_AWM": ".concrete_subspace.clip_post_defense.PostDefenseAWMAlgorithmForCLIP",
        "clip_post_defense_SAU": ".concrete_subspace.clip_post_defense.PostDefenseSAUAlgorithmForCLIP",
        "clip_safe_concrete_layer_wise_adamerging": ".concrete_subspace.clip_safe_concrete_adamerging.ConcreteSafeLayerWiseAdaMergingForCLIP",
        "clip_safe_concrete_task_wise_adamerging": ".concrete_subspace.clip_safe_concrete_adamerging.ConcreteSafeTaskWiseAdaMergingForCLIP",
        # model mixing methods
        "clip_weight_ensembling_moe": ".we_moe.clip_we_moe.CLIPWeightEnsemblingMoEAlgorithm",
        "sparse_clip_weight_ensembling_moe": "fusion_bench.method.SparseCLIPWeightEnsemblingMoEAlgorithm",
        "smile_mistral_upscaling": ".smile_upscaling.smile_mistral_upscaling.SmileMistralUpscalingAlgorithm",
        "rankone_moe": ".rankone_moe.clip_rankone_moe.CLIPRankOneMoEAlgorithm",
    }

    @staticmethod
    def create_algorithm(method_config: DictConfig) -> ModelFusionAlgorithm:
        """
        Create an instance of a model fusion algorithm based on the provided configuration.

        Args:
            method_config (DictConfig): The configuration for the algorithm. Must contain a 'name' attribute that specifies the type of the algorithm.

        Returns:
            ModelFusionAlgorithm: An instance of the specified algorithm.

        Raises:
            ValueError: If 'name' attribute is not found in the configuration or does not match any known algorithm names.
        """
        warnings.warn(
            "AlgorithmFactory.create_algorithm() is deprecated and will be removed in future versions. "
            "Please implement new model fusion algorithm using `fusion_bench.method.BaseModelFusionAlgorithm` instead.",
            DeprecationWarning,
        )

        from fusion_bench.utils import import_object

        algorithm_name = method_config.name
        if algorithm_name not in AlgorithmFactory._aglorithms:
            raise ValueError(
                f"Unknown algorithm: {algorithm_name}, available algorithms: {AlgorithmFactory._aglorithms.keys()}."
                "You can register a new algorithm using `AlgorithmFactory.register_algorithm()` method."
            )
        algorithm_cls = AlgorithmFactory._aglorithms[algorithm_name]
        if isinstance(algorithm_cls, str):
            if algorithm_cls.startswith("."):
                algorithm_cls = f"fusion_bench.method.{algorithm_cls[1:]}"
            algorithm_cls = import_object(algorithm_cls)
        return algorithm_cls(method_config)

    @staticmethod
    def register_algorithm(
        name: str, algorithm_cls: Type[ModelFusionAlgorithm]
    ) -> None:
        """
        Register a new algorithm with the factory.

        Args:
            name (str): The name of the algorithm.
            algorithm_cls: The class of the algorithm to register.
        """
        AlgorithmFactory._aglorithms[name] = algorithm_cls

    @classmethod
    def available_algorithms(cls) -> List[str]:
        """
        Get a list of available algorithms.

        Returns:
            list: A list of available algorithm names.
        """
        return list(cls._aglorithms.keys())

available_algorithms() classmethod

Get a list of available algorithms.

Returns:

  • list ( List[str] ) –

    A list of available algorithm names.

Source code in fusion_bench/compat/method/__init__.py
@classmethod
def available_algorithms(cls) -> List[str]:
    """
    Get a list of available algorithms.

    Returns:
        list: A list of available algorithm names.
    """
    return list(cls._aglorithms.keys())

create_algorithm(method_config) staticmethod

Create an instance of a model fusion algorithm based on the provided configuration.

Parameters:

  • method_config (DictConfig) –

    The configuration for the algorithm. Must contain a 'name' attribute that specifies the type of the algorithm.

Returns:

Raises:

  • ValueError

    If 'name' attribute is not found in the configuration or does not match any known algorithm names.

Source code in fusion_bench/compat/method/__init__.py
@staticmethod
def create_algorithm(method_config: DictConfig) -> ModelFusionAlgorithm:
    """
    Create an instance of a model fusion algorithm based on the provided configuration.

    Args:
        method_config (DictConfig): The configuration for the algorithm. Must contain a 'name' attribute that specifies the type of the algorithm.

    Returns:
        ModelFusionAlgorithm: An instance of the specified algorithm.

    Raises:
        ValueError: If 'name' attribute is not found in the configuration or does not match any known algorithm names.
    """
    warnings.warn(
        "AlgorithmFactory.create_algorithm() is deprecated and will be removed in future versions. "
        "Please implement new model fusion algorithm using `fusion_bench.method.BaseModelFusionAlgorithm` instead.",
        DeprecationWarning,
    )

    from fusion_bench.utils import import_object

    algorithm_name = method_config.name
    if algorithm_name not in AlgorithmFactory._aglorithms:
        raise ValueError(
            f"Unknown algorithm: {algorithm_name}, available algorithms: {AlgorithmFactory._aglorithms.keys()}."
            "You can register a new algorithm using `AlgorithmFactory.register_algorithm()` method."
        )
    algorithm_cls = AlgorithmFactory._aglorithms[algorithm_name]
    if isinstance(algorithm_cls, str):
        if algorithm_cls.startswith("."):
            algorithm_cls = f"fusion_bench.method.{algorithm_cls[1:]}"
        algorithm_cls = import_object(algorithm_cls)
    return algorithm_cls(method_config)

register_algorithm(name, algorithm_cls) staticmethod

Register a new algorithm with the factory.

Parameters:

  • name (str) –

    The name of the algorithm.

  • algorithm_cls (Type[ModelFusionAlgorithm]) –

    The class of the algorithm to register.

Source code in fusion_bench/compat/method/__init__.py
@staticmethod
def register_algorithm(
    name: str, algorithm_cls: Type[ModelFusionAlgorithm]
) -> None:
    """
    Register a new algorithm with the factory.

    Args:
        name (str): The name of the algorithm.
        algorithm_cls: The class of the algorithm to register.
    """
    AlgorithmFactory._aglorithms[name] = algorithm_cls

Model Pool

ModelPool

Bases: ABC

This is the base class for all modelpools. For verison v0.1.x, deprecated. Please implemente new algorithms use fusion_bench.modelpool.BaseModelPool.

Source code in fusion_bench/compat/modelpool/base_pool.py
class ModelPool(ABC):
    """
    This is the base class for all modelpools.
    For verison v0.1.x, deprecated.
    Please implemente new algorithms use `fusion_bench.modelpool.BaseModelPool`.
    """

    _model_names = None

    def __init__(self, modelpool_config: Optional[DictConfig] = None):
        """
        Initialize the ModelPool with the given configuration.

        Args:
            modelpool_config (Optional[DictConfig]): The configuration for the model pool.
        """
        super().__init__()
        self.config = modelpool_config

        # check for duplicate model names
        if self.config is not None and self.config.get("models", None) is not None:
            model_names = [model["name"] for model in self.config["models"]]
            assert len(model_names) == len(
                set(model_names)
            ), "Duplicate model names found in model pool"
            self._model_names = model_names

    def __len__(self) -> int:
        """
        Return the number of models in the model pool, exclude special models such as `_pretrained_`.

        Returns:
            int: The number of models in the model pool.
        """
        return len(self.model_names)

    @property
    def model_names(self) -> List[str]:
        """
        This property returns a list of model names from the configuration, excluding any names that start or end with an underscore.
        To obtain all model names, including those starting or ending with an underscore, use the `_model_names` attribute.

        Returns:
            list: A list of model names.
        """
        names = [
            name for name in self._model_names if name[0] != "_" and name[-1] != "_"
        ]
        return names

    @property
    def has_pretrained(self) -> bool:
        """
        Check if the pretrained model is available in the model pool.

        Returns:
            bool: True if the pretrained model is available, False otherwise.
        """
        for model_config in self.config["models"]:
            if model_config.get("name", None) == "_pretrained_":
                return True
        return False

    def get_model_config(self, model_name: str) -> Dict:
        """
        Retrieves the configuration for a specific model from the model pool.

        Args:
            model_name (str): The name of the model for which to retrieve the configuration.

        Returns:
            dict: The configuration dictionary for the specified model.

        Raises:
            ValueError: If the specified model is not found in the model pool.
        """
        for model in self.config["models"]:
            if model["name"] == model_name:
                return model
        raise ValueError(f"Model {model_name} not found in model pool")

    def load_model(self, model_config: Union[str, DictConfig]) -> nn.Module:
        """
        The models are load lazily, so this method should be implemented to load the model from the model pool.

        Load the model from the model pool.

        Args:
            model_config (Union[str, DictConfig]): The configuration dictionary for the model to load.

        Returns:
            Any: The loaded model.
        """
        raise NotImplementedError

    def load_pretrained_or_first_model(self, *args, **kwargs):
        """
        Load the pretrained model if available, otherwise load the first model in the list.

        This method checks if a pretrained model is available. If it is, it loads the pretrained model.
        If not, it loads the first model from the list of model names.

        Returns:
            nn.Module: The loaded model.
        """
        if self.has_pretrained:
            model = self.load_model("_pretrained_", *args, **kwargs)
        else:
            model = self.load_model(self.model_names[0], *args, **kwargs)
        return model

    def save_model(self, model: nn.Module, path: str):
        """
        Save the state dictionary of the model to the specified path.

        Args:
            model (nn.Module): The model whose state dictionary is to be saved.
            path (str): The path where the state dictionary will be saved.
        """
        with timeit_context(f"Saving the state dict of model to {path}"):
            torch.save(model.state_dict(), path)

    def models(self):
        """
        Generator that yields models from the model pool.

        Yields:
            nn.Module: The next model in the model pool.
        """
        for model_name in self.model_names:
            yield self.load_model(model_name)

    def named_models(self):
        """
        Generator that yields model names and models from the model pool.

        Yields:
            tuple: A tuple containing the model name and the model.
        """
        for model_name in self.model_names:
            yield model_name, self.load_model(model_name)

    def get_train_dataset(self, model_name: str):
        """
        Get the training dataset for the model.

        Args:
            model_name (str): The name of the model for which to get the training dataset.

        Returns:
            Any: The training dataset for the model.
        """
        raise NotImplementedError

    def get_test_dataset(self, model_name: str):
        """
        Get the testing dataset for the model.

        Args:
            model_name (str): The name of the model for which to get the testing dataset.

        Returns:
            Any: The testing dataset for the model.
        """
        raise NotImplementedError

    def setup_taskpool(self, taskpool):
        """
        Setup the taskpool before evaluation.
        Such as setting the fabric, processor, tokenizer, etc.

        Args:
            taskpool (Any): The taskpool to setup.
        """
        pass

    def to_modellist(self) -> List[nn.Module]:
        """
        Convert the model pool to a list of models.

        Returns:
            list: A list of models.
        """
        return [self.load_model(m) for m in self.model_names]

    def to_modeldict(self) -> Dict[str, nn.Module]:
        """
        Convert the model pool to a dictionary of models.

        Returns:
            dict: A dictionary of models.
        """
        return {m: self.load_model(m) for m in self.model_names}

has_pretrained property

Check if the pretrained model is available in the model pool.

Returns:

  • bool ( bool ) –

    True if the pretrained model is available, False otherwise.

model_names property

This property returns a list of model names from the configuration, excluding any names that start or end with an underscore. To obtain all model names, including those starting or ending with an underscore, use the _model_names attribute.

Returns:

  • list ( List[str] ) –

    A list of model names.

__init__(modelpool_config=None)

Initialize the ModelPool with the given configuration.

Parameters:

  • modelpool_config (Optional[DictConfig], default: None ) –

    The configuration for the model pool.

Source code in fusion_bench/compat/modelpool/base_pool.py
def __init__(self, modelpool_config: Optional[DictConfig] = None):
    """
    Initialize the ModelPool with the given configuration.

    Args:
        modelpool_config (Optional[DictConfig]): The configuration for the model pool.
    """
    super().__init__()
    self.config = modelpool_config

    # check for duplicate model names
    if self.config is not None and self.config.get("models", None) is not None:
        model_names = [model["name"] for model in self.config["models"]]
        assert len(model_names) == len(
            set(model_names)
        ), "Duplicate model names found in model pool"
        self._model_names = model_names

__len__()

Return the number of models in the model pool, exclude special models such as _pretrained_.

Returns:

  • int ( int ) –

    The number of models in the model pool.

Source code in fusion_bench/compat/modelpool/base_pool.py
def __len__(self) -> int:
    """
    Return the number of models in the model pool, exclude special models such as `_pretrained_`.

    Returns:
        int: The number of models in the model pool.
    """
    return len(self.model_names)

get_model_config(model_name)

Retrieves the configuration for a specific model from the model pool.

Parameters:

  • model_name (str) –

    The name of the model for which to retrieve the configuration.

Returns:

  • dict ( Dict ) –

    The configuration dictionary for the specified model.

Raises:

  • ValueError

    If the specified model is not found in the model pool.

Source code in fusion_bench/compat/modelpool/base_pool.py
def get_model_config(self, model_name: str) -> Dict:
    """
    Retrieves the configuration for a specific model from the model pool.

    Args:
        model_name (str): The name of the model for which to retrieve the configuration.

    Returns:
        dict: The configuration dictionary for the specified model.

    Raises:
        ValueError: If the specified model is not found in the model pool.
    """
    for model in self.config["models"]:
        if model["name"] == model_name:
            return model
    raise ValueError(f"Model {model_name} not found in model pool")

get_test_dataset(model_name)

Get the testing dataset for the model.

Parameters:

  • model_name (str) –

    The name of the model for which to get the testing dataset.

Returns:

  • Any

    The testing dataset for the model.

Source code in fusion_bench/compat/modelpool/base_pool.py
def get_test_dataset(self, model_name: str):
    """
    Get the testing dataset for the model.

    Args:
        model_name (str): The name of the model for which to get the testing dataset.

    Returns:
        Any: The testing dataset for the model.
    """
    raise NotImplementedError

get_train_dataset(model_name)

Get the training dataset for the model.

Parameters:

  • model_name (str) –

    The name of the model for which to get the training dataset.

Returns:

  • Any

    The training dataset for the model.

Source code in fusion_bench/compat/modelpool/base_pool.py
def get_train_dataset(self, model_name: str):
    """
    Get the training dataset for the model.

    Args:
        model_name (str): The name of the model for which to get the training dataset.

    Returns:
        Any: The training dataset for the model.
    """
    raise NotImplementedError

load_model(model_config)

The models are load lazily, so this method should be implemented to load the model from the model pool.

Load the model from the model pool.

Parameters:

  • model_config (Union[str, DictConfig]) –

    The configuration dictionary for the model to load.

Returns:

  • Any ( Module ) –

    The loaded model.

Source code in fusion_bench/compat/modelpool/base_pool.py
def load_model(self, model_config: Union[str, DictConfig]) -> nn.Module:
    """
    The models are load lazily, so this method should be implemented to load the model from the model pool.

    Load the model from the model pool.

    Args:
        model_config (Union[str, DictConfig]): The configuration dictionary for the model to load.

    Returns:
        Any: The loaded model.
    """
    raise NotImplementedError

load_pretrained_or_first_model(*args, **kwargs)

Load the pretrained model if available, otherwise load the first model in the list.

This method checks if a pretrained model is available. If it is, it loads the pretrained model. If not, it loads the first model from the list of model names.

Returns:

  • nn.Module: The loaded model.

Source code in fusion_bench/compat/modelpool/base_pool.py
def load_pretrained_or_first_model(self, *args, **kwargs):
    """
    Load the pretrained model if available, otherwise load the first model in the list.

    This method checks if a pretrained model is available. If it is, it loads the pretrained model.
    If not, it loads the first model from the list of model names.

    Returns:
        nn.Module: The loaded model.
    """
    if self.has_pretrained:
        model = self.load_model("_pretrained_", *args, **kwargs)
    else:
        model = self.load_model(self.model_names[0], *args, **kwargs)
    return model

models()

Generator that yields models from the model pool.

Yields:

  • nn.Module: The next model in the model pool.

Source code in fusion_bench/compat/modelpool/base_pool.py
def models(self):
    """
    Generator that yields models from the model pool.

    Yields:
        nn.Module: The next model in the model pool.
    """
    for model_name in self.model_names:
        yield self.load_model(model_name)

named_models()

Generator that yields model names and models from the model pool.

Yields:

  • tuple

    A tuple containing the model name and the model.

Source code in fusion_bench/compat/modelpool/base_pool.py
def named_models(self):
    """
    Generator that yields model names and models from the model pool.

    Yields:
        tuple: A tuple containing the model name and the model.
    """
    for model_name in self.model_names:
        yield model_name, self.load_model(model_name)

save_model(model, path)

Save the state dictionary of the model to the specified path.

Parameters:

  • model (Module) –

    The model whose state dictionary is to be saved.

  • path (str) –

    The path where the state dictionary will be saved.

Source code in fusion_bench/compat/modelpool/base_pool.py
def save_model(self, model: nn.Module, path: str):
    """
    Save the state dictionary of the model to the specified path.

    Args:
        model (nn.Module): The model whose state dictionary is to be saved.
        path (str): The path where the state dictionary will be saved.
    """
    with timeit_context(f"Saving the state dict of model to {path}"):
        torch.save(model.state_dict(), path)

setup_taskpool(taskpool)

Setup the taskpool before evaluation. Such as setting the fabric, processor, tokenizer, etc.

Parameters:

  • taskpool (Any) –

    The taskpool to setup.

Source code in fusion_bench/compat/modelpool/base_pool.py
def setup_taskpool(self, taskpool):
    """
    Setup the taskpool before evaluation.
    Such as setting the fabric, processor, tokenizer, etc.

    Args:
        taskpool (Any): The taskpool to setup.
    """
    pass

to_modeldict()

Convert the model pool to a dictionary of models.

Returns:

  • dict ( Dict[str, Module] ) –

    A dictionary of models.

Source code in fusion_bench/compat/modelpool/base_pool.py
def to_modeldict(self) -> Dict[str, nn.Module]:
    """
    Convert the model pool to a dictionary of models.

    Returns:
        dict: A dictionary of models.
    """
    return {m: self.load_model(m) for m in self.model_names}

to_modellist()

Convert the model pool to a list of models.

Returns:

  • list ( List[Module] ) –

    A list of models.

Source code in fusion_bench/compat/modelpool/base_pool.py
def to_modellist(self) -> List[nn.Module]:
    """
    Convert the model pool to a list of models.

    Returns:
        list: A list of models.
    """
    return [self.load_model(m) for m in self.model_names]

ModelPoolFactory

Factory class to create and manage different model pools.

This class provides methods to create model pools based on a given configuration, register new model pools, and list available model pools.

Source code in fusion_bench/compat/modelpool/__init__.py
class ModelPoolFactory:
    """
    Factory class to create and manage different model pools.

    This class provides methods to create model pools based on a given configuration,
    register new model pools, and list available model pools.
    """

    _modelpool = {
        "NYUv2ModelPool": "fusion_bench.modelpool.nyuv2_modelpool.NYUv2ModelPool",
        "huggingface_clip_vision": HuggingFaceClipVisionPool,
        "HF_GPT2ForSequenceClassification": GPT2ForSequenceClassificationPool,
        "AutoModelPool": ".huggingface_automodel.AutoModelPool",
        # CausualLM
        "AutoModelForCausalLMPool": ".huggingface_llm.AutoModelForCausalLMPool",
        "LLamaForCausalLMPool": ".huggingface_llm.LLamaForCausalLMPool",
        "MistralForCausalLMPool": ".huggingface_llm.MistralForCausalLMPool",
        # Seq2SeqLM
        "AutoModelForSeq2SeqLMPool": AutoModelForSeq2SeqLMPool,
        "PeftModelForSeq2SeqLMPool": PeftModelForSeq2SeqLMPool,
    }

    @staticmethod
    def create_modelpool(modelpool_config: DictConfig) -> ModelPool:
        """
        Create an instance of a model pool based on the provided configuration.
        This is for v0.1.x versions, deprecated.
        For implementing new model pool, use `fusion_bench.modelpool.BaseModelPool` instead.

        Args:
            modelpool_config (DictConfig): The configuration for the model pool.
            Must contain a 'type' attribute that specifies the type of the model pool.

        Returns:
            ModelPool: An instance of the specified model pool.

        Raises:
            ValueError: If 'type' attribute is not found in the configuration or does not match any known model pool types.
        """
        warnings.warn(
            "ModelPoolFactory.create_modelpool() is deprecated and will be removed in future versions. "
            "Please implement new model pool using `fusion_bench.modelpool.BaseModelPool` instead.",
            DeprecationWarning,
        )

        from fusion_bench.utils import import_object

        modelpool_type = modelpool_config.get("type")
        if modelpool_type is None:
            raise ValueError("Model pool type not specified")

        if modelpool_type not in ModelPoolFactory._modelpool:
            raise ValueError(
                f"Unknown model pool: {modelpool_type}, available model pools: {ModelPoolFactory._modelpool.keys()}. You can register a new model pool using `ModelPoolFactory.register_modelpool()` method."
            )
        modelpool_cls = ModelPoolFactory._modelpool[modelpool_type]
        if isinstance(modelpool_cls, str):
            if modelpool_cls.startswith("."):
                modelpool_cls = f"fusion_bench.compat.modelpool.{modelpool_cls[1:]}"
            modelpool_cls = import_object(modelpool_cls)
        return modelpool_cls(modelpool_config)

    @staticmethod
    def register_modelpool(name: str, modelpool_cls):
        """
        Register a new model pool with the factory.

        Args:
            name (str): The name of the model pool.
            modelpool_cls: The class of the model pool to register.
        """
        ModelPoolFactory._modelpool[name] = modelpool_cls

    @classmethod
    def available_modelpools(cls):
        """
        Get a list of available model pools.

        Returns:
            list: A list of available model pool names.
        """
        return list(cls._modelpool.keys())

available_modelpools() classmethod

Get a list of available model pools.

Returns:

  • list

    A list of available model pool names.

Source code in fusion_bench/compat/modelpool/__init__.py
@classmethod
def available_modelpools(cls):
    """
    Get a list of available model pools.

    Returns:
        list: A list of available model pool names.
    """
    return list(cls._modelpool.keys())

create_modelpool(modelpool_config) staticmethod

Create an instance of a model pool based on the provided configuration. This is for v0.1.x versions, deprecated. For implementing new model pool, use fusion_bench.modelpool.BaseModelPool instead.

Parameters:

  • modelpool_config (DictConfig) –

    The configuration for the model pool.

Returns:

  • ModelPool ( ModelPool ) –

    An instance of the specified model pool.

Raises:

  • ValueError

    If 'type' attribute is not found in the configuration or does not match any known model pool types.

Source code in fusion_bench/compat/modelpool/__init__.py
@staticmethod
def create_modelpool(modelpool_config: DictConfig) -> ModelPool:
    """
    Create an instance of a model pool based on the provided configuration.
    This is for v0.1.x versions, deprecated.
    For implementing new model pool, use `fusion_bench.modelpool.BaseModelPool` instead.

    Args:
        modelpool_config (DictConfig): The configuration for the model pool.
        Must contain a 'type' attribute that specifies the type of the model pool.

    Returns:
        ModelPool: An instance of the specified model pool.

    Raises:
        ValueError: If 'type' attribute is not found in the configuration or does not match any known model pool types.
    """
    warnings.warn(
        "ModelPoolFactory.create_modelpool() is deprecated and will be removed in future versions. "
        "Please implement new model pool using `fusion_bench.modelpool.BaseModelPool` instead.",
        DeprecationWarning,
    )

    from fusion_bench.utils import import_object

    modelpool_type = modelpool_config.get("type")
    if modelpool_type is None:
        raise ValueError("Model pool type not specified")

    if modelpool_type not in ModelPoolFactory._modelpool:
        raise ValueError(
            f"Unknown model pool: {modelpool_type}, available model pools: {ModelPoolFactory._modelpool.keys()}. You can register a new model pool using `ModelPoolFactory.register_modelpool()` method."
        )
    modelpool_cls = ModelPoolFactory._modelpool[modelpool_type]
    if isinstance(modelpool_cls, str):
        if modelpool_cls.startswith("."):
            modelpool_cls = f"fusion_bench.compat.modelpool.{modelpool_cls[1:]}"
        modelpool_cls = import_object(modelpool_cls)
    return modelpool_cls(modelpool_config)

register_modelpool(name, modelpool_cls) staticmethod

Register a new model pool with the factory.

Parameters:

  • name (str) –

    The name of the model pool.

  • modelpool_cls

    The class of the model pool to register.

Source code in fusion_bench/compat/modelpool/__init__.py
@staticmethod
def register_modelpool(name: str, modelpool_cls):
    """
    Register a new model pool with the factory.

    Args:
        name (str): The name of the model pool.
        modelpool_cls: The class of the model pool to register.
    """
    ModelPoolFactory._modelpool[name] = modelpool_cls

Task Pool

TaskPool

A class to manage a pool of tasks for evaluation. This is the base class for version 0.1.x, deprecated. Use fusion_bench.taskpool.BaseTaskPool instead.

Attributes:

  • config (DictConfig) –

    The configuration for the task pool.

  • _all_task_names (List[str]) –

    A list of all task names in the task pool.

Source code in fusion_bench/compat/taskpool/base_pool.py
class TaskPool:
    """
    A class to manage a pool of tasks for evaluation.
    This is the base class for version 0.1.x, deprecated.
    Use `fusion_bench.taskpool.BaseTaskPool` instead.

    Attributes:
        config (DictConfig): The configuration for the task pool.
        _all_task_names (List[str]): A list of all task names in the task pool.
    """

    _program = None

    def __init__(self, taskpool_config: DictConfig):
        """
        Initialize the TaskPool with the given configuration.

        Args:
            taskpool_config (DictConfig): The configuration for the task pool.
        """
        super().__init__()
        self.config = taskpool_config

        # Check for duplicate task names
        if self.config.get("tasks", None) is not None:
            task_names = [task["name"] for task in self.config["tasks"]]
            assert len(task_names) == len(
                set(task_names)
            ), "Duplicate task names found in the task pool"
            self._all_task_names = task_names

    def evaluate(self, model):
        """
        Evaluate the model on all tasks in the task pool, and return a report.

        Take image classification as an example, the report will look like:

        ```python
        {
            "mnist": {
                "accuracy": 0.8,
                "loss": 0.2,
            },
            <task_name>: {
                <metric_name>: <metric_value>,
                ...
            },
        }
        ```

        Args:
            model: The model to evaluate.

        Returns:
            report (dict): A dictionary containing the results of the evaluation for each task.
        """
        report = {}
        for task_name in tqdm(self.task_names, desc="Evaluating tasks"):
            task = self.load_task(task_name)
            result = task.evaluate(model)
            report[task_name] = result
        return report

    @property
    def task_names(self):
        """
        Return a list of all task names in the task pool.

        Returns:
            List[str]: A list of all task names.
        """
        return self._all_task_names

    def get_task_config(self, task_name: str):
        """
        Retrieve the configuration for a specific task from the task pool.

        Args:
            task_name (str): The name of the task for which to retrieve the configuration.

        Returns:
            DictConfig: The configuration dictionary for the specified task.

        Raises:
            ValueError: If the specified task is not found in the task pool.
        """
        for task in self.config["tasks"]:
            if task["name"] == task_name:
                return task
        raise ValueError(f"Task {task_name} not found in the task pool")

    def load_task(self, task_name_or_config: Union[str, DictConfig]):
        """
        Load a task from the task pool.

        Args:
            task_name_or_config (Union[str, DictConfig]): The name or configuration of the task to load.

        Returns:
            Any: The loaded task.

        Raises:
            NotImplementedError: If the method is not implemented in the subclass.
        """
        raise NotImplementedError

task_names property

Return a list of all task names in the task pool.

Returns:

  • List[str]: A list of all task names.

__init__(taskpool_config)

Initialize the TaskPool with the given configuration.

Parameters:

  • taskpool_config (DictConfig) –

    The configuration for the task pool.

Source code in fusion_bench/compat/taskpool/base_pool.py
def __init__(self, taskpool_config: DictConfig):
    """
    Initialize the TaskPool with the given configuration.

    Args:
        taskpool_config (DictConfig): The configuration for the task pool.
    """
    super().__init__()
    self.config = taskpool_config

    # Check for duplicate task names
    if self.config.get("tasks", None) is not None:
        task_names = [task["name"] for task in self.config["tasks"]]
        assert len(task_names) == len(
            set(task_names)
        ), "Duplicate task names found in the task pool"
        self._all_task_names = task_names

evaluate(model)

Evaluate the model on all tasks in the task pool, and return a report.

Take image classification as an example, the report will look like:

{
    "mnist": {
        "accuracy": 0.8,
        "loss": 0.2,
    },
    <task_name>: {
        <metric_name>: <metric_value>,
        ...
    },
}

Parameters:

  • model

    The model to evaluate.

Returns:

  • report ( dict ) –

    A dictionary containing the results of the evaluation for each task.

Source code in fusion_bench/compat/taskpool/base_pool.py
def evaluate(self, model):
    """
    Evaluate the model on all tasks in the task pool, and return a report.

    Take image classification as an example, the report will look like:

    ```python
    {
        "mnist": {
            "accuracy": 0.8,
            "loss": 0.2,
        },
        <task_name>: {
            <metric_name>: <metric_value>,
            ...
        },
    }
    ```

    Args:
        model: The model to evaluate.

    Returns:
        report (dict): A dictionary containing the results of the evaluation for each task.
    """
    report = {}
    for task_name in tqdm(self.task_names, desc="Evaluating tasks"):
        task = self.load_task(task_name)
        result = task.evaluate(model)
        report[task_name] = result
    return report

get_task_config(task_name)

Retrieve the configuration for a specific task from the task pool.

Parameters:

  • task_name (str) –

    The name of the task for which to retrieve the configuration.

Returns:

  • DictConfig

    The configuration dictionary for the specified task.

Raises:

  • ValueError

    If the specified task is not found in the task pool.

Source code in fusion_bench/compat/taskpool/base_pool.py
def get_task_config(self, task_name: str):
    """
    Retrieve the configuration for a specific task from the task pool.

    Args:
        task_name (str): The name of the task for which to retrieve the configuration.

    Returns:
        DictConfig: The configuration dictionary for the specified task.

    Raises:
        ValueError: If the specified task is not found in the task pool.
    """
    for task in self.config["tasks"]:
        if task["name"] == task_name:
            return task
    raise ValueError(f"Task {task_name} not found in the task pool")

load_task(task_name_or_config)

Load a task from the task pool.

Parameters:

  • task_name_or_config (Union[str, DictConfig]) –

    The name or configuration of the task to load.

Returns:

  • Any

    The loaded task.

Raises:

  • NotImplementedError

    If the method is not implemented in the subclass.

Source code in fusion_bench/compat/taskpool/base_pool.py
def load_task(self, task_name_or_config: Union[str, DictConfig]):
    """
    Load a task from the task pool.

    Args:
        task_name_or_config (Union[str, DictConfig]): The name or configuration of the task to load.

    Returns:
        Any: The loaded task.

    Raises:
        NotImplementedError: If the method is not implemented in the subclass.
    """
    raise NotImplementedError

TaskPoolFactory

Factory class to create and manage different task pools. This is for v0.1.x versions, deprecated. For implementing new task pool, use fusion_bench.taskpool.BaseTaskPool instead.

This class provides methods to create task pools based on a given configuration, register new task pools, and list available task pools.

Source code in fusion_bench/compat/taskpool/__init__.py
class TaskPoolFactory:
    """
    Factory class to create and manage different task pools.
    This is for v0.1.x versions, deprecated.
    For implementing new task pool, use `fusion_bench.taskpool.BaseTaskPool` instead.

    This class provides methods to create task pools based on a given configuration,
    register new task pools, and list available task pools.
    """

    _taskpool_types = {
        "dummy": DummyTaskPool,
        "clip_vit_classification": ".clip_image_classification.CLIPImageClassificationTaskPool",
        "FlanT5GLUETextGenerationTaskPool": ".flan_t5_glue_text_generation.FlanT5GLUETextGenerationTaskPool",
        "NYUv2TaskPool": "fusion_bench.taskpool.nyuv2_taskpool.NYUv2TaskPool",
    }

    @staticmethod
    def create_taskpool(taskpool_config: DictConfig):
        """
        Create an instance of a task pool based on the provided configuration.

        Args:
            taskpool_config (DictConfig): The configuration for the task pool. Must contain a 'type' attribute that specifies the type of the task pool.

        Returns:
            TaskPool: An instance of the specified task pool.

        Raises:
            ValueError: If 'type' attribute is not found in the configuration or does not match any known task pool types.
        """
        from fusion_bench.utils import import_object

        taskpool_type = taskpool_config.get("type")
        if taskpool_type is None:
            raise ValueError("Task pool type not specified")

        if taskpool_type not in TaskPoolFactory._taskpool_types:
            raise ValueError(
                f"Unknown task pool: {taskpool_type}, available task pools: {TaskPoolFactory._taskpool_types.keys()}. You can register a new task pool using `TaskPoolFactory.register_taskpool()` method."
            )
        taskpool_cls = TaskPoolFactory._taskpool_types[taskpool_type]
        if isinstance(taskpool_cls, str):
            if taskpool_cls.startswith("."):
                taskpool_cls = f"fusion_bench.compat.taskpool.{taskpool_cls[1:]}"
            taskpool_cls = import_object(taskpool_cls)
        return taskpool_cls(taskpool_config)

    @staticmethod
    def register_taskpool(name: str, taskpool_cls):
        """
        Register a new task pool with the factory.

        Args:
            name (str): The name of the task pool.
            taskpool_cls: The class of the task pool to register.
        """
        TaskPoolFactory._taskpool_types[name] = taskpool_cls

    @classmethod
    def available_taskpools(cls):
        """
        Get a list of available task pools.

        Returns:
            list: A list of available task pool names.
        """
        return list(cls._taskpool_types.keys())

available_taskpools() classmethod

Get a list of available task pools.

Returns:

  • list

    A list of available task pool names.

Source code in fusion_bench/compat/taskpool/__init__.py
@classmethod
def available_taskpools(cls):
    """
    Get a list of available task pools.

    Returns:
        list: A list of available task pool names.
    """
    return list(cls._taskpool_types.keys())

create_taskpool(taskpool_config) staticmethod

Create an instance of a task pool based on the provided configuration.

Parameters:

  • taskpool_config (DictConfig) –

    The configuration for the task pool. Must contain a 'type' attribute that specifies the type of the task pool.

Returns:

  • TaskPool

    An instance of the specified task pool.

Raises:

  • ValueError

    If 'type' attribute is not found in the configuration or does not match any known task pool types.

Source code in fusion_bench/compat/taskpool/__init__.py
@staticmethod
def create_taskpool(taskpool_config: DictConfig):
    """
    Create an instance of a task pool based on the provided configuration.

    Args:
        taskpool_config (DictConfig): The configuration for the task pool. Must contain a 'type' attribute that specifies the type of the task pool.

    Returns:
        TaskPool: An instance of the specified task pool.

    Raises:
        ValueError: If 'type' attribute is not found in the configuration or does not match any known task pool types.
    """
    from fusion_bench.utils import import_object

    taskpool_type = taskpool_config.get("type")
    if taskpool_type is None:
        raise ValueError("Task pool type not specified")

    if taskpool_type not in TaskPoolFactory._taskpool_types:
        raise ValueError(
            f"Unknown task pool: {taskpool_type}, available task pools: {TaskPoolFactory._taskpool_types.keys()}. You can register a new task pool using `TaskPoolFactory.register_taskpool()` method."
        )
    taskpool_cls = TaskPoolFactory._taskpool_types[taskpool_type]
    if isinstance(taskpool_cls, str):
        if taskpool_cls.startswith("."):
            taskpool_cls = f"fusion_bench.compat.taskpool.{taskpool_cls[1:]}"
        taskpool_cls = import_object(taskpool_cls)
    return taskpool_cls(taskpool_config)

register_taskpool(name, taskpool_cls) staticmethod

Register a new task pool with the factory.

Parameters:

  • name (str) –

    The name of the task pool.

  • taskpool_cls

    The class of the task pool to register.

Source code in fusion_bench/compat/taskpool/__init__.py
@staticmethod
def register_taskpool(name: str, taskpool_cls):
    """
    Register a new task pool with the factory.

    Args:
        name (str): The name of the task pool.
        taskpool_cls: The class of the task pool to register.
    """
    TaskPoolFactory._taskpool_types[name] = taskpool_cls