Skip to content

Weight-Ensembling Mixture of Experts (Data-Adaptive Model Merging)

arxiv github

alt text
(a) Framework overview. This figure shows the overall framework of our proposed method to merge the pre-trained model and fine-tuned task-specific models. We merge weights in the Transformer Layers except for the MLPs. For the MLPs, we upcycle them into weight-assembling MoE modules. (b) Wieght-Ensembling Mixture of Experts (MoE) Module. Here we outline the detailed structure of the Weight-Ensembling MoE module, composed of the router, pre-trained MLP weights, and a collection of task vectors. Collaboration between shared weights and task vectors is employed to create input-conditioned weights dynamically. In this way, we separate shared information and task-specific knowledge, which are then combined based on input in time.

This method is designed to handle a wide range of tasks by segregating shared information and task-specific knowledge. It dynamically combines these elements based on the input samples.

The Weight-Ensembling MoE module consists of three main components: the router, the pre-trained MLP weights, and a collection of task vectors. The router, which is an MLP, processes the input data and generates routing weights. These weights determine how the knowledge from different tasks is combined. The pre-trained MLP weights are crucial as they have been trained to recognize a wide range of data patterns. The task vectors represent the differences between the MLPs that have been fine-tuned for specific tasks and the pre-trained ones, capturing the unique adjustments made to optimize them for specific tasks. The routing weights are averaged across the input tokens, and these weights are used to select task vectors from a dictionary matrix. These task vectors are then added to the pre-trained MLP weights to create input-conditioned weights.

Algorithm Requirements:

Method Access to labeled tasks data Access to validation data (labeled) Test time adaptation
Fisher Merging Yes (Estimate Fisher information matrix) No No
RegMean Yes (compute Gram Matrix) No No
Task Arithmetic No Yes (select sacling factor) No
Ties-Merging No Yes (select sacling factor) No
AdaMerging No No Yes
Ours No No Yes

WEMoE V2: E-WEMoE

L. Shen, A. Tang, E. Yang et al. Efficient and Effective Weight-Ensembling Mixture of Experts for Multi-Task Model Merging. Oct, 2024.3

arXiv github

alt text (a) Overview of the Efficient Weight-Ensembling Mixture of Experts (E-WEMoE) Framework. It merges all non-MLP modules through task arithmetic and upgrades the MLP modules into an efficient E-WEMoE module. (b) E-WEMoE Module. The module includes a router shared across all Transformer blocks, the pre-trained MLP module, and a set of sparse task-specific vectors w.r.t. MLP modules.
alt text Comparison of (a) trainable parameters and (b) total parameters between WEMoE and E-WEMoE-90%.
alt text Comparison of the relationship between parameter count and performance across various model merging methods.

Parameters Comparison

Tip for reducing the parameter count

Here we present the parameter count for the method outlined in the original paper1. An effective strategy to minimize the number of parameters involves employing Singular Value Decomposition (SVD) to compress the task vectors. This approach significantly cuts down on the number of parameters while only marginally impacting performance. For additional information, please refer to the Twin-Merging paper2. Which not only reduces the number of parameters but also conducts extensive experiments to demonstrate the effectiveness of data-adaptive merging on language domain.

Here is the number of parameters compared to a single pre-trained model (OpenCLIP CLIP-ViT-B/32):

Method Trainable Parameters Total Parameters Paremeters Reduced by Merging
Single Pre-trained 113.45M (100%) 113.45M -
WEMoE (2-layer, 1 task) 7.10M (4.00%) 177.21M -
WEMoE (2-layer, 2 tasks) 7.11M (3.04%) 233.89M 2*113.45-233.89=-6.99M
WEMoE (2-layer, 3 tasks) 7.11M (2.45%) 290.57M 3*113.45-290.57=49.78M
WEMoE (2-layer, 4 tasks) 7.12M (2.02%) 347.25M 4*113.45-347.25=106.55M
WEMoE (2-layer, 5 tasks) 7.13M (1.77%) 403.93M 5*113.45-403.93=163.32M
WEMoE (2-layer, 6 tasks) 7.14M (1.55%) 460.61M 6*113.45-460.61=220.09M
WEMoE (2-layer, 7 tasks) 7.15M (1.38%) 517.28M 7*113.45-517.28=276.87M
WEMoE (2-layer, 8 tasks) 7.16M (1.25%) 573.96M 8*113.45-573.96=333.64M

The number of parameter count of HuggingFace CLIP vision models (of type transformers.models.clip.modeling_clip.CLIPVisionModel) are different from the OpenCLIP models downloaded from the task arithmetic repo, because the OpenCLIP models (of type src.modeling.ImageEncoder) include the embedding layer for text tokens, while the HuggingFace CLIP vision models do not. Therefore, the relative parameter count of the upscaled model using Transformer CLIP vision models will be larger than the OpenCLIP models.

ImageEncoder( # (1)
  (model): CLIP(
    (visual): VisualTransformer( # (2)
      (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
      (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (transformer): Transformer(
        (resblocks): ModuleList(
          (0-11): 12 x ResidualAttentionBlock(
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (ln_attn): Identity()
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (ln): Identity()
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
          )
        )
      )
      (ln_post): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (token_embedding): Embedding(49408, 512) # (3)
    (ln_final): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
)
  1. trainable params: 113.45M || all params: 113.45M || trainable%: 100.0000
  2. trainable params: 87.85M || all params: 87.85M || trainable%: 100.0000
  3. trainable params: 25.30M || all params: 25.30M || trainable%: 100.0000
CLIPVisionModel( # (1)
  (vision_model): CLIPVisionTransformer(
    (embeddings): CLIPVisionEmbeddings(
      (patch_embedding): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
      (position_embedding): Embedding(50, 768)
    )
    (pre_layrnorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
          )
          (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
    (post_layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
)
  1. trainable params: 87.85M || all params: 87.85M || trainable%: 100.0000

Loss Landscape Visualization

alt text
Visualization of the joint loss \(\mathcal{L}_1 + \mathcal{L}_2\) and five task pairs for CLIP-ViT-B/32 in the loss landscape. We perform interpolations between pre-trained weights and two fine-tuned weights in the weight space on a 2D plane using the formula \(\theta=\theta_0 + \lambda_1 \tau_1 + \lambda_2 \tau_2\), where \(\theta_0\) represents pre-trained weights, \(\tau_i=\theta_i -\theta_0\) are two task vectors with \(\lambda_i\) in the range [-1, 1].

Hyperparameter Tuning

In the below figure, we show the performance of the merged models with varying numbers of steps. Figure (b) shows the performance of the merged WEMoE models with varying number of steps. In Figure (a), we merge CLIP-ViT-B/32 models with different learning rate configurations. We observe that the performance of the merged model shows an upward trend with an increase in the number of training steps, and it converges rapidly, reaching a high accuracy level in just 200 steps. Furthermore, the influence of different learning rates is not significant, suggesting that our method is insensitive to the learning rate parameter. This is a desirable property as it reduces the need for hyperparameter tuning.

alt text
The performance of the merged models with a varying number of steps.
(a) CLIP-ViT-B/32 model with different learning rates.
(b) Comparison of CLIP-ViT-B/32 and CLIP-ViT-L/14.

Ablations of Router Depth

Table: Parameter comparison of WEMoE (1-layer) and WEMoE (2-layer) on CLIP-ViT-B/32 models (OpenCLIP).

Method Number of Trainable Parameters
AdaMerging (layer-wise) 1.3K
WEMoE (1-layer) 73.8K (0.01%)
WEMoE (2-layer) 7.16M (1.25%)

Table: Ablation study of the router depth on the performance of the up-scaled CLIP-ViT-B/32 models (OpenCLIP).

Method SUN397 CARS RESISC45 EuroSAT SVHN GRSRB MNIST DTD Avg.
AdaMerging (layer-wise) 66.6 68.3 82.4 92.5 86.5 93.7 97.7 61.1 80.9
WEMoE (1-layer) 73.2 76.7 93.8 98.6 95.7 98.6 99.5 74.5 88.3
WEMoE (2-layer) 74.1 77.4 93.7 99.1 96.2 98.9 99.6 76.4 89.4

To explore the influence of router depth on the performance of the scaled-up model, we perform an ablation study where the router depth is varied. In WEMoE modules, the router is implemented as a multi-layer perceptron (MLP).

  • WEMoE (0-layer) functions as a bias-only model, representing a special case of an MLP with no hidden layers. It generates a constant routing weight for all inputs, captured by the formula as \(r(h) = b_0\), indicating that it does not adjust based on the input. When we only up-scale the MLP modules of the vision Transformers to MoE modules, WEMoE (0-layer) can be considered as a partial implementation of AdaMerging. Add when we up-scale the vision Transformers layer-wisely, WEMoE (0-layer) can be considered equivalent to AdaMerging. For WEMoE (0-layer), the MoE modules can be unloaded, thus no additional parameters and inference cost are introduced.
  • For WEMoE (1-layer), each router is a one-layer MLP that takes the input sample \(h\) and outputs the routing weight \(r(h)\), which is adaptive to the input. The routing weight is calculated as \(r(h) = W_1 h + b_1\).
  • For WEMoE (2-layer), each router is a two-layer MLP and the routing weight is calculated as \(r(h) = W_2 ReLU(W_1 h + b_1) + b_2\).

In the above two Tables, we present additional findings to support our argument. We compare the number of trainable parameters and performance between WEMoE (1-layer) and WEMoE (2-layer). The data reveal that WEMoE (1-layer) possesses 73.8K trainable parameters, which constitute only 0.01% of the total parameters in the merged model. Notably, the performance of WEMoE (1-layer) is significantly better than AdaMerging and nearly matches that of WEMoE (2-layer) across all tasks. This evidence underscores our claim that the MoE design is crucial for performance enhancement.

Code Integration

multi-task model fusion experiment on eight image classification tasks.

# merge eight CLIP-ViT-B/32 models using WE MoE
fusion_bench \
  method=weight_ensembling_moe \
    method.name=clip_weight_ensembling_moe \
    method.use_grad_accumulate=false \
    method.save_checkpoint=outputs/clip-vit-base-patch32_TA8_weight_ensembling_moe_checkpoint.ckpt \
  modelpool=clip-vit-base-patch32_TA8 \
  taskpool=clip-vit-classification_TA8

merge eight CLIP-ViT-L/14 models:

# merge eight CLIP-ViT-L/14 models using WE MoE, fine-tune the routers
fusion_bench print_config=false \
  method=weight_ensembling_moe \
    method.name=clip_weight_ensembling_moe \
    method.use_grad_accumulate=true \
    method.save_checkpoint=outputs/clip-vit-large-patch14_TA8_weight_ensembling_moe_checkpoint.ckpt \
    method.batch_size=4 method.devices=4 \
  modelpool=clip-vit-large-patch14_TA8 \
  taskpool=dummy &&

# load the checkpoint and evaluate the model
fusion_bench \
  method=weight_ensembling_moe \
    method.name=clip_weight_ensembling_moe \
    method.checkpoint=outputs/clip-vit-large-patch14_TA8_weight_ensembling_moe_checkpoint.ckpt \
  modelpool=clip-vit-large-patch14_TA8 \
  taskpool=clip-vit-classification_TA8 \
    taskpool.clip_model=openai/clip-vit-large-patch14

Reference

we_moe

WeightEnsemblingMoEAlgorithm

Bases: ModelFusionAlgorithm

Algorithm for fusing models using Weight Ensembling Mixture of Experts (MoE).

This class provides methods for constructing the MoE model, performing test-time adaptation, and running the fusion process.

Attributes:

  • _fabric (Fabric) –

    The fabric for distributed training.

  • modelpool (ModelPool) –

    The pool of models to be fused.

  • profiler (SimpleProfiler) –

    The profiler for measuring performance.

Source code in fusion_bench/method/we_moe/we_moe.py
class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
    """
    Algorithm for fusing models using Weight Ensembling Mixture of Experts (MoE).

    This class provides methods for constructing the MoE model, performing test-time adaptation,
    and running the fusion process.

    Attributes:
        _fabric (L.Fabric): The fabric for distributed training.
        modelpool (ModelPool): The pool of models to be fused.
        profiler (SimpleProfiler): The profiler for measuring performance.
    """

    _fabric: L.Fabric = None
    modelpool: ModelPool = None

    def __init__(self, algorithm_config: DictConfig):
        """
        Initialize the WeightEnsemblingMoEAlgorithm with the given configuration.

        Args:
            algorithm_config (DictConfig): The configuration for the algorithm.
        """
        super().__init__(algorithm_config)

        if self._fabric is None and torch.cuda.is_available():
            self._fabric = L.Fabric(
                devices=self.config.get("devices", 1),
            )
            self._fabric.launch()
        else:
            assert "No CUDA device available."
        self.profiler = SimpleProfiler(
            self.config.get("cache_dir", "outputs"), "we_moe_profiler.txt"
        )

    @abstractmethod
    def load_checkpoint(self, model, checkpoint):
        """
        Load the checkpoint file.

        Args:
            model: The model to load the checkpoint into.
            checkpoint: The checkpoint file to load.
        """
        pass

    @abstractmethod
    def save_checkpoint(self, model, checkpoint):
        """
        Save the checkpoint file.

        Args:
            model: The model to save the checkpoint from.
            checkpoint: The checkpoint file to save.
        """
        pass

    @abstractmethod
    def construct_moe_model(self) -> WeightEnsemblingMoE:
        """
        Construct the Mixture of Experts model using the models in the model pool.

        Returns:
            WeightEnsemblingMoE: The constructed MoE model.
        """
        pass

    def on_test_time_adaptation_start(self):
        """
        Hook method called at the start of test-time adaptation.
        """
        pass

    @abstractmethod
    def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
        """
        Get an iterator for the shuffled test data loader for a specific task.

        Args:
            task (str): The task for which to get the test data loader.

        Returns:
            DataLoader: The shuffled test data loader iterator.
        """
        pass

    @abstractmethod
    def compute_logits(self, module, batch, task) -> Tensor:
        """
        Compute the logits for a given batch and task.

        Args:
            module: The model module to use for computing logits.
            batch: The batch of data.
            task: The task for which to compute logits.

        Returns:
            Tensor: The computed logits.
        """
        pass

    def test_time_adaptation(self, module: WeightEnsemblingMoE):
        """
        Perform test-time adaptation for the given module.

        Args:
            module (WeightEnsemblingMoE): The MoE module to adapt.

        Returns:
            WeightEnsemblingMoE: The adapted MoE module.
        """
        self.on_test_time_adaptation_start()

        # configure optimizer
        if self.config.optimizer == "adam":
            optimizer = torch.optim.Adam(
                [p for p in module.parameters() if p.requires_grad], lr=self.config.lr
            )
        else:
            raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")

        if self._fabric is not None:
            module, optimizer = self._fabric.setup(module, optimizer)

        module.train()

        if self.config.get("fast_dev_run", False):
            log.info("Running fast_dev_run, only one step")
            pbar = tqdm(
                range(1),
                "Test-time adaptation",
                dynamic_ncols=True,
            )
        else:
            pbar = tqdm(
                range(self.config.max_steps),
                "Test-time adaptation",
                dynamic_ncols=True,
            )
        for step_idx in pbar:
            if self.config.use_grad_accumulate:
                for task in self.modelpool.model_names:
                    with self.profiler.profile("data time"):
                        batch = next(self.get_shuffled_test_loader_iter(task))
                    with self.profiler.profile("forward pass"):
                        logits = self.compute_logits(module, batch, task)
                        assert (
                            logits.dim() == 2
                        ), f"Expected logits to be 2D, got {logits.dim()}"
                        loss = entropy_loss(logits)
                    # .backward() accumulates when .zero_grad() wasn't called
                    # this can save memory
                    with self.profiler.profile("backward pass"):
                        self._fabric.backward(loss, retain_graph=True)
            else:
                loss = 0
                for task in self.modelpool.model_names:
                    with self.profiler.profile("data time"):
                        batch = next(self.get_shuffled_test_loader_iter(task))
                    with self.profiler.profile("forward pass"):
                        logits = self.compute_logits(module, batch, task)
                        assert (
                            logits.dim() == 2
                        ), f"Expected logits to be 2D, got {logits.dim()}"
                        loss = loss + entropy_loss(logits)
                with self.profiler.profile("backward pass"):
                    self._fabric.backward(loss, retain_graph=True)

            with self.profiler.profile("optimizer step"):
                optimizer.step()
                optimizer.zero_grad()

        return module

    def run(self, modelpool: ModelPool):
        """
        Run the WeightEnsemblingMoEAlgorithm to fuse models using Weight Ensembling Mixture of Experts.

        Args:
            modelpool (ModelPool): The pool of models to be fused.

        Returns:
            WeightEnsemblingMoE: The fused MoE model.
        """
        log.info("Fusing models using WeightEnsembling Mixture of Experts modules.")
        self.modelpool = modelpool

        with timeit_context("upscaling models to a weight-ensembling MoE model"):
            moe_model = self.construct_moe_model()
            print_parameters(moe_model)

        if self.config.get("checkpoint", False):
            log.info(
                f"load checkpoint from {self.config.checkpoint}, test-time adaptation will be skipped."
            )
            self.load_checkpoint(moe_model, self.config.checkpoint)
        else:
            with self.profiler.profile("test-time adaptation"):
                moe_model = self.test_time_adaptation(moe_model)
            if self.config.get("save_checkpoint", False):
                log.info(f"save checkpoint to {self.config.save_checkpoint}")
                self.save_checkpoint(moe_model, self.config.save_checkpoint)

            if lightning.fabric.wrappers.is_wrapped(moe_model):
                moe_model = lightning.fabric.wrappers._unwrap_objects(moe_model)

        # enable sample-wise adaptation
        moe_model.batch_reduce = False
        print(self.profiler.summary())
        return moe_model
__init__(algorithm_config)

Initialize the WeightEnsemblingMoEAlgorithm with the given configuration.

Parameters:

  • algorithm_config (DictConfig) –

    The configuration for the algorithm.

Source code in fusion_bench/method/we_moe/we_moe.py
def __init__(self, algorithm_config: DictConfig):
    """
    Initialize the WeightEnsemblingMoEAlgorithm with the given configuration.

    Args:
        algorithm_config (DictConfig): The configuration for the algorithm.
    """
    super().__init__(algorithm_config)

    if self._fabric is None and torch.cuda.is_available():
        self._fabric = L.Fabric(
            devices=self.config.get("devices", 1),
        )
        self._fabric.launch()
    else:
        assert "No CUDA device available."
    self.profiler = SimpleProfiler(
        self.config.get("cache_dir", "outputs"), "we_moe_profiler.txt"
    )
compute_logits(module, batch, task) abstractmethod

Compute the logits for a given batch and task.

Parameters:

  • module

    The model module to use for computing logits.

  • batch

    The batch of data.

  • task

    The task for which to compute logits.

Returns:

  • Tensor ( Tensor ) –

    The computed logits.

Source code in fusion_bench/method/we_moe/we_moe.py
@abstractmethod
def compute_logits(self, module, batch, task) -> Tensor:
    """
    Compute the logits for a given batch and task.

    Args:
        module: The model module to use for computing logits.
        batch: The batch of data.
        task: The task for which to compute logits.

    Returns:
        Tensor: The computed logits.
    """
    pass
construct_moe_model() abstractmethod

Construct the Mixture of Experts model using the models in the model pool.

Returns:

  • WeightEnsemblingMoE ( WeightEnsemblingMoE ) –

    The constructed MoE model.

Source code in fusion_bench/method/we_moe/we_moe.py
@abstractmethod
def construct_moe_model(self) -> WeightEnsemblingMoE:
    """
    Construct the Mixture of Experts model using the models in the model pool.

    Returns:
        WeightEnsemblingMoE: The constructed MoE model.
    """
    pass
get_shuffled_test_loader_iter(task) abstractmethod

Get an iterator for the shuffled test data loader for a specific task.

Parameters:

  • task (str) –

    The task for which to get the test data loader.

Returns:

  • DataLoader ( DataLoader ) –

    The shuffled test data loader iterator.

Source code in fusion_bench/method/we_moe/we_moe.py
@abstractmethod
def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
    """
    Get an iterator for the shuffled test data loader for a specific task.

    Args:
        task (str): The task for which to get the test data loader.

    Returns:
        DataLoader: The shuffled test data loader iterator.
    """
    pass
load_checkpoint(model, checkpoint) abstractmethod

Load the checkpoint file.

Parameters:

  • model

    The model to load the checkpoint into.

  • checkpoint

    The checkpoint file to load.

Source code in fusion_bench/method/we_moe/we_moe.py
@abstractmethod
def load_checkpoint(self, model, checkpoint):
    """
    Load the checkpoint file.

    Args:
        model: The model to load the checkpoint into.
        checkpoint: The checkpoint file to load.
    """
    pass
on_test_time_adaptation_start()

Hook method called at the start of test-time adaptation.

Source code in fusion_bench/method/we_moe/we_moe.py
def on_test_time_adaptation_start(self):
    """
    Hook method called at the start of test-time adaptation.
    """
    pass
run(modelpool)

Run the WeightEnsemblingMoEAlgorithm to fuse models using Weight Ensembling Mixture of Experts.

Parameters:

  • modelpool (ModelPool) –

    The pool of models to be fused.

Returns:

  • WeightEnsemblingMoE

    The fused MoE model.

Source code in fusion_bench/method/we_moe/we_moe.py
def run(self, modelpool: ModelPool):
    """
    Run the WeightEnsemblingMoEAlgorithm to fuse models using Weight Ensembling Mixture of Experts.

    Args:
        modelpool (ModelPool): The pool of models to be fused.

    Returns:
        WeightEnsemblingMoE: The fused MoE model.
    """
    log.info("Fusing models using WeightEnsembling Mixture of Experts modules.")
    self.modelpool = modelpool

    with timeit_context("upscaling models to a weight-ensembling MoE model"):
        moe_model = self.construct_moe_model()
        print_parameters(moe_model)

    if self.config.get("checkpoint", False):
        log.info(
            f"load checkpoint from {self.config.checkpoint}, test-time adaptation will be skipped."
        )
        self.load_checkpoint(moe_model, self.config.checkpoint)
    else:
        with self.profiler.profile("test-time adaptation"):
            moe_model = self.test_time_adaptation(moe_model)
        if self.config.get("save_checkpoint", False):
            log.info(f"save checkpoint to {self.config.save_checkpoint}")
            self.save_checkpoint(moe_model, self.config.save_checkpoint)

        if lightning.fabric.wrappers.is_wrapped(moe_model):
            moe_model = lightning.fabric.wrappers._unwrap_objects(moe_model)

    # enable sample-wise adaptation
    moe_model.batch_reduce = False
    print(self.profiler.summary())
    return moe_model
save_checkpoint(model, checkpoint) abstractmethod

Save the checkpoint file.

Parameters:

  • model

    The model to save the checkpoint from.

  • checkpoint

    The checkpoint file to save.

Source code in fusion_bench/method/we_moe/we_moe.py
@abstractmethod
def save_checkpoint(self, model, checkpoint):
    """
    Save the checkpoint file.

    Args:
        model: The model to save the checkpoint from.
        checkpoint: The checkpoint file to save.
    """
    pass
test_time_adaptation(module)

Perform test-time adaptation for the given module.

Parameters:

  • module (WeightEnsemblingMoE) –

    The MoE module to adapt.

Returns:

  • WeightEnsemblingMoE

    The adapted MoE module.

Source code in fusion_bench/method/we_moe/we_moe.py
def test_time_adaptation(self, module: WeightEnsemblingMoE):
    """
    Perform test-time adaptation for the given module.

    Args:
        module (WeightEnsemblingMoE): The MoE module to adapt.

    Returns:
        WeightEnsemblingMoE: The adapted MoE module.
    """
    self.on_test_time_adaptation_start()

    # configure optimizer
    if self.config.optimizer == "adam":
        optimizer = torch.optim.Adam(
            [p for p in module.parameters() if p.requires_grad], lr=self.config.lr
        )
    else:
        raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")

    if self._fabric is not None:
        module, optimizer = self._fabric.setup(module, optimizer)

    module.train()

    if self.config.get("fast_dev_run", False):
        log.info("Running fast_dev_run, only one step")
        pbar = tqdm(
            range(1),
            "Test-time adaptation",
            dynamic_ncols=True,
        )
    else:
        pbar = tqdm(
            range(self.config.max_steps),
            "Test-time adaptation",
            dynamic_ncols=True,
        )
    for step_idx in pbar:
        if self.config.use_grad_accumulate:
            for task in self.modelpool.model_names:
                with self.profiler.profile("data time"):
                    batch = next(self.get_shuffled_test_loader_iter(task))
                with self.profiler.profile("forward pass"):
                    logits = self.compute_logits(module, batch, task)
                    assert (
                        logits.dim() == 2
                    ), f"Expected logits to be 2D, got {logits.dim()}"
                    loss = entropy_loss(logits)
                # .backward() accumulates when .zero_grad() wasn't called
                # this can save memory
                with self.profiler.profile("backward pass"):
                    self._fabric.backward(loss, retain_graph=True)
        else:
            loss = 0
            for task in self.modelpool.model_names:
                with self.profiler.profile("data time"):
                    batch = next(self.get_shuffled_test_loader_iter(task))
                with self.profiler.profile("forward pass"):
                    logits = self.compute_logits(module, batch, task)
                    assert (
                        logits.dim() == 2
                    ), f"Expected logits to be 2D, got {logits.dim()}"
                    loss = loss + entropy_loss(logits)
            with self.profiler.profile("backward pass"):
                self._fabric.backward(loss, retain_graph=True)

        with self.profiler.profile("optimizer step"):
            optimizer.step()
            optimizer.zero_grad()

    return module
entropy_loss(logits)

Compute the entropy loss of a set of logits.

Parameters:

  • logits
    (Tensor) –

    The logits to compute the entropy loss of.

Returns:

  • Tensor ( Tensor ) –

    The entropy loss of the logits.

Source code in fusion_bench/method/we_moe/we_moe.py
def entropy_loss(logits: Tensor) -> Tensor:
    """
    Compute the entropy loss of a set of logits.

    Args:
        logits (Tensor): The logits to compute the entropy loss of.

    Returns:
        Tensor: The entropy loss of the logits.
    """
    probs = torch.softmax(logits, dim=-1)
    return -torch.sum(probs * torch.log(probs + 1e-8), dim=-1).mean()

clip_we_moe

CLIPWeightEnsemblingMoEAlgorithm

Bases: WeightEnsemblingMoEAlgorithm, CLIPClassificationMixin

CLIPWeightEnsemblingMoEAlgorithm is a class that implements the WeightEnsemblingMoEAlgorithm for CLIP models. It extends the WeightEnsemblingMoEAlgorithm and CLIPClassificationMixin classes.

Attributes:

Source code in fusion_bench/method/we_moe/clip_we_moe.py
class CLIPWeightEnsemblingMoEAlgorithm(
    WeightEnsemblingMoEAlgorithm,
    CLIPClassificationMixin,
):
    """
    CLIPWeightEnsemblingMoEAlgorithm is a class that implements the WeightEnsemblingMoEAlgorithm
    for CLIP models. It extends the WeightEnsemblingMoEAlgorithm and CLIPClassificationMixin classes.

    Attributes:
        modelpool (CLIPVisionModelPool): The model pool containing the CLIP models.
    """

    modelpool: CLIPVisionModelPool = None

    def load_checkpoint(self, model, checkpoint):
        """
        Load the checkpoint file.

        Args:
            model: The model to load the checkpoint into.
            checkpoint: The path to the checkpoint file.
        """
        state = {"model": model}
        self._fabric.load(checkpoint, state)

    def save_checkpoint(self, model, checkpoint):
        """
        Save the checkpoint file.

        Args:
            model: The model to save the checkpoint from.
            checkpoint: The path to the checkpoint file.
        """
        self._fabric.save(checkpoint, {"model": model})

    def construct_moe_model(self) -> WeightEnsemblingMoE:
        """
        Construct the Mixture of Experts (MoE) model using the models in the model pool.

        Returns:
            WeightEnsemblingMoE: The constructed MoE model.
        """
        base_model = self.modelpool.load_model("_pretrained_")
        expert_models = [
            self.modelpool.load_model(m) for m in self.modelpool.model_names
        ]

        # Merge the models using task arithmetic
        moe_model = task_arithmetic_merge(
            # This function modifies the model in place, so we need to pass a deepcopy
            deepcopy(base_model),
            expert_models,
            scaling_factor=self.config.init_lambda,
        ).requires_grad_(False)

        # Up-scale MLP modules
        base_encoder: CLIPEncoder = base_model.vision_model.encoder
        moe_encoder: CLIPEncoder = moe_model.vision_model.encoder
        expert_encoders = [m.vision_model.encoder for m in expert_models]

        num_layers = len(base_encoder.layers)
        for layer_idx in range(num_layers):
            base_mlp = base_encoder.layers[layer_idx].mlp
            expert_mlps = [e.layers[layer_idx].mlp for e in expert_encoders]

            moe_encoder.layers[layer_idx].mlp = WeightEnsemblingMoE(
                hidden_size=base_encoder.config.hidden_size,
                base_model=base_mlp,
                expert_models=expert_mlps,
                init_lambda=self.config.init_lambda,
                batch_first=True,  # For open_clip models this is False
                router_hidden_layers=self.config.router_hidden_layers,
                batch_reduce=self.config.batch_reduce,
            )

        return moe_model

    @functools.cache
    def get_shuffled_test_loader_iter(self, tta_dataset: str):
        """
        Get an iterator for the shuffled test data loader.

        Args:
            tta_dataset (str): The name of the test-time adaptation dataset.

        Returns:
            Iterator: An iterator for the shuffled test data loader.
        """
        dataset = self.modelpool.load_test_dataset(tta_dataset)
        dataset = CLIPDataset(dataset, processor=self.clip_processor)
        log.info("get_shuffled_test_loader_iter")
        loader = DataLoader(
            dataset,
            batch_size=self.config.batch_size,
            shuffle=True,
            num_workers=self.config.num_workers,
            pin_memory=True,
        )
        loader = self.fabric.setup_dataloaders(loader)
        return iter(InfiniteDataLoader(loader))

    def on_test_time_adaptation_start(self):
        """
        Load the CLIP processor and construct the zero-shot classification head for each task.
        """
        self.setup_zero_shot_classification_head()

    def compute_logits(self, module, batch, task) -> Tensor:
        """
        Compute the logits for the given batch and task.

        Args:
            module: The model module.
            batch: The input batch.
            task: The task name.

        Returns:
            Tensor: The computed logits.
        """
        images, _ = batch
        text_embeds = self.zeroshot_weights[task]

        image_embeds = module(images)[1]
        image_embeds = self.visual_projection(image_embeds)

        # Normalize embeddings
        image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)

        # Cosine similarity
        logits_per_text = (
            torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
        )
        logits_per_image = logits_per_text.t()

        return logits_per_image
compute_logits(module, batch, task)

Compute the logits for the given batch and task.

Parameters:

  • module

    The model module.

  • batch

    The input batch.

  • task

    The task name.

Returns:

  • Tensor ( Tensor ) –

    The computed logits.

Source code in fusion_bench/method/we_moe/clip_we_moe.py
def compute_logits(self, module, batch, task) -> Tensor:
    """
    Compute the logits for the given batch and task.

    Args:
        module: The model module.
        batch: The input batch.
        task: The task name.

    Returns:
        Tensor: The computed logits.
    """
    images, _ = batch
    text_embeds = self.zeroshot_weights[task]

    image_embeds = module(images)[1]
    image_embeds = self.visual_projection(image_embeds)

    # Normalize embeddings
    image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)

    # Cosine similarity
    logits_per_text = (
        torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
    )
    logits_per_image = logits_per_text.t()

    return logits_per_image
construct_moe_model()

Construct the Mixture of Experts (MoE) model using the models in the model pool.

Returns:

  • WeightEnsemblingMoE ( WeightEnsemblingMoE ) –

    The constructed MoE model.

Source code in fusion_bench/method/we_moe/clip_we_moe.py
def construct_moe_model(self) -> WeightEnsemblingMoE:
    """
    Construct the Mixture of Experts (MoE) model using the models in the model pool.

    Returns:
        WeightEnsemblingMoE: The constructed MoE model.
    """
    base_model = self.modelpool.load_model("_pretrained_")
    expert_models = [
        self.modelpool.load_model(m) for m in self.modelpool.model_names
    ]

    # Merge the models using task arithmetic
    moe_model = task_arithmetic_merge(
        # This function modifies the model in place, so we need to pass a deepcopy
        deepcopy(base_model),
        expert_models,
        scaling_factor=self.config.init_lambda,
    ).requires_grad_(False)

    # Up-scale MLP modules
    base_encoder: CLIPEncoder = base_model.vision_model.encoder
    moe_encoder: CLIPEncoder = moe_model.vision_model.encoder
    expert_encoders = [m.vision_model.encoder for m in expert_models]

    num_layers = len(base_encoder.layers)
    for layer_idx in range(num_layers):
        base_mlp = base_encoder.layers[layer_idx].mlp
        expert_mlps = [e.layers[layer_idx].mlp for e in expert_encoders]

        moe_encoder.layers[layer_idx].mlp = WeightEnsemblingMoE(
            hidden_size=base_encoder.config.hidden_size,
            base_model=base_mlp,
            expert_models=expert_mlps,
            init_lambda=self.config.init_lambda,
            batch_first=True,  # For open_clip models this is False
            router_hidden_layers=self.config.router_hidden_layers,
            batch_reduce=self.config.batch_reduce,
        )

    return moe_model
get_shuffled_test_loader_iter(tta_dataset) cached

Get an iterator for the shuffled test data loader.

Parameters:

  • tta_dataset (str) –

    The name of the test-time adaptation dataset.

Returns:

  • Iterator

    An iterator for the shuffled test data loader.

Source code in fusion_bench/method/we_moe/clip_we_moe.py
@functools.cache
def get_shuffled_test_loader_iter(self, tta_dataset: str):
    """
    Get an iterator for the shuffled test data loader.

    Args:
        tta_dataset (str): The name of the test-time adaptation dataset.

    Returns:
        Iterator: An iterator for the shuffled test data loader.
    """
    dataset = self.modelpool.load_test_dataset(tta_dataset)
    dataset = CLIPDataset(dataset, processor=self.clip_processor)
    log.info("get_shuffled_test_loader_iter")
    loader = DataLoader(
        dataset,
        batch_size=self.config.batch_size,
        shuffle=True,
        num_workers=self.config.num_workers,
        pin_memory=True,
    )
    loader = self.fabric.setup_dataloaders(loader)
    return iter(InfiniteDataLoader(loader))
load_checkpoint(model, checkpoint)

Load the checkpoint file.

Parameters:

  • model

    The model to load the checkpoint into.

  • checkpoint

    The path to the checkpoint file.

Source code in fusion_bench/method/we_moe/clip_we_moe.py
def load_checkpoint(self, model, checkpoint):
    """
    Load the checkpoint file.

    Args:
        model: The model to load the checkpoint into.
        checkpoint: The path to the checkpoint file.
    """
    state = {"model": model}
    self._fabric.load(checkpoint, state)
on_test_time_adaptation_start()

Load the CLIP processor and construct the zero-shot classification head for each task.

Source code in fusion_bench/method/we_moe/clip_we_moe.py
def on_test_time_adaptation_start(self):
    """
    Load the CLIP processor and construct the zero-shot classification head for each task.
    """
    self.setup_zero_shot_classification_head()
save_checkpoint(model, checkpoint)

Save the checkpoint file.

Parameters:

  • model

    The model to save the checkpoint from.

  • checkpoint

    The path to the checkpoint file.

Source code in fusion_bench/method/we_moe/clip_we_moe.py
def save_checkpoint(self, model, checkpoint):
    """
    Save the checkpoint file.

    Args:
        model: The model to save the checkpoint from.
        checkpoint: The path to the checkpoint file.
    """
    self._fabric.save(checkpoint, {"model": model})

  1. Anke Tang et.al. ICML 2024. Merging Multi-Task Models via Weight-Ensembling Mixture of Experts. http://arxiv.org/abs/2402.00433 ICML 2024. 

  2. Z. Lu, C. Fan, W. Wei, X. Qu, D. Chen, and Y. Cheng, “Twin-Merging: Dynamic Integration of Modular Expertise in Model Merging,” doi: 10.48550/arXiv.2406.15479. NeurIPS 2024. 

  3. L. Shen, A. Tang, E. Yang et al. Efficient and Effective Weight-Ensembling Mixture of Experts for Multi-Task Model Merging. Oct, 2024.