fusion_bench.mixins¶
The mixins module provides reusable functionality through mixin classes that can be combined with other classes to add specific capabilities. These mixins follow the composition-over-inheritance principle and are designed to be modular, flexible, and easy to integrate.
Basic Mixin Composition¶
from fusion_bench.mixins import (
LightningFabricMixin,
SimpleProfilerMixin,
auto_register_config
)
from fusion_bench import BaseAlgorithm
@auto_register_config
class MyAlgorithm(
LightningFabricMixin,
SimpleProfilerMixin,
BaseAlgorithm
):
def __init__(self, learning_rate: float = 0.001, batch_size: int = 32, **kwargs):
super().__init__(**kwargs)
def run(self, modelpool):
# implement the fusion logic here
pass
Class Definitions¶
Configuration and Instantiation¶
- fusion_bench.mixins.HydraConfigMixin: A mixin class that provides configuration-based instantiation capabilities.
- fusion_bench.mixins.auto_register_config: Decorator for automatically mapping constructor parameters to configuration keys.
Serialization and Persistence¶
- fusion_bench.mixins.YAMLSerializationMixin: Provides methods for serializing and deserializing objects to and from YAML format.
- fusion_bench.mixins.BaseYAMLSerializable: Base class for objects that support YAML serialization.
Distributed Computing and Training¶
- fusion_bench.mixins.LightningFabricMixin: Integrates with Lightning Fabric for automatic distributed environment and accelerator management.
- fusion_bench.mixins.FabricTrainingMixin: Extends Lightning Fabric integration with training-specific utilities.
Performance and Debugging¶
- fusion_bench.mixins.SimpleProfilerMixin: Provides simple profiling capabilities for measuring execution time.
- fusion_bench.mixins.PyinstrumentProfilerMixin: Offers advanced statistical profiling using the pyinstrument library.
Computer Vision¶
- fusion_bench.mixins.CLIPClassificationMixin: Supports CLIP-based image classification tasks.
Class Decorators¶
References¶
HydraConfigMixin
¶
A mixin class that provides configuration-based instantiation capabilities.
This mixin enables classes to be instantiated directly from Hydra configuration files, supporting both direct instantiation and target-based instantiation patterns. It's particularly useful in FusionBench for creating model pools, task pools, and fusion algorithms from YAML configurations.
The mixin handles: - Configuration loading and composition - Target class validation - Nested configuration group navigation - Object instantiation with proper error handling
Example:
class MyAlgorithm(HydraConfigMixin):
def __init__(self, param1: str, param2: int = 10):
self.param1 = param1
self.param2 = param2
# Instantiate from config
algorithm = MyAlgorithm.from_config("algorithms/my_algorithm")
Note
This mixin requires Hydra to be properly initialized before use. Typically, this is handled by the main FusionBench CLI application.
Source code in fusion_bench/mixins/hydra_config.py
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | |
from_config(config_name, overrides=None)
classmethod
¶
Create an instance of the class from a Hydra configuration.
This method loads a Hydra configuration file and instantiates the class using the configuration parameters. It supports both direct parameter passing and target-based instantiation patterns.
Parameters:
-
config_name(Union[str, Path]) –The name/path of the configuration file to load. Can be a string like "algorithms/simple_average" or a Path object. The .yaml extension is optional.
-
overrides(Optional[List[str]], default:None) –Optional list of configuration overrides in the format ["key=value", "nested.key=value"]. These allow runtime modification of configuration parameters.
Returns:
-
T–An instance of the class configured according to the loaded configuration.
Raises:
-
RuntimeError–If Hydra is not properly initialized.
-
ImportError–If a target class specified in the config cannot be imported.
-
ValueError–If required configuration parameters are missing.
Example
Note
The method automatically handles nested configuration groups by navigating through the configuration hierarchy based on the config_name path structure.
Source code in fusion_bench/mixins/hydra_config.py
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | |
YAMLSerializationMixin
¶
Source code in fusion_bench/mixins/serialization.py
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 | |
config
property
¶
Returns the configuration of the model pool as a DictConfig.
This property converts the model pool instance into a dictionary configuration, which can be used for serialization or other purposes.
Returns:
-
DictConfig(DictConfig) –The configuration of the model pool.
from_yaml(path)
classmethod
¶
Load a model pool from a YAML file.
Parameters:
-
path(Union[str, Path]) –The path to load the model pool from.
Returns:
-
BaseModelPool–The loaded model pool.
Source code in fusion_bench/mixins/serialization.py
register_parameter_to_config(attr_name, param_name, value)
¶
Set an attribute value and register its config mapping.
This method allows dynamic setting of object attributes while simultaneously updating the configuration mapping that defines how the attribute should be serialized in the configuration output.
Parameters:
-
attr_name(str) –The name of the attribute to set on this object.
-
arg_name(str) –The corresponding parameter name to use in the config serialization. This is how the attribute will appear in YAML output.
-
value–The value to assign to the attribute.
Source code in fusion_bench/mixins/serialization.py
to_yaml(path, resolve=True)
¶
Save the model pool to a YAML file.
Parameters:
-
path(Union[str, Path]) –The path to save the model pool to.
Source code in fusion_bench/mixins/serialization.py
BaseYAMLSerializable
¶
Bases: YAMLSerializationMixin
A base class for YAML-serializable classes with enhanced metadata support.
This class extends YAMLSerializationMixin to provide additional metadata
fields commonly used in FusionBench classes, including usage information
and version tracking. It serves as a foundation for all serializable
model components in the framework.
The class automatically handles serialization of usage and version metadata alongside the standard configuration parameters, making it easier to track model provenance and intended usage patterns.
Attributes:
-
_usage_(Optional[str]) –Description of the model's intended usage or purpose.
-
_version_(Optional[str]) –Version information for the model or configuration.
Example
class MyAlgorithm(BaseYAMLSerializable):
_config_mapping = BaseYAMLSerializable._config_mapping | {
"model_name": "model_name",
"num_layers": "num_layers",
}
def __init__(self, _usage_: str = None, _version_: str = None):
super().__init__(_usage_=_usage_, _version_=_version_)
# Usage with metadata
model = MyAlgorithm(
_usage_="Text classification fine-tuning",
_version_="1.0.0"
)
# Serialization includes metadata
config = model.config
# DictConfig({
# '_target_': 'MyModel',
# '_usage_': 'Text classification fine-tuning',
# '_version_': '1.0.0'
# })
Note
The underscore prefix in _usage_ and _version_ follows the convention
for metadata fields that are not core model parameters but provide
important contextual information for model management and tracking.
Source code in fusion_bench/mixins/serialization.py
247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 | |
__init__(_recursive_=False, _usage_=None, _version_=FUSION_BENCH_VERSION, **kwargs)
¶
Initialize a base YAML-serializable model with metadata support.
Parameters:
-
_usage_(Optional[str], default:None) –Description of the model's intended usage or purpose. This can include information about the training domain, expected input types, or specific use cases. Defaults to None.
-
_version_(Optional[str], default:FUSION_BENCH_VERSION) –Version information for the model or configuration. Can be used to track model iterations, dataset versions, or compatibility information. Defaults to None.
-
**kwargs–Additional keyword arguments passed to the parent class. Unused arguments will trigger warnings via the parent's initialization.
Example
Source code in fusion_bench/mixins/serialization.py
LightningFabricMixin
¶
A mixin class for integrating Lightning Fabric into a project.
This class provides methods to initialize and manage a Lightning Fabric instance for distributed computing, including setup with optional logging, device management for tensors and modules, and hyperparameter logging. It leverages the Lightning framework to facilitate distributed training and inference across multiple devices and nodes, with support for custom logging via TensorBoard.
Attributes:
- _fabric (L.Fabric): The Lightning Fabric instance used for distributed computing.
Note:
This mixin is designed to be used with classes that require distributed computing capabilities and wish to
leverage the Lightning Fabric for this purpose. It assumes the presence of a config attribute or parameter
in the consuming class for configuration.
Source code in fusion_bench/mixins/lightning_fabric.py
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 | |
fabric
property
writable
¶
Get the Lightning Fabric instance, initializing it if necessary.
Returns:
-
–
L.Fabric: The Lightning Fabric instance for distributed computing.
is_debug_mode
property
¶
Check if the program is running in debug mode (fast_dev_run).
Returns:
-
bool–True if fast_dev_run is enabled, False otherwise.
log_dir
property
¶
Retrieves the log directory from the fabric's logger.
tensorboard_summarywriter
property
¶
Get the TensorBoard SummaryWriter for detailed logging.
Returns:
-
SummaryWriter(SummaryWriter) –The TensorBoard SummaryWriter instance.
Raises:
-
AttributeError–If the logger is not a TensorBoardLogger.
__del__()
¶
finalize()
¶
Destructor to ensure proper cleanup of the Lightning Fabric instance.
Source code in fusion_bench/mixins/lightning_fabric.py
log(name, value, step=None)
¶
Logs a single metric to the fabric's logger.
Parameters:
-
name(str) –The name of the metric to log.
-
value(Any) –The value of the metric.
-
step(Optional[int], default:None) –Optional step number for the metric.
Source code in fusion_bench/mixins/lightning_fabric.py
log_artifact(local_path, artifact_path=None)
¶
Logs a file as an artifact to the fabric's logger.
Parameters:
-
local_dir–The path to the directory to log as an artifact.
-
artifact_path(str | None, default:None) –The directory within the logger's artifact storage to save the file.
Source code in fusion_bench/mixins/lightning_fabric.py
log_artifacts(local_dir, artifact_path=None)
¶
Logs a directory as artifacts to the fabric's logger.
Parameters:
-
local_dir(str) –The path to the directory to log as artifacts.
-
artifact_path(str | None, default:None) –The directory within the logger's artifact storage to save the files.
Source code in fusion_bench/mixins/lightning_fabric.py
log_dict(metrics, step=None)
¶
Logs multiple metrics to the fabric's logger.
Parameters:
-
metrics(Mapping[str, Any]) –Dictionary of metric names and values.
-
step(Optional[int], default:None) –Optional step number for the metrics.
Source code in fusion_bench/mixins/lightning_fabric.py
log_hyperparams(config=None, save_dir=None, filename='config.yaml')
¶
Logs the hyperparameters and saves the configuration to a YAML file.
The YAML file is saved in the log directory by default with the name config.yaml, or in the specified save directory save_dir.
Parameters:
-
config(Optional[DictConfig], default:None) –The configuration to log and save. If not provided, the class's
configattribute is used. -
save_dir(Optional[str], default:None) –The directory in which to save the configuration file. If not provided, the log directory is used.
-
filename(str, default:'config.yaml') –The name of the configuration file. Default is
config.yaml.
Source code in fusion_bench/mixins/lightning_fabric.py
log_optimizer_lr(optimizer, step=None, name_template='train/lr_group_{0}')
¶
Logs the learning rate of each parameter group in the optimizer.
Parameters:
-
optimizer(Optimizer) –The optimizer whose learning rates should be logged.
-
step(Optional[int], default:None) –Optional step number for the log entry.
-
name_template(str, default:'train/lr_group_{0}') –Template string for the log name. Use {0} as placeholder for group index.
Source code in fusion_bench/mixins/lightning_fabric.py
setup_lightning_fabric(config)
¶
Initializes and launches the Lightning Fabric with optional logging.
This method sets up the Lightning Fabric for distributed computing based on the provided configuration. If a fabric configuration is not found, it logs a warning and exits. Optionally, if a fabric logger configuration is provided, it initializes a TensorBoardLogger with the specified settings.
Expected configuration keys: - fabric: The configuration for the Lightning Fabric. - fabric.loggers: The configuration for the TensorBoardLogger.
Source code in fusion_bench/mixins/lightning_fabric.py
to_device(obj)
¶
Moves a tensor or module to the proper device.
Parameters:
-
obj(TensorOrModule) –The tensor or module to move to the device.
Returns:
-
TensorOrModule(TensorOrModule) –the same type of object as the input, moved to the device.
Source code in fusion_bench/mixins/lightning_fabric.py
FabricTrainingMixin
¶
Bases: LightningFabricMixin
This is a general purpose mixin for training a model with PyTorch Lightning.
Source code in fusion_bench/mixins/fabric_training.py
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 | |
accumulate_grad_batches
instance-attribute
¶
The number of gradient accumulation steps. The effective global batch size is the batch size per device x the number of devices x the number of gradient accumulation steps.
checkpoint_save_frequency
instance-attribute
¶
The frequency to save the model checkpoint.
checkpoint_save_interval
instance-attribute
¶
The interval to save the model checkpoint. Available options: 'step', 'epoch'.
epoch_idx
instance-attribute
¶
The epoch index, which is the number of epochs completed.
expected_total_steps
property
¶
The expected total number of steps of the entire training. You need to run compute_expected_total_steps method to compute this value before accessing it.
Raises:
-
ValueError–If the expected total steps have not been computed.
global_step_idx
instance-attribute
¶
The global step index, which is the number of parameter update steps.
gradient_clip_algorithm
instance-attribute
¶
The algorithm to clip gradients. Available options: 'value', 'norm'.
gradient_clip_val
instance-attribute
¶
The value to clip gradients. If None, no clipping is applied.
is_training
instance-attribute
¶
Whether the training is in progress. If set to False, the training will stop.
lr_scheduler_frequency
instance-attribute
¶
The frequency to run the learning rate scheduler.
lr_scheduler_interval
instance-attribute
¶
The interval to run the learning rate scheduler. Available options: 'step', 'epoch'.
max_epochs
instance-attribute
¶
Max number of epochs of the entire training.
max_steps
instance-attribute
¶
Max number of parameter update steps of the entire training.
max_steps_per_epoch
instance-attribute
¶
Max number of parameter update steps per epoch.
clip_gradients_if_needed(model, optimizer)
¶
Clips gradients if the gradient clipping value is set.
Parameters:
-
model(Module) –The model whose gradients need to be clipped.
-
optimizer(Optimizer) –The optimizer used for training.
Source code in fusion_bench/mixins/fabric_training.py
compute_expected_total_steps(train_dataloader)
¶
Computes the expected total number of steps for the entire training.
Parameters:
-
train_dataloader(DataLoader) –The dataloader for the training data.
Source code in fusion_bench/mixins/fabric_training.py
conditional_checkpoint_save(stage, *args, **kwargs)
¶
Conditionally saves a checkpoint based on the current training stage.
Parameters:
-
stage(Literal['end_of_step', 'end_of_epoch', 'end_of_training']) –The current stage of training.
Source code in fusion_bench/mixins/fabric_training.py
save_checkpoint(path, **kwargs)
abstractmethod
¶
Saves a checkpoint of the model.
Parameters:
-
path(str) –The path where the checkpoint will be saved.
Raises:
-
NotImplementedError–If the method is not implemented.
Source code in fusion_bench/mixins/fabric_training.py
train(model, optimizer, lr_scheduler)
¶
Trains the model.
The global batch size is the batch size per device x the number of devices x the number of gradient accumulation steps.
Parameters:
-
model(Union[Module, _FabricModule]) –The model to be trained.
-
optimizer(Union[Optimizer, _FabricOptimizer]) –The optimizer used for training.
-
lr_scheduler(LRScheduler) –The learning rate scheduler.
Source code in fusion_bench/mixins/fabric_training.py
train_epoch(model, optimizer, lr_scheduler)
abstractmethod
¶
Trains the model for one epoch.
Parameters:
-
model(Union[Module, _FabricModule]) –The model to be trained.
-
optimizer(Union[Optimizer, _FabricOptimizer]) –The optimizer used for training.
-
lr_scheduler(LRScheduler) –The learning rate scheduler.
Raises:
-
NotImplementedError–If the method is not implemented.
Source code in fusion_bench/mixins/fabric_training.py
234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 | |
SimpleProfilerMixin
¶
A mixin class that provides simple profiling capabilities using Lightning's SimpleProfiler.
This mixin allows for easy profiling of code blocks using a context manager or manual start/stop methods. It measures the execution time of named actions and provides a summary of the profiling results. Unlike statistical profilers, this provides precise timing measurements for specific code blocks.
Note
This mixin uses Lightning's SimpleProfiler which measures wall-clock time for named actions. It's suitable for timing discrete operations rather than detailed function-level profiling.
Examples:
class MyClass(SimpleProfilerMixin):
def do_something(self):
with self.profile("data_loading"):
# Load data here
data = load_data()
with self.profile("model_training"):
# Train model here
model.train(data)
# Print the profiling summary
self.print_profile_summary("Training Profile")
Attributes:
-
_profiler(SimpleProfiler) –An instance of the SimpleProfiler class used for profiling.
Source code in fusion_bench/mixins/simple_profiler.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | |
profiler
property
¶
Get the SimpleProfiler instance, creating it if necessary.
Returns:
-
SimpleProfiler(SimpleProfiler) –The profiler instance used for timing measurements.
__del__()
¶
Cleanup when the object is destroyed.
Ensures that the profiler instance is properly cleaned up to prevent memory leaks when the mixin instance is garbage collected.
Source code in fusion_bench/mixins/simple_profiler.py
print_profile_summary(title=None)
¶
Print a summary of all profiled actions.
This method outputs a formatted summary showing the timing information for all actions that have been profiled. The output includes action names and their execution times.
Parameters:
-
title(Optional[str], default:None) –Optional title to print before the profiling summary. If provided, this will be printed as a header.
Note
This method is decorated with @rank_zero_only, meaning it will only execute on the main process in distributed training scenarios.
Source code in fusion_bench/mixins/simple_profiler.py
profile(action_name)
¶
Context manager for profiling a code block.
This context manager automatically starts profiling when entering the block and stops profiling when exiting the block (even if an exception occurs).
Parameters:
-
action_name(str) –A descriptive name for the action being profiled. This name will appear in the profiling summary.
Yields:
-
str(Generator) –The action name that was provided.
Example:
Source code in fusion_bench/mixins/simple_profiler.py
start_profile(action_name)
¶
Start profiling for a named action.
This method begins timing for the specified action. You must call stop_profile() with the same action name to complete the measurement.
Parameters:
-
action_name(str) –A descriptive name for the action being profiled. This name will appear in the profiling summary.
Example
Source code in fusion_bench/mixins/simple_profiler.py
stop_profile(action_name)
¶
Stop profiling for a named action.
This method ends timing for the specified action that was previously started with start_profile().
Parameters:
-
action_name(str) –The name of the action to stop profiling. Must match the name used in start_profile().
Source code in fusion_bench/mixins/simple_profiler.py
PyinstrumentProfilerMixin
¶
A mixin class that provides statistical profiling capabilities using pyinstrument.
This mixin allows for easy profiling of code blocks using a context manager. It provides methods to start and stop profiling actions, save profiling results to files, and print profiling summaries.
Note
This mixin requires the pyinstrument package to be installed.
If not available, an ImportError will be raised when importing this module.
Examples:
class MyClass(PyinstrumentProfilerMixin):
def do_something(self):
with self.profile("work"):
# do some work here
...
# save the profiling results
self.save_profile_report("profile_report.html")
# or print the summary
self.print_profile_summary()
Attributes:
-
_profiler(Profiler) –An instance of the pyinstrument Profiler class.
Source code in fusion_bench/mixins/pyinstrument.py
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 | |
profiler
property
¶
Get the profiler instance, creating it if necessary.
__del__()
¶
print_profile_summary(title=None, unicode=True, color=True)
¶
Print a summary of the profiling results.
Parameters:
-
title(Optional[str], default:None) –Optional title to print before the summary.
-
unicode(bool, default:True) –Whether to use unicode characters in the output.
-
color(bool, default:True) –Whether to use color in the output.
Source code in fusion_bench/mixins/pyinstrument.py
profile(action_name=None)
¶
Context manager for profiling a code block.
Parameters:
-
action_name(Optional[str], default:None) –Optional name for the profiling action (for logging purposes).
Example:
Source code in fusion_bench/mixins/pyinstrument.py
reset_profile()
¶
save_profile_report(output_path='profile_report.html', format='html', title=None)
¶
Save the profiling results to a file.
Parameters:
-
output_path(Union[str, Path], default:'profile_report.html') –Path where to save the profiling report.
-
format(str, default:'html') –Output format ('html', or 'text').
-
title(Optional[str], default:None) –Optional title for the report.
Source code in fusion_bench/mixins/pyinstrument.py
start_profile(action_name=None)
¶
Start profiling.
Parameters:
-
action_name(Optional[str], default:None) –Optional name for the profiling action.
Source code in fusion_bench/mixins/pyinstrument.py
stop_profile(action_name=None)
¶
Stop profiling.
Parameters:
-
action_name(Optional[str], default:None) –Optional name for the profiling action.
Source code in fusion_bench/mixins/pyinstrument.py
CLIPClassificationMixin
¶
Bases: LightningFabricMixin
This mixin provides methods to classify images using the CLIP model.
Attributes need to be set by the inheriting class:
_dataloader_kwargs(Dict[str, Any]): Keyword arguments for the dataloader.modelpool(CLIPVisionModelPool): The model pool containing the CLIP models.
Source code in fusion_bench/mixins/clip_classification.py
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 | |
clip_processor
property
¶
Get the CLIP processor, loading it from the model pool if necessary.
Returns:
-
CLIPProcessor–The CLIP processor for image and text preprocessing.
Raises:
-
AssertionError–If the model pool is not set.
compute_features(module, images, normalize=True)
¶
Extracts image features using CLIP's vision encoder and visual projection.
Parameters:
-
module(Union[Module, CLIPVisionModel, CLIPVisionTransformer]) –The CLIP vision encoder module.
-
images(Tensor) –Input image batch to process.
-
normalize(bool, default:True) –Whether to normalize the image embeddings.
Returns:
-
Tensor–torch.Tensor: Normalized image embeddings with dimension matching CLIP's projection space (
projection_dimin model config).
Source code in fusion_bench/mixins/clip_classification.py
compute_logits(module, images, task, image_embeds=None)
¶
Computes the classification logits for a batch of images for a specific task.
This method performs zero-shot classification by calculating the cosine similarity between image and text embeddings.
The image embeddings are obtained from the provided vision model, and the text embeddings (zero-shot weights) are pre-computed for the task.
The similarity scores are then scaled by the CLIP model's logit_scale to produce the final logits.
Parameters:
-
module(Union[Module, CLIPVisionModel, CLIPVisionTransformer]) –The vision encoder part of the CLIP model.
-
images(Tensor) –A batch of images to classify.
-
task(str) –The name of the classification task.
-
image_embeds(Optional[Tensor], default:None) –Pre-computed image embeddings. If provided, the method skips the image encoding step.
Returns:
-
Tensor–torch.Tensor: A tensor of logits for each image, with shape (batch_size, num_classes).
Source code in fusion_bench/mixins/clip_classification.py
get_shuffled_test_loader_iter(task, batch_size=None, num_workers=None, **loader_kwargs)
cached
¶
Get an iterator for a shuffled test DataLoader.
This method creates a DataLoader for the test dataset of the specified task, with shuffling enabled. It allows for optional customization of batch size, number of workers, and other DataLoader keyword arguments.
Parameters:
-
task(str) –The task identifier for which the test dataset is to be loaded.
-
batch_size(Optional[int], default:None) –The batch size to use for the DataLoader. If None, the default batch size is used.
-
num_workers(Optional[int], default:None) –The number of worker processes to use for data loading. If None, the default number of workers is used.
-
**loader_kwargs–Additional keyword arguments to pass to the DataLoader.
Returns:
-
Iterator(Iterator) –An iterator over the shuffled test DataLoader.
Source code in fusion_bench/mixins/clip_classification.py
setup_zero_shot_classification_head(clip_processor=None, clip_model=None, task_names=None)
¶
Initializes a zero-shot classification head.
This method constructs a zero-shot classification head by generating text embeddings for each class name using a set of templates.
These embeddings function as the weights of the classification layer. The method also extracts the visual_projection and logit_scale
from the provided CLIP model, which are necessary for calculating the final logits.
Parameters:
-
clip_processor(Optional[CLIPProcessor], default:None) –The processor for the CLIP model. If not provided, it is loaded from the model pool.
-
clip_model(Optional[CLIPModel], default:None) –The CLIP model to use. If not provided, a pretrained model is loaded from the model pool.
-
task_names(Optional[List[str]], default:None) –A list of task names to set up the classification head for. If not provided, all models in the model pool will be used.
Source code in fusion_bench/mixins/clip_classification.py
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 | |
auto_register_config(cls)
¶
Decorator to automatically register init parameters in _config_mapping.
This decorator enhances classes that inherit from YAMLSerializationMixin by automatically mapping constructor parameters to configuration keys and dynamically setting instance attributes based on provided arguments.
The decorator performs the following operations: 1. Inspects the class's init method signature 2. Automatically populates the _config_mapping dictionary with parameter names 3. Wraps the init method to handle both positional and keyword arguments 4. Sets instance attributes for all constructor parameters 5. Applies default values when parameters are not provided
Parameters:
-
cls(YAMLSerializationMixin) –The class to be decorated. Must inherit from YAMLSerializationMixin to ensure proper serialization capabilities.
Returns:
-
YAMLSerializationMixin–The decorated class with enhanced auto-registration functionality and modified init behavior.
Behavior
- Parameter Registration: All non-variadic parameters (excluding
*args,**kwargs) from the init method are automatically added to _config_mapping - Positional Arguments: Handled in order and mapped to corresponding parameter names
- Keyword Arguments: Processed after positional arguments, overriding any conflicts
- Default Values: Applied when parameters are not provided via arguments
- Attribute Setting: All parameters become instance attributes accessible via dot notation
Note
- The decorator wraps the original init method while preserving its signature for IDE support
- Parameters with
*argsor**kwargssignatures are ignored during registration - The attributes are auto-registered, then the original init method is called,
- Type hints, method name, and other metadata are preserved using functools.wraps
- This decorator is designed to work seamlessly with the YAML serialization system
Raises:
-
AttributeError–If the class does not have the required _config_mapping attribute infrastructure (should inherit from YAMLSerializationMixin)
Source code in fusion_bench/mixins/serialization.py
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | |