Skip to content

Data Utilities

Dataset Manipulation

fusion_bench.utils.data

InfiniteDataLoader

A wrapper class for DataLoader to create an infinite data loader. This is useful in case we are only interested in the number of steps and not the number of epochs.

This class wraps a DataLoader and provides an iterator that resets when the end of the dataset is reached, creating an infinite loop.

Attributes:

  • data_loader (DataLoader) –

    The DataLoader to wrap.

  • _data_iter (iterator) –

    An iterator over the DataLoader.

  • _iteration_count (int) –

    Number of complete iterations through the dataset.

Example

train_loader = DataLoader(dataset, batch_size=32) infinite_loader = InfiniteDataLoader(train_loader) for i, batch in enumerate(infinite_loader): ... if i >= 1000: # Train for 1000 steps ... break ... train_step(batch)

Source code in fusion_bench/utils/data.py
class InfiniteDataLoader:
    """
    A wrapper class for DataLoader to create an infinite data loader.
    This is useful in case we are only interested in the number of steps and not the number of epochs.

    This class wraps a DataLoader and provides an iterator that resets
    when the end of the dataset is reached, creating an infinite loop.

    Attributes:
        data_loader (DataLoader): The DataLoader to wrap.
        _data_iter (iterator): An iterator over the DataLoader.
        _iteration_count (int): Number of complete iterations through the dataset.

    Example:
        >>> train_loader = DataLoader(dataset, batch_size=32)
        >>> infinite_loader = InfiniteDataLoader(train_loader)
        >>> for i, batch in enumerate(infinite_loader):
        ...     if i >= 1000:  # Train for 1000 steps
        ...         break
        ...     train_step(batch)
    """

    def __init__(self, data_loader: DataLoader, max_retries: int = 1):
        """
        Initialize the InfiniteDataLoader.

        Args:
            data_loader: The DataLoader to wrap.
            max_retries: Maximum number of retry attempts when resetting the data loader (default: 1).

        Raises:
            ValidationError: If data_loader is None or not a DataLoader instance.
        """
        if data_loader is None:
            raise ValidationError(
                "data_loader cannot be None", field="data_loader", value=data_loader
            )

        self.data_loader = data_loader
        self.max_retries = max_retries
        self._data_iter = iter(data_loader)
        self._iteration_count = 0

    def __iter__(self):
        """Reset the iterator to the beginning."""
        self._data_iter = iter(self.data_loader)
        self._iteration_count = 0
        return self

    def __next__(self):
        """
        Get the next batch, resetting to the beginning when the dataset is exhausted.

        Returns:
            The next batch from the data loader.

        Raises:
            RuntimeError: If the data loader consistently fails to produce data.
        """
        last_exception = None
        for attempt in range(self.max_retries):
            try:
                data = next(self._data_iter)
                return data
            except StopIteration:
                # Dataset exhausted or dataloader is empty, reset to beginning
                self._iteration_count += 1
                try:
                    self._data_iter = iter(self.data_loader)
                    data = next(self._data_iter)
                    return data
                except Exception as e:
                    last_exception = e
                    continue
            except Exception as e:
                # Handle other potential errors from the data loader
                raise RuntimeError(
                    f"Error retrieving data from data loader: [{type(e).__name__}]{e}"
                ) from e

        # If we get here, all attempts failed
        raise RuntimeError(
            f"Failed to retrieve data from data loader after {self.max_retries} attempts. "
            f"Last error: [{type(last_exception).__name__}]{last_exception}. "
            + (
                f"The data loader appears to be empty."
                if isinstance(last_exception, StopIteration)
                else ""
            )
        ) from last_exception

    def reset(self):
        """Manually reset the iterator to the beginning of the dataset."""
        self._data_iter = iter(self.data_loader)
        self._iteration_count = 0

    @property
    def iteration_count(self) -> int:
        """Get the number of complete iterations through the dataset."""
        return self._iteration_count

    def __len__(self) -> int:
        """
        Return the length of the underlying data loader.

        Returns:
            The number of batches in one complete iteration.
        """
        return len(self.data_loader)
iteration_count property

Get the number of complete iterations through the dataset.

__init__(data_loader, max_retries=1)

Initialize the InfiniteDataLoader.

Parameters:

  • data_loader (DataLoader) –

    The DataLoader to wrap.

  • max_retries (int, default: 1 ) –

    Maximum number of retry attempts when resetting the data loader (default: 1).

Raises:

  • ValidationError

    If data_loader is None or not a DataLoader instance.

Source code in fusion_bench/utils/data.py
def __init__(self, data_loader: DataLoader, max_retries: int = 1):
    """
    Initialize the InfiniteDataLoader.

    Args:
        data_loader: The DataLoader to wrap.
        max_retries: Maximum number of retry attempts when resetting the data loader (default: 1).

    Raises:
        ValidationError: If data_loader is None or not a DataLoader instance.
    """
    if data_loader is None:
        raise ValidationError(
            "data_loader cannot be None", field="data_loader", value=data_loader
        )

    self.data_loader = data_loader
    self.max_retries = max_retries
    self._data_iter = iter(data_loader)
    self._iteration_count = 0
__iter__()

Reset the iterator to the beginning.

Source code in fusion_bench/utils/data.py
def __iter__(self):
    """Reset the iterator to the beginning."""
    self._data_iter = iter(self.data_loader)
    self._iteration_count = 0
    return self
__len__()

Return the length of the underlying data loader.

Returns:

  • int

    The number of batches in one complete iteration.

Source code in fusion_bench/utils/data.py
def __len__(self) -> int:
    """
    Return the length of the underlying data loader.

    Returns:
        The number of batches in one complete iteration.
    """
    return len(self.data_loader)
__next__()

Get the next batch, resetting to the beginning when the dataset is exhausted.

Returns:

  • The next batch from the data loader.

Raises:

  • RuntimeError

    If the data loader consistently fails to produce data.

Source code in fusion_bench/utils/data.py
def __next__(self):
    """
    Get the next batch, resetting to the beginning when the dataset is exhausted.

    Returns:
        The next batch from the data loader.

    Raises:
        RuntimeError: If the data loader consistently fails to produce data.
    """
    last_exception = None
    for attempt in range(self.max_retries):
        try:
            data = next(self._data_iter)
            return data
        except StopIteration:
            # Dataset exhausted or dataloader is empty, reset to beginning
            self._iteration_count += 1
            try:
                self._data_iter = iter(self.data_loader)
                data = next(self._data_iter)
                return data
            except Exception as e:
                last_exception = e
                continue
        except Exception as e:
            # Handle other potential errors from the data loader
            raise RuntimeError(
                f"Error retrieving data from data loader: [{type(e).__name__}]{e}"
            ) from e

    # If we get here, all attempts failed
    raise RuntimeError(
        f"Failed to retrieve data from data loader after {self.max_retries} attempts. "
        f"Last error: [{type(last_exception).__name__}]{last_exception}. "
        + (
            f"The data loader appears to be empty."
            if isinstance(last_exception, StopIteration)
            else ""
        )
    ) from last_exception
reset()

Manually reset the iterator to the beginning of the dataset.

Source code in fusion_bench/utils/data.py
def reset(self):
    """Manually reset the iterator to the beginning of the dataset."""
    self._data_iter = iter(self.data_loader)
    self._iteration_count = 0

load_tensor_from_file(file_path, device=None)

Loads a tensor from a file, which can be either a .pt, .pth or .np file. If the file is not one of these formats, it will try to load it as a pickle file.

Parameters:

  • file_path (str) –

    The path to the file to load.

  • device (Optional[Union[str, device]], default: None ) –

    The device to move the tensor to. By default the tensor is loaded on the CPU.

Returns:

  • Tensor

    torch.Tensor: The tensor loaded from the file.

Raises:

  • ValidationError

    If the file doesn't exist

  • ValueError

    If the file format is unsupported

Source code in fusion_bench/utils/data.py
def load_tensor_from_file(
    file_path: Union[str, Path], device: Optional[Union[str, torch.device]] = None
) -> torch.Tensor:
    """
    Loads a tensor from a file, which can be either a .pt, .pth or .np file.
    If the file is not one of these formats, it will try to load it as a pickle file.

    Args:
        file_path (str): The path to the file to load.
        device: The device to move the tensor to. By default the tensor is loaded on the CPU.

    Returns:
        torch.Tensor: The tensor loaded from the file.

    Raises:
        ValidationError: If the file doesn't exist
        ValueError: If the file format is unsupported
    """
    # Validate file exists
    validate_file_exists(file_path)

    if file_path.endswith(".np"):
        tensor = torch.from_numpy(np.load(file_path)).detach_()
    if file_path.endswith((".pt", ".pth")):
        tensor = torch.load(file_path, map_location="cpu").detach_()
    else:
        try:
            tensor = pickle.load(open(file_path, "rb"))
        except Exception:
            raise ValueError(f"Unsupported file format: {file_path}")

    # Move tensor to device
    assert isinstance(tensor, torch.Tensor), f"Expected tensor, got {type(tensor)}"
    if device is not None:
        tensor = tensor.to(device=device)
    return tensor

train_validation_split(dataset, validation_fraction=0.1, validation_size=None, random_seed=None, return_split='both')

Split a dataset into a training and validation set.

Parameters:

  • dataset (Dataset) –

    The dataset to split.

  • validation_fraction (Optional[float], default: 0.1 ) –

    The fraction of the dataset to use for validation.

  • validation_size (Optional[int], default: None ) –

    The number of samples to use for validation. validation_fraction must be set to None if this is provided.

  • random_seed (Optional[int], default: None ) –

    The random seed to use for reproducibility.

  • return_split (Literal['all', 'train', 'val'], default: 'both' ) –

    The split to return.

Returns:

  • Union[Tuple[Dataset, Dataset], Dataset]

    Tuple[Dataset, Dataset]: The training and validation datasets.

Source code in fusion_bench/utils/data.py
def train_validation_split(
    dataset: Dataset,
    validation_fraction: Optional[float] = 0.1,
    validation_size: Optional[int] = None,
    random_seed: Optional[int] = None,
    return_split: Literal["all", "train", "val"] = "both",
) -> Union[Tuple[Dataset, Dataset], Dataset]:
    """
    Split a dataset into a training and validation set.

    Args:
        dataset (Dataset): The dataset to split.
        validation_fraction (Optional[float]): The fraction of the dataset to use for validation.
        validation_size (Optional[int]): The number of samples to use for validation. `validation_fraction` must be set to `None` if this is provided.
        random_seed (Optional[int]): The random seed to use for reproducibility.
        return_split (Literal["all", "train", "val"]): The split to return.

    Returns:
        Tuple[Dataset, Dataset]: The training and validation datasets.
    """
    # Check the input arguments
    assert (
        validation_fraction is None or validation_size is None
    ), "Only one of validation_fraction and validation_size can be provided"
    assert (
        validation_fraction is not None or validation_size is not None
    ), "Either validation_fraction or validation_size must be provided"

    # Compute the number of samples for training and validation
    num_samples = len(dataset)
    if validation_size is None:
        assert (
            0 < validation_fraction < 1
        ), "Validation fraction must be between 0 and 1"
        num_validation_samples = int(num_samples * validation_fraction)
        num_training_samples = num_samples - num_validation_samples
    else:
        assert (
            validation_size < num_samples
        ), "Validation size must be less than num_samples"
        num_validation_samples = validation_size
        num_training_samples = num_samples - num_validation_samples

    # Split the dataset
    generator = (
        torch.Generator().manual_seed(random_seed) if random_seed is not None else None
    )
    training_dataset, validation_dataset = torch.utils.data.random_split(
        dataset, [num_training_samples, num_validation_samples], generator=generator
    )

    # return the split as requested
    if return_split == "all":
        return training_dataset, validation_dataset
    elif return_split == "train":
        return training_dataset
    elif return_split == "val":
        return validation_dataset
    else:
        raise ValueError(f"Invalid return_split: {return_split}")

train_validation_test_split(dataset, validation_fraction, test_fraction, random_seed=None, return_spilt='all')

Split a dataset into a training, validation and test set.

Parameters:

  • dataset (Dataset) –

    The dataset to split.

  • validation_fraction (float) –

    The fraction of the dataset to use for validation.

  • test_fraction (float) –

    The fraction of the dataset to use for test.

  • random_seed (Optional[int], default: None ) –

    The random seed to use for reproducibility.

  • return_spilt (Literal['all', 'train', 'val', 'test'], default: 'all' ) –

    The split to return.

Returns:

  • Union[Tuple[Dataset, Dataset, Dataset], Dataset]

    Tuple[Dataset, Dataset, Dataset]: The training, validation and test datasets.

Source code in fusion_bench/utils/data.py
def train_validation_test_split(
    dataset: Dataset,
    validation_fraction: float,
    test_fraction: float,
    random_seed: Optional[int] = None,
    return_spilt: Literal["all", "train", "val", "test"] = "all",
) -> Union[Tuple[Dataset, Dataset, Dataset], Dataset]:
    """
    Split a dataset into a training, validation and test set.

    Args:
        dataset (Dataset): The dataset to split.
        validation_fraction (float): The fraction of the dataset to use for validation.
        test_fraction (float): The fraction of the dataset to use for test.
        random_seed (Optional[int]): The random seed to use for reproducibility.
        return_spilt (Literal["all", "train", "val", "test"]): The split to return.

    Returns:
        Tuple[Dataset, Dataset, Dataset]: The training, validation and test datasets.
    """
    num_samples = len(dataset)
    assert 0 < validation_fraction < 1, "Validation fraction must be between 0 and 1"
    assert 0 < test_fraction < 1, "Test fraction must be between 0 and 1"
    generaotr = (
        torch.Generator().manual_seed(random_seed) if random_seed is not None else None
    )

    num_validation_samples = int(num_samples * validation_fraction)
    num_test_samples = int(num_samples * test_fraction)
    num_training_samples = num_samples - num_validation_samples - num_test_samples
    training_dataset, validation_dataset, test_dataset = torch.utils.data.random_split(
        dataset,
        [num_training_samples, num_validation_samples, num_test_samples],
        generator=generaotr,
    )

    # return the split as requested
    if return_spilt == "all":
        return training_dataset, validation_dataset, test_dataset
    elif return_spilt == "train":
        return training_dataset
    elif return_spilt == "val":
        return validation_dataset
    elif return_spilt == "test":
        return test_dataset
    else:
        raise ValueError(f"Invalid return_split: {return_spilt}")

Json Import/Export

fusion_bench.utils.json

load_from_json(path, filesystem=None)

load an object from a json file

Parameters:

  • path (Union[str, Path]) –

    the path to load the object

  • filesystem (FileSystem, default: None ) –

    PyArrow FileSystem to use for reading. If None, uses local filesystem via standard Python open(). Can also be an s3fs.S3FileSystem or fsspec filesystem.

Returns:

  • Union[dict, list]

    Union[dict, list]: the loaded object

Raises:

  • ValidationError

    If the file doesn't exist (when using local filesystem)

Source code in fusion_bench/utils/json.py
def load_from_json(
    path: Union[str, Path], filesystem: "FileSystem" = None
) -> Union[dict, list]:
    """load an object from a json file

    Args:
        path (Union[str, Path]): the path to load the object
        filesystem (FileSystem, optional): PyArrow FileSystem to use for reading.
            If None, uses local filesystem via standard Python open().
            Can also be an s3fs.S3FileSystem or fsspec filesystem.

    Returns:
        Union[dict, list]: the loaded object

    Raises:
        ValidationError: If the file doesn't exist (when using local filesystem)
    """
    if filesystem is not None:
        # Check if it's an fsspec-based filesystem (like s3fs)
        if hasattr(filesystem, "open"):
            # Direct fsspec/s3fs usage
            path_str = str(path)
            with filesystem.open(path_str, "r") as f:
                return json.load(f)
        else:
            # Use PyArrow filesystem
            path_str = str(path)
            with filesystem.open_input_stream(path_str) as f:
                json_data = f.read().decode("utf-8")
                return json.loads(json_data)
    else:
        # Use standard Python file operations
        validate_file_exists(path)
        with open(path, "r") as f:
            return json.load(f)

print_json(j, indent=' ', verbose=False, print_type=True)

print an overview of json file

Examples:

>>> print_json(open('path_to_json', 'r'))

Parameters:

  • j (dict) –

    loaded json file

  • indent (str, default: ' ' ) –

    Defaults to ' '.

Source code in fusion_bench/utils/json.py
def print_json(j: dict, indent="  ", verbose: bool = False, print_type: bool = True):
    R"""print an overview of json file

    Examples:
        >>> print_json(open('path_to_json', 'r'))

    Args:
        j (dict): loaded json file
        indent (str, optional): Defaults to '  '.
    """

    def _print_json(j: dict, level):
        def _sprint(s):
            return indent * level + s

        for k in j.keys():
            if isinstance(j[k], dict):
                print(_sprint(k) + ":")
                _print_json(j[k], level + 1)
            elif _is_list_of_dict(j[k]):
                if verbose:
                    print(_sprint(k) + ": [")
                    for i in range(len(j[k]) - 1):
                        _print_json(j[k][0], level + 2)
                        print(_sprint(f"{indent},"))
                    _print_json(j[k][-1], level + 2)
                    print(_sprint(f"{indent}]"))
                else:
                    print(_sprint(k) + ": [")
                    _print_json(j[k][0], level + 2)
                    print(_sprint(f"{indent}] ... {len(j[k]) - 1} more"))
            else:
                if print_type:
                    print(f"{_sprint(k)}: {_sprint_json_entry(j[k])}")
                else:
                    print(f"{_sprint(k)}: {j[k]}")

    _print_json(j, level=0)

save_to_json(obj, path, filesystem=None)

save an object to a json file

Parameters:

  • obj (Any) –

    the object to save

  • path (Union[str, Path]) –

    the path to save the object

  • filesystem (FileSystem, default: None ) –

    PyArrow FileSystem to use for writing. If None, uses local filesystem via standard Python open(). Can also be an s3fs.S3FileSystem or fsspec filesystem.

Source code in fusion_bench/utils/json.py
def save_to_json(obj, path: Union[str, Path], filesystem: "FileSystem" = None):
    """
    save an object to a json file

    Args:
        obj (Any): the object to save
        path (Union[str, Path]): the path to save the object
        filesystem (FileSystem, optional): PyArrow FileSystem to use for writing.
            If None, uses local filesystem via standard Python open().
            Can also be an s3fs.S3FileSystem or fsspec filesystem.
    """
    if filesystem is not None:
        json_str = json.dumps(obj)
        # Check if it's an fsspec-based filesystem (like s3fs)
        if hasattr(filesystem, "open"):
            # Direct fsspec/s3fs usage - more reliable for some endpoints
            path_str = str(path)
            with filesystem.open(path_str, "w") as f:
                f.write(json_str)
        else:
            # Use PyArrow filesystem
            path_str = str(path)
            with filesystem.open_output_stream(path_str) as f:
                f.write(json_str.encode("utf-8"))
    else:
        # Use standard Python file operations
        with open(path, "w") as f:
            json.dump(obj, f)

TensorBoard Data Import

fusion_bench.utils.tensorboard

functions deal with tensorboard logs.

parse_tensorboard_as_dict(path, scalars)

returns a dictionary of pandas dataframes for each requested scalar.

Parameters:

  • path (str) –

    A file path to a directory containing tf events files, or a single tf events file. The accumulator will load events from this path.

  • scalars (Iterable[str]) –

    scalars

Returns:

  • Dict[str, pandas.DataFrame]: a dictionary of pandas dataframes for each requested scalar

Source code in fusion_bench/utils/tensorboard.py
def parse_tensorboard_as_dict(path: str, scalars: Iterable[str]):
    """
    returns a dictionary of pandas dataframes for each requested scalar.

    Args:
        path(str): A file path to a directory containing tf events files, or a single
                   tf events file. The accumulator will load events from this path.
        scalars:   scalars

    Returns:
        Dict[str, pandas.DataFrame]: a dictionary of pandas dataframes for each requested scalar
    """
    ea = event_accumulator.EventAccumulator(
        path,
        size_guidance={event_accumulator.SCALARS: 0},
    )
    _absorb_print = ea.Reload()
    # make sure the scalars are in the event accumulator tags
    assert all(
        s in ea.Tags()["scalars"] for s in scalars
    ), "some scalars were not found in the event accumulator"
    return {k: pd.DataFrame(ea.Scalars(k)) for k in scalars}

parse_tensorboard_as_list(path, scalars)

returns a list of pandas dataframes for each requested scalar.

see also: 🇵🇾func:parse_tensorboard_as_dict

Parameters:

  • path (str) –

    A file path to a directory containing tf events files, or a single tf events file. The accumulator will load events from this path.

  • scalars (Iterable[str]) –

    scalars

Returns:

  • List[pandas.DataFrame]: a list of pandas dataframes for each requested scalar.

Source code in fusion_bench/utils/tensorboard.py
def parse_tensorboard_as_list(path: str, scalars: Iterable[str]):
    """
    returns a list of pandas dataframes for each requested scalar.

    see also: :py:func:`parse_tensorboard_as_dict`

    Args:
        path(str): A file path to a directory containing tf events files, or a single
                   tf events file. The accumulator will load events from this path.
        scalars:   scalars

    Returns:
        List[pandas.DataFrame]: a list of pandas dataframes for each requested scalar.
    """
    d = parse_tensorboard_as_dict(path, scalars)
    return [d[s] for s in scalars]