Skip to content

Flan-T5 Models for Text Generation Tasks

This task pool provides a set of text generation tasks from the GLUE benchmark for the Flan-T5 model. Each task is associated with a dataset. We report the exact match accuracy metric for CoLA, MNLI, MRPC, QNLI, QQP, RTE, and SST2, and spearman's rho for STSB.

References

flan_t5_glue_text_generation

FlanT5GLUETextGenerationTaskPool

Bases: TaskPool

A task pool for FlanT5 GLUE text generation tasks. This class manages the tasks and provides methods for loading and evaluating tasks.

Source code in fusion_bench/taskpool/flan_t5_glue_text_generation.py
class FlanT5GLUETextGenerationTaskPool(TaskPool):
    """
    A task pool for FlanT5 GLUE text generation tasks.
    This class manages the tasks and provides methods for loading and evaluating tasks.
    """

    _fabric: L.Fabric = None
    _tokenizer = None

    @property
    def tokenizer(self):
        """
        Returns the tokenizer. If it's not already initialized, it initializes it using the config's tokenizer.
        """
        if self._tokenizer is None:
            self._tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer)
        return self._tokenizer

    @property
    def fabric(self):
        if self._fabric is not None:
            return self._fabric
        else:
            self._fabric = L.Fabric(devices=1)
            self._fabric.launch()
            return self._fabric

    def load_task(self, task_name_or_config: str | DictConfig):
        """
        Loads a task given a task name or config. If the task name is in `CLASSIFICATION_TASKS`, it creates a `FlanT5GLUETextGenerationClassificationTask`.
        If the task name is in `REGRESSION_TASKS`, it creates a `FlanT5GLUETextGenerationRegressionTask`. Otherwise, it raises a `ValueError`.
        """
        if isinstance(task_name_or_config, str):
            task_config = self.get_task_config(task_name_or_config)
        else:
            task_config = task_name_or_config

        if task_config.name in CLASSIFICATION_TASKS:
            task = FlanT5GLUETextGenerationClassificationTask(task_config)
            task._taskpool = self
            return task
        elif task_config.name in REGRESSION_TASKS:
            task = FlanT5GLUETextGenerationRegressionTask(task_config)
            task._taskpool = self
            return task
        else:
            raise ValueError(f"Unknown task {task_config.name}")

    def evaluate(self, model: T5ForConditionalGeneration):
        if not isinstance(model, T5ForConditionalGeneration):
            log.warning(
                f"Model is not an instance of T5ForConditionalGeneration, but {type(model)}"
            )
        report = {}
        training_params, all_params = count_parameters(model)
        report["model_info"] = {
            "trainable_params": training_params,
            "all_params": all_params,
            "trainable_percentage": training_params / all_params,
        }
        model = self.fabric.setup(model)
        report.update(super().evaluate(model))
        log.info(f"evaluation report: {report}")
        return report
tokenizer property

Returns the tokenizer. If it's not already initialized, it initializes it using the config's tokenizer.

load_task(task_name_or_config)

Loads a task given a task name or config. If the task name is in CLASSIFICATION_TASKS, it creates a FlanT5GLUETextGenerationClassificationTask. If the task name is in REGRESSION_TASKS, it creates a FlanT5GLUETextGenerationRegressionTask. Otherwise, it raises a ValueError.

Source code in fusion_bench/taskpool/flan_t5_glue_text_generation.py
def load_task(self, task_name_or_config: str | DictConfig):
    """
    Loads a task given a task name or config. If the task name is in `CLASSIFICATION_TASKS`, it creates a `FlanT5GLUETextGenerationClassificationTask`.
    If the task name is in `REGRESSION_TASKS`, it creates a `FlanT5GLUETextGenerationRegressionTask`. Otherwise, it raises a `ValueError`.
    """
    if isinstance(task_name_or_config, str):
        task_config = self.get_task_config(task_name_or_config)
    else:
        task_config = task_name_or_config

    if task_config.name in CLASSIFICATION_TASKS:
        task = FlanT5GLUETextGenerationClassificationTask(task_config)
        task._taskpool = self
        return task
    elif task_config.name in REGRESSION_TASKS:
        task = FlanT5GLUETextGenerationRegressionTask(task_config)
        task._taskpool = self
        return task
    else:
        raise ValueError(f"Unknown task {task_config.name}")