Skip to content

AdaMerging

alt text
Task Vector, Task Arithmetic, and AdaMerging. Credit to 1

In the complex landscape of multi-task learning, AdaMerging has emerged as a potent method for adaptively merging model parameters to optimize performance across tasks. Unlike traditional fixed-coefficient methods, AdaMerging autonomously learns merging coefficients, offering a more refined and responsive approach1.

The cornerstone of AdaMerging lies in its adaptive nature, where it learns the coefficients for merging either on a task-wise or layer-wise basis. This adaptability is driven by an entropy minimization strategy applied to unlabeled test samples as a surrogate objective function, which serves to refine the merging coefficients for optimal performance.

Task-wise AdaMerging is formulated as:

\[ \theta = \theta_0 + \sum_{i=1}^{n} \lambda_i \tau_i \]

where \(\lambda_i\) represents the merging coefficient for the \(i\)-th task, and \(\tau_i\) denotes the task vector for the \(i\)-th task.

On the other hand, Layer-wise AdaMerging is articulated as:

\[ \theta^l = \theta_0^l + \sum_{i=1}^{n} \lambda^{l}_{i} \tau^{l}_{i} \]

where the merging coefficient \(\lambda^{l}_{i}\) and task vector \(\tau^{l}_{i}\) are specific to each layer \(l\) of the model.

By leveraging this adaptive learning approach, AdaMerging significantly enhances the model's ability to generalize across tasks and layers, resulting in a more robust and finely-tuned performance profile. The method’s reliance on entropy minimization ensures that the merging process continually seeks the most informative and stable configuration, adapting to the specific needs of the dataset and tasks at hand.

AdaMerging Analysis

Task-wise Coefficients. The below Figure shows the changes during the iteration process of merging coefficient optimization of each task vector in Task-wise AdaMerging and AdaMerging++, which is shown every ten steps. We consistently observe that the merging coefficients of each task vector are inconsistent. When the number of tasks is relatively large, it is obviously undesirable to grid search the coefficients of each task, but our AdaMerging avoids this manual search process.

alt text
Model merging coefficients \(\{λ_k\}_{k=1}^K\) change with respect to training steps on ViT-B/32:
(a) Task-wise AdaMerging; (b) Task-wise AdaMerging++. Each line represents the change process of the coefficient \(λ_k\) of a task vector \(T_k (k \in \{1, 2, . . . , K\})\).

Layer-wise Coefficients. The following Figure shows the merging coefficients learned by Layer-wise AdaMerging and AdaMerging++ on ViT-B/32 respectively. We observed that:

  1. The coefficients learned by each layer of each task vector are different, which shows that the importance of each layer in the model merging process is different.
  2. The coefficients learned by shallow layers are generally smaller than those of deep layers, which indicates that shallow layers rely more on the weights of the pre-trained model rather than the weights provided by task vectors, while the deep layers rely more on the weights provided by the task vectors. This may be since the shallow layer learns general features, which are cross-task, while the deep layer learns task-specific features 2. This finding is also consistent with routing analysis in 3.
alt text
Learned model merging coefficients \(\{λ_l^k\}^{K,L}_{k=1,l=1}\) of Layer-wise AdaMerging (Above) and AdaMerging++ (Below) on ViT-B/32. The \(k\)-th row represents the \(k\)-th task vector, the \(l\)-th column represents the \(l\)-th layer, and the intersection point represents the coefficient \(λ^l_k\).

Code Integration

Merge CLIP-ViT-B/32 models from eight downstream image classification tasks:

fusion_bench \
    method=adamerging \
        method.name=clip_layer_wise_adamerging \
        method.save_merging_weights=merging_weights.pt \
    modelpool=clip-vit-base-patch32_TA8 \
    taskpool=clip-vit-classification_TA8 \
    fabric_logger.root_dir=outputs/logs/ViT-B-32 \
    fabric_logger.name=clip_layer_wise_adamerging_adam

Part of the output:

Profiler Report

----------------------------------------------------------------------------------------------------------------------------------
|  Action                       |  Mean duration (s)    |  Num calls            |  Total time (s)       |  Percentage %         |
----------------------------------------------------------------------------------------------------------------------------------
|  Total                        |  -                    |  26001                |  724.65               |  100 %                |
----------------------------------------------------------------------------------------------------------------------------------
|  backward pass                |  0.060172             |  8000                 |  481.38               |  66.429               |
|  forward pass                 |  0.016124             |  8000                 |  128.99               |  17.801               |
|  data loading                 |  0.0063443            |  8000                 |  50.754               |  7.004                |
|  merging weights              |  0.050735             |  1000                 |  50.735               |  7.0013               |
|  construct the wrapped model  |  7.2558               |  1                    |  7.2558               |  1.0013               |
|  optimizer step               |  0.00098186           |  1000                 |  0.98186              |  0.13549              |
----------------------------------------------------------------------------------------------------------------------------------

Reference

Task-Wise AdaMerging

task_wise_adamerging

entropy_loss(logits)

Compute the entropy loss of a set of logits.

Parameters:

  • logits (Tensor) –

    The logits to compute the entropy loss of.

Returns:

  • Tensor ( Tensor ) –

    The entropy loss of the logits.

Source code in fusion_bench/method/adamerging/task_wise_adamerging.py
def entropy_loss(logits: Tensor) -> Tensor:
    """
    Compute the entropy loss of a set of logits.

    Args:
        logits (Tensor): The logits to compute the entropy loss of.

    Returns:
        Tensor: The entropy loss of the logits.
    """
    probs = torch.softmax(logits, dim=-1)
    return -torch.sum(probs * torch.log(probs + 1e-8), dim=-1).mean()

clip_task_wise_adamerging

CLIPTaskWiseAdaMergingAlgorithm

Bases: TaskWiseAdaMergingAlgorithm

Source code in fusion_bench/method/adamerging/clip_task_wise_adamerging.py
class CLIPTaskWiseAdaMergingAlgorithm(TaskWiseAdaMergingAlgorithm):
    modelpool: HuggingFaceClipVisionPool = None
    _clip_processor: CLIPProcessor = None
    zeroshot_weights = {}

    def __init__(self, algorithm_config: DictConfig):
        super().__init__(algorithm_config)

    def get_task_config(self, task):
        for task_config in self.modelpool.config.tta_datasets:
            if task_config.name == task:
                return task_config
        raise ValueError(f"Task {task} not found in config")

    def prepare_dataset_config(self, dataset_config: DictConfig):
        if not hasattr(dataset_config, "type"):
            with open_dict(dataset_config):
                dataset_config["type"] = self.modelpool.config.dataset_type
        return dataset_config

    @functools.cache
    def get_test_dataset(self, task: str):
        """
        Load the test dataset for the task.
        This method is cached, so the dataset is loaded only once.
        """
        dataset_config = self.get_task_config(task)["dataset"]
        dataset_config = self.prepare_dataset_config(dataset_config)
        log.info(f"Loading test dataset: {dataset_config.name}")
        dataset = load_dataset_from_config(dataset_config)
        dataset = CLIPDataset(dataset, self._clip_processor)
        return dataset

    @functools.cache
    def get_shuffled_test_loader_iter(self, task: str):
        loader = DataLoader(
            self.get_test_dataset(task),
            batch_size=self.config.batch_size,
            shuffle=True,
            num_workers=self.config.num_workers,
            pin_memory=True,
        )
        if self._fabric is not None:
            loader = self._fabric.setup_dataloaders(loader)
        return iter(InfiniteDataLoader(loader))

    def on_test_time_adaptation_start(self):
        """
        Here we load the CLIP processor and construct the zero-shot classification head for each task.
        """
        clip_model_config = self.modelpool.get_model_config("_pretrained_")

        with timeit_context("Loading CLIP processor and pretrained CLIP model."):
            self._clip_processor = CLIPProcessor.from_pretrained(clip_model_config.path)
            clip_model = CLIPModel.from_pretrained(clip_model_config.path)

            clip_classifier = HFCLIPClassifier(clip_model, self._clip_processor)
            self.visual_projection = clip_model.visual_projection.requires_grad_(False)
            self.logit_scale = clip_model.logit_scale.exp()
            if self._fabric is not None:
                self.visual_projection = self._fabric.to_device(self.visual_projection)
                self.logit_scale = self._fabric.to_device(self.logit_scale)

        for task in self.modelpool.model_names:
            cache_file = os.path.join(
                self.config.cache_dir,
                f"{os.path.basename(clip_model_config.path)}_{task}_zeroshot_weights.pt",
            )
            if os.path.exists(cache_file):
                log.info(f"Loading cached zeroshot weights for task: {task}")
                zeroshot_weights = torch.load(cache_file, map_location="cpu")
            else:
                log.info(f"Construct zero shot classification head for task: {task}")
                classnames, templates = get_classnames_and_templates(
                    self.get_task_config(task)["dataset"].name
                )
                clip_classifier.set_classification_task(classnames, templates)
                zeroshot_weights = clip_classifier.zeroshot_weights
                log.info(f"save zeroshot weights to {cache_file}")
                torch.save(zeroshot_weights, cache_file)
            self.zeroshot_weights[task] = zeroshot_weights
            if self._fabric is not None:
                self.zeroshot_weights[task] = self._fabric.to_device(
                    self.zeroshot_weights[task]
                )

    def compute_logits(self, module, batch, task: str) -> Tensor:
        images, _ = batch
        text_embeds = self.zeroshot_weights[task]

        image_embeds = module(images)[1]
        image_embeds = self.visual_projection(image_embeds)

        # normalize embeddings
        image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)

        # cosine similarity
        logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale
        logits_per_image = logits_per_text.t()

        return logits_per_image
get_test_dataset(task) cached

Load the test dataset for the task. This method is cached, so the dataset is loaded only once.

Source code in fusion_bench/method/adamerging/clip_task_wise_adamerging.py
@functools.cache
def get_test_dataset(self, task: str):
    """
    Load the test dataset for the task.
    This method is cached, so the dataset is loaded only once.
    """
    dataset_config = self.get_task_config(task)["dataset"]
    dataset_config = self.prepare_dataset_config(dataset_config)
    log.info(f"Loading test dataset: {dataset_config.name}")
    dataset = load_dataset_from_config(dataset_config)
    dataset = CLIPDataset(dataset, self._clip_processor)
    return dataset
on_test_time_adaptation_start()

Here we load the CLIP processor and construct the zero-shot classification head for each task.

Source code in fusion_bench/method/adamerging/clip_task_wise_adamerging.py
def on_test_time_adaptation_start(self):
    """
    Here we load the CLIP processor and construct the zero-shot classification head for each task.
    """
    clip_model_config = self.modelpool.get_model_config("_pretrained_")

    with timeit_context("Loading CLIP processor and pretrained CLIP model."):
        self._clip_processor = CLIPProcessor.from_pretrained(clip_model_config.path)
        clip_model = CLIPModel.from_pretrained(clip_model_config.path)

        clip_classifier = HFCLIPClassifier(clip_model, self._clip_processor)
        self.visual_projection = clip_model.visual_projection.requires_grad_(False)
        self.logit_scale = clip_model.logit_scale.exp()
        if self._fabric is not None:
            self.visual_projection = self._fabric.to_device(self.visual_projection)
            self.logit_scale = self._fabric.to_device(self.logit_scale)

    for task in self.modelpool.model_names:
        cache_file = os.path.join(
            self.config.cache_dir,
            f"{os.path.basename(clip_model_config.path)}_{task}_zeroshot_weights.pt",
        )
        if os.path.exists(cache_file):
            log.info(f"Loading cached zeroshot weights for task: {task}")
            zeroshot_weights = torch.load(cache_file, map_location="cpu")
        else:
            log.info(f"Construct zero shot classification head for task: {task}")
            classnames, templates = get_classnames_and_templates(
                self.get_task_config(task)["dataset"].name
            )
            clip_classifier.set_classification_task(classnames, templates)
            zeroshot_weights = clip_classifier.zeroshot_weights
            log.info(f"save zeroshot weights to {cache_file}")
            torch.save(zeroshot_weights, cache_file)
        self.zeroshot_weights[task] = zeroshot_weights
        if self._fabric is not None:
            self.zeroshot_weights[task] = self._fabric.to_device(
                self.zeroshot_weights[task]
            )

Layer-Wise AdaMerging

layer_wise_adamerging

LayerWiseAdaMergingAlgorithm

Bases: ModelFusionAlgorithm, LightningFabricMixin, SimpleProfilerMixin

Source code in fusion_bench/method/adamerging/layer_wise_adamerging.py
class LayerWiseAdaMergingAlgorithm(
    ModelFusionAlgorithm,
    LightningFabricMixin,
    SimpleProfilerMixin,
):
    def __init__(self, algorithm_config: DictConfig):
        super().__init__(algorithm_config)

    @torch.no_grad()
    def construct_layer_wise_merged_model(self, modelpool: ModelPool):
        """
        Constructs a wrapped layer-wise merged model from model pool.

        This method creates a new wrapped model by merging the layers of a pretrained model with those of several fine-tuned models.
        The merging is controlled by layer-wise weights, which is a `torch.Tensor` of the shape `(num_models, num_layers)`.
        The merging weights can be initialized based on a provided configuration or loaded from a file.

        Args:
            modelpool (ModelPool): An object containing the pretrained model and fine-tuned models to be merged.

        Returns:
            LayerWiseMergedModel: An instance of the merged model with layer-wise weights applied.
        """
        pretrained_model = modelpool.load_model("_pretrained_")
        finetuned_models = [
            modelpool.load_model(name) for name in modelpool.model_names
        ]

        # initialize layer-wise weights using the provided configuration `init_values` or load from file if `weights` is provided
        if self.config.weights is None:
            layer_wise_weight = get_layer_wise_weights(
                num_models=len(modelpool.model_names),
                num_layers=len(
                    tuple(
                        filter(lambda p: p.requires_grad, pretrained_model.parameters())
                    )
                ),
                init_values=self.config.init_values,
            )
        else:
            if isinstance(self.config.weights, str):
                # self.config.weights is a path to a saved tensor
                layer_wise_weight = load_tensor_from_file(self.config.weights)
            else:
                raise ValueError(f"Unsupported weights format: {self.config.weights}")

        module = LayerWiseMergedModel(
            layer_wise_weight=layer_wise_weight,
            pretrained_model=pretrained_model,
            finetuned_models=finetuned_models,
            clamp_weights=self.config.clamp_weights,
            tie_weights=self.config.tie_weights,
            strict=self.config.strict,
        )
        print(f"{layer_wise_weight.size()=}, {layer_wise_weight.numel()=}")
        return module

    @rank_zero_only
    def save_merging_weights(self, file_path: str, merging_weights: torch.Tensor):
        if self.fabric.is_global_zero and self.config.get(
            "save_merging_weights", False
        ):
            if isinstance(file_path, str) and not file_path.startswith(("/", ".")):
                # if the file path is not absolute or relative to current working directory, save it in the log directory
                save_path = os.path.join(self.log_dir, file_path)
            else:
                save_path = file_path
            log.info(f"saving merging weights to {save_path}.")
            if os.path.dirname(save_path):
                os.makedirs(os.path.dirname(save_path), exist_ok=True)
            torch.save(merging_weights.detach().cpu(), save_path)

    def run(self, modelpool: ModelPool):
        log.info("Fusing models using layer-wise adaptive merging.")
        self.modelpool = modelpool
        self.log_hyperparams(self.config)

        with self.profile("construct the wrapped model"):
            module = self.construct_layer_wise_merged_model(modelpool)

        if self.config.weights is not None:
            # skip the test-time adaptation
            return module.merge_and_unload()
        else:
            with self.profile("test-time adaptation"):
                module = self.test_time_adaptation(module)
            if self.config.get("save_merging_weights", False):
                self.save_merging_weights(
                    self.config.save_merging_weights, module.merge_weight
                )
            return module.merge_and_unload()

    def on_test_time_adaptation_start(self):
        """
        Something to do before the test-time adaptation starts. Such as setting up the task-specific heads.
        """
        pass

    @abstractmethod
    def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
        """
        Loader of test dataset for test-time adaptation. labels are not needed.
        """
        pass

    @abstractmethod
    def compute_logits(self, module, images: Tensor, task: str) -> Tensor:
        pass

    def test_time_adaptation(self, module: LayerWiseMergedModel):
        self.on_test_time_adaptation_start()
        config = self.config

        # configure optimizer
        if self.config.optimizer == "adam":
            optimizer = torch.optim.Adam([module.merge_weight], lr=self.config.lr)
            print(f"{optimizer=}")
            module, optimizer = self.fabric.setup(module, optimizer)
        else:
            raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")

        module.train()
        module.merge_weights()
        for step_idx in (
            pbar := tqdm(
                range(self.config.max_steps if not self.is_debug_mode else 1),
                ("[DEBUG MODE] " if self.is_debug_mode else "")
                + "AdaMerging Test-time adaptation",
                dynamic_ncols=True,
            )
        ):
            # default behavior for first-order optimizers
            for task in self.modelpool.model_names:
                with self.profile("data loading"):
                    batch = next(self.get_shuffled_test_loader_iter(task))
                with self.profile("forward pass"):
                    logits = self.compute_logits(module, batch[0], task)
                    loss = entropy_loss(logits)
                with self.profile("backward pass"):
                    self.fabric.backward(loss, retain_graph=True)

            with self.profile("optimizer step"):
                optimizer.step()
                optimizer.zero_grad()
            with self.profile("merging weights"):
                module.merge_weights()

            metrics = {
                "train/loss": loss.item(),
                "train/weight_max": module.merge_weight.max().item(),
                "train/weight_min": module.merge_weight.min().item(),
                "train/weight_mean": module.merge_weight.mean().item(),
            }
            self.fabric.log_dict(metrics, step=step_idx)
            pbar.set_postfix(metrics)

        self.print_profile_summary()
        return module
construct_layer_wise_merged_model(modelpool)

Constructs a wrapped layer-wise merged model from model pool.

This method creates a new wrapped model by merging the layers of a pretrained model with those of several fine-tuned models. The merging is controlled by layer-wise weights, which is a torch.Tensor of the shape (num_models, num_layers). The merging weights can be initialized based on a provided configuration or loaded from a file.

Parameters:

  • modelpool (ModelPool) –

    An object containing the pretrained model and fine-tuned models to be merged.

Returns:

  • LayerWiseMergedModel

    An instance of the merged model with layer-wise weights applied.

Source code in fusion_bench/method/adamerging/layer_wise_adamerging.py
@torch.no_grad()
def construct_layer_wise_merged_model(self, modelpool: ModelPool):
    """
    Constructs a wrapped layer-wise merged model from model pool.

    This method creates a new wrapped model by merging the layers of a pretrained model with those of several fine-tuned models.
    The merging is controlled by layer-wise weights, which is a `torch.Tensor` of the shape `(num_models, num_layers)`.
    The merging weights can be initialized based on a provided configuration or loaded from a file.

    Args:
        modelpool (ModelPool): An object containing the pretrained model and fine-tuned models to be merged.

    Returns:
        LayerWiseMergedModel: An instance of the merged model with layer-wise weights applied.
    """
    pretrained_model = modelpool.load_model("_pretrained_")
    finetuned_models = [
        modelpool.load_model(name) for name in modelpool.model_names
    ]

    # initialize layer-wise weights using the provided configuration `init_values` or load from file if `weights` is provided
    if self.config.weights is None:
        layer_wise_weight = get_layer_wise_weights(
            num_models=len(modelpool.model_names),
            num_layers=len(
                tuple(
                    filter(lambda p: p.requires_grad, pretrained_model.parameters())
                )
            ),
            init_values=self.config.init_values,
        )
    else:
        if isinstance(self.config.weights, str):
            # self.config.weights is a path to a saved tensor
            layer_wise_weight = load_tensor_from_file(self.config.weights)
        else:
            raise ValueError(f"Unsupported weights format: {self.config.weights}")

    module = LayerWiseMergedModel(
        layer_wise_weight=layer_wise_weight,
        pretrained_model=pretrained_model,
        finetuned_models=finetuned_models,
        clamp_weights=self.config.clamp_weights,
        tie_weights=self.config.tie_weights,
        strict=self.config.strict,
    )
    print(f"{layer_wise_weight.size()=}, {layer_wise_weight.numel()=}")
    return module
get_shuffled_test_loader_iter(task) abstractmethod

Loader of test dataset for test-time adaptation. labels are not needed.

Source code in fusion_bench/method/adamerging/layer_wise_adamerging.py
@abstractmethod
def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
    """
    Loader of test dataset for test-time adaptation. labels are not needed.
    """
    pass
on_test_time_adaptation_start()

Something to do before the test-time adaptation starts. Such as setting up the task-specific heads.

Source code in fusion_bench/method/adamerging/layer_wise_adamerging.py
def on_test_time_adaptation_start(self):
    """
    Something to do before the test-time adaptation starts. Such as setting up the task-specific heads.
    """
    pass

clip_layer_wise_adamerging

Example Usage:

fusion_bench     method=adamerging         method.name=clip_layer_wise_adamerging         method.save_merging_weights=merging_weights.pt     modelpool=clip-vit-base-patch32_TA8     taskpool=clip-vit-classification_TA8     fabric_logger.root_dir=outputs/logs/ViT-B-32     fabric_logger.name=clip_layer_wise_adamerging_adam
CLIPLayerWiseAdaMergingAlgorithm

Bases: CLIPClassificationMixin, LayerWiseAdaMergingAlgorithm

Source code in fusion_bench/method/adamerging/clip_layer_wise_adamerging.py
class CLIPLayerWiseAdaMergingAlgorithm(
    CLIPClassificationMixin,
    LayerWiseAdaMergingAlgorithm,
):
    def on_test_time_adaptation_start(self):
        """
        Here we load the CLIP processor and construct the zero-shot classification head for each task.
        """
        self.setup_zero_shot_classification_head()
on_test_time_adaptation_start()

Here we load the CLIP processor and construct the zero-shot classification head for each task.

Source code in fusion_bench/method/adamerging/clip_layer_wise_adamerging.py
def on_test_time_adaptation_start(self):
    """
    Here we load the CLIP processor and construct the zero-shot classification head for each task.
    """
    self.setup_zero_shot_classification_head()

  1. (ICLR 2024) AdaMerging: Adaptive Model Merging for Multi-Task Learning. https://openreview.net/pdf?id=nZP6NgD3QY 

  2. Jason Yosinski, Jeff Clune, Yoshua Bengio, and Hod Lipson. How transferable are features in deep neural networks? Advances in neural information processing systems, 27, 2014. 

  3. A. Tang, L. Shen, Y. Luo, N. Yin, L. Zhang, and D. Tao, “Merging Multi-Task Models via Weight-Ensembling Mixture of Experts,” ICML 2024. doi: 10.48550/arXiv.2402.00433.