Skip to content

Magnitude Pruning

Examples

Pruning a Llama Model

Unstructured Magnitude Pruning

The following command prunes a Llama model with a sparsity ratio of 0.7 (70% of the weights are pruned) using unstructured magnitude pruning. The pruned model is saved to outputs/llama/magnitude_pruning/unstructured/0.7.

fusion_bench \
    --config-name llama_magnitude_pruning \
    method.prune_type=unstructured \
    method.sparsity_ratio=0.7 \
    modelpool.models.0.path=decapoda-research/llama-7b-hf \
    merged_model_save_path=outputs/llama/magnitude_pruning/unstructured/0.7

Semi-Structured Magnitude Pruning

The following command prunes a Llama model with a 2:4 semi-structured pruning ratio using magnitude pruning. The pruned model is saved to outputs/llama/magnitude_pruning/semistructure/2_4.

fusion_bench \
    --config-name llama_magnitude_pruning \
    method.prune_type=semistructured \
    method.n=2 method.m=4 \
    modelpool.models.0.path=decapoda-research/llama-7b-hf \
    merged_model_save_path=outputs/llama/magnitude_pruning/semistructure/2_4

Below is an example of how to visualize the pruned weights of the first layer of the pruned model.

from transformers import AutoModelForCausalLM
import matplotlib.pyplot as plt
import seaborn as sns
import torch

# Load the pruned model
model = AutoModelForCausalLM.from_pretrained("outputs/llama/magnitude_pruning/semistructure/2_4")

# Extract the tensor data
tensor_data = model.model.layers[0].self_attn.q_proj.weight[:32, :32]

# Convert to NumPy array
tensor_data_np = tensor_data.detach().cpu().numpy()

# Plot heatmap
plt.figure(figsize=(10, 8))
ax = sns.heatmap(tensor_data_np, center=0, cmap="coolwarm", annot=False)

# Add grid lines for 4x4 cells
for i in range(0, tensor_data_np.shape[0], 4):
    ax.axhline(i, color="black", linewidth=0.5)
    ax.axvline(i, color="black", linewidth=0.5)

plt.title("Heatmap of q_proj.weight[:32, :32]")
plt.show()

The following image shows the pruned weights of the first layer of the pruned model.

alt text

References

MagnitudePruningForLlama

Bases: ModelFusionAlgorithm, SimpleProfilerMixin

Implements magnitude-based pruning for LLama models.

This class supports both unstructured and semistructured pruning methods. It loads a pre-trained model or the first model in the pool and applies the specified pruning technique.

Methods:

  • run

    LLamaForCausalLMPool) -> nn.Module: Executes the pruning process on the model pool and returns the pruned model.

Source code in fusion_bench/method/pruning/llama_magnitude_prune.py
class MagnitudePruningForLlama(ModelFusionAlgorithm, SimpleProfilerMixin):
    """
    Implements magnitude-based pruning for LLama models.

    This class supports both unstructured and semistructured pruning methods.
    It loads a pre-trained model or the first model in the pool and applies the specified pruning technique.

    Methods:
        run(modelpool: LLamaForCausalLMPool) -> nn.Module:
            Executes the pruning process on the model pool and returns the pruned model.
    """

    @torch.no_grad()
    def run(self, modelpool: LLamaForCausalLMPool):
        config = self.config

        # load pre-trained model or the first model in the pool
        with self.profile("load_model"):
            if modelpool.has_pretrained:
                base_model = modelpool.load_model("_pretrained_")
            else:
                base_model = modelpool.load_model(modelpool.model_names[0])

        dtype = parse_dtype(config.dtype)
        device = torch.device(config.device)

        if config.prune_type == "unstructured":
            unstructured_magnitude_prune_(
                base_model, config.sparsity_ratio, dtype=dtype, device=device
            )
        elif config.prune_type == "semistructured":
            semistructured_magnitude_prune_(
                base_model, config.n, config.m, dtype=dtype, device=device
            )
        else:
            raise ValueError(
                f"Invalid pruning type: {config.prune_type}"
                "Choose from 'unstructured' or 'semistructured'"
            )

        return base_model