Skip to content

Model Compression

Task Vector Compression

BitDelta

BitDeltaAlgorithm

Bases: LightningFabricMixin, SimpleProfilerMixin, BaseAlgorithm

Source code in fusion_bench/method/bitdelta/bitdelta.py
@auto_register_config
class BitDeltaAlgorithm(
    LightningFabricMixin,
    SimpleProfilerMixin,
    BaseAlgorithm,
):
    def __init__(
        self,
        save_dir: str,
        save_full_model: bool = False,
        lr: float = 1e-4,
        batch_size: int = 4,
        num_steps: int = 100,
        dataset_name: str = "c4",
        subset: str = "en",
        split: str = "train",
        max_length: int = 128,
        **kwargs,
    ):
        super().__init__(**kwargs)

    def run(self, modelpool: CausalLMPool):
        if self.save_dir is None:
            log.info(
                f"save_dir not set, using log_dir instead. log_dir: {self.log_dir}"
            )
            self.save_dir = self.log_dir

        with self.profile("model loading"):
            tokenizer = modelpool.load_tokenizer()
            base_model = modelpool.load_pretrained_model()
            finetuned_model = modelpool.load_model(modelpool.model_names[0])
            finetuned_compressed_model = modelpool.load_model(modelpool.model_names[0])

        with self.profile("model compression"):
            print(f"compressing diff...")
            compress_diff(base_model, finetuned_model, finetuned_compressed_model)

        # save untrained delta
        save_diff(
            finetuned_compressed_model, os.path.join(self.save_dir, "diff_untrained.pt")
        )

        optimizer = torch.optim.AdamW(
            finetuned_compressed_model.parameters(), lr=self.lr
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, self.num_steps
        )

        train_num_samples = self.batch_size * self.num_steps
        train_dataset = get_dataset(
            self.dataset_name,
            self.subset,
            "train",
            size=train_num_samples,
        )
        train_dataloader = get_dataloader(
            train_dataset,
            tokenizer,
            self.batch_size,
            num_workers=4,
            max_length=self.max_length,
        )

        bar = tqdm(train_dataloader)

        train_loss_list = []

        # Train loop
        for step, batch in enumerate(bar):
            batch1 = {k: v.to(finetuned_model.device) for k, v in batch.items()}
            with torch.inference_mode():
                finetuned_outputs = finetuned_model(**batch1)

            batch2 = {
                k: v.to(finetuned_compressed_model.device) for k, v in batch.items()
            }
            finetuned_compressed_outputs = finetuned_compressed_model(**batch2)

            loss = F.mse_loss(
                finetuned_outputs.logits.clone().to(
                    finetuned_compressed_outputs.logits.device
                ),
                finetuned_compressed_outputs.logits,
            )

            train_loss_list.append(loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()

            bar.set_description(f"train loss: {loss.item()}")

        # save trained delta
        save_diff(finetuned_compressed_model, os.path.join(self.save_dir, "diff.pt"))

        if self.save_full_model:
            print("saving uncalibrated model")
            save_full_model(
                base_model,
                tokenizer,
                os.path.join(self.save_dir, "diff_untrained.pt"),
                os.path.join(self.save_dir, "uncalibrated_model"),
                device="cpu",
            )
            print("saving calibrated model")
            save_full_model(
                base_model,
                tokenizer,
                os.path.join(self.save_dir, "diff.pt"),
                os.path.join(self.save_dir, "calibrated_model"),
                device="cpu",
            )

        del base_model, finetuned_model, finetuned_compressed_model
        torch.cuda.empty_cache()

Parameter Pruning

Random Pruning

RandomPruningForLlama

Bases: BaseAlgorithm, SimpleProfilerMixin

A class to perform random pruning for Llama models.

Attributes:

  • prune_type (PruningType) –

    The type of pruning to be performed.

  • sparsity_ratio (float) –

    The ratio of weights to be pruned.

  • n (int) –

    The number of weights to be pruned in each group (for semistructured pruning).

  • m (int) –

    The total number of weights in each group (for semistructured pruning).

Source code in fusion_bench/method/pruning/llama_random_prune.py
class RandomPruningForLlama(BaseAlgorithm, SimpleProfilerMixin):
    """
    A class to perform random pruning for Llama models.

    Attributes:
        prune_type (PruningType): The type of pruning to be performed.
        sparsity_ratio (float): The ratio of weights to be pruned.
        n (int): The number of weights to be pruned in each group (for semistructured pruning).
        m (int): The total number of weights in each group (for semistructured pruning).
    """

    _config_mapping = BaseAlgorithm._config_mapping | {
        "prune_type": "prune_type",
        "sparsity_ratio": "sparsity_ratio",
        "n": "n",
        "m": "m",
    }

    def __init__(
        self,
        *,
        prune_type: PruningType,
        sparsity_ratio: float,
        n: int,
        m: int,
        **kwargs,
    ):
        """
        Initialize the RandomPruningForLlama class.

        Args:
            prune_type (PruningType): The type of pruning to be performed.
            sparsity_ratio (float): The ratio of weights to be pruned.
            n (int): The number of weights to be pruned in each group (for semistructured pruning).
            m (int): The total number of weights in each group (for semistructured pruning).
            **kwargs: Additional keyword arguments.
        """
        self.prune_type = prune_type
        self.sparsity_ratio = sparsity_ratio
        self.n = n
        self.m = m
        super().__init__(**kwargs)

    @torch.no_grad()
    def run(self, modelpool: CausalLMPool):
        """
        Run the pruning algorithm on the first model from the given model pool.

        Args:
            modelpool (CausalLMPool): The pool of models to be pruned.

        Returns:
            The pruned model.
        """
        # load pre-trained model or the first model in the pool
        base_model = modelpool.load_pretrained_or_first_model()

        if self.prune_type == PruningType.UNSTRUCTURED:
            unstructured_magnitude_prune_(base_model, self.sparsity_ratio)
        elif self.prune_type == PruningType.SEMISTRUCTURED:
            semistructured_magnitude_prune_(base_model, self.n, self.m)
        else:
            raise ValueError(
                f"Invalid pruning type: {self.prune_type}"
                "Choose from 'unstructured' or 'semistructured'"
            )

        return base_model
__init__(*, prune_type, sparsity_ratio, n, m, **kwargs)

Initialize the RandomPruningForLlama class.

Parameters:

  • prune_type (PruningType) –

    The type of pruning to be performed.

  • sparsity_ratio (float) –

    The ratio of weights to be pruned.

  • n (int) –

    The number of weights to be pruned in each group (for semistructured pruning).

  • m (int) –

    The total number of weights in each group (for semistructured pruning).

  • **kwargs

    Additional keyword arguments.

Source code in fusion_bench/method/pruning/llama_random_prune.py
def __init__(
    self,
    *,
    prune_type: PruningType,
    sparsity_ratio: float,
    n: int,
    m: int,
    **kwargs,
):
    """
    Initialize the RandomPruningForLlama class.

    Args:
        prune_type (PruningType): The type of pruning to be performed.
        sparsity_ratio (float): The ratio of weights to be pruned.
        n (int): The number of weights to be pruned in each group (for semistructured pruning).
        m (int): The total number of weights in each group (for semistructured pruning).
        **kwargs: Additional keyword arguments.
    """
    self.prune_type = prune_type
    self.sparsity_ratio = sparsity_ratio
    self.n = n
    self.m = m
    super().__init__(**kwargs)
run(modelpool)

Run the pruning algorithm on the first model from the given model pool.

Parameters:

  • modelpool (CausalLMPool) –

    The pool of models to be pruned.

Returns:

  • The pruned model.

Source code in fusion_bench/method/pruning/llama_random_prune.py
@torch.no_grad()
def run(self, modelpool: CausalLMPool):
    """
    Run the pruning algorithm on the first model from the given model pool.

    Args:
        modelpool (CausalLMPool): The pool of models to be pruned.

    Returns:
        The pruned model.
    """
    # load pre-trained model or the first model in the pool
    base_model = modelpool.load_pretrained_or_first_model()

    if self.prune_type == PruningType.UNSTRUCTURED:
        unstructured_magnitude_prune_(base_model, self.sparsity_ratio)
    elif self.prune_type == PruningType.SEMISTRUCTURED:
        semistructured_magnitude_prune_(base_model, self.n, self.m)
    else:
        raise ValueError(
            f"Invalid pruning type: {self.prune_type}"
            "Choose from 'unstructured' or 'semistructured'"
        )

    return base_model

Magnitude-based Pruning

MagnitudeDiffPruningAlgorithm

Bases: BaseAlgorithm, SimpleProfilerMixin

Implements magnitude-based pruning on the difference between pretrained and fine-tuned model parameters.

This class supports pruning the difference between the pretrained and fine-tuned model parameters based on their magnitude. It allows specifying the ratio of weights to prune and the names of parameters to extract for pruning.

Methods:

  • run

    BaseModelPool) -> nn.Module: Executes the pruning process on the model pool and returns the pruned model.

  • magnitude_prune

    nn.Module, finetuned_model: nn.Module, in_place: bool = True) -> nn.Module: Prunes the difference between the pretrained and fine-tuned model parameters.

Source code in fusion_bench/method/pruning/magnitude_diff_pruning.py
class MagnitudeDiffPruningAlgorithm(
    BaseAlgorithm,
    SimpleProfilerMixin,
):
    """
    Implements magnitude-based pruning on the difference between pretrained and fine-tuned model parameters.

    This class supports pruning the difference between the pretrained and fine-tuned model parameters
    based on their magnitude. It allows specifying the ratio of weights to prune and the names of
    parameters to extract for pruning.

    Methods:
        run(modelpool: BaseModelPool) -> nn.Module:
            Executes the pruning process on the model pool and returns the pruned model.
        magnitude_prune(pretrained_model: nn.Module, finetuned_model: nn.Module, in_place: bool = True) -> nn.Module:
            Prunes the difference between the pretrained and fine-tuned model parameters.
    """

    _config_mapping = BaseAlgorithm._config_mapping | {
        "prune_ratio": "prune_ratio",
        "extract_names": "extract_names",
    }

    def __init__(
        self,
        prune_ratio: float,
        rescale: Optional[Union[bool, float]] = None,
        extract_names: List[str] = None,
        prune_type: Literal["minor", "major"] = "minor",
        **kwargs,
    ):
        """
        Initialize the MagnitudeDiffPruningAlgorithm with the given configuration.

        Args:
            prune_ratio (float): The ratio of weights to prune.
            extract_names (List[str], optional): List of regular expressions to match the parameter names for pruning. Defaults to None.
            **kwargs: Additional keyword arguments.
        """
        self.prune_ratio = prune_ratio
        self.rescale = rescale
        self.extract_names = extract_names
        self.prune_type = prune_type
        super().__init__(**kwargs)

    @torch.no_grad()
    def run(self, modelpool: BaseModelPool):
        """
        Execute the pruning process on the model pool.

        This method loads the pretrained and fine-tuned models from the model pool,
        prunes the difference between their parameters, and returns the pruned model.

        Args:
            modelpool (BaseModelPool): The model pool containing the models to prune.

        Returns:
            nn.Module: The pruned model.
        """
        if not isinstance(modelpool, BaseModelPool):
            modelpool = BaseModelPool(modelpool)

        assert (
            len(modelpool.model_names) == 1
        ), "Only one fine-tuned model is allowed in the model pool."
        with self.profile("load pretrained model"):
            pretrained_model = modelpool.load_model("_pretrained_")
        with self.profile("load fine-tuned model"):
            finetuned_model = modelpool.load_model(modelpool.model_names[0])

        with self.profile("prune model"):
            model = self.magnitude_prune(pretrained_model, finetuned_model)

        self.print_profile_summary()
        return model

    @torch.no_grad()
    def magnitude_prune(
        self,
        pretrained_model: nn.Module,
        finetuned_model: nn.Module,
        in_place: bool = True,
    ):
        """
        Prune the difference between the pretrained and fine-tuned model parameters.

        This method calculates the difference between the pretrained and fine-tuned model parameters,
        prunes the difference based on their magnitude, and updates the pretrained model parameters
        with the pruned difference.

        Args:
            pretrained_model (nn.Module): The pretrained model.
            finetuned_model (nn.Module): The fine-tuned model.
            in_place (bool, optional): Whether to perform the pruning in place. Defaults to True.

        Returns:
            nn.Module: The pruned model.
        """
        if in_place:
            model = pretrained_model
        else:
            model = deepcopy(pretrained_model)

        if self.extract_names is not None:
            extract_names: List[str] = (
                self.extract_names
            )  # regular expressions for the names of the parameters
        else:
            # extract the weight matrix of each linear layer
            extract_names = []
            for name, module in model.named_modules():
                if isinstance(module, nn.Linear):
                    extract_names.append(f"{name}.weight")

        ft_state_dict = finetuned_model.state_dict()
        for name, param in tqdm(
            model.named_parameters(),
            "Magnitude Pruning On Parameter Difference",
            total=len(tuple(model.named_parameters())),
        ):
            if not param.requires_grad:
                continue

            # Prune the diff parameter if its name matches
            if _is_name_matched(name, extract_names):
                w_diff = ft_state_dict[name] - param
                w_diff = unstructured_magnitude_prune_(
                    w_diff,
                    (
                        torch.abs
                        if self.prune_type == "minor"
                        else lambda x: -torch.abs(x)
                    ),
                    sparsity_ratio=self.prune_ratio,
                )
                if self.rescale is not None:
                    rescale = (
                        1 / self.prune_ratio if self.rescale == True else self.rescale
                    )
                    w_diff = w_diff * rescale
                param.data = param + w_diff

        return model
__init__(prune_ratio, rescale=None, extract_names=None, prune_type='minor', **kwargs)

Initialize the MagnitudeDiffPruningAlgorithm with the given configuration.

Parameters:

  • prune_ratio (float) –

    The ratio of weights to prune.

  • extract_names (List[str], default: None ) –

    List of regular expressions to match the parameter names for pruning. Defaults to None.

  • **kwargs

    Additional keyword arguments.

Source code in fusion_bench/method/pruning/magnitude_diff_pruning.py
def __init__(
    self,
    prune_ratio: float,
    rescale: Optional[Union[bool, float]] = None,
    extract_names: List[str] = None,
    prune_type: Literal["minor", "major"] = "minor",
    **kwargs,
):
    """
    Initialize the MagnitudeDiffPruningAlgorithm with the given configuration.

    Args:
        prune_ratio (float): The ratio of weights to prune.
        extract_names (List[str], optional): List of regular expressions to match the parameter names for pruning. Defaults to None.
        **kwargs: Additional keyword arguments.
    """
    self.prune_ratio = prune_ratio
    self.rescale = rescale
    self.extract_names = extract_names
    self.prune_type = prune_type
    super().__init__(**kwargs)
magnitude_prune(pretrained_model, finetuned_model, in_place=True)

Prune the difference between the pretrained and fine-tuned model parameters.

This method calculates the difference between the pretrained and fine-tuned model parameters, prunes the difference based on their magnitude, and updates the pretrained model parameters with the pruned difference.

Parameters:

  • pretrained_model (Module) –

    The pretrained model.

  • finetuned_model (Module) –

    The fine-tuned model.

  • in_place (bool, default: True ) –

    Whether to perform the pruning in place. Defaults to True.

Returns:

  • nn.Module: The pruned model.

Source code in fusion_bench/method/pruning/magnitude_diff_pruning.py
@torch.no_grad()
def magnitude_prune(
    self,
    pretrained_model: nn.Module,
    finetuned_model: nn.Module,
    in_place: bool = True,
):
    """
    Prune the difference between the pretrained and fine-tuned model parameters.

    This method calculates the difference between the pretrained and fine-tuned model parameters,
    prunes the difference based on their magnitude, and updates the pretrained model parameters
    with the pruned difference.

    Args:
        pretrained_model (nn.Module): The pretrained model.
        finetuned_model (nn.Module): The fine-tuned model.
        in_place (bool, optional): Whether to perform the pruning in place. Defaults to True.

    Returns:
        nn.Module: The pruned model.
    """
    if in_place:
        model = pretrained_model
    else:
        model = deepcopy(pretrained_model)

    if self.extract_names is not None:
        extract_names: List[str] = (
            self.extract_names
        )  # regular expressions for the names of the parameters
    else:
        # extract the weight matrix of each linear layer
        extract_names = []
        for name, module in model.named_modules():
            if isinstance(module, nn.Linear):
                extract_names.append(f"{name}.weight")

    ft_state_dict = finetuned_model.state_dict()
    for name, param in tqdm(
        model.named_parameters(),
        "Magnitude Pruning On Parameter Difference",
        total=len(tuple(model.named_parameters())),
    ):
        if not param.requires_grad:
            continue

        # Prune the diff parameter if its name matches
        if _is_name_matched(name, extract_names):
            w_diff = ft_state_dict[name] - param
            w_diff = unstructured_magnitude_prune_(
                w_diff,
                (
                    torch.abs
                    if self.prune_type == "minor"
                    else lambda x: -torch.abs(x)
                ),
                sparsity_ratio=self.prune_ratio,
            )
            if self.rescale is not None:
                rescale = (
                    1 / self.prune_ratio if self.rescale == True else self.rescale
                )
                w_diff = w_diff * rescale
            param.data = param + w_diff

    return model
run(modelpool)

Execute the pruning process on the model pool.

This method loads the pretrained and fine-tuned models from the model pool, prunes the difference between their parameters, and returns the pruned model.

Parameters:

  • modelpool (BaseModelPool) –

    The model pool containing the models to prune.

Returns:

  • nn.Module: The pruned model.

Source code in fusion_bench/method/pruning/magnitude_diff_pruning.py
@torch.no_grad()
def run(self, modelpool: BaseModelPool):
    """
    Execute the pruning process on the model pool.

    This method loads the pretrained and fine-tuned models from the model pool,
    prunes the difference between their parameters, and returns the pruned model.

    Args:
        modelpool (BaseModelPool): The model pool containing the models to prune.

    Returns:
        nn.Module: The pruned model.
    """
    if not isinstance(modelpool, BaseModelPool):
        modelpool = BaseModelPool(modelpool)

    assert (
        len(modelpool.model_names) == 1
    ), "Only one fine-tuned model is allowed in the model pool."
    with self.profile("load pretrained model"):
        pretrained_model = modelpool.load_model("_pretrained_")
    with self.profile("load fine-tuned model"):
        finetuned_model = modelpool.load_model(modelpool.model_names[0])

    with self.profile("prune model"):
        model = self.magnitude_prune(pretrained_model, finetuned_model)

    self.print_profile_summary()
    return model

MagnitudePruningForLlama

Bases: BaseAlgorithm, SimpleProfilerMixin

Implements magnitude-based pruning for LLama models.

This class supports both unstructured and semistructured pruning methods. It loads a pre-trained model or the first model in the pool and applies the specified pruning technique.

Methods:

  • run

    LLamaForCausalLMPool) -> nn.Module: Executes the pruning process on the model pool and returns the pruned model.

Source code in fusion_bench/method/pruning/llama_magnitude_prune.py
class MagnitudePruningForLlama(BaseAlgorithm, SimpleProfilerMixin):
    """
    Implements magnitude-based pruning for LLama models.

    This class supports both unstructured and semistructured pruning methods.
    It loads a pre-trained model or the first model in the pool and applies the specified pruning technique.

    Methods:
        run(modelpool: LLamaForCausalLMPool) -> nn.Module:
            Executes the pruning process on the model pool and returns the pruned model.
    """

    _config_mapping = BaseAlgorithm._config_mapping | {
        "prune_type": "prune_type",
        "device": "device",
        "dtype": "dtype",
        "sparsity_ratio": "sparsity_ratio",
        "n": "n",
        "m": "m",
    }

    def __init__(
        self,
        *,
        prune_type: Literal["unstructured", "semistructured"],
        device: str,
        dtype: Optional[str],
        sparsity_ratio: float,
        n: int,
        m: int,
        **kwargs,
    ):
        self.prune_type = prune_type
        self.device = device
        self.dtype = dtype
        self.sparsity_ratio = sparsity_ratio
        self.n = n
        self.m = m
        super().__init__(**kwargs)

    @torch.no_grad()
    def run(self, modelpool: CausalLMPool) -> LlamaForCausalLM:
        """
        Execute the pruning process on the first model from the given model pool.

        Args:
            modelpool (CausalLMPool): The model pool containing the models to prune.

        Returns:
            nn.Module: The pruned model.
        """
        config = self.config

        # load pre-trained model or the first model in the pool
        base_model = modelpool.load_pretrained_or_first_model()

        dtype = parse_dtype(config.dtype)
        device = torch.device(config.device)

        if config.prune_type == "unstructured":
            unstructured_magnitude_prune_(
                base_model, config.sparsity_ratio, dtype=dtype, device=device
            )
        elif config.prune_type == "semistructured":
            semistructured_magnitude_prune_(
                base_model, config.n, config.m, dtype=dtype, device=device
            )
        else:
            raise ValueError(
                f"Invalid pruning type: {config.prune_type}"
                "Choose from 'unstructured' or 'semistructured'"
            )

        return base_model
run(modelpool)

Execute the pruning process on the first model from the given model pool.

Parameters:

  • modelpool (CausalLMPool) –

    The model pool containing the models to prune.

Returns:

  • LlamaForCausalLM

    nn.Module: The pruned model.

Source code in fusion_bench/method/pruning/llama_magnitude_prune.py
@torch.no_grad()
def run(self, modelpool: CausalLMPool) -> LlamaForCausalLM:
    """
    Execute the pruning process on the first model from the given model pool.

    Args:
        modelpool (CausalLMPool): The model pool containing the models to prune.

    Returns:
        nn.Module: The pruned model.
    """
    config = self.config

    # load pre-trained model or the first model in the pool
    base_model = modelpool.load_pretrained_or_first_model()

    dtype = parse_dtype(config.dtype)
    device = torch.device(config.device)

    if config.prune_type == "unstructured":
        unstructured_magnitude_prune_(
            base_model, config.sparsity_ratio, dtype=dtype, device=device
        )
    elif config.prune_type == "semistructured":
        semistructured_magnitude_prune_(
            base_model, config.n, config.m, dtype=dtype, device=device
        )
    else:
        raise ValueError(
            f"Invalid pruning type: {config.prune_type}"
            "Choose from 'unstructured' or 'semistructured'"
        )

    return base_model

Wanda

WandaPruningForLlama

Bases: BaseAlgorithm, SimpleProfilerMixin

Class for Wanda pruning for Llama models.

Source code in fusion_bench/method/pruning/llama_wanda_prune.py
class WandaPruningForLlama(BaseAlgorithm, SimpleProfilerMixin):
    """
    Class for Wanda pruning for Llama models.
    """

    _config_mapping = BaseAlgorithm._config_mapping | {
        "nsamples": "nsamples",
        "seed": "seed",
        "use_variant": "use_variant",
        "prune_type": "prune_type",
        "device": "device",
        "dtype": "dtype",
        "sparsity_ratio": "sparsity_ratio",
        "n": "n",
        "m": "m",
    }

    def __init__(
        self,
        *,
        nsamples: int,
        seed: int,
        use_variant: bool,
        prune_type: PruningType,
        device: str,
        dtype: str,
        sparsity_ratio: float,
        n: int,
        m: int,
        model_save_path: Optional[str] = None,
        **kwargs,
    ):
        """
        Initialize the WandaPruningForLlama class.

        Args:
            nsamples (int): Number of samples for calibration.
            seed (int): Random seed.
            use_variant (bool): Whether to use a variant of the pruning method.
            prune_type (PruningType): Type of pruning to perform.
            device (str): Device to use for computation.
            dtype (str): Data type to use for computation.
            sparsity_ratio (float): Sparsity ratio for pruning.
            n (int): Number of elements to keep in semi-structured pruning.
            m (int): Number of elements in a group for semi-structured pruning.
            model_save_path (Optional[str]): Path to save the pruned model.
            **kwargs: Additional arguments.
        """
        super().__init__(**kwargs)
        self.nsamples = nsamples
        self.seed = seed
        self.use_variant = use_variant
        self.prune_type = prune_type
        self.device = device
        self.dtype = dtype
        self.sparsity_ratio = sparsity_ratio
        self.n = n
        self.m = m
        self.model_save_path = model_save_path

    def run(self, modelpool: CausalLMPool):
        """
        Run the pruning algorithm on the model pool.

        Args:
            modelpool (CausalLMPool): Pool of causal language models.

        Returns:
            LlamaForCausalLM: Pruned model.
        """

        # load pre-trained model or the first model in the pool
        with self.profile("load_model"):
            model = modelpool.load_pretrained_or_first_model()
            model.seqlen = model.config.max_position_embeddings
            tokenizer = modelpool.load_tokenizer(use_fast=False)

        if not isinstance(model, (LlamaForCausalLM,)):
            log.warning(f"Model type {type(model)} may not supported.")

        inps, outs, attention_mask, position_ids = self.prepare_calibration_data(
            model, tokenizer
        )

        self.prune_using_calibration_data_(
            model,
            inps=inps,
            outs=outs,
            attention_mask=attention_mask,
            position_ids=position_ids,
        )

        if self.model_save_path is not None:
            with timeit_context(f"Saving pruned model to {self.model_save_path}"):
                tokenizer.save_pretrained(self.model_save_path)
                model.save_pretrained(self.model_save_path)
        return model

    def _prepare_calibration_data(self, model, tokenizer):
        """
        Prepare calibration data for pruning.

        Args:
            model (LlamaForCausalLM): Model to be pruned.
            tokenizer: Tokenizer for the model.

        Returns:
            Tuple: Calibration data (inputs, outputs, attention mask, position IDs).
        """
        with timeit_context("loading calibration data"):
            dataloader, _ = get_loaders(
                "c4",
                nsamples=self.nsamples,
                seed=self.seed,
                seqlen=model.seqlen,
                tokenizer=tokenizer,
            )

        with torch.no_grad():
            # collect input to the first layer
            inps, outs, attention_mask, position_ids = prepare_calibration_input(
                model, dataloader, self.device
            )
        return inps, outs, attention_mask, position_ids

    def prepare_calibration_data(self, model: LlamaForCausalLM, tokenizer):
        """
        Prepare calibration data for pruning with caching.

        Args:
            model (LlamaForCausalLM): Model to be pruned.
            tokenizer: Tokenizer for the model.

        Returns:
            Tuple: Calibration data (inputs, outputs, attention mask, position IDs).
        """

        @cache_to_disk(
            f"outputs/cache/{model.config.name_or_path.split('/')[-1]}/calibration_data.pkl"
        )
        def _prepare_calibration_data(model, tokenizer):
            return self._prepare_calibration_data(model, tokenizer)

        return _prepare_calibration_data(model, tokenizer)

    def prune_using_calibration_data_(
        self,
        model: LlamaForCausalLM,
        *,
        inps,
        outs,
        attention_mask,
        position_ids,
    ):
        """
        Prune the model using calibration data.

        Args:
            model (LlamaForCausalLM): Model to be pruned.
            inps: Calibration inputs.
            outs: Calibration outputs.
            attention_mask: Attention mask for calibration data.
            position_ids: Position IDs for calibration data.
        """
        layers = model.model.layers
        for layer_idx, layer in tqdm(
            enumerate(layers),
            "Pruning Layers",
            total=len(layers),
            dynamic_ncols=True,
        ):
            if (
                hasattr(model, "hf_device_map")
                and f"model.layers.{layer_idx}" in model.hf_device_map
            ):
                # handle the case for llama-30B and llama-65B, when the device map has multiple GPUs;
                dev = model.hf_device_map[f"model.layers.{layer_idx}"]
                inps, outs, attention_mask, position_ids = (
                    inps.to(dev),
                    outs.to(dev),
                    attention_mask.to(dev) if attention_mask is not None else None,
                    position_ids.to(dev) if position_ids is not None else None,
                )

            # collect the importance scores
            linear_layers = cast(
                Dict[str, nn.Linear],
                find_linear_layers(layer, layers=[nn.Linear]),
            )

            # register hooks to collect the importance scores
            def get_hook_fn(linear: nn.Linear):
                hook_fn = WandaHookFn(linear)
                return hook_fn

            hooks = {}
            handles: List[torch.utils.hooks.RemovableHandle] = []
            for name, linear in linear_layers.items():
                hook_fn = get_hook_fn(linear)
                hooks[name] = hook_fn
                handles.append(linear.register_forward_hook(hook_fn))

            with torch.no_grad():
                for j in range(self.nsamples):
                    outs[j] = layer(
                        inps[j].unsqueeze(0),
                        attention_mask=attention_mask,
                        position_ids=position_ids,
                    )[0]

            # compute the importance scores and remove the hooks
            metrics = {}
            for name, hook in hooks.items():
                metrics[name] = hook.compute()
            for h in handles:
                h.remove()

            # prune the weights based on the importance scores
            if self.prune_type == PruningType.UNSTRUCTURED:
                for name, linear in linear_layers.items():
                    log.info(f"Pruning {name}")
                    unstructured_magnitude_prune_(
                        linear.weight.data,
                        metrics[name],
                        sparsity_ratio=self.sparsity_ratio,
                    )
                    self.check_sparsity(linear.weight)
            elif self.prune_type == PruningType.SEMISTRUCTURED:
                for name, linear in linear_layers.items():
                    log.info(f"Pruning {name}")
                    semistructured_magnitude_prune_(
                        linear.weight.data,
                        metrics[name],
                        n=self.n,
                        m=self.m,
                    )
                    self.check_sparsity(linear.weight)
            else:
                raise ValueError(f"Invalid pruning type: {self.prune_type}")

            # compute the input to the next layer
            with torch.no_grad():
                for j in range(self.nsamples):
                    outs[j] = layer(
                        inps[j].unsqueeze(0),
                        attention_mask=attention_mask,
                        position_ids=position_ids,
                    )[0]
            inps, outs = outs, inps

    @torch.no_grad()
    def check_sparsity(self, weight: Tensor, tol: float = 0.01):
        """
        Check the sparsity of the weight tensor.

        Args:
            weight (Tensor): Weight tensor.
            tol (float): Tolerance for sparsity check.

        Raises:
            ValueError: If the pruning type is invalid.
        """
        if self.prune_type == PruningType.UNSTRUCTURED:
            assert (compute_sparsity(weight) - self.sparsity_ratio).abs() < tol
        elif self.prune_type == PruningType.SEMISTRUCTURED:
            assert (compute_sparsity(weight) - self.n / self.m).abs() < tol
        else:
            raise ValueError(f"Invalid pruning type: {self.prune_type}")
__init__(*, nsamples, seed, use_variant, prune_type, device, dtype, sparsity_ratio, n, m, model_save_path=None, **kwargs)

Initialize the WandaPruningForLlama class.

Parameters:

  • nsamples (int) –

    Number of samples for calibration.

  • seed (int) –

    Random seed.

  • use_variant (bool) –

    Whether to use a variant of the pruning method.

  • prune_type (PruningType) –

    Type of pruning to perform.

  • device (str) –

    Device to use for computation.

  • dtype (str) –

    Data type to use for computation.

  • sparsity_ratio (float) –

    Sparsity ratio for pruning.

  • n (int) –

    Number of elements to keep in semi-structured pruning.

  • m (int) –

    Number of elements in a group for semi-structured pruning.

  • model_save_path (Optional[str], default: None ) –

    Path to save the pruned model.

  • **kwargs

    Additional arguments.

Source code in fusion_bench/method/pruning/llama_wanda_prune.py
def __init__(
    self,
    *,
    nsamples: int,
    seed: int,
    use_variant: bool,
    prune_type: PruningType,
    device: str,
    dtype: str,
    sparsity_ratio: float,
    n: int,
    m: int,
    model_save_path: Optional[str] = None,
    **kwargs,
):
    """
    Initialize the WandaPruningForLlama class.

    Args:
        nsamples (int): Number of samples for calibration.
        seed (int): Random seed.
        use_variant (bool): Whether to use a variant of the pruning method.
        prune_type (PruningType): Type of pruning to perform.
        device (str): Device to use for computation.
        dtype (str): Data type to use for computation.
        sparsity_ratio (float): Sparsity ratio for pruning.
        n (int): Number of elements to keep in semi-structured pruning.
        m (int): Number of elements in a group for semi-structured pruning.
        model_save_path (Optional[str]): Path to save the pruned model.
        **kwargs: Additional arguments.
    """
    super().__init__(**kwargs)
    self.nsamples = nsamples
    self.seed = seed
    self.use_variant = use_variant
    self.prune_type = prune_type
    self.device = device
    self.dtype = dtype
    self.sparsity_ratio = sparsity_ratio
    self.n = n
    self.m = m
    self.model_save_path = model_save_path
check_sparsity(weight, tol=0.01)

Check the sparsity of the weight tensor.

Parameters:

  • weight (Tensor) –

    Weight tensor.

  • tol (float, default: 0.01 ) –

    Tolerance for sparsity check.

Raises:

  • ValueError

    If the pruning type is invalid.

Source code in fusion_bench/method/pruning/llama_wanda_prune.py
@torch.no_grad()
def check_sparsity(self, weight: Tensor, tol: float = 0.01):
    """
    Check the sparsity of the weight tensor.

    Args:
        weight (Tensor): Weight tensor.
        tol (float): Tolerance for sparsity check.

    Raises:
        ValueError: If the pruning type is invalid.
    """
    if self.prune_type == PruningType.UNSTRUCTURED:
        assert (compute_sparsity(weight) - self.sparsity_ratio).abs() < tol
    elif self.prune_type == PruningType.SEMISTRUCTURED:
        assert (compute_sparsity(weight) - self.n / self.m).abs() < tol
    else:
        raise ValueError(f"Invalid pruning type: {self.prune_type}")
prepare_calibration_data(model, tokenizer)

Prepare calibration data for pruning with caching.

Parameters:

  • model (LlamaForCausalLM) –

    Model to be pruned.

  • tokenizer

    Tokenizer for the model.

Returns:

  • Tuple

    Calibration data (inputs, outputs, attention mask, position IDs).

Source code in fusion_bench/method/pruning/llama_wanda_prune.py
def prepare_calibration_data(self, model: LlamaForCausalLM, tokenizer):
    """
    Prepare calibration data for pruning with caching.

    Args:
        model (LlamaForCausalLM): Model to be pruned.
        tokenizer: Tokenizer for the model.

    Returns:
        Tuple: Calibration data (inputs, outputs, attention mask, position IDs).
    """

    @cache_to_disk(
        f"outputs/cache/{model.config.name_or_path.split('/')[-1]}/calibration_data.pkl"
    )
    def _prepare_calibration_data(model, tokenizer):
        return self._prepare_calibration_data(model, tokenizer)

    return _prepare_calibration_data(model, tokenizer)
prune_using_calibration_data_(model, *, inps, outs, attention_mask, position_ids)

Prune the model using calibration data.

Parameters:

  • model (LlamaForCausalLM) –

    Model to be pruned.

  • inps

    Calibration inputs.

  • outs

    Calibration outputs.

  • attention_mask

    Attention mask for calibration data.

  • position_ids

    Position IDs for calibration data.

Source code in fusion_bench/method/pruning/llama_wanda_prune.py
def prune_using_calibration_data_(
    self,
    model: LlamaForCausalLM,
    *,
    inps,
    outs,
    attention_mask,
    position_ids,
):
    """
    Prune the model using calibration data.

    Args:
        model (LlamaForCausalLM): Model to be pruned.
        inps: Calibration inputs.
        outs: Calibration outputs.
        attention_mask: Attention mask for calibration data.
        position_ids: Position IDs for calibration data.
    """
    layers = model.model.layers
    for layer_idx, layer in tqdm(
        enumerate(layers),
        "Pruning Layers",
        total=len(layers),
        dynamic_ncols=True,
    ):
        if (
            hasattr(model, "hf_device_map")
            and f"model.layers.{layer_idx}" in model.hf_device_map
        ):
            # handle the case for llama-30B and llama-65B, when the device map has multiple GPUs;
            dev = model.hf_device_map[f"model.layers.{layer_idx}"]
            inps, outs, attention_mask, position_ids = (
                inps.to(dev),
                outs.to(dev),
                attention_mask.to(dev) if attention_mask is not None else None,
                position_ids.to(dev) if position_ids is not None else None,
            )

        # collect the importance scores
        linear_layers = cast(
            Dict[str, nn.Linear],
            find_linear_layers(layer, layers=[nn.Linear]),
        )

        # register hooks to collect the importance scores
        def get_hook_fn(linear: nn.Linear):
            hook_fn = WandaHookFn(linear)
            return hook_fn

        hooks = {}
        handles: List[torch.utils.hooks.RemovableHandle] = []
        for name, linear in linear_layers.items():
            hook_fn = get_hook_fn(linear)
            hooks[name] = hook_fn
            handles.append(linear.register_forward_hook(hook_fn))

        with torch.no_grad():
            for j in range(self.nsamples):
                outs[j] = layer(
                    inps[j].unsqueeze(0),
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                )[0]

        # compute the importance scores and remove the hooks
        metrics = {}
        for name, hook in hooks.items():
            metrics[name] = hook.compute()
        for h in handles:
            h.remove()

        # prune the weights based on the importance scores
        if self.prune_type == PruningType.UNSTRUCTURED:
            for name, linear in linear_layers.items():
                log.info(f"Pruning {name}")
                unstructured_magnitude_prune_(
                    linear.weight.data,
                    metrics[name],
                    sparsity_ratio=self.sparsity_ratio,
                )
                self.check_sparsity(linear.weight)
        elif self.prune_type == PruningType.SEMISTRUCTURED:
            for name, linear in linear_layers.items():
                log.info(f"Pruning {name}")
                semistructured_magnitude_prune_(
                    linear.weight.data,
                    metrics[name],
                    n=self.n,
                    m=self.m,
                )
                self.check_sparsity(linear.weight)
        else:
            raise ValueError(f"Invalid pruning type: {self.prune_type}")

        # compute the input to the next layer
        with torch.no_grad():
            for j in range(self.nsamples):
                outs[j] = layer(
                    inps[j].unsqueeze(0),
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                )[0]
        inps, outs = outs, inps
run(modelpool)

Run the pruning algorithm on the model pool.

Parameters:

  • modelpool (CausalLMPool) –

    Pool of causal language models.

Returns:

  • LlamaForCausalLM

    Pruned model.

Source code in fusion_bench/method/pruning/llama_wanda_prune.py
def run(self, modelpool: CausalLMPool):
    """
    Run the pruning algorithm on the model pool.

    Args:
        modelpool (CausalLMPool): Pool of causal language models.

    Returns:
        LlamaForCausalLM: Pruned model.
    """

    # load pre-trained model or the first model in the pool
    with self.profile("load_model"):
        model = modelpool.load_pretrained_or_first_model()
        model.seqlen = model.config.max_position_embeddings
        tokenizer = modelpool.load_tokenizer(use_fast=False)

    if not isinstance(model, (LlamaForCausalLM,)):
        log.warning(f"Model type {type(model)} may not supported.")

    inps, outs, attention_mask, position_ids = self.prepare_calibration_data(
        model, tokenizer
    )

    self.prune_using_calibration_data_(
        model,
        inps=inps,
        outs=outs,
        attention_mask=attention_mask,
        position_ids=position_ids,
    )

    if self.model_save_path is not None:
        with timeit_context(f"Saving pruned model to {self.model_save_path}"):
            tokenizer.save_pretrained(self.model_save_path)
            model.save_pretrained(self.model_save_path)
    return model

SparseGPT

SparseGPTPruningForLlama

Bases: BaseAlgorithm, SimpleProfilerMixin

Source code in fusion_bench/method/pruning/llama_sparsegpt_prune.py
class SparseGPTPruningForLlama(BaseAlgorithm, SimpleProfilerMixin):
    def __init__(
        self,
        *,
        nsamples: int,
        seed: int,
        use_variant: bool,
        prune_type: PruningType,
        device: str,
        dtype: str,
        sparsity_ratio: float,
        n: int,
        m: int,
        model_save_path: Optional[str] = None,
        **kwargs,
    ):
        """
        Initialize the SparseGPTPruningForLlama class.

        Args:
            nsamples (int): Number of samples for calibration.
            seed (int): Random seed.
            use_variant (bool): Whether to use a variant of the pruning method.
            prune_type (PruningType): Type of pruning to perform.
            device (str): Device to use for computation.
            dtype (str): Data type to use for computation.
            sparsity_ratio (float): Sparsity ratio for pruning.
            n (int): Number of elements to keep in semi-structured pruning.
            m (int): Number of elements in a group for semi-structured pruning.
            model_save_path (Optional[str]): Path to save the pruned model.
            **kwargs: Additional arguments.
        """
        super().__init__(**kwargs)
        self.nsamples = nsamples
        self.seed = seed
        self.use_variant = use_variant
        self.prune_type = prune_type
        self.device = device
        self.dtype = dtype
        self.sparsity_ratio = sparsity_ratio
        self.n = n
        self.m = m
        self.model_save_path = model_save_path

    def run(self, modelpool: CausalLMPool):
        # load pre-trained model or the first model in the pool
        with self.profile("load_model"):
            model = modelpool.load_pretrained_or_first_model()
            model.seqlen = model.config.max_position_embeddings
            tokenizer = modelpool.load_tokenizer(use_fast=False)

        if not isinstance(model, (LlamaForCausalLM,)):
            log.warning(f"Model type {type(model)} may not supported.")

        inps, outs, attention_mask, position_ids = self.prepare_calibration_data(
            model, tokenizer
        )

        self.prune_using_calibration_data_(
            model,
            inps=inps,
            outs=outs,
            attention_mask=attention_mask,
            position_ids=position_ids,
        )

        if self.model_save_path is not None:
            with timeit_context(f"Saving pruned model to {self.model_save_path}"):
                tokenizer.save_pretrained(self.model_save_path)
                model.save_pretrained(self.model_save_path)
        return model

    def _prepare_calibration_data(self, model, tokenizer):
        """
        Prepare calibration data for pruning.

        Args:
            model (LlamaForCausalLM): Model to be pruned.
            tokenizer: Tokenizer for the model.

        Returns:
            Tuple: Calibration data (inputs, outputs, attention mask, position IDs).
        """
        with timeit_context("loading calibration data"):
            dataloader, _ = get_loaders(
                "c4",
                nsamples=self.nsamples,
                seed=self.seed,
                seqlen=model.seqlen,
                tokenizer=tokenizer,
            )

        with torch.no_grad():
            # collect input to the first layer
            inps, outs, attention_mask, position_ids = prepare_calibration_input(
                model, dataloader, self.device
            )
        return inps, outs, attention_mask, position_ids

    def prepare_calibration_data(self, model: LlamaForCausalLM, tokenizer):
        """
        Prepare calibration data for pruning with caching.

        Args:
            model (LlamaForCausalLM): Model to be pruned.
            tokenizer: Tokenizer for the model.

        Returns:
            Tuple: Calibration data (inputs, outputs, attention mask, position IDs).
        """

        @cache_to_disk(
            f"outputs/cache/{model.config.name_or_path.split('/')[-1]}/calibration_data.pkl"
        )
        def _prepare_calibration_data(model, tokenizer):
            return self._prepare_calibration_data(model, tokenizer)

        return _prepare_calibration_data(model, tokenizer)

    @torch.no_grad()
    def prune_using_calibration_data_(
        self,
        model: LlamaForCausalLM,
        *,
        inps,
        outs,
        attention_mask,
        position_ids,
    ):
        layers = model.model.layers
        for layer_indx, layer in tqdm(
            enumerate(layers),
            "Pruning Layers",
            total=len(layers),
            dynamic_ncols=True,
        ):
            layer = layers[layer_indx]
            if f"model.layers.{layer_indx}" in model.hf_device_map:
                dev = model.hf_device_map[f"model.layers.{layer_indx}"]
                print(f"layer {layer_indx} device {dev}")
                inps, outs, attention_mask, position_ids = (
                    inps.to(dev),
                    outs.to(dev),
                    attention_mask.to(dev),
                    position_ids.to(dev),
                )

            subset = find_linear_layers(layer, layers=[nn.Linear])

            gpts: Dict[str, SparseGPT] = {}
            for name in subset:
                gpts[name] = SparseGPT(subset[name])

            def add_batch(name):
                def tmp(_, inp, out):
                    gpts[name].add_batch(inp[0].data, out.data)

                return tmp

            handles = []
            for name in gpts:
                handles.append(subset[name].register_forward_hook(add_batch(name)))

            for j in range(self.nsamples):
                outs[j] = layer(
                    inps[j].unsqueeze(0),
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                )[0]
            for h in handles:
                h.remove()

            for name in gpts:
                print(layer_indx, name)
                print("Pruning ...")

                gpts[name].fasterprune(
                    self.sparsity_ratio,
                    prune_n=self.n,
                    prune_m=self.m,
                    percdamp=0.01,
                    blocksize=128,
                )
                gpts[name].free()

            for j in range(self.nsamples):
                outs[j] = layer(
                    inps[j].unsqueeze(0),
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                )[0]

            layers[layer_indx] = layer
            torch.cuda.empty_cache()

            inps, outs = outs, inps
__init__(*, nsamples, seed, use_variant, prune_type, device, dtype, sparsity_ratio, n, m, model_save_path=None, **kwargs)

Initialize the SparseGPTPruningForLlama class.

Parameters:

  • nsamples (int) –

    Number of samples for calibration.

  • seed (int) –

    Random seed.

  • use_variant (bool) –

    Whether to use a variant of the pruning method.

  • prune_type (PruningType) –

    Type of pruning to perform.

  • device (str) –

    Device to use for computation.

  • dtype (str) –

    Data type to use for computation.

  • sparsity_ratio (float) –

    Sparsity ratio for pruning.

  • n (int) –

    Number of elements to keep in semi-structured pruning.

  • m (int) –

    Number of elements in a group for semi-structured pruning.

  • model_save_path (Optional[str], default: None ) –

    Path to save the pruned model.

  • **kwargs

    Additional arguments.

Source code in fusion_bench/method/pruning/llama_sparsegpt_prune.py
def __init__(
    self,
    *,
    nsamples: int,
    seed: int,
    use_variant: bool,
    prune_type: PruningType,
    device: str,
    dtype: str,
    sparsity_ratio: float,
    n: int,
    m: int,
    model_save_path: Optional[str] = None,
    **kwargs,
):
    """
    Initialize the SparseGPTPruningForLlama class.

    Args:
        nsamples (int): Number of samples for calibration.
        seed (int): Random seed.
        use_variant (bool): Whether to use a variant of the pruning method.
        prune_type (PruningType): Type of pruning to perform.
        device (str): Device to use for computation.
        dtype (str): Data type to use for computation.
        sparsity_ratio (float): Sparsity ratio for pruning.
        n (int): Number of elements to keep in semi-structured pruning.
        m (int): Number of elements in a group for semi-structured pruning.
        model_save_path (Optional[str]): Path to save the pruned model.
        **kwargs: Additional arguments.
    """
    super().__init__(**kwargs)
    self.nsamples = nsamples
    self.seed = seed
    self.use_variant = use_variant
    self.prune_type = prune_type
    self.device = device
    self.dtype = dtype
    self.sparsity_ratio = sparsity_ratio
    self.n = n
    self.m = m
    self.model_save_path = model_save_path
prepare_calibration_data(model, tokenizer)

Prepare calibration data for pruning with caching.

Parameters:

  • model (LlamaForCausalLM) –

    Model to be pruned.

  • tokenizer

    Tokenizer for the model.

Returns:

  • Tuple

    Calibration data (inputs, outputs, attention mask, position IDs).

Source code in fusion_bench/method/pruning/llama_sparsegpt_prune.py
def prepare_calibration_data(self, model: LlamaForCausalLM, tokenizer):
    """
    Prepare calibration data for pruning with caching.

    Args:
        model (LlamaForCausalLM): Model to be pruned.
        tokenizer: Tokenizer for the model.

    Returns:
        Tuple: Calibration data (inputs, outputs, attention mask, position IDs).
    """

    @cache_to_disk(
        f"outputs/cache/{model.config.name_or_path.split('/')[-1]}/calibration_data.pkl"
    )
    def _prepare_calibration_data(model, tokenizer):
        return self._prepare_calibration_data(model, tokenizer)

    return _prepare_calibration_data(model, tokenizer)

Pruning with Low-Rank Refinement

SparseLoForLlama

Bases: BaseAlgorithm, SimpleProfilerMixin

Zero-Shot SVD Algorithm

Source code in fusion_bench/method/sparselo/sparselo.py
class SparseLoForLlama(BaseAlgorithm, SimpleProfilerMixin):
    "Zero-Shot SVD Algorithm"

    _variants_requires_calibration_data = ["wanda"]
    _variants_hook_mapping = {"wanda": WandaHookFn}

    _config_mapping = BaseAlgorithm._config_mapping | {
        "nsamples": "nsamples",
        "seed": "seed",
        "rank": "rank",
        "sparsity_ratio": "sparsity_ratio",
        "prune_type": "prune_type",
        "n": "n",
        "m": "m",
        "device": "device",
        "variant": "variant",
    }

    def __init__(
        self,
        *,
        nsamples: int,
        variant: Literal["dense", "random", "wanda", "lowrank-only", "magnitude"],
        seed: int,
        rank: int,
        sparsity_ratio: float,
        prune_type: PruningType,
        n: int,
        m: int,
        device: Optional[str] = None,
        model_save_path: Optional[str] = None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.nsamples = nsamples
        self.variant = variant
        self.seed = seed
        self.rank = rank
        self.sparsity_ratio = sparsity_ratio
        self.prune_type = prune_type
        self.device = device
        self.model_save_path = model_save_path
        self.n = n
        self.m = m

    @override
    def run(self, modelpool: CausalLMPool):
        self.modelpool = modelpool
        if self.seed is not None:
            L.seed_everything(self.seed)

        # load pre-trained model or the first model in the pool
        with self.profile("load_model"):
            model = modelpool.load_pretrained_or_first_model()
            model.seqlen = model.config.max_position_embeddings
            tokenizer = modelpool.load_tokenizer(use_fast=False)

        if not isinstance(model, (LlamaForCausalLM,)):
            log.warning(f"Model type {type(model)} may not supported.")

        if self.variant in self._variants_requires_calibration_data:
            inps, outs, attention_mask, position_ids = self.prepare_calibration_data(
                model, tokenizer
            )

        model = convert_to_losparse_llama(model, rank=self.rank)
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        for linear in find_linear_layers(model, layers=[LoSparseLinear]).values():
            linear = cast(LoSparseLinear, linear)
            linear.lo_A.data.zero_()
            linear.lo_B.data.zero_()
            linear.skip_lowrank = True

        match self.variant:
            case "dense":
                # this variant is a no-op, just for debug the conversion
                pass
            case "lowrank-only":
                self.extract_low_rank_parts_(model)
                self.set_weights_to_zeros_(model)
            case "random":
                self.random_prune_(model)
            case "magnitude":
                self.magnitude_prune_(model)
            case variant if variant in self._variants_requires_calibration_data:
                self.prune_using_calibration_data_(
                    model,
                    inps=inps,
                    outs=outs,
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                )
            case _:
                raise ValueError(f"Invalid variant: {self.variant}")

        if self.model_save_path is not None:
            with timeit_context(f"Saving the model to {self.model_save_path}"):
                tokenizer.save_pretrained(self.model_save_path)
                model.save_pretrained(self.model_save_path)

        return model

    def set_weights_to_zeros_(self, model):
        layers: nn.ModuleList = model.model.layers
        for layer in tqdm(
            list(layers),
            "Pruning Layers",
            dynamic_ncols=True,
        ):
            for name, losparse_linear in layer.named_modules():
                if isinstance(losparse_linear, LoSparseLinear):
                    log.info(f"Pruning {name}, set weights to zeros")
                    losparse_linear.weight.data.zero_()

    @torch.no_grad()
    def extract_low_rank_parts_(self, model):
        for layer in tqdm(
            list(model.model.layers),
            "Extract Low-Rank Parts (Layers)",
            dynamic_ncols=True,
        ):
            for losparse_linear in layer.modules():
                if isinstance(losparse_linear, LoSparseLinear):
                    if self.device is not None:
                        original_device = get_device(losparse_linear)
                        losparse_linear.to(self.device)
                    extract_low_rank_part_(losparse_linear, self.rank)
                    if self.device is not None:
                        losparse_linear.to(original_device)

    def _prepare_calibration_data(self, model, tokenizer):
        with timeit_context("loading calibration data"):
            dataloader, _ = get_loaders(
                "c4",
                nsamples=self.nsamples,
                seed=self.seed,
                seqlen=model.seqlen,
                tokenizer=tokenizer,
            )

        with torch.no_grad():
            # collect input to the first layer
            inps, outs, attention_mask, position_ids = prepare_calibration_input(
                model, dataloader, self.device
            )
        return inps, outs, attention_mask, position_ids

    def prepare_calibration_data(self, model: LlamaForCausalLM, tokenizer):

        @cache_to_disk(
            f"outputs/cache/{model.config.name_or_path.split('/')[-1]}/calibration_data.pkl"
        )
        def _prepare_calibration_data(model, tokenizer):
            return self._prepare_calibration_data(model, tokenizer)

        return _prepare_calibration_data(model, tokenizer)

    def random_prune_(self, model):
        layers: nn.ModuleList = model.model.layers
        for layer in tqdm(
            list(layers),
            "Pruning Layers",
            dynamic_ncols=True,
        ):
            for name, losparse_linear in layer.named_modules():
                if isinstance(losparse_linear, LoSparseLinear):
                    log.info(f"Pruning {name}, set weights to zeros")
                    if self.prune_type == PruningType.UNSTRUCTURED:
                        _, pruned_weights = unstructured_magnitude_prune_(
                            losparse_linear.weight.data,
                            metric_function_or_scores=torch.rand_like,
                            sparsity_ratio=self.sparsity_ratio,
                            return_pruned_weight=True,
                        )
                    elif self.prune_type == PruningType.SEMISTRUCTURED:
                        _, pruned_weights = semistructured_magnitude_prune_(
                            losparse_linear.weight.data,
                            metric_function_or_scores=torch.rand_like,
                            n=self.n,
                            m=self.m,
                            return_pruned_weight=True,
                        )
                    else:
                        raise ValueError(f"Invalid pruning type: {self.prune_type}")
                    self.check_sparsity(losparse_linear.weight)
                    self.extract_low_rank_part_using_pruned_(
                        losparse_linear, pruned_weights
                    )

    def magnitude_prune_(self, model):
        layers: nn.ModuleList = model.model.layers
        for layer_idx, layer in tqdm(
            enumerate(layers), "Pruning Layers", total=len(layers), dynamic_ncols=True
        ):
            for name, losparse_linear in layer.named_modules():
                if isinstance(losparse_linear, LoSparseLinear):
                    log.info(f"Magnitude Pruning {name}")
                    if self.prune_type == PruningType.UNSTRUCTURED:
                        _, pruned_weights = unstructured_magnitude_prune_(
                            losparse_linear.weight.data,
                            metric_function_or_scores=torch.abs,
                            sparsity_ratio=self.sparsity_ratio,
                            return_pruned_weight=True,
                        )
                    elif self.prune_type == PruningType.SEMISTRUCTURED:
                        _, pruned_weights = semistructured_magnitude_prune_(
                            losparse_linear.weight.data,
                            metric_function_or_scores=torch.abs,
                            n=self.n,
                            m=self.m,
                            return_pruned_weight=True,
                        )
                    else:
                        raise ValueError(f"Invalid pruning type: {self.prune_type}")
                    self.check_sparsity(losparse_linear.weight)
                    self.extract_low_rank_part_using_pruned_(
                        losparse_linear, pruned_weights
                    )

    def prune_using_calibration_data_(
        self,
        model: LoSparseLlamaForCausalLM,
        *,
        inps: Tensor,
        outs: Tensor,
        attention_mask: Optional[Tensor],
        position_ids: Optional[Tensor],
    ):
        layers = model.model.layers
        for layer_idx, layer in tqdm(
            enumerate(layers),
            "Pruning Layers",
            total=len(layers),
            dynamic_ncols=True,
        ):
            if (
                hasattr(model, "hf_device_map")
                and f"model.layers.{layer_idx}" in model.hf_device_map
            ):  ## handle the case for llama-30B and llama-65B, when the device map has multiple GPUs;
                dev = model.hf_device_map[f"model.layers.{layer_idx}"]
                inps, outs, attention_mask, position_ids = (
                    inps.to(dev),
                    outs.to(dev),
                    attention_mask.to(dev) if attention_mask is not None else None,
                    position_ids.to(dev) if position_ids is not None else None,
                )

            # collect the importance scores
            linear_layers = cast(
                Dict[str, LoSparseLinear],
                find_linear_layers(layer, layers=[LoSparseLinear]),
            )

            # register hooks to collect the importance scores
            def get_hook_fn(linear: LoSparseLinear):
                hook_fn = self._variants_hook_mapping[self.variant](linear)
                return hook_fn

            hooks = {}
            handles: List[torch.utils.hooks.RemovableHandle] = []
            for name, linear in linear_layers.items():
                hook_fn = get_hook_fn(linear)
                hooks[name] = hook_fn
                handles.append(linear.register_forward_hook(hook_fn))

            with torch.no_grad():
                for j in range(self.nsamples):
                    outs[j] = layer(
                        inps[j].unsqueeze(0),
                        attention_mask=attention_mask,
                        position_ids=position_ids,
                    )[0]

            # compute the importance scores and remove the hooks
            metrics = {}
            for name, hook in hooks.items():
                metrics[name] = hook.compute()
            for h in handles:
                h.remove()

            # prune the weights based on the importance scores
            pruned_weights_dict = {}
            for name, linear in linear_layers.items():
                log.info(f"Pruning {name}")
                if self.prune_type == PruningType.UNSTRUCTURED:
                    _, pruned_weights = unstructured_magnitude_prune_(
                        linear.weight.data,
                        metrics[name],
                        sparsity_ratio=self.sparsity_ratio,
                        return_pruned_weight=True,
                    )
                elif self.prune_type == PruningType.SEMISTRUCTURED:
                    _, pruned_weights = semistructured_magnitude_prune_(
                        linear.weight.data,
                        metrics[name],
                        n=self.n,
                        m=self.m,
                        return_pruned_weight=True,
                    )
                else:
                    raise ValueError(f"Invalid pruning type: {self.prune_type}")
                self.check_sparsity(linear.weight)
                pruned_weights_dict[name] = pruned_weights

            # compute the input to the next layer
            with torch.no_grad():
                for j in range(self.nsamples):
                    outs[j] = layer(
                        inps[j].unsqueeze(0),
                        attention_mask=attention_mask,
                        position_ids=position_ids,
                    )[0]
            inps, outs = outs, inps

            # extract the low-rank parts
            for name, linear in linear_layers.items():
                log.info(f"Extracting low-rank part for {name}")
                self.extract_low_rank_part_using_pruned_(
                    linear, pruned_weights_dict[name]
                )
                linear.skip_lowrank = False

    @torch.no_grad()
    def extract_low_rank_part_using_pruned_(
        self, linear: LoSparseLinear, pruned_weight: Tensor
    ):
        assert isinstance(
            linear, LoSparseLinear
        ), f"Expected LoSparseLinear, got {type(linear)}"

        u, s, vh = cast(
            Tuple[Tensor, Tensor, Tensor],
            torch.linalg.svd(pruned_weight.float(), full_matrices=False),
        )
        v = vh.T
        uk = u[:, : self.rank]
        sk = s[: self.rank]
        vk = v[:, : self.rank]
        linear.lo_A.data = vk.T.to(linear.lo_A.dtype).contiguous()
        linear.lo_B.data = (uk * sk).to(linear.lo_B.dtype).contiguous()
        return linear

    @torch.no_grad()
    def check_sparsity(self, weight: Tensor, tol: float = 0.01):
        if self.prune_type == PruningType.UNSTRUCTURED:
            assert (compute_sparsity(weight) - self.sparsity_ratio).abs() < tol
        elif self.prune_type == PruningType.SEMISTRUCTURED:
            assert (compute_sparsity(weight) - self.n / self.m).abs() < tol
        else:
            raise ValueError(f"Invalid pruning type: {self.prune_type}")

PCPSparseLoForLlama

Bases: SparseLoForLlama

PCP with mask

Source code in fusion_bench/method/sparselo/sparselo.py
class PCPSparseLoForLlama(SparseLoForLlama):
    "PCP with mask"

    _config_mapping = SparseLoForLlama._config_mapping | {
        "num_iterations": "num_iterations",
    }

    def __init__(self, num_iterations: int, **kwargs):
        super().__init__(**kwargs)
        self.num_iterations = num_iterations

    @override
    def run(self, modelpool):
        if self.seed is not None:
            L.seed_everything(self.seed)

        # load pre-trained model or the first model in the pool
        with self.profile("load_model"):
            model = modelpool.load_pretrained_or_first_model()
            model.seqlen = model.config.max_position_embeddings
            tokenizer = modelpool.load_tokenizer(use_fast=False)

        if not isinstance(model, (LlamaForCausalLM,)):
            log.warning(f"Model type {type(model)} may not supported.")

        if self.variant in self._variants_requires_calibration_data:
            inps, outs, attention_mask, position_ids = self.prepare_calibration_data(
                model, tokenizer
            )

        model = convert_to_losparse_llama(model, rank=self.rank)
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        for linear in find_linear_layers(model, layers=[LoSparseLinear]).values():
            linear = cast(LoSparseLinear, linear)
            linear.lo_A.data.zero_()
            linear.lo_B.data.zero_()
            linear.skip_lowrank = True

        match self.variant:
            case "dense":
                # this variant is a no-op, just for debug the conversion
                pass
            case "lowrank-only":
                self.extract_low_rank_parts_(model)
                self.set_weights_to_zeros_(model)
            case "random":
                self.pcp_random_prune_(model)
            case "magnitude":
                self.pcp_magnitude_prune_(model)
            case variant if variant in self._variants_requires_calibration_data:
                self.pcp_prune_using_calibration_data_(
                    model,
                    inps=inps,
                    outs=outs,
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                )
            case _:
                raise ValueError(f"Invalid variant: {self.variant}")

        if self.model_save_path is not None:
            with timeit_context(f"Saving the model to {self.model_save_path}"):
                tokenizer.save_pretrained(self.model_save_path)
                model.save_pretrained(self.model_save_path)

        return model

    @torch.no_grad()
    def pcp_random_prune_(self, model):
        layers: nn.ModuleList = model.model.layers
        for layer_idx, layer in tqdm(
            list(enumerate(layers)),
            "Pruning Layers",
            dynamic_ncols=True,
        ):
            for name, linear in layer.named_modules():
                if isinstance(linear, LoSparseLinear):
                    log.info(f"Pruning {name}, set weights to zeros")
                    W = linear.weight.data.clone()
                    if self.prune_type == PruningType.UNSTRUCTURED:
                        unstructured_magnitude_prune_(
                            linear.weight.data,
                            metric_function_or_scores=torch.rand_like,
                            sparsity_ratio=self.sparsity_ratio,
                        )
                    elif self.prune_type == PruningType.SEMISTRUCTURED:
                        semistructured_magnitude_prune_(
                            linear.weight.data,
                            metric_function_or_scores=torch.rand_like,
                            n=self.n,
                            m=self.m,
                        )
                    else:
                        raise ValueError(f"Invalid pruning type: {self.prune_type}")
                    self.check_sparsity(linear.weight)
                    mask = linear.weight != 0
                    linear.weight.data = PCP_search_with_mask(
                        W, mask, T_max=self.num_iterations
                    )
                    self.extract_low_rank_part_using_pruned_(linear, W - linear.weight)

    def pcp_magnitude_prune_(self, model):
        layers: nn.ModuleList = model.model.layers
        for layer_idx, layer in tqdm(
            enumerate(layers), "Pruning Layers", total=len(layers), dynamic_ncols=True
        ):
            for name, linear in layer.named_modules():
                if isinstance(linear, LoSparseLinear):
                    log.info(f"Magnitude Pruning {name}")
                    W = linear.weight.data.clone()
                    if self.prune_type == PruningType.UNSTRUCTURED:
                        unstructured_magnitude_prune_(
                            linear.weight.data,
                            metric_function_or_scores=torch.abs,
                            sparsity_ratio=self.sparsity_ratio,
                        )
                    elif self.prune_type == PruningType.SEMISTRUCTURED:
                        semistructured_magnitude_prune_(
                            linear.weight.data,
                            metric_function_or_scores=torch.abs,
                            n=self.n,
                            m=self.m,
                        )
                    else:
                        raise ValueError(f"Invalid pruning type: {self.prune_type}")
                    self.check_sparsity(linear.weight)
                    mask = linear.weight != 0
                    linear.weight.data = PCP_search_with_mask(
                        W, mask, T_max=self.num_iterations
                    )
                    self.extract_low_rank_part_using_pruned_(linear, W - linear.weight)

    def pcp_prune_using_calibration_data_(
        self,
        model: LoSparseLlamaForCausalLM,
        *,
        inps: Tensor,
        outs: Tensor,
        attention_mask: Optional[Tensor],
        position_ids: Optional[Tensor],
    ):
        layers = model.model.layers
        for layer_idx, layer in tqdm(
            enumerate(layers),
            "Pruning Layers",
            total=len(layers),
            dynamic_ncols=True,
        ):
            if (
                hasattr(model, "hf_device_map")
                and f"model.layers.{layer_idx}" in model.hf_device_map
            ):  ## handle the case for llama-30B and llama-65B, when the device map has multiple GPUs;
                dev = model.hf_device_map[f"model.layers.{layer_idx}"]
                inps, outs, attention_mask, position_ids = (
                    inps.to(dev),
                    outs.to(dev),
                    attention_mask.to(dev) if attention_mask is not None else None,
                    position_ids.to(dev) if position_ids is not None else None,
                )

            # collect the importance scores
            linear_layers = cast(
                Dict[str, LoSparseLinear],
                find_linear_layers(layer, layers=[LoSparseLinear]),
            )

            # register hooks to collect the importance scores
            def get_hook_fn(linear: LoSparseLinear):
                hook_fn = self._variants_hook_mapping[self.variant](linear)
                return hook_fn

            hooks = {}
            handles: List[torch.utils.hooks.RemovableHandle] = []
            for name, linear in linear_layers.items():
                hook_fn = get_hook_fn(linear)
                hooks[name] = hook_fn
                handles.append(linear.register_forward_hook(hook_fn))

            with torch.no_grad():
                for j in range(self.nsamples):
                    outs[j] = layer(
                        inps[j].unsqueeze(0),
                        attention_mask=attention_mask,
                        position_ids=position_ids,
                    )[0]

            # compute the importance scores and remove the hooks
            metrics = {}
            for name, hook in hooks.items():
                metrics[name] = hook.compute()
            for h in handles:
                h.remove()

            # prune the weights based on the importance scores
            for name, linear in linear_layers.items():
                log.info(f"Pruning {name}")
                W = linear.weight.data.clone()
                if self.prune_type == PruningType.UNSTRUCTURED:
                    _, pruned_weights = unstructured_magnitude_prune_(
                        linear.weight.data,
                        metrics[name],
                        sparsity_ratio=self.sparsity_ratio,
                        return_pruned_weight=True,
                    )
                elif self.prune_type == PruningType.SEMISTRUCTURED:
                    _, pruned_weights = semistructured_magnitude_prune_(
                        linear.weight.data,
                        metrics[name],
                        n=self.n,
                        m=self.m,
                        return_pruned_weight=True,
                    )
                else:
                    raise ValueError(f"Invalid pruning type: {self.prune_type}")
                self.check_sparsity(linear.weight)
                mask = linear.weight != 0
                linear.weight.data = PCP_search_with_mask(
                    W, mask, T_max=self.num_iterations
                )
                self.extract_low_rank_part_using_pruned_(linear, W - linear.weight)
                linear.skip_lowrank = False

            # compute the input to the next layer
            with torch.no_grad():
                for j in range(self.nsamples):
                    outs[j] = layer(
                        inps[j].unsqueeze(0),
                        attention_mask=attention_mask,
                        position_ids=position_ids,
                    )[0]
            inps, outs = outs, inps

IterativeSparseLoForLlama

Bases: SparseLoForLlama

Iterative Weight Update

Source code in fusion_bench/method/sparselo/sparselo.py
class IterativeSparseLoForLlama(SparseLoForLlama):
    "Iterative Weight Update"

    _config_mapping = SparseLoForLlama._config_mapping | {
        "num_iterations": "num_iterations",
    }

    def __init__(
        self, num_iterations: int, use_reference_model: bool = False, **kwargs
    ):
        super().__init__(**kwargs)
        self.num_iterations = num_iterations
        self.use_reference_model = use_reference_model

    @override
    def run(self, modelpool):
        self.modelpool = modelpool
        if self.seed is not None:
            L.seed_everything(self.seed)

        # load pre-trained model or the first model in the pool
        with self.profile("load_model"):
            model = modelpool.load_pretrained_or_first_model()
            model.seqlen = model.config.max_position_embeddings
            tokenizer = modelpool.load_tokenizer(use_fast=False)

        if not isinstance(model, (LlamaForCausalLM,)):
            log.warning(f"Model type {type(model)} may not supported.")

        if self.variant in self._variants_requires_calibration_data:
            inps, outs, attention_mask, position_ids = self.prepare_calibration_data(
                model, tokenizer
            )

        model = convert_to_losparse_llama(model, rank=self.rank)
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        for linear in find_linear_layers(model, layers=[LoSparseLinear]).values():
            linear = cast(LoSparseLinear, linear)
            linear.lo_A.data.zero_()
            linear.lo_B.data.zero_()
            linear.skip_lowrank = True

        match self.variant:
            case "dense":
                # this variant is a no-op, just for debug the conversion
                pass
            case "lowrank-only":
                self.extract_low_rank_parts_(model)
                self.set_weights_to_zeros_(model)
            case "random":
                self.iterative_random_prune_(model)
            case "magnitude":
                self.iterative_magnitude_prune_(model)
            case variant if variant in self._variants_requires_calibration_data:
                self.iterative_prune_using_calibration_data_(
                    model,
                    inps=inps,
                    outs=outs,
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                )
            case _:
                raise ValueError(f"Invalid variant: {self.variant}")

        if self.model_save_path is not None:
            with timeit_context(f"Saving the model to {self.model_save_path}"):
                tokenizer.save_pretrained(self.model_save_path)
                model.save_pretrained(self.model_save_path)

        return model

    @torch.no_grad()
    def iterative_random_prune_(self, model):
        layers: nn.ModuleList = model.model.layers
        for layer_idx, layer in tqdm(
            list(enumerate(layers)),
            "Pruning Layers",
            dynamic_ncols=True,
        ):
            for name, linear in layer.named_modules():
                if isinstance(linear, LoSparseLinear):
                    log.info(f"Pruning {name}, set weights to zeros")
                    W = linear.weight.data.clone()
                    if self.prune_type == PruningType.UNSTRUCTURED:
                        unstructured_magnitude_prune_(
                            linear.weight.data,
                            metric_function_or_scores=torch.rand_like,
                            sparsity_ratio=self.sparsity_ratio,
                        )
                    elif self.prune_type == PruningType.SEMISTRUCTURED:
                        semistructured_magnitude_prune_(
                            linear.weight.data,
                            metric_function_or_scores=torch.rand_like,
                            n=self.n,
                            m=self.m,
                        )
                    else:
                        raise ValueError(f"Invalid pruning type: {self.prune_type}")
                    self.check_sparsity(linear.weight)
                    mask = linear.weight != 0
                    for rank in tqdm(
                        np.linspace(1, self.rank, self.num_iterations, dtype=np.int64),
                        "Iterative Pruning",
                        leave=False,
                        dynamic_ncols=True,
                    ):
                        linear.weight.data, specturm_ratio = iterative_weight_update(
                            W,
                            linear.weight,
                            mask,
                            rank=rank,
                        )
                        if specturm_ratio > 0.99:
                            break
                    self.extract_low_rank_part_using_pruned_(linear, W - linear.weight)

    @torch.no_grad()
    def iterative_magnitude_prune_(self, model):
        layers: nn.ModuleList = model.model.layers
        if self.use_reference_model:
            reference_model = self.modelpool.load_model(
                "reference_model", torch_dtype="float16"
            )
            reference_layers: nn.ModuleList = reference_model.model.layers
        for layer_idx, layer in tqdm(
            enumerate(layers), "Pruning Layers", total=len(layers), dynamic_ncols=True
        ):
            for name, linear in layer.named_modules():
                if isinstance(linear, LoSparseLinear):
                    log.info(f"Magnitude Pruning {name}")
                    W = (
                        linear.weight.data.clone()
                        if not self.use_reference_model
                        else reference_layers[layer_idx]
                        .get_submodule(name)
                        .weight.data.clone()
                        .to(linear.weight.data.device)
                    )
                    if self.prune_type == PruningType.UNSTRUCTURED:
                        unstructured_magnitude_prune_(
                            linear.weight.data,
                            metric_function_or_scores=torch.abs,
                            sparsity_ratio=self.sparsity_ratio,
                        )
                    elif self.prune_type == PruningType.SEMISTRUCTURED:
                        semistructured_magnitude_prune_(
                            linear.weight.data,
                            metric_function_or_scores=torch.abs,
                            n=self.n,
                            m=self.m,
                        )
                    else:
                        raise ValueError(f"Invalid pruning type: {self.prune_type}")
                    self.check_sparsity(linear.weight)
                    mask = linear.weight != 0
                    for rank in tqdm(
                        np.linspace(1, self.rank, self.num_iterations, dtype=np.int64),
                        "Iterative Pruning",
                        leave=False,
                        dynamic_ncols=True,
                    ):
                        linear.weight.data, specturm_ratio = iterative_weight_update(
                            W,
                            linear.weight,
                            mask,
                            rank=rank,
                        )
                        if specturm_ratio > 0.99:
                            break
                    self.extract_low_rank_part_using_pruned_(linear, W - linear.weight)

    @torch.no_grad()
    def iterative_prune_using_calibration_data_(
        self,
        model: LoSparseLlamaForCausalLM,
        *,
        inps: Tensor,
        outs: Tensor,
        attention_mask: Optional[Tensor],
        position_ids: Optional[Tensor],
    ):
        layers = model.model.layers
        for layer_idx, layer in tqdm(
            enumerate(layers),
            "Pruning Layers",
            total=len(layers),
            dynamic_ncols=True,
        ):
            if (
                hasattr(model, "hf_device_map")
                and f"model.layers.{layer_idx}" in model.hf_device_map
            ):  ## handle the case for llama-30B and llama-65B, when the device map has multiple GPUs;
                dev = model.hf_device_map[f"model.layers.{layer_idx}"]
                inps, outs, attention_mask, position_ids = (
                    inps.to(dev),
                    outs.to(dev),
                    attention_mask.to(dev) if attention_mask is not None else None,
                    position_ids.to(dev) if position_ids is not None else None,
                )

            # collect the importance scores
            linear_layers = cast(
                Dict[str, LoSparseLinear],
                find_linear_layers(layer, layers=[LoSparseLinear]),
            )

            # register hooks to collect the importance scores
            def get_hook_fn(linear: LoSparseLinear):
                hook_fn = self._variants_hook_mapping[self.variant](linear)
                return hook_fn

            hooks = {}
            handles: List[torch.utils.hooks.RemovableHandle] = []
            for name, linear in linear_layers.items():
                hook_fn = get_hook_fn(linear)
                hooks[name] = hook_fn
                handles.append(linear.register_forward_hook(hook_fn))

            with torch.no_grad():
                for j in range(self.nsamples):
                    outs[j] = layer(
                        inps[j].unsqueeze(0),
                        attention_mask=attention_mask,
                        position_ids=position_ids,
                    )[0]

            # compute the importance scores and remove the hooks
            metrics = {}
            for name, hook in hooks.items():
                metrics[name] = hook.compute()
            for h in handles:
                h.remove()

            # prune the weights based on the importance scores
            for name, linear in linear_layers.items():
                log.info(f"Pruning {name}")
                W = linear.weight.data.clone()
                if self.prune_type == PruningType.UNSTRUCTURED:
                    _, pruned_weights = unstructured_magnitude_prune_(
                        linear.weight.data,
                        metrics[name],
                        sparsity_ratio=self.sparsity_ratio,
                        return_pruned_weight=True,
                    )
                elif self.prune_type == PruningType.SEMISTRUCTURED:
                    _, pruned_weights = semistructured_magnitude_prune_(
                        linear.weight.data,
                        metrics[name],
                        n=self.n,
                        m=self.m,
                        return_pruned_weight=True,
                    )
                else:
                    raise ValueError(f"Invalid pruning type: {self.prune_type}")
                self.check_sparsity(linear.weight)
                mask = linear.weight != 0
                for rank in tqdm(
                    np.linspace(1, self.rank, self.num_iterations, dtype=np.int64),
                    "Iterative Pruning",
                    leave=False,
                    dynamic_ncols=True,
                ):
                    linear.weight.data, specturm_ratio = iterative_weight_update(
                        W,
                        linear.weight,
                        mask,
                        rank=rank,
                    )
                    if specturm_ratio > 0.99:
                        break
                self.extract_low_rank_part_using_pruned_(linear, W - linear.weight)
                linear.skip_lowrank = False

            # compute the input to the next layer
            with torch.no_grad():
                for j in range(self.nsamples):
                    outs[j] = layer(
                        inps[j].unsqueeze(0),
                        attention_mask=attention_mask,
                        position_ids=position_ids,
                    )[0]
            inps, outs = outs, inps

MoE Expert Pruning

DynamicSkippingPruningForMixtral

Bases: BaseAlgorithm, LightningFabricMixin, SimpleProfilerMixin

Source code in fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py
@auto_register_config
class DynamicSkippingPruningForMixtral(
    fb.BaseAlgorithm,
    fb.mixins.LightningFabricMixin,
    fb.mixins.SimpleProfilerMixin,
):
    modelpool: fb.modelpool.CausalLMPool

    def __init__(
        self,
        calib_set: str,
        max_block_size: int,
        n_blocks_for_stat: int,
        batch_size: int,
        num_workers: int,
        num_preserved_experts: int,
        seed: int = 42,
        model_save_path: str = R"{log_dir}/pruned_model",
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.model_save_path = model_save_path
        self.calib_set = calib_set
        self.max_block_size = max_block_size
        self.n_blocks_for_stat = n_blocks_for_stat
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.seed = seed
        self.num_preserved_experts = num_preserved_experts

    def run(self, modelpool: fb.modelpool.CausalLMPool):
        """
        Args:
            modelpool (fb.modelpool.CausalLMPool): The model pool to run the algorithm on.
                Example Config: config/modelpool/CausalLMPool/mixtral-8x7b.yaml
        """
        self.modelpool = modelpool
        # set random seed
        if self.seed is not None:
            L.seed_everything(self.seed)
        # parse model_save_path
        self.model_save_path = self.model_save_path.format(log_dir=self.log_dir)

        with self.profile("load model"):
            model = modelpool.load_pretrained_or_first_model()
            tokenizer = modelpool.load_tokenizer()

        # Load the calibration data
        with self.profile("load calibration data"):
            calib_loader = build_calib_loader(
                self.calib_set,
                tokenizer=tokenizer,
                max_block_size=self.max_block_size,
                n_blocks_for_stat=self.n_blocks_for_stat,
                batch_size=self.batch_size,
                num_workers=self.num_workers,
                seed=self.seed,
            )

        with self.profile("prune model"):
            model, info = dynamic_skipping(
                model,
                calib_loader,
                batch_size=self.batch_size,
            )

        if self.model_save_path is not None:
            with self.profile("save model"):
                modelpool.save_model(
                    model,
                    path=self.model_save_path,
                    tokenizer=tokenizer,
                )
                torch.save(info, os.path.join(self.log_dir, "pruning_info.pt"))

        self.print_profile_summary()
        return model
run(modelpool)

Parameters:

  • modelpool (CausalLMPool) –

    The model pool to run the algorithm on. Example Config: config/modelpool/CausalLMPool/mixtral-8x7b.yaml

Source code in fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py
def run(self, modelpool: fb.modelpool.CausalLMPool):
    """
    Args:
        modelpool (fb.modelpool.CausalLMPool): The model pool to run the algorithm on.
            Example Config: config/modelpool/CausalLMPool/mixtral-8x7b.yaml
    """
    self.modelpool = modelpool
    # set random seed
    if self.seed is not None:
        L.seed_everything(self.seed)
    # parse model_save_path
    self.model_save_path = self.model_save_path.format(log_dir=self.log_dir)

    with self.profile("load model"):
        model = modelpool.load_pretrained_or_first_model()
        tokenizer = modelpool.load_tokenizer()

    # Load the calibration data
    with self.profile("load calibration data"):
        calib_loader = build_calib_loader(
            self.calib_set,
            tokenizer=tokenizer,
            max_block_size=self.max_block_size,
            n_blocks_for_stat=self.n_blocks_for_stat,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            seed=self.seed,
        )

    with self.profile("prune model"):
        model, info = dynamic_skipping(
            model,
            calib_loader,
            batch_size=self.batch_size,
        )

    if self.model_save_path is not None:
        with self.profile("save model"):
            modelpool.save_model(
                model,
                path=self.model_save_path,
                tokenizer=tokenizer,
            )
            torch.save(info, os.path.join(self.log_dir, "pruning_info.pt"))

    self.print_profile_summary()
    return model

ProgressivePruningForMixtral

Bases: BaseAlgorithm, LightningFabricMixin, SimpleProfilerMixin

Source code in fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py
@auto_register_config
class ProgressivePruningForMixtral(
    fb.BaseAlgorithm,
    fb.mixins.LightningFabricMixin,
    fb.mixins.SimpleProfilerMixin,
):
    modelpool: fb.modelpool.CausalLMPool

    def __init__(
        self,
        calib_set: str,
        max_block_size: int,
        n_blocks_for_stat: int,
        batch_size: int,
        num_workers: int,
        num_preserved_experts: int,
        seed: int = 42,
        model_save_path: str = R"{log_dir}/pruned_model",
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.model_save_path = model_save_path
        self.calib_set = calib_set
        self.max_block_size = max_block_size
        self.n_blocks_for_stat = n_blocks_for_stat
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.seed = seed
        self.num_preserved_experts = num_preserved_experts

    def run(self, modelpool: fb.modelpool.CausalLMPool):
        """
        Args:
            modelpool (fb.modelpool.CausalLMPool): The model pool to run the algorithm on.
                Example Config: config/modelpool/CausalLMPool/mixtral-8x7b.yaml
        """
        self.modelpool = modelpool
        # set random seed
        if self.seed is not None:
            L.seed_everything(self.seed)
        # parse model_save_path
        self.model_save_path = self.model_save_path.format(log_dir=self.log_dir)

        with self.profile("load model"):
            model = modelpool.load_pretrained_or_first_model()
            tokenizer = modelpool.load_tokenizer()

        # Load the calibration data
        with self.profile("load calibration data"):
            calib_loader = build_calib_loader(
                self.calib_set,
                tokenizer=tokenizer,
                max_block_size=self.max_block_size,
                n_blocks_for_stat=self.n_blocks_for_stat,
                batch_size=self.batch_size,
                num_workers=self.num_workers,
                seed=self.seed,
            )

        with self.profile("prune model"):
            model, info = progressive_pruning(
                model,
                calib_loader,
                r=self.num_preserved_experts,
            )

        if self.model_save_path is not None:
            with self.profile("save model"):
                modelpool.save_model(
                    model,
                    path=self.model_save_path,
                    tokenizer=tokenizer,
                )
                torch.save(info, os.path.join(self.log_dir, "pruning_info.pt"))

        self.print_profile_summary()
        return model
run(modelpool)

Parameters:

  • modelpool (CausalLMPool) –

    The model pool to run the algorithm on. Example Config: config/modelpool/CausalLMPool/mixtral-8x7b.yaml

Source code in fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py
def run(self, modelpool: fb.modelpool.CausalLMPool):
    """
    Args:
        modelpool (fb.modelpool.CausalLMPool): The model pool to run the algorithm on.
            Example Config: config/modelpool/CausalLMPool/mixtral-8x7b.yaml
    """
    self.modelpool = modelpool
    # set random seed
    if self.seed is not None:
        L.seed_everything(self.seed)
    # parse model_save_path
    self.model_save_path = self.model_save_path.format(log_dir=self.log_dir)

    with self.profile("load model"):
        model = modelpool.load_pretrained_or_first_model()
        tokenizer = modelpool.load_tokenizer()

    # Load the calibration data
    with self.profile("load calibration data"):
        calib_loader = build_calib_loader(
            self.calib_set,
            tokenizer=tokenizer,
            max_block_size=self.max_block_size,
            n_blocks_for_stat=self.n_blocks_for_stat,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            seed=self.seed,
        )

    with self.profile("prune model"):
        model, info = progressive_pruning(
            model,
            calib_loader,
            r=self.num_preserved_experts,
        )

    if self.model_save_path is not None:
        with self.profile("save model"):
            modelpool.save_model(
                model,
                path=self.model_save_path,
                tokenizer=tokenizer,
            )
            torch.save(info, os.path.join(self.log_dir, "pruning_info.pt"))

    self.print_profile_summary()
    return model

LayerWisePruningForMixtral

Bases: BaseAlgorithm, LightningFabricMixin, SimpleProfilerMixin

Source code in fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py
@auto_register_config
class LayerWisePruningForMixtral(
    fb.BaseAlgorithm,
    fb.mixins.LightningFabricMixin,
    fb.mixins.SimpleProfilerMixin,
):
    modelpool: fb.modelpool.CausalLMPool

    def __init__(
        self,
        calib_set: str,
        max_block_size: int,
        n_blocks_for_stat: int,
        batch_size: int,
        num_workers: int,
        num_preserved_experts: int,
        seed: int = 42,
        model_save_path: str = R"{log_dir}/pruned_model",
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.model_save_path = model_save_path
        self.calib_set = calib_set
        self.max_block_size = max_block_size
        self.n_blocks_for_stat = n_blocks_for_stat
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.seed = seed
        self.num_preserved_experts = num_preserved_experts

    def run(self, modelpool: fb.modelpool.CausalLMPool):
        """
        Args:
            modelpool (fb.modelpool.CausalLMPool): The model pool to run the algorithm on.
                Example Config: config/modelpool/CausalLMPool/mixtral-8x7b.yaml
        """
        self.modelpool = modelpool
        # set random seed
        if self.seed is not None:
            L.seed_everything(self.seed)
        # parse model_save_path
        self.model_save_path = self.model_save_path.format(log_dir=self.log_dir)

        with self.profile("load model"):
            model = modelpool.load_pretrained_or_first_model()
            tokenizer = modelpool.load_tokenizer()

        # Load the calibration data
        with self.profile("load calibration data"):
            calib_loader = build_calib_loader(
                self.calib_set,
                tokenizer=tokenizer,
                max_block_size=self.max_block_size,
                n_blocks_for_stat=self.n_blocks_for_stat,
                batch_size=self.batch_size,
                num_workers=self.num_workers,
                seed=self.seed,
            )

        with self.profile("prune model"):
            model, info = layerwise_pruning(
                model,
                calib_loader,
                r=self.num_preserved_experts,
            )

        if self.model_save_path is not None:
            with self.profile("save model"):
                modelpool.save_model(
                    model,
                    path=self.model_save_path,
                    tokenizer=tokenizer,
                )
                torch.save(info, os.path.join(self.log_dir, "pruning_info.pt"))

        self.print_profile_summary()
        return model
run(modelpool)

Parameters:

  • modelpool (CausalLMPool) –

    The model pool to run the algorithm on. Example Config: config/modelpool/CausalLMPool/mixtral-8x7b.yaml

Source code in fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py
def run(self, modelpool: fb.modelpool.CausalLMPool):
    """
    Args:
        modelpool (fb.modelpool.CausalLMPool): The model pool to run the algorithm on.
            Example Config: config/modelpool/CausalLMPool/mixtral-8x7b.yaml
    """
    self.modelpool = modelpool
    # set random seed
    if self.seed is not None:
        L.seed_everything(self.seed)
    # parse model_save_path
    self.model_save_path = self.model_save_path.format(log_dir=self.log_dir)

    with self.profile("load model"):
        model = modelpool.load_pretrained_or_first_model()
        tokenizer = modelpool.load_tokenizer()

    # Load the calibration data
    with self.profile("load calibration data"):
        calib_loader = build_calib_loader(
            self.calib_set,
            tokenizer=tokenizer,
            max_block_size=self.max_block_size,
            n_blocks_for_stat=self.n_blocks_for_stat,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            seed=self.seed,
        )

    with self.profile("prune model"):
        model, info = layerwise_pruning(
            model,
            calib_loader,
            r=self.num_preserved_experts,
        )

    if self.model_save_path is not None:
        with self.profile("save model"):
            modelpool.save_model(
                model,
                path=self.model_save_path,
                tokenizer=tokenizer,
            )
            torch.save(info, os.path.join(self.log_dir, "pruning_info.pt"))

    self.print_profile_summary()
    return model