Skip to content

Fine-Tune Your Own Vision Transformer

In this guide, we will show you how to fine-tune your own Vision Transformer (ViT) model on a custom dataset using fusion_bench CLI. FusionBench provides a simple and easy-to-use interface to fine-tune clip vision transformer in a single-task learning setting or traditional multi-task learning setting.

Basic Examples

Single-Task Learning

Refer to examples/clip_finetune/clip_finetune.sh for a complete example of fine-tuning a CLIP-ViT model, including full fine-tuning, lora fine-tuning and linearized lora fine-tuning.

Multi-Task Learning

Fine-tune CLIP-ViT-B/32:

fusion_bench \
    method=clip_finetune \
    modelpool=clip-vit-base-patch32_mtl \
    taskpool=dummy

Fine-tune CLIP-ViT-L/14 on eight GPUs with a per-device per-task batch size of 2.

fusion_bench \
    fabric.devices=8 \
    method=clip_finetune \
        method.batch_size=2 \
    modelpool=clip-vit-base-patch32_mtl \
        modelpool.models.0.path=openai/clip-vit-large-patch14 \
    taskpool=dummy

This will save the state dict of the vision model (transformers.models.clip.CLIPVisionModel.CLIPVisionTransformer) to the log directory. Subsequently, we can use fusion_bench/scripts/clip/convert_checkpoint.py to convert the state dict to a HuggingFace model (CLIPVisionModel).

config/method/clip_finetune.yaml
name: clip_finetune

seed: 42

learning_rate: 1e-5
num_steps: 4000

batch_size: 32
num_workers: 4

save_interval: 500
config/modelpool/clip-vit-base-patch32_mtl.yaml
type: huggingface_clip_vision
models:
  - name: _pretrained_
    path: openai/clip-vit-base-patch32

dataset_type: huggingface_image_classification
train_datasets:
  - name: svhn
    dataset:
      type: instantiate
      name: svhn
      object:
        _target_: datasets.load_dataset
        _args_:
          - svhn
          - cropped_digits
        split: train
  - name: stanford_cars
    dataset:
      name: tanganke/stanford_cars
      split: train
  # other datasets
  # ...
# or CLIP-ViT-L/14, add option: --model openai/clip-vit-large-patch14
python fusion_bench/scripts/clip/convert_checkpoint.py \
    --checkpoint /path/to/checkpoint \
    --output /path/to/output

After converting the checkpoint, you can use FusionBench to evaluate the model. For example, you can use the following command to evaluate the model on the eight tasks documented here.

path_to_clip_model=/path/to/converted/output
fusion_bench method=dummy \
  modelpool=clip-vit-base-patch32_individual \
    modelpool.models.0.path="'${path_to_clip_model}'" \
  taskpool=clip-vit-classification_TA8

Single-Task Learning

Simply remove some of the datasets from the train_datasets field in the model pool configuration.

References

ImageClassificationFineTuningForCLIP

Bases: CLIPClassificationMixin, SimpleProfilerMixin, ModelFusionAlgorithm

A class for fine-tuning CLIP models for image classification tasks.

Source code in fusion_bench/method/classification/clip_finetune.py
class ImageClassificationFineTuningForCLIP(
    CLIPClassificationMixin,
    SimpleProfilerMixin,
    ModelFusionAlgorithm,
):
    """
    A class for fine-tuning CLIP models for image classification tasks.
    """

    def run(self, modelpool: HuggingFaceClipVisionPool):
        """
        Executes the fine-tuning process.

        Args:
            modelpool (HuggingFaceClipVisionPool): The modelpool is responsible for loading the pre-trained model and training datasets.

        Returns:
            VisionModel: The fine-tuned vision model.
        """
        self.modelpool = to_modelpool(modelpool)
        config = self.config
        self.log_hyperparams(config, filename="method_config.yaml")
        self.finetune_method = "fine-tune"

        L.seed_everything(config.seed)

        task_names = [
            dataset_config["name"] for dataset_config in modelpool.config.train_datasets
        ]
        with self.profile("setup model and optimizer"):
            processor, classifier, optimizer, lr_scheduler = self.setup_model()

            if config.state_dict_load_path is not None:
                self.fabric.load(
                    config.state_dict_load_path,
                    {"vision_model": classifier.clip_model.vision_model},
                )
                if config.skip_training:
                    return classifier.clip_model.vision_model

            self.setup_zero_shot_classification_head(
                clip_processor=processor,
                clip_model=classifier.clip_model,
                task_names=task_names,
            )

            self.fabric.setup(classifier, optimizer)

        with self.profile("setup data"):
            train_datasets = [
                modelpool.get_train_dataset(task_name, processor)
                for task_name in task_names
            ]
            train_dataloaders = [
                DataLoader(
                    dataset,
                    shuffle=True,
                    batch_size=config.batch_size,
                    num_workers=config.num_workers,
                )
                for dataset in train_datasets
            ]
            train_dataloaders = self.fabric.setup_dataloaders(*train_dataloaders)
            if not isinstance(train_dataloaders, (list, tuple)):
                train_dataloaders = [train_dataloaders]
            train_dataloader_iters = [
                iter(InfiniteDataLoader(loader)) for loader in train_dataloaders
            ]

        # train
        for step_idx in tqdm(
            range(config.num_steps),
            desc=self.finetune_method,
            disable=not self.fabric.is_global_zero,
        ):
            optimizer.zero_grad()
            loss = 0
            for task, loader in zip(task_names, train_dataloader_iters):
                with self.profile("data loading"):
                    batch = next(loader)
                    images, labels = batch
                with self.profile("forward"):
                    classifier.zeroshot_weights = self.zeroshot_weights[task]
                    logits = classifier(images)
                loss = loss + nn.functional.cross_entropy(logits, labels)

            with self.profile("backward"):
                self.fabric.backward(loss)
            with self.profile("optimizer step"):
                optimizer.step()
                lr_scheduler.step()

            metrics = {"train/loss": loss}

            self.fabric.log_dict(metrics, step=step_idx)

            if (step_idx + 1) % config.save_interval == 0:
                save_path = os.path.join(
                    self.log_dir, "checkpoints", f"step={step_idx}.ckpt"
                )
                self.save_model(classifier, save_path, trainable_only=True)

        if config.state_dict_save_path is not None:
            self.save_model(
                classifier, config.state_dict_save_path, trainable_only=True
            )
        self.print_profile_summary()
        return classifier.clip_model.vision_model

    def save_model(
        self,
        model: HFCLIPClassifier | CLIPModel | CLIPVisionModel | CLIPVisionTransformer,
        save_path: str,
    ):
        if isinstance(model, HFCLIPClassifier):
            vision_model = model.clip_model.vision_model
        elif isinstance(model, CLIPModel):
            vision_model = model.vision_model
        elif isinstance(model, CLIPVisionModel):
            vision_model = model.vision_model
        elif isinstance(model, CLIPVisionTransformer):
            vision_model = model
        else:
            raise ValueError(f"Unsupported model type: {type(model)}")

        save_dir = os.path.dirname(save_path)
        if save_dir and not os.path.exists(save_dir):
            os.makedirs(save_dir, exist_ok=True)
        self.fabric.save(save_path, {"vision_model": vision_model})

    def setup_model(self):
        """
        Sets up the model, optimizer, and learning rate scheduler.
        """
        config = self.config
        modelpool = self.modelpool

        pretrained_model_config = modelpool.get_model_config("_pretrained_")
        clip_model: CLIPModel = CLIPModel.from_pretrained(pretrained_model_config.path)
        processor = CLIPProcessor.from_pretrained(pretrained_model_config.path)

        self.finetune_method = "full fine-tune"
        if config.use_lora or config.use_l_lora:
            self.finetune_method = "lora fine-tune"
            lora_config = LoraConfig(
                **OmegaConf.to_container(
                    config.lora_config, resolve=True, enum_to_str=True
                )
            )
            clip_model.vision_model = get_peft_model(
                clip_model.vision_model, lora_config
            )

            if config.use_l_lora:
                # http://arxiv.org/abs/2310.04742
                # Anke Tang et al. Parameter Efficient Multi-task Model Fusion with Partial Linearization. ICLR 2024.
                self.finetune_method = "l-lora fine-tune"
                print("Linearizing Lora Layers")
                linearize_lora_model_(clip_model.vision_model)

        classifier = HFCLIPClassifier(clip_model, processor=processor)

        if self.fabric.is_global_zero:
            print("=== Model Summary (For Vision Model Only) ===")
            print_parameters(classifier.clip_model.vision_model)
        # configure optimizers
        optimizer = torch.optim.Adam(
            [
                p
                for p in classifier.clip_model.vision_model.parameters()
                if p.requires_grad
            ],
            lr=config.learning_rate,
            weight_decay=config.weight_decay,
        )
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer=optimizer, T_max=config.num_steps
        )

        return processor, classifier, optimizer, lr_scheduler
run(modelpool)

Executes the fine-tuning process.

Parameters:

  • modelpool (HuggingFaceClipVisionPool) –

    The modelpool is responsible for loading the pre-trained model and training datasets.

Returns:

  • VisionModel

    The fine-tuned vision model.

Source code in fusion_bench/method/classification/clip_finetune.py
def run(self, modelpool: HuggingFaceClipVisionPool):
    """
    Executes the fine-tuning process.

    Args:
        modelpool (HuggingFaceClipVisionPool): The modelpool is responsible for loading the pre-trained model and training datasets.

    Returns:
        VisionModel: The fine-tuned vision model.
    """
    self.modelpool = to_modelpool(modelpool)
    config = self.config
    self.log_hyperparams(config, filename="method_config.yaml")
    self.finetune_method = "fine-tune"

    L.seed_everything(config.seed)

    task_names = [
        dataset_config["name"] for dataset_config in modelpool.config.train_datasets
    ]
    with self.profile("setup model and optimizer"):
        processor, classifier, optimizer, lr_scheduler = self.setup_model()

        if config.state_dict_load_path is not None:
            self.fabric.load(
                config.state_dict_load_path,
                {"vision_model": classifier.clip_model.vision_model},
            )
            if config.skip_training:
                return classifier.clip_model.vision_model

        self.setup_zero_shot_classification_head(
            clip_processor=processor,
            clip_model=classifier.clip_model,
            task_names=task_names,
        )

        self.fabric.setup(classifier, optimizer)

    with self.profile("setup data"):
        train_datasets = [
            modelpool.get_train_dataset(task_name, processor)
            for task_name in task_names
        ]
        train_dataloaders = [
            DataLoader(
                dataset,
                shuffle=True,
                batch_size=config.batch_size,
                num_workers=config.num_workers,
            )
            for dataset in train_datasets
        ]
        train_dataloaders = self.fabric.setup_dataloaders(*train_dataloaders)
        if not isinstance(train_dataloaders, (list, tuple)):
            train_dataloaders = [train_dataloaders]
        train_dataloader_iters = [
            iter(InfiniteDataLoader(loader)) for loader in train_dataloaders
        ]

    # train
    for step_idx in tqdm(
        range(config.num_steps),
        desc=self.finetune_method,
        disable=not self.fabric.is_global_zero,
    ):
        optimizer.zero_grad()
        loss = 0
        for task, loader in zip(task_names, train_dataloader_iters):
            with self.profile("data loading"):
                batch = next(loader)
                images, labels = batch
            with self.profile("forward"):
                classifier.zeroshot_weights = self.zeroshot_weights[task]
                logits = classifier(images)
            loss = loss + nn.functional.cross_entropy(logits, labels)

        with self.profile("backward"):
            self.fabric.backward(loss)
        with self.profile("optimizer step"):
            optimizer.step()
            lr_scheduler.step()

        metrics = {"train/loss": loss}

        self.fabric.log_dict(metrics, step=step_idx)

        if (step_idx + 1) % config.save_interval == 0:
            save_path = os.path.join(
                self.log_dir, "checkpoints", f"step={step_idx}.ckpt"
            )
            self.save_model(classifier, save_path, trainable_only=True)

    if config.state_dict_save_path is not None:
        self.save_model(
            classifier, config.state_dict_save_path, trainable_only=True
        )
    self.print_profile_summary()
    return classifier.clip_model.vision_model
setup_model()

Sets up the model, optimizer, and learning rate scheduler.

Source code in fusion_bench/method/classification/clip_finetune.py
def setup_model(self):
    """
    Sets up the model, optimizer, and learning rate scheduler.
    """
    config = self.config
    modelpool = self.modelpool

    pretrained_model_config = modelpool.get_model_config("_pretrained_")
    clip_model: CLIPModel = CLIPModel.from_pretrained(pretrained_model_config.path)
    processor = CLIPProcessor.from_pretrained(pretrained_model_config.path)

    self.finetune_method = "full fine-tune"
    if config.use_lora or config.use_l_lora:
        self.finetune_method = "lora fine-tune"
        lora_config = LoraConfig(
            **OmegaConf.to_container(
                config.lora_config, resolve=True, enum_to_str=True
            )
        )
        clip_model.vision_model = get_peft_model(
            clip_model.vision_model, lora_config
        )

        if config.use_l_lora:
            # http://arxiv.org/abs/2310.04742
            # Anke Tang et al. Parameter Efficient Multi-task Model Fusion with Partial Linearization. ICLR 2024.
            self.finetune_method = "l-lora fine-tune"
            print("Linearizing Lora Layers")
            linearize_lora_model_(clip_model.vision_model)

    classifier = HFCLIPClassifier(clip_model, processor=processor)

    if self.fabric.is_global_zero:
        print("=== Model Summary (For Vision Model Only) ===")
        print_parameters(classifier.clip_model.vision_model)
    # configure optimizers
    optimizer = torch.optim.Adam(
        [
            p
            for p in classifier.clip_model.vision_model.parameters()
            if p.requires_grad
        ],
        lr=config.learning_rate,
        weight_decay=config.weight_decay,
    )
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer=optimizer, T_max=config.num_steps
    )

    return processor, classifier, optimizer, lr_scheduler