Skip to content

Simple Ensemble

Ensemble methods are simple and effective ways to improve the performance of machine learning models. They combine the outputs of multiple models to create a stronger model.

Examples

from fusion_bench.method import EnsembleAlgorithm

# Instantiate the EnsembleAlgorithm
algorithm = EnsembleAlgorithm()

# Assume we have a list of PyTorch models (nn.Module instances) that we want to ensemble.
models = [...]

# Run the algorithm on the models.
merged_model = algorithm.run(models)

Code Integration

Configuration template for the ensemble algorithm:

config/method/simple_ensemble.yaml
name: simple_ensemble

create a simple ensemble of CLIP-ViT models for image classification tasks.

fusion_bench \
  method=ensemble/simple_ensemble \
  modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TA8 \
  taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8 

References

SimpleEnsembleAlgorithm

Bases: BaseAlgorithm

Source code in fusion_bench/method/ensemble.py
class SimpleEnsembleAlgorithm(BaseAlgorithm):
    @torch.no_grad()
    def run(self, modelpool: BaseModelPool | List[nn.Module]):
        """
        Run the simple ensemble algorithm on the given model pool.

        Args:
            modelpool (BaseModelPool | List[nn.Module]): The pool of models to ensemble.

        Returns:
            EnsembleModule: The ensembled model.
        """
        log.info(f"Running ensemble algorithm with {len(modelpool)} models")

        models = [modelpool.load_model(m) for m in modelpool.model_names]
        ensemble = EnsembleModule(models=models)
        return ensemble
run(modelpool)

Run the simple ensemble algorithm on the given model pool.

Parameters:

  • modelpool
    (BaseModelPool | List[Module]) –

    The pool of models to ensemble.

Returns:

  • EnsembleModule

    The ensembled model.

Source code in fusion_bench/method/ensemble.py
@torch.no_grad()
def run(self, modelpool: BaseModelPool | List[nn.Module]):
    """
    Run the simple ensemble algorithm on the given model pool.

    Args:
        modelpool (BaseModelPool | List[nn.Module]): The pool of models to ensemble.

    Returns:
        EnsembleModule: The ensembled model.
    """
    log.info(f"Running ensemble algorithm with {len(modelpool)} models")

    models = [modelpool.load_model(m) for m in modelpool.model_names]
    ensemble = EnsembleModule(models=models)
    return ensemble