Expert Sparsity¶
Expert Sparsity provides a suite of methods for pruning and optimizing Mixture-of-Experts (MoE) language models, specifically targeting Mixtral architectures. The goal is to reduce the number of experts or the computation per token while maintaining model quality, enabling faster inference and lower memory usage.
The implementation follows the paper "Not All Experts are Equal: Efficient Expert Pruning and Skipping for Mixture-of-Experts Large Language Models" (ACL 2024), which presents three complementary techniques: Layer-Wise Pruning, Progressive Pruning, and Dynamic Skipping.
Layer-Wise Pruning¶
Layer-wise pruning selects a subset of experts for each layer independently. The algorithm works as follows:
-
Wrapper Insertion: Each
MixtralSparseMoeBlockis wrapped in aPrunableMixtralSparseMoeBlockWrapperthat caches input activations (\(X\)) and intermediate outputs (\(Z\)). -
Calibration Forward Pass: Calibration data is forwarded through the model to accumulate activation statistics.
-
Expert Enumeration: For each layer, the wrapper evaluates all possible combinations of preserving \(r\) experts out of the total, computing the reconstruction loss (difference between the original MoE output and the pruned output) for each combination.
-
Optimal Selection: The combination with the lowest reconstruction loss is selected, and the pruned MoE block retains only those experts.
where \(S\) is a subset of \(r\) experts.
CLI Usage¶
_target_: fusion_bench.method.LayerWisePruningForMixtral
num_preserved_experts: 4
# c4 or math
# corresponding to the keys of `fusion_bench.method.expert_sparsity.utils.calibration_data.DATASETS`
calib_set: c4
# Maximal sequence length of each sample in calibration set
max_block_size: 2048
# Number of sequences in calibration set. If set to 0 or negative, the whole dataset will be used
n_blocks_for_stat: 128
# Batch size for model inference
batch_size: 8
# Number of workers in dataloader
num_workers: 8
# Random seed
seed: 42
# Path to save the pruned model
model_save_path: "{log_dir}/pruned_model"
fusion_bench \
method=expert_sparsity/mixtral \
method._target_=fusion_bench.method.LayerWisePruningForMixtral \
method.num_preserved_experts=4 \
method.calib_set=c4 \
method.n_blocks_for_stat=128 \
method.batch_size=8 \
modelpool=CausalLMPool/mixtral-8x7b
Progressive Pruning¶
Progressive pruning is a memory-efficient variant that prunes one layer at a time, replacing the wrapper with the pruned model before moving to the next layer. This reduces peak memory usage:
-
Z-activation Collection: First pass collects only the intermediate expert outputs (\(Z\)) for all layers.
-
Layer-by-Layer X-Collection: For each layer, a forward pass collects the input activations (\(X\)) for that layer only. After enumerating and pruning, the wrapper is replaced with the pruned model, freeing memory.
-
Result: The same optimal subset selection as layer-wise pruning, but with lower memory overhead.
CLI Usage¶
fusion_bench \
method=expert_sparsity/mixtral \
method._target_=fusion_bench.method.ProgressivePruningForMixtral \
method.num_preserved_experts=4 \
method.calib_set=c4 \
method.n_blocks_for_stat=128 \
method.batch_size=8 \
modelpool=CausalLMPool/mixtral-8x7b
Dynamic Skipping¶
Dynamic skipping is a runtime optimization that analyzes the routing weight ratios across calibration data to determine per-layer beta parameters. These betas control how aggressively tokens can skip the second-ranked expert during inference:
-
Router Logit Collection: The wrapper caches router logits, input activations (\(X\)), and expert outputs (\(Z\)).
-
Ratio Analysis: For each token, the ratio of the second-highest routing weight to the highest is computed:
where \(w_{(1)}\) and \(w_{(2)}\) are the sorted routing weights (descending).
- Beta Computation: The median (and mean) of \(\rho\) across all tokens and positions is computed per layer. The median is stored as
model.config.betas[layer_idx]and used at inference time to decide whether the second expert can be skipped.
CLI Usage¶
fusion_bench \
method=expert_sparsity/mixtral \
method._target_=fusion_bench.method.DynamicSkippingPruningForMixtral \
method.calib_set=c4 \
method.n_blocks_for_stat=128 \
method.batch_size=8 \
modelpool=CausalLMPool/mixtral-8x7b
Calibration Data¶
All three methods use calibration data for analysis. Supported datasets:
- C4: The Common Crawl C4 corpus (English subset). Downloaded from
allenai/c4on HuggingFace Hub. - MATH: A math pretraining-style dataset from
tanganke/math_pretrain_style.
Common Parameters¶
| Parameter | Type | Default | Description |
|---|---|---|---|
num_preserved_experts |
int | 4 | Number of experts to keep per layer (pruning methods). |
calib_set |
str | "c4" | Calibration dataset: "c4" or "math". |
max_block_size |
int | 2048 | Max sequence length per calibration sample. |
n_blocks_for_stat |
int | 128 | Number of sequence blocks for calibration. 0 = use entire dataset. |
batch_size |
int | 8 | Batch size for calibration forward passes. |
num_workers |
int | 8 | DataLoader workers. |
seed |
int | 42 | Random seed for calibration data shuffling. |
model_save_path |
str | "{log_dir}/pruned_model" | Path to save the pruned model. |
Output¶
-
Pruning methods: Return the pruned
MixtralForCausalLMmodel withnum_expertsreduced tonum_preserved_experts. Also save pruning info (loss history per layer) to{log_dir}/pruning_info.pt. -
Dynamic Skipping: Return the original model with
config.betasset to the per-layer median routing ratios. Also save(res_median, res_mean)to{log_dir}/pruning_info.pt.
Implementation Details¶
- fusion_bench.method.expert_sparsity.mixtral.layer_wise_pruning.LayerWisePruningForMixtral
- [fusion_bench.method.expert_sparsity.mixtral.layer_wise_pruning.layerwise_pruning][]
- fusion_bench.method.expert_sparsity.mixtral.progressive_pruning.ProgressivePruningForMixtral
- [fusion_bench.method.expert_sparsity.mixtral.progressive_pruning.progressive_pruning][]
- fusion_bench.method.expert_sparsity.mixtral.dynamic_skipping.DynamicSkippingPruningForMixtral
- [fusion_bench.method.expert_sparsity.mixtral.dynamic_skipping.dynamic_skipping][]
- [fusion_bench.method.expert_sparsity.utils.calibration_data.build_calib_loader][]