fusion_bench.method¶
Base Class¶
- BaseAlgorithm: The base class for all fusion algorithms in FusionBench.
BaseAlgorithm
¶
Bases: BaseYAMLSerializable
Base class for model fusion algorithms.
This abstract class provides a standardized interface for implementing model fusion algorithms. It inherits from BaseYAMLSerializable to support configuration loading from YAML files.
The class follows a template method pattern where subclasses must implement the
core fusion logic in the run method, while optional lifecycle hooks allow for
setup and cleanup operations.
If model has _fusion_bench_target_modules attribute, the algorithm will only fuse
the specified target modules. This is useful for models where only certain layers
should be fused (e.g., classification heads on top of a shared backbone are not merged).
Attributes:
-
_program–Optional program reference for algorithm execution context.
-
_config_key(str) –Configuration key used for YAML serialization, defaults to "method".
Examples:
Creating a simple averaging algorithm:
>>> class SimpleAverageAlgorithm(BaseAlgorithm):
... def run(self, modelpool: BaseModelPool):
... # Implementation of model averaging logic
... return averaged_model
...
>>> algorithm = SimpleAverageAlgorithm()
>>> merged_model = algorithm.run(modelpool)
Loading algorithm from YAML configuration:
Note
Subclasses must implement the abstract run method to define the specific
fusion strategy (e.g., simple averaging, task arithmetic, etc.).
Source code in fusion_bench/method/base_algorithm.py
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 | |
on_run_end()
¶
Lifecycle hook called at the end of algorithm execution.
This method is invoked after the main run method completes, providing
an opportunity for subclasses to perform cleanup and finalization tasks such as:
- Logging execution statistics or results
- Cleaning up temporary resources
- Saving intermediate results or metrics
- Releasing computational resources
The method is called regardless of whether the run method succeeded or failed,
making it suitable for cleanup operations that should always occur.
The default implementation does nothing, allowing subclasses to override as needed for their specific requirements.
Examples:
>>> class MyAlgorithm(BaseAlgorithm):
... def on_run_end(self):
... super().on_run_end()
... elapsed = time.time() - self.start_time
... print(f"Fusion completed in {elapsed:.2f}s")
Source code in fusion_bench/method/base_algorithm.py
on_run_start()
¶
Lifecycle hook called at the beginning of algorithm execution.
This method is invoked before the main run method executes, providing
an opportunity for subclasses to perform initialization tasks such as:
- Setting up logging or monitoring
- Initializing algorithm-specific state
- Validating prerequisites
- Preparing computational resources
The default implementation does nothing, allowing subclasses to override as needed for their specific requirements.
Examples:
>>> class MyAlgorithm(BaseAlgorithm):
... def on_run_start(self):
... super().on_run_start()
... print("Starting model fusion...")
... self.start_time = time.time()
Source code in fusion_bench/method/base_algorithm.py
run(modelpool)
abstractmethod
¶
Execute the model fusion algorithm on the provided model pool.
This is the core method that must be implemented by all subclasses to define their specific fusion strategy. The method takes a pool of models and produces a fused result according to the algorithm's logic.
Parameters:
-
modelpool(BaseModelPool) –A collection of models to be fused. The modelpool provides access to individual models and their metadata, allowing the algorithm to iterate over models, access their parameters, and perform fusion operations.
Returns:
-
–
The type of return value depends on the specific algorithm implementation. Common return types include:
- A single fused model (torch.nn.Module)
- A dictionary of fused models for multi-task scenarios
- Fusion results with additional metadata
- Custom data structures specific to the algorithm
Raises:
-
NotImplementedError–If called on the base class without implementation.
-
ValueError–If the modelpool is invalid or incompatible with the algorithm.
-
RuntimeError–If fusion fails due to model incompatibilities or other issues.
Examples:
Simple averaging implementation:
>>> def run(self, modelpool: BaseModelPool):
... models = [model for model in modelpool]
... averaged_params = {}
... for name in models[0].state_dict():
... averaged_params[name] = torch.stack([
... model.state_dict()[name] for model in models
... ]).mean(dim=0)
... fused_model = copy.deepcopy(models[0])
... fused_model.load_state_dict(averaged_params)
... return fused_model
Task arithmetic implementation:
>>> def run(self, modelpool: BaseModelPool):
... pretrained = modelpool.get_model('pretrained')
... task_vectors = []
... for model_name in modelpool.model_names:
... if model_name != 'pretrained':
... task_vector = self.compute_task_vector(
... modelpool.get_model(model_name), pretrained
... )
... task_vectors.append(task_vector)
... return self.merge_task_vectors(pretrained, task_vectors)
Note
- The modelpool iteration order may affect results for non-commutative operations
- Ensure model compatibility (architecture, parameter shapes) before fusion
- Consider memory constraints when loading multiple large models
- Use appropriate device placement for GPU/CPU computation
Source code in fusion_bench/method/base_algorithm.py
BaseModelFusionAlgorithm
¶
Bases: BaseAlgorithm
Alias for BaseAlgorithm class.
.. deprecated::
BaseModelFusionAlgorithm is deprecated and will be removed in a future version.
Use :class:BaseAlgorithm instead.
This alias was provided for backward compatibility and semantic clarity. Both names refer to the same base class and can be used interchangeably, but BaseAlgorithm is now the preferred name for all implementations.
Examples:
Preferred (using BaseAlgorithm):
Deprecated (using BaseModelFusionAlgorithm):
>>> class MyAlgorithm(BaseModelFusionAlgorithm): # Will trigger deprecation warning
... def run(self, modelpool): pass
Note
New implementations should use :class:BaseAlgorithm exclusively.
The BaseModelFusionAlgorithm alias will be removed in a future release.
Warning
Using BaseModelFusionAlgorithm will trigger a DeprecationWarning.
Source code in fusion_bench/method/base_algorithm.py
Implemented Algorithms¶
In FusionBench, we categorize deep model fusion methods into three categories: Model Ensemble, Model Merging, and Model Mixing. Learn more at here