Skip to content

fusion_bench.models

Task and Layer-wise Merging (AdaMerging)

layer_wise_fusion

LayerWiseMergedModel

Bases: Module, Generic[TorchModelType]

Source code in fusion_bench/models/wrappers/layer_wise_fusion.py
class LayerWiseMergedModel(nn.Module, Generic[TorchModelType]):
    _merged_state_dict: StateDictType = None

    def __init__(
        self,
        layer_wise_weight: Tensor,
        pretrained_model: TorchModelType,
        finetuned_models: List[TorchModelType],
        clamp_weights: bool = True,
        tie_weights: bool = False,
        strict: bool = True,
        sparsity_ratio: Optional[float] = None,
        normalized_merging_weights: bool = False,
    ):
        R"""
        This class wraps a pretrained model and a list of finetuned models, and merges the weights of the finetuned models into the pretrained model using layer-wise fusion.

        Reference:

            (ICLR 2024) Yang E, Wang Z, Shen L, et al. Adamerging: Adaptive model merging for multi-task learning. https://arxiv.org/pdf/2310.02575

        Args:
            layer_wise_weight (Tensor): A tensor of shape (num_models, num_layers) representing the weight of each layer for each model.
            pretrained_model (nn.Module): The pretrained model to merge the weights into.
            finetuned_models (List[nn.Module]): A list of finetuned models to merge the weights from. This should have the same architecture as the pretrained model. We use these models to compute the task vectors.
            clamp_weights (bool, optional): If True, the layer-wise weights will be clamped to [0, 1]. Defaults to True.
            tie_weights (bool, optional): This option passes the `tie_weights` argument to the `functional_call` function. Defaults to False.
            strict (bool, optional): This option passes the `strict` argument to the `functional_call` function. Defaults to True.
            sparsity_ratio (float, optional): If `sparsity_ratio` is provided, the task vector will be pruned before merging. A high spasity level can save the memory usage during merging.
            normalized_merging_weights (bool, optional): If True, the layer-wise weights will be normalized for each layer, so that the sum of weights across models for each layer is 1. Defaults to False.
        """
        super().__init__()
        self.clamp_weights = clamp_weights
        self.tie_weights = tie_weights
        self.strict = strict
        self.sparsity_ratio = sparsity_ratio
        self.nromalized_merging_weights = normalized_merging_weights

        self.merge_weight = nn.Parameter(layer_wise_weight, requires_grad=True)

        for name, param in pretrained_model.named_parameters():
            if not param.requires_grad:
                for m in finetuned_models:
                    del_attr(m, name.split("."))
            else:
                for m in finetuned_models:
                    get_attr(m, name.split(".")).data = (
                        get_attr(m, name.split(".")) - param
                    )

        self.pretrained_model = pretrained_model.requires_grad_(False)
        for m in finetuned_models:
            m.requires_grad_(False)

        self.task_vectors = nn.ModuleList(finetuned_models)

        # if `sparisty_ratio` is given, pruning the task vectors.
        if sparsity_ratio is not None:
            from fusion_bench.method.pruning.prune_utils import (
                unstructured_magnitude_prune_,
            )

            for name, param in self.task_vectors.named_parameters():
                if param.dim() != 2:
                    continue
                print(f"pruning {name}")
                pruned_param = unstructured_magnitude_prune_(
                    param.data.clone(), torch.abs, sparsity_ratio=sparsity_ratio
                )
                set_attr(
                    self.task_vectors,
                    name.split("."),
                    nn.Parameter(pruned_param.to_sparse(), requires_grad=False),
                )

    @property
    def forward_model(self):
        return functools.partial(
            functional_call,
            self.pretrained_model,
            self._merged_state_dict,
            tie_weights=self.tie_weights,
            strict=self.strict,
        )

    def merge_and_unload(self, task_vector_mask: Optional[Dict[str, Tensor]] = None):
        self.merge_weights(task_vector_mask=task_vector_mask)
        self.pretrained_model.load_state_dict(self._merged_state_dict)
        return self.pretrained_model

    def merge_weights(self, task_vector_mask: Optional[Dict[str, Tensor]] = None):
        """
        Merges the weights of the model.
        Call this after each update step.
        """
        if self.clamp_weights:
            layer_wise_weight = self.merge_weight.clamp(0, 1)
        else:
            layer_wise_weight = self.merge_weight
        if self.nromalized_merging_weights:
            # normalize the weights for each layer, so that the sum of weights across models for each layer is 1.
            layer_wise_weight = layer_wise_weight.softmax(dim=0)

        state_dict = self.pretrained_model.state_dict(keep_vars=True)
        # shape of layer_wise_weight: (num_models, num_layers)
        for weight, task_vector in zip(layer_wise_weight, self.task_vectors):
            assert len(list(task_vector.named_parameters())) == weight.size(0)
            if task_vector_mask is not None:
                weight = [
                    w * task_vector_mask[name]
                    for w, (name, param) in zip(weight, task_vector.named_parameters())
                ]
            for w, (name, param) in zip(weight, task_vector.named_parameters()):
                state_dict[name] = state_dict[name] + param * w
        self._merged_state_dict = state_dict

        return state_dict

    def forward(self, *args, **kwargs):
        if self._merged_state_dict is None:
            self.merge_weights()
        return self.forward_model(args=args, kwargs=kwargs)
__init__(layer_wise_weight, pretrained_model, finetuned_models, clamp_weights=True, tie_weights=False, strict=True, sparsity_ratio=None, normalized_merging_weights=False)

This class wraps a pretrained model and a list of finetuned models, and merges the weights of the finetuned models into the pretrained model using layer-wise fusion.

Reference:

(ICLR 2024) Yang E, Wang Z, Shen L, et al. Adamerging: Adaptive model merging for multi-task learning. https://arxiv.org/pdf/2310.02575

Parameters:

  • layer_wise_weight (Tensor) –

    A tensor of shape (num_models, num_layers) representing the weight of each layer for each model.

  • pretrained_model (Module) –

    The pretrained model to merge the weights into.

  • finetuned_models (List[Module]) –

    A list of finetuned models to merge the weights from. This should have the same architecture as the pretrained model. We use these models to compute the task vectors.

  • clamp_weights (bool, default: True ) –

    If True, the layer-wise weights will be clamped to [0, 1]. Defaults to True.

  • tie_weights (bool, default: False ) –

    This option passes the tie_weights argument to the functional_call function. Defaults to False.

  • strict (bool, default: True ) –

    This option passes the strict argument to the functional_call function. Defaults to True.

  • sparsity_ratio (float, default: None ) –

    If sparsity_ratio is provided, the task vector will be pruned before merging. A high spasity level can save the memory usage during merging.

  • normalized_merging_weights (bool, default: False ) –

    If True, the layer-wise weights will be normalized for each layer, so that the sum of weights across models for each layer is 1. Defaults to False.

Source code in fusion_bench/models/wrappers/layer_wise_fusion.py
def __init__(
    self,
    layer_wise_weight: Tensor,
    pretrained_model: TorchModelType,
    finetuned_models: List[TorchModelType],
    clamp_weights: bool = True,
    tie_weights: bool = False,
    strict: bool = True,
    sparsity_ratio: Optional[float] = None,
    normalized_merging_weights: bool = False,
):
    R"""
    This class wraps a pretrained model and a list of finetuned models, and merges the weights of the finetuned models into the pretrained model using layer-wise fusion.

    Reference:

        (ICLR 2024) Yang E, Wang Z, Shen L, et al. Adamerging: Adaptive model merging for multi-task learning. https://arxiv.org/pdf/2310.02575

    Args:
        layer_wise_weight (Tensor): A tensor of shape (num_models, num_layers) representing the weight of each layer for each model.
        pretrained_model (nn.Module): The pretrained model to merge the weights into.
        finetuned_models (List[nn.Module]): A list of finetuned models to merge the weights from. This should have the same architecture as the pretrained model. We use these models to compute the task vectors.
        clamp_weights (bool, optional): If True, the layer-wise weights will be clamped to [0, 1]. Defaults to True.
        tie_weights (bool, optional): This option passes the `tie_weights` argument to the `functional_call` function. Defaults to False.
        strict (bool, optional): This option passes the `strict` argument to the `functional_call` function. Defaults to True.
        sparsity_ratio (float, optional): If `sparsity_ratio` is provided, the task vector will be pruned before merging. A high spasity level can save the memory usage during merging.
        normalized_merging_weights (bool, optional): If True, the layer-wise weights will be normalized for each layer, so that the sum of weights across models for each layer is 1. Defaults to False.
    """
    super().__init__()
    self.clamp_weights = clamp_weights
    self.tie_weights = tie_weights
    self.strict = strict
    self.sparsity_ratio = sparsity_ratio
    self.nromalized_merging_weights = normalized_merging_weights

    self.merge_weight = nn.Parameter(layer_wise_weight, requires_grad=True)

    for name, param in pretrained_model.named_parameters():
        if not param.requires_grad:
            for m in finetuned_models:
                del_attr(m, name.split("."))
        else:
            for m in finetuned_models:
                get_attr(m, name.split(".")).data = (
                    get_attr(m, name.split(".")) - param
                )

    self.pretrained_model = pretrained_model.requires_grad_(False)
    for m in finetuned_models:
        m.requires_grad_(False)

    self.task_vectors = nn.ModuleList(finetuned_models)

    # if `sparisty_ratio` is given, pruning the task vectors.
    if sparsity_ratio is not None:
        from fusion_bench.method.pruning.prune_utils import (
            unstructured_magnitude_prune_,
        )

        for name, param in self.task_vectors.named_parameters():
            if param.dim() != 2:
                continue
            print(f"pruning {name}")
            pruned_param = unstructured_magnitude_prune_(
                param.data.clone(), torch.abs, sparsity_ratio=sparsity_ratio
            )
            set_attr(
                self.task_vectors,
                name.split("."),
                nn.Parameter(pruned_param.to_sparse(), requires_grad=False),
            )
merge_weights(task_vector_mask=None)

Merges the weights of the model. Call this after each update step.

Source code in fusion_bench/models/wrappers/layer_wise_fusion.py
def merge_weights(self, task_vector_mask: Optional[Dict[str, Tensor]] = None):
    """
    Merges the weights of the model.
    Call this after each update step.
    """
    if self.clamp_weights:
        layer_wise_weight = self.merge_weight.clamp(0, 1)
    else:
        layer_wise_weight = self.merge_weight
    if self.nromalized_merging_weights:
        # normalize the weights for each layer, so that the sum of weights across models for each layer is 1.
        layer_wise_weight = layer_wise_weight.softmax(dim=0)

    state_dict = self.pretrained_model.state_dict(keep_vars=True)
    # shape of layer_wise_weight: (num_models, num_layers)
    for weight, task_vector in zip(layer_wise_weight, self.task_vectors):
        assert len(list(task_vector.named_parameters())) == weight.size(0)
        if task_vector_mask is not None:
            weight = [
                w * task_vector_mask[name]
                for w, (name, param) in zip(weight, task_vector.named_parameters())
            ]
        for w, (name, param) in zip(weight, task_vector.named_parameters()):
            state_dict[name] = state_dict[name] + param * w
    self._merged_state_dict = state_dict

    return state_dict

fix_other_parts(module)

Sets all parameters in the module to not require gradients, except for the merge weights in LayerWiseMergedModel instances.

Parameters:

  • module (Module) –

    The module to process.

Returns:

  • nn.Module: The module with updated parameter requirements.

Source code in fusion_bench/models/wrappers/layer_wise_fusion.py
def fix_other_parts(module: nn.Module):
    """
    Sets all parameters in the module to not require gradients, except for the merge weights
    in `LayerWiseMergedModel` instances.

    Args:
        module (nn.Module): The module to process.

    Returns:
        nn.Module: The module with updated parameter requirements.
    """
    module.requires_grad_(False)
    for submodule in module.modules():
        if isinstance(submodule, LayerWiseMergedModel):
            submodule.merge_weight.requires_grad_(True)
    return module

fuse_weights(layer_wise_weight, state_dicts)

Fuse the weights of multiple models using layer-wise fusion.

Parameters:

  • layer_wise_weight (Tensor) –

    A tensor of shape (num_models, num_layers) representing the weight of each layer for each model.

  • state_dicts (List[StateDict]) –

    A list of state dictionaries, one for each model.

Returns:

  • StateDictType

    A dictionary mapping each weight tensor key to the fused weight tensor.

Source code in fusion_bench/models/wrappers/layer_wise_fusion.py
def fuse_weights(
    layer_wise_weight: Tensor, state_dicts: List[StateDictType]
) -> StateDictType:
    """
    Fuse the weights of multiple models using layer-wise fusion.

    Args:
        layer_wise_weight (Tensor): A tensor of shape (num_models, num_layers) representing the weight of each layer for each model.
        state_dicts (List[StateDict]): A list of state dictionaries, one for each model.

    Returns:
        A dictionary mapping each weight tensor key to the fused weight tensor.
    """
    num_models = len(state_dicts)
    num_layers = len(state_dicts[0])
    assert layer_wise_weight.shape == (
        num_models,
        num_layers,
    ), f"layer_wise_weight.shape={layer_wise_weight.shape}, expected (num_models, num_layers): ({num_models}, {num_layers})"
    return {
        k: _fuse_weights(
            layer_wise_weight[:, i], [state_dict[k] for state_dict in state_dicts]
        )
        for i, k in enumerate(state_dicts[0].keys())
    }

get_layer_wise_weights(num_models, num_layers, init_values=None, dtype=torch.float32)

Return a tensor of layer-wise weights for the given number of models and layers.

Parameters:

  • num_models (int) –

    The number of models to fuse.

  • num_layers (int) –

    The number of layers in each model.

  • init_values (float, default: None ) –

    The initial value for each weight. Defaults to 1.0 / num_models.

  • dtype (dtype, default: float32 ) –

    dtype of weights. This should be the same with model dtype.

Returns:

  • Tensor

    A tensor of shape (num_models, num_layers) containing the layer-wise weights.

Source code in fusion_bench/models/wrappers/layer_wise_fusion.py
def get_layer_wise_weights(
    num_models: int,
    num_layers: int,
    init_values: float = None,
    dtype: torch.dtype = torch.float32,
):
    """
    Return a tensor of layer-wise weights for the given number of models and layers.

    Args:
        num_models (int): The number of models to fuse.
        num_layers (int): The number of layers in each model.
        init_values (float, optional): The initial value for each weight. Defaults to 1.0 / num_models.
        dtype (torch.dtype): dtype of weights. This should be the same with model dtype.

    Returns:
        Tensor: A tensor of shape (num_models, num_layers) containing the layer-wise weights.
    """
    assert num_models >= 1, f"num_models must be >= 1, got {num_models}"
    assert num_layers >= 1, f"num_layers must be >= 1, got {num_layers}"
    if init_values is None:
        init_values = 1.0 / num_models
    return torch.full((num_models, num_layers), init_values, dtype=dtype)

merge_and_unload(module)

Merges and unloads all LayerWiseMergedModel instances within the given module.

Parameters:

  • module (Module) –

    The module to process.

Returns:

  • nn.Module: The updated module with merged weights.

Source code in fusion_bench/models/wrappers/layer_wise_fusion.py
def merge_and_unload(module: nn.Module):
    """
    Merges and unloads all `LayerWiseMergedModel` instances within the given module.

    Args:
        module (nn.Module): The module to process.

    Returns:
        nn.Module: The updated module with merged weights.
    """
    if isinstance(module, LayerWiseMergedModel):
        return module.merge_and_unload()
    else:
        for name, submodule in module.named_children():
            need_merge = isinstance(submodule, LayerWiseMergedModel)
            submodule = merge_and_unload(submodule)
            if need_merge:
                setattr(module, name, submodule)
        return module

merge_weights(module)

Merges the weights for all LayerWiseMergedModel instances within the given module.

Parameters:

  • module (Module) –

    The module to process.

Source code in fusion_bench/models/wrappers/layer_wise_fusion.py
def merge_weights(module: nn.Module):
    """
    Merges the weights for all `LayerWiseMergedModel` instances within the given module.

    Args:
        module (nn.Module): The module to process.
    """
    if isinstance(module, LayerWiseMergedModel):
        module.merge_weights()
        return
    else:
        for submodule in module.children():
            merge_weights(submodule)

task_wise_fusion

# Get the task-wise weights
task_wise_weights = get_task_wise_weights(num_models)

# Define the task vectors (in this case, we'll use the state_dict of the pretrained model)
task_vectors = ...

# Initialize the TaskWiseMergedModel
merged_model = TaskWiseMergedModel(pretrained_model, task_wise_weights, task_vectors)

# Now you can use the merged_model like a regular PyTorch model
outputs = merged_model(inputs)

TaskWiseMergedModel

Bases: Module, Generic[TorchModelType]

A PyTorch module that dynamically merges multiple fine-tuned models using learnable task-wise weights.

This class implements a sophisticated model fusion approach where multiple task-specific models are combined with a pretrained base model using learnable weights. The fusion is performed using task vectors (differences between fine-tuned and pretrained models) that are weighted and added to the base model's parameters.

The key innovation is that the merging weights are learnable parameters that can be optimized during training, allowing the model to automatically learn the optimal combination of different task-specific knowledge.

Architecture
  • Base pretrained model (frozen)
  • Multiple task vectors (differences from pretrained model, frozen)
  • Learnable task-wise weights (trainable parameters)
  • Dynamic merging during forward pass

Parameters:

  • task_wise_weight (Tensor) –

    Initial weights for each task model. Shape: (num_models,). These become learnable parameters that control the contribution of each task vector.

  • pretrained_model (TorchModelType) –

    The base pretrained model that serves as the foundation. This model is frozen and used as the starting point for merging.

  • finetuned_models (List[TorchModelType]) –

    List of fine-tuned models for different tasks. These are converted to task vectors (differences from pretrained model) and frozen.

  • clamp_weights (bool, default: True ) –

    Whether to clamp merge weights to [0, 1] range. Defaults to True. When True, ensures weights are non-negative and bounded.

  • tie_weights (bool, default: False ) –

    Whether to tie weights during functional call. Defaults to False. Used in the underlying PyTorch functional_call.

  • strict (bool, default: True ) –

    Whether to enforce strict parameter matching. Defaults to True. Used in the underlying PyTorch functional_call.

  • task_vector_dtype (Optional[dtype], default: None ) –

    Data type for task vectors. Defaults to None. Can be used to save memory (e.g., torch.float16).

Attributes:

  • merge_weight (Parameter) –

    Learnable weights for merging task vectors.

  • pretrained_model (TorchModelType) –

    The frozen base model.

  • task_vectors (ModuleList) –

    List of frozen task vector models.

  • _merged_state_dict (StateDictType) –

    Cached merged state dictionary.

Example
import torch
import torch.nn as nn

# Create example models
pretrained_model = nn.Linear(10, 5)
finetuned_model1 = nn.Linear(10, 5)  # Fine-tuned on task 1
finetuned_model2 = nn.Linear(10, 5)  # Fine-tuned on task 2

# Initialize task-wise weights
task_weights = torch.tensor([0.3, 0.7])  # Initial weights for 2 tasks

# Create merged model
merged_model = TaskWiseMergedModel(
    task_wise_weight=task_weights,
    pretrained_model=pretrained_model,
    finetuned_models=[finetuned_model1, finetuned_model2],
    clamp_weights=True
)

# Use like a regular PyTorch model
x = torch.randn(32, 10)
output = merged_model(x)

# Train the merge weights
optimizer = torch.optim.Adam(merged_model.parameters())
loss = some_loss_function(output, targets)
loss.backward()
optimizer.step()

# Get the final merged model
final_model = merged_model.merge_and_unload()
Training Workflow
  1. Initialization: Task vectors are computed as differences from pretrained model
  2. Forward Pass: Weights are dynamically merged based on current merge_weight values
  3. Loss Computation: Standard loss computation on model outputs
  4. Backpropagation: Gradients flow through merge_weight parameters
  5. Optimization: merge_weight parameters are updated to improve performance
Memory Efficiency
  • Task vectors can use lower precision (task_vector_dtype)
  • Base model and task vectors are frozen (no gradient computation)
  • Only merge weights require gradients
Note
  • The pretrained model and task vectors are frozen during training
  • Only the merge weights (task_wise_weight) are trainable parameters
  • Task vectors represent the difference between fine-tuned and pretrained models
  • The merged state dict is cached and recomputed when merge weights change
Source code in fusion_bench/models/wrappers/task_wise_fusion.py
class TaskWiseMergedModel(nn.Module, Generic[TorchModelType]):
    """
    A PyTorch module that dynamically merges multiple fine-tuned models using learnable task-wise weights.

    This class implements a sophisticated model fusion approach where multiple task-specific models
    are combined with a pretrained base model using learnable weights. The fusion is performed
    using task vectors (differences between fine-tuned and pretrained models) that are weighted
    and added to the base model's parameters.

    The key innovation is that the merging weights are learnable parameters that can be optimized
    during training, allowing the model to automatically learn the optimal combination of different
    task-specific knowledge.

    Architecture:
        - Base pretrained model (frozen)
        - Multiple task vectors (differences from pretrained model, frozen)
        - Learnable task-wise weights (trainable parameters)
        - Dynamic merging during forward pass

    Args:
        task_wise_weight (Tensor): Initial weights for each task model. Shape: (num_models,).
            These become learnable parameters that control the contribution of each task vector.
        pretrained_model (TorchModelType): The base pretrained model that serves as the foundation.
            This model is frozen and used as the starting point for merging.
        finetuned_models (List[TorchModelType]): List of fine-tuned models for different tasks.
            These are converted to task vectors (differences from pretrained model) and frozen.
        clamp_weights (bool, optional): Whether to clamp merge weights to [0, 1] range.
            Defaults to True. When True, ensures weights are non-negative and bounded.
        tie_weights (bool, optional): Whether to tie weights during functional call.
            Defaults to False. Used in the underlying PyTorch functional_call.
        strict (bool, optional): Whether to enforce strict parameter matching.
            Defaults to True. Used in the underlying PyTorch functional_call.
        task_vector_dtype (Optional[torch.dtype], optional): Data type for task vectors.
            Defaults to None. Can be used to save memory (e.g., torch.float16).

    Attributes:
        merge_weight (nn.Parameter): Learnable weights for merging task vectors.
        pretrained_model (TorchModelType): The frozen base model.
        task_vectors (nn.ModuleList): List of frozen task vector models.
        _merged_state_dict (StateDictType): Cached merged state dictionary.

    Example:
        ```python
        import torch
        import torch.nn as nn

        # Create example models
        pretrained_model = nn.Linear(10, 5)
        finetuned_model1 = nn.Linear(10, 5)  # Fine-tuned on task 1
        finetuned_model2 = nn.Linear(10, 5)  # Fine-tuned on task 2

        # Initialize task-wise weights
        task_weights = torch.tensor([0.3, 0.7])  # Initial weights for 2 tasks

        # Create merged model
        merged_model = TaskWiseMergedModel(
            task_wise_weight=task_weights,
            pretrained_model=pretrained_model,
            finetuned_models=[finetuned_model1, finetuned_model2],
            clamp_weights=True
        )

        # Use like a regular PyTorch model
        x = torch.randn(32, 10)
        output = merged_model(x)

        # Train the merge weights
        optimizer = torch.optim.Adam(merged_model.parameters())
        loss = some_loss_function(output, targets)
        loss.backward()
        optimizer.step()

        # Get the final merged model
        final_model = merged_model.merge_and_unload()
        ```

    Training Workflow:
        1. **Initialization**: Task vectors are computed as differences from pretrained model
        2. **Forward Pass**: Weights are dynamically merged based on current merge_weight values
        3. **Loss Computation**: Standard loss computation on model outputs
        4. **Backpropagation**: Gradients flow through merge_weight parameters
        5. **Optimization**: merge_weight parameters are updated to improve performance

    Memory Efficiency:
        - Task vectors can use lower precision (task_vector_dtype)
        - Base model and task vectors are frozen (no gradient computation)
        - Only merge weights require gradients

    Note:
        - The pretrained model and task vectors are frozen during training
        - Only the merge weights (task_wise_weight) are trainable parameters
        - Task vectors represent the difference between fine-tuned and pretrained models
        - The merged state dict is cached and recomputed when merge weights change
    """

    _merged_state_dict: StateDictType = None

    def __init__(
        self,
        task_wise_weight: Tensor,
        pretrained_model: TorchModelType,
        finetuned_models: List[TorchModelType],
        clamp_weights: bool = True,
        tie_weights: bool = False,
        strict: bool = True,
        task_vector_dtype: Optional[torch.dtype] = None,
    ):
        """
        Initialize the TaskWiseMergedModel.

        This constructor sets up the model by:
        1. Converting fine-tuned models to task vectors (differences from pretrained)
        2. Freezing the pretrained model and task vectors
        3. Setting up learnable merge weights as parameters
        4. Configuring merging behavior options

        Args:
            task_wise_weight (Tensor): Initial weights for each task model. Shape: (num_models,).
                These values become the starting point for learnable parameters.
            pretrained_model (TorchModelType): The base pretrained model.
                Will be frozen and used as the foundation for merging.
            finetuned_models (List[TorchModelType]): List of fine-tuned models.
                Must have the same architecture as pretrained_model.
            clamp_weights (bool, optional): Whether to clamp weights to [0, 1]. Defaults to True.
            tie_weights (bool, optional): Whether to tie weights in functional_call. Defaults to False.
            strict (bool, optional): Whether to use strict parameter matching. Defaults to True.
            task_vector_dtype (Optional[torch.dtype], optional): Data type for task vectors.
                Defaults to None (same as original models).

        Raises:
            ValueError: If the number of task_wise_weights doesn't match the number of fine-tuned models.
            RuntimeError: If models have incompatible architectures.
        """
        super().__init__()
        self.clamp_weights = clamp_weights
        self.tie_weights = tie_weights
        self.strict = strict
        self.task_vector_dtype = task_vector_dtype

        self.merge_weight = nn.Parameter(task_wise_weight, requires_grad=True)

        for name, param in pretrained_model.named_parameters():
            if not param.requires_grad:
                for m in finetuned_models:
                    del_attr(m, name.split("."))
            else:
                for m in finetuned_models:
                    get_attr(m, name.split(".")).data = (
                        get_attr(m, name.split(".")) - param
                    )
        self.pretrained_model = pretrained_model.requires_grad_(False)
        for m in finetuned_models:
            m.requires_grad_(False)
        self.task_vectors = nn.ModuleList(finetuned_models)
        if self.task_vector_dtype is not None:
            self.task_vectors = self.task_vectors.to(self.task_vector_dtype)

    @property
    def forward_model(self):
        """
        Get a functional model with merged parameters.

        Returns a partial function that applies the pretrained model with the current
        merged state dictionary. This allows for efficient forward passes without
        modifying the original model's parameters.

        Returns:
            Callable: A partial function that can be called with (args, kwargs) to
                perform forward pass with merged parameters.

        Example:
            ```python
            # Internal usage during forward pass
            forward_fn = merged_model.forward_model
            output = forward_fn(args=(x,), kwargs={})
            ```
        """
        return functools.partial(
            functional_call,
            self.pretrained_model,
            self._merged_state_dict,
            tie_weights=self.tie_weights,
            strict=self.strict,
        )

    def merge_weights(self, task_vector_mask: Optional[Dict[str, Tensor]] = None):
        """
        Merge task vectors with the pretrained model using current merge weights.

        This method computes the merged model parameters by combining the pretrained
        model with weighted task vectors. The resulting state dictionary represents
        a model that incorporates knowledge from all task-specific models.

        The merging formula for each parameter is:
        merged_param = pretrained_param + Σ(weight_i * task_vector_i * mask_i)

        Args:
            task_vector_mask (Optional[Dict[str, Tensor]], optional): Optional masks
                to selectively apply task vectors to specific parameters. Keys should
                match parameter names, values should be tensors with the same shape
                as the corresponding parameters. Defaults to None (no masking).

        Returns:
            StateDictType: The merged state dictionary containing combined parameters.

        Example:
            ```python
            # Basic merging
            merged_state = model.merge_weights()

            # Merging with parameter-specific masks
            masks = {
                'layer1.weight': torch.ones_like(model.pretrained_model.layer1.weight),
                'layer2.weight': torch.zeros_like(model.pretrained_model.layer2.weight),
            }
            masked_state = model.merge_weights(task_vector_mask=masks)
            ```

        Note:
            - If clamp_weights is True, merge weights are clamped to [0, 1] range
            - The merged state dict is cached in _merged_state_dict
            - Task vector masks allow fine-grained control over which parameters are affected
        """
        if self.clamp_weights:
            merge_weight = self.merge_weight.clamp(0, 1)
        else:
            merge_weight = self.merge_weight

        state_dict = self.pretrained_model.state_dict(keep_vars=True)
        for weight, task_vector in zip(merge_weight, self.task_vectors):
            for name, param in task_vector.named_parameters():
                if task_vector_mask is None:
                    w = weight
                else:
                    w = weight * task_vector_mask[name]
                state_dict[name] = state_dict[name] + param * w
        self._merged_state_dict = state_dict
        return state_dict

    def merge_and_unload(self, task_vector_mask: Optional[Dict[str, Tensor]] = None):
        """
        Merge models and return the final merged model.

        This method performs the merging operation and then loads the merged parameters
        into the pretrained model, returning a standard PyTorch model that can be used
        independently of the TaskWiseMergedModel wrapper.

        Args:
            task_vector_mask (Optional[Dict[str, Tensor]], optional): Optional masks
                for selective parameter merging. Defaults to None.

        Returns:
            TorchModelType: The pretrained model with merged parameters loaded.
                This is a standalone model that can be used without the wrapper.

        Example:
            ```python
            # Train the merged model
            for epoch in range(num_epochs):
                # ... training loop ...
                pass

            # Get the final merged model
            final_model = merged_model.merge_and_unload()

            # Save or use the final model
            torch.save(final_model.state_dict(), 'merged_model.pth')
            output = final_model(new_input)
            ```

        Warning:
            This method modifies the pretrained_model's parameters in-place.
            The original pretrained model parameters will be lost.
        """
        self.merge_weights(task_vector_mask=task_vector_mask)
        self.pretrained_model.load_state_dict(self._merged_state_dict)
        return self.pretrained_model

    def forward(self, *args, **kwargs):
        """
        Forward pass through the dynamically merged model.

        This method performs the forward pass by first ensuring the model parameters
        are merged according to the current merge weights, then applying the merged
        model to the input data.

        The forward pass involves:
        1. Check if merged state dict is current (recompute if needed)
        2. Apply the merged model to inputs using functional_call
        3. Return the model outputs

        Args:
            *args: Positional arguments to pass to the underlying model.
            **kwargs: Keyword arguments to pass to the underlying model.

        Returns:
            Any: The output of the merged model, typically torch.Tensor or tuple of tensors.

        Example:
            ```python
            # Single input
            x = torch.randn(32, 784)
            output = merged_model(x)

            # Multiple inputs
            x1, x2 = torch.randn(32, 784), torch.randn(32, 100)
            output = merged_model(x1, x2)

            # With keyword arguments
            output = merged_model(input_ids=input_ids, attention_mask=attention_mask)
            ```

        Note:
            - The merged state dict is recomputed if merge weights have changed
            - This allows for dynamic behavior during training as weights are updated
            - The computation is efficient as merging only happens when needed
        """
        if self._merged_state_dict is None:
            self.merge_weights()
        return self.forward_model(args=args, kwargs=kwargs)
forward_model property

Get a functional model with merged parameters.

Returns a partial function that applies the pretrained model with the current merged state dictionary. This allows for efficient forward passes without modifying the original model's parameters.

Returns:

  • Callable

    A partial function that can be called with (args, kwargs) to perform forward pass with merged parameters.

Example
# Internal usage during forward pass
forward_fn = merged_model.forward_model
output = forward_fn(args=(x,), kwargs={})
__init__(task_wise_weight, pretrained_model, finetuned_models, clamp_weights=True, tie_weights=False, strict=True, task_vector_dtype=None)

Initialize the TaskWiseMergedModel.

This constructor sets up the model by: 1. Converting fine-tuned models to task vectors (differences from pretrained) 2. Freezing the pretrained model and task vectors 3. Setting up learnable merge weights as parameters 4. Configuring merging behavior options

Parameters:

  • task_wise_weight (Tensor) –

    Initial weights for each task model. Shape: (num_models,). These values become the starting point for learnable parameters.

  • pretrained_model (TorchModelType) –

    The base pretrained model. Will be frozen and used as the foundation for merging.

  • finetuned_models (List[TorchModelType]) –

    List of fine-tuned models. Must have the same architecture as pretrained_model.

  • clamp_weights (bool, default: True ) –

    Whether to clamp weights to [0, 1]. Defaults to True.

  • tie_weights (bool, default: False ) –

    Whether to tie weights in functional_call. Defaults to False.

  • strict (bool, default: True ) –

    Whether to use strict parameter matching. Defaults to True.

  • task_vector_dtype (Optional[dtype], default: None ) –

    Data type for task vectors. Defaults to None (same as original models).

Raises:

  • ValueError

    If the number of task_wise_weights doesn't match the number of fine-tuned models.

  • RuntimeError

    If models have incompatible architectures.

Source code in fusion_bench/models/wrappers/task_wise_fusion.py
def __init__(
    self,
    task_wise_weight: Tensor,
    pretrained_model: TorchModelType,
    finetuned_models: List[TorchModelType],
    clamp_weights: bool = True,
    tie_weights: bool = False,
    strict: bool = True,
    task_vector_dtype: Optional[torch.dtype] = None,
):
    """
    Initialize the TaskWiseMergedModel.

    This constructor sets up the model by:
    1. Converting fine-tuned models to task vectors (differences from pretrained)
    2. Freezing the pretrained model and task vectors
    3. Setting up learnable merge weights as parameters
    4. Configuring merging behavior options

    Args:
        task_wise_weight (Tensor): Initial weights for each task model. Shape: (num_models,).
            These values become the starting point for learnable parameters.
        pretrained_model (TorchModelType): The base pretrained model.
            Will be frozen and used as the foundation for merging.
        finetuned_models (List[TorchModelType]): List of fine-tuned models.
            Must have the same architecture as pretrained_model.
        clamp_weights (bool, optional): Whether to clamp weights to [0, 1]. Defaults to True.
        tie_weights (bool, optional): Whether to tie weights in functional_call. Defaults to False.
        strict (bool, optional): Whether to use strict parameter matching. Defaults to True.
        task_vector_dtype (Optional[torch.dtype], optional): Data type for task vectors.
            Defaults to None (same as original models).

    Raises:
        ValueError: If the number of task_wise_weights doesn't match the number of fine-tuned models.
        RuntimeError: If models have incompatible architectures.
    """
    super().__init__()
    self.clamp_weights = clamp_weights
    self.tie_weights = tie_weights
    self.strict = strict
    self.task_vector_dtype = task_vector_dtype

    self.merge_weight = nn.Parameter(task_wise_weight, requires_grad=True)

    for name, param in pretrained_model.named_parameters():
        if not param.requires_grad:
            for m in finetuned_models:
                del_attr(m, name.split("."))
        else:
            for m in finetuned_models:
                get_attr(m, name.split(".")).data = (
                    get_attr(m, name.split(".")) - param
                )
    self.pretrained_model = pretrained_model.requires_grad_(False)
    for m in finetuned_models:
        m.requires_grad_(False)
    self.task_vectors = nn.ModuleList(finetuned_models)
    if self.task_vector_dtype is not None:
        self.task_vectors = self.task_vectors.to(self.task_vector_dtype)
forward(*args, **kwargs)

Forward pass through the dynamically merged model.

This method performs the forward pass by first ensuring the model parameters are merged according to the current merge weights, then applying the merged model to the input data.

The forward pass involves: 1. Check if merged state dict is current (recompute if needed) 2. Apply the merged model to inputs using functional_call 3. Return the model outputs

Parameters:

  • *args

    Positional arguments to pass to the underlying model.

  • **kwargs

    Keyword arguments to pass to the underlying model.

Returns:

  • Any

    The output of the merged model, typically torch.Tensor or tuple of tensors.

Example
# Single input
x = torch.randn(32, 784)
output = merged_model(x)

# Multiple inputs
x1, x2 = torch.randn(32, 784), torch.randn(32, 100)
output = merged_model(x1, x2)

# With keyword arguments
output = merged_model(input_ids=input_ids, attention_mask=attention_mask)
Note
  • The merged state dict is recomputed if merge weights have changed
  • This allows for dynamic behavior during training as weights are updated
  • The computation is efficient as merging only happens when needed
Source code in fusion_bench/models/wrappers/task_wise_fusion.py
def forward(self, *args, **kwargs):
    """
    Forward pass through the dynamically merged model.

    This method performs the forward pass by first ensuring the model parameters
    are merged according to the current merge weights, then applying the merged
    model to the input data.

    The forward pass involves:
    1. Check if merged state dict is current (recompute if needed)
    2. Apply the merged model to inputs using functional_call
    3. Return the model outputs

    Args:
        *args: Positional arguments to pass to the underlying model.
        **kwargs: Keyword arguments to pass to the underlying model.

    Returns:
        Any: The output of the merged model, typically torch.Tensor or tuple of tensors.

    Example:
        ```python
        # Single input
        x = torch.randn(32, 784)
        output = merged_model(x)

        # Multiple inputs
        x1, x2 = torch.randn(32, 784), torch.randn(32, 100)
        output = merged_model(x1, x2)

        # With keyword arguments
        output = merged_model(input_ids=input_ids, attention_mask=attention_mask)
        ```

    Note:
        - The merged state dict is recomputed if merge weights have changed
        - This allows for dynamic behavior during training as weights are updated
        - The computation is efficient as merging only happens when needed
    """
    if self._merged_state_dict is None:
        self.merge_weights()
    return self.forward_model(args=args, kwargs=kwargs)
merge_and_unload(task_vector_mask=None)

Merge models and return the final merged model.

This method performs the merging operation and then loads the merged parameters into the pretrained model, returning a standard PyTorch model that can be used independently of the TaskWiseMergedModel wrapper.

Parameters:

  • task_vector_mask (Optional[Dict[str, Tensor]], default: None ) –

    Optional masks for selective parameter merging. Defaults to None.

Returns:

  • TorchModelType

    The pretrained model with merged parameters loaded. This is a standalone model that can be used without the wrapper.

Example
# Train the merged model
for epoch in range(num_epochs):
    # ... training loop ...
    pass

# Get the final merged model
final_model = merged_model.merge_and_unload()

# Save or use the final model
torch.save(final_model.state_dict(), 'merged_model.pth')
output = final_model(new_input)
Warning

This method modifies the pretrained_model's parameters in-place. The original pretrained model parameters will be lost.

Source code in fusion_bench/models/wrappers/task_wise_fusion.py
def merge_and_unload(self, task_vector_mask: Optional[Dict[str, Tensor]] = None):
    """
    Merge models and return the final merged model.

    This method performs the merging operation and then loads the merged parameters
    into the pretrained model, returning a standard PyTorch model that can be used
    independently of the TaskWiseMergedModel wrapper.

    Args:
        task_vector_mask (Optional[Dict[str, Tensor]], optional): Optional masks
            for selective parameter merging. Defaults to None.

    Returns:
        TorchModelType: The pretrained model with merged parameters loaded.
            This is a standalone model that can be used without the wrapper.

    Example:
        ```python
        # Train the merged model
        for epoch in range(num_epochs):
            # ... training loop ...
            pass

        # Get the final merged model
        final_model = merged_model.merge_and_unload()

        # Save or use the final model
        torch.save(final_model.state_dict(), 'merged_model.pth')
        output = final_model(new_input)
        ```

    Warning:
        This method modifies the pretrained_model's parameters in-place.
        The original pretrained model parameters will be lost.
    """
    self.merge_weights(task_vector_mask=task_vector_mask)
    self.pretrained_model.load_state_dict(self._merged_state_dict)
    return self.pretrained_model
merge_weights(task_vector_mask=None)

Merge task vectors with the pretrained model using current merge weights.

This method computes the merged model parameters by combining the pretrained model with weighted task vectors. The resulting state dictionary represents a model that incorporates knowledge from all task-specific models.

The merging formula for each parameter is: merged_param = pretrained_param + Σ(weight_i * task_vector_i * mask_i)

Parameters:

  • task_vector_mask (Optional[Dict[str, Tensor]], default: None ) –

    Optional masks to selectively apply task vectors to specific parameters. Keys should match parameter names, values should be tensors with the same shape as the corresponding parameters. Defaults to None (no masking).

Returns:

  • StateDictType

    The merged state dictionary containing combined parameters.

Example
# Basic merging
merged_state = model.merge_weights()

# Merging with parameter-specific masks
masks = {
    'layer1.weight': torch.ones_like(model.pretrained_model.layer1.weight),
    'layer2.weight': torch.zeros_like(model.pretrained_model.layer2.weight),
}
masked_state = model.merge_weights(task_vector_mask=masks)
Note
  • If clamp_weights is True, merge weights are clamped to [0, 1] range
  • The merged state dict is cached in _merged_state_dict
  • Task vector masks allow fine-grained control over which parameters are affected
Source code in fusion_bench/models/wrappers/task_wise_fusion.py
def merge_weights(self, task_vector_mask: Optional[Dict[str, Tensor]] = None):
    """
    Merge task vectors with the pretrained model using current merge weights.

    This method computes the merged model parameters by combining the pretrained
    model with weighted task vectors. The resulting state dictionary represents
    a model that incorporates knowledge from all task-specific models.

    The merging formula for each parameter is:
    merged_param = pretrained_param + Σ(weight_i * task_vector_i * mask_i)

    Args:
        task_vector_mask (Optional[Dict[str, Tensor]], optional): Optional masks
            to selectively apply task vectors to specific parameters. Keys should
            match parameter names, values should be tensors with the same shape
            as the corresponding parameters. Defaults to None (no masking).

    Returns:
        StateDictType: The merged state dictionary containing combined parameters.

    Example:
        ```python
        # Basic merging
        merged_state = model.merge_weights()

        # Merging with parameter-specific masks
        masks = {
            'layer1.weight': torch.ones_like(model.pretrained_model.layer1.weight),
            'layer2.weight': torch.zeros_like(model.pretrained_model.layer2.weight),
        }
        masked_state = model.merge_weights(task_vector_mask=masks)
        ```

    Note:
        - If clamp_weights is True, merge weights are clamped to [0, 1] range
        - The merged state dict is cached in _merged_state_dict
        - Task vector masks allow fine-grained control over which parameters are affected
    """
    if self.clamp_weights:
        merge_weight = self.merge_weight.clamp(0, 1)
    else:
        merge_weight = self.merge_weight

    state_dict = self.pretrained_model.state_dict(keep_vars=True)
    for weight, task_vector in zip(merge_weight, self.task_vectors):
        for name, param in task_vector.named_parameters():
            if task_vector_mask is None:
                w = weight
            else:
                w = weight * task_vector_mask[name]
            state_dict[name] = state_dict[name] + param * w
    self._merged_state_dict = state_dict
    return state_dict

fuse_weights(task_wise_weight, state_dicts)

This function fuses the weights of the models and returns a state dictionary.

Parameters:

  • task_wise_weight (Tensor) –

    The weights for each model. on cuda or cpu.

  • state_dicts (List[StateDictType]) –

    The list of state dictionaries. on cpu.

Returns:

  • StateDictType ( StateDictType ) –

    The fused state dictionary.

Source code in fusion_bench/models/wrappers/task_wise_fusion.py
def fuse_weights(
    task_wise_weight: Tensor, state_dicts: List[StateDictType]
) -> StateDictType:
    """
    This function fuses the weights of the models and returns a state dictionary.

    Args:
        task_wise_weight (Tensor): The weights for each model. on cuda or cpu.
        state_dicts (List[StateDictType]): The list of state dictionaries. on cpu.

    Returns:
        StateDictType: The fused state dictionary.
    """
    num_models = len(state_dicts)
    assert (
        task_wise_weight.dim() == 1
    ), f"task_wise_weight must be a 1D tensor, got {task_wise_weight.dim()}"
    assert num_models == task_wise_weight.size(
        0
    ), f"num_models must be equal to the number of state_dicts, got {num_models} and {task_wise_weight.size(0)}"
    return {
        k: _fuse_weights(task_wise_weight, [sd[k] for sd in state_dicts])
        for k in state_dicts[0].keys()
    }

get_task_wise_weights(num_models, init_values=None)

This function generates a tensor of weights for each model.

Parameters:

  • num_models (int) –

    The number of models.

  • init_values (float, default: None ) –

    The initial value for each weight. Defaults to None.

Returns:

  • Tensor

    A tensor of weights for each model.

Source code in fusion_bench/models/wrappers/task_wise_fusion.py
def get_task_wise_weights(num_models: int, init_values: float = None):
    """
    This function generates a tensor of weights for each model.

    Args:
        num_models (int): The number of models.
        init_values (float, optional): The initial value for each weight. Defaults to None.

    Returns:
        Tensor: A tensor of weights for each model.
    """
    assert num_models >= 1, f"num_models must be >= 1, got {num_models}"
    if init_values is None:
        init_values = 1.0 / num_models
    return torch.full((num_models,), init_values, dtype=torch.float32)

Model Ensemble

ensemble

EnsembleModule

Bases: Module

Ensemble module that averages the outputs of multiple models.

Source code in fusion_bench/models/wrappers/ensemble.py
class EnsembleModule(nn.Module):
    """
    Ensemble module that averages the outputs of multiple models.
    """

    def __init__(self, models: List[nn.Module]):
        """
        Initializes the EnsembleModule with a list of models.

        Args:
            models (List[nn.Module]): List of models to ensemble.
        """
        super().__init__()
        # TODO: distribute models to devices
        self.model_list = nn.ModuleList(models)

    def _aggregate_tensors(self, outputs: List[Tensor]) -> Tensor:
        """
        Aggregates a list of tensors by computing their mean.

        Args:
            outputs (List[Tensor]): List of tensors to aggregate.

        Returns:
            Tensor: The mean tensor.
        """
        return torch.stack(outputs).mean(dim=0)

    def forward(self, *args, **kwargs):
        """
        Performs a forward pass by averaging the outputs of the models.

        Args:
            *args: Variable length argument list.
            **kwargs: Arbitrary keyword arguments.

        Returns:
            Aggregated output from the ensemble of models.
        """
        outputs = [model(*args, **kwargs) for model in self.model_list]
        return aggregate_tensors(outputs, self._aggregate_tensors)
__init__(models)

Initializes the EnsembleModule with a list of models.

Parameters:

  • models (List[Module]) –

    List of models to ensemble.

Source code in fusion_bench/models/wrappers/ensemble.py
def __init__(self, models: List[nn.Module]):
    """
    Initializes the EnsembleModule with a list of models.

    Args:
        models (List[nn.Module]): List of models to ensemble.
    """
    super().__init__()
    # TODO: distribute models to devices
    self.model_list = nn.ModuleList(models)
forward(*args, **kwargs)

Performs a forward pass by averaging the outputs of the models.

Parameters:

  • *args

    Variable length argument list.

  • **kwargs

    Arbitrary keyword arguments.

Returns:

  • Aggregated output from the ensemble of models.

Source code in fusion_bench/models/wrappers/ensemble.py
def forward(self, *args, **kwargs):
    """
    Performs a forward pass by averaging the outputs of the models.

    Args:
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.

    Returns:
        Aggregated output from the ensemble of models.
    """
    outputs = [model(*args, **kwargs) for model in self.model_list]
    return aggregate_tensors(outputs, self._aggregate_tensors)

MaxModelPredictor

Bases: EnsembleModule

Ensemble module that selects the maximum output among multiple models.

Source code in fusion_bench/models/wrappers/ensemble.py
class MaxModelPredictor(EnsembleModule):
    """
    Ensemble module that selects the maximum output among multiple models.
    """

    def _aggregate_tensors(self, outputs: List[Tensor]) -> Tensor:
        """
        Aggregates a list of tensors by selecting the maximum value at each position.

        Args:
            outputs (List[Tensor]): List of tensors to aggregate.

        Returns:
            Tensor: Tensor with the maximum values.
        """
        return torch.stack(outputs).max(dim=0).values

WeightedEnsembleModule

Bases: Module

Ensemble module that computes a weighted average of the outputs from multiple models.

Source code in fusion_bench/models/wrappers/ensemble.py
class WeightedEnsembleModule(nn.Module):
    """
    Ensemble module that computes a weighted average of the outputs from multiple models.
    """

    def __init__(
        self,
        models: List[nn.Module],
        weights: List[float] | Tensor | np.ndarray,
        normalize: bool = True,
    ):
        """
        Initializes the WeightedEnsembleModule with models and their corresponding weights.

        Args:
            models (List[nn.Module]): List of models to ensemble.
            weights (List[float] | Tensor | np.ndarray): Weights for each model.
            normalize (bool, optional): If True, normalizes the weights. Defaults to True.
        """
        super().__init__()
        self.model_list = nn.ModuleList(models)
        if isinstance(weights, (list, tuple, ListConfig)):
            weights = torch.tensor(weights)
        elif isinstance(weights, Tensor):
            weights = weights
        elif isinstance(weights, np.ndarray):
            weights = torch.from_numpy(weights)
        else:
            raise ValueError(f"Unsupported type for weights: {type(weights)=}")

        assert len(models) == len(weights) and weights.dim() == 1, (
            "weights must be a 1D tensor of the same length as models."
            f"But got {len(models)=}, {weights.dim()=}"
        )
        if normalize:
            weights = weights / weights.sum()
        self.register_buffer("weights", weights)

    def _aggregate_tensors(self, outputs: List[Tensor]) -> Tensor:
        """
        Aggregates a list of tensors using the provided weights.

        Args:
            outputs (List[Tensor]): List of tensors to aggregate.

        Returns:
            Tensor: The weighted sum of the tensors.
        """
        weights = cast(Tensor, self.weights).view(-1, *([1] * outputs[0].dim()))
        return (torch.stack(outputs) * weights).sum(dim=0)

    def forward(self, *args, **kwargs):
        """
        Performs a forward pass by computing the weighted average of the models' outputs.

        Args:
            *args: Variable length argument list.
            **kwargs: Arbitrary keyword arguments.

        Returns:
            Weighted aggregated output from the ensemble of models.
        """
        outputs = [model(*args, **kwargs) for model in self.model_list]
        return aggregate_tensors(outputs, self._aggregate_tensors)
__init__(models, weights, normalize=True)

Initializes the WeightedEnsembleModule with models and their corresponding weights.

Parameters:

  • models (List[Module]) –

    List of models to ensemble.

  • weights (List[float] | Tensor | ndarray) –

    Weights for each model.

  • normalize (bool, default: True ) –

    If True, normalizes the weights. Defaults to True.

Source code in fusion_bench/models/wrappers/ensemble.py
def __init__(
    self,
    models: List[nn.Module],
    weights: List[float] | Tensor | np.ndarray,
    normalize: bool = True,
):
    """
    Initializes the WeightedEnsembleModule with models and their corresponding weights.

    Args:
        models (List[nn.Module]): List of models to ensemble.
        weights (List[float] | Tensor | np.ndarray): Weights for each model.
        normalize (bool, optional): If True, normalizes the weights. Defaults to True.
    """
    super().__init__()
    self.model_list = nn.ModuleList(models)
    if isinstance(weights, (list, tuple, ListConfig)):
        weights = torch.tensor(weights)
    elif isinstance(weights, Tensor):
        weights = weights
    elif isinstance(weights, np.ndarray):
        weights = torch.from_numpy(weights)
    else:
        raise ValueError(f"Unsupported type for weights: {type(weights)=}")

    assert len(models) == len(weights) and weights.dim() == 1, (
        "weights must be a 1D tensor of the same length as models."
        f"But got {len(models)=}, {weights.dim()=}"
    )
    if normalize:
        weights = weights / weights.sum()
    self.register_buffer("weights", weights)
forward(*args, **kwargs)

Performs a forward pass by computing the weighted average of the models' outputs.

Parameters:

  • *args

    Variable length argument list.

  • **kwargs

    Arbitrary keyword arguments.

Returns:

  • Weighted aggregated output from the ensemble of models.

Source code in fusion_bench/models/wrappers/ensemble.py
def forward(self, *args, **kwargs):
    """
    Performs a forward pass by computing the weighted average of the models' outputs.

    Args:
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.

    Returns:
        Weighted aggregated output from the ensemble of models.
    """
    outputs = [model(*args, **kwargs) for model in self.model_list]
    return aggregate_tensors(outputs, self._aggregate_tensors)

aggregate_tensors(outputs, aggregate_fn)

Aggregates a list of outputs using the provided aggregation function.

This function handles different types of outputs: - If the outputs are Tensors, it applies the aggregation function directly. - If the outputs are dictionaries, it recursively aggregates each value. - If the outputs are tuples or lists, it recursively aggregates each element. - If all outputs are None, it returns None. - If the outputs are of an unsupported type, it raises a ValueError.

Parameters:

  • outputs (list) –

    A list of outputs to be aggregated. The outputs can be Tensors, dictionaries, tuples, lists, or None.

  • aggregate_fn (callable) –

    A function to aggregate the outputs. Typically, this could be a function like torch.mean.

Returns:

  • Tensor

    Tensor or dict or tuple or list or None: The aggregated output, matching the type of the input outputs.

Raises:

  • ValueError

    If the outputs are of an unsupported type.

Source code in fusion_bench/models/wrappers/ensemble.py
def aggregate_tensors(outputs: List[Any], aggregate_fn: Callable) -> Tensor:
    """
    Aggregates a list of outputs using the provided aggregation function.

    This function handles different types of outputs:
    - If the outputs are Tensors, it applies the aggregation function directly.
    - If the outputs are dictionaries, it recursively aggregates each value.
    - If the outputs are tuples or lists, it recursively aggregates each element.
    - If all outputs are None, it returns None.
    - If the outputs are of an unsupported type, it raises a ValueError.

    Args:
        outputs (list): A list of outputs to be aggregated. The outputs can be Tensors, dictionaries, tuples, lists, or None.
        aggregate_fn (callable): A function to aggregate the outputs. Typically, this could be a function like `torch.mean`.

    Returns:
        Tensor or dict or tuple or list or None: The aggregated output, matching the type of the input outputs.

    Raises:
        ValueError: If the outputs are of an unsupported type.
    """
    # If the output is a Tensor, take the mean
    if isinstance(outputs[0], torch.Tensor):
        return aggregate_fn(outputs)

    # If the output is a dict, take the mean of each value
    elif isinstance(outputs[0], Dict):
        result = type(outputs[0])()
        for key in outputs[0]:
            result[key] = aggregate_tensors(
                [output[key] for output in outputs], aggregate_fn
            )
        return result

    # If the output is a tuple or list, take the mean of each element
    elif isinstance(outputs[0], (tuple, list)):
        return tuple(
            aggregate_tensors([output[i] for output in outputs], aggregate_fn)
            for i in range(len(outputs[0]))
        )

    # If the output is None, return None
    elif all(output is None for output in outputs):
        return None

    # If the output is none of the above, return as is
    else:
        raise ValueError("Unsupported type for outputs")

Model Linearization (NTK)

LinearizedModelWraper

Bases: Module

Source code in fusion_bench/models/linearized/linearized_model_utils.py
class LinearizedModelWraper(nn.Module):
    def __init__(self, model: nn.Module, init_model: Optional[nn.Module] = None):
        """
        Initializes a linearized model.

        Args:
            model (nn.Module): The underlying PyTorch model to be linearized.
            init_model (nn.Module): The initial PyTorch model used to compute the linearization parameters (default: None).
        """
        super().__init__()
        self.model = model
        if init_model is None:
            init_model = model
        assert not hasattr(self, "params0")
        params0 = deepcopy([(k, v.detach()) for k, v in init_model.named_parameters()])
        self.params0_keys = [k for k, v in params0]
        self.params0_values = nn.ParameterList([v for k, v in params0])
        for p in self.params0_values:
            p.requires_grad_(False)

    def tuple_params_to_dict(self, tuple_params):
        """
        Converts a tuple of parameters to a dictionary with keys corresponding to the parameter names.

        Args:
            tuple_params (Tuple[Tensor, ...]): A tuple of parameters.

        Returns:
            Dict[str, Tensor]: A dictionary with keys corresponding to the parameter names and values corresponding to the
            parameter values.
        """
        assert len(tuple_params) == len(self.params0_keys)
        state_dict = {}
        for k, p in zip(self.params0_keys, tuple_params):
            state_dict[k] = p
        return state_dict

    def forward(self, *args, **kwargs):
        """
        Computes the linearized model output using a first-order Taylor decomposition.

        Args:
            *args: Positional arguments to be passed to the model.
            **kwargs: Keyword arguments to be passed to the model.

        Returns:
            torch.Tensor: The output of the linearized model, computed using a first-order Taylor decomposition.
        """
        params0 = tuple(self.params0_values)
        params = dict_params_to_tuple(OrderedDict(self.model.named_parameters()))
        dparams = tuple(p - p0 for p, p0 in zip(params, params0))
        out, dp = jvp(
            lambda *param: functional_call(
                self.model, self.tuple_params_to_dict(param), args, kwargs
            ),
            params0,
            dparams,
        )
        return out + dp

    @staticmethod
    def unload_linearized_modules_(module: nn.Module):
        """
        Unloads the linearized module and returns the original module.

        Args:
            module (nn.Module): The linearized module to be unloaded.

        Returns:
            nn.Module: The original module.
        """
        for name, model in module.named_children():
            if isinstance(model, LinearizedModelWraper):
                setattr(module, name, model.model)
            else:
                LinearizedModelWraper.unload_linearized_modules_(model)

__init__(model, init_model=None)

Initializes a linearized model.

Parameters:

  • model (Module) –

    The underlying PyTorch model to be linearized.

  • init_model (Module, default: None ) –

    The initial PyTorch model used to compute the linearization parameters (default: None).

Source code in fusion_bench/models/linearized/linearized_model_utils.py
def __init__(self, model: nn.Module, init_model: Optional[nn.Module] = None):
    """
    Initializes a linearized model.

    Args:
        model (nn.Module): The underlying PyTorch model to be linearized.
        init_model (nn.Module): The initial PyTorch model used to compute the linearization parameters (default: None).
    """
    super().__init__()
    self.model = model
    if init_model is None:
        init_model = model
    assert not hasattr(self, "params0")
    params0 = deepcopy([(k, v.detach()) for k, v in init_model.named_parameters()])
    self.params0_keys = [k for k, v in params0]
    self.params0_values = nn.ParameterList([v for k, v in params0])
    for p in self.params0_values:
        p.requires_grad_(False)

forward(*args, **kwargs)

Computes the linearized model output using a first-order Taylor decomposition.

Parameters:

  • *args

    Positional arguments to be passed to the model.

  • **kwargs

    Keyword arguments to be passed to the model.

Returns:

  • torch.Tensor: The output of the linearized model, computed using a first-order Taylor decomposition.

Source code in fusion_bench/models/linearized/linearized_model_utils.py
def forward(self, *args, **kwargs):
    """
    Computes the linearized model output using a first-order Taylor decomposition.

    Args:
        *args: Positional arguments to be passed to the model.
        **kwargs: Keyword arguments to be passed to the model.

    Returns:
        torch.Tensor: The output of the linearized model, computed using a first-order Taylor decomposition.
    """
    params0 = tuple(self.params0_values)
    params = dict_params_to_tuple(OrderedDict(self.model.named_parameters()))
    dparams = tuple(p - p0 for p, p0 in zip(params, params0))
    out, dp = jvp(
        lambda *param: functional_call(
            self.model, self.tuple_params_to_dict(param), args, kwargs
        ),
        params0,
        dparams,
    )
    return out + dp

tuple_params_to_dict(tuple_params)

Converts a tuple of parameters to a dictionary with keys corresponding to the parameter names.

Parameters:

  • tuple_params (Tuple[Tensor, ...]) –

    A tuple of parameters.

Returns:

  • Dict[str, Tensor]: A dictionary with keys corresponding to the parameter names and values corresponding to the

  • parameter values.

Source code in fusion_bench/models/linearized/linearized_model_utils.py
def tuple_params_to_dict(self, tuple_params):
    """
    Converts a tuple of parameters to a dictionary with keys corresponding to the parameter names.

    Args:
        tuple_params (Tuple[Tensor, ...]): A tuple of parameters.

    Returns:
        Dict[str, Tensor]: A dictionary with keys corresponding to the parameter names and values corresponding to the
        parameter values.
    """
    assert len(tuple_params) == len(self.params0_keys)
    state_dict = {}
    for k, p in zip(self.params0_keys, tuple_params):
        state_dict[k] = p
    return state_dict

unload_linearized_modules_(module) staticmethod

Unloads the linearized module and returns the original module.

Parameters:

  • module (Module) –

    The linearized module to be unloaded.

Returns:

  • nn.Module: The original module.

Source code in fusion_bench/models/linearized/linearized_model_utils.py
@staticmethod
def unload_linearized_modules_(module: nn.Module):
    """
    Unloads the linearized module and returns the original module.

    Args:
        module (nn.Module): The linearized module to be unloaded.

    Returns:
        nn.Module: The original module.
    """
    for name, model in module.named_children():
        if isinstance(model, LinearizedModelWraper):
            setattr(module, name, model.model)
        else:
            LinearizedModelWraper.unload_linearized_modules_(model)