Skip to content

fusion_bench.mixins

The mixins module provides reusable functionality through mixin classes that can be combined with other classes to add specific capabilities. These mixins follow the composition-over-inheritance principle and are designed to be modular, flexible, and easy to integrate.

Basic Mixin Composition

from fusion_bench.mixins import (
    LightningFabricMixin,
    SimpleProfilerMixin,
    auto_register_config
)
from fusion_bench import BaseAlgorithm

@auto_register_config
class MyAlgorithm(
    LightningFabricMixin,
    SimpleProfilerMixin,
    BaseAlgorithm
):
    def __init__(self, learning_rate: float = 0.001, batch_size: int = 32, **kwargs):
        super().__init__(**kwargs)

    def run(self, modelpool):
        # implement the fusion logic here
        pass

Class Definitions

Configuration and Instantiation

Serialization and Persistence

Distributed Computing and Training

Performance and Debugging

Computer Vision

Class Decorators

References

HydraConfigMixin

A mixin class that provides configuration-based instantiation capabilities.

This mixin enables classes to be instantiated directly from Hydra configuration files, supporting both direct instantiation and target-based instantiation patterns. It's particularly useful in FusionBench for creating model pools, task pools, and fusion algorithms from YAML configurations.

The mixin handles: - Configuration loading and composition - Target class validation - Nested configuration group navigation - Object instantiation with proper error handling

Example:

class MyAlgorithm(HydraConfigMixin):
    def __init__(self, param1: str, param2: int = 10):
        self.param1 = param1
        self.param2 = param2

# Instantiate from config
algorithm = MyAlgorithm.from_config("algorithms/my_algorithm")
Note

This mixin requires Hydra to be properly initialized before use. Typically, this is handled by the main FusionBench CLI application.

Source code in fusion_bench/mixins/hydra_config.py
class HydraConfigMixin:
    R"""
    A mixin class that provides configuration-based instantiation capabilities.

    This mixin enables classes to be instantiated directly from Hydra configuration
    files, supporting both direct instantiation and target-based instantiation patterns.
    It's particularly useful in FusionBench for creating model pools, task pools,
    and fusion algorithms from YAML configurations.

    The mixin handles:
    - Configuration loading and composition
    - Target class validation
    - Nested configuration group navigation
    - Object instantiation with proper error handling

    Example:

    ```python
    class MyAlgorithm(HydraConfigMixin):
        def __init__(self, param1: str, param2: int = 10):
            self.param1 = param1
            self.param2 = param2

    # Instantiate from config
    algorithm = MyAlgorithm.from_config("algorithms/my_algorithm")
    ```

    Note:
        This mixin requires Hydra to be properly initialized before use.
        Typically, this is handled by the main FusionBench CLI application.
    """

    @classmethod
    def from_config(
        cls,
        config_name: Union[str, Path],
        overrides: Optional[List[str]] = None,
    ) -> T:
        """
        Create an instance of the class from a Hydra configuration.

        This method loads a Hydra configuration file and instantiates the class
        using the configuration parameters. It supports both direct parameter
        passing and target-based instantiation patterns.

        Args:
            config_name: The name/path of the configuration file to load.
                        Can be a string like "algorithms/simple_average" or
                        a Path object. The .yaml extension is optional.
            overrides: Optional list of configuration overrides in the format
                      ["key=value", "nested.key=value"]. These allow runtime
                      modification of configuration parameters.

        Returns:
            An instance of the class configured according to the loaded configuration.

        Raises:
            RuntimeError: If Hydra is not properly initialized.
            ImportError: If a target class specified in the config cannot be imported.
            ValueError: If required configuration parameters are missing.

        Example:
            ```python
            # Load with basic config
            obj = MyClass.from_config("my_config")

            # Load with overrides
            obj = MyClass.from_config(
                "my_config",
                overrides=["param1=new_value", "param2=42"]
            )

            # Load nested config
            obj = MyClass.from_config("category/subcategory/my_config")
            ```

        Note:
            The method automatically handles nested configuration groups by
            navigating through the configuration hierarchy based on the
            config_name path structure.
        """
        # Verify Hydra initialization
        if not hydra.core.global_hydra.GlobalHydra.instance().is_initialized():
            raise RuntimeError(
                "Hydra is not initialized. Please ensure Hydra is properly "
                "initialized before calling from_config(). This is typically "
                "handled by the FusionBench CLI application."
            )
        else:
            # Compose the configuration with any provided overrides
            cfg = compose(config_name=config_name, overrides=overrides)

        # Navigate through nested configuration groups
        # E.g., "algorithms/simple_average" -> navigate to cfg.algorithms
        config_groups = config_name.split("/")[:-1]
        for config_group in config_groups:
            cfg = cfg[config_group]

        # Handle target-based instantiation
        if "_target_" in cfg:
            # Validate that the target class matches the calling class
            target_cls = import_object(cfg["_target_"])
            if target_cls != cls:
                log.warning(
                    f"Configuration target mismatch: config specifies "
                    f"'{cfg['_target_']}' but called on class '{cls.__name__}'. "
                    f"This may indicate a configuration error."
                )

            # Instantiate using the target pattern with function call logging disabled
            with set_print_function_call(False):
                obj = instantiate(cfg)
        else:
            # Direct instantiation using configuration as keyword arguments
            obj = cls(**cfg)

        return obj

from_config(config_name, overrides=None) classmethod

Create an instance of the class from a Hydra configuration.

This method loads a Hydra configuration file and instantiates the class using the configuration parameters. It supports both direct parameter passing and target-based instantiation patterns.

Parameters:

  • config_name (Union[str, Path]) –

    The name/path of the configuration file to load. Can be a string like "algorithms/simple_average" or a Path object. The .yaml extension is optional.

  • overrides (Optional[List[str]], default: None ) –

    Optional list of configuration overrides in the format ["key=value", "nested.key=value"]. These allow runtime modification of configuration parameters.

Returns:

  • T

    An instance of the class configured according to the loaded configuration.

Raises:

  • RuntimeError

    If Hydra is not properly initialized.

  • ImportError

    If a target class specified in the config cannot be imported.

  • ValueError

    If required configuration parameters are missing.

Example
# Load with basic config
obj = MyClass.from_config("my_config")

# Load with overrides
obj = MyClass.from_config(
    "my_config",
    overrides=["param1=new_value", "param2=42"]
)

# Load nested config
obj = MyClass.from_config("category/subcategory/my_config")
Note

The method automatically handles nested configuration groups by navigating through the configuration hierarchy based on the config_name path structure.

Source code in fusion_bench/mixins/hydra_config.py
@classmethod
def from_config(
    cls,
    config_name: Union[str, Path],
    overrides: Optional[List[str]] = None,
) -> T:
    """
    Create an instance of the class from a Hydra configuration.

    This method loads a Hydra configuration file and instantiates the class
    using the configuration parameters. It supports both direct parameter
    passing and target-based instantiation patterns.

    Args:
        config_name: The name/path of the configuration file to load.
                    Can be a string like "algorithms/simple_average" or
                    a Path object. The .yaml extension is optional.
        overrides: Optional list of configuration overrides in the format
                  ["key=value", "nested.key=value"]. These allow runtime
                  modification of configuration parameters.

    Returns:
        An instance of the class configured according to the loaded configuration.

    Raises:
        RuntimeError: If Hydra is not properly initialized.
        ImportError: If a target class specified in the config cannot be imported.
        ValueError: If required configuration parameters are missing.

    Example:
        ```python
        # Load with basic config
        obj = MyClass.from_config("my_config")

        # Load with overrides
        obj = MyClass.from_config(
            "my_config",
            overrides=["param1=new_value", "param2=42"]
        )

        # Load nested config
        obj = MyClass.from_config("category/subcategory/my_config")
        ```

    Note:
        The method automatically handles nested configuration groups by
        navigating through the configuration hierarchy based on the
        config_name path structure.
    """
    # Verify Hydra initialization
    if not hydra.core.global_hydra.GlobalHydra.instance().is_initialized():
        raise RuntimeError(
            "Hydra is not initialized. Please ensure Hydra is properly "
            "initialized before calling from_config(). This is typically "
            "handled by the FusionBench CLI application."
        )
    else:
        # Compose the configuration with any provided overrides
        cfg = compose(config_name=config_name, overrides=overrides)

    # Navigate through nested configuration groups
    # E.g., "algorithms/simple_average" -> navigate to cfg.algorithms
    config_groups = config_name.split("/")[:-1]
    for config_group in config_groups:
        cfg = cfg[config_group]

    # Handle target-based instantiation
    if "_target_" in cfg:
        # Validate that the target class matches the calling class
        target_cls = import_object(cfg["_target_"])
        if target_cls != cls:
            log.warning(
                f"Configuration target mismatch: config specifies "
                f"'{cfg['_target_']}' but called on class '{cls.__name__}'. "
                f"This may indicate a configuration error."
            )

        # Instantiate using the target pattern with function call logging disabled
        with set_print_function_call(False):
            obj = instantiate(cfg)
    else:
        # Direct instantiation using configuration as keyword arguments
        obj = cls(**cfg)

    return obj

YAMLSerializationMixin

Source code in fusion_bench/mixins/serialization.py
class YAMLSerializationMixin:
    _config_key: Optional[str] = None
    _config_mapping: MutableBidict[str, str] = bidict()
    R"""
    `_config_mapping` is a dictionary mapping the attribute names of the class to the config option names. This is used to convert the class to a DictConfig.

    >>> algorithm.config
        DictCOnfig({'_target_': 'SomeModelFusionAlgorithm', 'hyper_param_1': 1, 'hyper_param_2': 2})

    By default, the `_target_` key is set to the class name as `type(self).__name__`.
    """

    def __init__(self, **kwargs) -> None:
        for key, value in kwargs.items():
            log.warning(f"Unused argument: {key}={value}")

    @property
    def config(self) -> DictConfig:
        R"""
        Returns the configuration of the model pool as a DictConfig.

        This property converts the model pool instance into a dictionary
        configuration, which can be used for serialization or other purposes.

        Returns:
            DictConfig: The configuration of the model pool.
        """
        config = {"_target_": f"{type(self).__module__}.{type(self).__qualname__}"}
        for attr, key in self._config_mapping.items():
            if hasattr(self, attr):
                config[key] = getattr(self, attr)

        try:
            return OmegaConf.create(config)
        except Exception as e:
            return OmegaConf.create(config, flags={"allow_objects": True})

    def to_yaml(self, path: Union[str, Path], resolve: bool = True):
        """
        Save the model pool to a YAML file.

        Args:
            path (Union[str, Path]): The path to save the model pool to.
        """
        OmegaConf.save(self.config, path, resolve=resolve)

    @classmethod
    def from_yaml(cls, path: Union[str, Path]):
        """
        Load a model pool from a YAML file.

        Args:
            path (Union[str, Path]): The path to load the model pool from.

        Returns:
            BaseModelPool: The loaded model pool.
        """
        config = OmegaConf.load(path)
        if cls._config_key is not None and cls._config_key in config:
            config = config[cls._config_key]
        target_cls = import_object(config["_target_"])
        if target_cls != cls:
            log.warning(
                f"The class {target_cls.__name__} is not the same as the class {cls.__name__}. "
                f"Instantiating the class {target_cls.__name__} instead."
            )
        with set_print_function_call(False):
            return instantiate(config)

    def register_parameter_to_config(
        self,
        attr_name: str,
        param_name: str,
        value,
    ):
        """
        Set an attribute value and register its config mapping.

        This method allows dynamic setting of object attributes while simultaneously
        updating the configuration mapping that defines how the attribute should
        be serialized in the configuration output.

        Args:
            attr_name (str): The name of the attribute to set on this object.
            arg_name (str): The corresponding parameter name to use in the config
                serialization. This is how the attribute will appear in YAML output.
            value: The value to assign to the attribute.

        """
        setattr(self, attr_name, value)
        self._config_mapping[attr_name] = param_name

config property

Returns the configuration of the model pool as a DictConfig.

This property converts the model pool instance into a dictionary configuration, which can be used for serialization or other purposes.

Returns:

  • DictConfig ( DictConfig ) –

    The configuration of the model pool.

from_yaml(path) classmethod

Load a model pool from a YAML file.

Parameters:

  • path (Union[str, Path]) –

    The path to load the model pool from.

Returns:

  • BaseModelPool

    The loaded model pool.

Source code in fusion_bench/mixins/serialization.py
@classmethod
def from_yaml(cls, path: Union[str, Path]):
    """
    Load a model pool from a YAML file.

    Args:
        path (Union[str, Path]): The path to load the model pool from.

    Returns:
        BaseModelPool: The loaded model pool.
    """
    config = OmegaConf.load(path)
    if cls._config_key is not None and cls._config_key in config:
        config = config[cls._config_key]
    target_cls = import_object(config["_target_"])
    if target_cls != cls:
        log.warning(
            f"The class {target_cls.__name__} is not the same as the class {cls.__name__}. "
            f"Instantiating the class {target_cls.__name__} instead."
        )
    with set_print_function_call(False):
        return instantiate(config)

register_parameter_to_config(attr_name, param_name, value)

Set an attribute value and register its config mapping.

This method allows dynamic setting of object attributes while simultaneously updating the configuration mapping that defines how the attribute should be serialized in the configuration output.

Parameters:

  • attr_name (str) –

    The name of the attribute to set on this object.

  • arg_name (str) –

    The corresponding parameter name to use in the config serialization. This is how the attribute will appear in YAML output.

  • value

    The value to assign to the attribute.

Source code in fusion_bench/mixins/serialization.py
def register_parameter_to_config(
    self,
    attr_name: str,
    param_name: str,
    value,
):
    """
    Set an attribute value and register its config mapping.

    This method allows dynamic setting of object attributes while simultaneously
    updating the configuration mapping that defines how the attribute should
    be serialized in the configuration output.

    Args:
        attr_name (str): The name of the attribute to set on this object.
        arg_name (str): The corresponding parameter name to use in the config
            serialization. This is how the attribute will appear in YAML output.
        value: The value to assign to the attribute.

    """
    setattr(self, attr_name, value)
    self._config_mapping[attr_name] = param_name

to_yaml(path, resolve=True)

Save the model pool to a YAML file.

Parameters:

  • path (Union[str, Path]) –

    The path to save the model pool to.

Source code in fusion_bench/mixins/serialization.py
def to_yaml(self, path: Union[str, Path], resolve: bool = True):
    """
    Save the model pool to a YAML file.

    Args:
        path (Union[str, Path]): The path to save the model pool to.
    """
    OmegaConf.save(self.config, path, resolve=resolve)

BaseYAMLSerializable

Bases: YAMLSerializationMixin

A base class for YAML-serializable classes with enhanced metadata support.

This class extends YAMLSerializationMixin to provide additional metadata fields commonly used in FusionBench classes, including usage information and version tracking. It serves as a foundation for all serializable model components in the framework.

The class automatically handles serialization of usage and version metadata alongside the standard configuration parameters, making it easier to track model provenance and intended usage patterns.

Attributes:

  • _usage_ (Optional[str]) –

    Description of the model's intended usage or purpose.

  • _version_ (Optional[str]) –

    Version information for the model or configuration.

Example
class MyAlgorithm(BaseYAMLSerializable):
    _config_mapping = BaseYAMLSerializable._config_mapping | {
        "model_name": "model_name",
        "num_layers": "num_layers",
    }

    def __init__(self, _usage_: str = None, _version_: str = None):
        super().__init__(_usage_=_usage_, _version_=_version_)

# Usage with metadata
model = MyAlgorithm(
    _usage_="Text classification fine-tuning",
    _version_="1.0.0"
)

# Serialization includes metadata
config = model.config
# DictConfig({
#     '_target_': 'MyModel',
#     '_usage_': 'Text classification fine-tuning',
#     '_version_': '1.0.0'
# })
Note

The underscore prefix in _usage_ and _version_ follows the convention for metadata fields that are not core model parameters but provide important contextual information for model management and tracking.

Source code in fusion_bench/mixins/serialization.py
@auto_register_config
class BaseYAMLSerializable(YAMLSerializationMixin):
    """
    A base class for YAML-serializable classes with enhanced metadata support.

    This class extends `YAMLSerializationMixin` to provide additional metadata
    fields commonly used in FusionBench classes, including usage information
    and version tracking. It serves as a foundation for all serializable
    model components in the framework.

    The class automatically handles serialization of usage and version metadata
    alongside the standard configuration parameters, making it easier to track
    model provenance and intended usage patterns.

    Attributes:
        _usage_ (Optional[str]): Description of the model's intended usage or purpose.
        _version_ (Optional[str]): Version information for the model or configuration.

    Example:
        ```python
        class MyAlgorithm(BaseYAMLSerializable):
            _config_mapping = BaseYAMLSerializable._config_mapping | {
                "model_name": "model_name",
                "num_layers": "num_layers",
            }

            def __init__(self, _usage_: str = None, _version_: str = None):
                super().__init__(_usage_=_usage_, _version_=_version_)

        # Usage with metadata
        model = MyAlgorithm(
            _usage_="Text classification fine-tuning",
            _version_="1.0.0"
        )

        # Serialization includes metadata
        config = model.config
        # DictConfig({
        #     '_target_': 'MyModel',
        #     '_usage_': 'Text classification fine-tuning',
        #     '_version_': '1.0.0'
        # })
        ```

    Note:
        The underscore prefix in `_usage_` and `_version_` follows the convention
        for metadata fields that are not core model parameters but provide
        important contextual information for model management and tracking.
    """

    def __init__(
        self,
        _recursive_: bool = False,
        _usage_: Optional[str] = None,
        _version_: Optional[str] = FUSION_BENCH_VERSION,
        **kwargs,
    ):
        """
        Initialize a base YAML-serializable model with metadata support.

        Args:
            _usage_ (Optional[str], optional): Description of the model's intended
                usage or purpose. This can include information about the training
                domain, expected input types, or specific use cases. Defaults to None.
            _version_ (Optional[str], optional): Version information for the model
                or configuration. Can be used to track model iterations, dataset
                versions, or compatibility information. Defaults to None.
            **kwargs: Additional keyword arguments passed to the parent class.
                Unused arguments will trigger warnings via the parent's initialization.

        Example:
            ```python
            model = BaseYAMLSerializable(
                _usage_="Image classification on CIFAR-10",
                _version_="2.1.0"
            )
            ```
        """
        super().__init__(**kwargs)
        if _version_ != FUSION_BENCH_VERSION:
            log.warning(
                f"Current fusion-bench version is {FUSION_BENCH_VERSION}, but the serialized version is {_version_}. "
                "Attempting to use current version."
            )
            # override _version_ with current fusion-bench version
            self._version_ = FUSION_BENCH_VERSION

__init__(_recursive_=False, _usage_=None, _version_=FUSION_BENCH_VERSION, **kwargs)

Initialize a base YAML-serializable model with metadata support.

Parameters:

  • _usage_ (Optional[str], default: None ) –

    Description of the model's intended usage or purpose. This can include information about the training domain, expected input types, or specific use cases. Defaults to None.

  • _version_ (Optional[str], default: FUSION_BENCH_VERSION ) –

    Version information for the model or configuration. Can be used to track model iterations, dataset versions, or compatibility information. Defaults to None.

  • **kwargs

    Additional keyword arguments passed to the parent class. Unused arguments will trigger warnings via the parent's initialization.

Example
model = BaseYAMLSerializable(
    _usage_="Image classification on CIFAR-10",
    _version_="2.1.0"
)
Source code in fusion_bench/mixins/serialization.py
def __init__(
    self,
    _recursive_: bool = False,
    _usage_: Optional[str] = None,
    _version_: Optional[str] = FUSION_BENCH_VERSION,
    **kwargs,
):
    """
    Initialize a base YAML-serializable model with metadata support.

    Args:
        _usage_ (Optional[str], optional): Description of the model's intended
            usage or purpose. This can include information about the training
            domain, expected input types, or specific use cases. Defaults to None.
        _version_ (Optional[str], optional): Version information for the model
            or configuration. Can be used to track model iterations, dataset
            versions, or compatibility information. Defaults to None.
        **kwargs: Additional keyword arguments passed to the parent class.
            Unused arguments will trigger warnings via the parent's initialization.

    Example:
        ```python
        model = BaseYAMLSerializable(
            _usage_="Image classification on CIFAR-10",
            _version_="2.1.0"
        )
        ```
    """
    super().__init__(**kwargs)
    if _version_ != FUSION_BENCH_VERSION:
        log.warning(
            f"Current fusion-bench version is {FUSION_BENCH_VERSION}, but the serialized version is {_version_}. "
            "Attempting to use current version."
        )
        # override _version_ with current fusion-bench version
        self._version_ = FUSION_BENCH_VERSION

LightningFabricMixin

A mixin class for integrating Lightning Fabric into a project.

This class provides methods to initialize and manage a Lightning Fabric instance for distributed computing, including setup with optional logging, device management for tensors and modules, and hyperparameter logging. It leverages the Lightning framework to facilitate distributed training and inference across multiple devices and nodes, with support for custom logging via TensorBoard.

Attributes:

  • _fabric (L.Fabric): The Lightning Fabric instance used for distributed computing.

Note:

This mixin is designed to be used with classes that require distributed computing capabilities and wish to leverage the Lightning Fabric for this purpose. It assumes the presence of a config attribute or parameter in the consuming class for configuration.

Source code in fusion_bench/mixins/lightning_fabric.py
class LightningFabricMixin:
    """
    A mixin class for integrating Lightning Fabric into a project.

    This class provides methods to initialize and manage a Lightning Fabric instance for distributed computing,
    including setup with optional logging, device management for tensors and modules, and hyperparameter logging.
    It leverages the Lightning framework to facilitate distributed training and inference across multiple devices
    and nodes, with support for custom logging via TensorBoard.

    Attributes:

    - _fabric (L.Fabric): The Lightning Fabric instance used for distributed computing.

    Note:

    This mixin is designed to be used with classes that require distributed computing capabilities and wish to
    leverage the Lightning Fabric for this purpose. It assumes the presence of a `config` attribute or parameter
    in the consuming class for configuration.
    """

    _fabric_instance: L.Fabric = None

    def setup_lightning_fabric(self, config: DictConfig):
        """
        Initializes and launches the Lightning Fabric with optional logging.

        This method sets up the Lightning Fabric for distributed computing based on the provided configuration. If a fabric
        configuration is not found, it logs a warning and exits. Optionally, if a fabric logger configuration is provided,
        it initializes a TensorBoardLogger with the specified settings.

        Expected configuration keys:
        - fabric: The configuration for the Lightning Fabric.
        - fabric.loggers: The configuration for the TensorBoardLogger.
        """
        if self._fabric_instance is None:
            if config.get("fabric", None) is None:
                log.warning("No fabric configuration found. use default settings. By default, use 1 device.")
                self._fabric_instance = L.Fabric(devices=1)
            else:
                self._fabric_instance = instantiate(config.fabric)
            if not _is_using_cli():  # if not using cli, launch the fabric
                self._fabric_instance.launch()
            # Set the log directory in config if it is not already set
            if (
                self.log_dir is not None
                and hasattr(config, "log_dir")
                and config.get("log_dir", None) is None
            ):
                if self._fabric_instance.is_global_zero:
                    log.info(f"Setting log_dir to {self.log_dir}")
                config.log_dir = self.log_dir

    @property
    def fabric(self):
        """
        Get the Lightning Fabric instance, initializing it if necessary.

        Returns:
            L.Fabric: The Lightning Fabric instance for distributed computing.
        """
        if self._fabric_instance is None:
            self.setup_lightning_fabric(getattr(self, "config", DictConfig({})))
        return self._fabric_instance

    @fabric.setter
    def fabric(self, instance: L.Fabric):
        """
        Set the Lightning Fabric instance.

        Args:
            instance: The Lightning Fabric instance to use.
        """
        self._fabric_instance = instance

    @property
    def log_dir(self):
        """
        Retrieves the log directory from the fabric's logger.
        """
        if self.fabric is not None and len(self.fabric._loggers) > 0:
            if hasattr(self.fabric.logger, "log_dir"):
                log_dir = self.fabric.logger.log_dir
            else:
                log_dir = None

            # Special handling for SwanLabLogger to get the correct log directory
            if (
                log_dir is None
                and self.fabric.logger.__class__.__name__ == "SwanLabLogger"
            ):
                log_dir = self.fabric.logger.save_dir or self.fabric.logger._logdir

            if (
                log_dir is None
                and self.fabric.logger.__class__.__name__ == "MLFlowLogger"
            ):
                log_dir = self.fabric.logger.save_dir
                if log_dir is None:
                    try:
                        log_dir = self._program.config.path.log_dir
                    except Exception:
                        log.error(
                            "Failed to get log_dir from program config for MLFlowLogger."
                        )
                        log_dir = "outputs"

            assert log_dir is not None, "log_dir should not be None"
            if self.fabric.is_global_zero and not os.path.exists(log_dir):
                os.makedirs(log_dir, exist_ok=True)
            return log_dir
        else:
            return None

    def to_device(self, obj: TensorOrModule) -> TensorOrModule:
        """
        Moves a tensor or module to the proper device.

        Args:
            obj (TensorOrModule): The tensor or module to move to the device.

        Returns:
            TensorOrModule: the same type of object as the input, moved to the device.
        """
        return self.fabric.to_device(obj)

    @rank_zero_only
    def log_hyperparams(
        self,
        config: Optional[DictConfig] = None,
        save_dir: Optional[str] = None,
        filename: str = "config.yaml",
    ):
        R"""
        Logs the hyperparameters and saves the configuration to a YAML file.
        The YAML file is saved in the log directory by default with the name `config.yaml`, or in the specified save directory `save_dir`.

        Args:
            config (Optional[DictConfig]): The configuration to log and save. If not provided, the class's `config` attribute is used.
            save_dir (Optional[str]): The directory in which to save the configuration file. If not provided, the log directory is used.
            filename (str): The name of the configuration file. Default is `config.yaml`.
        """
        if config is None:
            config = self.config
        if save_dir is None:
            save_dir = self.log_dir
        self.fabric.logger.log_hyperparams(
            OmegaConf.to_container(config, resolve=True, enum_to_str=True)
        )
        if not os.path.exists(save_dir):
            os.makedirs(save_dir, exist_ok=True)
        OmegaConf.save(
            config,
            os.path.join(self.log_dir if save_dir is None else save_dir, filename),
        )

    @property
    def tensorboard_summarywriter(
        self,
    ) -> "lightning.fabric.loggers.tensorboard.SummaryWriter":
        """
        Get the TensorBoard SummaryWriter for detailed logging.

        Returns:
            SummaryWriter: The TensorBoard SummaryWriter instance.

        Raises:
            AttributeError: If the logger is not a TensorBoardLogger.
        """
        if isinstance(self.fabric.logger, TensorBoardLogger):
            return self.fabric.logger.experiment
        else:
            raise AttributeError("the logger is not a TensorBoardLogger.")

    @property
    def is_debug_mode(self):
        """
        Check if the program is running in debug mode (fast_dev_run).

        Returns:
            bool: True if fast_dev_run is enabled, False otherwise.
        """
        return RuntimeConstants().debug

    def log(self, name: str, value: Any, step: Optional[int] = None):
        """
        Logs a single metric to the fabric's logger.

        Args:
            name: The name of the metric to log.
            value: The value of the metric.
            step: Optional step number for the metric.
        """
        self.fabric.log(name, value, step=step)

    def log_dict(self, metrics: Mapping[str, Any], step: Optional[int] = None):
        """
        Logs multiple metrics to the fabric's logger.

        Args:
            metrics: Dictionary of metric names and values.
            step: Optional step number for the metrics.
        """
        self.fabric.log_dict(metrics, step=step)

    def log_optimizer_lr(
        self,
        optimizer: torch.optim.Optimizer,
        step: Optional[int] = None,
        name_template: str = "train/lr_group_{0}",
    ):
        """
        Logs the learning rate of each parameter group in the optimizer.

        Args:
            optimizer: The optimizer whose learning rates should be logged.
            step: Optional step number for the log entry.
            name_template: Template string for the log name. Use {0} as placeholder for group index.
        """
        for i, param_group in enumerate(optimizer.param_groups):
            self.fabric.log(name_template.format(i), param_group["lr"], step=step)

    def log_artifact(self, local_path: str, artifact_path: str | None = None):
        """
        Logs a file as an artifact to the fabric's logger.

        Args:
            local_dir: The path to the directory to log as an artifact.
            artifact_path: The directory within the logger's artifact storage to save the file.
        """
        if _is_mlflow_logger(self.fabric):
            logger: "MLFlowLogger" = self.fabric.logger
            experiment: "MlflowClient" = logger.experiment
            experiment.log_artifact(
                logger.run_id,
                local_path=local_path,
                artifact_path=(artifact_path),
            )

    def log_artifacts(self, local_dir: str, artifact_path: str | None = None):
        """
        Logs a directory as artifacts to the fabric's logger.

        Args:
            local_dir: The path to the directory to log as artifacts.
            artifact_path: The directory within the logger's artifact storage to save the files.
        """
        if _is_mlflow_logger(self.fabric):
            logger: "MLFlowLogger" = self.fabric.logger
            experiment: "MlflowClient" = logger.experiment
            experiment.log_artifacts(
                logger.run_id,
                local_dir=local_dir,
                artifact_path=artifact_path,
            )

    def finalize(self):
        """
        Destructor to ensure proper cleanup of the Lightning Fabric instance.
        """
        if self._fabric_instance is None:
            return

        if _fabric_has_logger(self.fabric) and _is_mlflow_logger(self.fabric):
            if sys.exc_info()[0] is None:
                status = "success"
            else:
                status = "failed"
            self.fabric.logger.finalize(status)

        del self._fabric_instance
        self._fabric_instance = None

    def __del__(self):
        """
        Destructor to ensure proper cleanup of the Lightning Fabric instance.
        """
        self.finalize()

fabric property writable

Get the Lightning Fabric instance, initializing it if necessary.

Returns:

  • L.Fabric: The Lightning Fabric instance for distributed computing.

is_debug_mode property

Check if the program is running in debug mode (fast_dev_run).

Returns:

  • bool

    True if fast_dev_run is enabled, False otherwise.

log_dir property

Retrieves the log directory from the fabric's logger.

tensorboard_summarywriter property

Get the TensorBoard SummaryWriter for detailed logging.

Returns:

  • SummaryWriter ( SummaryWriter ) –

    The TensorBoard SummaryWriter instance.

Raises:

  • AttributeError

    If the logger is not a TensorBoardLogger.

__del__()

Destructor to ensure proper cleanup of the Lightning Fabric instance.

Source code in fusion_bench/mixins/lightning_fabric.py
def __del__(self):
    """
    Destructor to ensure proper cleanup of the Lightning Fabric instance.
    """
    self.finalize()

finalize()

Destructor to ensure proper cleanup of the Lightning Fabric instance.

Source code in fusion_bench/mixins/lightning_fabric.py
def finalize(self):
    """
    Destructor to ensure proper cleanup of the Lightning Fabric instance.
    """
    if self._fabric_instance is None:
        return

    if _fabric_has_logger(self.fabric) and _is_mlflow_logger(self.fabric):
        if sys.exc_info()[0] is None:
            status = "success"
        else:
            status = "failed"
        self.fabric.logger.finalize(status)

    del self._fabric_instance
    self._fabric_instance = None

log(name, value, step=None)

Logs a single metric to the fabric's logger.

Parameters:

  • name (str) –

    The name of the metric to log.

  • value (Any) –

    The value of the metric.

  • step (Optional[int], default: None ) –

    Optional step number for the metric.

Source code in fusion_bench/mixins/lightning_fabric.py
def log(self, name: str, value: Any, step: Optional[int] = None):
    """
    Logs a single metric to the fabric's logger.

    Args:
        name: The name of the metric to log.
        value: The value of the metric.
        step: Optional step number for the metric.
    """
    self.fabric.log(name, value, step=step)

log_artifact(local_path, artifact_path=None)

Logs a file as an artifact to the fabric's logger.

Parameters:

  • local_dir

    The path to the directory to log as an artifact.

  • artifact_path (str | None, default: None ) –

    The directory within the logger's artifact storage to save the file.

Source code in fusion_bench/mixins/lightning_fabric.py
def log_artifact(self, local_path: str, artifact_path: str | None = None):
    """
    Logs a file as an artifact to the fabric's logger.

    Args:
        local_dir: The path to the directory to log as an artifact.
        artifact_path: The directory within the logger's artifact storage to save the file.
    """
    if _is_mlflow_logger(self.fabric):
        logger: "MLFlowLogger" = self.fabric.logger
        experiment: "MlflowClient" = logger.experiment
        experiment.log_artifact(
            logger.run_id,
            local_path=local_path,
            artifact_path=(artifact_path),
        )

log_artifacts(local_dir, artifact_path=None)

Logs a directory as artifacts to the fabric's logger.

Parameters:

  • local_dir (str) –

    The path to the directory to log as artifacts.

  • artifact_path (str | None, default: None ) –

    The directory within the logger's artifact storage to save the files.

Source code in fusion_bench/mixins/lightning_fabric.py
def log_artifacts(self, local_dir: str, artifact_path: str | None = None):
    """
    Logs a directory as artifacts to the fabric's logger.

    Args:
        local_dir: The path to the directory to log as artifacts.
        artifact_path: The directory within the logger's artifact storage to save the files.
    """
    if _is_mlflow_logger(self.fabric):
        logger: "MLFlowLogger" = self.fabric.logger
        experiment: "MlflowClient" = logger.experiment
        experiment.log_artifacts(
            logger.run_id,
            local_dir=local_dir,
            artifact_path=artifact_path,
        )

log_dict(metrics, step=None)

Logs multiple metrics to the fabric's logger.

Parameters:

  • metrics (Mapping[str, Any]) –

    Dictionary of metric names and values.

  • step (Optional[int], default: None ) –

    Optional step number for the metrics.

Source code in fusion_bench/mixins/lightning_fabric.py
def log_dict(self, metrics: Mapping[str, Any], step: Optional[int] = None):
    """
    Logs multiple metrics to the fabric's logger.

    Args:
        metrics: Dictionary of metric names and values.
        step: Optional step number for the metrics.
    """
    self.fabric.log_dict(metrics, step=step)

log_hyperparams(config=None, save_dir=None, filename='config.yaml')

Logs the hyperparameters and saves the configuration to a YAML file. The YAML file is saved in the log directory by default with the name config.yaml, or in the specified save directory save_dir.

Parameters:

  • config (Optional[DictConfig], default: None ) –

    The configuration to log and save. If not provided, the class's config attribute is used.

  • save_dir (Optional[str], default: None ) –

    The directory in which to save the configuration file. If not provided, the log directory is used.

  • filename (str, default: 'config.yaml' ) –

    The name of the configuration file. Default is config.yaml.

Source code in fusion_bench/mixins/lightning_fabric.py
@rank_zero_only
def log_hyperparams(
    self,
    config: Optional[DictConfig] = None,
    save_dir: Optional[str] = None,
    filename: str = "config.yaml",
):
    R"""
    Logs the hyperparameters and saves the configuration to a YAML file.
    The YAML file is saved in the log directory by default with the name `config.yaml`, or in the specified save directory `save_dir`.

    Args:
        config (Optional[DictConfig]): The configuration to log and save. If not provided, the class's `config` attribute is used.
        save_dir (Optional[str]): The directory in which to save the configuration file. If not provided, the log directory is used.
        filename (str): The name of the configuration file. Default is `config.yaml`.
    """
    if config is None:
        config = self.config
    if save_dir is None:
        save_dir = self.log_dir
    self.fabric.logger.log_hyperparams(
        OmegaConf.to_container(config, resolve=True, enum_to_str=True)
    )
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    OmegaConf.save(
        config,
        os.path.join(self.log_dir if save_dir is None else save_dir, filename),
    )

log_optimizer_lr(optimizer, step=None, name_template='train/lr_group_{0}')

Logs the learning rate of each parameter group in the optimizer.

Parameters:

  • optimizer (Optimizer) –

    The optimizer whose learning rates should be logged.

  • step (Optional[int], default: None ) –

    Optional step number for the log entry.

  • name_template (str, default: 'train/lr_group_{0}' ) –

    Template string for the log name. Use {0} as placeholder for group index.

Source code in fusion_bench/mixins/lightning_fabric.py
def log_optimizer_lr(
    self,
    optimizer: torch.optim.Optimizer,
    step: Optional[int] = None,
    name_template: str = "train/lr_group_{0}",
):
    """
    Logs the learning rate of each parameter group in the optimizer.

    Args:
        optimizer: The optimizer whose learning rates should be logged.
        step: Optional step number for the log entry.
        name_template: Template string for the log name. Use {0} as placeholder for group index.
    """
    for i, param_group in enumerate(optimizer.param_groups):
        self.fabric.log(name_template.format(i), param_group["lr"], step=step)

setup_lightning_fabric(config)

Initializes and launches the Lightning Fabric with optional logging.

This method sets up the Lightning Fabric for distributed computing based on the provided configuration. If a fabric configuration is not found, it logs a warning and exits. Optionally, if a fabric logger configuration is provided, it initializes a TensorBoardLogger with the specified settings.

Expected configuration keys: - fabric: The configuration for the Lightning Fabric. - fabric.loggers: The configuration for the TensorBoardLogger.

Source code in fusion_bench/mixins/lightning_fabric.py
def setup_lightning_fabric(self, config: DictConfig):
    """
    Initializes and launches the Lightning Fabric with optional logging.

    This method sets up the Lightning Fabric for distributed computing based on the provided configuration. If a fabric
    configuration is not found, it logs a warning and exits. Optionally, if a fabric logger configuration is provided,
    it initializes a TensorBoardLogger with the specified settings.

    Expected configuration keys:
    - fabric: The configuration for the Lightning Fabric.
    - fabric.loggers: The configuration for the TensorBoardLogger.
    """
    if self._fabric_instance is None:
        if config.get("fabric", None) is None:
            log.warning("No fabric configuration found. use default settings. By default, use 1 device.")
            self._fabric_instance = L.Fabric(devices=1)
        else:
            self._fabric_instance = instantiate(config.fabric)
        if not _is_using_cli():  # if not using cli, launch the fabric
            self._fabric_instance.launch()
        # Set the log directory in config if it is not already set
        if (
            self.log_dir is not None
            and hasattr(config, "log_dir")
            and config.get("log_dir", None) is None
        ):
            if self._fabric_instance.is_global_zero:
                log.info(f"Setting log_dir to {self.log_dir}")
            config.log_dir = self.log_dir

to_device(obj)

Moves a tensor or module to the proper device.

Parameters:

  • obj (TensorOrModule) –

    The tensor or module to move to the device.

Returns:

  • TensorOrModule ( TensorOrModule ) –

    the same type of object as the input, moved to the device.

Source code in fusion_bench/mixins/lightning_fabric.py
def to_device(self, obj: TensorOrModule) -> TensorOrModule:
    """
    Moves a tensor or module to the proper device.

    Args:
        obj (TensorOrModule): The tensor or module to move to the device.

    Returns:
        TensorOrModule: the same type of object as the input, moved to the device.
    """
    return self.fabric.to_device(obj)

FabricTrainingMixin

Bases: LightningFabricMixin

This is a general purpose mixin for training a model with PyTorch Lightning.

Source code in fusion_bench/mixins/fabric_training.py
class FabricTrainingMixin(LightningFabricMixin):
    """
    This is a general purpose mixin for training a model with PyTorch Lightning.
    """

    _latest_saved_checkpoint_global_step: int = -1
    """The global step index of the latest saved checkpoint."""
    _expected_total_steps: int = None
    """The expected total number of steps of the entire training."""
    is_training: bool
    """Whether the training is in progress. If set to False, the training will stop."""
    epoch_idx: int
    """The epoch index, which is the number of epochs completed."""
    global_step_idx: int
    """The global step index, which is the number of parameter update steps."""
    max_epochs: int
    """Max number of epochs of the entire training."""
    max_steps: int
    """Max number of parameter update steps of the entire training."""
    max_steps_per_epoch: int
    """Max number of parameter update steps per epoch."""
    gradient_clip_algorithm: Literal["value", "norm"]
    """The algorithm to clip gradients. Available options: 'value', 'norm'."""
    gradient_clip_val: float
    """The value to clip gradients. If None, no clipping is applied."""
    accumulate_grad_batches: int
    """The number of gradient accumulation steps. The effective global batch size is `the batch size per device` x `the number of devices` x `the number of gradient accumulation steps`."""
    lr_scheduler_interval: Literal["step", "epoch"]
    """The interval to run the learning rate scheduler. Available options: 'step', 'epoch'."""
    lr_scheduler_frequency: int
    """The frequency to run the learning rate scheduler."""
    checkpoint_save_interval: Literal["step", "epoch"]
    """The interval to save the model checkpoint. Available options: 'step', 'epoch'."""
    checkpoint_save_frequency: int
    """The frequency to save the model checkpoint."""

    def clip_gradients_if_needed(self, model, optimizer):
        """
        Clips gradients if the gradient clipping value is set.

        Args:
            model (nn.Module): The model whose gradients need to be clipped.
            optimizer (torch.optim.Optimizer): The optimizer used for training.
        """
        fabric = self.fabric

        if self.gradient_clip_val is not None:
            if self.gradient_clip_algorithm == "value":
                fabric.clip_gradients(model, optimizer, clip_val=self.gradient_clip_val)
            elif self.gradient_clip_algorithm == "norm":
                fabric.clip_gradients(model, optimizer, max_norm=self.gradient_clip_val)
            else:
                raise ValueError(
                    f"Unknown gradient clip algorithm: {self.gradient_clip_algorithm}. Available options: 'value', 'norm'"
                )

    def compute_expected_total_steps(
        self, train_dataloader: torch.utils.data.DataLoader
    ):
        """
        Computes the expected total number of steps for the entire training.

        Args:
            train_dataloader (torch.utils.data.DataLoader): The dataloader for the training data.
        """
        # compute expected total steps
        self._expected_total_steps = []
        if self.max_steps > 0:
            self._expected_total_steps.append(self.max_steps)
        if self.max_steps_per_epoch > 0 and self.max_epochs > 0:
            self._expected_total_steps.append(
                self.max_steps_per_epoch * self.max_epochs
            )
        if self.max_epochs > 0:
            self._expected_total_steps.append(
                len(train_dataloader) * self.max_epochs // self.accumulate_grad_batches
            )
        self._expected_total_steps = min(self._expected_total_steps)
        log.info(f"Expected total steps: {self._expected_total_steps}")

    @property
    def expected_total_steps(self):
        """
        The expected total number of steps of the entire training. You need to run `compute_expected_total_steps` method to compute this value before accessing it.

        Raises:
            ValueError: If the expected total steps have not been computed.
        """
        if self._expected_total_steps is None:
            raise ValueError(
                "The expected total steps have not been computed. Run `compute_expected_total_steps` method."
            )
        else:
            return self._expected_total_steps

    def conditional_checkpoint_save(
        self,
        stage: Literal["end_of_step", "end_of_epoch", "end_of_training"],
        *args,
        **kwargs,
    ):
        """
        Conditionally saves a checkpoint based on the current training stage.

        Args:
            stage (Literal["end_of_step", "end_of_epoch", "end_of_training"]): The current stage of training.
        """
        if stage == "end_of_step":
            if (
                self.checkpoint_save_interval == "step"
                and (self.global_step_idx + 1) % self.checkpoint_save_frequency == 0
            ):
                save_path = os.path.join(
                    self.log_dir, "checkpoints", f"step={self.global_step_idx}.ckpt"
                )
                self.save_checkpoint(save_path, *args, **kwargs)
        elif stage == "end_of_epoch":
            if (
                self.checkpoint_save_interval == "epoch"
                and (self.epoch_idx + 1) % self.checkpoint_save_frequency == 0
            ):
                save_path = os.path.join(
                    self.log_dir, "checkpoints", f"epoch={self.epoch_idx}.ckpt"
                )
                self.save_checkpoint(save_path, *args, **kwargs)
        elif stage == "end_of_training":
            # if the checkpoint has not been saved yet, save it
            if self.global_step_idx > self._latest_saved_checkpoint_global_step:
                save_path = os.path.join(
                    self.log_dir,
                    "checkpoints",
                    f"epoch={self.epoch_idx}_step={self.global_step_idx}.ckpt",
                )
                self.save_checkpoint(save_path, *args, **kwargs)
                try:
                    os.symlink(
                        src=save_path,
                        dst=os.path.join(
                            self.log_dir, "checkpoints", "latest_model.ckpt"
                        ),
                        target_is_directory=os.path.isdir(save_path),
                    )
                except Exception as e:
                    log.error(f"Failed to create symlink: {e}")
        else:
            raise ValueError(
                f"Unknown stage: {stage}. Available options: 'end_of_step', 'end_of_epoch', 'end_of_training'"
            )

    @abstractmethod
    def save_checkpoint(self, path, **kwargs):
        """
        Saves a checkpoint of the model.

        Args:
            path (str): The path where the checkpoint will be saved.

        Raises:
            NotImplementedError: If the method is not implemented.
        """
        raise NotImplementedError("save_checkpoint method is not implemented")

    def train(
        self,
        model: Union[nn.Module, "_FabricModule"],
        optimizer: Union[torch.optim.Optimizer, "_FabricOptimizer"],
        lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
    ):
        """
        Trains the model.

        The global batch size is `the batch size per device` x `the number of devices` x `the number of gradient accumulation steps`.

        Args:
            model (Union[nn.Module, "_FabricModule"]): The model to be trained.
            optimizer (Union[torch.optim.Optimizer, "_FabricOptimizer"]): The optimizer used for training.
            lr_scheduler (torch.optim.lr_scheduler.LRScheduler): The learning rate scheduler.
        """
        fabric = self.fabric
        self.is_training = True
        # number of parameter update iterations, not the number of batches
        self.global_step_idx = 0
        model.train()
        optimizer.zero_grad()
        for epoch_idx in tqdm(
            range(self.max_epochs) if self.max_epochs > 0 else itertools.count(0),
            "Training Epoch",
            dynamic_ncols=True,
            leave=False,
            disable=not fabric.is_global_zero,
        ):
            self.epoch_idx = epoch_idx
            self.train_epoch(model, optimizer, lr_scheduler)
            # run lr_scheduler at the end of the epoch if interval is set to "epoch"
            if (
                self.lr_scheduler_interval == "epoch"
                and (epoch_idx + 1) % self.lr_scheduler_frequency == 0
            ):
                lr_scheduler.step()

            # save the model at the end of the epoch if interval is set to "epoch" and frequency is met
            self.conditional_checkpoint_save(stage="end_of_epoch")

            if not self.is_training:
                break

        optimizer.zero_grad()
        # save the model at the end of training
        self.conditional_checkpoint_save(stage="end_of_training")
        return model

    @abstractmethod
    def train_epoch(
        self,
        model: Union[nn.Module, "_FabricModule"],
        optimizer: Union[torch.optim.Optimizer, "_FabricOptimizer"],
        lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
    ):
        """
        Trains the model for one epoch.

        Args:
            model (Union[nn.Module, "_FabricModule"]): The model to be trained.
            optimizer (Union[torch.optim.Optimizer, "_FabricOptimizer"]): The optimizer used for training.
            lr_scheduler (torch.optim.lr_scheduler.LRScheduler): The learning rate scheduler.

        Raises:
            NotImplementedError: If the method is not implemented.
        """
        raise NotImplementedError(
            "Copy this as a template and implement your own train_epoch method"
        )
        fabric = self.fabric

        accumulated_loss = 0
        for step_idx, batch in enumerate(
            pbar := tqdm(
                self.train_dataloader,
                desc="Training Batches",
                dynamic_ncols=True,
                leave=False,
                disable=not fabric.is_global_zero,
            )
        ):
            is_accumulating = (step_idx + 1) % self.accumulate_grad_batches != 0

            # disable gradient synchronization if accumulating gradients across steps for improved performance
            with fabric.no_backward_sync(self.model, enabled=is_accumulating):
                # use_cache=True is not compatible with gradient checkpointing, so we disable it here
                output = self.compute_loss(batch)
                loss = output["loss"] / self.accumulate_grad_batches

                fabric.backward(loss)
                accumulated_loss += loss.item()

            # 1. update the model parameters if not accumulating gradients
            # 2. step the lr_scheduler if interval is set to "step" and frequency is met
            # 3. save the model if interval is set to "step" and frequency is met
            # 4. log metrics
            # 5. increase the global step index and reset the accumulated metrics
            if not is_accumulating:
                self.clip_gradients_if_needed(model, optimizer)

                # run lr_scheduler at the end of the step if interval is set to "step"
                if (
                    self.lr_scheduler_interval == "step"
                    and (self.global_step_idx + 1) % self.lr_scheduler_frequency == 0
                ):
                    lr_scheduler.step()

                # update the model parameters and zero the gradients
                optimizer.step()
                optimizer.zero_grad()

                metrics = {
                    "train/loss": accumulated_loss,
                    "train/lr": optimizer.param_groups[0]["lr"],
                }

                fabric.log_dict(metrics, step=self.global_step_idx)
                pbar.set_postfix(metrics)

                # save the model at the end of the step if interval is set to "step" and frequency is met
                self.conditional_checkpoint_save(stage="end_of_step")

                # break if max_steps_per_epoch is set, and exit epoch
                if (
                    self.max_steps_per_epoch > 0
                    and step_idx + 1 >= self.max_steps_per_epoch
                ):
                    break
                # break if max_steps is set, and exit training
                if self.max_steps > 0 and self.global_step_idx >= self.max_steps - 1:
                    self.is_training = False
                    break

                self.global_step_idx += 1
                accumulated_loss = 0

accumulate_grad_batches instance-attribute

The number of gradient accumulation steps. The effective global batch size is the batch size per device x the number of devices x the number of gradient accumulation steps.

checkpoint_save_frequency instance-attribute

The frequency to save the model checkpoint.

checkpoint_save_interval instance-attribute

The interval to save the model checkpoint. Available options: 'step', 'epoch'.

epoch_idx instance-attribute

The epoch index, which is the number of epochs completed.

expected_total_steps property

The expected total number of steps of the entire training. You need to run compute_expected_total_steps method to compute this value before accessing it.

Raises:

  • ValueError

    If the expected total steps have not been computed.

global_step_idx instance-attribute

The global step index, which is the number of parameter update steps.

gradient_clip_algorithm instance-attribute

The algorithm to clip gradients. Available options: 'value', 'norm'.

gradient_clip_val instance-attribute

The value to clip gradients. If None, no clipping is applied.

is_training instance-attribute

Whether the training is in progress. If set to False, the training will stop.

lr_scheduler_frequency instance-attribute

The frequency to run the learning rate scheduler.

lr_scheduler_interval instance-attribute

The interval to run the learning rate scheduler. Available options: 'step', 'epoch'.

max_epochs instance-attribute

Max number of epochs of the entire training.

max_steps instance-attribute

Max number of parameter update steps of the entire training.

max_steps_per_epoch instance-attribute

Max number of parameter update steps per epoch.

clip_gradients_if_needed(model, optimizer)

Clips gradients if the gradient clipping value is set.

Parameters:

  • model (Module) –

    The model whose gradients need to be clipped.

  • optimizer (Optimizer) –

    The optimizer used for training.

Source code in fusion_bench/mixins/fabric_training.py
def clip_gradients_if_needed(self, model, optimizer):
    """
    Clips gradients if the gradient clipping value is set.

    Args:
        model (nn.Module): The model whose gradients need to be clipped.
        optimizer (torch.optim.Optimizer): The optimizer used for training.
    """
    fabric = self.fabric

    if self.gradient_clip_val is not None:
        if self.gradient_clip_algorithm == "value":
            fabric.clip_gradients(model, optimizer, clip_val=self.gradient_clip_val)
        elif self.gradient_clip_algorithm == "norm":
            fabric.clip_gradients(model, optimizer, max_norm=self.gradient_clip_val)
        else:
            raise ValueError(
                f"Unknown gradient clip algorithm: {self.gradient_clip_algorithm}. Available options: 'value', 'norm'"
            )

compute_expected_total_steps(train_dataloader)

Computes the expected total number of steps for the entire training.

Parameters:

  • train_dataloader (DataLoader) –

    The dataloader for the training data.

Source code in fusion_bench/mixins/fabric_training.py
def compute_expected_total_steps(
    self, train_dataloader: torch.utils.data.DataLoader
):
    """
    Computes the expected total number of steps for the entire training.

    Args:
        train_dataloader (torch.utils.data.DataLoader): The dataloader for the training data.
    """
    # compute expected total steps
    self._expected_total_steps = []
    if self.max_steps > 0:
        self._expected_total_steps.append(self.max_steps)
    if self.max_steps_per_epoch > 0 and self.max_epochs > 0:
        self._expected_total_steps.append(
            self.max_steps_per_epoch * self.max_epochs
        )
    if self.max_epochs > 0:
        self._expected_total_steps.append(
            len(train_dataloader) * self.max_epochs // self.accumulate_grad_batches
        )
    self._expected_total_steps = min(self._expected_total_steps)
    log.info(f"Expected total steps: {self._expected_total_steps}")

conditional_checkpoint_save(stage, *args, **kwargs)

Conditionally saves a checkpoint based on the current training stage.

Parameters:

  • stage (Literal['end_of_step', 'end_of_epoch', 'end_of_training']) –

    The current stage of training.

Source code in fusion_bench/mixins/fabric_training.py
def conditional_checkpoint_save(
    self,
    stage: Literal["end_of_step", "end_of_epoch", "end_of_training"],
    *args,
    **kwargs,
):
    """
    Conditionally saves a checkpoint based on the current training stage.

    Args:
        stage (Literal["end_of_step", "end_of_epoch", "end_of_training"]): The current stage of training.
    """
    if stage == "end_of_step":
        if (
            self.checkpoint_save_interval == "step"
            and (self.global_step_idx + 1) % self.checkpoint_save_frequency == 0
        ):
            save_path = os.path.join(
                self.log_dir, "checkpoints", f"step={self.global_step_idx}.ckpt"
            )
            self.save_checkpoint(save_path, *args, **kwargs)
    elif stage == "end_of_epoch":
        if (
            self.checkpoint_save_interval == "epoch"
            and (self.epoch_idx + 1) % self.checkpoint_save_frequency == 0
        ):
            save_path = os.path.join(
                self.log_dir, "checkpoints", f"epoch={self.epoch_idx}.ckpt"
            )
            self.save_checkpoint(save_path, *args, **kwargs)
    elif stage == "end_of_training":
        # if the checkpoint has not been saved yet, save it
        if self.global_step_idx > self._latest_saved_checkpoint_global_step:
            save_path = os.path.join(
                self.log_dir,
                "checkpoints",
                f"epoch={self.epoch_idx}_step={self.global_step_idx}.ckpt",
            )
            self.save_checkpoint(save_path, *args, **kwargs)
            try:
                os.symlink(
                    src=save_path,
                    dst=os.path.join(
                        self.log_dir, "checkpoints", "latest_model.ckpt"
                    ),
                    target_is_directory=os.path.isdir(save_path),
                )
            except Exception as e:
                log.error(f"Failed to create symlink: {e}")
    else:
        raise ValueError(
            f"Unknown stage: {stage}. Available options: 'end_of_step', 'end_of_epoch', 'end_of_training'"
        )

save_checkpoint(path, **kwargs) abstractmethod

Saves a checkpoint of the model.

Parameters:

  • path (str) –

    The path where the checkpoint will be saved.

Raises:

  • NotImplementedError

    If the method is not implemented.

Source code in fusion_bench/mixins/fabric_training.py
@abstractmethod
def save_checkpoint(self, path, **kwargs):
    """
    Saves a checkpoint of the model.

    Args:
        path (str): The path where the checkpoint will be saved.

    Raises:
        NotImplementedError: If the method is not implemented.
    """
    raise NotImplementedError("save_checkpoint method is not implemented")

train(model, optimizer, lr_scheduler)

Trains the model.

The global batch size is the batch size per device x the number of devices x the number of gradient accumulation steps.

Parameters:

  • model (Union[Module, _FabricModule]) –

    The model to be trained.

  • optimizer (Union[Optimizer, _FabricOptimizer]) –

    The optimizer used for training.

  • lr_scheduler (LRScheduler) –

    The learning rate scheduler.

Source code in fusion_bench/mixins/fabric_training.py
def train(
    self,
    model: Union[nn.Module, "_FabricModule"],
    optimizer: Union[torch.optim.Optimizer, "_FabricOptimizer"],
    lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
):
    """
    Trains the model.

    The global batch size is `the batch size per device` x `the number of devices` x `the number of gradient accumulation steps`.

    Args:
        model (Union[nn.Module, "_FabricModule"]): The model to be trained.
        optimizer (Union[torch.optim.Optimizer, "_FabricOptimizer"]): The optimizer used for training.
        lr_scheduler (torch.optim.lr_scheduler.LRScheduler): The learning rate scheduler.
    """
    fabric = self.fabric
    self.is_training = True
    # number of parameter update iterations, not the number of batches
    self.global_step_idx = 0
    model.train()
    optimizer.zero_grad()
    for epoch_idx in tqdm(
        range(self.max_epochs) if self.max_epochs > 0 else itertools.count(0),
        "Training Epoch",
        dynamic_ncols=True,
        leave=False,
        disable=not fabric.is_global_zero,
    ):
        self.epoch_idx = epoch_idx
        self.train_epoch(model, optimizer, lr_scheduler)
        # run lr_scheduler at the end of the epoch if interval is set to "epoch"
        if (
            self.lr_scheduler_interval == "epoch"
            and (epoch_idx + 1) % self.lr_scheduler_frequency == 0
        ):
            lr_scheduler.step()

        # save the model at the end of the epoch if interval is set to "epoch" and frequency is met
        self.conditional_checkpoint_save(stage="end_of_epoch")

        if not self.is_training:
            break

    optimizer.zero_grad()
    # save the model at the end of training
    self.conditional_checkpoint_save(stage="end_of_training")
    return model

train_epoch(model, optimizer, lr_scheduler) abstractmethod

Trains the model for one epoch.

Parameters:

  • model (Union[Module, _FabricModule]) –

    The model to be trained.

  • optimizer (Union[Optimizer, _FabricOptimizer]) –

    The optimizer used for training.

  • lr_scheduler (LRScheduler) –

    The learning rate scheduler.

Raises:

  • NotImplementedError

    If the method is not implemented.

Source code in fusion_bench/mixins/fabric_training.py
@abstractmethod
def train_epoch(
    self,
    model: Union[nn.Module, "_FabricModule"],
    optimizer: Union[torch.optim.Optimizer, "_FabricOptimizer"],
    lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
):
    """
    Trains the model for one epoch.

    Args:
        model (Union[nn.Module, "_FabricModule"]): The model to be trained.
        optimizer (Union[torch.optim.Optimizer, "_FabricOptimizer"]): The optimizer used for training.
        lr_scheduler (torch.optim.lr_scheduler.LRScheduler): The learning rate scheduler.

    Raises:
        NotImplementedError: If the method is not implemented.
    """
    raise NotImplementedError(
        "Copy this as a template and implement your own train_epoch method"
    )
    fabric = self.fabric

    accumulated_loss = 0
    for step_idx, batch in enumerate(
        pbar := tqdm(
            self.train_dataloader,
            desc="Training Batches",
            dynamic_ncols=True,
            leave=False,
            disable=not fabric.is_global_zero,
        )
    ):
        is_accumulating = (step_idx + 1) % self.accumulate_grad_batches != 0

        # disable gradient synchronization if accumulating gradients across steps for improved performance
        with fabric.no_backward_sync(self.model, enabled=is_accumulating):
            # use_cache=True is not compatible with gradient checkpointing, so we disable it here
            output = self.compute_loss(batch)
            loss = output["loss"] / self.accumulate_grad_batches

            fabric.backward(loss)
            accumulated_loss += loss.item()

        # 1. update the model parameters if not accumulating gradients
        # 2. step the lr_scheduler if interval is set to "step" and frequency is met
        # 3. save the model if interval is set to "step" and frequency is met
        # 4. log metrics
        # 5. increase the global step index and reset the accumulated metrics
        if not is_accumulating:
            self.clip_gradients_if_needed(model, optimizer)

            # run lr_scheduler at the end of the step if interval is set to "step"
            if (
                self.lr_scheduler_interval == "step"
                and (self.global_step_idx + 1) % self.lr_scheduler_frequency == 0
            ):
                lr_scheduler.step()

            # update the model parameters and zero the gradients
            optimizer.step()
            optimizer.zero_grad()

            metrics = {
                "train/loss": accumulated_loss,
                "train/lr": optimizer.param_groups[0]["lr"],
            }

            fabric.log_dict(metrics, step=self.global_step_idx)
            pbar.set_postfix(metrics)

            # save the model at the end of the step if interval is set to "step" and frequency is met
            self.conditional_checkpoint_save(stage="end_of_step")

            # break if max_steps_per_epoch is set, and exit epoch
            if (
                self.max_steps_per_epoch > 0
                and step_idx + 1 >= self.max_steps_per_epoch
            ):
                break
            # break if max_steps is set, and exit training
            if self.max_steps > 0 and self.global_step_idx >= self.max_steps - 1:
                self.is_training = False
                break

            self.global_step_idx += 1
            accumulated_loss = 0

SimpleProfilerMixin

A mixin class that provides simple profiling capabilities using Lightning's SimpleProfiler.

This mixin allows for easy profiling of code blocks using a context manager or manual start/stop methods. It measures the execution time of named actions and provides a summary of the profiling results. Unlike statistical profilers, this provides precise timing measurements for specific code blocks.

Note

This mixin uses Lightning's SimpleProfiler which measures wall-clock time for named actions. It's suitable for timing discrete operations rather than detailed function-level profiling.

Examples:

class MyClass(SimpleProfilerMixin):
    def do_something(self):
        with self.profile("data_loading"):
            # Load data here
            data = load_data()

        with self.profile("model_training"):
            # Train model here
            model.train(data)

        # Print the profiling summary
        self.print_profile_summary("Training Profile")

Attributes:

  • _profiler (SimpleProfiler) –

    An instance of the SimpleProfiler class used for profiling.

Source code in fusion_bench/mixins/simple_profiler.py
class SimpleProfilerMixin:
    """
    A mixin class that provides simple profiling capabilities using Lightning's SimpleProfiler.

    This mixin allows for easy profiling of code blocks using a context manager or manual
    start/stop methods. It measures the execution time of named actions and provides
    a summary of the profiling results. Unlike statistical profilers, this provides
    precise timing measurements for specific code blocks.

    Note:
        This mixin uses Lightning's SimpleProfiler which measures wall-clock time
        for named actions. It's suitable for timing discrete operations rather than
        detailed function-level profiling.

    Examples:
        ```python
        class MyClass(SimpleProfilerMixin):
            def do_something(self):
                with self.profile("data_loading"):
                    # Load data here
                    data = load_data()

                with self.profile("model_training"):
                    # Train model here
                    model.train(data)

                # Print the profiling summary
                self.print_profile_summary("Training Profile")
        ```

    Attributes:
        _profiler (SimpleProfiler): An instance of the SimpleProfiler class used for profiling.
    """

    _profiler: SimpleProfiler = None

    @property
    def profiler(self) -> SimpleProfiler:
        """
        Get the SimpleProfiler instance, creating it if necessary.

        Returns:
            SimpleProfiler: The profiler instance used for timing measurements.
        """
        # Lazy initialization of the profiler instance
        if self._profiler is None:
            self._profiler = SimpleProfiler()
        return self._profiler

    @contextmanager
    def profile(self, action_name: str) -> Generator:
        """
        Context manager for profiling a code block.

        This context manager automatically starts profiling when entering the block
        and stops profiling when exiting the block (even if an exception occurs).

        Args:
            action_name: A descriptive name for the action being profiled.
                        This name will appear in the profiling summary.

        Yields:
            str: The action name that was provided.

        Example:

        ```python
        with self.profile("data_processing"):
            # Process data here
            result = process_large_dataset()
        ```
        """
        try:
            self.start_profile(action_name)
            yield action_name
        finally:
            self.stop_profile(action_name)

    def start_profile(self, action_name: str):
        """
        Start profiling for a named action.

        This method begins timing for the specified action. You must call
        stop_profile() with the same action name to complete the measurement.

        Args:
            action_name: A descriptive name for the action being profiled.
                        This name will appear in the profiling summary.

        Example:
            ```python
            self.start_profile("model_inference")
            result = model.predict(data)
            self.stop_profile("model_inference")
            ```
        """
        self.profiler.start(action_name)

    def stop_profile(self, action_name: str):
        """
        Stop profiling for a named action.

        This method ends timing for the specified action that was previously
        started with start_profile().

        Args:
            action_name: The name of the action to stop profiling.
                        Must match the name used in start_profile().

        Example:
            ```python
            self.start_profile("data_loading")
            data = load_data()
            self.stop_profile("data_loading")
            ```
        """
        self.profiler.stop(action_name)

    @rank_zero_only
    def print_profile_summary(self, title: Optional[str] = None):
        """
        Print a summary of all profiled actions.

        This method outputs a formatted summary showing the timing information
        for all actions that have been profiled. The output includes action names
        and their execution times.

        Args:
            title: Optional title to print before the profiling summary.
                  If provided, this will be printed as a header.

        Note:
            This method is decorated with @rank_zero_only, meaning it will only
            execute on the main process in distributed training scenarios.

        Example:
            ```python
            # After profiling some actions
            self.print_profile_summary("Training Performance Summary")
            ```
        """
        if title is not None:
            print(title)
        print(self.profiler.summary())

    def __del__(self):
        """
        Cleanup when the object is destroyed.

        Ensures that the profiler instance is properly cleaned up to prevent
        memory leaks when the mixin instance is garbage collected.
        """
        if self._profiler is not None:
            del self._profiler
            self._profiler = None

profiler property

Get the SimpleProfiler instance, creating it if necessary.

Returns:

  • SimpleProfiler ( SimpleProfiler ) –

    The profiler instance used for timing measurements.

__del__()

Cleanup when the object is destroyed.

Ensures that the profiler instance is properly cleaned up to prevent memory leaks when the mixin instance is garbage collected.

Source code in fusion_bench/mixins/simple_profiler.py
def __del__(self):
    """
    Cleanup when the object is destroyed.

    Ensures that the profiler instance is properly cleaned up to prevent
    memory leaks when the mixin instance is garbage collected.
    """
    if self._profiler is not None:
        del self._profiler
        self._profiler = None

print_profile_summary(title=None)

Print a summary of all profiled actions.

This method outputs a formatted summary showing the timing information for all actions that have been profiled. The output includes action names and their execution times.

Parameters:

  • title (Optional[str], default: None ) –

    Optional title to print before the profiling summary. If provided, this will be printed as a header.

Note

This method is decorated with @rank_zero_only, meaning it will only execute on the main process in distributed training scenarios.

Example
# After profiling some actions
self.print_profile_summary("Training Performance Summary")
Source code in fusion_bench/mixins/simple_profiler.py
@rank_zero_only
def print_profile_summary(self, title: Optional[str] = None):
    """
    Print a summary of all profiled actions.

    This method outputs a formatted summary showing the timing information
    for all actions that have been profiled. The output includes action names
    and their execution times.

    Args:
        title: Optional title to print before the profiling summary.
              If provided, this will be printed as a header.

    Note:
        This method is decorated with @rank_zero_only, meaning it will only
        execute on the main process in distributed training scenarios.

    Example:
        ```python
        # After profiling some actions
        self.print_profile_summary("Training Performance Summary")
        ```
    """
    if title is not None:
        print(title)
    print(self.profiler.summary())

profile(action_name)

Context manager for profiling a code block.

This context manager automatically starts profiling when entering the block and stops profiling when exiting the block (even if an exception occurs).

Parameters:

  • action_name (str) –

    A descriptive name for the action being profiled. This name will appear in the profiling summary.

Yields:

  • str ( Generator ) –

    The action name that was provided.

Example:

with self.profile("data_processing"):
    # Process data here
    result = process_large_dataset()
Source code in fusion_bench/mixins/simple_profiler.py
@contextmanager
def profile(self, action_name: str) -> Generator:
    """
    Context manager for profiling a code block.

    This context manager automatically starts profiling when entering the block
    and stops profiling when exiting the block (even if an exception occurs).

    Args:
        action_name: A descriptive name for the action being profiled.
                    This name will appear in the profiling summary.

    Yields:
        str: The action name that was provided.

    Example:

    ```python
    with self.profile("data_processing"):
        # Process data here
        result = process_large_dataset()
    ```
    """
    try:
        self.start_profile(action_name)
        yield action_name
    finally:
        self.stop_profile(action_name)

start_profile(action_name)

Start profiling for a named action.

This method begins timing for the specified action. You must call stop_profile() with the same action name to complete the measurement.

Parameters:

  • action_name (str) –

    A descriptive name for the action being profiled. This name will appear in the profiling summary.

Example
self.start_profile("model_inference")
result = model.predict(data)
self.stop_profile("model_inference")
Source code in fusion_bench/mixins/simple_profiler.py
def start_profile(self, action_name: str):
    """
    Start profiling for a named action.

    This method begins timing for the specified action. You must call
    stop_profile() with the same action name to complete the measurement.

    Args:
        action_name: A descriptive name for the action being profiled.
                    This name will appear in the profiling summary.

    Example:
        ```python
        self.start_profile("model_inference")
        result = model.predict(data)
        self.stop_profile("model_inference")
        ```
    """
    self.profiler.start(action_name)

stop_profile(action_name)

Stop profiling for a named action.

This method ends timing for the specified action that was previously started with start_profile().

Parameters:

  • action_name (str) –

    The name of the action to stop profiling. Must match the name used in start_profile().

Example
self.start_profile("data_loading")
data = load_data()
self.stop_profile("data_loading")
Source code in fusion_bench/mixins/simple_profiler.py
def stop_profile(self, action_name: str):
    """
    Stop profiling for a named action.

    This method ends timing for the specified action that was previously
    started with start_profile().

    Args:
        action_name: The name of the action to stop profiling.
                    Must match the name used in start_profile().

    Example:
        ```python
        self.start_profile("data_loading")
        data = load_data()
        self.stop_profile("data_loading")
        ```
    """
    self.profiler.stop(action_name)

PyinstrumentProfilerMixin

A mixin class that provides statistical profiling capabilities using pyinstrument.

This mixin allows for easy profiling of code blocks using a context manager. It provides methods to start and stop profiling actions, save profiling results to files, and print profiling summaries.

Note

This mixin requires the pyinstrument package to be installed. If not available, an ImportError will be raised when importing this module.

Examples:

class MyClass(PyinstrumentProfilerMixin):
    def do_something(self):
        with self.profile("work"):
            # do some work here
            ...

        # save the profiling results
        self.save_profile_report("profile_report.html")

        # or print the summary
        self.print_profile_summary()

Attributes:

  • _profiler (Profiler) –

    An instance of the pyinstrument Profiler class.

Source code in fusion_bench/mixins/pyinstrument.py
class PyinstrumentProfilerMixin:
    """
    A mixin class that provides statistical profiling capabilities using pyinstrument.

    This mixin allows for easy profiling of code blocks using a context manager.
    It provides methods to start and stop profiling actions, save profiling results
    to files, and print profiling summaries.

    Note:
        This mixin requires the `pyinstrument` package to be installed.
        If not available, an ImportError will be raised when importing this module.

    Examples:

    ```python
    class MyClass(PyinstrumentProfilerMixin):
        def do_something(self):
            with self.profile("work"):
                # do some work here
                ...

            # save the profiling results
            self.save_profile_report("profile_report.html")

            # or print the summary
            self.print_profile_summary()
    ```

    Attributes:
        _profiler (Profiler): An instance of the pyinstrument Profiler class.
    """

    _profiler: Optional[Profiler] = None
    _is_profiling: bool = False

    @property
    def profiler(self) -> Optional[Profiler]:
        """Get the profiler instance, creating it if necessary."""
        if self._profiler is None:
            self._profiler = Profiler()
        return self._profiler

    @contextmanager
    def profile(self, action_name: Optional[str] = None) -> Generator:
        """
        Context manager for profiling a code block.

        Args:
            action_name: Optional name for the profiling action (for logging purposes).

        Example:

        ```python
        with self.profile("expensive_operation"):
            # do some expensive work here
            expensive_function()
        ```
        """
        try:
            self.start_profile(action_name)
            yield action_name
        finally:
            self.stop_profile(action_name)

    def start_profile(self, action_name: Optional[str] = None):
        """
        Start profiling.

        Args:
            action_name: Optional name for the profiling action.
        """
        if self._is_profiling:
            return

        self.profiler.start()
        self._is_profiling = True
        if action_name:
            print(f"Started profiling: {action_name}")

    def stop_profile(self, action_name: Optional[str] = None):
        """
        Stop profiling.

        Args:
            action_name: Optional name for the profiling action.
        """
        if not self._is_profiling:
            return

        self.profiler.stop()
        self._is_profiling = False
        if action_name:
            print(f"Stopped profiling: {action_name}")

    @rank_zero_only
    def print_profile_summary(
        self, title: Optional[str] = None, unicode: bool = True, color: bool = True
    ):
        """
        Print a summary of the profiling results.

        Args:
            title: Optional title to print before the summary.
            unicode: Whether to use unicode characters in the output.
            color: Whether to use color in the output.
        """
        if self.profiler is None:
            print("No profiling data available.")
            return

        if title is not None:
            print(title)

        print(self.profiler.output_text(unicode=unicode, color=color))

    @rank_zero_only
    def save_profile_report(
        self,
        output_path: Union[str, Path] = "profile_report.html",
        format: str = "html",
        title: Optional[str] = None,
    ):
        """
        Save the profiling results to a file.

        Args:
            output_path: Path where to save the profiling report.
            format: Output format ('html', or 'text').
            title: Optional title for the report.
        """
        if self.profiler is None:
            print("No profiling data available.")
            return

        output_path = Path(output_path)
        output_path.parent.mkdir(parents=True, exist_ok=True)

        if format.lower() == "html":
            content = self.profiler.output_html()
        elif format.lower() == "text":
            content = self.profiler.output_text(unicode=True, color=False)
        else:
            raise ValueError(f"Unsupported format: {format}. Use 'html', or 'text'.")

        with open(output_path, "w", encoding="utf-8") as f:
            f.write(content)

        print(f"Profile report saved to: {output_path}")

    def reset_profile(self):
        """Reset the profiler to start fresh."""
        if self._is_profiling:
            self.stop_profile()

        self._profiler = None

    def __del__(self):
        """Cleanup when the object is destroyed."""
        if self._is_profiling:
            self.stop_profile()

        if self._profiler is not None:
            del self._profiler
            self._profiler = None

profiler property

Get the profiler instance, creating it if necessary.

__del__()

Cleanup when the object is destroyed.

Source code in fusion_bench/mixins/pyinstrument.py
def __del__(self):
    """Cleanup when the object is destroyed."""
    if self._is_profiling:
        self.stop_profile()

    if self._profiler is not None:
        del self._profiler
        self._profiler = None

print_profile_summary(title=None, unicode=True, color=True)

Print a summary of the profiling results.

Parameters:

  • title (Optional[str], default: None ) –

    Optional title to print before the summary.

  • unicode (bool, default: True ) –

    Whether to use unicode characters in the output.

  • color (bool, default: True ) –

    Whether to use color in the output.

Source code in fusion_bench/mixins/pyinstrument.py
@rank_zero_only
def print_profile_summary(
    self, title: Optional[str] = None, unicode: bool = True, color: bool = True
):
    """
    Print a summary of the profiling results.

    Args:
        title: Optional title to print before the summary.
        unicode: Whether to use unicode characters in the output.
        color: Whether to use color in the output.
    """
    if self.profiler is None:
        print("No profiling data available.")
        return

    if title is not None:
        print(title)

    print(self.profiler.output_text(unicode=unicode, color=color))

profile(action_name=None)

Context manager for profiling a code block.

Parameters:

  • action_name (Optional[str], default: None ) –

    Optional name for the profiling action (for logging purposes).

Example:

with self.profile("expensive_operation"):
    # do some expensive work here
    expensive_function()
Source code in fusion_bench/mixins/pyinstrument.py
@contextmanager
def profile(self, action_name: Optional[str] = None) -> Generator:
    """
    Context manager for profiling a code block.

    Args:
        action_name: Optional name for the profiling action (for logging purposes).

    Example:

    ```python
    with self.profile("expensive_operation"):
        # do some expensive work here
        expensive_function()
    ```
    """
    try:
        self.start_profile(action_name)
        yield action_name
    finally:
        self.stop_profile(action_name)

reset_profile()

Reset the profiler to start fresh.

Source code in fusion_bench/mixins/pyinstrument.py
def reset_profile(self):
    """Reset the profiler to start fresh."""
    if self._is_profiling:
        self.stop_profile()

    self._profiler = None

save_profile_report(output_path='profile_report.html', format='html', title=None)

Save the profiling results to a file.

Parameters:

  • output_path (Union[str, Path], default: 'profile_report.html' ) –

    Path where to save the profiling report.

  • format (str, default: 'html' ) –

    Output format ('html', or 'text').

  • title (Optional[str], default: None ) –

    Optional title for the report.

Source code in fusion_bench/mixins/pyinstrument.py
@rank_zero_only
def save_profile_report(
    self,
    output_path: Union[str, Path] = "profile_report.html",
    format: str = "html",
    title: Optional[str] = None,
):
    """
    Save the profiling results to a file.

    Args:
        output_path: Path where to save the profiling report.
        format: Output format ('html', or 'text').
        title: Optional title for the report.
    """
    if self.profiler is None:
        print("No profiling data available.")
        return

    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    if format.lower() == "html":
        content = self.profiler.output_html()
    elif format.lower() == "text":
        content = self.profiler.output_text(unicode=True, color=False)
    else:
        raise ValueError(f"Unsupported format: {format}. Use 'html', or 'text'.")

    with open(output_path, "w", encoding="utf-8") as f:
        f.write(content)

    print(f"Profile report saved to: {output_path}")

start_profile(action_name=None)

Start profiling.

Parameters:

  • action_name (Optional[str], default: None ) –

    Optional name for the profiling action.

Source code in fusion_bench/mixins/pyinstrument.py
def start_profile(self, action_name: Optional[str] = None):
    """
    Start profiling.

    Args:
        action_name: Optional name for the profiling action.
    """
    if self._is_profiling:
        return

    self.profiler.start()
    self._is_profiling = True
    if action_name:
        print(f"Started profiling: {action_name}")

stop_profile(action_name=None)

Stop profiling.

Parameters:

  • action_name (Optional[str], default: None ) –

    Optional name for the profiling action.

Source code in fusion_bench/mixins/pyinstrument.py
def stop_profile(self, action_name: Optional[str] = None):
    """
    Stop profiling.

    Args:
        action_name: Optional name for the profiling action.
    """
    if not self._is_profiling:
        return

    self.profiler.stop()
    self._is_profiling = False
    if action_name:
        print(f"Stopped profiling: {action_name}")

CLIPClassificationMixin

Bases: LightningFabricMixin

This mixin provides methods to classify images using the CLIP model.

Attributes need to be set by the inheriting class:

  • _dataloader_kwargs (Dict[str, Any]): Keyword arguments for the dataloader.
  • modelpool (CLIPVisionModelPool): The model pool containing the CLIP models.
Source code in fusion_bench/mixins/clip_classification.py
class CLIPClassificationMixin(LightningFabricMixin):
    """
    This mixin provides methods to classify images using the CLIP model.

    Attributes need to be set by the inheriting class:

    - `_dataloader_kwargs` (Dict[str, Any]): Keyword arguments for the dataloader.
    - `modelpool` (CLIPVisionModelPool): The model pool containing the CLIP models.
    """

    dataloader_kwargs: Dict[str, Any] = {}
    # the modelpool is set by inheriting class
    modelpool: CLIPVisionModelPool = None
    _clip_processor: CLIPProcessor = None
    # a dict of zeroshot weights for each task, each key is the task name
    zeroshot_weights: Dict[str, torch.Tensor] = {}
    whether_setup_zero_shot_classification_head = False

    @property
    def clip_processor(self):
        """
        Get the CLIP processor, loading it from the model pool if necessary.

        Returns:
            CLIPProcessor: The CLIP processor for image and text preprocessing.

        Raises:
            AssertionError: If the model pool is not set.
        """
        if self._clip_processor is None:
            assert self.modelpool is not None, "Model pool is not set"
            self._clip_processor = self.modelpool.load_processor()
        return self._clip_processor

    @functools.cache
    def get_shuffled_test_loader_iter(
        self,
        task: str,
        batch_size: Optional[int] = None,
        num_workers: Optional[int] = None,
        **loader_kwargs,
    ) -> Iterator:
        """
        Get an iterator for a shuffled test DataLoader.

        This method creates a DataLoader for the test dataset of the specified task,
        with shuffling enabled. It allows for optional customization of batch size,
        number of workers, and other DataLoader keyword arguments.

        Args:
            task (str): The task identifier for which the test dataset is to be loaded.
            batch_size (Optional[int]): The batch size to use for the DataLoader. If None, the default batch size is used.
            num_workers (Optional[int]): The number of worker processes to use for data loading. If None, the default number of workers is used.
            **loader_kwargs: Additional keyword arguments to pass to the DataLoader.

        Returns:
            Iterator: An iterator over the shuffled test DataLoader.
        """
        # get dataloader kwargs
        dataloader_kwargs = self.dataloader_kwargs.copy()
        dataloader_kwargs["shuffle"] = True
        if batch_size is not None:
            dataloader_kwargs["batch_size"] = batch_size
        if num_workers is not None:
            dataloader_kwargs["num_workers"] = num_workers
        dataloader_kwargs.update(loader_kwargs)

        # get the test dataset
        clip_dataset = CLIPDataset(
            self.modelpool.load_test_dataset(task), self.clip_processor
        )
        # create the dataloader
        loader = DataLoader(clip_dataset, **dataloader_kwargs)
        loader = self.fabric.setup_dataloaders(loader)
        return iter(InfiniteDataLoader(loader))

    @torch.no_grad()
    def setup_zero_shot_classification_head(
        self,
        clip_processor: Optional[CLIPProcessor] = None,
        clip_model: Optional[CLIPModel] = None,
        task_names: Optional[List[str]] = None,
    ):
        """
        Initializes a zero-shot classification head.

        This method constructs a zero-shot classification head by generating text embeddings for each class name using a set of templates.
        These embeddings function as the weights of the classification layer. The method also extracts the `visual_projection` and `logit_scale`
        from the provided CLIP model, which are necessary for calculating the final logits.

        Args:
            clip_processor (Optional[CLIPProcessor]): The processor for the CLIP model. If not provided, it is loaded from the model pool.
            clip_model (Optional[CLIPModel]): The CLIP model to use. If not provided, a pretrained model is loaded from the model pool.
            task_names (Optional[List[str]]): A list of task names to set up the classification head for. If not provided, all models in the model pool will be used.
        """
        # make sure the task names are equal across all processes
        _task_names = self.fabric.broadcast(task_names, src=0)
        if not self.fabric.is_global_zero and task_names != _task_names:
            raise ValueError("The `task_names` must be the same across all processes.")

        self.whether_setup_zero_shot_classification_head = True
        # load clip model if not provided
        if clip_model is None:
            if self.modelpool.has_pretrained:
                clip_model = self.modelpool.load_clip_model("_pretrained_")
            else:
                log.warning(
                    f"No pretrained CLIP model found, using the model from the model pool: {self.modelpool.model_names[0]}."
                )
                clip_model = self.modelpool.load_clip_model(
                    self.modelpool.model_names[0]
                )
        if clip_processor is None:
            clip_processor = self.clip_processor
        clip_classifier = HFCLIPClassifier(clip_model, clip_processor)
        self.visual_projection = deepcopy(clip_model.visual_projection)
        self.visual_projection.requires_grad_(False)
        self.logit_scale_exp = clip_model.logit_scale.data.clone().exp()
        self.visual_projection = self.fabric.to_device(self.visual_projection)
        self.logit_scale_exp = self.fabric.to_device(self.logit_scale_exp)

        @cache_with_joblib()
        def construct_classification_head(task: str, model_name: str):
            log.info(
                f"Constructing zero-shot classification head for task: {task} using model: {model_name}"
            )
            nonlocal clip_classifier

            classnames, templates = get_classnames_and_templates(task)
            clip_classifier.set_classification_task(classnames, templates)
            zeroshot_weights = clip_classifier.zeroshot_weights.detach().clone()

            return zeroshot_weights

        for task in tqdm(
            self.modelpool.model_names if task_names is None else task_names,
            "Setting up zero-shot classification head",
            disable=not self.fabric.is_global_zero,
        ):
            zeroshot_weights = None
            if self.fabric.is_global_zero:
                if hasattr(clip_model, "config") and hasattr(
                    clip_model.config, "_name_or_path"
                ):
                    model_name = clip_model.config._name_or_path
                else:
                    model_name = "unknown_model"
                    log.warning(
                        "CLIP model config does not have `_name_or_path` attribute. Using 'unknown_model' as model name."
                    )
                zeroshot_weights = construct_classification_head(
                    task, model_name=model_name
                )

            self.fabric.barrier()
            self.zeroshot_weights[task] = self.fabric.broadcast(zeroshot_weights, src=0)
            self.zeroshot_weights[task] = self.to_device(self.zeroshot_weights[task])
            self.fabric.barrier()

        del clip_classifier
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    def compute_logits(
        self,
        module: Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"],
        images: torch.Tensor,
        task: str,
        image_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Computes the classification logits for a batch of images for a specific task.

        This method performs zero-shot classification by calculating the cosine similarity between image and text embeddings.
        The image embeddings are obtained from the provided vision model, and the text embeddings (zero-shot weights) are pre-computed for the task.
        The similarity scores are then scaled by the CLIP model's `logit_scale` to produce the final logits.

        Args:
            module (Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"]): The vision encoder part of the CLIP model.
            images (torch.Tensor): A batch of images to classify.
            task (str): The name of the classification task.
            image_embeds (Optional[torch.Tensor]): Pre-computed image embeddings. If provided, the method skips the image encoding step.

        Returns:
            torch.Tensor: A tensor of logits for each image, with shape (batch_size, num_classes).
        """
        text_embeds = self.zeroshot_weights[task]

        if image_embeds is None:
            image_embeds = module(images)[1]
        assert isinstance(
            image_embeds, torch.Tensor
        ), f"`image_embeds` must be a tensor, but got {type(image_embeds)}"
        image_embeds = self.visual_projection(image_embeds)

        # normalize embeddings
        image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)

        # cosine similarity
        logits_per_text = (
            torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
        )
        logits_per_image = logits_per_text.t()

        return logits_per_image

    def compute_features(
        self,
        module: Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"],
        images: torch.Tensor,
        normalize: bool = True,
    ) -> torch.Tensor:
        """
        Extracts image features using CLIP's vision encoder and visual projection.

        Args:
            module (Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"]): The CLIP vision encoder module.
            images (torch.Tensor): Input image batch to process.
            normalize (bool): Whether to normalize the image embeddings.

        Returns:
            torch.Tensor: Normalized image embeddings with dimension matching CLIP's projection space (`projection_dim` in model config).
        """
        image_embeds = module(images)[1]
        image_embeds = self.visual_projection(image_embeds)

        if normalize:
            image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
        return image_embeds

clip_processor property

Get the CLIP processor, loading it from the model pool if necessary.

Returns:

  • CLIPProcessor

    The CLIP processor for image and text preprocessing.

Raises:

  • AssertionError

    If the model pool is not set.

compute_features(module, images, normalize=True)

Extracts image features using CLIP's vision encoder and visual projection.

Parameters:

  • module (Union[Module, CLIPVisionModel, CLIPVisionTransformer]) –

    The CLIP vision encoder module.

  • images (Tensor) –

    Input image batch to process.

  • normalize (bool, default: True ) –

    Whether to normalize the image embeddings.

Returns:

  • Tensor

    torch.Tensor: Normalized image embeddings with dimension matching CLIP's projection space (projection_dim in model config).

Source code in fusion_bench/mixins/clip_classification.py
def compute_features(
    self,
    module: Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"],
    images: torch.Tensor,
    normalize: bool = True,
) -> torch.Tensor:
    """
    Extracts image features using CLIP's vision encoder and visual projection.

    Args:
        module (Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"]): The CLIP vision encoder module.
        images (torch.Tensor): Input image batch to process.
        normalize (bool): Whether to normalize the image embeddings.

    Returns:
        torch.Tensor: Normalized image embeddings with dimension matching CLIP's projection space (`projection_dim` in model config).
    """
    image_embeds = module(images)[1]
    image_embeds = self.visual_projection(image_embeds)

    if normalize:
        image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
    return image_embeds

compute_logits(module, images, task, image_embeds=None)

Computes the classification logits for a batch of images for a specific task.

This method performs zero-shot classification by calculating the cosine similarity between image and text embeddings. The image embeddings are obtained from the provided vision model, and the text embeddings (zero-shot weights) are pre-computed for the task. The similarity scores are then scaled by the CLIP model's logit_scale to produce the final logits.

Parameters:

  • module (Union[Module, CLIPVisionModel, CLIPVisionTransformer]) –

    The vision encoder part of the CLIP model.

  • images (Tensor) –

    A batch of images to classify.

  • task (str) –

    The name of the classification task.

  • image_embeds (Optional[Tensor], default: None ) –

    Pre-computed image embeddings. If provided, the method skips the image encoding step.

Returns:

  • Tensor

    torch.Tensor: A tensor of logits for each image, with shape (batch_size, num_classes).

Source code in fusion_bench/mixins/clip_classification.py
def compute_logits(
    self,
    module: Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"],
    images: torch.Tensor,
    task: str,
    image_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """
    Computes the classification logits for a batch of images for a specific task.

    This method performs zero-shot classification by calculating the cosine similarity between image and text embeddings.
    The image embeddings are obtained from the provided vision model, and the text embeddings (zero-shot weights) are pre-computed for the task.
    The similarity scores are then scaled by the CLIP model's `logit_scale` to produce the final logits.

    Args:
        module (Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"]): The vision encoder part of the CLIP model.
        images (torch.Tensor): A batch of images to classify.
        task (str): The name of the classification task.
        image_embeds (Optional[torch.Tensor]): Pre-computed image embeddings. If provided, the method skips the image encoding step.

    Returns:
        torch.Tensor: A tensor of logits for each image, with shape (batch_size, num_classes).
    """
    text_embeds = self.zeroshot_weights[task]

    if image_embeds is None:
        image_embeds = module(images)[1]
    assert isinstance(
        image_embeds, torch.Tensor
    ), f"`image_embeds` must be a tensor, but got {type(image_embeds)}"
    image_embeds = self.visual_projection(image_embeds)

    # normalize embeddings
    image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)

    # cosine similarity
    logits_per_text = (
        torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
    )
    logits_per_image = logits_per_text.t()

    return logits_per_image

get_shuffled_test_loader_iter(task, batch_size=None, num_workers=None, **loader_kwargs) cached

Get an iterator for a shuffled test DataLoader.

This method creates a DataLoader for the test dataset of the specified task, with shuffling enabled. It allows for optional customization of batch size, number of workers, and other DataLoader keyword arguments.

Parameters:

  • task (str) –

    The task identifier for which the test dataset is to be loaded.

  • batch_size (Optional[int], default: None ) –

    The batch size to use for the DataLoader. If None, the default batch size is used.

  • num_workers (Optional[int], default: None ) –

    The number of worker processes to use for data loading. If None, the default number of workers is used.

  • **loader_kwargs

    Additional keyword arguments to pass to the DataLoader.

Returns:

  • Iterator ( Iterator ) –

    An iterator over the shuffled test DataLoader.

Source code in fusion_bench/mixins/clip_classification.py
@functools.cache
def get_shuffled_test_loader_iter(
    self,
    task: str,
    batch_size: Optional[int] = None,
    num_workers: Optional[int] = None,
    **loader_kwargs,
) -> Iterator:
    """
    Get an iterator for a shuffled test DataLoader.

    This method creates a DataLoader for the test dataset of the specified task,
    with shuffling enabled. It allows for optional customization of batch size,
    number of workers, and other DataLoader keyword arguments.

    Args:
        task (str): The task identifier for which the test dataset is to be loaded.
        batch_size (Optional[int]): The batch size to use for the DataLoader. If None, the default batch size is used.
        num_workers (Optional[int]): The number of worker processes to use for data loading. If None, the default number of workers is used.
        **loader_kwargs: Additional keyword arguments to pass to the DataLoader.

    Returns:
        Iterator: An iterator over the shuffled test DataLoader.
    """
    # get dataloader kwargs
    dataloader_kwargs = self.dataloader_kwargs.copy()
    dataloader_kwargs["shuffle"] = True
    if batch_size is not None:
        dataloader_kwargs["batch_size"] = batch_size
    if num_workers is not None:
        dataloader_kwargs["num_workers"] = num_workers
    dataloader_kwargs.update(loader_kwargs)

    # get the test dataset
    clip_dataset = CLIPDataset(
        self.modelpool.load_test_dataset(task), self.clip_processor
    )
    # create the dataloader
    loader = DataLoader(clip_dataset, **dataloader_kwargs)
    loader = self.fabric.setup_dataloaders(loader)
    return iter(InfiniteDataLoader(loader))

setup_zero_shot_classification_head(clip_processor=None, clip_model=None, task_names=None)

Initializes a zero-shot classification head.

This method constructs a zero-shot classification head by generating text embeddings for each class name using a set of templates. These embeddings function as the weights of the classification layer. The method also extracts the visual_projection and logit_scale from the provided CLIP model, which are necessary for calculating the final logits.

Parameters:

  • clip_processor (Optional[CLIPProcessor], default: None ) –

    The processor for the CLIP model. If not provided, it is loaded from the model pool.

  • clip_model (Optional[CLIPModel], default: None ) –

    The CLIP model to use. If not provided, a pretrained model is loaded from the model pool.

  • task_names (Optional[List[str]], default: None ) –

    A list of task names to set up the classification head for. If not provided, all models in the model pool will be used.

Source code in fusion_bench/mixins/clip_classification.py
@torch.no_grad()
def setup_zero_shot_classification_head(
    self,
    clip_processor: Optional[CLIPProcessor] = None,
    clip_model: Optional[CLIPModel] = None,
    task_names: Optional[List[str]] = None,
):
    """
    Initializes a zero-shot classification head.

    This method constructs a zero-shot classification head by generating text embeddings for each class name using a set of templates.
    These embeddings function as the weights of the classification layer. The method also extracts the `visual_projection` and `logit_scale`
    from the provided CLIP model, which are necessary for calculating the final logits.

    Args:
        clip_processor (Optional[CLIPProcessor]): The processor for the CLIP model. If not provided, it is loaded from the model pool.
        clip_model (Optional[CLIPModel]): The CLIP model to use. If not provided, a pretrained model is loaded from the model pool.
        task_names (Optional[List[str]]): A list of task names to set up the classification head for. If not provided, all models in the model pool will be used.
    """
    # make sure the task names are equal across all processes
    _task_names = self.fabric.broadcast(task_names, src=0)
    if not self.fabric.is_global_zero and task_names != _task_names:
        raise ValueError("The `task_names` must be the same across all processes.")

    self.whether_setup_zero_shot_classification_head = True
    # load clip model if not provided
    if clip_model is None:
        if self.modelpool.has_pretrained:
            clip_model = self.modelpool.load_clip_model("_pretrained_")
        else:
            log.warning(
                f"No pretrained CLIP model found, using the model from the model pool: {self.modelpool.model_names[0]}."
            )
            clip_model = self.modelpool.load_clip_model(
                self.modelpool.model_names[0]
            )
    if clip_processor is None:
        clip_processor = self.clip_processor
    clip_classifier = HFCLIPClassifier(clip_model, clip_processor)
    self.visual_projection = deepcopy(clip_model.visual_projection)
    self.visual_projection.requires_grad_(False)
    self.logit_scale_exp = clip_model.logit_scale.data.clone().exp()
    self.visual_projection = self.fabric.to_device(self.visual_projection)
    self.logit_scale_exp = self.fabric.to_device(self.logit_scale_exp)

    @cache_with_joblib()
    def construct_classification_head(task: str, model_name: str):
        log.info(
            f"Constructing zero-shot classification head for task: {task} using model: {model_name}"
        )
        nonlocal clip_classifier

        classnames, templates = get_classnames_and_templates(task)
        clip_classifier.set_classification_task(classnames, templates)
        zeroshot_weights = clip_classifier.zeroshot_weights.detach().clone()

        return zeroshot_weights

    for task in tqdm(
        self.modelpool.model_names if task_names is None else task_names,
        "Setting up zero-shot classification head",
        disable=not self.fabric.is_global_zero,
    ):
        zeroshot_weights = None
        if self.fabric.is_global_zero:
            if hasattr(clip_model, "config") and hasattr(
                clip_model.config, "_name_or_path"
            ):
                model_name = clip_model.config._name_or_path
            else:
                model_name = "unknown_model"
                log.warning(
                    "CLIP model config does not have `_name_or_path` attribute. Using 'unknown_model' as model name."
                )
            zeroshot_weights = construct_classification_head(
                task, model_name=model_name
            )

        self.fabric.barrier()
        self.zeroshot_weights[task] = self.fabric.broadcast(zeroshot_weights, src=0)
        self.zeroshot_weights[task] = self.to_device(self.zeroshot_weights[task])
        self.fabric.barrier()

    del clip_classifier
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

auto_register_config(cls)

Decorator to automatically register init parameters in _config_mapping.

This decorator enhances classes that inherit from YAMLSerializationMixin by automatically mapping constructor parameters to configuration keys and dynamically setting instance attributes based on provided arguments.

The decorator performs the following operations: 1. Inspects the class's init method signature 2. Automatically populates the _config_mapping dictionary with parameter names 3. Wraps the init method to handle both positional and keyword arguments 4. Sets instance attributes for all constructor parameters 5. Applies default values when parameters are not provided

Parameters:

  • cls (YAMLSerializationMixin) –

    The class to be decorated. Must inherit from YAMLSerializationMixin to ensure proper serialization capabilities.

Returns:

  • YAMLSerializationMixin

    The decorated class with enhanced auto-registration functionality and modified init behavior.

Behavior
  • Parameter Registration: All non-variadic parameters (excluding *args, **kwargs) from the init method are automatically added to _config_mapping
  • Positional Arguments: Handled in order and mapped to corresponding parameter names
  • Keyword Arguments: Processed after positional arguments, overriding any conflicts
  • Default Values: Applied when parameters are not provided via arguments
  • Attribute Setting: All parameters become instance attributes accessible via dot notation
Note
  • The decorator wraps the original init method while preserving its signature for IDE support
  • Parameters with *args or **kwargs signatures are ignored during registration
  • The attributes are auto-registered, then the original init method is called,
  • Type hints, method name, and other metadata are preserved using functools.wraps
  • This decorator is designed to work seamlessly with the YAML serialization system

Raises:

  • AttributeError

    If the class does not have the required _config_mapping attribute infrastructure (should inherit from YAMLSerializationMixin)

Source code in fusion_bench/mixins/serialization.py
def auto_register_config(cls):
    """
    Decorator to automatically register __init__ parameters in _config_mapping.

    This decorator enhances classes that inherit from YAMLSerializationMixin by
    automatically mapping constructor parameters to configuration keys and
    dynamically setting instance attributes based on provided arguments.

    The decorator performs the following operations:
    1. Inspects the class's __init__ method signature
    2. Automatically populates the _config_mapping dictionary with parameter names
    3. Wraps the __init__ method to handle both positional and keyword arguments
    4. Sets instance attributes for all constructor parameters
    5. Applies default values when parameters are not provided

    Args:
        cls (YAMLSerializationMixin): The class to be decorated. Must inherit from
            YAMLSerializationMixin to ensure proper serialization capabilities.

    Returns:
        YAMLSerializationMixin: The decorated class with enhanced auto-registration
            functionality and modified __init__ behavior.

    Behavior:
        - **Parameter Registration**: All non-variadic parameters (excluding ``*args``, ``**kwargs``)
            from the __init__ method are automatically added to _config_mapping
        - **Positional Arguments**: Handled in order and mapped to corresponding parameter names
        - **Keyword Arguments**: Processed after positional arguments, overriding any conflicts
        - **Default Values**: Applied when parameters are not provided via arguments
        - **Attribute Setting**: All parameters become instance attributes accessible via dot notation

    Note:
        - The decorator wraps the original __init__ method while preserving its signature for IDE support
        - Parameters with ``*args`` or ``**kwargs`` signatures are ignored during registration
        - The attributes are auto-registered, then the original __init__ method is called,
        - Type hints, method name, and other metadata are preserved using functools.wraps
        - This decorator is designed to work seamlessly with the YAML serialization system

    Raises:
        AttributeError: If the class does not have the required _config_mapping attribute
            infrastructure (should inherit from YAMLSerializationMixin)
    """
    original_init = cls.__init__
    sig = inspect.signature(original_init)

    # Auto-register parameters in _config_mapping
    if not "_config_mapping" in cls.__dict__:
        cls._config_mapping = deepcopy(getattr(cls, "_config_mapping", bidict()))
    if not isinstance(cls._config_mapping, bidict):
        cls._config_mapping = bidict(cls._config_mapping)

    registered_parameters = tuple(cls._config_mapping.values())

    for param_name in list(sig.parameters.keys())[1:]:  # Skip 'self'
        if (
            sig.parameters[param_name].kind
            not in [
                _ParameterKind.VAR_POSITIONAL,
                _ParameterKind.VAR_KEYWORD,
            ]
        ) and (param_name not in registered_parameters):
            cls._config_mapping[param_name] = param_name

    @wraps(original_init)
    def __init__(self, *args, **kwargs):
        log.debug(f"set attributes for {self.__class__.__name__} in {cls.__name__}")
        # auto-register the attributes based on the signature
        sig = inspect.signature(original_init)
        param_names = list(sig.parameters.keys())[1:]  # Skip 'self'

        # Handle positional arguments
        for i, arg_value in enumerate(args):
            if i < len(param_names):
                param_name = param_names[i]
                if sig.parameters[param_name].kind not in [
                    _ParameterKind.VAR_POSITIONAL,
                    _ParameterKind.VAR_KEYWORD,
                ]:
                    _set_attr(self, param_name, arg_value)

        # Handle keyword arguments and defaults
        for param_name in param_names:
            if sig.parameters[param_name].kind not in [
                _ParameterKind.VAR_POSITIONAL,
                _ParameterKind.VAR_KEYWORD,
            ]:
                # Skip if already set by positional argument
                param_index = param_names.index(param_name)
                if param_index >= 0 and param_index < len(args):
                    continue

                if param_name in kwargs:
                    _set_attr(self, param_name, kwargs[param_name])
                else:
                    # Set default value if available and attribute doesn't exist
                    default_value = sig.parameters[param_name].default
                    if default_value is not Parameter.empty:
                        _set_attr(self, param_name, default_value)

        # Call the original __init__
        result = original_init(self, *args, **kwargs)
        return result

    # Replace the original __init__ method while preserving its signature
    cls.__init__ = __init__
    return cls