RankOne MoE for Model Merging¶
RankOne-MoE (Mixture of Experts) is a model merging approach that upscales merged models into a Mixture-of-Experts architecture. Instead of computing a single set of merged weights, it creates a pool of "experts" -- low-rank (rank-1) perturbations derived from each task-specific model -- and a learned router that dynamically selects which experts to use for each input.
The Rank-One Expert Construction. For each task-specific model, the algorithm computes a rank-one decomposition of the task-specific weight update. Given a task model with weight \(W_i\) and the pretrained weight \(W_0\), the update \(\Delta W_i = W_i - W_0\) is approximated by:
where \(a_i\) and \(b_i\) are vectors. Each such rank-one pair \((a_i, b_i)\) forms an "expert" for that task at that layer.
Model Architecture. The RankOne-MoE model replaces the MLP modules in each transformer layer with a RankOneMoE module. This module contains:
- A base model (the merged model via task arithmetic, made non-trainable).
- A pool of experts (rank-one perturbations, with rank_k experts per task).
- A router (a small network with router_hidden_layers hidden layers) that scores each expert and selects the top-\(k\).
The MoE Forward Pass. For each input token, the router computes:
where \(s_e(x)\) is the router score for expert \(e\) on input \(x\), and \(f_{\text{base}}\) is the base merged model.
Test-Time Adaptation¶
After constructing the MoE model, the router parameters are fine-tuned via test-time adaptation. The optimization objective is entropy maximization:
where \(p(c|x)\) is the softmax probability of class \(c\) given input \(x\). Maximizing entropy encourages the router to produce confident, well-distributed predictions. The adaptation supports gradient accumulation across tasks (use_grad_accumulate=true) for memory efficiency.
Examples¶
CLI Usage¶
name: ??? # this can be
# the path for loading the model weights, if specified, skip the test-time adaptation training
checkpoint: False
# the path for saving the model weights.
save_checkpoint: False
router_hidden_layers: 1
init_lambda: 0.3
batch_reduce: true
# device to compute svd
svd_accelerator: cuda
rank_k: 32 # How many experts are added to the pool per task?
select_k: -1 # How many experts are selected from the pool to merge? Range is (1, rank_k*task_num). In particular -1: All the experts in the pool
# learning rate
lr: 1e-4
optimizer: adam
# this is overrided by `fabric.devices` if launched from the `fusion_bench` CLI.
devices: 1
batch_size: 16
num_workers: 16
max_steps: 1000 # default: 1000
# if true, we will use the gradient accumulation across tasks to save memory
use_grad_accumulate: true
cache_dir: outputs
fast_dev_run: ${fast_dev_run}
fusion_bench \
method=rankone_moe/rankone_moe \
method.rank_k=32 \
method.select_k=-1 \
method.init_lambda=0.3 \
method.lr=0.0001 \
method.max_steps=1000 \
method.batch_size=16 \
modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TA8 \
taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8
Key Parameters¶
| Parameter | Type | Default | Description |
|---|---|---|---|
rank_k |
int | 32 | Number of rank-one experts added per task to the expert pool. |
select_k |
int | -1 | Number of experts selected from the pool during inference. -1 means all experts. |
init_lambda |
float | 0.3 | Scaling factor for the initial task arithmetic merge (base model). |
lr |
float | 1e-4 | Learning rate for router test-time adaptation. |
max_steps |
int | 1000 | Number of test-time adaptation steps. |
batch_size |
int | 16 | Batch size for test-time adaptation. |
num_workers |
int | 16 | DataLoader workers. |
router_hidden_layers |
int | 1 | Number of hidden layers in the router network. |
batch_reduce |
bool | true | Whether to use batch-reduce mode. Set to false for sample-wise adaptation at inference. |
use_grad_accumulate |
bool | true | Use gradient accumulation across tasks to save memory. |
svd_accelerator |
str | "cuda" | Device for SVD computation during expert construction. |
checkpoint |
str/bool | False | Path to load a checkpoint (skips test-time adaptation). |
save_checkpoint |
str/bool | False | Path to save a checkpoint after adaptation. |
API Usage¶
from fusion_bench.method.rankone_moe import CLIPRankOneMoEAlgorithm
# The algorithm is configured via DictConfig in the YAML
# See config/method/rankone_moe/rankone_moe.yaml for the full config
moe_model = algorithm.run(modelpool)
# After construction, each layer's MLP is a RankOneMoE module
# with learnable router parameters
Output¶
The algorithm returns a RankOneMoE model where each transformer layer's MLP has been replaced with a RankOneMoE module. The model supports both batch-reduced inference (all tasks processed together) and sample-wise inference (individual predictions).
Implementation Details¶
- fusion_bench.method.rankone_moe.rankone_moe.RankOneMoEAlgorithm
- fusion_bench.method.rankone_moe.clip_rankone_moe.CLIPRankOneMoEAlgorithm
- [fusion_bench.models.rankone_moe.RankOneMoE][]
-
(2024) RankOne-MoE: Mixture of Experts for Multi-Task Model Merging. https://github.com/EnnengYang/RankOne-MoE ↩