Skip to content

MoE-based Model Model Upscaling (Sparse Upcycling)

alt text

Sparse upcycling is a technique used to initialize a sparsely activated Mixture-of-Experts (MoE) model from a dense checkpoint. This approach leverages previously incurred training costs to improve the performance of large models while reducing the computational expense. In the process, dense Transformer blocks are partially replaced with MoE blocks, where the MLPs in a Transformer block are replaced by multiple experts. The experts are chosen based on routing probabilities determined by a router. The initialized MoE model is then further trained to recover the performance. This method results in improved performance for both language and vision models while using only a fraction of the original dense pretraining cost 1.

Examples

Here’s an example demonstrating how to upscale a pre-trained Mistral model to a Mixtral model:

import os

from omegaconf import DictConfig
from transformers import MistralForCausalLM

from fusion_bench.method.mixture_of_experts.mixtral_upcycling import (
    MixtralForCausalLMUpscalingAlgorithm,
)
from fusion_bench.utils import print_parameters

# Load a pre-trained Mistral model
pretrained_model = MistralForCausalLM.from_pretrained(
    os.path.expanduser("path_to_mistral_model")
)
print("Pretrained model:")
print_parameters(pretrained_model)
# Output:
# Pretrained model:
# trainable params: 7.24B || all params: 7.24B || trainable%: 100.0000

# Define the configuration for Mixtral
config = {
    "num_experts": 4,  # Number of expert channels
    "experts_per_token": 2,  # Experts to choose per token
}

# Initialize the upscaling algorithm
upscaling_for_causal_lm_algorithm = MixtralForCausalLMUpscalingAlgorithm(
    DictConfig(config)
)

# Run the upscaling process to get a Mixtral model
mixtral_for_causal_lm_model = upscaling_for_causal_lm_algorithm.run(pretrained_model)

print("Mixtral model:")
print_parameters(mixtral_for_causal_lm_model)
# Outputs:
# Mixtral model:
# trainable params: 24.15B || all params: 24.15B || trainable%: 100.0000

# Save the upscaled Mixtral model
mixtral_for_causal_lm_model.save_pretrained("path_to_save_mixtral_model")

A Jupyter notebook example is also available at our repo.

Code Integration

This is a guide on how to use the fusion_bench command-line interface to upscale a Mistral model to a Mixtral model.

The first code block is a YAML configuration file for the upscaling method. The name field specifies the name of the upscaling method. The num_experts field specifies the number of experts to use in the upscaling process. The experts_per_token field specifies the number of experts to use per token. The save_checkpoint field specifies the path where the upscaled model will be saved, if provided.

config/method/mixtral_moe_upscaling.yaml
name: mixtral_for_causal_lm_moe_upscaling # or "mixtral_moe_upscaling"

num_experts: 4
experts_per_token: 2
# path to save the upscaled model
save_checkpoint: null

The second code block is another YAML configuration file, this time for the model pool. The type field specifies the type of model pool to use. The models field is a list of models to include in the pool. Each model should have a name and a path, and the model is loaded from the path.

config/modelpool/mixtral_moe_upscaling.yaml
type: AutoModelForCausalLMPool
# each model should have a name and a path, and the model is loaded from the path
# this is equivalent to `AutoModelForCausalLM.from_pretrained(path)`
models:
  - name: _pretrained_
    path: path_to_your_pretrained_model

Finally, the third code block is a bash command that runs the fusion_bench command-line interface with the specified method, model pool, and task pool. The method argument specifies the upscaling method to use. The modelpool argument specifies the model pool to use. The modelpool.models.0.path argument specifies the path to the pretrained model to use. The taskpool argument specifies the task pool to use. In this case, a dummy task pool is used that does nothing but print the parameter counts of the merged model.

fusion_bench \
    method=mixtral_moe_upscaling \
    modelpool=mixtral_moe_upscaling \
        modelpool.models.0.path=path_to_your_pretrained_model \
    taskpool=dummy # this is a dummy taskpool that does nothing but print the parameter counts of the merged model

References

mixtral_upcycling

MixtralForCausalLMUpscalingAlgorithm

Bases: BaseAlgorithm

This class is responsible for upscaling a model to a MixtralForCausalLM. It inherits from the ModelFusionAlgorithm class.

Source code in fusion_bench/method/mixture_of_experts/mixtral_upcycling.py
class MixtralForCausalLMUpscalingAlgorithm(BaseAlgorithm):
    """
    This class is responsible for upscaling a model to a MixtralForCausalLM.
    It inherits from the ModelFusionAlgorithm class.
    """

    _config_mapping = BaseAlgorithm._config_mapping | {
        "num_experts": "num_experts",
        "experts_per_token": "experts_per_token",
        "save_checkpoint": "save_checkpoint",
    }

    def __init__(
        self,
        num_experts: int,
        experts_per_token: int,
        save_checkpoint: str,
        **kwargs,
    ):
        """
        Initialize the MixtralForCausalLMUpscalingAlgorithm.

        Args:
            num_experts (int): The number of experts in the Mixtral model.
            experts_per_token (int): The number of experts per token.
            save_checkpoint (str): The path to save the checkpoint.
            **kwargs: Additional keyword arguments.
        """
        self.num_experts = num_experts
        self.experts_per_token = experts_per_token
        self.save_checkpoint = save_checkpoint
        super().__init__(**kwargs)

    @torch.no_grad()
    def _run(
        self, modelpool: BaseModelPool | LlamaForCausalLM | MistralForCausalLM
    ) -> MixtralForCausalLM:
        """
        Internal method to run the upscaling process.

        Args:
            modelpool (BaseModelPool | LlamaForCausalLM | MistralForCausalLM): The model to be upscaled.

        Returns:
            MixtralForCausalLM: The upscaled model.
        """
        if isinstance(modelpool, BaseModelPool):
            assert modelpool.has_pretrained, "ModelPool must have pretrained model."
            pretrained_model = modelpool.load_model("_pretrained_")
        elif isinstance(modelpool, (LlamaForCausalLM, MistralForCausalLM)):
            pretrained_model = modelpool
        else:
            raise ValueError("Invalid modelpool type")

        mixtral_config = _convert_config_to_mixtral(
            pretrained_model.config,
            self.config.num_experts,
            self.config.experts_per_token,
        )

        with ContextManagers([no_init_weights(True)]):
            for _ in tqdm(range(1), desc="Initializing Mixtral model"):
                mixtral_model = MixtralForCausalLM(mixtral_config)
        upscale_to_mixtral_for_causal_lm(pretrained_model, mixtral_model)

        return mixtral_model

    @torch.no_grad()
    def run(
        self, modelpool: BaseModelPool | LlamaForCausalLM | MistralForCausalLM
    ) -> MixtralForCausalLM:
        """
        Runs the upscaling process.

        Args:
            modelpool (ModelPool | LlamaForCausalLM | MistralForCausalLM): The model to be upscaled.

        Returns:
            MixtralForCausalLM: The upscaled model.
        """
        mixtral_model = self._run(modelpool)

        if self.config.get("save_checkpoint", None) is not None:
            mixtral_model.save_pretrained(self.config.save_checkpoint)
        return mixtral_model
__init__(num_experts, experts_per_token, save_checkpoint, **kwargs)

Initialize the MixtralForCausalLMUpscalingAlgorithm.

Parameters:

  • num_experts (int) –

    The number of experts in the Mixtral model.

  • experts_per_token (int) –

    The number of experts per token.

  • save_checkpoint (str) –

    The path to save the checkpoint.

  • **kwargs

    Additional keyword arguments.

Source code in fusion_bench/method/mixture_of_experts/mixtral_upcycling.py
def __init__(
    self,
    num_experts: int,
    experts_per_token: int,
    save_checkpoint: str,
    **kwargs,
):
    """
    Initialize the MixtralForCausalLMUpscalingAlgorithm.

    Args:
        num_experts (int): The number of experts in the Mixtral model.
        experts_per_token (int): The number of experts per token.
        save_checkpoint (str): The path to save the checkpoint.
        **kwargs: Additional keyword arguments.
    """
    self.num_experts = num_experts
    self.experts_per_token = experts_per_token
    self.save_checkpoint = save_checkpoint
    super().__init__(**kwargs)
run(modelpool)

Runs the upscaling process.

Parameters:

  • modelpool (ModelPool | LlamaForCausalLM | MistralForCausalLM) –

    The model to be upscaled.

Returns:

  • MixtralForCausalLM ( MixtralForCausalLM ) –

    The upscaled model.

Source code in fusion_bench/method/mixture_of_experts/mixtral_upcycling.py
@torch.no_grad()
def run(
    self, modelpool: BaseModelPool | LlamaForCausalLM | MistralForCausalLM
) -> MixtralForCausalLM:
    """
    Runs the upscaling process.

    Args:
        modelpool (ModelPool | LlamaForCausalLM | MistralForCausalLM): The model to be upscaled.

    Returns:
        MixtralForCausalLM: The upscaled model.
    """
    mixtral_model = self._run(modelpool)

    if self.config.get("save_checkpoint", None) is not None:
        mixtral_model.save_pretrained(self.config.save_checkpoint)
    return mixtral_model
MixtralUpscalingAlgorithm

Bases: BaseAlgorithm

This class is responsible for upscaling a model to a MixtralModel. It inherits from the ModelFusionAlgorithm class.

Source code in fusion_bench/method/mixture_of_experts/mixtral_upcycling.py
class MixtralUpscalingAlgorithm(BaseAlgorithm):
    """
    This class is responsible for upscaling a model to a MixtralModel.
    It inherits from the ModelFusionAlgorithm class.
    """

    _config_mapping = BaseAlgorithm._config_mapping | {
        "num_experts": "num_experts",
        "experts_per_token": "experts_per_token",
        "save_checkpoint": "save_checkpoint",
    }

    def __init__(
        self,
        num_experts: int,
        experts_per_token: int,
        save_checkpoint: str,
        **kwargs,
    ):
        """
        Initialize the MixtralUpscalingAlgorithm.

        Args:
            num_experts (int): The number of experts in the Mixtral model.
            experts_per_token (int): The number of experts per token.
            save_checkpoint (str): The path to save the checkpoint.
            **kwargs: Additional keyword arguments.
        """
        self.num_experts = num_experts
        self.experts_per_token = experts_per_token
        self.save_checkpoint = save_checkpoint
        super().__init__(**kwargs)

    @torch.no_grad()
    def _run(
        self, modelpool: BaseModelPool | LlamaModel | MistralModel
    ) -> MixtralModel:
        """
        Internal method to run the upscaling process.

        Args:
            modelpool (BaseModelPool | LlamaModel | MistralModel): The model to be upscaled.

        Returns:
            MixtralModel: The upscaled model.
        """
        if isinstance(modelpool, BaseModelPool):
            assert modelpool.has_pretrained, "ModelPool must have pretrained model."
            pretrained_model = modelpool.load_model("_pretrained_")
        elif isinstance(modelpool, (LlamaModel, MistralModel)):
            pretrained_model = modelpool
        else:
            raise ValueError("Invalid modelpool type")

        mixtral_config = _convert_config_to_mixtral(
            pretrained_model.config,
            self.config.num_experts,
            self.config.experts_per_token,
        )

        with ContextManagers([no_init_weights(True)]):
            for _ in tqdm(range(1), desc="Initializing Mixtral model"):
                mixtral_model = MixtralModel(mixtral_config)
        upscale_to_mixtral_model(pretrained_model, mixtral_model)

        return mixtral_model

    @torch.no_grad()
    def run(self, modelpool: BaseModelPool | LlamaModel | MistralModel) -> MixtralModel:
        """
        Runs the upscaling process.

        Args:
            modelpool (ModelPool | LlamaModel | MistralModel): The model to be upscaled.

        Returns:
            MixtralModel: The upscaled model.
        """
        mixtral_model = self._run(modelpool)

        if self.config.get("save_checkpoint", None) is not None:
            mixtral_model.save_pretrained(self.config.save_checkpoint)
        return mixtral_model
__init__(num_experts, experts_per_token, save_checkpoint, **kwargs)

Initialize the MixtralUpscalingAlgorithm.

Parameters:

  • num_experts (int) –

    The number of experts in the Mixtral model.

  • experts_per_token (int) –

    The number of experts per token.

  • save_checkpoint (str) –

    The path to save the checkpoint.

  • **kwargs

    Additional keyword arguments.

Source code in fusion_bench/method/mixture_of_experts/mixtral_upcycling.py
def __init__(
    self,
    num_experts: int,
    experts_per_token: int,
    save_checkpoint: str,
    **kwargs,
):
    """
    Initialize the MixtralUpscalingAlgorithm.

    Args:
        num_experts (int): The number of experts in the Mixtral model.
        experts_per_token (int): The number of experts per token.
        save_checkpoint (str): The path to save the checkpoint.
        **kwargs: Additional keyword arguments.
    """
    self.num_experts = num_experts
    self.experts_per_token = experts_per_token
    self.save_checkpoint = save_checkpoint
    super().__init__(**kwargs)
run(modelpool)

Runs the upscaling process.

Parameters:

  • modelpool (ModelPool | LlamaModel | MistralModel) –

    The model to be upscaled.

Returns:

  • MixtralModel ( MixtralModel ) –

    The upscaled model.

Source code in fusion_bench/method/mixture_of_experts/mixtral_upcycling.py
@torch.no_grad()
def run(self, modelpool: BaseModelPool | LlamaModel | MistralModel) -> MixtralModel:
    """
    Runs the upscaling process.

    Args:
        modelpool (ModelPool | LlamaModel | MistralModel): The model to be upscaled.

    Returns:
        MixtralModel: The upscaled model.
    """
    mixtral_model = self._run(modelpool)

    if self.config.get("save_checkpoint", None) is not None:
        mixtral_model.save_pretrained(self.config.save_checkpoint)
    return mixtral_model
upscale_to_mixtral_for_causal_lm(input_model, output_model)

A helper function.

Upscales a LlamaForCausalLM or MistralForCausalLM to a MixtralForCausalLM.

Parameters:

  • input_model
    (LlamaForCausalLM | MistralForCausalLM) –

    The input model to be upscaled.

  • output_model
    (MixtralForCausalLM) –

    The output model where the upscaled weights will be loaded.

Returns:

  • None

Source code in fusion_bench/method/mixture_of_experts/mixtral_upcycling.py
def upscale_to_mixtral_for_causal_lm(
    input_model: LlamaForCausalLM | MistralForCausalLM, output_model: MixtralForCausalLM
):
    """
    A helper function.

    Upscales a LlamaForCausalLM or MistralForCausalLM to a MixtralForCausalLM.

    Args:
        input_model (LlamaForCausalLM | MistralForCausalLM): The input model to be upscaled.
        output_model (MixtralForCausalLM): The output model where the upscaled weights will be loaded.

    Returns:
        None
    """
    output_model.lm_head.load_state_dict(input_model.lm_head.state_dict())
    upscale_to_mixtral_model(input_model.model, output_model.model)
upscale_to_mixtral_model(input_model, output_model)

A helper function.

Upscales a LlamaModel or MistralModel to a MixtralModel.

Parameters:

  • input_model
    (LlamaModel | MistralModel) –

    The input model to be upscaled.

  • output_model
    (MixtralModel) –

    The output model where the upscaled weights will be loaded.

Returns:

  • None

Source code in fusion_bench/method/mixture_of_experts/mixtral_upcycling.py
def upscale_to_mixtral_model(
    input_model: LlamaModel | MistralModel, output_model: MixtralModel
):
    """
    A helper function.

    Upscales a LlamaModel or MistralModel to a MixtralModel.

    Args:
        input_model (LlamaModel | MistralModel): The input model to be upscaled.
        output_model (MixtralModel): The output model where the upscaled weights will be loaded.

    Returns:
        None
    """
    # copy the weights from the pretrained model
    output_model.embed_tokens.load_state_dict(input_model.embed_tokens.state_dict())
    output_model.norm.load_state_dict(input_model.norm.state_dict())
    for input_layer, output_layer in tqdm(
        zip(input_model.layers, output_model.layers),
        desc="Upscaling layers",
        total=len(input_model.layers),
    ):
        _upscale_decoder_layer(input_layer, output_layer)

  1. Sparse Upcycling: Training Mixture-of-Experts from Dense Checkpoints. http://arxiv.org/abs/2212.05055