Flan-T5 Models for Text Generation Tasks
This task pool provides a set of text generation tasks from the GLUE benchmark for the Flan-T5 model.
Each task is associated with a dataset.
We report the exact match accuracy metric for CoLA, MNLI, MRPC, QNLI, QQP, RTE, and SST2, and spearman's rho for STSB.
References
flan_t5_glue_text_generation
FlanT5GLUETextGenerationTaskPool
Bases: TaskPool
A task pool for FlanT5 GLUE text generation tasks.
This class manages the tasks and provides methods for loading and evaluating tasks.
Source code in fusion_bench/taskpool/flan_t5_glue_text_generation.py
| class FlanT5GLUETextGenerationTaskPool(TaskPool):
"""
A task pool for FlanT5 GLUE text generation tasks.
This class manages the tasks and provides methods for loading and evaluating tasks.
"""
_fabric: L.Fabric = None
_tokenizer = None
@property
def tokenizer(self):
"""
Returns the tokenizer. If it's not already initialized, it initializes it using the config's tokenizer.
"""
if self._tokenizer is None:
self._tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer)
return self._tokenizer
@property
def fabric(self):
if self._fabric is not None:
return self._fabric
else:
self._fabric = L.Fabric(devices=1)
self._fabric.launch()
return self._fabric
def load_task(self, task_name_or_config: str | DictConfig):
"""
Loads a task given a task name or config. If the task name is in `CLASSIFICATION_TASKS`, it creates a `FlanT5GLUETextGenerationClassificationTask`.
If the task name is in `REGRESSION_TASKS`, it creates a `FlanT5GLUETextGenerationRegressionTask`. Otherwise, it raises a `ValueError`.
"""
if isinstance(task_name_or_config, str):
task_config = self.get_task_config(task_name_or_config)
else:
task_config = task_name_or_config
if task_config.name in CLASSIFICATION_TASKS:
task = FlanT5GLUETextGenerationClassificationTask(task_config)
task._taskpool = self
return task
elif task_config.name in REGRESSION_TASKS:
task = FlanT5GLUETextGenerationRegressionTask(task_config)
task._taskpool = self
return task
else:
raise ValueError(f"Unknown task {task_config.name}")
def evaluate(self, model: T5ForConditionalGeneration):
if not isinstance(model, T5ForConditionalGeneration):
log.warning(
f"Model is not an instance of T5ForConditionalGeneration, but {type(model)}"
)
report = {}
training_params, all_params = count_parameters(model)
report["model_info"] = {
"trainable_params": training_params,
"all_params": all_params,
"trainable_percentage": training_params / all_params,
}
model = self.fabric.setup(model)
report.update(super().evaluate(model))
log.info(f"evaluation report: {report}")
return report
|
tokenizer
property
Returns the tokenizer. If it's not already initialized, it initializes it using the config's tokenizer.
load_task(task_name_or_config)
Loads a task given a task name or config. If the task name is in CLASSIFICATION_TASKS
, it creates a FlanT5GLUETextGenerationClassificationTask
.
If the task name is in REGRESSION_TASKS
, it creates a FlanT5GLUETextGenerationRegressionTask
. Otherwise, it raises a ValueError
.
Source code in fusion_bench/taskpool/flan_t5_glue_text_generation.py
| def load_task(self, task_name_or_config: str | DictConfig):
"""
Loads a task given a task name or config. If the task name is in `CLASSIFICATION_TASKS`, it creates a `FlanT5GLUETextGenerationClassificationTask`.
If the task name is in `REGRESSION_TASKS`, it creates a `FlanT5GLUETextGenerationRegressionTask`. Otherwise, it raises a `ValueError`.
"""
if isinstance(task_name_or_config, str):
task_config = self.get_task_config(task_name_or_config)
else:
task_config = task_name_or_config
if task_config.name in CLASSIFICATION_TASKS:
task = FlanT5GLUETextGenerationClassificationTask(task_config)
task._taskpool = self
return task
elif task_config.name in REGRESSION_TASKS:
task = FlanT5GLUETextGenerationRegressionTask(task_config)
task._taskpool = self
return task
else:
raise ValueError(f"Unknown task {task_config.name}")
|