Skip to content

Model Mixing

Layer-level Mixing

Depth Upscaling

DepthUpscalingAlgorithm

Bases: BaseAlgorithm

Implements the Depth Upscaling Algorithm.

  • Kim et al. SOLAR 10.7B: Scaling Large Language Models with Simple yet Effective Depth Up-Scaling. http://arxiv.org/abs/2312.15166

This class extends the BaseModelFusionAlgorithm to handle depth upscaling of models. It supports upscaling the depth of a model by duplicating specified layers.

Parameters:

  • layer_indices (list) –

    List of layer indices to duplicate.

  • **kwargs –

    Additional keyword arguments.

Source code in fusion_bench/method/depth_upscaling/depth_upscaling.py
class DepthUpscalingAlgorithm(BaseAlgorithm):
    R"""
    Implements the Depth Upscaling Algorithm.

    - Kim et al. SOLAR 10.7B: Scaling Large Language Models with Simple yet Effective Depth Up-Scaling. http://arxiv.org/abs/2312.15166

    This class extends the `BaseModelFusionAlgorithm` to handle depth upscaling of models.
    It supports upscaling the depth of a model by duplicating specified layers.

    Args:
        layer_indices (list): List of layer indices to duplicate.
        **kwargs: Additional keyword arguments.
    """

    _config_mapping = BaseAlgorithm._config_mapping | {
        "layer_indices": "layer_indices",
    }

    def __init__(self, layer_indices: Union[str, List[int]], **kwargs):
        self.layer_indices = layer_indices
        super().__init__(**kwargs)

    @torch.no_grad()
    def run(self, modelpool: nn.ModuleList | BaseModelPool) -> nn.ModuleList:
        """
        Executes the depth upscaling algorithm on a given model pool.

        This method checks the type of the model pool, ensures that it contains only one model, and verifies that the model is an instance of `nn.ModuleList`.

        Args:
            modelpool (nn.ModuleList | ModelPool): The pool of models to upscale. Must contain only one model.

        Returns:
            nn.ModuleList: The upscaled model.

        Raises:
            AssertionError: If the model pool contains more than one model or if the model is not an instance of `nn.ModuleList`.
            ValueError: If an invalid layer specification is provided in the configuration.
        """
        # check the modelpool type
        if isinstance(modelpool, BaseModelPool):
            assert len(modelpool) == 1, "DepthUpscaling only support one model"
            model = modelpool.load_model(modelpool.model_names[0])
            assert isinstance(
                model, nn.ModuleList
            ), f"The model should be a `nn.ModuleList`, but got {type(model)}"
        elif isinstance(modelpool, nn.ModuleList):
            model = modelpool
        else:
            raise AssertionError(
                f"Invalid modelpool type: {type(modelpool)}. Expected `ModelPool` or `nn.ModuleList`."
            )

        # parse the layers
        layer_indices = self.layer_indices
        parsed_layer_indices = []
        for layer in layer_indices:
            if isinstance(layer, int):
                parsed_layer_indices.append(layer)
            elif isinstance(layer, str):
                parsed_layer_indices.extend(eval(layer))
            else:
                raise ValueError("Invalid layer specification: {}".format(layer))

        # create a new model with the specified layers
        new_model = nn.ModuleList(
            [
                deepcopy(model[i])
                for i in tqdm(
                    parsed_layer_indices, desc="constructing depth-upscaled model"
                )
            ]
        )

        return new_model
run(modelpool)

Executes the depth upscaling algorithm on a given model pool.

This method checks the type of the model pool, ensures that it contains only one model, and verifies that the model is an instance of nn.ModuleList.

Parameters:

  • modelpool (ModuleList | ModelPool) –

    The pool of models to upscale. Must contain only one model.

Returns:

  • ModuleList –

    nn.ModuleList: The upscaled model.

Raises:

  • AssertionError –

    If the model pool contains more than one model or if the model is not an instance of nn.ModuleList.

  • ValueError –

    If an invalid layer specification is provided in the configuration.

Source code in fusion_bench/method/depth_upscaling/depth_upscaling.py
@torch.no_grad()
def run(self, modelpool: nn.ModuleList | BaseModelPool) -> nn.ModuleList:
    """
    Executes the depth upscaling algorithm on a given model pool.

    This method checks the type of the model pool, ensures that it contains only one model, and verifies that the model is an instance of `nn.ModuleList`.

    Args:
        modelpool (nn.ModuleList | ModelPool): The pool of models to upscale. Must contain only one model.

    Returns:
        nn.ModuleList: The upscaled model.

    Raises:
        AssertionError: If the model pool contains more than one model or if the model is not an instance of `nn.ModuleList`.
        ValueError: If an invalid layer specification is provided in the configuration.
    """
    # check the modelpool type
    if isinstance(modelpool, BaseModelPool):
        assert len(modelpool) == 1, "DepthUpscaling only support one model"
        model = modelpool.load_model(modelpool.model_names[0])
        assert isinstance(
            model, nn.ModuleList
        ), f"The model should be a `nn.ModuleList`, but got {type(model)}"
    elif isinstance(modelpool, nn.ModuleList):
        model = modelpool
    else:
        raise AssertionError(
            f"Invalid modelpool type: {type(modelpool)}. Expected `ModelPool` or `nn.ModuleList`."
        )

    # parse the layers
    layer_indices = self.layer_indices
    parsed_layer_indices = []
    for layer in layer_indices:
        if isinstance(layer, int):
            parsed_layer_indices.append(layer)
        elif isinstance(layer, str):
            parsed_layer_indices.extend(eval(layer))
        else:
            raise ValueError("Invalid layer specification: {}".format(layer))

    # create a new model with the specified layers
    new_model = nn.ModuleList(
        [
            deepcopy(model[i])
            for i in tqdm(
                parsed_layer_indices, desc="constructing depth-upscaled model"
            )
        ]
    )

    return new_model

DepthUpscalingForLlama

Bases: DepthUpscalingAlgorithm

Implements depth upscaling for Llama models.

This class extends the DepthUpscalingAlgorithm to handle Llama models specifically. It supports saving the upscaled model to a specified path.

Parameters:

  • layer_indices (list) –

    List of layer indices to upscale.

  • model_save_path (Optional[str]) –

    Path to save the upscaled model.

  • **kwargs –

    Additional keyword arguments.

Source code in fusion_bench/method/depth_upscaling/depth_upscaling_for_llama.py
class DepthUpscalingForLlama(DepthUpscalingAlgorithm):
    """
    Implements depth upscaling for Llama models.

    This class extends the DepthUpscalingAlgorithm to handle Llama models specifically.
    It supports saving the upscaled model to a specified path.

    Args:
        layer_indices (list): List of layer indices to upscale.
        model_save_path (Optional[str]): Path to save the upscaled model.
        **kwargs: Additional keyword arguments.
    """

    def __init__(self, layer_indices: list, model_save_path: Optional[str], **kwargs):
        if isinstance(model_save_path, str):
            model_save_path = os.path.expanduser(model_save_path)
        self.model_save_path = model_save_path
        super().__init__(layer_indices, **kwargs)

    @override
    def run(self, modelpool: CausalLMPool):
        """
        Executes the depth upscaling algorithm on a given model pool.

        This method loads the pretrained model or the first model in the pool,
        applies the depth upscaling algorithm, and updates the number of hidden layers in the model configuration.
        If a save path is provided, it saves the upscaled model and tokenizer to the specified path.

        Args:
            modelpool (CausalLMPool): The pool of models to upscale.

        Returns:
            CausalLM: The upscaled model.
        """
        if self.model_save_path is not None:
            tokenizer = modelpool.load_tokenizer()

        model: PreTrainedModel = modelpool.load_pretrained_or_first_model()
        model.model.layers = super().run(model.model.layers)
        model.config.num_hidden_layers = len(model.model.layers)

        if self.model_save_path is not None:
            with timeit_context(f"Saving the model to {self.model_save_path}"):
                tokenizer.save_pretrained(self.model_save_path)
                model.save_pretrained(self.model_save_path)
        return model
run(modelpool)

Executes the depth upscaling algorithm on a given model pool.

This method loads the pretrained model or the first model in the pool, applies the depth upscaling algorithm, and updates the number of hidden layers in the model configuration. If a save path is provided, it saves the upscaled model and tokenizer to the specified path.

Parameters:

  • modelpool (CausalLMPool) –

    The pool of models to upscale.

Returns:

  • CausalLM –

    The upscaled model.

Source code in fusion_bench/method/depth_upscaling/depth_upscaling_for_llama.py
@override
def run(self, modelpool: CausalLMPool):
    """
    Executes the depth upscaling algorithm on a given model pool.

    This method loads the pretrained model or the first model in the pool,
    applies the depth upscaling algorithm, and updates the number of hidden layers in the model configuration.
    If a save path is provided, it saves the upscaled model and tokenizer to the specified path.

    Args:
        modelpool (CausalLMPool): The pool of models to upscale.

    Returns:
        CausalLM: The upscaled model.
    """
    if self.model_save_path is not None:
        tokenizer = modelpool.load_tokenizer()

    model: PreTrainedModel = modelpool.load_pretrained_or_first_model()
    model.model.layers = super().run(model.model.layers)
    model.config.num_hidden_layers = len(model.model.layers)

    if self.model_save_path is not None:
        with timeit_context(f"Saving the model to {self.model_save_path}"):
            tokenizer.save_pretrained(self.model_save_path)
            model.save_pretrained(self.model_save_path)
    return model

Model Recombination

ModelRecombinationAlgorithm

Bases: BaseAlgorithm

Model recombination recombinates the layers of the given models, to create a new set of models.

Source code in fusion_bench/method/model_recombination.py
class ModelRecombinationAlgorithm(BaseAlgorithm):
    """
    Model recombination recombinates the layers of the given models, to create a new set of models.
    """

    _config_mapping = BaseAlgorithm._config_mapping | {
        "return_modelpool": "return_modelpool",
    }

    def __init__(self, return_modelpool: bool, **kwargs):
        self.return_modelpool = return_modelpool
        super().__init__(**kwargs)

    @torch.no_grad()
    def run(
        self,
        modelpool: BaseModelPool,
        return_modelpool: bool = True,
    ) -> Union[nn.Module, BaseModelPool]:
        """
        Executes the model recombination algorithm on a given model pool.

        This method loads models from the model pool, determines their type, and applies the appropriate recombination method.
        It then creates a new model pool with the recombined models. Depending on the `return_modelpool` flag, it either returns
        the entire new model pool or just the first model from it.

        - If the models in the model pool are of type `nn.ModuleList`, the recombination method `recombine_modellist` is used. Where each module in the list is shuffled across the models.
        - If the models are of type `nn.ModuleDict`, the recombination method `recombine_modeldict` is used. Where each module in the dictionary is shuffled across the models.
        - If the models are of type `nn.Module`, the recombination method `recombine_state_dict` is used. Where the state dictionaries of the models are shuffled across the models.

        Args:
            modelpool (BaseModelPool): The pool of models to recombine.
            return_modelpool (bool, optional): Flag indicating whether to return the entire model pool or just the first model. Defaults to True. If this algorithm is initialized with config, the value of `return_modelpool` in the config will be used and this argument passed to the method will be ignored.

        Returns:
            Union[nn.Module, BaseModelPool]: The recombined model pool or the first model from the recombined pool, depending on the `return_modelpool` flag.

        Raises:
            ValueError: If the models in the model pool are of an unsupported type.
        """
        # If the config has a return_modelpool flag, use that, otherwise use the argument
        if self.config.get("return_modelpool", None) is not None:
            return_modelpool = self.config.return_modelpool
        # check the modelpool type
        if not isinstance(modelpool, BaseModelPool):
            modelpool = BaseModelPool(modelpool)

        log.info(f"Running model recombination algorithm with {len(modelpool)} models")

        # TODO: optimize the `recombine_*` functions, if `return_modelpool` is False, we don't need to create the new modelpool, just the first model
        models = [modelpool.load_model(m) for m in modelpool.model_names]
        if isinstance(models[0], nn.ModuleList):
            new_models = recombine_modellist(models)
        elif isinstance(models[0], nn.ModuleDict):
            new_models = recombine_modeldict(models)
        elif isinstance(models[0], nn.Module):
            new_models = recombine_state_dict(models)
        else:
            raise ValueError(f"Unsupported model type {type(models[0])}")

        new_modelpool = BaseModelPool(
            {n: m for n, m in zip(modelpool.model_names, new_models)}
        )
        if return_modelpool:
            return new_modelpool
        else:
            return new_modelpool.load_model(new_modelpool.model_names[0])
run(modelpool, return_modelpool=True)

Executes the model recombination algorithm on a given model pool.

This method loads models from the model pool, determines their type, and applies the appropriate recombination method. It then creates a new model pool with the recombined models. Depending on the return_modelpool flag, it either returns the entire new model pool or just the first model from it.

  • If the models in the model pool are of type nn.ModuleList, the recombination method recombine_modellist is used. Where each module in the list is shuffled across the models.
  • If the models are of type nn.ModuleDict, the recombination method recombine_modeldict is used. Where each module in the dictionary is shuffled across the models.
  • If the models are of type nn.Module, the recombination method recombine_state_dict is used. Where the state dictionaries of the models are shuffled across the models.

Parameters:

  • modelpool (BaseModelPool) –

    The pool of models to recombine.

  • return_modelpool (bool, default: True ) –

    Flag indicating whether to return the entire model pool or just the first model. Defaults to True. If this algorithm is initialized with config, the value of return_modelpool in the config will be used and this argument passed to the method will be ignored.

Returns:

  • Union[Module, BaseModelPool] –

    Union[nn.Module, BaseModelPool]: The recombined model pool or the first model from the recombined pool, depending on the return_modelpool flag.

Raises:

  • ValueError –

    If the models in the model pool are of an unsupported type.

Source code in fusion_bench/method/model_recombination.py
@torch.no_grad()
def run(
    self,
    modelpool: BaseModelPool,
    return_modelpool: bool = True,
) -> Union[nn.Module, BaseModelPool]:
    """
    Executes the model recombination algorithm on a given model pool.

    This method loads models from the model pool, determines their type, and applies the appropriate recombination method.
    It then creates a new model pool with the recombined models. Depending on the `return_modelpool` flag, it either returns
    the entire new model pool or just the first model from it.

    - If the models in the model pool are of type `nn.ModuleList`, the recombination method `recombine_modellist` is used. Where each module in the list is shuffled across the models.
    - If the models are of type `nn.ModuleDict`, the recombination method `recombine_modeldict` is used. Where each module in the dictionary is shuffled across the models.
    - If the models are of type `nn.Module`, the recombination method `recombine_state_dict` is used. Where the state dictionaries of the models are shuffled across the models.

    Args:
        modelpool (BaseModelPool): The pool of models to recombine.
        return_modelpool (bool, optional): Flag indicating whether to return the entire model pool or just the first model. Defaults to True. If this algorithm is initialized with config, the value of `return_modelpool` in the config will be used and this argument passed to the method will be ignored.

    Returns:
        Union[nn.Module, BaseModelPool]: The recombined model pool or the first model from the recombined pool, depending on the `return_modelpool` flag.

    Raises:
        ValueError: If the models in the model pool are of an unsupported type.
    """
    # If the config has a return_modelpool flag, use that, otherwise use the argument
    if self.config.get("return_modelpool", None) is not None:
        return_modelpool = self.config.return_modelpool
    # check the modelpool type
    if not isinstance(modelpool, BaseModelPool):
        modelpool = BaseModelPool(modelpool)

    log.info(f"Running model recombination algorithm with {len(modelpool)} models")

    # TODO: optimize the `recombine_*` functions, if `return_modelpool` is False, we don't need to create the new modelpool, just the first model
    models = [modelpool.load_model(m) for m in modelpool.model_names]
    if isinstance(models[0], nn.ModuleList):
        new_models = recombine_modellist(models)
    elif isinstance(models[0], nn.ModuleDict):
        new_models = recombine_modeldict(models)
    elif isinstance(models[0], nn.Module):
        new_models = recombine_state_dict(models)
    else:
        raise ValueError(f"Unsupported model type {type(models[0])}")

    new_modelpool = BaseModelPool(
        {n: m for n, m in zip(modelpool.model_names, new_models)}
    )
    if return_modelpool:
        return new_modelpool
    else:
        return new_modelpool.load_model(new_modelpool.model_names[0])

MoE-based Mixing

MoE Upscaling

MixtralUpscalingAlgorithm

Bases: BaseAlgorithm

This class is responsible for upscaling a model to a MixtralModel. It inherits from the ModelFusionAlgorithm class.

Source code in fusion_bench/method/mixture_of_experts/mixtral_upcycling.py
class MixtralUpscalingAlgorithm(BaseAlgorithm):
    """
    This class is responsible for upscaling a model to a MixtralModel.
    It inherits from the ModelFusionAlgorithm class.
    """

    _config_mapping = BaseAlgorithm._config_mapping | {
        "num_experts": "num_experts",
        "experts_per_token": "experts_per_token",
        "save_checkpoint": "save_checkpoint",
    }

    def __init__(
        self,
        num_experts: int,
        experts_per_token: int,
        save_checkpoint: str,
        **kwargs,
    ):
        """
        Initialize the MixtralUpscalingAlgorithm.

        Args:
            num_experts (int): The number of experts in the Mixtral model.
            experts_per_token (int): The number of experts per token.
            save_checkpoint (str): The path to save the checkpoint.
            **kwargs: Additional keyword arguments.
        """
        self.num_experts = num_experts
        self.experts_per_token = experts_per_token
        self.save_checkpoint = save_checkpoint
        super().__init__(**kwargs)

    @torch.no_grad()
    def _run(
        self, modelpool: BaseModelPool | LlamaModel | MistralModel
    ) -> MixtralModel:
        """
        Internal method to run the upscaling process.

        Args:
            modelpool (BaseModelPool | LlamaModel | MistralModel): The model to be upscaled.

        Returns:
            MixtralModel: The upscaled model.
        """
        if isinstance(modelpool, BaseModelPool):
            assert modelpool.has_pretrained, "ModelPool must have pretrained model."
            pretrained_model = modelpool.load_model("_pretrained_")
        elif isinstance(modelpool, (LlamaModel, MistralModel)):
            pretrained_model = modelpool
        else:
            raise ValueError("Invalid modelpool type")

        mixtral_config = _convert_config_to_mixtral(
            pretrained_model.config,
            self.config.num_experts,
            self.config.experts_per_token,
        )

        with ContextManagers([no_init_weights(True)]):
            for _ in tqdm(range(1), desc="Initializing Mixtral model"):
                mixtral_model = MixtralModel(mixtral_config)
        upscale_to_mixtral_model(pretrained_model, mixtral_model)

        return mixtral_model

    @torch.no_grad()
    def run(self, modelpool: BaseModelPool | LlamaModel | MistralModel) -> MixtralModel:
        """
        Runs the upscaling process.

        Args:
            modelpool (ModelPool | LlamaModel | MistralModel): The model to be upscaled.

        Returns:
            MixtralModel: The upscaled model.
        """
        mixtral_model = self._run(modelpool)

        if self.config.get("save_checkpoint", None) is not None:
            mixtral_model.save_pretrained(self.config.save_checkpoint)
        return mixtral_model
__init__(num_experts, experts_per_token, save_checkpoint, **kwargs)

Initialize the MixtralUpscalingAlgorithm.

Parameters:

  • num_experts (int) –

    The number of experts in the Mixtral model.

  • experts_per_token (int) –

    The number of experts per token.

  • save_checkpoint (str) –

    The path to save the checkpoint.

  • **kwargs –

    Additional keyword arguments.

Source code in fusion_bench/method/mixture_of_experts/mixtral_upcycling.py
def __init__(
    self,
    num_experts: int,
    experts_per_token: int,
    save_checkpoint: str,
    **kwargs,
):
    """
    Initialize the MixtralUpscalingAlgorithm.

    Args:
        num_experts (int): The number of experts in the Mixtral model.
        experts_per_token (int): The number of experts per token.
        save_checkpoint (str): The path to save the checkpoint.
        **kwargs: Additional keyword arguments.
    """
    self.num_experts = num_experts
    self.experts_per_token = experts_per_token
    self.save_checkpoint = save_checkpoint
    super().__init__(**kwargs)
run(modelpool)

Runs the upscaling process.

Parameters:

  • modelpool (ModelPool | LlamaModel | MistralModel) –

    The model to be upscaled.

Returns:

  • MixtralModel ( MixtralModel ) –

    The upscaled model.

Source code in fusion_bench/method/mixture_of_experts/mixtral_upcycling.py
@torch.no_grad()
def run(self, modelpool: BaseModelPool | LlamaModel | MistralModel) -> MixtralModel:
    """
    Runs the upscaling process.

    Args:
        modelpool (ModelPool | LlamaModel | MistralModel): The model to be upscaled.

    Returns:
        MixtralModel: The upscaled model.
    """
    mixtral_model = self._run(modelpool)

    if self.config.get("save_checkpoint", None) is not None:
        mixtral_model.save_pretrained(self.config.save_checkpoint)
    return mixtral_model

MixtralForCausalLMUpscalingAlgorithm

Bases: BaseAlgorithm

This class is responsible for upscaling a model to a MixtralForCausalLM. It inherits from the ModelFusionAlgorithm class.

Source code in fusion_bench/method/mixture_of_experts/mixtral_upcycling.py
class MixtralForCausalLMUpscalingAlgorithm(BaseAlgorithm):
    """
    This class is responsible for upscaling a model to a MixtralForCausalLM.
    It inherits from the ModelFusionAlgorithm class.
    """

    _config_mapping = BaseAlgorithm._config_mapping | {
        "num_experts": "num_experts",
        "experts_per_token": "experts_per_token",
        "save_checkpoint": "save_checkpoint",
    }

    def __init__(
        self,
        num_experts: int,
        experts_per_token: int,
        save_checkpoint: str,
        **kwargs,
    ):
        """
        Initialize the MixtralForCausalLMUpscalingAlgorithm.

        Args:
            num_experts (int): The number of experts in the Mixtral model.
            experts_per_token (int): The number of experts per token.
            save_checkpoint (str): The path to save the checkpoint.
            **kwargs: Additional keyword arguments.
        """
        self.num_experts = num_experts
        self.experts_per_token = experts_per_token
        self.save_checkpoint = save_checkpoint
        super().__init__(**kwargs)

    @torch.no_grad()
    def _run(
        self, modelpool: BaseModelPool | LlamaForCausalLM | MistralForCausalLM
    ) -> MixtralForCausalLM:
        """
        Internal method to run the upscaling process.

        Args:
            modelpool (BaseModelPool | LlamaForCausalLM | MistralForCausalLM): The model to be upscaled.

        Returns:
            MixtralForCausalLM: The upscaled model.
        """
        if isinstance(modelpool, BaseModelPool):
            assert modelpool.has_pretrained, "ModelPool must have pretrained model."
            pretrained_model = modelpool.load_model("_pretrained_")
        elif isinstance(modelpool, (LlamaForCausalLM, MistralForCausalLM)):
            pretrained_model = modelpool
        else:
            raise ValueError("Invalid modelpool type")

        mixtral_config = _convert_config_to_mixtral(
            pretrained_model.config,
            self.config.num_experts,
            self.config.experts_per_token,
        )

        with ContextManagers([no_init_weights(True)]):
            for _ in tqdm(range(1), desc="Initializing Mixtral model"):
                mixtral_model = MixtralForCausalLM(mixtral_config)
        upscale_to_mixtral_for_causal_lm(pretrained_model, mixtral_model)

        return mixtral_model

    @torch.no_grad()
    def run(
        self, modelpool: BaseModelPool | LlamaForCausalLM | MistralForCausalLM
    ) -> MixtralForCausalLM:
        """
        Runs the upscaling process.

        Args:
            modelpool (ModelPool | LlamaForCausalLM | MistralForCausalLM): The model to be upscaled.

        Returns:
            MixtralForCausalLM: The upscaled model.
        """
        mixtral_model = self._run(modelpool)

        if self.config.get("save_checkpoint", None) is not None:
            mixtral_model.save_pretrained(self.config.save_checkpoint)
        return mixtral_model
__init__(num_experts, experts_per_token, save_checkpoint, **kwargs)

Initialize the MixtralForCausalLMUpscalingAlgorithm.

Parameters:

  • num_experts (int) –

    The number of experts in the Mixtral model.

  • experts_per_token (int) –

    The number of experts per token.

  • save_checkpoint (str) –

    The path to save the checkpoint.

  • **kwargs –

    Additional keyword arguments.

Source code in fusion_bench/method/mixture_of_experts/mixtral_upcycling.py
def __init__(
    self,
    num_experts: int,
    experts_per_token: int,
    save_checkpoint: str,
    **kwargs,
):
    """
    Initialize the MixtralForCausalLMUpscalingAlgorithm.

    Args:
        num_experts (int): The number of experts in the Mixtral model.
        experts_per_token (int): The number of experts per token.
        save_checkpoint (str): The path to save the checkpoint.
        **kwargs: Additional keyword arguments.
    """
    self.num_experts = num_experts
    self.experts_per_token = experts_per_token
    self.save_checkpoint = save_checkpoint
    super().__init__(**kwargs)
run(modelpool)

Runs the upscaling process.

Parameters:

  • modelpool (ModelPool | LlamaForCausalLM | MistralForCausalLM) –

    The model to be upscaled.

Returns:

  • MixtralForCausalLM ( MixtralForCausalLM ) –

    The upscaled model.

Source code in fusion_bench/method/mixture_of_experts/mixtral_upcycling.py
@torch.no_grad()
def run(
    self, modelpool: BaseModelPool | LlamaForCausalLM | MistralForCausalLM
) -> MixtralForCausalLM:
    """
    Runs the upscaling process.

    Args:
        modelpool (ModelPool | LlamaForCausalLM | MistralForCausalLM): The model to be upscaled.

    Returns:
        MixtralForCausalLM: The upscaled model.
    """
    mixtral_model = self._run(modelpool)

    if self.config.get("save_checkpoint", None) is not None:
        mixtral_model.save_pretrained(self.config.save_checkpoint)
    return mixtral_model

MixtralMoEMergingAlgorithm

Bases: MixtralUpscalingAlgorithm

This class is responsible for merging models into a MixtralModel.

Source code in fusion_bench/method/mixture_of_experts/mixtral_merging.py
class MixtralMoEMergingAlgorithm(MixtralUpscalingAlgorithm):
    """
    This class is responsible for merging models into a MixtralModel.
    """

    @torch.no_grad()
    def run(self, modelpool: BaseModelPool) -> MixtralModel:
        """
        Runs the merging process.

        Args:
            modelpool (ModelPool): The pool of models to be merged. Each model in the pool will be treated as an expert, and should be a `MistralModel` or `LlamaModel`.

        Returns:
            MixtralModel: The merged model.
        """
        with open_dict(self.config):
            self.config.num_experts = len(modelpool)

        # firstly, we upscale the models to MixtralModel
        mixtral_model = super()._run(modelpool)

        # then we substitute the experts of the MixtralModel with the models from the modelpool
        for model_idx, model_name in enumerate(modelpool.model_names):
            expert_model: MistralModel | LlamaModel = modelpool.load_model(model_name)
            _substitute_experts(model_idx, expert_model, mixtral_model)

        if self.config.get("save_checkpoint", None) is not None:
            mixtral_model.save_pretrained(self.config.save_checkpoint)
        return mixtral_model
run(modelpool)

Runs the merging process.

Parameters:

  • modelpool (ModelPool) –

    The pool of models to be merged. Each model in the pool will be treated as an expert, and should be a MistralModel or LlamaModel.

Returns:

  • MixtralModel ( MixtralModel ) –

    The merged model.

Source code in fusion_bench/method/mixture_of_experts/mixtral_merging.py
@torch.no_grad()
def run(self, modelpool: BaseModelPool) -> MixtralModel:
    """
    Runs the merging process.

    Args:
        modelpool (ModelPool): The pool of models to be merged. Each model in the pool will be treated as an expert, and should be a `MistralModel` or `LlamaModel`.

    Returns:
        MixtralModel: The merged model.
    """
    with open_dict(self.config):
        self.config.num_experts = len(modelpool)

    # firstly, we upscale the models to MixtralModel
    mixtral_model = super()._run(modelpool)

    # then we substitute the experts of the MixtralModel with the models from the modelpool
    for model_idx, model_name in enumerate(modelpool.model_names):
        expert_model: MistralModel | LlamaModel = modelpool.load_model(model_name)
        _substitute_experts(model_idx, expert_model, mixtral_model)

    if self.config.get("save_checkpoint", None) is not None:
        mixtral_model.save_pretrained(self.config.save_checkpoint)
    return mixtral_model

MixtralForCausalLMMergingAlgorithm

Bases: MixtralForCausalLMUpscalingAlgorithm

This class is responsible for merging models into a MixtralForCausalLM.

Source code in fusion_bench/method/mixture_of_experts/mixtral_merging.py
class MixtralForCausalLMMergingAlgorithm(MixtralForCausalLMUpscalingAlgorithm):
    """
    This class is responsible for merging models into a `MixtralForCausalLM`.
    """

    @torch.no_grad()
    def run(self, modelpool: BaseModelPool) -> MixtralForCausalLM:
        """
        Runs the merging process. It first upscales the models to MixtralForCausalLM,
        then substitutes the experts of the MixtralForCausalLM with the models from the modelpool.

        Args:
            modelpool (ModelPool): The pool of models to be merged. Each model in the pool will be treated as an expert, and should be a `MistralForCausalLM` or `LlamaForCausalLM`.

        Returns:
            MixtralForCausalLM: The merged model.
        """
        with open_dict(self.config):
            self.config.num_experts = len(modelpool)

        # firstly, we upscale the models to MixtralForCausalLM
        mixtral_model = super()._run(modelpool)

        # then we substitute the experts of the MixtralForCausalLM with the models from the modelpool
        for model_idx, model_name in enumerate(modelpool.model_names):
            expert_model: MistralForCausalLM | LlamaForCausalLM = modelpool.load_model(
                model_name
            )
            _substitute_experts(model_idx, expert_model.model, mixtral_model.model)

        if self.config.get("save_checkpoint", None) is not None:
            mixtral_model.save_pretrained(self.config.save_checkpoint)
        return mixtral_model
run(modelpool)

Runs the merging process. It first upscales the models to MixtralForCausalLM, then substitutes the experts of the MixtralForCausalLM with the models from the modelpool.

Parameters:

  • modelpool (ModelPool) –

    The pool of models to be merged. Each model in the pool will be treated as an expert, and should be a MistralForCausalLM or LlamaForCausalLM.

Returns:

  • MixtralForCausalLM ( MixtralForCausalLM ) –

    The merged model.

Source code in fusion_bench/method/mixture_of_experts/mixtral_merging.py
@torch.no_grad()
def run(self, modelpool: BaseModelPool) -> MixtralForCausalLM:
    """
    Runs the merging process. It first upscales the models to MixtralForCausalLM,
    then substitutes the experts of the MixtralForCausalLM with the models from the modelpool.

    Args:
        modelpool (ModelPool): The pool of models to be merged. Each model in the pool will be treated as an expert, and should be a `MistralForCausalLM` or `LlamaForCausalLM`.

    Returns:
        MixtralForCausalLM: The merged model.
    """
    with open_dict(self.config):
        self.config.num_experts = len(modelpool)

    # firstly, we upscale the models to MixtralForCausalLM
    mixtral_model = super()._run(modelpool)

    # then we substitute the experts of the MixtralForCausalLM with the models from the modelpool
    for model_idx, model_name in enumerate(modelpool.model_names):
        expert_model: MistralForCausalLM | LlamaForCausalLM = modelpool.load_model(
            model_name
        )
        _substitute_experts(model_idx, expert_model.model, mixtral_model.model)

    if self.config.get("save_checkpoint", None) is not None:
        mixtral_model.save_pretrained(self.config.save_checkpoint)
    return mixtral_model

Weight-Ensembling Mixture of Experts (WE-MoE)

CLIPWeightEnsemblingMoEAlgorithm

Bases: WeightEnsemblingMoEAlgorithm, CLIPClassificationMixin

CLIPWeightEnsemblingMoEAlgorithm is a class that implements the WeightEnsemblingMoEAlgorithm for CLIP models. It extends the WeightEnsemblingMoEAlgorithm and CLIPClassificationMixin classes.

Attributes:

Source code in fusion_bench/method/we_moe/clip_we_moe.py
class CLIPWeightEnsemblingMoEAlgorithm(
    WeightEnsemblingMoEAlgorithm,
    CLIPClassificationMixin,
):
    """
    CLIPWeightEnsemblingMoEAlgorithm is a class that implements the WeightEnsemblingMoEAlgorithm
    for CLIP models. It extends the WeightEnsemblingMoEAlgorithm and CLIPClassificationMixin classes.

    Attributes:
        modelpool (CLIPVisionModelPool): The model pool containing the CLIP models.
    """

    modelpool: CLIPVisionModelPool = None

    def load_checkpoint(self, model, checkpoint):
        """
        Load the checkpoint file.

        Args:
            model: The model to load the checkpoint into.
            checkpoint: The path to the checkpoint file.
        """
        state = {"model": model}
        self._fabric.load(checkpoint, state)

    def save_checkpoint(self, model, checkpoint):
        """
        Save the checkpoint file.

        Args:
            model: The model to save the checkpoint from.
            checkpoint: The path to the checkpoint file.
        """
        self._fabric.save(checkpoint, {"model": model})

    def construct_moe_model(self) -> WeightEnsemblingMoE:
        """
        Construct the Mixture of Experts (MoE) model using the models in the model pool.

        Returns:
            WeightEnsemblingMoE: The constructed MoE model.
        """
        base_model = self.modelpool.load_model("_pretrained_")
        expert_models = [
            self.modelpool.load_model(m) for m in self.modelpool.model_names
        ]

        # Merge the models using task arithmetic
        moe_model = task_arithmetic_merge(
            # This function modifies the model in place, so we need to pass a deepcopy
            deepcopy(base_model),
            expert_models,
            scaling_factor=self.config.init_lambda,
        ).requires_grad_(False)

        # Up-scale MLP modules
        base_encoder: CLIPEncoder = base_model.vision_model.encoder
        moe_encoder: CLIPEncoder = moe_model.vision_model.encoder
        expert_encoders = [m.vision_model.encoder for m in expert_models]

        num_layers = len(base_encoder.layers)
        for layer_idx in range(num_layers):
            base_mlp = base_encoder.layers[layer_idx].mlp
            expert_mlps = [e.layers[layer_idx].mlp for e in expert_encoders]

            moe_encoder.layers[layer_idx].mlp = WeightEnsemblingMoE(
                hidden_size=base_encoder.config.hidden_size,
                base_model=base_mlp,
                expert_models=expert_mlps,
                init_lambda=self.config.init_lambda,
                batch_first=True,  # For open_clip models this is False
                router_hidden_layers=self.config.router_hidden_layers,
                batch_reduce=self.config.batch_reduce,
            )

        return moe_model

    @functools.cache
    def get_shuffled_test_loader_iter(self, tta_dataset: str):
        """
        Get an iterator for the shuffled test data loader.

        Args:
            tta_dataset (str): The name of the test-time adaptation dataset.

        Returns:
            Iterator: An iterator for the shuffled test data loader.
        """
        dataset = self.modelpool.load_test_dataset(tta_dataset)
        dataset = CLIPDataset(dataset, processor=self.clip_processor)
        log.info("get_shuffled_test_loader_iter")
        loader = DataLoader(
            dataset,
            batch_size=self.config.batch_size,
            shuffle=True,
            num_workers=self.config.num_workers,
            pin_memory=True,
        )
        loader = self.fabric.setup_dataloaders(loader)
        return iter(InfiniteDataLoader(loader))

    def on_test_time_adaptation_start(self):
        """
        Load the CLIP processor and construct the zero-shot classification head for each task.
        """
        self.setup_zero_shot_classification_head()

    def compute_logits(self, module, batch, task) -> Tensor:
        """
        Compute the logits for the given batch and task.

        Args:
            module: The model module.
            batch: The input batch.
            task: The task name.

        Returns:
            Tensor: The computed logits.
        """
        images, _ = batch
        text_embeds = self.zeroshot_weights[task]

        image_embeds = module(images)[1]
        image_embeds = self.visual_projection(image_embeds)

        # Normalize embeddings
        image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)

        # Cosine similarity
        logits_per_text = (
            torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
        )
        logits_per_image = logits_per_text.t()

        return logits_per_image
compute_logits(module, batch, task)

Compute the logits for the given batch and task.

Parameters:

  • module –

    The model module.

  • batch –

    The input batch.

  • task –

    The task name.

Returns:

  • Tensor ( Tensor ) –

    The computed logits.

Source code in fusion_bench/method/we_moe/clip_we_moe.py
def compute_logits(self, module, batch, task) -> Tensor:
    """
    Compute the logits for the given batch and task.

    Args:
        module: The model module.
        batch: The input batch.
        task: The task name.

    Returns:
        Tensor: The computed logits.
    """
    images, _ = batch
    text_embeds = self.zeroshot_weights[task]

    image_embeds = module(images)[1]
    image_embeds = self.visual_projection(image_embeds)

    # Normalize embeddings
    image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)

    # Cosine similarity
    logits_per_text = (
        torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
    )
    logits_per_image = logits_per_text.t()

    return logits_per_image
construct_moe_model()

Construct the Mixture of Experts (MoE) model using the models in the model pool.

Returns:

  • WeightEnsemblingMoE ( WeightEnsemblingMoE ) –

    The constructed MoE model.

Source code in fusion_bench/method/we_moe/clip_we_moe.py
def construct_moe_model(self) -> WeightEnsemblingMoE:
    """
    Construct the Mixture of Experts (MoE) model using the models in the model pool.

    Returns:
        WeightEnsemblingMoE: The constructed MoE model.
    """
    base_model = self.modelpool.load_model("_pretrained_")
    expert_models = [
        self.modelpool.load_model(m) for m in self.modelpool.model_names
    ]

    # Merge the models using task arithmetic
    moe_model = task_arithmetic_merge(
        # This function modifies the model in place, so we need to pass a deepcopy
        deepcopy(base_model),
        expert_models,
        scaling_factor=self.config.init_lambda,
    ).requires_grad_(False)

    # Up-scale MLP modules
    base_encoder: CLIPEncoder = base_model.vision_model.encoder
    moe_encoder: CLIPEncoder = moe_model.vision_model.encoder
    expert_encoders = [m.vision_model.encoder for m in expert_models]

    num_layers = len(base_encoder.layers)
    for layer_idx in range(num_layers):
        base_mlp = base_encoder.layers[layer_idx].mlp
        expert_mlps = [e.layers[layer_idx].mlp for e in expert_encoders]

        moe_encoder.layers[layer_idx].mlp = WeightEnsemblingMoE(
            hidden_size=base_encoder.config.hidden_size,
            base_model=base_mlp,
            expert_models=expert_mlps,
            init_lambda=self.config.init_lambda,
            batch_first=True,  # For open_clip models this is False
            router_hidden_layers=self.config.router_hidden_layers,
            batch_reduce=self.config.batch_reduce,
        )

    return moe_model
get_shuffled_test_loader_iter(tta_dataset) cached

Get an iterator for the shuffled test data loader.

Parameters:

  • tta_dataset (str) –

    The name of the test-time adaptation dataset.

Returns:

  • Iterator –

    An iterator for the shuffled test data loader.

Source code in fusion_bench/method/we_moe/clip_we_moe.py
@functools.cache
def get_shuffled_test_loader_iter(self, tta_dataset: str):
    """
    Get an iterator for the shuffled test data loader.

    Args:
        tta_dataset (str): The name of the test-time adaptation dataset.

    Returns:
        Iterator: An iterator for the shuffled test data loader.
    """
    dataset = self.modelpool.load_test_dataset(tta_dataset)
    dataset = CLIPDataset(dataset, processor=self.clip_processor)
    log.info("get_shuffled_test_loader_iter")
    loader = DataLoader(
        dataset,
        batch_size=self.config.batch_size,
        shuffle=True,
        num_workers=self.config.num_workers,
        pin_memory=True,
    )
    loader = self.fabric.setup_dataloaders(loader)
    return iter(InfiniteDataLoader(loader))
load_checkpoint(model, checkpoint)

Load the checkpoint file.

Parameters:

  • model –

    The model to load the checkpoint into.

  • checkpoint –

    The path to the checkpoint file.

Source code in fusion_bench/method/we_moe/clip_we_moe.py
def load_checkpoint(self, model, checkpoint):
    """
    Load the checkpoint file.

    Args:
        model: The model to load the checkpoint into.
        checkpoint: The path to the checkpoint file.
    """
    state = {"model": model}
    self._fabric.load(checkpoint, state)
on_test_time_adaptation_start()

Load the CLIP processor and construct the zero-shot classification head for each task.

Source code in fusion_bench/method/we_moe/clip_we_moe.py
def on_test_time_adaptation_start(self):
    """
    Load the CLIP processor and construct the zero-shot classification head for each task.
    """
    self.setup_zero_shot_classification_head()
save_checkpoint(model, checkpoint)

Save the checkpoint file.

Parameters:

  • model –

    The model to save the checkpoint from.

  • checkpoint –

    The path to the checkpoint file.

Source code in fusion_bench/method/we_moe/clip_we_moe.py
def save_checkpoint(self, model, checkpoint):
    """
    Save the checkpoint file.

    Args:
        model: The model to save the checkpoint from.
        checkpoint: The path to the checkpoint file.
    """
    self._fabric.save(checkpoint, {"model": model})

Sparse WE-MoE

SparseWeightEnsemblingMoEAlgorithm

Bases: ModelFusionAlgorithm

Source code in fusion_bench/method/sparse_we_moe/sparse_we_moe.py
class SparseWeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
    _fabric: L.Fabric = None
    modelpool: BaseModelPool = None

    def __init__(self, algorithm_config: DictConfig):
        """
        Initialize the SparseWeightEnsemblingMoEAlgorithm with the given configuration.

        Args:
            algorithm_config (DictConfig): The configuration for the algorithm.
        """
        super().__init__(algorithm_config)

        self.profiler = SimpleProfiler(
            self.config.get("cache_dir", "outputs"), "we_moe_profiler.txt"
        )

    @abstractmethod
    def load_checkpoint(self, model, checkpoint):
        """
        Load the checkpoint file.

        Args:
            model (nn.Module): The model to load the checkpoint into.
            checkpoint (str): The path to the checkpoint file.
        """
        pass

    @abstractmethod
    def save_checkpoint(self, model, checkpoint):
        """
        Save the checkpoint file.

        Args:
            model (nn.Module): The model to save the checkpoint from.
            checkpoint (str): The path to the checkpoint file.
        """
        pass

    @abstractmethod
    def construct_moe_model(self) -> SparseWeightEnsemblingMoE:
        """
        Construct the Mixture of Experts model using the models in the model pool.

        Returns:
            SparseWeightEnsemblingMoE: The constructed Mixture of Experts model.
        """
        pass

    @abstractmethod
    def construct_moe_model_sharedgate(self) -> SparseWeightEnsemblingMoE_ShardGate:
        """
        Construct the Mixture of Experts model using the models in the model pool.

        Returns:
            SparseWeightEnsemblingMoE_ShardGate: The constructed Mixture of Experts model with shared gate.
        """
        pass

    def on_test_time_adaptation_start(self):
        """
        Hook that is called at the start of test-time adaptation.
        """
        pass

    @abstractmethod
    def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
        """
        Get an iterator for the shuffled test DataLoader for a specific task.

        Args:
            task (str): The task for which to get the DataLoader iterator.

        Returns:
            DataLoader: The DataLoader iterator for the specified task.
        """
        pass

    @abstractmethod
    def compute_logits(self, module, batch, task) -> Tensor:
        """
        Compute the logits for a given batch and task.

        Args:
            module (nn.Module): The model module.
            batch (Any): The input batch.
            task (str): The task for which to compute the logits.

        Returns:
            Tensor: The computed logits.
        """
        pass

    def dynamic_prune(self, module, prune_ratio):
        """
        Dynamically prune the parameters of a module based on the given prune ratio.

        Args:
            module (nn.Module): The module to prune.
            prune_ratio (float): The ratio of parameters to prune.
        """
        for param in module.parameters():
            if param.requires_grad:
                param.data = _magnitude_prune(param, prune_ratio)

    def l1_regularization(self, module, l1_lambda):
        """
        Compute the L1 regularization loss for a module.

        Args:
            module (nn.Module): The module for which to compute the L1 regularization loss.
            l1_lambda (float): The L1 regularization coefficient.

        Returns:
            Tensor: The L1 regularization loss.
        """
        l1_norm = sum(
            param.abs().sum() for param in module.parameters() if param.requires_grad
        )
        return l1_lambda * l1_norm

    def test_time_adaptation(self, module: SparseWeightEnsemblingMoE):
        """
        Perform test-time adaptation for the given module.

        Args:
            module (SparseWeightEnsemblingMoE): The module to adapt.

        Returns:
            SparseWeightEnsemblingMoE: The adapted module.
        """
        self.on_test_time_adaptation_start()

        # configure optimizer
        if self.config.optimizer == "adam":
            optimizer = torch.optim.Adam(
                [p for p in module.parameters() if p.requires_grad], lr=self.config.lr
            )
        else:
            raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")

        if self._fabric is not None:
            module, optimizer = self._fabric.setup(module, optimizer)

        module.train()

        if self.config.get("fast_dev_run", False):
            log.info("Running fast_dev_run, only one step")
            pbar = tqdm(
                range(1),
                "Test-time adaptation",
                dynamic_ncols=True,
            )
        else:
            pbar = tqdm(
                range(self.config.max_steps),
                "Test-time adaptation",
                dynamic_ncols=True,
            )

        for step_idx in pbar:
            if self.config.use_grad_accumulate:
                for task in self.modelpool.model_names:
                    with self.profiler.profile("data time"):
                        batch = next(self.get_shuffled_test_loader_iter(task))
                    with self.profiler.profile("forward pass"):
                        logits = self.compute_logits(module, batch, task)
                        assert (
                            logits.dim() == 2
                        ), f"Expected logits to be 2D, got {logits.dim()}"
                        loss = entropy_loss(logits)
                    # .backward() accumulates when .zero_grad() wasn't called
                    # this can save memory
                    with self.profiler.profile("backward pass"):
                        self._fabric.backward(loss, retain_graph=True)
            else:
                loss = 0
                for task in self.modelpool.model_names:
                    with self.profiler.profile("data time"):
                        batch = next(self.get_shuffled_test_loader_iter(task))
                    with self.profiler.profile("forward pass"):
                        logits = self.compute_logits(module, batch, task)
                        assert (
                            logits.dim() == 2
                        ), f"Expected logits to be 2D, got {logits.dim()}"
                        loss = loss + entropy_loss(logits)

                with self.profiler.profile("backward pass"):
                    self._fabric.backward(loss, retain_graph=True)

            with self.profiler.profile("optimizer step"):
                optimizer.step()
                optimizer.zero_grad()

        return module

    def construct_post_spare_gate_model(self, moe_model, gate_prune_ratio):
        """
        Construct a (post) sparse gated model.

        Args:
            moe_model (SparseWeightEnsemblingMoE): The Mixture of Experts model.
            gate_prune_ratio (float): The ratio of parameters to prune in the gate.

        Returns:
            SparseWeightEnsemblingMoE: The constructed (post) sparse gated model.
        """
        moe_encoder = moe_model.vision_model.encoder
        num_layers = len(moe_encoder.layers)
        for layer_idx in range(num_layers):
            gate = moe_encoder.layers[layer_idx].mlp.gate
            sparse_gate = _module_magnitude_prune(gate, gate_prune_ratio, layer_idx)
            moe_encoder.layers[layer_idx].mlp.gate = sparse_gate
        return moe_model

    def run(self, modelpool: BaseModelPool):
        """
        Run the SparseWeightEnsemblingMoEAlgorithm with the given model pool.

        Args:
            modelpool (BaseModelPool): The model pool to use for the algorithm.

        Returns:
            SparseWeightEnsemblingMoE: The final Mixture of Experts model.
        """
        log.info("Fusing models using WeightEnsembling Mixture of Experts modules.")
        self.modelpool = modelpool

        with timeit_context("upscaling models to a weight-ensembling MoE model"):
            if self.config.shared_gate:
                moe_model = self.construct_moe_model_sharedgate()
            else:
                moe_model = self.construct_moe_model()
            print_parameters(moe_model)

        if self.config.get("checkpoint", False):
            log.info(
                f"load checkpoint from {self.config.checkpoint}, test-time adaptation will be skipped."
            )
            self.load_checkpoint(moe_model, self.config.checkpoint)
        else:
            with self.profiler.profile("test-time adaptation"):
                moe_model = self.test_time_adaptation(moe_model)
            if self.config.get("save_checkpoint", False):
                log.info(f"save checkpoint to {self.config.save_checkpoint}")
                self.save_checkpoint(moe_model, self.config.save_checkpoint)

            if lightning.fabric.wrappers.is_wrapped(moe_model):
                moe_model = lightning.fabric.wrappers._unwrap_objects(moe_model)

        #  (post) sparse gate model
        if self.config.post_sparse_gate:
            moe_model = self.construct_post_spare_gate_model(
                moe_model, self.config.gate_prune_ratio
            )

        # enable sample-wise adaptation
        moe_model.batch_reduce = False
        print(self.profiler.summary())
        return moe_model
__init__(algorithm_config)

Initialize the SparseWeightEnsemblingMoEAlgorithm with the given configuration.

Parameters:

  • algorithm_config (DictConfig) –

    The configuration for the algorithm.

Source code in fusion_bench/method/sparse_we_moe/sparse_we_moe.py
def __init__(self, algorithm_config: DictConfig):
    """
    Initialize the SparseWeightEnsemblingMoEAlgorithm with the given configuration.

    Args:
        algorithm_config (DictConfig): The configuration for the algorithm.
    """
    super().__init__(algorithm_config)

    self.profiler = SimpleProfiler(
        self.config.get("cache_dir", "outputs"), "we_moe_profiler.txt"
    )
compute_logits(module, batch, task) abstractmethod

Compute the logits for a given batch and task.

Parameters:

  • module (Module) –

    The model module.

  • batch (Any) –

    The input batch.

  • task (str) –

    The task for which to compute the logits.

Returns:

  • Tensor ( Tensor ) –

    The computed logits.

Source code in fusion_bench/method/sparse_we_moe/sparse_we_moe.py
@abstractmethod
def compute_logits(self, module, batch, task) -> Tensor:
    """
    Compute the logits for a given batch and task.

    Args:
        module (nn.Module): The model module.
        batch (Any): The input batch.
        task (str): The task for which to compute the logits.

    Returns:
        Tensor: The computed logits.
    """
    pass
construct_moe_model() abstractmethod

Construct the Mixture of Experts model using the models in the model pool.

Returns:

  • SparseWeightEnsemblingMoE ( SparseWeightEnsemblingMoE ) –

    The constructed Mixture of Experts model.

Source code in fusion_bench/method/sparse_we_moe/sparse_we_moe.py
@abstractmethod
def construct_moe_model(self) -> SparseWeightEnsemblingMoE:
    """
    Construct the Mixture of Experts model using the models in the model pool.

    Returns:
        SparseWeightEnsemblingMoE: The constructed Mixture of Experts model.
    """
    pass
construct_moe_model_sharedgate() abstractmethod

Construct the Mixture of Experts model using the models in the model pool.

Returns:

  • SparseWeightEnsemblingMoE_ShardGate ( SparseWeightEnsemblingMoE_ShardGate ) –

    The constructed Mixture of Experts model with shared gate.

Source code in fusion_bench/method/sparse_we_moe/sparse_we_moe.py
@abstractmethod
def construct_moe_model_sharedgate(self) -> SparseWeightEnsemblingMoE_ShardGate:
    """
    Construct the Mixture of Experts model using the models in the model pool.

    Returns:
        SparseWeightEnsemblingMoE_ShardGate: The constructed Mixture of Experts model with shared gate.
    """
    pass
construct_post_spare_gate_model(moe_model, gate_prune_ratio)

Construct a (post) sparse gated model.

Parameters:

  • moe_model (SparseWeightEnsemblingMoE) –

    The Mixture of Experts model.

  • gate_prune_ratio (float) –

    The ratio of parameters to prune in the gate.

Returns:

  • SparseWeightEnsemblingMoE –

    The constructed (post) sparse gated model.

Source code in fusion_bench/method/sparse_we_moe/sparse_we_moe.py
def construct_post_spare_gate_model(self, moe_model, gate_prune_ratio):
    """
    Construct a (post) sparse gated model.

    Args:
        moe_model (SparseWeightEnsemblingMoE): The Mixture of Experts model.
        gate_prune_ratio (float): The ratio of parameters to prune in the gate.

    Returns:
        SparseWeightEnsemblingMoE: The constructed (post) sparse gated model.
    """
    moe_encoder = moe_model.vision_model.encoder
    num_layers = len(moe_encoder.layers)
    for layer_idx in range(num_layers):
        gate = moe_encoder.layers[layer_idx].mlp.gate
        sparse_gate = _module_magnitude_prune(gate, gate_prune_ratio, layer_idx)
        moe_encoder.layers[layer_idx].mlp.gate = sparse_gate
    return moe_model
dynamic_prune(module, prune_ratio)

Dynamically prune the parameters of a module based on the given prune ratio.

Parameters:

  • module (Module) –

    The module to prune.

  • prune_ratio (float) –

    The ratio of parameters to prune.

Source code in fusion_bench/method/sparse_we_moe/sparse_we_moe.py
def dynamic_prune(self, module, prune_ratio):
    """
    Dynamically prune the parameters of a module based on the given prune ratio.

    Args:
        module (nn.Module): The module to prune.
        prune_ratio (float): The ratio of parameters to prune.
    """
    for param in module.parameters():
        if param.requires_grad:
            param.data = _magnitude_prune(param, prune_ratio)
get_shuffled_test_loader_iter(task) abstractmethod

Get an iterator for the shuffled test DataLoader for a specific task.

Parameters:

  • task (str) –

    The task for which to get the DataLoader iterator.

Returns:

  • DataLoader ( DataLoader ) –

    The DataLoader iterator for the specified task.

Source code in fusion_bench/method/sparse_we_moe/sparse_we_moe.py
@abstractmethod
def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
    """
    Get an iterator for the shuffled test DataLoader for a specific task.

    Args:
        task (str): The task for which to get the DataLoader iterator.

    Returns:
        DataLoader: The DataLoader iterator for the specified task.
    """
    pass
l1_regularization(module, l1_lambda)

Compute the L1 regularization loss for a module.

Parameters:

  • module (Module) –

    The module for which to compute the L1 regularization loss.

  • l1_lambda (float) –

    The L1 regularization coefficient.

Returns:

  • Tensor –

    The L1 regularization loss.

Source code in fusion_bench/method/sparse_we_moe/sparse_we_moe.py
def l1_regularization(self, module, l1_lambda):
    """
    Compute the L1 regularization loss for a module.

    Args:
        module (nn.Module): The module for which to compute the L1 regularization loss.
        l1_lambda (float): The L1 regularization coefficient.

    Returns:
        Tensor: The L1 regularization loss.
    """
    l1_norm = sum(
        param.abs().sum() for param in module.parameters() if param.requires_grad
    )
    return l1_lambda * l1_norm
load_checkpoint(model, checkpoint) abstractmethod

Load the checkpoint file.

Parameters:

  • model (Module) –

    The model to load the checkpoint into.

  • checkpoint (str) –

    The path to the checkpoint file.

Source code in fusion_bench/method/sparse_we_moe/sparse_we_moe.py
@abstractmethod
def load_checkpoint(self, model, checkpoint):
    """
    Load the checkpoint file.

    Args:
        model (nn.Module): The model to load the checkpoint into.
        checkpoint (str): The path to the checkpoint file.
    """
    pass
on_test_time_adaptation_start()

Hook that is called at the start of test-time adaptation.

Source code in fusion_bench/method/sparse_we_moe/sparse_we_moe.py
def on_test_time_adaptation_start(self):
    """
    Hook that is called at the start of test-time adaptation.
    """
    pass
run(modelpool)

Run the SparseWeightEnsemblingMoEAlgorithm with the given model pool.

Parameters:

  • modelpool (BaseModelPool) –

    The model pool to use for the algorithm.

Returns:

  • SparseWeightEnsemblingMoE –

    The final Mixture of Experts model.

Source code in fusion_bench/method/sparse_we_moe/sparse_we_moe.py
def run(self, modelpool: BaseModelPool):
    """
    Run the SparseWeightEnsemblingMoEAlgorithm with the given model pool.

    Args:
        modelpool (BaseModelPool): The model pool to use for the algorithm.

    Returns:
        SparseWeightEnsemblingMoE: The final Mixture of Experts model.
    """
    log.info("Fusing models using WeightEnsembling Mixture of Experts modules.")
    self.modelpool = modelpool

    with timeit_context("upscaling models to a weight-ensembling MoE model"):
        if self.config.shared_gate:
            moe_model = self.construct_moe_model_sharedgate()
        else:
            moe_model = self.construct_moe_model()
        print_parameters(moe_model)

    if self.config.get("checkpoint", False):
        log.info(
            f"load checkpoint from {self.config.checkpoint}, test-time adaptation will be skipped."
        )
        self.load_checkpoint(moe_model, self.config.checkpoint)
    else:
        with self.profiler.profile("test-time adaptation"):
            moe_model = self.test_time_adaptation(moe_model)
        if self.config.get("save_checkpoint", False):
            log.info(f"save checkpoint to {self.config.save_checkpoint}")
            self.save_checkpoint(moe_model, self.config.save_checkpoint)

        if lightning.fabric.wrappers.is_wrapped(moe_model):
            moe_model = lightning.fabric.wrappers._unwrap_objects(moe_model)

    #  (post) sparse gate model
    if self.config.post_sparse_gate:
        moe_model = self.construct_post_spare_gate_model(
            moe_model, self.config.gate_prune_ratio
        )

    # enable sample-wise adaptation
    moe_model.batch_reduce = False
    print(self.profiler.summary())
    return moe_model
save_checkpoint(model, checkpoint) abstractmethod

Save the checkpoint file.

Parameters:

  • model (Module) –

    The model to save the checkpoint from.

  • checkpoint (str) –

    The path to the checkpoint file.

Source code in fusion_bench/method/sparse_we_moe/sparse_we_moe.py
@abstractmethod
def save_checkpoint(self, model, checkpoint):
    """
    Save the checkpoint file.

    Args:
        model (nn.Module): The model to save the checkpoint from.
        checkpoint (str): The path to the checkpoint file.
    """
    pass
test_time_adaptation(module)

Perform test-time adaptation for the given module.

Parameters:

  • module (SparseWeightEnsemblingMoE) –

    The module to adapt.

Returns:

  • SparseWeightEnsemblingMoE –

    The adapted module.

Source code in fusion_bench/method/sparse_we_moe/sparse_we_moe.py
def test_time_adaptation(self, module: SparseWeightEnsemblingMoE):
    """
    Perform test-time adaptation for the given module.

    Args:
        module (SparseWeightEnsemblingMoE): The module to adapt.

    Returns:
        SparseWeightEnsemblingMoE: The adapted module.
    """
    self.on_test_time_adaptation_start()

    # configure optimizer
    if self.config.optimizer == "adam":
        optimizer = torch.optim.Adam(
            [p for p in module.parameters() if p.requires_grad], lr=self.config.lr
        )
    else:
        raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")

    if self._fabric is not None:
        module, optimizer = self._fabric.setup(module, optimizer)

    module.train()

    if self.config.get("fast_dev_run", False):
        log.info("Running fast_dev_run, only one step")
        pbar = tqdm(
            range(1),
            "Test-time adaptation",
            dynamic_ncols=True,
        )
    else:
        pbar = tqdm(
            range(self.config.max_steps),
            "Test-time adaptation",
            dynamic_ncols=True,
        )

    for step_idx in pbar:
        if self.config.use_grad_accumulate:
            for task in self.modelpool.model_names:
                with self.profiler.profile("data time"):
                    batch = next(self.get_shuffled_test_loader_iter(task))
                with self.profiler.profile("forward pass"):
                    logits = self.compute_logits(module, batch, task)
                    assert (
                        logits.dim() == 2
                    ), f"Expected logits to be 2D, got {logits.dim()}"
                    loss = entropy_loss(logits)
                # .backward() accumulates when .zero_grad() wasn't called
                # this can save memory
                with self.profiler.profile("backward pass"):
                    self._fabric.backward(loss, retain_graph=True)
        else:
            loss = 0
            for task in self.modelpool.model_names:
                with self.profiler.profile("data time"):
                    batch = next(self.get_shuffled_test_loader_iter(task))
                with self.profiler.profile("forward pass"):
                    logits = self.compute_logits(module, batch, task)
                    assert (
                        logits.dim() == 2
                    ), f"Expected logits to be 2D, got {logits.dim()}"
                    loss = loss + entropy_loss(logits)

            with self.profiler.profile("backward pass"):
                self._fabric.backward(loss, retain_graph=True)

        with self.profiler.profile("optimizer step"):
            optimizer.step()
            optimizer.zero_grad()

    return module

SparseCLIPWeightEnsemblingMoEAlgorithm

Bases: SparseWeightEnsemblingMoEAlgorithm, CLIPClassificationMixin

Source code in fusion_bench/method/sparse_we_moe/sparse_clip_we_moe.py
class SparseCLIPWeightEnsemblingMoEAlgorithm(
    SparseWeightEnsemblingMoEAlgorithm,
    CLIPClassificationMixin,
):
    modelpool: CLIPVisionModelPool = None

    def load_checkpoint(self, model, checkpoint):
        """
        Load the checkpoint file.
        """
        state = {"model": model}
        self._fabric.load(checkpoint, state)

    def save_checkpoint(self, model, checkpoint):
        """
        Save the checkpoint file.
        """
        self._fabric.save(checkpoint, {"model": model})

    def construct_moe_model(self) -> SparseWeightEnsemblingMoE:
        """
        Construct the Mixture of Experts model using the models in the model pool.
        """
        base_model = self.modelpool.load_model("_pretrained_")
        expert_models = [
            self.modelpool.load_model(m) for m in self.modelpool.model_names
        ]

        # merge the models using task arithmetic
        moe_model = task_arithmetic_merge(
            # this function modifies the model in place, so we need to pass a deepcopy
            deepcopy(base_model),
            expert_models,
            scaling_factor=self.config.init_lambda,
        ).requires_grad_(False)

        # up-scale MLP modules
        base_encoder: CLIPEncoder = base_model.vision_model.encoder
        moe_encoder: CLIPEncoder = moe_model.vision_model.encoder
        expert_encoders = [m.vision_model.encoder for m in expert_models]

        num_layers = len(base_encoder.layers)
        for layer_idx in range(num_layers):
            base_mlp = base_encoder.layers[layer_idx].mlp
            expert_mlps = [e.layers[layer_idx].mlp for e in expert_encoders]

            moe_encoder.layers[layer_idx].mlp = SparseWeightEnsemblingMoE(
                hidden_size=base_encoder.config.hidden_size,
                base_model=base_mlp,
                expert_models=expert_mlps,
                init_lambda=self.config.init_lambda,
                batch_first=True,  # for open_clip models this is False
                router_hidden_layers=self.config.router_hidden_layers,
                batch_reduce=self.config.batch_reduce,
                num_layers=num_layers,
                layer_idx=layer_idx,
                tv_prune_ratio=self.config.tv_prune_ratio,
            )

        return moe_model

    def construct_moe_model_sharedgate(self) -> SparseWeightEnsemblingMoE_ShardGate:
        """
        Construct the Mixture of Experts model using the models in the model pool with a shared gate.
        """
        base_model = self.modelpool.load_model("_pretrained_")
        expert_models = [
            self.modelpool.load_model(m) for m in self.modelpool.model_names
        ]

        # merge the models using task arithmetic
        moe_model = task_arithmetic_merge(
            # this function modifies the model in place, so we need to pass a deepcopy
            deepcopy(base_model),
            expert_models,
            scaling_factor=self.config.init_lambda,
        ).requires_grad_(False)

        # up-scale MLP modules
        base_encoder: CLIPEncoder = base_model.vision_model.encoder
        moe_encoder: CLIPEncoder = moe_model.vision_model.encoder
        expert_encoders = [m.vision_model.encoder for m in expert_models]

        # shared gate
        shared_gate = construct_weight_ensembling_gate(
            hidden_size=(
                base_encoder.config.hidden_size + self.config.position_encoding_dim
                if self.config.position_encoding
                else base_encoder.config.hidden_size
            ),
            num_experts=len(expert_models),
            init_lambda=self.config.init_lambda,
            num_hidden_layers=self.config.router_hidden_layers,
        )

        # ------------------------------------------------------------------------------------
        # Calculate magnitude
        # num_layers = len(base_encoder.layers)
        # exp_id = 0
        # for e in expert_encoders:
        #     for layer_idx in range(num_layers):
        #         if layer_idx in [0,3,5,7,9,11]:
        #             print(f"layer_idx: {layer_idx}")
        #             v_e = torch.cat([param.view(-1) for param in e.layers[layer_idx].mlp.parameters()])
        #             v_base = torch.cat([param.view(-1) for param in base_encoder.layers[layer_idx].mlp.parameters()])
        #             absolute_vector = torch.abs(v_e - v_base)
        #             np.save(f"/home/enneng/fusion_bench/outputs/sparse_we_moe/magnitude/absolute_vector_expert_{exp_id}_layer_{layer_idx}.npy", absolute_vector.detach().numpy())
        #     exp_id += 1
        # print('succ')
        # ------------------------------------------------------------------------------------

        # ------------------------------------------------------------------------------------
        # Calculate l2 distance and cos similarity
        # key = 'att' # 'mlp' or 'att'
        # num_layers = len(base_encoder.layers)
        # l2_distance_ss = []
        # cos_sim_ss = []
        # for e in expert_encoders:
        #     l2_distance_s = []
        #     cos_sim_s = []
        #     for layer_idx in range(num_layers):
        #         print(f"layer_idx: {layer_idx}")
        #         v_e = torch.cat([param.view(-1) for param in e.layers[layer_idx].mlp.parameters()]) if key == 'mlp' \
        #             else torch.cat([param.view(-1) for param in e.layers[layer_idx].self_attn.parameters()])
        #         v_base = torch.cat([param.view(-1) for param in base_encoder.layers[layer_idx].mlp.parameters()]) if key == 'mlp' \
        #             else torch.cat([param.view(-1) for param in base_encoder.layers[layer_idx].self_attn.parameters()])
        #         l2_distance = torch.norm(v_e - v_base, p=2)
        #         print(f"L2 Distance: {l2_distance}")
        #         cos_sim = torch.nn.functional.cosine_similarity(v_e, v_base, dim=0)
        #         print(f"Cosine Similarity: {cos_sim}")
        #
        #         l2_distance_s.append(l2_distance.item())
        #         cos_sim_s.append(cos_sim.item())
        #     l2_distance_ss.append(l2_distance_s)
        #     cos_sim_ss.append(cos_sim_s)
        #
        # print("L2 Distances:")
        # print(l2_distance_ss)
        # print("Cosine Similarity:")
        # print(cos_sim_ss)
        # ------------------------------------------------------------------------------------

        num_layers = len(base_encoder.layers)
        for layer_idx in range(num_layers):
            base_mlp = base_encoder.layers[layer_idx].mlp
            expert_mlps = [e.layers[layer_idx].mlp for e in expert_encoders]

            moe_encoder.layers[layer_idx].mlp = SparseWeightEnsemblingMoE_ShardGate(
                hidden_size=base_encoder.config.hidden_size,
                base_model=base_mlp,
                expert_models=expert_mlps,
                init_lambda=self.config.init_lambda,
                batch_first=True,  # for open_clip models this is False
                router_hidden_layers=self.config.router_hidden_layers,
                batch_reduce=self.config.batch_reduce,
                num_layers=num_layers,
                layer_idx=layer_idx,
                tv_prune_ratio=self.config.tv_prune_ratio,
                sharedgate=shared_gate,
                position_encoding=self.config.position_encoding,
                position_encoding_dim=self.config.position_encoding_dim,
            )

        return moe_model

    @functools.cache
    def get_shuffled_test_loader_iter(self, tta_dataset: str):
        """
        Get an iterator for the shuffled test data loader.
        """
        log.info("get_shuffled_test_loader_iter")
        dataset = self.modelpool.load_test_dataset(tta_dataset)
        dataset = CLIPDataset(dataset, processor=self.clip_processor)
        loader = DataLoader(
            dataset,
            batch_size=self.config.batch_size,
            shuffle=True,
            num_workers=self.config.num_workers,
            pin_memory=True,
        )
        if self._fabric is not None:
            loader = self._fabric.setup_dataloaders(loader)
        return iter(InfiniteDataLoader(loader))

    def on_test_time_adaptation_start(self):
        """
        Here we load the CLIP processor and construct the zero-shot classification head for each task.
        """
        self.setup_zero_shot_classification_head()

    def compute_logits(
        self, module: CLIPVisionModel, batch: Tuple[Tensor, Tensor], task: str
    ) -> Tensor:
        """
        Compute the logits for the given batch and task.

        Args:
            module (CLIPVisionModel): The vision model to use for computing logits.
            batch (Tuple[Tensor, Tensor]): The batch of data.
            task (str): The task for which to compute logits.

        Returns:
            Tensor: The computed logits.
        """
        images, _ = batch
        text_embeds = self.zeroshot_weights[task]

        image_embeds = module(images)[1]
        image_embeds = self.visual_projection(image_embeds)

        # normalize embeddings
        image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)

        # cosine similarity
        logits_per_text = (
            torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
        )
        logits_per_image = logits_per_text.t()

        return logits_per_image
compute_logits(module, batch, task)

Compute the logits for the given batch and task.

Parameters:

  • module (CLIPVisionModel) –

    The vision model to use for computing logits.

  • batch (Tuple[Tensor, Tensor]) –

    The batch of data.

  • task (str) –

    The task for which to compute logits.

Returns:

  • Tensor ( Tensor ) –

    The computed logits.

Source code in fusion_bench/method/sparse_we_moe/sparse_clip_we_moe.py
def compute_logits(
    self, module: CLIPVisionModel, batch: Tuple[Tensor, Tensor], task: str
) -> Tensor:
    """
    Compute the logits for the given batch and task.

    Args:
        module (CLIPVisionModel): The vision model to use for computing logits.
        batch (Tuple[Tensor, Tensor]): The batch of data.
        task (str): The task for which to compute logits.

    Returns:
        Tensor: The computed logits.
    """
    images, _ = batch
    text_embeds = self.zeroshot_weights[task]

    image_embeds = module(images)[1]
    image_embeds = self.visual_projection(image_embeds)

    # normalize embeddings
    image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)

    # cosine similarity
    logits_per_text = (
        torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
    )
    logits_per_image = logits_per_text.t()

    return logits_per_image
construct_moe_model()

Construct the Mixture of Experts model using the models in the model pool.

Source code in fusion_bench/method/sparse_we_moe/sparse_clip_we_moe.py
def construct_moe_model(self) -> SparseWeightEnsemblingMoE:
    """
    Construct the Mixture of Experts model using the models in the model pool.
    """
    base_model = self.modelpool.load_model("_pretrained_")
    expert_models = [
        self.modelpool.load_model(m) for m in self.modelpool.model_names
    ]

    # merge the models using task arithmetic
    moe_model = task_arithmetic_merge(
        # this function modifies the model in place, so we need to pass a deepcopy
        deepcopy(base_model),
        expert_models,
        scaling_factor=self.config.init_lambda,
    ).requires_grad_(False)

    # up-scale MLP modules
    base_encoder: CLIPEncoder = base_model.vision_model.encoder
    moe_encoder: CLIPEncoder = moe_model.vision_model.encoder
    expert_encoders = [m.vision_model.encoder for m in expert_models]

    num_layers = len(base_encoder.layers)
    for layer_idx in range(num_layers):
        base_mlp = base_encoder.layers[layer_idx].mlp
        expert_mlps = [e.layers[layer_idx].mlp for e in expert_encoders]

        moe_encoder.layers[layer_idx].mlp = SparseWeightEnsemblingMoE(
            hidden_size=base_encoder.config.hidden_size,
            base_model=base_mlp,
            expert_models=expert_mlps,
            init_lambda=self.config.init_lambda,
            batch_first=True,  # for open_clip models this is False
            router_hidden_layers=self.config.router_hidden_layers,
            batch_reduce=self.config.batch_reduce,
            num_layers=num_layers,
            layer_idx=layer_idx,
            tv_prune_ratio=self.config.tv_prune_ratio,
        )

    return moe_model
construct_moe_model_sharedgate()

Construct the Mixture of Experts model using the models in the model pool with a shared gate.

Source code in fusion_bench/method/sparse_we_moe/sparse_clip_we_moe.py
def construct_moe_model_sharedgate(self) -> SparseWeightEnsemblingMoE_ShardGate:
    """
    Construct the Mixture of Experts model using the models in the model pool with a shared gate.
    """
    base_model = self.modelpool.load_model("_pretrained_")
    expert_models = [
        self.modelpool.load_model(m) for m in self.modelpool.model_names
    ]

    # merge the models using task arithmetic
    moe_model = task_arithmetic_merge(
        # this function modifies the model in place, so we need to pass a deepcopy
        deepcopy(base_model),
        expert_models,
        scaling_factor=self.config.init_lambda,
    ).requires_grad_(False)

    # up-scale MLP modules
    base_encoder: CLIPEncoder = base_model.vision_model.encoder
    moe_encoder: CLIPEncoder = moe_model.vision_model.encoder
    expert_encoders = [m.vision_model.encoder for m in expert_models]

    # shared gate
    shared_gate = construct_weight_ensembling_gate(
        hidden_size=(
            base_encoder.config.hidden_size + self.config.position_encoding_dim
            if self.config.position_encoding
            else base_encoder.config.hidden_size
        ),
        num_experts=len(expert_models),
        init_lambda=self.config.init_lambda,
        num_hidden_layers=self.config.router_hidden_layers,
    )

    # ------------------------------------------------------------------------------------
    # Calculate magnitude
    # num_layers = len(base_encoder.layers)
    # exp_id = 0
    # for e in expert_encoders:
    #     for layer_idx in range(num_layers):
    #         if layer_idx in [0,3,5,7,9,11]:
    #             print(f"layer_idx: {layer_idx}")
    #             v_e = torch.cat([param.view(-1) for param in e.layers[layer_idx].mlp.parameters()])
    #             v_base = torch.cat([param.view(-1) for param in base_encoder.layers[layer_idx].mlp.parameters()])
    #             absolute_vector = torch.abs(v_e - v_base)
    #             np.save(f"/home/enneng/fusion_bench/outputs/sparse_we_moe/magnitude/absolute_vector_expert_{exp_id}_layer_{layer_idx}.npy", absolute_vector.detach().numpy())
    #     exp_id += 1
    # print('succ')
    # ------------------------------------------------------------------------------------

    # ------------------------------------------------------------------------------------
    # Calculate l2 distance and cos similarity
    # key = 'att' # 'mlp' or 'att'
    # num_layers = len(base_encoder.layers)
    # l2_distance_ss = []
    # cos_sim_ss = []
    # for e in expert_encoders:
    #     l2_distance_s = []
    #     cos_sim_s = []
    #     for layer_idx in range(num_layers):
    #         print(f"layer_idx: {layer_idx}")
    #         v_e = torch.cat([param.view(-1) for param in e.layers[layer_idx].mlp.parameters()]) if key == 'mlp' \
    #             else torch.cat([param.view(-1) for param in e.layers[layer_idx].self_attn.parameters()])
    #         v_base = torch.cat([param.view(-1) for param in base_encoder.layers[layer_idx].mlp.parameters()]) if key == 'mlp' \
    #             else torch.cat([param.view(-1) for param in base_encoder.layers[layer_idx].self_attn.parameters()])
    #         l2_distance = torch.norm(v_e - v_base, p=2)
    #         print(f"L2 Distance: {l2_distance}")
    #         cos_sim = torch.nn.functional.cosine_similarity(v_e, v_base, dim=0)
    #         print(f"Cosine Similarity: {cos_sim}")
    #
    #         l2_distance_s.append(l2_distance.item())
    #         cos_sim_s.append(cos_sim.item())
    #     l2_distance_ss.append(l2_distance_s)
    #     cos_sim_ss.append(cos_sim_s)
    #
    # print("L2 Distances:")
    # print(l2_distance_ss)
    # print("Cosine Similarity:")
    # print(cos_sim_ss)
    # ------------------------------------------------------------------------------------

    num_layers = len(base_encoder.layers)
    for layer_idx in range(num_layers):
        base_mlp = base_encoder.layers[layer_idx].mlp
        expert_mlps = [e.layers[layer_idx].mlp for e in expert_encoders]

        moe_encoder.layers[layer_idx].mlp = SparseWeightEnsemblingMoE_ShardGate(
            hidden_size=base_encoder.config.hidden_size,
            base_model=base_mlp,
            expert_models=expert_mlps,
            init_lambda=self.config.init_lambda,
            batch_first=True,  # for open_clip models this is False
            router_hidden_layers=self.config.router_hidden_layers,
            batch_reduce=self.config.batch_reduce,
            num_layers=num_layers,
            layer_idx=layer_idx,
            tv_prune_ratio=self.config.tv_prune_ratio,
            sharedgate=shared_gate,
            position_encoding=self.config.position_encoding,
            position_encoding_dim=self.config.position_encoding_dim,
        )

    return moe_model
get_shuffled_test_loader_iter(tta_dataset) cached

Get an iterator for the shuffled test data loader.

Source code in fusion_bench/method/sparse_we_moe/sparse_clip_we_moe.py
@functools.cache
def get_shuffled_test_loader_iter(self, tta_dataset: str):
    """
    Get an iterator for the shuffled test data loader.
    """
    log.info("get_shuffled_test_loader_iter")
    dataset = self.modelpool.load_test_dataset(tta_dataset)
    dataset = CLIPDataset(dataset, processor=self.clip_processor)
    loader = DataLoader(
        dataset,
        batch_size=self.config.batch_size,
        shuffle=True,
        num_workers=self.config.num_workers,
        pin_memory=True,
    )
    if self._fabric is not None:
        loader = self._fabric.setup_dataloaders(loader)
    return iter(InfiniteDataLoader(loader))
load_checkpoint(model, checkpoint)

Load the checkpoint file.

Source code in fusion_bench/method/sparse_we_moe/sparse_clip_we_moe.py
def load_checkpoint(self, model, checkpoint):
    """
    Load the checkpoint file.
    """
    state = {"model": model}
    self._fabric.load(checkpoint, state)
on_test_time_adaptation_start()

Here we load the CLIP processor and construct the zero-shot classification head for each task.

Source code in fusion_bench/method/sparse_we_moe/sparse_clip_we_moe.py
def on_test_time_adaptation_start(self):
    """
    Here we load the CLIP processor and construct the zero-shot classification head for each task.
    """
    self.setup_zero_shot_classification_head()
save_checkpoint(model, checkpoint)

Save the checkpoint file.

Source code in fusion_bench/method/sparse_we_moe/sparse_clip_we_moe.py
def save_checkpoint(self, model, checkpoint):
    """
    Save the checkpoint file.
    """
    self._fabric.save(checkpoint, {"model": model})

Rank-One MoE

RankOneMoEAlgorithm

Bases: ModelFusionAlgorithm

Algorithm for fusing models using RankOne-MoE (https://github.com/EnnengYang/RankOne-MoE).

This class provides methods for constructing the MoE model, performing test-time adaptation, and running the fusion process.

Attributes:

  • _fabric (Fabric) –

    The fabric for distributed training.

  • modelpool (ModelPool) –

    The pool of models to be fused.

  • profiler (SimpleProfiler) –

    The profiler for measuring performance.

Source code in fusion_bench/method/rankone_moe/rankone_moe.py
class RankOneMoEAlgorithm(ModelFusionAlgorithm):
    """
    Algorithm for fusing models using RankOne-MoE (https://github.com/EnnengYang/RankOne-MoE).

    This class provides methods for constructing the MoE model, performing test-time adaptation,
    and running the fusion process.

    Attributes:
        _fabric (L.Fabric): The fabric for distributed training.
        modelpool (ModelPool): The pool of models to be fused.
        profiler (SimpleProfiler): The profiler for measuring performance.
    """

    _fabric: L.Fabric = None
    modelpool: ModelPool = None

    def __init__(self, algorithm_config: DictConfig):
        """
        Initialize the RankOneMoEAlgorithm with the given configuration.

        Args:
            algorithm_config (DictConfig): The configuration for the algorithm.
        """
        super().__init__(algorithm_config)

        if self._fabric is None and torch.cuda.is_available():
            self._fabric = L.Fabric(
                devices=self.config.get("devices", 1),
            )
            self._fabric.launch()
        else:
            assert "No CUDA device available."
        self.profiler = SimpleProfiler(
            self.config.get("cache_dir", "outputs"), "we_moe_profiler.txt"
        )

    @abstractmethod
    def load_checkpoint(self, model, checkpoint):
        """
        Load the checkpoint file.

        Args:
            model: The model to load the checkpoint into.
            checkpoint: The checkpoint file to load.
        """
        pass

    @abstractmethod
    def save_checkpoint(self, model, checkpoint):
        """
        Save the checkpoint file.

        Args:
            model: The model to save the checkpoint from.
            checkpoint: The checkpoint file to save.
        """
        pass

    @abstractmethod
    def construct_moe_model(self) -> RankOneMoE:
        """
        Construct the Mixture of Experts model using the models in the model pool.

        Returns:
            RankOne-MoE: The constructed MoE model.
        """
        pass

    def on_test_time_adaptation_start(self):
        """
        Hook method called at the start of test-time adaptation.
        """
        pass

    @abstractmethod
    def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
        """
        Get an iterator for the shuffled test data loader for a specific task.

        Args:
            task (str): The task for which to get the test data loader.

        Returns:
            DataLoader: The shuffled test data loader iterator.
        """
        pass

    @abstractmethod
    def compute_logits(self, module, batch, task) -> Tensor:
        """
        Compute the logits for a given batch and task.

        Args:
            module: The model module to use for computing logits.
            batch: The batch of data.
            task: The task for which to compute logits.

        Returns:
            Tensor: The computed logits.
        """
        pass

    def test_time_adaptation(self, module: RankOneMoE):
        """
        Perform test-time adaptation for the given module.

        Args:
            module (RankOne-MoE): The MoE module to adapt.

        Returns:
            RankOne-MoE: The adapted MoE module.
        """
        self.on_test_time_adaptation_start()

        # configure optimizer
        if self.config.optimizer == "adam":
            optimizer = torch.optim.Adam(
                [p for p in module.parameters() if p.requires_grad], lr=self.config.lr
            )
        else:
            raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")

        if self._fabric is not None:
            module, optimizer = self._fabric.setup(module, optimizer)

        module.train()

        if self.config.get("fast_dev_run", False):
            log.info("Running fast_dev_run, only one step")
            pbar = tqdm(
                range(1),
                "Test-time adaptation",
                dynamic_ncols=True,
            )
        else:
            pbar = tqdm(
                range(self.config.max_steps),
                "Test-time adaptation",
                dynamic_ncols=True,
            )
        for step_idx in pbar:
            if self.config.use_grad_accumulate:
                for task in self.modelpool.model_names:
                    with self.profiler.profile("data time"):
                        batch = next(self.get_shuffled_test_loader_iter(task))
                    with self.profiler.profile("forward pass"):
                        logits = self.compute_logits(module, batch, task)
                        assert (
                            logits.dim() == 2
                        ), f"Expected logits to be 2D, got {logits.dim()}"
                        loss = entropy_loss(logits)
                    # .backward() accumulates when .zero_grad() wasn't called
                    # this can save memory
                    with self.profiler.profile("backward pass"):
                        self._fabric.backward(loss, retain_graph=True)
            else:
                loss = 0
                for task in self.modelpool.model_names:
                    with self.profiler.profile("data time"):
                        batch = next(self.get_shuffled_test_loader_iter(task))
                    with self.profiler.profile("forward pass"):
                        logits = self.compute_logits(module, batch, task)
                        assert (
                            logits.dim() == 2
                        ), f"Expected logits to be 2D, got {logits.dim()}"
                        loss = loss + entropy_loss(logits)
                with self.profiler.profile("backward pass"):
                    self._fabric.backward(loss, retain_graph=True)

            with self.profiler.profile("optimizer step"):
                optimizer.step()
                optimizer.zero_grad()

            # print([m for m in module.parameters() if m.requires_grad][0])

        return module

    def run(self, modelpool: ModelPool):
        """
        Run the RankOneMoEAlgorithm to fuse models using RankOne-MoE.

        Args:
            modelpool (ModelPool): The pool of models to be fused.

        Returns:
            RankOne-MoE: The fused RankOne MoE model.
        """
        log.info("Fusing models using RankOne-MoE modules.")
        self.modelpool = modelpool

        with timeit_context("upscaling models to a RankOne-MoE model"):
            moe_model = self.construct_moe_model()
            print_parameters(moe_model)

        if self.config.get("checkpoint", False):
            log.info(
                f"load checkpoint from {self.config.checkpoint}, test-time adaptation will be skipped."
            )
            self.load_checkpoint(moe_model, self.config.checkpoint)
        else:
            with self.profiler.profile("test-time adaptation"):
                moe_model = self.test_time_adaptation(moe_model)
            if self.config.get("save_checkpoint", False):
                log.info(f"save checkpoint to {self.config.save_checkpoint}")
                self.save_checkpoint(moe_model, self.config.save_checkpoint)

            if lightning.fabric.wrappers.is_wrapped(moe_model):
                moe_model = lightning.fabric.wrappers._unwrap_objects(moe_model)

        # enable sample-wise adaptation
        moe_model.batch_reduce = False
        print(self.profiler.summary())
        return moe_model
__init__(algorithm_config)

Initialize the RankOneMoEAlgorithm with the given configuration.

Parameters:

  • algorithm_config (DictConfig) –

    The configuration for the algorithm.

Source code in fusion_bench/method/rankone_moe/rankone_moe.py
def __init__(self, algorithm_config: DictConfig):
    """
    Initialize the RankOneMoEAlgorithm with the given configuration.

    Args:
        algorithm_config (DictConfig): The configuration for the algorithm.
    """
    super().__init__(algorithm_config)

    if self._fabric is None and torch.cuda.is_available():
        self._fabric = L.Fabric(
            devices=self.config.get("devices", 1),
        )
        self._fabric.launch()
    else:
        assert "No CUDA device available."
    self.profiler = SimpleProfiler(
        self.config.get("cache_dir", "outputs"), "we_moe_profiler.txt"
    )
compute_logits(module, batch, task) abstractmethod

Compute the logits for a given batch and task.

Parameters:

  • module –

    The model module to use for computing logits.

  • batch –

    The batch of data.

  • task –

    The task for which to compute logits.

Returns:

  • Tensor ( Tensor ) –

    The computed logits.

Source code in fusion_bench/method/rankone_moe/rankone_moe.py
@abstractmethod
def compute_logits(self, module, batch, task) -> Tensor:
    """
    Compute the logits for a given batch and task.

    Args:
        module: The model module to use for computing logits.
        batch: The batch of data.
        task: The task for which to compute logits.

    Returns:
        Tensor: The computed logits.
    """
    pass
construct_moe_model() abstractmethod

Construct the Mixture of Experts model using the models in the model pool.

Returns:

  • RankOneMoE –

    RankOne-MoE: The constructed MoE model.

Source code in fusion_bench/method/rankone_moe/rankone_moe.py
@abstractmethod
def construct_moe_model(self) -> RankOneMoE:
    """
    Construct the Mixture of Experts model using the models in the model pool.

    Returns:
        RankOne-MoE: The constructed MoE model.
    """
    pass
get_shuffled_test_loader_iter(task) abstractmethod

Get an iterator for the shuffled test data loader for a specific task.

Parameters:

  • task (str) –

    The task for which to get the test data loader.

Returns:

  • DataLoader ( DataLoader ) –

    The shuffled test data loader iterator.

Source code in fusion_bench/method/rankone_moe/rankone_moe.py
@abstractmethod
def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
    """
    Get an iterator for the shuffled test data loader for a specific task.

    Args:
        task (str): The task for which to get the test data loader.

    Returns:
        DataLoader: The shuffled test data loader iterator.
    """
    pass
load_checkpoint(model, checkpoint) abstractmethod

Load the checkpoint file.

Parameters:

  • model –

    The model to load the checkpoint into.

  • checkpoint –

    The checkpoint file to load.

Source code in fusion_bench/method/rankone_moe/rankone_moe.py
@abstractmethod
def load_checkpoint(self, model, checkpoint):
    """
    Load the checkpoint file.

    Args:
        model: The model to load the checkpoint into.
        checkpoint: The checkpoint file to load.
    """
    pass
on_test_time_adaptation_start()

Hook method called at the start of test-time adaptation.

Source code in fusion_bench/method/rankone_moe/rankone_moe.py
def on_test_time_adaptation_start(self):
    """
    Hook method called at the start of test-time adaptation.
    """
    pass
run(modelpool)

Run the RankOneMoEAlgorithm to fuse models using RankOne-MoE.

Parameters:

  • modelpool (ModelPool) –

    The pool of models to be fused.

Returns:

  • –

    RankOne-MoE: The fused RankOne MoE model.

Source code in fusion_bench/method/rankone_moe/rankone_moe.py
def run(self, modelpool: ModelPool):
    """
    Run the RankOneMoEAlgorithm to fuse models using RankOne-MoE.

    Args:
        modelpool (ModelPool): The pool of models to be fused.

    Returns:
        RankOne-MoE: The fused RankOne MoE model.
    """
    log.info("Fusing models using RankOne-MoE modules.")
    self.modelpool = modelpool

    with timeit_context("upscaling models to a RankOne-MoE model"):
        moe_model = self.construct_moe_model()
        print_parameters(moe_model)

    if self.config.get("checkpoint", False):
        log.info(
            f"load checkpoint from {self.config.checkpoint}, test-time adaptation will be skipped."
        )
        self.load_checkpoint(moe_model, self.config.checkpoint)
    else:
        with self.profiler.profile("test-time adaptation"):
            moe_model = self.test_time_adaptation(moe_model)
        if self.config.get("save_checkpoint", False):
            log.info(f"save checkpoint to {self.config.save_checkpoint}")
            self.save_checkpoint(moe_model, self.config.save_checkpoint)

        if lightning.fabric.wrappers.is_wrapped(moe_model):
            moe_model = lightning.fabric.wrappers._unwrap_objects(moe_model)

    # enable sample-wise adaptation
    moe_model.batch_reduce = False
    print(self.profiler.summary())
    return moe_model
save_checkpoint(model, checkpoint) abstractmethod

Save the checkpoint file.

Parameters:

  • model –

    The model to save the checkpoint from.

  • checkpoint –

    The checkpoint file to save.

Source code in fusion_bench/method/rankone_moe/rankone_moe.py
@abstractmethod
def save_checkpoint(self, model, checkpoint):
    """
    Save the checkpoint file.

    Args:
        model: The model to save the checkpoint from.
        checkpoint: The checkpoint file to save.
    """
    pass
test_time_adaptation(module)

Perform test-time adaptation for the given module.

Parameters:

  • module (RankOne - MoE) –

    The MoE module to adapt.

Returns:

  • –

    RankOne-MoE: The adapted MoE module.

Source code in fusion_bench/method/rankone_moe/rankone_moe.py
def test_time_adaptation(self, module: RankOneMoE):
    """
    Perform test-time adaptation for the given module.

    Args:
        module (RankOne-MoE): The MoE module to adapt.

    Returns:
        RankOne-MoE: The adapted MoE module.
    """
    self.on_test_time_adaptation_start()

    # configure optimizer
    if self.config.optimizer == "adam":
        optimizer = torch.optim.Adam(
            [p for p in module.parameters() if p.requires_grad], lr=self.config.lr
        )
    else:
        raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")

    if self._fabric is not None:
        module, optimizer = self._fabric.setup(module, optimizer)

    module.train()

    if self.config.get("fast_dev_run", False):
        log.info("Running fast_dev_run, only one step")
        pbar = tqdm(
            range(1),
            "Test-time adaptation",
            dynamic_ncols=True,
        )
    else:
        pbar = tqdm(
            range(self.config.max_steps),
            "Test-time adaptation",
            dynamic_ncols=True,
        )
    for step_idx in pbar:
        if self.config.use_grad_accumulate:
            for task in self.modelpool.model_names:
                with self.profiler.profile("data time"):
                    batch = next(self.get_shuffled_test_loader_iter(task))
                with self.profiler.profile("forward pass"):
                    logits = self.compute_logits(module, batch, task)
                    assert (
                        logits.dim() == 2
                    ), f"Expected logits to be 2D, got {logits.dim()}"
                    loss = entropy_loss(logits)
                # .backward() accumulates when .zero_grad() wasn't called
                # this can save memory
                with self.profiler.profile("backward pass"):
                    self._fabric.backward(loss, retain_graph=True)
        else:
            loss = 0
            for task in self.modelpool.model_names:
                with self.profiler.profile("data time"):
                    batch = next(self.get_shuffled_test_loader_iter(task))
                with self.profiler.profile("forward pass"):
                    logits = self.compute_logits(module, batch, task)
                    assert (
                        logits.dim() == 2
                    ), f"Expected logits to be 2D, got {logits.dim()}"
                    loss = loss + entropy_loss(logits)
            with self.profiler.profile("backward pass"):
                self._fabric.backward(loss, retain_graph=True)

        with self.profiler.profile("optimizer step"):
            optimizer.step()
            optimizer.zero_grad()

        # print([m for m in module.parameters() if m.requires_grad][0])

    return module

CLIPRankOneMoEAlgorithm

Bases: RankOneMoEAlgorithm, CLIPClassificationMixin

CLIPRankOneMoEAlgorithm is a class that implements the RankOneMoEAlgorithm (https://github.com/EnnengYang/RankOne-MoE) for CLIP models. It extends the RankOneMoEAlgorithm and CLIPClassificationMixin classes.

Attributes:

Source code in fusion_bench/method/rankone_moe/clip_rankone_moe.py
class CLIPRankOneMoEAlgorithm(
    RankOneMoEAlgorithm,
    CLIPClassificationMixin,
):
    """
    CLIPRankOneMoEAlgorithm is a class that implements the RankOneMoEAlgorithm (https://github.com/EnnengYang/RankOne-MoE)
    for CLIP models. It extends the RankOneMoEAlgorithm and CLIPClassificationMixin classes.

    Attributes:
        modelpool (CLIPVisionModelPool): The model pool containing the CLIP models.
    """

    modelpool: CLIPVisionModelPool = None

    def load_checkpoint(self, model, checkpoint):
        """
        Load the checkpoint file.

        Args:
            model: The model to load the checkpoint into.
            checkpoint: The path to the checkpoint file.
        """
        state = {"model": model}
        self._fabric.load(checkpoint, state)

    def save_checkpoint(self, model, checkpoint):
        """
        Save the checkpoint file.

        Args:
            model: The model to save the checkpoint from.
            checkpoint: The path to the checkpoint file.
        """
        self._fabric.save(checkpoint, {"model": model})

    def construct_moe_model(self) -> RankOneMoE:
        """
        Construct the RankOne-MoE model using the models in the model pool.

        Returns:
            RankOne-MoE: The constructed MoE model.
        """
        base_model = self.modelpool.load_model("_pretrained_")
        expert_models = [
            self.modelpool.load_model(m) for m in self.modelpool.model_names
        ]

        # Merge the models using task arithmetic
        moe_model = task_arithmetic_merge(
            # This function modifies the model in place, so we need to pass a deepcopy
            deepcopy(base_model),
            expert_models,
            scaling_factor=self.config.init_lambda,
        ).requires_grad_(False)

        # Up-scale MLP modules
        base_encoder: CLIPEncoder = base_model.vision_model.encoder
        moe_encoder: CLIPEncoder = moe_model.vision_model.encoder
        expert_encoders = [m.vision_model.encoder for m in expert_models]

        num_layers = len(base_encoder.layers)
        for layer_idx in range(num_layers):
            base_mlp = base_encoder.layers[layer_idx].mlp
            expert_mlps = [e.layers[layer_idx].mlp for e in expert_encoders]

            moe_encoder.layers[layer_idx].mlp = RankOneMoE(
                hidden_size=base_encoder.config.hidden_size,
                base_model=base_mlp,
                expert_models=expert_mlps,
                init_lambda=self.config.init_lambda,
                batch_first=True,  # For open_clip models this is False
                router_hidden_layers=self.config.router_hidden_layers,
                batch_reduce=self.config.batch_reduce,
                svd_accelerator=self.config.svd_accelerator,
                rank_k=self.config.rank_k,
                select_k=self.config.select_k,
            )

        return moe_model

    @functools.cache
    def get_shuffled_test_loader_iter(self, tta_dataset: str):
        """
        Get an iterator for the shuffled test data loader.

        Args:
            tta_dataset (str): The name of the test-time adaptation dataset.

        Returns:
            Iterator: An iterator for the shuffled test data loader.
        """
        dataset = self.modelpool.load_test_dataset(tta_dataset)
        dataset = CLIPDataset(dataset, processor=self.clip_processor)
        log.info("get_shuffled_test_loader_iter")
        loader = DataLoader(
            dataset,
            batch_size=self.config.batch_size,
            shuffle=True,
            num_workers=self.config.num_workers,
            pin_memory=True,
        )
        loader = self.fabric.setup_dataloaders(loader)
        return iter(InfiniteDataLoader(loader))

    def on_test_time_adaptation_start(self):
        """
        Load the CLIP processor and construct the zero-shot classification head for each task.
        """
        self.setup_zero_shot_classification_head()

    def compute_logits(self, module, batch, task) -> Tensor:
        """
        Compute the logits for the given batch and task.

        Args:
            module: The model module.
            batch: The input batch.
            task: The task name.

        Returns:
            Tensor: The computed logits.
        """
        images, _ = batch
        text_embeds = self.zeroshot_weights[task]

        image_embeds = module(images)[1]
        image_embeds = self.visual_projection(image_embeds)

        # Normalize embeddings
        image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)

        # Cosine similarity
        logits_per_text = (
            torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
        )
        logits_per_image = logits_per_text.t()

        return logits_per_image
compute_logits(module, batch, task)

Compute the logits for the given batch and task.

Parameters:

  • module –

    The model module.

  • batch –

    The input batch.

  • task –

    The task name.

Returns:

  • Tensor ( Tensor ) –

    The computed logits.

Source code in fusion_bench/method/rankone_moe/clip_rankone_moe.py
def compute_logits(self, module, batch, task) -> Tensor:
    """
    Compute the logits for the given batch and task.

    Args:
        module: The model module.
        batch: The input batch.
        task: The task name.

    Returns:
        Tensor: The computed logits.
    """
    images, _ = batch
    text_embeds = self.zeroshot_weights[task]

    image_embeds = module(images)[1]
    image_embeds = self.visual_projection(image_embeds)

    # Normalize embeddings
    image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)

    # Cosine similarity
    logits_per_text = (
        torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
    )
    logits_per_image = logits_per_text.t()

    return logits_per_image
construct_moe_model()

Construct the RankOne-MoE model using the models in the model pool.

Returns:

  • RankOneMoE –

    RankOne-MoE: The constructed MoE model.

Source code in fusion_bench/method/rankone_moe/clip_rankone_moe.py
def construct_moe_model(self) -> RankOneMoE:
    """
    Construct the RankOne-MoE model using the models in the model pool.

    Returns:
        RankOne-MoE: The constructed MoE model.
    """
    base_model = self.modelpool.load_model("_pretrained_")
    expert_models = [
        self.modelpool.load_model(m) for m in self.modelpool.model_names
    ]

    # Merge the models using task arithmetic
    moe_model = task_arithmetic_merge(
        # This function modifies the model in place, so we need to pass a deepcopy
        deepcopy(base_model),
        expert_models,
        scaling_factor=self.config.init_lambda,
    ).requires_grad_(False)

    # Up-scale MLP modules
    base_encoder: CLIPEncoder = base_model.vision_model.encoder
    moe_encoder: CLIPEncoder = moe_model.vision_model.encoder
    expert_encoders = [m.vision_model.encoder for m in expert_models]

    num_layers = len(base_encoder.layers)
    for layer_idx in range(num_layers):
        base_mlp = base_encoder.layers[layer_idx].mlp
        expert_mlps = [e.layers[layer_idx].mlp for e in expert_encoders]

        moe_encoder.layers[layer_idx].mlp = RankOneMoE(
            hidden_size=base_encoder.config.hidden_size,
            base_model=base_mlp,
            expert_models=expert_mlps,
            init_lambda=self.config.init_lambda,
            batch_first=True,  # For open_clip models this is False
            router_hidden_layers=self.config.router_hidden_layers,
            batch_reduce=self.config.batch_reduce,
            svd_accelerator=self.config.svd_accelerator,
            rank_k=self.config.rank_k,
            select_k=self.config.select_k,
        )

    return moe_model
get_shuffled_test_loader_iter(tta_dataset) cached

Get an iterator for the shuffled test data loader.

Parameters:

  • tta_dataset (str) –

    The name of the test-time adaptation dataset.

Returns:

  • Iterator –

    An iterator for the shuffled test data loader.

Source code in fusion_bench/method/rankone_moe/clip_rankone_moe.py
@functools.cache
def get_shuffled_test_loader_iter(self, tta_dataset: str):
    """
    Get an iterator for the shuffled test data loader.

    Args:
        tta_dataset (str): The name of the test-time adaptation dataset.

    Returns:
        Iterator: An iterator for the shuffled test data loader.
    """
    dataset = self.modelpool.load_test_dataset(tta_dataset)
    dataset = CLIPDataset(dataset, processor=self.clip_processor)
    log.info("get_shuffled_test_loader_iter")
    loader = DataLoader(
        dataset,
        batch_size=self.config.batch_size,
        shuffle=True,
        num_workers=self.config.num_workers,
        pin_memory=True,
    )
    loader = self.fabric.setup_dataloaders(loader)
    return iter(InfiniteDataLoader(loader))
load_checkpoint(model, checkpoint)

Load the checkpoint file.

Parameters:

  • model –

    The model to load the checkpoint into.

  • checkpoint –

    The path to the checkpoint file.

Source code in fusion_bench/method/rankone_moe/clip_rankone_moe.py
def load_checkpoint(self, model, checkpoint):
    """
    Load the checkpoint file.

    Args:
        model: The model to load the checkpoint into.
        checkpoint: The path to the checkpoint file.
    """
    state = {"model": model}
    self._fabric.load(checkpoint, state)
on_test_time_adaptation_start()

Load the CLIP processor and construct the zero-shot classification head for each task.

Source code in fusion_bench/method/rankone_moe/clip_rankone_moe.py
def on_test_time_adaptation_start(self):
    """
    Load the CLIP processor and construct the zero-shot classification head for each task.
    """
    self.setup_zero_shot_classification_head()
save_checkpoint(model, checkpoint)

Save the checkpoint file.

Parameters:

  • model –

    The model to save the checkpoint from.

  • checkpoint –

    The path to the checkpoint file.

Source code in fusion_bench/method/rankone_moe/clip_rankone_moe.py
def save_checkpoint(self, model, checkpoint):
    """
    Save the checkpoint file.

    Args:
        model: The model to save the checkpoint from.
        checkpoint: The path to the checkpoint file.
    """
    self._fabric.save(checkpoint, {"model": model})

Smile Upscaling

SmileUpscalingAlgorithm

Bases: SimpleProfilerMixin, BaseAlgorithm

Source code in fusion_bench/method/smile_upscaling/smile_upscaling.py
class SmileUpscalingAlgorithm(
    SimpleProfilerMixin,
    BaseAlgorithm,
):
    _linear_layer_cls = (nn.Linear,)
    _config_mapping = BaseAlgorithm._config_mapping | {
        "device": "device",
        "upscaling_accelerator": "upscaling_accelerator",
        "full_matrices": "full_matrices",
        "gate_k": "gate_k",
        "k": "k",
        "top_k": "top_k",
        "routing_use_diff": "routing_use_diff",
        "average_experts": "average_experts",
        "model_path": "model_path",
    }

    def __init__(
        self,
        *,
        device: str = "cuda",
        upscaling_accelerator: str = None,
        full_matrices: bool = True,
        gate_k: int = 256,
        k: int = 256,
        top_k: int = 1,
        routing_use_diff: bool = True,
        average_experts: bool = False,
        model_path: str = None,
        **kwargs,
    ):
        """
        Initialize the SmileUpscalingAlgorithm.

        Args:
            device (str): The device to perform the computation on.
            upscaling_accelerator (str): The device to perform the SVD computation on.
            full_matrices (bool): Whether to compute the full-sized U and V matrices.
            gate_k (int): The number of singular values to keep for the gate.
            k (int): The number of singular values to keep for the experts.
            top_k (int): The number of top experts to select.
            routing_use_diff (bool): Whether to use weight differences for routing.
            average_experts (bool): Whether to average the experts.
            model_path (str): The path to save/load the model.
            **kwargs: Additional arguments.
        """
        super().__init__()
        self.device = device
        self.upscaling_accelerator = upscaling_accelerator
        self.full_matrices = full_matrices
        self.gate_k = gate_k
        self.k = k
        self.top_k = top_k
        self.routing_use_diff = routing_use_diff
        self.average_experts = average_experts
        self.model_path = model_path
        for key, value in kwargs.items():
            log.warning(f"Unrecognized argument: {key}")
            setattr(self, key, value)

        # print `self.config` as yaml
        print(f"=== Config for `{type(self).__name__}` ===")
        print(OmegaConf.to_yaml(self.config))
        print(f"=== Config for `{type(self).__name__}` ===")

    @torch.no_grad()
    def run(self, modelpool: BaseModelPool):
        """
        Executes the upscaling process.

        Args:
            modelpool (ModelPool): The pool of models to be used for upscaling.

        Returns:
            nn.Module: The upscaled model.
        """
        if not isinstance(modelpool, BaseModelPool):
            modelpool = BaseModelPool(modelpool)

        if self.config.model_path is not None and os.path.exists(
            self.config.model_path
        ):
            log.info(f"Loading model from {self.config.model_path}")
            model = torch.load(self.config.model_path)
            print_parameters(model)
            return model

        with self.profile("loading model"):
            # load models and move to GPU if available
            with self.profile("load pretrained model"):
                pretrained_model = modelpool.load_model("_pretrained_")
            with self.profile("load fine-tuned model"):
                finetuned_models = [
                    m
                    for m in tqdm(modelpool.models(), total=len(modelpool.model_names))
                ]

            if self.config.device == "cuda" and torch.cuda.is_available():
                pretrained_model = pretrained_model.cuda()
                finetuned_models = [m.cuda() for m in finetuned_models]

        with self.profile("merge model"):
            model = self.merge(pretrained_model, finetuned_models)

        self.print_profile_summary()
        if self.config.model_path is not None:
            os.makedirs(os.path.dirname(self.config.model_path), exist_ok=True)
            log.info(f"Saving model to {self.config.model_path}")
            torch.save(model, self.config.model_path)
        print_parameters(model)
        return model

    def merge(
        self,
        pretrained_model: nn.Module,
        finetuned_models: List[nn.Module],
        in_place: bool = True,
    ):
        """
        Merges the pretrained model with the fine-tuned models to create an upscaled model.

        Args:
            pretrained_model (nn.Module): The pretrained model.
            finetuned_models (List[nn.Module]): A list of fine-tuned models.
            in_place (bool): If True, modifies the pretrained model in place. Otherwise, creates a copy.

        Returns:
            nn.Module: The merged model.
        """
        if in_place:
            model = pretrained_model
        else:
            model = deepcopy(pretrained_model)

        self._upscale_submodules(model, finetuned_models)
        return model

    def _upscale_linear_layer(
        self,
        pretrained_model,
        finetuned_models,
        name: str,
    ):
        """
        Upscale a linear layer by merging it with the corresponding layers from the fine-tuned models.

        Args:
            pretrained_model (nn.Module): The pretrained model.
            finetuned_models (List[nn.Module]): A list of fine-tuned models.
            name (str): The name of the linear layer to upscale.
        """
        config = self.config

        name_list = name.split(".")
        module = get_attr(pretrained_model, name_list)
        experts = [get_attr(m, name_list) for m in finetuned_models]
        try:
            moe_linear = SmileMoELinear(
                module,
                experts,
                gate_k=config.gate_k,
                k=config.k,
                top_k=config.top_k,
                routing_use_diff=self.routing_use_diff,
                full_matrices=self.full_matrices,
                upscaling_accelerator=self.upscaling_accelerator,
            )
        except ExpertNotTrainedError:
            print(f"skip {name} because the experts are not trained.")
            return
        set_attr(pretrained_model, name_list, moe_linear)
        # remove the original module from fine-tuned models to save memory
        for m in finetuned_models:
            set_attr(m, name_list, None)

    def _average_experts(self, pretarined_model, finetuned_models, name: str):
        """
        Average the experts for a given layer.

        Args:
            pretarined_model (nn.Module): The pretrained model.
            finetuned_models (List[nn.Module]): A list of fine-tuned models.
            name (str): The name of the layer to average.
        """
        name_list = name.split(".")
        experts = [get_attr(m, name_list) for m in finetuned_models]
        averaged_module = simple_average(experts)
        set_attr(pretarined_model, name_list, averaged_module)

    def _upscale_submodules(
        self,
        pretrained_model: nn.Module,
        finetuned_models: List[nn.Module],
        tqdm_desc: str = "Upscaling Linear Modules",
    ):
        """
        Upscales the submodules of the pretrained model by merging them with the corresponding submodules from the fine-tuned models.

        Args:
            pretrained_model (nn.Module): The pretrained model.
            finetuned_models (List[nn.Module]): A list of fine-tuned models.
            tqdm_desc (str): Description for the tqdm progress bar.
        """
        config = self.config
        for name, module in tqdm(
            tuple(pretrained_model.named_modules()),
            tqdm_desc,
            leave=False,
            dynamic_ncols=True,
        ):
            if isinstance(module, self._linear_layer_cls):
                self._upscale_linear_layer(
                    pretrained_model=pretrained_model,
                    finetuned_models=finetuned_models,
                    name=name,
                )
            elif config.average_experts and len(tuple(module.named_modules())) == 1:
                # if the module is a leaf module, we perform a parameter average
                self._average_experts(pretrained_model, finetuned_models, name)
__init__(*, device='cuda', upscaling_accelerator=None, full_matrices=True, gate_k=256, k=256, top_k=1, routing_use_diff=True, average_experts=False, model_path=None, **kwargs)

Initialize the SmileUpscalingAlgorithm.

Parameters:

  • device (str, default: 'cuda' ) –

    The device to perform the computation on.

  • upscaling_accelerator (str, default: None ) –

    The device to perform the SVD computation on.

  • full_matrices (bool, default: True ) –

    Whether to compute the full-sized U and V matrices.

  • gate_k (int, default: 256 ) –

    The number of singular values to keep for the gate.

  • k (int, default: 256 ) –

    The number of singular values to keep for the experts.

  • top_k (int, default: 1 ) –

    The number of top experts to select.

  • routing_use_diff (bool, default: True ) –

    Whether to use weight differences for routing.

  • average_experts (bool, default: False ) –

    Whether to average the experts.

  • model_path (str, default: None ) –

    The path to save/load the model.

  • **kwargs –

    Additional arguments.

Source code in fusion_bench/method/smile_upscaling/smile_upscaling.py
def __init__(
    self,
    *,
    device: str = "cuda",
    upscaling_accelerator: str = None,
    full_matrices: bool = True,
    gate_k: int = 256,
    k: int = 256,
    top_k: int = 1,
    routing_use_diff: bool = True,
    average_experts: bool = False,
    model_path: str = None,
    **kwargs,
):
    """
    Initialize the SmileUpscalingAlgorithm.

    Args:
        device (str): The device to perform the computation on.
        upscaling_accelerator (str): The device to perform the SVD computation on.
        full_matrices (bool): Whether to compute the full-sized U and V matrices.
        gate_k (int): The number of singular values to keep for the gate.
        k (int): The number of singular values to keep for the experts.
        top_k (int): The number of top experts to select.
        routing_use_diff (bool): Whether to use weight differences for routing.
        average_experts (bool): Whether to average the experts.
        model_path (str): The path to save/load the model.
        **kwargs: Additional arguments.
    """
    super().__init__()
    self.device = device
    self.upscaling_accelerator = upscaling_accelerator
    self.full_matrices = full_matrices
    self.gate_k = gate_k
    self.k = k
    self.top_k = top_k
    self.routing_use_diff = routing_use_diff
    self.average_experts = average_experts
    self.model_path = model_path
    for key, value in kwargs.items():
        log.warning(f"Unrecognized argument: {key}")
        setattr(self, key, value)

    # print `self.config` as yaml
    print(f"=== Config for `{type(self).__name__}` ===")
    print(OmegaConf.to_yaml(self.config))
    print(f"=== Config for `{type(self).__name__}` ===")
merge(pretrained_model, finetuned_models, in_place=True)

Merges the pretrained model with the fine-tuned models to create an upscaled model.

Parameters:

  • pretrained_model (Module) –

    The pretrained model.

  • finetuned_models (List[Module]) –

    A list of fine-tuned models.

  • in_place (bool, default: True ) –

    If True, modifies the pretrained model in place. Otherwise, creates a copy.

Returns:

  • –

    nn.Module: The merged model.

Source code in fusion_bench/method/smile_upscaling/smile_upscaling.py
def merge(
    self,
    pretrained_model: nn.Module,
    finetuned_models: List[nn.Module],
    in_place: bool = True,
):
    """
    Merges the pretrained model with the fine-tuned models to create an upscaled model.

    Args:
        pretrained_model (nn.Module): The pretrained model.
        finetuned_models (List[nn.Module]): A list of fine-tuned models.
        in_place (bool): If True, modifies the pretrained model in place. Otherwise, creates a copy.

    Returns:
        nn.Module: The merged model.
    """
    if in_place:
        model = pretrained_model
    else:
        model = deepcopy(pretrained_model)

    self._upscale_submodules(model, finetuned_models)
    return model
run(modelpool)

Executes the upscaling process.

Parameters:

  • modelpool (ModelPool) –

    The pool of models to be used for upscaling.

Returns:

  • –

    nn.Module: The upscaled model.

Source code in fusion_bench/method/smile_upscaling/smile_upscaling.py
@torch.no_grad()
def run(self, modelpool: BaseModelPool):
    """
    Executes the upscaling process.

    Args:
        modelpool (ModelPool): The pool of models to be used for upscaling.

    Returns:
        nn.Module: The upscaled model.
    """
    if not isinstance(modelpool, BaseModelPool):
        modelpool = BaseModelPool(modelpool)

    if self.config.model_path is not None and os.path.exists(
        self.config.model_path
    ):
        log.info(f"Loading model from {self.config.model_path}")
        model = torch.load(self.config.model_path)
        print_parameters(model)
        return model

    with self.profile("loading model"):
        # load models and move to GPU if available
        with self.profile("load pretrained model"):
            pretrained_model = modelpool.load_model("_pretrained_")
        with self.profile("load fine-tuned model"):
            finetuned_models = [
                m
                for m in tqdm(modelpool.models(), total=len(modelpool.model_names))
            ]

        if self.config.device == "cuda" and torch.cuda.is_available():
            pretrained_model = pretrained_model.cuda()
            finetuned_models = [m.cuda() for m in finetuned_models]

    with self.profile("merge model"):
        model = self.merge(pretrained_model, finetuned_models)

    self.print_profile_summary()
    if self.config.model_path is not None:
        os.makedirs(os.path.dirname(self.config.model_path), exist_ok=True)
        log.info(f"Saving model to {self.config.model_path}")
        torch.save(model, self.config.model_path)
    print_parameters(model)
    return model

SingularProjectionMergingAlgorithm

Bases: ModelFusionAlgorithm, SimpleProfilerMixin

A model fusion algorithm that projects parameter differences into the SVD subspace of a pretrained model.

This algorithm is experimental and aims to investigate the location of task-specific knowledge.

Source code in fusion_bench/method/smile_upscaling/singular_projection_merging.py
class SingularProjectionMergingAlgorithm(ModelFusionAlgorithm, SimpleProfilerMixin):
    """
    A model fusion algorithm that projects parameter differences into the SVD subspace of a pretrained model.

    This algorithm is experimental and aims to investigate the location of task-specific knowledge.
    """

    @torch.no_grad()
    def run(self, modelpool: ModelPool) -> nn.Module:
        """
        Run the singular projection merging algorithm on the given model pool.

        Args:
            modelpool (ModelPool): The pool of models to merge.

        Returns:
            nn.Module: The merged model.
        """
        modelpool = to_modelpool(modelpool)

        if self.config.model_path is not None and os.path.exists(
            self.config.model_path
        ):
            log.info(f"loading merged model from {self.config.model_path}")
            model = torch.load(self.config.model_path)

        with self.profile("load pretrained model"):
            pretrained_model = modelpool.load_model("_pretrained_").to(
                self.config.device
            )
        with self.profile("load fine-tuned model"):
            finetuned_models = modelpool.load_model(modelpool.model_names[0]).to(
                self.config.device
            )

        with self.profile("merge model"):
            model = self.merge(pretrained_model, finetuned_models)

        if self.config.model_path is not None:
            os.path.makedirs(os.path.dirname(self.config.model_path), exist_ok=True)
            torch.save(model, self.config.model_path)

        self.print_profile_summary()
        return model

    def merge(
        self,
        pretrained_model: nn.Module,
        finetuned_model: nn.Module,
        in_place: bool = True,
    ) -> nn.Module:
        """
        Merges the pretrained model with the fine-tuned model by projecting parameter differences
        into the SVD subspace of the pretrained model.

        Args:
            pretrained_model (nn.Module): The pretrained model.
            finetuned_model (nn.Module): The fine-tuned model.
            in_place (bool): If True, modifies the fine-tuned model in place. Otherwise, creates a copy.

        Returns:
            nn.Module: The merged model.
        """
        if in_place:
            model = finetuned_model
        else:
            model = deepcopy(finetuned_model)

        for name, module in tqdm(
            tuple(model.named_modules()),
            "Projection merging in SVD subspace of pretrained model",
        ):
            if isinstance(module, nn.Linear):
                name_list = name.split(".")
                set_attr(
                    model,
                    name_list,
                    self.projection_merge_linear(
                        get_attr(pretrained_model, name_list),
                        get_attr(finetuned_model, name_list),
                        k=self.config.k,
                    ),
                )
        return model

    def projection_merge_linear(
        self, pretrained_model: nn.Linear, finetuned_model: nn.Linear, k: int
    ) -> nn.Linear:
        """
        Projects the parameter differences of linear layers into the SVD subspace of the pretrained model.

        Args:
            pretrained_model (nn.Linear): The linear layer of the pretrained model.
            finetuned_model (nn.Linear): The linear layer of the fine-tuned model.
            k (int): The number of singular values to keep. If negative, it is determined based on the sum of singular values.

        Returns:
            nn.Linear: The merged linear layer with projected parameter differences.
        """
        w = pretrained_model.weight
        w_ft = finetuned_model.weight

        u, s, v = svd(w, full_matrices=self.config.full_matrices)
        if k < 0:
            # find the position where the sum of singular values is larger than 50% of the total sum
            cumsum = s.cumsum(0)
            k = (cumsum < cumsum[-1] * 0.5).sum().item() + 1

        if self.config.rank == "low":
            u = u[:, :k]
            s = s[:k]
            v = v[:, :k]
        else:
            u = u[:, k:]
            s = s[k:]
            v = v[:, k:]

        w_diff = w_ft - w
        w_diff_proj = u.T @ w_diff @ v
        w.data = w + u @ w_diff_proj @ v.T
        if pretrained_model.bias is not None:
            pretrained_model.bias.data = finetuned_model.bias.data
        return pretrained_model
merge(pretrained_model, finetuned_model, in_place=True)

Merges the pretrained model with the fine-tuned model by projecting parameter differences into the SVD subspace of the pretrained model.

Parameters:

  • pretrained_model (Module) –

    The pretrained model.

  • finetuned_model (Module) –

    The fine-tuned model.

  • in_place (bool, default: True ) –

    If True, modifies the fine-tuned model in place. Otherwise, creates a copy.

Returns:

  • Module –

    nn.Module: The merged model.

Source code in fusion_bench/method/smile_upscaling/singular_projection_merging.py
def merge(
    self,
    pretrained_model: nn.Module,
    finetuned_model: nn.Module,
    in_place: bool = True,
) -> nn.Module:
    """
    Merges the pretrained model with the fine-tuned model by projecting parameter differences
    into the SVD subspace of the pretrained model.

    Args:
        pretrained_model (nn.Module): The pretrained model.
        finetuned_model (nn.Module): The fine-tuned model.
        in_place (bool): If True, modifies the fine-tuned model in place. Otherwise, creates a copy.

    Returns:
        nn.Module: The merged model.
    """
    if in_place:
        model = finetuned_model
    else:
        model = deepcopy(finetuned_model)

    for name, module in tqdm(
        tuple(model.named_modules()),
        "Projection merging in SVD subspace of pretrained model",
    ):
        if isinstance(module, nn.Linear):
            name_list = name.split(".")
            set_attr(
                model,
                name_list,
                self.projection_merge_linear(
                    get_attr(pretrained_model, name_list),
                    get_attr(finetuned_model, name_list),
                    k=self.config.k,
                ),
            )
    return model
projection_merge_linear(pretrained_model, finetuned_model, k)

Projects the parameter differences of linear layers into the SVD subspace of the pretrained model.

Parameters:

  • pretrained_model (Linear) –

    The linear layer of the pretrained model.

  • finetuned_model (Linear) –

    The linear layer of the fine-tuned model.

  • k (int) –

    The number of singular values to keep. If negative, it is determined based on the sum of singular values.

Returns:

  • Linear –

    nn.Linear: The merged linear layer with projected parameter differences.

Source code in fusion_bench/method/smile_upscaling/singular_projection_merging.py
def projection_merge_linear(
    self, pretrained_model: nn.Linear, finetuned_model: nn.Linear, k: int
) -> nn.Linear:
    """
    Projects the parameter differences of linear layers into the SVD subspace of the pretrained model.

    Args:
        pretrained_model (nn.Linear): The linear layer of the pretrained model.
        finetuned_model (nn.Linear): The linear layer of the fine-tuned model.
        k (int): The number of singular values to keep. If negative, it is determined based on the sum of singular values.

    Returns:
        nn.Linear: The merged linear layer with projected parameter differences.
    """
    w = pretrained_model.weight
    w_ft = finetuned_model.weight

    u, s, v = svd(w, full_matrices=self.config.full_matrices)
    if k < 0:
        # find the position where the sum of singular values is larger than 50% of the total sum
        cumsum = s.cumsum(0)
        k = (cumsum < cumsum[-1] * 0.5).sum().item() + 1

    if self.config.rank == "low":
        u = u[:, :k]
        s = s[:k]
        v = v[:, :k]
    else:
        u = u[:, k:]
        s = s[k:]
        v = v[:, k:]

    w_diff = w_ft - w
    w_diff_proj = u.T @ w_diff @ v
    w.data = w + u @ w_diff_proj @ v.T
    if pretrained_model.bias is not None:
        pretrained_model.bias.data = finetuned_model.bias.data
    return pretrained_model
run(modelpool)

Run the singular projection merging algorithm on the given model pool.

Parameters:

  • modelpool (ModelPool) –

    The pool of models to merge.

Returns:

  • Module –

    nn.Module: The merged model.

Source code in fusion_bench/method/smile_upscaling/singular_projection_merging.py
@torch.no_grad()
def run(self, modelpool: ModelPool) -> nn.Module:
    """
    Run the singular projection merging algorithm on the given model pool.

    Args:
        modelpool (ModelPool): The pool of models to merge.

    Returns:
        nn.Module: The merged model.
    """
    modelpool = to_modelpool(modelpool)

    if self.config.model_path is not None and os.path.exists(
        self.config.model_path
    ):
        log.info(f"loading merged model from {self.config.model_path}")
        model = torch.load(self.config.model_path)

    with self.profile("load pretrained model"):
        pretrained_model = modelpool.load_model("_pretrained_").to(
            self.config.device
        )
    with self.profile("load fine-tuned model"):
        finetuned_models = modelpool.load_model(modelpool.model_names[0]).to(
            self.config.device
        )

    with self.profile("merge model"):
        model = self.merge(pretrained_model, finetuned_models)

    if self.config.model_path is not None:
        os.path.makedirs(os.path.dirname(self.config.model_path), exist_ok=True)
        torch.save(model, self.config.model_path)

    self.print_profile_summary()
    return model