Skip to content

Model Utilities

Type Definitions

fusion_bench.utils.type

StateDictType = Dict[str, Tensor] module-attribute

BoolStateDictType = Dict[str, torch.BoolTensor] module-attribute

TorchModelType = TypeVar('TorchModelType', bound=(nn.Module)) module-attribute

Parameter Count and Manipulation

fusion_bench.utils.parameters

check_parameters_all_equal(list_of_param_names)

Checks if all models have the same parameters.

This function takes a list of parameter names or state dictionaries from different models. It checks if all models have the same parameters by comparing the parameter names. If any model has different parameters, it raises a ValueError with the differing parameters.

Parameters:

  • list_of_param_names (List[Union[StateDict, List[str]]]) –

    A list of parameter names or state dictionaries.

Raises:

  • ValueError

    If any model has different parameters.

Returns:

  • None

    None

Source code in fusion_bench/utils/parameters.py
def check_parameters_all_equal(
    list_of_param_names: List[Union[StateDictType, nn.Module, List[str]]],
) -> None:
    """
    Checks if all models have the same parameters.

    This function takes a list of parameter names or state dictionaries from different models.
    It checks if all models have the same parameters by comparing the parameter names.
    If any model has different parameters, it raises a ValueError with the differing parameters.

    Args:
        list_of_param_names (List[Union[StateDict, List[str]]]): A list of parameter names or state dictionaries.

    Raises:
        ValueError: If any model has different parameters.

    Returns:
        None
    """
    if isinstance(list_of_param_names[0], Mapping):
        list_of_param_names = [list(i.keys()) for i in list_of_param_names]
    elif isinstance(list_of_param_names[0], nn.Module):
        list_of_param_names = [list(i.state_dict().keys()) for i in list_of_param_names]
    else:
        parameter_names = set(list_of_param_names[0])

        if len(list_of_param_names) >= 2:
            # raise ValueError("Number of models is less than 2.")
            for names in list_of_param_names[1:]:
                current_parameterNames = set(names)
                if current_parameterNames != parameter_names:
                    raise ValueError(
                        "Differing parameter names in models. "
                        f"The different parameters are {parameter_names.symmetric_difference(current_parameterNames)}"
                    )

count_parameters(module, non_zero_only=False)

Counts the number of trainable and total parameters in a PyTorch model.

Parameters:

  • model (Module) –

    The PyTorch model for which to count parameters.

  • non_zero_only (bool, default: False ) –

    If True, only non-zero parameters are counted. If False, all parameters are counted. Defaults to False.

Returns:

  • tuple ( tuple[int, int] ) –

    A tuple containing the number of trainable parameters and the total number of parameters.

Examples:

```python
# Count the parameters
trainable_params, all_params = count_parameters(model)
```
Source code in fusion_bench/utils/parameters.py
@torch.no_grad()
def count_parameters(module: nn.Module, non_zero_only: bool = False) -> tuple[int, int]:
    """
    Counts the number of trainable and total parameters in a PyTorch model.

    Args:
        model (nn.Module): The PyTorch model for which to count parameters.
        non_zero_only (bool, optional): If True, only non-zero parameters are counted. If False, all parameters are counted. Defaults to False.

    Returns:
        tuple: A tuple containing the number of trainable parameters and the total number of parameters.

    Examples:

        ```python
        # Count the parameters
        trainable_params, all_params = count_parameters(model)
        ```
    """
    trainable_params = 0
    all_param = 0

    for name, param in module.named_parameters():
        # count the number of parameters
        num_params = _numel(param, non_zero_only)

        # accumulate the number of trainable and total parameters
        all_param += num_params
        if param.requires_grad:
            trainable_params += num_params

    return trainable_params, all_param

get_parameter_statistics(module_or_state_dict, model_wise=False)

Get statistics of the parameters in a PyTorch model or state dictionary.

Parameters:

  • module_or_state_dict (Union[Module, StateDictType]) –

    The PyTorch model for which to get parameter statistics.

Returns:

  • dict ( dict ) –

    A dictionary containing the mean, standard deviation, min, and max of the parameters.

Source code in fusion_bench/utils/parameters.py
@torch.no_grad()
def get_parameter_statistics(
    module_or_state_dict: Union[nn.Module, StateDictType],
    model_wise: bool = False,
) -> dict:
    """
    Get statistics of the parameters in a PyTorch model or state dictionary.

    Args:
        module_or_state_dict (Union[nn.Module, StateDictType]): The PyTorch model for which to get parameter statistics.

    Returns:
        dict: A dictionary containing the mean, standard deviation, min, and max of the parameters.
    """
    stats = {}
    if isinstance(module_or_state_dict, nn.Module):
        state_dict = module_or_state_dict.state_dict()
    else:
        state_dict = module_or_state_dict

    if model_wise:
        # if model-wise, return the statistics for the entire model
        state_dict = {"model": state_dict_to_vector(state_dict)}

    for name, param in state_dict.items():
        stats[name] = {
            "mean": param.data.mean().item(),
            "std": param.data.std().item(),
            "min": param.data.min().item(),
            "max": param.data.max().item(),
        }

    return stats

get_parameter_summary(module_or_state_dict, non_zero_only=False)

Get a summary of the parameters in a PyTorch model.

Source code in fusion_bench/utils/parameters.py
@torch.no_grad()
def get_parameter_summary(
    module_or_state_dict: Union[nn.Module, StateDictType], non_zero_only: bool = False
) -> dict:
    """
    Get a summary of the parameters in a PyTorch model.
    """
    if isinstance(module_or_state_dict, nn.Module):
        state_dict = module_or_state_dict.state_dict(keep_vars=True)
    else:
        state_dict = module_or_state_dict

    trainable_params = 0
    all_param = 0
    bytes = 0

    for name, param in state_dict.items():
        # count the number of parameters
        num_params = _numel(param, non_zero_only)
        bytes += _numel(param, non_zero_only=False) * param.element_size()

        # accumulate the number of trainable and total parameters
        all_param += num_params
        if param.requires_grad:
            trainable_params += num_params

    return {
        "trainable_params": trainable_params,
        "all_param": all_param,
        "bytes": bytes,
    }

human_readable(num)

Converts a number into a human-readable string with appropriate magnitude suffix.

Examples:

```python
print(human_readable(1500))
# Output: '1.50K'
print(human_readable(1500000))
# Output: '1.50M'
```

Parameters:

  • num (int) –

    The number to convert.

Returns:

  • str ( str ) –

    The human-readable string representation of the number.

Source code in fusion_bench/utils/parameters.py
def human_readable(num: int) -> str:
    """
    Converts a number into a human-readable string with appropriate magnitude suffix.

    Examples:

        ```python
        print(human_readable(1500))
        # Output: '1.50K'
        print(human_readable(1500000))
        # Output: '1.50M'
        ```

    Args:
        num (int): The number to convert.

    Returns:
        str: The human-readable string representation of the number.
    """
    if num < 1000 and isinstance(num, int):
        return str(num)
    magnitude = 0
    while abs(num) >= 1000:
        magnitude += 1
        num /= 1000.0
    return "%.2f%s" % (num, ["", "K", "M", "B", "T", "P"][magnitude])

print_parameters(module, is_human_readable=True, print_fn=print, non_zero_only=False)

Prints the number of trainable and total parameters in a PyTorch model.

Parameters:

  • module (Module) –

    The PyTorch model for which to print parameters.

  • human_readable (bool) –

    If True, the parameter counts are converted to a human-readable format (e.g., '1.5M' instead of '1500000'). Defaults to True.

  • print_fn (Callable, default: print ) –

    Function used to print the message.

  • non_zero_only (bool, default: False ) –

    If True, only non-zero elements are counted. If False, all elements are counted. Defaults to False.

Prints

The number of trainable parameters, the total number of parameters, and the percentage of trainable parameters in the model.

Source code in fusion_bench/utils/parameters.py
def print_parameters(
    module: nn.Module,
    is_human_readable: bool = True,
    print_fn=print,
    non_zero_only: bool = False,
):
    """
    Prints the number of trainable and total parameters in a PyTorch model.

    Args:
        module (nn.Module): The PyTorch model for which to print parameters.
        human_readable (bool, optional): If True, the parameter counts are converted to a human-readable format (e.g., '1.5M' instead of '1500000'). Defaults to True.
        print_fn (Callable): Function used to print the message.
        non_zero_only (bool, optional): If True, only non-zero elements are counted. If False, all elements are counted. Defaults to False.

    Prints:
        The number of trainable parameters, the total number of parameters, and the percentage of trainable parameters in the model.
    """
    trainable_params, all_param = count_parameters(module, non_zero_only=non_zero_only)
    trainable_ratio = 100 * trainable_params / all_param
    if is_human_readable:
        trainable_params = human_readable(trainable_params)
        all_param = human_readable(all_param)

    print_fn(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {trainable_ratio:.4f}"
    )

state_dict_to_vector(state_dict, remove_keys=None)

Convert a state dictionary to a vector.

Parameters:

  • state_dict (Union[dict[str, Tensor], Module]) –

    The state dictionary to convert.

  • remove_keys (list, default: None ) –

    List of keys to remove from the state dictionary. Defaults to [].

Returns:

  • torch.Tensor: The converted vector.

Source code in fusion_bench/utils/parameters.py
def state_dict_to_vector(
    state_dict: Union[StateDictType, nn.Module],
    remove_keys: Optional[List[str]] = None,
):
    """
    Convert a state dictionary to a vector.

    Args:
        state_dict (Union[dict[str, torch.Tensor], nn.Module]): The state dictionary to convert.
        remove_keys (list, optional): List of keys to remove from the state dictionary. Defaults to [].

    Returns:
        torch.Tensor: The converted vector.
    """
    remove_keys = remove_keys if remove_keys is not None else []

    if isinstance(state_dict, nn.Module):
        shared_state_dict = state_dict.state_dict()
    else:
        shared_state_dict = copy.copy(state_dict)

    # remove the keys to be removed
    for key in remove_keys:
        if key in shared_state_dict:
            del shared_state_dict[key]

    # sort the reference dict
    sorted_shared_state_dict = OrderedDict(sorted(shared_state_dict.items()))

    vector = nn.utils.parameters_to_vector(
        [value.reshape(-1) for key, value in sorted_shared_state_dict.items()]
    )
    return vector

trainable_state_dict(module, prefix='', keep_vars=False)

Returns the state dictionary of the module containing only the trainable parameters.

Parameters:

  • module (Module) –

    The neural network module.

  • prefix (str, default: '' ) –

    The prefix to add to the parameter names. Defaults to "".

  • keep_vars (bool, default: False ) –

    If True, the parameters are not detached. Defaults to False.

Returns:

  • StateDictType

    Dict[str, Tensor]: A dictionary containing the names and values of the trainable parameters.

Source code in fusion_bench/utils/parameters.py
def trainable_state_dict(
    module: nn.Module,
    prefix: str = "",
    keep_vars: bool = False,
) -> StateDictType:
    """
    Returns the state dictionary of the module containing only the trainable parameters.

    Args:
        module (nn.Module): The neural network module.
        prefix (str, optional): The prefix to add to the parameter names. Defaults to "".
        keep_vars (bool, optional): If True, the parameters are not detached. Defaults to False.

    Returns:
        Dict[str, Tensor]: A dictionary containing the names and values of the trainable parameters.
    """
    state_dict = {
        prefix + name: param if keep_vars else param.detach()
        for name, param in module.named_parameters()
        if param.requires_grad
    }
    return state_dict

vector_to_state_dict(vector, state_dict, remove_keys=None)

Convert a vector to a state dictionary.

Parameters:

  • vector (Tensor) –

    The vector to convert.

  • state_dict (Union[dict[str, Tensor], Module]) –

    The reference state dictionary to define the order of the vector.

  • remove_keys (list, default: None ) –

    List of keys to remove from the reference state dictionary. Defaults to [].

Returns:

  • dict

    The converted state dictionary.

Source code in fusion_bench/utils/parameters.py
def vector_to_state_dict(
    vector: torch.Tensor,
    state_dict: Union[StateDictType, nn.Module],
    remove_keys: Optional[List[str]] = None,
):
    """
    Convert a vector to a state dictionary.

    Args:
        vector (torch.Tensor): The vector to convert.
        state_dict (Union[dict[str, torch.Tensor], nn.Module]): The reference state dictionary to define the order of the vector.
        remove_keys (list, optional): List of keys to remove from the reference state dictionary. Defaults to [].

    Returns:
        dict: The converted state dictionary.
    """
    remove_keys = remove_keys if remove_keys is not None else []

    # create a reference dict to define the order of the vector
    if isinstance(state_dict, nn.Module):
        reference_dict = state_dict.state_dict()
    else:
        # shallow copy the state_dict
        reference_dict = copy.copy(state_dict)

    # remove the keys to be removed
    for key in remove_keys:
        if key in reference_dict:
            del reference_dict[key]

    # sort the reference dict
    sorted_reference_dict = OrderedDict(sorted(reference_dict.items()))

    # create a shared state dict using the reference dict
    nn.utils.vector_to_parameters(vector, sorted_reference_dict.values())

    # add back the encoder and decoder embedding weights.
    if "transformer.shared.weight" in sorted_reference_dict:
        for key in remove_keys:
            sorted_reference_dict[key] = sorted_reference_dict[
                "transformer.shared.weight"
            ]
    return sorted_reference_dict

State Dict Arithmetic

fusion_bench.utils.state_dict_arithmetic

num_params_of_state_dict(state_dict)

Returns the number of parameters in a state dict.

Parameters:

  • state_dict (Dict[str, Tensor]) –

    The state dict to count the number of parameters in.

Returns:

  • int

    The number of parameters in the state dict.

Source code in fusion_bench/utils/state_dict_arithmetic.py
def num_params_of_state_dict(state_dict: StateDictType):
    """
    Returns the number of parameters in a state dict.

    Args:
        state_dict (Dict[str, Tensor]): The state dict to count the number of parameters in.

    Returns:
        int: The number of parameters in the state dict.
    """
    return sum([state_dict[key].numel() for key in state_dict])

state_dict_add(a, b, strict=True, device=None, show_pbar=False)

Returns the sum of two state dicts.

Parameters:

  • a (Dict) –

    The first state dict.

  • b (Dict) –

    The second state dict.

  • strict (bool, default: True ) –

    Whether to check if the keys of the two state dicts are the same.

Returns:

  • Dict

    The sum of the two state dicts.

Source code in fusion_bench/utils/state_dict_arithmetic.py
def state_dict_add(
    a: StateDictType,
    b: StateDictType,
    strict: bool = True,
    device=None,
    show_pbar: bool = False,
):
    """
    Returns the sum of two state dicts.

    Args:
        a (Dict): The first state dict.
        b (Dict): The second state dict.
        strict (bool): Whether to check if the keys of the two state dicts are the same.

    Returns:
        Dict: The sum of the two state dicts.
    """
    ans = {}
    if strict:
        check_parameters_all_equal([a, b])
        for key in tqdm(tuple(a.keys())) if show_pbar else a:
            ans[key] = a[key] + b[key]
    else:
        for key in tqdm(tuple(a.keys())) if show_pbar else a:
            if key in b:
                ans[key] = a[key] + b[key]
    if device is not None:
        ans = to_device(ans, device)
    return ans

state_dict_avg(state_dicts)

Returns the average of a list of state dicts.

Parameters:

  • state_dicts (List[Dict[str, Tensor]]) –

    The list of state dicts to average.

Returns:

  • Dict

    The average of the state dicts.

Source code in fusion_bench/utils/state_dict_arithmetic.py
def state_dict_avg(state_dicts: List[StateDictType]):
    """
    Returns the average of a list of state dicts.

    Args:
        state_dicts (List[Dict[str, Tensor]]): The list of state dicts to average.

    Returns:
        Dict: The average of the state dicts.
    """
    assert len(state_dicts) > 0, "The number of state_dicts must be greater than 0"
    assert all(
        [len(state_dicts[0]) == len(state_dict) for state_dict in state_dicts]
    ), "All state_dicts must have the same number of keys"

    num_state_dicts = len(state_dicts)
    avg_state_dict = OrderedDict()
    for key in state_dicts[0]:
        avg_state_dict[key] = torch.zeros_like(state_dicts[0][key])
        for state_dict in state_dicts:
            avg_state_dict[key] += state_dict[key]
        avg_state_dict[key] /= num_state_dicts
    return avg_state_dict

state_dict_binary_mask(a, b, compare_fn='greater')

Returns the binary mask of elements in a compared to elements in b using the provided comparison function.

Parameters:

  • a (StateDictType) –

    The first state dict.

  • b (StateDictType) –

    The second state dict.

  • compare_fn (Union[Literal['greater', 'less', 'equal', 'not_equal'], Callable[[Tensor, Tensor], Tensor]], default: 'greater' ) –

    A function that takes two tensors and returns a boolean tensor. Defaults to greater than comparison (x > y).

Returns:

  • StateDictType ( BoolStateDictType ) –

    A dictionary containing binary masks (0 or 1) based on the comparison.

Source code in fusion_bench/utils/state_dict_arithmetic.py
def state_dict_binary_mask(
    a: StateDictType,
    b: StateDictType,
    compare_fn: Union[
        Literal["greater", "less", "equal", "not_equal"],
        Callable[[Tensor, Tensor], torch.BoolTensor],
    ] = "greater",
) -> BoolStateDictType:
    """
    Returns the binary mask of elements in a compared to elements in b using the provided comparison function.

    Args:
        a (StateDictType): The first state dict.
        b (StateDictType): The second state dict.
        compare_fn (Union[Literal["greater", "less", "equal", "not_equal"], Callable[[Tensor, Tensor], Tensor]]): A function that takes two tensors and returns a boolean tensor.
            Defaults to greater than comparison (x > y).

    Returns:
        StateDictType: A dictionary containing binary masks (0 or 1) based on the comparison.
    """
    compare_fn_dict = {
        "greater": lambda x, y: x > y,
        "less": lambda x, y: x < y,
        "equal": lambda x, y: x == y,
        "not_equal": lambda x, y: x != y,
    }
    if isinstance(compare_fn, str):
        compare_fn = compare_fn_dict[compare_fn]
    elif not callable(compare_fn):
        raise ValueError(
            f"compare_fn must be a string or a callable, but got {type(compare_fn)}"
        )

    mask = OrderedDict()
    for key in a:
        mask[key] = compare_fn(a[key], b[key])
    return mask

state_dict_diff_abs(a, b)

Returns the per-layer abs of the difference between two state dicts.

Parameters:

Returns:

  • StateDictType

    The absolute difference between the two state dicts.

Source code in fusion_bench/utils/state_dict_arithmetic.py
def state_dict_diff_abs(a: StateDictType, b: StateDictType):
    """
    Returns the per-layer abs of the difference between two state dicts.

    Args:
        a (StateDictType): The first state dict.
        b (StateDictType): The second state dict.

    Returns:
        StateDictType: The absolute difference between the two state dicts.
    """
    diff = state_dict_sub(a, b)
    abs_diff = {key: diff[key].abs() for key in diff}
    return abs_diff

state_dict_div(state_dict, scalar, show_pbar=False)

Returns the division of a state dict by a scalar.

Parameters:

  • state_dict (Dict) –

    The state dict to be divided.

  • scalar (float) –

    The scalar to divide the state dict by.

Returns:

  • Dict

    The division of the state dict by the scalar.

Source code in fusion_bench/utils/state_dict_arithmetic.py
def state_dict_div(state_dict: StateDictType, scalar: float, show_pbar: bool = False):
    """
    Returns the division of a state dict by a scalar.

    Args:
        state_dict (Dict): The state dict to be divided.
        scalar (float): The scalar to divide the state dict by.

    Returns:
        Dict: The division of the state dict by the scalar.
    """
    diff = OrderedDict()
    for k in tqdm(tuple(state_dict.keys())) if show_pbar else state_dict:
        diff[k] = state_dict[k] / scalar
    return diff

state_dict_flatten(state_dict)

Flattens a state dict.

Parameters:

  • state_dict (Dict[str, Tensor]) –

    The state dict to be flattened.

Returns:

  • Tensor

    The flattened state dict.

Source code in fusion_bench/utils/state_dict_arithmetic.py
def state_dict_flatten(state_dict: Dict[str, Tensor]):
    """
    Flattens a state dict.

    Args:
        state_dict (Dict[str, Tensor]): The state dict to be flattened.

    Returns:
        Tensor: The flattened state dict.
    """
    flattened_state_dict = []
    for key in state_dict:
        flattened_state_dict.append(state_dict[key].flatten())
    return torch.cat(flattened_state_dict)

state_dict_hadmard_product(a, b)

Returns the Hadamard product of two state dicts, i.e. element-wise product.

Parameters:

Returns:

  • StateDictType ( StateDictType ) –

    The Hadamard product of the two state dicts.

Source code in fusion_bench/utils/state_dict_arithmetic.py
def state_dict_hadmard_product(a: StateDictType, b: StateDictType) -> StateDictType:
    """
    Returns the Hadamard product of two state dicts, i.e. element-wise product.

    Args:
        a (StateDictType): The first state dict.
        b (StateDictType): The second state dict.

    Returns:
        StateDictType: The Hadamard product of the two state dicts.
    """
    ans = OrderedDict()
    for key in a:
        ans[key] = a[key] * b[key]
    return ans

state_dict_interpolation(state_dicts, scalars)

Interpolates between a list of state dicts using a list of scalars.

Parameters:

  • state_dicts (List[Dict[str, Tensor]]) –

    The list of state dicts to interpolate between.

  • scalars (List[float]) –

    The list of scalars to use for interpolation.

Returns:

  • Dict

    The interpolated state dict.

Source code in fusion_bench/utils/state_dict_arithmetic.py
def state_dict_interpolation(
    state_dicts: List[Dict[str, Tensor]], scalars: List[float]
):
    """
    Interpolates between a list of state dicts using a list of scalars.

    Args:
        state_dicts (List[Dict[str, Tensor]]): The list of state dicts to interpolate between.
        scalars (List[float]): The list of scalars to use for interpolation.

    Returns:
        Dict: The interpolated state dict.
    """
    assert len(state_dicts) == len(
        scalars
    ), "The number of state_dicts and scalars must be the same"
    assert len(state_dicts) > 0, "The number of state_dicts must be greater than 0"
    assert all(
        [len(state_dicts[0]) == len(state_dict) for state_dict in state_dicts]
    ), "All state_dicts must have the same number of keys"

    interpolated_state_dict = {}
    for key in state_dicts[0]:
        interpolated_state_dict[key] = torch.zeros_like(state_dicts[0][key])
        for state_dict, scalar in zip(state_dicts, scalars):
            interpolated_state_dict[key] += scalar * state_dict[key]
    return interpolated_state_dict

state_dict_mul(state_dict, scalar)

Returns the product of a state dict and a scalar.

Parameters:

  • state_dict (Dict) –

    The state dict to be multiplied.

  • scalar (float) –

    The scalar to multiply the state dict with.

Returns:

  • Dict

    The product of the state dict and the scalar.

Source code in fusion_bench/utils/state_dict_arithmetic.py
def state_dict_mul(state_dict: StateDictType, scalar: float):
    """
    Returns the product of a state dict and a scalar.

    Args:
        state_dict (Dict): The state dict to be multiplied.
        scalar (float): The scalar to multiply the state dict with.

    Returns:
        Dict: The product of the state dict and the scalar.
    """
    diff = OrderedDict()
    for k in state_dict:
        diff[k] = scalar * state_dict[k]
    return diff

state_dict_power(state_dict, p)

Returns the power of a state dict.

Parameters:

  • state_dict (Dict[str, Tensor]) –

    The state dict to be powered.

  • p (float) –

    The power to raise the state dict to.

Returns:

  • Dict[str, Tensor]: The powered state dict.

Source code in fusion_bench/utils/state_dict_arithmetic.py
def state_dict_power(state_dict: Dict[str, Tensor], p: float):
    """
    Returns the power of a state dict.

    Args:
        state_dict (Dict[str, Tensor]): The state dict to be powered.
        p (float): The power to raise the state dict to.

    Returns:
        Dict[str, Tensor]: The powered state dict.
    """
    powered_state_dict = {}
    for key in state_dict:
        powered_state_dict[key] = state_dict[key] ** p
    return powered_state_dict

state_dict_sub(a, b, strict=True, device=None)

Returns the difference between two state dicts a-b.

Parameters:

  • a (StateDictType) –

    The first state dict.

  • b (StateDictType) –

    The second state dict.

  • strict (bool, default: True ) –

    Whether to check if the keys of the two state dicts are the same.

Returns:

  • StateDictType

    The difference between the two state dicts.

Source code in fusion_bench/utils/state_dict_arithmetic.py
def state_dict_sub(
    a: StateDictType, b: StateDictType, strict: bool = True, device=None
):
    """
    Returns the difference between two state dicts `a-b`.

    Args:
        a (StateDictType): The first state dict.
        b (StateDictType): The second state dict.
        strict (bool): Whether to check if the keys of the two state dicts are the same.

    Returns:
        StateDictType: The difference between the two state dicts.
    """
    if strict:
        assert set(a.keys()) == set(b.keys())

    diff = OrderedDict()
    for k in a:
        if k in b:
            diff[k] = a[k] - b[k]
            if device is not None:
                diff[k] = diff[k].to(device, non_blocking=True)
    return diff

state_dict_sum(state_dicts)

Returns the sum of a list of state dicts.

Parameters:

  • state_dicts (List[Dict[str, Tensor]]) –

    The list of state dicts to sum.

Returns:

  • Dict

    The sum of the state dicts.

Source code in fusion_bench/utils/state_dict_arithmetic.py
def state_dict_sum(state_dicts: List[StateDictType]):
    """
    Returns the sum of a list of state dicts.

    Args:
        state_dicts (List[Dict[str, Tensor]]): The list of state dicts to sum.

    Returns:
        Dict: The sum of the state dicts.
    """
    assert len(state_dicts) > 0, "The number of state_dicts must be greater than 0"
    assert all(
        [len(state_dicts[0]) == len(state_dict) for state_dict in state_dicts]
    ), "All state_dicts must have the same number of keys"

    sum_state_dict = OrderedDict()
    for key in state_dicts[0]:
        sum_state_dict[key] = 0
        for state_dict in state_dicts:
            sum_state_dict[key] = sum_state_dict[key] + state_dict[key]
    return sum_state_dict

state_dict_weighted_sum(state_dicts, weights, device=None)

Returns the weighted sum of a list of state dicts.

Parameters:

  • state_dicts (List[Dict[str, Tensor]]) –

    The list of state dicts to interpolate between.

  • weights (List[float]) –

    The list of weights to use for the weighted sum.

Returns:

  • Dict

    The weighted sum of the state dicts.

Source code in fusion_bench/utils/state_dict_arithmetic.py
def state_dict_weighted_sum(
    state_dicts: List[Dict[str, Tensor]], weights: List[float], device=None
):
    """
    Returns the weighted sum of a list of state dicts.

    Args:
        state_dicts (List[Dict[str, Tensor]]): The list of state dicts to interpolate between.
        weights (List[float]): The list of weights to use for the weighted sum.

    Returns:
        Dict: The weighted sum of the state dicts.
    """
    assert len(state_dicts) == len(
        weights
    ), "The number of state_dicts and weights must be the same"
    assert len(state_dicts) > 0, "The number of state_dicts must be greater than 0"
    assert all(
        [len(state_dicts[0]) == len(state_dict) for state_dict in state_dicts]
    ), "All state_dicts must have the same number of keys"

    weighted_sum_state_dict: Dict[str, Tensor] = {}
    for key in state_dicts[0]:
        # states dicts can be sparse matrices
        weighted_sum_state_dict[key] = torch.zeros_like(state_dicts[0][key]).to_dense()
        for state_dict, weight in zip(state_dicts, weights):
            weighted_sum_state_dict[key] = torch.add(
                weighted_sum_state_dict[key], weight * state_dict[key]
            )
        if device is not None:
            weighted_sum_state_dict[key] = weighted_sum_state_dict[key].to(
                device, non_blocking=True
            )
    return weighted_sum_state_dict

state_dicts_check_keys(state_dicts)

Checks that the state dictionaries have the same keys.

Parameters:

  • state_dicts (List[Dict[str, Tensor]]) –

    A list of dictionaries containing the state of PyTorch models.

Raises:

  • ValueError

    If the state dictionaries have different keys.

Source code in fusion_bench/utils/state_dict_arithmetic.py
def state_dicts_check_keys(state_dicts: List[StateDictType]):
    """
    Checks that the state dictionaries have the same keys.

    Args:
        state_dicts (List[Dict[str, Tensor]]): A list of dictionaries containing the state of PyTorch models.

    Raises:
        ValueError: If the state dictionaries have different keys.
    """
    # Get the keys of the first state dictionary in the list
    keys = set(state_dicts[0].keys())
    # Check that all the state dictionaries have the same keys
    for state_dict in state_dicts:
        assert keys == set(state_dict.keys()), "keys of state_dicts are not equal"

Lazy Model Loading

fusion_bench.utils.lazy_state_dict.LazyStateDict

Dictionary-like object that lazily loads a state dict from a checkpoint path.

Source code in fusion_bench/utils/lazy_state_dict.py
class LazyStateDict:
    """
    Dictionary-like object that lazily loads a state dict from a checkpoint path.
    """

    _local_path: str
    """local path to the checkpoint."""
    _state_dict_cache: Optional[Dict]
    """Cache for the state dict, if enabled."""
    _index_filename: Optional[str]
    _checkpoint_files: Optional[List[str]]
    _index: Optional[Dict[str, str]]
    """Mapping of parameter names to checkpoint files."""

    def __init__(
        self,
        checkpoint: str,
        meta_module_class: Optional[Type[nn.Module]] = None,
        meta_module: Optional[nn.Module] = None,
        cache_state_dict: bool = False,
        torch_dtype: Optional[torch.dtype] = None,
        device: str = "cpu",
        hf_revision: Optional[str] = None,
        hf_cache_dir: Optional[str] = None,
        hf_proxies: Optional[Dict] = None,
    ):
        """
        Args:
            checkpoint (str): Path to the checkpoint file or directory.
            meta_module_class (Type[nn.Module], optional): Class of the meta module to instantiate.
            meta_module (nn.Module, optional): Pre-initialized meta module.
            cache_state_dict (bool): Whether to cache the state dict in memory.
            torch_dtype (torch.dtype, optional): The dtype to use for the tensors.
            device (str): The device to load the tensors onto.
            hf_revision (str, optional): The revision of the model to download from Hugging Face Hub.
            hf_cache_dir (str, optional): The cache directory for Hugging Face models.
            hf_proxies (Dict, optional): Proxies to use for downloading from Hugging Face Hub.
        """
        self.cache_state_dict = cache_state_dict
        self.meta_module_class = meta_module_class
        if isinstance(self.meta_module_class, str):
            self.meta_module_class = import_object(self.meta_module_class)
        self.meta_module = meta_module
        if self.meta_module_class is not None:
            if self.meta_module is not None:
                raise ValueError(
                    "Cannot provide both meta_module_class and meta_module, please provide only one."
                )
            with init_empty_weights():
                self.meta_module = self.meta_module_class.from_pretrained(
                    checkpoint,
                    torch_dtype=torch_dtype,
                    revision=hf_revision,
                    cache_dir=hf_cache_dir,
                    proxies=hf_proxies,
                )

        self._checkpoint = checkpoint
        self._local_path = resolve_checkpoint_path(
            checkpoint,
            hf_revision=hf_revision,
            hf_cache_dir=hf_cache_dir,
            hf_proxies=hf_proxies,
        )

        self._index, self._index_filename, self._checkpoint_files = (
            self._resolve_checkpoint_files(self._local_path)
        )

        if self._index is not None:
            # if meta_module is provided, remove the keys that are not in the meta_module
            if self.meta_module is not None:
                meta_module_state_dict = self.meta_module.state_dict()
                for key in tuple(self._index.keys()):
                    if key not in meta_module_state_dict:
                        self._index.pop(key)
            if cache_state_dict:
                self._state_dict_cache = {}
            else:
                self._state_dict_cache = None
        elif len(self._checkpoint_files) == 1 and self._checkpoint_files[0].endswith(
            SAFE_WEIGHTS_NAME
        ):
            # let the keys of self._index be the keys of the state dict, the values are the checkpoint file
            with safe_open(
                self._checkpoint_files[0], framework="pt", device=device
            ) as f:
                self._index = {key: self._checkpoint_files[0] for key in f.keys()}
                if cache_state_dict:
                    self._state_dict_cache = {}
                else:
                    self._state_dict_cache = None
        elif len(self._checkpoint_files) == 1 and self._checkpoint_files[0].endswith(
            WEIGHTS_NAME
        ):
            log.info(f"Loading full state dict from {WEIGHTS_NAME}")
            self._state_dict_cache = torch.load(self._checkpoint_files[0])
            # if meta_module is provided, remove the keys that are not in the meta_module
            if self.meta_module is not None:
                meta_module_state_dict = self.meta_module.state_dict()
                for key in tuple(self._state_dict_cache.keys()):
                    if key not in meta_module_state_dict:
                        self._state_dict_cache.pop(key)
        else:
            raise ValueError(
                f"Cannot determine the type of checkpoint, please provide a checkpoint path to a file containing a whole state dict with file name {WEIGHTS_NAME} or {SAFE_WEIGHTS_NAME}, or the index of a sharded checkpoint ending with `.index.json`."
            )

        self._torch_dtype = parse_dtype(torch_dtype)
        self._device = device

    @property
    def checkpoint(self) -> str:
        return self._checkpoint

    @property
    def config(self) -> "PretrainedConfig":
        return AutoConfig.from_pretrained(self._checkpoint)

    def state_dict(self, keep_vars: bool = False) -> "LazyStateDict":
        """
        Args:
            keep_vars (bool): Ignored, as LazyStateDict does not support keep_vars. Just for compatibility.
        """
        return self

    def _resolve_checkpoint_files(self, checkpoint: str):
        # reference: https://huggingface.co/docs/accelerate/v0.17.1/en/usage_guides/big_modeling
        checkpoint_files = None
        index_filename = None
        if os.path.isfile(checkpoint):
            if str(checkpoint).endswith(".json"):
                index_filename = checkpoint
            else:
                checkpoint_files = [checkpoint]
        elif os.path.isdir(checkpoint):
            # check if the whole state dict is present
            potential_state_bin = [
                f for f in os.listdir(checkpoint) if f == WEIGHTS_NAME
            ]
            potential_state_safetensor = [
                f for f in os.listdir(checkpoint) if f == SAFE_WEIGHTS_NAME
            ]
            if len(potential_state_bin) == 1:
                checkpoint_files = [os.path.join(checkpoint, potential_state_bin[0])]
            elif len(potential_state_safetensor) == 1:
                checkpoint_files = [
                    os.path.join(checkpoint, potential_state_safetensor[0])
                ]
            else:
                # otherwise check for sharded checkpoints
                potential_index = [
                    f for f in os.listdir(checkpoint) if f.endswith(".index.json")
                ]
                if len(potential_index) == 0:
                    raise ValueError(
                        f"{checkpoint} is not a folder containing a `.index.json` file or a {WEIGHTS_NAME} or a {SAFE_WEIGHTS_NAME} file"
                    )
                elif len(potential_index) == 1:
                    index_filename = os.path.join(checkpoint, potential_index[0])
                else:
                    raise ValueError(
                        f"{checkpoint} containing more than one `.index.json` file, delete the irrelevant ones."
                    )
        else:
            raise ValueError(
                "`checkpoint` should be the path to a file containing a whole state dict, or the index of a sharded "
                f"checkpoint, or a folder containing a sharded checkpoint or the whole state dict, but got {checkpoint}."
            )

        if index_filename is not None:
            checkpoint_folder = os.path.split(index_filename)[0]
            with open(index_filename) as f:
                index = json.loads(f.read())

            if "weight_map" in index:
                index = index["weight_map"]
            checkpoint_files = sorted(list(set(index.values())))
            checkpoint_files = [
                os.path.join(checkpoint_folder, f) for f in checkpoint_files
            ]
        else:
            index = None
        return index, index_filename, checkpoint_files

    def _load_tensor_from_checkpoint_file(
        self, checkpoint_file: str, key: str, update_cache: bool = True
    ) -> torch.Tensor:
        if checkpoint_file.endswith(".safetensors"):
            with safe_open(checkpoint_file, framework="pt", device=self._device) as f:
                tensor = f.get_tensor(key)
                if self._torch_dtype is not None:
                    tensor = tensor.to(self._torch_dtype)
                if update_cache and self._state_dict_cache is not None:
                    self._state_dict_cache[key] = tensor
                return tensor
        else:
            state_dict = torch.load(checkpoint_file, map_location=self._device)
            if update_cache:
                if self._state_dict_cache is not None:
                    self._state_dict_cache.update(state_dict)
                else:
                    log.warning(
                        f"Load full state dict from file {checkpoint_file}, but state dict cache is disabled."
                    )
            return state_dict[key]

    def __getitem__(self, key: str) -> torch.Tensor:
        if self._state_dict_cache is not None and key in self._state_dict_cache:
            return self._state_dict_cache[key]

        if self._index is None:
            if len(self._checkpoint_files) == 1 and os.path.isfile(
                self._checkpoint_files[0]
            ):
                checkpoint_file = self._checkpoint_files[0]
                tensor = self._load_tensor_from_checkpoint_file(
                    checkpoint_file, key, update_cache=True
                )
                return tensor
            else:
                if len(self._checkpoint_files) > 1:
                    raise RuntimeError(
                        "Get multiple checkpoint files, but index is not provided."
                    )
                if not os.path.isfile(self._checkpoint_files[0]):
                    raise FileNotFoundError(
                        f"Checkpoint file {self._checkpoint_files[0]} not found."
                    )
                raise RuntimeError("Unexpected error.")
        else:
            if key not in self._index:
                raise KeyError(f"Key {key} not found in index.")
            checkpoint_file = os.path.join(self._local_path, self._index[key])
            if not os.path.isfile(checkpoint_file):
                raise FileNotFoundError(f"Checkpoint file {checkpoint_file} not found.")
            tensor = self._load_tensor_from_checkpoint_file(
                checkpoint_file, key, update_cache=True
            )
            return tensor

    def __setitem__(self, key: str, value: torch.Tensor) -> None:
        """
        Set a tensor in the LazyStateDict. This will update the state dict cache if it is enabled.
        """
        assert key in list(
            self.keys()
        ), "KeyError: Cannot set a tensor for a key that does not exist in the LazyStateDict."
        if self._state_dict_cache is not None:
            self._state_dict_cache[key] = value
        else:
            log.warning(
                "State dict cache is disabled, setting a tensor will not update the cache."
            )
            self._state_dict_cache = {key: value}

    def __contains__(self, key: str) -> bool:
        if self._state_dict_cache is not None and key in self._state_dict_cache:
            return True
        if self._index is not None and key in self._index:
            return True
        if len(self._checkpoint_files) == 1 and os.path.isfile(
            self._checkpoint_files[0]
        ):
            try:
                tensor = self._load_tensor_from_checkpoint_file(
                    self._checkpoint_files[0], key, update_cache=False
                )
                return tensor is not None
            except Exception:
                return False
        return False

    def __len__(self) -> int:
        if self._index is not None:
            return len(self._index)
        if len(self._checkpoint_files) == 1 and os.path.isfile(
            self._checkpoint_files[0]
        ):
            checkpoint_file = self._checkpoint_files[0]
            if checkpoint_file.endswith(".safetensors"):
                with safe_open(checkpoint_file, framework="pt", device="cpu") as f:
                    return len(tuple(f.keys()))
            else:
                return len(
                    tuple(torch.load(checkpoint_file, map_location="cpu").keys())
                )
        raise RuntimeError(
            "Unexpected error: cannot determine the number of keys in the state dict."
        )

    def __iter__(self) -> Iterator[str]:
        if self._index is not None:
            return iter(self._index)
        elif self._state_dict_cache is not None:
            return iter(self._state_dict_cache)
        else:
            raise RuntimeError(
                "Unexpected error: cannot determine the keys in the state dict."
            )

    def keys(self) -> Iterator[str]:
        for key in self:
            yield key

    def values(self) -> Iterator[torch.Tensor]:
        for key in self:
            yield self[key]

    def items(self) -> Iterator[Tuple[str, torch.Tensor]]:
        for key in self:
            yield key, self[key]

    def __repr__(self) -> str:
        if self._index is not None:
            return f"{self.__class__.__name__}(keys={list(self.keys())})"
        else:
            return (
                f"{self.__class__.__name__}(checkpoint_files={self._checkpoint_files})"
            )

    def get_parameter(self, target: str) -> torch.Tensor:
        return self[target]

    def get_submodule(self, target: str) -> nn.Module:
        if self.meta_module is not None:
            module: nn.Module = deepcopy(self.meta_module.get_submodule(target))
            module.to_empty(device=self._device)
            state_dict = {}
            for name, _ in module.named_parameters():
                state_dict[name] = self[f"{target}.{name}"]
            module.load_state_dict(state_dict)
            return module
        else:
            raise RuntimeError(
                "Cannot get submodule because meta_module is not provided."
            )

    def load_state_dict(
        self, state_dict: Dict[str, torch.Tensor], strict: bool = True
    ) -> None:
        """
        Load a state dict into this LazyStateDict.
        This method is only for compatibility with nn.Module and it overrides the cache of LazyStateDict.

        Args:
            state_dict (Dict[str, torch.Tensor]): The state dict to load.
            strict (bool): Whether to enforce that all keys in the state dict are present in this LazyStateDict.
        """
        log.warning(
            "Loading state dict into LazyStateDict is not recommended, as it may lead to unexpected behavior. "
            "Use with caution."
        )
        if strict:
            for key in state_dict:
                if key not in self:
                    raise KeyError(f"Key {key} not found in LazyStateDict.")
        for key, value in state_dict.items():
            self[key] = value

__init__(checkpoint, meta_module_class=None, meta_module=None, cache_state_dict=False, torch_dtype=None, device='cpu', hf_revision=None, hf_cache_dir=None, hf_proxies=None)

Parameters:

  • checkpoint (str) –

    Path to the checkpoint file or directory.

  • meta_module_class (Type[Module], default: None ) –

    Class of the meta module to instantiate.

  • meta_module (Module, default: None ) –

    Pre-initialized meta module.

  • cache_state_dict (bool, default: False ) –

    Whether to cache the state dict in memory.

  • torch_dtype (dtype, default: None ) –

    The dtype to use for the tensors.

  • device (str, default: 'cpu' ) –

    The device to load the tensors onto.

  • hf_revision (str, default: None ) –

    The revision of the model to download from Hugging Face Hub.

  • hf_cache_dir (str, default: None ) –

    The cache directory for Hugging Face models.

  • hf_proxies (Dict, default: None ) –

    Proxies to use for downloading from Hugging Face Hub.

Source code in fusion_bench/utils/lazy_state_dict.py
def __init__(
    self,
    checkpoint: str,
    meta_module_class: Optional[Type[nn.Module]] = None,
    meta_module: Optional[nn.Module] = None,
    cache_state_dict: bool = False,
    torch_dtype: Optional[torch.dtype] = None,
    device: str = "cpu",
    hf_revision: Optional[str] = None,
    hf_cache_dir: Optional[str] = None,
    hf_proxies: Optional[Dict] = None,
):
    """
    Args:
        checkpoint (str): Path to the checkpoint file or directory.
        meta_module_class (Type[nn.Module], optional): Class of the meta module to instantiate.
        meta_module (nn.Module, optional): Pre-initialized meta module.
        cache_state_dict (bool): Whether to cache the state dict in memory.
        torch_dtype (torch.dtype, optional): The dtype to use for the tensors.
        device (str): The device to load the tensors onto.
        hf_revision (str, optional): The revision of the model to download from Hugging Face Hub.
        hf_cache_dir (str, optional): The cache directory for Hugging Face models.
        hf_proxies (Dict, optional): Proxies to use for downloading from Hugging Face Hub.
    """
    self.cache_state_dict = cache_state_dict
    self.meta_module_class = meta_module_class
    if isinstance(self.meta_module_class, str):
        self.meta_module_class = import_object(self.meta_module_class)
    self.meta_module = meta_module
    if self.meta_module_class is not None:
        if self.meta_module is not None:
            raise ValueError(
                "Cannot provide both meta_module_class and meta_module, please provide only one."
            )
        with init_empty_weights():
            self.meta_module = self.meta_module_class.from_pretrained(
                checkpoint,
                torch_dtype=torch_dtype,
                revision=hf_revision,
                cache_dir=hf_cache_dir,
                proxies=hf_proxies,
            )

    self._checkpoint = checkpoint
    self._local_path = resolve_checkpoint_path(
        checkpoint,
        hf_revision=hf_revision,
        hf_cache_dir=hf_cache_dir,
        hf_proxies=hf_proxies,
    )

    self._index, self._index_filename, self._checkpoint_files = (
        self._resolve_checkpoint_files(self._local_path)
    )

    if self._index is not None:
        # if meta_module is provided, remove the keys that are not in the meta_module
        if self.meta_module is not None:
            meta_module_state_dict = self.meta_module.state_dict()
            for key in tuple(self._index.keys()):
                if key not in meta_module_state_dict:
                    self._index.pop(key)
        if cache_state_dict:
            self._state_dict_cache = {}
        else:
            self._state_dict_cache = None
    elif len(self._checkpoint_files) == 1 and self._checkpoint_files[0].endswith(
        SAFE_WEIGHTS_NAME
    ):
        # let the keys of self._index be the keys of the state dict, the values are the checkpoint file
        with safe_open(
            self._checkpoint_files[0], framework="pt", device=device
        ) as f:
            self._index = {key: self._checkpoint_files[0] for key in f.keys()}
            if cache_state_dict:
                self._state_dict_cache = {}
            else:
                self._state_dict_cache = None
    elif len(self._checkpoint_files) == 1 and self._checkpoint_files[0].endswith(
        WEIGHTS_NAME
    ):
        log.info(f"Loading full state dict from {WEIGHTS_NAME}")
        self._state_dict_cache = torch.load(self._checkpoint_files[0])
        # if meta_module is provided, remove the keys that are not in the meta_module
        if self.meta_module is not None:
            meta_module_state_dict = self.meta_module.state_dict()
            for key in tuple(self._state_dict_cache.keys()):
                if key not in meta_module_state_dict:
                    self._state_dict_cache.pop(key)
    else:
        raise ValueError(
            f"Cannot determine the type of checkpoint, please provide a checkpoint path to a file containing a whole state dict with file name {WEIGHTS_NAME} or {SAFE_WEIGHTS_NAME}, or the index of a sharded checkpoint ending with `.index.json`."
        )

    self._torch_dtype = parse_dtype(torch_dtype)
    self._device = device

__setitem__(key, value)

Set a tensor in the LazyStateDict. This will update the state dict cache if it is enabled.

Source code in fusion_bench/utils/lazy_state_dict.py
def __setitem__(self, key: str, value: torch.Tensor) -> None:
    """
    Set a tensor in the LazyStateDict. This will update the state dict cache if it is enabled.
    """
    assert key in list(
        self.keys()
    ), "KeyError: Cannot set a tensor for a key that does not exist in the LazyStateDict."
    if self._state_dict_cache is not None:
        self._state_dict_cache[key] = value
    else:
        log.warning(
            "State dict cache is disabled, setting a tensor will not update the cache."
        )
        self._state_dict_cache = {key: value}

load_state_dict(state_dict, strict=True)

Load a state dict into this LazyStateDict. This method is only for compatibility with nn.Module and it overrides the cache of LazyStateDict.

Parameters:

  • state_dict (Dict[str, Tensor]) –

    The state dict to load.

  • strict (bool, default: True ) –

    Whether to enforce that all keys in the state dict are present in this LazyStateDict.

Source code in fusion_bench/utils/lazy_state_dict.py
def load_state_dict(
    self, state_dict: Dict[str, torch.Tensor], strict: bool = True
) -> None:
    """
    Load a state dict into this LazyStateDict.
    This method is only for compatibility with nn.Module and it overrides the cache of LazyStateDict.

    Args:
        state_dict (Dict[str, torch.Tensor]): The state dict to load.
        strict (bool): Whether to enforce that all keys in the state dict are present in this LazyStateDict.
    """
    log.warning(
        "Loading state dict into LazyStateDict is not recommended, as it may lead to unexpected behavior. "
        "Use with caution."
    )
    if strict:
        for key in state_dict:
            if key not in self:
                raise KeyError(f"Key {key} not found in LazyStateDict.")
    for key, value in state_dict.items():
        self[key] = value

state_dict(keep_vars=False)

Parameters:

  • keep_vars (bool, default: False ) –

    Ignored, as LazyStateDict does not support keep_vars. Just for compatibility.

Source code in fusion_bench/utils/lazy_state_dict.py
def state_dict(self, keep_vars: bool = False) -> "LazyStateDict":
    """
    Args:
        keep_vars (bool): Ignored, as LazyStateDict does not support keep_vars. Just for compatibility.
    """
    return self