Skip to content

Model Ensemble

SimpleEnsembleAlgorithm

Bases: BaseAlgorithm

Source code in fusion_bench/method/ensemble.py
@auto_register_config
class SimpleEnsembleAlgorithm(BaseAlgorithm):
    def __init__(
        self,
        device_map: Optional[Mapping[int, Union[str, torch.device]]] = None,
        **kwargs,
    ):
        """
        Initializes the SimpleEnsembleAlgorithm with an optional device map.

        Args:
            device_map (Optional[Mapping[int, Union[str, torch.device]]], optional): A mapping from model index to device. Defaults to None.
        """
        super().__init__(**kwargs)

    @torch.no_grad()
    def run(self, modelpool: BaseModelPool | List[nn.Module]) -> EnsembleModule:
        """
        Run the simple ensemble algorithm on the given model pool.

        Args:
            modelpool (BaseModelPool | List[nn.Module]): The pool of models to ensemble.

        Returns:
            EnsembleModule: The ensembled model.
        """
        log.info(f"Running ensemble algorithm with {len(modelpool)} models")
        models = [modelpool.load_model(m) for m in modelpool.model_names]

        log.info("creating ensemble module")
        ensemble = EnsembleModule(models=models, device_map=self.device_map)
        return ensemble

__init__(device_map=None, **kwargs)

Initializes the SimpleEnsembleAlgorithm with an optional device map.

Parameters:

  • device_map (Optional[Mapping[int, Union[str, device]]], default: None ) –

    A mapping from model index to device. Defaults to None.

Source code in fusion_bench/method/ensemble.py
def __init__(
    self,
    device_map: Optional[Mapping[int, Union[str, torch.device]]] = None,
    **kwargs,
):
    """
    Initializes the SimpleEnsembleAlgorithm with an optional device map.

    Args:
        device_map (Optional[Mapping[int, Union[str, torch.device]]], optional): A mapping from model index to device. Defaults to None.
    """
    super().__init__(**kwargs)

run(modelpool)

Run the simple ensemble algorithm on the given model pool.

Parameters:

  • modelpool (BaseModelPool | List[Module]) –

    The pool of models to ensemble.

Returns:

Source code in fusion_bench/method/ensemble.py
@torch.no_grad()
def run(self, modelpool: BaseModelPool | List[nn.Module]) -> EnsembleModule:
    """
    Run the simple ensemble algorithm on the given model pool.

    Args:
        modelpool (BaseModelPool | List[nn.Module]): The pool of models to ensemble.

    Returns:
        EnsembleModule: The ensembled model.
    """
    log.info(f"Running ensemble algorithm with {len(modelpool)} models")
    models = [modelpool.load_model(m) for m in modelpool.model_names]

    log.info("creating ensemble module")
    ensemble = EnsembleModule(models=models, device_map=self.device_map)
    return ensemble

WeightedEnsembleAlgorithm

Bases: BaseAlgorithm

Source code in fusion_bench/method/ensemble.py
@auto_register_config
class WeightedEnsembleAlgorithm(BaseAlgorithm):

    def __init__(
        self,
        normalize: bool = True,
        weights: Optional[List[float]] = None,
        **kwargs,
    ):
        super().__init__(**kwargs)

    @torch.no_grad()
    def run(self, modelpool: BaseModelPool | List[nn.Module]) -> WeightedEnsembleModule:
        """
        Run the weighted ensemble algorithm on the given model pool.

        Args:
            modelpool (BaseModelPool | List[nn.Module]): The pool of models to ensemble.

        Returns:
            WeightedEnsembleModule: The weighted ensembled model.
        """
        if not isinstance(modelpool, BaseModelPool):
            modelpool = BaseModelPool(models=modelpool)

        log.info(f"Running weighted ensemble algorithm with {len(modelpool)} models")

        models = [modelpool.load_model(m) for m in modelpool.model_names]
        if self.weights is None:
            weights = np.ones(len(models)) / len(models)
        else:
            weights = self.weights
        ensemble = WeightedEnsembleModule(
            models,
            weights=weights,
            normalize=self.config.get("normalize", True),
        )
        return ensemble

run(modelpool)

Run the weighted ensemble algorithm on the given model pool.

Parameters:

  • modelpool (BaseModelPool | List[Module]) –

    The pool of models to ensemble.

Returns:

Source code in fusion_bench/method/ensemble.py
@torch.no_grad()
def run(self, modelpool: BaseModelPool | List[nn.Module]) -> WeightedEnsembleModule:
    """
    Run the weighted ensemble algorithm on the given model pool.

    Args:
        modelpool (BaseModelPool | List[nn.Module]): The pool of models to ensemble.

    Returns:
        WeightedEnsembleModule: The weighted ensembled model.
    """
    if not isinstance(modelpool, BaseModelPool):
        modelpool = BaseModelPool(models=modelpool)

    log.info(f"Running weighted ensemble algorithm with {len(modelpool)} models")

    models = [modelpool.load_model(m) for m in modelpool.model_names]
    if self.weights is None:
        weights = np.ones(len(models)) / len(models)
    else:
        weights = self.weights
    ensemble = WeightedEnsembleModule(
        models,
        weights=weights,
        normalize=self.config.get("normalize", True),
    )
    return ensemble

MaxModelPredictorAlgorithm

Bases: BaseAlgorithm

Source code in fusion_bench/method/ensemble.py
class MaxModelPredictorAlgorithm(BaseAlgorithm):
    @torch.no_grad()
    def run(self, modelpool: BaseModelPool | List[nn.Module]) -> MaxModelPredictor:
        """
        Run the max model predictor algorithm on the given model pool.

        Args:
            modelpool (BaseModelPool | List[nn.Module]): The pool of models to ensemble.

        Returns:
            MaxModelPredictor: The max model predictor ensembled model.
        """
        if not isinstance(modelpool, BaseModelPool):
            modelpool = BaseModelPool(models=modelpool)

        log.info(f"Running max predictor algorithm with {len(modelpool)} models")

        models = [modelpool.load_model(m) for m in modelpool.model_names]
        ensemble = MaxModelPredictor(models=models)
        return ensemble

run(modelpool)

Run the max model predictor algorithm on the given model pool.

Parameters:

  • modelpool (BaseModelPool | List[Module]) –

    The pool of models to ensemble.

Returns:

  • MaxModelPredictor ( MaxModelPredictor ) –

    The max model predictor ensembled model.

Source code in fusion_bench/method/ensemble.py
@torch.no_grad()
def run(self, modelpool: BaseModelPool | List[nn.Module]) -> MaxModelPredictor:
    """
    Run the max model predictor algorithm on the given model pool.

    Args:
        modelpool (BaseModelPool | List[nn.Module]): The pool of models to ensemble.

    Returns:
        MaxModelPredictor: The max model predictor ensembled model.
    """
    if not isinstance(modelpool, BaseModelPool):
        modelpool = BaseModelPool(models=modelpool)

    log.info(f"Running max predictor algorithm with {len(modelpool)} models")

    models = [modelpool.load_model(m) for m in modelpool.model_names]
    ensemble = MaxModelPredictor(models=models)
    return ensemble

DataAdaptiveWeightEnsemblingForCLIP

Bases: BaseAlgorithm, CLIPClassificationMixin

Source code in fusion_bench/method/dawe/dawe_for_clip.py
class DataAdaptiveWeightEnsemblingForCLIP(
    BaseAlgorithm,
    CLIPClassificationMixin,
):
    modelpool: CLIPVisionModelPool
    _processor: CLIPProcessor

    def __init__(
        self,
        # merge options
        merge_mode: Literal["task_wise", "layer_wise"],
        init_lambda: float,
        batch_reduce: bool,
        eval_batch_reduce: bool,
        # model options
        dict_processor: DictConfig,
        dict_feature_extractor: DictConfig,
        hidden_size: Optional[int],
        gate_hidden_layers: int,
        task_vector_dtype: Optional[str | torch.dtype],
        task_vector_sparsity: float,
        # training & logging args
        max_steps: int,
        save_interval: int,
        learning_rate: float = 1e-5,
        skip_training: bool = False,
        resume_checkpoint_path: Optional[str] = None,
        # dataloader args
        batch_size: int = 4,
        num_workers: int = 0,
        pin_memory: bool = True,
        **kwargs,
    ):
        # merge options
        self.merge_mode = merge_mode
        self.init_lambda = init_lambda
        self.batch_reduce = batch_reduce
        self.eval_batch_reduce = eval_batch_reduce
        # model options
        self._dict_processor = dict_processor
        self._dict_feature_extractor = dict_feature_extractor
        self.hidden_size = hidden_size
        self.gate_hidden_layers = gate_hidden_layers
        self.task_vector_dtype = task_vector_dtype
        self.task_vector_sparsity = task_vector_sparsity
        # training & logging args
        self.max_steps = max_steps
        self.save_interval = save_interval
        self.learning_rate = learning_rate
        self.skip_training = skip_training
        self.resume_checkpoint_path = resume_checkpoint_path
        # dataloader args
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        super().__init__(**kwargs)

    def load_models(self):
        modelpool = self.modelpool

        dict_processor = instantiate(self._dict_processor)
        clip_processor = modelpool.load_processor()

        dict_feature_extractor: Union[PreTrainedModel, nn.Module] = instantiate(
            self._dict_feature_extractor
        )
        if self.hidden_size is None:
            # try to infer hidden size from feature extractor model
            self.hidden_size = dict_feature_extractor.config.hidden_sizes[-1]

        # initialize classification head
        self.setup_zero_shot_classification_head(
            clip_processor=clip_processor,
            task_names=modelpool.model_names,
        )
        model = DataAdaptiveWeightEnsemblingCLIPVisionModel(
            merge_mode=self.merge_mode,
            hidden_size=self.hidden_size,
            dict_processor=dict_processor,
            model_processor=lambda images: clip_processor(
                images=images, return_tensors="pt"
            ).pixel_values,
            collate_fn=lambda outputs: torch.cat(
                [out.pooler_output for out in outputs], dim=0
            ),
            dict_feature_extractor=dict_feature_extractor,
            base_model=modelpool.load_model("_pretrained_"),
            expert_models=list(modelpool.models()),
            task_vector_dtype=self.task_vector_dtype,
            task_vector_sparsity=self.task_vector_sparsity,
            init_lambda=self.init_lambda,
            gate_hidden_layers=self.gate_hidden_layers,
            batch_reduce=self.batch_reduce,
        )

        if self.resume_checkpoint_path is not None:
            self.fabric.load(self.resume_checkpoint_path, {"model": model})
        return model

    def load_datasets(self):
        modelpool = self.modelpool
        self.test_datasets = {
            task_name: CLIPDataset(
                modelpool.load_test_dataset(task_name),
                processor=None,  # NOTE: processor is not used in CLIPDataset because feature extractor and model may have different processors, so we want to pass the image as is
            )
            for task_name in modelpool.model_names
        }

        # setup dataloaders for test-time adaptation training

        dataloader_kwargs = {
            "batch_size": self.batch_size,
            "num_workers": self.num_workers,
            "pin_memory": self.pin_memory,
        }
        self.shuffled_test_loaders = {
            task_name: self.fabric.setup_dataloaders(
                DataLoader(
                    test_dataset,
                    **dataloader_kwargs,
                    collate_fn=raw_image_collate_fn,
                    shuffle=True,
                )
            )
            for task_name, test_dataset in self.test_datasets.items()
        }
        self.shuffled_test_loader_iters = {
            task_name: InfiniteDataLoader(loader)
            for task_name, loader in self.shuffled_test_loaders.items()
        }

    def run(self, modelpool: CLIPVisionModelPool):
        self.modelpool = modelpool
        with timeit_context("Loading models"):
            model = self.load_models()
        with timeit_context("Loading dataloaders"):
            self.load_datasets()

        # run test-time adaptation
        if not self.skip_training:
            model = self.test_time_adaptation_training(modelpool, model)

        if self.eval_batch_reduce is not None:
            model.batch_reduce = self.eval_batch_reduce
        return model

    def test_time_adaptation_training(self, modelpool, model):
        optimizer = torch.optim.Adam(
            [p for p in model.gate.parameters() if p.requires_grad],
            lr=self.learning_rate,
        )
        model, optimizer = self.fabric.setup(model, optimizer)
        model.train()
        for step_idx in tqdm(
            range(self.max_steps),
            desc="TTA Training",
            dynamic_ncols=True,
        ):
            log_metrics = {}
            losses = 0
            for task_idx, task_name in enumerate(modelpool.model_names):
                # labels are used for logging acc, not involved in training
                images, labels = next(self.shuffled_test_loader_iters[task_name])
                logits = self.compute_logits(model, images=images, task=task_name)
                loss = entropy_loss(logits)
                losses += loss
                log_metrics[f"train/{task_name}_loss"] = loss.item()
                log_metrics[f"train/{task_name}_accuracy"] = (
                    logits.argmax(dim=-1).eq(labels).float().mean().item()
                )

            optimizer.zero_grad()
            self.fabric.backward(losses)
            optimizer.step()

            log_metrics["train/loss"] = losses.item()
            self.fabric.log_dict(log_metrics, step=step_idx)

            if (step_idx + 1) % self.save_interval == 0:
                log.info(f"Saving model at step {step_idx}")
                self.fabric.save(
                    Path(self.log_dir) / "checkpoints" / f"model_{step_idx}.pt",
                    {"model": model},
                )

        if (step_idx + 1) % self.save_interval != 0:
            # if the last step was not saved, save it now
            self.fabric.save(
                Path(self.log_dir) / "checkpoints" / f"model_{step_idx}.pt",
                {"model": model},
            )

        return model

FlanT5WeightEnsemblingMoEAlgorithm

Bases: LightningFabricMixin, SimpleProfilerMixin, BaseAlgorithm

FlanT5WeightEnsemblingMoEAlgorithm is a class that implements the WeightEnsemblingMoEAlgorithm for FlanT5 models. It extends the WeightEnsemblingMoEAlgorithm and CLIPClassificationMixin classes.

Attributes:

  • modelpool (Seq2SeqLMPool) –

    The model pool containing the FlanT5 models.

Source code in fusion_bench/method/we_moe/flan_t5_we_moe.py
@auto_register_config
class FlanT5WeightEnsemblingMoEAlgorithm(
    LightningFabricMixin,
    SimpleProfilerMixin,
    BaseAlgorithm,
):
    """
    FlanT5WeightEnsemblingMoEAlgorithm is a class that implements the WeightEnsemblingMoEAlgorithm
    for FlanT5 models. It extends the WeightEnsemblingMoEAlgorithm and CLIPClassificationMixin classes.

    Attributes:
        modelpool (Seq2SeqLMPool): The model pool containing the FlanT5 models.
    """

    modelpool: Seq2SeqLMPool = None

    def __init__(
        self,
        checkpoint: bool = False,
        save_checkpoint: bool = False,
        router_hidden_layers: int = 2,
        init_lambda: float = 0.3,
        batch_reduce: bool = True,
        lr: float = 1e-4,
        optimizer: str = "adam",
        devices: int = 1,
        batch_size: int = 16,
        num_workers: int = 0,
        max_steps: int = 1000,
        use_grad_accumulate: bool = True,
        fast_dev_run: bool = False,
        **kwargs,
    ):
        """
        Initialize the WeightEnsemblingMoEAlgorithm with the given configuration.

        Args:
            algorithm_config (DictConfig): The configuration for the algorithm.
        """
        super().__init__(**kwargs)

    def construct_moe_model(self):
        """
        Construct the Mixture of Experts (MoE) model using the models in the model pool.

        Returns:
            WeightEnsemblingMoE: The constructed MoE model.
        """
        base_model = self.modelpool.load_model("_pretrained_")
        expert_models = [
            self.modelpool.load_model(name) for name in self.modelpool.model_names
        ]

        # Merge the models using task arithmetic
        moe_model = task_arithmetic_merge(
            # This function modifies the model in place, so we need to pass a deepcopy
            deepcopy(base_model),
            expert_models,
            scaling_factor=self.init_lambda,
        ).requires_grad_(False)

        print(base_model)

        # Up-scale MLP modules
        num_layer = 12
        encoder_mlp_index = 1
        base_encoder = base_model.encoder
        moe_encoder = moe_model.encoder
        expert_encoders = [m.encoder for m in expert_models]

        for layer_idx in range(num_layer):
            base_mlp = (
                base_encoder.block[layer_idx].layer[encoder_mlp_index].DenseReluDense
            )
            expert_mlps = [
                e.block[layer_idx].layer[encoder_mlp_index].DenseReluDense
                for e in expert_encoders
            ]

            moe_encoder.block[layer_idx].layer[encoder_mlp_index].DenseReluDense = (
                WeightEnsemblingMoE(
                    hidden_size=base_encoder.config.hidden_size,
                    base_model=base_mlp,
                    expert_models=expert_mlps,
                    init_lambda=self.init_lambda,
                    batch_first=True,
                    router_hidden_layers=self.router_hidden_layers,
                    batch_reduce=self.batch_reduce,
                )
            )

        decoder_mlp_index = 2
        base_decoder = base_model.decoder
        moe_decoder = moe_model.decoder
        expert_decoders = [m.decoder for m in expert_models]

        for layer_idx in range(num_layer):
            base_mlp = (
                base_decoder.block[layer_idx].layer[decoder_mlp_index].DenseReluDense
            )
            expert_mlps = [
                e.block[layer_idx].layer[decoder_mlp_index].DenseReluDense
                for e in expert_decoders
            ]

            moe_decoder.block[layer_idx].layer[decoder_mlp_index].DenseReluDense = (
                WeightEnsemblingMoE(
                    hidden_size=base_decoder.config.hidden_size,
                    base_model=base_mlp,
                    expert_models=expert_mlps,
                    init_lambda=self.init_lambda,
                    batch_first=True,
                    router_hidden_layers=self.router_hidden_layers,
                    batch_reduce=self.batch_reduce,
                )
            )

        print(moe_model)
        return moe_model

    @functools.cache
    def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
        """
        Loader of test dataset for test-time adaptation. labels are not needed.

        Args:
            task (str): The name of the task.

        Returns:
            DataLoader: The data loader for the test dataset.
        """
        # dataloader_kwargs = dict(self.dataloader_kwargs)
        # dataloader_kwargs.update(dict(shuffle=True, collate_fn=default_data_collator))

        dataset = self.modelpool.load_test_dataset(task)
        log.info("get_shuffled_test_loader_iter")
        loader = DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            collate_fn=default_data_collator,
        )
        # loader = DataLoader(dataset, **dataloader_kwargs)
        if self.fabric is not None:
            loader = self.fabric.setup_dataloaders(loader)
        return iter(InfiniteDataLoader(loader))

    def compute_logits(
        self,
        module: Union[T5ForConditionalGeneration],
        batch,
        task: str,
    ) -> Tensor:
        """
        Compute the logits for the given images and task.

        Args:
            module: The model module.
            images (Tensor): The input images.
            task (str): The name of the task.

        Returns:
            Tensor: The computed logits.
        """
        input_ids: Tensor = batch["input_ids"]
        attention_mask: Tensor = batch["attention_mask"]

        # remove padding tokens from the input
        while attention_mask[:, -1].eq(0).all():
            input_ids = input_ids[:, :-1]
            attention_mask = attention_mask[:, :-1]

        outputs = module(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=torch.ones(
                input_ids.size(0), 1, dtype=torch.long, device=input_ids.device
            ),
        )
        logits = outputs.logits[:, 0, :]
        return logits

    def test_time_adaptation(self, module):
        """
        Perform test-time adaptation for the given module.

        Args:
            module (WeightEnsemblingMoE): The MoE module to adapt.

        Returns:
            WeightEnsemblingMoE: The adapted MoE module.
        """
        self.on_test_time_adaptation_start()

        # configure optimizer
        if self.optimizer == "adam":
            print([name for name, p in module.named_parameters() if p.requires_grad])
            optimizer = torch.optim.Adam(
                [p for p in module.parameters() if p.requires_grad], lr=self.lr
            )
        else:
            raise ValueError(f"Unsupported optimizer: {self.optimizer}")

        module, optimizer = self.fabric.setup(module, optimizer)

        module.train()
        # module.merge_weights()
        for step_idx in (
            pbar := tqdm(
                range(self.max_steps if not self.is_debug_mode else 1),
                ("[DEBUG MODE] " if self.is_debug_mode else "")
                + "WEMoE Test-time adaptation",
                dynamic_ncols=True,
            )
        ):
            total_loss = 0
            for task in self.modelpool.model_names:
                with self.profile("data loading"):
                    batch = next(self.get_shuffled_test_loader_iter(task))
                with self.profile("forward pass"):
                    logits = self.compute_logits(module, batch, task)
                    logits = logits.mean(dim=0, keepdim=True)
                    loss = entropy_loss(logits)
                    total_loss += loss
                with self.profile("backward pass"):
                    self.fabric.backward(loss, retain_graph=True)

            with self.profile("optimizer step"):
                optimizer.step()
                optimizer.zero_grad()

            metrics = {
                "train/loss": total_loss.item(),
            }
            self.fabric.log_dict(metrics, step=step_idx)
            pbar.set_postfix(metrics)

        log.info(get_memory_usage(f"after adamerging, the memory usage of GPU is:"))
        self.print_profile_summary()
        return module

    def on_test_time_adaptation_start(self):
        """
        Something to do before the test-time adaptation starts. Such as setting up the task-specific heads.
        """
        pass

    def run(self, modelpool: Seq2SeqLMPool, **kwargs):
        """
        Run the WeightEnsemblingMoEAlgorithm to fuse models using Weight Ensembling Mixture of Experts.

        Args:
            modelpool (ModelPool): The pool of models to be fused.

        Returns:
            WeightEnsemblingMoE: The fused MoE model.
        """
        log.info("Fusing models using layer-wise adaptive merging.")
        self.modelpool = modelpool

        with timeit_context("upscaling models to a weight-ensembling MoE model"):
            moe_model = self.construct_moe_model()
            print_parameters(moe_model)

        if self.checkpoint != False:
            log.info(
                f"load checkpoint from {self.checkpoint}, test-time adaptation will be skipped."
            )
            self.load_checkpoint(moe_model, self.checkpoint)
        else:
            with self.profile("test-time adaptation"):
                moe_model = self.test_time_adaptation(moe_model)
            if self.save_checkpoint != False:
                log.info(f"save checkpoint to {self.save_checkpoint}")
                self.save_checkpoint(moe_model, self.save_checkpoint)

            if lightning.fabric.wrappers.is_wrapped(moe_model):
                moe_model = lightning.fabric.wrappers._unwrap_objects(moe_model)

        # enable sample-wise adaptation
        moe_model.batch_reduce = False
        self.print_profile_summary()
        return moe_model

__init__(checkpoint=False, save_checkpoint=False, router_hidden_layers=2, init_lambda=0.3, batch_reduce=True, lr=0.0001, optimizer='adam', devices=1, batch_size=16, num_workers=0, max_steps=1000, use_grad_accumulate=True, fast_dev_run=False, **kwargs)

Initialize the WeightEnsemblingMoEAlgorithm with the given configuration.

Parameters:

  • algorithm_config (DictConfig) –

    The configuration for the algorithm.

Source code in fusion_bench/method/we_moe/flan_t5_we_moe.py
def __init__(
    self,
    checkpoint: bool = False,
    save_checkpoint: bool = False,
    router_hidden_layers: int = 2,
    init_lambda: float = 0.3,
    batch_reduce: bool = True,
    lr: float = 1e-4,
    optimizer: str = "adam",
    devices: int = 1,
    batch_size: int = 16,
    num_workers: int = 0,
    max_steps: int = 1000,
    use_grad_accumulate: bool = True,
    fast_dev_run: bool = False,
    **kwargs,
):
    """
    Initialize the WeightEnsemblingMoEAlgorithm with the given configuration.

    Args:
        algorithm_config (DictConfig): The configuration for the algorithm.
    """
    super().__init__(**kwargs)

compute_logits(module, batch, task)

Compute the logits for the given images and task.

Parameters:

  • module (Union[T5ForConditionalGeneration]) –

    The model module.

  • images (Tensor) –

    The input images.

  • task (str) –

    The name of the task.

Returns:

  • Tensor ( Tensor ) –

    The computed logits.

Source code in fusion_bench/method/we_moe/flan_t5_we_moe.py
def compute_logits(
    self,
    module: Union[T5ForConditionalGeneration],
    batch,
    task: str,
) -> Tensor:
    """
    Compute the logits for the given images and task.

    Args:
        module: The model module.
        images (Tensor): The input images.
        task (str): The name of the task.

    Returns:
        Tensor: The computed logits.
    """
    input_ids: Tensor = batch["input_ids"]
    attention_mask: Tensor = batch["attention_mask"]

    # remove padding tokens from the input
    while attention_mask[:, -1].eq(0).all():
        input_ids = input_ids[:, :-1]
        attention_mask = attention_mask[:, :-1]

    outputs = module(
        input_ids=input_ids,
        attention_mask=attention_mask,
        decoder_input_ids=torch.ones(
            input_ids.size(0), 1, dtype=torch.long, device=input_ids.device
        ),
    )
    logits = outputs.logits[:, 0, :]
    return logits

construct_moe_model()

Construct the Mixture of Experts (MoE) model using the models in the model pool.

Returns:

  • WeightEnsemblingMoE

    The constructed MoE model.

Source code in fusion_bench/method/we_moe/flan_t5_we_moe.py
def construct_moe_model(self):
    """
    Construct the Mixture of Experts (MoE) model using the models in the model pool.

    Returns:
        WeightEnsemblingMoE: The constructed MoE model.
    """
    base_model = self.modelpool.load_model("_pretrained_")
    expert_models = [
        self.modelpool.load_model(name) for name in self.modelpool.model_names
    ]

    # Merge the models using task arithmetic
    moe_model = task_arithmetic_merge(
        # This function modifies the model in place, so we need to pass a deepcopy
        deepcopy(base_model),
        expert_models,
        scaling_factor=self.init_lambda,
    ).requires_grad_(False)

    print(base_model)

    # Up-scale MLP modules
    num_layer = 12
    encoder_mlp_index = 1
    base_encoder = base_model.encoder
    moe_encoder = moe_model.encoder
    expert_encoders = [m.encoder for m in expert_models]

    for layer_idx in range(num_layer):
        base_mlp = (
            base_encoder.block[layer_idx].layer[encoder_mlp_index].DenseReluDense
        )
        expert_mlps = [
            e.block[layer_idx].layer[encoder_mlp_index].DenseReluDense
            for e in expert_encoders
        ]

        moe_encoder.block[layer_idx].layer[encoder_mlp_index].DenseReluDense = (
            WeightEnsemblingMoE(
                hidden_size=base_encoder.config.hidden_size,
                base_model=base_mlp,
                expert_models=expert_mlps,
                init_lambda=self.init_lambda,
                batch_first=True,
                router_hidden_layers=self.router_hidden_layers,
                batch_reduce=self.batch_reduce,
            )
        )

    decoder_mlp_index = 2
    base_decoder = base_model.decoder
    moe_decoder = moe_model.decoder
    expert_decoders = [m.decoder for m in expert_models]

    for layer_idx in range(num_layer):
        base_mlp = (
            base_decoder.block[layer_idx].layer[decoder_mlp_index].DenseReluDense
        )
        expert_mlps = [
            e.block[layer_idx].layer[decoder_mlp_index].DenseReluDense
            for e in expert_decoders
        ]

        moe_decoder.block[layer_idx].layer[decoder_mlp_index].DenseReluDense = (
            WeightEnsemblingMoE(
                hidden_size=base_decoder.config.hidden_size,
                base_model=base_mlp,
                expert_models=expert_mlps,
                init_lambda=self.init_lambda,
                batch_first=True,
                router_hidden_layers=self.router_hidden_layers,
                batch_reduce=self.batch_reduce,
            )
        )

    print(moe_model)
    return moe_model

get_shuffled_test_loader_iter(task) cached

Loader of test dataset for test-time adaptation. labels are not needed.

Parameters:

  • task (str) –

    The name of the task.

Returns:

  • DataLoader ( DataLoader ) –

    The data loader for the test dataset.

Source code in fusion_bench/method/we_moe/flan_t5_we_moe.py
@functools.cache
def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
    """
    Loader of test dataset for test-time adaptation. labels are not needed.

    Args:
        task (str): The name of the task.

    Returns:
        DataLoader: The data loader for the test dataset.
    """
    # dataloader_kwargs = dict(self.dataloader_kwargs)
    # dataloader_kwargs.update(dict(shuffle=True, collate_fn=default_data_collator))

    dataset = self.modelpool.load_test_dataset(task)
    log.info("get_shuffled_test_loader_iter")
    loader = DataLoader(
        dataset,
        batch_size=self.batch_size,
        shuffle=True,
        num_workers=self.num_workers,
        collate_fn=default_data_collator,
    )
    # loader = DataLoader(dataset, **dataloader_kwargs)
    if self.fabric is not None:
        loader = self.fabric.setup_dataloaders(loader)
    return iter(InfiniteDataLoader(loader))

on_test_time_adaptation_start()

Something to do before the test-time adaptation starts. Such as setting up the task-specific heads.

Source code in fusion_bench/method/we_moe/flan_t5_we_moe.py
def on_test_time_adaptation_start(self):
    """
    Something to do before the test-time adaptation starts. Such as setting up the task-specific heads.
    """
    pass

run(modelpool, **kwargs)

Run the WeightEnsemblingMoEAlgorithm to fuse models using Weight Ensembling Mixture of Experts.

Parameters:

  • modelpool (ModelPool) –

    The pool of models to be fused.

Returns:

  • WeightEnsemblingMoE

    The fused MoE model.

Source code in fusion_bench/method/we_moe/flan_t5_we_moe.py
def run(self, modelpool: Seq2SeqLMPool, **kwargs):
    """
    Run the WeightEnsemblingMoEAlgorithm to fuse models using Weight Ensembling Mixture of Experts.

    Args:
        modelpool (ModelPool): The pool of models to be fused.

    Returns:
        WeightEnsemblingMoE: The fused MoE model.
    """
    log.info("Fusing models using layer-wise adaptive merging.")
    self.modelpool = modelpool

    with timeit_context("upscaling models to a weight-ensembling MoE model"):
        moe_model = self.construct_moe_model()
        print_parameters(moe_model)

    if self.checkpoint != False:
        log.info(
            f"load checkpoint from {self.checkpoint}, test-time adaptation will be skipped."
        )
        self.load_checkpoint(moe_model, self.checkpoint)
    else:
        with self.profile("test-time adaptation"):
            moe_model = self.test_time_adaptation(moe_model)
        if self.save_checkpoint != False:
            log.info(f"save checkpoint to {self.save_checkpoint}")
            self.save_checkpoint(moe_model, self.save_checkpoint)

        if lightning.fabric.wrappers.is_wrapped(moe_model):
            moe_model = lightning.fabric.wrappers._unwrap_objects(moe_model)

    # enable sample-wise adaptation
    moe_model.batch_reduce = False
    self.print_profile_summary()
    return moe_model

test_time_adaptation(module)

Perform test-time adaptation for the given module.

Parameters:

  • module (WeightEnsemblingMoE) –

    The MoE module to adapt.

Returns:

  • WeightEnsemblingMoE

    The adapted MoE module.

Source code in fusion_bench/method/we_moe/flan_t5_we_moe.py
def test_time_adaptation(self, module):
    """
    Perform test-time adaptation for the given module.

    Args:
        module (WeightEnsemblingMoE): The MoE module to adapt.

    Returns:
        WeightEnsemblingMoE: The adapted MoE module.
    """
    self.on_test_time_adaptation_start()

    # configure optimizer
    if self.optimizer == "adam":
        print([name for name, p in module.named_parameters() if p.requires_grad])
        optimizer = torch.optim.Adam(
            [p for p in module.parameters() if p.requires_grad], lr=self.lr
        )
    else:
        raise ValueError(f"Unsupported optimizer: {self.optimizer}")

    module, optimizer = self.fabric.setup(module, optimizer)

    module.train()
    # module.merge_weights()
    for step_idx in (
        pbar := tqdm(
            range(self.max_steps if not self.is_debug_mode else 1),
            ("[DEBUG MODE] " if self.is_debug_mode else "")
            + "WEMoE Test-time adaptation",
            dynamic_ncols=True,
        )
    ):
        total_loss = 0
        for task in self.modelpool.model_names:
            with self.profile("data loading"):
                batch = next(self.get_shuffled_test_loader_iter(task))
            with self.profile("forward pass"):
                logits = self.compute_logits(module, batch, task)
                logits = logits.mean(dim=0, keepdim=True)
                loss = entropy_loss(logits)
                total_loss += loss
            with self.profile("backward pass"):
                self.fabric.backward(loss, retain_graph=True)

        with self.profile("optimizer step"):
            optimizer.step()
            optimizer.zero_grad()

        metrics = {
            "train/loss": total_loss.item(),
        }
        self.fabric.log_dict(metrics, step=step_idx)
        pbar.set_postfix(metrics)

    log.info(get_memory_usage(f"after adamerging, the memory usage of GPU is:"))
    self.print_profile_summary()
    return module