Skip to content

Image Classification Tasks for CLIP Models

CLIPVisionModelTaskPool

The CLIPVisionModelTaskPool class is used to define image classification tasks for CLIP models. It provides methods to evaluate the performance of a given model on multiple datasets.

Attributes

  • test_datasets: A dictionary containing the test datasets.
  • processor: The processor used for preprocessing the input data. This is used to set up the classifier.
  • data_processor: The data processor used for processing the input data.
  • clip_model: The CLIP model used for evaluation.
  • dataloader_kwargs: Keyword arguments for the data loader.
  • layer_wise_feature_save_path: Path to save the layer-wise features.
  • layer_wise_feature_first_token_only: Boolean indicating whether to save only the first token of the features.
  • layer_wise_feature_max_num: Maximum number of features to save.
  • fast_dev_run: Boolean indicating whether to run in fast development mode.

Methods

  • setup(): Sets up the processor, data processor, CLIP model, test datasets, and data loaders.
  • evaluate(model): Evaluates the given model on the image classification task.
  • on_task_evaluation_begin(classifier, task_name): Called at the beginning of task evaluation to set up hooks for saving layer-wise features.
  • on_task_evaluation_end(): Called at the end of task evaluation to save features and remove hooks.

Configuration

The CLIPVisionModelTaskPool class can be configured using a YAML file. Here is an example configuration:

test_datasets:
  dataset1: ...
  dataset2: ...
processor:
  _target_: transformers.CLIPProcessor.from_pretrained
  pretrained_model_name_or_path: openai/clip-vit-base-patch32
data_processor:
  _target_: transformers.CLIPProcessor.from_pretrained
  pretrained_model_name_or_path: openai/clip-vit-base-patch32
clip_model:
  _target_: transformers.CLIPModel.from_pretrained
  pretrained_model_name_or_path: openai/clip-vit-base-patch32
dataloader_kwargs:
  batch_size: 32
  num_workers: 4
layer_wise_feature_save_path: path/to/save/features
layer_wise_feature_first_token_only: true
layer_wise_feature_max_num: 1000
fast_dev_run: false

References

CLIPVisionModelTaskPool

Bases: BaseTaskPool, LightningFabricMixin

This class is used to define the image classification task for CLIP models.

Attributes:

  • test_datasets (Union[DictConfig, Dict[str, Dataset]]) –

    The test datasets to evaluate the model on.

  • processor (Union[DictConfig, CLIPProcessor]) –

    The processor used for preprocessing the input data.

  • data_processor (Union[DictConfig, CLIPProcessor]) –

    The data processor used for processing the input data.

  • clip_model (Union[DictConfig, CLIPModel]) –

    The CLIP model used for evaluation.

  • dataloader_kwargs (DictConfig) –

    Keyword arguments for the data loader.

  • layer_wise_feature_save_path (Optional[str]) –

    Path to save the layer-wise features.

  • layer_wise_feature_first_token_only (bool) –

    Boolean indicating whether to save only the first token of the features.

  • layer_wise_feature_max_num (Optional[int]) –

    Maximum number of features to save.

  • fast_dev_run (bool) –

    Boolean indicating whether to run in fast development mode.

Source code in fusion_bench/taskpool/clip_vision/taskpool.py
class CLIPVisionModelTaskPool(
    BaseTaskPool,
    LightningFabricMixin,
):
    """
    This class is used to define the image classification task for CLIP models.

    Attributes:
        test_datasets (Union[DictConfig, Dict[str, Dataset]]): The test datasets to evaluate the model on.
        processor (Union[DictConfig, CLIPProcessor]): The processor used for preprocessing the input data.
        data_processor (Union[DictConfig, CLIPProcessor]): The data processor used for processing the input data.
        clip_model (Union[DictConfig, CLIPModel]): The CLIP model used for evaluation.
        dataloader_kwargs (DictConfig): Keyword arguments for the data loader.
        layer_wise_feature_save_path (Optional[str]): Path to save the layer-wise features.
        layer_wise_feature_first_token_only (bool): Boolean indicating whether to save only the first token of the features.
        layer_wise_feature_max_num (Optional[int]): Maximum number of features to save.
        fast_dev_run (bool): Boolean indicating whether to run in fast development mode.
    """

    _is_setup = False

    # hooks and handles for saving layer-wise features
    _layer_wise_feature_save_hooks: Dict[int, LayerWiseFeatureSaver] = {}
    _layer_wise_feature_save_hook_handles: Dict[int, RemovableHandle] = {}

    _config_mapping = BaseTaskPool._config_mapping | {
        "_test_datasets": "test_datasets",
        "_processor": "processor",
        "_data_processor": "data_processor",
        "_clip_model": "clip_model",
        "_dataloader_kwargs": "dataloader_kwargs",
        "_layer_wise_feature_save_path": "layer_wise_feature_save_path",
        "fast_dev_run": "fast_dev_run",
    }

    def __init__(
        self,
        test_datasets: Union[DictConfig, Dict[str, Dataset]],
        *,
        processor: Union[DictConfig, CLIPProcessor],
        data_processor: Union[DictConfig, CLIPProcessor],
        clip_model: Union[DictConfig, CLIPModel],
        dataloader_kwargs: DictConfig = None,
        layer_wise_feature_save_path: Optional[str] = None,
        layer_wise_feature_first_token_only: bool = True,
        layer_wise_feature_max_num: Optional[int] = None,
        fast_dev_run: bool = False,
        **kwargs,
    ):
        """
        Initialize the CLIPVisionModelTaskPool.
        """
        self._test_datasets = test_datasets
        self._processor = processor
        self._data_processor = data_processor
        self._clip_model = clip_model
        self._dataloader_kwargs = dataloader_kwargs or {}

        # layer-wise feature saving
        self._layer_wise_feature_save_path = layer_wise_feature_save_path
        self.layer_wise_feature_save_path = (
            Path(layer_wise_feature_save_path)
            if layer_wise_feature_save_path is not None
            else None
        )
        self.layer_wise_feature_first_token_only = layer_wise_feature_first_token_only
        self.layer_wise_feature_max_num = layer_wise_feature_max_num

        self.fast_dev_run = fast_dev_run
        super().__init__(**kwargs)

    def setup(self):
        """
        Set up the processor, data processor, CLIP model, test datasets, and data loaders.
        """
        # setup processor and clip model
        self.processor = (
            instantiate(self._processor)
            if isinstance(self._processor, DictConfig)
            else self._processor
        )
        self.data_processor = (
            instantiate(self._data_processor)
            if isinstance(self._data_processor, DictConfig)
            else self._data_processor
        )
        self.clip_model = (
            instantiate(self._clip_model)
            if isinstance(self._clip_model, DictConfig)
            else self._clip_model
        )
        self.clip_model = self.fabric.to_device(self.clip_model)
        self.clip_model.requires_grad_(False)
        self.clip_model.eval()

        # Load the test datasets
        self.test_datasets = {
            name: instantiate(dataset) if isinstance(dataset, DictConfig) else dataset
            for name, dataset in self._test_datasets.items()
        }
        self.test_datasets = {
            name: CLIPDataset(dataset, self.data_processor)
            for name, dataset in self.test_datasets.items()
        }
        # Setup the dataloaders
        self.test_dataloaders = {
            name: DataLoader(
                dataset,
                **self._dataloader_kwargs,
                collate_fn=(
                    raw_image_collate_fn if self.data_processor is None else None
                ),
            )
            for name, dataset in self.test_datasets.items()
        }
        self.test_dataloaders = {
            name: self.fabric.setup_dataloaders(dataloader)
            for name, dataloader in self.test_dataloaders.items()
        }

        self._is_setup = True

    @torch.no_grad()
    def _evaluate(
        self,
        classifier: HFCLIPClassifier,
        test_loader: DataLoader,
        num_classes: int,
        task_name: str = None,
    ):
        """
        Evaluate the classifier on the test dataset (single-task evaluation).

        Args:
            classifier (HFCLIPClassifier): The classifier to evaluate.
            test_loader (DataLoader): The data loader for the test dataset.
            num_classes (int): The number of classes in the classification task.
            task_name (str): The name of the task.

        Returns:
            Dict[str, float]: A dictionary containing the accuracy and loss of the classifier on the test dataset.
        """
        accuracy: MulticlassAccuracy = Accuracy(
            task="multiclass", num_classes=num_classes
        )
        classifier.eval()
        loss_metric = MeanMetric()
        # if fast_dev_run is set, we only evaluate on a batch of the data
        if self.fast_dev_run:
            log.info("Running under fast_dev_run mode, evaluating on a single batch.")
            test_loader = itertools.islice(test_loader, 1)
        else:
            test_loader = test_loader

        pbar = tqdm(
            test_loader,
            desc=f"Evaluating {task_name}",
            leave=False,
            dynamic_ncols=True,
        )
        for batch in pbar:
            inputs, targets = batch
            outputs = classifier(
                inputs,
                return_image_embeds=True,
                return_dict=True,
                task_name=task_name,
            )
            logits: Tensor = outputs["logits"]

            loss = F.cross_entropy(logits, targets)
            loss_metric.update(loss.detach().cpu())
            acc = accuracy(logits.detach().cpu(), targets.detach().cpu())
            pbar.set_postfix(
                {
                    "accuracy": accuracy.compute().item(),
                    "loss": loss_metric.compute().item(),
                }
            )

        acc = accuracy.compute().item()
        loss = loss_metric.compute().item()
        results = {"accuracy": acc, "loss": loss}
        return results

    def evaluate(
        self,
        model: Union[CLIPVisionModel, CLIPVisionTransformer],
        name=None,
        **kwargs,
    ):
        """
        Evaluate the model on the image classification task.

        Args:
            model (Union[CLIPVisionModel, CLIPVisionTransformer]): The model to evaluate.
            name (Optional[str]): The name of the model. This will be logged into the report if not None.

        Returns:
            Dict[str, Any]: A dictionary containing the evaluation results for each task.
        """
        if not self._is_setup:
            self.setup()

        report = {}
        # CLIPVisionModel works the same with CLIPVisonTransformer, so we can use it directly
        if hasattr(model, "is_surgery_model") and model.is_surgery_model:
            log.info("running evaluation on a surgery model.")
            model: "SurgeryModelWrapper" = model
            self.clip_model.vision_model = model
        else:
            # replace the vision encoder with the model
            self.clip_model.vision_model = model
        classifier = HFCLIPClassifier(
            self.clip_model,
            processor=self.processor,
        )
        classifier = cast(HFCLIPClassifier, self.fabric.to_device(classifier))
        # collect basic model information
        training_params, all_params = count_parameters(model)
        report["model_info"] = {
            "trainable_params": training_params,
            "all_params": all_params,
            "trainable_percentage": training_params / all_params,
        }
        if name is not None:
            report["model_info"]["name"] = name

        # evaluate on each task
        pbar = tqdm(
            self.test_dataloaders.items(),
            desc="Evaluating tasks",
            total=len(self.test_dataloaders),
        )
        for task_name, test_dataloader in pbar:
            classnames, templates = get_classnames_and_templates(task_name)
            self.on_task_evaluation_begin(classifier, task_name)
            classifier.set_classification_task(classnames, templates)
            result = self._evaluate(
                classifier,
                test_dataloader,
                num_classes=len(classnames),
                task_name=task_name,
            )
            report[task_name] = result
            self.on_task_evaluation_end()

        # calculate the average accuracy and loss
        if "average" not in report:
            report["average"] = {}
            accuracies = [
                value["accuracy"]
                for key, value in report.items()
                if "accuracy" in value
            ]
            if len(accuracies) > 0:
                average_accuracy = sum(accuracies) / len(accuracies)
                report["average"]["accuracy"] = average_accuracy
            losses = [value["loss"] for key, value in report.items() if "loss" in value]
            if len(losses) > 0:
                average_loss = sum(losses) / len(losses)
                report["average"]["loss"] = average_loss

        log.info(f"Evaluation Result: {report}")
        if self.fabric.is_global_zero and len(self.fabric._loggers) > 0:
            with open(os.path.join(self.log_dir, "report.json"), "w") as fp:
                json.dump(report, fp)
        return report

    def on_task_evaluation_begin(self, classifier: HFCLIPClassifier, task_name: str):
        """
        Called at the beginning of task evaluation to set up hooks for saving layer-wise features.

        Args:
            classifier (HFCLIPClassifier): The classifier being evaluated.
            task_name (str): The name of the task being evaluated.
        """
        if self.layer_wise_feature_save_path is not None:
            # setup hooks for saving layer-wise features
            assert isinstance(
                classifier.clip_model.vision_model,
                (CLIPVisionTransformer, CLIPVisionModel),
            ), "Vision model is expected to be a CLIPVisionTransformer"
            vision_model = classifier.clip_model.vision_model
            if isinstance(vision_model, CLIPVisionModel):
                vision_model = vision_model.vision_model
                # assign forward hooks for each layer
            for i, layer in enumerate(vision_model.encoder.layers):
                self._layer_wise_feature_save_hooks[i] = LayerWiseFeatureSaver(
                    self.layer_wise_feature_save_path / task_name / f"layer_{i}.pth",
                    first_token_only=self.layer_wise_feature_first_token_only,
                    max_num=self.layer_wise_feature_max_num,
                )
                self._layer_wise_feature_save_hook_handles[i] = (
                    layer.register_forward_hook(self._layer_wise_feature_save_hooks[i])
                )

    def on_task_evaluation_end(self):
        """
        Called at the end of task evaluation to save features and remove hooks.
        """
        if self.layer_wise_feature_save_path is not None:
            # save features and remove hooks after evaluation
            for i, hook in self._layer_wise_feature_save_hooks.items():
                hook.save_features()
                self._layer_wise_feature_save_hook_handles[i].remove()
__init__(test_datasets, *, processor, data_processor, clip_model, dataloader_kwargs=None, layer_wise_feature_save_path=None, layer_wise_feature_first_token_only=True, layer_wise_feature_max_num=None, fast_dev_run=False, **kwargs)

Initialize the CLIPVisionModelTaskPool.

Source code in fusion_bench/taskpool/clip_vision/taskpool.py
def __init__(
    self,
    test_datasets: Union[DictConfig, Dict[str, Dataset]],
    *,
    processor: Union[DictConfig, CLIPProcessor],
    data_processor: Union[DictConfig, CLIPProcessor],
    clip_model: Union[DictConfig, CLIPModel],
    dataloader_kwargs: DictConfig = None,
    layer_wise_feature_save_path: Optional[str] = None,
    layer_wise_feature_first_token_only: bool = True,
    layer_wise_feature_max_num: Optional[int] = None,
    fast_dev_run: bool = False,
    **kwargs,
):
    """
    Initialize the CLIPVisionModelTaskPool.
    """
    self._test_datasets = test_datasets
    self._processor = processor
    self._data_processor = data_processor
    self._clip_model = clip_model
    self._dataloader_kwargs = dataloader_kwargs or {}

    # layer-wise feature saving
    self._layer_wise_feature_save_path = layer_wise_feature_save_path
    self.layer_wise_feature_save_path = (
        Path(layer_wise_feature_save_path)
        if layer_wise_feature_save_path is not None
        else None
    )
    self.layer_wise_feature_first_token_only = layer_wise_feature_first_token_only
    self.layer_wise_feature_max_num = layer_wise_feature_max_num

    self.fast_dev_run = fast_dev_run
    super().__init__(**kwargs)
evaluate(model, name=None, **kwargs)

Evaluate the model on the image classification task.

Parameters:

  • model
    (Union[CLIPVisionModel, CLIPVisionTransformer]) –

    The model to evaluate.

  • name
    (Optional[str], default: None ) –

    The name of the model. This will be logged into the report if not None.

Returns:

  • Dict[str, Any]: A dictionary containing the evaluation results for each task.

Source code in fusion_bench/taskpool/clip_vision/taskpool.py
def evaluate(
    self,
    model: Union[CLIPVisionModel, CLIPVisionTransformer],
    name=None,
    **kwargs,
):
    """
    Evaluate the model on the image classification task.

    Args:
        model (Union[CLIPVisionModel, CLIPVisionTransformer]): The model to evaluate.
        name (Optional[str]): The name of the model. This will be logged into the report if not None.

    Returns:
        Dict[str, Any]: A dictionary containing the evaluation results for each task.
    """
    if not self._is_setup:
        self.setup()

    report = {}
    # CLIPVisionModel works the same with CLIPVisonTransformer, so we can use it directly
    if hasattr(model, "is_surgery_model") and model.is_surgery_model:
        log.info("running evaluation on a surgery model.")
        model: "SurgeryModelWrapper" = model
        self.clip_model.vision_model = model
    else:
        # replace the vision encoder with the model
        self.clip_model.vision_model = model
    classifier = HFCLIPClassifier(
        self.clip_model,
        processor=self.processor,
    )
    classifier = cast(HFCLIPClassifier, self.fabric.to_device(classifier))
    # collect basic model information
    training_params, all_params = count_parameters(model)
    report["model_info"] = {
        "trainable_params": training_params,
        "all_params": all_params,
        "trainable_percentage": training_params / all_params,
    }
    if name is not None:
        report["model_info"]["name"] = name

    # evaluate on each task
    pbar = tqdm(
        self.test_dataloaders.items(),
        desc="Evaluating tasks",
        total=len(self.test_dataloaders),
    )
    for task_name, test_dataloader in pbar:
        classnames, templates = get_classnames_and_templates(task_name)
        self.on_task_evaluation_begin(classifier, task_name)
        classifier.set_classification_task(classnames, templates)
        result = self._evaluate(
            classifier,
            test_dataloader,
            num_classes=len(classnames),
            task_name=task_name,
        )
        report[task_name] = result
        self.on_task_evaluation_end()

    # calculate the average accuracy and loss
    if "average" not in report:
        report["average"] = {}
        accuracies = [
            value["accuracy"]
            for key, value in report.items()
            if "accuracy" in value
        ]
        if len(accuracies) > 0:
            average_accuracy = sum(accuracies) / len(accuracies)
            report["average"]["accuracy"] = average_accuracy
        losses = [value["loss"] for key, value in report.items() if "loss" in value]
        if len(losses) > 0:
            average_loss = sum(losses) / len(losses)
            report["average"]["loss"] = average_loss

    log.info(f"Evaluation Result: {report}")
    if self.fabric.is_global_zero and len(self.fabric._loggers) > 0:
        with open(os.path.join(self.log_dir, "report.json"), "w") as fp:
            json.dump(report, fp)
    return report
on_task_evaluation_begin(classifier, task_name)

Called at the beginning of task evaluation to set up hooks for saving layer-wise features.

Parameters:

  • classifier
    (HFCLIPClassifier) –

    The classifier being evaluated.

  • task_name
    (str) –

    The name of the task being evaluated.

Source code in fusion_bench/taskpool/clip_vision/taskpool.py
def on_task_evaluation_begin(self, classifier: HFCLIPClassifier, task_name: str):
    """
    Called at the beginning of task evaluation to set up hooks for saving layer-wise features.

    Args:
        classifier (HFCLIPClassifier): The classifier being evaluated.
        task_name (str): The name of the task being evaluated.
    """
    if self.layer_wise_feature_save_path is not None:
        # setup hooks for saving layer-wise features
        assert isinstance(
            classifier.clip_model.vision_model,
            (CLIPVisionTransformer, CLIPVisionModel),
        ), "Vision model is expected to be a CLIPVisionTransformer"
        vision_model = classifier.clip_model.vision_model
        if isinstance(vision_model, CLIPVisionModel):
            vision_model = vision_model.vision_model
            # assign forward hooks for each layer
        for i, layer in enumerate(vision_model.encoder.layers):
            self._layer_wise_feature_save_hooks[i] = LayerWiseFeatureSaver(
                self.layer_wise_feature_save_path / task_name / f"layer_{i}.pth",
                first_token_only=self.layer_wise_feature_first_token_only,
                max_num=self.layer_wise_feature_max_num,
            )
            self._layer_wise_feature_save_hook_handles[i] = (
                layer.register_forward_hook(self._layer_wise_feature_save_hooks[i])
            )
on_task_evaluation_end()

Called at the end of task evaluation to save features and remove hooks.

Source code in fusion_bench/taskpool/clip_vision/taskpool.py
def on_task_evaluation_end(self):
    """
    Called at the end of task evaluation to save features and remove hooks.
    """
    if self.layer_wise_feature_save_path is not None:
        # save features and remove hooks after evaluation
        for i, hook in self._layer_wise_feature_save_hooks.items():
            hook.save_features()
            self._layer_wise_feature_save_hook_handles[i].remove()
setup()

Set up the processor, data processor, CLIP model, test datasets, and data loaders.

Source code in fusion_bench/taskpool/clip_vision/taskpool.py
def setup(self):
    """
    Set up the processor, data processor, CLIP model, test datasets, and data loaders.
    """
    # setup processor and clip model
    self.processor = (
        instantiate(self._processor)
        if isinstance(self._processor, DictConfig)
        else self._processor
    )
    self.data_processor = (
        instantiate(self._data_processor)
        if isinstance(self._data_processor, DictConfig)
        else self._data_processor
    )
    self.clip_model = (
        instantiate(self._clip_model)
        if isinstance(self._clip_model, DictConfig)
        else self._clip_model
    )
    self.clip_model = self.fabric.to_device(self.clip_model)
    self.clip_model.requires_grad_(False)
    self.clip_model.eval()

    # Load the test datasets
    self.test_datasets = {
        name: instantiate(dataset) if isinstance(dataset, DictConfig) else dataset
        for name, dataset in self._test_datasets.items()
    }
    self.test_datasets = {
        name: CLIPDataset(dataset, self.data_processor)
        for name, dataset in self.test_datasets.items()
    }
    # Setup the dataloaders
    self.test_dataloaders = {
        name: DataLoader(
            dataset,
            **self._dataloader_kwargs,
            collate_fn=(
                raw_image_collate_fn if self.data_processor is None else None
            ),
        )
        for name, dataset in self.test_datasets.items()
    }
    self.test_dataloaders = {
        name: self.fabric.setup_dataloaders(dataloader)
        for name, dataloader in self.test_dataloaders.items()
    }

    self._is_setup = True