This task pool provides a set of text generation tasks from the GLUE benchmark for the Flan-T5 model.
Each task is associated with a dataset.
We report the exact match accuracy metric for CoLA, MNLI, MRPC, QNLI, QQP, RTE, and SST2, and spearman's rho for STSB.
classFlanT5GLUETextGenerationTaskPool(LightningFabricMixin,TaskPool):""" A task pool for FlanT5 GLUE text generation tasks. This class manages the tasks and provides methods for loading and evaluating tasks. """_tokenizer_instance=None@propertydeftokenizer(self):""" Returns the tokenizer. If it's not already initialized, it initializes it using the config's tokenizer. """ifself._tokenizer_instanceisNone:self._tokenizer_instance=AutoTokenizer.from_pretrained(self.config.tokenizer)returnself._tokenizer_instancedefload_task(self,task_name_or_config:str|DictConfig):""" Loads a task given a task name or config. If the task name is in `CLASSIFICATION_TASKS`, it creates a `FlanT5GLUETextGenerationClassificationTask`. If the task name is in `REGRESSION_TASKS`, it creates a `FlanT5GLUETextGenerationRegressionTask`. Otherwise, it raises a `ValueError`. """ifisinstance(task_name_or_config,str):task_config=self.get_task_config(task_name_or_config)else:task_config=task_name_or_configiftask_config.nameinCLASSIFICATION_TASKS:task=FlanT5GLUETextGenerationClassificationTask(task_config)task._taskpool=selfreturntaskeliftask_config.nameinREGRESSION_TASKS:task=FlanT5GLUETextGenerationRegressionTask(task_config)task._taskpool=selfreturntaskelse:raiseValueError(f"Unknown task {task_config.name}")defevaluate(self,model:T5ForConditionalGeneration,name:str=None):""" Evaluate the model on the FlanT5 GLUE text generation tasks. Args: model (T5ForConditionalGeneration): The model to evaluate. name (str, optional): The name of the model. Defaults to None. This is used to identify the model in the report. Returns: dict: A dictionary containing the evaluation results for each task. """ifnotisinstance(model,T5ForConditionalGeneration):log.warning(f"Model is not an instance of T5ForConditionalGeneration, but {type(model)}")report={}training_params,all_params=count_parameters(model)report["model_info"]={"trainable_params":training_params,"all_params":all_params,"trainable_percentage":training_params/all_params,}ifnameisnotNone:report["model_info"]["name"]=namemodel=self.fabric.setup(model)report.update(super().evaluate(model))log.info(f"evaluation report: {report}")returnreport
defevaluate(self,model:T5ForConditionalGeneration,name:str=None):""" Evaluate the model on the FlanT5 GLUE text generation tasks. Args: model (T5ForConditionalGeneration): The model to evaluate. name (str, optional): The name of the model. Defaults to None. This is used to identify the model in the report. Returns: dict: A dictionary containing the evaluation results for each task. """ifnotisinstance(model,T5ForConditionalGeneration):log.warning(f"Model is not an instance of T5ForConditionalGeneration, but {type(model)}")report={}training_params,all_params=count_parameters(model)report["model_info"]={"trainable_params":training_params,"all_params":all_params,"trainable_percentage":training_params/all_params,}ifnameisnotNone:report["model_info"]["name"]=namemodel=self.fabric.setup(model)report.update(super().evaluate(model))log.info(f"evaluation report: {report}")returnreport
Loads a task given a task name or config. If the task name is in CLASSIFICATION_TASKS, it creates a FlanT5GLUETextGenerationClassificationTask.
If the task name is in REGRESSION_TASKS, it creates a FlanT5GLUETextGenerationRegressionTask. Otherwise, it raises a ValueError.
Source code in fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py
defload_task(self,task_name_or_config:str|DictConfig):""" Loads a task given a task name or config. If the task name is in `CLASSIFICATION_TASKS`, it creates a `FlanT5GLUETextGenerationClassificationTask`. If the task name is in `REGRESSION_TASKS`, it creates a `FlanT5GLUETextGenerationRegressionTask`. Otherwise, it raises a `ValueError`. """ifisinstance(task_name_or_config,str):task_config=self.get_task_config(task_name_or_config)else:task_config=task_name_or_configiftask_config.nameinCLASSIFICATION_TASKS:task=FlanT5GLUETextGenerationClassificationTask(task_config)task._taskpool=selfreturntaskeliftask_config.nameinREGRESSION_TASKS:task=FlanT5GLUETextGenerationRegressionTask(task_config)task._taskpool=selfreturntaskelse:raiseValueError(f"Unknown task {task_config.name}")