(Diagonal) Fisher Merging
The Fisher merging algorithm is a per-parameter weighed averaging method that assigns weights to the models based on the Fisher information matrix of the models on some labeled data.
The Fisher information matrix \(F_\theta\) of a model with parameters \(\theta\) can be expressed as:
\[ F_\theta = \mathbb{E}_{x \sim p(x)} \left[ \nabla_\theta \log p(y|x, \theta) \nabla_\theta \log p(y|x, \theta)^T \right] \]
where \(p(x)\) is the data distribution, \(p(y|x, \theta)\) is the model's output distribution, for example, the softmax output of a classification model, and \(\nabla_\theta\) is the gradient with respect to the model's parameters \(\theta\).
The Fisher information matrix can be used to estimate the importance of each parameter in the model and thus assign weights to the models based on their Fisher information.
In addition, the Fisher information matrix can be used to estimate the similarity between tasks, which can be useful in auxiliary-task learning and multi-task learning scenarios .
As the full Fisher information matrix is often computationally expensive to compute and memory-intensive to store, we approximate using the diagonal Fisher information matrix, which is the diagonal of the full Fisher information matrix.
The diagonal Fisher information matrix can be computed as:
\[ \hat{F}_\theta = \mathbb{E}_{x \sim p(x)} \left[ \left(\nabla_\theta \log p(y|x, \theta)\right)^2 \right] \]
Assuming we have \(n\) models with parameters \(\theta_i\) and diagonal Fisher information matrices \(\hat{F}_{\theta_i}\), the Fisher merging algorithm computes the merged model's parameters \(\theta\) as follows:
\[ \theta^{(j)} = \frac{\sum_{i=1}^{n} \hat{F}_{\theta_i}^{(j)} \theta_i^{(j)}}{\sum_{i=1}^{n} \hat{F}_{\theta_i}^{(j)}} \]
where \(\theta_i\) are the parameters of the individual models, \(\hat{F}_{\theta_i}\) are the diagonal Fisher information matrices of the individual models, and \(j\) indexes the parameters of the models.
The Fisher merging algorithm can be considered a per-weight weighed averaging method, where the weights are determined by the Fisher information of each parameter in the models.
Code Integration
Example of merging eight CLIP-ViT-B/32 models using Fisher merging:
fusion_bench method=clip_fisher_merging \
modelpool=clip-vit-base-patch32_TA8 \
taskpool=clip-vit-classification_TA8
Merge eight CLIP-ViT-L/14 models using Fisher merging:
fusion_bench \
method=clip_fisher_merging \
method.batch_size=8 method.num_workers=4 \
modelpool=clip-vit-large-patch14_TA8 \
taskpool=clip-vit-classification_TA8 \
taskpool.clip_model=openai/clip-vit-large-patch14
Merge GPT-2 models for text classification tasks:
fusion_bench \
method=gpt2_fisher_merging \
method.num_fisher_examples=512 method.batch_size=8 \
modelpool=gpt-2_glue \
taskpool=gpt-2_glue
References
FisherMergingAlgorithm
Bases: BaseAlgorithm
Implements the Fisher Merging Algorithm.
This class extends the BaseModelFusionAlgorithm to handle merging of models using Fisher weights.
It supports excluding certain parameters, normalizing Fisher weights, and setting a minimal value for Fisher weights.
Methods:
-
run
–
BaseModelPool) -> nn.Module:
Executes the Fisher merging process on the model pool and returns the merged model.
Source code in fusion_bench/method/fisher_merging/fisher_merging.py
| class FisherMergingAlgorithm(BaseAlgorithm):
"""
Implements the Fisher Merging Algorithm.
This class extends the BaseModelFusionAlgorithm to handle merging of models using Fisher weights.
It supports excluding certain parameters, normalizing Fisher weights, and setting a minimal value for Fisher weights.
Methods:
run(modelpool: BaseModelPool) -> nn.Module:
Executes the Fisher merging process on the model pool and returns the merged model.
"""
_config_mapping = BaseAlgorithm._config_mapping | {
"exclude_param_names_regex": "exclude_param_names_regex",
"normalize_fisher_weight": "normalize_fisher_weight",
"minimal_fisher_weight": "minimal_fisher_weight",
"num_fisher_examples": "num_fisher_examples",
}
def __init__(
self,
*,
exclude_param_names_regex: list,
normalize_fisher_weight: bool,
minimal_fisher_weight: float,
num_fisher_examples: int,
):
super().__init__()
self.exclude_param_names_regex = exclude_param_names_regex
self.normalize_fisher_weight = normalize_fisher_weight
self.minimal_fisher_weight = minimal_fisher_weight
self.num_fisher_examples = num_fisher_examples
def run(self, modelpool: BaseModelPool) -> nn.Module:
"""
Run the Fisher Merging Algorithm.
This method constructs the wrapped model and performs test-time adaptation if necessary.
Args:
modelpool (BaseModelPool): The model pool containing the pretrained and fine-tuned models.
Returns:
nn.Module: The merged model after test-time adaptation.
"""
log.info("Running Fisher Merging Algorithm")
if isinstance(modelpool, (dict, list, tuple)):
modelpool = BaseModelPool(modelpool)
assert len(modelpool) > 0, "model pool is empty"
assert (
modelpool.has_pretrained
), "no pretrained model (base model) in the model pool"
self.modelpool = modelpool
self.on_fisher_merging_start()
# dictionary of list, where key is the parameter name,
# value is a list of the corresponding parameters of all the models that need to be merged
models_to_merge_param_dict = defaultdict(list)
# list of dictionaries with length len(models_to_merge),
# each dictionary records the fisher weights (matrix or vector) of parameters for each model that needs to be merged
models_to_merge_fisher_weights_list = []
param_names_to_merge = None
for name, model in modelpool.named_models():
param_dict = model.state_dict()
if param_names_to_merge is None:
param_names_to_merge = get_param_names_to_merge(
input_param_names=list(param_dict.keys()),
exclude_param_names_regex=self.config.get(
"exclude_param_names_regex", []
),
)
for param_name in param_names_to_merge:
models_to_merge_param_dict[param_name].append(param_dict[param_name])
model_to_merge_fisher_weights = self.get_fisher_weights(
model_name=name,
model=model,
train_dataset=modelpool.load_train_dataset(name),
param_names_to_merge=param_names_to_merge,
)
models_to_merge_fisher_weights_list.append(model_to_merge_fisher_weights)
merged_params = merging_with_fisher_weights(
models_to_merge_param_dict=models_to_merge_param_dict,
models_to_merge_fisher_weights_list=models_to_merge_fisher_weights_list,
fisher_scaling_coefficients=torch.ones(len(modelpool)) / len(modelpool),
normalize_fisher_weight=self.config.get("normalize_fisher_weight", True),
minimal_fisher_weight=self.config.get("minimal_fisher_weight", 1e-6),
)
merged_model = modelpool.load_model("_pretrained_")
merged_model.load_state_dict(merged_params, strict=False)
return merged_model
def get_fisher_weights(
self,
model_name: str,
model: nn.Module,
train_dataset,
param_names_to_merge: List[str],
) -> Dict[str, Tensor]:
"""
Compute the Fisher weights for the given model and training dataset.
Args:
model_name (str): The name of the model.
model (nn.Module): The model module.
train_dataset: The training dataset.
param_names_to_merge (List[str]): List of parameter names to merge.
Returns:
Dict[str, Tensor]: The computed Fisher weights for each parameter.
"""
# this function is used to compute fisher weights for a model
# it should be implemented in the subclass
raise NotImplementedError
def on_fisher_merging_start(self):
"""
Setup the zero-shot classification head before starting the Fisher merging process.
"""
# this function is used to initialize some variables before running fisher merging
pass
|
get_fisher_weights(model_name, model, train_dataset, param_names_to_merge)
Compute the Fisher weights for the given model and training dataset.
Parameters:
-
model_name
(str
)
–
-
model
(Module
)
–
-
train_dataset
–
-
param_names_to_merge
(List[str]
)
–
List of parameter names to merge.
Returns:
-
Dict[str, Tensor]
–
Dict[str, Tensor]: The computed Fisher weights for each parameter.
Source code in fusion_bench/method/fisher_merging/fisher_merging.py
| def get_fisher_weights(
self,
model_name: str,
model: nn.Module,
train_dataset,
param_names_to_merge: List[str],
) -> Dict[str, Tensor]:
"""
Compute the Fisher weights for the given model and training dataset.
Args:
model_name (str): The name of the model.
model (nn.Module): The model module.
train_dataset: The training dataset.
param_names_to_merge (List[str]): List of parameter names to merge.
Returns:
Dict[str, Tensor]: The computed Fisher weights for each parameter.
"""
# this function is used to compute fisher weights for a model
# it should be implemented in the subclass
raise NotImplementedError
|
on_fisher_merging_start()
Setup the zero-shot classification head before starting the Fisher merging process.
Source code in fusion_bench/method/fisher_merging/fisher_merging.py
| def on_fisher_merging_start(self):
"""
Setup the zero-shot classification head before starting the Fisher merging process.
"""
# this function is used to initialize some variables before running fisher merging
pass
|
run(modelpool)
Run the Fisher Merging Algorithm.
This method constructs the wrapped model and performs test-time adaptation if necessary.
Parameters:
-
modelpool
(BaseModelPool
)
–
The model pool containing the pretrained and fine-tuned models.
Returns:
-
Module
–
nn.Module: The merged model after test-time adaptation.
Source code in fusion_bench/method/fisher_merging/fisher_merging.py
| def run(self, modelpool: BaseModelPool) -> nn.Module:
"""
Run the Fisher Merging Algorithm.
This method constructs the wrapped model and performs test-time adaptation if necessary.
Args:
modelpool (BaseModelPool): The model pool containing the pretrained and fine-tuned models.
Returns:
nn.Module: The merged model after test-time adaptation.
"""
log.info("Running Fisher Merging Algorithm")
if isinstance(modelpool, (dict, list, tuple)):
modelpool = BaseModelPool(modelpool)
assert len(modelpool) > 0, "model pool is empty"
assert (
modelpool.has_pretrained
), "no pretrained model (base model) in the model pool"
self.modelpool = modelpool
self.on_fisher_merging_start()
# dictionary of list, where key is the parameter name,
# value is a list of the corresponding parameters of all the models that need to be merged
models_to_merge_param_dict = defaultdict(list)
# list of dictionaries with length len(models_to_merge),
# each dictionary records the fisher weights (matrix or vector) of parameters for each model that needs to be merged
models_to_merge_fisher_weights_list = []
param_names_to_merge = None
for name, model in modelpool.named_models():
param_dict = model.state_dict()
if param_names_to_merge is None:
param_names_to_merge = get_param_names_to_merge(
input_param_names=list(param_dict.keys()),
exclude_param_names_regex=self.config.get(
"exclude_param_names_regex", []
),
)
for param_name in param_names_to_merge:
models_to_merge_param_dict[param_name].append(param_dict[param_name])
model_to_merge_fisher_weights = self.get_fisher_weights(
model_name=name,
model=model,
train_dataset=modelpool.load_train_dataset(name),
param_names_to_merge=param_names_to_merge,
)
models_to_merge_fisher_weights_list.append(model_to_merge_fisher_weights)
merged_params = merging_with_fisher_weights(
models_to_merge_param_dict=models_to_merge_param_dict,
models_to_merge_fisher_weights_list=models_to_merge_fisher_weights_list,
fisher_scaling_coefficients=torch.ones(len(modelpool)) / len(modelpool),
normalize_fisher_weight=self.config.get("normalize_fisher_weight", True),
minimal_fisher_weight=self.config.get("minimal_fisher_weight", 1e-6),
)
merged_model = modelpool.load_model("_pretrained_")
merged_model.load_state_dict(merged_params, strict=False)
return merged_model
|