Skip to content

(Diagonal) Fisher Merging

The Fisher merging algorithm 1 is a per-parameter weighed averaging method that assigns weights to the models based on the Fisher information matrix of the models on some labeled data. The Fisher information matrix \(F_\theta\) of a model with parameters \(\theta\) can be expressed as:

\[ F_\theta = \mathbb{E}_{x \sim p(x)} \left[ \nabla_\theta \log p(y|x, \theta) \nabla_\theta \log p(y|x, \theta)^T \right] \]

where \(p(x)\) is the data distribution, \(p(y|x, \theta)\) is the model's output distribution, for example, the softmax output of a classification model, and \(\nabla_\theta\) is the gradient with respect to the model's parameters \(\theta\). The Fisher information matrix can be used to estimate the importance of each parameter in the model and thus assign weights to the models based on their Fisher information. In addition, the Fisher information matrix can be used to estimate the similarity between tasks, which can be useful in auxiliary-task learning and multi-task learning scenarios 2.

As the full Fisher information matrix is often computationally expensive to compute and memory-intensive to store, we approximate using the diagonal Fisher information matrix, which is the diagonal of the full Fisher information matrix. The diagonal Fisher information matrix can be computed as:

\[ \hat{F}_\theta = \mathbb{E}_{x \sim p(x)} \left[ \left(\nabla_\theta \log p(y|x, \theta)\right)^2 \right] \]

Assuming we have \(n\) models with parameters \(\theta_i\) and diagonal Fisher information matrices \(\hat{F}_{\theta_i}\), the Fisher merging algorithm computes the merged model's parameters \(\theta\) as follows:

\[ \theta^{(j)} = \frac{\sum_{i=1}^{n} \hat{F}_{\theta_i}^{(j)} \theta_i^{(j)}}{\sum_{i=1}^{n} \hat{F}_{\theta_i}^{(j)}} \]

where \(\theta_i\) are the parameters of the individual models, \(\hat{F}_{\theta_i}\) are the diagonal Fisher information matrices of the individual models, and \(j\) indexes the parameters of the models. The Fisher merging algorithm can be considered a per-weight weighed averaging method, where the weights are determined by the Fisher information of each parameter in the models.

Code Integration

Example of merging eight CLIP-ViT-B/32 models using Fisher merging:

fusion_bench method=clip_fisher_merging \
  modelpool=clip-vit-base-patch32_TA8 \
  taskpool=clip-vit-classification_TA8

Merge eight CLIP-ViT-L/14 models using Fisher merging:

fusion_bench \
  method=clip_fisher_merging \
    method.batch_size=8 method.num_workers=4 \
  modelpool=clip-vit-large-patch14_TA8 \
  taskpool=clip-vit-classification_TA8 \
    taskpool.clip_model=openai/clip-vit-large-patch14

Merge GPT-2 models for text classification tasks:

fusion_bench \
  method=gpt2_fisher_merging \
    method.num_fisher_examples=512 method.batch_size=8 \
  modelpool=gpt-2_glue \
  taskpool=gpt-2_glue

References

FisherMergingAlgorithm

Bases: ModelFusionAlgorithm

Source code in fusion_bench/method/fisher_merging/fisher_merging.py
class FisherMergingAlgorithm(ModelFusionAlgorithm):
    def run(self, modelpool: ModelPool):
        log.info("Running Fisher Merging Algorithm")
        modelpool = to_modelpool(modelpool)
        assert modelpool._model_names, "model pool is empty"
        assert (
            "_pretrained_" in modelpool._model_names
        ), "no pretrained model (base model) in the model pool"

        self.modelpool = modelpool
        self.on_fisher_merging_start()

        # dictionary of list, where key is the parameter name,
        # value is a list of the corresponding parameters of all the models that need to be merged
        models_to_merge_param_dict = defaultdict(list)

        # list of dictionaries with length len(models_to_merge),
        # each dictionary records the fisher weights (matrix or vector) of parameters for each model that needs to be merged
        models_to_merge_fisher_weights_list = []

        param_names_to_merge = None

        for name, model in modelpool.named_models():
            param_dict = model.state_dict()
            if param_names_to_merge is None:
                param_names_to_merge = get_param_names_to_merge(
                    input_param_names=list(param_dict.keys()),
                    exclude_param_names_regex=self.config.get(
                        "exclude_param_names_regex", []
                    ),
                )

            for param_name in param_names_to_merge:
                models_to_merge_param_dict[param_name].append(param_dict[param_name])

            model_to_merge_fisher_weights = self.get_fisher_weights(
                model_name=name,
                model=model,
                train_dataset=modelpool.get_train_dataset(name),
                param_names_to_merge=param_names_to_merge,
            )

            models_to_merge_fisher_weights_list.append(model_to_merge_fisher_weights)

        merged_params = merging_with_fisher_weights(
            models_to_merge_param_dict=models_to_merge_param_dict,
            models_to_merge_fisher_weights_list=models_to_merge_fisher_weights_list,
            fisher_scaling_coefficients=torch.ones(len(modelpool)) / len(modelpool),
            normalize_fisher_weight=self.config.get("normalize_fisher_weight", True),
            minimal_fisher_weight=self.config.get("minimal_fisher_weight", 1e-6),
        )

        merged_model = modelpool.load_model("_pretrained_")
        merged_model.load_state_dict(merged_params, strict=False)
        return merged_model

    def get_fisher_weights(
        self,
        model_name: str,
        model: nn.Module,
        train_dataset,
        param_names_to_merge: List[str],
    ) -> Dict[str, Tensor]:
        # this function is used to compute fisher weights for a model
        # it should be implemented in the subclass
        raise NotImplementedError

    def on_fisher_merging_start(self):
        # this function is used to initialize some variables before running fisher merging
        pass

  1. M. Matena, C. Raffel. "Merging Models with Fisher-Weighted Averaging" http://arxiv.org/abs/2111.09832 

  2. C. Wu, et al. "Pi-Tuning: Transferring Multimodal Foundation Models with Optimal Multi-task Interpolation". https://github.com/TencentARC/pi-Tuning