Skip to content

Ties Merging

Image title
Ties-Merging. Credit to 1

Ties-Merging1 represents a novel and structured approach to consolidating multiple task-specific models into a single, efficient multi-task model. This method employs a sequence of deliberate steps to systematically merge task vectors, ensuring that the final model effectively integrates the strengths of each individual task-specific model and resolves potential conflicts between them.

The Ties-Merging algorithm operates through three primary steps:

  1. Trim: This initial step involves refining the task-specific models by trimming unnecessary parameters, focusing the model on essential elements for each task.
  2. Elect Sign of Parameters: In this step, the algorithm selects the appropriate signs for the parameters, ensuring that the integrated model parameters are optimally oriented for multi-task learning.
  3. Disjoint Merge: Finally, the method performs a disjoint merge to combine the task-specific parameters into a single cohesive task vector, denoted as \(\tau\).

Given the final merged task vector \(\tau\), the ultimate model is determined similarly to the method used in task arithmetic. The formulation is expressed as:

\[ \theta = \theta_0 + \lambda \tau \]

where \(\lambda\) is a hyperparameter chosen based on the validation set to ensure the best-performing model.

By following these structured steps, Ties-Merging effectively integrates multiple task-specific models into a unified multi-task model, balancing the contributions of each task to enhance overall performance. The process ensures that the final model retains the benefits of the pre-trained model while optimally incorporating the diverse knowledge contained within the individual task-specific models.

Hyperparameter Tuning

alt text
Task Arithmetic and Ties-Merging. Here we illustrate the average performance of models merged using Task Arithmetic and Ties-Merging methods, with varying scaling coefficients. The subfigures represent different models: CLIP-ViT-B/32, CLIP-ViT-L/14, Flan-T5-base (LoRA fine-tuned), and Flan-T5-large (LoRA fine-tuned).

In the above figure, we show the average performance of Task Arithmetic and Ties-Merging merged models as the scaling coefficient varies. Subfigure (a), (b), (c), and (d) show the results of CLIP-ViT-B/32, CLIP-ViT-L/14, Flan-T5-base (LoRA fine-tuned), and Flan-T5-large (LoRA fine-tuned), respectively. It is evident that the merged multi-task model hits a peak in average performance across various tasks when the scaling coefficient is set around 0.3. This value was empirically selected as the scaling coefficient in our experiments. As we increase the scaling coefficient beyond this point, the average performance of the model begins to decline, eventually even falling below the level of the pre-trained model’s original performance. This suggests that too high of a scaling coefficient can have a negative impact on the knowledge that the pre-trained model initially possessed, emphasizing the importance of calibrating the scaling coefficient parameter \(\lambda\) to avoid diminishing the model’s existing strengths.

Code Integration

Configuration template for the Ties-Merging algorithm:

config/method/ties_merging.yaml
name: ties_merging
# Scaling factor $\lambda$
scaling_factor: 0.5
threshold: 0.5
# List of keys to remove from the state dict, default is empty
remove_keys: []
# Function to merge the models, default is sum. Options are 'sum', 'mean', and 'max'
merge_func: sum 

Use the following command to run the Ties-Merging algorithm:

fusion_bench method=ties_merging ...

Reference

TiesMergingAlgorithm

Bases: BaseAlgorithm

TiesMergingAlgorithm is a class for fusing multiple models using the TIES merging technique.

Attributes:

  • scaling_factor (float) –

    The scaling factor to apply to the merged task vector.

  • threshold (float) –

    The threshold for resetting values in the task vector.

  • remove_keys (List[str]) –

    List of keys to remove from the state dictionary.

  • merge_func (Literal['sum', 'mean', 'max']) –

    The merge function to use for disjoint merging.

Source code in fusion_bench/method/ties_merging/ties_merging.py
class TiesMergingAlgorithm(BaseAlgorithm):
    """
    TiesMergingAlgorithm is a class for fusing multiple models using the TIES merging technique.

    Attributes:
        scaling_factor (float): The scaling factor to apply to the merged task vector.
        threshold (float): The threshold for resetting values in the task vector.
        remove_keys (List[str]): List of keys to remove from the state dictionary.
        merge_func (Literal["sum", "mean", "max"]): The merge function to use for disjoint merging.
    """

    _config_mapping = BaseAlgorithm._config_mapping | {
        "scaling_factor": "scaling_factor",
        "threshold": "threshold",
        "remove_keys": "remove_keys",
        "merge_func": "merge_func",
    }

    def __init__(
        self,
        scaling_factor: float,
        threshold: float,
        remove_keys: List[str],
        merge_func: Literal["sum", "mean", "max"],
        **kwargs,
    ):
        """
        Initialize the TiesMergingAlgorithm with the given parameters.

        Args:
            scaling_factor (float): The scaling factor to apply to the merged task vector.
            threshold (float): The threshold for resetting values in the task vector.
            remove_keys (List[str]): List of keys to remove from the state dictionary.
            merge_func (Literal["sum", "mean", "max"]): The merge function to use for disjoint merging.
            **kwargs: Additional keyword arguments for the base class.
        """
        self.scaling_factor = scaling_factor
        self.threshold = threshold
        self.remove_keys = remove_keys
        self.merge_func = merge_func
        super().__init__(**kwargs)

    @torch.no_grad()
    def run(self, modelpool: BaseModelPool | Dict[str, nn.Module], **kwargs):
        """
        Run the TIES merging algorithm to fuse models in the model pool.

        Args:
            modelpool (BaseModelPool | Dict[str, nn.Module]): The model pool containing the models to fuse.

        Returns:
            nn.Module: The fused model.
        """
        log.info("Fusing models using ties merging.")
        modelpool = to_modelpool(modelpool)
        remove_keys = self.config.get("remove_keys", [])
        merge_func = self.config.get("merge_func", "sum")
        scaling_factor = self.scaling_factor
        threshold = self.threshold

        # Load the pretrained model
        pretrained_model = modelpool.load_model("_pretrained_")

        # Load the state dicts of the models
        ft_checks: List[StateDictType] = [
            modelpool.load_model(model_name).state_dict(keep_vars=True)
            for model_name in modelpool.model_names
        ]
        ptm_check: StateDictType = pretrained_model.state_dict(keep_vars=True)

        # Compute the task vectors
        flat_ft: Tensor = torch.vstack(
            [state_dict_to_vector(check, remove_keys) for check in ft_checks]
        )
        flat_ptm: Tensor = state_dict_to_vector(ptm_check, remove_keys)
        tv_flat_checks = flat_ft - flat_ptm

        # Perform TIES Merging
        merged_tv = ties_merging(
            tv_flat_checks,
            reset_thresh=threshold,
            merge_func=merge_func,
        )
        merged_check = flat_ptm + scaling_factor * merged_tv
        merged_state_dict = vector_to_state_dict(
            merged_check, ptm_check, remove_keys=remove_keys
        )

        # Load the merged state dict into the pretrained model
        pretrained_model.load_state_dict(merged_state_dict)
        return pretrained_model
_config_mapping = BaseAlgorithm._config_mapping | {'scaling_factor': 'scaling_factor', 'threshold': 'threshold', 'remove_keys': 'remove_keys', 'merge_func': 'merge_func'} class-attribute instance-attribute
merge_func = merge_func instance-attribute
remove_keys = remove_keys instance-attribute
scaling_factor = scaling_factor instance-attribute
threshold = threshold instance-attribute
__init__(scaling_factor, threshold, remove_keys, merge_func, **kwargs)

Initialize the TiesMergingAlgorithm with the given parameters.

Parameters:

  • scaling_factor
    (float) –

    The scaling factor to apply to the merged task vector.

  • threshold
    (float) –

    The threshold for resetting values in the task vector.

  • remove_keys
    (List[str]) –

    List of keys to remove from the state dictionary.

  • merge_func
    (Literal['sum', 'mean', 'max']) –

    The merge function to use for disjoint merging.

  • **kwargs

    Additional keyword arguments for the base class.

Source code in fusion_bench/method/ties_merging/ties_merging.py
def __init__(
    self,
    scaling_factor: float,
    threshold: float,
    remove_keys: List[str],
    merge_func: Literal["sum", "mean", "max"],
    **kwargs,
):
    """
    Initialize the TiesMergingAlgorithm with the given parameters.

    Args:
        scaling_factor (float): The scaling factor to apply to the merged task vector.
        threshold (float): The threshold for resetting values in the task vector.
        remove_keys (List[str]): List of keys to remove from the state dictionary.
        merge_func (Literal["sum", "mean", "max"]): The merge function to use for disjoint merging.
        **kwargs: Additional keyword arguments for the base class.
    """
    self.scaling_factor = scaling_factor
    self.threshold = threshold
    self.remove_keys = remove_keys
    self.merge_func = merge_func
    super().__init__(**kwargs)
run(modelpool, **kwargs)

Run the TIES merging algorithm to fuse models in the model pool.

Parameters:

  • modelpool
    (BaseModelPool | Dict[str, Module]) –

    The model pool containing the models to fuse.

Returns:

  • nn.Module: The fused model.

Source code in fusion_bench/method/ties_merging/ties_merging.py
@torch.no_grad()
def run(self, modelpool: BaseModelPool | Dict[str, nn.Module], **kwargs):
    """
    Run the TIES merging algorithm to fuse models in the model pool.

    Args:
        modelpool (BaseModelPool | Dict[str, nn.Module]): The model pool containing the models to fuse.

    Returns:
        nn.Module: The fused model.
    """
    log.info("Fusing models using ties merging.")
    modelpool = to_modelpool(modelpool)
    remove_keys = self.config.get("remove_keys", [])
    merge_func = self.config.get("merge_func", "sum")
    scaling_factor = self.scaling_factor
    threshold = self.threshold

    # Load the pretrained model
    pretrained_model = modelpool.load_model("_pretrained_")

    # Load the state dicts of the models
    ft_checks: List[StateDictType] = [
        modelpool.load_model(model_name).state_dict(keep_vars=True)
        for model_name in modelpool.model_names
    ]
    ptm_check: StateDictType = pretrained_model.state_dict(keep_vars=True)

    # Compute the task vectors
    flat_ft: Tensor = torch.vstack(
        [state_dict_to_vector(check, remove_keys) for check in ft_checks]
    )
    flat_ptm: Tensor = state_dict_to_vector(ptm_check, remove_keys)
    tv_flat_checks = flat_ft - flat_ptm

    # Perform TIES Merging
    merged_tv = ties_merging(
        tv_flat_checks,
        reset_thresh=threshold,
        merge_func=merge_func,
    )
    merged_check = flat_ptm + scaling_factor * merged_tv
    merged_state_dict = vector_to_state_dict(
        merged_check, ptm_check, remove_keys=remove_keys
    )

    # Load the merged state dict into the pretrained model
    pretrained_model.load_state_dict(merged_state_dict)
    return pretrained_model

  1. (NIPS 2023) Resolving Interference When Merging Models. http://arxiv.org/abs/2306.01708