Weighted Ensemble¶
A weighted ensemble is a machine learning technique that combines the predictions of multiple models to produce a final prediction. The idea is to leverage the strengths of each individual model to improve overall performance and robustness.
Formally, a weighted ensemble can be defined as follows:
Given a set of \(n\) models, each model \(f_i\) produces a prediction \(f_i(x)\) for an input \(x\). Each model \(i\) also has an associated weight \(w_i\). The final prediction \(F(x)\) of the weighted ensemble is a weighted sum of the individual model predictions:
The weights \(w_i\) are typically non-negative and sum to 1 (i.e., \(\sum_{i=1}^n w_i = 1\)), which ensures that the final prediction is a convex combination of the individual model predictions. The weights can be determined in various ways. They could be set based on the performance of the models on a validation set, or they could be learned as part of the training process. In some cases, all models might be given equal weight. The goal of a weighted ensemble is to produce a final prediction that is more accurate or robust than any individual model. This is particularly useful when the individual models have complementary strengths and weaknesses.
Examples¶
CLI Usage¶
Configuration template for the weighted ensemble algorithm:
_target_: fusion_bench.method.WeightedEnsembleAlgorithm
normalize: true
# this should be a list of floats, one for each model in the ensemble
# If weights is null, the ensemble will use the default weights, which are equal weights for all models.
weights: null
Construct a weighted ensemble using our CLI tool fusion_bench
:
# With specific weights, override via method.weights
fusion_bench method=ensemble/weighted_ensemble \
method.weights=[0.3, 0.7] \
modelpool=... \
taskpool=...
# With equal weights (default)
fusion_bench method=ensemble/weighted_ensemble \
modelpool=... \
taskpool=...
API Usage¶
The following Python code snippet demonstrates how to use the WeightedEnsembleAlgorithm
class from the fusion_bench.method
module to create a weighted ensemble of PyTorch models.
from fusion_bench.method import WeightedEnsembleAlgorithm
# Instantiate the algorithm
algorithm = WeightedEnsembleAlgorithm(weights=[0.3, 0.7], normalize=True)
# Assume we have a modelpool or a list of PyTorch models (nn.Module instances) that we want to ensemble.
models = [...] # List of nn.Module instances
# Run the algorithm on the modelpool or models.
ensemble_model = algorithm.run(modelpool) # or algorithm.run(models)
Here's a step-by-step explanation:
-
Instantiate the
WeightedEnsembleAlgorithm
:- The algorithm is instantiated with two parameters:
weights
(a list of floats representing the weights for each model) andnormalize
(whether to normalize the weights). - If
weights
is set toNone
, the algorithm will automatically assign equal weights to all models.
- The algorithm is instantiated with two parameters:
-
Prepare your models:
- You can either use a
BaseModelPool
instance that contains your models, or directly provide a list of PyTorchnn.Module
instances. - If you provide a list of models directly, the algorithm will automatically wrap them in a
BaseModelPool
.
- You can either use a
-
Run the algorithm:
- The
run
method processes the modelpool and returns aWeightedEnsembleModule
that represents the weighted ensemble of the input models.
- The
Here we list the options for the weighted ensemble algorithm:
Option | Default | Description |
---|---|---|
weights |
null |
A list of floats representing the weights for each model in the ensemble. If null , equal weights are automatically assigned to all models. |
normalize |
True |
Whether to normalize the weights so that they sum to 1. Default is True . |
If normalize
is set to True
, the weights will be normalized so that they sum to 1. Mathematically, this means that the weights \(w_i\) will be divided by the sum of all weights, so that
When weights
is set to null
(or None
in Python), the algorithm automatically assigns equal weights to all models: \(w_i = \frac{1}{n}\) where \(n\) is the number of models.