Flan-T5 Models for Text Generation¶
Model Information¶
Prompt-based fine-tuned Flan-T5 models on GLUE benchmark tasks. The models are fine-tuned in a text-to-text setting, and the prompt templates are provided below. The source code for the prompt templates can be found in the repository.
cola = {
"description": "template used by GLUE-CoLA",
"input_text": "Indicate if the following sentence is grammatically correct or not: \"{sentence}\". Answere 'acceptable' or 'unacceptable'.",
"target_text": {"0": "unacceptable", "1": "acceptable"},
}
mnli = {
"input_text": "Does the premise: '{premise}' logically imply, contradict, or is neutral to the hypothesis: '{hypothesis}'? Answere with 'entailment', 'contradiction', or 'neutral'.",
"target_text": {"0": "entailment", "1": "neutral", "2": "contradiction"},
}
mrpc = {
"input_text": "Are the following sentences '{sentence1}' and '{sentence2}' conveying the same meaning? Answere with 'yes' or 'no'.",
"target_text": {"0": "no", "1": "yes"},
}
qnli = {
"input_text": "Given the context: '{sentence}', does the question '{question}' have an answer based on the information provided? Answer with 'yes' or 'no'.",
"target_text": {"0": "yes", "1": "no"},
}
qqp = {
"input_text": "Do the questions '{question1}' and '{question2}' have the same intent? Answere with 'yes' or 'no'.",
"target_text": {"0": "no", "1": "yes"},
}
rte = {
"description": "Template used by GLUE-RTE",
"input_text": "Does the text: '{sentence1}' entail that '{sentence2}' is true? Provide 'yes' or 'no'.",
"target_text": {"0": "yes", "1": "no"},
}
sst2 = {
"input_text": "Given the sentence '{sentence}', determine the sentiment. Is it positive or negative?",
"target_text": {"0": "negative", "1": "positive"},
}
stsb = {
"input_text": "Consider the sentences '{sentence1}' and '{sentence2}'. On a scale from 1 (completely different) to 5 (completely similar), rate the similarity.",
"target_text": "{:.1f}",
}
Flan-T5-base¶
Full Fine-tuned Models¶
full fine-tuned Flan-T5-base models on tasks from GLUE benchmark
Model | GLUE-COLA | GLUE-MNLI | GLUE-MRPC | GLUE-QNLI | GLUE-QQP | GLUE-RTE | GLUE-SST2 | GLUE-STSB | Average |
---|---|---|---|---|---|---|---|---|---|
Pre-trained | 69.127517 | 56.454407 | 76.225490 | 88.449570 | 82.119713 | 80.144404 | 91.169725 | 62.190453 | 75.735160 |
GLUE-COLA | 74.976031 | 37.208355 | 72.794118 | 87.552627 | 80.415533 | 76.895307 | 91.399083 | 63.583974 | 73.103128 |
GLUE-MNLI | 65.867689 | 83.413143 | 75.735294 | 89.236683 | 82.616869 | 77.978339 | 90.596330 | 66.215025 | 78.957422 |
GLUE-MRPC | 63.374880 | 48.293428 | 87.500000 | 85.831960 | 81.100668 | 72.563177 | 88.073394 | 76.062875 | 75.350048 |
GLUE-QNLI | 68.744008 | 39.246052 | 75.490196 | 91.488193 | 81.291120 | 78.339350 | 91.628440 | 68.200428 | 74.303474 |
GLUE-QQP | 59.060403 | 50.412634 | 73.774510 | 88.339740 | 85.369775 | 81.227437 | 90.825688 | 75.948390 | 75.619822 |
GLUE-RTE | 65.388303 | 51.115639 | 69.607843 | 88.705839 | 80.774178 | 85.920578 | 90.252294 | 68.944418 | 75.088636 |
GLUE-SST2 | 67.785235 | 53.958227 | 76.470588 | 87.772286 | 83.415780 | 80.505415 | 93.577982 | 63.612718 | 75.887279 |
GLUE-STSB | 69.319271 | 49.302089 | 76.470588 | 88.962109 | 81.662132 | 77.617329 | 90.137615 | 88.695433 | 77.770821 |
LoRA Fine-tuned Models (r=16)¶
LoRA fine-tuned (r=16) Flan-T5-base models on tasks from GLUE benchmark:
Model | GLUE-COLA | GLUE-MNLI | GLUE-MRPC | GLUE-QNLI | GLUE-QQP | GLUE-RTE | GLUE-SST2 | GLUE-STSB | Average |
---|---|---|---|---|---|---|---|---|---|
Pre-trained | 69.1 | 56.5 | 76.2 | 88.4 | 82.1 | 80.1 | 91.2 | 62.2 | 75.7 |
GLUE-COLA | 69.1 | 39.9 | 75.2 | 89.1 | 81.1 | 81.9 | 90.7 | 54.0 | |
GLUE-MNLI | 69.4 | 82.7 | 73.8 | 89.3 | 82.0 | 79.4 | 90.9 | 68.1 | |
GLUE-MRPC | 64.0 | 44.9 | 85.5 | 82.6 | 81.0 | 69.0 | 88.6 | 73.6 | |
GLUE-QNLI | 68.9 | 52.7 | 76.7 | 90.9 | 82.8 | 79.8 | 91.5 | 68.9 | |
GLUE-QQP | 65.0 | 54.6 | 75.7 | 89.0 | 84.0 | 81.6 | 90.7 | 75.3 | |
GLUE-RTE | 64.9 | 51.8 | 69.4 | 89.2 | 79.8 | 84.5 | 90.6 | 70.1 | |
GLUE-SST2 | 68.3 | 56.6 | 76.0 | 88.5 | 83.4 | 79.8 | 92.9 | 62.6 | |
GLUE-STSB | 65.7 | 1.7 | 67.4 | 89.3 | 80.1 | 79.8 | 90.8 | 87.4 |
Flan-T5-Large¶
LoRA Fine-tuned Models (r=16)¶
LoRA fine-tuned (r=16) Flan-T5-large models on tasks from GLUE benchmark:
Model | GLUE-COLA | GLUE-MNLI | GLUE-MRPC | GLUE-QNLI | GLUE-QQP | GLUE-RTE | GLUE-SST2 | GLUE-STSB | Average |
---|---|---|---|---|---|---|---|---|---|
Pretrained | 73.7 | 56.6 | 82.4 | 91.1 | 85.5 | 85.6 | 94.3 | 87.5 | 82.1 |
GLUE-COLA | 80.2 | 53.9 | 81.4 | 90.8 | 84.5 | 84.1 | 93.9 | 87.1 | |
GLUE-MNLI | 73.7 | 88.5 | 77.9 | 92.4 | 85.2 | 87.7 | 94.4 | 86.7 | |
GLUE-MRPC | 75.6 | 52.6 | 89.2 | 92.6 | 84.4 | 86.3 | 94.3 | 86.3 | |
GLUE-QNLI | 73.5 | 54.5 | 82.8 | 94.4 | 85.8 | 85.2 | 93.7 | 87.1 | |
GLUE-QQP | 74.0 | 53.8 | 82.8 | 92.5 | 87.2 | 85.6 | 94.5 | 88.3 | |
GLUE-RTE | 75.6 | 57.5 | 69.9 | 92.8 | 83.8 | 91.7 | 94.6 | 86.0 | |
GLUE-SST2 | 73.6 | 55.3 | 82.1 | 91.6 | 85.5 | 85.2 | 95.2 | 86.9 | |
GLUE-STSB | 73.4 | 39.3 | 82.1 | 92.6 | 86.1 | 83.4 | 94.0 | 90.9 |
Basic Examples¶
Inverstigate Model Information¶
Load pre-trained Flan-T5-base model and print the model information
fusion_bench \
modelpool=Seq2SeqLMPool/flan-t5-base_glue \
method=dummy taskpool=dummy
# {'model_info': {'trainable_params': 247577856, 'all_params': 247577856, 'trainable_percentage': 1.0}}
Load pre-trained Flan-T5-large model and print the model information
Evaluate Single Model¶
Evaluate the pre-trained Flan-T5-base model on GLUE tasks
fusion_bench \
method=dummy \
modelpool=Seq2SeqLMPool/flan-t5-base_individual \
modelpool.models._pretrained_.pretrained_model_name_or_path=google/flan-t5-base \
taskpool=flan-t5_glue_text_generation \
report_save_path=outputs/flan-t5-base/pretrained.json
or evaluate the fine-tuned Flan-T5-base model on GLUE tasks
for task in cola mnli mrpc qnli qqp rte sst2 stsb; do
fusion_bench \
method=dummy \
modelpool=Seq2SeqLMPool/flan-t5-base_individual \
modelpool.models._pretrained_.pretrained_model_name_or_path=tanganke/flan-t5-base_glue-${task} \
taskpool=flan-t5_glue_text_generation \
report_save_path=outputs/flan-t5-base/glue-$task.json
done
Simple Average¶
Merge the Flan-T5 models on GLUE tasks using simple average and evaluate on the Flan-T5 text generation task
# for full fine-tuned models
fusion_bench \
method=simple_average \
modelpool=Seq2SeqLMPool/flan-t5-base_glue \
taskpool=flan-t5_glue_text_generation
# or using the LoRA fine-tuned models
fusion_bench \
method=simple_average \
modelpool=Seq2SeqLMPool/flan-t5-base_glue_lora16 \
taskpool=flan-t5_glue_text_generation
Task Arithmetic¶
Merge the Flan-T5 models on GLUE tasks using task arithmetic and evaluate on the Flan-T5 text generation task, with scaling factor from 0.0 to 1.0
# full fine-tuned models with scaling factor set to 0.3
fusion_bench \
method=task_arithmetic \
method.scaling_factor=0.3 \
modelpool=Seq2SeqLMPool/flan-t5-base_glue \
taskpool=flan-t5_glue_text_generation
# use a for loop to evaluate the performance of task arithmetic with different scaling factors (LoRA fine-tuned models)
for scaling_factor in $(seq 0.0 0.1 1.0)
do
fusion_bench \
method=task_arithmetic \
method.scaling_factor=$scaling_factor \
modelpool=Seq2SeqLMPool/flan-t5-base_glue_lora16 \
taskpool=flan-t5_glue_text_generation
done
scaling_coef | glue-cola | glue-mnli | glue-mrpc | glue-qnli | glue-qqp | glue-rte | glue-sst2 | glue-stsb | Average |
---|---|---|---|---|---|---|---|---|---|
0.0 | 69.1 | 56.5 | 76.2 | 88.4 | 82.1 | 80.1 | 91.2 | 62.2 | 75.7 |
0.1 | 69.5 | 59.8 | 78.7 | 89.7 | 83.6 | 80.5 | 91.1 | 70.7 | 78.0 |
0.2 | 69.3 | 59.0 | 78.7 | 90.1 | 83.8 | 79.1 | 91.5 | 72.9 | 78.1 |
0.3 | 68.8 | 55.2 | 78.7 | 89.8 | 83.7 | 79.1 | 91.5 | 72.4 | 77.4 |
0.4 | 68.1 | 31.3 | 77.7 | 88.7 | 83.3 | 78.7 | 91.2 | 68.9 | 73.5 |
0.5 | 66.0 | 2.2 | 78.7 | 86.3 | 82.6 | 78.0 | 90.4 | 74.2 | |
0.6 | 5.9 | 0.0 | 78.4 | 81.2 | 81.6 | 74.4 | 88.2 | 74.9 | |
0.7 | 0.0 | 0.0 | 77.0 | 74.1 | 79.8 | 66.1 | 70.8 | 74.7 | |
0.8 | 0.0 | 0.0 | 74.0 | 67.5 | 77.0 | 62.1 | 8.1 | 72.5 | |
0.9 | 0.0 | 0.0 | 66.7 | 60.5 | 71.7 | 58.5 | 0.0 | 71.6 | |
1.0 | 0.0 | 0.0 | 46.3 | 50.6 | 56.3 | 52.3 | 0.0 | 69.8 |
Ties-Merging¶
or using ties-merging
# for full fine-tuned models with scaling factor set to 0.3
fusion_bench \
method=ties_merging \
method.scaling_factor=0.3 \
modelpool=Seq2SeqLMPool/flan-t5-base_glue \
taskpool=flan-t5_glue_text_generation
# use a for loop to evaluate the performance of ties-merging with different scaling factors (LoRA fine-tuned models)
for scaling_factor in $(seq 0.0 0.1 1.0)
do
fusion_bench \
method=ties_merging \
method.scaling_factor=$scaling_factor \
modelpool=Seq2SeqLMPool/flan-t5-base_glue_lora16 \
taskpool=flan-t5_glue_text_generation
done
Experimental Results¶
Flan-T5-Base models:
Method | CoLA | MNLI | MRPC | QNLI | QQP | RTE | SST-2 | STSB | Average |
---|---|---|---|---|---|---|---|---|---|
Reference Results | |||||||||
Pre-trained | 69.1 | 56.5 | 76.2 | 88.4 | 82.1 | 80.1 | 91.2 | 62.2 | 75.7 |
Fine-tuned (STL) | 75.0 | 83.4 | 87.5 | 91.5 | 85.4 | 85.9 | 93.6 | 88.7 | 86.4 |
Model Merging Methods | |||||||||
Simple Average | 69.1 | 62.6 | 79.4 | 89.8 | 83.9 | 81.2 | 91.7 | 73.2 | 78.9 |
Task Arithmetic (\(\lambda=0.3\)) | 70.5 | 57.8 | 78.4 | 90.2 | 83.6 | 80.5 | 92.3 | 77.8 | 78.9 |
Ties-Merging (\(\lambda=0.3\)) | 70.3 | 65.0 | 78.9 | 90.2 | 83.5 | 81.6 | 91.7 | 78.3 | 79.9 |
Method | CoLA | MNLI | MRPC | QNLI | QQP | RTE | SST-2 | STSB | Average |
---|---|---|---|---|---|---|---|---|---|
Reference Results | |||||||||
Pre-trained | 69.1 | 56.5 | 76.2 | 88.4 | 82.1 | 80.1 | 91.2 | 62.2 | 75.7 |
Fine-tuned (STL) | 69.1 | 82.7 | 85.5 | 90.9 | 84.0 | 84.4 | 92.9 | 87.4 | 84.6 |
Model Merging Methods | |||||||||
Simple Average | 69.7 | 59.7 | 78.9 | 90.1 | 83.8 | 90.5 | 91.2 | 72.0 | 78.2 |
Task Arithmetic (\(\lambda=0.3\)) | 68.8 | 55.2 | 78.7 | 89.8 | 83.7 | 79.1 | 91.5 | 72.4 | 77.4 |
Ties-Merging (\(\lambda=0.3\)) | 68.3 | 56.3 | 79.4 | 89.8 | 83.7 | 79.4 | 91.6 | 71.2 | 77.5 |
Flan-T5-Large models:
Method | CoLA | MNLI | MRPC | QNLI | QQP | RTE | SST-2 | STSB | Average |
---|---|---|---|---|---|---|---|---|---|
Reference Results | |||||||||
Pre-trained | 73.7 | 56.6 | 82.4 | 91.1 | 85.5 | 85.6 | 94.3 | 87.5 | 82.1 |
Fine-tuned (STL) | 80.2 | 88.5 | 89.2 | 94.4 | 87.2 | 91.7 | 95.2 | 90.9 | 89.6 |
Model Merging Methods | |||||||||
Simple Average | 74.6 | 84.3 | 84.1 | 92.8 | 86.3 | 87.4 | 94.8 | 88.0 | 86.5 |
Task Arithmetic (\(\lambda=0.3\)) | 76.9 | 85.4 | 85.3 | 93.9 | 85.8 | 88.1 | 95.2 | 87.8 | 87.3 |
Ties-Merging (\(\lambda=0.3\)) | 77.1 | 85.1 | 86.3 | 93.9 | 86.0 | 87.7 | 95.1 | 88.0 | 87.4 |
References¶
Seq2SeqLMPool
¶
Bases: BaseModelPool