LLaMA-3¶
Llama-3.1-8B¶
This configuration includes the pretrained base model along with domain-specific fine-tuned models from MergeBench:
config/modelpool/CausalLMPool/mergebench/Llama-3.1-8B.yaml
_target_: fusion_bench.modelpool.CausalLMPool
models:
_pretrained_: meta-llama/Llama-3.1-8B
instruction: MergeBench/Llama-3.1-8B_instruction
math: MergeBench/Llama-3.1-8B_math
coding: MergeBench/Llama-3.1-8B_coding
multilingual: MergeBench/Llama-3.1-8B_multilingual
safety: MergeBench/Llama-3.1-8B_safety
model_kwargs:
torch_dtype: bfloat16
tokenizer: meta-llama/Llama-3.1-8B
This configuration focuses on instruction-tuned variants:
config/modelpool/CausalLMPool/mergebench/Llama-3.1-8B-Instruct.yaml
_target_: fusion_bench.modelpool.CausalLMPool
models:
_pretrained_: meta-llama/Llama-3.1-8B-Instruct
instruction: MergeBench/Llama-3.1-8B-Instruct_instruction
math: MergeBench/Llama-3.1-8B-Instruct_math
coding: MergeBench/Llama-3.1-8B-Instruct_coding
multilingual: MergeBench/Llama-3.1-8B-Instruct_multilingual
safety: MergeBench/Llama-3.1-8B-Instruct_safety
model_kwargs:
torch_dtype: bfloat16
tokenizer: meta-llama/Llama-3.1-8B-Instruct
Model Fusion Experiments¶
Simple Average¶
fusion_bench path.log_dir=outputs/llama-3.1-8b/simple_average \
method=linear/simple_average_for_causallm \
modelpool=CausalLMPool/mergebench/Llama-3.1-8B
Llama-3.2-3B¶
This configuration includes the pretrained base model along with domain-specific fine-tuned models from MergeBench:
config/modelpool/CausalLMPool/mergebench/Llama-3.2-3B.yaml
_target_: fusion_bench.modelpool.CausalLMPool
models:
_pretrained_: meta-llama/Llama-3.2-3B
instruction: MergeBench/Llama-3.2-3B_instruction
math: MergeBench/Llama-3.2-3B_math
coding: MergeBench/Llama-3.2-3B_coding
multilingual: MergeBench/Llama-3.2-3B_multilingual
safety: MergeBench/Llama-3.2-3B_safety
model_kwargs:
torch_dtype: bfloat16
tokenizer: meta-llama/Llama-3.2-3B
This configuration focuses on instruction-tuned variants:
config/modelpool/CausalLMPool/mergebench/Llama-3.2-3B-Instruct.yaml
_target_: fusion_bench.modelpool.CausalLMPool
models:
_pretrained_: meta-llama/Llama-3.2-3B-Instruct
instruction: MergeBench/Llama-3.2-3B-Instruct_instruction
math: MergeBench/Llama-3.2-3B-Instruct_math
coding: MergeBench/Llama-3.2-3B-Instruct_coding
multilingual: MergeBench/Llama-3.2-3B-Instruct_multilingual
safety: MergeBench/Llama-3.2-3B-Instruct_safety
model_kwargs:
torch_dtype: bfloat16
tokenizer: meta-llama/Llama-3.2-3B-Instruct