Weighted Averaging
Weighted averaging, also known as weight-ensembling.
In the context of full fine-tuned models, the weights are averaged according to their respective performance weights. Concretely, this means that if we have \(n\) models with their respective weights \(\theta_i\) and model-wise weights \(w_i\), the weights of the final model \(\theta\) are computed as:
\[ \theta = \sum_{i=1}^{n} w_i \theta_i \]
Examples
General Usage
Configuration template for the Weighted Averaging algorithm:
config/method/weighted_average.yamlname: weighted_average
normalize: true # if true, the weights will be normalized before merging
weights: # List of weights for each model
- 0.5
- 0.5
Use the following command to run the Weighted Averaging algorithm:
fusion_bench method=weighted_average ...
Merge CLIP-ViT Models
The following command merges eight clip-ViT models using a weighted average approach.
Because method.normalize
is set to true, the weights are normalized to sum to 1, thus equivalent to simple average.
fusion_bench \
method=weighted_average \
method.normalize=true \
method.weights=[0.3,0.3,0.3,0.3,0.3,0.3,0.3,0.3] \
modelpool=clip-vit-base-patch32_TA8_model_only \
taskpool=clip-vit-classification_TA8
Merge Llama/Mistral Models
Here is an example of how to use the Weighted Averaging algorithm to merge two LLama models. In particular, LLaMa models of the type transformers.LlamaForCausalLM
are merged using the Weighted Averaging algorithm.
fusion_bench \
method=weighted_average_for_llama \
method.merged_model_save_path=outputs/test_merged_llama_model \
modelpool=llama_for_causallm \
taskpool=dummy
or using the following configuration file config/llama_weighted_average.yaml
fusion_bench --config-name llama_weighted_average
config/llama_weighted_average.yamldefaults:
- example_config
- override method: weighted_average_for_llama
- override modelpool: llama_for_causallm
- _self_
modelpool:
models:
# the pre-trained model (base model) is optional
# if not provided, the first model will be used as the base model
- name: _pretrained_
path: meta-llama/Meta-Llama-3-8B
- name: expert_1
path: meta-llama/Meta-Llama-3-8B
- name: expert_2
path: meta-llama/Meta-Llama-3-8B-Instruct
method:
normalize: true # if true, the weights will be normalized before merging
weights: # List of weights for each model
- 0.5
- 0.5
# if true, only the backbone of the model will be merged and the head will be keeped as the pre-trained model (if the pre-trained model is provided, otherwise the head of the first model will be used)
# if false, the whole model will be merged
backbone_only: true
merged_model_save_path: null
save_tokenizer: true
push_to_hub: false
References
WeightedAverageAlgorithm
Bases: BaseAlgorithm
, SimpleProfilerMixin
Source code in fusion_bench/method/weighted_average/weighted_average.py
| class WeightedAverageAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
_config_mapping = BaseAlgorithm._config_mapping | {
"normalize": "normalize",
"weights": "weights",
}
def __init__(
self,
normalize: bool,
weights: List[float],
verbose: bool = True,
**kwargs,
):
self.normalize = normalize
self.weights = weights
self.verbose = verbose
log.disabled = not self.verbose
super().__init__(**kwargs)
@override
@torch.no_grad()
def run(self, modelpool: BaseModelPool):
"""
Fuses the models in the model pool using a weighted average approach.
Parameters
modelpool (ModelPool): The pool of models to be fused.
Raises
ValueError: If the number of weights does not match the number of models in the model pool.
Returns
forward_model (torch.nn.Module): The resulting model after fusion.
"""
if not isinstance(modelpool, BaseModelPool):
modelpool = BaseModelPool(modelpool)
log.info("Fusing models using weighted average.")
weights = np.asarray(self.weights)
if len(weights) != len(modelpool.model_names):
raise ValueError(
"Number of weights must match the number of models.,"
f"but got {len(weights)} weights and {len(modelpool.model_names)} models."
f"weights: {weights}, models: {modelpool.model_names}"
)
if self.normalize:
weights = weights / np.sum(weights)
if self.verbose:
print(f"weights: {weights}, normalized: {self.normalize}")
sd: Optional[StateDictType] = None
forward_model = None
for model_name, weight in zip(modelpool.model_names, weights):
with self.profile("load_model"):
model = modelpool.load_model(model_name)
with self.profile("merge weights"):
if sd is None:
sd = state_dict_mul(model.state_dict(keep_vars=True), weight)
forward_model = model
else:
sd = state_dict_add(
sd, state_dict_mul(model.state_dict(keep_vars=True), weight)
)
forward_model.load_state_dict(sd)
if self.verbose:
self.print_profile_summary()
return forward_model
|
_config_mapping = BaseAlgorithm._config_mapping | {'normalize': 'normalize', 'weights': 'weights'}
class-attribute
instance-attribute
normalize = normalize
instance-attribute
verbose = verbose
instance-attribute
weights = weights
instance-attribute
__init__(normalize, weights, verbose=True, **kwargs)
Source code in fusion_bench/method/weighted_average/weighted_average.py
| def __init__(
self,
normalize: bool,
weights: List[float],
verbose: bool = True,
**kwargs,
):
self.normalize = normalize
self.weights = weights
self.verbose = verbose
log.disabled = not self.verbose
super().__init__(**kwargs)
|
run(modelpool)
Fuses the models in the model pool using a weighted average approach.
Parameters
modelpool (ModelPool): The pool of models to be fused.
Raises
ValueError: If the number of weights does not match the number of models in the model pool.
Returns
forward_model (torch.nn.Module): The resulting model after fusion.
Source code in fusion_bench/method/weighted_average/weighted_average.py
| @override
@torch.no_grad()
def run(self, modelpool: BaseModelPool):
"""
Fuses the models in the model pool using a weighted average approach.
Parameters
modelpool (ModelPool): The pool of models to be fused.
Raises
ValueError: If the number of weights does not match the number of models in the model pool.
Returns
forward_model (torch.nn.Module): The resulting model after fusion.
"""
if not isinstance(modelpool, BaseModelPool):
modelpool = BaseModelPool(modelpool)
log.info("Fusing models using weighted average.")
weights = np.asarray(self.weights)
if len(weights) != len(modelpool.model_names):
raise ValueError(
"Number of weights must match the number of models.,"
f"but got {len(weights)} weights and {len(modelpool.model_names)} models."
f"weights: {weights}, models: {modelpool.model_names}"
)
if self.normalize:
weights = weights / np.sum(weights)
if self.verbose:
print(f"weights: {weights}, normalized: {self.normalize}")
sd: Optional[StateDictType] = None
forward_model = None
for model_name, weight in zip(modelpool.model_names, weights):
with self.profile("load_model"):
model = modelpool.load_model(model_name)
with self.profile("merge weights"):
if sd is None:
sd = state_dict_mul(model.state_dict(keep_vars=True), weight)
forward_model = model
else:
sd = state_dict_add(
sd, state_dict_mul(model.state_dict(keep_vars=True), weight)
)
forward_model.load_state_dict(sd)
if self.verbose:
self.print_profile_summary()
return forward_model
|
WeightedAverageForLLama
Bases: BaseAlgorithm
A class to perform weighted averaging of LlaMa/Mistral models.
Source code in fusion_bench/method/weighted_average/llama.py
| class WeightedAverageForLLama(BaseAlgorithm):
"""
A class to perform weighted averaging of LlaMa/Mistral models.
"""
_config_mapping = BaseAlgorithm._config_mapping | {
"normalize": "normalize",
"weights": "weights",
"backbone_only": "backbone_only",
"merged_model_save_path": "merged_model_save_path",
"save_tokenizer": "save_tokenizer",
"push_to_hub": "push_to_hub",
}
def __init__(
self,
normalize: bool,
weights: List[float],
backbone_only: bool,
merged_model_save_path: str,
save_tokenizer: bool,
push_to_hub: bool,
**kwargs,
):
"""
Initialize the WeightedAverageForLLama class with the given parameters.
Args:
normalize (bool): Whether to normalize the weights.
weights (List[float]): The weights for averaging the models.
backbone_only (bool): Whether to use only the backbone of the models.
merged_model_save_path (str): The path to save the merged model.
save_tokenizer (bool): Whether to save the tokenizer.
push_to_hub (bool): Whether to push the model to the hub.
"""
self.normalize = normalize
self.weights = weights
self.backbone_only = backbone_only
self.merged_model_save_path = merged_model_save_path
self.save_tokenizer = save_tokenizer
self.push_to_hub = push_to_hub
super().__init__(**kwargs)
@override
@torch.no_grad()
def run(self, modelpool: CausalLMPool):
"""
Executes the weighted averaging of models in the provided model pool.
Args:
modelpool (LLamaForCausalLMPoolThe): pool of models to be averaged.
Returns:
base_model: The base model after merging the state dictionaries of the models in the pool.
Raises:
ValueError: If the number of weights does not match the number of models in the pool.
"""
if modelpool.has_pretrained:
base_model = modelpool.load_model("_pretrained_")
else:
base_model = modelpool.load_model(modelpool.model_names[0])
weights = self.weights
if len(weights) != len(modelpool.model_names):
raise ValueError(
"Number of weights must match the number of models.,"
f"but got {len(weights)} weights and {len(modelpool.model_names)} models."
f"weights: {weights}, models: {modelpool.model_names}"
)
if self.normalize:
weights = np.asarray(weights)
weights = weights / np.sum(weights)
merged_state_dict: StateDictType = None
for model_name, weight in zip(modelpool.model_names, weights):
model = modelpool.load_model(model_name, backbone_only=self.backbone_only)
sd = state_dict_mul(model.state_dict(), weight)
if merged_state_dict is None:
merged_state_dict = sd
else:
merged_state_dict = state_dict_add(merged_state_dict, sd)
base_model.load_state_dict(
merged_state_dict, strict=False if self.backbone_only else True
)
if self.merged_model_save_path is not None:
with timeit_context(
f"Saving the merged model to {self.merged_model_save_path}"
):
modelpool.save_model(
base_model,
path=self.merged_model_save_path,
save_tokenizer=self.save_tokenizer,
push_to_hub=self.push_to_hub,
)
return base_model
|
__init__(normalize, weights, backbone_only, merged_model_save_path, save_tokenizer, push_to_hub, **kwargs)
Initialize the WeightedAverageForLLama class with the given parameters.
Parameters:
-
normalize
(bool
)
–
Whether to normalize the weights.
-
weights
(List[float]
)
–
The weights for averaging the models.
-
backbone_only
(bool
)
–
Whether to use only the backbone of the models.
-
merged_model_save_path
(str
)
–
The path to save the merged model.
-
save_tokenizer
(bool
)
–
Whether to save the tokenizer.
-
push_to_hub
(bool
)
–
Whether to push the model to the hub.
Source code in fusion_bench/method/weighted_average/llama.py
| def __init__(
self,
normalize: bool,
weights: List[float],
backbone_only: bool,
merged_model_save_path: str,
save_tokenizer: bool,
push_to_hub: bool,
**kwargs,
):
"""
Initialize the WeightedAverageForLLama class with the given parameters.
Args:
normalize (bool): Whether to normalize the weights.
weights (List[float]): The weights for averaging the models.
backbone_only (bool): Whether to use only the backbone of the models.
merged_model_save_path (str): The path to save the merged model.
save_tokenizer (bool): Whether to save the tokenizer.
push_to_hub (bool): Whether to push the model to the hub.
"""
self.normalize = normalize
self.weights = weights
self.backbone_only = backbone_only
self.merged_model_save_path = merged_model_save_path
self.save_tokenizer = save_tokenizer
self.push_to_hub = push_to_hub
super().__init__(**kwargs)
|
run(modelpool)
Executes the weighted averaging of models in the provided model pool.
Parameters:
-
modelpool
(LLamaForCausalLMPoolThe
)
–
pool of models to be averaged.
Returns:
-
base_model
–
The base model after merging the state dictionaries of the models in the pool.
Raises:
-
ValueError
–
If the number of weights does not match the number of models in the pool.
Source code in fusion_bench/method/weighted_average/llama.py
| @override
@torch.no_grad()
def run(self, modelpool: CausalLMPool):
"""
Executes the weighted averaging of models in the provided model pool.
Args:
modelpool (LLamaForCausalLMPoolThe): pool of models to be averaged.
Returns:
base_model: The base model after merging the state dictionaries of the models in the pool.
Raises:
ValueError: If the number of weights does not match the number of models in the pool.
"""
if modelpool.has_pretrained:
base_model = modelpool.load_model("_pretrained_")
else:
base_model = modelpool.load_model(modelpool.model_names[0])
weights = self.weights
if len(weights) != len(modelpool.model_names):
raise ValueError(
"Number of weights must match the number of models.,"
f"but got {len(weights)} weights and {len(modelpool.model_names)} models."
f"weights: {weights}, models: {modelpool.model_names}"
)
if self.normalize:
weights = np.asarray(weights)
weights = weights / np.sum(weights)
merged_state_dict: StateDictType = None
for model_name, weight in zip(modelpool.model_names, weights):
model = modelpool.load_model(model_name, backbone_only=self.backbone_only)
sd = state_dict_mul(model.state_dict(), weight)
if merged_state_dict is None:
merged_state_dict = sd
else:
merged_state_dict = state_dict_add(merged_state_dict, sd)
base_model.load_state_dict(
merged_state_dict, strict=False if self.backbone_only else True
)
if self.merged_model_save_path is not None:
with timeit_context(
f"Saving the merged model to {self.merged_model_save_path}"
):
modelpool.save_model(
base_model,
path=self.merged_model_save_path,
save_tokenizer=self.save_tokenizer,
push_to_hub=self.push_to_hub,
)
return base_model
|