Skip to content

Model Ensemble

SimpleEnsembleAlgorithm

Bases: BaseAlgorithm

Source code in fusion_bench/method/ensemble.py
@auto_register_config
class SimpleEnsembleAlgorithm(BaseAlgorithm):
    def __init__(
        self,
        device_map: Optional[Mapping[int, Union[str, torch.device]]] = None,
        **kwargs,
    ):
        """
        Initializes the SimpleEnsembleAlgorithm with an optional device map.

        Args:
            device_map (Optional[Mapping[int, Union[str, torch.device]]], optional): A mapping from model index to device. Defaults to None.
        """
        super().__init__(**kwargs)

    @torch.no_grad()
    def run(self, modelpool: BaseModelPool | List[nn.Module]) -> EnsembleModule:
        """
        Run the simple ensemble algorithm on the given model pool.

        Args:
            modelpool (BaseModelPool | List[nn.Module]): The pool of models to ensemble.

        Returns:
            EnsembleModule: The ensembled model.
        """
        log.info(f"Running ensemble algorithm with {len(modelpool)} models")
        models = [modelpool.load_model(m) for m in modelpool.model_names]

        log.info("creating ensemble module")
        ensemble = EnsembleModule(models=models, device_map=self.device_map)
        return ensemble

__init__(device_map=None, **kwargs)

Initializes the SimpleEnsembleAlgorithm with an optional device map.

Parameters:

  • device_map (Optional[Mapping[int, Union[str, device]]], default: None ) –

    A mapping from model index to device. Defaults to None.

Source code in fusion_bench/method/ensemble.py
def __init__(
    self,
    device_map: Optional[Mapping[int, Union[str, torch.device]]] = None,
    **kwargs,
):
    """
    Initializes the SimpleEnsembleAlgorithm with an optional device map.

    Args:
        device_map (Optional[Mapping[int, Union[str, torch.device]]], optional): A mapping from model index to device. Defaults to None.
    """
    super().__init__(**kwargs)

run(modelpool)

Run the simple ensemble algorithm on the given model pool.

Parameters:

  • modelpool (BaseModelPool | List[Module]) –

    The pool of models to ensemble.

Returns:

Source code in fusion_bench/method/ensemble.py
@torch.no_grad()
def run(self, modelpool: BaseModelPool | List[nn.Module]) -> EnsembleModule:
    """
    Run the simple ensemble algorithm on the given model pool.

    Args:
        modelpool (BaseModelPool | List[nn.Module]): The pool of models to ensemble.

    Returns:
        EnsembleModule: The ensembled model.
    """
    log.info(f"Running ensemble algorithm with {len(modelpool)} models")
    models = [modelpool.load_model(m) for m in modelpool.model_names]

    log.info("creating ensemble module")
    ensemble = EnsembleModule(models=models, device_map=self.device_map)
    return ensemble

WeightedEnsembleAlgorithm

Bases: BaseAlgorithm

Source code in fusion_bench/method/ensemble.py
@auto_register_config
class WeightedEnsembleAlgorithm(BaseAlgorithm):

    def __init__(
        self,
        normalize: bool = True,
        weights: Optional[List[float]] = None,
        **kwargs,
    ):
        super().__init__(**kwargs)

    @torch.no_grad()
    def run(self, modelpool: BaseModelPool | List[nn.Module]) -> WeightedEnsembleModule:
        """
        Run the weighted ensemble algorithm on the given model pool.

        Args:
            modelpool (BaseModelPool | List[nn.Module]): The pool of models to ensemble.

        Returns:
            WeightedEnsembleModule: The weighted ensembled model.
        """
        if not isinstance(modelpool, BaseModelPool):
            modelpool = BaseModelPool(models=modelpool)

        log.info(f"Running weighted ensemble algorithm with {len(modelpool)} models")

        models = [modelpool.load_model(m) for m in modelpool.model_names]
        if self.weights is None:
            weights = np.ones(len(models)) / len(models)
        else:
            weights = self.weights
        ensemble = WeightedEnsembleModule(
            models,
            weights=weights,
            normalize=self.config.get("normalize", True),
        )
        return ensemble

run(modelpool)

Run the weighted ensemble algorithm on the given model pool.

Parameters:

  • modelpool (BaseModelPool | List[Module]) –

    The pool of models to ensemble.

Returns:

Source code in fusion_bench/method/ensemble.py
@torch.no_grad()
def run(self, modelpool: BaseModelPool | List[nn.Module]) -> WeightedEnsembleModule:
    """
    Run the weighted ensemble algorithm on the given model pool.

    Args:
        modelpool (BaseModelPool | List[nn.Module]): The pool of models to ensemble.

    Returns:
        WeightedEnsembleModule: The weighted ensembled model.
    """
    if not isinstance(modelpool, BaseModelPool):
        modelpool = BaseModelPool(models=modelpool)

    log.info(f"Running weighted ensemble algorithm with {len(modelpool)} models")

    models = [modelpool.load_model(m) for m in modelpool.model_names]
    if self.weights is None:
        weights = np.ones(len(models)) / len(models)
    else:
        weights = self.weights
    ensemble = WeightedEnsembleModule(
        models,
        weights=weights,
        normalize=self.config.get("normalize", True),
    )
    return ensemble

MaxModelPredictorAlgorithm

Bases: BaseAlgorithm

Source code in fusion_bench/method/ensemble.py
class MaxModelPredictorAlgorithm(BaseAlgorithm):
    @torch.no_grad()
    def run(self, modelpool: BaseModelPool | List[nn.Module]) -> MaxModelPredictor:
        """
        Run the max model predictor algorithm on the given model pool.

        Args:
            modelpool (BaseModelPool | List[nn.Module]): The pool of models to ensemble.

        Returns:
            MaxModelPredictor: The max model predictor ensembled model.
        """
        if not isinstance(modelpool, BaseModelPool):
            modelpool = BaseModelPool(models=modelpool)

        log.info(f"Running max predictor algorithm with {len(modelpool)} models")

        models = [modelpool.load_model(m) for m in modelpool.model_names]
        ensemble = MaxModelPredictor(models=models)
        return ensemble

run(modelpool)

Run the max model predictor algorithm on the given model pool.

Parameters:

  • modelpool (BaseModelPool | List[Module]) –

    The pool of models to ensemble.

Returns:

  • MaxModelPredictor ( MaxModelPredictor ) –

    The max model predictor ensembled model.

Source code in fusion_bench/method/ensemble.py
@torch.no_grad()
def run(self, modelpool: BaseModelPool | List[nn.Module]) -> MaxModelPredictor:
    """
    Run the max model predictor algorithm on the given model pool.

    Args:
        modelpool (BaseModelPool | List[nn.Module]): The pool of models to ensemble.

    Returns:
        MaxModelPredictor: The max model predictor ensembled model.
    """
    if not isinstance(modelpool, BaseModelPool):
        modelpool = BaseModelPool(models=modelpool)

    log.info(f"Running max predictor algorithm with {len(modelpool)} models")

    models = [modelpool.load_model(m) for m in modelpool.model_names]
    ensemble = MaxModelPredictor(models=models)
    return ensemble