fusion_bench.models¶
Task and Layer-wise Merging (AdaMerging)¶
layer_wise_fusion
¶
LayerWiseMergedModel
¶
Bases: Module
, Generic[TorchModelType]
Source code in fusion_bench/models/wrappers/layer_wise_fusion.py
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
|
__init__(layer_wise_weight, pretrained_model, finetuned_models, clamp_weights=True, tie_weights=False, strict=True, sparsity_ratio=None, normalized_merging_weights=False)
¶
This class wraps a pretrained model and a list of finetuned models, and merges the weights of the finetuned models into the pretrained model using layer-wise fusion.
Reference:
(ICLR 2024) Yang E, Wang Z, Shen L, et al. Adamerging: Adaptive model merging for multi-task learning. https://arxiv.org/pdf/2310.02575
Parameters:
-
layer_wise_weight
(Tensor
) –A tensor of shape (num_models, num_layers) representing the weight of each layer for each model.
-
pretrained_model
(Module
) –The pretrained model to merge the weights into.
-
finetuned_models
(List[Module]
) –A list of finetuned models to merge the weights from. This should have the same architecture as the pretrained model. We use these models to compute the task vectors.
-
clamp_weights
(bool
, default:True
) –If True, the layer-wise weights will be clamped to [0, 1]. Defaults to True.
-
tie_weights
(bool
, default:False
) –This option passes the
tie_weights
argument to thefunctional_call
function. Defaults to False. -
strict
(bool
, default:True
) –This option passes the
strict
argument to thefunctional_call
function. Defaults to True. -
sparsity_ratio
(float
, default:None
) –If
sparsity_ratio
is provided, the task vector will be pruned before merging. A high spasity level can save the memory usage during merging. -
normalized_merging_weights
(bool
, default:False
) –If True, the layer-wise weights will be normalized for each layer, so that the sum of weights across models for each layer is 1. Defaults to False.
Source code in fusion_bench/models/wrappers/layer_wise_fusion.py
merge_weights(task_vector_mask=None)
¶
Merges the weights of the model. Call this after each update step.
Source code in fusion_bench/models/wrappers/layer_wise_fusion.py
fix_other_parts(module)
¶
Sets all parameters in the module to not require gradients, except for the merge weights
in LayerWiseMergedModel
instances.
Parameters:
-
module
(Module
) –The module to process.
Returns:
-
–
nn.Module: The module with updated parameter requirements.
Source code in fusion_bench/models/wrappers/layer_wise_fusion.py
fuse_weights(layer_wise_weight, state_dicts)
¶
Fuse the weights of multiple models using layer-wise fusion.
Parameters:
-
layer_wise_weight
(Tensor
) –A tensor of shape (num_models, num_layers) representing the weight of each layer for each model.
-
state_dicts
(List[StateDict]
) –A list of state dictionaries, one for each model.
Returns:
-
StateDictType
–A dictionary mapping each weight tensor key to the fused weight tensor.
Source code in fusion_bench/models/wrappers/layer_wise_fusion.py
get_layer_wise_weights(num_models, num_layers, init_values=None, dtype=torch.float32)
¶
Return a tensor of layer-wise weights for the given number of models and layers.
Parameters:
-
num_models
(int
) –The number of models to fuse.
-
num_layers
(int
) –The number of layers in each model.
-
init_values
(float
, default:None
) –The initial value for each weight. Defaults to 1.0 / num_models.
-
dtype
(dtype
, default:float32
) –dtype of weights. This should be the same with model dtype.
Returns:
-
Tensor
–A tensor of shape (num_models, num_layers) containing the layer-wise weights.
Source code in fusion_bench/models/wrappers/layer_wise_fusion.py
merge_and_unload(module)
¶
Merges and unloads all LayerWiseMergedModel
instances within the given module.
Parameters:
-
module
(Module
) –The module to process.
Returns:
-
–
nn.Module: The updated module with merged weights.
Source code in fusion_bench/models/wrappers/layer_wise_fusion.py
merge_weights(module)
¶
Merges the weights for all LayerWiseMergedModel
instances within the given module.
Parameters:
-
module
(Module
) –The module to process.
Source code in fusion_bench/models/wrappers/layer_wise_fusion.py
task_wise_fusion
¶
# Get the task-wise weights
task_wise_weights = get_task_wise_weights(num_models)
# Define the task vectors (in this case, we'll use the state_dict of the pretrained model)
task_vectors = ...
# Initialize the TaskWiseMergedModel
merged_model = TaskWiseMergedModel(pretrained_model, task_wise_weights, task_vectors)
# Now you can use the merged_model like a regular PyTorch model
outputs = merged_model(inputs)
TaskWiseMergedModel
¶
Bases: Module
, Generic[TorchModelType]
A PyTorch module that dynamically merges multiple fine-tuned models using learnable task-wise weights.
This class implements a sophisticated model fusion approach where multiple task-specific models are combined with a pretrained base model using learnable weights. The fusion is performed using task vectors (differences between fine-tuned and pretrained models) that are weighted and added to the base model's parameters.
The key innovation is that the merging weights are learnable parameters that can be optimized during training, allowing the model to automatically learn the optimal combination of different task-specific knowledge.
Architecture
- Base pretrained model (frozen)
- Multiple task vectors (differences from pretrained model, frozen)
- Learnable task-wise weights (trainable parameters)
- Dynamic merging during forward pass
Parameters:
-
task_wise_weight
(Tensor
) –Initial weights for each task model. Shape: (num_models,). These become learnable parameters that control the contribution of each task vector.
-
pretrained_model
(TorchModelType
) –The base pretrained model that serves as the foundation. This model is frozen and used as the starting point for merging.
-
finetuned_models
(List[TorchModelType]
) –List of fine-tuned models for different tasks. These are converted to task vectors (differences from pretrained model) and frozen.
-
clamp_weights
(bool
, default:True
) –Whether to clamp merge weights to [0, 1] range. Defaults to True. When True, ensures weights are non-negative and bounded.
-
tie_weights
(bool
, default:False
) –Whether to tie weights during functional call. Defaults to False. Used in the underlying PyTorch functional_call.
-
strict
(bool
, default:True
) –Whether to enforce strict parameter matching. Defaults to True. Used in the underlying PyTorch functional_call.
-
task_vector_dtype
(Optional[dtype]
, default:None
) –Data type for task vectors. Defaults to None. Can be used to save memory (e.g., torch.float16).
Attributes:
-
merge_weight
(Parameter
) –Learnable weights for merging task vectors.
-
pretrained_model
(TorchModelType
) –The frozen base model.
-
task_vectors
(ModuleList
) –List of frozen task vector models.
-
_merged_state_dict
(StateDictType
) –Cached merged state dictionary.
Example
import torch
import torch.nn as nn
# Create example models
pretrained_model = nn.Linear(10, 5)
finetuned_model1 = nn.Linear(10, 5) # Fine-tuned on task 1
finetuned_model2 = nn.Linear(10, 5) # Fine-tuned on task 2
# Initialize task-wise weights
task_weights = torch.tensor([0.3, 0.7]) # Initial weights for 2 tasks
# Create merged model
merged_model = TaskWiseMergedModel(
task_wise_weight=task_weights,
pretrained_model=pretrained_model,
finetuned_models=[finetuned_model1, finetuned_model2],
clamp_weights=True
)
# Use like a regular PyTorch model
x = torch.randn(32, 10)
output = merged_model(x)
# Train the merge weights
optimizer = torch.optim.Adam(merged_model.parameters())
loss = some_loss_function(output, targets)
loss.backward()
optimizer.step()
# Get the final merged model
final_model = merged_model.merge_and_unload()
Training Workflow
- Initialization: Task vectors are computed as differences from pretrained model
- Forward Pass: Weights are dynamically merged based on current merge_weight values
- Loss Computation: Standard loss computation on model outputs
- Backpropagation: Gradients flow through merge_weight parameters
- Optimization: merge_weight parameters are updated to improve performance
Memory Efficiency
- Task vectors can use lower precision (task_vector_dtype)
- Base model and task vectors are frozen (no gradient computation)
- Only merge weights require gradients
Note
- The pretrained model and task vectors are frozen during training
- Only the merge weights (task_wise_weight) are trainable parameters
- Task vectors represent the difference between fine-tuned and pretrained models
- The merged state dict is cached and recomputed when merge weights change
Source code in fusion_bench/models/wrappers/task_wise_fusion.py
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 |
|
forward_model
property
¶
Get a functional model with merged parameters.
Returns a partial function that applies the pretrained model with the current merged state dictionary. This allows for efficient forward passes without modifying the original model's parameters.
Returns:
-
Callable
–A partial function that can be called with (args, kwargs) to perform forward pass with merged parameters.
__init__(task_wise_weight, pretrained_model, finetuned_models, clamp_weights=True, tie_weights=False, strict=True, task_vector_dtype=None)
¶
Initialize the TaskWiseMergedModel.
This constructor sets up the model by: 1. Converting fine-tuned models to task vectors (differences from pretrained) 2. Freezing the pretrained model and task vectors 3. Setting up learnable merge weights as parameters 4. Configuring merging behavior options
Parameters:
-
task_wise_weight
(Tensor
) –Initial weights for each task model. Shape: (num_models,). These values become the starting point for learnable parameters.
-
pretrained_model
(TorchModelType
) –The base pretrained model. Will be frozen and used as the foundation for merging.
-
finetuned_models
(List[TorchModelType]
) –List of fine-tuned models. Must have the same architecture as pretrained_model.
-
clamp_weights
(bool
, default:True
) –Whether to clamp weights to [0, 1]. Defaults to True.
-
tie_weights
(bool
, default:False
) –Whether to tie weights in functional_call. Defaults to False.
-
strict
(bool
, default:True
) –Whether to use strict parameter matching. Defaults to True.
-
task_vector_dtype
(Optional[dtype]
, default:None
) –Data type for task vectors. Defaults to None (same as original models).
Raises:
-
ValueError
–If the number of task_wise_weights doesn't match the number of fine-tuned models.
-
RuntimeError
–If models have incompatible architectures.
Source code in fusion_bench/models/wrappers/task_wise_fusion.py
forward(*args, **kwargs)
¶
Forward pass through the dynamically merged model.
This method performs the forward pass by first ensuring the model parameters are merged according to the current merge weights, then applying the merged model to the input data.
The forward pass involves: 1. Check if merged state dict is current (recompute if needed) 2. Apply the merged model to inputs using functional_call 3. Return the model outputs
Parameters:
-
*args
–Positional arguments to pass to the underlying model.
-
**kwargs
–Keyword arguments to pass to the underlying model.
Returns:
-
Any
–The output of the merged model, typically torch.Tensor or tuple of tensors.
Example
Note
- The merged state dict is recomputed if merge weights have changed
- This allows for dynamic behavior during training as weights are updated
- The computation is efficient as merging only happens when needed
Source code in fusion_bench/models/wrappers/task_wise_fusion.py
merge_and_unload(task_vector_mask=None)
¶
Merge models and return the final merged model.
This method performs the merging operation and then loads the merged parameters into the pretrained model, returning a standard PyTorch model that can be used independently of the TaskWiseMergedModel wrapper.
Parameters:
-
task_vector_mask
(Optional[Dict[str, Tensor]]
, default:None
) –Optional masks for selective parameter merging. Defaults to None.
Returns:
-
TorchModelType
–The pretrained model with merged parameters loaded. This is a standalone model that can be used without the wrapper.
Example
Warning
This method modifies the pretrained_model's parameters in-place. The original pretrained model parameters will be lost.
Source code in fusion_bench/models/wrappers/task_wise_fusion.py
merge_weights(task_vector_mask=None)
¶
Merge task vectors with the pretrained model using current merge weights.
This method computes the merged model parameters by combining the pretrained model with weighted task vectors. The resulting state dictionary represents a model that incorporates knowledge from all task-specific models.
The merging formula for each parameter is: merged_param = pretrained_param + Σ(weight_i * task_vector_i * mask_i)
Parameters:
-
task_vector_mask
(Optional[Dict[str, Tensor]]
, default:None
) –Optional masks to selectively apply task vectors to specific parameters. Keys should match parameter names, values should be tensors with the same shape as the corresponding parameters. Defaults to None (no masking).
Returns:
-
StateDictType
–The merged state dictionary containing combined parameters.
Example
# Basic merging
merged_state = model.merge_weights()
# Merging with parameter-specific masks
masks = {
'layer1.weight': torch.ones_like(model.pretrained_model.layer1.weight),
'layer2.weight': torch.zeros_like(model.pretrained_model.layer2.weight),
}
masked_state = model.merge_weights(task_vector_mask=masks)
Note
- If clamp_weights is True, merge weights are clamped to [0, 1] range
- The merged state dict is cached in _merged_state_dict
- Task vector masks allow fine-grained control over which parameters are affected
Source code in fusion_bench/models/wrappers/task_wise_fusion.py
fuse_weights(task_wise_weight, state_dicts)
¶
This function fuses the weights of the models and returns a state dictionary.
Parameters:
-
task_wise_weight
(Tensor
) –The weights for each model. on cuda or cpu.
-
state_dicts
(List[StateDictType]
) –The list of state dictionaries. on cpu.
Returns:
-
StateDictType
(StateDictType
) –The fused state dictionary.
Source code in fusion_bench/models/wrappers/task_wise_fusion.py
get_task_wise_weights(num_models, init_values=None)
¶
This function generates a tensor of weights for each model.
Parameters:
-
num_models
(int
) –The number of models.
-
init_values
(float
, default:None
) –The initial value for each weight. Defaults to None.
Returns:
-
Tensor
–A tensor of weights for each model.
Source code in fusion_bench/models/wrappers/task_wise_fusion.py
Model Ensemble¶
ensemble
¶
EnsembleModule
¶
Bases: Module
Ensemble module that averages the outputs of multiple models.
Source code in fusion_bench/models/wrappers/ensemble.py
__init__(models)
¶
Initializes the EnsembleModule with a list of models.
Parameters:
-
models
(List[Module]
) –List of models to ensemble.
Source code in fusion_bench/models/wrappers/ensemble.py
forward(*args, **kwargs)
¶
Performs a forward pass by averaging the outputs of the models.
Parameters:
-
*args
–Variable length argument list.
-
**kwargs
–Arbitrary keyword arguments.
Returns:
-
–
Aggregated output from the ensemble of models.
Source code in fusion_bench/models/wrappers/ensemble.py
MaxModelPredictor
¶
Bases: EnsembleModule
Ensemble module that selects the maximum output among multiple models.
Source code in fusion_bench/models/wrappers/ensemble.py
WeightedEnsembleModule
¶
Bases: Module
Ensemble module that computes a weighted average of the outputs from multiple models.
Source code in fusion_bench/models/wrappers/ensemble.py
__init__(models, weights, normalize=True)
¶
Initializes the WeightedEnsembleModule with models and their corresponding weights.
Parameters:
-
models
(List[Module]
) –List of models to ensemble.
-
weights
(List[float] | Tensor | ndarray
) –Weights for each model.
-
normalize
(bool
, default:True
) –If True, normalizes the weights. Defaults to True.
Source code in fusion_bench/models/wrappers/ensemble.py
forward(*args, **kwargs)
¶
Performs a forward pass by computing the weighted average of the models' outputs.
Parameters:
-
*args
–Variable length argument list.
-
**kwargs
–Arbitrary keyword arguments.
Returns:
-
–
Weighted aggregated output from the ensemble of models.
Source code in fusion_bench/models/wrappers/ensemble.py
aggregate_tensors(outputs, aggregate_fn)
¶
Aggregates a list of outputs using the provided aggregation function.
This function handles different types of outputs: - If the outputs are Tensors, it applies the aggregation function directly. - If the outputs are dictionaries, it recursively aggregates each value. - If the outputs are tuples or lists, it recursively aggregates each element. - If all outputs are None, it returns None. - If the outputs are of an unsupported type, it raises a ValueError.
Parameters:
-
outputs
(list
) –A list of outputs to be aggregated. The outputs can be Tensors, dictionaries, tuples, lists, or None.
-
aggregate_fn
(callable
) –A function to aggregate the outputs. Typically, this could be a function like
torch.mean
.
Returns:
-
Tensor
–Tensor or dict or tuple or list or None: The aggregated output, matching the type of the input outputs.
Raises:
-
ValueError
–If the outputs are of an unsupported type.
Source code in fusion_bench/models/wrappers/ensemble.py
Model Linearization (NTK)¶
LinearizedModelWraper
¶
Bases: Module
Source code in fusion_bench/models/linearized/linearized_model_utils.py
__init__(model, init_model=None)
¶
Initializes a linearized model.
Parameters:
-
model
(Module
) –The underlying PyTorch model to be linearized.
-
init_model
(Module
, default:None
) –The initial PyTorch model used to compute the linearization parameters (default: None).
Source code in fusion_bench/models/linearized/linearized_model_utils.py
forward(*args, **kwargs)
¶
Computes the linearized model output using a first-order Taylor decomposition.
Parameters:
-
*args
–Positional arguments to be passed to the model.
-
**kwargs
–Keyword arguments to be passed to the model.
Returns:
-
–
torch.Tensor: The output of the linearized model, computed using a first-order Taylor decomposition.
Source code in fusion_bench/models/linearized/linearized_model_utils.py
tuple_params_to_dict(tuple_params)
¶
Converts a tuple of parameters to a dictionary with keys corresponding to the parameter names.
Parameters:
-
tuple_params
(Tuple[Tensor, ...]
) –A tuple of parameters.
Returns:
-
–
Dict[str, Tensor]: A dictionary with keys corresponding to the parameter names and values corresponding to the
-
–
parameter values.
Source code in fusion_bench/models/linearized/linearized_model_utils.py
unload_linearized_modules_(module)
staticmethod
¶
Unloads the linearized module and returns the original module.
Parameters:
-
module
(Module
) –The linearized module to be unloaded.
Returns:
-
–
nn.Module: The original module.