Skip to content

Weighted Averaging

Weighted averaging, also known as weight-ensembling. In the context of full fine-tuned models, the weights are averaged according to their respective performance weights. Concretely, this means that if we have \(n\) models with their respective weights \(\theta_i\) and model-wise weights \(w_i\), the weights of the final model \(\theta\) are computed as:

\[ \theta = \sum_{i=1}^{n} w_i \theta_i \]

Examples

General Usage

Configuration template for the Weighted Averaging algorithm:

config/method/weighted_average.yaml
name: weighted_average
normalize: true # if true, the weights will be normalized before merging
weights: # List of weights for each model
  - 0.5
  - 0.5

Use the following command to run the Weighted Averaging algorithm:

fusion_bench method=weighted_average ...

Merge CLIP-ViT Models

The following command merges eight clip-ViT models using a weighted average approach. Because method.normalize is set to true, the weights are normalized to sum to 1, thus equivalent to simple average.

fusion_bench \
    method=weighted_average \
    method.normalize=true \
    method.weights=[0.3,0.3,0.3,0.3,0.3,0.3,0.3,0.3] \
    modelpool=clip-vit-base-patch32_TA8_model_only \
    taskpool=clip-vit-classification_TA8

Merge Llama/Mistral Models

Here is an example of how to use the Weighted Averaging algorithm to merge two LLama models. In particular, LLaMa models of the type transformers.LlamaForCausalLM are merged using the Weighted Averaging algorithm.

fusion_bench \
    method=weighted_average_for_llama \
    method.merged_model_save_path=outputs/test_merged_llama_model \
    modelpool=llama_for_causallm \
    taskpool=dummy

or using the following configuration file config/llama_weighted_average.yaml

fusion_bench --config-name llama_weighted_average
config/llama_weighted_average.yaml
defaults:
  - example_config
  - override method: weighted_average_for_llama
  - override modelpool: llama_for_causallm
  - _self_

modelpool:
  models:
    # the pre-trained model (base model) is optional
    # if not provided, the first model will be used as the base model
    - name: _pretrained_
      path: meta-llama/Meta-Llama-3-8B
    - name: expert_1
      path: meta-llama/Meta-Llama-3-8B
    - name: expert_2
      path: meta-llama/Meta-Llama-3-8B-Instruct

method:
  normalize: true # if true, the weights will be normalized before merging
  weights: # List of weights for each model
    - 0.5
    - 0.5
  # if true, only the backbone of the model will be merged and the head will be keeped as the pre-trained model (if the pre-trained model is provided, otherwise the head of the first model will be used)
  # if false, the whole model will be merged
  backbone_only: true

  merged_model_save_path: null
  save_tokenizer: true
  push_to_hub: false

References

WeightedAverageAlgorithm

Bases: ModelFusionAlgorithm, SimpleProfilerMixin

Source code in fusion_bench/method/weighted_average/weighted_average.py
class WeightedAverageAlgorithm(ModelFusionAlgorithm, SimpleProfilerMixin):
    @override
    @torch.no_grad()
    def run(self, modelpool: ModelPool):
        """
        Fuses the models in the model pool using a weighted average approach.

        Parameters
            modelpool (ModelPool): The pool of models to be fused.

        Raises
            ValueError: If the number of weights does not match the number of models in the model pool.

        Returns
            forward_model (torch.nn.Module): The resulting model after fusion.
        """
        modelpool = to_modelpool(modelpool)
        log.info("Fusing models using weighted average.")
        weights = np.asarray(self.config.weights)
        if len(weights) != len(modelpool.model_names):
            raise ValueError(
                "Number of weights must match the number of models.,"
                f"but got {len(weights)} weights and {len(modelpool.model_names)} models."
                f"weights: {weights}, models: {modelpool.model_names}"
            )
        if self.config.normalize:
            weights = weights / np.sum(weights)
        print(f"weights: {weights}, normalized: {self.config.normalize}")

        sd: Optional[StateDictType] = None
        forward_model = None

        for model_name, weight in zip(modelpool.model_names, weights):
            with self.profile("load_model"):
                model = modelpool.load_model(model_name)
            with self.profile("merge weights"):
                if sd is None:
                    sd = state_dict_mul(model.state_dict(keep_vars=True), weight)
                    forward_model = model
                else:
                    sd = state_dict_add(
                        sd, state_dict_mul(model.state_dict(keep_vars=True), weight)
                    )

        forward_model.load_state_dict(sd)
        self.print_profile_summary()
        return forward_model
run(modelpool)

Fuses the models in the model pool using a weighted average approach.

Parameters modelpool (ModelPool): The pool of models to be fused.

Raises ValueError: If the number of weights does not match the number of models in the model pool.

Returns forward_model (torch.nn.Module): The resulting model after fusion.

Source code in fusion_bench/method/weighted_average/weighted_average.py
@override
@torch.no_grad()
def run(self, modelpool: ModelPool):
    """
    Fuses the models in the model pool using a weighted average approach.

    Parameters
        modelpool (ModelPool): The pool of models to be fused.

    Raises
        ValueError: If the number of weights does not match the number of models in the model pool.

    Returns
        forward_model (torch.nn.Module): The resulting model after fusion.
    """
    modelpool = to_modelpool(modelpool)
    log.info("Fusing models using weighted average.")
    weights = np.asarray(self.config.weights)
    if len(weights) != len(modelpool.model_names):
        raise ValueError(
            "Number of weights must match the number of models.,"
            f"but got {len(weights)} weights and {len(modelpool.model_names)} models."
            f"weights: {weights}, models: {modelpool.model_names}"
        )
    if self.config.normalize:
        weights = weights / np.sum(weights)
    print(f"weights: {weights}, normalized: {self.config.normalize}")

    sd: Optional[StateDictType] = None
    forward_model = None

    for model_name, weight in zip(modelpool.model_names, weights):
        with self.profile("load_model"):
            model = modelpool.load_model(model_name)
        with self.profile("merge weights"):
            if sd is None:
                sd = state_dict_mul(model.state_dict(keep_vars=True), weight)
                forward_model = model
            else:
                sd = state_dict_add(
                    sd, state_dict_mul(model.state_dict(keep_vars=True), weight)
                )

    forward_model.load_state_dict(sd)
    self.print_profile_summary()
    return forward_model

WeightedAverageForLLama

Bases: ModelFusionAlgorithm

A class to perform weighted averaging of models in a LLamaForCausalLMPool.

Attributes:

  • config (DictConfig) –

    Configuration parameters for the weighted averaging process.

Methods:

  • run

    LLamaForCausalLMPool): Executes the weighted averaging of models in the provided model pool.

Source code in fusion_bench/method/weighted_average/llama.py
class WeightedAverageForLLama(ModelFusionAlgorithm):
    """
    A class to perform weighted averaging of models in a LLamaForCausalLMPool.

    Attributes:
        config (DictConfig): Configuration parameters for the weighted averaging process.

    Methods:
        run(modelpool: LLamaForCausalLMPool):
            Executes the weighted averaging of models in the provided model pool.
    """

    @torch.no_grad()
    @override
    def run(self, modelpool: LLamaForCausalLMPool):
        """
        Executes the weighted averaging of models in the provided model pool.

        Args:
            modelpool (LLamaForCausalLMPoolThe):  pool of models to be averaged.

        Returns:
            base_model: The base model after merging the state dictionaries of the models in the pool.

        Raises:
            ValueError: If the number of weights does not match the number of models in the pool.
        """
        config = self.config
        if modelpool.has_pretrained:
            base_model = modelpool.load_model("_pretrained_")
        else:
            base_model = modelpool.load_model(modelpool.model_names[0])

        weights = config.weights
        if len(weights) != len(modelpool.model_names):
            raise ValueError(
                "Number of weights must match the number of models.,"
                f"but got {len(weights)} weights and {len(modelpool.model_names)} models."
                f"weights: {weights}, models: {modelpool.model_names}"
            )
        if self.config.normalize:
            weights = np.asarray(weights)
            weights = weights / np.sum(weights)

        merged_state_dict = None
        for model_name, weight in zip(modelpool.model_names, weights):
            model = modelpool.load_model(model_name, backbone_only=config.backbone_only)
            sd = state_dict_mul(model.state_dict(), weight)
            if merged_state_dict is None:
                merged_state_dict = sd
            else:
                merged_state_dict = state_dict_add(merged_state_dict, sd)

        base_model.load_state_dict(
            merged_state_dict, strict=False if config.backbone_only else True
        )
        if config.merged_model_save_path is not None:
            with timeit_context(
                f"Saving the merged model to {config.merged_model_save_path}"
            ):
                modelpool.save_model(
                    base_model,
                    path=config.merged_model_save_path,
                    save_tokenizer=config.save_tokenizer,
                    push_to_hub=config.push_to_hub,
                )
        return base_model
run(modelpool)

Executes the weighted averaging of models in the provided model pool.

Parameters:

  • modelpool (LLamaForCausalLMPoolThe) –

    pool of models to be averaged.

Returns:

  • base_model

    The base model after merging the state dictionaries of the models in the pool.

Raises:

  • ValueError

    If the number of weights does not match the number of models in the pool.

Source code in fusion_bench/method/weighted_average/llama.py
@torch.no_grad()
@override
def run(self, modelpool: LLamaForCausalLMPool):
    """
    Executes the weighted averaging of models in the provided model pool.

    Args:
        modelpool (LLamaForCausalLMPoolThe):  pool of models to be averaged.

    Returns:
        base_model: The base model after merging the state dictionaries of the models in the pool.

    Raises:
        ValueError: If the number of weights does not match the number of models in the pool.
    """
    config = self.config
    if modelpool.has_pretrained:
        base_model = modelpool.load_model("_pretrained_")
    else:
        base_model = modelpool.load_model(modelpool.model_names[0])

    weights = config.weights
    if len(weights) != len(modelpool.model_names):
        raise ValueError(
            "Number of weights must match the number of models.,"
            f"but got {len(weights)} weights and {len(modelpool.model_names)} models."
            f"weights: {weights}, models: {modelpool.model_names}"
        )
    if self.config.normalize:
        weights = np.asarray(weights)
        weights = weights / np.sum(weights)

    merged_state_dict = None
    for model_name, weight in zip(modelpool.model_names, weights):
        model = modelpool.load_model(model_name, backbone_only=config.backbone_only)
        sd = state_dict_mul(model.state_dict(), weight)
        if merged_state_dict is None:
            merged_state_dict = sd
        else:
            merged_state_dict = state_dict_add(merged_state_dict, sd)

    base_model.load_state_dict(
        merged_state_dict, strict=False if config.backbone_only else True
    )
    if config.merged_model_save_path is not None:
        with timeit_context(
            f"Saving the merged model to {config.merged_model_save_path}"
        ):
            modelpool.save_model(
                base_model,
                path=config.merged_model_save_path,
                save_tokenizer=config.save_tokenizer,
                push_to_hub=config.push_to_hub,
            )
    return base_model