Skip to content

MoE-based Model Model Merging

Code Intergration

Here we provides instructions on how to use the fusion_bench command-line interface to merge models using a Mixture of Experts (MoE) approach.

The first code block is a YAML configuration file for the merging method. The name field specifies the name of the merging method. The num_experts field specifies the number of experts to use in the merging process. The experts_per_token field specifies the number of experts to use per token. The save_checkpoint field specifies the path where the merged model will be saved.

config/method/mixtral_moe_merging.yaml
name: mixtral_for_causal_lm_moe_merging

experts_per_token: 2
# path to save the merged model, if provided
save_checkpoint: null

The second code block is another YAML configuration file, this time for the model pool. The type field specifies the type of model pool to use. The models field is a list of models to include in the pool. Each model should have a name and a path, and the model is loaded from the path.

config/modelpool/mixtral_moe_merging.yaml
type: AutoModelForCausalLMPool
# each model should have a name and a path, and the model is loaded from the path
# this is equivalent to `AutoModelForCausalLM.from_pretrained(path)`
models:
  - name: _pretrained_
    path: path_to_your_pretrained_model
  - name: expert_1
    path: path_to_your_expert_model_1
  - name: expert_2
    path: path_to_your_expert_model_2
  - name: expert_3
    path: path_to_your_expert_model_3
  - name: expert_4
    path: path_to_your_expert_model_4

Finally, the third code block is a bash command that runs the fusion_bench command-line interface with the specified method, model pool, and task pool. The method argument specifies the merging method to use. The modelpool argument specifies the model pool to use. The modelpool.models.0.path argument specifies the path to the pretrained model to use. The taskpool argument specifies the task pool to use. In this case, a dummy task pool is used that does nothing but print the parameter counts of the merged model.

fusion_bench \
    method=mixtral_moe_merging \
    modelpool=mixtral_moe_merging \
    taskpool=dummy # this is a dummy taskpool that does nothing but print the parameter counts of the merged model

This guide provides a step-by-step process for merging models using the fusion_bench command-line interface. By following these instructions, you can merge your own models and save them for future use.

References

mixtral_merging

MixtralForCausalLMMergingAlgorithm

Bases: MixtralForCausalLMUpscalingAlgorithm

Source code in fusion_bench/method/mixture_of_experts/mixtral_merging.py
class MixtralForCausalLMMergingAlgorithm(MixtralForCausalLMUpscalingAlgorithm):
    @torch.no_grad()
    def run(self, modelpool: ModelPool) -> MixtralForCausalLM:
        """
        Runs the merging process. It first upscales the models to MixtralForCausalLM,
        then substitutes the experts of the MixtralForCausalLM with the models from the modelpool.

        Args:
            modelpool (ModelPool): The pool of models to be merged. Each model in the pool will be treated as an expert, and should be a `MistralForCausalLM` or `LlamaForCausalLM`.

        Returns:
            MixtralForCausalLM: The merged model.
        """
        with open_dict(self.config):
            self.config.num_experts = len(modelpool)

        # firstly, we upscale the models to MixtralForCausalLM
        mixtral_model = super()._run(modelpool)

        # then we substitute the experts of the MixtralForCausalLM with the models from the modelpool
        for model_idx, model_name in enumerate(modelpool.model_names):
            expert_model: MistralForCausalLM | LlamaForCausalLM = modelpool.load_model(
                model_name
            )
            _substitute_experts(model_idx, expert_model.model, mixtral_model.model)

        if self.config.get("save_checkpoint", None) is not None:
            mixtral_model.save_pretrained(self.config.save_checkpoint)
        return mixtral_model
run(modelpool)

Runs the merging process. It first upscales the models to MixtralForCausalLM, then substitutes the experts of the MixtralForCausalLM with the models from the modelpool.

Parameters:

  • modelpool (ModelPool) –

    The pool of models to be merged. Each model in the pool will be treated as an expert, and should be a MistralForCausalLM or LlamaForCausalLM.

Returns:

  • MixtralForCausalLM ( MixtralForCausalLM ) –

    The merged model.

Source code in fusion_bench/method/mixture_of_experts/mixtral_merging.py
@torch.no_grad()
def run(self, modelpool: ModelPool) -> MixtralForCausalLM:
    """
    Runs the merging process. It first upscales the models to MixtralForCausalLM,
    then substitutes the experts of the MixtralForCausalLM with the models from the modelpool.

    Args:
        modelpool (ModelPool): The pool of models to be merged. Each model in the pool will be treated as an expert, and should be a `MistralForCausalLM` or `LlamaForCausalLM`.

    Returns:
        MixtralForCausalLM: The merged model.
    """
    with open_dict(self.config):
        self.config.num_experts = len(modelpool)

    # firstly, we upscale the models to MixtralForCausalLM
    mixtral_model = super()._run(modelpool)

    # then we substitute the experts of the MixtralForCausalLM with the models from the modelpool
    for model_idx, model_name in enumerate(modelpool.model_names):
        expert_model: MistralForCausalLM | LlamaForCausalLM = modelpool.load_model(
            model_name
        )
        _substitute_experts(model_idx, expert_model.model, mixtral_model.model)

    if self.config.get("save_checkpoint", None) is not None:
        mixtral_model.save_pretrained(self.config.save_checkpoint)
    return mixtral_model
MixtralMoEMergingAlgorithm

Bases: MixtralUpscalingAlgorithm

This class is responsible for merging models into a MixtralModel.

Source code in fusion_bench/method/mixture_of_experts/mixtral_merging.py
class MixtralMoEMergingAlgorithm(MixtralUpscalingAlgorithm):
    """
    This class is responsible for merging models into a MixtralModel.
    """

    @torch.no_grad()
    def run(self, modelpool: ModelPool) -> MixtralModel:
        """
        Runs the merging process.

        Args:
            modelpool (ModelPool): The pool of models to be merged. Each model in the pool will be treated as an expert, and should be a `MistralModel` or `LlamaModel`.

        Returns:
            MixtralModel: The merged model.
        """
        with open_dict(self.config):
            self.config.num_experts = len(modelpool)

        # firstly, we upscale the models to MixtralModel
        mixtral_model = super()._run(modelpool)

        # then we substitute the experts of the MixtralModel with the models from the modelpool
        for model_idx, model_name in enumerate(modelpool.model_names):
            expert_model: MistralModel | LlamaModel = modelpool.load_model(model_name)
            _substitute_experts(model_idx, expert_model, mixtral_model)

        if self.config.get("save_checkpoint", None) is not None:
            mixtral_model.save_pretrained(self.config.save_checkpoint)
        return mixtral_model
run(modelpool)

Runs the merging process.

Parameters:

  • modelpool (ModelPool) –

    The pool of models to be merged. Each model in the pool will be treated as an expert, and should be a MistralModel or LlamaModel.

Returns:

  • MixtralModel ( MixtralModel ) –

    The merged model.

Source code in fusion_bench/method/mixture_of_experts/mixtral_merging.py
@torch.no_grad()
def run(self, modelpool: ModelPool) -> MixtralModel:
    """
    Runs the merging process.

    Args:
        modelpool (ModelPool): The pool of models to be merged. Each model in the pool will be treated as an expert, and should be a `MistralModel` or `LlamaModel`.

    Returns:
        MixtralModel: The merged model.
    """
    with open_dict(self.config):
        self.config.num_experts = len(modelpool)

    # firstly, we upscale the models to MixtralModel
    mixtral_model = super()._run(modelpool)

    # then we substitute the experts of the MixtralModel with the models from the modelpool
    for model_idx, model_name in enumerate(modelpool.model_names):
        expert_model: MistralModel | LlamaModel = modelpool.load_model(model_name)
        _substitute_experts(model_idx, expert_model, mixtral_model)

    if self.config.get("save_checkpoint", None) is not None:
        mixtral_model.save_pretrained(self.config.save_checkpoint)
    return mixtral_model