Weighted Averaging
Weighted averaging, also known as weight-ensembling.
In the context of full fine-tuned models, the weights are averaged according to their respective performance weights. Concretely, this means that if we have \(n\) models with their respective weights \(\theta_i\) and model-wise weights \(w_i\), the weights of the final model \(\theta\) are computed as:
\[ \theta = \sum_{i=1}^{n} w_i \theta_i \]
Examples
General Usage
Configuration template for the Weighted Averaging algorithm:
config/method/weighted_average.yamlname: weighted_average
normalize: true # if true, the weights will be normalized before merging
weights: # List of weights for each model
- 0.5
- 0.5
Use the following command to run the Weighted Averaging algorithm:
fusion_bench method=weighted_average ...
Merge CLIP-ViT Models
The following command merges eight clip-ViT models using a weighted average approach.
Because method.normalize
is set to true, the weights are normalized to sum to 1, thus equivalent to simple average.
fusion_bench \
method=weighted_average \
method.normalize=true \
method.weights=[0.3,0.3,0.3,0.3,0.3,0.3,0.3,0.3] \
modelpool=clip-vit-base-patch32_TA8_model_only \
taskpool=clip-vit-classification_TA8
Merge Llama/Mistral Models
Here is an example of how to use the Weighted Averaging algorithm to merge two LLama models. In particular, LLaMa models of the type transformers.LlamaForCausalLM
are merged using the Weighted Averaging algorithm.
fusion_bench \
method=weighted_average_for_llama \
method.merged_model_save_path=outputs/test_merged_llama_model \
modelpool=llama_for_causallm \
taskpool=dummy
or using the following configuration file config/llama_weighted_average.yaml
fusion_bench --config-name llama_weighted_average
config/llama_weighted_average.yamldefaults:
- example_config
- override method: weighted_average_for_llama
- override modelpool: llama_for_causallm
- _self_
modelpool:
models:
# the pre-trained model (base model) is optional
# if not provided, the first model will be used as the base model
- name: _pretrained_
path: meta-llama/Meta-Llama-3-8B
- name: expert_1
path: meta-llama/Meta-Llama-3-8B
- name: expert_2
path: meta-llama/Meta-Llama-3-8B-Instruct
method:
normalize: true # if true, the weights will be normalized before merging
weights: # List of weights for each model
- 0.5
- 0.5
# if true, only the backbone of the model will be merged and the head will be keeped as the pre-trained model (if the pre-trained model is provided, otherwise the head of the first model will be used)
# if false, the whole model will be merged
backbone_only: true
merged_model_save_path: null
save_tokenizer: true
push_to_hub: false
References
WeightedAverageAlgorithm
Bases: ModelFusionAlgorithm
, SimpleProfilerMixin
Source code in fusion_bench/method/weighted_average/weighted_average.py
| class WeightedAverageAlgorithm(ModelFusionAlgorithm, SimpleProfilerMixin):
@override
@torch.no_grad()
def run(self, modelpool: ModelPool):
"""
Fuses the models in the model pool using a weighted average approach.
Parameters
modelpool (ModelPool): The pool of models to be fused.
Raises
ValueError: If the number of weights does not match the number of models in the model pool.
Returns
forward_model (torch.nn.Module): The resulting model after fusion.
"""
modelpool = to_modelpool(modelpool)
log.info("Fusing models using weighted average.")
weights = np.asarray(self.config.weights)
if len(weights) != len(modelpool.model_names):
raise ValueError(
"Number of weights must match the number of models.,"
f"but got {len(weights)} weights and {len(modelpool.model_names)} models."
f"weights: {weights}, models: {modelpool.model_names}"
)
if self.config.normalize:
weights = weights / np.sum(weights)
print(f"weights: {weights}, normalized: {self.config.normalize}")
sd: Optional[StateDictType] = None
forward_model = None
for model_name, weight in zip(modelpool.model_names, weights):
with self.profile("load_model"):
model = modelpool.load_model(model_name)
with self.profile("merge weights"):
if sd is None:
sd = state_dict_mul(model.state_dict(keep_vars=True), weight)
forward_model = model
else:
sd = state_dict_add(
sd, state_dict_mul(model.state_dict(keep_vars=True), weight)
)
forward_model.load_state_dict(sd)
self.print_profile_summary()
return forward_model
|
run(modelpool)
Fuses the models in the model pool using a weighted average approach.
Parameters
modelpool (ModelPool): The pool of models to be fused.
Raises
ValueError: If the number of weights does not match the number of models in the model pool.
Returns
forward_model (torch.nn.Module): The resulting model after fusion.
Source code in fusion_bench/method/weighted_average/weighted_average.py
| @override
@torch.no_grad()
def run(self, modelpool: ModelPool):
"""
Fuses the models in the model pool using a weighted average approach.
Parameters
modelpool (ModelPool): The pool of models to be fused.
Raises
ValueError: If the number of weights does not match the number of models in the model pool.
Returns
forward_model (torch.nn.Module): The resulting model after fusion.
"""
modelpool = to_modelpool(modelpool)
log.info("Fusing models using weighted average.")
weights = np.asarray(self.config.weights)
if len(weights) != len(modelpool.model_names):
raise ValueError(
"Number of weights must match the number of models.,"
f"but got {len(weights)} weights and {len(modelpool.model_names)} models."
f"weights: {weights}, models: {modelpool.model_names}"
)
if self.config.normalize:
weights = weights / np.sum(weights)
print(f"weights: {weights}, normalized: {self.config.normalize}")
sd: Optional[StateDictType] = None
forward_model = None
for model_name, weight in zip(modelpool.model_names, weights):
with self.profile("load_model"):
model = modelpool.load_model(model_name)
with self.profile("merge weights"):
if sd is None:
sd = state_dict_mul(model.state_dict(keep_vars=True), weight)
forward_model = model
else:
sd = state_dict_add(
sd, state_dict_mul(model.state_dict(keep_vars=True), weight)
)
forward_model.load_state_dict(sd)
self.print_profile_summary()
return forward_model
|
WeightedAverageForLLama
Bases: ModelFusionAlgorithm
A class to perform weighted averaging of models in a LLamaForCausalLMPool.
Attributes:
-
config
(DictConfig
)
–
Configuration parameters for the weighted averaging process.
Methods:
-
run
–
LLamaForCausalLMPool):
Executes the weighted averaging of models in the provided model pool.
Source code in fusion_bench/method/weighted_average/llama.py
| class WeightedAverageForLLama(ModelFusionAlgorithm):
"""
A class to perform weighted averaging of models in a LLamaForCausalLMPool.
Attributes:
config (DictConfig): Configuration parameters for the weighted averaging process.
Methods:
run(modelpool: LLamaForCausalLMPool):
Executes the weighted averaging of models in the provided model pool.
"""
@torch.no_grad()
@override
def run(self, modelpool: LLamaForCausalLMPool):
"""
Executes the weighted averaging of models in the provided model pool.
Args:
modelpool (LLamaForCausalLMPoolThe): pool of models to be averaged.
Returns:
base_model: The base model after merging the state dictionaries of the models in the pool.
Raises:
ValueError: If the number of weights does not match the number of models in the pool.
"""
config = self.config
if modelpool.has_pretrained:
base_model = modelpool.load_model("_pretrained_")
else:
base_model = modelpool.load_model(modelpool.model_names[0])
weights = config.weights
if len(weights) != len(modelpool.model_names):
raise ValueError(
"Number of weights must match the number of models.,"
f"but got {len(weights)} weights and {len(modelpool.model_names)} models."
f"weights: {weights}, models: {modelpool.model_names}"
)
if self.config.normalize:
weights = np.asarray(weights)
weights = weights / np.sum(weights)
merged_state_dict = None
for model_name, weight in zip(modelpool.model_names, weights):
model = modelpool.load_model(model_name, backbone_only=config.backbone_only)
sd = state_dict_mul(model.state_dict(), weight)
if merged_state_dict is None:
merged_state_dict = sd
else:
merged_state_dict = state_dict_add(merged_state_dict, sd)
base_model.load_state_dict(
merged_state_dict, strict=False if config.backbone_only else True
)
if config.merged_model_save_path is not None:
with timeit_context(
f"Saving the merged model to {config.merged_model_save_path}"
):
modelpool.save_model(
base_model,
path=config.merged_model_save_path,
save_tokenizer=config.save_tokenizer,
push_to_hub=config.push_to_hub,
)
return base_model
|
run(modelpool)
Executes the weighted averaging of models in the provided model pool.
Parameters:
-
modelpool
(LLamaForCausalLMPoolThe
)
–
pool of models to be averaged.
Returns:
-
base_model
–
The base model after merging the state dictionaries of the models in the pool.
Raises:
-
ValueError
–
If the number of weights does not match the number of models in the pool.
Source code in fusion_bench/method/weighted_average/llama.py
| @torch.no_grad()
@override
def run(self, modelpool: LLamaForCausalLMPool):
"""
Executes the weighted averaging of models in the provided model pool.
Args:
modelpool (LLamaForCausalLMPoolThe): pool of models to be averaged.
Returns:
base_model: The base model after merging the state dictionaries of the models in the pool.
Raises:
ValueError: If the number of weights does not match the number of models in the pool.
"""
config = self.config
if modelpool.has_pretrained:
base_model = modelpool.load_model("_pretrained_")
else:
base_model = modelpool.load_model(modelpool.model_names[0])
weights = config.weights
if len(weights) != len(modelpool.model_names):
raise ValueError(
"Number of weights must match the number of models.,"
f"but got {len(weights)} weights and {len(modelpool.model_names)} models."
f"weights: {weights}, models: {modelpool.model_names}"
)
if self.config.normalize:
weights = np.asarray(weights)
weights = weights / np.sum(weights)
merged_state_dict = None
for model_name, weight in zip(modelpool.model_names, weights):
model = modelpool.load_model(model_name, backbone_only=config.backbone_only)
sd = state_dict_mul(model.state_dict(), weight)
if merged_state_dict is None:
merged_state_dict = sd
else:
merged_state_dict = state_dict_add(merged_state_dict, sd)
base_model.load_state_dict(
merged_state_dict, strict=False if config.backbone_only else True
)
if config.merged_model_save_path is not None:
with timeit_context(
f"Saving the merged model to {config.merged_model_save_path}"
):
modelpool.save_model(
base_model,
path=config.merged_model_save_path,
save_tokenizer=config.save_tokenizer,
push_to_hub=config.push_to_hub,
)
return base_model
|