Weighted Ensemble¶
A weighted ensemble is a machine learning technique that combines the predictions of multiple models to produce a final prediction. The idea is to leverage the strengths of each individual model to improve overall performance and robustness.
Formally, a weighted ensemble can be defined as follows:
Given a set of \(n\) models, each model \(f_i\) produces a prediction \(f_i(x)\) for an input \(x\). Each model \(i\) also has an associated weight \(w_i\). The final prediction \(F(x)\) of the weighted ensemble is a weighted sum of the individual model predictions:
The weights \(w_i\) are typically non-negative and sum to 1 (i.e., \(\sum_{i=1}^n w_i = 1\)), which ensures that the final prediction is a convex combination of the individual model predictions. The weights can be determined in various ways. They could be set based on the performance of the models on a validation set, or they could be learned as part of the training process. In some cases, all models might be given equal weight. The goal of a weighted ensemble is to produce a final prediction that is more accurate or robust than any individual model. This is particularly useful when the individual models have complementary strengths and weaknesses.
Examples¶
The following Python code snippet demonstrates how to use the WeightedEnsembleAlgorithm
class from the fusion_bench.method
module to create a weighted ensemble of PyTorch models.
from omegaconf import DictConfig
from fusion_bench.method import WeightedEnsembleAlgorithm
#Instantiate the algorithm
method_config = {'name': 'weighted_ensemble', 'weights': [0.3, 0.7]}
algorithm = WeightedEnsembleAlgorithm(DictConfig(method_config))
# Assume we have a list of PyTorch models (nn.Module instances) that we want to ensemble.
models = [...]
# Run the algorithm on the models.
merged_model = algorithm.run(models)
Here's a step-by-step explanation:
-
Instantiate the
WeightedEnsembleAlgorithm
:- A dictionary
method_config
is created with two keys:'name'
and'weights'
. The'name'
key is set to'weighted_ensemble'
indicating the type of ensemble method to use. The'weights'
key is set to a list of weights[0.3, 0.7]
indicating the weights assigned to each model in the ensemble. - The
method_config
dictionary is converted to aDictConfig
object, which is a configuration object used by theomegaconf
library. - The
WeightedEnsembleAlgorithm
is then instantiated with theDictConfig
object as an argument.
- A dictionary
-
Assume a list of PyTorch models that you want to ensemble. This list is assigned to the variable
models
. The actual models are not shown in this code snippet. -
Run the algorithm on the models: The
run
method of theWeightedEnsembleAlgorithm
instance is called with themodels
list as an argument. The result is a merged model that represents the weighted ensemble of the input models. This merged model is assigned to the variablemerged_model
.
Here we list the options for the weighted ensemble algorithm:
Option | Default | Description |
---|---|---|
weights |
A list of floats representing the weights for each model in the ensemble. | |
normalize |
True |
Whether to normalize the weights so that they sum to 1. Default is True . |
if normalize
is set to True
, the weights will be normalized so that they sum to 1. Mathematically, this means that the weights \(w_i\) will be divided by the sum of all weights, so that
Code Intergration¶
Configuration template for the weighted ensemble algorithm:
name: weighted_ensemble
# this should be a list of floats, one for each model in the ensemble
# If weights is null, the ensemble will use the default weights, which are equal weights for all models.
weights: null
nomalize: true
Construct a weighted ensemble using our CLI tool fusion_bench
:
References¶
WeightedEnsembleAlgorithm
¶
Bases: BaseAlgorithm
Source code in fusion_bench/method/ensemble.py
run(modelpool)
¶
Run the weighted ensemble algorithm on the given model pool.
Parameters:
-
modelpool
¶BaseModelPool | List[Module]
) –The pool of models to ensemble.
Returns:
-
WeightedEnsembleModule
–The weighted ensembled model.