Model Merging¶
Linear Interpolation¶
Simple Average¶
SimpleAverageAlgorithm
¶
Bases: BaseAlgorithm
, SimpleProfilerMixin
Source code in fusion_bench/method/simple_average.py
__init__(show_pbar=False)
¶
Parameters:
-
show_pbar
(bool
, default:False
) –If True, shows a progress bar during model loading and merging. Default is False.
run(modelpool)
¶
Fuse the models in the given model pool using simple averaging.
This method iterates over the names of the models in the model pool, loads each model, and appends it to a list. It then returns the simple average of the models in the list.
Parameters:
-
modelpool
(Union[BaseModelPool, Dict[str, Module]]
) –The pool of models to fuse.
Returns:
-
–
The fused model obtained by simple averaging.
Source code in fusion_bench/method/simple_average.py
SimpleAverageForLlama
¶
Bases: BaseAlgorithm
A simple averaging algorithm for LLama models. If merge_backbone
is set to True
, the backbone of the model will be averaged and the rest of the model will be loaded from the pre-trained model.
Examples:
The following example demonstrates how to use the SimpleAverageForLlama
algorithm to merge Mistral models.
fusion_bench \
method=linear/simple_average_for_llama \
method.model_save_path=outputs/simle_mixtral_exp_v4/simple_average \
modelpool=CausalLMPool/simle_mixtral_exp_v4.yaml
Source code in fusion_bench/method/linear/simple_average_for_llama.py
Weighted Average¶
LinearInterpolationAlgorithm
¶
Bases: BaseAlgorithm
LinearInterpolationAlgorithm
performs linear interpolation between two models.
Returns a model with the state dict that is a linear interpolation of the state dicts of the two models.
\(\theta = (1-t) \theta_1 + t \theta_2\)
Source code in fusion_bench/method/linear/linear_interpolation.py
__init__(t, **kwargs)
¶
Initialize the LinearInterpolationAlgorithm
with the given interpolation parameter.
Parameters:
-
t
(float
) –The interpolation parameter, should be in the range [0, 1].
-
**kwargs
–Additional keyword arguments.
Source code in fusion_bench/method/linear/linear_interpolation.py
run(modelpool)
¶
Run the linear interpolation algorithm on the given model pool.
This method performs linear interpolation between two models in the model pool and returns a model with the interpolated state dict.
Parameters:
-
modelpool
(BaseModelPool
) –The pool of models to interpolate. Must contain exactly two models.
Returns:
-
–
nn.Module: The model with the interpolated state dict.
Source code in fusion_bench/method/linear/linear_interpolation.py
WeightedAverageAlgorithm
¶
Bases: BaseAlgorithm
, SimpleProfilerMixin
Source code in fusion_bench/method/weighted_average/weighted_average.py
run(modelpool)
¶
Fuses the models in the model pool using a weighted average approach.
Parameters modelpool (ModelPool): The pool of models to be fused.
Raises ValueError: If the number of weights does not match the number of models in the model pool.
Returns forward_model (torch.nn.Module): The resulting model after fusion.
Source code in fusion_bench/method/weighted_average/weighted_average.py
WeightedAverageForLLama
¶
Bases: BaseAlgorithm
A class to perform weighted averaging of LlaMa/Mistral models.
Source code in fusion_bench/method/weighted_average/llama.py
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
|
__init__(normalize, weights, backbone_only, merged_model_save_path, save_tokenizer, push_to_hub, **kwargs)
¶
Initialize the WeightedAverageForLLama class with the given parameters.
Parameters:
-
normalize
(bool
) –Whether to normalize the weights.
-
weights
(List[float]
) –The weights for averaging the models.
-
backbone_only
(bool
) –Whether to use only the backbone of the models.
-
merged_model_save_path
(str
) –The path to save the merged model.
-
save_tokenizer
(bool
) –Whether to save the tokenizer.
-
push_to_hub
(bool
) –Whether to push the model to the hub.
Source code in fusion_bench/method/weighted_average/llama.py
run(modelpool)
¶
Executes the weighted averaging of models in the provided model pool.
Parameters:
-
modelpool
(LLamaForCausalLMPoolThe
) –pool of models to be averaged.
Returns:
-
base_model
–The base model after merging the state dictionaries of the models in the pool.
Raises:
-
ValueError
–If the number of weights does not match the number of models in the pool.
Source code in fusion_bench/method/weighted_average/llama.py
Spherical Linear Interpolation (Slerp)¶
SlerpMergeAlgorithm
¶
Bases: BaseAlgorithm
General purpose implementation of Slerp (Spherical Linear Interpolation) for PyTorch models.
Source code in fusion_bench/method/slerp/slerp.py
__init__(t, DOT_THRESHOLD=0.9995, epsilon=1e-08)
¶
Initialize the SlerpMergeAlgorithm.
Parameters:
-
t
(float
) –The interpolation parameter. Must be in the range [0, 1].
-
DOT_THRESHOLD
(float
, default:0.9995
) –The threshold for the dot product of the two vectors. Defaults to 0.9995.
-
epsilon
(float
, default:1e-08
) –The epsilon value for numerical stability. Defaults to 1e-8.
Source code in fusion_bench/method/slerp/slerp.py
run(modelpool)
¶
Run the SlerpMergeAlgorithm on the given model pool.
Parameters:
-
modelpool
(BaseModelPool
) –The pool of models to fuse.
Returns:
-
–
nn.Module: The fused model.
Source code in fusion_bench/method/slerp/slerp.py
Task Arithmetic¶
TaskArithmeticForLlama
¶
Bases: TaskArithmeticAlgorithm
, SimpleProfilerMixin
Examples:
fusion_bench \ method=linear/task_arithmetic_for_llama \ method.scaling_factor=0.3 \ method.model_save_path=outputs/simle_mixtral_exp_v4/task_arithmetic_0.3 \ modelpool=CausalLMPool/simle_mixtral_exp_v4.yaml
Source code in fusion_bench/method/linear/task_arithmetic_for_llama.py
Ties-Merging¶
TiesMergingAlgorithm
¶
Bases: BaseAlgorithm
, SimpleProfilerMixin
TiesMergingAlgorithm is a class for fusing multiple models using the TIES merging technique.
Attributes:
-
scaling_factor
(float
) –The scaling factor to apply to the merged task vector.
-
threshold
(float
) –The threshold for resetting values in the task vector.
-
remove_keys
(List[str]
) –List of keys to remove from the state dictionary.
-
merge_func
(Literal['sum', 'mean', 'max']
) –The merge function to use for disjoint merging.
Source code in fusion_bench/method/ties_merging/ties_merging.py
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
|
__init__(scaling_factor, threshold, remove_keys, merge_func, **kwargs)
¶
Initialize the TiesMergingAlgorithm with the given parameters.
Parameters:
-
scaling_factor
(float
) –The scaling factor to apply to the merged task vector.
-
threshold
(float
) –The threshold for resetting values in the task vector.
-
remove_keys
(List[str]
) –List of keys to remove from the state dictionary.
-
merge_func
(Literal['sum', 'mean', 'max']
) –The merge function to use for disjoint merging.
-
**kwargs
–Additional keyword arguments for the base class.
Source code in fusion_bench/method/ties_merging/ties_merging.py
run(modelpool, **kwargs)
¶
Run the TIES merging algorithm to fuse models in the model pool.
Parameters:
-
modelpool
(BaseModelPool | Dict[str, Module]
) –The model pool containing the models to fuse.
Returns:
-
–
nn.Module: The fused model.
Source code in fusion_bench/method/ties_merging/ties_merging.py
Fisher Merging¶
FisherMergingForCLIPVisionModel
¶
Bases: CLIPClassificationMixin
, FisherMergingAlgorithm
Implements Fisher Merging for CLIP Vision Models.
This class extends the FisherMergingAlgorithm and CLIPClassificationMixin to handle the specifics of merging CLIP Vision models using Fisher weights.
Source code in fusion_bench/method/fisher_merging/clip_fisher_merging.py
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
|
__init__(*, exclude_param_names_regex, normalize_fisher_weight, minimal_fisher_weight, num_fisher_examples, dataloader_kwargs, zeroshot_weights_cache_dir=None, **kwargs)
¶
Initialize the FisherMergingForCLIPVisionModel with the given configuration.
Parameters:
-
exclude_param_names_regex
(list
) –List of regex patterns to exclude certain parameter names.
-
normalize_fisher_weight
(bool
) –Whether to normalize Fisher weights.
-
minimal_fisher_weight
(float
) –Minimal value for Fisher weights to avoid numerical issues.
-
num_fisher_examples
(int
) –Number of examples to compute Fisher weights.
-
dataloader_kwargs
(DictConfig
) –Configuration for the dataloader.
-
zeroshot_weights_cache_dir
(str
, default:None
) –Directory to cache zero-shot weights. Defaults to None.
-
**kwargs
–Additional keyword arguments.
Source code in fusion_bench/method/fisher_merging/clip_fisher_merging.py
compute_logits(module, batch, task)
¶
Compute the logits for the given images and task.
Parameters:
-
module
(Module
) –The model module.
-
batch
(tuple
) –A batch of data containing images and labels.
-
task
(str
) –The name of the task.
Returns:
-
Tensor
(Tensor
) –The computed logits.
Source code in fusion_bench/method/fisher_merging/clip_fisher_merging.py
get_fisher_weights(model_name, model, train_dataset, param_names_to_merge)
¶
Compute the Fisher weights for the given model and training dataset.
Parameters:
-
model_name
(str
) –The name of the model.
-
model
(Module
) –The model module.
-
train_dataset
–The training dataset.
-
param_names_to_merge
(List[str]
) –List of parameter names to merge.
Returns:
-
Dict[str, Tensor]
–Dict[str, Tensor]: The computed Fisher weights for each parameter.
Source code in fusion_bench/method/fisher_merging/clip_fisher_merging.py
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
|
on_fisher_merging_start()
¶
Setup the zero-shot classification head before starting the Fisher merging process.
Drop And REscale (DARE)¶
DareSimpleAverage
¶
Bases: BaseAlgorithm
Source code in fusion_bench/method/dare/simple_average.py
DareTaskArithmetic
¶
Bases: BaseAlgorithm
Implementation of Task Arithmetic w/ DARE.
- Yu et al. Language Models are Super Mario: Absorbing Abilities from Homologous Models as a Free Lunch. 2023. http://arxiv.org/abs/2311.03099
Source code in fusion_bench/method/dare/task_arithmetic.py
DareTiesMerging
¶
Bases: BaseAlgorithm
Source code in fusion_bench/method/dare/ties_merging.py
Model Extrapolation (ExPO)¶
ExPOAlgorithm
¶
Bases: BaseAlgorithm
ExPO merge algorithm.
This algorithm merges a pretrained model with a finetuned model.
where \(\theta_{merged}\) is the merged model, \(\theta_{rlhf}\) is the finetuned model (medium-aligned model), \(\theta_{sft}\) is the pretrained model (base model), and \(\alpha\) is the extrapolation factor.
In the configuration, the SFT model should have name _pretrained_
and the rlhf name can be set arbitarily.
Source code in fusion_bench/method/linear/expo.py
run(modelpool)
¶
Run the ExPO merge algorithm.
Parameters:
-
modelpool
(BaseModelPool
) –The pool of models to merge.
Returns:
-
–
nn.Module: The merged model.
Source code in fusion_bench/method/linear/expo.py
ExPOAlgorithmForLlama
¶
Bases: BaseAlgorithm
Source code in fusion_bench/method/linear/llama_expo.py
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
|
DOGE¶
DOGE_TA_Algorithm
¶
Bases: BaseAlgorithm
, SimpleProfilerMixin
, LightningFabricMixin
Task Arithmetic Algorithm for model fusion with learnable delta.
This class extends the Task Arithmetic method to include a learnable delta for task vectors, optimized to maximize cosine similarity among the task vectors.
Attributes:
-
scaling_factor
(int
) –The factor by which the task vectors will be scaled before merging.
-
delta
(StateDictType
) –A learnable parameter to adjust task vectors, initialized as zeros.
Source code in fusion_bench/method/doge_ta/doge_ta.py
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 |
|
compute_task_vectors(modelpool, pretrained_model)
¶
Computes task vectors for each model in the model pool relative to the pretrained model.
Source code in fusion_bench/method/doge_ta/doge_ta.py
optimize_delta(task_vectors)
¶
Optimizes the delta based on the loss of task vectors.
Source code in fusion_bench/method/doge_ta/doge_ta.py
run(modelpool)
¶
Runs the Algorithm with learnable delta to fuse models in the given model pool.
Parameters:
-
modelpool
(Union[BaseModelPool, Dict[str, Module]]
) –The pool of models to fuse.
Returns:
-
–
nn.Module: The pre-trained model with the merged task vectors after optimizing delta.
Source code in fusion_bench/method/doge_ta/doge_ta.py
173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 |
|
state_dict_to_vector(state_dict, remove_keys=[])
¶
Convert a state dictionary to a vector, removing specified keys.
Parameters:
-
state_dict
(dict
) –The state dictionary to convert.
-
remove_keys
(list
, default:[]
) –List of keys to remove from the state dictionary.
Returns:
-
Tensor
–A vector representation of the state dictionary.
Source code in fusion_bench/method/doge_ta/doge_ta.py
taskvector_loss(layer_vectors, layer_delta, layer_lamdas)
¶
Computes the loss based on delta and task vectors for a specific layer.
Source code in fusion_bench/method/doge_ta/doge_ta.py
vector_to_state_dict(vector, state_dict, remove_keys=[])
¶
Convert a vector back to a state dictionary, removing specified keys.
Parameters:
-
vector
(Tensor
) –The vector to convert.
-
state_dict
(dict
) –The reference state dictionary.
-
remove_keys
(list
, default:[]
) –List of keys to remove from the state dictionary.
Returns:
-
dict
–A state dictionary representation of the vector.
Source code in fusion_bench/method/doge_ta/doge_ta.py
AdaMerging¶
CLIPTaskWiseAdaMergingAlgorithm
¶
Bases: TaskWiseAdaMergingAlgorithm
A class for task-wise adaptive merging of CLIP models.
This class extends the TaskWiseAdaMergingAlgorithm to provide specific functionality for CLIP models, including loading datasets, constructing zero-shot classification heads, and computing logits.
Attributes:
-
modelpool
(CLIPVisionModelPool
) –The model pool containing CLIP models.
-
_clip_processor
(CLIPProcessor
) –The CLIP processor for preparing inputs.
-
zeroshot_weights
(dict
) –A dictionary to store zero-shot weights for each task.
Source code in fusion_bench/method/adamerging/clip_task_wise_adamerging.py
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
|
compute_logits(module, batch, task)
¶
Compute the logits for the given batch and task.
This method computes the image embeddings, normalizes them, and calculates the cosine similarity with the text embeddings to produce classification logits.
Parameters:
-
module
(Module
) –The model module.
-
batch
(tuple
) –A batch of input data.
-
task
(str
) –The name of the task.
Returns:
-
Tensor
(Tensor
) –The classification logits for the batch.
Source code in fusion_bench/method/adamerging/clip_task_wise_adamerging.py
get_shuffled_test_loader_iter(task)
cached
¶
Get an iterator over the shuffled test DataLoader for the task.
Parameters:
-
task
(str
) –The name of the task.
Returns:
-
iterator
–An iterator over the shuffled test DataLoader.
Source code in fusion_bench/method/adamerging/clip_task_wise_adamerging.py
get_test_dataset(task)
cached
¶
Load the test dataset for the task. This method is cached, so the dataset is loaded only once.
Parameters:
-
task
(str
) –The name of the task.
Returns:
-
CLIPDataset
–The test dataset for the task.
Source code in fusion_bench/method/adamerging/clip_task_wise_adamerging.py
on_test_time_adaptation_start()
¶
Prepare for test-time adaptation.
This method loads the CLIP processor and constructs the zero-shot classification head for each task.
Source code in fusion_bench/method/adamerging/clip_task_wise_adamerging.py
CLIPLayerWiseAdaMergingAlgorithm
¶
Bases: CLIPClassificationMixin
, LayerWiseAdaMergingAlgorithm
Source code in fusion_bench/method/adamerging/clip_layer_wise_adamerging.py
on_test_time_adaptation_start()
¶
Here we load the CLIP processor and construct the zero-shot classification head for each task.
GPT2LayerWiseAdaMergingAlgorithm
¶
Bases: BaseAlgorithm
, LightningFabricMixin
, SimpleProfilerMixin
Source code in fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 |
|
compute_logits(module, batch, task)
¶
Compute the logits for the given images and task.
Parameters:
-
module
(GPT2Model
) –The model module.
-
images
(Tensor
) –The input images.
-
task
(str
) –The name of the task.
Returns:
-
Tensor
(Tensor
) –The computed logits.
Source code in fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py
construct_layer_wise_merged_model(modelpool)
¶
Constructs a wrapped layer-wise merged model from model pool.
This method creates a new wrapped model by merging the layers of a pretrained model with those of several fine-tuned models.
The merging is controlled by layer-wise weights, which is a torch.Tensor
of the shape (num_models, num_layers)
.
The merging weights can be initialized based on a provided configuration or loaded from a file.
Parameters:
-
modelpool
(ModelPool
) –An object containing the pretrained model and fine-tuned models to be merged.
Returns:
-
LayerWiseMergedModel
–An instance of the merged model with layer-wise weights applied.
Source code in fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py
get_shuffled_test_loader_iter(task)
cached
¶
Loader of test dataset for test-time adaptation. labels are not needed.
Parameters:
-
task
(str
) –The name of the task.
Returns:
-
DataLoader
(DataLoader
) –The data loader for the test dataset.
Source code in fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py
on_test_time_adaptation_start()
¶
Something to do before the test-time adaptation starts. Such as setting up the task-specific heads.
Source code in fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py
run(modelpool, **kwargs)
¶
Run the Layer-Wise AdaMerging Algorithm.
This method constructs the wrapped model and performs test-time adaptation if necessary.
Parameters:
-
modelpool
(ModelPool
) –The model pool containing the pretrained and fine-tuned models.
Returns:
-
LayerWiseMergedModel
–The merged model after test-time adaptation.
Source code in fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py
save_merging_weights(file_path, merging_weights)
¶
Save the merging weights to a file.
Parameters:
-
file_path
(str
) –The path to save the merging weights.
-
merging_weights
(Tensor
) –The merging weights to save.
Source code in fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py
test_time_adaptation(module)
¶
Perform test-time adaptation on the merged model.
This method adapts the merging weights during test-time to improve performance.
Parameters:
-
module
(LayerWiseMergedModel
) –The merged model.
Returns:
-
LayerWiseMergedModel
–The adapted merged model.
Source code in fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py
FlanT5LayerWiseAdaMergingAlgorithm
¶
Bases: BaseAlgorithm
, LightningFabricMixin
, SimpleProfilerMixin
Source code in fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 |
|
compute_logits(module, batch, task)
¶
Compute the logits for the given images and task.
Parameters:
-
module
(Union[T5ForConditionalGeneration, LayerWiseMergedModel]
) –The model module.
-
images
(Tensor
) –The input images.
-
task
(str
) –The name of the task.
Returns:
-
Tensor
(Tensor
) –The computed logits.
Source code in fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py
construct_layer_wise_merged_model(modelpool)
¶
Constructs a wrapped layer-wise merged model from model pool.
This method creates a new wrapped model by merging the layers of a pretrained model with those of several fine-tuned models.
The merging is controlled by layer-wise weights, which is a torch.Tensor
of the shape (num_models, num_layers)
.
The merging weights can be initialized based on a provided configuration or loaded from a file.
Parameters:
-
modelpool
(ModelPool
) –An object containing the pretrained model and fine-tuned models to be merged.
Returns:
-
LayerWiseMergedModel
–An instance of the merged model with layer-wise weights applied.
Source code in fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py
get_shuffled_test_loader_iter(task)
cached
¶
Loader of test dataset for test-time adaptation. labels are not needed.
Parameters:
-
task
(str
) –The name of the task.
Returns:
-
DataLoader
(DataLoader
) –The data loader for the test dataset.
Source code in fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py
on_test_time_adaptation_start()
¶
Something to do before the test-time adaptation starts. Such as setting up the task-specific heads.
run(modelpool, **kwargs)
¶
Run the Layer-Wise AdaMerging Algorithm.
This method constructs the wrapped model and performs test-time adaptation if necessary.
Parameters:
-
modelpool
(ModelPool
) –The model pool containing the pretrained and fine-tuned models.
Returns:
-
LayerWiseMergedModel
–The merged model after test-time adaptation.
Source code in fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py
save_merging_weights(file_path, merging_weights)
¶
Save the merging weights to a file.
Parameters:
-
file_path
(str
) –The path to save the merging weights.
-
merging_weights
(Tensor
) –The merging weights to save.
Source code in fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py
test_time_adaptation(module)
¶
Perform test-time adaptation on the merged model.
This method adapts the merging weights during test-time to improve performance.
Parameters:
-
module
(LayerWiseMergedModel
) –The merged model.
Returns:
-
LayerWiseMergedModel
–The adapted merged model.
Source code in fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py
Optimization-based Methods¶
RegMean¶
RegMeanAlgorithmForCLIP
¶
Bases: RegMeanAlgorithm
, CLIPClassificationMixin
Source code in fusion_bench/method/regmean/clip_regmean.py
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
|
RegMeanAlgorithmForGPT2
¶
Bases: RegMeanAlgorithm
, LightningFabricMixin
Source code in fusion_bench/method/regmean/gpt2_regmean.py
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
|
RegMean++¶
RegMeanAlgorithmForCLIPPlusPlus
¶
Bases: RegMeanAlgorithmPlusPlus
, CLIPClassificationMixin
Source code in fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
|
Frank-Wolfe Merging¶
FrankWolfeSoftAlgorithm
¶
Bases: CLIPClassificationMixin
, ModelFusionAlgorithm
, SimpleProfilerMixin
Source code in fusion_bench/method/fw_merging/fw_soft.py
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 |
|
__init__(max_iters, dataset_size, ada_iters, ada_coeff, merge_fn, granularity='task', max_num_models=100, step_size=0.3, tasks=[], init_weight='', ada_loss='entropy_loss', **kwargs)
¶
Initializes the TaskArithmeticAlgorithm with the given scaling factor.
Parameters:
-
step_size
(int
, default:0.3
) –The factor by which the task vectors will be scaled before merging.
Source code in fusion_bench/method/fw_merging/fw_soft.py
FrankWolfeHardAlgorithm
¶
Bases: CLIPClassificationMixin
, ModelFusionAlgorithm
, SimpleProfilerMixin
Source code in fusion_bench/method/fw_merging/fw_hard.py
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 |
|
__init__(merge_fn, step_size, max_iters, dataset_size, tasks=[], granularity='task', max_num_models=100, loss_fn='cross_entropy', init_weight='', scaling_factor=1.0, threshold=20, **kwargs)
¶
Initializes the TaskArithmeticAlgorithm with the given scaling factor.
Parameters:
-
scaling_factor
(int
, default:1.0
) –The factor by which the task vectors will be scaled before merging.
Source code in fusion_bench/method/fw_merging/fw_hard.py
Subspace-based Methods¶
Concrete Subspace¶
ConcreteTaskArithmeticAlgorithmForCLIP
¶
Bases: CLIPClassificationMixin
, SimpleProfilerMixin
, ModelFusionAlgorithm
ConcreteTaskArithmeticAlgorithmForCLIP is a class for performing task arithmetic on CLIP models with learned masking.
This class extends the CLIPClassificationMixin, SimpleProfilerMixin, and ModelFusionAlgorithm classes. It provides methods for setting up models, training masks, and running the task arithmetic algorithm.
Attributes:
-
merge_dtype
(dtype
) –The data type for merging weights.
-
modelpool
(HuggingFaceClipVisionPool
) –The model pool containing the pretrained and fine-tuned models.
Source code in fusion_bench/method/concrete_subspace/clip_concrete_task_arithmetic.py
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 |
|
run(modelpool)
¶
Run the Concrete Task Arithmetic algorithm.
This method sets up the models, trains the mask model if necessary, and performs the final merging of weights.
Parameters:
-
modelpool
(HuggingFaceClipVisionPool
) –The model pool containing the pretrained and fine-tuned models.
Returns:
-
–
torch.nn.Module: The final merged model.
Source code in fusion_bench/method/concrete_subspace/clip_concrete_task_arithmetic.py
setup_models()
¶
Set up the pretrained model, fine-tuned models, and mask model.
This method loads the pretrained model, constructs the PGE mask model, and loads the fine-tuned models. It also creates a wrapped model with task-wise weights.
Returns:
-
–
Tuple[TaskWiseMergedModel, MaskModel]: The wrapped model and mask model.
Source code in fusion_bench/method/concrete_subspace/clip_concrete_task_arithmetic.py
train_mask(module, mask_model)
¶
Train the mask model using the provided module.
This method configures the optimizer, sets up the mask model, and performs test-time adaptation to train the mask model.
Parameters:
-
module
(TaskWiseMergedModel
) –The wrapped model with task-wise weights.
-
mask_model
(MaskModel
) –The mask model to be trained.
Source code in fusion_bench/method/concrete_subspace/clip_concrete_task_arithmetic.py
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
|
ConcreteTaskWiseAdaMergingForCLIP
¶
Bases: CLIPClassificationMixin
, SimpleProfilerMixin
, ModelFusionAlgorithm
Source code in fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 |
|
ConcreteLayerWiseAdaMergingForCLIP
¶
Bases: CLIPClassificationMixin
, SimpleProfilerMixin
, ModelFusionAlgorithm
Source code in fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py
324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 |
|
Task Singular Vector Merging (TSVM)¶
TaskSingularVectorMerging
¶
Bases: BaseAlgorithm
, LightningFabricMixin
Task Singular Vector Merging (TSVM) Algorithm
This class implements a model merging technique that leverages Singular Value Decomposition (SVD) to identify and combine the most important directions in the task vector space. The algorithm is particularly effective for merging multiple models fine-tuned on different tasks while preserving their essential capabilities.
Key Concepts: - Task Vector: The difference between a fine-tuned model and its pretrained base model, representing the knowledge gained during fine-tuning for a specific task. - Singular Value Decomposition: A matrix factorization technique used to find the principal components (most important directions) in the space of task vectors. - Model Merging: The process of combining multiple models into a single unified model that retains capabilities from all constituent models.
Algorithm Steps: 1. Extract task vectors from all fine-tuned models by subtracting the pretrained model 2. Apply SVD to the matrix of task vectors to find principal components 3. Reconstruct task vectors using only the most significant singular vectors 4. Merge the reconstructed task vectors (either individually scaled or as a sum) 5. Add the final merged task vector to the pretrained model to create the unified model
see docs/algorithms/task_singular_vector.md
for comprehensive algorithmic details.
Source code in fusion_bench/method/task_singular_vector/TSVM.py
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 |
|
__init__(alpha=None, exclude_keys=None, return_single_task_models=False, **kwargs)
¶
Initialize the Task Singular Vector Merging algorithm.
Parameters:
-
alpha
(Union[float, Iterable[float]]
, default:None
) –Scaling factor(s) for task vectors. This parameter controls the strength of the task-specific adaptations in the final model.
-
If a single float: Applied to the final merged task vector after SVD reconstruction. This uniformly scales the entire merged adaptation.
-
If an iterable of floats: Applied to individual task vectors before SVD and merging. Must have the same length as the number of models in the modelpool. This allows for task-specific weighting (e.g., giving more importance to certain tasks).
-
If None: No scaling is applied (equivalent to alpha=1.0).
Example: alpha=[0.8, 1.2, 0.5] would apply different weights to three different task vectors.
-
-
exclude_keys
(Optional[List[str]]
, default:None
) –List of parameter names to exclude from TSVM. These parameters will not participate in the SVD computation and merging process. Useful for excluding certain layers (e.g., task-specific heads, normalization layers) that should not be merged across tasks. Defaults to an empty list.
Example: exclude_keys=['classifier.weight', 'classifier.bias'] to skip classification heads.
-
return_single_task_models
(bool
, default:False
) –Whether to return individual transformed models.
- If True: Returns a dictionary containing both individual models with their transformed task vectors applied AND the final merged model. The dictionary has the structure:
{'model_name_1': transformed_model_1, ..., 'merged': final_merged_model}
- If False: Returns only the final merged model.
This is useful for analysis or when you need access to intermediate results. Defaults to False.
-
**kwargs
–Additional arguments passed to the parent BaseAlgorithm class.
Note
The choice between single alpha vs. list of alphas affects the merging strategy: - Single alpha: SVD is applied first, then the result is scaled - List of alphas: Individual task vectors are scaled first, then SVD is applied
Source code in fusion_bench/method/task_singular_vector/TSVM.py
load_pretrained_model_and_task_vectors(modelpool)
¶
Load the pretrained base model and compute task vectors from all fine-tuned models.
This method performs the initial step of the TSVM algorithm by: 1. Loading the original pretrained model (before any task-specific fine-tuning) 2. For each fine-tuned model in the pool: - Load the fine-tuned model - Compute the task vector (fine-tuned params - pretrained params) - Optionally apply individual scaling if alpha is provided as a list
Task vectors represent the knowledge gained during fine-tuning and are the core data structure that TSVM operates on.
Parameters:
-
modelpool
(BaseModelPool
) –Pool containing the pretrained model and all fine-tuned models to be merged.
Returns:
-
tuple
–A tuple containing: - pretrained_model: The original pretrained model (torch.nn.Module) - task_vectors: List of task vectors (List[StateDictType]), where each task vector is a state dictionary representing the parameter differences for one specific task
Source code in fusion_bench/method/task_singular_vector/TSVM.py
run(modelpool)
¶
Execute the complete Task Singular Vector Merging algorithm.
This is the main entry point that orchestrates the entire TSVM process:
The algorithm leverages the mathematical insight that task vectors often lie in a lower-dimensional subspace, and SVD helps identify the most important directions in this subspace while filtering out noise and interference.
Parameters:
-
modelpool
(BaseModelPool
) –Pool of models to merge, including: - The pretrained base model - Multiple fine-tuned models (one per task) All models must have compatible architectures.
Returns:
-
–
Union[torch.nn.Module, Dict[str, torch.nn.Module]]: - If return_single_task_models=False: Returns the merged model - If return_single_task_models=True: Returns a dictionary with: * Individual transformed models keyed by their original names * Final merged model under the key 'merged'
Raises:
-
AssertionError
–If alpha list length doesn't match the number of models
Source code in fusion_bench/method/task_singular_vector/TSVM.py
212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 |
|
Isotropic Merging¶
ISO_C_Merge = IsotropicMergingInCommonSubspace
module-attribute
¶
ISO_CTS_Merge = IsotropicMergingInCommonAndTaskSubspace
module-attribute
¶
IsotropicMergingInCommonSubspace
¶
Bases: BaseAlgorithm
, LightningFabricMixin
Isotropic Merging in Common Subspace (Iso-C)
Source code in fusion_bench/method/isotropic_merging/iso.py
IsotropicMergingInCommonAndTaskSubspace
¶
Bases: BaseAlgorithm
, LightningFabricMixin
Isotropic Merging in Common and Task-Specific Subspaces (Iso-CTS)
Source code in fusion_bench/method/isotropic_merging/iso.py
Distributed Model Merging¶
Gossip¶
CLIPTaskWiseGossipAlgorithm
¶
Bases: TaskWiseGossipAlgorithm
A class for task-wise adaptive merging of CLIP models.
This class extends the TaskWiseGossipAlgorithm to provide specific functionality for CLIP models, including loading datasets, constructing zero-shot classification heads, and computing logits.
Attributes:
-
modelpool
(CLIPVisionModelPool
) –The model pool containing CLIP models.
-
_clip_processor
(CLIPProcessor
) –The CLIP processor for preparing inputs.
-
zeroshot_weights
(dict
) –A dictionary to store zero-shot weights for each task.
Source code in fusion_bench/method/gossip/clip_task_wise_gossip.py
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
|
compute_logits(module, batch, task)
¶
Compute the logits for the given batch and task.
This method computes the image embeddings, normalizes them, and calculates the cosine similarity with the text embeddings to produce classification logits.
Parameters:
-
module
(Module
) –The model module.
-
batch
(tuple
) –A batch of input data.
-
task
(str
) –The name of the task.
Returns:
-
Tensor
(Tensor
) –The classification logits for the batch.
Source code in fusion_bench/method/gossip/clip_task_wise_gossip.py
get_shuffled_test_loader_iter(task)
cached
¶
Get an iterator over the shuffled test DataLoader for the task.
Parameters:
-
task
(str
) –The name of the task.
Returns:
-
iterator
–An iterator over the shuffled test DataLoader.
Source code in fusion_bench/method/gossip/clip_task_wise_gossip.py
get_test_dataset(task)
cached
¶
Load the test dataset for the task. This method is cached, so the dataset is loaded only once.
Parameters:
-
task
(str
) –The name of the task.
Returns:
-
CLIPDataset
–The test dataset for the task.
Source code in fusion_bench/method/gossip/clip_task_wise_gossip.py
on_test_time_adaptation_start()
¶
Prepare for test-time adaptation.
This method loads the CLIP processor and constructs the zero-shot classification head for each task.
Source code in fusion_bench/method/gossip/clip_task_wise_gossip.py
CLIPLayerWiseGossipAlgorithm
¶
Bases: CLIPClassificationMixin
, LayerWiseGossipAlgorithm
Source code in fusion_bench/method/gossip/clip_layer_wise_gossip.py
on_test_time_adaptation_start()
¶
Here we load the CLIP processor and construct the zero-shot classification head for each task.
Source code in fusion_bench/method/gossip/clip_layer_wise_gossip.py
FlanT5LayerWiseGossipAlgorithm
¶
Bases: BaseAlgorithm
, LightningFabricMixin
, SimpleProfilerMixin
Source code in fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 |
|
compute_logits(module, batch, task)
¶
Compute the logits for the given images and task.
Parameters:
-
module
(Union[T5ForConditionalGeneration, LayerWiseMergedModel]
) –The model module.
-
images
(Tensor
) –The input images.
-
task
(str
) –The name of the task.
Returns:
-
Tensor
(Tensor
) –The computed logits.
Source code in fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py
get_shuffled_test_loader_iter(task)
cached
¶
Loader of test dataset for test-time adaptation. labels are not needed.
Parameters:
-
task
(str
) –The name of the task.
Returns:
-
DataLoader
(DataLoader
) –The data loader for the test dataset.
Source code in fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py
on_test_time_adaptation_start()
¶
Something to do before the test-time adaptation starts. Such as setting up the task-specific heads.
run(modelpool, **kwargs)
¶
Run the Layer-Wise AdaMerging Algorithm.
This method constructs the wrapped model and performs test-time adaptation if necessary.
Parameters:
-
modelpool
(ModelPool
) –The model pool containing the pretrained and fine-tuned models.
Returns:
-
LayerWiseMergedModel
–The merged model after test-time adaptation.
Source code in fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
|
save_merging_weights(file_path, merging_weights)
¶
Save the merging weights to a file.
Parameters:
-
file_path
(str
) –The path to save the merging weights.
-
merging_weights
(Tensor
) –The merging weights to save.
Source code in fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py
test_time_adaptation(module, datasets)
¶
Perform test-time adaptation on the merged model.
This method adapts the merging weights during test-time to improve performance.
Parameters:
-
module
(LayerWiseMergedModel
) –The merged model.
Returns:
-
LayerWiseMergedModel
–The adapted merged model.
Source code in fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py
update_datasets(datasets)
¶
for evary epoch of local adamerging, we only use the data set corresponding to the model involved in the fusion
Source code in fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py
Continual Model Merging¶
Orthogonal Projection-based Continual Merging (OPCM)¶
OPCMForCLIP
¶
Bases: BaseAlgorithm
, LightningFabricMixin
, SimpleProfilerMixin
Source code in fusion_bench/method/opcm/opcm.py
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 |
|
__init__(alpha, shuffle_order=True, seed=None, save_on_every_step=True, evaluate_on_every_step=False, **kwargs)
¶
Continual Model Merging via SVD Projection.
Parameters:
-
alpha
(float
) –the scaling factor for the SVD projection.
-
shuffle_order
(bool
, default:True
) –whether to shuffle the order of the models.
-
seed
(Optional[int]
, default:None
) –the seed to use.
-
save_on_every_step
(bool
, default:True
) –whether to save the merged model on every step.
-
evaluate_on_every_step
(bool
, default:False
) –whether to evaluate the merged model on every step.