The HFCLIPClassifier class provides a wrapper around the CLIP (Contrastive Language-Image Pre-training) model for image classification tasks. It supports zero-shot learning and can be fine-tuned for specific classification tasks.
After initializing the classifier, we need to set up the classification task by defining class names and optionally, custom text templates.
The text encoder of CLIP model is used to encode the class names into text embeddings, which are then used to compute the logits for each class.
By default, set_classification_task uses the following templates:
default_templates=[lambdac:f"a photo of a {c}",]
You can also use custom templates:
custom_templates=[lambdac:f"a photo of a {c}",lambdac:f"an image containing a {c}",]classifier.set_classification_task(class_names,templates=custom_templates)
Below is the code for set_classification_task and forward method of HFCLIPClassifier:
A classifier based on the CLIP (Contrastive Language-Image Pre-training) model.
This class wraps a CLIP model and provides functionality for image classification
using zero-shot learning. It allows setting a classification task with custom
class names and text templates.
Attributes:
clip_model
(CLIPModel)
–
The underlying CLIP model.
processor
(CLIPProcessor)
–
The CLIP processor for preparing inputs.
zeroshot_weights
(Tensor)
–
Computed text embeddings for zero-shot classification.
classnames
(List[str])
–
List of class names for the current classification task.
templates
(List[Callable[[str], str]])
–
List of template functions for generating text prompts.
classHFCLIPClassifier(nn.Module):""" A classifier based on the CLIP (Contrastive Language-Image Pre-training) model. This class wraps a CLIP model and provides functionality for image classification using zero-shot learning. It allows setting a classification task with custom class names and text templates. Attributes: clip_model (CLIPModel): The underlying CLIP model. processor (CLIPProcessor): The CLIP processor for preparing inputs. zeroshot_weights (Tensor): Computed text embeddings for zero-shot classification. classnames (List[str]): List of class names for the current classification task. templates (List[Callable[[str], str]]): List of template functions for generating text prompts. """def__init__(self,clip_model:CLIPModel,processor:CLIPProcessor,extra_module=None,):""" Initialize the HFCLIPClassifier. Args: clip_model (CLIPModel): The CLIP model to use for classification. processor (CLIPProcessor): The CLIP processor for preparing inputs. """super().__init__()# we only fine-tune the vision modelclip_model.visual_projection.requires_grad_(False)clip_model.text_model.requires_grad_(False)clip_model.text_projection.requires_grad_(False)clip_model.logit_scale.requires_grad_(False)self.clip_model=clip_modelself.processor=processorself.register_buffer("zeroshot_weights",None,persistent=False,)self.extra_module=extra_module@propertydeftext_model(self):"""Get the text model component of CLIP."""returnself.clip_model.text_model@propertydefvision_model(self):"""Get the vision model component of CLIP."""returnself.clip_model.vision_modeldefset_classification_task(self,classnames:List[str],templates:List[Callable[[str],str]]=default_templates,):""" Set up the zero-shot classification task. This method computes text embeddings for the given class names using the provided templates. These embeddings are then used for classification. Args: classnames (List[str]): List of class names for the classification task. templates (List[Callable[[str], str]], optional): List of template functions for generating text prompts. Defaults to `default_templates`, i.e. ["a photo of a {classname}"]. """processor=self.processorself.classnames=classnamesself.templates=templateswithtorch.no_grad():zeroshot_weights=[]forclassnameinclassnames:text=[template(classname)fortemplateintemplates]inputs=processor(text=text,return_tensors="pt",padding=True)inputs={k:v.to(get_device(self.text_model))fork,vininputs.items()}embeddings=self.text_model(**inputs)[1]embeddings=self.clip_model.text_projection(embeddings)# normalize embeddingsembeddings=embeddings/embeddings.norm(p=2,dim=-1,keepdim=True)embeddings=embeddings.mean(dim=0)embeddings=embeddings/embeddings.norm(p=2,dim=-1,keepdim=True)zeroshot_weights.append(embeddings)zeroshot_weights=torch.stack(zeroshot_weights,dim=0)self.zeroshot_weights=zeroshot_weightsdefforward(self,images:Tensor,return_image_embeds=False,return_dict=False,task_name=None,):""" Perform forward pass for zero-shot image classification. This method computes image embeddings for the input images and calculates the similarity with the pre-computed text embeddings to produce classification logits. Args: images (Tensor): Input images to classify. return_image_embeds (bool): Whether to return the image embeddings. return_dict (bool): Whether to return a dictionary with logits and image embeddings. task_name (Optional[str]): The name of the task. Returns: Tensor: Classification logits for each input image. Raises: ValueError: If the classification task hasn't been set using set_classification_task. """ifself.zeroshot_weightsisNone:raiseValueError("Must set classification task before forward pass")text_embeds=self.zeroshot_weightsimage_embeds=self.get_image_features(images)# normalize embeddingsimage_embeds=image_embeds/image_embeds.norm(p=2,dim=-1,keepdim=True)if(hasattr(self.vision_model,"is_surgery_model")andself.vision_model.is_surgery_model):# Dealing with the surgery model, for more details, please refer to:# (ICML 2024) Yang, et.al. Representation Surgery for Multi-Task Model Merging# https://arxiv.org/abs/2402.02705self.vision_model:"SurgeryModelWrapper"=self.vision_modelimage_embeds,_,_=self.vision_model.compute_surgery_features(image_embeds,dataset_name=task_name)# cosine similaritylogit_scale=self.clip_model.logit_scale.exp()logits_per_text=torch.matmul(text_embeds,image_embeds.t())*logit_scalelogits_per_image=logits_per_text.t()ifreturn_dict:ret={"logits":logits_per_image}ifreturn_image_embeds:ret.update({"image_embeds":image_embeds})returnretelse:ifreturn_image_embeds:returnlogits_per_image,image_embedselse:returnlogits_per_imagedefget_image_features(self,images:Tensor)->Tensor:""" Compute the image embeddings. Returns: image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPVisionModel`]. """image_embeds=self.vision_model(images)ifisinstance(image_embeds,Tensor):passelifisinstance(image_embeds,BaseModelOutputWithPooling):image_embeds=image_embeds[1]image_embeds=self.clip_model.visual_projection(image_embeds)returnimage_embeds
defset_classification_task(self,classnames:List[str],templates:List[Callable[[str],str]]=default_templates,):""" Set up the zero-shot classification task. This method computes text embeddings for the given class names using the provided templates. These embeddings are then used for classification. Args: classnames (List[str]): List of class names for the classification task. templates (List[Callable[[str], str]], optional): List of template functions for generating text prompts. Defaults to `default_templates`, i.e. ["a photo of a {classname}"]. """processor=self.processorself.classnames=classnamesself.templates=templateswithtorch.no_grad():zeroshot_weights=[]forclassnameinclassnames:text=[template(classname)fortemplateintemplates]inputs=processor(text=text,return_tensors="pt",padding=True)inputs={k:v.to(get_device(self.text_model))fork,vininputs.items()}embeddings=self.text_model(**inputs)[1]embeddings=self.clip_model.text_projection(embeddings)# normalize embeddingsembeddings=embeddings/embeddings.norm(p=2,dim=-1,keepdim=True)embeddings=embeddings.mean(dim=0)embeddings=embeddings/embeddings.norm(p=2,dim=-1,keepdim=True)zeroshot_weights.append(embeddings)zeroshot_weights=torch.stack(zeroshot_weights,dim=0)self.zeroshot_weights=zeroshot_weights
Perform forward pass for zero-shot image classification.
This method computes image embeddings for the input images and calculates
the similarity with the pre-computed text embeddings to produce classification logits.
Parameters:
images
(Tensor)
–
Input images to classify.
return_image_embeds
(bool, default:
False
)
–
Whether to return the image embeddings.
return_dict
(bool, default:
False
)
–
Whether to return a dictionary with logits and image embeddings.
task_name
(Optional[str], default:
None
)
–
The name of the task.
Returns:
Tensor –
Classification logits for each input image.
Raises:
ValueError
–
If the classification task hasn't been set using set_classification_task.
defforward(self,images:Tensor,return_image_embeds=False,return_dict=False,task_name=None,):""" Perform forward pass for zero-shot image classification. This method computes image embeddings for the input images and calculates the similarity with the pre-computed text embeddings to produce classification logits. Args: images (Tensor): Input images to classify. return_image_embeds (bool): Whether to return the image embeddings. return_dict (bool): Whether to return a dictionary with logits and image embeddings. task_name (Optional[str]): The name of the task. Returns: Tensor: Classification logits for each input image. Raises: ValueError: If the classification task hasn't been set using set_classification_task. """ifself.zeroshot_weightsisNone:raiseValueError("Must set classification task before forward pass")text_embeds=self.zeroshot_weightsimage_embeds=self.get_image_features(images)# normalize embeddingsimage_embeds=image_embeds/image_embeds.norm(p=2,dim=-1,keepdim=True)if(hasattr(self.vision_model,"is_surgery_model")andself.vision_model.is_surgery_model):# Dealing with the surgery model, for more details, please refer to:# (ICML 2024) Yang, et.al. Representation Surgery for Multi-Task Model Merging# https://arxiv.org/abs/2402.02705self.vision_model:"SurgeryModelWrapper"=self.vision_modelimage_embeds,_,_=self.vision_model.compute_surgery_features(image_embeds,dataset_name=task_name)# cosine similaritylogit_scale=self.clip_model.logit_scale.exp()logits_per_text=torch.matmul(text_embeds,image_embeds.t())*logit_scalelogits_per_image=logits_per_text.t()ifreturn_dict:ret={"logits":logits_per_image}ifreturn_image_embeds:ret.update({"image_embeds":image_embeds})returnretelse:ifreturn_image_embeds:returnlogits_per_image,image_embedselse:returnlogits_per_image
Create a custom dataset class that loads and preprocesses your images:
fromtorchvisionimporttransformsfromPILimportImageclassSimpleDataset(torch.utils.data.Dataset):def__init__(self,image_paths:List[str],labels:List[int]):self.image_paths=image_pathsself.labels=labelsself.transform=transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor(),])def__len__(self):returnlen(self.image_paths)def__getitem__(self,idx):image=Image.open(self.image_paths[idx]).convert('RGB')image=self.transform(image)returnimage,self.labels[idx]# Create DataLoaderdataset=SimpleDataset(image_paths,labels)# Replace with your datadataloader=DataLoader(dataset,batch_size=32,shuffle=True)
You can also use fusion_bench.dataset.clip_dataset.CLIPDataset or fusion_bench.dataset.image_dataset.TransformedImageDataset to prepare your dataset. Here is examples of using fusion_bench.dataset.clip_dataset.CLIPDataset and fusion_bench.dataset.image_dataset.TransformedImageDataset to prepare your dataset:
Where dataset is your original dataset and transform is the transform you want to apply to the images.
Below is the reference code for these two classes:
A dataset class for CLIP models that converts a dataset of dictionaries or tuples
into a format suitable for CLIP processing.
This class wraps an existing dataset and applies CLIP preprocessing to the images.
It expects each item in the dataset to be either a dictionary with 'image' and 'label' keys,
or a tuple/list of (image, label).
Parameters:
dataset
(Dataset)
–
The original dataset to wrap.
processor
(CLIPProcessor, default:
None
)
–
The CLIP processor for preparing inputs. If None, no preprocessing is applied and raw images are returned.
Attributes:
dataset
–
The wrapped dataset.
processor
(CLIPProcessor)
–
The CLIP processor used for image preprocessing.
Source code in fusion_bench/dataset/clip_dataset.py
classCLIPDataset(torch.utils.data.Dataset):""" A dataset class for CLIP models that converts a dataset of dictionaries or tuples into a format suitable for CLIP processing. This class wraps an existing dataset and applies CLIP preprocessing to the images. It expects each item in the dataset to be either a dictionary with 'image' and 'label' keys, or a tuple/list of (image, label). Args: dataset: The original dataset to wrap. processor (CLIPProcessor): The CLIP processor for preparing inputs. If None, no preprocessing is applied and raw images are returned. Attributes: dataset: The wrapped dataset. processor (CLIPProcessor): The CLIP processor used for image preprocessing. """def__init__(self,dataset:Dataset,processor:Optional[CLIPProcessor]=None):self.dataset=datasetself.processor=processordef__len__(self):"""Returns the number of items in the dataset."""returnlen(self.dataset)def__getitem__(self,idx:int)->Tuple[torch.Tensor,int]:""" Retrieves and processes an item from the dataset. Args: idx (int): The index of the item to retrieve. Returns: tuple: A tuple containing the processed image tensor and the label. Raises: ValueError: If the item is neither a dictionary nor a tuple/list of length 2. """item=self.dataset[idx]ifisinstance(item,dict):item=itemelifisinstance(item,(tuple,list)):assertlen(item)==2,"Each item should be a tuple or list of length 2"item={"image":item[0],"label":item[1]}else:raiseValueError("Each item should be a dictionary or a tuple of length 2")image=item["image"]ifself.processorisnotNone:ifisinstance(self.processor,ProcessorMixin):# Apply the processor to the image to get the input tensorinputs=self.processor(images=[image],return_tensors="pt")["pixel_values"][0]elifcallable(self.processor):inputs=self.processor(image)else:raiseValueError("The processor should be a CLIPProcessor or a callable function")else:# if processor is None, return the raw image directlyinputs=image# convert boolean label to int, this is for the case when the label is a binary classification taskifisinstance(item["label"],bool):item["label"]=1ifitem["label"]else0returninputs,item["label"]
def__getitem__(self,idx:int)->Tuple[torch.Tensor,int]:""" Retrieves and processes an item from the dataset. Args: idx (int): The index of the item to retrieve. Returns: tuple: A tuple containing the processed image tensor and the label. Raises: ValueError: If the item is neither a dictionary nor a tuple/list of length 2. """item=self.dataset[idx]ifisinstance(item,dict):item=itemelifisinstance(item,(tuple,list)):assertlen(item)==2,"Each item should be a tuple or list of length 2"item={"image":item[0],"label":item[1]}else:raiseValueError("Each item should be a dictionary or a tuple of length 2")image=item["image"]ifself.processorisnotNone:ifisinstance(self.processor,ProcessorMixin):# Apply the processor to the image to get the input tensorinputs=self.processor(images=[image],return_tensors="pt")["pixel_values"][0]elifcallable(self.processor):inputs=self.processor(image)else:raiseValueError("The processor should be a CLIPProcessor or a callable function")else:# if processor is None, return the raw image directlyinputs=image# convert boolean label to int, this is for the case when the label is a binary classification taskifisinstance(item["label"],bool):item["label"]=1ifitem["label"]else0returninputs,item["label"]
A dataset class for image classification tasks that applies a transform to images.
This class wraps an existing dataset and applies a specified transform to the images.
It expects each item in the dataset to be either a dictionary with 'image' and 'label' keys,
or a tuple/list of (image, label).
Parameters:
dataset
(Dataset)
–
The original dataset to wrap.
transform
(Callable)
–
A function/transform to apply on the image.
Attributes:
dataset
–
The wrapped dataset.
transform
(Callable)
–
The transform to be applied to the images.
Source code in fusion_bench/dataset/image_dataset.py
classTransformedImageDataset(Dataset):""" A dataset class for image classification tasks that applies a transform to images. This class wraps an existing dataset and applies a specified transform to the images. It expects each item in the dataset to be either a dictionary with 'image' and 'label' keys, or a tuple/list of (image, label). Args: dataset: The original dataset to wrap. transform (Callable): A function/transform to apply on the image. Attributes: dataset: The wrapped dataset. transform (Callable): The transform to be applied to the images. """def__init__(self,dataset:Dataset,transform:Callable):super().__init__()self.dataset=datasetself.transform=transformdef__len__(self):"""Returns the number of items in the dataset."""returnlen(self.dataset)def__getitem__(self,idx:int)->Tuple[Any,Any]:""" Retrieves and processes an item from the dataset. Args: idx (int): The index of the item to retrieve. Returns: tuple: A tuple containing the processed image and the label. Raises: ValueError: If the item is neither a dictionary nor a tuple/list of length 2. """item=self.dataset[idx]ifisinstance(item,dict):item=itemelifisinstance(item,(tuple,list)):assertlen(item)==2,"Each item should be a tuple or list of length 2"item={"image":item[0],"label":item[1]}else:raiseValueError("Each item should be a dictionary or a tuple of length 2")image=item["image"]inputs=self.transform(image)returninputs,item["label"]
def__getitem__(self,idx:int)->Tuple[Any,Any]:""" Retrieves and processes an item from the dataset. Args: idx (int): The index of the item to retrieve. Returns: tuple: A tuple containing the processed image and the label. Raises: ValueError: If the item is neither a dictionary nor a tuple/list of length 2. """item=self.dataset[idx]ifisinstance(item,dict):item=itemelifisinstance(item,(tuple,list)):assertlen(item)==2,"Each item should be a tuple or list of length 2"item={"image":item[0],"label":item[1]}else:raiseValueError("Each item should be a dictionary or a tuple of length 2")image=item["image"]inputs=self.transform(image)returninputs,item["label"]
classifier.eval()withtorch.no_grad():forimages,labelsindataloader:logits=classifier(images)predictions=torch.argmax(logits,dim=1)# Process predictions as needed
You can provide custom templates when setting up the classification task:
custom_templates=[lambdac:f"a photo of a {c}",lambdac:f"an image containing a {c}",]classifier.set_classification_task(class_names,templates=custom_templates)
After setting the classification task, you can access the zero-shot weights:
zeroshot_weights=classifier.zeroshot_weights
These weights represent the text embeddings for each class and can be used for further analysis or custom processing.
Remember to adjust the code according to your specific dataset and requirements. This documentation provides a comprehensive guide for using the HFCLIPClassifier for image classification tasks with CLIP models.