Skip to content

AdaMerging

alt text
Task Vector, Task Arithmetic, and AdaMerging. Credit to 1

In the complex landscape of multi-task learning, AdaMerging has emerged as a potent method for adaptively merging model parameters to optimize performance across tasks. Unlike traditional fixed-coefficient methods, AdaMerging autonomously learns merging coefficients, offering a more refined and responsive approach1.

The cornerstone of AdaMerging lies in its adaptive nature, where it learns the coefficients for merging either on a task-wise or layer-wise basis. This adaptability is driven by an entropy minimization strategy applied to unlabeled test samples as a surrogate objective function, which serves to refine the merging coefficients for optimal performance.

Task-wise AdaMerging is formulated as:

\[ \theta = \theta_0 + \sum_{i=1}^{n} \lambda_i \tau_i \]

where \(\lambda_i\) represents the merging coefficient for the \(i\)-th task, and \(\tau_i\) denotes the task vector for the \(i\)-th task.

On the other hand, Layer-wise AdaMerging is articulated as:

\[ \theta^l = \theta_0^l + \sum_{i=1}^{n} \lambda^{l}_{i} \tau^{l}_{i} \]

where the merging coefficient \(\lambda^{l}_{i}\) and task vector \(\tau^{l}_{i}\) are specific to each layer \(l\) of the model.

By leveraging this adaptive learning approach, AdaMerging significantly enhances the model's ability to generalize across tasks and layers, resulting in a more robust and finely-tuned performance profile. The method’s reliance on entropy minimization ensures that the merging process continually seeks the most informative and stable configuration, adapting to the specific needs of the dataset and tasks at hand.

AdaMerging Analysis

Task-wise Coefficients. The below Figure shows the changes during the iteration process of merging coefficient optimization of each task vector in Task-wise AdaMerging and AdaMerging++, which is shown every ten steps. We consistently observe that the merging coefficients of each task vector are inconsistent. When the number of tasks is relatively large, it is obviously undesirable to grid search the coefficients of each task, but our AdaMerging avoids this manual search process.

alt text
Model merging coefficients \(\{λ_k\}_{k=1}^K\) change with respect to training steps on ViT-B/32:
(a) Task-wise AdaMerging; (b) Task-wise AdaMerging++. Each line represents the change process of the coefficient \(λ_k\) of a task vector \(T_k (k \in \{1, 2, . . . , K\})\).

Layer-wise Coefficients. The following Figure shows the merging coefficients learned by Layer-wise AdaMerging and AdaMerging++ on ViT-B/32 respectively. We observed that:

  1. The coefficients learned by each layer of each task vector are different, which shows that the importance of each layer in the model merging process is different.
  2. The coefficients learned by shallow layers are generally smaller than those of deep layers, which indicates that shallow layers rely more on the weights of the pre-trained model rather than the weights provided by task vectors, while the deep layers rely more on the weights provided by the task vectors. This may be since the shallow layer learns general features, which are cross-task, while the deep layer learns task-specific features 2. This finding is also consistent with routing analysis in 3.
alt text
Learned model merging coefficients \(\{λ_l^k\}^{K,L}_{k=1,l=1}\) of Layer-wise AdaMerging (Above) and AdaMerging++ (Below) on ViT-B/32. The \(k\)-th row represents the \(k\)-th task vector, the \(l\)-th column represents the \(l\)-th layer, and the intersection point represents the coefficient \(λ^l_k\).

Code Integration

Merge CLIP-ViT-B/32 models from eight downstream image classification tasks:

fusion_bench \
    method=adamerging \
        method.name=clip_layer_wise_adamerging \
        method.save_merging_weights=merging_weights.pt \
    modelpool=clip-vit-base-patch32_TA8 \
    taskpool=clip-vit-classification_TA8 \
    fabric.loggers.root_dir=outputs/logs/ViT-B-32 \
    fabric.loggers.name=clip_layer_wise_adamerging_adam

Part of the output:

Profiler Report

----------------------------------------------------------------------------------------------------------------------------------
|  Action                       |  Mean duration (s)    |  Num calls            |  Total time (s)       |  Percentage %         |
----------------------------------------------------------------------------------------------------------------------------------
|  Total                        |  -                    |  26001                |  724.65               |  100 %                |
----------------------------------------------------------------------------------------------------------------------------------
|  backward pass                |  0.060172             |  8000                 |  481.38               |  66.429               |
|  forward pass                 |  0.016124             |  8000                 |  128.99               |  17.801               |
|  data loading                 |  0.0063443            |  8000                 |  50.754               |  7.004                |
|  merging weights              |  0.050735             |  1000                 |  50.735               |  7.0013               |
|  construct the wrapped model  |  7.2558               |  1                    |  7.2558               |  1.0013               |
|  optimizer step               |  0.00098186           |  1000                 |  0.98186              |  0.13549              |
----------------------------------------------------------------------------------------------------------------------------------

Reference

Task-Wise AdaMerging

task_wise_adamerging

TaskWiseAdaMergingAlgorithm

Bases: ModelFusionAlgorithm

Source code in fusion_bench/method/adamerging/task_wise_adamerging.py
class TaskWiseAdaMergingAlgorithm(ModelFusionAlgorithm):
    _fabric: L.Fabric = None

    def __init__(self, algorithm_config: DictConfig):
        super().__init__(algorithm_config)

        if self._fabric is None and torch.cuda.is_available():
            self._fabric = L.Fabric(devices=self.config.get("devices", 1))
            self._fabric.launch()

    @torch.no_grad()
    def construct_task_wise_merged_model(self, modelpool: ModelPool):
        if self.config.weights is None:
            task_wise_weight = get_task_wise_weights(
                num_models=len(modelpool.model_names),
                init_values=self.config.init_values,
            )
        else:
            if isinstance(self.config.weights, str):
                # self.config.weights is a path to a .np or .pt file
                if self.config.weights.endswith(".pt"):
                    task_wise_weight = torch.load(
                        self.config.weights, map_location="cpu"
                    ).detach_()
                elif self.config.weights.endswith(".np"):
                    task_wise_weight = torch.from_numpy(
                        np.load(self.config.weights)
                    ).detach_()
                else:
                    raise ValueError(f"Unsupported file format: {self.config.weights}")
            else:
                try:
                    task_wise_weight = torch.tensor(
                        list(self.config.weights), dtype=torch.float32
                    )
                except ValueError:
                    raise ValueError(
                        f"Unsupported weights format: {self.config.weights}"
                    )

        pretrained_model = modelpool.load_model("_pretrained_")
        finetuned_models = [
            modelpool.load_model(name) for name in modelpool.model_names
        ]

        module = TaskWiseMergedModel(
            task_wise_weight=task_wise_weight,
            pretrained_model=pretrained_model,
            finetuned_models=finetuned_models,
            clamp_weights=self.config.clamp_weights,
            tie_weights=self.config.tie_weights,
            strict=self.config.strict,
        )
        return module

    def run(self, modelpool: ModelPool):
        log.info("Fusing models using task-wise adaptive merging.")
        self.modelpool = modelpool

        module = self.construct_task_wise_merged_model(modelpool)

        if self.config.weights is not None:
            # skip the test-time adaptation
            return module.merge_and_unload()
        else:
            module = self.test_time_adaptation(module)
            if self.config.get("save_merging_weights", False):
                torch.save(module.merge_weight, self.config.save_merging_weights)
            return module.merge_and_unload()

    def on_test_time_adaptation_start(self):
        pass

    @abstractmethod
    def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
        pass

    @abstractmethod
    def compute_logits(self, module: nn.Module, batch, task: str) -> Tensor:
        """
        Compute the logits for the given batch and task.

        Args:
            module (nn.Module): The model module.
            batch (tuple): A batch of input data.
            task (str): The name of the task.

        Returns:
            Tensor: The classification logits for the batch.
        """
        pass

    def test_time_adaptation(self, module: TaskWiseMergedModel):
        self.on_test_time_adaptation_start()

        # configure optimizer
        if self.config.optimizer == "adam":
            optimizer = torch.optim.Adam([module.merge_weight], lr=self.config.lr)
        else:
            raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")

        if self._fabric is not None:
            module, optimizer = self._fabric.setup(module, optimizer)

        module.train()
        module.merge_weights()

        if self.config.get("fast_dev_run", False):
            log.info("Running fast_dev_run, only one step")
            pbar = tqdm(
                range(1),
                "AdaMerging Test-time adaptation",
                dynamic_ncols=True,
            )
        else:
            pbar = tqdm(
                range(self.config.max_steps),
                "AdaMerging Test-time adaptation",
                dynamic_ncols=True,
            )
        for step_idx in pbar:
            for task in self.modelpool.model_names:
                batch = next(self.get_shuffled_test_loader_iter(task))
                logits = self.compute_logits(module, batch, task)
                assert (
                    logits.dim() == 2
                ), f"Expected logits to be 2D, got {logits.dim()}"
                loss = entropy_loss(logits)
                # .backward() accumulates when .zero_grad() wasn't called
                # this can save memory
                self._fabric.backward(loss, retain_graph=True)

            optimizer.step()
            optimizer.zero_grad()
            module.merge_weights()

        return module
compute_logits(module, batch, task) abstractmethod

Compute the logits for the given batch and task.

Parameters:

  • module (Module) –

    The model module.

  • batch (tuple) –

    A batch of input data.

  • task (str) –

    The name of the task.

Returns:

  • Tensor ( Tensor ) –

    The classification logits for the batch.

Source code in fusion_bench/method/adamerging/task_wise_adamerging.py
@abstractmethod
def compute_logits(self, module: nn.Module, batch, task: str) -> Tensor:
    """
    Compute the logits for the given batch and task.

    Args:
        module (nn.Module): The model module.
        batch (tuple): A batch of input data.
        task (str): The name of the task.

    Returns:
        Tensor: The classification logits for the batch.
    """
    pass
entropy_loss(logits)

Compute the entropy loss of a set of logits.

Parameters:

  • logits
    (Tensor) –

    The logits to compute the entropy loss of.

Returns:

  • Tensor ( Tensor ) –

    The entropy loss of the logits.

Source code in fusion_bench/method/adamerging/task_wise_adamerging.py
def entropy_loss(logits: Tensor) -> Tensor:
    """
    Compute the entropy loss of a set of logits.

    Args:
        logits (Tensor): The logits to compute the entropy loss of.

    Returns:
        Tensor: The entropy loss of the logits.
    """
    probs = torch.softmax(logits, dim=-1)
    return -torch.sum(probs * torch.log(probs + 1e-8), dim=-1).mean()

clip_task_wise_adamerging

CLIPTaskWiseAdaMergingAlgorithm

Bases: TaskWiseAdaMergingAlgorithm

A class for task-wise adaptive merging of CLIP models.

This class extends the TaskWiseAdaMergingAlgorithm to provide specific functionality for CLIP models, including loading datasets, constructing zero-shot classification heads, and computing logits.

Attributes:

  • modelpool (CLIPVisionModelPool) –

    The model pool containing CLIP models.

  • _clip_processor (CLIPProcessor) –

    The CLIP processor for preparing inputs.

  • zeroshot_weights (dict) –

    A dictionary to store zero-shot weights for each task.

Source code in fusion_bench/method/adamerging/clip_task_wise_adamerging.py
class CLIPTaskWiseAdaMergingAlgorithm(TaskWiseAdaMergingAlgorithm):
    """
    A class for task-wise adaptive merging of CLIP models.

    This class extends the TaskWiseAdaMergingAlgorithm to provide specific
    functionality for CLIP models, including loading datasets, constructing
    zero-shot classification heads, and computing logits.

    Attributes:
        modelpool (CLIPVisionModelPool): The model pool containing CLIP models.
        _clip_processor (CLIPProcessor): The CLIP processor for preparing inputs.
        zeroshot_weights (dict): A dictionary to store zero-shot weights for each task.
    """

    modelpool: CLIPVisionModelPool = None
    _clip_processor: CLIPProcessor = None
    zeroshot_weights = {}

    def __init__(self, algorithm_config: DictConfig):
        super().__init__(algorithm_config)

    @functools.cache
    def get_test_dataset(self, task: str):
        """
        Load the test dataset for the task.
        This method is cached, so the dataset is loaded only once.

        Args:
            task (str): The name of the task.

        Returns:
            CLIPDataset: The test dataset for the task.
        """
        log.info(f"Loading test dataset: {task}")
        dataset = self.modelpool.load_test_dataset(task)
        dataset = CLIPDataset(dataset, self._clip_processor)
        return dataset

    @functools.cache
    def get_shuffled_test_loader_iter(self, task: str):
        """
        Get an iterator over the shuffled test DataLoader for the task.

        Args:
            task (str): The name of the task.

        Returns:
            iterator: An iterator over the shuffled test DataLoader.
        """
        loader = DataLoader(
            self.get_test_dataset(task),
            batch_size=self.config.batch_size,
            shuffle=True,
            num_workers=self.config.num_workers,
            pin_memory=True,
        )
        if self._fabric is not None:
            loader = self._fabric.setup_dataloaders(loader)
        return iter(InfiniteDataLoader(loader))

    def on_test_time_adaptation_start(self):
        """
        Prepare for test-time adaptation.

        This method loads the CLIP processor and constructs the zero-shot
        classification head for each task.
        """
        clip_model_config = self.modelpool.get_model_config("_pretrained_")
        pretrained_path = (
            clip_model_config.pretrained_model_name_or_path
            if hasattr(clip_model_config, "pretrained_model_name_or_path")
            else clip_model_config.path
        )

        with timeit_context("Loading CLIP processor and pretrained CLIP model."):
            self._clip_processor = CLIPProcessor.from_pretrained(pretrained_path)
            clip_model: CLIPModel = CLIPModel.from_pretrained(pretrained_path)

            clip_classifier = HFCLIPClassifier(clip_model, self._clip_processor)
            self.visual_projection = clip_model.visual_projection.requires_grad_(False)
            self.logit_scale_exp = clip_model.logit_scale.exp()
            if self._fabric is not None:
                self.visual_projection = self._fabric.to_device(self.visual_projection)
                self.logit_scale_exp = self._fabric.to_device(self.logit_scale_exp)

        for task in self.modelpool.model_names:
            cache_file = os.path.join(
                self.config.cache_dir,
                f"{os.path.basename(pretrained_path)}_{task}_zeroshot_weights.pt",
            )
            if os.path.exists(cache_file):
                log.info(f"Loading cached zeroshot weights for task: {task}")
                zeroshot_weights = torch.load(cache_file, map_location="cpu")
            else:
                log.info(f"Construct zero shot classification head for task: {task}")
                classnames, templates = get_classnames_and_templates(task)
                clip_classifier.set_classification_task(classnames, templates)
                zeroshot_weights = clip_classifier.zeroshot_weights
                log.info(f"save zeroshot weights to {cache_file}")
                torch.save(zeroshot_weights, cache_file)
            self.zeroshot_weights[task] = zeroshot_weights
            if self._fabric is not None:
                self.zeroshot_weights[task] = self._fabric.to_device(
                    self.zeroshot_weights[task]
                )

    def compute_logits(self, module, batch, task: str) -> Tensor:
        """
        Compute the logits for the given batch and task.

        This method computes the image embeddings, normalizes them, and calculates
        the cosine similarity with the text embeddings to produce classification logits.

        Args:
            module (nn.Module): The model module.
            batch (tuple): A batch of input data.
            task (str): The name of the task.

        Returns:
            Tensor: The classification logits for the batch.
        """
        images, _ = batch
        text_embeds = self.zeroshot_weights[task]

        image_embeds = module(images)[1]
        image_embeds = self.visual_projection(image_embeds)

        # normalize embeddings
        image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)

        # cosine similarity
        logits_per_text = (
            torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
        )
        logits_per_image = logits_per_text.t()

        return logits_per_image
compute_logits(module, batch, task)

Compute the logits for the given batch and task.

This method computes the image embeddings, normalizes them, and calculates the cosine similarity with the text embeddings to produce classification logits.

Parameters:

  • module (Module) –

    The model module.

  • batch (tuple) –

    A batch of input data.

  • task (str) –

    The name of the task.

Returns:

  • Tensor ( Tensor ) –

    The classification logits for the batch.

Source code in fusion_bench/method/adamerging/clip_task_wise_adamerging.py
def compute_logits(self, module, batch, task: str) -> Tensor:
    """
    Compute the logits for the given batch and task.

    This method computes the image embeddings, normalizes them, and calculates
    the cosine similarity with the text embeddings to produce classification logits.

    Args:
        module (nn.Module): The model module.
        batch (tuple): A batch of input data.
        task (str): The name of the task.

    Returns:
        Tensor: The classification logits for the batch.
    """
    images, _ = batch
    text_embeds = self.zeroshot_weights[task]

    image_embeds = module(images)[1]
    image_embeds = self.visual_projection(image_embeds)

    # normalize embeddings
    image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)

    # cosine similarity
    logits_per_text = (
        torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
    )
    logits_per_image = logits_per_text.t()

    return logits_per_image
get_shuffled_test_loader_iter(task) cached

Get an iterator over the shuffled test DataLoader for the task.

Parameters:

  • task (str) –

    The name of the task.

Returns:

  • iterator

    An iterator over the shuffled test DataLoader.

Source code in fusion_bench/method/adamerging/clip_task_wise_adamerging.py
@functools.cache
def get_shuffled_test_loader_iter(self, task: str):
    """
    Get an iterator over the shuffled test DataLoader for the task.

    Args:
        task (str): The name of the task.

    Returns:
        iterator: An iterator over the shuffled test DataLoader.
    """
    loader = DataLoader(
        self.get_test_dataset(task),
        batch_size=self.config.batch_size,
        shuffle=True,
        num_workers=self.config.num_workers,
        pin_memory=True,
    )
    if self._fabric is not None:
        loader = self._fabric.setup_dataloaders(loader)
    return iter(InfiniteDataLoader(loader))
get_test_dataset(task) cached

Load the test dataset for the task. This method is cached, so the dataset is loaded only once.

Parameters:

  • task (str) –

    The name of the task.

Returns:

  • CLIPDataset

    The test dataset for the task.

Source code in fusion_bench/method/adamerging/clip_task_wise_adamerging.py
@functools.cache
def get_test_dataset(self, task: str):
    """
    Load the test dataset for the task.
    This method is cached, so the dataset is loaded only once.

    Args:
        task (str): The name of the task.

    Returns:
        CLIPDataset: The test dataset for the task.
    """
    log.info(f"Loading test dataset: {task}")
    dataset = self.modelpool.load_test_dataset(task)
    dataset = CLIPDataset(dataset, self._clip_processor)
    return dataset
on_test_time_adaptation_start()

Prepare for test-time adaptation.

This method loads the CLIP processor and constructs the zero-shot classification head for each task.

Source code in fusion_bench/method/adamerging/clip_task_wise_adamerging.py
def on_test_time_adaptation_start(self):
    """
    Prepare for test-time adaptation.

    This method loads the CLIP processor and constructs the zero-shot
    classification head for each task.
    """
    clip_model_config = self.modelpool.get_model_config("_pretrained_")
    pretrained_path = (
        clip_model_config.pretrained_model_name_or_path
        if hasattr(clip_model_config, "pretrained_model_name_or_path")
        else clip_model_config.path
    )

    with timeit_context("Loading CLIP processor and pretrained CLIP model."):
        self._clip_processor = CLIPProcessor.from_pretrained(pretrained_path)
        clip_model: CLIPModel = CLIPModel.from_pretrained(pretrained_path)

        clip_classifier = HFCLIPClassifier(clip_model, self._clip_processor)
        self.visual_projection = clip_model.visual_projection.requires_grad_(False)
        self.logit_scale_exp = clip_model.logit_scale.exp()
        if self._fabric is not None:
            self.visual_projection = self._fabric.to_device(self.visual_projection)
            self.logit_scale_exp = self._fabric.to_device(self.logit_scale_exp)

    for task in self.modelpool.model_names:
        cache_file = os.path.join(
            self.config.cache_dir,
            f"{os.path.basename(pretrained_path)}_{task}_zeroshot_weights.pt",
        )
        if os.path.exists(cache_file):
            log.info(f"Loading cached zeroshot weights for task: {task}")
            zeroshot_weights = torch.load(cache_file, map_location="cpu")
        else:
            log.info(f"Construct zero shot classification head for task: {task}")
            classnames, templates = get_classnames_and_templates(task)
            clip_classifier.set_classification_task(classnames, templates)
            zeroshot_weights = clip_classifier.zeroshot_weights
            log.info(f"save zeroshot weights to {cache_file}")
            torch.save(zeroshot_weights, cache_file)
        self.zeroshot_weights[task] = zeroshot_weights
        if self._fabric is not None:
            self.zeroshot_weights[task] = self._fabric.to_device(
                self.zeroshot_weights[task]
            )
InfiniteDataLoader

A wrapper class for DataLoader to create an infinite data loader. This is useful in case we are only interested in the number of steps and not the number of epochs.

This class wraps a DataLoader and provides an iterator that resets when the end of the dataset is reached, creating an infinite loop.

Attributes:

  • data_loader (DataLoader) –

    The DataLoader to wrap.

  • data_iter (iterator) –

    An iterator over the DataLoader.

Source code in fusion_bench/method/adamerging/clip_task_wise_adamerging.py
class InfiniteDataLoader:
    """
    A wrapper class for DataLoader to create an infinite data loader.
    This is useful in case we are only interested in the number of steps and not the number of epochs.

    This class wraps a DataLoader and provides an iterator that resets
    when the end of the dataset is reached, creating an infinite loop.

    Attributes:
        data_loader (DataLoader): The DataLoader to wrap.
        data_iter (iterator): An iterator over the DataLoader.
    """

    def __init__(self, data_loader):
        self.data_loader = data_loader
        self.data_iter = iter(data_loader)

    def __iter__(self):
        return self

    def __next__(self):
        try:
            data = next(self.data_iter)
        except StopIteration:
            self.data_iter = iter(self.data_loader)  # Reset the data loader
            data = next(self.data_iter)
        return data

Layer-Wise AdaMerging

layer_wise_adamerging

LayerWiseAdaMergingAlgorithm

Bases: ModelFusionAlgorithm, LightningFabricMixin, SimpleProfilerMixin

Source code in fusion_bench/method/adamerging/layer_wise_adamerging.py
class LayerWiseAdaMergingAlgorithm(
    ModelFusionAlgorithm,
    LightningFabricMixin,
    SimpleProfilerMixin,
):
    _program: "FabricModelFusionProgram"
    """The program that this algorithm is running on."""

    """
    Implements the Layer-Wise AdaMerging Algorithm.

    This class merges the layers of a pretrained model with those of several fine-tuned models.
    The merging is controlled by layer-wise weights, which can be initialized based on a provided configuration or loaded from a file.
    """

    def __init__(self, algorithm_config: DictConfig):
        """
        Initialize the LayerWiseAdaMergingAlgorithm with the given configuration.

        Args:
            algorithm_config (DictConfig): The configuration for the algorithm.
        """
        super().__init__(algorithm_config)

    @torch.no_grad()
    def construct_layer_wise_merged_model(self, modelpool: "ModelPool"):
        """
        Constructs a wrapped layer-wise merged model from model pool.

        This method creates a new wrapped model by merging the layers of a pretrained model with those of several fine-tuned models.
        The merging is controlled by layer-wise weights, which is a `torch.Tensor` of the shape `(num_models, num_layers)`.
        The merging weights can be initialized based on a provided configuration or loaded from a file.

        Args:
            modelpool (ModelPool): An object containing the pretrained model and fine-tuned models to be merged.

        Returns:
            LayerWiseMergedModel: An instance of the merged model with layer-wise weights applied.
        """
        pretrained_model = modelpool.load_model("_pretrained_")
        finetuned_models = [
            modelpool.load_model(name) for name in modelpool.model_names
        ]

        # initialize layer-wise weights using the provided configuration `init_values` or load from file if `weights` is provided
        if self.config.weights is None:
            layer_wise_weight = get_layer_wise_weights(
                num_models=len(modelpool.model_names),
                num_layers=len(
                    tuple(
                        filter(lambda p: p.requires_grad, pretrained_model.parameters())
                    )
                ),
                init_values=self.config.init_values,
            )
        else:
            if isinstance(self.config.weights, str):
                # self.config.weights is a path to a saved tensor
                layer_wise_weight = load_tensor_from_file(self.config.weights)
            else:
                raise ValueError(f"Unsupported weights format: {self.config.weights}")

        module = LayerWiseMergedModel(
            layer_wise_weight=layer_wise_weight,
            pretrained_model=pretrained_model,
            finetuned_models=finetuned_models,
            clamp_weights=self.config.clamp_weights,
            tie_weights=self.config.tie_weights,
            strict=self.config.strict,
        )
        print(f"{layer_wise_weight.size()=}, {layer_wise_weight.numel()=}")
        return module

    @rank_zero_only
    def save_merging_weights(self, file_path: str, merging_weights: torch.Tensor):
        """
        Save the merging weights to a file.

        Args:
            file_path (str): The path to save the merging weights.
            merging_weights (torch.Tensor): The merging weights to save.
        """
        if self.fabric.is_global_zero and self.config.get(
            "save_merging_weights", False
        ):
            if isinstance(file_path, str) and not file_path.startswith(("/", ".")):
                # if the file path is not absolute or relative to current working directory, save it in the log directory
                save_path = os.path.join(self.log_dir, file_path)
            else:
                save_path = file_path
            log.info(f"saving merging weights to {save_path}.")
            if os.path.dirname(save_path):
                os.makedirs(os.path.dirname(save_path), exist_ok=True)
            torch.save(merging_weights.detach().cpu(), save_path)

    def run(self, modelpool: ModelPool, **kwargs):
        """
        Run the Layer-Wise AdaMerging Algorithm.

        This method constructs the wrapped model and performs test-time adaptation if necessary.

        Args:
            modelpool (ModelPool): The model pool containing the pretrained and fine-tuned models.

        Returns:
            LayerWiseMergedModel: The merged model after test-time adaptation.
        """
        log.info("Fusing models using layer-wise adaptive merging.")
        self.modelpool = modelpool
        self.log_hyperparams(self.config)

        with self.profile("construct the wrapped model"):
            module = self.construct_layer_wise_merged_model(modelpool)

        if self.config.weights is not None:
            # skip the test-time adaptation
            return module.merge_and_unload()
        else:
            with self.profile("test-time adaptation"):
                module = self.test_time_adaptation(module)
            if self.config.get("save_merging_weights", False):
                self.save_merging_weights(
                    self.config.save_merging_weights, module.merge_weight
                )
            return module.merge_and_unload()

    def on_test_time_adaptation_start(self):
        """
        Something to do before the test-time adaptation starts. Such as setting up the task-specific heads.
        """
        pass

    @abstractmethod
    def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
        """
        Loader of test dataset for test-time adaptation. labels are not needed.

        Args:
            task (str): The name of the task.

        Returns:
            DataLoader: The data loader for the test dataset.
        """
        pass

    @abstractmethod
    def compute_logits(self, module, images: Tensor, task: str) -> Tensor:
        """
        Compute the logits for the given images and task.

        Args:
            module: The model module.
            images (Tensor): The input images.
            task (str): The name of the task.

        Returns:
            Tensor: The computed logits.
        """
        pass

    def test_time_adaptation(self, module: "LayerWiseMergedModel[TorchModelType]"):
        """
        Perform test-time adaptation on the merged model.

        This method adapts the merging weights during test-time to improve performance.

        Args:
            module (LayerWiseMergedModel): The merged model.

        Returns:
            LayerWiseMergedModel: The adapted merged model.
        """
        self.on_test_time_adaptation_start()

        # configure optimizer
        if self.config.optimizer == "adam":
            optimizer = torch.optim.Adam([module.merge_weight], lr=self.config.lr)
            print(f"{optimizer=}")
            module, optimizer = self.fabric.setup(module, optimizer)
        else:
            raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")

        module.train()
        module.merge_weights()
        for step_idx in (
            pbar := tqdm(
                range(self.config.max_steps if not self.is_debug_mode else 1),
                ("[DEBUG MODE] " if self.is_debug_mode else "")
                + "AdaMerging Test-time adaptation",
                dynamic_ncols=True,
            )
        ):
            # default behavior for first-order optimizers
            for task in self.modelpool.model_names:
                with self.profile("data loading"):
                    batch = next(self.get_shuffled_test_loader_iter(task))
                with self.profile("forward pass"):
                    logits = self.compute_logits(module, batch[0], task)
                    loss = entropy_loss(logits)
                with self.profile("backward pass"):
                    self.fabric.backward(loss, retain_graph=True)

            with self.profile("optimizer step"):
                optimizer.step()
                optimizer.zero_grad()
            with self.profile("merging weights"):
                module.merge_weights()

            metrics = {
                "train/loss": loss.item(),
                "train/weight_max": module.merge_weight.max().item(),
                "train/weight_min": module.merge_weight.min().item(),
                "train/weight_mean": module.merge_weight.mean().item(),
            }
            self.fabric.log_dict(metrics, step=step_idx)
            pbar.set_postfix(metrics)

        log.info(get_memory_usage(f"after adamerging, the memory usage of GPU is:"))
        self.print_profile_summary()
        return module
__init__(algorithm_config)

Initialize the LayerWiseAdaMergingAlgorithm with the given configuration.

Parameters:

  • algorithm_config (DictConfig) –

    The configuration for the algorithm.

Source code in fusion_bench/method/adamerging/layer_wise_adamerging.py
def __init__(self, algorithm_config: DictConfig):
    """
    Initialize the LayerWiseAdaMergingAlgorithm with the given configuration.

    Args:
        algorithm_config (DictConfig): The configuration for the algorithm.
    """
    super().__init__(algorithm_config)
compute_logits(module, images, task) abstractmethod

Compute the logits for the given images and task.

Parameters:

  • module

    The model module.

  • images (Tensor) –

    The input images.

  • task (str) –

    The name of the task.

Returns:

  • Tensor ( Tensor ) –

    The computed logits.

Source code in fusion_bench/method/adamerging/layer_wise_adamerging.py
@abstractmethod
def compute_logits(self, module, images: Tensor, task: str) -> Tensor:
    """
    Compute the logits for the given images and task.

    Args:
        module: The model module.
        images (Tensor): The input images.
        task (str): The name of the task.

    Returns:
        Tensor: The computed logits.
    """
    pass
construct_layer_wise_merged_model(modelpool)

Constructs a wrapped layer-wise merged model from model pool.

This method creates a new wrapped model by merging the layers of a pretrained model with those of several fine-tuned models. The merging is controlled by layer-wise weights, which is a torch.Tensor of the shape (num_models, num_layers). The merging weights can be initialized based on a provided configuration or loaded from a file.

Parameters:

  • modelpool (ModelPool) –

    An object containing the pretrained model and fine-tuned models to be merged.

Returns:

  • LayerWiseMergedModel

    An instance of the merged model with layer-wise weights applied.

Source code in fusion_bench/method/adamerging/layer_wise_adamerging.py
@torch.no_grad()
def construct_layer_wise_merged_model(self, modelpool: "ModelPool"):
    """
    Constructs a wrapped layer-wise merged model from model pool.

    This method creates a new wrapped model by merging the layers of a pretrained model with those of several fine-tuned models.
    The merging is controlled by layer-wise weights, which is a `torch.Tensor` of the shape `(num_models, num_layers)`.
    The merging weights can be initialized based on a provided configuration or loaded from a file.

    Args:
        modelpool (ModelPool): An object containing the pretrained model and fine-tuned models to be merged.

    Returns:
        LayerWiseMergedModel: An instance of the merged model with layer-wise weights applied.
    """
    pretrained_model = modelpool.load_model("_pretrained_")
    finetuned_models = [
        modelpool.load_model(name) for name in modelpool.model_names
    ]

    # initialize layer-wise weights using the provided configuration `init_values` or load from file if `weights` is provided
    if self.config.weights is None:
        layer_wise_weight = get_layer_wise_weights(
            num_models=len(modelpool.model_names),
            num_layers=len(
                tuple(
                    filter(lambda p: p.requires_grad, pretrained_model.parameters())
                )
            ),
            init_values=self.config.init_values,
        )
    else:
        if isinstance(self.config.weights, str):
            # self.config.weights is a path to a saved tensor
            layer_wise_weight = load_tensor_from_file(self.config.weights)
        else:
            raise ValueError(f"Unsupported weights format: {self.config.weights}")

    module = LayerWiseMergedModel(
        layer_wise_weight=layer_wise_weight,
        pretrained_model=pretrained_model,
        finetuned_models=finetuned_models,
        clamp_weights=self.config.clamp_weights,
        tie_weights=self.config.tie_weights,
        strict=self.config.strict,
    )
    print(f"{layer_wise_weight.size()=}, {layer_wise_weight.numel()=}")
    return module
get_shuffled_test_loader_iter(task) abstractmethod

Loader of test dataset for test-time adaptation. labels are not needed.

Parameters:

  • task (str) –

    The name of the task.

Returns:

  • DataLoader ( DataLoader ) –

    The data loader for the test dataset.

Source code in fusion_bench/method/adamerging/layer_wise_adamerging.py
@abstractmethod
def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
    """
    Loader of test dataset for test-time adaptation. labels are not needed.

    Args:
        task (str): The name of the task.

    Returns:
        DataLoader: The data loader for the test dataset.
    """
    pass
on_test_time_adaptation_start()

Something to do before the test-time adaptation starts. Such as setting up the task-specific heads.

Source code in fusion_bench/method/adamerging/layer_wise_adamerging.py
def on_test_time_adaptation_start(self):
    """
    Something to do before the test-time adaptation starts. Such as setting up the task-specific heads.
    """
    pass
run(modelpool, **kwargs)

Run the Layer-Wise AdaMerging Algorithm.

This method constructs the wrapped model and performs test-time adaptation if necessary.

Parameters:

  • modelpool (ModelPool) –

    The model pool containing the pretrained and fine-tuned models.

Returns:

  • LayerWiseMergedModel

    The merged model after test-time adaptation.

Source code in fusion_bench/method/adamerging/layer_wise_adamerging.py
def run(self, modelpool: ModelPool, **kwargs):
    """
    Run the Layer-Wise AdaMerging Algorithm.

    This method constructs the wrapped model and performs test-time adaptation if necessary.

    Args:
        modelpool (ModelPool): The model pool containing the pretrained and fine-tuned models.

    Returns:
        LayerWiseMergedModel: The merged model after test-time adaptation.
    """
    log.info("Fusing models using layer-wise adaptive merging.")
    self.modelpool = modelpool
    self.log_hyperparams(self.config)

    with self.profile("construct the wrapped model"):
        module = self.construct_layer_wise_merged_model(modelpool)

    if self.config.weights is not None:
        # skip the test-time adaptation
        return module.merge_and_unload()
    else:
        with self.profile("test-time adaptation"):
            module = self.test_time_adaptation(module)
        if self.config.get("save_merging_weights", False):
            self.save_merging_weights(
                self.config.save_merging_weights, module.merge_weight
            )
        return module.merge_and_unload()
save_merging_weights(file_path, merging_weights)

Save the merging weights to a file.

Parameters:

  • file_path (str) –

    The path to save the merging weights.

  • merging_weights (Tensor) –

    The merging weights to save.

Source code in fusion_bench/method/adamerging/layer_wise_adamerging.py
@rank_zero_only
def save_merging_weights(self, file_path: str, merging_weights: torch.Tensor):
    """
    Save the merging weights to a file.

    Args:
        file_path (str): The path to save the merging weights.
        merging_weights (torch.Tensor): The merging weights to save.
    """
    if self.fabric.is_global_zero and self.config.get(
        "save_merging_weights", False
    ):
        if isinstance(file_path, str) and not file_path.startswith(("/", ".")):
            # if the file path is not absolute or relative to current working directory, save it in the log directory
            save_path = os.path.join(self.log_dir, file_path)
        else:
            save_path = file_path
        log.info(f"saving merging weights to {save_path}.")
        if os.path.dirname(save_path):
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
        torch.save(merging_weights.detach().cpu(), save_path)
test_time_adaptation(module)

Perform test-time adaptation on the merged model.

This method adapts the merging weights during test-time to improve performance.

Parameters:

  • module (LayerWiseMergedModel) –

    The merged model.

Returns:

  • LayerWiseMergedModel

    The adapted merged model.

Source code in fusion_bench/method/adamerging/layer_wise_adamerging.py
def test_time_adaptation(self, module: "LayerWiseMergedModel[TorchModelType]"):
    """
    Perform test-time adaptation on the merged model.

    This method adapts the merging weights during test-time to improve performance.

    Args:
        module (LayerWiseMergedModel): The merged model.

    Returns:
        LayerWiseMergedModel: The adapted merged model.
    """
    self.on_test_time_adaptation_start()

    # configure optimizer
    if self.config.optimizer == "adam":
        optimizer = torch.optim.Adam([module.merge_weight], lr=self.config.lr)
        print(f"{optimizer=}")
        module, optimizer = self.fabric.setup(module, optimizer)
    else:
        raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")

    module.train()
    module.merge_weights()
    for step_idx in (
        pbar := tqdm(
            range(self.config.max_steps if not self.is_debug_mode else 1),
            ("[DEBUG MODE] " if self.is_debug_mode else "")
            + "AdaMerging Test-time adaptation",
            dynamic_ncols=True,
        )
    ):
        # default behavior for first-order optimizers
        for task in self.modelpool.model_names:
            with self.profile("data loading"):
                batch = next(self.get_shuffled_test_loader_iter(task))
            with self.profile("forward pass"):
                logits = self.compute_logits(module, batch[0], task)
                loss = entropy_loss(logits)
            with self.profile("backward pass"):
                self.fabric.backward(loss, retain_graph=True)

        with self.profile("optimizer step"):
            optimizer.step()
            optimizer.zero_grad()
        with self.profile("merging weights"):
            module.merge_weights()

        metrics = {
            "train/loss": loss.item(),
            "train/weight_max": module.merge_weight.max().item(),
            "train/weight_min": module.merge_weight.min().item(),
            "train/weight_mean": module.merge_weight.mean().item(),
        }
        self.fabric.log_dict(metrics, step=step_idx)
        pbar.set_postfix(metrics)

    log.info(get_memory_usage(f"after adamerging, the memory usage of GPU is:"))
    self.print_profile_summary()
    return module

clip_layer_wise_adamerging

Example Usage:

fusion_bench     method=adamerging         method.name=clip_layer_wise_adamerging         method.save_merging_weights=merging_weights.pt     modelpool=clip-vit-base-patch32_TA8     taskpool=clip-vit-classification_TA8     fabric.loggers.root_dir=outputs/logs/ViT-B-32     fabric.loggers.name=clip_layer_wise_adamerging_adam
CLIPLayerWiseAdaMergingAlgorithm

Bases: CLIPClassificationMixin, LayerWiseAdaMergingAlgorithm

Source code in fusion_bench/method/adamerging/clip_layer_wise_adamerging.py
class CLIPLayerWiseAdaMergingAlgorithm(
    CLIPClassificationMixin,
    LayerWiseAdaMergingAlgorithm,
):
    def on_test_time_adaptation_start(self):
        """
        Here we load the CLIP processor and construct the zero-shot classification head for each task.
        """
        self.setup_zero_shot_classification_head()

    @functools.cache
    def get_shuffled_test_loader_iter(self, task: str):
        return super().get_shuffled_test_loader_iter(
            task,
            batch_size=self.config.batch_size,
            num_workers=self.config.num_workers,
        )
on_test_time_adaptation_start()

Here we load the CLIP processor and construct the zero-shot classification head for each task.

Source code in fusion_bench/method/adamerging/clip_layer_wise_adamerging.py
def on_test_time_adaptation_start(self):
    """
    Here we load the CLIP processor and construct the zero-shot classification head for each task.
    """
    self.setup_zero_shot_classification_head()

  1. (ICLR 2024) AdaMerging: Adaptive Model Merging for Multi-Task Learning. https://openreview.net/pdf?id=nZP6NgD3QY 

  2. Jason Yosinski, Jeff Clune, Yoshua Bengio, and Hod Lipson. How transferable are features in deep neural networks? Advances in neural information processing systems, 27, 2014. 

  3. A. Tang, L. Shen, Y. Luo, N. Yin, L. Zhang, and D. Tao, “Merging Multi-Task Models via Weight-Ensembling Mixture of Experts,” ICML 2024. doi: 10.48550/arXiv.2402.00433.