Skip to content

RegMean++

arxiv github

Revisiting the RegMean Algorithm

Regression Mean (RegMean)1, an approach that formulates model merging as a linear regression problem, aims to find the optimal weights for each linear layer in the merge model by minimizing the discrepancy in predictions between the merge and candidate models. At a transformer layer \(l\), to obtain the merge weights for a linear layer \(W^{(l)}_{M}\) , RegMean provides a precise closed-form solution for merging those from \(K\) candidate models as follows:

\[W^{(l)}_{M} = \left[\sum_{i=1}^{K} (X^{(l)}_i)^{\top} X^{(l)}_i\right]^{-1} \sum_{i=1}^{K} (X^{(l)}_i)^{\top} X^{(l)}_i W^{(l)}_i.\]

Problem of RegMean and How RegMean++ Addresses It

RegMean merges each linear layer independently, overlooking how the features and information in the earlier layers propagate through the layers and influence the final prediction in the merge model. To address this, RegMean++2 is proposed to explicitly incorporate both intra- and cross-layer dependencies between merge models' layers into RegMean's objective.

alt text
Comparison between RegMean and RegMean++ for model merging. RegMean++ leverages representations from the merge model for merging, enabling accurate alignment with its behavior.

The key difference between RegMean++ and RegMean lies in how input feature \(X^{(l,j)}_i\) for the \(j\)-th linear layer is obtained: For input features that are activations (cushion representations between transformer layers), RegMean++ computes \(X^{(l,j)}_i\) based on the activations produced by the previous merge layer \(f_{M}^{(l-1)}\) in the merge model, that is, \(X^{(l)}_i = f_{M}^{(l-1)}(X^{(l-1)}_{i})\) while RegMean relies on the activations produced by the previous candidate layer \(f_{i}^{(l-1)}\) in the candidate model, that is, \(X^{(l)}_i = f_{i}^{(l-1)}(X^{(l-1)}_{i})\).

Code Integration

The following command lines can be used to run and evaluate the RegMean++ algorithm on eight image classification tasks:

  • For CLIP-ViT-B/32 models:

    fusion_bench \
        method=regmean_plusplus/clip_regmean_plusplus \
        modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TA8 \
        taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8 \
            taskpool.base_model=openai/clip-vit-base-patch32
    

  • For CLIP-ViT-B/16 models:

    fusion_bench \
        method=regmean_plusplus/clip_regmean_plusplus \
        modelpool=CLIPVisionModelPool/clip-vit-base-patch16_TA8 \
        taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8 \
            taskpool.base_model=openai/clip-vit-base-patch16
    

  • For CLIP-ViT-L/14 models:

    fusion_bench \
        method=regmean_plusplus/clip_regmean_plusplus \
        modelpool=CLIPVisionModelPool/clip-vit-large-patch14_TA8 \
        taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8 \
            taskpool.base_model=openai/clip-vit-large-patch14
    

Code Implementation

RegMeanAlgorithmPlusPlus

Bases: BaseAlgorithm, SimpleProfilerMixin

Source code in fusion_bench/method/regmean_plusplus/regmean_plusplus.py
class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
    _include_module_type = [nn.Linear]
    _config_mapping = {
        "num_regmean_examples": "num_regmean_examples",
        "exclude_param_names_regex": "exclude_param_names_regex",
        "reduce_non_diagonal_ratio": "reduce_non_diagonal_ratio",
        "weight_transpose": "weight_transpose",
    }

    def __init__(
        self,
        *,
        num_regmean_examples: int,
        exclude_param_names_regex: list,
        reduce_non_diagonal_ratio: float,
        weight_transpose: bool,
        **kwargs,
    ):
        self.num_regmean_examples = num_regmean_examples
        self.exclude_param_names_regex = exclude_param_names_regex
        self.reduce_non_diagonal_ratio = reduce_non_diagonal_ratio
        self.weight_transpose = weight_transpose
        super().__init__(**kwargs)

    def run(self, modelpool: BaseModelPool, **kwargs):
        if not isinstance(modelpool, BaseModelPool):
            modelpool = BaseModelPool(modelpool)
        self.modelpool = modelpool
        device = "cuda:0" if torch.cuda.is_available() else "cpu"
        models_to_merge_dict = {name: model.to(device) for name, model in modelpool.named_models()}
        self.on_regmean_start()

        # initialize the merged models as the pretrained model
        merged_model = modelpool.load_pretrained_model().to(device)
        merged_params_dict = {}

        # 1. merge embedding layer
        merged_embedding_dict = self.merge_embedding_layer(models_to_merge_dict=models_to_merge_dict)
        merged_model.load_state_dict(merged_embedding_dict, strict=False)

        with torch.no_grad():
            # 1.1. compute input for the first layer
            with (
                self.profile("merging models"),
                self.profile("computing first layer input"),
            ):
                batches_input_dict = defaultdict(list)
                for name in tqdm(models_to_merge_dict.keys(), desc="computing input for first layer"):
                    dataset = modelpool.load_train_dataset(name)

                    batches_input_dict[name] = self.get_input_for_first_layer(
                        merged_model,
                        dataset
                    )

            # 2. iteratively merge layer by layer with regmean algorithm
            backbone_layers = self.get_layers(merged_model)
            num_layers = len(backbone_layers)

            models_to_merge_layers_dict = defaultdict(list)
            for name, model in models_to_merge_dict.items():
                models_to_merge_layers_dict[name] = self.get_layers(model)

            param_names_to_merge = None
            for layer_idx, backbone_layer in tqdm(enumerate(backbone_layers), 
                                                  desc="merging layers", 
                                                  total=num_layers):
                # dictionary of list, where key is the parameter name,
                # value is a list of the corresponding parameters of all the models that need to be merged
                models_to_merge_param_dict = defaultdict(list)

                # list of dictionaries with length len(models_to_merge),
                # each dictionary records the regmean weights (matrix) of parameters for each model that needs to be merged
                models_to_merge_regmean_weights_list = []

                for name, layers_to_merge in models_to_merge_layers_dict.items():
                    layer_to_merge = layers_to_merge[layer_idx]
                    param_dict = layer_to_merge.state_dict()

                    # exclude parameter whose name matches element in exclude_param_names_regex
                    if param_names_to_merge is None:
                        param_names_to_merge = get_param_names_to_merge(
                            input_param_names=list(param_dict.keys()),
                            exclude_param_names_regex=self.config.get(
                                "exclude_param_names_regex", []
                            ),
                        )

                    for param_name in param_names_to_merge:
                        models_to_merge_param_dict[param_name].append(
                            param_dict[param_name]
                        )

                    linear_modules_to_merge = get_modules_to_merge(
                        model=layer_to_merge, include_module_types=self._include_module_type
                    )
                    assert len(linear_modules_to_merge) > 0, "No linear modules to merge"

                    # 2.1. compute regmean weights for each model
                    with (
                        self.profile("merging models"),
                        self.profile("computing regmean weights"),
                    ):
                        regmean_weights = self.get_regmean_weights(
                            name,
                            layer_to_merge,
                            batches_input=batches_input_dict[name],
                            linear_modules_to_merge=linear_modules_to_merge,
                        )

                        module_subset = get_param_names_to_merge(
                            input_param_names=list(param_dict.keys()),
                            exclude_param_names_regex=self.exclude_param_names_regex
                        )
                        module_subset = [name.replace(".weight", "").replace(".bias", "") for name in module_subset]
                        module_subset = list(set(module_subset))
                        regmean_weights = {module_name: regmean_weights[module_name] for module_name in module_subset if module_name in regmean_weights}

                        models_to_merge_regmean_weights_list.append(regmean_weights)

                # 2.2. merge parameters with regmean weights
                with self.profile("merging models"):
                    # merging with regmean weights
                    merged_layer_params = merging_with_regmean_weights(
                        models_to_merge_param_dict=models_to_merge_param_dict,
                        models_to_merge_regmean_weights_list=models_to_merge_regmean_weights_list,
                        reduce_non_diagonal_ratio=self.reduce_non_diagonal_ratio,
                        weight_transpose=self.config.get("weight_transpose", True),
                    )

                    merged_params_dict = self.update_merged_params_dict(
                        merged_params_dict=merged_params_dict,
                        new_merged_params=merged_layer_params,
                        layer_idx=layer_idx,
                    )

                # 2.3. compute input for the next layer
                with (
                    self.profile("merging models"),
                    self.profile("forwarding next layer"),
                ):
                    if layer_idx < num_layers - 1:
                        backbone_layer.load_state_dict(merged_layer_params, strict=False)
                        batches_output_dict = defaultdict(list)
                        for name in models_to_merge_dict.keys():
                            batches_output_dict[name] = self.layer_batches_forward(
                                backbone_layer, 
                                batches_input_dict[name]
                            )
                        batches_input_dict = batches_output_dict

            # 3. load state dict to the merged model
            merged_model.load_state_dict(merged_params_dict, strict=False)

        self.print_profile_summary()
        return merged_model

    def merge_embedding_layer(self, models_to_merge_dict: Dict[str, nn.Module]):
        """
        Merge the embedding layer of the model with the merged model.
        This method should be implemented in subclasses if needed.
        """
        raise NotImplementedError()

    def get_input_for_first_layer(self, model: nn.Module, train_dataset):
        raise NotImplementedError

    def get_layers(self, model: nn.Module):
        raise NotImplementedError

    def update_merged_params_dict(self, merged_params_dict, new_merged_params, layer_idx):
        raise NotImplementedError

    def layer_batches_forward(self, layer: nn.Module, batches_input: List[Tensor]):
        raise NotImplementedError

    def on_regmean_start(self):
        pass

    def get_regmean_weights(
        self,
        model_name: str,
        layer: nn.Module,
        batches_input: List[Tensor],
        linear_modules_to_merge: Dict[str, nn.Module],
    ):
        raise NotImplementedError
merge_embedding_layer(models_to_merge_dict)

Merge the embedding layer of the model with the merged model. This method should be implemented in subclasses if needed.

Source code in fusion_bench/method/regmean_plusplus/regmean_plusplus.py
def merge_embedding_layer(self, models_to_merge_dict: Dict[str, nn.Module]):
    """
    Merge the embedding layer of the model with the merged model.
    This method should be implemented in subclasses if needed.
    """
    raise NotImplementedError()

RegMeanAlgorithmForCLIPPlusPlus

Bases: RegMeanAlgorithmPlusPlus, CLIPClassificationMixin

Source code in fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py
class RegMeanAlgorithmForCLIPPlusPlus(
    RegMeanAlgorithmPlusPlus,
    CLIPClassificationMixin,
):
    _config_mapping = {
        "_dataloader_kwargs": "dataloader_kwargs",
    }

    def __init__(self, *, dataloader_kwargs: DictConfig, **kwargs):
        super().__init__(**kwargs)
        self._dataloader_kwargs = dataloader_kwargs

    def on_regmean_start(self):
        self.setup_zero_shot_classification_head()

    def compute_logits(self, module, batch, task: str) -> Tensor:
        images, _ = batch
        text_embeds = self.zeroshot_weights[task]

        image_embeds = module(images)[1]
        image_embeds = self.visual_projection(image_embeds)

        # normalize embeddings
        image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)

        # cosine similarity
        logits_per_text = (
            torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
        )
        logits_per_image = logits_per_text.t()

        return logits_per_image

    def get_regmean_weights(
        self,
        model_name: str,
        layer: Module,
        batches_input: List[Tensor],
        linear_modules_to_merge: Dict[str, Module],
    ):
        layer = self.fabric.setup(layer)

        def compute_regmean_weights(module_name: str):
            """
            compute the regmean weights, a hook function to deal with each module's input
            :param module_name: str, module name
            :return:
            """

            def hook(module: nn.Module, input: tuple, output: torch.Tensor):
                # Tensor, shape (batch_size, sequence_length, hidden_dim)
                x = cast(Tensor, input[0]).detach()
                batch_num_actual_examples = x.shape[0]
                # Tensor, shape (batch_size * sequence_length, hidden_dim)
                x = x.reshape(-1, x.shape[-1])
                # Tensor, shape (hidden_dim, hidden_dim)
                xtx = torch.matmul(x.transpose(0, 1), x)
                # store the averaged weights in regmean_weights
                if module_name not in regmean_weights.keys():
                    regmean_weights[module_name] = xtx / x.shape[0]
                    num_computed_examples[module_name] = x.shape[0]
                    num_actual_examples[module_name] = batch_num_actual_examples
                else:
                    regmean_weights[module_name] = (
                        regmean_weights[module_name]
                        * num_computed_examples[module_name]
                        + xtx
                    ) / (num_computed_examples[module_name] + x.shape[0])
                    num_computed_examples[module_name] += x.shape[0]
                    num_actual_examples[module_name] += batch_num_actual_examples

            return hook

        handles = []
        # dictionary, regmean matrices for each linear module inputs
        regmean_weights = {}
        # dictionary, number of examples (multiplied the sequence length) used for computing regmean matrices
        num_computed_examples = {}
        # dictionary, number of actual examples used for computing regmean matrices
        num_actual_examples = {}

        for module_name, linear_module_to_merge in linear_modules_to_merge.items():
            # register a hook in the forward process
            handle = linear_module_to_merge.register_forward_hook(
                compute_regmean_weights(module_name=module_name)
            )
            handles.append(handle)
        _ = self.layer_batches_forward(layer, batches_input)

        # remove the added hook
        for handle in handles:
            handle.remove()

        for module_name in regmean_weights.keys():
            regmean_weights[module_name] = regmean_weights[module_name].detach().cpu()

        return regmean_weights

    def merge_embedding_layer(self, models_to_merge_dict: Dict[str, nn.Module]):
        models_to_merge_param_dict = defaultdict(list)

        # get the parameters of the embedding layer from each model
        for model_to_merge in models_to_merge_dict.values():
            model_to_merge_state_dict = model_to_merge.state_dict()

            param_dict = {}
            for name, param in model_to_merge_state_dict.items():
                if name.startswith("vision_model.embeddings") or name.startswith("vision_model.pre_layrnorm"):
                    param_dict[name] = param

            for param_name in param_dict.keys():
                models_to_merge_param_dict[param_name].append(
                    param_dict[param_name]
                )

        # merge the parameters of the embedding layer
        merged_params_dict = {}
        for param_name, param_list in models_to_merge_param_dict.items():
            merged_params_dict[param_name] = torch.stack(param_list).mean(dim=0)

        return merged_params_dict


    def get_input_for_first_layer(self, model: nn.Module, train_dataset):
        # setup dataloader
        train_dataset = CLIPDataset(train_dataset, self.clip_processor)
        train_dataloader = DataLoader(
            train_dataset, shuffle=True, **self._dataloader_kwargs
        )
        train_dataloader = self.fabric.setup_dataloaders(train_dataloader)
        model = self.fabric.setup(model)

        def compute_input(model, batch):
            images, _ = batch

            images = images.to(model.device)
            image_embeds = model.vision_model.embeddings(images)
            image_embeds = model.vision_model.pre_layrnorm(image_embeds)
            image_embeds = image_embeds.detach().cpu()

            return image_embeds

        num_computed_examples = 0
        num_regmean_examples = self.num_regmean_examples

        batches_input = []
        for batch in train_dataloader:
            if num_computed_examples >= num_regmean_examples:
                break
            batches_input.append(compute_input(model, batch))
            num_computed_examples += batch[0].size(0)

        return batches_input

    def get_layers(self, model: nn.Module):
        return model.vision_model.encoder.layers

    def update_merged_params_dict(self, merged_params_dict, new_merged_params, layer_idx):
        for key, value in new_merged_params.items():
            key = f"vision_model.encoder.layers.{layer_idx}.{key}"
            merged_params_dict[key] = value

        return merged_params_dict

    def layer_batches_forward(self, layer: nn.Module, batches_input: List[Tensor]) -> Tensor:
        batches_output = []
        for batch in batches_input:
            device = next(layer.parameters()).device
            batch = batch.to(device)
            logits = layer(batch, attention_mask=None, causal_attention_mask=None)[0].detach().cpu()
            batches_output.append(logits)
        return batches_output

References


  1. Xisen Jin, Xiang Ren, Daniel Preotiuc-Pietro, and Pengxiang Cheng. "Dataless Knowledge Fusion by Merging Weights of Language Models." The Eleventh International Conference on Learning Representations. 

  2. The-Hai Nguyen, Huu-Tien Dang, Takeshi Suzuki, and Le-Minh Nguyen. "RegMean++: Enhancing Effectiveness and Generalization of Regression Mean for Model Merging". arXiv preprint arXiv:2508.03121 (2025).