Linear Model Merging Methods¶
Linear model merging encompasses a family of methods that combine model parameters through linear operations -- interpolation, extrapolation, and weighted averaging. These methods form the foundation of model fusion, with more advanced techniques often building upon them.
Overview¶
Linear merging methods operate by computing a linear combination of model parameters. Given models with parameters \(\theta_1, \theta_2, \dots, \theta_K\), the merged model is:
where \(\sum_i w_i = 1\) and \(w_i\) are the merging weights. The specific choice of weights and the relationship between the models define the variant.
FusionBench implements several linear merging methods:
- Linear Interpolation: Interpolates between two models with a parameter \(t \in [0, 1]\).
- ExPO (Extrapolation): Extrapolates from a pretrained model through a fine-tuned model.
- ExPO for LLaMA: A LLaMA-specific variant of ExPO with layer-wise control.
- Simple Average for Causal LM: Uniform averaging with optional backbone-only merging.
Linear Interpolation¶
The simplest linear merge method interpolates between two models:
where \(t \in [0, 1]\) controls the interpolation: \(t = 0\) yields \(\theta_1\), \(t = 1\) yields \(\theta_2\), and \(t = 0.5\) gives the simple average.
from fusion_bench.method import LinearInterpolationAlgorithm
algorithm = LinearInterpolationAlgorithm(t=0.5)
merged_model = algorithm.run(modelpool) # modelpool must have exactly 2 models
Configuration¶
# =============================================================================
# FusionBench Method Configuration: Linear Interpolation
# =============================================================================
# Interpolates between two models: (1 - t) * model0 + t * model1
#
# - t in [0,1]: 0 returns model0; 1 returns model1.
# - Only meaningful for two-model pools.
# =============================================================================
_target_: fusion_bench.method.LinearInterpolationAlgorithm
t: 0.5
CLI Usage¶
fusion_bench method=linear/linear_interpolation \
method.t=0.5 \
modelpool=CausalLMPool/two_models \
taskpool=...
ExPO (Extrapolation)¶
ExPO (Extrapolation from Pretrained to Optimized) extends the idea of linear interpolation into extrapolation. Given a pretrained (SFT) model \(\theta_{pre}\) and a fine-tuned (RLHF) model \(\theta_{ft}\), ExPO computes:
where \(\alpha\) is the extrapolation factor. When \(\alpha > 0\), the merged model lies on the ray from \(\theta_{pre}\) through \(\theta_{ft}\), beyond \(\theta_{ft}\). This can amplify the alignment improvements introduced by fine-tuning1.
General ExPO¶
For general nn.Module models, the ExPOAlgorithm class handles any model architecture:
from fusion_bench.method import ExPOAlgorithm
algorithm = ExPOAlgorithm(extrapolation_factor=0.1)
merged_model = algorithm.run(modelpool)
When multiple RLHF models are provided, ExPO first averages them via SimpleAverageAlgorithm, then extrapolates from the pretrained model through the averaged RLHF model.
Configuration¶
# =============================================================================
# FusionBench Method Configuration: ExPO
# =============================================================================
# Extrapolates from pretrained to finetuned direction by a factor.
# =============================================================================
# This algorithm merges a pretrained model with a finetuned model.
#
# $$\theta_{merged} = \theta_{ft} + \alpha (\theta_{ft} - \theta_{pre})$$
#
# where $\theta_{merged}$ is the merged model, $\theta_{ft}$ is the finetuned model (medium-aligned model),
# $\theta_{pre}$ is the pretrained model (base model), and $\alpha$ is the extrapolation factor.
_target_: fusion_bench.method.ExPOAlgorithm
extrapolation_factor: 0.1
CLI Usage¶
fusion_bench method=linear/expo \
method.extrapolation_factor=0.1 \
modelpool=CausalLMPool/sft_and_rlhf \
taskpool=...
ExPO for LLaMA¶
The ExPOAlgorithmForLlama class provides fine-grained control over which parts of a LLaMA model are extrapolated. This is critical because different components (attention, MLP, embeddings, lm_head) may benefit from different treatment.
Key Parameters¶
extrapolation_factor: The extrapolation coefficient \(\alpha\).attention_scaling_factor: Scales the extrapolation factor for attention layers separately. The effective factor for attention becomesextrapolation_factor * attention_scaling_factor.only_on_backbone: WhenTrue, only the backbone (transformer layers) is merged; the lm_head is kept from the RLHF model.on_linear_weights/on_linear_bias: Control whether linear weights and biases are extrapolated.on_embedding: Whether to extrapolate the token embedding layer.fix_first_n_layers/fix_last_n_layers: Skip extrapolation for the first/last N layers (supports"half"for half the layers).magnitude_sparsity_ratio: Optionally apply magnitude pruning to the delta vector before extrapolation.
Mathematical Formulation¶
For each layer \(l\), the LLaMA-specific ExPO applies:
where \(\alpha_l = \alpha \cdot \alpha_{attn}\) for attention layers and \(\alpha_l = \alpha\) for MLP layers.
If magnitude_sparsity_ratio is set, the delta \(\delta = \theta_{ft} - \theta_{pre}\) is first pruned via unstructured magnitude pruning before scaling.
ExPO with DARE for LLaMA¶
The ExPOWithDareForLLama variant first merges the RLHF models using DARE simple averaging (random drop and rescale), then applies ExPO extrapolation. This combines the benefits of DARE's interference reduction with ExPO's extrapolation:
from fusion_bench.method import ExPOWithDareForLLama
algorithm = ExPOWithDareForLLama(
extrapolation_factor=0.1,
dare_sparsity_ratio=0.5,
dare_only_on_linear_weights=True,
dare_rescale=True,
)
Configuration¶
# =============================================================================
# FusionBench Method Configuration: ExPO for LLaMA
# =============================================================================
# LLaMA-specific ExPO with backbone-only and attention scaling options.
# =============================================================================
# This algorithm merges a pretrained model with a finetuned model.
#
# $$\theta_{merged} = \theta_{ft} + \alpha (\theta_{ft} - \theta_{pre})$$
#
# where $\theta_{merged}$ is the merged model, $\theta_{ft}$ is the finetuned model (medium-aligned model),
# $\theta_{pre}$ is the pretrained model (base model), and $\alpha$ is the extrapolation factor.
_target_: fusion_bench.method.ExPOAlgorithmForLlama
extrapolation_factor: 0.1
attention_scaling_factor: 1.0
only_on_backbone: true
on_linear_weights: true
on_linear_bias: false
on_embedding: false
fix_last_n_layers: 0
fix_first_n_layers: 0
magnitude_sparsity_ratio: null
CLI Usage¶
fusion_bench method=linear/llama_expo \
method.extrapolation_factor=0.1 \
method.attention_scaling_factor=1.0 \
method.only_on_backbone=true \
modelpool=CausalLMPool/sft_and_rlhf \
taskpool=...
Simple Average for Causal LM¶
The SimpleAverageForCausalLM class extends the basic simple average with Causal LM-specific features:
merge_backbone: WhenTrue, only the backbone (transformer layers) is averaged. The lm_head is taken from the pretrained model. This is useful when merging models with different heads (e.g., chat vs. generation).model_save_path: Save the merged model and tokenizer to the specified path.show_pbar: Show a progress bar during merging.
Configuration¶
# =============================================================================
# FusionBench Method Configuration: Simple Average (Causal LM)
# =============================================================================
# Uniformly averages causal LM weights with optional backbone-only.
# =============================================================================
_target_: fusion_bench.method.SimpleAverageForCausalLM
# set `merge_backbone` to true if you has a base model and only want to merge the backbone of the experts
# if `merge_backbone` is False, this is equivalent to `SimpleAverageAlgorithm`
merge_backbone: false
model_save_path: ${path.log_dir}/checkpoint
show_pbar: true
CLI Usage¶
fusion_bench method=linear/simple_average_for_causallm \
method.merge_backbone=false \
method.model_save_path=outputs/merged_model \
method.show_pbar=true \
modelpool=CausalLMPool/multiple_models \
taskpool=...
API Usage¶
from fusion_bench.method import SimpleAverageForCausalLM
algorithm = SimpleAverageForCausalLM(
merge_backbone=False,
model_save_path="outputs/merged",
show_pbar=True,
)
merged_model = algorithm.run(modelpool)
Implementation Details¶
ExPO Merge Function¶
The core expo_merge() function implements the extrapolation at the parameter level:
def expo_merge(sft_model, rlhf_model, extrapolation_factor, inplace=True, enable_grad=False):
for (sft_name, sft_param), (rlhf_name, rlhf_param) in zip(
sft_model.named_parameters(), rlhf_model.named_parameters()
):
rlhf_param.data = rlhf_param.data + extrapolation_factor * (
rlhf_param.data - sft_param.data
)
return rlhf_model
Linear Interpolation¶
The LinearInterpolationAlgorithm uses state_dict_weighted_sum to combine two state dictionaries:
state_dict = state_dict_weighted_sum(
[primary_state_dict, secondary_state_dict], [1 - self.t, self.t]
)
Choosing a Method¶
| Scenario | Recommended Method |
|---|---|
| Two models, equal importance | Linear Interpolation (t=0.5) or Simple Average |
| Two models, unequal importance | Linear Interpolation with tuned \(t\) |
| Pretrained + aligned model | ExPO (general or LLaMA) |
| Multiple RLHF models + SFT | ExPO (auto-averages RLHF models) |
| Multiple RLHF + SFT, large models | ExPO with DARE for LLaMA |
| Causal LMs with different heads | Simple Average for Causal LM (merge_backbone=True) |
Implementation Details¶
- ExPOAlgorithm
- LinearInterpolationAlgorithm
- ExPOAlgorithmForLlama
- [ExPOWithDareForLLama][fusion_bench.method.ExPOWithDareForLLama]
- SimpleAverageForCausalLM
-
(2024) Zheng et al. Weak-to-Strong Extrapolation Expedites Alignment. https://arxiv.org/abs/2404.12717 ↩