Skip to content

Image Classification with CLIP Models using HFCLIPClassifier

Introduction

The HFCLIPClassifier class provides a wrapper around the CLIP (Contrastive Language-Image Pre-training) model for image classification tasks. It supports zero-shot learning and can be fine-tuned for specific classification tasks.

Basic Steps

Importing Required Modules

First, we need to import the necessary modules for our CLIP-based image classification task:

import torch
from transformers import CLIPModel, CLIPProcessor
from fusion_bench.models.hf_clip import HFCLIPClassifier
from torch.utils.data import DataLoader

Loading CLIP Model and Processor

clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

Initializing HFCLIPClassifier

classifier = HFCLIPClassifier(clip_model, processor)

Setting Up the Classification Task

After initializing the classifier, we need to set up the classification task by defining class names and optionally, custom text templates. The text encoder of CLIP model is used to encode the class names into text embeddings, which are then used to compute the logits for each class.

class_names = ["cat", "dog", "bird", "fish", "horse"]
classifier.set_classification_task(class_names)

By default, set_classification_task uses the following templates:

default_templates = [
    lambda c: f"a photo of a {c}",
]

You can also use custom templates:

custom_templates = [
    lambda c: f"a photo of a {c}",
    lambda c: f"an image containing a {c}",
]
classifier.set_classification_task(class_names, templates=custom_templates)

Below is the code for set_classification_task and forward method of HFCLIPClassifier:

HFCLIPClassifier

Bases: Module

A classifier based on the CLIP (Contrastive Language-Image Pre-training) model.

This class wraps a CLIP model and provides functionality for image classification using zero-shot learning. It allows setting a classification task with custom class names and text templates.

Attributes:

  • clip_model (CLIPModel) –

    The underlying CLIP model.

  • processor (CLIPProcessor) –

    The CLIP processor for preparing inputs.

  • zeroshot_weights (Tensor) –

    Computed text embeddings for zero-shot classification.

  • classnames (List[str]) –

    List of class names for the current classification task.

  • templates (List[Callable[[str], str]]) –

    List of template functions for generating text prompts.

Source code in fusion_bench/models/hf_clip.py
class HFCLIPClassifier(nn.Module):
    """
    A classifier based on the CLIP (Contrastive Language-Image Pre-training) model.

    This class wraps a CLIP model and provides functionality for image classification
    using zero-shot learning. It allows setting a classification task with custom
    class names and text templates.

    Attributes:
        clip_model (CLIPModel): The underlying CLIP model.
        processor (CLIPProcessor): The CLIP processor for preparing inputs.
        zeroshot_weights (Tensor): Computed text embeddings for zero-shot classification.
        classnames (List[str]): List of class names for the current classification task.
        templates (List[Callable[[str], str]]): List of template functions for generating text prompts.

    """

    def __init__(
        self,
        clip_model: CLIPModel,
        processor: CLIPProcessor,
    ):
        """
        Initialize the HFCLIPClassifier.

        Args:
            clip_model (CLIPModel): The CLIP model to use for classification.
            processor (CLIPProcessor): The CLIP processor for preparing inputs.
        """
        super().__init__()
        # we only fine-tune the vision model
        clip_model.visual_projection.requires_grad_(False)
        clip_model.text_model.requires_grad_(False)
        clip_model.text_projection.requires_grad_(False)
        clip_model.logit_scale.requires_grad_(False)

        self.clip_model = clip_model
        self.processor = processor
        self.register_buffer(
            "zeroshot_weights",
            None,
            persistent=False,
        )

    @property
    def text_model(self):
        """Get the text model component of CLIP."""
        return self.clip_model.text_model

    @property
    def vision_model(self):
        """Get the vision model component of CLIP."""
        return self.clip_model.vision_model

    def set_classification_task(
        self,
        classnames: List[str],
        templates: List[Callable[[str], str]] = default_templates,
    ):
        """
        Set up the zero-shot classification task.

        This method computes text embeddings for the given class names using the
        provided templates. These embeddings are then used for classification.

        Args:
            classnames (List[str]): List of class names for the classification task.
            templates (List[Callable[[str], str]], optional): List of template functions
                for generating text prompts. Defaults to `default_templates`, i.e.
                ["a photo of a {classname}"].
        """
        processor = self.processor

        self.classnames = classnames
        self.templates = templates

        with torch.no_grad():
            zeroshot_weights = []
            for classname in classnames:
                text = [template(classname) for template in templates]
                inputs = processor(text=text, return_tensors="pt", padding=True)

                embeddings = self.text_model(**inputs)[1]
                embeddings = self.clip_model.text_projection(embeddings)

                # normalize embeddings
                embeddings = embeddings / embeddings.norm(p=2, dim=-1, keepdim=True)

                embeddings = embeddings.mean(dim=0)
                embeddings = embeddings / embeddings.norm(p=2, dim=-1, keepdim=True)

                zeroshot_weights.append(embeddings)

            zeroshot_weights = torch.stack(zeroshot_weights, dim=0)

        self.zeroshot_weights = zeroshot_weights

    def forward(self, images):
        """
        Perform forward pass for zero-shot image classification.

        This method computes image embeddings for the input images and calculates
        the similarity with the pre-computed text embeddings to produce classification logits.

        Args:
            images (Tensor): Input images to classify.

        Returns:
            Tensor: Classification logits for each input image.

        Raises:
            ValueError: If the classification task hasn't been set using set_classification_task.
        """
        if self.zeroshot_weights is None:
            raise ValueError("Must set classification task before forward pass")
        text_embeds = self.zeroshot_weights

        image_embeds = self.vision_model(images)
        image_embeds = image_embeds[1]
        image_embeds = self.clip_model.visual_projection(image_embeds)

        # normalize embeddings
        image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)

        # cosine similarity
        logit_scale = self.clip_model.logit_scale.exp()
        logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
        logits_per_image = logits_per_text.t()

        return logits_per_image
set_classification_task(classnames, templates=default_templates)

Set up the zero-shot classification task.

This method computes text embeddings for the given class names using the provided templates. These embeddings are then used for classification.

Parameters:

  • classnames (List[str]) –

    List of class names for the classification task.

  • templates (List[Callable[[str], str]], default: default_templates ) –

    List of template functions for generating text prompts. Defaults to default_templates, i.e. ["a photo of a {classname}"].

Source code in fusion_bench/models/hf_clip.py
def set_classification_task(
    self,
    classnames: List[str],
    templates: List[Callable[[str], str]] = default_templates,
):
    """
    Set up the zero-shot classification task.

    This method computes text embeddings for the given class names using the
    provided templates. These embeddings are then used for classification.

    Args:
        classnames (List[str]): List of class names for the classification task.
        templates (List[Callable[[str], str]], optional): List of template functions
            for generating text prompts. Defaults to `default_templates`, i.e.
            ["a photo of a {classname}"].
    """
    processor = self.processor

    self.classnames = classnames
    self.templates = templates

    with torch.no_grad():
        zeroshot_weights = []
        for classname in classnames:
            text = [template(classname) for template in templates]
            inputs = processor(text=text, return_tensors="pt", padding=True)

            embeddings = self.text_model(**inputs)[1]
            embeddings = self.clip_model.text_projection(embeddings)

            # normalize embeddings
            embeddings = embeddings / embeddings.norm(p=2, dim=-1, keepdim=True)

            embeddings = embeddings.mean(dim=0)
            embeddings = embeddings / embeddings.norm(p=2, dim=-1, keepdim=True)

            zeroshot_weights.append(embeddings)

        zeroshot_weights = torch.stack(zeroshot_weights, dim=0)

    self.zeroshot_weights = zeroshot_weights
forward(images)

Perform forward pass for zero-shot image classification.

This method computes image embeddings for the input images and calculates the similarity with the pre-computed text embeddings to produce classification logits.

Parameters:

  • images (Tensor) –

    Input images to classify.

Returns:

  • Tensor

    Classification logits for each input image.

Raises:

  • ValueError

    If the classification task hasn't been set using set_classification_task.

Source code in fusion_bench/models/hf_clip.py
def forward(self, images):
    """
    Perform forward pass for zero-shot image classification.

    This method computes image embeddings for the input images and calculates
    the similarity with the pre-computed text embeddings to produce classification logits.

    Args:
        images (Tensor): Input images to classify.

    Returns:
        Tensor: Classification logits for each input image.

    Raises:
        ValueError: If the classification task hasn't been set using set_classification_task.
    """
    if self.zeroshot_weights is None:
        raise ValueError("Must set classification task before forward pass")
    text_embeds = self.zeroshot_weights

    image_embeds = self.vision_model(images)
    image_embeds = image_embeds[1]
    image_embeds = self.clip_model.visual_projection(image_embeds)

    # normalize embeddings
    image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)

    # cosine similarity
    logit_scale = self.clip_model.logit_scale.exp()
    logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
    logits_per_image = logits_per_text.t()

    return logits_per_image

Preparing Your Dataset

Create a custom dataset class that loads and preprocesses your images:

from torchvision import transforms
from PIL import Image

class SimpleDataset(torch.utils.data.Dataset):
    def __init__(self, image_paths: List[str], labels: List[int]):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        image = self.transform(image)
        return image, self.labels[idx]

# Create DataLoader
dataset = SimpleDataset(image_paths, labels)  # Replace with your data
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

You can also use fusion_bench.dataset.clip_dataset.CLIPDataset or fusion_bench.dataset.image_dataset.TransformedImageDataset to prepare your dataset. Here is examples of using fusion_bench.dataset.clip_dataset.CLIPDataset and fusion_bench.dataset.image_dataset.TransformedImageDataset to prepare your dataset:

from fusion_bench.dataset.clip_dataset import CLIPDataset

dataset = CLIPDataset(dataset, processor)
from fusion_bench.dataset.image_dataset import TransformedImageDataset

dataset = TransformedImageDataset(dataset, transform)

Where dataset is your original dataset and transform is the transform you want to apply to the images. Below is the reference code for these two classes:

CLIPDataset

Bases: Dataset

A dataset class for CLIP models that converts a dataset of dictionaries or tuples into a format suitable for CLIP processing.

This class wraps an existing dataset and applies CLIP preprocessing to the images. It expects each item in the dataset to be either a dictionary with 'image' and 'label' keys, or a tuple/list of (image, label).

Parameters:

  • dataset

    The original dataset to wrap.

  • processor (CLIPProcessor) –

    The CLIP processor for preparing inputs.

Attributes:

  • dataset

    The wrapped dataset.

  • processor (CLIPProcessor) –

    The CLIP processor used for image preprocessing.

Source code in fusion_bench/dataset/clip_dataset.py
class CLIPDataset(torch.utils.data.Dataset):
    """
    A dataset class for CLIP models that converts a dataset of dictionaries or tuples
    into a format suitable for CLIP processing.

    This class wraps an existing dataset and applies CLIP preprocessing to the images.
    It expects each item in the dataset to be either a dictionary with 'image' and 'label' keys,
    or a tuple/list of (image, label).

    Args:
        dataset: The original dataset to wrap.
        processor (CLIPProcessor): The CLIP processor for preparing inputs.

    Attributes:
        dataset: The wrapped dataset.
        processor (CLIPProcessor): The CLIP processor used for image preprocessing.
    """

    def __init__(self, dataset, processor: CLIPProcessor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        """Returns the number of items in the dataset."""
        return len(self.dataset)

    def __getitem__(self, idx: int):
        """
        Retrieves and processes an item from the dataset.

        Args:
            idx (int): The index of the item to retrieve.

        Returns:
            tuple: A tuple containing the processed image tensor and the label.

        Raises:
            ValueError: If the item is neither a dictionary nor a tuple/list of length 2.
        """
        item = self.dataset[idx]
        if isinstance(item, dict):
            item = item
        elif isinstance(item, (tuple, list)):
            assert len(item) == 2, "Each item should be a tuple or list of length 2"
            item = {"image": item[0], "label": item[1]}
        else:
            raise ValueError("Each item should be a dictionary or a tuple of length 2")
        image = item["image"]
        inputs = self.processor(images=[image], return_tensors="pt")["pixel_values"][0]
        return inputs, item["label"]
__getitem__(idx)

Retrieves and processes an item from the dataset.

Parameters:

  • idx (int) –

    The index of the item to retrieve.

Returns:

  • tuple

    A tuple containing the processed image tensor and the label.

Raises:

  • ValueError

    If the item is neither a dictionary nor a tuple/list of length 2.

Source code in fusion_bench/dataset/clip_dataset.py
def __getitem__(self, idx: int):
    """
    Retrieves and processes an item from the dataset.

    Args:
        idx (int): The index of the item to retrieve.

    Returns:
        tuple: A tuple containing the processed image tensor and the label.

    Raises:
        ValueError: If the item is neither a dictionary nor a tuple/list of length 2.
    """
    item = self.dataset[idx]
    if isinstance(item, dict):
        item = item
    elif isinstance(item, (tuple, list)):
        assert len(item) == 2, "Each item should be a tuple or list of length 2"
        item = {"image": item[0], "label": item[1]}
    else:
        raise ValueError("Each item should be a dictionary or a tuple of length 2")
    image = item["image"]
    inputs = self.processor(images=[image], return_tensors="pt")["pixel_values"][0]
    return inputs, item["label"]
__len__()

Returns the number of items in the dataset.

Source code in fusion_bench/dataset/clip_dataset.py
def __len__(self):
    """Returns the number of items in the dataset."""
    return len(self.dataset)

TransformedImageDataset

Bases: Dataset

A dataset class for image classification tasks that applies a transform to images.

This class wraps an existing dataset and applies a specified transform to the images. It expects each item in the dataset to be either a dictionary with 'image' and 'label' keys, or a tuple/list of (image, label).

Parameters:

  • dataset

    The original dataset to wrap.

  • transform (Callable) –

    A function/transform to apply on the image.

Attributes:

  • dataset

    The wrapped dataset.

  • transform (Callable) –

    The transform to be applied to the images.

Source code in fusion_bench/dataset/image_dataset.py
class TransformedImageDataset(Dataset):
    """
    A dataset class for image classification tasks that applies a transform to images.

    This class wraps an existing dataset and applies a specified transform to the images.
    It expects each item in the dataset to be either a dictionary with 'image' and 'label' keys,
    or a tuple/list of (image, label).

    Args:
        dataset: The original dataset to wrap.
        transform (Callable): A function/transform to apply on the image.

    Attributes:
        dataset: The wrapped dataset.
        transform (Callable): The transform to be applied to the images.
    """

    def __init__(self, dataset, transform: Callable):
        super().__init__()
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        """Returns the number of items in the dataset."""
        return len(self.dataset)

    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
        """
        Retrieves and processes an item from the dataset.

        Args:
            idx (int): The index of the item to retrieve.

        Returns:
            tuple: A tuple containing the processed image and the label.

        Raises:
            ValueError: If the item is neither a dictionary nor a tuple/list of length 2.
        """
        item = self.dataset[idx]
        if isinstance(item, dict):
            item = item
        elif isinstance(item, (tuple, list)):
            assert len(item) == 2, "Each item should be a tuple or list of length 2"
            item = {"image": item[0], "label": item[1]}
        else:
            raise ValueError("Each item should be a dictionary or a tuple of length 2")
        image = item["image"]
        inputs = self.transform(image)
        return inputs, item["label"]
__getitem__(idx)

Retrieves and processes an item from the dataset.

Parameters:

  • idx (int) –

    The index of the item to retrieve.

Returns:

  • tuple ( Tuple[Any, Any] ) –

    A tuple containing the processed image and the label.

Raises:

  • ValueError

    If the item is neither a dictionary nor a tuple/list of length 2.

Source code in fusion_bench/dataset/image_dataset.py
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
    """
    Retrieves and processes an item from the dataset.

    Args:
        idx (int): The index of the item to retrieve.

    Returns:
        tuple: A tuple containing the processed image and the label.

    Raises:
        ValueError: If the item is neither a dictionary nor a tuple/list of length 2.
    """
    item = self.dataset[idx]
    if isinstance(item, dict):
        item = item
    elif isinstance(item, (tuple, list)):
        assert len(item) == 2, "Each item should be a tuple or list of length 2"
        item = {"image": item[0], "label": item[1]}
    else:
        raise ValueError("Each item should be a dictionary or a tuple of length 2")
    image = item["image"]
    inputs = self.transform(image)
    return inputs, item["label"]
__len__()

Returns the number of items in the dataset.

Source code in fusion_bench/dataset/image_dataset.py
def __len__(self):
    """Returns the number of items in the dataset."""
    return len(self.dataset)

Inference

Perform inference on your dataset:

classifier.eval()
with torch.no_grad():
    for images, labels in dataloader:
        logits = classifier(images)
        predictions = torch.argmax(logits, dim=1)
        # Process predictions as needed

Fine-tuning (Optional)

If you want to fine-tune the model:

optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-5)
criterion = torch.nn.CrossEntropyLoss()

classifier.train()
for epoch in range(num_epochs):
    for images, labels in dataloader:
        optimizer.zero_grad()
        logits = classifier(images)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

Advanced Usage

Custom Templates

You can provide custom templates when setting up the classification task:

custom_templates = [
    lambda c: f"a photo of a {c}",
    lambda c: f"an image containing a {c}",
]
classifier.set_classification_task(class_names, templates=custom_templates)

Accessing Model Components

You can access the text and vision models directly:

text_model = classifier.text_model
vision_model = classifier.vision_model

Working with Zero-shot Weights

After setting the classification task, you can access the zero-shot weights:

zeroshot_weights = classifier.zeroshot_weights

These weights represent the text embeddings for each class and can be used for further analysis or custom processing.

Remember to adjust the code according to your specific dataset and requirements. This documentation provides a comprehensive guide for using the HFCLIPClassifier for image classification tasks with CLIP models.