Image Classification with CLIP Models using HFCLIPClassifier
Introduction
The HFCLIPClassifier
class provides a wrapper around the CLIP (Contrastive Language-Image Pre-training) model for image classification tasks. It supports zero-shot learning and can be fine-tuned for specific classification tasks.
Basic Steps
Importing Required Modules
First, we need to import the necessary modules for our CLIP-based image classification task:
import torch
from transformers import CLIPModel, CLIPProcessor
from fusion_bench.models.hf_clip import HFCLIPClassifier
from torch.utils.data import DataLoader
Loading CLIP Model and Processor
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
Initializing HFCLIPClassifier
classifier = HFCLIPClassifier(clip_model, processor)
Setting Up the Classification Task
After initializing the classifier, we need to set up the classification task by defining class names and optionally, custom text templates.
The text encoder of CLIP model is used to encode the class names into text embeddings, which are then used to compute the logits for each class.
class_names = ["cat", "dog", "bird", "fish", "horse"]
classifier.set_classification_task(class_names)
By default, set_classification_task
uses the following templates:
default_templates = [
lambda c: f"a photo of a {c}",
]
You can also use custom templates:
custom_templates = [
lambda c: f"a photo of a {c}",
lambda c: f"an image containing a {c}",
]
classifier.set_classification_task(class_names, templates=custom_templates)
Below is the code for set_classification_task
and forward
method of HFCLIPClassifier
:
HFCLIPClassifier
Bases: Module
A classifier based on the CLIP (Contrastive Language-Image Pre-training) model.
This class wraps a CLIP model and provides functionality for image classification
using zero-shot learning. It allows setting a classification task with custom
class names and text templates.
Attributes:
-
clip_model
(CLIPModel
)
–
The underlying CLIP model.
-
processor
(CLIPProcessor
)
–
The CLIP processor for preparing inputs.
-
zeroshot_weights
(Tensor
)
–
Computed text embeddings for zero-shot classification.
-
classnames
(List[str]
)
–
List of class names for the current classification task.
-
templates
(List[Callable[[str], str]]
)
–
List of template functions for generating text prompts.
Source code in fusion_bench/models/hf_clip.py
| class HFCLIPClassifier(nn.Module):
"""
A classifier based on the CLIP (Contrastive Language-Image Pre-training) model.
This class wraps a CLIP model and provides functionality for image classification
using zero-shot learning. It allows setting a classification task with custom
class names and text templates.
Attributes:
clip_model (CLIPModel): The underlying CLIP model.
processor (CLIPProcessor): The CLIP processor for preparing inputs.
zeroshot_weights (Tensor): Computed text embeddings for zero-shot classification.
classnames (List[str]): List of class names for the current classification task.
templates (List[Callable[[str], str]]): List of template functions for generating text prompts.
"""
def __init__(
self,
clip_model: CLIPModel,
processor: CLIPProcessor,
):
"""
Initialize the HFCLIPClassifier.
Args:
clip_model (CLIPModel): The CLIP model to use for classification.
processor (CLIPProcessor): The CLIP processor for preparing inputs.
"""
super().__init__()
# we only fine-tune the vision model
clip_model.visual_projection.requires_grad_(False)
clip_model.text_model.requires_grad_(False)
clip_model.text_projection.requires_grad_(False)
clip_model.logit_scale.requires_grad_(False)
self.clip_model = clip_model
self.processor = processor
self.register_buffer(
"zeroshot_weights",
None,
persistent=False,
)
@property
def text_model(self):
"""Get the text model component of CLIP."""
return self.clip_model.text_model
@property
def vision_model(self):
"""Get the vision model component of CLIP."""
return self.clip_model.vision_model
def set_classification_task(
self,
classnames: List[str],
templates: List[Callable[[str], str]] = default_templates,
):
"""
Set up the zero-shot classification task.
This method computes text embeddings for the given class names using the
provided templates. These embeddings are then used for classification.
Args:
classnames (List[str]): List of class names for the classification task.
templates (List[Callable[[str], str]], optional): List of template functions
for generating text prompts. Defaults to `default_templates`, i.e.
["a photo of a {classname}"].
"""
processor = self.processor
self.classnames = classnames
self.templates = templates
with torch.no_grad():
zeroshot_weights = []
for classname in classnames:
text = [template(classname) for template in templates]
inputs = processor(text=text, return_tensors="pt", padding=True)
inputs = {
k: v.to(get_device(self.text_model)) for k, v in inputs.items()
}
embeddings = self.text_model(**inputs)[1]
embeddings = self.clip_model.text_projection(embeddings)
# normalize embeddings
embeddings = embeddings / embeddings.norm(p=2, dim=-1, keepdim=True)
embeddings = embeddings.mean(dim=0)
embeddings = embeddings / embeddings.norm(p=2, dim=-1, keepdim=True)
zeroshot_weights.append(embeddings)
zeroshot_weights = torch.stack(zeroshot_weights, dim=0)
self.zeroshot_weights = zeroshot_weights
def forward(self, images, return_image_embeds=False, return_dict=False):
"""
Perform forward pass for zero-shot image classification.
This method computes image embeddings for the input images and calculates
the similarity with the pre-computed text embeddings to produce classification logits.
Args:
images (Tensor): Input images to classify.
Returns:
Tensor: Classification logits for each input image.
Raises:
ValueError: If the classification task hasn't been set using set_classification_task.
"""
if self.zeroshot_weights is None:
raise ValueError("Must set classification task before forward pass")
text_embeds = self.zeroshot_weights
image_embeds = self.vision_model(images)
if isinstance(image_embeds, Tensor):
pass
elif isinstance(image_embeds, BaseModelOutputWithPooling):
image_embeds = image_embeds[1]
image_embeds = self.clip_model.visual_projection(image_embeds)
# normalize embeddings
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
# cosine similarity
logit_scale = self.clip_model.logit_scale.exp()
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
logits_per_image = logits_per_text.t()
if return_dict:
ret = {"logits": logits_per_image}
if return_image_embeds:
ret.update({"image_embeds": image_embeds})
return ret
else:
if return_image_embeds:
return logits_per_image, image_embeds
else:
return logits_per_image
|
set_classification_task(classnames, templates=default_templates)
Set up the zero-shot classification task.
This method computes text embeddings for the given class names using the
provided templates. These embeddings are then used for classification.
Parameters:
-
classnames
(List[str]
)
–
List of class names for the classification task.
-
templates
(List[Callable[[str], str]]
, default:
default_templates
)
–
List of template functions
for generating text prompts. Defaults to default_templates
, i.e.
["a photo of a {classname}"].
Source code in fusion_bench/models/hf_clip.py
| def set_classification_task(
self,
classnames: List[str],
templates: List[Callable[[str], str]] = default_templates,
):
"""
Set up the zero-shot classification task.
This method computes text embeddings for the given class names using the
provided templates. These embeddings are then used for classification.
Args:
classnames (List[str]): List of class names for the classification task.
templates (List[Callable[[str], str]], optional): List of template functions
for generating text prompts. Defaults to `default_templates`, i.e.
["a photo of a {classname}"].
"""
processor = self.processor
self.classnames = classnames
self.templates = templates
with torch.no_grad():
zeroshot_weights = []
for classname in classnames:
text = [template(classname) for template in templates]
inputs = processor(text=text, return_tensors="pt", padding=True)
inputs = {
k: v.to(get_device(self.text_model)) for k, v in inputs.items()
}
embeddings = self.text_model(**inputs)[1]
embeddings = self.clip_model.text_projection(embeddings)
# normalize embeddings
embeddings = embeddings / embeddings.norm(p=2, dim=-1, keepdim=True)
embeddings = embeddings.mean(dim=0)
embeddings = embeddings / embeddings.norm(p=2, dim=-1, keepdim=True)
zeroshot_weights.append(embeddings)
zeroshot_weights = torch.stack(zeroshot_weights, dim=0)
self.zeroshot_weights = zeroshot_weights
|
forward(images, return_image_embeds=False, return_dict=False)
Perform forward pass for zero-shot image classification.
This method computes image embeddings for the input images and calculates
the similarity with the pre-computed text embeddings to produce classification logits.
Parameters:
-
images
(Tensor
)
–
Input images to classify.
Returns:
-
Tensor
–
Classification logits for each input image.
Raises:
-
ValueError
–
If the classification task hasn't been set using set_classification_task.
Source code in fusion_bench/models/hf_clip.py
| def forward(self, images, return_image_embeds=False, return_dict=False):
"""
Perform forward pass for zero-shot image classification.
This method computes image embeddings for the input images and calculates
the similarity with the pre-computed text embeddings to produce classification logits.
Args:
images (Tensor): Input images to classify.
Returns:
Tensor: Classification logits for each input image.
Raises:
ValueError: If the classification task hasn't been set using set_classification_task.
"""
if self.zeroshot_weights is None:
raise ValueError("Must set classification task before forward pass")
text_embeds = self.zeroshot_weights
image_embeds = self.vision_model(images)
if isinstance(image_embeds, Tensor):
pass
elif isinstance(image_embeds, BaseModelOutputWithPooling):
image_embeds = image_embeds[1]
image_embeds = self.clip_model.visual_projection(image_embeds)
# normalize embeddings
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
# cosine similarity
logit_scale = self.clip_model.logit_scale.exp()
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
logits_per_image = logits_per_text.t()
if return_dict:
ret = {"logits": logits_per_image}
if return_image_embeds:
ret.update({"image_embeds": image_embeds})
return ret
else:
if return_image_embeds:
return logits_per_image, image_embeds
else:
return logits_per_image
|
Preparing Your Dataset
Create a custom dataset class that loads and preprocesses your images:
from torchvision import transforms
from PIL import Image
class SimpleDataset(torch.utils.data.Dataset):
def __init__(self, image_paths: List[str], labels: List[int]):
self.image_paths = image_paths
self.labels = labels
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = Image.open(self.image_paths[idx]).convert('RGB')
image = self.transform(image)
return image, self.labels[idx]
# Create DataLoader
dataset = SimpleDataset(image_paths, labels) # Replace with your data
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
You can also use fusion_bench.dataset.clip_dataset.CLIPDataset
or fusion_bench.dataset.image_dataset.TransformedImageDataset
to prepare your dataset. Here is examples of using fusion_bench.dataset.clip_dataset.CLIPDataset
and fusion_bench.dataset.image_dataset.TransformedImageDataset
to prepare your dataset:
from fusion_bench.dataset.clip_dataset import CLIPDataset
dataset = CLIPDataset(dataset, processor)
from fusion_bench.dataset.image_dataset import TransformedImageDataset
dataset = TransformedImageDataset(dataset, transform)
Where dataset
is your original dataset and transform
is the transform you want to apply to the images.
Below is the reference code for these two classes:
CLIPDataset
Bases: Dataset
A dataset class for CLIP models that converts a dataset of dictionaries or tuples
into a format suitable for CLIP processing.
This class wraps an existing dataset and applies CLIP preprocessing to the images.
It expects each item in the dataset to be either a dictionary with 'image' and 'label' keys,
or a tuple/list of (image, label).
Parameters:
-
dataset
–
The original dataset to wrap.
-
processor
(CLIPProcessor
, default:
None
)
–
The CLIP processor for preparing inputs. If None, no preprocessing is applied and raw images are returned.
Attributes:
-
dataset
–
-
processor
(CLIPProcessor
)
–
The CLIP processor used for image preprocessing.
Source code in fusion_bench/dataset/clip_dataset.py
| class CLIPDataset(torch.utils.data.Dataset):
"""
A dataset class for CLIP models that converts a dataset of dictionaries or tuples
into a format suitable for CLIP processing.
This class wraps an existing dataset and applies CLIP preprocessing to the images.
It expects each item in the dataset to be either a dictionary with 'image' and 'label' keys,
or a tuple/list of (image, label).
Args:
dataset: The original dataset to wrap.
processor (CLIPProcessor): The CLIP processor for preparing inputs. If None, no preprocessing is applied and raw images are returned.
Attributes:
dataset: The wrapped dataset.
processor (CLIPProcessor): The CLIP processor used for image preprocessing.
"""
def __init__(self, dataset, processor: Optional[CLIPProcessor] = None):
self.dataset = dataset
self.processor = processor
def __len__(self):
"""Returns the number of items in the dataset."""
return len(self.dataset)
def __getitem__(self, idx: int):
"""
Retrieves and processes an item from the dataset.
Args:
idx (int): The index of the item to retrieve.
Returns:
tuple: A tuple containing the processed image tensor and the label.
Raises:
ValueError: If the item is neither a dictionary nor a tuple/list of length 2.
"""
item = self.dataset[idx]
if isinstance(item, dict):
item = item
elif isinstance(item, (tuple, list)):
assert len(item) == 2, "Each item should be a tuple or list of length 2"
item = {"image": item[0], "label": item[1]}
else:
raise ValueError("Each item should be a dictionary or a tuple of length 2")
image = item["image"]
if self.processor is not None:
if isinstance(self.processor, ProcessorMixin):
# Apply the processor to the image to get the input tensor
inputs = self.processor(images=[image], return_tensors="pt")[
"pixel_values"
][0]
else:
# if processor is None, return the raw image directly
inputs = image
return inputs, item["label"]
|
__getitem__(idx)
Retrieves and processes an item from the dataset.
Parameters:
-
idx
(int
)
–
The index of the item to retrieve.
Returns:
-
tuple
–
A tuple containing the processed image tensor and the label.
Raises:
-
ValueError
–
If the item is neither a dictionary nor a tuple/list of length 2.
Source code in fusion_bench/dataset/clip_dataset.py
| def __getitem__(self, idx: int):
"""
Retrieves and processes an item from the dataset.
Args:
idx (int): The index of the item to retrieve.
Returns:
tuple: A tuple containing the processed image tensor and the label.
Raises:
ValueError: If the item is neither a dictionary nor a tuple/list of length 2.
"""
item = self.dataset[idx]
if isinstance(item, dict):
item = item
elif isinstance(item, (tuple, list)):
assert len(item) == 2, "Each item should be a tuple or list of length 2"
item = {"image": item[0], "label": item[1]}
else:
raise ValueError("Each item should be a dictionary or a tuple of length 2")
image = item["image"]
if self.processor is not None:
if isinstance(self.processor, ProcessorMixin):
# Apply the processor to the image to get the input tensor
inputs = self.processor(images=[image], return_tensors="pt")[
"pixel_values"
][0]
else:
# if processor is None, return the raw image directly
inputs = image
return inputs, item["label"]
|
__len__()
Returns the number of items in the dataset.
Source code in fusion_bench/dataset/clip_dataset.py
| def __len__(self):
"""Returns the number of items in the dataset."""
return len(self.dataset)
|
Bases: Dataset
A dataset class for image classification tasks that applies a transform to images.
This class wraps an existing dataset and applies a specified transform to the images.
It expects each item in the dataset to be either a dictionary with 'image' and 'label' keys,
or a tuple/list of (image, label).
Parameters:
-
–
The original dataset to wrap.
-
(
Callable
)
–
A function/transform to apply on the image.
Attributes:
-
dataset
–
-
transform
(Callable
)
–
The transform to be applied to the images.
Source code in fusion_bench/dataset/image_dataset.py
| class TransformedImageDataset(Dataset):
"""
A dataset class for image classification tasks that applies a transform to images.
This class wraps an existing dataset and applies a specified transform to the images.
It expects each item in the dataset to be either a dictionary with 'image' and 'label' keys,
or a tuple/list of (image, label).
Args:
dataset: The original dataset to wrap.
transform (Callable): A function/transform to apply on the image.
Attributes:
dataset: The wrapped dataset.
transform (Callable): The transform to be applied to the images.
"""
def __init__(self, dataset, transform: Callable):
super().__init__()
self.dataset = dataset
self.transform = transform
def __len__(self):
"""Returns the number of items in the dataset."""
return len(self.dataset)
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
"""
Retrieves and processes an item from the dataset.
Args:
idx (int): The index of the item to retrieve.
Returns:
tuple: A tuple containing the processed image and the label.
Raises:
ValueError: If the item is neither a dictionary nor a tuple/list of length 2.
"""
item = self.dataset[idx]
if isinstance(item, dict):
item = item
elif isinstance(item, (tuple, list)):
assert len(item) == 2, "Each item should be a tuple or list of length 2"
item = {"image": item[0], "label": item[1]}
else:
raise ValueError("Each item should be a dictionary or a tuple of length 2")
image = item["image"]
inputs = self.transform(image)
return inputs, item["label"]
|
Retrieves and processes an item from the dataset.
Parameters:
-
(
int
)
–
The index of the item to retrieve.
Returns:
-
tuple
( Tuple[Any, Any]
) –
A tuple containing the processed image and the label.
Raises:
-
ValueError
–
If the item is neither a dictionary nor a tuple/list of length 2.
Source code in fusion_bench/dataset/image_dataset.py
| def __getitem__(self, idx: int) -> Tuple[Any, Any]:
"""
Retrieves and processes an item from the dataset.
Args:
idx (int): The index of the item to retrieve.
Returns:
tuple: A tuple containing the processed image and the label.
Raises:
ValueError: If the item is neither a dictionary nor a tuple/list of length 2.
"""
item = self.dataset[idx]
if isinstance(item, dict):
item = item
elif isinstance(item, (tuple, list)):
assert len(item) == 2, "Each item should be a tuple or list of length 2"
item = {"image": item[0], "label": item[1]}
else:
raise ValueError("Each item should be a dictionary or a tuple of length 2")
image = item["image"]
inputs = self.transform(image)
return inputs, item["label"]
|
Returns the number of items in the dataset.
Source code in fusion_bench/dataset/image_dataset.py
| def __len__(self):
"""Returns the number of items in the dataset."""
return len(self.dataset)
|
Inference
Perform inference on your dataset:
classifier.eval()
with torch.no_grad():
for images, labels in dataloader:
logits = classifier(images)
predictions = torch.argmax(logits, dim=1)
# Process predictions as needed
Fine-tuning (Optional)
If you want to fine-tune the model:
optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-5)
criterion = torch.nn.CrossEntropyLoss()
classifier.train()
for epoch in range(num_epochs):
for images, labels in dataloader:
optimizer.zero_grad()
logits = classifier(images)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
Advanced Usage
Custom Templates
You can provide custom templates when setting up the classification task:
custom_templates = [
lambda c: f"a photo of a {c}",
lambda c: f"an image containing a {c}",
]
classifier.set_classification_task(class_names, templates=custom_templates)
Accessing Model Components
You can access the text and vision models directly:
text_model = classifier.text_model
vision_model = classifier.vision_model
Working with Zero-shot Weights
After setting the classification task, you can access the zero-shot weights:
zeroshot_weights = classifier.zeroshot_weights
These weights represent the text embeddings for each class and can be used for further analysis or custom processing.
Remember to adjust the code according to your specific dataset and requirements. This documentation provides a comprehensive guide for using the HFCLIPClassifier
for image classification tasks with CLIP models.