Skip to content

Weighted Averaging

Weighted averaging, also known as weight-ensembling, combines multiple models by averaging their parameters according to specified weights. This approach allows for non-uniform combination of models, where better-performing or more reliable models can be given higher weights in the final merged model.

In the context of model fusion, if we have \(n\) models with their respective parameters \(\theta_i\) and model-wise weights \(w_i\), the parameters of the final merged model \(\theta\) are computed as:

\[ \theta = \sum_{i=1}^{n} w_i \theta_i \]

where the weights \(w_i\) can optionally be normalized to sum to 1.

Examples

CLI Usage

General Pytorch Models

The WeightedAverageAlgorithm works with general PyTorch models and performs weighted averaging of all model parameters.

Configuration template for the standard Weighted Averaging algorithm:

config/method/linear/weighted_average.yaml
_target_: fusion_bench.method.WeightedAverageAlgorithm
normalize: true # if true, the weights will be normalized before merging
weights: # List of weights for each model
  - 0.5
  - 0.5

Use the following command to run the Weighted Averaging algorithm:

fusion_bench method=linear/weighted_average ...

The following command merges eight CLIP-ViT models using a weighted average approach:

# Note: Since `method.normalize=true`, the weights are normalized to sum to 1, making this example equivalent to simple averaging.
fusion_bench \
    method=linear/weighted_average \
    method.normalize=true \
    method.weights=[0.3,0.3,0.3,0.3,0.3,0.3,0.3,0.3] \
    modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only \
    taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8

Large Language Models

The WeightedAverageForLLama is specialized for large language models with additional features:

  • Backbone-only merging option
  • Model saving capabilities
  • Hub integration support

Configuration template for LLaMA/Mistral model weighted averaging:

config/method/linear/weighted_average_for_llama.yaml
_target_: WeightedAverageForLLama
normalize: true # if true, the weights will be normalized before merging
weights: # List of weights for each model
  - 0.5
  - 0.5
# if true, only the backbone of the model will be merged and the head will be keeped as the pre-trained model (if the pre-trained model is provided, otherwise the head of the first model will be used)
# if false, the whole model will be merged
backbone_only: true
merged_model_save_path: null
save_tokenizer: true
push_to_hub: false

Use the following command:

fusion_bench method=linear/weighted_average_for_llama ...

Example of merging LLaMA models with different weights:

fusion_bench \
    method=linear/weighted_average_for_llama \
    method.weights=[0.3,0.7] \
    method.normalize=true \
    method.backbone_only=true \
    method.merged_model_save_path=outputs/merged_llama_model \
    modelpool=CausalLMPool/simle_mixtral_exp_v4.yaml \
    taskpool=dummy

API Usage

General Pytorch Models

from fusion_bench.method.weighted_average import WeightedAverageAlgorithm

# Create the algorithm with custom weights
algorithm = WeightedAverageAlgorithm(
    normalize=True,  # Normalize weights to sum to 1
    weights=[0.3, 0.5, 0.2],  # Custom weights for 3 models
    verbose=True
)

# Run on a model pool
merged_model = algorithm.run(modelpool)

Large Language Models

from fusion_bench.method import WeightedAverageForLLama

# Create the algorithm for LLaMA models
algorithm = WeightedAverageForLLama(
    normalize=True,
    weights=[0.4, 0.6],
    backbone_only=True,  # Only merge backbone, keep heads
    merged_model_save_path="./merged_model",
    save_tokenizer=True,
    push_to_hub=False
)

# Run on a CausalLMPool
merged_model = algorithm.run(causal_lm_pool)

Implementation Details