PyTorch Utilities¶
Device Management¶
fusion_bench.utils.devices
¶
cleanup_cuda()
¶
Call gc collect, empty CUDA cache, and reset peak memory stats.
clear_cuda_cache()
¶
Clears the CUDA memory cache to free up GPU memory. Works only if CUDA is available.
Source code in fusion_bench/utils/devices.py
get_current_device()
¶
Gets the current available device for PyTorch operations. This is used for distributed training.
This function checks the availability of various types of devices in the following order: 1. XPU (Intel's AI accelerator) 2. NPU (Neural Processing Unit) 3. MPS (Metal Performance Shaders, for Apple devices) 4. CUDA (NVIDIA's GPU) 5. CPU (Central Processing Unit, used as a fallback)
The function returns the first available device found in the above order. If none of the specialized devices are available, it defaults to the CPU.
Returns:
-
device
–torch.device: The current available device for PyTorch operations.
Environment Variables
LOCAL_RANK: This environment variable is used to specify the device index for multi-device setups. If not set, it defaults to "0".
Example:
>>> device = get_current_device()
>>> print(device)
xpu:0 # or npu:0, mps:0, cuda:0, cpu depending on availability
Source code in fusion_bench/utils/devices.py
get_device(obj)
¶
Get the device of a given object.
Parameters:
-
obj
(Any
) –The object whose device is to be determined.
Returns:
-
device
–torch.device: The device of the given object.
Raises:
-
ValueError
–If the object type is not supported.
Source code in fusion_bench/utils/devices.py
get_device_capabilities(device)
¶
Get capabilities information for a given device.
Parameters:
-
device
(device
) –The device for which to get capabilities information.
Returns:
-
dict
(dict
) –A dictionary containing capabilities information for the given device.
Source code in fusion_bench/utils/devices.py
get_device_memory_info(device, reset_stats=True)
¶
Get memory information for a given device.
Parameters:
-
device
(device
) –The device for which to get memory information.
Returns:
-
dict
(dict
) –A dictionary containing memory information for the given device.
Source code in fusion_bench/utils/devices.py
num_devices(devices)
¶
Return the number of devices.
Parameters:
-
devices
(Union[int, List[int], str]
) –devices
can be a single int to specify the number of devices, or a list of device ids, e.g. [0, 1, 2, 3], or a str of device ids, e.g. "0,1,2,3" and "[0, 1, 2]".
Returns:
-
int
–The number of devices.
Source code in fusion_bench/utils/devices.py
print_memory_usage(print_fn=print)
¶
Print the current GPU memory usage.
Returns:
-
str
(str
) –A string containing the allocated and cached memory in MB.
Source code in fusion_bench/utils/devices.py
to_device(obj, device, copy_on_move=False, **kwargs)
¶
Move a given object to the specified device.
This function recursively moves tensors, modules, lists, tuples, and dictionaries to the specified device. For unsupported types, the object is returned as is.
Parameters:
-
obj
(T
) –The object to be moved to the device. This can be a torch.Tensor, torch.nn.Module, list, tuple, or dict.
-
device
(device
) –The target device to move the object to. This can be
None
. -
copy_on_move
(bool
, default:False
) –Whether to force a copy operation when moving tensors to a different device. If True, tensors will be copied when moved to a different device (copy=True is passed to tensor.to()). If False (default), tensors are moved without forcing a copy operation, allowing PyTorch to optimize the operation. This parameter only affects torch.Tensor objects; modules and other types are unaffected. Defaults to False.
-
**kwargs
(Any
, default:{}
) –Additional keyword arguments to be passed to the
to
method of torch.Tensor or torch.nn.Module. For example,non_blocking=True
,dtype=torch.float16
. Note that ifcopy_on_move=True
, thecopy
keyword argument will be automatically set and should not be provided manually.
Returns:
-
T
–The object moved to the specified device. The type of the returned object matches the type of the input object.
Examples:
>>> tensor = torch.tensor([1, 2, 3])
>>> to_device(tensor, torch.device('cuda'))
tensor([1, 2, 3], device='cuda:0')
>>> model = torch.nn.Linear(2, 2)
>>> to_device(model, torch.device('cuda'))
Linear(..., device='cuda:0')
>>> data = [torch.tensor([1, 2]), torch.tensor([3, 4])]
>>> to_device(data, torch.device('cuda'))
[tensor([1, 2], device='cuda:0'), tensor([3, 4], device='cuda:0')]
>>> # Force copy when moving to different device
>>> tensor = torch.tensor([1, 2, 3], device='cpu')
>>> copied_tensor = to_device(tensor, torch.device('cuda'), copy_on_move=True)
>>> # tensor and copied_tensor will have different memory locations
Source code in fusion_bench/utils/devices.py
Dtype¶
fusion_bench.utils.dtype
¶
get_dtype(obj)
¶
Get the data type (dtype) of a given object.
Returns:
-
dtype
–torch.dtype: The data type of the given object.
Raises:
-
ValueError
–If the object type is not supported.
Source code in fusion_bench/utils/dtype.py
infer_optim_dtype(model_dtype)
¶
Infers the optimal dtype according to the model_dtype and device compatibility.
Source code in fusion_bench/utils/dtype.py
parse_dtype(dtype)
¶
Parses a string representation of a data type and returns the corresponding torch.dtype.
Parameters:
-
dtype
(Optional[str]
) –The string representation of the data type. Can be one of "float32", "float", "float64", "double", "float16", "half", "bfloat16", or "bf16". If None, returns None.
Returns:
-
Optional[dtype]
–torch.dtype: The corresponding torch.dtype if the input is a valid string representation. If the input is already a torch.dtype, it is returned as is. If the input is None, returns None.
Raises:
-
ValueError
–If the input string does not correspond to a supported data type.
Source code in fusion_bench/utils/dtype.py
set_default_dtype(dtype)
¶
Context manager to set torch's default dtype.
Parameters:
-
dtype
(dtype
) –The desired default dtype inside the context manager.
Returns:
-
ContextManager
(None
) –context manager for setting default dtype.
Example:
>>> with set_default_dtype(torch.bfloat16):
>>> x = torch.tensor([1, 2, 3])
>>> x.dtype
torch.bfloat16
Source code in fusion_bench/utils/dtype.py
validate_expected_param_dtype(named_params, dtype)
¶
Validates that all input parameters have the expected dtype.
Parameters:
-
named_params
(Iterable[Tuple[str, Parameter]]
) –Iterable of named parameters.
-
dtype
(dtype
) –Expected dtype.
Raises:
-
ValueError
–If any parameter has a different dtype than
dtype
.