Skip to content

Model Merging

Linear Interpolation

Simple Average

SimpleAverageAlgorithm

Bases: BaseAlgorithm, SimpleProfilerMixin

Source code in fusion_bench/method/simple_average.py
class SimpleAverageAlgorithm(
    BaseAlgorithm,
    SimpleProfilerMixin,
):
    _config_mapping = BaseAlgorithm._config_mapping | {
        "show_pbar": "show_pbar",
    }

    def __init__(self, show_pbar: bool = False):
        """
        Args:
            show_pbar (bool): If True, shows a progress bar during model loading and merging. Default is False.
        """
        super().__init__()
        self.show_pbar = show_pbar

    @torch.no_grad()
    def run(self, modelpool: Union[BaseModelPool, Dict[str, nn.Module]]):
        """
        Fuse the models in the given model pool using simple averaging.

        This method iterates over the names of the models in the model pool, loads each model, and appends it to a list.
        It then returns the simple average of the models in the list.

        Args:
            modelpool: The pool of models to fuse.

        Returns:
            The fused model obtained by simple averaging.
        """
        if isinstance(modelpool, dict):
            modelpool = BaseModelPool(modelpool)

        log.info(
            f"Fusing models using simple average on {len(modelpool.model_names)} models."
            f"models: {modelpool.model_names}"
        )
        sd: Optional[StateDictType] = None
        forward_model = None
        merged_model_names = []

        for model_name in modelpool.model_names:
            with self.profile("load model"):
                model = modelpool.load_model(model_name)
                merged_model_names.append(model_name)
                print(f"load model of type: {type(model).__name__}")
            with self.profile("merge weights"):
                if sd is None:
                    # Initialize the state dictionary with the first model's state dictionary
                    sd = model.state_dict(keep_vars=True)
                    forward_model = model
                else:
                    # Add the current model's state dictionary to the accumulated state dictionary
                    sd = state_dict_add(
                        sd, model.state_dict(keep_vars=True), show_pbar=self.show_pbar
                    )
        with self.profile("merge weights"):
            # Divide the accumulated state dictionary by the number of models to get the average
            sd = state_dict_div(
                sd, len(modelpool.model_names), show_pbar=self.show_pbar
            )

        if isinstance(forward_model, LazyStateDict):
            # if the model is a LazyStateDict, convert it to an empty module
            forward_model = forward_model.meta_module.to_empty(
                device=(
                    "cpu"
                    if forward_model._torch_dtype is None
                    else forward_model._torch_dtype
                )
            )
        forward_model.load_state_dict(sd)
        # print profile report and log the merged models
        self.print_profile_summary()
        log.info(f"merged {len(merged_model_names)} models:")
        for model_name in merged_model_names:
            log.info(f"  - {model_name}")
        return forward_model
__init__(show_pbar=False)

Parameters:

  • show_pbar (bool, default: False ) –

    If True, shows a progress bar during model loading and merging. Default is False.

Source code in fusion_bench/method/simple_average.py
def __init__(self, show_pbar: bool = False):
    """
    Args:
        show_pbar (bool): If True, shows a progress bar during model loading and merging. Default is False.
    """
    super().__init__()
    self.show_pbar = show_pbar
run(modelpool)

Fuse the models in the given model pool using simple averaging.

This method iterates over the names of the models in the model pool, loads each model, and appends it to a list. It then returns the simple average of the models in the list.

Parameters:

  • modelpool (Union[BaseModelPool, Dict[str, Module]]) –

    The pool of models to fuse.

Returns:

  • The fused model obtained by simple averaging.

Source code in fusion_bench/method/simple_average.py
@torch.no_grad()
def run(self, modelpool: Union[BaseModelPool, Dict[str, nn.Module]]):
    """
    Fuse the models in the given model pool using simple averaging.

    This method iterates over the names of the models in the model pool, loads each model, and appends it to a list.
    It then returns the simple average of the models in the list.

    Args:
        modelpool: The pool of models to fuse.

    Returns:
        The fused model obtained by simple averaging.
    """
    if isinstance(modelpool, dict):
        modelpool = BaseModelPool(modelpool)

    log.info(
        f"Fusing models using simple average on {len(modelpool.model_names)} models."
        f"models: {modelpool.model_names}"
    )
    sd: Optional[StateDictType] = None
    forward_model = None
    merged_model_names = []

    for model_name in modelpool.model_names:
        with self.profile("load model"):
            model = modelpool.load_model(model_name)
            merged_model_names.append(model_name)
            print(f"load model of type: {type(model).__name__}")
        with self.profile("merge weights"):
            if sd is None:
                # Initialize the state dictionary with the first model's state dictionary
                sd = model.state_dict(keep_vars=True)
                forward_model = model
            else:
                # Add the current model's state dictionary to the accumulated state dictionary
                sd = state_dict_add(
                    sd, model.state_dict(keep_vars=True), show_pbar=self.show_pbar
                )
    with self.profile("merge weights"):
        # Divide the accumulated state dictionary by the number of models to get the average
        sd = state_dict_div(
            sd, len(modelpool.model_names), show_pbar=self.show_pbar
        )

    if isinstance(forward_model, LazyStateDict):
        # if the model is a LazyStateDict, convert it to an empty module
        forward_model = forward_model.meta_module.to_empty(
            device=(
                "cpu"
                if forward_model._torch_dtype is None
                else forward_model._torch_dtype
            )
        )
    forward_model.load_state_dict(sd)
    # print profile report and log the merged models
    self.print_profile_summary()
    log.info(f"merged {len(merged_model_names)} models:")
    for model_name in merged_model_names:
        log.info(f"  - {model_name}")
    return forward_model

SimpleAverageForLlama

Bases: BaseAlgorithm

A simple averaging algorithm for LLama models. If merge_backbone is set to True, the backbone of the model will be averaged and the rest of the model will be loaded from the pre-trained model.

Examples:

The following example demonstrates how to use the SimpleAverageForLlama algorithm to merge Mistral models.

fusion_bench \
    method=linear/simple_average_for_llama \
    method.model_save_path=outputs/simle_mixtral_exp_v4/simple_average \
    modelpool=CausalLMPool/simle_mixtral_exp_v4.yaml
Source code in fusion_bench/method/linear/simple_average_for_llama.py
class SimpleAverageForLlama(BaseAlgorithm):
    R"""
    A simple averaging algorithm for LLama models. If `merge_backbone` is set to `True`, the backbone of the model will be averaged and the rest of the model will be loaded from the pre-trained model.

    Examples:
        The following example demonstrates how to use the `SimpleAverageForLlama` algorithm to merge Mistral models.

        ```bash
        fusion_bench \
            method=linear/simple_average_for_llama \
            method.model_save_path=outputs/simle_mixtral_exp_v4/simple_average \
            modelpool=CausalLMPool/simle_mixtral_exp_v4.yaml
        ```
    """

    _config_mapping = BaseAlgorithm._config_mapping | {
        "merge_backbone": "merge_backbone",
        "show_pbar": "show_pbar",
    }

    def __init__(
        self,
        merge_backbone: bool,
        model_save_path: Optional[str] = None,
        show_pbar: bool = False,
    ):
        super().__init__()
        self.merge_backbone = merge_backbone
        self.model_save_path = model_save_path
        self.show_pbar = show_pbar

    @override
    def run(self, modelpool: CausalLMPool):
        if self.model_save_path:
            tokenizer = modelpool.load_tokenizer()

        if self.merge_backbone:
            assert modelpool.has_pretrained
            log.info(
                "Merging backbone of the model pool, use CausalLMBackbonePool instead of CausalLMPool."
            )
            modelpool_config = deepcopy(modelpool.config)
            with flag_override(modelpool_config, "allow_objects", True):
                modelpool_config._target_ = (
                    "fusion_bench.modelpool.causal_lm.CausalLMBackbonePool"
                )
            backbone_modelpool = instantiate(modelpool_config)
            model = modelpool.load_model("_pretrained_")
            backbone_model = SimpleAverageAlgorithm(show_pbar=self.show_pbar).run(
                backbone_modelpool
            )
            model.model.layers = backbone_model
        else:
            model = SimpleAverageAlgorithm(show_pbar=self.show_pbar).run(
                modelpool=modelpool
            )

        if self.model_save_path is not None:
            with timeit_context(f"Saving the model to {self.model_save_path}"):
                tokenizer.save_pretrained(self.model_save_path)
                model.save_pretrained(self.model_save_path)
        return model

Weighted Average

LinearInterpolationAlgorithm

Bases: BaseAlgorithm

LinearInterpolationAlgorithm performs linear interpolation between two models. Returns a model with the state dict that is a linear interpolation of the state dicts of the two models. \(\theta = (1-t) \theta_1 + t \theta_2\)

Source code in fusion_bench/method/linear/linear_interpolation.py
class LinearInterpolationAlgorithm(BaseAlgorithm):
    R"""
    `LinearInterpolationAlgorithm` performs linear interpolation between two models.
    Returns a model with the state dict that is a linear interpolation of the state dicts of the two models.
    $\theta = (1-t) \theta_1 + t \theta_2$
    """

    _config_mapping = BaseAlgorithm._config_mapping | {
        "t": "t",
    }

    def __init__(self, t: float, **kwargs):
        """
        Initialize the `LinearInterpolationAlgorithm` with the given interpolation parameter.

        Args:
            t (float): The interpolation parameter, should be in the range [0, 1].
            **kwargs: Additional keyword arguments.
        """
        assert 0 <= t <= 1, "t should be in the range [0, 1]"
        self.t = t
        super().__init__(**kwargs)

    def run(self, modelpool: BaseModelPool):
        """
        Run the linear interpolation algorithm on the given model pool.

        This method performs linear interpolation between two models in the model pool
        and returns a model with the interpolated state dict.

        Args:
            modelpool (BaseModelPool): The pool of models to interpolate. Must contain exactly two models.

        Returns:
            nn.Module: The model with the interpolated state dict.
        """
        assert (
            modelpool.all_model_names == 2
        ), "linear interpolation expect exactly 2 models"
        primary_model = modelpool.load_model(modelpool.all_model_names[0])
        secondary_model = modelpool.load_model(modelpool.all_model_names[1])

        with torch.no_grad():
            primary_state_dict = primary_model.state_dict()
            secondary_state_dict = secondary_model.state_dict()
            state_dict = state_dict_weighted_sum(
                [primary_state_dict, secondary_state_dict], [1 - self.t, self.t]
            )
        primary_model.load_state_dict(state_dict)
        return primary_model
__init__(t, **kwargs)

Initialize the LinearInterpolationAlgorithm with the given interpolation parameter.

Parameters:

  • t (float) –

    The interpolation parameter, should be in the range [0, 1].

  • **kwargs

    Additional keyword arguments.

Source code in fusion_bench/method/linear/linear_interpolation.py
def __init__(self, t: float, **kwargs):
    """
    Initialize the `LinearInterpolationAlgorithm` with the given interpolation parameter.

    Args:
        t (float): The interpolation parameter, should be in the range [0, 1].
        **kwargs: Additional keyword arguments.
    """
    assert 0 <= t <= 1, "t should be in the range [0, 1]"
    self.t = t
    super().__init__(**kwargs)
run(modelpool)

Run the linear interpolation algorithm on the given model pool.

This method performs linear interpolation between two models in the model pool and returns a model with the interpolated state dict.

Parameters:

  • modelpool (BaseModelPool) –

    The pool of models to interpolate. Must contain exactly two models.

Returns:

  • nn.Module: The model with the interpolated state dict.

Source code in fusion_bench/method/linear/linear_interpolation.py
def run(self, modelpool: BaseModelPool):
    """
    Run the linear interpolation algorithm on the given model pool.

    This method performs linear interpolation between two models in the model pool
    and returns a model with the interpolated state dict.

    Args:
        modelpool (BaseModelPool): The pool of models to interpolate. Must contain exactly two models.

    Returns:
        nn.Module: The model with the interpolated state dict.
    """
    assert (
        modelpool.all_model_names == 2
    ), "linear interpolation expect exactly 2 models"
    primary_model = modelpool.load_model(modelpool.all_model_names[0])
    secondary_model = modelpool.load_model(modelpool.all_model_names[1])

    with torch.no_grad():
        primary_state_dict = primary_model.state_dict()
        secondary_state_dict = secondary_model.state_dict()
        state_dict = state_dict_weighted_sum(
            [primary_state_dict, secondary_state_dict], [1 - self.t, self.t]
        )
    primary_model.load_state_dict(state_dict)
    return primary_model

WeightedAverageAlgorithm

Bases: BaseAlgorithm, SimpleProfilerMixin

Source code in fusion_bench/method/weighted_average/weighted_average.py
class WeightedAverageAlgorithm(BaseAlgorithm, SimpleProfilerMixin):

    _config_mapping = BaseAlgorithm._config_mapping | {
        "normalize": "normalize",
        "weights": "weights",
    }

    def __init__(
        self,
        normalize: bool,
        weights: List[float],
        verbose: bool = True,
        **kwargs,
    ):
        self.normalize = normalize
        self.weights = weights
        self.verbose = verbose
        log.disabled = not self.verbose
        super().__init__(**kwargs)

    @override
    @torch.no_grad()
    def run(self, modelpool: BaseModelPool):
        """
        Fuses the models in the model pool using a weighted average approach.

        Parameters
            modelpool (ModelPool): The pool of models to be fused.

        Raises
            ValueError: If the number of weights does not match the number of models in the model pool.

        Returns
            forward_model (torch.nn.Module): The resulting model after fusion.
        """
        if not isinstance(modelpool, BaseModelPool):
            modelpool = BaseModelPool(modelpool)

        log.info("Fusing models using weighted average.")
        weights = np.asarray(self.weights)
        if len(weights) != len(modelpool.model_names):
            raise ValueError(
                "Number of weights must match the number of models.,"
                f"but got {len(weights)} weights and {len(modelpool.model_names)} models."
                f"weights: {weights}, models: {modelpool.model_names}"
            )
        if self.normalize:
            weights = weights / np.sum(weights)
        if self.verbose:
            print(f"weights: {weights}, normalized: {self.normalize}")

        sd: Optional[StateDictType] = None
        forward_model = None

        for model_name, weight in zip(modelpool.model_names, weights):
            with self.profile("load_model"):
                model = modelpool.load_model(model_name)
            with self.profile("merge weights"):
                if sd is None:
                    sd = state_dict_mul(model.state_dict(keep_vars=True), weight)
                    forward_model = model
                else:
                    sd = state_dict_add(
                        sd, state_dict_mul(model.state_dict(keep_vars=True), weight)
                    )

        forward_model.load_state_dict(sd)
        if self.verbose:
            self.print_profile_summary()
        return forward_model
run(modelpool)

Fuses the models in the model pool using a weighted average approach.

Parameters modelpool (ModelPool): The pool of models to be fused.

Raises ValueError: If the number of weights does not match the number of models in the model pool.

Returns forward_model (torch.nn.Module): The resulting model after fusion.

Source code in fusion_bench/method/weighted_average/weighted_average.py
@override
@torch.no_grad()
def run(self, modelpool: BaseModelPool):
    """
    Fuses the models in the model pool using a weighted average approach.

    Parameters
        modelpool (ModelPool): The pool of models to be fused.

    Raises
        ValueError: If the number of weights does not match the number of models in the model pool.

    Returns
        forward_model (torch.nn.Module): The resulting model after fusion.
    """
    if not isinstance(modelpool, BaseModelPool):
        modelpool = BaseModelPool(modelpool)

    log.info("Fusing models using weighted average.")
    weights = np.asarray(self.weights)
    if len(weights) != len(modelpool.model_names):
        raise ValueError(
            "Number of weights must match the number of models.,"
            f"but got {len(weights)} weights and {len(modelpool.model_names)} models."
            f"weights: {weights}, models: {modelpool.model_names}"
        )
    if self.normalize:
        weights = weights / np.sum(weights)
    if self.verbose:
        print(f"weights: {weights}, normalized: {self.normalize}")

    sd: Optional[StateDictType] = None
    forward_model = None

    for model_name, weight in zip(modelpool.model_names, weights):
        with self.profile("load_model"):
            model = modelpool.load_model(model_name)
        with self.profile("merge weights"):
            if sd is None:
                sd = state_dict_mul(model.state_dict(keep_vars=True), weight)
                forward_model = model
            else:
                sd = state_dict_add(
                    sd, state_dict_mul(model.state_dict(keep_vars=True), weight)
                )

    forward_model.load_state_dict(sd)
    if self.verbose:
        self.print_profile_summary()
    return forward_model

WeightedAverageForLLama

Bases: BaseAlgorithm

A class to perform weighted averaging of LlaMa/Mistral models.

Source code in fusion_bench/method/weighted_average/llama.py
class WeightedAverageForLLama(BaseAlgorithm):
    """
    A class to perform weighted averaging of LlaMa/Mistral models.
    """

    _config_mapping = BaseAlgorithm._config_mapping | {
        "normalize": "normalize",
        "weights": "weights",
        "backbone_only": "backbone_only",
        "merged_model_save_path": "merged_model_save_path",
        "save_tokenizer": "save_tokenizer",
        "push_to_hub": "push_to_hub",
    }

    def __init__(
        self,
        normalize: bool,
        weights: List[float],
        backbone_only: bool,
        merged_model_save_path: str,
        save_tokenizer: bool,
        push_to_hub: bool,
        **kwargs,
    ):
        """
        Initialize the WeightedAverageForLLama class with the given parameters.

        Args:
            normalize (bool): Whether to normalize the weights.
            weights (List[float]): The weights for averaging the models.
            backbone_only (bool): Whether to use only the backbone of the models.
            merged_model_save_path (str): The path to save the merged model.
            save_tokenizer (bool): Whether to save the tokenizer.
            push_to_hub (bool): Whether to push the model to the hub.
        """
        self.normalize = normalize
        self.weights = weights
        self.backbone_only = backbone_only
        self.merged_model_save_path = merged_model_save_path
        self.save_tokenizer = save_tokenizer
        self.push_to_hub = push_to_hub
        super().__init__(**kwargs)

    @override
    @torch.no_grad()
    def run(self, modelpool: CausalLMPool):
        """
        Executes the weighted averaging of models in the provided model pool.

        Args:
            modelpool (LLamaForCausalLMPoolThe):  pool of models to be averaged.

        Returns:
            base_model: The base model after merging the state dictionaries of the models in the pool.

        Raises:
            ValueError: If the number of weights does not match the number of models in the pool.
        """
        if modelpool.has_pretrained:
            base_model = modelpool.load_model("_pretrained_")
        else:
            base_model = modelpool.load_model(modelpool.model_names[0])

        weights = self.weights
        if len(weights) != len(modelpool.model_names):
            raise ValueError(
                "Number of weights must match the number of models.,"
                f"but got {len(weights)} weights and {len(modelpool.model_names)} models."
                f"weights: {weights}, models: {modelpool.model_names}"
            )
        if self.normalize:
            weights = np.asarray(weights)
            weights = weights / np.sum(weights)

        merged_state_dict: StateDictType = None
        for model_name, weight in zip(modelpool.model_names, weights):
            model = modelpool.load_model(model_name, backbone_only=self.backbone_only)
            sd = state_dict_mul(model.state_dict(), weight)
            if merged_state_dict is None:
                merged_state_dict = sd
            else:
                merged_state_dict = state_dict_add(merged_state_dict, sd)

        base_model.load_state_dict(
            merged_state_dict, strict=False if self.backbone_only else True
        )
        if self.merged_model_save_path is not None:
            with timeit_context(
                f"Saving the merged model to {self.merged_model_save_path}"
            ):
                modelpool.save_model(
                    base_model,
                    path=self.merged_model_save_path,
                    save_tokenizer=self.save_tokenizer,
                    push_to_hub=self.push_to_hub,
                )
        return base_model
__init__(normalize, weights, backbone_only, merged_model_save_path, save_tokenizer, push_to_hub, **kwargs)

Initialize the WeightedAverageForLLama class with the given parameters.

Parameters:

  • normalize (bool) –

    Whether to normalize the weights.

  • weights (List[float]) –

    The weights for averaging the models.

  • backbone_only (bool) –

    Whether to use only the backbone of the models.

  • merged_model_save_path (str) –

    The path to save the merged model.

  • save_tokenizer (bool) –

    Whether to save the tokenizer.

  • push_to_hub (bool) –

    Whether to push the model to the hub.

Source code in fusion_bench/method/weighted_average/llama.py
def __init__(
    self,
    normalize: bool,
    weights: List[float],
    backbone_only: bool,
    merged_model_save_path: str,
    save_tokenizer: bool,
    push_to_hub: bool,
    **kwargs,
):
    """
    Initialize the WeightedAverageForLLama class with the given parameters.

    Args:
        normalize (bool): Whether to normalize the weights.
        weights (List[float]): The weights for averaging the models.
        backbone_only (bool): Whether to use only the backbone of the models.
        merged_model_save_path (str): The path to save the merged model.
        save_tokenizer (bool): Whether to save the tokenizer.
        push_to_hub (bool): Whether to push the model to the hub.
    """
    self.normalize = normalize
    self.weights = weights
    self.backbone_only = backbone_only
    self.merged_model_save_path = merged_model_save_path
    self.save_tokenizer = save_tokenizer
    self.push_to_hub = push_to_hub
    super().__init__(**kwargs)
run(modelpool)

Executes the weighted averaging of models in the provided model pool.

Parameters:

  • modelpool (LLamaForCausalLMPoolThe) –

    pool of models to be averaged.

Returns:

  • base_model

    The base model after merging the state dictionaries of the models in the pool.

Raises:

  • ValueError

    If the number of weights does not match the number of models in the pool.

Source code in fusion_bench/method/weighted_average/llama.py
@override
@torch.no_grad()
def run(self, modelpool: CausalLMPool):
    """
    Executes the weighted averaging of models in the provided model pool.

    Args:
        modelpool (LLamaForCausalLMPoolThe):  pool of models to be averaged.

    Returns:
        base_model: The base model after merging the state dictionaries of the models in the pool.

    Raises:
        ValueError: If the number of weights does not match the number of models in the pool.
    """
    if modelpool.has_pretrained:
        base_model = modelpool.load_model("_pretrained_")
    else:
        base_model = modelpool.load_model(modelpool.model_names[0])

    weights = self.weights
    if len(weights) != len(modelpool.model_names):
        raise ValueError(
            "Number of weights must match the number of models.,"
            f"but got {len(weights)} weights and {len(modelpool.model_names)} models."
            f"weights: {weights}, models: {modelpool.model_names}"
        )
    if self.normalize:
        weights = np.asarray(weights)
        weights = weights / np.sum(weights)

    merged_state_dict: StateDictType = None
    for model_name, weight in zip(modelpool.model_names, weights):
        model = modelpool.load_model(model_name, backbone_only=self.backbone_only)
        sd = state_dict_mul(model.state_dict(), weight)
        if merged_state_dict is None:
            merged_state_dict = sd
        else:
            merged_state_dict = state_dict_add(merged_state_dict, sd)

    base_model.load_state_dict(
        merged_state_dict, strict=False if self.backbone_only else True
    )
    if self.merged_model_save_path is not None:
        with timeit_context(
            f"Saving the merged model to {self.merged_model_save_path}"
        ):
            modelpool.save_model(
                base_model,
                path=self.merged_model_save_path,
                save_tokenizer=self.save_tokenizer,
                push_to_hub=self.push_to_hub,
            )
    return base_model

Spherical Linear Interpolation (Slerp)

SlerpMergeAlgorithm

Bases: BaseAlgorithm

General purpose implementation of Slerp (Spherical Linear Interpolation) for PyTorch models.

Source code in fusion_bench/method/slerp/slerp.py
class SlerpMergeAlgorithm(BaseAlgorithm):
    """
    General purpose implementation of Slerp (Spherical Linear Interpolation) for PyTorch models.
    """

    _config_mapping = BaseAlgorithm._config_mapping | {
        "t": "t",
        "DOT_THRESHOLD": "DOT_THRESHOLD",
        "epsilon": "epsilon",
    }

    def __init__(self, t: float, DOT_THRESHOLD: float = 0.9995, epsilon: float = 1e-8):
        """
        Initialize the SlerpMergeAlgorithm.

        Args:
            t (float): The interpolation parameter. Must be in the range [0, 1].
            DOT_THRESHOLD (float, optional): The threshold for the dot product of the two vectors. Defaults to 0.9995.
            epsilon (float, optional): The epsilon value for numerical stability. Defaults to 1e-8.
        """
        self.t = t
        self.DOT_THRESHOLD = DOT_THRESHOLD
        self.epsilon = epsilon
        super().__init__()

    @override
    def run(self, modelpool: BaseModelPool):
        """
        Run the SlerpMergeAlgorithm on the given model pool.

        Args:
            modelpool (BaseModelPool): The pool of models to fuse.

        Returns:
            nn.Module: The fused model.
        """
        assert len(modelpool.all_model_names) == 2, "Slerp expect exactly 2 models"
        primary_model = modelpool.load_model(modelpool.all_model_names[0])
        secondary_model = modelpool.load_model(modelpool.all_model_names[1])

        with torch.no_grad():
            primary_state_dict = primary_model.state_dict()
            secondary_state_dict = secondary_model.state_dict()
            state_dict = slerp_on_state_dicts(
                self.t,
                primary_state_dict,
                secondary_state_dict,
                DOT_THRESHOLD=self.DOT_THRESHOLD,
                epsilon=self.epsilon,
            )

        primary_model.load_state_dict(state_dict)
        return primary_model
__init__(t, DOT_THRESHOLD=0.9995, epsilon=1e-08)

Initialize the SlerpMergeAlgorithm.

Parameters:

  • t (float) –

    The interpolation parameter. Must be in the range [0, 1].

  • DOT_THRESHOLD (float, default: 0.9995 ) –

    The threshold for the dot product of the two vectors. Defaults to 0.9995.

  • epsilon (float, default: 1e-08 ) –

    The epsilon value for numerical stability. Defaults to 1e-8.

Source code in fusion_bench/method/slerp/slerp.py
def __init__(self, t: float, DOT_THRESHOLD: float = 0.9995, epsilon: float = 1e-8):
    """
    Initialize the SlerpMergeAlgorithm.

    Args:
        t (float): The interpolation parameter. Must be in the range [0, 1].
        DOT_THRESHOLD (float, optional): The threshold for the dot product of the two vectors. Defaults to 0.9995.
        epsilon (float, optional): The epsilon value for numerical stability. Defaults to 1e-8.
    """
    self.t = t
    self.DOT_THRESHOLD = DOT_THRESHOLD
    self.epsilon = epsilon
    super().__init__()
run(modelpool)

Run the SlerpMergeAlgorithm on the given model pool.

Parameters:

Returns:

  • nn.Module: The fused model.

Source code in fusion_bench/method/slerp/slerp.py
@override
def run(self, modelpool: BaseModelPool):
    """
    Run the SlerpMergeAlgorithm on the given model pool.

    Args:
        modelpool (BaseModelPool): The pool of models to fuse.

    Returns:
        nn.Module: The fused model.
    """
    assert len(modelpool.all_model_names) == 2, "Slerp expect exactly 2 models"
    primary_model = modelpool.load_model(modelpool.all_model_names[0])
    secondary_model = modelpool.load_model(modelpool.all_model_names[1])

    with torch.no_grad():
        primary_state_dict = primary_model.state_dict()
        secondary_state_dict = secondary_model.state_dict()
        state_dict = slerp_on_state_dicts(
            self.t,
            primary_state_dict,
            secondary_state_dict,
            DOT_THRESHOLD=self.DOT_THRESHOLD,
            epsilon=self.epsilon,
        )

    primary_model.load_state_dict(state_dict)
    return primary_model

Task Arithmetic

TaskArithmeticForLlama

Bases: TaskArithmeticAlgorithm, SimpleProfilerMixin

Examples:

fusion_bench \ method=linear/task_arithmetic_for_llama \ method.scaling_factor=0.3 \ method.model_save_path=outputs/simle_mixtral_exp_v4/task_arithmetic_0.3 \ modelpool=CausalLMPool/simle_mixtral_exp_v4.yaml

Source code in fusion_bench/method/linear/task_arithmetic_for_llama.py
class TaskArithmeticForLlama(TaskArithmeticAlgorithm, SimpleProfilerMixin):
    R"""
    Examples:

    fusion_bench \
        method=linear/task_arithmetic_for_llama \
            method.scaling_factor=0.3 \
        method.model_save_path=outputs/simle_mixtral_exp_v4/task_arithmetic_0.3 \
        modelpool=CausalLMPool/simle_mixtral_exp_v4.yaml
    """

    _config_mapping = TaskArithmeticAlgorithm._config_mapping | {
        "merge_backbone": "merge_backbone",
    }

    def __init__(
        self,
        scaling_factor: float,
        merge_backbone: bool,
        model_save_path: Optional[str] = None,
    ):
        self.merge_backbone = merge_backbone
        self.model_save_path = model_save_path
        super().__init__(scaling_factor=scaling_factor)

    @override
    def run(self, modelpool: CausalLMPool):
        if self.model_save_path:
            tokenizer = modelpool.load_tokenizer()

        if self.merge_backbone:
            assert modelpool.has_pretrained
            backbone_modelpool = CausalLMBackbonePool(**modelpool.config)
            model = modelpool.load_model("_pretrained_")
            backbone_model = super().run(backbone_modelpool)
            model.model.layers = backbone_model
        else:
            model = super().run(modelpool)

        if self.model_save_path is not None:
            with timeit_context(f"Saving the model to {self.model_save_path}"):
                tokenizer.save_pretrained(self.model_save_path)
                model.save_pretrained(self.model_save_path)
        return model

Ties-Merging

TiesMergingAlgorithm

Bases: BaseAlgorithm, SimpleProfilerMixin

TiesMergingAlgorithm is a class for fusing multiple models using the TIES merging technique.

Attributes:

  • scaling_factor (float) –

    The scaling factor to apply to the merged task vector.

  • threshold (float) –

    The threshold for resetting values in the task vector.

  • remove_keys (List[str]) –

    List of keys to remove from the state dictionary.

  • merge_func (Literal['sum', 'mean', 'max']) –

    The merge function to use for disjoint merging.

Source code in fusion_bench/method/ties_merging/ties_merging.py
class TiesMergingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
    """
    TiesMergingAlgorithm is a class for fusing multiple models using the TIES merging technique.

    Attributes:
        scaling_factor (float): The scaling factor to apply to the merged task vector.
        threshold (float): The threshold for resetting values in the task vector.
        remove_keys (List[str]): List of keys to remove from the state dictionary.
        merge_func (Literal["sum", "mean", "max"]): The merge function to use for disjoint merging.
    """

    _config_mapping = BaseAlgorithm._config_mapping | {
        "scaling_factor": "scaling_factor",
        "threshold": "threshold",
        "remove_keys": "remove_keys",
        "merge_func": "merge_func",
    }

    def __init__(
        self,
        scaling_factor: float,
        threshold: float,
        remove_keys: List[str],
        merge_func: Literal["sum", "mean", "max"],
        **kwargs,
    ):
        """
        Initialize the TiesMergingAlgorithm with the given parameters.

        Args:
            scaling_factor (float): The scaling factor to apply to the merged task vector.
            threshold (float): The threshold for resetting values in the task vector.
            remove_keys (List[str]): List of keys to remove from the state dictionary.
            merge_func (Literal["sum", "mean", "max"]): The merge function to use for disjoint merging.
            **kwargs: Additional keyword arguments for the base class.
        """
        self.scaling_factor = scaling_factor
        self.threshold = threshold
        self.remove_keys = remove_keys
        self.merge_func = merge_func
        super().__init__(**kwargs)

    @torch.no_grad()
    def run(self, modelpool: BaseModelPool | Dict[str, nn.Module], **kwargs):
        """
        Run the TIES merging algorithm to fuse models in the model pool.

        Args:
            modelpool (BaseModelPool | Dict[str, nn.Module]): The model pool containing the models to fuse.

        Returns:
            nn.Module: The fused model.
        """
        log.info("Fusing models using ties merging.")
        modelpool = to_modelpool(modelpool)
        remove_keys = self.config.get("remove_keys", [])
        merge_func = self.config.get("merge_func", "sum")
        scaling_factor = self.scaling_factor
        threshold = self.threshold

        with self.profile("loading models"):
            # Load the pretrained model
            pretrained_model = modelpool.load_model("_pretrained_")

            # Load the state dicts of the models
            ft_checks: List[StateDictType] = [
                modelpool.load_model(model_name).state_dict(keep_vars=True)
                for model_name in modelpool.model_names
            ]
            ptm_check: StateDictType = pretrained_model.state_dict(keep_vars=True)

        with self.profile("merging models"):
            # Compute the task vectors
            flat_ft: Tensor = torch.vstack(
                [state_dict_to_vector(check, remove_keys) for check in ft_checks]
            )
            flat_ptm: Tensor = state_dict_to_vector(ptm_check, remove_keys)
            tv_flat_checks = flat_ft - flat_ptm

            # Perform TIES Merging
            merged_tv = ties_merging(
                tv_flat_checks,
                reset_thresh=threshold,
                merge_func=merge_func,
            )
            merged_check = flat_ptm + scaling_factor * merged_tv
            merged_state_dict = vector_to_state_dict(
                merged_check, ptm_check, remove_keys=remove_keys
            )

            # Load the merged state dict into the pretrained model
            pretrained_model.load_state_dict(merged_state_dict)

        self.print_profile_summary()
        return pretrained_model
__init__(scaling_factor, threshold, remove_keys, merge_func, **kwargs)

Initialize the TiesMergingAlgorithm with the given parameters.

Parameters:

  • scaling_factor (float) –

    The scaling factor to apply to the merged task vector.

  • threshold (float) –

    The threshold for resetting values in the task vector.

  • remove_keys (List[str]) –

    List of keys to remove from the state dictionary.

  • merge_func (Literal['sum', 'mean', 'max']) –

    The merge function to use for disjoint merging.

  • **kwargs

    Additional keyword arguments for the base class.

Source code in fusion_bench/method/ties_merging/ties_merging.py
def __init__(
    self,
    scaling_factor: float,
    threshold: float,
    remove_keys: List[str],
    merge_func: Literal["sum", "mean", "max"],
    **kwargs,
):
    """
    Initialize the TiesMergingAlgorithm with the given parameters.

    Args:
        scaling_factor (float): The scaling factor to apply to the merged task vector.
        threshold (float): The threshold for resetting values in the task vector.
        remove_keys (List[str]): List of keys to remove from the state dictionary.
        merge_func (Literal["sum", "mean", "max"]): The merge function to use for disjoint merging.
        **kwargs: Additional keyword arguments for the base class.
    """
    self.scaling_factor = scaling_factor
    self.threshold = threshold
    self.remove_keys = remove_keys
    self.merge_func = merge_func
    super().__init__(**kwargs)
run(modelpool, **kwargs)

Run the TIES merging algorithm to fuse models in the model pool.

Parameters:

  • modelpool (BaseModelPool | Dict[str, Module]) –

    The model pool containing the models to fuse.

Returns:

  • nn.Module: The fused model.

Source code in fusion_bench/method/ties_merging/ties_merging.py
@torch.no_grad()
def run(self, modelpool: BaseModelPool | Dict[str, nn.Module], **kwargs):
    """
    Run the TIES merging algorithm to fuse models in the model pool.

    Args:
        modelpool (BaseModelPool | Dict[str, nn.Module]): The model pool containing the models to fuse.

    Returns:
        nn.Module: The fused model.
    """
    log.info("Fusing models using ties merging.")
    modelpool = to_modelpool(modelpool)
    remove_keys = self.config.get("remove_keys", [])
    merge_func = self.config.get("merge_func", "sum")
    scaling_factor = self.scaling_factor
    threshold = self.threshold

    with self.profile("loading models"):
        # Load the pretrained model
        pretrained_model = modelpool.load_model("_pretrained_")

        # Load the state dicts of the models
        ft_checks: List[StateDictType] = [
            modelpool.load_model(model_name).state_dict(keep_vars=True)
            for model_name in modelpool.model_names
        ]
        ptm_check: StateDictType = pretrained_model.state_dict(keep_vars=True)

    with self.profile("merging models"):
        # Compute the task vectors
        flat_ft: Tensor = torch.vstack(
            [state_dict_to_vector(check, remove_keys) for check in ft_checks]
        )
        flat_ptm: Tensor = state_dict_to_vector(ptm_check, remove_keys)
        tv_flat_checks = flat_ft - flat_ptm

        # Perform TIES Merging
        merged_tv = ties_merging(
            tv_flat_checks,
            reset_thresh=threshold,
            merge_func=merge_func,
        )
        merged_check = flat_ptm + scaling_factor * merged_tv
        merged_state_dict = vector_to_state_dict(
            merged_check, ptm_check, remove_keys=remove_keys
        )

        # Load the merged state dict into the pretrained model
        pretrained_model.load_state_dict(merged_state_dict)

    self.print_profile_summary()
    return pretrained_model

Fisher Merging

FisherMergingForCLIPVisionModel

Bases: CLIPClassificationMixin, FisherMergingAlgorithm

Implements Fisher Merging for CLIP Vision Models.

This class extends the FisherMergingAlgorithm and CLIPClassificationMixin to handle the specifics of merging CLIP Vision models using Fisher weights.

Source code in fusion_bench/method/fisher_merging/clip_fisher_merging.py
class FisherMergingForCLIPVisionModel(
    CLIPClassificationMixin,
    FisherMergingAlgorithm,
):
    """
    Implements Fisher Merging for CLIP Vision Models.

    This class extends the FisherMergingAlgorithm and CLIPClassificationMixin to handle
    the specifics of merging CLIP Vision models using Fisher weights.
    """

    _clip_processor: CLIPProcessor = None
    zeroshot_weights = {}

    _config_mapping = FisherMergingAlgorithm._config_mapping | {
        "zeroshot_weights_cache_dir": "zeroshot_weights_cache_dir",
        "_dataloader_kwargs": "dataloader_kwargs",
    }

    def __init__(
        self,
        *,
        exclude_param_names_regex,
        normalize_fisher_weight,
        minimal_fisher_weight,
        num_fisher_examples,
        dataloader_kwargs: DictConfig,
        zeroshot_weights_cache_dir=None,
        **kwargs,
    ):
        """
        Initialize the FisherMergingForCLIPVisionModel with the given configuration.

        Args:
            exclude_param_names_regex (list): List of regex patterns to exclude certain parameter names.
            normalize_fisher_weight (bool): Whether to normalize Fisher weights.
            minimal_fisher_weight (float): Minimal value for Fisher weights to avoid numerical issues.
            num_fisher_examples (int): Number of examples to compute Fisher weights.
            dataloader_kwargs (DictConfig): Configuration for the dataloader.
            zeroshot_weights_cache_dir (str, optional): Directory to cache zero-shot weights. Defaults to None.
            **kwargs: Additional keyword arguments.
        """
        super().__init__(
            exclude_param_names_regex=exclude_param_names_regex,
            normalize_fisher_weight=normalize_fisher_weight,
            minimal_fisher_weight=minimal_fisher_weight,
            num_fisher_examples=num_fisher_examples,
        )
        self._dataloader_kwargs = dataloader_kwargs
        self.zeroshot_weights_cache_dir = zeroshot_weights_cache_dir
        for key, value in kwargs.items():
            log.warning(f"Unused argument: {key}={value}")
            setattr(self, key, value)

    def on_fisher_merging_start(self):
        """
        Setup the zero-shot classification head before starting the Fisher merging process.
        """
        self.setup_zero_shot_classification_head()

    def compute_logits(self, module, batch, task: str) -> Tensor:
        """
        Compute the logits for the given images and task.

        Args:
            module (Module): The model module.
            batch (tuple): A batch of data containing images and labels.
            task (str): The name of the task.

        Returns:
            Tensor: The computed logits.
        """
        images, _ = batch
        text_embeds = self.zeroshot_weights[task]

        image_embeds = module(images)[1]
        image_embeds = self.visual_projection(image_embeds)

        # normalize embeddings
        image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)

        # cosine similarity
        logits_per_text = (
            torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
        )
        logits_per_image = logits_per_text.t()

        return logits_per_image

    def get_fisher_weights(
        self,
        model_name: str,
        model: Module,
        train_dataset,
        param_names_to_merge: List[str],
    ) -> Dict[str, Tensor]:
        """
        Compute the Fisher weights for the given model and training dataset.

        Args:
            model_name (str): The name of the model.
            model (Module): The model module.
            train_dataset: The training dataset.
            param_names_to_merge (List[str]): List of parameter names to merge.

        Returns:
            Dict[str, Tensor]: The computed Fisher weights for each parameter.
        """
        # setup dataloader
        train_dataset = CLIPDataset(train_dataset, self.clip_processor)
        train_dataloader = DataLoader(train_dataset, **self._dataloader_kwargs)
        if self.fabric is not None:
            train_dataloader = self.fabric.setup_dataloaders(train_dataloader)
            model = self.fabric.setup(model)
        num_fisher_examples = self.config.num_fisher_examples
        if num_fisher_examples % train_dataloader.batch_size != 0:
            print(
                "warning: the number of examples for computing fisher cannot be fully divided by the batch size for model, "
                "which may lead to a slightly different number of the actually used examples."
            )
        num_computed_examples = 0
        batches_fisher_weights_list = []
        for step, batch in tqdm(
            enumerate(train_dataloader),
            desc="computing fisher weights",
            total=num_fisher_examples // train_dataloader.batch_size,
        ):
            if num_computed_examples >= num_fisher_examples:
                break
            logits = self.compute_logits(model, batch, model_name)
            # Tensor, shape (batch_size, num_label_classes)

            # compute fisher weights for classification task
            # use detach() to detach from the computation graph
            # Tensor, shape (batch_size, num_label_classes)
            labels_probabilities = torch.softmax(logits, dim=-1).detach()
            labels_log_probabilities = torch.log_softmax(logits, dim=-1)
            # sqrt labels_probabilities, since torch.sqrt(labels_probabilities) would be squared in the following squared gradients
            labels_expectations = (
                torch.sqrt(labels_probabilities) * labels_log_probabilities
            )
            # sum over label classes and batch dimension
            sum_labels_expectations = labels_expectations.sum(dim=-1).sum(dim=0)
            model.zero_grad()
            sum_labels_expectations.backward()
            # dict, fisher weights of a batch
            batch_fisher_weights = get_param_squared_gradients(
                model=model, param_names_to_merge=param_names_to_merge
            )

            # move fisher weights to cpu to save GPU memory
            for key, weights in batch_fisher_weights.items():
                batch_fisher_weights[key] = weights.detach().cpu()

            batches_fisher_weights_list.append(batch_fisher_weights)
            num_computed_examples += batch[0].size(0)

        model_to_merge_fisher_weights = {}
        for batch_fisher_weights in batches_fisher_weights_list:
            for key in batch_fisher_weights:
                if key not in model_to_merge_fisher_weights:
                    model_to_merge_fisher_weights[key] = batch_fisher_weights[key]
                else:
                    model_to_merge_fisher_weights[key] += batch_fisher_weights[key]

        # mean over batches
        for key in model_to_merge_fisher_weights:
            model_to_merge_fisher_weights[key] /= num_computed_examples
            model_to_merge_fisher_weights[key] = (
                model_to_merge_fisher_weights[key].detach().cpu()
            )
        return model_to_merge_fisher_weights
__init__(*, exclude_param_names_regex, normalize_fisher_weight, minimal_fisher_weight, num_fisher_examples, dataloader_kwargs, zeroshot_weights_cache_dir=None, **kwargs)

Initialize the FisherMergingForCLIPVisionModel with the given configuration.

Parameters:

  • exclude_param_names_regex (list) –

    List of regex patterns to exclude certain parameter names.

  • normalize_fisher_weight (bool) –

    Whether to normalize Fisher weights.

  • minimal_fisher_weight (float) –

    Minimal value for Fisher weights to avoid numerical issues.

  • num_fisher_examples (int) –

    Number of examples to compute Fisher weights.

  • dataloader_kwargs (DictConfig) –

    Configuration for the dataloader.

  • zeroshot_weights_cache_dir (str, default: None ) –

    Directory to cache zero-shot weights. Defaults to None.

  • **kwargs

    Additional keyword arguments.

Source code in fusion_bench/method/fisher_merging/clip_fisher_merging.py
def __init__(
    self,
    *,
    exclude_param_names_regex,
    normalize_fisher_weight,
    minimal_fisher_weight,
    num_fisher_examples,
    dataloader_kwargs: DictConfig,
    zeroshot_weights_cache_dir=None,
    **kwargs,
):
    """
    Initialize the FisherMergingForCLIPVisionModel with the given configuration.

    Args:
        exclude_param_names_regex (list): List of regex patterns to exclude certain parameter names.
        normalize_fisher_weight (bool): Whether to normalize Fisher weights.
        minimal_fisher_weight (float): Minimal value for Fisher weights to avoid numerical issues.
        num_fisher_examples (int): Number of examples to compute Fisher weights.
        dataloader_kwargs (DictConfig): Configuration for the dataloader.
        zeroshot_weights_cache_dir (str, optional): Directory to cache zero-shot weights. Defaults to None.
        **kwargs: Additional keyword arguments.
    """
    super().__init__(
        exclude_param_names_regex=exclude_param_names_regex,
        normalize_fisher_weight=normalize_fisher_weight,
        minimal_fisher_weight=minimal_fisher_weight,
        num_fisher_examples=num_fisher_examples,
    )
    self._dataloader_kwargs = dataloader_kwargs
    self.zeroshot_weights_cache_dir = zeroshot_weights_cache_dir
    for key, value in kwargs.items():
        log.warning(f"Unused argument: {key}={value}")
        setattr(self, key, value)
compute_logits(module, batch, task)

Compute the logits for the given images and task.

Parameters:

  • module (Module) –

    The model module.

  • batch (tuple) –

    A batch of data containing images and labels.

  • task (str) –

    The name of the task.

Returns:

  • Tensor ( Tensor ) –

    The computed logits.

Source code in fusion_bench/method/fisher_merging/clip_fisher_merging.py
def compute_logits(self, module, batch, task: str) -> Tensor:
    """
    Compute the logits for the given images and task.

    Args:
        module (Module): The model module.
        batch (tuple): A batch of data containing images and labels.
        task (str): The name of the task.

    Returns:
        Tensor: The computed logits.
    """
    images, _ = batch
    text_embeds = self.zeroshot_weights[task]

    image_embeds = module(images)[1]
    image_embeds = self.visual_projection(image_embeds)

    # normalize embeddings
    image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)

    # cosine similarity
    logits_per_text = (
        torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
    )
    logits_per_image = logits_per_text.t()

    return logits_per_image
get_fisher_weights(model_name, model, train_dataset, param_names_to_merge)

Compute the Fisher weights for the given model and training dataset.

Parameters:

  • model_name (str) –

    The name of the model.

  • model (Module) –

    The model module.

  • train_dataset

    The training dataset.

  • param_names_to_merge (List[str]) –

    List of parameter names to merge.

Returns:

  • Dict[str, Tensor]

    Dict[str, Tensor]: The computed Fisher weights for each parameter.

Source code in fusion_bench/method/fisher_merging/clip_fisher_merging.py
def get_fisher_weights(
    self,
    model_name: str,
    model: Module,
    train_dataset,
    param_names_to_merge: List[str],
) -> Dict[str, Tensor]:
    """
    Compute the Fisher weights for the given model and training dataset.

    Args:
        model_name (str): The name of the model.
        model (Module): The model module.
        train_dataset: The training dataset.
        param_names_to_merge (List[str]): List of parameter names to merge.

    Returns:
        Dict[str, Tensor]: The computed Fisher weights for each parameter.
    """
    # setup dataloader
    train_dataset = CLIPDataset(train_dataset, self.clip_processor)
    train_dataloader = DataLoader(train_dataset, **self._dataloader_kwargs)
    if self.fabric is not None:
        train_dataloader = self.fabric.setup_dataloaders(train_dataloader)
        model = self.fabric.setup(model)
    num_fisher_examples = self.config.num_fisher_examples
    if num_fisher_examples % train_dataloader.batch_size != 0:
        print(
            "warning: the number of examples for computing fisher cannot be fully divided by the batch size for model, "
            "which may lead to a slightly different number of the actually used examples."
        )
    num_computed_examples = 0
    batches_fisher_weights_list = []
    for step, batch in tqdm(
        enumerate(train_dataloader),
        desc="computing fisher weights",
        total=num_fisher_examples // train_dataloader.batch_size,
    ):
        if num_computed_examples >= num_fisher_examples:
            break
        logits = self.compute_logits(model, batch, model_name)
        # Tensor, shape (batch_size, num_label_classes)

        # compute fisher weights for classification task
        # use detach() to detach from the computation graph
        # Tensor, shape (batch_size, num_label_classes)
        labels_probabilities = torch.softmax(logits, dim=-1).detach()
        labels_log_probabilities = torch.log_softmax(logits, dim=-1)
        # sqrt labels_probabilities, since torch.sqrt(labels_probabilities) would be squared in the following squared gradients
        labels_expectations = (
            torch.sqrt(labels_probabilities) * labels_log_probabilities
        )
        # sum over label classes and batch dimension
        sum_labels_expectations = labels_expectations.sum(dim=-1).sum(dim=0)
        model.zero_grad()
        sum_labels_expectations.backward()
        # dict, fisher weights of a batch
        batch_fisher_weights = get_param_squared_gradients(
            model=model, param_names_to_merge=param_names_to_merge
        )

        # move fisher weights to cpu to save GPU memory
        for key, weights in batch_fisher_weights.items():
            batch_fisher_weights[key] = weights.detach().cpu()

        batches_fisher_weights_list.append(batch_fisher_weights)
        num_computed_examples += batch[0].size(0)

    model_to_merge_fisher_weights = {}
    for batch_fisher_weights in batches_fisher_weights_list:
        for key in batch_fisher_weights:
            if key not in model_to_merge_fisher_weights:
                model_to_merge_fisher_weights[key] = batch_fisher_weights[key]
            else:
                model_to_merge_fisher_weights[key] += batch_fisher_weights[key]

    # mean over batches
    for key in model_to_merge_fisher_weights:
        model_to_merge_fisher_weights[key] /= num_computed_examples
        model_to_merge_fisher_weights[key] = (
            model_to_merge_fisher_weights[key].detach().cpu()
        )
    return model_to_merge_fisher_weights
on_fisher_merging_start()

Setup the zero-shot classification head before starting the Fisher merging process.

Source code in fusion_bench/method/fisher_merging/clip_fisher_merging.py
def on_fisher_merging_start(self):
    """
    Setup the zero-shot classification head before starting the Fisher merging process.
    """
    self.setup_zero_shot_classification_head()

Drop And REscale (DARE)

DareSimpleAverage

Bases: BaseAlgorithm

Source code in fusion_bench/method/dare/simple_average.py
class DareSimpleAverage(BaseAlgorithm):

    def __init__(
        self,
        sparsity_ratio: float,
        only_on_linear_weights: bool,
        rescale: bool = True,
        **kwargs,
    ):
        self.sparsity_ratio = sparsity_ratio
        self.only_on_linear_weight = only_on_linear_weights
        self.rescale = rescale
        super().__init__(**kwargs)

    def run(self, modelpool: BaseModelPool):
        return DareTaskArithmetic(
            scaling_factor=1 / len(modelpool),
            sparsity_ratio=self.sparsity_ratio,
            only_on_linear_weights=self.only_on_linear_weight,
            rescale=self.rescale,
        ).run(modelpool)

DareTaskArithmetic

Bases: BaseAlgorithm

Implementation of Task Arithmetic w/ DARE.

  • Yu et al. Language Models are Super Mario: Absorbing Abilities from Homologous Models as a Free Lunch. 2023. http://arxiv.org/abs/2311.03099
Source code in fusion_bench/method/dare/task_arithmetic.py
class DareTaskArithmetic(BaseAlgorithm):
    """
    Implementation of Task Arithmetic w/ DARE.

    - Yu et al. Language Models are Super Mario: Absorbing Abilities from Homologous Models as a Free Lunch. 2023. http://arxiv.org/abs/2311.03099
    """

    def __init__(
        self,
        scaling_factor: float,
        sparsity_ratio: float,
        only_on_linear_weights: bool,
        rescale: bool = True,
        **kwargs,
    ):
        self.scaling_factor = scaling_factor
        self.sparsity_ratio = sparsity_ratio
        self.only_on_linear_weights = only_on_linear_weights
        self.rescale = rescale
        super().__init__(**kwargs)

    def _load_task_vector(
        self,
        modelpool: BaseModelPool,
        model_name: str,
        pretrained_model: nn.Module,
    ):
        finetuned_model = modelpool.load_model(model_name)
        task_vector = module_sub_(finetuned_model, pretrained_model)
        return task_vector

    @torch.no_grad()
    def run(self, modelpool: BaseModelPool):
        assert (
            self.sparsity_ratio >= 0 and self.sparsity_ratio <= 1
        ), "Sparsity ratio must be between 0 and 1"
        pretrained_model = modelpool.load_pretrained_model()

        # load task vectors
        task_vectors = {
            model_name: self._load_task_vector(modelpool, model_name, pretrained_model)
            for model_name in modelpool.model_names
        }

        # drop and rescale task vectors
        for model_name, tv in task_vectors.items():
            if self.only_on_linear_weights:
                for module_name, module in tv.named_modules():
                    if isinstance(module, nn.Linear):
                        print(f"pruning model: `{model_name}`, layer: {module_name}.")
                        param_random_drop_(
                            module.weight, self.sparsity_ratio, rescale=self.rescale
                        )
            else:
                print(f"pruning model: `{model_name}`")
                module_random_drop_(tv, self.sparsity_ratio, rescale=self.rescale)

        # merge task vectors
        task_vector_sum = state_dict_sum(
            [trainable_state_dict(tv) for tv in task_vectors.values()]
        )

        # scale the task vector and add it to the pretrained model
        for name, delta in task_vector_sum.items():
            delta = delta * self.scaling_factor
            pretrained_model.get_parameter(name).data.add_(delta)

        return pretrained_model

DareTiesMerging

Bases: BaseAlgorithm

Source code in fusion_bench/method/dare/ties_merging.py
class DareTiesMerging(BaseAlgorithm):
    def __init__(
        self,
        # DARE parameters
        sparsity_ratio: float,
        only_on_linear_weights: bool,
        rescale: bool,
        # Ties merging parameters
        scaling_factor: float,
        threshold: int,
        remove_keys: list[str],
        merge_func: Literal["sum", "mean", "max"],
        **kwargs,
    ):
        self.sparsity_ratio = sparsity_ratio
        self.only_on_linear_weights = only_on_linear_weights
        self.rescale = rescale
        self.scaling_factor = scaling_factor
        self.threshold = threshold
        self.remove_keys = remove_keys
        self.merge_func = merge_func
        super().__init__(**kwargs)

    @torch.no_grad()
    def _load_task_vector(
        self,
        modelpool: BaseModelPool,
        model_name: str,
        pretrained_model: nn.Module,
    ):
        finetuned_model = modelpool.load_model(model_name)
        task_vector = module_sub_(finetuned_model, pretrained_model)
        return task_vector

    def run(self, modelpool: BaseModelPool):
        assert (
            self.sparsity_ratio >= 0 and self.sparsity_ratio <= 1
        ), "Sparsity ratio must be between 0 and 1"
        pretrained_model = modelpool.load_pretrained_model()

        # load task vectors
        task_vectors = {
            model_name: self._load_task_vector(modelpool, model_name, pretrained_model)
            for model_name in modelpool.model_names
        }

        # drop and rescale task vectors
        for model_name, tv in task_vectors.items():
            if self.only_on_linear_weights:
                for module_name, module in tv.named_modules():
                    if isinstance(module, nn.Linear):
                        print(f"pruning model: `{model_name}`, layer: {module_name}.")
                        param_random_drop_(
                            module.weight, self.sparsity_ratio, rescale=self.rescale
                        )
            else:
                print(f"pruning model: `{model_name}`")
                module_random_drop_(tv, self.sparsity_ratio, rescale=self.rescale)

        ptm_check = pretrained_model.state_dict()
        flat_ptm = state_dict_to_vector(ptm_check, self.remove_keys)
        tv_flat_checks = torch.vstack(
            [
                state_dict_to_vector(check.state_dict(), self.remove_keys)
                for check in task_vectors.values()
            ]
        )
        del task_vectors

        # Perform TIES Merging
        merged_tv = ties_merging(
            tv_flat_checks,
            reset_thresh=self.threshold,
            merge_func=self.merge_func,
        )
        merged_check = flat_ptm + self.scaling_factor * merged_tv
        merged_state_dict = vector_to_state_dict(
            merged_check, ptm_check, remove_keys=self.remove_keys
        )

        pretrained_model.load_state_dict(merged_state_dict)
        return pretrained_model

Model Extrapolation (ExPO)

ExPOAlgorithm

Bases: BaseAlgorithm

ExPO merge algorithm.

This algorithm merges a pretrained model with a finetuned model.

\[\theta_{merged} = \theta_{sft} + \alpha (\theta_{rlhf} - \theta_{sft})\]

where \(\theta_{merged}\) is the merged model, \(\theta_{rlhf}\) is the finetuned model (medium-aligned model), \(\theta_{sft}\) is the pretrained model (base model), and \(\alpha\) is the extrapolation factor.

In the configuration, the SFT model should have name _pretrained_ and the rlhf name can be set arbitarily.

Source code in fusion_bench/method/linear/expo.py
class ExPOAlgorithm(BaseAlgorithm):
    R"""
    ExPO merge algorithm.

    This algorithm merges a pretrained model with a finetuned model.

    $$\theta_{merged} = \theta_{sft} + \alpha (\theta_{rlhf} - \theta_{sft})$$

    where $\theta_{merged}$ is the merged model, $\theta_{rlhf}$ is the finetuned model (medium-aligned model),
    $\theta_{sft}$ is the pretrained model (base model), and $\alpha$ is the extrapolation factor.

    In the configuration, the SFT model should have name `_pretrained_` and the rlhf name can be set arbitarily.
    """

    _config_mapping = BaseAlgorithm._config_mapping | {
        "extrapolation_factor": "extrapolation_factor"
    }

    def __init__(self, extrapolation_factor: float, **kwargs):
        self.extrapolation_factor = extrapolation_factor
        super().__init__(**kwargs)

    def run(self, modelpool: BaseModelPool):
        """
        Run the ExPO merge algorithm.

        Args:
            modelpool (BaseModelPool): The pool of models to merge.

        Returns:
            nn.Module: The merged model.
        """
        if not isinstance(modelpool, BaseModelPool):
            modelpool = BaseModelPool(modelpool)

        assert len(modelpool.model_names) >= 1, "ExPO requires at least one model."
        assert modelpool.has_pretrained, "ExPO requires pretrained models (base model)."

        sft_model = modelpool.load_pretrained_model()
        if len(modelpool) == 1:
            rlhf_model = modelpool.load_model(modelpool.model_names[0])
        else:
            # if there are multiple RLHF models, use simple average to merge them before running ExPO
            log.info(
                f"There are {len(modelpool)} models in the model pool, averaging them first..."
            )
            rlhf_model = SimpleAverageAlgorithm().run(modelpool)

        # merge the pretrained model and the finetuned model
        delta_parameters = state_dict_sub(
            rlhf_model.state_dict(), sft_model.state_dict()
        )
        merged_sd = state_dict_add(
            rlhf_model.state_dict(),
            state_dict_mul(delta_parameters, scalar=self.extrapolation_factor),
        )

        rlhf_model.load_state_dict(merged_sd)
        return rlhf_model
run(modelpool)

Run the ExPO merge algorithm.

Parameters:

Returns:

  • nn.Module: The merged model.

Source code in fusion_bench/method/linear/expo.py
def run(self, modelpool: BaseModelPool):
    """
    Run the ExPO merge algorithm.

    Args:
        modelpool (BaseModelPool): The pool of models to merge.

    Returns:
        nn.Module: The merged model.
    """
    if not isinstance(modelpool, BaseModelPool):
        modelpool = BaseModelPool(modelpool)

    assert len(modelpool.model_names) >= 1, "ExPO requires at least one model."
    assert modelpool.has_pretrained, "ExPO requires pretrained models (base model)."

    sft_model = modelpool.load_pretrained_model()
    if len(modelpool) == 1:
        rlhf_model = modelpool.load_model(modelpool.model_names[0])
    else:
        # if there are multiple RLHF models, use simple average to merge them before running ExPO
        log.info(
            f"There are {len(modelpool)} models in the model pool, averaging them first..."
        )
        rlhf_model = SimpleAverageAlgorithm().run(modelpool)

    # merge the pretrained model and the finetuned model
    delta_parameters = state_dict_sub(
        rlhf_model.state_dict(), sft_model.state_dict()
    )
    merged_sd = state_dict_add(
        rlhf_model.state_dict(),
        state_dict_mul(delta_parameters, scalar=self.extrapolation_factor),
    )

    rlhf_model.load_state_dict(merged_sd)
    return rlhf_model

ExPOAlgorithmForLlama

Bases: BaseAlgorithm

Source code in fusion_bench/method/linear/llama_expo.py
class ExPOAlgorithmForLlama(BaseAlgorithm):

    def __init__(
        self,
        extrapolation_factor: float,
        attention_scaling_factor: float = 0.5,
        only_on_backbone: bool = True,
        on_linear_weights: bool = True,
        on_linear_bias: bool = False,
        on_embedding: bool = False,
        fix_last_n_layers: int = 0,
        fix_first_n_layers: int = 0,
        magnitude_sparsity_ratio: Optional[float] = None,
        **kwargs,
    ):
        self.extrapolation_factor = extrapolation_factor
        self.attention_scaling_factor = attention_scaling_factor
        self.only_on_backbone = only_on_backbone
        self.on_linear_weights = on_linear_weights
        self.on_linear_bias = on_linear_bias
        self.on_embedding = on_embedding
        self.fix_last_n_layers = fix_last_n_layers
        self.fix_first_n_layers = fix_first_n_layers
        self.magnitude_sparsity_ratio = magnitude_sparsity_ratio
        super().__init__(**kwargs)

    def load_models(self, modelpool: BaseModelPool):
        sft_model: LlamaForCausalLM = modelpool.load_pretrained_model()
        if len(modelpool) == 1:
            rlhf_model = modelpool.load_model(modelpool.model_names[0])
        else:
            # if there are multiple RLHF models, use simple average to merge them before running ExPO
            log.info(
                f"There are {len(modelpool)} models in the model pool, averaging them first..."
            )
            rlhf_model = SimpleAverageAlgorithm().run(modelpool)
        rlhf_model = cast(LlamaForCausalLM, rlhf_model)
        return sft_model, rlhf_model

    def run(self, modelpool: BaseModelPool):
        if not isinstance(modelpool, BaseModelPool):
            modelpool = BaseModelPool(modelpool)

        assert len(modelpool.model_names) >= 1, "ExPO requires at least one model."
        assert modelpool.has_pretrained, "ExPO requires pretrained models (base model)."

        sft_model, rlhf_model = self.load_models(modelpool)

        if not self.on_linear_bias:
            for name, module in sft_model.named_modules():
                if isinstance(module, nn.Linear):
                    module.bias = rlhf_model.get_submodule(name).bias
        if not self.on_linear_weights:
            for name, module in sft_model.named_modules():
                if isinstance(module, nn.Linear):
                    module.weight = rlhf_model.get_submodule(name).weight

        if not self.only_on_backbone:
            expo_(sft_model.lm_head, rlhf_model.lm_head, self.extrapolation_factor)

        # expo on the backbone
        self._expo_lm_model_(
            sft_model.model, rlhf_model.model, self.extrapolation_factor
        )
        return rlhf_model

    def _expo_lm_model_(
        self,
        sft_model: LlamaModel,
        rlhf_model: LlamaModel,
        extrapolation_factor: float,
    ):
        if self.on_embedding:
            expo_(sft_model.embed_tokens, rlhf_model.embed_tokens, extrapolation_factor)

        if self.fix_first_n_layers == "half":
            self.fix_first_n_layers = len(sft_model.layers) // 2
        if self.fix_last_n_layers == "half":
            self.fix_last_n_layers = len(sft_model.layers) // 2

        for layer_idx in range(
            self.fix_first_n_layers, len(sft_model.layers) - self.fix_last_n_layers
        ):
            sft_layer = sft_model.layers[layer_idx]
            expo_linear_modules_(
                sft_layer.self_attn,
                rlhf_model.layers[layer_idx].self_attn,
                extrapolation_factor=extrapolation_factor
                * self.attention_scaling_factor,
                merge_dtype=torch.float32,
                magnitude_sparsity_ratio=self.magnitude_sparsity_ratio,
            )
            expo_linear_modules_(
                sft_layer.mlp,
                rlhf_model.layers[layer_idx].mlp,
                extrapolation_factor=extrapolation_factor,
                merge_dtype=torch.float32,
                magnitude_sparsity_ratio=self.magnitude_sparsity_ratio,
            )

DOGE

DOGE_TA_Algorithm

Bases: BaseAlgorithm, SimpleProfilerMixin, LightningFabricMixin

Task Arithmetic Algorithm for model fusion with learnable delta.

This class extends the Task Arithmetic method to include a learnable delta for task vectors, optimized to maximize cosine similarity among the task vectors.

Attributes:

  • scaling_factor (int) –

    The factor by which the task vectors will be scaled before merging.

  • delta (StateDictType) –

    A learnable parameter to adjust task vectors, initialized as zeros.

Source code in fusion_bench/method/doge_ta/doge_ta.py
class DOGE_TA_Algorithm(
    BaseAlgorithm,
    SimpleProfilerMixin,
    LightningFabricMixin,
):
    """
    Task Arithmetic Algorithm for model fusion with learnable delta.

    This class extends the Task Arithmetic method to include a learnable delta
    for task vectors, optimized to maximize cosine similarity among the task vectors.

    Attributes:
        scaling_factor (int): The factor by which the task vectors will be scaled before merging.
        delta (StateDictType): A learnable parameter to adjust task vectors, initialized as zeros.
    """

    _config_mapping = BaseAlgorithm._config_mapping | {
        "subspace": "subspace",
        "K": "K",
        "lamda": "lamda",
    }

    def __init__(self, subspace, K, lamda):
        self.delta = None  # Initialize delta as None; will be set during run
        self.subspace = subspace
        self.K = K
        self.lamda = lamda
        super().__init__()

    @property
    def device(self) -> torch.device:
        return self.fabric.device

    @torch.no_grad()
    def compute_task_vectors(
        self, modelpool: BaseModelPool, pretrained_model: nn.Module
    ) -> List[StateDictType]:
        """
        Computes task vectors for each model in the model pool relative to the pretrained model.
        """
        task_vectors = []
        pretrained_sd = pretrained_model.state_dict(keep_vars=True)
        filtered_keys = [
            k
            for k in pretrained_sd.keys()
            if ("encoder" in k and "layer_norm" not in k and "weight" in k)
        ]  # Flan T5: "layer_norm" not in k and ("q.weight" in k or "v.weight" in k)

        for model_name in modelpool.model_names:
            model = modelpool.load_model(model_name)
            model_sd = model.state_dict(keep_vars=True)

            filtered_task_vector = {
                k: (model_sd[k] - pretrained_sd[k]) for k in filtered_keys
            }
            task_vectors.append(filtered_task_vector)

        return task_vectors

    def taskvector_loss(self, layer_vectors, layer_delta, layer_lamdas) -> torch.Tensor:
        """
        Computes the loss based on delta and task vectors for a specific layer.
        """
        total_loss = 0.0

        layer_vectors_scale = layer_vectors * layer_lamdas.view(-1, 1, 1)
        sum_over_num_vectors = layer_vectors_scale.sum(dim=0)

        layer_delta_scale = layer_delta.unsqueeze(0) * layer_lamdas.view(-1, 1, 1)
        sum_over_delta = layer_delta_scale.sum(dim=0)

        # Iterate through each vector and calculate the loss one by one
        for v_j in layer_vectors:
            part1 = -v_j * sum_over_num_vectors
            part2 = -v_j * sum_over_delta
            part3 = v_j * v_j

            expression = part1 + part2 + part3
            layer_loss = expression.sum(dim=1).pow(2).sum()

            # Cumulative total loss
            total_loss += layer_loss
        return total_loss

    @torch.enable_grad()
    def optimize_delta(self, task_vectors: List[StateDictType]) -> None:
        """
        Optimizes the delta based on the loss of task vectors.
        """
        if self.delta is None:
            self.delta = {
                k: nn.Parameter(torch.zeros_like(v, device=self.device).detach())
                for k, v in task_vectors[0].items()
            }

        optimizer = torch.optim.Adam(self.delta.values(), lr=1e-4)
        initial_mem = torch.cuda.memory_allocated()
        start_time = time.time()
        for layer_name in task_vectors[0].keys():
            layer_vectors = torch.stack([vec[layer_name] for vec in task_vectors]).to(
                self.device
            )
            layer_lamdas = torch.stack(
                [lamdas[layer_name] for lamdas in self.lamdas]
            ).to(self.device)
            for _ in range(400):
                optimizer.zero_grad()
                loss = self.taskvector_loss(
                    layer_vectors, self.delta[layer_name], layer_lamdas
                )
                self.fabric.backward(loss)
                grad_proj = (
                    self.projection[layer_name] @ self.delta[layer_name].grad.detach()
                )
                self.delta[layer_name].grad.data = self.delta[
                    layer_name
                ].grad.data.sub_(grad_proj)
                optimizer.step()
                self.delta[layer_name].grad = None
        end_time = time.time()
        print(f"Running time: {end_time - start_time} s")
        final_mem = torch.cuda.memory_allocated()
        print(f"Memory usage: {(final_mem - initial_mem) / (1024 ** 2)} MB")
        print("Optimization completed.")

    @torch.no_grad()
    def run(self, modelpool: Union[BaseModelPool, Dict[str, nn.Module]]):
        """
        Runs the Algorithm with learnable delta to fuse models in the given model pool.

        Args:
            modelpool (Union[BaseModelPool, Dict[str, nn.Module]]): The pool of models to fuse.

        Returns:
            nn.Module: The pre-trained model with the merged task vectors after optimizing delta.
        """
        if not isinstance(modelpool, BaseModelPool):
            modelpool = BaseModelPool(modelpool)

        log.info("Fusing models using DOGE_TA with learnable delta.")
        with self.profile("load model"):
            pretrained_model = modelpool.load_model("_pretrained_")

        task_vectors = self.compute_task_vectors(modelpool, pretrained_model)

        self.lamdas = self.compute_layer_lamdas(task_vectors)
        self.projection = {}
        for layer_name in task_vectors[0].keys():
            for i, vector in enumerate(task_vectors):
                layer_vector = vector[layer_name].to(self.device)
                u, s, v = torch.linalg.svd(layer_vector, full_matrices=False)
                if i == 0:
                    print(f"Computed SVD for {layer_name}...")
                    sum_u = torch.zeros_like(u, device=layer_vector.device)
                    sum_s = torch.zeros_like(s, device=layer_vector.device)
                    sum_v = torch.zeros_like(v, device=layer_vector.device)

                reduced_index_s = int(s.shape[0] / len(task_vectors))

                # select only the first reduced_index_s columns of u and place them
                sum_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[
                    :, :reduced_index_s
                ]
                sum_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[
                    :reduced_index_s
                ]
                # select only the first reduced_index_s rows of v and place them
                sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[
                    :reduced_index_s, :
                ]
            u_u, s_u, v_u = torch.linalg.svd(sum_u, full_matrices=False)
            layer_proj = torch.matmul(
                u_u[:, : int(s.shape[0] / self.config.subspace)],
                u_u[:, : int(s.shape[0] / self.config.subspace)].T,
            )
            self.projection[layer_name] = layer_proj

        self.optimize_delta(task_vectors)

        del self.projection
        self.delta = {key: param.detach().cpu() for key, param in self.delta.items()}
        self.lamdas = [
            {key: param.cpu() for key, param in lamdas.items()}
            for lamdas in self.lamdas
        ]
        task_vectors = [
            {k: v.cpu() for k, v in task_vector.items()} for task_vector in task_vectors
        ]
        flat_vectors = []
        vector_masks = []
        for idx, task_vector in enumerate(task_vectors):
            flat_vector = self.state_dict_to_vector(task_vector)
            vector_mask = self.topk_values_mask(flat_vector, K=self.config.K)
            flat_vectors.append(flat_vector)
            vector_masks.append(vector_mask)
        flat_delta = self.state_dict_to_vector(self.delta)

        adjusted_vectors = [
            self.vector_to_state_dict(
                (flat_vector + flat_delta) * vector_mask, self.delta
            )
            for flat_vector, vector_mask in zip(flat_vectors, vector_masks)
        ]

        for layer_name in adjusted_vectors[0].keys():
            layer_vectors = torch.stack(
                [vec[layer_name] for vec in adjusted_vectors], dim=0
            )
            layer_lamdas = torch.stack(
                [lamdas[layer_name] for lamdas in self.lamdas], dim=0
            )
            layer_vectors_scale = layer_vectors * layer_lamdas.view(-1, 1, 1)
            task_vectors[0][layer_name] = layer_vectors_scale.sum(dim=0)

        final_state_dict = {}
        pretrained_sd = pretrained_model.state_dict(keep_vars=True)
        for k, v in pretrained_sd.items():
            if k in task_vectors[0]:
                final_state_dict[k] = v + task_vectors[0][k]
            else:
                final_state_dict[k] = v

        pretrained_model.load_state_dict(final_state_dict)

        self.print_profile_summary()
        return pretrained_model

    def compute_lamdas(self, vectors: List[StateDictType]) -> torch.Tensor:
        lamdas = []
        for vec in vectors:
            norm_vec = torch.norm(
                torch.cat([param.flatten() for param in vec.values()])
            )
            # norm_vec = sum([torch.norm(param) for param in vec.values()])
            lamdas.append(self.config.lamda / norm_vec)
        print(lamdas)
        return lamdas

    def compute_layer_lamdas(self, vectors: List[StateDictType]) -> torch.Tensor:
        lamdas = []
        for vec in vectors:
            tmp = {}
            for layer_name in vec.keys():
                norm_vec = torch.norm(vec[layer_name])
                tmp[layer_name] = self.config.lamda / norm_vec
            lamdas.append(tmp)
        return lamdas

    def topk_values_mask(self, M, K):
        if K > 1:
            K /= 100

        original_shape = M.shape
        if M.dim() == 1:
            M = M.unsqueeze(0)

        n, d = M.shape
        k = int(d * K)
        k = d - k  # Keep top k elements instead of bottom k elements

        # Find the k-th smallest element by magnitude for each row
        kth_values, _ = M.abs().kthvalue(k, dim=1, keepdim=True)
        # Create a mask tensor with True for the top k elements in each row
        mask = M.abs() >= kth_values
        final_mask = mask.squeeze() if original_shape == M.squeeze().shape else mask

        return final_mask

    def state_dict_to_vector(self, state_dict, remove_keys=[]):
        """
        Convert a state dictionary to a vector, removing specified keys.

        Args:
            state_dict (dict): The state dictionary to convert.
            remove_keys (list): List of keys to remove from the state dictionary.

        Returns:
            Tensor: A vector representation of the state dictionary.
        """
        shared_state_dict = copy.deepcopy(state_dict)
        for key in remove_keys:
            if key in shared_state_dict:
                del shared_state_dict[key]
        sorted_shared_state_dict = OrderedDict(sorted(shared_state_dict.items()))
        return nn.utils.parameters_to_vector(
            [value.reshape(-1) for key, value in sorted_shared_state_dict.items()]
        )

    def vector_to_state_dict(self, vector, state_dict, remove_keys=[]):
        """
        Convert a vector back to a state dictionary, removing specified keys.

        Args:
            vector (Tensor): The vector to convert.
            state_dict (dict): The reference state dictionary.
            remove_keys (list): List of keys to remove from the state dictionary.

        Returns:
            dict: A state dictionary representation of the vector.
        """
        # create a reference dict to define the order of the vector
        reference_dict = copy.deepcopy(state_dict)
        for key in remove_keys:
            if key in reference_dict:
                del reference_dict[key]
        sorted_reference_dict = OrderedDict(sorted(reference_dict.items()))

        # create a shared state dict using the reference dict
        nn.utils.vector_to_parameters(vector, sorted_reference_dict.values())

        # add back the encoder and decoder embedding weights.
        if "transformer.shared.weight" in sorted_reference_dict:
            for key in remove_keys:
                sorted_reference_dict[key] = sorted_reference_dict[
                    "transformer.shared.weight"
                ]
        return sorted_reference_dict
compute_task_vectors(modelpool, pretrained_model)

Computes task vectors for each model in the model pool relative to the pretrained model.

Source code in fusion_bench/method/doge_ta/doge_ta.py
@torch.no_grad()
def compute_task_vectors(
    self, modelpool: BaseModelPool, pretrained_model: nn.Module
) -> List[StateDictType]:
    """
    Computes task vectors for each model in the model pool relative to the pretrained model.
    """
    task_vectors = []
    pretrained_sd = pretrained_model.state_dict(keep_vars=True)
    filtered_keys = [
        k
        for k in pretrained_sd.keys()
        if ("encoder" in k and "layer_norm" not in k and "weight" in k)
    ]  # Flan T5: "layer_norm" not in k and ("q.weight" in k or "v.weight" in k)

    for model_name in modelpool.model_names:
        model = modelpool.load_model(model_name)
        model_sd = model.state_dict(keep_vars=True)

        filtered_task_vector = {
            k: (model_sd[k] - pretrained_sd[k]) for k in filtered_keys
        }
        task_vectors.append(filtered_task_vector)

    return task_vectors
optimize_delta(task_vectors)

Optimizes the delta based on the loss of task vectors.

Source code in fusion_bench/method/doge_ta/doge_ta.py
@torch.enable_grad()
def optimize_delta(self, task_vectors: List[StateDictType]) -> None:
    """
    Optimizes the delta based on the loss of task vectors.
    """
    if self.delta is None:
        self.delta = {
            k: nn.Parameter(torch.zeros_like(v, device=self.device).detach())
            for k, v in task_vectors[0].items()
        }

    optimizer = torch.optim.Adam(self.delta.values(), lr=1e-4)
    initial_mem = torch.cuda.memory_allocated()
    start_time = time.time()
    for layer_name in task_vectors[0].keys():
        layer_vectors = torch.stack([vec[layer_name] for vec in task_vectors]).to(
            self.device
        )
        layer_lamdas = torch.stack(
            [lamdas[layer_name] for lamdas in self.lamdas]
        ).to(self.device)
        for _ in range(400):
            optimizer.zero_grad()
            loss = self.taskvector_loss(
                layer_vectors, self.delta[layer_name], layer_lamdas
            )
            self.fabric.backward(loss)
            grad_proj = (
                self.projection[layer_name] @ self.delta[layer_name].grad.detach()
            )
            self.delta[layer_name].grad.data = self.delta[
                layer_name
            ].grad.data.sub_(grad_proj)
            optimizer.step()
            self.delta[layer_name].grad = None
    end_time = time.time()
    print(f"Running time: {end_time - start_time} s")
    final_mem = torch.cuda.memory_allocated()
    print(f"Memory usage: {(final_mem - initial_mem) / (1024 ** 2)} MB")
    print("Optimization completed.")
run(modelpool)

Runs the Algorithm with learnable delta to fuse models in the given model pool.

Parameters:

  • modelpool (Union[BaseModelPool, Dict[str, Module]]) –

    The pool of models to fuse.

Returns:

  • nn.Module: The pre-trained model with the merged task vectors after optimizing delta.

Source code in fusion_bench/method/doge_ta/doge_ta.py
@torch.no_grad()
def run(self, modelpool: Union[BaseModelPool, Dict[str, nn.Module]]):
    """
    Runs the Algorithm with learnable delta to fuse models in the given model pool.

    Args:
        modelpool (Union[BaseModelPool, Dict[str, nn.Module]]): The pool of models to fuse.

    Returns:
        nn.Module: The pre-trained model with the merged task vectors after optimizing delta.
    """
    if not isinstance(modelpool, BaseModelPool):
        modelpool = BaseModelPool(modelpool)

    log.info("Fusing models using DOGE_TA with learnable delta.")
    with self.profile("load model"):
        pretrained_model = modelpool.load_model("_pretrained_")

    task_vectors = self.compute_task_vectors(modelpool, pretrained_model)

    self.lamdas = self.compute_layer_lamdas(task_vectors)
    self.projection = {}
    for layer_name in task_vectors[0].keys():
        for i, vector in enumerate(task_vectors):
            layer_vector = vector[layer_name].to(self.device)
            u, s, v = torch.linalg.svd(layer_vector, full_matrices=False)
            if i == 0:
                print(f"Computed SVD for {layer_name}...")
                sum_u = torch.zeros_like(u, device=layer_vector.device)
                sum_s = torch.zeros_like(s, device=layer_vector.device)
                sum_v = torch.zeros_like(v, device=layer_vector.device)

            reduced_index_s = int(s.shape[0] / len(task_vectors))

            # select only the first reduced_index_s columns of u and place them
            sum_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[
                :, :reduced_index_s
            ]
            sum_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[
                :reduced_index_s
            ]
            # select only the first reduced_index_s rows of v and place them
            sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[
                :reduced_index_s, :
            ]
        u_u, s_u, v_u = torch.linalg.svd(sum_u, full_matrices=False)
        layer_proj = torch.matmul(
            u_u[:, : int(s.shape[0] / self.config.subspace)],
            u_u[:, : int(s.shape[0] / self.config.subspace)].T,
        )
        self.projection[layer_name] = layer_proj

    self.optimize_delta(task_vectors)

    del self.projection
    self.delta = {key: param.detach().cpu() for key, param in self.delta.items()}
    self.lamdas = [
        {key: param.cpu() for key, param in lamdas.items()}
        for lamdas in self.lamdas
    ]
    task_vectors = [
        {k: v.cpu() for k, v in task_vector.items()} for task_vector in task_vectors
    ]
    flat_vectors = []
    vector_masks = []
    for idx, task_vector in enumerate(task_vectors):
        flat_vector = self.state_dict_to_vector(task_vector)
        vector_mask = self.topk_values_mask(flat_vector, K=self.config.K)
        flat_vectors.append(flat_vector)
        vector_masks.append(vector_mask)
    flat_delta = self.state_dict_to_vector(self.delta)

    adjusted_vectors = [
        self.vector_to_state_dict(
            (flat_vector + flat_delta) * vector_mask, self.delta
        )
        for flat_vector, vector_mask in zip(flat_vectors, vector_masks)
    ]

    for layer_name in adjusted_vectors[0].keys():
        layer_vectors = torch.stack(
            [vec[layer_name] for vec in adjusted_vectors], dim=0
        )
        layer_lamdas = torch.stack(
            [lamdas[layer_name] for lamdas in self.lamdas], dim=0
        )
        layer_vectors_scale = layer_vectors * layer_lamdas.view(-1, 1, 1)
        task_vectors[0][layer_name] = layer_vectors_scale.sum(dim=0)

    final_state_dict = {}
    pretrained_sd = pretrained_model.state_dict(keep_vars=True)
    for k, v in pretrained_sd.items():
        if k in task_vectors[0]:
            final_state_dict[k] = v + task_vectors[0][k]
        else:
            final_state_dict[k] = v

    pretrained_model.load_state_dict(final_state_dict)

    self.print_profile_summary()
    return pretrained_model
state_dict_to_vector(state_dict, remove_keys=[])

Convert a state dictionary to a vector, removing specified keys.

Parameters:

  • state_dict (dict) –

    The state dictionary to convert.

  • remove_keys (list, default: [] ) –

    List of keys to remove from the state dictionary.

Returns:

  • Tensor

    A vector representation of the state dictionary.

Source code in fusion_bench/method/doge_ta/doge_ta.py
def state_dict_to_vector(self, state_dict, remove_keys=[]):
    """
    Convert a state dictionary to a vector, removing specified keys.

    Args:
        state_dict (dict): The state dictionary to convert.
        remove_keys (list): List of keys to remove from the state dictionary.

    Returns:
        Tensor: A vector representation of the state dictionary.
    """
    shared_state_dict = copy.deepcopy(state_dict)
    for key in remove_keys:
        if key in shared_state_dict:
            del shared_state_dict[key]
    sorted_shared_state_dict = OrderedDict(sorted(shared_state_dict.items()))
    return nn.utils.parameters_to_vector(
        [value.reshape(-1) for key, value in sorted_shared_state_dict.items()]
    )
taskvector_loss(layer_vectors, layer_delta, layer_lamdas)

Computes the loss based on delta and task vectors for a specific layer.

Source code in fusion_bench/method/doge_ta/doge_ta.py
def taskvector_loss(self, layer_vectors, layer_delta, layer_lamdas) -> torch.Tensor:
    """
    Computes the loss based on delta and task vectors for a specific layer.
    """
    total_loss = 0.0

    layer_vectors_scale = layer_vectors * layer_lamdas.view(-1, 1, 1)
    sum_over_num_vectors = layer_vectors_scale.sum(dim=0)

    layer_delta_scale = layer_delta.unsqueeze(0) * layer_lamdas.view(-1, 1, 1)
    sum_over_delta = layer_delta_scale.sum(dim=0)

    # Iterate through each vector and calculate the loss one by one
    for v_j in layer_vectors:
        part1 = -v_j * sum_over_num_vectors
        part2 = -v_j * sum_over_delta
        part3 = v_j * v_j

        expression = part1 + part2 + part3
        layer_loss = expression.sum(dim=1).pow(2).sum()

        # Cumulative total loss
        total_loss += layer_loss
    return total_loss
vector_to_state_dict(vector, state_dict, remove_keys=[])

Convert a vector back to a state dictionary, removing specified keys.

Parameters:

  • vector (Tensor) –

    The vector to convert.

  • state_dict (dict) –

    The reference state dictionary.

  • remove_keys (list, default: [] ) –

    List of keys to remove from the state dictionary.

Returns:

  • dict

    A state dictionary representation of the vector.

Source code in fusion_bench/method/doge_ta/doge_ta.py
def vector_to_state_dict(self, vector, state_dict, remove_keys=[]):
    """
    Convert a vector back to a state dictionary, removing specified keys.

    Args:
        vector (Tensor): The vector to convert.
        state_dict (dict): The reference state dictionary.
        remove_keys (list): List of keys to remove from the state dictionary.

    Returns:
        dict: A state dictionary representation of the vector.
    """
    # create a reference dict to define the order of the vector
    reference_dict = copy.deepcopy(state_dict)
    for key in remove_keys:
        if key in reference_dict:
            del reference_dict[key]
    sorted_reference_dict = OrderedDict(sorted(reference_dict.items()))

    # create a shared state dict using the reference dict
    nn.utils.vector_to_parameters(vector, sorted_reference_dict.values())

    # add back the encoder and decoder embedding weights.
    if "transformer.shared.weight" in sorted_reference_dict:
        for key in remove_keys:
            sorted_reference_dict[key] = sorted_reference_dict[
                "transformer.shared.weight"
            ]
    return sorted_reference_dict

AdaMerging

CLIPTaskWiseAdaMergingAlgorithm

Bases: TaskWiseAdaMergingAlgorithm

A class for task-wise adaptive merging of CLIP models.

This class extends the TaskWiseAdaMergingAlgorithm to provide specific functionality for CLIP models, including loading datasets, constructing zero-shot classification heads, and computing logits.

Attributes:

  • modelpool (CLIPVisionModelPool) –

    The model pool containing CLIP models.

  • _clip_processor (CLIPProcessor) –

    The CLIP processor for preparing inputs.

  • zeroshot_weights (dict) –

    A dictionary to store zero-shot weights for each task.

Source code in fusion_bench/method/adamerging/clip_task_wise_adamerging.py
class CLIPTaskWiseAdaMergingAlgorithm(TaskWiseAdaMergingAlgorithm):
    """
    A class for task-wise adaptive merging of CLIP models.

    This class extends the TaskWiseAdaMergingAlgorithm to provide specific
    functionality for CLIP models, including loading datasets, constructing
    zero-shot classification heads, and computing logits.

    Attributes:
        modelpool (CLIPVisionModelPool): The model pool containing CLIP models.
        _clip_processor (CLIPProcessor): The CLIP processor for preparing inputs.
        zeroshot_weights (dict): A dictionary to store zero-shot weights for each task.
    """

    modelpool: CLIPVisionModelPool = None
    _clip_processor: CLIPProcessor = None
    zeroshot_weights = {}

    def __init__(self, algorithm_config: DictConfig):
        super().__init__(algorithm_config)

    @functools.cache
    def get_test_dataset(self, task: str):
        """
        Load the test dataset for the task.
        This method is cached, so the dataset is loaded only once.

        Args:
            task (str): The name of the task.

        Returns:
            CLIPDataset: The test dataset for the task.
        """
        log.info(f"Loading test dataset: {task}")
        dataset = self.modelpool.load_test_dataset(task)
        dataset = CLIPDataset(dataset, self._clip_processor)
        return dataset

    @functools.cache
    def get_shuffled_test_loader_iter(self, task: str):
        """
        Get an iterator over the shuffled test DataLoader for the task.

        Args:
            task (str): The name of the task.

        Returns:
            iterator: An iterator over the shuffled test DataLoader.
        """
        loader = DataLoader(
            self.get_test_dataset(task),
            batch_size=self.config.batch_size,
            shuffle=True,
            num_workers=self.config.num_workers,
            pin_memory=True,
        )
        if self._fabric is not None:
            loader = self._fabric.setup_dataloaders(loader)
        return iter(InfiniteDataLoader(loader))

    def on_test_time_adaptation_start(self):
        """
        Prepare for test-time adaptation.

        This method loads the CLIP processor and constructs the zero-shot
        classification head for each task.
        """
        clip_model_config = self.modelpool.get_model_config("_pretrained_")
        pretrained_path = (
            clip_model_config.pretrained_model_name_or_path
            if hasattr(clip_model_config, "pretrained_model_name_or_path")
            else clip_model_config.path
        )

        with timeit_context("Loading CLIP processor and pretrained CLIP model."):
            self._clip_processor = CLIPProcessor.from_pretrained(pretrained_path)
            clip_model: CLIPModel = CLIPModel.from_pretrained(pretrained_path)

            clip_classifier = HFCLIPClassifier(clip_model, self._clip_processor)
            self.visual_projection = clip_model.visual_projection.requires_grad_(False)
            self.logit_scale_exp = clip_model.logit_scale.exp()
            if self._fabric is not None:
                self.visual_projection = self._fabric.to_device(self.visual_projection)
                self.logit_scale_exp = self._fabric.to_device(self.logit_scale_exp)

        for task in self.modelpool.model_names:
            cache_file = os.path.join(
                self.config.cache_dir,
                f"{os.path.basename(pretrained_path)}_{task}_zeroshot_weights.pt",
            )
            if os.path.exists(cache_file):
                log.info(f"Loading cached zeroshot weights for task: {task}")
                zeroshot_weights = torch.load(cache_file, map_location="cpu")
            else:
                log.info(f"Construct zero shot classification head for task: {task}")
                classnames, templates = get_classnames_and_templates(task)
                clip_classifier.set_classification_task(classnames, templates)
                zeroshot_weights = clip_classifier.zeroshot_weights
                log.info(f"save zeroshot weights to {cache_file}")
                torch.save(zeroshot_weights, cache_file)
            self.zeroshot_weights[task] = zeroshot_weights
            if self._fabric is not None:
                self.zeroshot_weights[task] = self._fabric.to_device(
                    self.zeroshot_weights[task]
                )

    def compute_logits(self, module, batch, task: str) -> Tensor:
        """
        Compute the logits for the given batch and task.

        This method computes the image embeddings, normalizes them, and calculates
        the cosine similarity with the text embeddings to produce classification logits.

        Args:
            module (nn.Module): The model module.
            batch (tuple): A batch of input data.
            task (str): The name of the task.

        Returns:
            Tensor: The classification logits for the batch.
        """
        images, _ = batch
        text_embeds = self.zeroshot_weights[task]

        image_embeds = module(images)[1]
        image_embeds = self.visual_projection(image_embeds)

        # normalize embeddings
        image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)

        # cosine similarity
        logits_per_text = (
            torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
        )
        logits_per_image = logits_per_text.t()

        return logits_per_image
compute_logits(module, batch, task)

Compute the logits for the given batch and task.

This method computes the image embeddings, normalizes them, and calculates the cosine similarity with the text embeddings to produce classification logits.

Parameters:

  • module (Module) –

    The model module.

  • batch (tuple) –

    A batch of input data.

  • task (str) –

    The name of the task.

Returns:

  • Tensor ( Tensor ) –

    The classification logits for the batch.

Source code in fusion_bench/method/adamerging/clip_task_wise_adamerging.py
def compute_logits(self, module, batch, task: str) -> Tensor:
    """
    Compute the logits for the given batch and task.

    This method computes the image embeddings, normalizes them, and calculates
    the cosine similarity with the text embeddings to produce classification logits.

    Args:
        module (nn.Module): The model module.
        batch (tuple): A batch of input data.
        task (str): The name of the task.

    Returns:
        Tensor: The classification logits for the batch.
    """
    images, _ = batch
    text_embeds = self.zeroshot_weights[task]

    image_embeds = module(images)[1]
    image_embeds = self.visual_projection(image_embeds)

    # normalize embeddings
    image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)

    # cosine similarity
    logits_per_text = (
        torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
    )
    logits_per_image = logits_per_text.t()

    return logits_per_image
get_shuffled_test_loader_iter(task) cached

Get an iterator over the shuffled test DataLoader for the task.

Parameters:

  • task (str) –

    The name of the task.

Returns:

  • iterator

    An iterator over the shuffled test DataLoader.

Source code in fusion_bench/method/adamerging/clip_task_wise_adamerging.py
@functools.cache
def get_shuffled_test_loader_iter(self, task: str):
    """
    Get an iterator over the shuffled test DataLoader for the task.

    Args:
        task (str): The name of the task.

    Returns:
        iterator: An iterator over the shuffled test DataLoader.
    """
    loader = DataLoader(
        self.get_test_dataset(task),
        batch_size=self.config.batch_size,
        shuffle=True,
        num_workers=self.config.num_workers,
        pin_memory=True,
    )
    if self._fabric is not None:
        loader = self._fabric.setup_dataloaders(loader)
    return iter(InfiniteDataLoader(loader))
get_test_dataset(task) cached

Load the test dataset for the task. This method is cached, so the dataset is loaded only once.

Parameters:

  • task (str) –

    The name of the task.

Returns:

  • CLIPDataset

    The test dataset for the task.

Source code in fusion_bench/method/adamerging/clip_task_wise_adamerging.py
@functools.cache
def get_test_dataset(self, task: str):
    """
    Load the test dataset for the task.
    This method is cached, so the dataset is loaded only once.

    Args:
        task (str): The name of the task.

    Returns:
        CLIPDataset: The test dataset for the task.
    """
    log.info(f"Loading test dataset: {task}")
    dataset = self.modelpool.load_test_dataset(task)
    dataset = CLIPDataset(dataset, self._clip_processor)
    return dataset
on_test_time_adaptation_start()

Prepare for test-time adaptation.

This method loads the CLIP processor and constructs the zero-shot classification head for each task.

Source code in fusion_bench/method/adamerging/clip_task_wise_adamerging.py
def on_test_time_adaptation_start(self):
    """
    Prepare for test-time adaptation.

    This method loads the CLIP processor and constructs the zero-shot
    classification head for each task.
    """
    clip_model_config = self.modelpool.get_model_config("_pretrained_")
    pretrained_path = (
        clip_model_config.pretrained_model_name_or_path
        if hasattr(clip_model_config, "pretrained_model_name_or_path")
        else clip_model_config.path
    )

    with timeit_context("Loading CLIP processor and pretrained CLIP model."):
        self._clip_processor = CLIPProcessor.from_pretrained(pretrained_path)
        clip_model: CLIPModel = CLIPModel.from_pretrained(pretrained_path)

        clip_classifier = HFCLIPClassifier(clip_model, self._clip_processor)
        self.visual_projection = clip_model.visual_projection.requires_grad_(False)
        self.logit_scale_exp = clip_model.logit_scale.exp()
        if self._fabric is not None:
            self.visual_projection = self._fabric.to_device(self.visual_projection)
            self.logit_scale_exp = self._fabric.to_device(self.logit_scale_exp)

    for task in self.modelpool.model_names:
        cache_file = os.path.join(
            self.config.cache_dir,
            f"{os.path.basename(pretrained_path)}_{task}_zeroshot_weights.pt",
        )
        if os.path.exists(cache_file):
            log.info(f"Loading cached zeroshot weights for task: {task}")
            zeroshot_weights = torch.load(cache_file, map_location="cpu")
        else:
            log.info(f"Construct zero shot classification head for task: {task}")
            classnames, templates = get_classnames_and_templates(task)
            clip_classifier.set_classification_task(classnames, templates)
            zeroshot_weights = clip_classifier.zeroshot_weights
            log.info(f"save zeroshot weights to {cache_file}")
            torch.save(zeroshot_weights, cache_file)
        self.zeroshot_weights[task] = zeroshot_weights
        if self._fabric is not None:
            self.zeroshot_weights[task] = self._fabric.to_device(
                self.zeroshot_weights[task]
            )

CLIPLayerWiseAdaMergingAlgorithm

Bases: CLIPClassificationMixin, LayerWiseAdaMergingAlgorithm

Source code in fusion_bench/method/adamerging/clip_layer_wise_adamerging.py
class CLIPLayerWiseAdaMergingAlgorithm(
    CLIPClassificationMixin,
    LayerWiseAdaMergingAlgorithm,
):
    def on_test_time_adaptation_start(self):
        """
        Here we load the CLIP processor and construct the zero-shot classification head for each task.
        """
        self.setup_zero_shot_classification_head()

    @functools.cache
    def get_shuffled_test_loader_iter(self, task: str):
        return super().get_shuffled_test_loader_iter(
            task,
            batch_size=self.config.batch_size,
            num_workers=self.config.num_workers,
        )
on_test_time_adaptation_start()

Here we load the CLIP processor and construct the zero-shot classification head for each task.

Source code in fusion_bench/method/adamerging/clip_layer_wise_adamerging.py
def on_test_time_adaptation_start(self):
    """
    Here we load the CLIP processor and construct the zero-shot classification head for each task.
    """
    self.setup_zero_shot_classification_head()

GPT2LayerWiseAdaMergingAlgorithm

Bases: BaseAlgorithm, LightningFabricMixin, SimpleProfilerMixin

Source code in fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py
class GPT2LayerWiseAdaMergingAlgorithm(
    BaseAlgorithm,
    LightningFabricMixin,
    SimpleProfilerMixin,
):
    scores: Dict[str, nn.Linear] = None

    def __init__(
        self,
        optimizer: DictConfig,
        dataloader_kwargs: DictConfig,
        init_values: float,
        max_steps: int,
        merging_weights_load_path: Optional[Union[str, Path]] = None,
        merging_weights_save_path: Optional[Union[str, Path]] = None,
        clamp_weights: bool = False,
        tie_weights: bool = True,
        strict: bool = False,
        cache_dir: str = "outputs/cache",
        variant: Optional[str] = None,
        **kwargs,
    ):
        self._optimizer = optimizer
        self.dataloader_kwargs = dataloader_kwargs
        self.init_values = init_values
        self.merging_weights_load_path = merging_weights_load_path
        self.merging_weights_save_path = merging_weights_save_path
        self.clamp_weights = clamp_weights
        self.tie_weights = tie_weights
        self.strict = strict
        self.max_steps = max_steps
        self.cache_dir = cache_dir
        self.variant = variant
        super().__init__(**kwargs)

    @torch.no_grad()
    def construct_layer_wise_merged_model(
        self, modelpool: GPT2ForSequenceClassificationPool
    ):
        """
        Constructs a wrapped layer-wise merged model from model pool.

        This method creates a new wrapped model by merging the layers of a pretrained model with those of several fine-tuned models.
        The merging is controlled by layer-wise weights, which is a `torch.Tensor` of the shape `(num_models, num_layers)`.
        The merging weights can be initialized based on a provided configuration or loaded from a file.

        Args:
            modelpool (ModelPool): An object containing the pretrained model and fine-tuned models to be merged.

        Returns:
            LayerWiseMergedModel: An instance of the merged model with layer-wise weights applied.
        """
        pretrained_model: GPT2Model = modelpool.load_model("_pretrained_")
        finetuned_models: List[GPT2Model] = [
            modelpool.load_model(name) for name in modelpool.model_names
        ]

        # initialize layer-wise weights using the provided configuration `init_values` or load from file if `weights` is provided
        if self.merging_weights_load_path is None:
            layer_wise_weight = get_layer_wise_weights(
                num_models=len(modelpool.model_names),
                num_layers=len(
                    tuple(
                        filter(lambda p: p.requires_grad, pretrained_model.parameters())
                    )
                ),
                init_values=self.init_values,
            )
        else:
            if isinstance(self.merging_weights_load_path, str):
                # load the merging weights from a file
                layer_wise_weight = load_tensor_from_file(
                    self.merging_weights_load_path
                )
            else:
                raise ValueError(
                    f"Unsupported weights format: {self.merging_weights_load_path}"
                )

        module = LayerWiseMergedModel(
            layer_wise_weight=layer_wise_weight,
            pretrained_model=pretrained_model,
            finetuned_models=finetuned_models,
            clamp_weights=self.clamp_weights,
            tie_weights=self.tie_weights,
            strict=self.strict,
        )
        print(f"{layer_wise_weight.size()=}, {layer_wise_weight.numel()=}")
        return module

    @rank_zero_only
    def save_merging_weights(self, file_path: str, merging_weights: torch.Tensor):
        """
        Save the merging weights to a file.

        Args:
            file_path (str): The path to save the merging weights.
            merging_weights (torch.Tensor): The merging weights to save.
        """
        if self.fabric.is_global_zero and self.merging_weights_save_path is not None:
            if isinstance(file_path, str) and not file_path.startswith(("/", ".")):
                # if the file path is not absolute or relative to current working directory, save it in the log directory
                save_path = os.path.join(self.log_dir, file_path)
            else:
                save_path = file_path
            log.info(f"saving merging weights to {save_path}.")
            if os.path.dirname(save_path):
                os.makedirs(os.path.dirname(save_path), exist_ok=True)
            torch.save(merging_weights.detach().cpu(), save_path)

    def run(self, modelpool: GPT2ForSequenceClassificationPool, **kwargs):
        """
        Run the Layer-Wise AdaMerging Algorithm.

        This method constructs the wrapped model and performs test-time adaptation if necessary.

        Args:
            modelpool (ModelPool): The model pool containing the pretrained and fine-tuned models.

        Returns:
            LayerWiseMergedModel: The merged model after test-time adaptation.
        """
        log.info("Fusing models using layer-wise adaptive merging.")
        self.modelpool = modelpool

        with self.profile("construct the wrapped model"):
            module = self.construct_layer_wise_merged_model(modelpool)

        if self.merging_weights_load_path is not None:
            # skip the test-time adaptation
            return module.merge_and_unload()
        else:
            with self.profile("test-time adaptation"):
                module = self.test_time_adaptation(module)
            if self.merging_weights_save_path is not None:
                self.save_merging_weights(
                    self.merging_weights_save_path, module.merge_weight
                )
            return module.merge_and_unload()

    def on_test_time_adaptation_start(self):
        """
        Something to do before the test-time adaptation starts. Such as setting up the task-specific heads.
        """
        self.scores = {}
        for model_name in self.modelpool.model_names:
            score = cast(
                GPT2ForSequenceClassification,
                self.modelpool.load_classifier(model_name),
            ).score.requires_grad_(False)
            score = score.to(self.fabric.device)
            self.scores[model_name] = score

    @functools.cache
    def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
        """
        Loader of test dataset for test-time adaptation. labels are not needed.

        Args:
            task (str): The name of the task.

        Returns:
            DataLoader: The data loader for the test dataset.
        """
        dataloader_kwargs = dict(self.dataloader_kwargs)
        dataloader_kwargs.update(dict(shuffle=True, collate_fn=default_data_collator))

        dataset = self.modelpool.load_test_dataset(task)
        loader = DataLoader(dataset, **dataloader_kwargs)

        if self.fabric is not None:
            loader = self.fabric.setup_dataloaders(loader)
        return iter(InfiniteDataLoader(loader))

    def compute_logits(self, module: GPT2Model, batch, task: str) -> Tensor:
        """
        Compute the logits for the given images and task.

        Args:
            module: The model module.
            images (Tensor): The input images.
            task (str): The name of the task.

        Returns:
            Tensor: The computed logits.
        """
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        batch_size, _ = input_ids.shape[:2]
        pad_token_id = 50256

        transformer_outputs = module(
            input_ids,
            past_key_values=None,
            attention_mask=attention_mask,
            token_type_ids=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            use_cache=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=True,
        )
        hidden_states = transformer_outputs[0]
        logits = self.scores[task](hidden_states)

        sequence_lengths = torch.eq(input_ids, pad_token_id).int().argmax(-1) - 1
        sequence_lengths = sequence_lengths % input_ids.shape[-1]
        sequence_lengths = sequence_lengths.to(logits.device)

        pooled_logits = logits[
            torch.arange(batch_size, device=logits.device), sequence_lengths
        ]

        assert pooled_logits.dim() == 2
        return pooled_logits

    def test_time_adaptation(self, module: LayerWiseMergedModel):
        """
        Perform test-time adaptation on the merged model.

        This method adapts the merging weights during test-time to improve performance.

        Args:
            module (LayerWiseMergedModel): The merged model.

        Returns:
            LayerWiseMergedModel: The adapted merged model.
        """
        self.on_test_time_adaptation_start()

        # configure optimizer
        optimizer = instantiate(self._optimizer, [module.merge_weight])
        module, optimizer = self.fabric.setup(module, optimizer)

        module.train()
        module.merge_weights()
        for step_idx in (
            pbar := tqdm(
                range(self.max_steps if not self.is_debug_mode else 1),
                ("[DEBUG MODE] " if self.is_debug_mode else "")
                + "AdaMerging Test-time adaptation",
                dynamic_ncols=True,
            )
        ):
            if self.variant == "mgda":
                total_loss = self._compute_gradients_using_mgda(module)
            else:
                total_loss = 0
                for task in self.modelpool.model_names:
                    with self.profile("data loading"):
                        batch = next(self.get_shuffled_test_loader_iter(task))
                    with self.profile("forward pass"):
                        logits = self.compute_logits(module, batch, task)
                        logits = logits.mean(dim=0, keepdim=True)
                        loss = entropy_loss(logits)
                        total_loss += loss
                    with self.profile("backward pass"):
                        self.fabric.backward(loss, retain_graph=True)

            with self.profile("optimizer step"):
                optimizer.step()
                optimizer.zero_grad()
            with self.profile("merging weights"):
                module.merge_weights()

            metrics = {
                "train/loss": total_loss.item(),
                "train/weight_max": module.merge_weight.max().item(),
                "train/weight_min": module.merge_weight.min().item(),
                "train/weight_mean": module.merge_weight.mean().item(),
            }
            self.fabric.log_dict(metrics, step=step_idx)
            pbar.set_postfix(metrics)

        log.info(get_memory_usage(f"after adamerging, the memory usage of GPU is:"))
        self.print_profile_summary()
        return module

    def _compute_gradients_using_mgda(self, module: LayerWiseMergedModel):
        all_grads = []
        total_loss = 0
        # default behavior for first-order optimizers
        for task in self.modelpool.model_names:
            with self.profile("data loading"):
                batch = next(self.get_shuffled_test_loader_iter(task))
            with self.profile("forward pass"):
                logits = self.compute_logits(module, batch, task)
                logits = logits.mean(dim=0, keepdim=True)
                loss = entropy_loss(logits)
                total_loss += loss
            with self.profile("backward pass"):
                # self.fabric.backward(loss, retain_graph=True)
                _grads = torch.autograd.grad(
                    loss,
                    [module.merge_weight],
                    create_graph=False,
                    retain_graph=True,
                )
                all_grads.append(_grads[0].flatten().detach())
        sol, min_norm = MinNormSolver.find_min_norm_element(all_grads)
        if not isinstance(sol, torch.Tensor):
            sol = torch.from_numpy(sol)
        sol = sol.to(
            device=module.merge_weight.device,
            dtype=module.merge_weight.dtype,
        )
        grad = torch.stack(all_grads) * sol.view(-1, 1)
        module.merge_weight.grad = grad.sum(dim=0).view_as(module.merge_weight)
        return total_loss
compute_logits(module, batch, task)

Compute the logits for the given images and task.

Parameters:

  • module (GPT2Model) –

    The model module.

  • images (Tensor) –

    The input images.

  • task (str) –

    The name of the task.

Returns:

  • Tensor ( Tensor ) –

    The computed logits.

Source code in fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py
def compute_logits(self, module: GPT2Model, batch, task: str) -> Tensor:
    """
    Compute the logits for the given images and task.

    Args:
        module: The model module.
        images (Tensor): The input images.
        task (str): The name of the task.

    Returns:
        Tensor: The computed logits.
    """
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    batch_size, _ = input_ids.shape[:2]
    pad_token_id = 50256

    transformer_outputs = module(
        input_ids,
        past_key_values=None,
        attention_mask=attention_mask,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=True,
    )
    hidden_states = transformer_outputs[0]
    logits = self.scores[task](hidden_states)

    sequence_lengths = torch.eq(input_ids, pad_token_id).int().argmax(-1) - 1
    sequence_lengths = sequence_lengths % input_ids.shape[-1]
    sequence_lengths = sequence_lengths.to(logits.device)

    pooled_logits = logits[
        torch.arange(batch_size, device=logits.device), sequence_lengths
    ]

    assert pooled_logits.dim() == 2
    return pooled_logits
construct_layer_wise_merged_model(modelpool)

Constructs a wrapped layer-wise merged model from model pool.

This method creates a new wrapped model by merging the layers of a pretrained model with those of several fine-tuned models. The merging is controlled by layer-wise weights, which is a torch.Tensor of the shape (num_models, num_layers). The merging weights can be initialized based on a provided configuration or loaded from a file.

Parameters:

  • modelpool (ModelPool) –

    An object containing the pretrained model and fine-tuned models to be merged.

Returns:

  • LayerWiseMergedModel

    An instance of the merged model with layer-wise weights applied.

Source code in fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py
@torch.no_grad()
def construct_layer_wise_merged_model(
    self, modelpool: GPT2ForSequenceClassificationPool
):
    """
    Constructs a wrapped layer-wise merged model from model pool.

    This method creates a new wrapped model by merging the layers of a pretrained model with those of several fine-tuned models.
    The merging is controlled by layer-wise weights, which is a `torch.Tensor` of the shape `(num_models, num_layers)`.
    The merging weights can be initialized based on a provided configuration or loaded from a file.

    Args:
        modelpool (ModelPool): An object containing the pretrained model and fine-tuned models to be merged.

    Returns:
        LayerWiseMergedModel: An instance of the merged model with layer-wise weights applied.
    """
    pretrained_model: GPT2Model = modelpool.load_model("_pretrained_")
    finetuned_models: List[GPT2Model] = [
        modelpool.load_model(name) for name in modelpool.model_names
    ]

    # initialize layer-wise weights using the provided configuration `init_values` or load from file if `weights` is provided
    if self.merging_weights_load_path is None:
        layer_wise_weight = get_layer_wise_weights(
            num_models=len(modelpool.model_names),
            num_layers=len(
                tuple(
                    filter(lambda p: p.requires_grad, pretrained_model.parameters())
                )
            ),
            init_values=self.init_values,
        )
    else:
        if isinstance(self.merging_weights_load_path, str):
            # load the merging weights from a file
            layer_wise_weight = load_tensor_from_file(
                self.merging_weights_load_path
            )
        else:
            raise ValueError(
                f"Unsupported weights format: {self.merging_weights_load_path}"
            )

    module = LayerWiseMergedModel(
        layer_wise_weight=layer_wise_weight,
        pretrained_model=pretrained_model,
        finetuned_models=finetuned_models,
        clamp_weights=self.clamp_weights,
        tie_weights=self.tie_weights,
        strict=self.strict,
    )
    print(f"{layer_wise_weight.size()=}, {layer_wise_weight.numel()=}")
    return module
get_shuffled_test_loader_iter(task) cached

Loader of test dataset for test-time adaptation. labels are not needed.

Parameters:

  • task (str) –

    The name of the task.

Returns:

  • DataLoader ( DataLoader ) –

    The data loader for the test dataset.

Source code in fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py
@functools.cache
def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
    """
    Loader of test dataset for test-time adaptation. labels are not needed.

    Args:
        task (str): The name of the task.

    Returns:
        DataLoader: The data loader for the test dataset.
    """
    dataloader_kwargs = dict(self.dataloader_kwargs)
    dataloader_kwargs.update(dict(shuffle=True, collate_fn=default_data_collator))

    dataset = self.modelpool.load_test_dataset(task)
    loader = DataLoader(dataset, **dataloader_kwargs)

    if self.fabric is not None:
        loader = self.fabric.setup_dataloaders(loader)
    return iter(InfiniteDataLoader(loader))
on_test_time_adaptation_start()

Something to do before the test-time adaptation starts. Such as setting up the task-specific heads.

Source code in fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py
def on_test_time_adaptation_start(self):
    """
    Something to do before the test-time adaptation starts. Such as setting up the task-specific heads.
    """
    self.scores = {}
    for model_name in self.modelpool.model_names:
        score = cast(
            GPT2ForSequenceClassification,
            self.modelpool.load_classifier(model_name),
        ).score.requires_grad_(False)
        score = score.to(self.fabric.device)
        self.scores[model_name] = score
run(modelpool, **kwargs)

Run the Layer-Wise AdaMerging Algorithm.

This method constructs the wrapped model and performs test-time adaptation if necessary.

Parameters:

  • modelpool (ModelPool) –

    The model pool containing the pretrained and fine-tuned models.

Returns:

  • LayerWiseMergedModel

    The merged model after test-time adaptation.

Source code in fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py
def run(self, modelpool: GPT2ForSequenceClassificationPool, **kwargs):
    """
    Run the Layer-Wise AdaMerging Algorithm.

    This method constructs the wrapped model and performs test-time adaptation if necessary.

    Args:
        modelpool (ModelPool): The model pool containing the pretrained and fine-tuned models.

    Returns:
        LayerWiseMergedModel: The merged model after test-time adaptation.
    """
    log.info("Fusing models using layer-wise adaptive merging.")
    self.modelpool = modelpool

    with self.profile("construct the wrapped model"):
        module = self.construct_layer_wise_merged_model(modelpool)

    if self.merging_weights_load_path is not None:
        # skip the test-time adaptation
        return module.merge_and_unload()
    else:
        with self.profile("test-time adaptation"):
            module = self.test_time_adaptation(module)
        if self.merging_weights_save_path is not None:
            self.save_merging_weights(
                self.merging_weights_save_path, module.merge_weight
            )
        return module.merge_and_unload()
save_merging_weights(file_path, merging_weights)

Save the merging weights to a file.

Parameters:

  • file_path (str) –

    The path to save the merging weights.

  • merging_weights (Tensor) –

    The merging weights to save.

Source code in fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py
@rank_zero_only
def save_merging_weights(self, file_path: str, merging_weights: torch.Tensor):
    """
    Save the merging weights to a file.

    Args:
        file_path (str): The path to save the merging weights.
        merging_weights (torch.Tensor): The merging weights to save.
    """
    if self.fabric.is_global_zero and self.merging_weights_save_path is not None:
        if isinstance(file_path, str) and not file_path.startswith(("/", ".")):
            # if the file path is not absolute or relative to current working directory, save it in the log directory
            save_path = os.path.join(self.log_dir, file_path)
        else:
            save_path = file_path
        log.info(f"saving merging weights to {save_path}.")
        if os.path.dirname(save_path):
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
        torch.save(merging_weights.detach().cpu(), save_path)
test_time_adaptation(module)

Perform test-time adaptation on the merged model.

This method adapts the merging weights during test-time to improve performance.

Parameters:

Returns:

  • LayerWiseMergedModel

    The adapted merged model.

Source code in fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py
def test_time_adaptation(self, module: LayerWiseMergedModel):
    """
    Perform test-time adaptation on the merged model.

    This method adapts the merging weights during test-time to improve performance.

    Args:
        module (LayerWiseMergedModel): The merged model.

    Returns:
        LayerWiseMergedModel: The adapted merged model.
    """
    self.on_test_time_adaptation_start()

    # configure optimizer
    optimizer = instantiate(self._optimizer, [module.merge_weight])
    module, optimizer = self.fabric.setup(module, optimizer)

    module.train()
    module.merge_weights()
    for step_idx in (
        pbar := tqdm(
            range(self.max_steps if not self.is_debug_mode else 1),
            ("[DEBUG MODE] " if self.is_debug_mode else "")
            + "AdaMerging Test-time adaptation",
            dynamic_ncols=True,
        )
    ):
        if self.variant == "mgda":
            total_loss = self._compute_gradients_using_mgda(module)
        else:
            total_loss = 0
            for task in self.modelpool.model_names:
                with self.profile("data loading"):
                    batch = next(self.get_shuffled_test_loader_iter(task))
                with self.profile("forward pass"):
                    logits = self.compute_logits(module, batch, task)
                    logits = logits.mean(dim=0, keepdim=True)
                    loss = entropy_loss(logits)
                    total_loss += loss
                with self.profile("backward pass"):
                    self.fabric.backward(loss, retain_graph=True)

        with self.profile("optimizer step"):
            optimizer.step()
            optimizer.zero_grad()
        with self.profile("merging weights"):
            module.merge_weights()

        metrics = {
            "train/loss": total_loss.item(),
            "train/weight_max": module.merge_weight.max().item(),
            "train/weight_min": module.merge_weight.min().item(),
            "train/weight_mean": module.merge_weight.mean().item(),
        }
        self.fabric.log_dict(metrics, step=step_idx)
        pbar.set_postfix(metrics)

    log.info(get_memory_usage(f"after adamerging, the memory usage of GPU is:"))
    self.print_profile_summary()
    return module

FlanT5LayerWiseAdaMergingAlgorithm

Bases: BaseAlgorithm, LightningFabricMixin, SimpleProfilerMixin

Source code in fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py
class FlanT5LayerWiseAdaMergingAlgorithm(
    BaseAlgorithm,
    LightningFabricMixin,
    SimpleProfilerMixin,
):

    def __init__(
        self,
        optimizer: DictConfig,
        dataloader_kwargs: DictConfig,
        init_values: float,
        max_steps: int,
        merging_weights_load_path: Optional[Union[str, Path]] = None,
        merging_weights_save_path: Optional[Union[str, Path]] = None,
        clamp_weights: bool = False,
        tie_weights: bool = True,
        strict: bool = False,
        cache_dir: str = "outputs/cache",
        variant: Optional[str] = None,
        **kwargs,
    ):
        self._optimizer = optimizer
        self.dataloader_kwargs = dataloader_kwargs
        self.init_values = init_values
        self.merging_weights_load_path = merging_weights_load_path
        self.merging_weights_save_path = merging_weights_save_path
        self.clamp_weights = clamp_weights
        self.tie_weights = tie_weights
        self.strict = strict
        self.max_steps = max_steps
        self.cache_dir = cache_dir
        self.variant = variant
        super().__init__(**kwargs)

    @torch.no_grad()
    def construct_layer_wise_merged_model(self, modelpool: Seq2SeqLMPool):
        """
        Constructs a wrapped layer-wise merged model from model pool.

        This method creates a new wrapped model by merging the layers of a pretrained model with those of several fine-tuned models.
        The merging is controlled by layer-wise weights, which is a `torch.Tensor` of the shape `(num_models, num_layers)`.
        The merging weights can be initialized based on a provided configuration or loaded from a file.

        Args:
            modelpool (ModelPool): An object containing the pretrained model and fine-tuned models to be merged.

        Returns:
            LayerWiseMergedModel: An instance of the merged model with layer-wise weights applied.
        """
        pretrained_model = modelpool.load_model("_pretrained_")
        finetuned_models = [
            modelpool.load_model(name) for name in modelpool.model_names
        ]

        # initialize layer-wise weights using the provided configuration `init_values` or load from file if `weights` is provided
        if self.merging_weights_load_path is None:
            layer_wise_weight = get_layer_wise_weights(
                num_models=len(modelpool.model_names),
                num_layers=len(
                    tuple(
                        filter(lambda p: p.requires_grad, pretrained_model.parameters())
                    )
                ),
                init_values=self.init_values,
            )
        else:
            if isinstance(self.merging_weights_load_path, str):
                # load the merging weights from a file
                layer_wise_weight = load_tensor_from_file(
                    self.merging_weights_load_path
                )
            else:
                raise ValueError(
                    f"Unsupported weights format: {self.merging_weights_load_path}"
                )

        module = LayerWiseMergedModel(
            layer_wise_weight=layer_wise_weight,
            pretrained_model=pretrained_model,
            finetuned_models=finetuned_models,
            clamp_weights=self.clamp_weights,
            tie_weights=self.tie_weights,
            strict=self.strict,
        )
        print(f"{layer_wise_weight.size()=}, {layer_wise_weight.numel()=}")
        return module

    @rank_zero_only
    def save_merging_weights(self, file_path: str, merging_weights: torch.Tensor):
        """
        Save the merging weights to a file.

        Args:
            file_path (str): The path to save the merging weights.
            merging_weights (torch.Tensor): The merging weights to save.
        """
        if self.fabric.is_global_zero and self.merging_weights_save_path is not None:
            if isinstance(file_path, str) and not file_path.startswith(("/", ".")):
                # if the file path is not absolute or relative to current working directory, save it in the log directory
                save_path = os.path.join(self.log_dir, file_path)
            else:
                save_path = file_path
            log.info(f"saving merging weights to {save_path}.")
            if os.path.dirname(save_path):
                os.makedirs(os.path.dirname(save_path), exist_ok=True)
            torch.save(merging_weights.detach().cpu(), save_path)

    def run(self, modelpool: Seq2SeqLMPool, **kwargs):
        """
        Run the Layer-Wise AdaMerging Algorithm.

        This method constructs the wrapped model and performs test-time adaptation if necessary.

        Args:
            modelpool (ModelPool): The model pool containing the pretrained and fine-tuned models.

        Returns:
            LayerWiseMergedModel: The merged model after test-time adaptation.
        """
        log.info("Fusing models using layer-wise adaptive merging.")
        self.modelpool = modelpool

        with self.profile("construct the wrapped model"):
            module = self.construct_layer_wise_merged_model(modelpool)

        if self.merging_weights_load_path is not None:
            # skip the test-time adaptation
            return module.merge_and_unload()
        else:
            with self.profile("test-time adaptation"):
                module = self.test_time_adaptation(module)
            if self.merging_weights_save_path is not None:
                self.save_merging_weights(
                    self.merging_weights_save_path, module.merge_weight
                )
            return module.merge_and_unload()

    @functools.cache
    def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
        """
        Loader of test dataset for test-time adaptation. labels are not needed.

        Args:
            task (str): The name of the task.

        Returns:
            DataLoader: The data loader for the test dataset.
        """
        dataloader_kwargs = dict(self.dataloader_kwargs)
        dataloader_kwargs.update(dict(shuffle=True, collate_fn=default_data_collator))

        dataset = self.modelpool.load_test_dataset(task)
        loader = DataLoader(dataset, **dataloader_kwargs)

        if self.fabric is not None:
            loader = self.fabric.setup_dataloaders(loader)
        return iter(InfiniteDataLoader(loader))

    def compute_logits(
        self,
        module: Union[T5ForConditionalGeneration, LayerWiseMergedModel],
        batch,
        task: str,
    ) -> Tensor:
        """
        Compute the logits for the given images and task.

        Args:
            module: The model module.
            images (Tensor): The input images.
            task (str): The name of the task.

        Returns:
            Tensor: The computed logits.
        """
        input_ids: Tensor = batch["input_ids"]
        attention_mask: Tensor = batch["attention_mask"]

        # remove padding tokens from the input
        while attention_mask[:, -1].eq(0).all():
            input_ids = input_ids[:, :-1]
            attention_mask = attention_mask[:, :-1]

        outputs = module(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=torch.ones(
                input_ids.size(0), 1, dtype=torch.long, device=input_ids.device
            ),
        )
        logits = outputs.logits[:, 0, :]
        return logits

    def on_test_time_adaptation_start(self):
        """
        Something to do before the test-time adaptation starts. Such as setting up the task-specific heads.
        """
        pass

    def test_time_adaptation(self, module: LayerWiseMergedModel):
        """
        Perform test-time adaptation on the merged model.

        This method adapts the merging weights during test-time to improve performance.

        Args:
            module (LayerWiseMergedModel): The merged model.

        Returns:
            LayerWiseMergedModel: The adapted merged model.
        """
        self.on_test_time_adaptation_start()

        # configure optimizer
        optimizer = instantiate(self._optimizer, [module.merge_weight])
        module, optimizer = self.fabric.setup(module, optimizer)

        module.train()
        module.merge_weights()
        for step_idx in (
            pbar := tqdm(
                range(self.max_steps if not self.is_debug_mode else 1),
                ("[DEBUG MODE] " if self.is_debug_mode else "")
                + "AdaMerging Test-time adaptation",
                dynamic_ncols=True,
            )
        ):
            if self.variant == "mgda":
                total_loss = self._compute_gradients_using_mgda(module)
            else:
                total_loss = 0
                for task in self.modelpool.model_names:
                    with self.profile("data loading"):
                        batch = next(self.get_shuffled_test_loader_iter(task))
                    with self.profile("forward pass"):
                        logits = self.compute_logits(module, batch, task)
                        logits = logits.mean(dim=0, keepdim=True)
                        loss = entropy_loss(logits)
                        total_loss += loss
                    with self.profile("backward pass"):
                        self.fabric.backward(loss, retain_graph=True)

            with self.profile("optimizer step"):
                optimizer.step()
                optimizer.zero_grad()
            with self.profile("merging weights"):
                module.merge_weights()

            metrics = {
                "train/loss": total_loss.item(),
                "train/weight_max": module.merge_weight.max().item(),
                "train/weight_min": module.merge_weight.min().item(),
                "train/weight_mean": module.merge_weight.mean().item(),
            }
            self.fabric.log_dict(metrics, step=step_idx)
            pbar.set_postfix(metrics)

        log.info(get_memory_usage(f"after adamerging, the memory usage of GPU is:"))
        self.print_profile_summary()
        return module

    def _compute_gradients_using_mgda(self, module: LayerWiseMergedModel):
        all_grads = []
        total_loss = 0
        # default behavior for first-order optimizers
        for task in self.modelpool.model_names:
            with self.profile("data loading"):
                batch = next(self.get_shuffled_test_loader_iter(task))
            with self.profile("forward pass"):
                logits = self.compute_logits(module, batch, task)
                logits = logits.mean(dim=0, keepdim=True)
                loss = entropy_loss(logits)
                total_loss += loss
            with self.profile("backward pass"):
                # self.fabric.backward(loss, retain_graph=True)
                _grads = torch.autograd.grad(
                    loss,
                    [module.merge_weight],
                    create_graph=False,
                    retain_graph=True,
                )
                all_grads.append(_grads[0].flatten().detach())
        sol, min_norm = MinNormSolver.find_min_norm_element(all_grads)
        if not isinstance(sol, torch.Tensor):
            sol = torch.from_numpy(sol)
        sol = sol.to(
            device=module.merge_weight.device,
            dtype=module.merge_weight.dtype,
        )
        grad = torch.stack(all_grads) * sol.view(-1, 1)
        module.merge_weight.grad = grad.sum(dim=0).view_as(module.merge_weight)
        return total_loss
compute_logits(module, batch, task)

Compute the logits for the given images and task.

Parameters:

  • module (Union[T5ForConditionalGeneration, LayerWiseMergedModel]) –

    The model module.

  • images (Tensor) –

    The input images.

  • task (str) –

    The name of the task.

Returns:

  • Tensor ( Tensor ) –

    The computed logits.

Source code in fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py
def compute_logits(
    self,
    module: Union[T5ForConditionalGeneration, LayerWiseMergedModel],
    batch,
    task: str,
) -> Tensor:
    """
    Compute the logits for the given images and task.

    Args:
        module: The model module.
        images (Tensor): The input images.
        task (str): The name of the task.

    Returns:
        Tensor: The computed logits.
    """
    input_ids: Tensor = batch["input_ids"]
    attention_mask: Tensor = batch["attention_mask"]

    # remove padding tokens from the input
    while attention_mask[:, -1].eq(0).all():
        input_ids = input_ids[:, :-1]
        attention_mask = attention_mask[:, :-1]

    outputs = module(
        input_ids=input_ids,
        attention_mask=attention_mask,
        decoder_input_ids=torch.ones(
            input_ids.size(0), 1, dtype=torch.long, device=input_ids.device
        ),
    )
    logits = outputs.logits[:, 0, :]
    return logits
construct_layer_wise_merged_model(modelpool)

Constructs a wrapped layer-wise merged model from model pool.

This method creates a new wrapped model by merging the layers of a pretrained model with those of several fine-tuned models. The merging is controlled by layer-wise weights, which is a torch.Tensor of the shape (num_models, num_layers). The merging weights can be initialized based on a provided configuration or loaded from a file.

Parameters:

  • modelpool (ModelPool) –

    An object containing the pretrained model and fine-tuned models to be merged.

Returns:

  • LayerWiseMergedModel

    An instance of the merged model with layer-wise weights applied.

Source code in fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py
@torch.no_grad()
def construct_layer_wise_merged_model(self, modelpool: Seq2SeqLMPool):
    """
    Constructs a wrapped layer-wise merged model from model pool.

    This method creates a new wrapped model by merging the layers of a pretrained model with those of several fine-tuned models.
    The merging is controlled by layer-wise weights, which is a `torch.Tensor` of the shape `(num_models, num_layers)`.
    The merging weights can be initialized based on a provided configuration or loaded from a file.

    Args:
        modelpool (ModelPool): An object containing the pretrained model and fine-tuned models to be merged.

    Returns:
        LayerWiseMergedModel: An instance of the merged model with layer-wise weights applied.
    """
    pretrained_model = modelpool.load_model("_pretrained_")
    finetuned_models = [
        modelpool.load_model(name) for name in modelpool.model_names
    ]

    # initialize layer-wise weights using the provided configuration `init_values` or load from file if `weights` is provided
    if self.merging_weights_load_path is None:
        layer_wise_weight = get_layer_wise_weights(
            num_models=len(modelpool.model_names),
            num_layers=len(
                tuple(
                    filter(lambda p: p.requires_grad, pretrained_model.parameters())
                )
            ),
            init_values=self.init_values,
        )
    else:
        if isinstance(self.merging_weights_load_path, str):
            # load the merging weights from a file
            layer_wise_weight = load_tensor_from_file(
                self.merging_weights_load_path
            )
        else:
            raise ValueError(
                f"Unsupported weights format: {self.merging_weights_load_path}"
            )

    module = LayerWiseMergedModel(
        layer_wise_weight=layer_wise_weight,
        pretrained_model=pretrained_model,
        finetuned_models=finetuned_models,
        clamp_weights=self.clamp_weights,
        tie_weights=self.tie_weights,
        strict=self.strict,
    )
    print(f"{layer_wise_weight.size()=}, {layer_wise_weight.numel()=}")
    return module
get_shuffled_test_loader_iter(task) cached

Loader of test dataset for test-time adaptation. labels are not needed.

Parameters:

  • task (str) –

    The name of the task.

Returns:

  • DataLoader ( DataLoader ) –

    The data loader for the test dataset.

Source code in fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py
@functools.cache
def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
    """
    Loader of test dataset for test-time adaptation. labels are not needed.

    Args:
        task (str): The name of the task.

    Returns:
        DataLoader: The data loader for the test dataset.
    """
    dataloader_kwargs = dict(self.dataloader_kwargs)
    dataloader_kwargs.update(dict(shuffle=True, collate_fn=default_data_collator))

    dataset = self.modelpool.load_test_dataset(task)
    loader = DataLoader(dataset, **dataloader_kwargs)

    if self.fabric is not None:
        loader = self.fabric.setup_dataloaders(loader)
    return iter(InfiniteDataLoader(loader))
on_test_time_adaptation_start()

Something to do before the test-time adaptation starts. Such as setting up the task-specific heads.

Source code in fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py
def on_test_time_adaptation_start(self):
    """
    Something to do before the test-time adaptation starts. Such as setting up the task-specific heads.
    """
    pass
run(modelpool, **kwargs)

Run the Layer-Wise AdaMerging Algorithm.

This method constructs the wrapped model and performs test-time adaptation if necessary.

Parameters:

  • modelpool (ModelPool) –

    The model pool containing the pretrained and fine-tuned models.

Returns:

  • LayerWiseMergedModel

    The merged model after test-time adaptation.

Source code in fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py
def run(self, modelpool: Seq2SeqLMPool, **kwargs):
    """
    Run the Layer-Wise AdaMerging Algorithm.

    This method constructs the wrapped model and performs test-time adaptation if necessary.

    Args:
        modelpool (ModelPool): The model pool containing the pretrained and fine-tuned models.

    Returns:
        LayerWiseMergedModel: The merged model after test-time adaptation.
    """
    log.info("Fusing models using layer-wise adaptive merging.")
    self.modelpool = modelpool

    with self.profile("construct the wrapped model"):
        module = self.construct_layer_wise_merged_model(modelpool)

    if self.merging_weights_load_path is not None:
        # skip the test-time adaptation
        return module.merge_and_unload()
    else:
        with self.profile("test-time adaptation"):
            module = self.test_time_adaptation(module)
        if self.merging_weights_save_path is not None:
            self.save_merging_weights(
                self.merging_weights_save_path, module.merge_weight
            )
        return module.merge_and_unload()
save_merging_weights(file_path, merging_weights)

Save the merging weights to a file.

Parameters:

  • file_path (str) –

    The path to save the merging weights.

  • merging_weights (Tensor) –

    The merging weights to save.

Source code in fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py
@rank_zero_only
def save_merging_weights(self, file_path: str, merging_weights: torch.Tensor):
    """
    Save the merging weights to a file.

    Args:
        file_path (str): The path to save the merging weights.
        merging_weights (torch.Tensor): The merging weights to save.
    """
    if self.fabric.is_global_zero and self.merging_weights_save_path is not None:
        if isinstance(file_path, str) and not file_path.startswith(("/", ".")):
            # if the file path is not absolute or relative to current working directory, save it in the log directory
            save_path = os.path.join(self.log_dir, file_path)
        else:
            save_path = file_path
        log.info(f"saving merging weights to {save_path}.")
        if os.path.dirname(save_path):
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
        torch.save(merging_weights.detach().cpu(), save_path)
test_time_adaptation(module)

Perform test-time adaptation on the merged model.

This method adapts the merging weights during test-time to improve performance.

Parameters:

Returns:

  • LayerWiseMergedModel

    The adapted merged model.

Source code in fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py
def test_time_adaptation(self, module: LayerWiseMergedModel):
    """
    Perform test-time adaptation on the merged model.

    This method adapts the merging weights during test-time to improve performance.

    Args:
        module (LayerWiseMergedModel): The merged model.

    Returns:
        LayerWiseMergedModel: The adapted merged model.
    """
    self.on_test_time_adaptation_start()

    # configure optimizer
    optimizer = instantiate(self._optimizer, [module.merge_weight])
    module, optimizer = self.fabric.setup(module, optimizer)

    module.train()
    module.merge_weights()
    for step_idx in (
        pbar := tqdm(
            range(self.max_steps if not self.is_debug_mode else 1),
            ("[DEBUG MODE] " if self.is_debug_mode else "")
            + "AdaMerging Test-time adaptation",
            dynamic_ncols=True,
        )
    ):
        if self.variant == "mgda":
            total_loss = self._compute_gradients_using_mgda(module)
        else:
            total_loss = 0
            for task in self.modelpool.model_names:
                with self.profile("data loading"):
                    batch = next(self.get_shuffled_test_loader_iter(task))
                with self.profile("forward pass"):
                    logits = self.compute_logits(module, batch, task)
                    logits = logits.mean(dim=0, keepdim=True)
                    loss = entropy_loss(logits)
                    total_loss += loss
                with self.profile("backward pass"):
                    self.fabric.backward(loss, retain_graph=True)

        with self.profile("optimizer step"):
            optimizer.step()
            optimizer.zero_grad()
        with self.profile("merging weights"):
            module.merge_weights()

        metrics = {
            "train/loss": total_loss.item(),
            "train/weight_max": module.merge_weight.max().item(),
            "train/weight_min": module.merge_weight.min().item(),
            "train/weight_mean": module.merge_weight.mean().item(),
        }
        self.fabric.log_dict(metrics, step=step_idx)
        pbar.set_postfix(metrics)

    log.info(get_memory_usage(f"after adamerging, the memory usage of GPU is:"))
    self.print_profile_summary()
    return module

Optimization-based Methods

RegMean

RegMeanAlgorithmForCLIP

Bases: RegMeanAlgorithm, CLIPClassificationMixin

Source code in fusion_bench/method/regmean/clip_regmean.py
class RegMeanAlgorithmForCLIP(
    RegMeanAlgorithm,
    CLIPClassificationMixin,
):
    _config_mapping = {
        "_dataloader_kwargs": "dataloader_kwargs",
    }

    def __init__(self, *, dataloader_kwargs: DictConfig, **kwargs):
        super().__init__(**kwargs)
        self._dataloader_kwargs = dataloader_kwargs

    def on_regmean_start(self):
        self.setup_zero_shot_classification_head()

    def compute_logits(self, module, batch, task: str) -> Tensor:
        images, _ = batch
        text_embeds = self.zeroshot_weights[task]

        image_embeds = module(images)[1]
        image_embeds = self.visual_projection(image_embeds)

        # normalize embeddings
        image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)

        # cosine similarity
        logits_per_text = (
            torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
        )
        logits_per_image = logits_per_text.t()

        return logits_per_image

    def get_regmean_weights(
        self,
        model_name: str,
        model: Module,
        train_dataset: torch.utils.data.Dataset,
        linear_modules_to_merge: Dict[str, Module],
    ):
        # setup dataloader
        train_dataset = CLIPDataset(train_dataset, self.clip_processor)
        train_dataloader = DataLoader(
            train_dataset, shuffle=True, **self._dataloader_kwargs
        )
        train_dataloader = self.fabric.setup_dataloaders(train_dataloader)
        model = self.fabric.setup(model)

        def compute_regmean_weights(module_name: str):
            """
            compute the regmean weights, a hook function to deal with each module's input
            :param module_name: str, module name
            :return:
            """

            def hook(module: nn.Module, input: tuple, output: torch.Tensor):
                # Tensor, shape (batch_size, sequence_length, hidden_dim)
                x = cast(Tensor, input[0]).detach()
                batch_num_actual_examples = x.shape[0]
                # Tensor, shape (batch_size * sequence_length, hidden_dim)
                x = x.reshape(-1, x.shape[-1])
                # Tensor, shape (hidden_dim, hidden_dim)
                xtx = torch.matmul(x.transpose(0, 1), x)
                # store the averaged weights in regmean_weights
                if module_name not in regmean_weights.keys():
                    regmean_weights[module_name] = xtx / x.shape[0]
                    num_computed_examples[module_name] = x.shape[0]
                    num_actual_examples[module_name] = batch_num_actual_examples
                else:
                    regmean_weights[module_name] = (
                        regmean_weights[module_name]
                        * num_computed_examples[module_name]
                        + xtx
                    ) / (num_computed_examples[module_name] + x.shape[0])
                    num_computed_examples[module_name] += x.shape[0]
                    num_actual_examples[module_name] += batch_num_actual_examples

            return hook

        handles = []
        # dictionary, regmean matrices for each linear module inputs
        regmean_weights = {}
        # dictionary, number of examples (multiplied the sequence length) used for computing regmean matrices
        num_computed_examples = {}
        # dictionary, number of actual examples used for computing regmean matrices
        num_actual_examples = {}

        for module_name, linear_module_to_merge in linear_modules_to_merge.items():
            # register a hook in the forward process
            handle = linear_module_to_merge.register_forward_hook(
                compute_regmean_weights(module_name=module_name)
            )
            handles.append(handle)
        for step, batch in tqdm(
            enumerate(train_dataloader),
            desc=f"computing regmean weights for model {model_name}",
        ):
            if (
                len(num_actual_examples) > 0
                and list(num_actual_examples.values())[0] >= self.num_regmean_examples
            ):
                break
            logits = self.compute_logits(model, batch, model_name)  # noqa: F841

        # remove the added hook
        for handle in handles:
            handle.remove()

        for module_name in regmean_weights.keys():
            regmean_weights[module_name] = regmean_weights[module_name].detach().cpu()

        return regmean_weights

RegMeanAlgorithmForGPT2

Bases: RegMeanAlgorithm, LightningFabricMixin

Source code in fusion_bench/method/regmean/gpt2_regmean.py
class RegMeanAlgorithmForGPT2(
    RegMeanAlgorithm,
    LightningFabricMixin,
):
    _include_module_type = [Conv1D]
    classifiers = {}
    _config_mapping = RegMeanAlgorithm._config_mapping | {
        "cache_dir": "cache_dir",
        "batch_size": "batch_size",
        "num_workers": "num_workers",
    }

    def __init__(self, cache_dir: str, batch_size: int, num_workers: int, **kwargs):
        self.cache_dir = cache_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        super().__init__(**kwargs)

    def on_regmean_start(self):
        for model_name in self.modelpool.model_names:
            classifier = cast(
                GPT2ForSequenceClassification,
                self.modelpool.load_classifier(model_name),
            ).requires_grad_(False)
            classifier.transformer = None
            classifier = classifier.to(self.fabric.device)
            self.classifiers[model_name] = classifier

    def compute_logits(self, module: GPT2Model, batch, task: str) -> Tensor:
        self.classifiers[task].transformer = module
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]

        outputs = self.classifiers[task](input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        assert logits.dim() == 2
        return logits

    def get_regmean_weights(
        self,
        model_name: str,
        model: Module,
        train_dataset,
        linear_modules_to_merge: Dict[str, Module],
    ):
        # setup dataloader
        train_dataloader = DataLoader(
            train_dataset,
            batch_size=self.config.batch_size,
            shuffle=True,
            num_workers=self.config.num_workers,
            collate_fn=default_data_collator,
            pin_memory=True,
        )
        train_dataloader = self.fabric.setup_dataloaders(train_dataloader)
        model = self.fabric.setup(model)

        def compute_regmean_weights(module_name: str):
            """
            compute the regmean weights, a hook function to deal with each module's input
            :param module_name: str, module name
            :return:
            """

            def hook(module: nn.Module, input: tuple, output: torch.Tensor):
                # Tensor, shape (batch_size, sequence_length, hidden_dim)
                x = cast(Tensor, input[0]).detach()
                batch_num_actual_examples = x.shape[0]
                # Tensor, shape (batch_size * sequence_length, hidden_dim)
                x = x.reshape(-1, x.shape[-1])
                # Tensor, shape (hidden_dim, hidden_dim)
                xtx = torch.matmul(x.transpose(0, 1), x)
                # store the averaged weights in regmean_weights
                if module_name not in regmean_weights.keys():
                    regmean_weights[module_name] = xtx / x.shape[0]
                    num_computed_examples[module_name] = x.shape[0]
                    num_actual_examples[module_name] = batch_num_actual_examples
                else:
                    regmean_weights[module_name] = (
                        regmean_weights[module_name]
                        * num_computed_examples[module_name]
                        + xtx
                    ) / (num_computed_examples[module_name] + x.shape[0])
                    num_computed_examples[module_name] += x.shape[0]
                    num_actual_examples[module_name] += batch_num_actual_examples

            return hook

        handles = []
        # dictionary, regmean matrices for each linear module inputs
        regmean_weights = {}
        # dictionary, number of examples (multiplied the sequence length) used for computing regmean matrices
        num_computed_examples = {}
        # dictionary, number of actual examples used for computing regmean matrices
        num_actual_examples = {}

        for module_name, linear_module_to_merge in linear_modules_to_merge.items():
            # register a hook in the forward process
            handle = linear_module_to_merge.register_forward_hook(
                compute_regmean_weights(module_name=module_name)
            )
            handles.append(handle)
        for step, batch in tqdm(
            enumerate(train_dataloader),
            desc=f"computing regmean weights for model {model_name}",
        ):
            if (
                len(num_actual_examples) > 0
                and list(num_actual_examples.values())[0]
                >= self.config.num_regmean_examples
            ):
                break
            logits = self.compute_logits(model, batch, model_name)

        # remove the added hook
        for handle in handles:
            handle.remove()

        for module_name in regmean_weights.keys():
            regmean_weights[module_name] = regmean_weights[module_name].detach().cpu()

        return regmean_weights

RegMean++

RegMeanAlgorithmForCLIPPlusPlus

Bases: RegMeanAlgorithmPlusPlus, CLIPClassificationMixin

Source code in fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py
class RegMeanAlgorithmForCLIPPlusPlus(
    RegMeanAlgorithmPlusPlus,
    CLIPClassificationMixin,
):
    _config_mapping = {
        "_dataloader_kwargs": "dataloader_kwargs",
    }

    def __init__(self, *, dataloader_kwargs: DictConfig, **kwargs):
        super().__init__(**kwargs)
        self._dataloader_kwargs = dataloader_kwargs

    def on_regmean_start(self):
        self.setup_zero_shot_classification_head()

    def compute_logits(self, module, batch, task: str) -> Tensor:
        images, _ = batch
        text_embeds = self.zeroshot_weights[task]

        image_embeds = module(images)[1]
        image_embeds = self.visual_projection(image_embeds)

        # normalize embeddings
        image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)

        # cosine similarity
        logits_per_text = (
            torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
        )
        logits_per_image = logits_per_text.t()

        return logits_per_image

    def get_regmean_weights(
        self,
        model_name: str,
        layer: Module,
        batches_input: List[Tensor],
        linear_modules_to_merge: Dict[str, Module],
    ):
        layer = self.fabric.setup(layer)

        def compute_regmean_weights(module_name: str):
            """
            compute the regmean weights, a hook function to deal with each module's input
            :param module_name: str, module name
            :return:
            """

            def hook(module: nn.Module, input: tuple, output: torch.Tensor):
                # Tensor, shape (batch_size, sequence_length, hidden_dim)
                x = cast(Tensor, input[0]).detach()
                batch_num_actual_examples = x.shape[0]
                # Tensor, shape (batch_size * sequence_length, hidden_dim)
                x = x.reshape(-1, x.shape[-1])
                # Tensor, shape (hidden_dim, hidden_dim)
                xtx = torch.matmul(x.transpose(0, 1), x)
                # store the averaged weights in regmean_weights
                if module_name not in regmean_weights.keys():
                    regmean_weights[module_name] = xtx / x.shape[0]
                    num_computed_examples[module_name] = x.shape[0]
                    num_actual_examples[module_name] = batch_num_actual_examples
                else:
                    regmean_weights[module_name] = (
                        regmean_weights[module_name]
                        * num_computed_examples[module_name]
                        + xtx
                    ) / (num_computed_examples[module_name] + x.shape[0])
                    num_computed_examples[module_name] += x.shape[0]
                    num_actual_examples[module_name] += batch_num_actual_examples

            return hook

        handles = []
        # dictionary, regmean matrices for each linear module inputs
        regmean_weights = {}
        # dictionary, number of examples (multiplied the sequence length) used for computing regmean matrices
        num_computed_examples = {}
        # dictionary, number of actual examples used for computing regmean matrices
        num_actual_examples = {}

        for module_name, linear_module_to_merge in linear_modules_to_merge.items():
            # register a hook in the forward process
            handle = linear_module_to_merge.register_forward_hook(
                compute_regmean_weights(module_name=module_name)
            )
            handles.append(handle)
        _ = self.layer_batches_forward(layer, batches_input)

        # remove the added hook
        for handle in handles:
            handle.remove()

        for module_name in regmean_weights.keys():
            regmean_weights[module_name] = regmean_weights[module_name].detach().cpu()

        return regmean_weights

    def merge_embedding_layer(self, models_to_merge_dict: Dict[str, nn.Module]):
        models_to_merge_param_dict = defaultdict(list)

        # get the parameters of the embedding layer from each model
        for model_to_merge in models_to_merge_dict.values():
            model_to_merge_state_dict = model_to_merge.state_dict()

            param_dict = {}
            for name, param in model_to_merge_state_dict.items():
                if name.startswith("vision_model.embeddings") or name.startswith("vision_model.pre_layrnorm"):
                    param_dict[name] = param

            for param_name in param_dict.keys():
                models_to_merge_param_dict[param_name].append(
                    param_dict[param_name]
                )

        # merge the parameters of the embedding layer
        merged_params_dict = {}
        for param_name, param_list in models_to_merge_param_dict.items():
            merged_params_dict[param_name] = torch.stack(param_list).mean(dim=0)

        return merged_params_dict


    def get_input_for_first_layer(self, model: nn.Module, train_dataset):
        # setup dataloader
        train_dataset = CLIPDataset(train_dataset, self.clip_processor)
        train_dataloader = DataLoader(
            train_dataset, shuffle=True, **self._dataloader_kwargs
        )
        train_dataloader = self.fabric.setup_dataloaders(train_dataloader)
        model = self.fabric.setup(model)

        def compute_input(model, batch):
            images, _ = batch

            images = images.to(model.device)
            image_embeds = model.vision_model.embeddings(images)
            image_embeds = model.vision_model.pre_layrnorm(image_embeds)
            image_embeds = image_embeds.detach().cpu()

            return image_embeds

        num_computed_examples = 0
        num_regmean_examples = self.num_regmean_examples

        batches_input = []
        for batch in train_dataloader:
            if num_computed_examples >= num_regmean_examples:
                break
            batches_input.append(compute_input(model, batch))
            num_computed_examples += batch[0].size(0)

        return batches_input

    def get_layers(self, model: nn.Module):
        return model.vision_model.encoder.layers

    def update_merged_params_dict(self, merged_params_dict, new_merged_params, layer_idx):
        for key, value in new_merged_params.items():
            key = f"vision_model.encoder.layers.{layer_idx}.{key}"
            merged_params_dict[key] = value

        return merged_params_dict

    def layer_batches_forward(self, layer: nn.Module, batches_input: List[Tensor]) -> Tensor:
        batches_output = []
        for batch in batches_input:
            device = next(layer.parameters()).device
            batch = batch.to(device)
            logits = layer(batch, attention_mask=None, causal_attention_mask=None)[0].detach().cpu()
            batches_output.append(logits)
        return batches_output

Frank-Wolfe Merging

FrankWolfeSoftAlgorithm

Bases: CLIPClassificationMixin, ModelFusionAlgorithm, SimpleProfilerMixin

Source code in fusion_bench/method/fw_merging/fw_soft.py
class FrankWolfeSoftAlgorithm(
    CLIPClassificationMixin,
    ModelFusionAlgorithm,
    SimpleProfilerMixin,
):
    def __init__(
        self,
        max_iters: int,
        dataset_size: int,
        ada_iters: int,
        ada_coeff: float,
        merge_fn: str,
        granularity: str = "task",
        max_num_models: int = 100,
        step_size: float = 0.3,
        tasks: List[str] = [],
        init_weight: str = "",
        ada_loss="entropy_loss",
        **kwargs,
    ):
        """
        Initializes the TaskArithmeticAlgorithm with the given scaling factor.

        Args:
            step_size (int): The factor by which the task vectors will be scaled before merging.
        """
        self.merge_fn = merge_fn

        self.init_weight = init_weight
        self.max_iters = max_iters
        self.ada_iters = ada_iters
        self.ada_coeff = ada_coeff
        self.granularity = granularity
        self.tasks = tasks
        self.step_size = step_size
        self.dataset_size = dataset_size
        self.max_num_models = max_num_models
        self.ada_loss = ada_loss
        super().__init__(**kwargs)

    def on_frank_wolfe_iteration_start(self):
        self.setup_zero_shot_classification_head()

    @functools.cache
    def get_shuffled_train_loader_iter(self, task: str, batch_size: int = 1):
        # get dataloader kwargs
        dataloader_kwargs = self._dataloader_kwargs.copy()
        dataloader_kwargs["shuffle"] = True
        dataloader_kwargs["batch_size"] = batch_size

        # get the test dataset
        clip_dataset = CLIPDataset(
            self.modelpool.load_train_dataset(task), self.clip_processor
        )
        # create the dataloader
        loader = DataLoader(clip_dataset, **dataloader_kwargs)
        loader = self.fabric.setup_dataloaders(loader)
        return iter(InfiniteDataLoader(loader))

    @functools.cache
    def get_shuffled_test_loader_iter(self, task: str, batch_size: int = 1):
        return super().get_shuffled_test_loader_iter(task, batch_size=batch_size)

    def run_adamerging(self, module):
        use_entropy_loss = self.ada_loss == "entropy_loss"

        optimizer = torch.optim.Adam([module.merge_weight], lr=1e-3)
        module, optimizer = self.fabric.setup(module, optimizer)
        module.train()
        for step_idx in (
            pbar := tqdm(
                range(self.ada_iters),
                "AdaMerging (2/2)",
                dynamic_ncols=True,
                disable=not self.fabric.is_global_zero,
            )
        ):
            with self.profile("merge weights"):
                module.merge_weights()

            metrics = {}
            total_loss = None
            tasks = self.modelpool.model_names if self.tasks == [] else self.tasks
            if not use_entropy_loss:
                loss_fn = nn.CrossEntropyLoss()
            for task in tasks:
                with self.profile("data loading"):
                    if use_entropy_loss:
                        batch = next(
                            self.get_shuffled_test_loader_iter(task, batch_size=16)
                        )
                    else:
                        batch = next(
                            self.get_shuffled_train_loader_iter(task, batch_size=16)
                        )
                        # NOTE: The labels are not allowed to be used during test-time adaptation
                    images = batch[0]
                with self.profile("forward pass"):
                    logits = self.compute_logits(module, images, task)
                    if use_entropy_loss:
                        loss = entropy_loss(logits)
                    else:
                        loss = loss_fn(logits, batch[1])
                    total_loss = loss if total_loss is None else total_loss + loss

            optimizer.zero_grad()
            with self.profile("compute grad"):
                self.fabric.backward(total_loss)

            with self.profile("base optimizer step"):
                optimizer.step()

            metrics.update({"train/loss": loss.item()})
            self.fabric.log_dict(metrics, step=step_idx)
            pbar.set_postfix(metrics)
        return module

    def frank_wolfe_iteration(self, merged_model, task):

        merged_model.train()
        # zero the gradients
        requires_grad_dict = {}
        for name, param in merged_model.named_parameters():
            requires_grad_dict[name] = param.requires_grad
            param.requires_grad = True
            param.grad = None

        loss_fn = nn.CrossEntropyLoss()
        avg_loss = defaultdict(list)
        log.info(f"Processing task {task}")
        for i in range(self.dataset_size):
            with self.profile("data loading"):
                batch = next(self.get_shuffled_train_loader_iter(task))
            with self.profile("forward pass"):
                logits = self.compute_logits(merged_model, batch[0], task)
                loss = loss_fn(logits, batch[1]) / (
                    self.dataset_size * len(self.modelpool.model_names)
                )
            with self.profile("backward pass"):
                loss.backward()
            avg_loss[task].append(loss.item())

        # calculate the loss
        avg_loss = {
            task: sum(losses) / len(losses) for task, losses in avg_loss.items()
        }
        log.info(
            f"Average Loss: {avg_loss}, Total Loss: {sum(avg_loss.values()) / len(avg_loss)}"
        )

        gradients = {
            name: param.grad.clone().to("cpu")
            for name, param in merged_model.named_parameters()
            if param.requires_grad
        }
        for name, param in merged_model.named_parameters():
            param.requires_grad = requires_grad_dict[name]
            param.grad = None
        merged_model.eval()

        return gradients

    def frank_wolfe_selection(
        self, gradients, checkpoints, model_to_merge_names=[], type="task"
    ):
        assert type in [
            "task",
            "layer",
        ], f"Unsupported FW selection type: {type}, supported types are ['task', 'layer']"
        min_inner_product = float("inf")
        min_model = None
        min_model_name = None
        log_dict = {}
        if type == "task":
            for model_name, model_to_merge in checkpoints.items():
                model_to_merge = model_to_merge.to("cpu").state_dict()
                inner_product_sum = 0
                for param_name, param_value in model_to_merge.items():
                    # caclulate consine similarity
                    grad = gradients[param_name]
                    ckpt = model_to_merge[param_name]
                    param_alignment = torch.dot(grad.flatten(), ckpt.flatten()) / (
                        torch.norm(grad) * torch.norm(ckpt)
                    )
                    inner_product_sum += param_alignment
                log_dict[model_name] = inner_product_sum.item()
                if (
                    inner_product_sum < min_inner_product
                    and model_name not in model_to_merge_names
                ):
                    min_inner_product = inner_product_sum
                    min_model = deepcopy(model_to_merge)
                    min_model_name = model_name
        else:
            min_model = {}
            min_inner_product = {}
            min_idx = {}
            min_model_name = {}
            for model_name, model_to_merge in checkpoints.items():
                model_to_merge = model_to_merge.to("cpu").state_dict()
                for param_name, param_value in model_to_merge.items():
                    # caclulate consine similarity
                    grad = gradients[param_name]
                    ckpt = model_to_merge[param_name]
                    param_alignment = torch.dot(grad.flatten(), ckpt.flatten()) / (
                        torch.norm(grad) * torch.norm(ckpt)
                    )
                    if (
                        param_name not in min_inner_product
                        or param_alignment < min_inner_product[param_name]
                    ) and model_name not in model_to_merge_names[param_name]:
                        min_inner_product[param_name] = param_alignment
                        min_model[param_name] = param_value
                        min_idx[param_name] = model_name
                        min_model_name[param_name] = model_name
            min_inner_product = sum(min_inner_product.values())
            log_dict = {model_name: 0 for model_name in checkpoints.keys()}
            for k in min_idx.values():
                log_dict[k] += 1

        return min_model, min_model_name, min_inner_product, log_dict

    def run(self, modelpool: HuggingFaceClipVisionPool):
        log.info("Fusing models using FW merging.")
        self.modelpool = modelpool
        tasks = self.tasks if self.tasks else self.modelpool.model_names
        self.log_hyperparams(self.config)
        self.on_frank_wolfe_iteration_start()

        assert modelpool.has_pretrained, "Pretrained model is required."
        finetuned_models = {
            name: modelpool.load_model(name)
            for name in modelpool.model_names[: self.max_num_models]
        }

        if self.init_weight == "base" or self.init_weight == "":
            merged_model = modelpool.load_model("_pretrained_")
        else:
            log.info("Initializing the merged model with the initial weight")
            if isinstance(self.init_weight, str):
                # self.config.weights is a path to a saved tensor
                layer_wise_weight = load_tensor_from_file(self.init_weight)
            else:
                raise ValueError(f"Unsupported weights format: {self.init_weight}")

            pretrained_model = modelpool.load_model("_pretrained_")
            layerwise_merged_model = LayerWiseMergedModel(
                layer_wise_weight=layer_wise_weight,
                pretrained_model=pretrained_model,
                finetuned_models=list(finetuned_models.values())[: self.max_num_models],
                clamp_weights=False,
                tie_weights=True,
                strict=False,
            ).cuda()
            merged_model = layerwise_merged_model.merge_and_unload()

        initial_model = modelpool.load_model("_pretrained_")
        self.set_requires_grad(merged_model, initial_model)
        # initial_model.load_state_dict(deepcopy(merged_model.state_dict()))
        # finetuned_models['initial'] = initial_model
        for step_idx in (
            pbar := tqdm(
                range(self.max_iters if not self.is_debug_mode else 1),
                ("[DEBUG MODE] " if self.is_debug_mode else "") + "Frank-Wolfe Merging",
                dynamic_ncols=True,
            )
        ):
            # Find the task vector with the most alignment to the gradient
            models_dict_to_merge = []
            model_to_merge_names = (
                []
                if self.granularity == "task"
                else {name: [] for name in merged_model.state_dict().keys()}
            )
            inner_products = []
            for task in tasks:
                torch.set_grad_enabled(True)
                torch.cuda.empty_cache()
                gradients = self.frank_wolfe_iteration(merged_model.cuda(), task)
                torch.set_grad_enabled(False)
                grad_norm = torch.norm(
                    torch.stack([torch.norm(g) for g in gradients.values()])
                )

                min_model, min_model_name, min_inner_product, log_dict = (
                    self.frank_wolfe_selection(
                        gradients,
                        finetuned_models,
                        model_to_merge_names,
                        type=self.granularity,
                    )
                )
                if self.granularity == "task":
                    model_to_merge_names.append(min_model_name)
                else:
                    for k, v in min_model_name.items():
                        model_to_merge_names[k].append(v)
                models_dict_to_merge.append(min_model)
                inner_products.append(min_inner_product)

                log.info(f"Task: {task}, Inner Products: {log_dict}")
                if (
                    len(models_dict_to_merge) >= len(self.modelpool.model_names)
                    or len(models_dict_to_merge) >= self.max_num_models
                ):
                    log.info(f"Breaking at {len(models_dict_to_merge)}")
                    break

            # print iteration information
            log.info(
                f"Iteration {step_idx+1}, Task Vector: {model_to_merge_names}, Gradient Norm: {grad_norm:.6f}, Inner Products: {inner_products}"
            )

            if self.merge_fn == "adamerging":
                models_to_merge = [
                    modelpool.load_model("_pretrained_")
                    for _ in range(len(models_dict_to_merge))
                ]
                layer_wise_weight = get_layer_wise_weights(
                    num_models=len(models_to_merge),
                    num_layers=len(
                        tuple(
                            filter(
                                lambda p: p.requires_grad,
                                models_to_merge[0].parameters(),
                            )
                        )
                    ),
                    init_values=self.ada_coeff if step_idx > 0 else 0.3,
                )
                for model_to_merge, model_to_merge_dict in zip(
                    models_to_merge, models_dict_to_merge
                ):
                    model_to_merge.load_state_dict(model_to_merge_dict)
                layerwise_merged_model = LayerWiseMergedModel(
                    layer_wise_weight=layer_wise_weight,
                    pretrained_model=merged_model.to("cpu"),
                    finetuned_models=models_to_merge,
                    clamp_weights=False,
                    tie_weights=True,
                    strict=False,
                ).cuda()
                torch.set_grad_enabled(True)
                layerwise_merged_model = self.run_adamerging(layerwise_merged_model)
                torch.set_grad_enabled(False)
                with torch.no_grad():
                    merged_model = layerwise_merged_model.merge_and_unload()
                    self.set_requires_grad(merged_model, initial_model)
                del (
                    models_to_merge,
                    layerwise_merged_model,
                    layer_wise_weight,
                    models_dict_to_merge,
                )
            else:
                step = 2 / (step_idx + 2) * self.step_size if step_idx > 0 else 1
                merged_model = task_arithmetic_merge(
                    merged_model.to("cpu"), models_dict_to_merge, 0.3 * step
                )
                del models_dict_to_merge

        torch.set_grad_enabled(False)
        merged_model = merged_model.cuda().eval()
        return merged_model

    def set_requires_grad(self, merged_model, initial_model):
        for name, param in initial_model.named_parameters():
            for n, p in merged_model.named_parameters():
                if name == n:
                    p.requires_grad = param.requires_grad
__init__(max_iters, dataset_size, ada_iters, ada_coeff, merge_fn, granularity='task', max_num_models=100, step_size=0.3, tasks=[], init_weight='', ada_loss='entropy_loss', **kwargs)

Initializes the TaskArithmeticAlgorithm with the given scaling factor.

Parameters:

  • step_size (int, default: 0.3 ) –

    The factor by which the task vectors will be scaled before merging.

Source code in fusion_bench/method/fw_merging/fw_soft.py
def __init__(
    self,
    max_iters: int,
    dataset_size: int,
    ada_iters: int,
    ada_coeff: float,
    merge_fn: str,
    granularity: str = "task",
    max_num_models: int = 100,
    step_size: float = 0.3,
    tasks: List[str] = [],
    init_weight: str = "",
    ada_loss="entropy_loss",
    **kwargs,
):
    """
    Initializes the TaskArithmeticAlgorithm with the given scaling factor.

    Args:
        step_size (int): The factor by which the task vectors will be scaled before merging.
    """
    self.merge_fn = merge_fn

    self.init_weight = init_weight
    self.max_iters = max_iters
    self.ada_iters = ada_iters
    self.ada_coeff = ada_coeff
    self.granularity = granularity
    self.tasks = tasks
    self.step_size = step_size
    self.dataset_size = dataset_size
    self.max_num_models = max_num_models
    self.ada_loss = ada_loss
    super().__init__(**kwargs)

FrankWolfeHardAlgorithm

Bases: CLIPClassificationMixin, ModelFusionAlgorithm, SimpleProfilerMixin

Source code in fusion_bench/method/fw_merging/fw_hard.py
class FrankWolfeHardAlgorithm(
    CLIPClassificationMixin,
    ModelFusionAlgorithm,
    SimpleProfilerMixin,
):

    def __init__(
        self,
        merge_fn: str,
        step_size: float,
        max_iters: int,
        dataset_size: int,
        tasks: List[str] = [],
        granularity: str = "task",
        max_num_models: int = 100,
        loss_fn: str = "cross_entropy",
        init_weight: str = "",
        scaling_factor: float = 1.0,
        threshold: int = 20,
        **kwargs,
    ):
        """
        Initializes the TaskArithmeticAlgorithm with the given scaling factor.

        Args:
            scaling_factor (int): The factor by which the task vectors will be scaled before merging.
        """
        self.merger = merge_fn
        if merge_fn == "task_arithmetic":
            self.merge_fn = task_arithmetic_merge
        elif merge_fn == "ties":
            self.merge_fn = partial(ties_merge, threshold=threshold)
        # elif merge_fn == "concrete_ta":
        #     self.merge_fn = ConcreteTaskArithmeticAlgorithmForCLIP(
        #         instantiate(OmegaConf.load("config/method/concrete_subspace/clip_concrete_task_arithmetic.yaml"))
        #     )
        else:
            raise ValueError(f"Unsupported merge_fn: {merge_fn}")
        self.scaling_factor = scaling_factor

        self.init_weight = init_weight
        self.step_size = step_size
        self.max_iters = max_iters
        self.granularity = granularity
        self.loss_fn = loss_fn
        self.tasks = tasks
        self.dataset_size = dataset_size
        self.max_num_models = max_num_models
        super().__init__(**kwargs)

    def on_frank_wolfe_iteration_start(self):
        self.setup_zero_shot_classification_head()

    @functools.cache
    def get_shuffled_loader_iter(self, task: str):
        if self.loss_fn == "cross_entropy":
            # get dataloader kwargs
            dataloader_kwargs = self._dataloader_kwargs.copy()
            dataloader_kwargs["shuffle"] = True
            dataloader_kwargs["batch_size"] = 1

            # get the test dataset
            clip_dataset = CLIPDataset(
                self.modelpool.load_train_dataset(task), self.clip_processor
            )
            # create the dataloader
            loader = DataLoader(clip_dataset, **dataloader_kwargs)
            loader = self.fabric.setup_dataloaders(loader)
            return iter(InfiniteDataLoader(loader))
        elif self.loss_fn == "entropy":
            return super().get_shuffled_test_loader_iter(
                task,
                batch_size=1,
            )
        else:
            raise ValueError(f"Unsupported loss function: {self.loss_fn}")

    def frank_wolfe_iteration(self, merged_model):

        merged_model.train()
        # zero the gradients
        for name, param in merged_model.named_parameters():
            param.requires_grad = True
            param.grad = None

        if self.loss_fn == "cross_entropy":
            loss_fn = nn.CrossEntropyLoss()
        elif self.loss_fn == "entropy":
            loss_fn = entropy_loss
        avg_loss = defaultdict(list)
        tasks = self.tasks if self.tasks else self.modelpool.model_names
        for task in tasks:
            log.info(f"Processing task {task}")
            for _ in range(self.dataset_size):
                with self.profile("data loading"):
                    batch = next(self.get_shuffled_loader_iter(task))
                with self.profile("forward pass"):
                    logits = self.compute_logits(merged_model, batch[0], task)
                    loss = loss_fn(logits, batch[1]) / (
                        self.dataset_size * len(self.modelpool.model_names)
                    )
                with self.profile("backward pass"):
                    # self.fabric.backward(loss, retain_graph=True)
                    loss.backward()
                avg_loss[task].append(loss.item())

        # calculate the loss
        avg_loss = {
            task: sum(losses) / len(losses) for task, losses in avg_loss.items()
        }
        log.info(
            f"Average Loss: {avg_loss}, Total Loss: {sum(avg_loss.values()) / len(avg_loss)}"
        )

        gradients = {
            name: param.grad.clone().to("cpu")
            for name, param in merged_model.named_parameters()
            if param.requires_grad
        }
        for name, param in merged_model.named_parameters():
            param.grad = None
        merged_model.eval()

        return gradients

    def frank_wolfe_selection(
        self, gradients, checkpoints, model_to_merge_names={}, type="task"
    ):
        assert type in [
            "task",
            "layer",
        ], f"Unsupported FW selection type: {type}, supported types are ['task', 'layer']"
        min_inner_product = float("inf")
        min_model = None
        min_model_name = None
        log_dict = {}
        if type == "task":
            for model_name, model_to_merge in checkpoints.items():
                model_to_merge = model_to_merge.to("cpu").state_dict()
                inner_product_sum = 0
                for param_name, param_value in model_to_merge.items():
                    # caclulate consine similarity
                    grad = gradients[param_name]
                    ckpt = model_to_merge[param_name]
                    param_alignment = torch.dot(grad.flatten(), ckpt.flatten()) / (
                        torch.norm(grad) * torch.norm(ckpt)
                    )
                    inner_product_sum += param_alignment
                log_dict[model_name] = inner_product_sum.item()
                if (
                    inner_product_sum < min_inner_product
                    and model_name not in model_to_merge_names
                ):
                    min_inner_product = inner_product_sum
                    min_model = deepcopy(model_to_merge)
                    min_model_name = model_name
        else:
            min_model = {}
            min_inner_product = {}
            min_idx = {}
            min_model_name = {}
            for model_name, model_to_merge in checkpoints.items():
                model_to_merge = model_to_merge.to("cpu").state_dict()
                for param_name, param_value in model_to_merge.items():
                    # caclulate consine similarity
                    grad = gradients[param_name]
                    ckpt = model_to_merge[param_name]
                    param_alignment = torch.dot(grad.flatten(), ckpt.flatten()) / (
                        torch.norm(grad) * torch.norm(ckpt)
                    )
                    if (
                        param_name not in min_inner_product
                        or param_alignment < min_inner_product[param_name]
                    ) and model_name not in model_to_merge_names[param_name]:
                        min_inner_product[param_name] = param_alignment
                        # if min_inner_product[param_name] < 0:
                        min_model[param_name] = param_value
                        min_idx[param_name] = model_name
                        min_model_name[param_name] = model_name
                        # else:
                        # min_model[param_name] = torch.zeros_like(param_value)
            min_inner_product = sum(min_inner_product.values())
            log_dict = {model_name: 0 for model_name in checkpoints.keys()}
            for k in min_idx.values():
                log_dict[k] += 1

        return min_model, min_model_name, min_inner_product, log_dict

    def run(self, modelpool: HuggingFaceClipVisionPool):
        log.info("Fusing models using FW merging.")
        self.modelpool = modelpool
        self.log_hyperparams(self.config)
        self.on_frank_wolfe_iteration_start()

        assert modelpool.has_pretrained, "Pretrained model is required."
        finetuned_models = {
            name: modelpool.load_model(name)
            for name in modelpool.model_names[: self.max_num_models]
        }
        pretrained_model = modelpool.load_model("_pretrained_")

        if self.init_weight:
            if self.init_weight == "base":
                log.info("Initializing the merged model with the base model")
                merged_model = pretrained_model
            else:
                log.info("Initializing the merged model with the initial weight")
                if isinstance(self.init_weight, str):
                    # self.config.weights is a path to a saved tensor
                    layer_wise_weight = load_tensor_from_file(self.init_weight)
                else:
                    raise ValueError(f"Unsupported weights format: {self.init_weight}")

                merged_model = LayerWiseMergedModel(
                    layer_wise_weight=layer_wise_weight,
                    pretrained_model=modelpool.load_model("_pretrained_"),
                    finetuned_models=list(finetuned_models.values()),
                    clamp_weights=False,
                    tie_weights=True,
                    strict=False,
                ).cuda()
                merged_model = merged_model.merge_and_unload()
        else:
            log.info("Initializing the merged model with merge function")
            merged_model = self.merge_fn(
                pretrained_model=modelpool.load_model("_pretrained_"),
                finetuned_models=list(finetuned_models.values()),
                scaling_factor=self.scaling_factor,
            ).cuda()
        # merged_model = self.fabric.setup(merged_model)

        initial_model = modelpool.load_model("_pretrained_")
        initial_model.load_state_dict(deepcopy(merged_model.state_dict()))
        finetuned_models["initial"] = initial_model
        for step_idx in (
            pbar := tqdm(
                range(self.max_iters if not self.is_debug_mode else 1),
                ("[DEBUG MODE] " if self.is_debug_mode else "") + "Frank-Wolfe Merging",
                dynamic_ncols=True,
            )
        ):
            torch.cuda.empty_cache()
            torch.set_grad_enabled(True)
            gradients = self.frank_wolfe_iteration(merged_model.cuda())
            torch.set_grad_enabled(False)
            grad_norm = torch.norm(
                torch.stack([torch.norm(g) for g in gradients.values()])
            )

            model_to_merge_names = (
                []
                if self.granularity == "task"
                else {name: [] for name in merged_model.state_dict().keys()}
            )
            min_model, min_model_name, min_alignment, chosen_model = (
                self.frank_wolfe_selection(
                    gradients,
                    finetuned_models,
                    model_to_merge_names=model_to_merge_names,
                    type=self.granularity,
                )
            )

            # Determine step size
            step = 2 / (step_idx + 2) * self.step_size

            # print iteration information
            log.info(
                f"Iteration {step_idx+1}, Task Vector: {min_model_name}, Gradient Norm: {grad_norm:.6f}, Inner Products: {min_alignment:.6f}, Chosen Model: {chosen_model}"
            )

            merged_model = self.merge_fn(
                pretrained_model=merged_model.to("cpu"),
                finetuned_models=[min_model],
                scaling_factor=step * self.scaling_factor,
            )

        torch.set_grad_enabled(False)
        merged_model = merged_model.cuda().eval()
        return merged_model
__init__(merge_fn, step_size, max_iters, dataset_size, tasks=[], granularity='task', max_num_models=100, loss_fn='cross_entropy', init_weight='', scaling_factor=1.0, threshold=20, **kwargs)

Initializes the TaskArithmeticAlgorithm with the given scaling factor.

Parameters:

  • scaling_factor (int, default: 1.0 ) –

    The factor by which the task vectors will be scaled before merging.

Source code in fusion_bench/method/fw_merging/fw_hard.py
def __init__(
    self,
    merge_fn: str,
    step_size: float,
    max_iters: int,
    dataset_size: int,
    tasks: List[str] = [],
    granularity: str = "task",
    max_num_models: int = 100,
    loss_fn: str = "cross_entropy",
    init_weight: str = "",
    scaling_factor: float = 1.0,
    threshold: int = 20,
    **kwargs,
):
    """
    Initializes the TaskArithmeticAlgorithm with the given scaling factor.

    Args:
        scaling_factor (int): The factor by which the task vectors will be scaled before merging.
    """
    self.merger = merge_fn
    if merge_fn == "task_arithmetic":
        self.merge_fn = task_arithmetic_merge
    elif merge_fn == "ties":
        self.merge_fn = partial(ties_merge, threshold=threshold)
    # elif merge_fn == "concrete_ta":
    #     self.merge_fn = ConcreteTaskArithmeticAlgorithmForCLIP(
    #         instantiate(OmegaConf.load("config/method/concrete_subspace/clip_concrete_task_arithmetic.yaml"))
    #     )
    else:
        raise ValueError(f"Unsupported merge_fn: {merge_fn}")
    self.scaling_factor = scaling_factor

    self.init_weight = init_weight
    self.step_size = step_size
    self.max_iters = max_iters
    self.granularity = granularity
    self.loss_fn = loss_fn
    self.tasks = tasks
    self.dataset_size = dataset_size
    self.max_num_models = max_num_models
    super().__init__(**kwargs)

Subspace-based Methods

Concrete Subspace

ConcreteTaskArithmeticAlgorithmForCLIP

Bases: CLIPClassificationMixin, SimpleProfilerMixin, ModelFusionAlgorithm

ConcreteTaskArithmeticAlgorithmForCLIP is a class for performing task arithmetic on CLIP models with learned masking.

This class extends the CLIPClassificationMixin, SimpleProfilerMixin, and ModelFusionAlgorithm classes. It provides methods for setting up models, training masks, and running the task arithmetic algorithm.

Attributes:

  • merge_dtype (dtype) –

    The data type for merging weights.

  • modelpool (HuggingFaceClipVisionPool) –

    The model pool containing the pretrained and fine-tuned models.

Source code in fusion_bench/method/concrete_subspace/clip_concrete_task_arithmetic.py
class ConcreteTaskArithmeticAlgorithmForCLIP(
    CLIPClassificationMixin,
    SimpleProfilerMixin,
    ModelFusionAlgorithm,
):
    """
    ConcreteTaskArithmeticAlgorithmForCLIP is a class for performing task arithmetic on CLIP models with learned masking.

    This class extends the CLIPClassificationMixin, SimpleProfilerMixin, and ModelFusionAlgorithm classes.
    It provides methods for setting up models, training masks, and running the task arithmetic algorithm.

    Attributes:
        merge_dtype (torch.dtype): The data type for merging weights.
        modelpool (HuggingFaceClipVisionPool): The model pool containing the pretrained and fine-tuned models.
    """

    @torch.no_grad()
    def setup_models(self):
        """
        Set up the pretrained model, fine-tuned models, and mask model.

        This method loads the pretrained model, constructs the PGE mask model, and loads the fine-tuned models.
        It also creates a wrapped model with task-wise weights.

        Returns:
            Tuple[TaskWiseMergedModel, MaskModel]: The wrapped model and mask model.
        """
        config = self.config
        self.merge_dtype = parse_dtype(config.get("merge_dtype", None))
        modelpool = self.modelpool

        # Load the pretrained model
        pretrained_model = modelpool.load_model("_pretrained_")

        # construct PGE mask model
        mask_model = MaskModel(
            pretrained_model,
            ignore_untrained_params=True,
            parameter_type="logits",
        )
        if self.merge_dtype is not None:
            mask_model.to(self.merge_dtype)
        mask_model.fill_(self.config.initial_logits)
        # TODO: ablation study for the initialization of mask model
        # for param in mask_model.parameters():
        #     param.data = param + 0.1 * torch.randn_like(param)
        print("Summary of mask model:")
        print_parameters(mask_model)

        # Load the fine-tuned models
        finetuned_models = [
            modelpool.load_model(name) for name in modelpool.model_names
        ]

        task_wise_weight = get_task_wise_weights(
            num_models=len(modelpool.model_names),
            init_values=self.config.scaling_factor,
        )

        # create a warpped model
        module = TaskWiseMergedModel(
            task_wise_weight=task_wise_weight,
            pretrained_model=pretrained_model,
            finetuned_models=finetuned_models,
            clamp_weights=self.config.clamp_weights,
            tie_weights=self.config.tie_weights,
            strict=self.config.strict,
            task_vector_dtype=self.merge_dtype,
        )

        return module, mask_model

    def train_mask(self, module: TaskWiseMergedModel, mask_model: MaskModel):
        """
        Train the mask model using the provided module.

        This method configures the optimizer, sets up the mask model, and performs test-time adaptation to train the mask model.

        Args:
            module (TaskWiseMergedModel): The wrapped model with task-wise weights.
            mask_model (MaskModel): The mask model to be trained.
        """
        config = self.config
        # mask_model: MaskModel = self.fabric.to_device(mask_model)

        # configure optimizer
        lr_scheduler = None
        if self.config.optimizer == "adam":
            optimizer = torch.optim.Adam(
                filter(lambda p: p.requires_grad, mask_model.parameters()),
                lr=self.config.lr,
            )
            print(f"{optimizer=}")
            # TODO: ablation study for the learning rate scheduler. It should yield similar results.
            # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            #     optimizer, self.config.max_steps, eta_min=0.1
            # )
            mask_model, optimizer = self.fabric.setup(mask_model, optimizer)
        elif self.config.optimizer == "sgd":
            optimizer = torch.optim.SGD(mask_model.parameters(), lr=self.config.lr)
            print(f"{optimizer=}")
            lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, self.config.max_steps, eta_min=0.1
            )
            mask_model, optimizer = self.fabric.setup(mask_model, optimizer)
        else:
            raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")

        module.to(mask_model.device)
        module.requires_grad_(False)

        mask_model.train()
        optimizer.zero_grad()
        for step_idx in (
            pbar := tqdm(
                range(self.config.max_steps if not self.is_debug_mode else 5),
                ("[DEBUG MODE] " if self.is_debug_mode else "")
                + "Concrete Task Arithmetic Test-Time Adaptation",
                dynamic_ncols=True,
                disable=not self.fabric.is_global_zero,
            )
        ):
            metrics = {}
            # sample a shared mask and merge weights
            with self.profile("sample mask"):
                mask = mask_model.sample_mask(
                    mask_type="continuous", temperature=config.temperature
                )
                metrics["train/sparsity"] = mask_sparsity(mask)
            with self.profile("merge weights"):
                # rescale mask
                for name, m in mask.items():
                    mask[name] = m / torch.mean(m)
                module.merge_weights(task_vector_mask=mask)

            # ------ inner optimization goes here ------
            # NOTE:
            #   Because the algorithmic parameters of task arithmetic are assumed to be chosen on a validation test
            #   set, we do not need to perform inner optimization here. So here we skip the inner optimization step.
            # ------------------------------------------

            total_loss = None
            for task in self.modelpool.model_names:
                with self.profile("data loading"):
                    batch = next(self.get_shuffled_test_loader_iter(task))
                    # NOTE: The labels are not allowed to be used during test-time adaptation
                    images = batch[0].to(dtype=self.merge_dtype)
                with self.profile("forward pass"):
                    logits = self.compute_logits(module, images, task)
                    loss = entropy_loss(logits)
                    total_loss = loss if total_loss is None else total_loss + loss

            with self.profile("compute grad"):
                self.fabric.backward(total_loss)

            with self.profile("optimizer step"):
                optimizer.step()
                optimizer.zero_grad()

                if lr_scheduler is not None:
                    lr_scheduler.step()

            metrics.update({"train/loss": loss.item()})
            self.fabric.log_dict(metrics, step=step_idx)
            pbar.set_postfix(metrics)

            if (step_idx + 1) % self.config.save_interval == 0:
                with self.profiler.profile("save checkpoint"):
                    save_dir = os.path.join(self.fabric.logger.log_dir, "checkpoints")
                    if not os.path.exists(save_dir):
                        os.makedirs(save_dir, exist_ok=True)
                    save_path = os.path.join(save_dir, f"mask_steps_{step_idx}.pt")
                    print(f"saving checkpoint to {save_path}")
                    state = {"model": mask_model}
                    self.fabric.save(save_path, state)

                    # Create or update a symbolic link to the latest checkpoint
                    if self.fabric.is_global_zero:
                        symlink_path = os.path.join(save_dir, "latest_checkpoint.pt")
                        if os.path.exists(symlink_path):
                            os.remove(symlink_path)
                        os.link(os.path.abspath(save_path), symlink_path)

                self.print_profile_summary()

    def run(self, modelpool: HuggingFaceClipVisionPool):
        """
        Run the Concrete Task Arithmetic algorithm.

        This method sets up the models, trains the mask model if necessary, and performs the final merging of weights.

        Args:
            modelpool (HuggingFaceClipVisionPool): The model pool containing the pretrained and fine-tuned models.

        Returns:
            torch.nn.Module: The final merged model.
        """
        self.modelpool = to_modelpool(modelpool)
        config = self.config
        self.log_hyperparams(config, filename="method_config.yaml")

        with self.profile("setup models"):
            module, mask_model = self.setup_models()
            self.setup_zero_shot_classification_head()

        if config.mask_checkpoint is None:
            if not config.skip_training:
                torch.cuda.empty_cache()
                self.train_mask(module=module, mask_model=mask_model)
        else:
            if self.fabric.is_global_zero:
                print("loading mask from checkpoint", config.mask_checkpoint)
            self.fabric.load(config.mask_checkpoint, {"model": mask_model})

        with torch.no_grad():
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            mask = mask_model.sample_mask(
                mask_type=config.eval_mask_type,
                temperature=config.temperature,
            )
            # rescale mask
            for name, m in mask.items():
                mask[name] = m / torch.mean(m)
            model = module.merge_and_unload(mask)
        return model.to(dtype=torch.float32)
run(modelpool)

Run the Concrete Task Arithmetic algorithm.

This method sets up the models, trains the mask model if necessary, and performs the final merging of weights.

Parameters:

  • modelpool (HuggingFaceClipVisionPool) –

    The model pool containing the pretrained and fine-tuned models.

Returns:

  • torch.nn.Module: The final merged model.

Source code in fusion_bench/method/concrete_subspace/clip_concrete_task_arithmetic.py
def run(self, modelpool: HuggingFaceClipVisionPool):
    """
    Run the Concrete Task Arithmetic algorithm.

    This method sets up the models, trains the mask model if necessary, and performs the final merging of weights.

    Args:
        modelpool (HuggingFaceClipVisionPool): The model pool containing the pretrained and fine-tuned models.

    Returns:
        torch.nn.Module: The final merged model.
    """
    self.modelpool = to_modelpool(modelpool)
    config = self.config
    self.log_hyperparams(config, filename="method_config.yaml")

    with self.profile("setup models"):
        module, mask_model = self.setup_models()
        self.setup_zero_shot_classification_head()

    if config.mask_checkpoint is None:
        if not config.skip_training:
            torch.cuda.empty_cache()
            self.train_mask(module=module, mask_model=mask_model)
    else:
        if self.fabric.is_global_zero:
            print("loading mask from checkpoint", config.mask_checkpoint)
        self.fabric.load(config.mask_checkpoint, {"model": mask_model})

    with torch.no_grad():
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        mask = mask_model.sample_mask(
            mask_type=config.eval_mask_type,
            temperature=config.temperature,
        )
        # rescale mask
        for name, m in mask.items():
            mask[name] = m / torch.mean(m)
        model = module.merge_and_unload(mask)
    return model.to(dtype=torch.float32)
setup_models()

Set up the pretrained model, fine-tuned models, and mask model.

This method loads the pretrained model, constructs the PGE mask model, and loads the fine-tuned models. It also creates a wrapped model with task-wise weights.

Returns:

  • Tuple[TaskWiseMergedModel, MaskModel]: The wrapped model and mask model.

Source code in fusion_bench/method/concrete_subspace/clip_concrete_task_arithmetic.py
@torch.no_grad()
def setup_models(self):
    """
    Set up the pretrained model, fine-tuned models, and mask model.

    This method loads the pretrained model, constructs the PGE mask model, and loads the fine-tuned models.
    It also creates a wrapped model with task-wise weights.

    Returns:
        Tuple[TaskWiseMergedModel, MaskModel]: The wrapped model and mask model.
    """
    config = self.config
    self.merge_dtype = parse_dtype(config.get("merge_dtype", None))
    modelpool = self.modelpool

    # Load the pretrained model
    pretrained_model = modelpool.load_model("_pretrained_")

    # construct PGE mask model
    mask_model = MaskModel(
        pretrained_model,
        ignore_untrained_params=True,
        parameter_type="logits",
    )
    if self.merge_dtype is not None:
        mask_model.to(self.merge_dtype)
    mask_model.fill_(self.config.initial_logits)
    # TODO: ablation study for the initialization of mask model
    # for param in mask_model.parameters():
    #     param.data = param + 0.1 * torch.randn_like(param)
    print("Summary of mask model:")
    print_parameters(mask_model)

    # Load the fine-tuned models
    finetuned_models = [
        modelpool.load_model(name) for name in modelpool.model_names
    ]

    task_wise_weight = get_task_wise_weights(
        num_models=len(modelpool.model_names),
        init_values=self.config.scaling_factor,
    )

    # create a warpped model
    module = TaskWiseMergedModel(
        task_wise_weight=task_wise_weight,
        pretrained_model=pretrained_model,
        finetuned_models=finetuned_models,
        clamp_weights=self.config.clamp_weights,
        tie_weights=self.config.tie_weights,
        strict=self.config.strict,
        task_vector_dtype=self.merge_dtype,
    )

    return module, mask_model
train_mask(module, mask_model)

Train the mask model using the provided module.

This method configures the optimizer, sets up the mask model, and performs test-time adaptation to train the mask model.

Parameters:

  • module (TaskWiseMergedModel) –

    The wrapped model with task-wise weights.

  • mask_model (MaskModel) –

    The mask model to be trained.

Source code in fusion_bench/method/concrete_subspace/clip_concrete_task_arithmetic.py
def train_mask(self, module: TaskWiseMergedModel, mask_model: MaskModel):
    """
    Train the mask model using the provided module.

    This method configures the optimizer, sets up the mask model, and performs test-time adaptation to train the mask model.

    Args:
        module (TaskWiseMergedModel): The wrapped model with task-wise weights.
        mask_model (MaskModel): The mask model to be trained.
    """
    config = self.config
    # mask_model: MaskModel = self.fabric.to_device(mask_model)

    # configure optimizer
    lr_scheduler = None
    if self.config.optimizer == "adam":
        optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, mask_model.parameters()),
            lr=self.config.lr,
        )
        print(f"{optimizer=}")
        # TODO: ablation study for the learning rate scheduler. It should yield similar results.
        # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        #     optimizer, self.config.max_steps, eta_min=0.1
        # )
        mask_model, optimizer = self.fabric.setup(mask_model, optimizer)
    elif self.config.optimizer == "sgd":
        optimizer = torch.optim.SGD(mask_model.parameters(), lr=self.config.lr)
        print(f"{optimizer=}")
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, self.config.max_steps, eta_min=0.1
        )
        mask_model, optimizer = self.fabric.setup(mask_model, optimizer)
    else:
        raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")

    module.to(mask_model.device)
    module.requires_grad_(False)

    mask_model.train()
    optimizer.zero_grad()
    for step_idx in (
        pbar := tqdm(
            range(self.config.max_steps if not self.is_debug_mode else 5),
            ("[DEBUG MODE] " if self.is_debug_mode else "")
            + "Concrete Task Arithmetic Test-Time Adaptation",
            dynamic_ncols=True,
            disable=not self.fabric.is_global_zero,
        )
    ):
        metrics = {}
        # sample a shared mask and merge weights
        with self.profile("sample mask"):
            mask = mask_model.sample_mask(
                mask_type="continuous", temperature=config.temperature
            )
            metrics["train/sparsity"] = mask_sparsity(mask)
        with self.profile("merge weights"):
            # rescale mask
            for name, m in mask.items():
                mask[name] = m / torch.mean(m)
            module.merge_weights(task_vector_mask=mask)

        # ------ inner optimization goes here ------
        # NOTE:
        #   Because the algorithmic parameters of task arithmetic are assumed to be chosen on a validation test
        #   set, we do not need to perform inner optimization here. So here we skip the inner optimization step.
        # ------------------------------------------

        total_loss = None
        for task in self.modelpool.model_names:
            with self.profile("data loading"):
                batch = next(self.get_shuffled_test_loader_iter(task))
                # NOTE: The labels are not allowed to be used during test-time adaptation
                images = batch[0].to(dtype=self.merge_dtype)
            with self.profile("forward pass"):
                logits = self.compute_logits(module, images, task)
                loss = entropy_loss(logits)
                total_loss = loss if total_loss is None else total_loss + loss

        with self.profile("compute grad"):
            self.fabric.backward(total_loss)

        with self.profile("optimizer step"):
            optimizer.step()
            optimizer.zero_grad()

            if lr_scheduler is not None:
                lr_scheduler.step()

        metrics.update({"train/loss": loss.item()})
        self.fabric.log_dict(metrics, step=step_idx)
        pbar.set_postfix(metrics)

        if (step_idx + 1) % self.config.save_interval == 0:
            with self.profiler.profile("save checkpoint"):
                save_dir = os.path.join(self.fabric.logger.log_dir, "checkpoints")
                if not os.path.exists(save_dir):
                    os.makedirs(save_dir, exist_ok=True)
                save_path = os.path.join(save_dir, f"mask_steps_{step_idx}.pt")
                print(f"saving checkpoint to {save_path}")
                state = {"model": mask_model}
                self.fabric.save(save_path, state)

                # Create or update a symbolic link to the latest checkpoint
                if self.fabric.is_global_zero:
                    symlink_path = os.path.join(save_dir, "latest_checkpoint.pt")
                    if os.path.exists(symlink_path):
                        os.remove(symlink_path)
                    os.link(os.path.abspath(save_path), symlink_path)

            self.print_profile_summary()

ConcreteTaskWiseAdaMergingForCLIP

Bases: CLIPClassificationMixin, SimpleProfilerMixin, ModelFusionAlgorithm

Source code in fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py
class ConcreteTaskWiseAdaMergingForCLIP(
    CLIPClassificationMixin,
    SimpleProfilerMixin,
    ModelFusionAlgorithm,
):
    @torch.no_grad()
    def setup_models(self):
        config = self.config
        self.merge_dtype = parse_dtype(config.get("merge_dtype", None))
        modelpool = self.modelpool

        # Load the pretrained model
        pretrained_model = modelpool.load_model("_pretrained_")

        # construct PGE mask model
        mask_model = MaskModel(
            pretrained_model,
            ignore_untrained_params=True,
            parameter_type="logits",
        )
        if self.merge_dtype is not None:
            mask_model.to(self.merge_dtype)
        mask_model.fill_(self.config.initial_logits)
        # TODO: ablation study for the initialization of mask model
        # for param in mask_model.parameters():
        #     param.data = param + 0.1 * torch.randn_like(param)
        print("Summary of mask model:")
        print_parameters(mask_model)

        # Load the fine-tuned models
        finetuned_models = [
            modelpool.load_model(name) for name in modelpool.model_names
        ]

        task_wise_weight = get_task_wise_weights(
            num_models=len(modelpool.model_names),
            init_values=self.config.scaling_factor,
        )
        self.init_task_wise_weight = deepcopy(task_wise_weight)

        # create a warpped model
        module = TaskWiseMergedModel(
            task_wise_weight=task_wise_weight,
            pretrained_model=pretrained_model,
            finetuned_models=finetuned_models,
            clamp_weights=self.config.clamp_weights,
            tie_weights=self.config.tie_weights,
            strict=self.config.strict,
            task_vector_dtype=self.merge_dtype,
        )
        return module, mask_model

    def train_mask(self, module: TaskWiseMergedModel, mask_model: MaskModel):
        config = self.config
        self.init_task_wise_weight = self.to_device(self.init_task_wise_weight)

        # configure optimizer
        lr_scheduler = None
        if self.config.optimizer == "adam":
            base_optimizer = torch.optim.Adam(
                [module.merge_weight], lr=self.config.base_lr
            )
            optimizer = torch.optim.Adam(mask_model.parameters(), lr=self.config.lr)
            print(f"{optimizer=}")
            # TODO: ablation study for the learning rate scheduler. It should yield similar results.
            # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            #     optimizer, self.config.max_steps, eta_min=0.1
            # )
            module, base_optimizer = self.fabric.setup(module, base_optimizer)
            mask_model, optimizer = self.fabric.setup(mask_model, optimizer)
        else:
            raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")

        module.train()
        mask_model.train()
        for step_idx in (
            pbar := tqdm(
                range(self.config.max_steps if not self.is_debug_mode else 5),
                ("[DEBUG MODE] " if self.is_debug_mode else "")
                + "Concrete AdaMerging Meta-Learn Mask (1/2)",
                dynamic_ncols=True,
                disable=not self.fabric.is_global_zero,
            )
        ):
            metrics = {}
            # sample a shared mask and merge weights
            with self.profile("sample mask"):
                mask = mask_model.sample_mask(
                    mask_type="continuous", temperature=config.temperature
                )
                metrics["train/sparsity"] = mask_sparsity(mask)
            with self.profile("merge weights"):
                # rescale mask
                for name, m in mask.items():
                    mask[name] = m / torch.mean(m)

                # for inner optimization, we do not optimize the mask, so we detach it
                module.merge_weights(
                    task_vector_mask={name: m.detach() for name, m in mask.items()}
                )

            # ------ inner optimization goes here ------
            module.merge_weight.data = deepcopy(self.init_task_wise_weight)
            total_loss = None
            for task in self.modelpool.model_names:
                with self.profile("data loading"):
                    batch = next(self.get_shuffled_test_loader_iter(task))
                    # NOTE: The labels are not allowed to be used during test-time adaptation
                    images = batch[0]
                with self.profile("forward pass"):
                    logits = self.compute_logits(module, images, task)
                    loss = entropy_loss(logits)
                    total_loss = loss if total_loss is None else total_loss + loss

            base_optimizer.zero_grad()
            with self.profile("compute grad"):
                self.fabric.backward(total_loss)

            with self.profile("base optimizer step"):
                base_optimizer.step()

            with self.profile("merge weights"):
                module.merge_weights(task_vector_mask=mask)

            # ------------------------------------------

            total_loss = None
            for task in self.modelpool.model_names:
                with self.profile("data loading"):
                    batch = next(self.get_shuffled_test_loader_iter(task))
                    # NOTE: The labels are not allowed to be used during test-time adaptation
                    images = batch[0]
                with self.profile("forward pass"):
                    logits = self.compute_logits(module, images, task)
                    loss = entropy_loss(logits)
                    total_loss = loss if total_loss is None else total_loss + loss

            optimizer.zero_grad()
            with self.profile("compute grad"):
                self.fabric.backward(total_loss)

            with self.profile("optimizer step"):
                optimizer.step()

                if lr_scheduler is not None:
                    lr_scheduler.step()

            metrics.update({"train/loss": loss.item()})
            self.fabric.log_dict(metrics, step=step_idx)
            pbar.set_postfix(metrics)

            if (step_idx + 1) % self.config.save_interval == 0:
                with self.profiler.profile("save checkpoint"):
                    save_dir = os.path.join(self.fabric.logger.log_dir, "checkpoints")
                    if not os.path.exists(save_dir):
                        os.makedirs(save_dir, exist_ok=True)
                    save_path = os.path.join(save_dir, f"mask_steps_{step_idx}.pt")
                    print(f"saving checkpoint to {save_path}")
                    state = {"model": mask_model}
                    self.fabric.save(save_path, state)

                    # Create or update a symbolic link to the latest checkpoint
                    if self.fabric.is_global_zero:
                        symlink_path = os.path.join(save_dir, "latest_checkpoint.pt")
                        if os.path.exists(symlink_path):
                            os.remove(symlink_path)
                        os.link(os.path.abspath(save_path), symlink_path)

                self.print_profile_summary()

    def run_adamerging(self, module: TaskWiseMergedModel, mask):
        module.merge_weight.data = deepcopy(self.init_task_wise_weight)
        optimizer = torch.optim.Adam(
            [module.merge_weight], lr=self.config.adamerging_lr
        )
        module, optimizer = self.fabric.setup(module, optimizer)
        module.train()
        for step_idx in (
            pbar := tqdm(
                range(
                    self.config.max_adamerging_steps if not self.is_debug_mode else 5
                ),
                ("[DEBUG MODE] " if self.is_debug_mode else "")
                + "Concrete AdaMerging AdaMerging (2/2)",
                dynamic_ncols=True,
                disable=not self.fabric.is_global_zero,
            )
        ):
            step_idx = step_idx + self.config.max_steps
            with self.profile("merge weights"):
                module.merge_weights(task_vector_mask=mask)

            metrics = {}
            total_loss = None
            for task in self.modelpool.model_names:
                with self.profile("data loading"):
                    batch = next(self.get_shuffled_test_loader_iter(task))
                    # NOTE: The labels are not allowed to be used during test-time adaptation
                    images = batch[0]
                with self.profile("forward pass"):
                    logits = self.compute_logits(module, images, task)
                    loss = entropy_loss(logits)
                    total_loss = loss if total_loss is None else total_loss + loss

            optimizer.zero_grad()
            with self.profile("compute grad"):
                self.fabric.backward(total_loss)

            with self.profile("base optimizer step"):
                optimizer.step()

            metrics.update({"train/loss": loss.item()})
            self.fabric.log_dict(metrics, step=step_idx)
            pbar.set_postfix(metrics)

            if (step_idx + 1) % self.config.save_interval == 0:
                with self.profiler.profile("save checkpoint"):
                    save_dir = os.path.join(self.fabric.logger.log_dir, "checkpoints")
                    if not os.path.exists(save_dir):
                        os.makedirs(save_dir, exist_ok=True)
                    save_path = os.path.join(save_dir, f"merge_weight_{step_idx}.pt")
                    print(f"saving checkpoint to {save_path}")
                    state = {"merge_weight": module.merge_weight}
                    self.fabric.save(save_path, state)

                    # Create or update a symbolic link to the latest checkpoint
                    if self.fabric.is_global_zero:
                        symlink_path = os.path.join(
                            save_dir, "merge_weight_latest_checkpoint.pt"
                        )
                        if os.path.exists(symlink_path):
                            os.remove(symlink_path)
                        os.link(os.path.abspath(save_path), symlink_path)

                self.print_profile_summary()
        return module

    def run(self, modelpool: HuggingFaceClipVisionPool):
        self.modelpool = to_modelpool(modelpool)
        config = self.config
        self.log_hyperparams(config, filename="method_config.yaml")

        with self.profile("setup models"):
            module, mask_model = self.setup_models()
            mask_model: MaskModel = self.fabric.to_device(mask_model)
            module: TaskWiseMergedModel = self.fabric.to_device(module)
            self.setup_zero_shot_classification_head()

        if config.mask_checkpoint is None:
            self.train_mask(module=module, mask_model=mask_model)
        else:
            if self.fabric.is_global_zero:
                print("loading mask from checkpoint", config.mask_checkpoint)
            self.fabric.load(config.mask_checkpoint, {"model": mask_model})

        # run adamerging
        with torch.no_grad():
            mask = mask_model.sample_mask(
                mask_type=config.eval_mask_type,
                temperature=config.temperature,
            )
            # rescale mask
            for name, m in mask.items():
                mask[name] = m / torch.mean(m)
        module = self.run_adamerging(module, mask=mask)

        with torch.no_grad():
            model = module.merge_and_unload(mask)
        return model

ConcreteLayerWiseAdaMergingForCLIP

Bases: CLIPClassificationMixin, SimpleProfilerMixin, ModelFusionAlgorithm

Source code in fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py
class ConcreteLayerWiseAdaMergingForCLIP(
    CLIPClassificationMixin,
    SimpleProfilerMixin,
    ModelFusionAlgorithm,
):
    @torch.no_grad()
    def setup_models(self):
        config = self.config
        self.merge_dtype = parse_dtype(config.get("merge_dtype", None))
        modelpool = self.modelpool

        # Load the pretrained model
        pretrained_model = modelpool.load_model("_pretrained_")

        # construct PGE mask model
        mask_model = MaskModel(
            pretrained_model,
            ignore_untrained_params=True,
            parameter_type="logits",
        )
        if self.merge_dtype is not None:
            mask_model.to(self.merge_dtype)
        mask_model.fill_(self.config.initial_logits)
        # TODO: ablation study for the initialization of mask model
        # for param in mask_model.parameters():
        #     param.data = param + 0.1 * torch.randn_like(param)
        print("Summary of mask model:")
        print_parameters(mask_model)

        # Load the fine-tuned models
        finetuned_models = [
            modelpool.load_model(name) for name in modelpool.model_names
        ]

        layer_wise_weight = get_layer_wise_weights(
            num_models=len(modelpool.model_names),
            num_layers=len(
                tuple(filter(lambda p: p.requires_grad, pretrained_model.parameters()))
            ),
            init_values=self.config.scaling_factor,
        )
        self.init_layer_wise_weight = deepcopy(layer_wise_weight)

        # create a warpped model
        module = LayerWiseMergedModel(
            layer_wise_weight=layer_wise_weight,
            pretrained_model=pretrained_model,
            finetuned_models=finetuned_models,
            clamp_weights=self.config.clamp_weights,
            tie_weights=self.config.tie_weights,
            strict=self.config.strict,
            layer_vector_dtype=self.merge_dtype,
        )
        return module, mask_model

    def train_mask(self, module: LayerWiseMergedModel, mask_model: MaskModel):
        config = self.config
        self.init_layer_wise_weight = self.to_device(self.init_layer_wise_weight)

        # configure optimizer
        lr_scheduler = None
        if self.config.optimizer == "adam":
            base_optimizer = torch.optim.Adam(
                [module.merge_weight], lr=self.config.base_lr
            )
            optimizer = torch.optim.Adam(mask_model.parameters(), lr=self.config.lr)
            print(f"{optimizer=}")
            # TODO: ablation study for the learning rate scheduler. It should yield similar results.
            # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            #     optimizer, self.config.max_steps, eta_min=0.1
            # )
            module, base_optimizer = self.fabric.setup(module, base_optimizer)
            mask_model, optimizer = self.fabric.setup(mask_model, optimizer)
        else:
            raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")

        module.train()
        mask_model.train()
        for step_idx in (
            pbar := tqdm(
                range(self.config.max_steps if not self.is_debug_mode else 5),
                ("[DEBUG MODE] " if self.is_debug_mode else "")
                + "Concrete AdaMerging Meta-Learn Mask (1/2)",
                dynamic_ncols=True,
                disable=not self.fabric.is_global_zero,
            )
        ):
            metrics = {}
            # sample a shared mask and merge weights
            with self.profile("sample mask"):
                mask = mask_model.sample_mask(
                    mask_type="continuous", temperature=config.temperature
                )
                metrics["train/sparsity"] = mask_sparsity(mask)
            with self.profile("merge weights"):
                # rescale mask
                for name, m in mask.items():
                    mask[name] = m / torch.mean(m)

                # for inner optimization, we do not optimize the mask, so we detach it
                module.merge_weights(
                    task_vector_mask={name: m.detach() for name, m in mask.items()}
                )

            # ------ inner optimization goes here ------
            module.merge_weight.data = deepcopy(self.init_layer_wise_weight)
            total_loss = None
            for task in self.modelpool.model_names:
                with self.profile("data loading"):
                    batch = next(self.get_shuffled_test_loader_iter(task))
                    # NOTE: The labels are not allowed to be used during test-time adaptation
                    images = batch[0]
                with self.profile("forward pass"):
                    logits = self.compute_logits(module, images, task)
                    loss = entropy_loss(logits)
                    total_loss = loss if total_loss is None else total_loss + loss

            base_optimizer.zero_grad()
            with self.profile("compute grad"):
                self.fabric.backward(total_loss)

            with self.profile("base optimizer step"):
                base_optimizer.step()

            with self.profile("merge weights"):
                module.merge_weights(task_vector_mask=mask)

            # ------------------------------------------

            total_loss = None
            for task in self.modelpool.model_names:
                with self.profile("data loading"):
                    batch = next(self.get_shuffled_test_loader_iter(task))
                    # NOTE: The labels are not allowed to be used during test-time adaptation
                    images = batch[0]
                with self.profile("forward pass"):
                    logits = self.compute_logits(module, images, task)
                    loss = entropy_loss(logits)
                    total_loss = loss if total_loss is None else total_loss + loss

            optimizer.zero_grad()
            with self.profile("compute grad"):
                self.fabric.backward(total_loss)

            with self.profile("optimizer step"):
                optimizer.step()

                if lr_scheduler is not None:
                    lr_scheduler.step()

            metrics.update({"train/loss": loss.item()})
            self.fabric.log_dict(metrics, step=step_idx)
            pbar.set_postfix(metrics)

            if (step_idx + 1) % self.config.save_interval == 0:
                with self.profiler.profile("save checkpoint"):
                    save_dir = os.path.join(self.fabric.logger.log_dir, "checkpoints")
                    if not os.path.exists(save_dir):
                        os.makedirs(save_dir, exist_ok=True)
                    save_path = os.path.join(save_dir, f"mask_steps_{step_idx}.pt")
                    print(f"saving checkpoint to {save_path}")
                    state = {"model": mask_model}
                    self.fabric.save(save_path, state)

                    # Create or update a symbolic link to the latest checkpoint
                    if self.fabric.is_global_zero:
                        symlink_path = os.path.join(save_dir, "latest_checkpoint.pt")
                        if os.path.exists(symlink_path):
                            os.remove(symlink_path)
                        os.link(os.path.abspath(save_path), symlink_path)

                self.print_profile_summary()

    def run_adamerging(self, module: LayerWiseMergedModel, mask):
        module.merge_weight.data = deepcopy(self.init_layer_wise_weight)
        optimizer = torch.optim.Adam(
            [module.merge_weight], lr=self.config.adamerging_lr
        )
        module, optimizer = self.fabric.setup(module, optimizer)
        module.train()
        for step_idx in (
            pbar := tqdm(
                range(
                    self.config.max_adamerging_steps if not self.is_debug_mode else 5
                ),
                ("[DEBUG MODE] " if self.is_debug_mode else "")
                + "Concrete AdaMerging AdaMerging (2/2)",
                dynamic_ncols=True,
                disable=not self.fabric.is_global_zero,
            )
        ):
            step_idx = step_idx + self.config.max_steps
            with self.profile("merge weights"):
                module.merge_weights(task_vector_mask=mask)

            metrics = {}
            total_loss = None
            for task in self.modelpool.model_names:
                with self.profile("data loading"):
                    batch = next(self.get_shuffled_test_loader_iter(task))
                    # NOTE: The labels are not allowed to be used during test-time adaptation
                    images = batch[0]
                with self.profile("forward pass"):
                    logits = self.compute_logits(module, images, task)
                    loss = entropy_loss(logits)
                    total_loss = loss if total_loss is None else total_loss + loss

            optimizer.zero_grad()
            with self.profile("compute grad"):
                self.fabric.backward(total_loss)

            with self.profile("base optimizer step"):
                optimizer.step()

            metrics.update({"train/loss": loss.item()})
            self.fabric.log_dict(metrics, step=step_idx)
            pbar.set_postfix(metrics)

            if (step_idx + 1) % self.config.save_interval == 0:
                with self.profiler.profile("save checkpoint"):
                    save_dir = os.path.join(self.fabric.logger.log_dir, "checkpoints")
                    if not os.path.exists(save_dir):
                        os.makedirs(save_dir, exist_ok=True)
                    save_path = os.path.join(save_dir, f"merge_weight_{step_idx}.pt")
                    print(f"saving checkpoint to {save_path}")
                    state = {"merge_weight": module.merge_weight}
                    self.fabric.save(save_path, state)

                    # Create or update a symbolic link to the latest checkpoint
                    if self.fabric.is_global_zero:
                        symlink_path = os.path.join(
                            save_dir, "merge_weight_latest_checkpoint.pt"
                        )
                        if os.path.exists(symlink_path):
                            os.remove(symlink_path)
                        os.link(os.path.abspath(save_path), symlink_path)

                self.print_profile_summary()
        return module

    def run(self, modelpool: HuggingFaceClipVisionPool):
        self.modelpool = to_modelpool(modelpool)
        config = self.config
        self.log_hyperparams(config, filename="method_config.yaml")

        with self.profile("setup models"):
            module, mask_model = self.setup_models()
            mask_model: MaskModel = self.fabric.to_device(mask_model)
            module: LayerWiseMergedModel = self.fabric.to_device(module)
            self.setup_zero_shot_classification_head()

        if config.mask_checkpoint is None:
            self.train_mask(module=module, mask_model=mask_model)
        else:
            if self.fabric.is_global_zero:
                print("loading mask from checkpoint", config.mask_checkpoint)
            self.fabric.load(config.mask_checkpoint, {"model": mask_model})

        # run adamerging
        with torch.no_grad():
            mask = mask_model.sample_mask(
                mask_type=config.eval_mask_type,
                temperature=config.temperature,
            )
            # rescale mask
            for name, m in mask.items():
                mask[name] = m / torch.mean(m)
        module = self.run_adamerging(module, mask=mask)

        with torch.no_grad():
            model = module.merge_and_unload(mask)
        return model

Task Singular Vector Merging (TSVM)

TaskSingularVectorMerging

Bases: BaseAlgorithm, LightningFabricMixin

Task Singular Vector Merging (TSVM) Algorithm

This class implements a model merging technique that leverages Singular Value Decomposition (SVD) to identify and combine the most important directions in the task vector space. The algorithm is particularly effective for merging multiple models fine-tuned on different tasks while preserving their essential capabilities.

Key Concepts: - Task Vector: The difference between a fine-tuned model and its pretrained base model, representing the knowledge gained during fine-tuning for a specific task. - Singular Value Decomposition: A matrix factorization technique used to find the principal components (most important directions) in the space of task vectors. - Model Merging: The process of combining multiple models into a single unified model that retains capabilities from all constituent models.

Algorithm Steps: 1. Extract task vectors from all fine-tuned models by subtracting the pretrained model 2. Apply SVD to the matrix of task vectors to find principal components 3. Reconstruct task vectors using only the most significant singular vectors 4. Merge the reconstructed task vectors (either individually scaled or as a sum) 5. Add the final merged task vector to the pretrained model to create the unified model

see docs/algorithms/task_singular_vector.md for comprehensive algorithmic details.

Source code in fusion_bench/method/task_singular_vector/TSVM.py
class TaskSingularVectorMerging(BaseAlgorithm, LightningFabricMixin):
    """
    Task Singular Vector Merging (TSVM) Algorithm

    This class implements a model merging technique that leverages Singular Value
    Decomposition (SVD) to identify and combine the most important directions in the task vector
    space. The algorithm is particularly effective for merging multiple models fine-tuned on
    different tasks while preserving their essential capabilities.

    Key Concepts:
    - Task Vector: The difference between a fine-tuned model and its pretrained base model,
      representing the knowledge gained during fine-tuning for a specific task.
    - Singular Value Decomposition: A matrix factorization technique used to find the principal
      components (most important directions) in the space of task vectors.
    - Model Merging: The process of combining multiple models into a single unified model that
      retains capabilities from all constituent models.

    Algorithm Steps:
    1. Extract task vectors from all fine-tuned models by subtracting the pretrained model
    2. Apply SVD to the matrix of task vectors to find principal components
    3. Reconstruct task vectors using only the most significant singular vectors
    4. Merge the reconstructed task vectors (either individually scaled or as a sum)
    5. Add the final merged task vector to the pretrained model to create the unified model

    see `docs/algorithms/task_singular_vector.md` for comprehensive algorithmic details.
    """

    def __init__(
        self,
        alpha: Optional[Union[float, Iterable[float]]] = None,
        exclude_keys: Optional[List[str]] = None,
        return_single_task_models: bool = False,
        **kwargs,
    ):
        """
        Initialize the Task Singular Vector Merging algorithm.

        Args:
            alpha (Union[float, Iterable[float]], optional): Scaling factor(s) for task vectors.
                This parameter controls the strength of the task-specific adaptations in the final model.

                - If a single float: Applied to the final merged task vector after SVD reconstruction.
                  This uniformly scales the entire merged adaptation.

                - If an iterable of floats: Applied to individual task vectors before SVD and merging.
                  Must have the same length as the number of models in the modelpool.
                  This allows for task-specific weighting (e.g., giving more importance to certain tasks).

                - If None: No scaling is applied (equivalent to alpha=1.0).

                Example: alpha=[0.8, 1.2, 0.5] would apply different weights to three different task vectors.

            exclude_keys (Optional[List[str]], optional): List of parameter names to exclude from TSVM.
                These parameters will not participate in the SVD computation and merging process.
                Useful for excluding certain layers (e.g., task-specific heads, normalization layers)
                that should not be merged across tasks. Defaults to an empty list.

                Example: exclude_keys=['classifier.weight', 'classifier.bias'] to skip classification heads.

            return_single_task_models (bool, optional): Whether to return individual transformed models.

                - If True: Returns a dictionary containing both individual models with their transformed
                  task vectors applied AND the final merged model. The dictionary has the structure:

                  >>> {'model_name_1': transformed_model_1, ..., 'merged': final_merged_model}

                - If False: Returns only the final merged model.

                This is useful for analysis or when you need access to intermediate results.
                Defaults to False.

            **kwargs: Additional arguments passed to the parent BaseAlgorithm class.

        Note:
            The choice between single alpha vs. list of alphas affects the merging strategy:
            - Single alpha: SVD is applied first, then the result is scaled
            - List of alphas: Individual task vectors are scaled first, then SVD is applied
        """
        self.alpha = alpha
        self.exclude_keys = exclude_keys if exclude_keys is not None else []
        self.return_single_task_models = return_single_task_models
        super().__init__(**kwargs)

    def load_pretrained_model_and_task_vectors(self, modelpool: fb.BaseModelPool):
        """
        Load the pretrained base model and compute task vectors from all fine-tuned models.

        This method performs the initial step of the TSVM algorithm by:
        1. Loading the original pretrained model (before any task-specific fine-tuning)
        2. For each fine-tuned model in the pool:
           - Load the fine-tuned model
           - Compute the task vector (fine-tuned params - pretrained params)
           - Optionally apply individual scaling if alpha is provided as a list

        Task vectors represent the knowledge gained during fine-tuning and are the core
        data structure that TSVM operates on.

        Args:
            modelpool (fb.BaseModelPool): Pool containing the pretrained model and all
                fine-tuned models to be merged.

        Returns:
            tuple: A tuple containing:
                - pretrained_model: The original pretrained model (torch.nn.Module)
                - task_vectors: List of task vectors (List[StateDictType]), where each
                  task vector is a state dictionary representing the parameter differences
                  for one specific task
        """
        # Load the original pretrained model that serves as the base for all fine-tuned variants
        pretrained_model = modelpool.load_pretrained_model()

        # Initialize list to store computed task vectors
        task_vectors = []

        # Process each fine-tuned model in the modelpool
        for model_idx, model_name in enumerate(modelpool.model_names):
            # Load the current fine-tuned model
            finetuned_model = modelpool.load_model(model_name)

            # Compute task vector: difference between fine-tuned and pretrained parameters
            # This captures the task-specific adaptations learned during fine-tuning
            task_vector = state_dict_sub(
                finetuned_model.state_dict(), pretrained_model.state_dict()
            )
            task_vectors.append(task_vector)

            # Apply individual scaling to task vectors if alpha is provided as a list
            # This allows for task-specific weighting before the SVD computation
            if self.alpha is not None and isinstance(self.alpha, Iterable):
                # Ensure the number of alpha values matches the number of models
                assert len(self.alpha) == len(
                    modelpool.model_names
                ), f"Alpha list length ({len(self.alpha)}) must match number of models ({len(modelpool.model_names)})"

                # Scale the current task vector by its corresponding alpha value
                task_vectors[-1] = state_dict_mul(
                    state_dict=task_vectors[-1], scalar=self.alpha[model_idx]
                )

        return pretrained_model, task_vectors

    def run(self, modelpool: fb.BaseModelPool):
        """
        Execute the complete Task Singular Vector Merging algorithm.

        This is the main entry point that orchestrates the entire TSVM process:

        The algorithm leverages the mathematical insight that task vectors often lie in a
        lower-dimensional subspace, and SVD helps identify the most important directions
        in this subspace while filtering out noise and interference.

        Args:
            modelpool (fb.BaseModelPool): Pool of models to merge, including:
                - The pretrained base model
                - Multiple fine-tuned models (one per task)
                All models must have compatible architectures.

        Returns:
            Union[torch.nn.Module, Dict[str, torch.nn.Module]]:
                - If return_single_task_models=False: Returns the merged model
                - If return_single_task_models=True: Returns a dictionary with:
                  * Individual transformed models keyed by their original names
                  * Final merged model under the key 'merged'

        Raises:
            AssertionError: If alpha list length doesn't match the number of models
        """
        # Determine the compute device for SVD operations (GPU if available for faster computation)
        accelerator = self.fabric.device

        # Phase 1: Load pretrained model and compute task vectors from all fine-tuned models
        pretrained_model, task_vectors = self.load_pretrained_model_and_task_vectors(
            modelpool
        )

        # Phase 2: Apply SVD-based merging to the task vectors
        # This is the core of the TSVM algorithm where:
        # - Task vectors are organized into a matrix
        # - SVD finds the principal components (most important directions)
        # - Task vectors are reconstructed using only the most significant components
        # - The reconstructed vectors are merged (summed) to create a unified task vector
        new_merged_tv = TSVM_utils.compute_and_sum_svd_mem_reduction(
            task_vectors,
            exclude_keys=self.exclude_keys,  # Skip certain parameters from SVD
            accelerator=accelerator,  # Use GPU if available
            return_single_task_models=self.return_single_task_models,
        )

        # Handle the case where individual transformed task vectors are also returned
        if self.return_single_task_models:
            new_merged_tv, single_task_models = new_merged_tv

        # Phase 3: Apply global scaling to the merged task vector (if alpha is a single value)
        # This is different from individual scaling applied earlier - here we scale the
        # final merged result, which affects the overall strength of all merged adaptations
        if self.alpha is not None and isinstance(self.alpha, (float, int)):
            print(f"Scaling new merged task vector by alpha: {self.alpha}")
            new_merged_tv = state_dict_mul(state_dict=new_merged_tv, scalar=self.alpha)

        # Phase 4: Prepare individual transformed models if requested
        if self.return_single_task_models:
            models = {}
            # Create individual models by adding each transformed task vector to the pretrained base
            for model_idx, model_name in enumerate(modelpool.model_names):
                # Create a deep copy to avoid modifying the original pretrained model
                model = deepcopy(pretrained_model)
                # Apply the transformed task vector to get the individual model
                model.load_state_dict(
                    state_dict_add(model.state_dict(), single_task_models[model_idx])
                )
                models[model_name] = model

        # Phase 5: Create the final merged model by adding the merged task vector to pretrained model
        # This produces a single model that combines capabilities from all input models
        pretrained_model.load_state_dict(
            state_dict_add(new_merged_tv, pretrained_model.state_dict())
        )

        # Phase 6: Return results based on the requested output format
        if self.return_single_task_models:
            # Include the final merged model in the dictionary of results
            models["merged"] = pretrained_model
            return models
        else:
            # Return only the merged model
            return pretrained_model
__init__(alpha=None, exclude_keys=None, return_single_task_models=False, **kwargs)

Initialize the Task Singular Vector Merging algorithm.

Parameters:

  • alpha (Union[float, Iterable[float]], default: None ) –

    Scaling factor(s) for task vectors. This parameter controls the strength of the task-specific adaptations in the final model.

    • If a single float: Applied to the final merged task vector after SVD reconstruction. This uniformly scales the entire merged adaptation.

    • If an iterable of floats: Applied to individual task vectors before SVD and merging. Must have the same length as the number of models in the modelpool. This allows for task-specific weighting (e.g., giving more importance to certain tasks).

    • If None: No scaling is applied (equivalent to alpha=1.0).

    Example: alpha=[0.8, 1.2, 0.5] would apply different weights to three different task vectors.

  • exclude_keys (Optional[List[str]], default: None ) –

    List of parameter names to exclude from TSVM. These parameters will not participate in the SVD computation and merging process. Useful for excluding certain layers (e.g., task-specific heads, normalization layers) that should not be merged across tasks. Defaults to an empty list.

    Example: exclude_keys=['classifier.weight', 'classifier.bias'] to skip classification heads.

  • return_single_task_models (bool, default: False ) –

    Whether to return individual transformed models.

    • If True: Returns a dictionary containing both individual models with their transformed task vectors applied AND the final merged model. The dictionary has the structure:

    {'model_name_1': transformed_model_1, ..., 'merged': final_merged_model}

    • If False: Returns only the final merged model.

    This is useful for analysis or when you need access to intermediate results. Defaults to False.

  • **kwargs

    Additional arguments passed to the parent BaseAlgorithm class.

Note

The choice between single alpha vs. list of alphas affects the merging strategy: - Single alpha: SVD is applied first, then the result is scaled - List of alphas: Individual task vectors are scaled first, then SVD is applied

Source code in fusion_bench/method/task_singular_vector/TSVM.py
def __init__(
    self,
    alpha: Optional[Union[float, Iterable[float]]] = None,
    exclude_keys: Optional[List[str]] = None,
    return_single_task_models: bool = False,
    **kwargs,
):
    """
    Initialize the Task Singular Vector Merging algorithm.

    Args:
        alpha (Union[float, Iterable[float]], optional): Scaling factor(s) for task vectors.
            This parameter controls the strength of the task-specific adaptations in the final model.

            - If a single float: Applied to the final merged task vector after SVD reconstruction.
              This uniformly scales the entire merged adaptation.

            - If an iterable of floats: Applied to individual task vectors before SVD and merging.
              Must have the same length as the number of models in the modelpool.
              This allows for task-specific weighting (e.g., giving more importance to certain tasks).

            - If None: No scaling is applied (equivalent to alpha=1.0).

            Example: alpha=[0.8, 1.2, 0.5] would apply different weights to three different task vectors.

        exclude_keys (Optional[List[str]], optional): List of parameter names to exclude from TSVM.
            These parameters will not participate in the SVD computation and merging process.
            Useful for excluding certain layers (e.g., task-specific heads, normalization layers)
            that should not be merged across tasks. Defaults to an empty list.

            Example: exclude_keys=['classifier.weight', 'classifier.bias'] to skip classification heads.

        return_single_task_models (bool, optional): Whether to return individual transformed models.

            - If True: Returns a dictionary containing both individual models with their transformed
              task vectors applied AND the final merged model. The dictionary has the structure:

              >>> {'model_name_1': transformed_model_1, ..., 'merged': final_merged_model}

            - If False: Returns only the final merged model.

            This is useful for analysis or when you need access to intermediate results.
            Defaults to False.

        **kwargs: Additional arguments passed to the parent BaseAlgorithm class.

    Note:
        The choice between single alpha vs. list of alphas affects the merging strategy:
        - Single alpha: SVD is applied first, then the result is scaled
        - List of alphas: Individual task vectors are scaled first, then SVD is applied
    """
    self.alpha = alpha
    self.exclude_keys = exclude_keys if exclude_keys is not None else []
    self.return_single_task_models = return_single_task_models
    super().__init__(**kwargs)
load_pretrained_model_and_task_vectors(modelpool)

Load the pretrained base model and compute task vectors from all fine-tuned models.

This method performs the initial step of the TSVM algorithm by: 1. Loading the original pretrained model (before any task-specific fine-tuning) 2. For each fine-tuned model in the pool: - Load the fine-tuned model - Compute the task vector (fine-tuned params - pretrained params) - Optionally apply individual scaling if alpha is provided as a list

Task vectors represent the knowledge gained during fine-tuning and are the core data structure that TSVM operates on.

Parameters:

  • modelpool (BaseModelPool) –

    Pool containing the pretrained model and all fine-tuned models to be merged.

Returns:

  • tuple

    A tuple containing: - pretrained_model: The original pretrained model (torch.nn.Module) - task_vectors: List of task vectors (List[StateDictType]), where each task vector is a state dictionary representing the parameter differences for one specific task

Source code in fusion_bench/method/task_singular_vector/TSVM.py
def load_pretrained_model_and_task_vectors(self, modelpool: fb.BaseModelPool):
    """
    Load the pretrained base model and compute task vectors from all fine-tuned models.

    This method performs the initial step of the TSVM algorithm by:
    1. Loading the original pretrained model (before any task-specific fine-tuning)
    2. For each fine-tuned model in the pool:
       - Load the fine-tuned model
       - Compute the task vector (fine-tuned params - pretrained params)
       - Optionally apply individual scaling if alpha is provided as a list

    Task vectors represent the knowledge gained during fine-tuning and are the core
    data structure that TSVM operates on.

    Args:
        modelpool (fb.BaseModelPool): Pool containing the pretrained model and all
            fine-tuned models to be merged.

    Returns:
        tuple: A tuple containing:
            - pretrained_model: The original pretrained model (torch.nn.Module)
            - task_vectors: List of task vectors (List[StateDictType]), where each
              task vector is a state dictionary representing the parameter differences
              for one specific task
    """
    # Load the original pretrained model that serves as the base for all fine-tuned variants
    pretrained_model = modelpool.load_pretrained_model()

    # Initialize list to store computed task vectors
    task_vectors = []

    # Process each fine-tuned model in the modelpool
    for model_idx, model_name in enumerate(modelpool.model_names):
        # Load the current fine-tuned model
        finetuned_model = modelpool.load_model(model_name)

        # Compute task vector: difference between fine-tuned and pretrained parameters
        # This captures the task-specific adaptations learned during fine-tuning
        task_vector = state_dict_sub(
            finetuned_model.state_dict(), pretrained_model.state_dict()
        )
        task_vectors.append(task_vector)

        # Apply individual scaling to task vectors if alpha is provided as a list
        # This allows for task-specific weighting before the SVD computation
        if self.alpha is not None and isinstance(self.alpha, Iterable):
            # Ensure the number of alpha values matches the number of models
            assert len(self.alpha) == len(
                modelpool.model_names
            ), f"Alpha list length ({len(self.alpha)}) must match number of models ({len(modelpool.model_names)})"

            # Scale the current task vector by its corresponding alpha value
            task_vectors[-1] = state_dict_mul(
                state_dict=task_vectors[-1], scalar=self.alpha[model_idx]
            )

    return pretrained_model, task_vectors
run(modelpool)

Execute the complete Task Singular Vector Merging algorithm.

This is the main entry point that orchestrates the entire TSVM process:

The algorithm leverages the mathematical insight that task vectors often lie in a lower-dimensional subspace, and SVD helps identify the most important directions in this subspace while filtering out noise and interference.

Parameters:

  • modelpool (BaseModelPool) –

    Pool of models to merge, including: - The pretrained base model - Multiple fine-tuned models (one per task) All models must have compatible architectures.

Returns:

  • Union[torch.nn.Module, Dict[str, torch.nn.Module]]: - If return_single_task_models=False: Returns the merged model - If return_single_task_models=True: Returns a dictionary with: * Individual transformed models keyed by their original names * Final merged model under the key 'merged'

Raises:

  • AssertionError

    If alpha list length doesn't match the number of models

Source code in fusion_bench/method/task_singular_vector/TSVM.py
def run(self, modelpool: fb.BaseModelPool):
    """
    Execute the complete Task Singular Vector Merging algorithm.

    This is the main entry point that orchestrates the entire TSVM process:

    The algorithm leverages the mathematical insight that task vectors often lie in a
    lower-dimensional subspace, and SVD helps identify the most important directions
    in this subspace while filtering out noise and interference.

    Args:
        modelpool (fb.BaseModelPool): Pool of models to merge, including:
            - The pretrained base model
            - Multiple fine-tuned models (one per task)
            All models must have compatible architectures.

    Returns:
        Union[torch.nn.Module, Dict[str, torch.nn.Module]]:
            - If return_single_task_models=False: Returns the merged model
            - If return_single_task_models=True: Returns a dictionary with:
              * Individual transformed models keyed by their original names
              * Final merged model under the key 'merged'

    Raises:
        AssertionError: If alpha list length doesn't match the number of models
    """
    # Determine the compute device for SVD operations (GPU if available for faster computation)
    accelerator = self.fabric.device

    # Phase 1: Load pretrained model and compute task vectors from all fine-tuned models
    pretrained_model, task_vectors = self.load_pretrained_model_and_task_vectors(
        modelpool
    )

    # Phase 2: Apply SVD-based merging to the task vectors
    # This is the core of the TSVM algorithm where:
    # - Task vectors are organized into a matrix
    # - SVD finds the principal components (most important directions)
    # - Task vectors are reconstructed using only the most significant components
    # - The reconstructed vectors are merged (summed) to create a unified task vector
    new_merged_tv = TSVM_utils.compute_and_sum_svd_mem_reduction(
        task_vectors,
        exclude_keys=self.exclude_keys,  # Skip certain parameters from SVD
        accelerator=accelerator,  # Use GPU if available
        return_single_task_models=self.return_single_task_models,
    )

    # Handle the case where individual transformed task vectors are also returned
    if self.return_single_task_models:
        new_merged_tv, single_task_models = new_merged_tv

    # Phase 3: Apply global scaling to the merged task vector (if alpha is a single value)
    # This is different from individual scaling applied earlier - here we scale the
    # final merged result, which affects the overall strength of all merged adaptations
    if self.alpha is not None and isinstance(self.alpha, (float, int)):
        print(f"Scaling new merged task vector by alpha: {self.alpha}")
        new_merged_tv = state_dict_mul(state_dict=new_merged_tv, scalar=self.alpha)

    # Phase 4: Prepare individual transformed models if requested
    if self.return_single_task_models:
        models = {}
        # Create individual models by adding each transformed task vector to the pretrained base
        for model_idx, model_name in enumerate(modelpool.model_names):
            # Create a deep copy to avoid modifying the original pretrained model
            model = deepcopy(pretrained_model)
            # Apply the transformed task vector to get the individual model
            model.load_state_dict(
                state_dict_add(model.state_dict(), single_task_models[model_idx])
            )
            models[model_name] = model

    # Phase 5: Create the final merged model by adding the merged task vector to pretrained model
    # This produces a single model that combines capabilities from all input models
    pretrained_model.load_state_dict(
        state_dict_add(new_merged_tv, pretrained_model.state_dict())
    )

    # Phase 6: Return results based on the requested output format
    if self.return_single_task_models:
        # Include the final merged model in the dictionary of results
        models["merged"] = pretrained_model
        return models
    else:
        # Return only the merged model
        return pretrained_model

Isotropic Merging

ISO_C_Merge = IsotropicMergingInCommonSubspace module-attribute

ISO_CTS_Merge = IsotropicMergingInCommonAndTaskSubspace module-attribute

IsotropicMergingInCommonSubspace

Bases: BaseAlgorithm, LightningFabricMixin

Isotropic Merging in Common Subspace (Iso-C)

Source code in fusion_bench/method/isotropic_merging/iso.py
class IsotropicMergingInCommonSubspace(BaseAlgorithm, LightningFabricMixin):
    """
    Isotropic Merging in Common Subspace (Iso-C)
    """

    def __init__(
        self,
        scaling_factor: float,
        exclude_keys: List[str] = None,
    ):
        self.scaling_factor = scaling_factor
        self.exclude_keys = exclude_keys
        super().__init__()

    def run(self, modelpool: BaseModelPool):
        # load the pretrained model and the task vectors of all the finetuned models
        with torch.no_grad():
            pretrained_model = modelpool.load_pretrained_model()
            task_vectors = []
            for model_name in modelpool.model_names:
                finetuned_model = modelpool.load_model(model_name)
                task_vectors.append(
                    state_dict_sub(
                        finetuned_model.state_dict(), pretrained_model.state_dict()
                    )
                )
                del finetuned_model  # free memory
            check_parameterNamesMatch(task_vectors)

        # compute the merged task vector
        merged_tv = iso_c(
            task_vectors,
            accelerator=self.fabric.device,
            exclude_keys=self.exclude_keys,
        )

        # merged_parameters = pretrained_parameters + scaling_factor * merged_task_vector
        pretrained_model.load_state_dict(
            state_dict_add(
                pretrained_model.state_dict(),
                state_dict_mul(merged_tv, self.scaling_factor),
            )
        )

        return pretrained_model

IsotropicMergingInCommonAndTaskSubspace

Bases: BaseAlgorithm, LightningFabricMixin

Isotropic Merging in Common and Task-Specific Subspaces (Iso-CTS)

Source code in fusion_bench/method/isotropic_merging/iso.py
class IsotropicMergingInCommonAndTaskSubspace(BaseAlgorithm, LightningFabricMixin):
    """
    Isotropic Merging in Common and Task-Specific Subspaces (Iso-CTS)
    """

    def __init__(
        self,
        scaling_factor: float,
        common_space_fraction: float,
        exclude_keys: List[str] = None,
    ):
        self.common_space_fraction = common_space_fraction
        self.scaling_factor = scaling_factor
        self.exclude_keys = exclude_keys
        super().__init__()

    def run(self, modelpool: BaseModelPool):
        # load the pretrained model and the task vectors of all the finetuned models
        with torch.no_grad():
            pretrained_model = modelpool.load_pretrained_model()
            task_vectors = []
            for model_name in modelpool.model_names:
                finetuned_model = modelpool.load_model(model_name)
                task_vectors.append(
                    state_dict_sub(
                        finetuned_model.state_dict(), pretrained_model.state_dict()
                    )
                )
                del finetuned_model  # free memory
            check_parameterNamesMatch(task_vectors)

        # compute the merged task vector
        merged_tv = iso_cts(
            task_vectors,
            common_space_fraction=self.common_space_fraction,
            accelerator=self.fabric.device,
            exclude_keys=self.exclude_keys,
        )

        # merged_parameters = pretrained_parameters + scaling_factor * merged_task_vector
        pretrained_model.load_state_dict(
            state_dict_add(
                pretrained_model.state_dict(),
                state_dict_mul(merged_tv, self.scaling_factor),
            )
        )

        return pretrained_model

Distributed Model Merging

Gossip

CLIPTaskWiseGossipAlgorithm

Bases: TaskWiseGossipAlgorithm

A class for task-wise adaptive merging of CLIP models.

This class extends the TaskWiseGossipAlgorithm to provide specific functionality for CLIP models, including loading datasets, constructing zero-shot classification heads, and computing logits.

Attributes:

  • modelpool (CLIPVisionModelPool) –

    The model pool containing CLIP models.

  • _clip_processor (CLIPProcessor) –

    The CLIP processor for preparing inputs.

  • zeroshot_weights (dict) –

    A dictionary to store zero-shot weights for each task.

Source code in fusion_bench/method/gossip/clip_task_wise_gossip.py
class CLIPTaskWiseGossipAlgorithm(TaskWiseGossipAlgorithm):
    """
    A class for task-wise adaptive merging of CLIP models.

    This class extends the TaskWiseGossipAlgorithm to provide specific
    functionality for CLIP models, including loading datasets, constructing
    zero-shot classification heads, and computing logits.

    Attributes:
        modelpool (CLIPVisionModelPool): The model pool containing CLIP models.
        _clip_processor (CLIPProcessor): The CLIP processor for preparing inputs.
        zeroshot_weights (dict): A dictionary to store zero-shot weights for each task.
    """

    modelpool: CLIPVisionModelPool = None
    _clip_processor: CLIPProcessor = None
    zeroshot_weights = {}

    def __init__(self, algorithm_config: DictConfig):
        super().__init__(algorithm_config)

    @functools.cache
    def get_test_dataset(self, task: str):
        """
        Load the test dataset for the task.
        This method is cached, so the dataset is loaded only once.

        Args:
            task (str): The name of the task.

        Returns:
            CLIPDataset: The test dataset for the task.
        """
        log.info(f"Loading test dataset: {task}")
        dataset = self.modelpool.load_test_dataset(task)
        dataset = CLIPDataset(dataset, self._clip_processor)
        return dataset

    @functools.cache
    def get_shuffled_test_loader_iter(self, task: str):
        """
        Get an iterator over the shuffled test DataLoader for the task.

        Args:
            task (str): The name of the task.

        Returns:
            iterator: An iterator over the shuffled test DataLoader.
        """
        loader = DataLoader(
            self.get_test_dataset(task),
            batch_size=self.config.batch_size,
            shuffle=True,
            num_workers=self.config.num_workers,
            pin_memory=True,
        )
        if self._fabric is not None:
            loader = self._fabric.setup_dataloaders(loader)
        return iter(InfiniteDataLoader(loader))

    def on_test_time_adaptation_start(self):
        """
        Prepare for test-time adaptation.

        This method loads the CLIP processor and constructs the zero-shot
        classification head for each task.
        """
        if self._clip_processor is not None and self.zeroshot_weights is not None:
            return  # this can be reused in Gossip

        clip_model_config = self.modelpool.get_model_config("_pretrained_")
        pretrained_path = (
            clip_model_config.pretrained_model_name_or_path
            if hasattr(clip_model_config, "pretrained_model_name_or_path")
            else clip_model_config.path
        )

        with timeit_context("Loading CLIP processor and pretrained CLIP model."):
            self._clip_processor = CLIPProcessor.from_pretrained(pretrained_path)
            clip_model: CLIPModel = CLIPModel.from_pretrained(pretrained_path)

            clip_classifier = HFCLIPClassifier(clip_model, self._clip_processor)
            self.visual_projection = clip_model.visual_projection.requires_grad_(False)
            self.logit_scale_exp = clip_model.logit_scale.exp()
            if self._fabric is not None:
                self.visual_projection = self._fabric.to_device(self.visual_projection)
                self.logit_scale_exp = self._fabric.to_device(self.logit_scale_exp)

        for task in self.modelpool.model_names:
            cache_file = os.path.join(
                self.config.cache_dir,
                f"{os.path.basename(pretrained_path)}_{task}_zeroshot_weights.pt",
            )
            if os.path.exists(cache_file):
                log.info(f"Loading cached zeroshot weights for task: {task}")
                zeroshot_weights = torch.load(cache_file, map_location="cpu")
            else:
                log.info(f"Construct zero shot classification head for task: {task}")
                classnames, templates = get_classnames_and_templates(task)
                clip_classifier.set_classification_task(classnames, templates)
                zeroshot_weights = clip_classifier.zeroshot_weights
                log.info(f"save zeroshot weights to {cache_file}")
                torch.save(zeroshot_weights, cache_file)
            self.zeroshot_weights[task] = zeroshot_weights
            if self._fabric is not None:
                self.zeroshot_weights[task] = self._fabric.to_device(
                    self.zeroshot_weights[task]
                )

    def compute_logits(self, module, batch, task: str) -> Tensor:
        """
        Compute the logits for the given batch and task.

        This method computes the image embeddings, normalizes them, and calculates
        the cosine similarity with the text embeddings to produce classification logits.

        Args:
            module (nn.Module): The model module.
            batch (tuple): A batch of input data.
            task (str): The name of the task.

        Returns:
            Tensor: The classification logits for the batch.
        """
        images, _ = batch
        text_embeds = self.zeroshot_weights[task]

        image_embeds = module(images)[1]
        image_embeds = self.visual_projection(image_embeds)

        # normalize embeddings
        image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)

        # cosine similarity
        logits_per_text = (
            torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
        )
        logits_per_image = logits_per_text.t()

        return logits_per_image
compute_logits(module, batch, task)

Compute the logits for the given batch and task.

This method computes the image embeddings, normalizes them, and calculates the cosine similarity with the text embeddings to produce classification logits.

Parameters:

  • module (Module) –

    The model module.

  • batch (tuple) –

    A batch of input data.

  • task (str) –

    The name of the task.

Returns:

  • Tensor ( Tensor ) –

    The classification logits for the batch.

Source code in fusion_bench/method/gossip/clip_task_wise_gossip.py
def compute_logits(self, module, batch, task: str) -> Tensor:
    """
    Compute the logits for the given batch and task.

    This method computes the image embeddings, normalizes them, and calculates
    the cosine similarity with the text embeddings to produce classification logits.

    Args:
        module (nn.Module): The model module.
        batch (tuple): A batch of input data.
        task (str): The name of the task.

    Returns:
        Tensor: The classification logits for the batch.
    """
    images, _ = batch
    text_embeds = self.zeroshot_weights[task]

    image_embeds = module(images)[1]
    image_embeds = self.visual_projection(image_embeds)

    # normalize embeddings
    image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)

    # cosine similarity
    logits_per_text = (
        torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
    )
    logits_per_image = logits_per_text.t()

    return logits_per_image
get_shuffled_test_loader_iter(task) cached

Get an iterator over the shuffled test DataLoader for the task.

Parameters:

  • task (str) –

    The name of the task.

Returns:

  • iterator

    An iterator over the shuffled test DataLoader.

Source code in fusion_bench/method/gossip/clip_task_wise_gossip.py
@functools.cache
def get_shuffled_test_loader_iter(self, task: str):
    """
    Get an iterator over the shuffled test DataLoader for the task.

    Args:
        task (str): The name of the task.

    Returns:
        iterator: An iterator over the shuffled test DataLoader.
    """
    loader = DataLoader(
        self.get_test_dataset(task),
        batch_size=self.config.batch_size,
        shuffle=True,
        num_workers=self.config.num_workers,
        pin_memory=True,
    )
    if self._fabric is not None:
        loader = self._fabric.setup_dataloaders(loader)
    return iter(InfiniteDataLoader(loader))
get_test_dataset(task) cached

Load the test dataset for the task. This method is cached, so the dataset is loaded only once.

Parameters:

  • task (str) –

    The name of the task.

Returns:

  • CLIPDataset

    The test dataset for the task.

Source code in fusion_bench/method/gossip/clip_task_wise_gossip.py
@functools.cache
def get_test_dataset(self, task: str):
    """
    Load the test dataset for the task.
    This method is cached, so the dataset is loaded only once.

    Args:
        task (str): The name of the task.

    Returns:
        CLIPDataset: The test dataset for the task.
    """
    log.info(f"Loading test dataset: {task}")
    dataset = self.modelpool.load_test_dataset(task)
    dataset = CLIPDataset(dataset, self._clip_processor)
    return dataset
on_test_time_adaptation_start()

Prepare for test-time adaptation.

This method loads the CLIP processor and constructs the zero-shot classification head for each task.

Source code in fusion_bench/method/gossip/clip_task_wise_gossip.py
def on_test_time_adaptation_start(self):
    """
    Prepare for test-time adaptation.

    This method loads the CLIP processor and constructs the zero-shot
    classification head for each task.
    """
    if self._clip_processor is not None and self.zeroshot_weights is not None:
        return  # this can be reused in Gossip

    clip_model_config = self.modelpool.get_model_config("_pretrained_")
    pretrained_path = (
        clip_model_config.pretrained_model_name_or_path
        if hasattr(clip_model_config, "pretrained_model_name_or_path")
        else clip_model_config.path
    )

    with timeit_context("Loading CLIP processor and pretrained CLIP model."):
        self._clip_processor = CLIPProcessor.from_pretrained(pretrained_path)
        clip_model: CLIPModel = CLIPModel.from_pretrained(pretrained_path)

        clip_classifier = HFCLIPClassifier(clip_model, self._clip_processor)
        self.visual_projection = clip_model.visual_projection.requires_grad_(False)
        self.logit_scale_exp = clip_model.logit_scale.exp()
        if self._fabric is not None:
            self.visual_projection = self._fabric.to_device(self.visual_projection)
            self.logit_scale_exp = self._fabric.to_device(self.logit_scale_exp)

    for task in self.modelpool.model_names:
        cache_file = os.path.join(
            self.config.cache_dir,
            f"{os.path.basename(pretrained_path)}_{task}_zeroshot_weights.pt",
        )
        if os.path.exists(cache_file):
            log.info(f"Loading cached zeroshot weights for task: {task}")
            zeroshot_weights = torch.load(cache_file, map_location="cpu")
        else:
            log.info(f"Construct zero shot classification head for task: {task}")
            classnames, templates = get_classnames_and_templates(task)
            clip_classifier.set_classification_task(classnames, templates)
            zeroshot_weights = clip_classifier.zeroshot_weights
            log.info(f"save zeroshot weights to {cache_file}")
            torch.save(zeroshot_weights, cache_file)
        self.zeroshot_weights[task] = zeroshot_weights
        if self._fabric is not None:
            self.zeroshot_weights[task] = self._fabric.to_device(
                self.zeroshot_weights[task]
            )

CLIPLayerWiseGossipAlgorithm

Bases: CLIPClassificationMixin, LayerWiseGossipAlgorithm

Source code in fusion_bench/method/gossip/clip_layer_wise_gossip.py
class CLIPLayerWiseGossipAlgorithm(
    CLIPClassificationMixin,
    LayerWiseGossipAlgorithm,
):
    def on_test_time_adaptation_start(self):
        """
        Here we load the CLIP processor and construct the zero-shot classification head for each task.
        """
        if self.whether_setup_zero_shot_classification_head == False:
            self.setup_zero_shot_classification_head()

    @functools.cache
    def get_shuffled_test_loader_iter(self, task: str):
        return super().get_shuffled_test_loader_iter(
            task,
            batch_size=self.config.batch_size,
            num_workers=self.config.num_workers,
        )
on_test_time_adaptation_start()

Here we load the CLIP processor and construct the zero-shot classification head for each task.

Source code in fusion_bench/method/gossip/clip_layer_wise_gossip.py
def on_test_time_adaptation_start(self):
    """
    Here we load the CLIP processor and construct the zero-shot classification head for each task.
    """
    if self.whether_setup_zero_shot_classification_head == False:
        self.setup_zero_shot_classification_head()

FlanT5LayerWiseGossipAlgorithm

Bases: BaseAlgorithm, LightningFabricMixin, SimpleProfilerMixin

Source code in fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py
class FlanT5LayerWiseGossipAlgorithm(
    BaseAlgorithm,
    LightningFabricMixin,
    SimpleProfilerMixin,
):

    def __init__(
        self,
        optimizer: DictConfig,
        dataloader_kwargs: DictConfig,
        init_values: float,
        max_steps: int,
        merging_weights_load_path: Optional[Union[str, Path]] = None,
        merging_weights_save_path: Optional[Union[str, Path]] = None,
        clamp_weights: bool = False,
        tie_weights: bool = True,
        strict: bool = False,
        cache_dir: str = "outputs/cache",
        variant: Optional[str] = None,
        **kwargs,
    ):
        self._optimizer = optimizer
        self.dataloader_kwargs = dataloader_kwargs
        self.init_values = init_values
        self.merging_weights_load_path = merging_weights_load_path
        self.merging_weights_save_path = merging_weights_save_path
        self.clamp_weights = clamp_weights
        self.tie_weights = tie_weights
        self.strict = strict
        self.max_steps = max_steps
        self.cache_dir = cache_dir
        self.variant = variant

        self.configs = SimpleNamespace(**kwargs)
        self.configs.init_values = init_values
        self.configs.clamp_weights = clamp_weights
        self.configs.tie_weights = tie_weights
        self.configs.strict = strict
        if isinstance(self.configs.accuracy_test_interval, ListConfig):
            self.configs.accuracy_test_interval = list(
                self.configs.accuracy_test_interval
            )
        elif isinstance(self.configs.accuracy_test_interval, int):
            pass
        else:
            log.warning(
                f"Unexpected type of accuracy_test_interval: {type(self.configs.accuracy_test_interval)}"
            )
        super().__init__(**kwargs)

    @rank_zero_only
    def save_merging_weights(self, file_path: str, merging_weights: torch.Tensor):
        """
        Save the merging weights to a file.

        Args:
            file_path (str): The path to save the merging weights.
            merging_weights (torch.Tensor): The merging weights to save.
        """
        if self.fabric.is_global_zero and self.merging_weights_save_path is not None:
            if isinstance(file_path, str) and not file_path.startswith(("/", ".")):
                # if the file path is not absolute or relative to current working directory, save it in the log directory
                save_path = os.path.join(self.log_dir, file_path)
            else:
                save_path = file_path
            log.info(f"saving merging weights to {save_path}.")
            if os.path.dirname(save_path):
                os.makedirs(os.path.dirname(save_path), exist_ok=True)
            torch.save(merging_weights.detach().cpu(), save_path)

    def free_gpu_memory(self, module: LayerWiseMergedModel):
        module.pretrained_model.to("cpu")
        for model in module.task_vectors:
            model.to("cpu")
        del module
        gc.collect()
        torch.cuda.empty_cache()
        log.info(get_memory_usage("after freeing memory, the memory usage of GPU is:"))

    def update_datasets(self, datasets):
        """
        for evary epoch of local adamerging, we only use the data set corresponding to the model involved in the fusion
        """
        num_datasets = len(datasets)
        datasets_copy = datasets.copy()
        for i in range(num_datasets):
            datasets[i] = (
                datasets_copy[i]
                .union(datasets_copy[(i + 1) % num_datasets])
                .union(datasets_copy[(i - 1) % num_datasets])
            )
        return datasets

    def run(self, modelpool: Seq2SeqLMPool, **kwargs):
        """
        Run the Layer-Wise AdaMerging Algorithm.

        This method constructs the wrapped model and performs test-time adaptation if necessary.

        Args:
            modelpool (ModelPool): The model pool containing the pretrained and fine-tuned models.

        Returns:
            LayerWiseMergedModel: The merged model after test-time adaptation.
        """
        log.info("Fusing models using layer-wise adaptive merging.")
        self.modelpool = modelpool
        self.num_finetuned_models = len(modelpool.model_names)
        datasets = [{dataset} for dataset in modelpool.model_names]

        with self.profile("construct the wrapped model"):
            model_scheduler = ModelScheduler(self.configs, self.modelpool)

        if self.merging_weights_load_path is not None:
            # skip the test-time adaptation
            return module.merge_and_unload()
        else:
            for step_idx in tqdm(
                range(self.configs.gossip_max_steps),
                "Gossip merging",
                dynamic_ncols=True,
            ):
                datasets = self.update_datasets(datasets)
                log.info(f"Gossip merging step:, {step_idx}")
                for model_id in tqdm(
                    range(self.num_finetuned_models),
                    "local admerging",
                    dynamic_ncols=True,
                ):
                    if self.configs.gossip_skip_adamerging == True:
                        # skip adamerging, only merge
                        with self.profile("construct the local wrapped model"):
                            module = model_scheduler(model_id)
                        log.info(
                            f"skip adamerging, only merge ({modelpool.model_names[model_id]})"
                        )
                        model_scheduler.store_model(module.merge_weights(), model_id)
                        self.free_gpu_memory(module)
                    else:
                        with self.profile("construct the local wrapped model"):
                            module = model_scheduler(model_id)

                        if self.configs.improve_dataset == True:
                            log.info(
                                f"improved datasets, the datasets used in this local merging is {datasets[model_id]}"
                            )
                        else:
                            log.info(
                                f"unimproved datasets, the datasets used in this local merging is {modelpool.model_names}"
                            )
                        with self.profile("test-time adaptation"):
                            module = self.test_time_adaptation(
                                module, datasets[model_id]
                            )
                        # if self.configs.get("save_merging_weights", False):
                        #     self.save_merging_weights(
                        #         self.configs.save_merging_weights, module.merge_weight
                        #     )
                        model_scheduler.store_model(module.merge_weights(), model_id)
                        log.info(
                            get_memory_usage(
                                f"after local merging ({modelpool.model_names[model_id]}), the memory usage of GPU is:"
                            )
                        )
                        self.free_gpu_memory(
                            module
                        )  # simulate distributed GPU memory usage as much as possible

                model_scheduler.update_models()
                do_evaluation = False  # whether to do evaluation after each Gossip step
                if isinstance(self.configs.accuracy_test_interval, list):
                    if (step_idx + 1) in self.configs.accuracy_test_interval:
                        do_evaluation = True
                elif isinstance(self.configs.accuracy_test_interval, int):
                    if (
                        self.configs.accuracy_test_interval != 0
                        and (step_idx + 1) % self.configs.accuracy_test_interval == 0
                    ):
                        do_evaluation = True
                if do_evaluation:
                    self._program.evaluate_merged_model(
                        self._program.taskpool, model_scheduler.get_final_models()
                    )
                    model_scheduler.move_to("cpu")

        return model_scheduler.get_final_models()

    @functools.cache
    def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
        """
        Loader of test dataset for test-time adaptation. labels are not needed.

        Args:
            task (str): The name of the task.

        Returns:
            DataLoader: The data loader for the test dataset.
        """
        dataloader_kwargs = dict(self.dataloader_kwargs)
        dataloader_kwargs.update(dict(shuffle=True, collate_fn=default_data_collator))

        dataset = self.modelpool.load_test_dataset(task)
        loader = DataLoader(dataset, **dataloader_kwargs)

        if self.fabric is not None:
            loader = self.fabric.setup_dataloaders(loader)
        return iter(InfiniteDataLoader(loader))

    def compute_logits(
        self,
        module: Union[T5ForConditionalGeneration, LayerWiseMergedModel],
        batch,
        task: str,
    ) -> Tensor:
        """
        Compute the logits for the given images and task.

        Args:
            module: The model module.
            images (Tensor): The input images.
            task (str): The name of the task.

        Returns:
            Tensor: The computed logits.
        """
        input_ids: Tensor = batch["input_ids"]
        attention_mask: Tensor = batch["attention_mask"]

        # remove padding tokens from the input
        while attention_mask[:, -1].eq(0).all():
            input_ids = input_ids[:, :-1]
            attention_mask = attention_mask[:, :-1]

        outputs = module(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=torch.ones(
                input_ids.size(0), 1, dtype=torch.long, device=input_ids.device
            ),
        )
        logits = outputs.logits[:, 0, :]
        return logits

    def on_test_time_adaptation_start(self):
        """
        Something to do before the test-time adaptation starts. Such as setting up the task-specific heads.
        """
        pass

    def test_time_adaptation(self, module: LayerWiseMergedModel, datasets):
        """
        Perform test-time adaptation on the merged model.

        This method adapts the merging weights during test-time to improve performance.

        Args:
            module (LayerWiseMergedModel): The merged model.

        Returns:
            LayerWiseMergedModel: The adapted merged model.
        """
        self.on_test_time_adaptation_start()

        # configure optimizer
        optimizer = instantiate(self._optimizer, [module.merge_weight])
        module, optimizer = self.fabric.setup(module, optimizer)

        module.train()
        module.merge_weights()
        for step_idx in (
            pbar := tqdm(
                range(self.max_steps if not self.is_debug_mode else 1),
                ("[DEBUG MODE] " if self.is_debug_mode else "")
                + "AdaMerging Test-time adaptation",
                dynamic_ncols=True,
            )
        ):
            if self.variant == "mgda":
                total_loss = self._compute_gradients_using_mgda(module)
            else:
                total_loss = 0
                for task in self.modelpool.model_names:
                    with self.profile("data loading"):
                        batch = next(self.get_shuffled_test_loader_iter(task))
                    with self.profile("forward pass"):
                        logits = self.compute_logits(module, batch, task)
                        logits = logits.mean(dim=0, keepdim=True)
                        loss = entropy_loss(logits)
                        total_loss += loss
                    with self.profile("backward pass"):
                        self.fabric.backward(loss, retain_graph=True)

            with self.profile("optimizer step"):
                optimizer.step()
                optimizer.zero_grad()
            with self.profile("merging weights"):
                module.merge_weights()

            metrics = {
                "train/loss": total_loss.item(),
                "train/weight_max": module.merge_weight.max().item(),
                "train/weight_min": module.merge_weight.min().item(),
                "train/weight_mean": module.merge_weight.mean().item(),
            }
            self.fabric.log_dict(metrics, step=step_idx)
            pbar.set_postfix(metrics)

        self.print_profile_summary()
        del optimizer
        gc.collect()
        torch.cuda.empty_cache()
        return module

    def _compute_gradients_using_mgda(self, module: LayerWiseMergedModel):
        all_grads = []
        total_loss = 0
        # default behavior for first-order optimizers
        for task in self.modelpool.model_names:
            with self.profile("data loading"):
                batch = next(self.get_shuffled_test_loader_iter(task))
            with self.profile("forward pass"):
                logits = self.compute_logits(module, batch, task)
                logits = logits.mean(dim=0, keepdim=True)
                loss = entropy_loss(logits)
                total_loss += loss
            with self.profile("backward pass"):
                # self.fabric.backward(loss, retain_graph=True)
                _grads = torch.autograd.grad(
                    loss,
                    [module.merge_weight],
                    create_graph=False,
                    retain_graph=True,
                )
                all_grads.append(_grads[0].flatten().detach())
        sol, min_norm = MinNormSolver.find_min_norm_element(all_grads)
        if not isinstance(sol, torch.Tensor):
            sol = torch.from_numpy(sol)
        sol = sol.to(
            device=module.merge_weight.device,
            dtype=module.merge_weight.dtype,
        )
        grad = torch.stack(all_grads) * sol.view(-1, 1)
        module.merge_weight.grad = grad.sum(dim=0).view_as(module.merge_weight)
        return total_loss
compute_logits(module, batch, task)

Compute the logits for the given images and task.

Parameters:

  • module (Union[T5ForConditionalGeneration, LayerWiseMergedModel]) –

    The model module.

  • images (Tensor) –

    The input images.

  • task (str) –

    The name of the task.

Returns:

  • Tensor ( Tensor ) –

    The computed logits.

Source code in fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py
def compute_logits(
    self,
    module: Union[T5ForConditionalGeneration, LayerWiseMergedModel],
    batch,
    task: str,
) -> Tensor:
    """
    Compute the logits for the given images and task.

    Args:
        module: The model module.
        images (Tensor): The input images.
        task (str): The name of the task.

    Returns:
        Tensor: The computed logits.
    """
    input_ids: Tensor = batch["input_ids"]
    attention_mask: Tensor = batch["attention_mask"]

    # remove padding tokens from the input
    while attention_mask[:, -1].eq(0).all():
        input_ids = input_ids[:, :-1]
        attention_mask = attention_mask[:, :-1]

    outputs = module(
        input_ids=input_ids,
        attention_mask=attention_mask,
        decoder_input_ids=torch.ones(
            input_ids.size(0), 1, dtype=torch.long, device=input_ids.device
        ),
    )
    logits = outputs.logits[:, 0, :]
    return logits
get_shuffled_test_loader_iter(task) cached

Loader of test dataset for test-time adaptation. labels are not needed.

Parameters:

  • task (str) –

    The name of the task.

Returns:

  • DataLoader ( DataLoader ) –

    The data loader for the test dataset.

Source code in fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py
@functools.cache
def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
    """
    Loader of test dataset for test-time adaptation. labels are not needed.

    Args:
        task (str): The name of the task.

    Returns:
        DataLoader: The data loader for the test dataset.
    """
    dataloader_kwargs = dict(self.dataloader_kwargs)
    dataloader_kwargs.update(dict(shuffle=True, collate_fn=default_data_collator))

    dataset = self.modelpool.load_test_dataset(task)
    loader = DataLoader(dataset, **dataloader_kwargs)

    if self.fabric is not None:
        loader = self.fabric.setup_dataloaders(loader)
    return iter(InfiniteDataLoader(loader))
on_test_time_adaptation_start()

Something to do before the test-time adaptation starts. Such as setting up the task-specific heads.

Source code in fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py
def on_test_time_adaptation_start(self):
    """
    Something to do before the test-time adaptation starts. Such as setting up the task-specific heads.
    """
    pass
run(modelpool, **kwargs)

Run the Layer-Wise AdaMerging Algorithm.

This method constructs the wrapped model and performs test-time adaptation if necessary.

Parameters:

  • modelpool (ModelPool) –

    The model pool containing the pretrained and fine-tuned models.

Returns:

  • LayerWiseMergedModel

    The merged model after test-time adaptation.

Source code in fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py
def run(self, modelpool: Seq2SeqLMPool, **kwargs):
    """
    Run the Layer-Wise AdaMerging Algorithm.

    This method constructs the wrapped model and performs test-time adaptation if necessary.

    Args:
        modelpool (ModelPool): The model pool containing the pretrained and fine-tuned models.

    Returns:
        LayerWiseMergedModel: The merged model after test-time adaptation.
    """
    log.info("Fusing models using layer-wise adaptive merging.")
    self.modelpool = modelpool
    self.num_finetuned_models = len(modelpool.model_names)
    datasets = [{dataset} for dataset in modelpool.model_names]

    with self.profile("construct the wrapped model"):
        model_scheduler = ModelScheduler(self.configs, self.modelpool)

    if self.merging_weights_load_path is not None:
        # skip the test-time adaptation
        return module.merge_and_unload()
    else:
        for step_idx in tqdm(
            range(self.configs.gossip_max_steps),
            "Gossip merging",
            dynamic_ncols=True,
        ):
            datasets = self.update_datasets(datasets)
            log.info(f"Gossip merging step:, {step_idx}")
            for model_id in tqdm(
                range(self.num_finetuned_models),
                "local admerging",
                dynamic_ncols=True,
            ):
                if self.configs.gossip_skip_adamerging == True:
                    # skip adamerging, only merge
                    with self.profile("construct the local wrapped model"):
                        module = model_scheduler(model_id)
                    log.info(
                        f"skip adamerging, only merge ({modelpool.model_names[model_id]})"
                    )
                    model_scheduler.store_model(module.merge_weights(), model_id)
                    self.free_gpu_memory(module)
                else:
                    with self.profile("construct the local wrapped model"):
                        module = model_scheduler(model_id)

                    if self.configs.improve_dataset == True:
                        log.info(
                            f"improved datasets, the datasets used in this local merging is {datasets[model_id]}"
                        )
                    else:
                        log.info(
                            f"unimproved datasets, the datasets used in this local merging is {modelpool.model_names}"
                        )
                    with self.profile("test-time adaptation"):
                        module = self.test_time_adaptation(
                            module, datasets[model_id]
                        )
                    # if self.configs.get("save_merging_weights", False):
                    #     self.save_merging_weights(
                    #         self.configs.save_merging_weights, module.merge_weight
                    #     )
                    model_scheduler.store_model(module.merge_weights(), model_id)
                    log.info(
                        get_memory_usage(
                            f"after local merging ({modelpool.model_names[model_id]}), the memory usage of GPU is:"
                        )
                    )
                    self.free_gpu_memory(
                        module
                    )  # simulate distributed GPU memory usage as much as possible

            model_scheduler.update_models()
            do_evaluation = False  # whether to do evaluation after each Gossip step
            if isinstance(self.configs.accuracy_test_interval, list):
                if (step_idx + 1) in self.configs.accuracy_test_interval:
                    do_evaluation = True
            elif isinstance(self.configs.accuracy_test_interval, int):
                if (
                    self.configs.accuracy_test_interval != 0
                    and (step_idx + 1) % self.configs.accuracy_test_interval == 0
                ):
                    do_evaluation = True
            if do_evaluation:
                self._program.evaluate_merged_model(
                    self._program.taskpool, model_scheduler.get_final_models()
                )
                model_scheduler.move_to("cpu")

    return model_scheduler.get_final_models()
save_merging_weights(file_path, merging_weights)

Save the merging weights to a file.

Parameters:

  • file_path (str) –

    The path to save the merging weights.

  • merging_weights (Tensor) –

    The merging weights to save.

Source code in fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py
@rank_zero_only
def save_merging_weights(self, file_path: str, merging_weights: torch.Tensor):
    """
    Save the merging weights to a file.

    Args:
        file_path (str): The path to save the merging weights.
        merging_weights (torch.Tensor): The merging weights to save.
    """
    if self.fabric.is_global_zero and self.merging_weights_save_path is not None:
        if isinstance(file_path, str) and not file_path.startswith(("/", ".")):
            # if the file path is not absolute or relative to current working directory, save it in the log directory
            save_path = os.path.join(self.log_dir, file_path)
        else:
            save_path = file_path
        log.info(f"saving merging weights to {save_path}.")
        if os.path.dirname(save_path):
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
        torch.save(merging_weights.detach().cpu(), save_path)
test_time_adaptation(module, datasets)

Perform test-time adaptation on the merged model.

This method adapts the merging weights during test-time to improve performance.

Parameters:

Returns:

  • LayerWiseMergedModel

    The adapted merged model.

Source code in fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py
def test_time_adaptation(self, module: LayerWiseMergedModel, datasets):
    """
    Perform test-time adaptation on the merged model.

    This method adapts the merging weights during test-time to improve performance.

    Args:
        module (LayerWiseMergedModel): The merged model.

    Returns:
        LayerWiseMergedModel: The adapted merged model.
    """
    self.on_test_time_adaptation_start()

    # configure optimizer
    optimizer = instantiate(self._optimizer, [module.merge_weight])
    module, optimizer = self.fabric.setup(module, optimizer)

    module.train()
    module.merge_weights()
    for step_idx in (
        pbar := tqdm(
            range(self.max_steps if not self.is_debug_mode else 1),
            ("[DEBUG MODE] " if self.is_debug_mode else "")
            + "AdaMerging Test-time adaptation",
            dynamic_ncols=True,
        )
    ):
        if self.variant == "mgda":
            total_loss = self._compute_gradients_using_mgda(module)
        else:
            total_loss = 0
            for task in self.modelpool.model_names:
                with self.profile("data loading"):
                    batch = next(self.get_shuffled_test_loader_iter(task))
                with self.profile("forward pass"):
                    logits = self.compute_logits(module, batch, task)
                    logits = logits.mean(dim=0, keepdim=True)
                    loss = entropy_loss(logits)
                    total_loss += loss
                with self.profile("backward pass"):
                    self.fabric.backward(loss, retain_graph=True)

        with self.profile("optimizer step"):
            optimizer.step()
            optimizer.zero_grad()
        with self.profile("merging weights"):
            module.merge_weights()

        metrics = {
            "train/loss": total_loss.item(),
            "train/weight_max": module.merge_weight.max().item(),
            "train/weight_min": module.merge_weight.min().item(),
            "train/weight_mean": module.merge_weight.mean().item(),
        }
        self.fabric.log_dict(metrics, step=step_idx)
        pbar.set_postfix(metrics)

    self.print_profile_summary()
    del optimizer
    gc.collect()
    torch.cuda.empty_cache()
    return module
update_datasets(datasets)

for evary epoch of local adamerging, we only use the data set corresponding to the model involved in the fusion

Source code in fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py
def update_datasets(self, datasets):
    """
    for evary epoch of local adamerging, we only use the data set corresponding to the model involved in the fusion
    """
    num_datasets = len(datasets)
    datasets_copy = datasets.copy()
    for i in range(num_datasets):
        datasets[i] = (
            datasets_copy[i]
            .union(datasets_copy[(i + 1) % num_datasets])
            .union(datasets_copy[(i - 1) % num_datasets])
        )
    return datasets

Continual Model Merging

Orthogonal Projection-based Continual Merging (OPCM)

OPCMForCLIP

Bases: BaseAlgorithm, LightningFabricMixin, SimpleProfilerMixin

Source code in fusion_bench/method/opcm/opcm.py
class OPCMForCLIP(
    BaseAlgorithm,
    LightningFabricMixin,
    SimpleProfilerMixin,
):
    def __init__(
        self,
        alpha: float,
        shuffle_order: bool = True,
        seed: Optional[int] = None,
        save_on_every_step: bool = True,
        evaluate_on_every_step: bool = False,
        **kwargs,
    ):
        """
        Continual Model Merging via SVD Projection.

        Args:
            alpha (float): the scaling factor for the SVD projection.
            shuffle_order (bool): whether to shuffle the order of the models.
            seed (Optional[int]): the seed to use.
            save_on_every_step (bool): whether to save the merged model on every step.
            evaluate_on_every_step (bool): whether to evaluate the merged model on every step.
        """
        self.alpha = alpha
        self.shuffle_order = shuffle_order
        self.seed = seed
        self.save_on_every_step = save_on_every_step
        self.evaluate_on_every_step = evaluate_on_every_step
        super().__init__(**kwargs)

    @torch.no_grad()
    def run(self, modelpool: BaseModelPool):
        if self.seed is not None:
            L.seed_everything(self.seed)
        accelerator = self.fabric.device

        with self.profile("loading model"):
            pretrained_model = modelpool.load_pretrained_model()

        model_names = modelpool.model_names
        if self.shuffle_order:
            random.shuffle(model_names)

        self.taskpool = cast(CLIPVisionModelTaskPool, self._program.taskpool)
        self._test_datasets = deepcopy(self.taskpool._test_datasets)
        """Configuration for the test datasets"""

        # log the model names
        if self.log_dir is not None:
            save_to_json(model_names, Path(self.log_dir) / "model_names.json")
            tensorboard_summarywriter: "SummaryWriter" = self.tensorboard_summarywriter
            tensorboard_summarywriter.add_text(
                "global/model_names", str(model_names), global_step=0
            )

        # get the average model
        with self.profile("loading model"):
            merged_model = modelpool.load_model(model_names[0])

        if self.evaluate_on_every_step:
            with self.profile("evaluating model"):
                self.taskpool._is_setup = False
                self.taskpool._test_datasets = DictConfig(
                    {model_names[0]: self._test_datasets[model_names[0]]}
                )
                report = self.taskpool.evaluate(deepcopy(merged_model))
                save_to_json(report, Path(self.log_dir) / "report_0.json")

        self.avg_task_vector_norm = get_task_vector_norm(merged_model, pretrained_model)
        self.all_task_vector_norm = [self.avg_task_vector_norm]
        self.fabric.log("model/task_vector_norm", self.avg_task_vector_norm, step=0)
        self.fabric.log("model/avg_task_vector_norm", self.avg_task_vector_norm, step=0)
        self.fabric.log(
            "model/merged_task_vector_norm", self.avg_task_vector_norm, step=0
        )

        self.previous_lambda_t = 1
        self.lambda_t = None
        self.fabric.log("model/lambda_t", self.previous_lambda_t, step=0)
        self.fabric.log("empirical/lambda_t", 1, step=0)

        if self.save_on_every_step:
            self.save_merged_model(merged_model, 0)

        for model_idx, model_name in tqdm(
            enumerate(model_names[1:]), desc="Processing models"
        ):
            model_idx += 1
            with self.profile("loading model"):
                task_model = modelpool.load_model(model_name)

            with self.profile("merging model"):
                self.all_task_vector_norm.append(
                    get_task_vector_norm(task_model, pretrained_model)
                )
                self.avg_task_vector_norm = np.mean(self.all_task_vector_norm)
                self.fabric.log(
                    "model/task_vector_norm",
                    self.all_task_vector_norm[-1],
                    step=model_idx,
                )
                self.fabric.log(
                    "model/avg_task_vector_norm",
                    self.avg_task_vector_norm,
                    step=model_idx,
                )

                self.lambda_t = 1  # temporary value

                for module_name, module in tqdm(
                    list(merged_model.named_modules()),
                    desc=f"Processing {model_name}",
                    leave=False,
                ):
                    if not is_leaf_module(module):
                        continue

                    if isinstance(module, nn.Linear):
                        module.weight.data = self.merge_linear_weights(
                            module.weight,
                            pretrained_model.get_submodule(module_name).weight,
                            task_model.get_submodule(module_name).weight,
                            param_name=".".join([module_name, "weight"]),
                            alpha=self.alpha,
                            accelerator=accelerator,
                        )
                        if module.bias is not None:
                            module.bias.data = self.merge_other_parameters(
                                module.bias,
                                pretrained_model.get_submodule(module_name).bias,
                                task_model.get_submodule(module_name).bias,
                                param_name=".".join([module_name, "bias"]),
                                accelerator=accelerator,
                            )
                    else:
                        for param_name, param in module.named_parameters():
                            param.data = self.merge_other_parameters(
                                merged_W=param,
                                pretrained_W=pretrained_model.get_submodule(
                                    module_name
                                ).get_parameter(param_name),
                                task_W=task_model.get_submodule(
                                    module_name
                                ).get_parameter(param_name),
                                param_name=".".join([module_name, param_name]),
                                accelerator=accelerator,
                            )

                task_vector_norm = get_task_vector_norm(merged_model, pretrained_model)
                self.lambda_t *= task_vector_norm / self.avg_task_vector_norm
                for param_name, param in merged_model.named_parameters():
                    param.data = pretrained_model.get_parameter(param_name) + (
                        param - pretrained_model.get_parameter(param_name)
                    ) * (self.avg_task_vector_norm / task_vector_norm)
                self.fabric.log("model/lambda_t", self.lambda_t, step=model_idx)
                self.fabric.log(
                    "empirical/lambda_t", np.sqrt(model_idx + 1), step=model_idx
                )
                self.previous_lambda_t = self.lambda_t
                self.lambda_t = None

                self.fabric.log(
                    "model/merged_task_vector_norm",
                    get_task_vector_norm(merged_model, pretrained_model),
                    step=model_idx,
                )

            if self.save_on_every_step:
                with self.profile("saving model"):
                    self.save_merged_model(merged_model, model_idx)

            if self.evaluate_on_every_step:
                with self.profile("evaluating model"):
                    self.taskpool._is_setup = False
                    self.taskpool._test_datasets = DictConfig(
                        {
                            n: self._test_datasets[n]
                            for n in model_names[: model_idx + 1]
                        }
                    )
                    report = self.taskpool.evaluate(deepcopy(merged_model))
                    save_to_json(
                        report, Path(self.log_dir) / f"report_{model_idx}.json"
                    )

        self.print_profile_summary()
        return merged_model

    def save_merged_model(self, merged_model: CLIPVisionModel, step: int):
        os.makedirs(Path(self.log_dir) / "checkpoints", exist_ok=True)
        merged_model.save_pretrained(
            Path(self.log_dir) / "checkpoints" / f"merged_model_{step}"
        )

    def merge_linear_weights(
        self,
        merged_W: Tensor,
        pretrained_W: Tensor,
        task_W: Tensor,
        param_name: str,
        alpha: float,
        accelerator: str = "cpu",
    ):
        original_device = merged_W.device
        merged_W = merged_W.to(accelerator)
        pretrained_W = pretrained_W.to(accelerator)
        task_W = task_W.to(accelerator)

        previous_merged_tv = merged_W - pretrained_W
        task_tv = task_W - pretrained_W

        u, s, v = svd(previous_merged_tv)
        rank = s.size(0)
        split_rank = (s.cumsum(dim=0) / s.sum() > alpha).float().argmax().item()

        projected_task_tv = u.T @ task_tv @ v
        projected_task_tv.diagonal().fill_(0)

        projected_task_tv[:split_rank, :split_rank] = 0

        cleaned_task_tv = u @ projected_task_tv @ v.T

        previous_lambda_t = self.previous_lambda_t
        lambda_t = self.lambda_t
        new_merged_W = (
            pretrained_W
            + (previous_lambda_t * previous_merged_tv + cleaned_task_tv) / lambda_t
        )
        return new_merged_W.to(original_device)

    def merge_other_parameters(
        self,
        merged_W: Tensor,
        pretrained_W: Tensor,
        task_W: Tensor,
        param_name: str,
        accelerator: str = "cpu",
    ):
        original_device = merged_W.device
        merged_W = merged_W.to(accelerator)
        pretrained_W = pretrained_W.to(accelerator)
        task_W = task_W.to(accelerator)

        previous_merged_tv = merged_W - pretrained_W
        task_tv = task_W - pretrained_W

        previous_lambda_t = self.previous_lambda_t
        lambda_t = self.lambda_t

        new_merged_W = (
            pretrained_W + (previous_lambda_t * previous_merged_tv + task_tv) / lambda_t
        )
        return new_merged_W.to(original_device)

    def compute_lambda_t(
        self, previous_merged_tv: Tensor, task_tv: Tensor, previous_lambda_t: float
    ):
        previous_merged_tv = torch.flatten(previous_merged_tv)
        task_tv = torch.flatten(task_tv)

        lambda_t = torch.linalg.vector_norm(
            previous_lambda_t * previous_merged_tv + task_tv
        ) / torch.linalg.vector_norm(previous_merged_tv)
        return lambda_t.item()
__init__(alpha, shuffle_order=True, seed=None, save_on_every_step=True, evaluate_on_every_step=False, **kwargs)

Continual Model Merging via SVD Projection.

Parameters:

  • alpha (float) –

    the scaling factor for the SVD projection.

  • shuffle_order (bool, default: True ) –

    whether to shuffle the order of the models.

  • seed (Optional[int], default: None ) –

    the seed to use.

  • save_on_every_step (bool, default: True ) –

    whether to save the merged model on every step.

  • evaluate_on_every_step (bool, default: False ) –

    whether to evaluate the merged model on every step.

Source code in fusion_bench/method/opcm/opcm.py
def __init__(
    self,
    alpha: float,
    shuffle_order: bool = True,
    seed: Optional[int] = None,
    save_on_every_step: bool = True,
    evaluate_on_every_step: bool = False,
    **kwargs,
):
    """
    Continual Model Merging via SVD Projection.

    Args:
        alpha (float): the scaling factor for the SVD projection.
        shuffle_order (bool): whether to shuffle the order of the models.
        seed (Optional[int]): the seed to use.
        save_on_every_step (bool): whether to save the merged model on every step.
        evaluate_on_every_step (bool): whether to evaluate the merged model on every step.
    """
    self.alpha = alpha
    self.shuffle_order = shuffle_order
    self.seed = seed
    self.save_on_every_step = save_on_every_step
    self.evaluate_on_every_step = evaluate_on_every_step
    super().__init__(**kwargs)