Skip to content

PyTorch Utilities

Device Management

fusion_bench.utils.devices

cleanup_cuda()

Call gc collect, empty CUDA cache, and reset peak memory stats.

Source code in fusion_bench/utils/devices.py
def cleanup_cuda():
    """
    Call gc collect, empty CUDA cache, and reset peak memory stats.
    """
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

clear_cuda_cache()

Clears the CUDA memory cache to free up GPU memory. Works only if CUDA is available.

Source code in fusion_bench/utils/devices.py
def clear_cuda_cache():
    """
    Clears the CUDA memory cache to free up GPU memory.
    Works only if CUDA is available.
    """
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    else:
        log.warning("CUDA is not available. No cache to clear.")

get_current_device()

Gets the current available device for PyTorch operations. This is used for distributed training.

This function checks the availability of various types of devices in the following order: 1. XPU (Intel's AI accelerator) 2. NPU (Neural Processing Unit) 3. MPS (Metal Performance Shaders, for Apple devices) 4. CUDA (NVIDIA's GPU) 5. CPU (Central Processing Unit, used as a fallback)

The function returns the first available device found in the above order. If none of the specialized devices are available, it defaults to the CPU.

Returns:

  • device

    torch.device: The current available device for PyTorch operations.

Environment Variables

LOCAL_RANK: This environment variable is used to specify the device index for multi-device setups. If not set, it defaults to "0".

Example

device = get_current_device() print(device) xpu:0 # or npu:0, mps:0, cuda:0, cpu depending on availability

Source code in fusion_bench/utils/devices.py
def get_current_device() -> torch.device:
    R"""
    Gets the current available device for PyTorch operations.
    This is used for distributed training.

    This function checks the availability of various types of devices in the following order:
    1. XPU (Intel's AI accelerator)
    2. NPU (Neural Processing Unit)
    3. MPS (Metal Performance Shaders, for Apple devices)
    4. CUDA (NVIDIA's GPU)
    5. CPU (Central Processing Unit, used as a fallback)

    The function returns the first available device found in the above order. If none of the specialized devices
    are available, it defaults to the CPU.

    Returns:
        torch.device: The current available device for PyTorch operations.

    Environment Variables:
        LOCAL_RANK: This environment variable is used to specify the device index for multi-device setups.
                    If not set, it defaults to "0".

    Example:
        >>> device = get_current_device()
        >>> print(device)
        xpu:0  # or npu:0, mps:0, cuda:0, cpu depending on availability
    """

    if is_torch_xpu_available():
        device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
    elif is_torch_npu_available():
        device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
    elif is_torch_mps_available():
        device = "mps:{}".format(os.environ.get("LOCAL_RANK", "0"))
    elif is_torch_cuda_available():
        device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
    else:
        device = "cpu"

    return torch.device(device)

get_device(obj)

Get the device of a given object.

Parameters:

  • obj

    The object whose device is to be determined.

Returns:

  • device

    torch.device: The device of the given object.

Raises:

  • ValueError

    If the object type is not supported.

Source code in fusion_bench/utils/devices.py
def get_device(obj) -> torch.device:
    """
    Get the device of a given object.

    Args:
        obj: The object whose device is to be determined.

    Returns:
        torch.device: The device of the given object.

    Raises:
        ValueError: If the object type is not supported.
    """
    if isinstance(obj, torch.Tensor):
        return obj.device
    elif isinstance(obj, torch.nn.Module):
        if hasattr(obj, "device"):
            return obj.device
        else:
            return next(iter(obj.parameters())).device
    elif isinstance(obj, torch.device):
        return obj
    else:
        raise ValueError(f"Unsupported object type: {type(obj)}")

get_device_capabilities(device)

Get capabilities information for a given device.

Parameters:

  • device (device) –

    The device for which to get capabilities information.

Returns:

  • dict ( dict ) –

    A dictionary containing capabilities information for the given device.

Source code in fusion_bench/utils/devices.py
def get_device_capabilities(device: torch.device) -> dict:
    """
    Get capabilities information for a given device.

    Args:
        device (torch.device): The device for which to get capabilities information.

    Returns:
        dict: A dictionary containing capabilities information for the given device.
    """
    if device.type == "cuda":
        return {
            "name": torch.cuda.get_device_name(device),
            "capability": torch.cuda.get_device_capability(device),
            "total_memory": torch.cuda.get_device_properties(device).total_memory,
            "multi_processor_count": torch.cuda.get_device_properties(
                device
            ).multi_processor_count,
        }
    else:
        raise ValueError(
            f"Capabilities information not available for device type: {device.type}"
        )

get_device_memory_info(device, reset_stats=True)

Get memory information for a given device.

Parameters:

  • device (device) –

    The device for which to get memory information.

Returns:

  • dict ( dict ) –

    A dictionary containing memory information for the given device.

Source code in fusion_bench/utils/devices.py
def get_device_memory_info(device: torch.device, reset_stats: bool = True) -> dict:
    """
    Get memory information for a given device.

    Args:
        device (torch.device): The device for which to get memory information.

    Returns:
        dict: A dictionary containing memory information for the given device.
    """
    if device.type == "cuda":
        total_memory = torch.cuda.get_device_properties(device).total_memory
        reserved_memory = torch.cuda.memory_reserved(device)
        allocated_memory = torch.cuda.memory_allocated(device)
        peak_memory_active = torch.cuda.memory_stats(device).get(
            "active_bytes.all.peak", 0
        )
        peak_mem_alloc = torch.cuda.max_memory_allocated(device)
        peak_mem_reserved = torch.cuda.max_memory_reserved(device)

        if reset_stats:
            torch.cuda.reset_peak_memory_stats(device)

        return {
            "total_memory": total_memory,
            "reserved_memory": reserved_memory,
            "allocated_memory": allocated_memory,
            "peak_memory_active": peak_memory_active,
            "peak_memory_allocated": peak_mem_alloc,
            "peak_memory_reserved": peak_mem_reserved,
        }
    else:
        raise ValueError(
            f"Memory information not available for device type: {device.type}"
        )

num_devices(devices)

Return the number of devices.

Parameters:

  • devices (Union[int, List[int], str]) –

    devices can be a single int to specify the number of devices, or a list of device ids, e.g. [0, 1, 2, 3], or a str of device ids, e.g. "0,1,2,3" and "[0, 1, 2]".

Returns:

  • int

    The number of devices.

Source code in fusion_bench/utils/devices.py
def num_devices(devices: Union[int, List[int], str]) -> int:
    """
    Return the number of devices.

    Args:
        devices: `devices` can be a single int to specify the number of devices, or a list of device ids, e.g. [0, 1, 2, 3], or a str of device ids, e.g. "0,1,2,3" and "[0, 1, 2]".

    Returns:
        The number of devices.
    """
    if isinstance(devices, int):
        return devices
    elif isinstance(devices, str):
        return len(devices.split(","))
    elif isinstance(devices, list):
        return len(devices)
    else:
        raise TypeError(
            f"devices must be a single int or a list of ints, but got {type(devices)}"
        )

print_memory_usage(print_fn=print)

Print the current GPU memory usage.

Returns:

  • str

    A string containing the allocated and cached memory in MB.

Source code in fusion_bench/utils/devices.py
def print_memory_usage(print_fn=print):
    """
    Print the current GPU memory usage.

    Returns:
        str: A string containing the allocated and cached memory in MB.
    """
    allocated = torch.cuda.memory_allocated() / 1024**2  # 转换为 MB
    cached = torch.cuda.memory_reserved() / 1024**2  # 转换为 MB
    print_str = f"Allocated Memory: {allocated:.2f} MB\nCached Memory: {cached:.2f} MB"
    print_fn(print_str)
    return print_str

to_device(obj, device, **kwargs)

Move a given object to the specified device.

This function recursively moves tensors, modules, lists, tuples, and dictionaries to the specified device. For unsupported types, the object is returned as is.

Parameters:

  • obj

    The object to be moved to the device. This can be a torch.Tensor, torch.nn.Module, list, tuple, or dict.

  • device (device) –

    The target device to move the object to. This can be None.

  • **kwargs

    Additional keyword arguments to be passed to the to method of torch.Tensor or torch.nn.Module. For example, non_blocking=True, dtype=torch.float16.

Returns:

  • The object moved to the specified device. The type of the returned object matches the type of the input object.

Examples:

>>> tensor = torch.tensor([1, 2, 3])
>>> to_device(tensor, torch.device('cuda'))
tensor([1, 2, 3], device='cuda:0')
>>> model = torch.nn.Linear(2, 2)
>>> to_device(model, torch.device('cuda'))
Linear(..., device='cuda:0')
>>> data = [torch.tensor([1, 2]), torch.tensor([3, 4])]
>>> to_device(data, torch.device('cuda'))
[tensor([1, 2], device='cuda:0'), tensor([3, 4], device='cuda:0')]
Source code in fusion_bench/utils/devices.py
def to_device(obj, device: Optional[torch.device], **kwargs):
    """
    Move a given object to the specified device.

    This function recursively moves tensors, modules, lists, tuples, and dictionaries to the specified device.
    For unsupported types, the object is returned as is.

    Args:
        obj: The object to be moved to the device. This can be a torch.Tensor, torch.nn.Module, list, tuple, or dict.
        device (torch.device): The target device to move the object to. This can be `None`.
        **kwargs: Additional keyword arguments to be passed to the `to` method of torch.Tensor or torch.nn.Module. For example, `non_blocking=True`, `dtype=torch.float16`.

    Returns:
        The object moved to the specified device. The type of the returned object matches the type of the input object.

    Examples:
        >>> tensor = torch.tensor([1, 2, 3])
        >>> to_device(tensor, torch.device('cuda'))
        tensor([1, 2, 3], device='cuda:0')

        >>> model = torch.nn.Linear(2, 2)
        >>> to_device(model, torch.device('cuda'))
        Linear(..., device='cuda:0')

        >>> data = [torch.tensor([1, 2]), torch.tensor([3, 4])]
        >>> to_device(data, torch.device('cuda'))
        [tensor([1, 2], device='cuda:0'), tensor([3, 4], device='cuda:0')]
    """
    if isinstance(obj, (torch.Tensor, torch.nn.Module)):
        return obj.to(device, **kwargs)
    elif isinstance(obj, list):
        return [to_device(o, device) for o in obj]
    elif isinstance(obj, tuple):
        return tuple(to_device(o, device) for o in obj)
    elif isinstance(obj, dict):
        for key in obj:
            obj[key] = to_device(obj[key], device)
        return obj
    else:
        # the default behavior is to return the object as is
        return obj

Dtype

fusion_bench.utils.dtype

get_dtype(obj)

Get the data type (dtype) of a given object.

Returns:

  • dtype

    torch.dtype: The data type of the given object.

Raises:

  • ValueError

    If the object type is not supported.

Source code in fusion_bench/utils/dtype.py
def get_dtype(obj) -> torch.dtype:
    """
    Get the data type (dtype) of a given object.

    Returns:
        torch.dtype: The data type of the given object.

    Raises:
        ValueError: If the object type is not supported.
    """
    if isinstance(obj, torch.Tensor):
        return obj.dtype
    elif isinstance(obj, torch.nn.Module):
        if hasattr(obj, "dtype"):
            return obj.dtype
        else:
            return next(iter(obj.parameters())).dtype
    elif isinstance(obj, (torch.device, str)):
        return parse_dtype(obj)
    else:
        raise ValueError(f"Unsupported object type: {type(obj)}")

infer_optim_dtype(model_dtype)

Infers the optimal dtype according to the model_dtype and device compatibility.

Source code in fusion_bench/utils/dtype.py
def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
    r"""
    Infers the optimal dtype according to the model_dtype and device compatibility.
    """
    _is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
    try:
        _is_bf16_available = is_torch_bf16_gpu_available() or (
            is_torch_npu_available() and torch.npu.is_bf16_supported()
        )
    except Exception:
        _is_bf16_available = False

    if _is_bf16_available and model_dtype == torch.bfloat16:
        return torch.bfloat16
    elif _is_fp16_available:
        return torch.float16
    else:
        return torch.float32

parse_dtype(dtype)

Parses a string representation of a data type and returns the corresponding torch.dtype.

Parameters:

  • dtype (Optional[str]) –

    The string representation of the data type. Can be one of "float32", "float", "float64", "double", "float16", "half", "bfloat16", or "bf16". If None, returns None.

Returns:

  • torch.dtype: The corresponding torch.dtype if the input is a valid string representation. If the input is already a torch.dtype, it is returned as is. If the input is None, returns None.

Raises:

  • ValueError

    If the input string does not correspond to a supported data type.

Source code in fusion_bench/utils/dtype.py
def parse_dtype(dtype: Optional[str]):
    """
    Parses a string representation of a data type and returns the corresponding torch.dtype.

    Args:
        dtype (Optional[str]): The string representation of the data type.
                               Can be one of "float32", "float", "float64", "double",
                               "float16", "half", "bfloat16", or "bf16".
                               If None, returns None.

    Returns:
        torch.dtype: The corresponding torch.dtype if the input is a valid string representation.
                     If the input is already a torch.dtype, it is returned as is.
                     If the input is None, returns None.

    Raises:
        ValueError: If the input string does not correspond to a supported data type.
    """
    if isinstance(dtype, torch.dtype):
        return dtype

    if dtype is None:
        return None

    dtype = dtype.strip('"')
    if dtype not in PRECISION_STR_TO_DTYPE:
        raise ValueError(f"Unsupported dtype string: {dtype}")

    dtype = PRECISION_STR_TO_DTYPE[dtype]
    return dtype

set_default_dtype(dtype)

Context manager to set torch's default dtype.

Parameters:

  • dtype (dtype) –

    The desired default dtype inside the context manager.

Returns:

  • ContextManager ( None ) –

    context manager for setting default dtype.

Example

with set_default_dtype(torch.bfloat16): x = torch.tensor([1, 2, 3]) x.dtype torch.bfloat16

Source code in fusion_bench/utils/dtype.py
@contextlib.contextmanager
def set_default_dtype(dtype: torch.dtype) -> Generator[None, None, None]:
    """
    Context manager to set torch's default dtype.

    Args:
        dtype (torch.dtype): The desired default dtype inside the context manager.

    Returns:
        ContextManager: context manager for setting default dtype.

    Example:
        >>> with set_default_dtype(torch.bfloat16):
        >>>     x = torch.tensor([1, 2, 3])
        >>>     x.dtype
        torch.bfloat16


    """
    old_dtype = torch.get_default_dtype()
    torch.set_default_dtype(dtype)
    try:
        yield
    finally:
        torch.set_default_dtype(old_dtype)

validate_expected_param_dtype(named_params, dtype)

Validates that all input parameters have the expected dtype.

Parameters:

  • named_params (Iterable[Tuple[str, Parameter]]) –

    Iterable of named parameters.

  • dtype (dtype) –

    Expected dtype.

Raises:

  • ValueError

    If any parameter has a different dtype than dtype.

Source code in fusion_bench/utils/dtype.py
def validate_expected_param_dtype(
    named_params: Iterable[Tuple[str, torch.nn.Parameter]], dtype: torch.dtype
) -> None:
    """
    Validates that all input parameters have the expected dtype.

    Args:
        named_params (Iterable[Tuple[str, torch.nn.Parameter]]): Iterable of named parameters.
        dtype (torch.dtype): Expected dtype.

    Raises:
        ValueError: If any parameter has a different dtype than `dtype`.
    """
    for name, param in named_params:
        if param.dtype != dtype:
            raise ValueError(
                f"Parameter {name} has dtype {param.dtype}, but expected {dtype}"
            )