Image Classification with CLIP Models using HFCLIPClassifier
¶
Introduction¶
The HFCLIPClassifier
class provides a wrapper around the CLIP (Contrastive Language-Image Pre-training) model for image classification tasks. It supports zero-shot learning and can be fine-tuned for specific classification tasks.
Basic Steps¶
Importing Required Modules¶
First, we need to import the necessary modules for our CLIP-based image classification task:
import torch
from transformers import CLIPModel, CLIPProcessor
from fusion_bench.models.hf_clip import HFCLIPClassifier
from torch.utils.data import DataLoader
Loading CLIP Model and Processor¶
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
Initializing HFCLIPClassifier¶
Setting Up the Classification Task¶
After initializing the classifier, we need to set up the classification task by defining class names and optionally, custom text templates. The text encoder of CLIP model is used to encode the class names into text embeddings, which are then used to compute the logits for each class.
class_names = ["cat", "dog", "bird", "fish", "horse"]
classifier.set_classification_task(class_names)
By default, set_classification_task
uses the following templates:
You can also use custom templates:
custom_templates = [
lambda c: f"a photo of a {c}",
lambda c: f"an image containing a {c}",
]
classifier.set_classification_task(class_names, templates=custom_templates)
Below is the code for set_classification_task
and forward
method of HFCLIPClassifier
:
HFCLIPClassifier
¶
Bases: Module
A classifier based on the CLIP (Contrastive Language-Image Pre-training) model.
This class wraps a CLIP model and provides functionality for image classification using zero-shot learning. It allows setting a classification task with custom class names and text templates.
Attributes:
-
clip_model
(CLIPModel
) –The underlying CLIP model.
-
processor
(CLIPProcessor
) –The CLIP processor for preparing inputs.
-
zeroshot_weights
(Tensor
) –Computed text embeddings for zero-shot classification.
-
classnames
(List[str]
) –List of class names for the current classification task.
-
templates
(List[Callable[[str], str]]
) –List of template functions for generating text prompts.
Source code in fusion_bench/models/hf_clip.py
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
|
set_classification_task(classnames, templates=default_templates)
¶
Set up the zero-shot classification task.
This method computes text embeddings for the given class names using the provided templates. These embeddings are then used for classification.
Parameters:
-
classnames
(List[str]
) –List of class names for the classification task.
-
templates
(List[Callable[[str], str]]
, default:default_templates
) –List of template functions for generating text prompts. Defaults to
default_templates
, i.e. ["a photo of a {classname}"].
Source code in fusion_bench/models/hf_clip.py
forward(images)
¶
Perform forward pass for zero-shot image classification.
This method computes image embeddings for the input images and calculates the similarity with the pre-computed text embeddings to produce classification logits.
Parameters:
-
images
(Tensor
) –Input images to classify.
Returns:
-
Tensor
–Classification logits for each input image.
Raises:
-
ValueError
–If the classification task hasn't been set using set_classification_task.
Source code in fusion_bench/models/hf_clip.py
Preparing Your Dataset¶
Create a custom dataset class that loads and preprocesses your images:
from torchvision import transforms
from PIL import Image
class SimpleDataset(torch.utils.data.Dataset):
def __init__(self, image_paths: List[str], labels: List[int]):
self.image_paths = image_paths
self.labels = labels
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = Image.open(self.image_paths[idx]).convert('RGB')
image = self.transform(image)
return image, self.labels[idx]
# Create DataLoader
dataset = SimpleDataset(image_paths, labels) # Replace with your data
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
You can also use fusion_bench.dataset.clip_dataset.CLIPDataset
or fusion_bench.dataset.image_dataset.TransformedImageDataset
to prepare your dataset. Here is examples of using fusion_bench.dataset.clip_dataset.CLIPDataset
and fusion_bench.dataset.image_dataset.TransformedImageDataset
to prepare your dataset:
from fusion_bench.dataset.image_dataset import TransformedImageDataset
dataset = TransformedImageDataset(dataset, transform)
Where dataset
is your original dataset and transform
is the transform you want to apply to the images.
Below is the reference code for these two classes:
CLIPDataset
¶
Bases: Dataset
A dataset class for CLIP models that converts a dataset of dictionaries or tuples into a format suitable for CLIP processing.
This class wraps an existing dataset and applies CLIP preprocessing to the images. It expects each item in the dataset to be either a dictionary with 'image' and 'label' keys, or a tuple/list of (image, label).
Parameters:
-
dataset
–The original dataset to wrap.
-
processor
(CLIPProcessor
) –The CLIP processor for preparing inputs.
Attributes:
-
dataset
–The wrapped dataset.
-
processor
(CLIPProcessor
) –The CLIP processor used for image preprocessing.
Source code in fusion_bench/dataset/clip_dataset.py
__getitem__(idx)
¶
Retrieves and processes an item from the dataset.
Parameters:
-
idx
(int
) –The index of the item to retrieve.
Returns:
-
tuple
–A tuple containing the processed image tensor and the label.
Raises:
-
ValueError
–If the item is neither a dictionary nor a tuple/list of length 2.
Source code in fusion_bench/dataset/clip_dataset.py
TransformedImageDataset
¶
Bases: Dataset
A dataset class for image classification tasks that applies a transform to images.
This class wraps an existing dataset and applies a specified transform to the images. It expects each item in the dataset to be either a dictionary with 'image' and 'label' keys, or a tuple/list of (image, label).
Parameters:
-
dataset
–The original dataset to wrap.
-
transform
(Callable
) –A function/transform to apply on the image.
Attributes:
-
dataset
–The wrapped dataset.
-
transform
(Callable
) –The transform to be applied to the images.
Source code in fusion_bench/dataset/image_dataset.py
__getitem__(idx)
¶
Retrieves and processes an item from the dataset.
Parameters:
-
idx
(int
) –The index of the item to retrieve.
Returns:
-
tuple
(Tuple[Any, Any]
) –A tuple containing the processed image and the label.
Raises:
-
ValueError
–If the item is neither a dictionary nor a tuple/list of length 2.
Source code in fusion_bench/dataset/image_dataset.py
Inference¶
Perform inference on your dataset:
classifier.eval()
with torch.no_grad():
for images, labels in dataloader:
logits = classifier(images)
predictions = torch.argmax(logits, dim=1)
# Process predictions as needed
Fine-tuning (Optional)¶
If you want to fine-tune the model:
optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-5)
criterion = torch.nn.CrossEntropyLoss()
classifier.train()
for epoch in range(num_epochs):
for images, labels in dataloader:
optimizer.zero_grad()
logits = classifier(images)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
Advanced Usage¶
Custom Templates¶
You can provide custom templates when setting up the classification task:
custom_templates = [
lambda c: f"a photo of a {c}",
lambda c: f"an image containing a {c}",
]
classifier.set_classification_task(class_names, templates=custom_templates)
Accessing Model Components¶
You can access the text and vision models directly:
Working with Zero-shot Weights¶
After setting the classification task, you can access the zero-shot weights:
These weights represent the text embeddings for each class and can be used for further analysis or custom processing.
Remember to adjust the code according to your specific dataset and requirements. This documentation provides a comprehensive guide for using the HFCLIPClassifier
for image classification tasks with CLIP models.