Skip to content

Flan-T5 Models for Text Generation Tasks

This task pool provides a set of text generation tasks from the GLUE benchmark for the Flan-T5 model. Each task is associated with a dataset. We report the exact match accuracy metric for CoLA, MNLI, MRPC, QNLI, QQP, RTE, and SST2, and spearman's rho for STSB.

References

flan_t5_glue_text_generation

FlanT5GLUETextGenerationTaskPool

Bases: LightningFabricMixin, TaskPool

A task pool for FlanT5 GLUE text generation tasks. This class manages the tasks and provides methods for loading and evaluating tasks.

Source code in fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py
class FlanT5GLUETextGenerationTaskPool(LightningFabricMixin, TaskPool):
    """
    A task pool for FlanT5 GLUE text generation tasks.
    This class manages the tasks and provides methods for loading and evaluating tasks.
    """

    _tokenizer_instance = None

    @property
    def tokenizer(self):
        """
        Returns the tokenizer. If it's not already initialized, it initializes it using the config's tokenizer.
        """
        if self._tokenizer_instance is None:
            self._tokenizer_instance = AutoTokenizer.from_pretrained(
                self.config.tokenizer
            )
        return self._tokenizer_instance

    def load_task(self, task_name_or_config: str | DictConfig):
        """
        Loads a task given a task name or config. If the task name is in `CLASSIFICATION_TASKS`, it creates a `FlanT5GLUETextGenerationClassificationTask`.
        If the task name is in `REGRESSION_TASKS`, it creates a `FlanT5GLUETextGenerationRegressionTask`. Otherwise, it raises a `ValueError`.
        """
        if isinstance(task_name_or_config, str):
            task_config = self.get_task_config(task_name_or_config)
        else:
            task_config = task_name_or_config

        if task_config.name in CLASSIFICATION_TASKS:
            task = FlanT5GLUETextGenerationClassificationTask(task_config)
            task._taskpool = self
            return task
        elif task_config.name in REGRESSION_TASKS:
            task = FlanT5GLUETextGenerationRegressionTask(task_config)
            task._taskpool = self
            return task
        else:
            raise ValueError(f"Unknown task {task_config.name}")

    def evaluate(self, model: T5ForConditionalGeneration):
        """
        Evaluate the model on the FlanT5 GLUE text generation tasks.

        Args:
            model (T5ForConditionalGeneration): The model to evaluate.

        Returns:
            dict: A dictionary containing the evaluation results for each task.
        """
        if not isinstance(model, T5ForConditionalGeneration):
            log.warning(
                f"Model is not an instance of T5ForConditionalGeneration, but {type(model)}"
            )
        report = {}
        training_params, all_params = count_parameters(model)
        report["model_info"] = {
            "trainable_params": training_params,
            "all_params": all_params,
            "trainable_percentage": training_params / all_params,
        }
        model = self.fabric.setup(model)
        report.update(super().evaluate(model))
        log.info(f"evaluation report: {report}")
        return report
tokenizer property

Returns the tokenizer. If it's not already initialized, it initializes it using the config's tokenizer.

evaluate(model)

Evaluate the model on the FlanT5 GLUE text generation tasks.

Parameters:

  • model (T5ForConditionalGeneration) –

    The model to evaluate.

Returns:

  • dict

    A dictionary containing the evaluation results for each task.

Source code in fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py
def evaluate(self, model: T5ForConditionalGeneration):
    """
    Evaluate the model on the FlanT5 GLUE text generation tasks.

    Args:
        model (T5ForConditionalGeneration): The model to evaluate.

    Returns:
        dict: A dictionary containing the evaluation results for each task.
    """
    if not isinstance(model, T5ForConditionalGeneration):
        log.warning(
            f"Model is not an instance of T5ForConditionalGeneration, but {type(model)}"
        )
    report = {}
    training_params, all_params = count_parameters(model)
    report["model_info"] = {
        "trainable_params": training_params,
        "all_params": all_params,
        "trainable_percentage": training_params / all_params,
    }
    model = self.fabric.setup(model)
    report.update(super().evaluate(model))
    log.info(f"evaluation report: {report}")
    return report
load_task(task_name_or_config)

Loads a task given a task name or config. If the task name is in CLASSIFICATION_TASKS, it creates a FlanT5GLUETextGenerationClassificationTask. If the task name is in REGRESSION_TASKS, it creates a FlanT5GLUETextGenerationRegressionTask. Otherwise, it raises a ValueError.

Source code in fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py
def load_task(self, task_name_or_config: str | DictConfig):
    """
    Loads a task given a task name or config. If the task name is in `CLASSIFICATION_TASKS`, it creates a `FlanT5GLUETextGenerationClassificationTask`.
    If the task name is in `REGRESSION_TASKS`, it creates a `FlanT5GLUETextGenerationRegressionTask`. Otherwise, it raises a `ValueError`.
    """
    if isinstance(task_name_or_config, str):
        task_config = self.get_task_config(task_name_or_config)
    else:
        task_config = task_name_or_config

    if task_config.name in CLASSIFICATION_TASKS:
        task = FlanT5GLUETextGenerationClassificationTask(task_config)
        task._taskpool = self
        return task
    elif task_config.name in REGRESSION_TASKS:
        task = FlanT5GLUETextGenerationRegressionTask(task_config)
        task._taskpool = self
        return task
    else:
        raise ValueError(f"Unknown task {task_config.name}")