Skip to content

GPT-2 Sequence Classification Tasks

This task pool provides a set of sequence classification tasks from the GLUE benchmark for the GPT-2 model. Each task is associated with a dataset and the accuracy metric. The tasks are: CoLA, MNLI, MRPC, QNLI, QQP, RTE, and SST2.

References

gpt2_text_classification

GPT2TextClassificationTaskPool

Bases: TaskPool

A task pool for GPT2 text classification tasks. This class manages the tasks and provides methods for loading test dataset and evaluation.

Source code in fusion_bench/taskpool/gpt2_text_classification.py
class GPT2TextClassificationTaskPool(TaskPool):
    """
    A task pool for GPT2 text classification tasks.
    This class manages the tasks and provides methods for loading test dataset and evaluation.
    """

    _fabric: L.Fabric = None
    _tokenizer: GPT2Tokenizer = None
    _modelpool: "fusion_bench.modelpool.HuggingFaceGPT2ClassificationPool" = None

    @property
    def fabric(self):
        if self._fabric is not None:
            return self._fabric
        else:
            self._fabric = L.Fabric(devices=1)
            self._fabric.launch()
            return self._fabric

    @property
    def tokenizer(self):
        if self._tokenizer is not None:
            return self._tokenizer
        else:
            raise ValueError("Tokenizer not set")

    def prepare_dataset_config(self, dataset_config: DictConfig):
        """
        Set default values for dataset configuration.
        """
        if not hasattr(dataset_config, "type"):
            with open_dict(dataset_config):
                dataset_config["type"] = self.config.dataset_type
        return dataset_config

    def prepare_task_config(self, task_config: DictConfig):
        """
        Set default values for task configuration.
        """
        for key in ["num_workers", "batch_size", "fast_dev_run"]:
            if not hasattr(task_config, key):
                with open_dict(task_config):
                    task_config[key] = self.config[key]
        return task_config

    def load_task(self, task_name_or_config: str | DictConfig):
        """
        Loads a task given a task name or config. It prepares the task configuration and loads the task from it.
        """
        if isinstance(task_name_or_config, str):
            task_config = self.get_task_config(task_name_or_config)
        else:
            task_config = task_name_or_config
        task_config = self.prepare_task_config(task_config)

        # load the task from the configuration
        task = GPT2ClassificationTask(task_config, self.fabric, self.tokenizer)
        task._fabric = self._fabric
        task._tokenizer = self._tokenizer
        task._taskpool = self

        return task
load_task(task_name_or_config)

Loads a task given a task name or config. It prepares the task configuration and loads the task from it.

Source code in fusion_bench/taskpool/gpt2_text_classification.py
def load_task(self, task_name_or_config: str | DictConfig):
    """
    Loads a task given a task name or config. It prepares the task configuration and loads the task from it.
    """
    if isinstance(task_name_or_config, str):
        task_config = self.get_task_config(task_name_or_config)
    else:
        task_config = task_name_or_config
    task_config = self.prepare_task_config(task_config)

    # load the task from the configuration
    task = GPT2ClassificationTask(task_config, self.fabric, self.tokenizer)
    task._fabric = self._fabric
    task._tokenizer = self._tokenizer
    task._taskpool = self

    return task
prepare_dataset_config(dataset_config)

Set default values for dataset configuration.

Source code in fusion_bench/taskpool/gpt2_text_classification.py
def prepare_dataset_config(self, dataset_config: DictConfig):
    """
    Set default values for dataset configuration.
    """
    if not hasattr(dataset_config, "type"):
        with open_dict(dataset_config):
            dataset_config["type"] = self.config.dataset_type
    return dataset_config
prepare_task_config(task_config)

Set default values for task configuration.

Source code in fusion_bench/taskpool/gpt2_text_classification.py
def prepare_task_config(self, task_config: DictConfig):
    """
    Set default values for task configuration.
    """
    for key in ["num_workers", "batch_size", "fast_dev_run"]:
        if not hasattr(task_config, key):
            with open_dict(task_config):
                task_config[key] = self.config[key]
    return task_config