DAWE (Data-Adaptive Weight Ensembling)¶
DAWE is a data-adaptive model ensembling method that learns a gating mechanism at inference time to dynamically route inputs through different expert models. Unlike static merging approaches that compute a fixed set of weights, DAWE uses a learned neural network gate that conditions on both the input data and task-specific feature embeddings to produce soft routing weights. This enables the merged model to selectively leverage the strengths of different expert models for different inputs.
Algorithm Overview¶
DAWE addresses a fundamental limitation of static model merging: a single fixed set of merge weights cannot optimally combine models for all possible inputs. Instead, DAWE learns a data-dependent routing function that produces different ensembling weights for each input sample.
Architecture¶
The DAWE system consists of three components:
-
Expert Models: A set of fine-tuned models \(\{\theta_1, \theta_2, ..., \theta_N\}\), each specialized for a particular task. A pretrained base model \(\theta_0\) serves as the reference point.
-
Feature Extractor: A separate model (e.g., ResNet-18) that extracts task-discriminative features from the input. This feature extractor processes the raw input and provides a representation that captures which task the input belongs to.
-
Gating Network: A small neural network that takes the feature extractor's output as input and produces routing weights for the expert models. The gate has configurable hidden layers (
gate_hidden_layers) and its parameters are learned during test-time adaptation.
Inference Process¶
At inference time, for an input image \(x\):
- The CLIP vision model extracts visual features via its
pooler_output. - The ResNet-based feature extractor processes the raw image to obtain task-discriminative features.
- The gating network maps the feature extractor output to routing weights \(w \in \mathbb{R}^{N+1}\) (including the base model).
- The final output is a weighted combination of the expert models' outputs.
Merge Modes¶
DAWE supports two merging granularity modes:
task_wise: A single routing weight per model (model-level mixing).layer_wise: Per-layer routing weights (layer-level mixing).
Batch Reduction¶
The batch_reduce option enables reducing the routed outputs within a batch, which can be useful for generating batch-level aggregated predictions.
Mathematical Formulation¶
Task Vector Representation¶
Each expert model \(i\) is represented as a task vector relative to the pretrained model:
Gating Network¶
The gating network \(g_\phi\) is parameterized by learnable parameters \(\phi\). Given input features \(f\) from the feature extractor:
where \(w \in \mathbb{R}^{N+1}\) are the routing weights, and the softmax ensures they sum to 1.
Weighted Ensemble¶
The final merged representation is:
where \(w_i(x) = g_\phi(f(x))_i\) are the input-dependent weights.
The base model output is added back:
Task Vector Sparsity¶
For efficiency, task vectors can be sparsified by keeping only the top-\(k\) most important parameters:
controlled by the task_vector_sparsity parameter.
Training Objective¶
The gate parameters \(\phi\) are optimized via entropy minimization on the model's predictions:
The gradient flows through the gate, the expert outputs, and the routing weights, enabling end-to-end optimization.
Configuration¶
_target_: fusion_bench.method.DataAdaptiveWeightEnsemblingForCLIP
_recursive_: false
merge_mode: task_wise
init_lambda: 0.3
batch_reduce: true
eval_batch_reduce: false
_dict_feature_extractor_path: microsoft/resnet-18
dict_processor:
_target_: fusion_bench.method.dawe.dawe_for_clip.load_resnet_processor
pretrained_model_name_or_path: ${.._dict_feature_extractor_path}
dict_feature_extractor:
_target_: fusion_bench.method.dawe.dawe_for_clip.load_resnet_feature_extractor
pretrained_model_name_or_path: ${.._dict_feature_extractor_path}
# dimension of the extracted embeddings, if this None, try to infer from the feature extractor model
hidden_size: null
gate_hidden_layers: 1
# if task_vector_dtype is null, the task vector will have the same dtype as pretrained model
task_vector_dtype: null
task_vector_sparsity: 0
# training & logging args
max_steps: 1000
save_interval: 200
learning_rate: 1e-5
resume_checkpoint_path: null
skip_training: false
# dataloader args
batch_size: 1
num_workers: 0
pin_memory: true
Key configuration parameters:
| Parameter | Description | Default |
|---|---|---|
merge_mode |
Merging granularity: task_wise or layer_wise |
task_wise |
init_lambda |
Initial merge weight for the gate | 0.3 |
batch_reduce |
Whether to reduce within batch | true |
dict_feature_extractor_path |
Path to the feature extractor model | microsoft/resnet-18 |
hidden_size |
Dimension of extracted features (inferred if null) | null |
gate_hidden_layers |
Number of hidden layers in the gate | 1 |
task_vector_sparsity |
Sparsity ratio for task vectors | 0 |
max_steps |
Number of training steps for the gate | 1000 |
learning_rate |
Learning rate for gate optimization | 1e-5 |
skip_training |
Skip gate training (use initial weights) | false |
Examples¶
CLI Usage¶
fusion_bench \
method=dawe/dawe_for_clip \
method.merge_mode=task_wise \
method.max_steps=1000 \
method.learning_rate=1e-5 \
modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TA8 \
taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8
API Usage¶
from fusion_bench.method.dawe.dawe_for_clip import DataAdaptiveWeightEnsemblingForCLIP
from fusion_bench.modelpool import CLIPVisionModelPool
# Create the algorithm
algorithm = DataAdaptiveWeightEnsemblingForCLIP(
merge_mode="task_wise",
init_lambda=0.3,
batch_reduce=True,
max_steps=1000,
learning_rate=1e-5,
)
# Run on a model pool
modelpool = CLIPVisionModelPool(...)
merged_model = algorithm.run(modelpool)
Implementation Details¶
DataAdaptiveWeightEnsemblingCLIPVisionModel: The core wrapper model that combines the CLIP vision model, feature extractor, and gating network. Forward pass routes inputs through experts based on gate predictions.ResNetFeatureExtractor: A wrapper aroundResNetForImageClassificationthat removes the classification head and flattens to produce feature vectors.load_resnet_processor: Loads a ResNet processor for image preprocessing, handling RGB conversion.- Checkpoints: During training, checkpoints are saved at every
save_intervalsteps tolog_dir/checkpoints/model_{step}.pt.
References¶
-
(ICLR 2024) DAWE: Data-Adaptive Weight Ensembling for Pre-Trained Models. http://arxiv.org/abs/2310.02575. Introduces the data-adaptive ensembling framework with learnable routing. ↩