Skip to content

fusion_bench.dataset

NYUv2 Dataset

fusion_bench.dataset.nyuv2.NYUv2

Bases: Dataset

NYUv2 dataset, 3 tasks + 1 generated useless task Included tasks:

1. Semantic Segmentation,
2. Depth prediction,
3. Surface Normal prediction,
4. Noise prediction [to test auxiliary learning, purely conflict gradients]

Modified from https://github.com/lorenmt/auto-lambda/blob/main/create_dataset.py

removed the augmentation arg and add transform args

Source code in fusion_bench/dataset/nyuv2.py
class NYUv2(Dataset):
    R"""
    NYUv2 dataset, 3 tasks + 1 generated useless task
    Included tasks:

        1. Semantic Segmentation,
        2. Depth prediction,
        3. Surface Normal prediction,
        4. Noise prediction [to test auxiliary learning, purely conflict gradients]

    Modified from https://github.com/lorenmt/auto-lambda/blob/main/create_dataset.py

    removed the `augmentation` arg and add `transform` args
    """

    num_out_channels = {
        "segmentation": 13,
        "depth": 1,
        "normal": 3,
        "noise": 1,
    }

    def __init__(
        self,
        root: str,
        train: bool = True,
        transform: Optional[Callable] = None,
        seg_transform: Optional[Callable] = None,
        sn_transform: Optional[Callable] = None,
        depth_transform: Optional[Callable] = None,
    ):
        """
        Initialize the NYUv2 dataset.

        Args:
            root (str): The root directory of the dataset.
            train (bool, optional): If True, use training set. If False, use validation set. Defaults to True.
            transform (Callable, optional): image transform. Defaults to None.
            seg_transform (Callable, optional): segmentation transform. Defaults to None.
            sn_transform (Callable, optional): surface normal transform. Defaults to None.
            depth_transform (Callable, optional): depth transform. Defaults to None.
        """
        self.root = os.path.expanduser(root)
        self.train = train

        self.transform = transform
        self.seg_transform = seg_transform
        self.sn_transform = sn_transform
        self.depth_transform = depth_transform

        if train:
            self.data_path = self.root + "/train"
        else:
            self.data_path = self.root + "/val"

        # calculate data length
        self.data_len = len(
            fnmatch.filter(os.listdir(self.data_path + "/image"), "*.npy")
        )
        self.noise = torch.rand(self.data_len, 1, 288, 384)

    def __getitem__(self, index):
        """
        Retrieve an item from the dataset.

        Args:
            index (int): The index of the item to retrieve.

        Returns:
            tuple: A tuple containing the image and a dictionary of task-specific outputs.
        """
        # load data from the pre-processed npy files
        image = torch.from_numpy(
            np.moveaxis(
                np.load(self.data_path + "/image/{:d}.npy".format(index)), -1, 0
            )
        ).float()
        semantic = torch.from_numpy(
            np.load(self.data_path + "/label/{:d}.npy".format(index))
        ).float()
        depth = torch.from_numpy(
            np.moveaxis(
                np.load(self.data_path + "/depth/{:d}.npy".format(index)), -1, 0
            )
        ).float()
        normal = torch.from_numpy(
            np.moveaxis(
                np.load(self.data_path + "/normal/{:d}.npy".format(index)), -1, 0
            )
        ).float()
        noise = self.noise[index].float()

        if self.transform is not None:
            image = self.transform(image)
        if self.seg_transform is not None:
            semantic = self.seg_transform(semantic)
        if self.sn_transform is not None:
            normal = self.sn_transform(normal)
        if self.depth_transform is not None:
            depth = self.depth_transform(depth)

        return image, {
            "segmentation": semantic,
            "depth": depth,
            "normal": normal,
            "noise": noise,
        }

    def __len__(self):
        return self.data_len

__getitem__(index)

Retrieve an item from the dataset.

Parameters:

  • index (int) –

    The index of the item to retrieve.

Returns:

  • tuple

    A tuple containing the image and a dictionary of task-specific outputs.

Source code in fusion_bench/dataset/nyuv2.py
def __getitem__(self, index):
    """
    Retrieve an item from the dataset.

    Args:
        index (int): The index of the item to retrieve.

    Returns:
        tuple: A tuple containing the image and a dictionary of task-specific outputs.
    """
    # load data from the pre-processed npy files
    image = torch.from_numpy(
        np.moveaxis(
            np.load(self.data_path + "/image/{:d}.npy".format(index)), -1, 0
        )
    ).float()
    semantic = torch.from_numpy(
        np.load(self.data_path + "/label/{:d}.npy".format(index))
    ).float()
    depth = torch.from_numpy(
        np.moveaxis(
            np.load(self.data_path + "/depth/{:d}.npy".format(index)), -1, 0
        )
    ).float()
    normal = torch.from_numpy(
        np.moveaxis(
            np.load(self.data_path + "/normal/{:d}.npy".format(index)), -1, 0
        )
    ).float()
    noise = self.noise[index].float()

    if self.transform is not None:
        image = self.transform(image)
    if self.seg_transform is not None:
        semantic = self.seg_transform(semantic)
    if self.sn_transform is not None:
        normal = self.sn_transform(normal)
    if self.depth_transform is not None:
        depth = self.depth_transform(depth)

    return image, {
        "segmentation": semantic,
        "depth": depth,
        "normal": normal,
        "noise": noise,
    }

__init__(root, train=True, transform=None, seg_transform=None, sn_transform=None, depth_transform=None)

Initialize the NYUv2 dataset.

Parameters:

  • root (str) –

    The root directory of the dataset.

  • train (bool, default: True ) –

    If True, use training set. If False, use validation set. Defaults to True.

  • transform (Callable, default: None ) –

    image transform. Defaults to None.

  • seg_transform (Callable, default: None ) –

    segmentation transform. Defaults to None.

  • sn_transform (Callable, default: None ) –

    surface normal transform. Defaults to None.

  • depth_transform (Callable, default: None ) –

    depth transform. Defaults to None.

Source code in fusion_bench/dataset/nyuv2.py
def __init__(
    self,
    root: str,
    train: bool = True,
    transform: Optional[Callable] = None,
    seg_transform: Optional[Callable] = None,
    sn_transform: Optional[Callable] = None,
    depth_transform: Optional[Callable] = None,
):
    """
    Initialize the NYUv2 dataset.

    Args:
        root (str): The root directory of the dataset.
        train (bool, optional): If True, use training set. If False, use validation set. Defaults to True.
        transform (Callable, optional): image transform. Defaults to None.
        seg_transform (Callable, optional): segmentation transform. Defaults to None.
        sn_transform (Callable, optional): surface normal transform. Defaults to None.
        depth_transform (Callable, optional): depth transform. Defaults to None.
    """
    self.root = os.path.expanduser(root)
    self.train = train

    self.transform = transform
    self.seg_transform = seg_transform
    self.sn_transform = sn_transform
    self.depth_transform = depth_transform

    if train:
        self.data_path = self.root + "/train"
    else:
        self.data_path = self.root + "/val"

    # calculate data length
    self.data_len = len(
        fnmatch.filter(os.listdir(self.data_path + "/image"), "*.npy")
    )
    self.noise = torch.rand(self.data_len, 1, 288, 384)

Image Classification Tasks

fusion_bench.dataset.clip_dataset.CLIPDataset

Bases: Dataset

A dataset class for CLIP models that converts a dataset of dictionaries or tuples into a format suitable for CLIP processing.

This class wraps an existing dataset and applies CLIP preprocessing to the images. It expects each item in the dataset to be either a dictionary with 'image' and 'label' keys, or a tuple/list of (image, label).

Parameters:

  • dataset

    The original dataset to wrap.

  • processor (CLIPProcessor, default: None ) –

    The CLIP processor for preparing inputs. If None, no preprocessing is applied and raw images are returned.

Attributes:

  • dataset

    The wrapped dataset.

  • processor (CLIPProcessor) –

    The CLIP processor used for image preprocessing.

Source code in fusion_bench/dataset/clip_dataset.py
class CLIPDataset(torch.utils.data.Dataset):
    """
    A dataset class for CLIP models that converts a dataset of dictionaries or tuples
    into a format suitable for CLIP processing.

    This class wraps an existing dataset and applies CLIP preprocessing to the images.
    It expects each item in the dataset to be either a dictionary with 'image' and 'label' keys,
    or a tuple/list of (image, label).

    Args:
        dataset: The original dataset to wrap.
        processor (CLIPProcessor): The CLIP processor for preparing inputs. If None, no preprocessing is applied and raw images are returned.

    Attributes:
        dataset: The wrapped dataset.
        processor (CLIPProcessor): The CLIP processor used for image preprocessing.
    """

    def __init__(self, dataset, processor: Optional[CLIPProcessor] = None):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        """Returns the number of items in the dataset."""
        return len(self.dataset)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        """
        Retrieves and processes an item from the dataset.

        Args:
            idx (int): The index of the item to retrieve.

        Returns:
            tuple: A tuple containing the processed image tensor and the label.

        Raises:
            ValueError: If the item is neither a dictionary nor a tuple/list of length 2.
        """
        item = self.dataset[idx]
        if isinstance(item, dict):
            item = item
        elif isinstance(item, (tuple, list)):
            assert len(item) == 2, "Each item should be a tuple or list of length 2"
            item = {"image": item[0], "label": item[1]}
        else:
            raise ValueError("Each item should be a dictionary or a tuple of length 2")
        image = item["image"]
        if self.processor is not None:
            if isinstance(self.processor, ProcessorMixin):
                # Apply the processor to the image to get the input tensor
                inputs = self.processor(images=[image], return_tensors="pt")[
                    "pixel_values"
                ][0]
            elif callable(self.processor):
                inputs = self.processor(image)
            else:
                raise ValueError(
                    "The processor should be a CLIPProcessor or a callable function"
                )
        else:
            # if processor is None, return the raw image directly
            inputs = image
        # convert boolean label to int, this is for the case when the label is a binary classification task
        if isinstance(item["label"], bool):
            item["label"] = 1 if item["label"] else 0
        return inputs, item["label"]
__getitem__(idx)

Retrieves and processes an item from the dataset.

Parameters:

  • idx (int) –

    The index of the item to retrieve.

Returns:

  • tuple ( Tuple[Tensor, int] ) –

    A tuple containing the processed image tensor and the label.

Raises:

  • ValueError

    If the item is neither a dictionary nor a tuple/list of length 2.

Source code in fusion_bench/dataset/clip_dataset.py
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
    """
    Retrieves and processes an item from the dataset.

    Args:
        idx (int): The index of the item to retrieve.

    Returns:
        tuple: A tuple containing the processed image tensor and the label.

    Raises:
        ValueError: If the item is neither a dictionary nor a tuple/list of length 2.
    """
    item = self.dataset[idx]
    if isinstance(item, dict):
        item = item
    elif isinstance(item, (tuple, list)):
        assert len(item) == 2, "Each item should be a tuple or list of length 2"
        item = {"image": item[0], "label": item[1]}
    else:
        raise ValueError("Each item should be a dictionary or a tuple of length 2")
    image = item["image"]
    if self.processor is not None:
        if isinstance(self.processor, ProcessorMixin):
            # Apply the processor to the image to get the input tensor
            inputs = self.processor(images=[image], return_tensors="pt")[
                "pixel_values"
            ][0]
        elif callable(self.processor):
            inputs = self.processor(image)
        else:
            raise ValueError(
                "The processor should be a CLIPProcessor or a callable function"
            )
    else:
        # if processor is None, return the raw image directly
        inputs = image
    # convert boolean label to int, this is for the case when the label is a binary classification task
    if isinstance(item["label"], bool):
        item["label"] = 1 if item["label"] else 0
    return inputs, item["label"]
__len__()

Returns the number of items in the dataset.

Source code in fusion_bench/dataset/clip_dataset.py
def __len__(self):
    """Returns the number of items in the dataset."""
    return len(self.dataset)

fusion_bench.dataset.image_dataset.TransformedImageDataset

Bases: Dataset

A dataset class for image classification tasks that applies a transform to images.

This class wraps an existing dataset and applies a specified transform to the images. It expects each item in the dataset to be either a dictionary with 'image' and 'label' keys, or a tuple/list of (image, label).

Parameters:

  • dataset

    The original dataset to wrap.

  • transform (Callable) –

    A function/transform to apply on the image.

Attributes:

  • dataset

    The wrapped dataset.

  • transform (Callable) –

    The transform to be applied to the images.

Source code in fusion_bench/dataset/image_dataset.py
class TransformedImageDataset(Dataset):
    """
    A dataset class for image classification tasks that applies a transform to images.

    This class wraps an existing dataset and applies a specified transform to the images.
    It expects each item in the dataset to be either a dictionary with 'image' and 'label' keys,
    or a tuple/list of (image, label).

    Args:
        dataset: The original dataset to wrap.
        transform (Callable): A function/transform to apply on the image.

    Attributes:
        dataset: The wrapped dataset.
        transform (Callable): The transform to be applied to the images.
    """

    def __init__(self, dataset, transform: Callable):
        super().__init__()
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        """Returns the number of items in the dataset."""
        return len(self.dataset)

    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
        """
        Retrieves and processes an item from the dataset.

        Args:
            idx (int): The index of the item to retrieve.

        Returns:
            tuple: A tuple containing the processed image and the label.

        Raises:
            ValueError: If the item is neither a dictionary nor a tuple/list of length 2.
        """
        item = self.dataset[idx]
        if isinstance(item, dict):
            item = item
        elif isinstance(item, (tuple, list)):
            assert len(item) == 2, "Each item should be a tuple or list of length 2"
            item = {"image": item[0], "label": item[1]}
        else:
            raise ValueError("Each item should be a dictionary or a tuple of length 2")
        image = item["image"]
        inputs = self.transform(image)
        return inputs, item["label"]
__getitem__(idx)

Retrieves and processes an item from the dataset.

Parameters:

  • idx (int) –

    The index of the item to retrieve.

Returns:

  • tuple ( Tuple[Any, Any] ) –

    A tuple containing the processed image and the label.

Raises:

  • ValueError

    If the item is neither a dictionary nor a tuple/list of length 2.

Source code in fusion_bench/dataset/image_dataset.py
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
    """
    Retrieves and processes an item from the dataset.

    Args:
        idx (int): The index of the item to retrieve.

    Returns:
        tuple: A tuple containing the processed image and the label.

    Raises:
        ValueError: If the item is neither a dictionary nor a tuple/list of length 2.
    """
    item = self.dataset[idx]
    if isinstance(item, dict):
        item = item
    elif isinstance(item, (tuple, list)):
        assert len(item) == 2, "Each item should be a tuple or list of length 2"
        item = {"image": item[0], "label": item[1]}
    else:
        raise ValueError("Each item should be a dictionary or a tuple of length 2")
    image = item["image"]
    inputs = self.transform(image)
    return inputs, item["label"]
__len__()

Returns the number of items in the dataset.

Source code in fusion_bench/dataset/image_dataset.py
def __len__(self):
    """Returns the number of items in the dataset."""
    return len(self.dataset)

GPT-2 on GLUE Benchmark

fusion_bench.dataset.gpt2_glue.TokenizedGLUE

A class to load and cache GLUE datasets for GPT-2 models.

This class provides methods to load various GLUE datasets and tokenize them using a provided tokenizer. The datasets are cached to disk to avoid reloading and tokenizing them multiple times.

Attributes:

  • tokenizer (PreTrainedTokenizer) –

    The tokenizer to use for tokenizing the datasets.

Source code in fusion_bench/dataset/gpt2_glue.py
class TokenizedGLUE:
    """
    A class to load and cache GLUE datasets for GPT-2 models.

    This class provides methods to load various GLUE datasets and tokenize them
    using a provided tokenizer. The datasets are cached to disk to avoid
    reloading and tokenizing them multiple times.

    Attributes:
        tokenizer (PreTrainedTokenizer): The tokenizer to use for tokenizing the datasets.
    """

    def __init__(self, tokenizer: PreTrainedTokenizer):
        """
        Initialize the TokenizedGLUE class with a tokenizer.

        Args:
            tokenizer (PreTrainedTokenizer): The tokenizer to use for tokenizing the datasets.
        """
        super().__init__()
        self.tokenizer = tokenizer

    def load_dataset(
        self, name: Literal["mrpc", "mnli", "cola", "sst2", "qnli", "qqp", "rte"]
    ):
        """
        Load and tokenize a GLUE dataset.

        This method loads a specified GLUE dataset, tokenizes it using the provided
        tokenizer, and caches the tokenized dataset to disk.

        Args:
            name (Literal["mrpc", "mnli", "cola", "sst2", "qnli", "qqp", "rte"]): The name of the GLUE dataset to load.

        Returns:
            Dataset: The tokenized GLUE dataset.
        """
        glue_dataset_loaders = {
            "mrpc": self.load_mrpc_dataset,
            "mnli": self.load_mnli_dataset,
            "cola": self.load_cola_dataset,
            "sst2": self.load_sst2_dataset,
            "qnli": self.load_qnli_dataset,
            "qqp": self.load_qqp_dataset,
            "rte": self.load_rte_dataset,
            # "wnli": load_wnli_dataset,
        }
        return glue_dataset_loaders[name]()

    @cache_dataset
    def load_mrpc_dataset(self):
        """
        Load and tokenize the MRPC dataset.

        This method loads the MRPC dataset, tokenizes it using the provided
        tokenizer, and caches the tokenized dataset to disk.

        Returns:
            Dataset: The tokenized MRPC dataset.
        """
        dataset = load_dataset("glue", "mrpc")
        dataset = dataset.map(
            partial(mrpc_tokenize_function, tokenizer=self.tokenizer),
            batched=True,
            remove_columns=["sentence1", "sentence2"],
        )
        return dataset

    @cache_dataset
    def load_rte_dataset(self):
        """
        Load and tokenize the RTE dataset.

        This method loads the RTE dataset, tokenizes it using the provided
        tokenizer, and caches the tokenized dataset to disk.

        Returns:
            Dataset: The tokenized RTE dataset.
        """
        dataset = load_dataset("glue", "rte")
        dataset = dataset.map(
            # RTE has the same format as MRPC
            partial(mrpc_tokenize_function, tokenizer=self.tokenizer),
            batched=True,
            remove_columns=["sentence1", "sentence2"],
        )
        return dataset

    @cache_dataset
    def load_wnli_dataset(self):
        """
        Load and tokenize the WNLI dataset.

        This method loads the WNLI dataset, tokenizes it using the provided
        tokenizer, and caches the tokenized dataset to disk.

        Returns:
            Dataset: The tokenized WNLI dataset.
        """
        dataset = load_dataset("glue", "wnli")
        dataset = dataset.map(
            partial(mrpc_tokenize_function, tokenizer=self.tokenizer),
            batched=True,
            remove_columns=["sentence1", "sentence2"],
        )
        return dataset

    @cache_dataset
    def load_qqp_dataset(self):
        """
        Load and tokenize the QQP dataset.

        This method loads the QQP dataset, tokenizes it using the provided
        tokenizer, and caches the tokenized dataset to disk.

        Returns:
            Dataset: The tokenized QQP dataset.
        """
        dataset = load_dataset("glue", "qqp")
        dataset = dataset.map(
            partial(qqp_tokenize_function, tokenizer=self.tokenizer),
            batched=True,
            remove_columns=["question1", "question2"],
        )
        return dataset

    @cache_dataset
    def load_mnli_dataset(self):
        """
        Load and tokenize the MNLI dataset.

        This method loads the MNLI dataset, tokenizes it using the provided
        tokenizer, and caches the tokenized dataset to disk.

        Returns:
            Dataset: The tokenized MNLI dataset.
        """
        dataset = load_dataset("glue", "mnli")
        dataset = dataset.map(
            partial(mnli_tokenize_function, tokenizer=self.tokenizer),
            batched=True,
            remove_columns=["premise", "hypothesis"],
        )
        return dataset

    @cache_dataset
    def load_cola_dataset(self):
        """
        Load and tokenize the CoLA dataset.

        This method loads the CoLA dataset, tokenizes it using the provided
        tokenizer, and caches the tokenized dataset to disk.

        Returns:
            Dataset: The tokenized CoLA dataset.
        """
        dataset = load_dataset("glue", "cola")
        dataset = dataset.map(
            partial(cola_tokenize_function, tokenizer=self.tokenizer),
            batched=True,
            remove_columns=["sentence"],
        )
        return dataset

    @cache_dataset
    def load_sst2_dataset(self):
        """
        Load and tokenize the SST-2 dataset.

        This method loads the SST-2 dataset, tokenizes it using the provided
        tokenizer, and caches the tokenized dataset to disk.

        Returns:
            Dataset: The tokenized SST-2 dataset.
        """
        dataset = load_dataset("glue", "sst2")
        dataset = dataset.map(
            partial(cola_tokenize_function, tokenizer=self.tokenizer),
            batched=True,
            remove_columns=["sentence"],
        )
        return dataset

    @cache_dataset
    def load_qnli_dataset(self):
        """
        Load and tokenize the QNLI dataset.

        This method loads the QNLI dataset, tokenizes it using the provided
        tokenizer, and caches the tokenized dataset to disk.

        Returns:
            Dataset: The tokenized QNLI dataset.
        """
        dataset = load_dataset("glue", "qnli")
        dataset = dataset.map(
            partial(qnli_tokenize_function, tokenizer=self.tokenizer),
            batched=True,
            remove_columns=["question", "sentence"],
        )
        return dataset

__init__(tokenizer)

Initialize the TokenizedGLUE class with a tokenizer.

Parameters:

  • tokenizer (PreTrainedTokenizer) –

    The tokenizer to use for tokenizing the datasets.

Source code in fusion_bench/dataset/gpt2_glue.py
def __init__(self, tokenizer: PreTrainedTokenizer):
    """
    Initialize the TokenizedGLUE class with a tokenizer.

    Args:
        tokenizer (PreTrainedTokenizer): The tokenizer to use for tokenizing the datasets.
    """
    super().__init__()
    self.tokenizer = tokenizer

load_cola_dataset()

Load and tokenize the CoLA dataset.

This method loads the CoLA dataset, tokenizes it using the provided tokenizer, and caches the tokenized dataset to disk.

Returns:

  • Dataset

    The tokenized CoLA dataset.

Source code in fusion_bench/dataset/gpt2_glue.py
@cache_dataset
def load_cola_dataset(self):
    """
    Load and tokenize the CoLA dataset.

    This method loads the CoLA dataset, tokenizes it using the provided
    tokenizer, and caches the tokenized dataset to disk.

    Returns:
        Dataset: The tokenized CoLA dataset.
    """
    dataset = load_dataset("glue", "cola")
    dataset = dataset.map(
        partial(cola_tokenize_function, tokenizer=self.tokenizer),
        batched=True,
        remove_columns=["sentence"],
    )
    return dataset

load_dataset(name)

Load and tokenize a GLUE dataset.

This method loads a specified GLUE dataset, tokenizes it using the provided tokenizer, and caches the tokenized dataset to disk.

Parameters:

  • name (Literal['mrpc', 'mnli', 'cola', 'sst2', 'qnli', 'qqp', 'rte']) –

    The name of the GLUE dataset to load.

Returns:

  • Dataset

    The tokenized GLUE dataset.

Source code in fusion_bench/dataset/gpt2_glue.py
def load_dataset(
    self, name: Literal["mrpc", "mnli", "cola", "sst2", "qnli", "qqp", "rte"]
):
    """
    Load and tokenize a GLUE dataset.

    This method loads a specified GLUE dataset, tokenizes it using the provided
    tokenizer, and caches the tokenized dataset to disk.

    Args:
        name (Literal["mrpc", "mnli", "cola", "sst2", "qnli", "qqp", "rte"]): The name of the GLUE dataset to load.

    Returns:
        Dataset: The tokenized GLUE dataset.
    """
    glue_dataset_loaders = {
        "mrpc": self.load_mrpc_dataset,
        "mnli": self.load_mnli_dataset,
        "cola": self.load_cola_dataset,
        "sst2": self.load_sst2_dataset,
        "qnli": self.load_qnli_dataset,
        "qqp": self.load_qqp_dataset,
        "rte": self.load_rte_dataset,
        # "wnli": load_wnli_dataset,
    }
    return glue_dataset_loaders[name]()

load_mnli_dataset()

Load and tokenize the MNLI dataset.

This method loads the MNLI dataset, tokenizes it using the provided tokenizer, and caches the tokenized dataset to disk.

Returns:

  • Dataset

    The tokenized MNLI dataset.

Source code in fusion_bench/dataset/gpt2_glue.py
@cache_dataset
def load_mnli_dataset(self):
    """
    Load and tokenize the MNLI dataset.

    This method loads the MNLI dataset, tokenizes it using the provided
    tokenizer, and caches the tokenized dataset to disk.

    Returns:
        Dataset: The tokenized MNLI dataset.
    """
    dataset = load_dataset("glue", "mnli")
    dataset = dataset.map(
        partial(mnli_tokenize_function, tokenizer=self.tokenizer),
        batched=True,
        remove_columns=["premise", "hypothesis"],
    )
    return dataset

load_mrpc_dataset()

Load and tokenize the MRPC dataset.

This method loads the MRPC dataset, tokenizes it using the provided tokenizer, and caches the tokenized dataset to disk.

Returns:

  • Dataset

    The tokenized MRPC dataset.

Source code in fusion_bench/dataset/gpt2_glue.py
@cache_dataset
def load_mrpc_dataset(self):
    """
    Load and tokenize the MRPC dataset.

    This method loads the MRPC dataset, tokenizes it using the provided
    tokenizer, and caches the tokenized dataset to disk.

    Returns:
        Dataset: The tokenized MRPC dataset.
    """
    dataset = load_dataset("glue", "mrpc")
    dataset = dataset.map(
        partial(mrpc_tokenize_function, tokenizer=self.tokenizer),
        batched=True,
        remove_columns=["sentence1", "sentence2"],
    )
    return dataset

load_qnli_dataset()

Load and tokenize the QNLI dataset.

This method loads the QNLI dataset, tokenizes it using the provided tokenizer, and caches the tokenized dataset to disk.

Returns:

  • Dataset

    The tokenized QNLI dataset.

Source code in fusion_bench/dataset/gpt2_glue.py
@cache_dataset
def load_qnli_dataset(self):
    """
    Load and tokenize the QNLI dataset.

    This method loads the QNLI dataset, tokenizes it using the provided
    tokenizer, and caches the tokenized dataset to disk.

    Returns:
        Dataset: The tokenized QNLI dataset.
    """
    dataset = load_dataset("glue", "qnli")
    dataset = dataset.map(
        partial(qnli_tokenize_function, tokenizer=self.tokenizer),
        batched=True,
        remove_columns=["question", "sentence"],
    )
    return dataset

load_qqp_dataset()

Load and tokenize the QQP dataset.

This method loads the QQP dataset, tokenizes it using the provided tokenizer, and caches the tokenized dataset to disk.

Returns:

  • Dataset

    The tokenized QQP dataset.

Source code in fusion_bench/dataset/gpt2_glue.py
@cache_dataset
def load_qqp_dataset(self):
    """
    Load and tokenize the QQP dataset.

    This method loads the QQP dataset, tokenizes it using the provided
    tokenizer, and caches the tokenized dataset to disk.

    Returns:
        Dataset: The tokenized QQP dataset.
    """
    dataset = load_dataset("glue", "qqp")
    dataset = dataset.map(
        partial(qqp_tokenize_function, tokenizer=self.tokenizer),
        batched=True,
        remove_columns=["question1", "question2"],
    )
    return dataset

load_rte_dataset()

Load and tokenize the RTE dataset.

This method loads the RTE dataset, tokenizes it using the provided tokenizer, and caches the tokenized dataset to disk.

Returns:

  • Dataset

    The tokenized RTE dataset.

Source code in fusion_bench/dataset/gpt2_glue.py
@cache_dataset
def load_rte_dataset(self):
    """
    Load and tokenize the RTE dataset.

    This method loads the RTE dataset, tokenizes it using the provided
    tokenizer, and caches the tokenized dataset to disk.

    Returns:
        Dataset: The tokenized RTE dataset.
    """
    dataset = load_dataset("glue", "rte")
    dataset = dataset.map(
        # RTE has the same format as MRPC
        partial(mrpc_tokenize_function, tokenizer=self.tokenizer),
        batched=True,
        remove_columns=["sentence1", "sentence2"],
    )
    return dataset

load_sst2_dataset()

Load and tokenize the SST-2 dataset.

This method loads the SST-2 dataset, tokenizes it using the provided tokenizer, and caches the tokenized dataset to disk.

Returns:

  • Dataset

    The tokenized SST-2 dataset.

Source code in fusion_bench/dataset/gpt2_glue.py
@cache_dataset
def load_sst2_dataset(self):
    """
    Load and tokenize the SST-2 dataset.

    This method loads the SST-2 dataset, tokenizes it using the provided
    tokenizer, and caches the tokenized dataset to disk.

    Returns:
        Dataset: The tokenized SST-2 dataset.
    """
    dataset = load_dataset("glue", "sst2")
    dataset = dataset.map(
        partial(cola_tokenize_function, tokenizer=self.tokenizer),
        batched=True,
        remove_columns=["sentence"],
    )
    return dataset

load_wnli_dataset()

Load and tokenize the WNLI dataset.

This method loads the WNLI dataset, tokenizes it using the provided tokenizer, and caches the tokenized dataset to disk.

Returns:

  • Dataset

    The tokenized WNLI dataset.

Source code in fusion_bench/dataset/gpt2_glue.py
@cache_dataset
def load_wnli_dataset(self):
    """
    Load and tokenize the WNLI dataset.

    This method loads the WNLI dataset, tokenizes it using the provided
    tokenizer, and caches the tokenized dataset to disk.

    Returns:
        Dataset: The tokenized WNLI dataset.
    """
    dataset = load_dataset("glue", "wnli")
    dataset = dataset.map(
        partial(mrpc_tokenize_function, tokenizer=self.tokenizer),
        batched=True,
        remove_columns=["sentence1", "sentence2"],
    )
    return dataset