Skip to content

Flan-T5 Models for Text Generation

Model Information

Prompt-based fine-tuned Flan-T5 models on GLUE benchmark tasks. The models are fine-tuned in a text-to-text setting, and the prompt templates are provided below. The source code for the prompt templates can be found in the repository.

fusion_bench/tasks/flan_t5_text_generation/glue_prompt_templates.py
cola = {
    "description": "template used by GLUE-CoLA",
    "input_text": "Indicate if the following sentence is grammatically correct or not: \"{sentence}\". Answere 'acceptable' or 'unacceptable'.",
    "target_text": {"0": "unacceptable", "1": "acceptable"},
}
mnli = {
    "input_text": "Does the premise: '{premise}' logically imply, contradict, or is neutral to the hypothesis: '{hypothesis}'? Answere with 'entailment', 'contradiction', or 'neutral'.",
    "target_text": {"0": "entailment", "1": "neutral", "2": "contradiction"},
}
mrpc = {
    "input_text": "Are the following sentences '{sentence1}' and '{sentence2}' conveying the same meaning? Answere with 'yes' or 'no'.",
    "target_text": {"0": "no", "1": "yes"},
}
qnli = {
    "input_text": "Given the context: '{sentence}', does the question '{question}' have an answer based on the information provided? Answer with 'yes' or 'no'.",
    "target_text": {"0": "yes", "1": "no"},
}
qqp = {
    "input_text": "Do the questions '{question1}' and '{question2}' have the same intent? Answere with 'yes' or 'no'.",
    "target_text": {"0": "no", "1": "yes"},
}
rte = {
    "description": "Template used by GLUE-RTE",
    "input_text": "Does the text: '{sentence1}' entail that '{sentence2}' is true? Provide 'yes' or 'no'.",
    "target_text": {"0": "yes", "1": "no"},
}
sst2 = {
    "input_text": "Given the sentence '{sentence}', determine the sentiment. Is it positive or negative?",
    "target_text": {"0": "negative", "1": "positive"},
}
stsb = {
    "input_text": "Consider the sentences '{sentence1}' and '{sentence2}'. On a scale from 1 (completely different) to 5 (completely similar), rate the similarity.",
    "target_text": "{:.1f}",
}

Flan-T5-base

Full Fine-tuned Models

full fine-tuned Flan-T5-base models on tasks from GLUE benchmark

Model GLUE-COLA GLUE-MNLI GLUE-MRPC GLUE-QNLI GLUE-QQP GLUE-RTE GLUE-SST2 GLUE-STSB Average
Pre-trained 69.127517 56.454407 76.225490 88.449570 82.119713 80.144404 91.169725 62.190453 75.735160
GLUE-COLA 74.976031 37.208355 72.794118 87.552627 80.415533 76.895307 91.399083 63.583974 73.103128
GLUE-MNLI 65.867689 83.413143 75.735294 89.236683 82.616869 77.978339 90.596330 66.215025 78.957422
GLUE-MRPC 63.374880 48.293428 87.500000 85.831960 81.100668 72.563177 88.073394 76.062875 75.350048
GLUE-QNLI 68.744008 39.246052 75.490196 91.488193 81.291120 78.339350 91.628440 68.200428 74.303474
GLUE-QQP 59.060403 50.412634 73.774510 88.339740 85.369775 81.227437 90.825688 75.948390 75.619822
GLUE-RTE 65.388303 51.115639 69.607843 88.705839 80.774178 85.920578 90.252294 68.944418 75.088636
GLUE-SST2 67.785235 53.958227 76.470588 87.772286 83.415780 80.505415 93.577982 63.612718 75.887279
GLUE-STSB 69.319271 49.302089 76.470588 88.962109 81.662132 77.617329 90.137615 88.695433 77.770821

LoRA Fine-tuned Models (r=16)

LoRA fine-tuned (r=16) Flan-T5-base models on tasks from GLUE benchmark:

Model GLUE-COLA GLUE-MNLI GLUE-MRPC GLUE-QNLI GLUE-QQP GLUE-RTE GLUE-SST2 GLUE-STSB Average
Pre-trained 69.1 56.5 76.2 88.4 82.1 80.1 91.2 62.2 75.7
GLUE-COLA 69.1 39.9 75.2 89.1 81.1 81.9 90.7 54.0
GLUE-MNLI 69.4 82.7 73.8 89.3 82.0 79.4 90.9 68.1
GLUE-MRPC 64.0 44.9 85.5 82.6 81.0 69.0 88.6 73.6
GLUE-QNLI 68.9 52.7 76.7 90.9 82.8 79.8 91.5 68.9
GLUE-QQP 65.0 54.6 75.7 89.0 84.0 81.6 90.7 75.3
GLUE-RTE 64.9 51.8 69.4 89.2 79.8 84.5 90.6 70.1
GLUE-SST2 68.3 56.6 76.0 88.5 83.4 79.8 92.9 62.6
GLUE-STSB 65.7 1.7 67.4 89.3 80.1 79.8 90.8 87.4

Flan-T5-Large

LoRA Fine-tuned Models (r=16)

LoRA fine-tuned (r=16) Flan-T5-large models on tasks from GLUE benchmark:

Model GLUE-COLA GLUE-MNLI GLUE-MRPC GLUE-QNLI GLUE-QQP GLUE-RTE GLUE-SST2 GLUE-STSB Average
Pretrained 73.7 56.6 82.4 91.1 85.5 85.6 94.3 87.5 82.1
GLUE-COLA 80.2 53.9 81.4 90.8 84.5 84.1 93.9 87.1
GLUE-MNLI 73.7 88.5 77.9 92.4 85.2 87.7 94.4 86.7
GLUE-MRPC 75.6 52.6 89.2 92.6 84.4 86.3 94.3 86.3
GLUE-QNLI 73.5 54.5 82.8 94.4 85.8 85.2 93.7 87.1
GLUE-QQP 74.0 53.8 82.8 92.5 87.2 85.6 94.5 88.3
GLUE-RTE 75.6 57.5 69.9 92.8 83.8 91.7 94.6 86.0
GLUE-SST2 73.6 55.3 82.1 91.6 85.5 85.2 95.2 86.9
GLUE-STSB 73.4 39.3 82.1 92.6 86.1 83.4 94.0 90.9

Basic Examples

Inverstigate Model Information

Load pre-trained Flan-T5-base model and print the model information

fusion_bench \
    modelpool=Seq2SeqLMPool/flan-t5-base_glue \
    method=dummy taskpool=dummy
# {'model_info': {'trainable_params': 247577856, 'all_params': 247577856, 'trainable_percentage': 1.0}}

Load pre-trained Flan-T5-large model and print the model information

fusion_bench \
    modelpool=Seq2SeqLMPool/flan-t5-large_glue_lora16 \
    method=dummy taskpool=dummy

Evaluate Single Model

Evaluate the pre-trained Flan-T5-base model on GLUE tasks

fusion_bench \
    method=dummy \
    modelpool=Seq2SeqLMPool/flan-t5-base_individual \
        modelpool.models._pretrained_.pretrained_model_name_or_path=google/flan-t5-base \
    taskpool=flan-t5_glue_text_generation \
    report_save_path=outputs/flan-t5-base/pretrained.json

or evaluate the fine-tuned Flan-T5-base model on GLUE tasks

for task in cola mnli mrpc qnli qqp rte sst2 stsb; do
    fusion_bench \
        method=dummy \
        modelpool=Seq2SeqLMPool/flan-t5-base_individual \
            modelpool.models._pretrained_.pretrained_model_name_or_path=tanganke/flan-t5-base_glue-${task} \
        taskpool=flan-t5_glue_text_generation \
        report_save_path=outputs/flan-t5-base/glue-$task.json
done

Simple Average

Merge the Flan-T5 models on GLUE tasks using simple average and evaluate on the Flan-T5 text generation task

# for full fine-tuned models
fusion_bench \
    method=simple_average \
    modelpool=Seq2SeqLMPool/flan-t5-base_glue \
    taskpool=flan-t5_glue_text_generation

# or using the LoRA fine-tuned models
fusion_bench \
    method=simple_average \
    modelpool=Seq2SeqLMPool/flan-t5-base_glue_lora16 \
    taskpool=flan-t5_glue_text_generation

Task Arithmetic

Merge the Flan-T5 models on GLUE tasks using task arithmetic and evaluate on the Flan-T5 text generation task, with scaling factor from 0.0 to 1.0

# full fine-tuned models with scaling factor set to 0.3
fusion_bench \
    method=task_arithmetic \
        method.scaling_factor=0.3 \
    modelpool=Seq2SeqLMPool/flan-t5-base_glue \
    taskpool=flan-t5_glue_text_generation

# use a for loop to evaluate the performance of task arithmetic with different scaling factors (LoRA fine-tuned models)
for scaling_factor in $(seq 0.0 0.1 1.0)
do
    fusion_bench \
        method=task_arithmetic \
            method.scaling_factor=$scaling_factor \
        modelpool=Seq2SeqLMPool/flan-t5-base_glue_lora16 \
        taskpool=flan-t5_glue_text_generation
done
scaling_coef glue-cola glue-mnli glue-mrpc glue-qnli glue-qqp glue-rte glue-sst2 glue-stsb Average
0.0 69.1 56.5 76.2 88.4 82.1 80.1 91.2 62.2 75.7
0.1 69.5 59.8 78.7 89.7 83.6 80.5 91.1 70.7 78.0
0.2 69.3 59.0 78.7 90.1 83.8 79.1 91.5 72.9 78.1
0.3 68.8 55.2 78.7 89.8 83.7 79.1 91.5 72.4 77.4
0.4 68.1 31.3 77.7 88.7 83.3 78.7 91.2 68.9 73.5
0.5 66.0 2.2 78.7 86.3 82.6 78.0 90.4 74.2
0.6 5.9 0.0 78.4 81.2 81.6 74.4 88.2 74.9
0.7 0.0 0.0 77.0 74.1 79.8 66.1 70.8 74.7
0.8 0.0 0.0 74.0 67.5 77.0 62.1 8.1 72.5
0.9 0.0 0.0 66.7 60.5 71.7 58.5 0.0 71.6
1.0 0.0 0.0 46.3 50.6 56.3 52.3 0.0 69.8

Ties-Merging

or using ties-merging

# for full fine-tuned models with scaling factor set to 0.3
fusion_bench \
    method=ties_merging \
        method.scaling_factor=0.3 \
    modelpool=Seq2SeqLMPool/flan-t5-base_glue \
    taskpool=flan-t5_glue_text_generation

# use a for loop to evaluate the performance of ties-merging with different scaling factors (LoRA fine-tuned models)
for scaling_factor in $(seq 0.0 0.1 1.0)
do
    fusion_bench \
        method=ties_merging \
            method.scaling_factor=$scaling_factor \
        modelpool=Seq2SeqLMPool/flan-t5-base_glue_lora16 \
        taskpool=flan-t5_glue_text_generation
done

Experimental Results

Flan-T5-Base models:

Method CoLA MNLI MRPC QNLI QQP RTE SST-2 STSB Average
Reference Results
Pre-trained 69.1 56.5 76.2 88.4 82.1 80.1 91.2 62.2 75.7
Fine-tuned (STL) 75.0 83.4 87.5 91.5 85.4 85.9 93.6 88.7 86.4
Model Merging Methods
Simple Average 69.1 62.6 79.4 89.8 83.9 81.2 91.7 73.2 78.9
Task Arithmetic (\(\lambda=0.3\)) 70.5 57.8 78.4 90.2 83.6 80.5 92.3 77.8 78.9
Ties-Merging (\(\lambda=0.3\)) 70.3 65.0 78.9 90.2 83.5 81.6 91.7 78.3 79.9
Method CoLA MNLI MRPC QNLI QQP RTE SST-2 STSB Average
Reference Results
Pre-trained 69.1 56.5 76.2 88.4 82.1 80.1 91.2 62.2 75.7
Fine-tuned (STL) 69.1 82.7 85.5 90.9 84.0 84.4 92.9 87.4 84.6
Model Merging Methods
Simple Average 69.7 59.7 78.9 90.1 83.8 90.5 91.2 72.0 78.2
Task Arithmetic (\(\lambda=0.3\)) 68.8 55.2 78.7 89.8 83.7 79.1 91.5 72.4 77.4
Ties-Merging (\(\lambda=0.3\)) 68.3 56.3 79.4 89.8 83.7 79.4 91.6 71.2 77.5

Flan-T5-Large models:

Method CoLA MNLI MRPC QNLI QQP RTE SST-2 STSB Average
Reference Results
Pre-trained 73.7 56.6 82.4 91.1 85.5 85.6 94.3 87.5 82.1
Fine-tuned (STL) 80.2 88.5 89.2 94.4 87.2 91.7 95.2 90.9 89.6
Model Merging Methods
Simple Average 74.6 84.3 84.1 92.8 86.3 87.4 94.8 88.0 86.5
Task Arithmetic (\(\lambda=0.3\)) 76.9 85.4 85.3 93.9 85.8 88.1 95.2 87.8 87.3
Ties-Merging (\(\lambda=0.3\)) 77.1 85.1 86.3 93.9 86.0 87.7 95.1 88.0 87.4

References

Seq2SeqLMPool

Bases: BaseModelPool

Source code in fusion_bench/modelpool/seq2seq_lm/modelpool.py
class Seq2SeqLMPool(BaseModelPool):
    _config_mapping = BaseModelPool._config_mapping | {
        "_tokenizer": "tokenizer",
        "_model_kwargs": "model_kwargs",
    }

    def __init__(
        self,
        models: DictConfig,
        *,
        tokenizer: Optional[DictConfig],
        model_kwargs: Optional[DictConfig] = None,
        **kwargs,
    ):
        super().__init__(models, **kwargs)
        self._tokenizer = tokenizer
        self._model_kwargs = model_kwargs
        if self._model_kwargs is None:
            self._model_kwargs = DictConfig({})
        with flag_override(self._model_kwargs, "allow_objects", True):
            if hasattr(self._model_kwargs, "torch_dtype"):
                self._model_kwargs.torch_dtype = parse_dtype(
                    self._model_kwargs.torch_dtype
                )

    def load_model(self, model_name_or_config: str | DictConfig, *args, **kwargs):
        model_kwargs = deepcopy(self._model_kwargs)
        model_kwargs.update(kwargs)
        return super().load_model(model_name_or_config, *args, **model_kwargs)

    def load_tokenizer(self, *args, **kwargs):
        assert self._tokenizer is not None, "Tokenizer is not defined in the config"
        tokenizer = isinstance(self._tokenizer, *args, **kwargs)
        return tokenizer