Skip to content

Weighted Averaging

Weighted averaging, also known as weight-ensembling. In the context of full fine-tuned models, the weights are averaged according to their respective performance weights. Concretely, this means that if we have \(n\) models with their respective weights \(\theta_i\) and model-wise weights \(w_i\), the weights of the final model \(\theta\) are computed as:

\[ \theta = \sum_{i=1}^{n} w_i \theta_i \]

Examples

General Usage

Configuration template for the Weighted Averaging algorithm:

config/method/weighted_average.yaml
name: weighted_average
normalize: true # if true, the weights will be normalized before merging
weights: # List of weights for each model
  - 0.5
  - 0.5

Use the following command to run the Weighted Averaging algorithm:

fusion_bench method=weighted_average ...

Merge CLIP-ViT Models

The following command merges eight clip-ViT models using a weighted average approach. Because method.normalize is set to true, the weights are normalized to sum to 1, thus equivalent to simple average.

fusion_bench \
    method=weighted_average \
    method.normalize=true \
    method.weights=[0.3,0.3,0.3,0.3,0.3,0.3,0.3,0.3] \
    modelpool=clip-vit-base-patch32_TA8_model_only \
    taskpool=clip-vit-classification_TA8

Merge Llama/Mistral Models

Here is an example of how to use the Weighted Averaging algorithm to merge two LLama models. In particular, LLaMa models of the type transformers.LlamaForCausalLM are merged using the Weighted Averaging algorithm.

fusion_bench \
    method=weighted_average_for_llama \
    method.merged_model_save_path=outputs/test_merged_llama_model \
    modelpool=llama_for_causallm \
    taskpool=dummy

or using the following configuration file config/llama_weighted_average.yaml

fusion_bench --config-name llama_weighted_average
config/llama_weighted_average.yaml
defaults:
  - example_config
  - override method: weighted_average_for_llama
  - override modelpool: llama_for_causallm
  - _self_

modelpool:
  models:
    # the pre-trained model (base model) is optional
    # if not provided, the first model will be used as the base model
    - name: _pretrained_
      path: meta-llama/Meta-Llama-3-8B
    - name: expert_1
      path: meta-llama/Meta-Llama-3-8B
    - name: expert_2
      path: meta-llama/Meta-Llama-3-8B-Instruct

method:
  normalize: true # if true, the weights will be normalized before merging
  weights: # List of weights for each model
    - 0.5
    - 0.5
  # if true, only the backbone of the model will be merged and the head will be keeped as the pre-trained model (if the pre-trained model is provided, otherwise the head of the first model will be used)
  # if false, the whole model will be merged
  backbone_only: true

  merged_model_save_path: null
  save_tokenizer: true
  push_to_hub: false

References

WeightedAverageAlgorithm

Bases: BaseAlgorithm, SimpleProfilerMixin

Source code in fusion_bench/method/weighted_average/weighted_average.py
class WeightedAverageAlgorithm(BaseAlgorithm, SimpleProfilerMixin):

    _config_mapping = BaseAlgorithm._config_mapping | {
        "normalize": "normalize",
        "weights": "weights",
    }

    def __init__(
        self,
        normalize: bool,
        weights: List[float],
        verbose: bool = True,
        **kwargs,
    ):
        self.normalize = normalize
        self.weights = weights
        self.verbose = verbose
        log.disabled = not self.verbose
        super().__init__(**kwargs)

    @override
    @torch.no_grad()
    def run(self, modelpool: BaseModelPool):
        """
        Fuses the models in the model pool using a weighted average approach.

        Parameters
            modelpool (ModelPool): The pool of models to be fused.

        Raises
            ValueError: If the number of weights does not match the number of models in the model pool.

        Returns
            forward_model (torch.nn.Module): The resulting model after fusion.
        """
        if not isinstance(modelpool, BaseModelPool):
            modelpool = BaseModelPool(modelpool)

        log.info("Fusing models using weighted average.")
        weights = np.asarray(self.weights)
        if len(weights) != len(modelpool.model_names):
            raise ValueError(
                "Number of weights must match the number of models.,"
                f"but got {len(weights)} weights and {len(modelpool.model_names)} models."
                f"weights: {weights}, models: {modelpool.model_names}"
            )
        if self.normalize:
            weights = weights / np.sum(weights)
        if self.verbose:
            print(f"weights: {weights}, normalized: {self.normalize}")

        sd: Optional[StateDictType] = None
        forward_model = None

        for model_name, weight in zip(modelpool.model_names, weights):
            with self.profile("load_model"):
                model = modelpool.load_model(model_name)
            with self.profile("merge weights"):
                if sd is None:
                    sd = state_dict_mul(model.state_dict(keep_vars=True), weight)
                    forward_model = model
                else:
                    sd = state_dict_add(
                        sd, state_dict_mul(model.state_dict(keep_vars=True), weight)
                    )

        forward_model.load_state_dict(sd)
        if self.verbose:
            self.print_profile_summary()
        return forward_model
_config_mapping = BaseAlgorithm._config_mapping | {'normalize': 'normalize', 'weights': 'weights'} class-attribute instance-attribute
normalize = normalize instance-attribute
verbose = verbose instance-attribute
weights = weights instance-attribute
__init__(normalize, weights, verbose=True, **kwargs)
Source code in fusion_bench/method/weighted_average/weighted_average.py
def __init__(
    self,
    normalize: bool,
    weights: List[float],
    verbose: bool = True,
    **kwargs,
):
    self.normalize = normalize
    self.weights = weights
    self.verbose = verbose
    log.disabled = not self.verbose
    super().__init__(**kwargs)
run(modelpool)

Fuses the models in the model pool using a weighted average approach.

Parameters modelpool (ModelPool): The pool of models to be fused.

Raises ValueError: If the number of weights does not match the number of models in the model pool.

Returns forward_model (torch.nn.Module): The resulting model after fusion.

Source code in fusion_bench/method/weighted_average/weighted_average.py
@override
@torch.no_grad()
def run(self, modelpool: BaseModelPool):
    """
    Fuses the models in the model pool using a weighted average approach.

    Parameters
        modelpool (ModelPool): The pool of models to be fused.

    Raises
        ValueError: If the number of weights does not match the number of models in the model pool.

    Returns
        forward_model (torch.nn.Module): The resulting model after fusion.
    """
    if not isinstance(modelpool, BaseModelPool):
        modelpool = BaseModelPool(modelpool)

    log.info("Fusing models using weighted average.")
    weights = np.asarray(self.weights)
    if len(weights) != len(modelpool.model_names):
        raise ValueError(
            "Number of weights must match the number of models.,"
            f"but got {len(weights)} weights and {len(modelpool.model_names)} models."
            f"weights: {weights}, models: {modelpool.model_names}"
        )
    if self.normalize:
        weights = weights / np.sum(weights)
    if self.verbose:
        print(f"weights: {weights}, normalized: {self.normalize}")

    sd: Optional[StateDictType] = None
    forward_model = None

    for model_name, weight in zip(modelpool.model_names, weights):
        with self.profile("load_model"):
            model = modelpool.load_model(model_name)
        with self.profile("merge weights"):
            if sd is None:
                sd = state_dict_mul(model.state_dict(keep_vars=True), weight)
                forward_model = model
            else:
                sd = state_dict_add(
                    sd, state_dict_mul(model.state_dict(keep_vars=True), weight)
                )

    forward_model.load_state_dict(sd)
    if self.verbose:
        self.print_profile_summary()
    return forward_model

WeightedAverageForLLama

Bases: BaseAlgorithm

A class to perform weighted averaging of LlaMa/Mistral models.

Source code in fusion_bench/method/weighted_average/llama.py
class WeightedAverageForLLama(BaseAlgorithm):
    """
    A class to perform weighted averaging of LlaMa/Mistral models.
    """

    _config_mapping = BaseAlgorithm._config_mapping | {
        "normalize": "normalize",
        "weights": "weights",
        "backbone_only": "backbone_only",
        "merged_model_save_path": "merged_model_save_path",
        "save_tokenizer": "save_tokenizer",
        "push_to_hub": "push_to_hub",
    }

    def __init__(
        self,
        normalize: bool,
        weights: List[float],
        backbone_only: bool,
        merged_model_save_path: str,
        save_tokenizer: bool,
        push_to_hub: bool,
        **kwargs,
    ):
        """
        Initialize the WeightedAverageForLLama class with the given parameters.

        Args:
            normalize (bool): Whether to normalize the weights.
            weights (List[float]): The weights for averaging the models.
            backbone_only (bool): Whether to use only the backbone of the models.
            merged_model_save_path (str): The path to save the merged model.
            save_tokenizer (bool): Whether to save the tokenizer.
            push_to_hub (bool): Whether to push the model to the hub.
        """
        self.normalize = normalize
        self.weights = weights
        self.backbone_only = backbone_only
        self.merged_model_save_path = merged_model_save_path
        self.save_tokenizer = save_tokenizer
        self.push_to_hub = push_to_hub
        super().__init__(**kwargs)

    @override
    @torch.no_grad()
    def run(self, modelpool: CausalLMPool):
        """
        Executes the weighted averaging of models in the provided model pool.

        Args:
            modelpool (LLamaForCausalLMPoolThe):  pool of models to be averaged.

        Returns:
            base_model: The base model after merging the state dictionaries of the models in the pool.

        Raises:
            ValueError: If the number of weights does not match the number of models in the pool.
        """
        if modelpool.has_pretrained:
            base_model = modelpool.load_model("_pretrained_")
        else:
            base_model = modelpool.load_model(modelpool.model_names[0])

        weights = self.weights
        if len(weights) != len(modelpool.model_names):
            raise ValueError(
                "Number of weights must match the number of models.,"
                f"but got {len(weights)} weights and {len(modelpool.model_names)} models."
                f"weights: {weights}, models: {modelpool.model_names}"
            )
        if self.normalize:
            weights = np.asarray(weights)
            weights = weights / np.sum(weights)

        merged_state_dict: StateDictType = None
        for model_name, weight in zip(modelpool.model_names, weights):
            model = modelpool.load_model(model_name, backbone_only=self.backbone_only)
            sd = state_dict_mul(model.state_dict(), weight)
            if merged_state_dict is None:
                merged_state_dict = sd
            else:
                merged_state_dict = state_dict_add(merged_state_dict, sd)

        base_model.load_state_dict(
            merged_state_dict, strict=False if self.backbone_only else True
        )
        if self.merged_model_save_path is not None:
            with timeit_context(
                f"Saving the merged model to {self.merged_model_save_path}"
            ):
                modelpool.save_model(
                    base_model,
                    path=self.merged_model_save_path,
                    save_tokenizer=self.save_tokenizer,
                    push_to_hub=self.push_to_hub,
                )
        return base_model
__init__(normalize, weights, backbone_only, merged_model_save_path, save_tokenizer, push_to_hub, **kwargs)

Initialize the WeightedAverageForLLama class with the given parameters.

Parameters:

  • normalize
    (bool) –

    Whether to normalize the weights.

  • weights
    (List[float]) –

    The weights for averaging the models.

  • backbone_only
    (bool) –

    Whether to use only the backbone of the models.

  • merged_model_save_path
    (str) –

    The path to save the merged model.

  • save_tokenizer
    (bool) –

    Whether to save the tokenizer.

  • push_to_hub
    (bool) –

    Whether to push the model to the hub.

Source code in fusion_bench/method/weighted_average/llama.py
def __init__(
    self,
    normalize: bool,
    weights: List[float],
    backbone_only: bool,
    merged_model_save_path: str,
    save_tokenizer: bool,
    push_to_hub: bool,
    **kwargs,
):
    """
    Initialize the WeightedAverageForLLama class with the given parameters.

    Args:
        normalize (bool): Whether to normalize the weights.
        weights (List[float]): The weights for averaging the models.
        backbone_only (bool): Whether to use only the backbone of the models.
        merged_model_save_path (str): The path to save the merged model.
        save_tokenizer (bool): Whether to save the tokenizer.
        push_to_hub (bool): Whether to push the model to the hub.
    """
    self.normalize = normalize
    self.weights = weights
    self.backbone_only = backbone_only
    self.merged_model_save_path = merged_model_save_path
    self.save_tokenizer = save_tokenizer
    self.push_to_hub = push_to_hub
    super().__init__(**kwargs)
run(modelpool)

Executes the weighted averaging of models in the provided model pool.

Parameters:

  • modelpool
    (LLamaForCausalLMPoolThe) –

    pool of models to be averaged.

Returns:

  • base_model

    The base model after merging the state dictionaries of the models in the pool.

Raises:

  • ValueError

    If the number of weights does not match the number of models in the pool.

Source code in fusion_bench/method/weighted_average/llama.py
@override
@torch.no_grad()
def run(self, modelpool: CausalLMPool):
    """
    Executes the weighted averaging of models in the provided model pool.

    Args:
        modelpool (LLamaForCausalLMPoolThe):  pool of models to be averaged.

    Returns:
        base_model: The base model after merging the state dictionaries of the models in the pool.

    Raises:
        ValueError: If the number of weights does not match the number of models in the pool.
    """
    if modelpool.has_pretrained:
        base_model = modelpool.load_model("_pretrained_")
    else:
        base_model = modelpool.load_model(modelpool.model_names[0])

    weights = self.weights
    if len(weights) != len(modelpool.model_names):
        raise ValueError(
            "Number of weights must match the number of models.,"
            f"but got {len(weights)} weights and {len(modelpool.model_names)} models."
            f"weights: {weights}, models: {modelpool.model_names}"
        )
    if self.normalize:
        weights = np.asarray(weights)
        weights = weights / np.sum(weights)

    merged_state_dict: StateDictType = None
    for model_name, weight in zip(modelpool.model_names, weights):
        model = modelpool.load_model(model_name, backbone_only=self.backbone_only)
        sd = state_dict_mul(model.state_dict(), weight)
        if merged_state_dict is None:
            merged_state_dict = sd
        else:
            merged_state_dict = state_dict_add(merged_state_dict, sd)

    base_model.load_state_dict(
        merged_state_dict, strict=False if self.backbone_only else True
    )
    if self.merged_model_save_path is not None:
        with timeit_context(
            f"Saving the merged model to {self.merged_model_save_path}"
        ):
            modelpool.save_model(
                base_model,
                path=self.merged_model_save_path,
                save_tokenizer=self.save_tokenizer,
                push_to_hub=self.push_to_hub,
            )
    return base_model