Skip to content

(Diagonal) Fisher Merging

The Fisher merging algorithm 1 is a per-parameter weighed averaging method that assigns weights to the models based on the Fisher information matrix of the models on some labeled data. The Fisher information matrix \(F_\theta\) of a model with parameters \(\theta\) can be expressed as:

\[ F_\theta = \mathbb{E}_{x \sim p(x)} \left[ \nabla_\theta \log p(y|x, \theta) \nabla_\theta \log p(y|x, \theta)^T \right] \]

where \(p(x)\) is the data distribution, \(p(y|x, \theta)\) is the model's output distribution, for example, the softmax output of a classification model, and \(\nabla_\theta\) is the gradient with respect to the model's parameters \(\theta\). The Fisher information matrix can be used to estimate the importance of each parameter in the model and thus assign weights to the models based on their Fisher information. In addition, the Fisher information matrix can be used to estimate the similarity between tasks, which can be useful in auxiliary-task learning and multi-task learning scenarios 2.

As the full Fisher information matrix is often computationally expensive to compute and memory-intensive to store, we approximate using the diagonal Fisher information matrix, which is the diagonal of the full Fisher information matrix. The diagonal Fisher information matrix can be computed as:

\[ \hat{F}_\theta = \mathbb{E}_{x \sim p(x)} \left[ \left(\nabla_\theta \log p(y|x, \theta)\right)^2 \right] \]

Assuming we have \(n\) models with parameters \(\theta_i\) and diagonal Fisher information matrices \(\hat{F}_{\theta_i}\), the Fisher merging algorithm computes the merged model's parameters \(\theta\) as follows:

\[ \theta^{(j)} = \frac{\sum_{i=1}^{n} \hat{F}_{\theta_i}^{(j)} \theta_i^{(j)}}{\sum_{i=1}^{n} \hat{F}_{\theta_i}^{(j)}} \]

where \(\theta_i\) are the parameters of the individual models, \(\hat{F}_{\theta_i}\) are the diagonal Fisher information matrices of the individual models, and \(j\) indexes the parameters of the models. The Fisher merging algorithm can be considered a per-weight weighed averaging method, where the weights are determined by the Fisher information of each parameter in the models.

Code Integration

Example of merging eight CLIP-ViT-B/32 models using Fisher merging:

fusion_bench method=clip_fisher_merging \
  modelpool=clip-vit-base-patch32_TA8 \
  taskpool=clip-vit-classification_TA8

Merge eight CLIP-ViT-L/14 models using Fisher merging:

fusion_bench \
  method=clip_fisher_merging \
    method.batch_size=8 method.num_workers=4 \
  modelpool=clip-vit-large-patch14_TA8 \
  taskpool=clip-vit-classification_TA8 \
    taskpool.clip_model=openai/clip-vit-large-patch14

Merge GPT-2 models for text classification tasks:

fusion_bench \
  method=gpt2_fisher_merging \
    method.num_fisher_examples=512 method.batch_size=8 \
  modelpool=gpt-2_glue \
  taskpool=gpt-2_glue

References

FisherMergingAlgorithm

Bases: BaseAlgorithm

Implements the Fisher Merging Algorithm.

This class extends the BaseModelFusionAlgorithm to handle merging of models using Fisher weights. It supports excluding certain parameters, normalizing Fisher weights, and setting a minimal value for Fisher weights.

Methods:

  • run

    BaseModelPool) -> nn.Module: Executes the Fisher merging process on the model pool and returns the merged model.

Source code in fusion_bench/method/fisher_merging/fisher_merging.py
class FisherMergingAlgorithm(BaseAlgorithm):
    """
    Implements the Fisher Merging Algorithm.

    This class extends the BaseModelFusionAlgorithm to handle merging of models using Fisher weights.
    It supports excluding certain parameters, normalizing Fisher weights, and setting a minimal value for Fisher weights.

    Methods:
        run(modelpool: BaseModelPool) -> nn.Module:
            Executes the Fisher merging process on the model pool and returns the merged model.
    """

    _config_mapping = BaseAlgorithm._config_mapping | {
        "exclude_param_names_regex": "exclude_param_names_regex",
        "normalize_fisher_weight": "normalize_fisher_weight",
        "minimal_fisher_weight": "minimal_fisher_weight",
        "num_fisher_examples": "num_fisher_examples",
    }

    def __init__(
        self,
        *,
        exclude_param_names_regex: list,
        normalize_fisher_weight: bool,
        minimal_fisher_weight: float,
        num_fisher_examples: int,
    ):
        super().__init__()
        self.exclude_param_names_regex = exclude_param_names_regex
        self.normalize_fisher_weight = normalize_fisher_weight
        self.minimal_fisher_weight = minimal_fisher_weight
        self.num_fisher_examples = num_fisher_examples

    def run(self, modelpool: BaseModelPool) -> nn.Module:
        """
        Run the Fisher Merging Algorithm.

        This method constructs the wrapped model and performs test-time adaptation if necessary.

        Args:
            modelpool (BaseModelPool): The model pool containing the pretrained and fine-tuned models.

        Returns:
            nn.Module: The merged model after test-time adaptation.
        """
        log.info("Running Fisher Merging Algorithm")
        if isinstance(modelpool, (dict, list, tuple)):
            modelpool = BaseModelPool(modelpool)

        assert len(modelpool) > 0, "model pool is empty"
        assert (
            modelpool.has_pretrained
        ), "no pretrained model (base model) in the model pool"

        self.modelpool = modelpool
        self.on_fisher_merging_start()

        # dictionary of list, where key is the parameter name,
        # value is a list of the corresponding parameters of all the models that need to be merged
        models_to_merge_param_dict = defaultdict(list)

        # list of dictionaries with length len(models_to_merge),
        # each dictionary records the fisher weights (matrix or vector) of parameters for each model that needs to be merged
        models_to_merge_fisher_weights_list = []

        param_names_to_merge = None

        for name, model in modelpool.named_models():
            param_dict = model.state_dict()
            if param_names_to_merge is None:
                param_names_to_merge = get_param_names_to_merge(
                    input_param_names=list(param_dict.keys()),
                    exclude_param_names_regex=self.config.get(
                        "exclude_param_names_regex", []
                    ),
                )

            for param_name in param_names_to_merge:
                models_to_merge_param_dict[param_name].append(param_dict[param_name])

            model_to_merge_fisher_weights = self.get_fisher_weights(
                model_name=name,
                model=model,
                train_dataset=modelpool.load_train_dataset(name),
                param_names_to_merge=param_names_to_merge,
            )

            models_to_merge_fisher_weights_list.append(model_to_merge_fisher_weights)

        merged_params = merging_with_fisher_weights(
            models_to_merge_param_dict=models_to_merge_param_dict,
            models_to_merge_fisher_weights_list=models_to_merge_fisher_weights_list,
            fisher_scaling_coefficients=torch.ones(len(modelpool)) / len(modelpool),
            normalize_fisher_weight=self.config.get("normalize_fisher_weight", True),
            minimal_fisher_weight=self.config.get("minimal_fisher_weight", 1e-6),
        )

        merged_model = modelpool.load_model("_pretrained_")
        merged_model.load_state_dict(merged_params, strict=False)
        return merged_model

    def get_fisher_weights(
        self,
        model_name: str,
        model: nn.Module,
        train_dataset,
        param_names_to_merge: List[str],
    ) -> Dict[str, Tensor]:
        """
        Compute the Fisher weights for the given model and training dataset.

        Args:
            model_name (str): The name of the model.
            model (nn.Module): The model module.
            train_dataset: The training dataset.
            param_names_to_merge (List[str]): List of parameter names to merge.

        Returns:
            Dict[str, Tensor]: The computed Fisher weights for each parameter.
        """
        # this function is used to compute fisher weights for a model
        # it should be implemented in the subclass
        raise NotImplementedError

    def on_fisher_merging_start(self):
        """
        Setup the zero-shot classification head before starting the Fisher merging process.
        """
        # this function is used to initialize some variables before running fisher merging
        pass
get_fisher_weights(model_name, model, train_dataset, param_names_to_merge)

Compute the Fisher weights for the given model and training dataset.

Parameters:

  • model_name
    (str) –

    The name of the model.

  • model
    (Module) –

    The model module.

  • train_dataset

    The training dataset.

  • param_names_to_merge
    (List[str]) –

    List of parameter names to merge.

Returns:

  • Dict[str, Tensor]

    Dict[str, Tensor]: The computed Fisher weights for each parameter.

Source code in fusion_bench/method/fisher_merging/fisher_merging.py
def get_fisher_weights(
    self,
    model_name: str,
    model: nn.Module,
    train_dataset,
    param_names_to_merge: List[str],
) -> Dict[str, Tensor]:
    """
    Compute the Fisher weights for the given model and training dataset.

    Args:
        model_name (str): The name of the model.
        model (nn.Module): The model module.
        train_dataset: The training dataset.
        param_names_to_merge (List[str]): List of parameter names to merge.

    Returns:
        Dict[str, Tensor]: The computed Fisher weights for each parameter.
    """
    # this function is used to compute fisher weights for a model
    # it should be implemented in the subclass
    raise NotImplementedError
on_fisher_merging_start()

Setup the zero-shot classification head before starting the Fisher merging process.

Source code in fusion_bench/method/fisher_merging/fisher_merging.py
def on_fisher_merging_start(self):
    """
    Setup the zero-shot classification head before starting the Fisher merging process.
    """
    # this function is used to initialize some variables before running fisher merging
    pass
run(modelpool)

Run the Fisher Merging Algorithm.

This method constructs the wrapped model and performs test-time adaptation if necessary.

Parameters:

  • modelpool
    (BaseModelPool) –

    The model pool containing the pretrained and fine-tuned models.

Returns:

  • Module

    nn.Module: The merged model after test-time adaptation.

Source code in fusion_bench/method/fisher_merging/fisher_merging.py
def run(self, modelpool: BaseModelPool) -> nn.Module:
    """
    Run the Fisher Merging Algorithm.

    This method constructs the wrapped model and performs test-time adaptation if necessary.

    Args:
        modelpool (BaseModelPool): The model pool containing the pretrained and fine-tuned models.

    Returns:
        nn.Module: The merged model after test-time adaptation.
    """
    log.info("Running Fisher Merging Algorithm")
    if isinstance(modelpool, (dict, list, tuple)):
        modelpool = BaseModelPool(modelpool)

    assert len(modelpool) > 0, "model pool is empty"
    assert (
        modelpool.has_pretrained
    ), "no pretrained model (base model) in the model pool"

    self.modelpool = modelpool
    self.on_fisher_merging_start()

    # dictionary of list, where key is the parameter name,
    # value is a list of the corresponding parameters of all the models that need to be merged
    models_to_merge_param_dict = defaultdict(list)

    # list of dictionaries with length len(models_to_merge),
    # each dictionary records the fisher weights (matrix or vector) of parameters for each model that needs to be merged
    models_to_merge_fisher_weights_list = []

    param_names_to_merge = None

    for name, model in modelpool.named_models():
        param_dict = model.state_dict()
        if param_names_to_merge is None:
            param_names_to_merge = get_param_names_to_merge(
                input_param_names=list(param_dict.keys()),
                exclude_param_names_regex=self.config.get(
                    "exclude_param_names_regex", []
                ),
            )

        for param_name in param_names_to_merge:
            models_to_merge_param_dict[param_name].append(param_dict[param_name])

        model_to_merge_fisher_weights = self.get_fisher_weights(
            model_name=name,
            model=model,
            train_dataset=modelpool.load_train_dataset(name),
            param_names_to_merge=param_names_to_merge,
        )

        models_to_merge_fisher_weights_list.append(model_to_merge_fisher_weights)

    merged_params = merging_with_fisher_weights(
        models_to_merge_param_dict=models_to_merge_param_dict,
        models_to_merge_fisher_weights_list=models_to_merge_fisher_weights_list,
        fisher_scaling_coefficients=torch.ones(len(modelpool)) / len(modelpool),
        normalize_fisher_weight=self.config.get("normalize_fisher_weight", True),
        minimal_fisher_weight=self.config.get("minimal_fisher_weight", 1e-6),
    )

    merged_model = modelpool.load_model("_pretrained_")
    merged_model.load_state_dict(merged_params, strict=False)
    return merged_model

  1. M. Matena, C. Raffel. "Merging Models with Fisher-Weighted Averaging" http://arxiv.org/abs/2111.09832 

  2. C. Wu, et al. "Pi-Tuning: Transferring Multimodal Foundation Models with Optimal Multi-task Interpolation". https://github.com/TencentARC/pi-Tuning