Skip to content

fusion_bench.mixins

Class Definitions

References

HydraConfigMixin

A mixin class that provides configuration-based instantiation capabilities.

This mixin enables classes to be instantiated directly from Hydra configuration files, supporting both direct instantiation and target-based instantiation patterns. It's particularly useful in FusionBench for creating model pools, task pools, and fusion algorithms from YAML configurations.

The mixin handles: - Configuration loading and composition - Target class validation - Nested configuration group navigation - Object instantiation with proper error handling

Example:

class MyAlgorithm(HydraConfigMixin):
    def __init__(self, param1: str, param2: int = 10):
        self.param1 = param1
        self.param2 = param2

# Instantiate from config
algorithm = MyAlgorithm.from_config("algorithms/my_algorithm")
Note

This mixin requires Hydra to be properly initialized before use. Typically, this is handled by the main FusionBench CLI application.

Source code in fusion_bench/mixins/hydra_config.py
class HydraConfigMixin:
    R"""
    A mixin class that provides configuration-based instantiation capabilities.

    This mixin enables classes to be instantiated directly from Hydra configuration
    files, supporting both direct instantiation and target-based instantiation patterns.
    It's particularly useful in FusionBench for creating model pools, task pools,
    and fusion algorithms from YAML configurations.

    The mixin handles:
    - Configuration loading and composition
    - Target class validation
    - Nested configuration group navigation
    - Object instantiation with proper error handling

    Example:

    ```python
    class MyAlgorithm(HydraConfigMixin):
        def __init__(self, param1: str, param2: int = 10):
            self.param1 = param1
            self.param2 = param2

    # Instantiate from config
    algorithm = MyAlgorithm.from_config("algorithms/my_algorithm")
    ```

    Note:
        This mixin requires Hydra to be properly initialized before use.
        Typically, this is handled by the main FusionBench CLI application.
    """

    @classmethod
    def from_config(
        cls,
        config_name: Union[str, Path],
        overrides: Optional[List[str]] = None,
    ):
        """
        Create an instance of the class from a Hydra configuration.

        This method loads a Hydra configuration file and instantiates the class
        using the configuration parameters. It supports both direct parameter
        passing and target-based instantiation patterns.

        Args:
            config_name: The name/path of the configuration file to load.
                        Can be a string like "algorithms/simple_average" or
                        a Path object. The .yaml extension is optional.
            overrides: Optional list of configuration overrides in the format
                      ["key=value", "nested.key=value"]. These allow runtime
                      modification of configuration parameters.

        Returns:
            An instance of the class configured according to the loaded configuration.

        Raises:
            RuntimeError: If Hydra is not properly initialized.
            ImportError: If a target class specified in the config cannot be imported.
            ValueError: If required configuration parameters are missing.

        Example:
            ```python
            # Load with basic config
            obj = MyClass.from_config("my_config")

            # Load with overrides
            obj = MyClass.from_config(
                "my_config",
                overrides=["param1=new_value", "param2=42"]
            )

            # Load nested config
            obj = MyClass.from_config("category/subcategory/my_config")
            ```

        Note:
            The method automatically handles nested configuration groups by
            navigating through the configuration hierarchy based on the
            config_name path structure.
        """
        # Verify Hydra initialization
        if not hydra.core.global_hydra.GlobalHydra.instance().is_initialized():
            raise RuntimeError(
                "Hydra is not initialized. Please ensure Hydra is properly "
                "initialized before calling from_config(). This is typically "
                "handled by the FusionBench CLI application."
            )
        else:
            # Compose the configuration with any provided overrides
            cfg = compose(config_name=config_name, overrides=overrides)

        # Navigate through nested configuration groups
        # E.g., "algorithms/simple_average" -> navigate to cfg.algorithms
        config_groups = config_name.split("/")[:-1]
        for config_group in config_groups:
            cfg = cfg[config_group]

        # Handle target-based instantiation
        if "_target_" in cfg:
            # Validate that the target class matches the calling class
            target_cls = import_object(cfg["_target_"])
            if target_cls != cls:
                log.warning(
                    f"Configuration target mismatch: config specifies "
                    f"'{cfg['_target_']}' but called on class '{cls.__name__}'. "
                    f"This may indicate a configuration error."
                )

            # Instantiate using the target pattern with function call logging disabled
            with set_print_function_call(False):
                obj = instantiate(cfg)
        else:
            # Direct instantiation using configuration as keyword arguments
            obj = cls(**cfg)

        return obj

from_config(config_name, overrides=None) classmethod

Create an instance of the class from a Hydra configuration.

This method loads a Hydra configuration file and instantiates the class using the configuration parameters. It supports both direct parameter passing and target-based instantiation patterns.

Parameters:

  • config_name (Union[str, Path]) –

    The name/path of the configuration file to load. Can be a string like "algorithms/simple_average" or a Path object. The .yaml extension is optional.

  • overrides (Optional[List[str]], default: None ) –

    Optional list of configuration overrides in the format ["key=value", "nested.key=value"]. These allow runtime modification of configuration parameters.

Returns:

  • An instance of the class configured according to the loaded configuration.

Raises:

  • RuntimeError

    If Hydra is not properly initialized.

  • ImportError

    If a target class specified in the config cannot be imported.

  • ValueError

    If required configuration parameters are missing.

Example
# Load with basic config
obj = MyClass.from_config("my_config")

# Load with overrides
obj = MyClass.from_config(
    "my_config",
    overrides=["param1=new_value", "param2=42"]
)

# Load nested config
obj = MyClass.from_config("category/subcategory/my_config")
Note

The method automatically handles nested configuration groups by navigating through the configuration hierarchy based on the config_name path structure.

Source code in fusion_bench/mixins/hydra_config.py
@classmethod
def from_config(
    cls,
    config_name: Union[str, Path],
    overrides: Optional[List[str]] = None,
):
    """
    Create an instance of the class from a Hydra configuration.

    This method loads a Hydra configuration file and instantiates the class
    using the configuration parameters. It supports both direct parameter
    passing and target-based instantiation patterns.

    Args:
        config_name: The name/path of the configuration file to load.
                    Can be a string like "algorithms/simple_average" or
                    a Path object. The .yaml extension is optional.
        overrides: Optional list of configuration overrides in the format
                  ["key=value", "nested.key=value"]. These allow runtime
                  modification of configuration parameters.

    Returns:
        An instance of the class configured according to the loaded configuration.

    Raises:
        RuntimeError: If Hydra is not properly initialized.
        ImportError: If a target class specified in the config cannot be imported.
        ValueError: If required configuration parameters are missing.

    Example:
        ```python
        # Load with basic config
        obj = MyClass.from_config("my_config")

        # Load with overrides
        obj = MyClass.from_config(
            "my_config",
            overrides=["param1=new_value", "param2=42"]
        )

        # Load nested config
        obj = MyClass.from_config("category/subcategory/my_config")
        ```

    Note:
        The method automatically handles nested configuration groups by
        navigating through the configuration hierarchy based on the
        config_name path structure.
    """
    # Verify Hydra initialization
    if not hydra.core.global_hydra.GlobalHydra.instance().is_initialized():
        raise RuntimeError(
            "Hydra is not initialized. Please ensure Hydra is properly "
            "initialized before calling from_config(). This is typically "
            "handled by the FusionBench CLI application."
        )
    else:
        # Compose the configuration with any provided overrides
        cfg = compose(config_name=config_name, overrides=overrides)

    # Navigate through nested configuration groups
    # E.g., "algorithms/simple_average" -> navigate to cfg.algorithms
    config_groups = config_name.split("/")[:-1]
    for config_group in config_groups:
        cfg = cfg[config_group]

    # Handle target-based instantiation
    if "_target_" in cfg:
        # Validate that the target class matches the calling class
        target_cls = import_object(cfg["_target_"])
        if target_cls != cls:
            log.warning(
                f"Configuration target mismatch: config specifies "
                f"'{cfg['_target_']}' but called on class '{cls.__name__}'. "
                f"This may indicate a configuration error."
            )

        # Instantiate using the target pattern with function call logging disabled
        with set_print_function_call(False):
            obj = instantiate(cfg)
    else:
        # Direct instantiation using configuration as keyword arguments
        obj = cls(**cfg)

    return obj

YAMLSerializationMixin

Source code in fusion_bench/mixins/serialization.py
class YAMLSerializationMixin:
    _recursive_: bool = False
    _config_key: Optional[str] = None
    _config_mapping: Dict[str, str] = {
        "_recursive_": "_recursive_",
    }
    R"""
    `_config_mapping` is a dictionary mapping the attribute names of the class to the config option names. This is used to convert the class to a DictConfig.

    For example, if an algorithm class is defined as follows:

    ```python
    class SomeModelFusionAlgorithm(BaseModelFusionAlgorithm):
        hyper_parameter_1 = None
        hyper_parameter_2 = None

        _config_mapping = BaseModelFusionAlgorithm._config_mapping | {
            "hyper_parameter_1" : "hyper_param_1",
            "hyper_parameter_2" : "hyper_param_2",
        }
        def __init__(self, hyper_param_1: int, hyper_param_2: int):
            self.hyper_parameter_1 = hyper_param_1
            self.hyper_parameter_2 = hyper_param_2
            super().__init__()
    ```

    The model pool will be converted to a DictConfig as follows:

    ```python
    algorithm = SomeModelFusionAlgorithm(hyper_param_1=1, hyper_param_2=2)
    ```

    >>> algorithm.config
        DictCOnfig({'_target_': 'SomeModelFusionAlgorithm', 'hyper_param_1': 1, 'hyper_param_2': 2})

    By default, the `_target_` key is set to the class name as `type(self).__name__`.
    """

    def __init__(
        self,
        _recursive_: bool = False,
        **kwargs,
    ) -> None:
        self._recursive_ = _recursive_
        for key, value in kwargs.items():
            log.warning(f"Unused argument: {key}={value}")

    @property
    def config(self):
        R"""
        Returns the configuration of the model pool as a DictConfig.

        This property calls the `to_config` method to convert the model pool
        instance into a dictionary configuration, which can be used for
        serialization or other purposes.

        Example:

        ```python
        model = SomeModelFusionAlgorithm(hyper_param_1=1, hyper_param_2=2)
        config = model.config
        print(config)
        # DictConfig({'_target_': 'SomeModelFusionAlgorithm', 'hyper_param_1': 1, 'hyper_param_2': 2})
        ```

        This is useful for serializing the object to a YAML file or for debugging.

        Returns:
            DictConfig: The configuration of the model pool.
        """
        return self.to_config()

    def to_yaml(self, path: Union[str, Path]):
        """
        Save the model pool to a YAML file.

        Args:
            path (Union[str, Path]): The path to save the model pool to.
        """
        config = self.to_config()
        OmegaConf.save(config, path, resolve=True)

    @classmethod
    def from_yaml(cls, path: Union[str, Path]):
        """
        Load a model pool from a YAML file.

        Args:
            path (Union[str, Path]): The path to load the model pool from.

        Returns:
            BaseModelPool: The loaded model pool.
        """
        config = OmegaConf.load(path)
        if cls._config_key is not None and cls._config_key in config:
            config = config[cls._config_key]
        target_cls = import_object(config["_target_"])
        if target_cls != cls:
            log.warning(
                f"The class {target_cls.__name__} is not the same as the class {cls.__name__}. "
                f"Instantiating the class {target_cls.__name__} instead."
            )
        return instantiate(
            config,
            _recursive_=(
                cls._recursive_
                if config.get("_recursive_") is None
                else config.get("_recursive_")
            ),
        )

    def to_config(self):
        """
        Convert the model pool to a DictConfig.

        Returns:
            Dict: The model pool as a DictConfig.
        """
        config = {"_target_": type(self).__name__}
        for attr, key in self._config_mapping.items():
            if hasattr(self, attr):
                config[key] = getattr(self, attr)
        return OmegaConf.create(config)

config property

Returns the configuration of the model pool as a DictConfig.

This property calls the to_config method to convert the model pool instance into a dictionary configuration, which can be used for serialization or other purposes.

Example:

model = SomeModelFusionAlgorithm(hyper_param_1=1, hyper_param_2=2)
config = model.config
print(config)
# DictConfig({'_target_': 'SomeModelFusionAlgorithm', 'hyper_param_1': 1, 'hyper_param_2': 2})

This is useful for serializing the object to a YAML file or for debugging.

Returns:

  • DictConfig

    The configuration of the model pool.

from_yaml(path) classmethod

Load a model pool from a YAML file.

Parameters:

  • path (Union[str, Path]) –

    The path to load the model pool from.

Returns:

  • BaseModelPool

    The loaded model pool.

Source code in fusion_bench/mixins/serialization.py
@classmethod
def from_yaml(cls, path: Union[str, Path]):
    """
    Load a model pool from a YAML file.

    Args:
        path (Union[str, Path]): The path to load the model pool from.

    Returns:
        BaseModelPool: The loaded model pool.
    """
    config = OmegaConf.load(path)
    if cls._config_key is not None and cls._config_key in config:
        config = config[cls._config_key]
    target_cls = import_object(config["_target_"])
    if target_cls != cls:
        log.warning(
            f"The class {target_cls.__name__} is not the same as the class {cls.__name__}. "
            f"Instantiating the class {target_cls.__name__} instead."
        )
    return instantiate(
        config,
        _recursive_=(
            cls._recursive_
            if config.get("_recursive_") is None
            else config.get("_recursive_")
        ),
    )

to_config()

Convert the model pool to a DictConfig.

Returns:

  • Dict

    The model pool as a DictConfig.

Source code in fusion_bench/mixins/serialization.py
def to_config(self):
    """
    Convert the model pool to a DictConfig.

    Returns:
        Dict: The model pool as a DictConfig.
    """
    config = {"_target_": type(self).__name__}
    for attr, key in self._config_mapping.items():
        if hasattr(self, attr):
            config[key] = getattr(self, attr)
    return OmegaConf.create(config)

to_yaml(path)

Save the model pool to a YAML file.

Parameters:

  • path (Union[str, Path]) –

    The path to save the model pool to.

Source code in fusion_bench/mixins/serialization.py
def to_yaml(self, path: Union[str, Path]):
    """
    Save the model pool to a YAML file.

    Args:
        path (Union[str, Path]): The path to save the model pool to.
    """
    config = self.to_config()
    OmegaConf.save(config, path, resolve=True)

BaseYAMLSerializableModel

Bases: YAMLSerializationMixin

A base class for YAML-serializable classes with enhanced metadata support.

This class extends YAMLSerializationMixin to provide additional metadata fields commonly used in FusionBench classes, including usage information and version tracking. It serves as a foundation for all serializable model components in the framework.

The class automatically handles serialization of usage and version metadata alongside the standard configuration parameters, making it easier to track model provenance and intended usage patterns.

Attributes:

  • _usage_ (Optional[str]) –

    Description of the model's intended usage or purpose.

  • _version_ (Optional[str]) –

    Version information for the model or configuration.

Example
class MyAlgorithm(BaseYAMLSerializableModel):
    _config_mapping = BaseYAMLSerializableModel._config_mapping | {
        "model_name": "model_name",
        "num_layers": "num_layers",
    }

    def __init__(self, _usage_: str = None, _version_: str = None):
        super().__init__(_usage_=_usage_, _version_=_version_)

# Usage with metadata
model = MyAlgorithm(
    _usage_="Text classification fine-tuning",
    _version_="1.0.0"
)

# Serialization includes metadata
config = model.config
# DictConfig({
#     '_target_': 'MyModel',
#     '_usage_': 'Text classification fine-tuning',
#     '_version_': '1.0.0'
# })
Note

The underscore prefix in _usage_ and _version_ follows the convention for metadata fields that are not core model parameters but provide important contextual information for model management and tracking.

Source code in fusion_bench/mixins/serialization.py
class BaseYAMLSerializableModel(YAMLSerializationMixin):
    """
    A base class for YAML-serializable classes with enhanced metadata support.

    This class extends `YAMLSerializationMixin` to provide additional metadata
    fields commonly used in FusionBench classes, including usage information
    and version tracking. It serves as a foundation for all serializable
    model components in the framework.

    The class automatically handles serialization of usage and version metadata
    alongside the standard configuration parameters, making it easier to track
    model provenance and intended usage patterns.

    Attributes:
        _usage_ (Optional[str]): Description of the model's intended usage or purpose.
        _version_ (Optional[str]): Version information for the model or configuration.

    Example:
        ```python
        class MyAlgorithm(BaseYAMLSerializableModel):
            _config_mapping = BaseYAMLSerializableModel._config_mapping | {
                "model_name": "model_name",
                "num_layers": "num_layers",
            }

            def __init__(self, _usage_: str = None, _version_: str = None):
                super().__init__(_usage_=_usage_, _version_=_version_)

        # Usage with metadata
        model = MyAlgorithm(
            _usage_="Text classification fine-tuning",
            _version_="1.0.0"
        )

        # Serialization includes metadata
        config = model.config
        # DictConfig({
        #     '_target_': 'MyModel',
        #     '_usage_': 'Text classification fine-tuning',
        #     '_version_': '1.0.0'
        # })
        ```

    Note:
        The underscore prefix in `_usage_` and `_version_` follows the convention
        for metadata fields that are not core model parameters but provide
        important contextual information for model management and tracking.
    """

    _config_mapping = YAMLSerializationMixin._config_mapping | {
        "_usage_": "_usage_",
        "_version_": "_version_",
    }

    _usage_: Optional[str] = None
    _version_: Optional[str] = None

    def __init__(
        self,
        _usage_: Optional[str] = None,
        _version_: Optional[str] = None,
        **kwargs,
    ):
        """
        Initialize a base YAML-serializable model with metadata support.

        Args:
            _usage_ (Optional[str], optional): Description of the model's intended
                usage or purpose. This can include information about the training
                domain, expected input types, or specific use cases. Defaults to None.
            _version_ (Optional[str], optional): Version information for the model
                or configuration. Can be used to track model iterations, dataset
                versions, or compatibility information. Defaults to None.
            **kwargs: Additional keyword arguments passed to the parent class.
                Unused arguments will trigger warnings via the parent's initialization.

        Example:
            ```python
            model = BaseYAMLSerializableModel(
                _usage_="Image classification on CIFAR-10",
                _version_="2.1.0"
            )
            ```
        """
        super().__init__(**kwargs)
        if _usage_ is not None:
            self._usage_ = _usage_
        if _version_ is not None:
            self._version_ = _version_

__init__(_usage_=None, _version_=None, **kwargs)

Initialize a base YAML-serializable model with metadata support.

Parameters:

  • _usage_ (Optional[str], default: None ) –

    Description of the model's intended usage or purpose. This can include information about the training domain, expected input types, or specific use cases. Defaults to None.

  • _version_ (Optional[str], default: None ) –

    Version information for the model or configuration. Can be used to track model iterations, dataset versions, or compatibility information. Defaults to None.

  • **kwargs

    Additional keyword arguments passed to the parent class. Unused arguments will trigger warnings via the parent's initialization.

Example
model = BaseYAMLSerializableModel(
    _usage_="Image classification on CIFAR-10",
    _version_="2.1.0"
)
Source code in fusion_bench/mixins/serialization.py
def __init__(
    self,
    _usage_: Optional[str] = None,
    _version_: Optional[str] = None,
    **kwargs,
):
    """
    Initialize a base YAML-serializable model with metadata support.

    Args:
        _usage_ (Optional[str], optional): Description of the model's intended
            usage or purpose. This can include information about the training
            domain, expected input types, or specific use cases. Defaults to None.
        _version_ (Optional[str], optional): Version information for the model
            or configuration. Can be used to track model iterations, dataset
            versions, or compatibility information. Defaults to None.
        **kwargs: Additional keyword arguments passed to the parent class.
            Unused arguments will trigger warnings via the parent's initialization.

    Example:
        ```python
        model = BaseYAMLSerializableModel(
            _usage_="Image classification on CIFAR-10",
            _version_="2.1.0"
        )
        ```
    """
    super().__init__(**kwargs)
    if _usage_ is not None:
        self._usage_ = _usage_
    if _version_ is not None:
        self._version_ = _version_

SimpleProfilerMixin

A mixin class that provides simple profiling capabilities.

This mixin allows for easy profiling of code blocks using a context manager. It also provides methods to start and stop profiling actions, and to print a summary of the profiling results.

Examples:

class MyClass(SimpleProfilerMixin):
    def do_something(self):
        with self.profile("work"):
            # do some work here
            ...
        with self.profile("more work"):
            # do more work here
            ...

        # print the profiling summary
        self.print_profile_summary()

Attributes:

  • _profiler (SimpleProfiler) –

    An instance of the SimpleProfiler class used for profiling.

Source code in fusion_bench/mixins/simple_profiler.py
class SimpleProfilerMixin:
    """
    A mixin class that provides simple profiling capabilities.

    This mixin allows for easy profiling of code blocks using a context manager.
    It also provides methods to start and stop profiling actions, and to print
    a summary of the profiling results.

    Examples:

    ```python
    class MyClass(SimpleProfilerMixin):
        def do_something(self):
            with self.profile("work"):
                # do some work here
                ...
            with self.profile("more work"):
                # do more work here
                ...

            # print the profiling summary
            self.print_profile_summary()
    ```

    Attributes:
        _profiler (SimpleProfiler): An instance of the SimpleProfiler class used for profiling.
    """

    _profiler: SimpleProfiler = None

    @property
    def profiler(self):
        # Lazy initialization of the profiler instance
        if self._profiler is None:
            self._profiler = SimpleProfiler()
        return self._profiler

    @contextmanager
    def profile(self, action_name: str) -> Generator:
        """
        Context manager for profiling a code block

        Example:

        ```python
        with self.profile("work"):
            # do some work here
            ...
        ```
        """
        try:
            self.start_profile(action_name)
            yield action_name
        finally:
            self.stop_profile(action_name)

    def start_profile(self, action_name: str):
        self.profiler.start(action_name)

    def stop_profile(self, action_name: str):
        self.profiler.stop(action_name)

    @rank_zero_only
    def print_profile_summary(self, title: Optional[str] = None):
        if title is not None:
            print(title)
        print(self.profiler.summary())

    def __del__(self):
        if self._profiler is not None:
            del self._profiler
            self._profiler = None

profile(action_name)

Context manager for profiling a code block

Example:

with self.profile("work"):
    # do some work here
    ...
Source code in fusion_bench/mixins/simple_profiler.py
@contextmanager
def profile(self, action_name: str) -> Generator:
    """
    Context manager for profiling a code block

    Example:

    ```python
    with self.profile("work"):
        # do some work here
        ...
    ```
    """
    try:
        self.start_profile(action_name)
        yield action_name
    finally:
        self.stop_profile(action_name)

LightningFabricMixin

A mixin class for integrating Lightning Fabric into a project.

This class provides methods to initialize and manage a Lightning Fabric instance for distributed computing, including setup with optional logging, device management for tensors and modules, and hyperparameter logging. It leverages the Lightning framework to facilitate distributed training and inference across multiple devices and nodes, with support for custom logging via TensorBoard.

Attributes:

  • _fabric (L.Fabric): The Lightning Fabric instance used for distributed computing.

Note:

This mixin is designed to be used with classes that require distributed computing capabilities and wish to leverage the Lightning Fabric for this purpose. It assumes the presence of a config attribute or parameter in the consuming class for configuration.

Source code in fusion_bench/mixins/lightning_fabric.py
class LightningFabricMixin:
    """
    A mixin class for integrating Lightning Fabric into a project.

    This class provides methods to initialize and manage a Lightning Fabric instance for distributed computing,
    including setup with optional logging, device management for tensors and modules, and hyperparameter logging.
    It leverages the Lightning framework to facilitate distributed training and inference across multiple devices
    and nodes, with support for custom logging via TensorBoard.

    Attributes:

    - _fabric (L.Fabric): The Lightning Fabric instance used for distributed computing.

    Note:

    This mixin is designed to be used with classes that require distributed computing capabilities and wish to
    leverage the Lightning Fabric for this purpose. It assumes the presence of a `config` attribute or parameter
    in the consuming class for configuration.
    """

    _fabric_instance: L.Fabric = None

    def setup_lightning_fabric(self, config: DictConfig):
        """
        Initializes and launches the Lightning Fabric with optional logging.

        This method sets up the Lightning Fabric for distributed computing based on the provided configuration. If a fabric
        configuration is not found, it logs a warning and exits. Optionally, if a fabric logger configuration is provided,
        it initializes a TensorBoardLogger with the specified settings.

        Expected configuration keys:
        - fabric: The configuration for the Lightning Fabric.
        - fabric.loggers: The configuration for the TensorBoardLogger.
        """
        if self._fabric_instance is None:
            if config.get("fabric", None) is None:
                log.warning("No fabric configuration found. use default settings.")
                self._fabric_instance = L.Fabric()
            else:
                self._fabric_instance = instantiate(config.fabric)
            if not _is_using_cli():  # if not using cli, launch the fabric
                self._fabric_instance.launch()
            # Set the log directory in config if it is not already set
            if (
                self.log_dir is not None
                and hasattr(config, "log_dir")
                and config.get("log_dir", None) is None
            ):
                if self._fabric_instance.is_global_zero:
                    log.info(f"Setting log_dir to {self.log_dir}")
                config.log_dir = self.log_dir

    @property
    def fabric(self):
        if self._fabric_instance is None:
            self.setup_lightning_fabric(getattr(self, "config", DictConfig({})))
        return self._fabric_instance

    @property
    def log_dir(self):
        """
        Retrieves the log directory from the fabric's logger.
        """
        if self.fabric is not None and len(self.fabric._loggers) > 0:
            log_dir = self.fabric.logger.log_dir
            if self.fabric.is_global_zero and not os.path.exists(log_dir):
                os.makedirs(log_dir, exist_ok=True)
            return log_dir
        else:
            return None

    def to_device(self, obj: TensorOrModule) -> TensorOrModule:
        """
        Moves a tensor or module to the proper device.

        Args:
            obj (TensorOrModule): The tensor or module to move to the device.

        Returns:
            TensorOrModule: the same type of object as the input, moved to the device.
        """
        return self.fabric.to_device(obj)

    @rank_zero_only
    def log_hyperparams(
        self,
        config: Optional[DictConfig] = None,
        save_dir: Optional[str] = None,
        filename: str = "config.yaml",
    ):
        R"""
        Logs the hyperparameters and saves the configuration to a YAML file.
        The YAML file is saved in the log directory by default with the name `config.yaml`, or in the specified save directory `save_dir`.

        Args:
            config (Optional[DictConfig]): The configuration to log and save. If not provided, the class's `config` attribute is used.
            save_dir (Optional[str]): The directory in which to save the configuration file. If not provided, the log directory is used.
            filename (str): The name of the configuration file. Default is `config.yaml`.
        """
        if config is None:
            config = self.config
        if save_dir is None:
            save_dir = self.log_dir
        self.fabric.logger.log_hyperparams(
            OmegaConf.to_container(config, resolve=True, enum_to_str=True)
        )
        if not os.path.exists(save_dir):
            os.makedirs(save_dir, exist_ok=True)
        OmegaConf.save(
            config,
            os.path.join(self.log_dir if save_dir is None else save_dir, filename),
        )

    @property
    def tensorboard_summarywriter(
        self,
    ) -> "lightning.fabric.loggers.tensorboard.SummaryWriter":
        if isinstance(self.fabric.logger, TensorBoardLogger):
            return self.fabric.logger.experiment
        else:
            raise AttributeError("the logger is not a TensorBoardLogger.")

    @property
    def is_debug_mode(self):
        if hasattr(self, "config") and self.config.get("fast_dev_run", False):
            return True
        elif hasattr(self, "_program") and self._program.config.get(
            "fast_dev_run", False
        ):
            return True
        else:
            return False

    def log(self, name: str, value: Any, step: Optional[int] = None):
        """
        Logs the metric to the fabric's logger.
        """
        self.fabric.log(name, value, step=step)

    def log_dict(self, metrics: dict, step: Optional[int] = None):
        """
        Logs the metrics to the fabric's logger.
        """
        self.fabric.log_dict(metrics, step=step)

    def log_optimizer_lr(
        self,
        optimizer: torch.optim.Optimizer,
        step: Optional[int] = None,
        name_template: str = "train/lr_group_{0}",
    ):
        """
        Logs the learning rate of the optimizer to the fabric's logger.
        """
        for i, param_group in enumerate(optimizer.param_groups):
            self.fabric.log(name_template.format(i), param_group["lr"], step=step)

log_dir property

Retrieves the log directory from the fabric's logger.

log(name, value, step=None)

Logs the metric to the fabric's logger.

Source code in fusion_bench/mixins/lightning_fabric.py
def log(self, name: str, value: Any, step: Optional[int] = None):
    """
    Logs the metric to the fabric's logger.
    """
    self.fabric.log(name, value, step=step)

log_dict(metrics, step=None)

Logs the metrics to the fabric's logger.

Source code in fusion_bench/mixins/lightning_fabric.py
def log_dict(self, metrics: dict, step: Optional[int] = None):
    """
    Logs the metrics to the fabric's logger.
    """
    self.fabric.log_dict(metrics, step=step)

log_hyperparams(config=None, save_dir=None, filename='config.yaml')

Logs the hyperparameters and saves the configuration to a YAML file. The YAML file is saved in the log directory by default with the name config.yaml, or in the specified save directory save_dir.

Parameters:

  • config (Optional[DictConfig], default: None ) –

    The configuration to log and save. If not provided, the class's config attribute is used.

  • save_dir (Optional[str], default: None ) –

    The directory in which to save the configuration file. If not provided, the log directory is used.

  • filename (str, default: 'config.yaml' ) –

    The name of the configuration file. Default is config.yaml.

Source code in fusion_bench/mixins/lightning_fabric.py
@rank_zero_only
def log_hyperparams(
    self,
    config: Optional[DictConfig] = None,
    save_dir: Optional[str] = None,
    filename: str = "config.yaml",
):
    R"""
    Logs the hyperparameters and saves the configuration to a YAML file.
    The YAML file is saved in the log directory by default with the name `config.yaml`, or in the specified save directory `save_dir`.

    Args:
        config (Optional[DictConfig]): The configuration to log and save. If not provided, the class's `config` attribute is used.
        save_dir (Optional[str]): The directory in which to save the configuration file. If not provided, the log directory is used.
        filename (str): The name of the configuration file. Default is `config.yaml`.
    """
    if config is None:
        config = self.config
    if save_dir is None:
        save_dir = self.log_dir
    self.fabric.logger.log_hyperparams(
        OmegaConf.to_container(config, resolve=True, enum_to_str=True)
    )
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    OmegaConf.save(
        config,
        os.path.join(self.log_dir if save_dir is None else save_dir, filename),
    )

log_optimizer_lr(optimizer, step=None, name_template='train/lr_group_{0}')

Logs the learning rate of the optimizer to the fabric's logger.

Source code in fusion_bench/mixins/lightning_fabric.py
def log_optimizer_lr(
    self,
    optimizer: torch.optim.Optimizer,
    step: Optional[int] = None,
    name_template: str = "train/lr_group_{0}",
):
    """
    Logs the learning rate of the optimizer to the fabric's logger.
    """
    for i, param_group in enumerate(optimizer.param_groups):
        self.fabric.log(name_template.format(i), param_group["lr"], step=step)

setup_lightning_fabric(config)

Initializes and launches the Lightning Fabric with optional logging.

This method sets up the Lightning Fabric for distributed computing based on the provided configuration. If a fabric configuration is not found, it logs a warning and exits. Optionally, if a fabric logger configuration is provided, it initializes a TensorBoardLogger with the specified settings.

Expected configuration keys: - fabric: The configuration for the Lightning Fabric. - fabric.loggers: The configuration for the TensorBoardLogger.

Source code in fusion_bench/mixins/lightning_fabric.py
def setup_lightning_fabric(self, config: DictConfig):
    """
    Initializes and launches the Lightning Fabric with optional logging.

    This method sets up the Lightning Fabric for distributed computing based on the provided configuration. If a fabric
    configuration is not found, it logs a warning and exits. Optionally, if a fabric logger configuration is provided,
    it initializes a TensorBoardLogger with the specified settings.

    Expected configuration keys:
    - fabric: The configuration for the Lightning Fabric.
    - fabric.loggers: The configuration for the TensorBoardLogger.
    """
    if self._fabric_instance is None:
        if config.get("fabric", None) is None:
            log.warning("No fabric configuration found. use default settings.")
            self._fabric_instance = L.Fabric()
        else:
            self._fabric_instance = instantiate(config.fabric)
        if not _is_using_cli():  # if not using cli, launch the fabric
            self._fabric_instance.launch()
        # Set the log directory in config if it is not already set
        if (
            self.log_dir is not None
            and hasattr(config, "log_dir")
            and config.get("log_dir", None) is None
        ):
            if self._fabric_instance.is_global_zero:
                log.info(f"Setting log_dir to {self.log_dir}")
            config.log_dir = self.log_dir

to_device(obj)

Moves a tensor or module to the proper device.

Parameters:

  • obj (TensorOrModule) –

    The tensor or module to move to the device.

Returns:

  • TensorOrModule ( TensorOrModule ) –

    the same type of object as the input, moved to the device.

Source code in fusion_bench/mixins/lightning_fabric.py
def to_device(self, obj: TensorOrModule) -> TensorOrModule:
    """
    Moves a tensor or module to the proper device.

    Args:
        obj (TensorOrModule): The tensor or module to move to the device.

    Returns:
        TensorOrModule: the same type of object as the input, moved to the device.
    """
    return self.fabric.to_device(obj)

CLIPClassificationMixin

Bases: LightningFabricMixin

This mixin provides methods to classify images using the CLIP model.

Attributes need to be set by the inheriting class:

  • _dataloader_kwargs (Dict[str, Any]): Keyword arguments for the dataloader.
  • modelpool (CLIPVisionModelPool): The model pool containing the CLIP models.
  • zeroshot_weights_cache_dir (Optional[str]): The directory to cache the zero-shot weights.
Source code in fusion_bench/mixins/clip_classification.py
class CLIPClassificationMixin(LightningFabricMixin):
    """
    This mixin provides methods to classify images using the CLIP model.

    Attributes need to be set by the inheriting class:

    - `_dataloader_kwargs` (Dict[str, Any]): Keyword arguments for the dataloader.
    - `modelpool` (CLIPVisionModelPool): The model pool containing the CLIP models.
    - `zeroshot_weights_cache_dir` (Optional[str]): The directory to cache the zero-shot weights.
    """

    _dataloader_kwargs: Dict[str, Any] = {}
    # the modelpool is set by inheriting class
    modelpool: CLIPVisionModelPool = None
    _clip_processor: CLIPProcessor = None
    # a dict of zeroshot weights for each task, each key is the task name
    zeroshot_weights_cache_dir: str = "outputs/cache/clip_zeroshot_weights"
    zeroshot_weights: Dict[str, torch.Tensor] = {}
    whether_setup_zero_shot_classification_head = False

    @property
    def clip_processor(self):
        if self._clip_processor is None:
            assert self.modelpool is not None, "Model pool is not set"
            self._clip_processor = self.modelpool.load_processor()
        return self._clip_processor

    @functools.cache
    def get_shuffled_test_loader_iter(
        self,
        task: str,
        batch_size: Optional[int] = None,
        num_workers: Optional[int] = None,
        **loader_kwargs,
    ):
        """
        Get an iterator for a shuffled test DataLoader.

        This method creates a DataLoader for the test dataset of the specified task,
        with shuffling enabled. It allows for optional customization of batch size,
        number of workers, and other DataLoader keyword arguments.

        Args:
            task (str): The task identifier for which the test dataset is to be loaded.
            batch_size (Optional[int]): The batch size to use for the DataLoader. If None, the default batch size is used.
            num_workers (Optional[int]): The number of worker processes to use for data loading. If None, the default number of workers is used.
            **loader_kwargs: Additional keyword arguments to pass to the DataLoader.

        Returns:
            Iterator: An iterator over the shuffled test DataLoader.
        """
        # get dataloader kwargs
        dataloader_kwargs = self._dataloader_kwargs.copy()
        dataloader_kwargs["shuffle"] = True
        if batch_size is not None:
            dataloader_kwargs["batch_size"] = batch_size
        if num_workers is not None:
            dataloader_kwargs["num_workers"] = num_workers
        dataloader_kwargs.update(loader_kwargs)

        # get the test dataset
        clip_dataset = CLIPDataset(
            self.modelpool.load_test_dataset(task), self.clip_processor
        )
        # create the dataloader
        loader = DataLoader(clip_dataset, **dataloader_kwargs)
        loader = self.fabric.setup_dataloaders(loader)
        return iter(InfiniteDataLoader(loader))

    @torch.no_grad()
    def setup_zero_shot_classification_head(
        self,
        clip_processor: Optional[CLIPProcessor] = None,
        clip_model: Optional[CLIPModel] = None,
        task_names: Optional[List[str]] = None,
    ):
        self.whether_setup_zero_shot_classification_head = True
        if clip_model is None:
            if self.modelpool.has_pretrained:
                clip_model = self.modelpool.load_clip_model("_pretrained_")
            else:
                clip_model = self.modelpool.load_clip_model(
                    self.modelpool.model_names[0]
                )
        if clip_processor is None:
            clip_processor = self.clip_processor
        clip_classifier = HFCLIPClassifier(clip_model, clip_processor)
        self.visual_projection = deepcopy(clip_model.visual_projection)
        self.visual_projection.requires_grad_(False)
        self.logit_scale_exp = clip_model.logit_scale.data.clone().exp()
        self.visual_projection = self.fabric.to_device(self.visual_projection)
        self.logit_scale_exp = self.fabric.to_device(self.logit_scale_exp)

        # get cache directory
        if self.modelpool.has_pretrained:
            model_name = self.modelpool.get_model_config("_pretrained_")
            if not isinstance(model_name, str):
                model_name = model_name.pretrained_model_name_or_path
        else:
            model_name = self.modelpool.get_model_config(self.modelpool.model_names[0])
            if not isinstance(model_name, str):
                model_name = model_name.pretrained_model_name_or_path
        cache_dir = os.path.join(
            self.zeroshot_weights_cache_dir,
            os.path.normpath(model_name.split("/")[-1]),
        )
        if not os.path.exists(cache_dir):
            log.info(
                f"Creating cache directory for zero-shot classification head at {cache_dir}"
            )
            os.makedirs(cache_dir)

        log.info(f"cache directory for zero-shot classification head: {cache_dir}")
        for task in tqdm(
            self.modelpool.model_names if task_names is None else task_names,
            "Setting up zero-shot classification head",
            disable=not self.fabric.is_global_zero,
        ):
            zeroshot_weights = None
            if self.fabric.is_global_zero:
                cache_file = os.path.join(
                    cache_dir, os.path.normpath(f"{task}_zeroshot_weights.pt")
                )
                if os.path.exists(cache_file):
                    zeroshot_weights = torch.load(
                        cache_file,
                        map_location="cpu",
                        weights_only=True,
                    ).detach()
                    log.info(
                        f"Loadded cached zeroshot weights for task: {task}, shape: {zeroshot_weights.shape}"
                    )
                else:
                    log.info(
                        f"Construct zero shot classification head for task: {task}"
                    )
                    classnames, templates = get_classnames_and_templates(task)
                    clip_classifier.set_classification_task(classnames, templates)
                    zeroshot_weights = clip_classifier.zeroshot_weights.detach().clone()
                    log.info(f"save zeroshot weights to {cache_file}")
                    torch.save(zeroshot_weights, cache_file)

            self.fabric.barrier()
            self.zeroshot_weights[task] = self.fabric.broadcast(zeroshot_weights, src=0)
            self.zeroshot_weights[task] = self.to_device(self.zeroshot_weights[task])
            self.fabric.barrier()

        del clip_classifier
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    def compute_logits(
        self,
        module: Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"],
        images: torch.Tensor,
        task: str,
        image_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Compute the logits of the images for a given task.

        Args:
            module (Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"]): The module to compute the logits.
            images (torch.Tensor): The images to compute the logits.
            task (str): The task to compute the logits.
            image_embeds (Optional[torch.Tensor]): The precomputed image embeddings. If None, the image embeddings will be computed.

        Returns:
            torch.Tensor: The logits of the images.
        """
        text_embeds = self.zeroshot_weights[task]

        if image_embeds is None:
            image_embeds = module(images)[1]
        assert isinstance(
            image_embeds, torch.Tensor
        ), f"`image_embeds` must be a tensor, but got {type(image_embeds)}"
        image_embeds = self.visual_projection(image_embeds)

        # normalize embeddings
        image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)

        # cosine similarity
        logits_per_text = (
            torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
        )
        logits_per_image = logits_per_text.t()

        return logits_per_image

    def compute_features(
        self,
        module: Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"],
        images: torch.Tensor,
        normalize: bool = True,
    ) -> torch.Tensor:
        """
        Extracts image features using CLIP's vision encoder and visual projection.

        Args:
            module (Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"]): The CLIP vision encoder module.
            images (torch.Tensor): Input image batch to process.
            normalize (bool): Whether to normalize the image embeddings.

        Returns:
            torch.Tensor: Normalized image embeddings with dimension matching CLIP's projection space (`projection_dim` in model config).
        """
        image_embeds = module(images)[1]
        image_embeds = self.visual_projection(image_embeds)

        if normalize:
            image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
        return image_embeds

compute_features(module, images, normalize=True)

Extracts image features using CLIP's vision encoder and visual projection.

Parameters:

  • module (Union[Module, CLIPVisionModel, CLIPVisionTransformer]) –

    The CLIP vision encoder module.

  • images (Tensor) –

    Input image batch to process.

  • normalize (bool, default: True ) –

    Whether to normalize the image embeddings.

Returns:

  • Tensor

    torch.Tensor: Normalized image embeddings with dimension matching CLIP's projection space (projection_dim in model config).

Source code in fusion_bench/mixins/clip_classification.py
def compute_features(
    self,
    module: Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"],
    images: torch.Tensor,
    normalize: bool = True,
) -> torch.Tensor:
    """
    Extracts image features using CLIP's vision encoder and visual projection.

    Args:
        module (Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"]): The CLIP vision encoder module.
        images (torch.Tensor): Input image batch to process.
        normalize (bool): Whether to normalize the image embeddings.

    Returns:
        torch.Tensor: Normalized image embeddings with dimension matching CLIP's projection space (`projection_dim` in model config).
    """
    image_embeds = module(images)[1]
    image_embeds = self.visual_projection(image_embeds)

    if normalize:
        image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
    return image_embeds

compute_logits(module, images, task, image_embeds=None)

Compute the logits of the images for a given task.

Parameters:

  • module (Union[Module, CLIPVisionModel, CLIPVisionTransformer]) –

    The module to compute the logits.

  • images (Tensor) –

    The images to compute the logits.

  • task (str) –

    The task to compute the logits.

  • image_embeds (Optional[Tensor], default: None ) –

    The precomputed image embeddings. If None, the image embeddings will be computed.

Returns:

  • Tensor

    torch.Tensor: The logits of the images.

Source code in fusion_bench/mixins/clip_classification.py
def compute_logits(
    self,
    module: Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"],
    images: torch.Tensor,
    task: str,
    image_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """
    Compute the logits of the images for a given task.

    Args:
        module (Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"]): The module to compute the logits.
        images (torch.Tensor): The images to compute the logits.
        task (str): The task to compute the logits.
        image_embeds (Optional[torch.Tensor]): The precomputed image embeddings. If None, the image embeddings will be computed.

    Returns:
        torch.Tensor: The logits of the images.
    """
    text_embeds = self.zeroshot_weights[task]

    if image_embeds is None:
        image_embeds = module(images)[1]
    assert isinstance(
        image_embeds, torch.Tensor
    ), f"`image_embeds` must be a tensor, but got {type(image_embeds)}"
    image_embeds = self.visual_projection(image_embeds)

    # normalize embeddings
    image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)

    # cosine similarity
    logits_per_text = (
        torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
    )
    logits_per_image = logits_per_text.t()

    return logits_per_image

get_shuffled_test_loader_iter(task, batch_size=None, num_workers=None, **loader_kwargs) cached

Get an iterator for a shuffled test DataLoader.

This method creates a DataLoader for the test dataset of the specified task, with shuffling enabled. It allows for optional customization of batch size, number of workers, and other DataLoader keyword arguments.

Parameters:

  • task (str) –

    The task identifier for which the test dataset is to be loaded.

  • batch_size (Optional[int], default: None ) –

    The batch size to use for the DataLoader. If None, the default batch size is used.

  • num_workers (Optional[int], default: None ) –

    The number of worker processes to use for data loading. If None, the default number of workers is used.

  • **loader_kwargs

    Additional keyword arguments to pass to the DataLoader.

Returns:

  • Iterator

    An iterator over the shuffled test DataLoader.

Source code in fusion_bench/mixins/clip_classification.py
@functools.cache
def get_shuffled_test_loader_iter(
    self,
    task: str,
    batch_size: Optional[int] = None,
    num_workers: Optional[int] = None,
    **loader_kwargs,
):
    """
    Get an iterator for a shuffled test DataLoader.

    This method creates a DataLoader for the test dataset of the specified task,
    with shuffling enabled. It allows for optional customization of batch size,
    number of workers, and other DataLoader keyword arguments.

    Args:
        task (str): The task identifier for which the test dataset is to be loaded.
        batch_size (Optional[int]): The batch size to use for the DataLoader. If None, the default batch size is used.
        num_workers (Optional[int]): The number of worker processes to use for data loading. If None, the default number of workers is used.
        **loader_kwargs: Additional keyword arguments to pass to the DataLoader.

    Returns:
        Iterator: An iterator over the shuffled test DataLoader.
    """
    # get dataloader kwargs
    dataloader_kwargs = self._dataloader_kwargs.copy()
    dataloader_kwargs["shuffle"] = True
    if batch_size is not None:
        dataloader_kwargs["batch_size"] = batch_size
    if num_workers is not None:
        dataloader_kwargs["num_workers"] = num_workers
    dataloader_kwargs.update(loader_kwargs)

    # get the test dataset
    clip_dataset = CLIPDataset(
        self.modelpool.load_test_dataset(task), self.clip_processor
    )
    # create the dataloader
    loader = DataLoader(clip_dataset, **dataloader_kwargs)
    loader = self.fabric.setup_dataloaders(loader)
    return iter(InfiniteDataLoader(loader))