Skip to content

Large Language Models (Causal LMs)

The CausalLMPool class provides a unified interface for managing and loading causal language models from the Hugging Face Transformers library with flexible configuration options.

Configuration

The CausalLMPool can be configured using YAML files. Here are the main configuration options:

Basic Configuration

_target_: fusion_bench.modelpool.CausalLMPool # (1)
models:
  _pretrained_: path_to_pretrained_model # (2)
  model_a: path_to_model_a
  model_b: path_to_model_b
model_kwargs: # (3)
  torch_dtype: bfloat16  # or float16, float32, etc.
tokenizer: path_to_tokenizer # (4)
  1. _target_ indicates the modelpool class to be instantiated.
  2. _pretrained_, model_a, and model_b indicates the name of the model to be loaded, if a plain string is given as the value, it will be passed to AutoModelForCausalLM.from_pretrained to load the model.
  3. model_kwargs is a dictionary of keyword arguments to be passed to AutoModelForCausalLM.from_pretrained, can be overridden by passing kwargs to modelpool.load_model function.
  4. tokenizer indicates the tokenizer to be loaded, if a plain string, it will be passed to AutoTokenizer.from_pretrained.

Special Model Names in FusionBench

Names starting and ending with "_" are reserved for special purposes in FusionBench. For example, _pretrained_ is a special model name in FusionBench, it is used to specify the pre-trained model to be loaded and pre-trained model can be loaded by calling modelpool.load_pretrained_model() or modelpool.load_model("_pretrained_").

Basic Usage

Information about the Model Pool

Get all the model names in the model pool except the special model names:

>>> modelpool.model_names
['model_a', 'model_b']

Check if a pre-trained model is in the model pool:

>>> modelpool.has_pretrained
True

Get all the model names in the model pool, including the special model names:

>>> modelpool.all_model_names
['_pretrained_', 'model_a', 'model_b']

Loading and Saving Models and Tokenizers

Load a model from the model pool by model name:

>>> model_a = modelpool.load_model("model_a")

Load a model from the model pool and pass/override additional arguments to the model constructor:

>>> model_a_fp32 = modelpool.load_model("model_a", torch_dtype="float32")

Load the pre-trained model from the model pool:

>>> pretrained_model = modelpool.load_pretrained_model()
# or equivalently
>>> pretrained_model = modelpool.load_model("_pretrained_")

Load the pre-trained model or the first model in the model pool:

# if there is a pre-trained model in the model pool, then it will be loaded
# otherwise, the first model in the model pool will be loaded
>>> model = modelpool.load_pretrained_or_first_model()

Load the tokenizer from the model pool:

>>> tokenizer = modelpool.load_tokenizer()

Save a model with tokenizer:

# Save model with tokenizer
>>> modelpool.save_model(
    model=model,
    path="path/to/save",
    save_tokenizer=True,
    push_to_hub=False
)

Advanced Configuration

You can also use more detailed configuration with explicit model and tokenizer settings:

_target_: fusion_bench.modelpool.CausalLMPool
models:
  _pretrained_:
    _target_: transformers.AutoModelForCausalLM # (1)
    pretrained_model_name_or_path: path_to_pretrained_model
  model_a:
    _target_: transformers.AutoModelForCausalLM
    pretrained_model_name_or_path: path_to_model_a
tokenizer:
  _target_: transformers.AutoTokenizer # (2)
  pretrained_model_name_or_path: path_to_tokenizer
model_kwargs:
  torch_dtype: bfloat16
  1. _target_ indicates the model class to be loaded, if a plain string is given as the value, it will be passed to AutoModelForCausalLM.from_pretrained to load the model. By setting _target_, you can use a custom model class or function to load the model. For example, you can use load_peft_causal_lm to load a PEFT model.
  2. _target_ indicates the tokenizer class to be loaded, if a plain string is given as the value, it will be passed to AutoTokenizer.from_pretrained to load the tokenizer. By setting _target_, you can use a custom tokenizer class or function to load the tokenizer.

Working with PEFT Models

from fusion_bench.modelpool.causal_lm import load_peft_causal_lm

# Load a PEFT model
model = load_peft_causal_lm(
    base_model_path="path/to/base/model",
    peft_model_path="path/to/peft/model",
    torch_dtype="bfloat16",
    is_trainable=True,
    merge_and_unload=False
)

Configuration Examples

Single Model Configuration

config/modelpool/CausalLMPool/single_llama_model.yaml
_target_: fusion_bench.modelpool.CausalLMPool
_recursive_: false
# each model should have a name and a path, and the model is loaded from the path
# this is equivalent to `AutoModelForCausalLM.from_pretrained(path)`
models:
  _pretrained_:
    _target_: transformers.LlamaForCausalLM.from_pretrained
    pretrained_model_name_or_path: ${...base_model}
model_kwargs:
  torch_dtype: float16
tokenizer:
  _target_: transformers.AutoTokenizer.from_pretrained
  pretrained_model_name_or_path: ${..base_model}
base_model: decapoda-research/llama-7b-hf

Multiple Models Configuration

Here we use models from MergeBench as an example.

config/modelpool/CausalLMPool/mergebench/gemma-2-2b.yaml
_target_: fusion_bench.modelpool.CausalLMPool
models:
  _pretrained_: google/gemma-2-2b
  instruction: MergeBench/gemma-2-2b_instruction
  math: MergeBench/gemma-2-2b_math
  coding: MergeBench/gemma-2-2b_coding
  multilingual: MergeBench/gemma-2-2b_multilingual
  safety: MergeBench/gemma-2-2b_safety
model_kwargs:
  torch_dtype: bfloat16
tokenizer: google/gemma-2-2b
config/modelpool/CausalLMPool/mergebench/gemma-2-2b-it.yaml
_target_: fusion_bench.modelpool.CausalLMPool
models:
  _pretrained_: google/gemma-2-2b-it
  instruction: MergeBench/gemma-2-2b-it_instruction
  math: MergeBench/gemma-2-2b-it_math
  coding: MergeBench/gemma-2-2b-it_coding
  multilingual: MergeBench/gemma-2-2b-it_multilingual
  safety: MergeBench/gemma-2-2b-it_safety
model_kwargs:
  torch_dtype: bfloat16
tokenizer: google/gemma-2-2b-it
config/modelpool/CausalLMPool/mergebench/gemma-2-9b.yaml
_target_: fusion_bench.modelpool.CausalLMPool
models:
  _pretrained_: google/gemma-2-9b
  instruction: MergeBench/gemma-2-9b_instruction
  math: MergeBench/gemma-2-9b_math
  coding: MergeBench/gemma-2-9b_coding
  multilingual: MergeBench/gemma-2-9b_multilingual
  safety: MergeBench/gemma-2-9b_safety
model_kwargs:
  torch_dtype: bfloat16
tokenizer: google/gemma-2-9b
config/modelpool/CausalLMPool/mergebench/gemma-2-9b-it.yaml
_target_: fusion_bench.modelpool.CausalLMPool
models:
  _pretrained_: google/gemma-2-9b-it
  instruction: MergeBench/gemma-2-9b-it_instruction
  math: MergeBench/gemma-2-9b-it_math
  coding: MergeBench/gemma-2-9b-it_coding
  multilingual: MergeBench/gemma-2-9b-it_multilingual
  safety: MergeBench/gemma-2-9b-it_safety
model_kwargs:
  torch_dtype: bfloat16
tokenizer: google/gemma-2-9b-it
config/modelpool/CausalLMPool/mergebench/Llama-3.1-8B.yaml
_target_: fusion_bench.modelpool.CausalLMPool
models:
  _pretrained_: meta-llama/Llama-3.1-8B
  instruction: MergeBench/Llama-3.1-8B_instruction
  math: MergeBench/Llama-3.1-8B_math
  coding: MergeBench/Llama-3.1-8B_coding
  multilingual: MergeBench/Llama-3.1-8B_multilingual
  safety: MergeBench/Llama-3.1-8B_safety
model_kwargs:
  torch_dtype: bfloat16
tokenizer: meta-llama/Llama-3.1-8B
config/modelpool/CausalLMPool/mergebench/Llama-3.1-8B-Instruct.yaml
_target_: fusion_bench.modelpool.CausalLMPool
models:
  _pretrained_: meta-llama/Llama-3.1-8B-Instruct
  instruction: MergeBench/Llama-3.1-8B-Instruct_instruction
  math: MergeBench/Llama-3.1-8B-Instruct_math
  coding: MergeBench/Llama-3.1-8B-Instruct_coding
  multilingual: MergeBench/Llama-3.1-8B-Instruct_multilingual
  safety: MergeBench/Llama-3.1-8B-Instruct_safety
model_kwargs:
  torch_dtype: bfloat16
tokenizer: meta-llama/Llama-3.1-8B-Instruct
config/modelpool/CausalLMPool/mergebench/Llama-3.2-3B.yaml
_target_: fusion_bench.modelpool.CausalLMPool
models:
  _pretrained_: meta-llama/Llama-3.2-3B
  instruction: MergeBench/Llama-3.2-3B_instruction
  math: MergeBench/Llama-3.2-3B_math
  coding: MergeBench/Llama-3.2-3B_coding
  multilingual: MergeBench/Llama-3.2-3B_multilingual
  safety: MergeBench/Llama-3.2-3B_safety
model_kwargs:
  torch_dtype: bfloat16
tokenizer: meta-llama/Llama-3.2-3B
config/modelpool/CausalLMPool/mergebench/Llama-3.2-3B-Instruct.yaml
_target_: fusion_bench.modelpool.CausalLMPool
models:
  _pretrained_: meta-llama/Llama-3.2-3B-Instruct
  instruction: MergeBench/Llama-3.2-3B-Instruct_instruction
  math: MergeBench/Llama-3.2-3B-Instruct_math
  coding: MergeBench/Llama-3.2-3B-Instruct_coding
  multilingual: MergeBench/Llama-3.2-3B-Instruct_multilingual
  safety: MergeBench/Llama-3.2-3B-Instruct_safety
model_kwargs:
  torch_dtype: bfloat16
tokenizer: meta-llama/Llama-3.2-3B-Instruct

Merge Large Language Models with FusionBench

Merge gemma-2b models with simple average:

fusion_bench method=simple_average modelpool=CausalLMPool/mergebench/gemma-2-2b

Merge gemma-2b models with Task Arithmetic:

fusion_bench method=task_arithmetic modelpool=CausalLMPool/mergebench/gemma-2-2b

Merge Llama-3.1-8B models with Ties-Merging:

fusion_bench method=ties_merging modelpool=CausalLMPool/mergebench/Llama-3.1-8B

Merge Llama-3.1-8B-Instruct models with Dare-Ties, with 70% sparsity:

fusion_bench method=dare/ties_merging method.sparsity_ratio=0.7 modelpool=CausalLMPool/mergebench/Llama-3.1-8B-Instruct

Special Features

CausalLMBackbonePool

The CausalLMBackbonePool is a specialized version of CausalLMPool that returns only the transformer layers of the model. This is useful when you need to work with the model's backbone architecture directly.

from fusion_bench.modelpool import CausalLMBackbonePool

backbone_pool = CausalLMBackbonePool.from_config(config)
layers = backbone_pool.load_model("model_a")  # Returns model.layers

References

CausalLMPool

Bases: BaseModelPool

Source code in fusion_bench/modelpool/causal_lm/causal_lm.py
class CausalLMPool(BaseModelPool):
    _config_mapping = BaseModelPool._config_mapping | {
        "_tokenizer": "tokenizer",
        "_model_kwargs": "model_kwargs",
    }

    def __init__(
        self,
        models,
        *,
        tokenizer: Optional[DictConfig],
        model_kwargs: Optional[DictConfig] = None,
        **kwargs,
    ):
        super().__init__(models, **kwargs)
        # process `model_kwargs`
        self._tokenizer = tokenizer
        self._model_kwargs = model_kwargs
        if self._model_kwargs is None:
            self._model_kwargs = DictConfig({})
        with flag_override(self._model_kwargs, "allow_objects", True):
            if hasattr(self._model_kwargs, "torch_dtype"):
                self._model_kwargs.torch_dtype = parse_dtype(
                    self._model_kwargs.torch_dtype
                )

    @override
    def load_model(
        self,
        model_name_or_config: str | DictConfig,
        *args,
        **kwargs,
    ) -> PreTrainedModel:
        """
        Example of YAML config:

        ```yaml
        models:
          _pretrained_: path_to_pretrained_model # if a plain string, it will be passed to AutoModelForCausalLM.from_pretrained
          model_a: path_to_model_a
          model_b: path_to_model_b
        ```

        or equivalently,

        ```yaml
        models:
          _pretrained_:
            _target_: transformers.AutoModelForCausalLM # any callable that returns a model
            pretrained_model_name_or_path: path_to_pretrained_model
          model_a:
            _target_: transformers.AutoModelForCausalLM
            pretrained_model_name_or_path: path_to_model_a
          model_b:
            _target_: transformers.AutoModelForCausalLM
            pretrained_model_name_or_path: path_to_model_b
        ```
        """
        model_kwargs = deepcopy(self._model_kwargs)
        model_kwargs.update(kwargs)

        if isinstance(model_name_or_config, str):
            log.info(f"Loading model: {model_name_or_config}", stacklevel=2)
            if model_name_or_config in self._models.keys():
                model_config = self._models[model_name_or_config]
                if isinstance(model_config, str):
                    # model_config is a string
                    model = AutoModelForCausalLM.from_pretrained(
                        model_config,
                        *args,
                        **model_kwargs,
                    )
                    return model
        elif isinstance(model_name_or_config, (DictConfig, Dict)):
            model_config = model_name_or_config

        model = instantiate(model_config, *args, **model_kwargs)
        return model

    def load_tokenizer(self, *args, **kwargs) -> PreTrainedTokenizer:
        """
        Example of YAML config:

        ```yaml
        tokenizer: google/gemma-2-2b-it # if a plain string, it will be passed to AutoTokenizer.from_pretrained
        ```

        or equivalently,

        ```yaml
        tokenizer:
          _target_: transformers.AutoTokenizer # any callable that returns a tokenizer
          pretrained_model_name_or_path: google/gemma-2-2b-it
        ```

        Returns:
            PreTrainedTokenizer: The tokenizer.
        """
        assert self._tokenizer is not None, "Tokenizer is not defined in the config"
        log.info("Loading tokenizer.", stacklevel=2)
        if isinstance(self._tokenizer, str):
            tokenizer = AutoTokenizer.from_pretrained(self._tokenizer, *args, **kwargs)
        else:
            tokenizer = instantiate(self._tokenizer, *args, **kwargs)
        return tokenizer

    @override
    def save_model(
        self,
        model: PreTrainedModel,
        path: str,
        push_to_hub: bool = False,
        model_dtype: Optional[str] = None,
        save_tokenizer: bool = False,
        tokenizer_kwargs=None,
        **kwargs,
    ):
        """
        Save the model to the specified path.

        Args:
            model (PreTrainedModel): The model to be saved.
            path (str): The path where the model will be saved.
            push_to_hub (bool, optional): Whether to push the model to the Hugging Face Hub. Defaults to False.
            save_tokenizer (bool, optional): Whether to save the tokenizer along with the model. Defaults to False.
            **kwargs: Additional keyword arguments passed to the `save_pretrained` method.
        """
        path = os.path.expanduser(path)
        if save_tokenizer:
            if tokenizer_kwargs is None:
                tokenizer_kwargs = {}
            # load the tokenizer
            tokenizer = self.load_tokenizer(**tokenizer_kwargs)
            tokenizer.save_pretrained(
                path,
                push_to_hub=push_to_hub,
            )
        if model_dtype is not None:
            model.to(dtype=parse_dtype(model_dtype))
        model.save_pretrained(
            path,
            push_to_hub=push_to_hub,
            **kwargs,
        )
load_model(model_name_or_config, *args, **kwargs)

Example of YAML config:

models:
  _pretrained_: path_to_pretrained_model # if a plain string, it will be passed to AutoModelForCausalLM.from_pretrained
  model_a: path_to_model_a
  model_b: path_to_model_b

or equivalently,

models:
  _pretrained_:
    _target_: transformers.AutoModelForCausalLM # any callable that returns a model
    pretrained_model_name_or_path: path_to_pretrained_model
  model_a:
    _target_: transformers.AutoModelForCausalLM
    pretrained_model_name_or_path: path_to_model_a
  model_b:
    _target_: transformers.AutoModelForCausalLM
    pretrained_model_name_or_path: path_to_model_b
Source code in fusion_bench/modelpool/causal_lm/causal_lm.py
@override
def load_model(
    self,
    model_name_or_config: str | DictConfig,
    *args,
    **kwargs,
) -> PreTrainedModel:
    """
    Example of YAML config:

    ```yaml
    models:
      _pretrained_: path_to_pretrained_model # if a plain string, it will be passed to AutoModelForCausalLM.from_pretrained
      model_a: path_to_model_a
      model_b: path_to_model_b
    ```

    or equivalently,

    ```yaml
    models:
      _pretrained_:
        _target_: transformers.AutoModelForCausalLM # any callable that returns a model
        pretrained_model_name_or_path: path_to_pretrained_model
      model_a:
        _target_: transformers.AutoModelForCausalLM
        pretrained_model_name_or_path: path_to_model_a
      model_b:
        _target_: transformers.AutoModelForCausalLM
        pretrained_model_name_or_path: path_to_model_b
    ```
    """
    model_kwargs = deepcopy(self._model_kwargs)
    model_kwargs.update(kwargs)

    if isinstance(model_name_or_config, str):
        log.info(f"Loading model: {model_name_or_config}", stacklevel=2)
        if model_name_or_config in self._models.keys():
            model_config = self._models[model_name_or_config]
            if isinstance(model_config, str):
                # model_config is a string
                model = AutoModelForCausalLM.from_pretrained(
                    model_config,
                    *args,
                    **model_kwargs,
                )
                return model
    elif isinstance(model_name_or_config, (DictConfig, Dict)):
        model_config = model_name_or_config

    model = instantiate(model_config, *args, **model_kwargs)
    return model
load_tokenizer(*args, **kwargs)

Example of YAML config:

tokenizer: google/gemma-2-2b-it # if a plain string, it will be passed to AutoTokenizer.from_pretrained

or equivalently,

tokenizer:
  _target_: transformers.AutoTokenizer # any callable that returns a tokenizer
  pretrained_model_name_or_path: google/gemma-2-2b-it

Returns:

  • PreTrainedTokenizer ( PreTrainedTokenizer ) –

    The tokenizer.

Source code in fusion_bench/modelpool/causal_lm/causal_lm.py
def load_tokenizer(self, *args, **kwargs) -> PreTrainedTokenizer:
    """
    Example of YAML config:

    ```yaml
    tokenizer: google/gemma-2-2b-it # if a plain string, it will be passed to AutoTokenizer.from_pretrained
    ```

    or equivalently,

    ```yaml
    tokenizer:
      _target_: transformers.AutoTokenizer # any callable that returns a tokenizer
      pretrained_model_name_or_path: google/gemma-2-2b-it
    ```

    Returns:
        PreTrainedTokenizer: The tokenizer.
    """
    assert self._tokenizer is not None, "Tokenizer is not defined in the config"
    log.info("Loading tokenizer.", stacklevel=2)
    if isinstance(self._tokenizer, str):
        tokenizer = AutoTokenizer.from_pretrained(self._tokenizer, *args, **kwargs)
    else:
        tokenizer = instantiate(self._tokenizer, *args, **kwargs)
    return tokenizer
save_model(model, path, push_to_hub=False, model_dtype=None, save_tokenizer=False, tokenizer_kwargs=None, **kwargs)

Save the model to the specified path.

Parameters:

  • model
    (PreTrainedModel) –

    The model to be saved.

  • path
    (str) –

    The path where the model will be saved.

  • push_to_hub
    (bool, default: False ) –

    Whether to push the model to the Hugging Face Hub. Defaults to False.

  • save_tokenizer
    (bool, default: False ) –

    Whether to save the tokenizer along with the model. Defaults to False.

  • **kwargs

    Additional keyword arguments passed to the save_pretrained method.

Source code in fusion_bench/modelpool/causal_lm/causal_lm.py
@override
def save_model(
    self,
    model: PreTrainedModel,
    path: str,
    push_to_hub: bool = False,
    model_dtype: Optional[str] = None,
    save_tokenizer: bool = False,
    tokenizer_kwargs=None,
    **kwargs,
):
    """
    Save the model to the specified path.

    Args:
        model (PreTrainedModel): The model to be saved.
        path (str): The path where the model will be saved.
        push_to_hub (bool, optional): Whether to push the model to the Hugging Face Hub. Defaults to False.
        save_tokenizer (bool, optional): Whether to save the tokenizer along with the model. Defaults to False.
        **kwargs: Additional keyword arguments passed to the `save_pretrained` method.
    """
    path = os.path.expanduser(path)
    if save_tokenizer:
        if tokenizer_kwargs is None:
            tokenizer_kwargs = {}
        # load the tokenizer
        tokenizer = self.load_tokenizer(**tokenizer_kwargs)
        tokenizer.save_pretrained(
            path,
            push_to_hub=push_to_hub,
        )
    if model_dtype is not None:
        model.to(dtype=parse_dtype(model_dtype))
    model.save_pretrained(
        path,
        push_to_hub=push_to_hub,
        **kwargs,
    )

CausalLMBackbonePool

Bases: CausalLMPool

Source code in fusion_bench/modelpool/causal_lm/causal_lm.py
class CausalLMBackbonePool(CausalLMPool):
    def load_model(
        self, model_name_or_config: str | DictConfig, *args, **kwargs
    ) -> Module:
        model: AutoModelForCausalLM = super().load_model(
            model_name_or_config, *args, **kwargs
        )
        return model.model.layers

load_peft_causal_lm(base_model_path, peft_model_path, torch_dtype='bfloat16', is_trainable=True, merge_and_unload=False)

Source code in fusion_bench/modelpool/causal_lm/causal_lm.py
def load_peft_causal_lm(
    base_model_path: str,
    peft_model_path: str,
    torch_dtype: str = "bfloat16",
    is_trainable: bool = True,
    merge_and_unload: bool = False,
):
    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_path, torch_dtype=torch_dtype
    )
    model = peft.PeftModel.from_pretrained(
        base_model,
        peft_model_path,
        is_trainable=is_trainable,
    )
    if merge_and_unload:
        model = model.merge_and_unload()
    return model