SparseLo Pruning¶
SparseLo (LoSparse) is a structured pruning framework for Llama models that combines magnitude-based weight pruning with low-rank compensation. The core idea is to prune a large fraction of weights while compensating for the lost information by learning a low-rank decomposition of the pruned weights.
The LoSparse Linear Layer. Each linear layer in the model is converted to a LoSparseLinear module, which consists of two parts:
where: - \(W_{\text{sparse}}\) is the pruned (sparse) weight matrix. - \(B \in \mathbb{R}^{m \times r}\) and \(A \in \mathbb{R}^{r \times n}\) form a low-rank (rank \(r\)) compensation matrix \(B A^T\).
The low-rank part captures the information lost during pruning, preserving model accuracy with minimal additional parameters.
Pruning Variants. SparseLo supports multiple pruning strategies:
-
wanda: Activation-aware pruning using the WANDA method. Importance scores are computed as \(|w_i| \cdot \max_j |x_j|\) where \(x_j\) are calibration activations. -
magnitude: Standard magnitude-based pruning using absolute weight values \(|w_i|\) as importance scores. -
random: Random pruning as a baseline. -
lowrank-only: Extracts the low-rank part via SVD of the original weights and sets the sparse part to zero. -
dense: No pruning (debug mode).
Low-Rank Extraction. After pruning, the low-rank compensation is extracted from the pruned weight matrix via SVD:
where \(k = \text{rank}\). The sparse weight is then updated to \(W_{\text{sparse}} = W_{\text{original}} - B A^T\).
Advanced Variants¶
PCP SparseLo (PCPSparseLoForLlama)¶
The PCP (Principal Component Pursuit) variant applies an optimization-based refinement step after pruning. Given the original weight \(W\) and the pruning mask, it solves:
where \(\| \cdot \|\) denotes the nuclear norm and \(\| \cdot \|_1\) is the entrywise L1 norm, \(\lambda = 1/\sqrt{\max(|W|)}\), and \(m\) is the pruning mask. This is optimized via Adam for num_iterations steps with cosine annealing. The refined sparse weights replace the magnitude-pruned ones.
Iterative SparseLo (IterativeSparseLoForLlama)¶
The iterative variant progressively reconstructs the pruned weight matrix. For each iteration with increasing rank \(r\):
- Compute the difference \(W_{\text{diff}} = W_{\text{original}} - W_{\text{current}}\).
- Perform SVD of \(W_{\text{diff}}\), retaining only the tail components (rank \(r\) and above).
- Update: \(W_{\text{current}} = W_{\text{current}} + \text{mask} \odots (U_{\text{tail}} \Sigma_{\text{tail}} V_{\text{tail}}^T)\).
- Stop early if the spectrum ratio (fraction of singular value energy in the first \(r\) components) exceeds 0.99.
This iterative reconstruction can optionally use a reference model's weights as the original.
Examples¶
CLI Usage -- Standard SparseLo¶
_target_: fusion_bench.method.sparselo.sparselo.SparseLoForLlama
_recursive_: false
nsamples: 128
seed: 0
rank: 128
variant: wanda
# `prune_type` can be either `unstructured` or `semistructured`
prune_type: unstructured
# device and dtype to compute the pruning mask
device: null
# === options for unstructured pruning ===
# `sparsity_ratio` is the ratio of weights to be pruned, 1 means all weights are pruned
sparsity_ratio: 0.5
# === options for semistructured pruning ===
# 2:4 means 2 out of 4 weights are pruned
n: 2
m: 4
# string to specify the path to where the pruned model is saved
model_save_path: null
fusion_bench \
method=sparselo_pruning/llama_sparselo \
method.rank=128 \
method.sparsity_ratio=0.5 \
method.prune_type=unstructured \
method.variant=wanda \
method.nsamples=128 \
modelpool=CausalLMPool/meta-llama/Llama-2-7b-hf
CLI Usage -- Iterative SparseLo¶
_target_: fusion_bench.method.sparselo.sparselo.IterativeSparseLoForLlama
_recursive_: false
nsamples: 128
seed: 0
rank: 128
num_iterations: 10
variant: wanda
use_reference_model: false
# `prune_type` can be either `unstructured` or `semistructured`
prune_type: unstructured
# device and dtype to compute the pruning mask
device: null
# === options for unstructured pruning ===
# `sparsity_ratio` is the ratio of weights to be pruned, 1 means all weights are pruned
sparsity_ratio: 0.5
# === options for semistructured pruning ===
# 2:4 means 2 out of 4 weights are pruned
n: 2
m: 4
# string to specify the path to where the pruned model is saved
model_save_path: null
fusion_bench \
method=sparselo_pruning/llama_iterative_sparselo \
method.rank=128 \
method.num_iterations=10 \
method.sparsity_ratio=0.5 \
method.variant=wanda \
method.use_reference_model=false \
modelpool=CausalLMPool/meta-llama/Llama-2-7b-hf
CLI Usage -- PCP SparseLo¶
_target_: fusion_bench.method.sparselo.sparselo.PCPSparseLoForLlama
_recursive_: false
nsamples: 128
seed: 0
rank: 128
num_iterations: 10
variant: wanda
# `prune_type` can be either `unstructured` or `semistructured`
prune_type: unstructured
# device and dtype to compute the pruning mask
device: null
# === options for unstructured pruning ===
# `sparsity_ratio` is the ratio of weights to be pruned, 1 means all weights are pruned
sparsity_ratio: 0.5
# === options for semistructured pruning ===
# 2:4 means 2 out of 4 weights are pruned
n: 2
m: 4
# string to specify the path to where the pruned model is saved
model_save_path: null
fusion_bench \
method=sparselo_pruning/llama_pcp_sparselo \
method.rank=128 \
method.num_iterations=10 \
method.sparsity_ratio=0.5 \
method.variant=wanda \
modelpool=CausalLMPool/meta-llama/Llama-2-7b-hf
Key Parameters¶
| Parameter | Type | Default | Description |
|---|---|---|---|
rank |
int | 128 | Rank of the low-rank compensation matrix. |
variant |
str | "wanda" | Pruning variant: "dense", "random", "wanda", "lowrank-only", "magnitude". |
nsamples |
int | 128 | Number of calibration samples (for wanda variant). |
sparsity_ratio |
float | 0.5 | Fraction of weights to prune (unstructured). |
prune_type |
str | "unstructured" | "unstructured" or "semistructured". |
n, m |
int | 2, 4 | n:m ratio for semistructured pruning. |
num_iterations |
int | 10 | (Iterative/PCP only) Number of refinement iterations. |
use_reference_model |
bool | false | (Iterative only) Use a reference model's weights. |
model_save_path |
str | null | Path to save the pruned model. |
API Usage¶
from fusion_bench.method.sparselo.sparselo import SparseLoForLlama
algorithm = SparseLoForLlama(
nsamples=128,
variant="wanda",
seed=0,
rank=128,
sparsity_ratio=0.5,
prune_type="unstructured",
n=2,
m=4,
)
pruned_model = algorithm.run(modelpool)
Model Conversion¶
The algorithm converts a standard LlamaForCausalLM to a LoSparseLlamaForCausalLM by replacing all nn.Linear modules with LoSparseLinear modules. Each LoSparseLinear layer has three weight matrices: weight (sparse), lo_A, and lo_B (low-rank).
Implementation Details¶
- fusion_bench.method.sparselo.sparselo.SparseLoForLlama
- fusion_bench.method.sparselo.sparselo.PCPSparseLoForLlama
- fusion_bench.method.sparselo.sparselo.IterativeSparseLoForLlama
- [fusion_bench.models.modeling_losparse_llama.losparse_linear.LoSparseLinear][]
-
SparseLo: Sparse Models with Low-Rank Compensation for Efficient Language Model Compression. ↩