Skip to content

fusion_bench.taskpool

Base Class

BaseTaskPool

Bases: BaseYAMLSerializableModel

Source code in fusion_bench/taskpool/base_pool.py
class BaseTaskPool(BaseYAMLSerializableModel):
    _program = None
    _config_key = "taskpool"

    @abstractmethod
    def evaluate(self, model, *args, **kwargs):
        """
        Evaluate the model on all tasks in the task pool, and return a report.

        Take image classification as an example, the report will look like:

        ```python
        {
            "mnist": {
                "accuracy": 0.8,
                "loss": 0.2,
            },
            <task_name>: {
                <metric_name>: <metric_value>,
                ...
            },
        }
        ```

        Args:
            model: The model to evaluate.

        Returns:
            report (dict): A dictionary containing the results of the evaluation for each task.
        """
        pass

evaluate(model, *args, **kwargs) abstractmethod

Evaluate the model on all tasks in the task pool, and return a report.

Take image classification as an example, the report will look like:

{
    "mnist": {
        "accuracy": 0.8,
        "loss": 0.2,
    },
    <task_name>: {
        <metric_name>: <metric_value>,
        ...
    },
}

Parameters:

  • model –

    The model to evaluate.

Returns:

  • report ( dict ) –

    A dictionary containing the results of the evaluation for each task.

Source code in fusion_bench/taskpool/base_pool.py
@abstractmethod
def evaluate(self, model, *args, **kwargs):
    """
    Evaluate the model on all tasks in the task pool, and return a report.

    Take image classification as an example, the report will look like:

    ```python
    {
        "mnist": {
            "accuracy": 0.8,
            "loss": 0.2,
        },
        <task_name>: {
            <metric_name>: <metric_value>,
            ...
        },
    }
    ```

    Args:
        model: The model to evaluate.

    Returns:
        report (dict): A dictionary containing the results of the evaluation for each task.
    """
    pass

Vision Task Pool

NYUv2 Tasks

NYUv2TaskPool

Bases: TaskPool

Source code in fusion_bench/taskpool/nyuv2_taskpool.py
class NYUv2TaskPool(TaskPool):
    _trainer: L.Trainer = None

    def __init__(self, taskpool_config: DictConfig):
        self.config = taskpool_config

    def load_datasets(self):
        log.info("Loading NYUv2 dataset")
        data_path = str(Path(self.config.data_dir) / "nyuv2")

        train_dataset = NYUv2(root=data_path, train=True)
        val_dataset = NYUv2(root=data_path, train=False)
        return train_dataset, val_dataset

    @property
    def trainer(self):
        if self._trainer is None:
            self._trainer = L.Trainer(devices=1)
        return self._trainer

    def get_decoders(self):
        from fusion_bench.modelpool.nyuv2_modelpool import NYUv2ModelPool

        modelpool: NYUv2ModelPool = self._program.modelpool
        decoders = nn.ModuleDict()
        for task in self.config.tasks:
            decoders[task] = modelpool.load_model(task, encoder_only=False).decoders[
                task
            ]
        return decoders

    def evaluate(self, encoder: ResnetDilated):
        model = NYUv2MTLModule(
            encoder,
            self.get_decoders(),
            tasks=self.config.tasks,
            task_weights=[1] * len(self.config.tasks),
        )
        _, val_dataset = self.load_datasets()
        val_loader = DataLoader(
            val_dataset,
            batch_size=self.config.batch_size,
            shuffle=False,
            num_workers=self.config.num_workers,
        )
        report = self.trainer.validate(model, val_loader)
        if isinstance(report, list) and len(report) == 1:
            report = report[0]
        return report

CLIP Task Pool

CLIPVisionModelTaskPool

Bases: BaseTaskPool, LightningFabricMixin

This class is used to define the image classification task for CLIP models.

Attributes:

  • test_datasets (Union[DictConfig, Dict[str, Dataset]]) –

    The test datasets to evaluate the model on.

  • processor (Union[DictConfig, CLIPProcessor]) –

    The processor used for preprocessing the input data.

  • data_processor (Union[DictConfig, CLIPProcessor]) –

    The data processor used for processing the input data.

  • clip_model (Union[DictConfig, CLIPModel]) –

    The CLIP model used for evaluation.

  • dataloader_kwargs (DictConfig) –

    Keyword arguments for the data loader.

  • layer_wise_feature_save_path (Optional[str]) –

    Path to save the layer-wise features.

  • layer_wise_feature_first_token_only (bool) –

    Boolean indicating whether to save only the first token of the features.

  • layer_wise_feature_max_num (Optional[int]) –

    Maximum number of features to save.

  • fast_dev_run (bool) –

    Boolean indicating whether to run in fast development mode.

Source code in fusion_bench/taskpool/clip_vision/taskpool.py
class CLIPVisionModelTaskPool(
    BaseTaskPool,
    LightningFabricMixin,
):
    """
    This class is used to define the image classification task for CLIP models.

    Attributes:
        test_datasets (Union[DictConfig, Dict[str, Dataset]]): The test datasets to evaluate the model on.
        processor (Union[DictConfig, CLIPProcessor]): The processor used for preprocessing the input data.
        data_processor (Union[DictConfig, CLIPProcessor]): The data processor used for processing the input data.
        clip_model (Union[DictConfig, CLIPModel]): The CLIP model used for evaluation.
        dataloader_kwargs (DictConfig): Keyword arguments for the data loader.
        layer_wise_feature_save_path (Optional[str]): Path to save the layer-wise features.
        layer_wise_feature_first_token_only (bool): Boolean indicating whether to save only the first token of the features.
        layer_wise_feature_max_num (Optional[int]): Maximum number of features to save.
        fast_dev_run (bool): Boolean indicating whether to run in fast development mode.
    """

    _is_setup = False

    # hooks and handles for saving layer-wise features
    _layer_wise_feature_save_hooks: Dict[int, LayerWiseFeatureSaver] = {}
    _layer_wise_feature_save_hook_handles: Dict[int, RemovableHandle] = {}

    _config_mapping = BaseTaskPool._config_mapping | {
        "_test_datasets": "test_datasets",
        "_processor": "processor",
        "_data_processor": "data_processor",
        "_clip_model": "clip_model",
        "_dataloader_kwargs": "dataloader_kwargs",
        "_layer_wise_feature_save_path": "layer_wise_feature_save_path",
        "fast_dev_run": "fast_dev_run",
    }

    def __init__(
        self,
        test_datasets: Union[DictConfig, Dict[str, Dataset]],
        *,
        processor: Union[str, DictConfig, CLIPProcessor],
        clip_model: Union[str, DictConfig, CLIPModel],
        data_processor: Union[DictConfig, CLIPProcessor] = None,
        dataloader_kwargs: DictConfig = None,
        layer_wise_feature_save_path: Optional[str] = None,
        layer_wise_feature_first_token_only: bool = True,
        layer_wise_feature_max_num: Optional[int] = None,
        fast_dev_run: bool = False,
        **kwargs,
    ):
        """
        Initialize the CLIPVisionModelTaskPool.
        """
        self._test_datasets = test_datasets
        self._processor = processor
        self._data_processor = data_processor
        self._clip_model = clip_model
        self._dataloader_kwargs = dataloader_kwargs or {}

        # layer-wise feature saving
        self._layer_wise_feature_save_path = layer_wise_feature_save_path
        self.layer_wise_feature_save_path = (
            Path(layer_wise_feature_save_path)
            if layer_wise_feature_save_path is not None
            else None
        )
        self.layer_wise_feature_first_token_only = layer_wise_feature_first_token_only
        self.layer_wise_feature_max_num = layer_wise_feature_max_num

        self.fast_dev_run = fast_dev_run
        super().__init__(**kwargs)

    def setup(self):
        """
        Set up the processor, data processor, CLIP model, test datasets, and data loaders.
        """
        # setup processor and clip model
        if isinstance(self._processor, str):
            self.processor = CLIPProcessor.from_pretrained(self._processor)
        elif (
            isinstance(self._processor, (dict, DictConfig))
            and "_target_" in self._processor
        ):
            self.processor = instantiate(self._processor)
        else:
            self.processor = self._processor

        if self._data_processor is None:
            self.data_processor = self.processor
        else:
            self.data_processor = (
                instantiate(self._data_processor)
                if isinstance(self._data_processor, DictConfig)
                else self._data_processor
            )

        if isinstance(self._clip_model, str):
            self.clip_model = CLIPModel.from_pretrained(self._clip_model)
        elif (
            isinstance(self._clip_model, (dict, DictConfig))
            and "_target_" in self._clip_model
        ):
            self.clip_model = instantiate(self._clip_model)
        else:
            self.clip_model = self._clip_model

        self.clip_model = self.fabric.to_device(self.clip_model)
        self.clip_model.requires_grad_(False)
        self.clip_model.eval()

        # Load the test datasets
        self.test_datasets = {
            name: instantiate(dataset) if isinstance(dataset, DictConfig) else dataset
            for name, dataset in self._test_datasets.items()
        }
        self.test_datasets = {
            name: CLIPDataset(dataset, self.data_processor)
            for name, dataset in self.test_datasets.items()
        }
        # Setup the dataloaders
        self.test_dataloaders = {
            name: DataLoader(
                dataset,
                **self._dataloader_kwargs,
                collate_fn=(
                    raw_image_collate_fn if self.data_processor is None else None
                ),
            )
            for name, dataset in self.test_datasets.items()
        }
        self.test_dataloaders = {
            name: self.fabric.setup_dataloaders(dataloader)
            for name, dataloader in self.test_dataloaders.items()
        }

        self._is_setup = True

    @torch.no_grad()
    def _evaluate(
        self,
        classifier: HFCLIPClassifier,
        test_loader: DataLoader,
        num_classes: int,
        task_name: str = None,
    ):
        """
        Evaluate the classifier on the test dataset (single-task evaluation).

        Args:
            classifier (HFCLIPClassifier): The classifier to evaluate.
            test_loader (DataLoader): The data loader for the test dataset.
            num_classes (int): The number of classes in the classification task.
            task_name (str): The name of the task.

        Returns:
            Dict[str, float]: A dictionary containing the accuracy and loss of the classifier on the test dataset.
        """
        accuracy: MulticlassAccuracy = Accuracy(
            task="multiclass", num_classes=num_classes
        )
        classifier.eval()
        loss_metric = MeanMetric()
        # if fast_dev_run is set, we only evaluate on a batch of the data
        if self.fast_dev_run:
            log.info("Running under fast_dev_run mode, evaluating on a single batch.")
            test_loader = itertools.islice(test_loader, 1)
        else:
            test_loader = test_loader

        pbar = tqdm(
            test_loader,
            desc=f"Evaluating {task_name}",
            leave=False,
            dynamic_ncols=True,
        )
        for batch in pbar:
            inputs, targets = batch
            outputs = classifier(
                inputs,
                return_image_embeds=True,
                return_dict=True,
                task_name=task_name,
            )
            logits: Tensor = outputs["logits"]

            loss = F.cross_entropy(logits, targets)
            loss_metric.update(loss.detach().cpu())
            acc = accuracy(logits.detach().cpu(), targets.detach().cpu())
            pbar.set_postfix(
                {
                    "accuracy": accuracy.compute().item(),
                    "loss": loss_metric.compute().item(),
                }
            )

        acc = accuracy.compute().item()
        loss = loss_metric.compute().item()
        results = {"accuracy": acc, "loss": loss}
        return results

    def evaluate(
        self,
        model: Union[CLIPVisionModel, CLIPVisionTransformer],
        name=None,
        **kwargs,
    ):
        """
        Evaluate the model on the image classification task.

        Args:
            model (Union[CLIPVisionModel, CLIPVisionTransformer]): The model to evaluate.
            name (Optional[str]): The name of the model. This will be logged into the report if not None.

        Returns:
            Dict[str, Any]: A dictionary containing the evaluation results for each task.
        """
        if not self._is_setup:
            self.setup()

        report = {}
        # CLIPVisionModel works the same with CLIPVisonTransformer, so we can use it directly
        if hasattr(model, "is_surgery_model") and model.is_surgery_model:
            log.info("running evaluation on a surgery model.")
            model: "SurgeryModelWrapper" = model
            self.clip_model.vision_model = model
        else:
            # replace the vision encoder with the model
            self.clip_model.vision_model = model
        classifier = HFCLIPClassifier(
            self.clip_model,
            processor=self.processor,
        )
        classifier = cast(HFCLIPClassifier, self.fabric.to_device(classifier))
        # collect basic model information
        training_params, all_params = count_parameters(model)
        report["model_info"] = {
            "trainable_params": training_params,
            "all_params": all_params,
            "trainable_percentage": training_params / all_params,
        }
        if name is not None:
            report["model_info"]["name"] = name

        # evaluate on each task
        pbar = tqdm(
            self.test_dataloaders.items(),
            desc="Evaluating tasks",
            total=len(self.test_dataloaders),
        )
        for task_name, test_dataloader in pbar:
            classnames, templates = get_classnames_and_templates(task_name)
            self.on_task_evaluation_begin(classifier, task_name)
            classifier.set_classification_task(classnames, templates)
            result = self._evaluate(
                classifier,
                test_dataloader,
                num_classes=len(classnames),
                task_name=task_name,
            )
            report[task_name] = result
            self.on_task_evaluation_end()

        # calculate the average accuracy and loss
        if "average" not in report:
            report["average"] = {}
            accuracies = [
                value["accuracy"]
                for key, value in report.items()
                if "accuracy" in value
            ]
            if len(accuracies) > 0:
                average_accuracy = sum(accuracies) / len(accuracies)
                report["average"]["accuracy"] = average_accuracy
            losses = [value["loss"] for key, value in report.items() if "loss" in value]
            if len(losses) > 0:
                average_loss = sum(losses) / len(losses)
                report["average"]["loss"] = average_loss

        log.info(f"Evaluation Result: {report}")
        if self.fabric.is_global_zero and len(self.fabric._loggers) > 0:
            save_path = os.path.join(self.log_dir, "report.json")
            for version in itertools.count(1):
                if not os.path.exists(save_path):
                    break
                # if the file already exists, increment the version to avoid overwriting
                save_path = os.path.join(self.log_dir, f"report_{version}.json")
            with open(save_path, "w") as fp:
                json.dump(report, fp)
            log.info(f"Evaluation report saved to {save_path}")
        return report

    def on_task_evaluation_begin(self, classifier: HFCLIPClassifier, task_name: str):
        """
        Called at the beginning of task evaluation to set up hooks for saving layer-wise features.

        Args:
            classifier (HFCLIPClassifier): The classifier being evaluated.
            task_name (str): The name of the task being evaluated.
        """
        if self.layer_wise_feature_save_path is not None:
            # setup hooks for saving layer-wise features
            assert isinstance(
                classifier.clip_model.vision_model,
                (CLIPVisionTransformer, CLIPVisionModel),
            ), "Vision model is expected to be a CLIPVisionTransformer"
            vision_model = classifier.clip_model.vision_model
            if isinstance(vision_model, CLIPVisionModel):
                vision_model = vision_model.vision_model
                # assign forward hooks for each layer
            for i, layer in enumerate(vision_model.encoder.layers):
                self._layer_wise_feature_save_hooks[i] = LayerWiseFeatureSaver(
                    self.layer_wise_feature_save_path / task_name / f"layer_{i}.pth",
                    first_token_only=self.layer_wise_feature_first_token_only,
                    max_num=self.layer_wise_feature_max_num,
                )
                self._layer_wise_feature_save_hook_handles[i] = (
                    layer.register_forward_hook(self._layer_wise_feature_save_hooks[i])
                )

    def on_task_evaluation_end(self):
        """
        Called at the end of task evaluation to save features and remove hooks.
        """
        if self.layer_wise_feature_save_path is not None:
            # save features and remove hooks after evaluation
            for i, hook in self._layer_wise_feature_save_hooks.items():
                hook.save_features()
                self._layer_wise_feature_save_hook_handles[i].remove()
__init__(test_datasets, *, processor, clip_model, data_processor=None, dataloader_kwargs=None, layer_wise_feature_save_path=None, layer_wise_feature_first_token_only=True, layer_wise_feature_max_num=None, fast_dev_run=False, **kwargs)

Initialize the CLIPVisionModelTaskPool.

Source code in fusion_bench/taskpool/clip_vision/taskpool.py
def __init__(
    self,
    test_datasets: Union[DictConfig, Dict[str, Dataset]],
    *,
    processor: Union[str, DictConfig, CLIPProcessor],
    clip_model: Union[str, DictConfig, CLIPModel],
    data_processor: Union[DictConfig, CLIPProcessor] = None,
    dataloader_kwargs: DictConfig = None,
    layer_wise_feature_save_path: Optional[str] = None,
    layer_wise_feature_first_token_only: bool = True,
    layer_wise_feature_max_num: Optional[int] = None,
    fast_dev_run: bool = False,
    **kwargs,
):
    """
    Initialize the CLIPVisionModelTaskPool.
    """
    self._test_datasets = test_datasets
    self._processor = processor
    self._data_processor = data_processor
    self._clip_model = clip_model
    self._dataloader_kwargs = dataloader_kwargs or {}

    # layer-wise feature saving
    self._layer_wise_feature_save_path = layer_wise_feature_save_path
    self.layer_wise_feature_save_path = (
        Path(layer_wise_feature_save_path)
        if layer_wise_feature_save_path is not None
        else None
    )
    self.layer_wise_feature_first_token_only = layer_wise_feature_first_token_only
    self.layer_wise_feature_max_num = layer_wise_feature_max_num

    self.fast_dev_run = fast_dev_run
    super().__init__(**kwargs)
evaluate(model, name=None, **kwargs)

Evaluate the model on the image classification task.

Parameters:

  • model (Union[CLIPVisionModel, CLIPVisionTransformer]) –

    The model to evaluate.

  • name (Optional[str], default: None ) –

    The name of the model. This will be logged into the report if not None.

Returns:

  • –

    Dict[str, Any]: A dictionary containing the evaluation results for each task.

Source code in fusion_bench/taskpool/clip_vision/taskpool.py
def evaluate(
    self,
    model: Union[CLIPVisionModel, CLIPVisionTransformer],
    name=None,
    **kwargs,
):
    """
    Evaluate the model on the image classification task.

    Args:
        model (Union[CLIPVisionModel, CLIPVisionTransformer]): The model to evaluate.
        name (Optional[str]): The name of the model. This will be logged into the report if not None.

    Returns:
        Dict[str, Any]: A dictionary containing the evaluation results for each task.
    """
    if not self._is_setup:
        self.setup()

    report = {}
    # CLIPVisionModel works the same with CLIPVisonTransformer, so we can use it directly
    if hasattr(model, "is_surgery_model") and model.is_surgery_model:
        log.info("running evaluation on a surgery model.")
        model: "SurgeryModelWrapper" = model
        self.clip_model.vision_model = model
    else:
        # replace the vision encoder with the model
        self.clip_model.vision_model = model
    classifier = HFCLIPClassifier(
        self.clip_model,
        processor=self.processor,
    )
    classifier = cast(HFCLIPClassifier, self.fabric.to_device(classifier))
    # collect basic model information
    training_params, all_params = count_parameters(model)
    report["model_info"] = {
        "trainable_params": training_params,
        "all_params": all_params,
        "trainable_percentage": training_params / all_params,
    }
    if name is not None:
        report["model_info"]["name"] = name

    # evaluate on each task
    pbar = tqdm(
        self.test_dataloaders.items(),
        desc="Evaluating tasks",
        total=len(self.test_dataloaders),
    )
    for task_name, test_dataloader in pbar:
        classnames, templates = get_classnames_and_templates(task_name)
        self.on_task_evaluation_begin(classifier, task_name)
        classifier.set_classification_task(classnames, templates)
        result = self._evaluate(
            classifier,
            test_dataloader,
            num_classes=len(classnames),
            task_name=task_name,
        )
        report[task_name] = result
        self.on_task_evaluation_end()

    # calculate the average accuracy and loss
    if "average" not in report:
        report["average"] = {}
        accuracies = [
            value["accuracy"]
            for key, value in report.items()
            if "accuracy" in value
        ]
        if len(accuracies) > 0:
            average_accuracy = sum(accuracies) / len(accuracies)
            report["average"]["accuracy"] = average_accuracy
        losses = [value["loss"] for key, value in report.items() if "loss" in value]
        if len(losses) > 0:
            average_loss = sum(losses) / len(losses)
            report["average"]["loss"] = average_loss

    log.info(f"Evaluation Result: {report}")
    if self.fabric.is_global_zero and len(self.fabric._loggers) > 0:
        save_path = os.path.join(self.log_dir, "report.json")
        for version in itertools.count(1):
            if not os.path.exists(save_path):
                break
            # if the file already exists, increment the version to avoid overwriting
            save_path = os.path.join(self.log_dir, f"report_{version}.json")
        with open(save_path, "w") as fp:
            json.dump(report, fp)
        log.info(f"Evaluation report saved to {save_path}")
    return report
on_task_evaluation_begin(classifier, task_name)

Called at the beginning of task evaluation to set up hooks for saving layer-wise features.

Parameters:

  • classifier (HFCLIPClassifier) –

    The classifier being evaluated.

  • task_name (str) –

    The name of the task being evaluated.

Source code in fusion_bench/taskpool/clip_vision/taskpool.py
def on_task_evaluation_begin(self, classifier: HFCLIPClassifier, task_name: str):
    """
    Called at the beginning of task evaluation to set up hooks for saving layer-wise features.

    Args:
        classifier (HFCLIPClassifier): The classifier being evaluated.
        task_name (str): The name of the task being evaluated.
    """
    if self.layer_wise_feature_save_path is not None:
        # setup hooks for saving layer-wise features
        assert isinstance(
            classifier.clip_model.vision_model,
            (CLIPVisionTransformer, CLIPVisionModel),
        ), "Vision model is expected to be a CLIPVisionTransformer"
        vision_model = classifier.clip_model.vision_model
        if isinstance(vision_model, CLIPVisionModel):
            vision_model = vision_model.vision_model
            # assign forward hooks for each layer
        for i, layer in enumerate(vision_model.encoder.layers):
            self._layer_wise_feature_save_hooks[i] = LayerWiseFeatureSaver(
                self.layer_wise_feature_save_path / task_name / f"layer_{i}.pth",
                first_token_only=self.layer_wise_feature_first_token_only,
                max_num=self.layer_wise_feature_max_num,
            )
            self._layer_wise_feature_save_hook_handles[i] = (
                layer.register_forward_hook(self._layer_wise_feature_save_hooks[i])
            )
on_task_evaluation_end()

Called at the end of task evaluation to save features and remove hooks.

Source code in fusion_bench/taskpool/clip_vision/taskpool.py
def on_task_evaluation_end(self):
    """
    Called at the end of task evaluation to save features and remove hooks.
    """
    if self.layer_wise_feature_save_path is not None:
        # save features and remove hooks after evaluation
        for i, hook in self._layer_wise_feature_save_hooks.items():
            hook.save_features()
            self._layer_wise_feature_save_hook_handles[i].remove()
setup()

Set up the processor, data processor, CLIP model, test datasets, and data loaders.

Source code in fusion_bench/taskpool/clip_vision/taskpool.py
def setup(self):
    """
    Set up the processor, data processor, CLIP model, test datasets, and data loaders.
    """
    # setup processor and clip model
    if isinstance(self._processor, str):
        self.processor = CLIPProcessor.from_pretrained(self._processor)
    elif (
        isinstance(self._processor, (dict, DictConfig))
        and "_target_" in self._processor
    ):
        self.processor = instantiate(self._processor)
    else:
        self.processor = self._processor

    if self._data_processor is None:
        self.data_processor = self.processor
    else:
        self.data_processor = (
            instantiate(self._data_processor)
            if isinstance(self._data_processor, DictConfig)
            else self._data_processor
        )

    if isinstance(self._clip_model, str):
        self.clip_model = CLIPModel.from_pretrained(self._clip_model)
    elif (
        isinstance(self._clip_model, (dict, DictConfig))
        and "_target_" in self._clip_model
    ):
        self.clip_model = instantiate(self._clip_model)
    else:
        self.clip_model = self._clip_model

    self.clip_model = self.fabric.to_device(self.clip_model)
    self.clip_model.requires_grad_(False)
    self.clip_model.eval()

    # Load the test datasets
    self.test_datasets = {
        name: instantiate(dataset) if isinstance(dataset, DictConfig) else dataset
        for name, dataset in self._test_datasets.items()
    }
    self.test_datasets = {
        name: CLIPDataset(dataset, self.data_processor)
        for name, dataset in self.test_datasets.items()
    }
    # Setup the dataloaders
    self.test_dataloaders = {
        name: DataLoader(
            dataset,
            **self._dataloader_kwargs,
            collate_fn=(
                raw_image_collate_fn if self.data_processor is None else None
            ),
        )
        for name, dataset in self.test_datasets.items()
    }
    self.test_dataloaders = {
        name: self.fabric.setup_dataloaders(dataloader)
        for name, dataloader in self.test_dataloaders.items()
    }

    self._is_setup = True

SparseWEMoECLIPVisionModelTaskPool

Bases: CLIPVisionModelTaskPool

Source code in fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py
class SparseWEMoECLIPVisionModelTaskPool(CLIPVisionModelTaskPool):

    # hooks and handles for saving layer-wise routing weights
    _layer_wise_routing_weights_save_hooks: Dict[Any, LayerWiseRoutingWeightSaver] = {}
    _layer_wise_routing_weights_save_hook_handles: Dict[Any, RemovableHandle] = {}

    _config_mapping = CLIPVisionModelTaskPool._config_mapping | {
        "_layer_wise_routing_weights_save_path": "layer_wise_routing_weights_save_path",
    }

    def __init__(
        self,
        layer_wise_routing_weights_save_path: Optional[str],
        layer_wise_routing_weights_max_num: Optional[int] = None,
        **kwargs,
    ):
        # save path for layer-wise routing weights
        self._layer_wise_routing_weights_save_path = (
            layer_wise_routing_weights_save_path
        )
        self.layer_wise_routing_weights_save_path = (
            Path(layer_wise_routing_weights_save_path)
            if layer_wise_routing_weights_save_path is not None
            else None
        )
        self.layer_wise_routing_weights_max_num = layer_wise_routing_weights_max_num
        super().__init__(**kwargs)

    def on_task_evaluation_begin(self, classifier: HFCLIPClassifier, task_name: str):
        super().on_task_evaluation_begin(classifier, task_name)
        if self.layer_wise_routing_weights_save_path is not None:
            # setup hooks for saving layer-wise routing weights
            assert isinstance(
                classifier.clip_model.vision_model,
                (CLIPVisionTransformer, CLIPVisionModel),
            ), "Vision model is expected to be a CLIPVisionTransformer"
            vision_model = classifier.clip_model.vision_model
            if isinstance(vision_model, CLIPVisionModel):
                vision_model = vision_model.vision_model
                # assign forward hooks for each layer
            shared_gate = None
            for i, layer in enumerate(vision_model.encoder.layers):
                mlp = layer.mlp
                assert isinstance(
                    mlp,
                    (SparseWeightEnsemblingMoE, SparseWeightEnsemblingMoE_ShardGate),
                ), f"MLP is expected to be a SparseWeightEnsemblingMoE or SparseWeightEnsemblingMoE_ShardGate, but got {type(mlp)}"
                # layer-wise routing weights
                hook = LayerWiseRoutingWeightSaver(
                    self.layer_wise_routing_weights_save_path
                    / task_name
                    / f"layer_{i}.pt",
                    max_num=self.layer_wise_routing_weights_max_num,
                )
                self._layer_wise_routing_weights_save_hooks[i] = hook
                if isinstance(mlp, SparseWeightEnsemblingMoE_ShardGate):
                    # if use shared gate, copy the gate to all layers to avoid multiple hooks
                    if shared_gate is None:
                        shared_gate = mlp.gate
                    mlp.gate = deepcopy(shared_gate)
                self._layer_wise_routing_weights_save_hook_handles[i] = (
                    mlp.gate.register_forward_hook(hook)
                )

    def on_task_evaluation_end(self):
        super().on_task_evaluation_end()
        if self.layer_wise_routing_weights_save_path is not None:
            # remove hooks for saving layer-wise routing weights
            for i, handle in self._layer_wise_routing_weights_save_hook_handles.items():
                self._layer_wise_routing_weights_save_hooks[i].save_routing_weights()
                self._layer_wise_routing_weights_save_hook_handles.pop(i)
                handle.remove()

RankoneMoECLIPVisionModelTaskPool

Bases: CLIPVisionModelTaskPool

Source code in fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py
class RankoneMoECLIPVisionModelTaskPool(CLIPVisionModelTaskPool):

    # hooks and handles for saving layer-wise routing weights
    _layer_wise_routing_weights_save_hooks: Dict[Any, LayerWiseRoutingWeightSaver] = {}
    _layer_wise_routing_weights_save_hook_handles: Dict[Any, RemovableHandle] = {}

    _config_mapping = CLIPVisionModelTaskPool._config_mapping | {
        "_layer_wise_routing_weights_save_path": "layer_wise_routing_weights_save_path",
    }

    def __init__(
        self,
        layer_wise_routing_weights_save_path: Optional[str],
        layer_wise_routing_weights_max_num: Optional[int] = None,
        **kwargs,
    ):
        # save path for layer-wise routing weights
        self._layer_wise_routing_weights_save_path = (
            layer_wise_routing_weights_save_path
        )
        self.layer_wise_routing_weights_save_path = (
            Path(layer_wise_routing_weights_save_path)
            if layer_wise_routing_weights_save_path is not None
            else None
        )
        self.layer_wise_routing_weights_max_num = layer_wise_routing_weights_max_num
        super().__init__(**kwargs)

    def on_task_evaluation_begin(self, classifier: HFCLIPClassifier, task_name: str):
        super().on_task_evaluation_begin(classifier, task_name)
        if self.layer_wise_routing_weights_save_path is not None:
            # setup hooks for saving layer-wise routing weights
            assert isinstance(
                classifier.clip_model.vision_model,
                (CLIPVisionTransformer, CLIPVisionModel),
            ), "Vision model is expected to be a CLIPVisionTransformer"
            vision_model = classifier.clip_model.vision_model
            if isinstance(vision_model, CLIPVisionModel):
                vision_model = vision_model.vision_model
                # assign forward hooks for each layer

            for i, layer in enumerate(vision_model.encoder.layers):
                mlp = layer.mlp
                assert isinstance(
                    mlp,
                    (RankOneMoE),
                ), f"MLP is expected to be a RankOneWeightEnsemblingMoE, but got {type(mlp)}"
                # layer-wise routing weights
                hook = LayerWiseRoutingWeightSaver(
                    self.layer_wise_routing_weights_save_path
                    / task_name
                    / f"layer_{i}.pt",
                    max_num=self.layer_wise_routing_weights_max_num,
                )
                self._layer_wise_routing_weights_save_hooks[i] = hook
                self._layer_wise_routing_weights_save_hook_handles[i] = (
                    mlp.gate.register_forward_hook(hook)
                )

    def on_task_evaluation_end(self):
        super().on_task_evaluation_end()
        if self.layer_wise_routing_weights_save_path is not None:
            # remove hooks for saving layer-wise routing weights
            for i, handle in self._layer_wise_routing_weights_save_hook_handles.items():
                self._layer_wise_routing_weights_save_hooks[i].save_routing_weights()
                self._layer_wise_routing_weights_save_hook_handles.pop(i)
                handle.remove()

Natural Language Processing (NLP) Tasks

GPT-2

GPT2TextClassificationTaskPool

Bases: BaseTaskPool, LightningFabricMixin

A task pool for GPT2 text classification tasks. This class manages the tasks and provides methods for loading test dataset and evaluation.

Source code in fusion_bench/taskpool/gpt2_text_classification.py
class GPT2TextClassificationTaskPool(BaseTaskPool, LightningFabricMixin):
    """
    A task pool for GPT2 text classification tasks.
    This class manages the tasks and provides methods for loading test dataset and evaluation.
    """

    _config_mapping = BaseTaskPool._config_mapping | {
        "_test_datasets": "test_datasets",
        "_tokenizer": "tokenizer",
        "dataloader_kwargs": "dataloader_kwargs",
        "fast_dev_run": "fast_dev_run",
    }

    def __init__(
        self,
        test_datasets: DictConfig,
        tokenizer: DictConfig,
        dataloader_kwargs: DictConfig,
        fast_dev_run: bool,
        **kwargs,
    ):
        self._test_datasets = test_datasets
        self._tokenizer = tokenizer
        self.dataloader_kwargs = dataloader_kwargs
        self.fast_dev_run = fast_dev_run
        super().__init__(**kwargs)

        self.setup()

    def setup(self):
        global tokenizer
        self.tokenizer = tokenizer = instantiate(self._tokenizer)

    def get_classifier(
        self, task_name: str, model: GPT2Model
    ) -> GPT2ForSequenceClassification:
        modelpool = self._program.modelpool
        classifier = modelpool.load_classifier(task_name)
        classifier.transformer = deepcopy(model)
        return classifier

    @torch.no_grad()
    def evaluate_single_task(
        self,
        task_name: str,
        model: GPT2Model,
        test_loader: DataLoader,
    ):
        loss_metric = MeanMetric()
        # load classifier and replace the backbone with the passed model
        model: GPT2ForSequenceClassification = self.get_classifier(task_name, model)
        accuracy = Accuracy("multiclass", num_classes=model.num_labels)
        model = self.fabric.setup(model)

        if self.config.get("fast_dev_run", False):
            log.info("Running under fast_dev_run mode, evaluating on a single batch.")
            test_loader = itertools.islice(test_loader, 1)
        else:
            test_loader = test_loader

        for batch in (
            pbar := tqdm(
                test_loader, desc="Evaluating", leave=False, dynamic_ncols=True
            )
        ):
            input_ids = batch["input_ids"]
            attention_mask = batch["attention_mask"]
            labels = batch["labels"]

            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            loss = F.cross_entropy(logits, labels)

            accuracy(logits.detach().cpu(), labels.detach().cpu())
            loss_metric.update(loss.detach().cpu())
            pbar.set_postfix(
                {
                    "accuracy": accuracy.compute().item(),
                    "loss": loss_metric.compute().item(),
                }
            )

        acc = accuracy.compute().item()
        loss = loss_metric.compute().item()
        results = {"accuracy": acc, "loss": loss}
        log.info(f"Results for task {task_name}: {results}")
        return results

    def get_test_dataloader(self, task_name: str):
        dataset = instantiate(self._test_datasets[task_name])
        dataloader_kwargs = {
            "shuffle": False,
        }
        dataloader_kwargs.update(self.dataloader_kwargs)
        dataloader = DataLoader(
            dataset, collate_fn=default_data_collator, **dataloader_kwargs
        )
        if self.fabric is not None:
            dataloader = self.fabric.setup_dataloaders(dataloader)
        return dataloader

    @override
    def evaluate(self, model: GPT2Model, name: str = None):
        """Evaluate the model on the test datasets.

        Args:
            model (GPT2Model): The model to evaluate.
            name (str, optional): The name of the model. Defaults to None. This is used to identify the model in the report.

        Returns:
            dict: A dictionary containing the evaluation results for each task.
        """
        report = {}
        if name is not None:
            report["name"] = name
        for task_name in (pbar := tqdm(self._test_datasets, desc="Evaluating tasks")):
            pbar.set_description(f"Evaluating task {task_name}")
            dataloader = self.get_test_dataloader(task_name)
            result = self.evaluate_single_task(task_name, model, dataloader)
            report[task_name] = result

        # calculate the average accuracy and loss
        if "average" not in report:
            report["average"] = {}
            accuracies = [
                value["accuracy"]
                for key, value in report.items()
                if isinstance(value, dict) and "accuracy" in value
            ]
            if len(accuracies) > 0:
                average_accuracy = sum(accuracies) / len(accuracies)
                report["average"]["accuracy"] = average_accuracy
            losses = [value["loss"] for key, value in report.items() if "loss" in value]
            if len(losses) > 0:
                average_loss = sum(losses) / len(losses)
                report["average"]["loss"] = average_loss

        log.info(f"Evaluation Result: {report}")
        return report
evaluate(model, name=None)

Evaluate the model on the test datasets.

Parameters:

  • model (GPT2Model) –

    The model to evaluate.

  • name (str, default: None ) –

    The name of the model. Defaults to None. This is used to identify the model in the report.

Returns:

  • dict –

    A dictionary containing the evaluation results for each task.

Source code in fusion_bench/taskpool/gpt2_text_classification.py
@override
def evaluate(self, model: GPT2Model, name: str = None):
    """Evaluate the model on the test datasets.

    Args:
        model (GPT2Model): The model to evaluate.
        name (str, optional): The name of the model. Defaults to None. This is used to identify the model in the report.

    Returns:
        dict: A dictionary containing the evaluation results for each task.
    """
    report = {}
    if name is not None:
        report["name"] = name
    for task_name in (pbar := tqdm(self._test_datasets, desc="Evaluating tasks")):
        pbar.set_description(f"Evaluating task {task_name}")
        dataloader = self.get_test_dataloader(task_name)
        result = self.evaluate_single_task(task_name, model, dataloader)
        report[task_name] = result

    # calculate the average accuracy and loss
    if "average" not in report:
        report["average"] = {}
        accuracies = [
            value["accuracy"]
            for key, value in report.items()
            if isinstance(value, dict) and "accuracy" in value
        ]
        if len(accuracies) > 0:
            average_accuracy = sum(accuracies) / len(accuracies)
            report["average"]["accuracy"] = average_accuracy
        losses = [value["loss"] for key, value in report.items() if "loss" in value]
        if len(losses) > 0:
            average_loss = sum(losses) / len(losses)
            report["average"]["loss"] = average_loss

    log.info(f"Evaluation Result: {report}")
    return report

Flan-T5

fusion_bench.compat.taskpool.flan_t5_glue_text_generation.FlanT5GLUETextGenerationTask

Bases: BaseTask

Source code in fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py
class FlanT5GLUETextGenerationTask(BaseTask):
    _taskpool: "FlanT5GLUETextGenerationTaskPool" = None

    @property
    def taskpool(self):
        if self._taskpool is not None:
            return self._taskpool
        else:
            raise ValueError("Taskpool not set")

    @property
    def fabric(self):
        return self.taskpool.fabric

    @property
    def tokenizer(self):
        return self.taskpool.tokenizer

    @functools.cached_property
    def dataset(self):
        log.info(f'Loading dataset: "{self.config.dataset.name}"')
        dataset = load_glue_dataset(
            self.config.dataset.name, self.tokenizer, self.taskpool.config.cache_dir
        )
        return dataset

    @functools.cached_property
    def test_dataset(self):
        return self.dataset[self.config.dataset.split]

    @property
    def test_loader(self):
        loader = DataLoader(
            self.test_dataset,
            batch_size=self.taskpool.config.batch_size,
            num_workers=self.taskpool.config.num_workers,
            shuffle=False,
            collate_fn=default_data_collator,
        )
        loader = self.fabric.setup_dataloaders(loader)
        return loader

LM-Eval-Harness Integration (LLM)

LMEvalHarnessTaskPool

Bases: BaseTaskPool, LightningFabricMixin

Source code in fusion_bench/taskpool/lm_eval_harness/taskpool.py
class LMEvalHarnessTaskPool(BaseTaskPool, LightningFabricMixin):
    def __init__(
        self,
        tasks: Union[str, List[str]],
        apply_chat_template: bool = False,
        include_path: Optional[str] = None,
        batch_size: int = 1,
        metadata: Optional[DictConfig] = None,
        verbosity: Optional[
            Literal["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"]
        ] = None,
        output_path: Optional[str] = None,
        log_samples: bool = False,
        _usage_: Optional[str] = None,
        _version_: Optional[str] = None,
        **kwargs,
    ):
        super().__init__(_usage_=_usage_, _version_=_version_)
        self.tasks = tasks
        self.include_path = include_path
        self.batch_size = batch_size
        self.metadata = metadata
        self.apply_chat_template = apply_chat_template
        self.verbosity = verbosity
        self.kwargs = kwargs
        self.output_path = output_path
        self.log_samples = log_samples

    def evaluate(self, model, *command_line_args, **kwargs):
        command_line_args = []
        if self.include_path is not None:
            command_line_args.extend(["--include_path", self.include_path])
        if isinstance(self.tasks, (list, ListConfig)):
            command_line_args.extend(["--tasks", ",".join(self.tasks)])
        elif isinstance(self.tasks, str):
            command_line_args.extend(["--tasks", self.tasks])
        if self.apply_chat_template:
            command_line_args.extend(
                ["--apply_chat_template", str(self.apply_chat_template)]
            )
        if self.batch_size is not None:
            command_line_args.extend(["--batch_size", str(self.batch_size)])
        if self.verbosity is not None:
            command_line_args.extend(["--verbosity", str(self.verbosity)])
        if self.metadata is not None:
            command_line_args.extend(["--metadata", str(self.metadata)])
        if self.output_path is None:
            command_line_args.extend(
                [
                    "--output_path",
                    os.path.join(self.log_dir, "lm_eval_results"),
                ]
            )
        else:
            command_line_args.extend(["--output_path", self.output_path])
        if self.log_samples:
            command_line_args.extend(["--log_samples"])
        for key, value in kwargs.items():
            command_line_args.extend([f"--{key}", str(value)])

        parser = setup_parser()
        check_argument_types(parser)
        args = parser.parse_args(args=command_line_args)
        log.info("LM-Eval Harness arguments:\n%s", args)

        if not lightning.fabric.is_wrapped(model):
            model = self.fabric.setup(model)
        args.model = lm_eval.models.huggingface.HFLM(pretrained=model)
        cli_evaluate(args)

Task Agnostic

Utility Classes

DummyTaskPool

Bases: BaseTaskPool

This is a dummy task pool used for debugging purposes. It inherits from the base TaskPool class.

Source code in fusion_bench/taskpool/dummy.py
class DummyTaskPool(BaseTaskPool):
    """
    This is a dummy task pool used for debugging purposes. It inherits from the base TaskPool class.
    """

    def __init__(self, model_save_path: Optional[str] = None):
        super().__init__()
        self.model_save_path = model_save_path

    def evaluate(self, model):
        """
        Evaluate the given model.
        This method does nothing but print the parameters of the model in a human-readable format.

        Args:
            model: The model to evaluate.
        """
        if rank_zero_only.rank == 0:
            print_parameters(model, is_human_readable=True)

            if self.model_save_path is not None:
                with timeit_context(f"Saving the model to {self.model_save_path}"):
                    separate_save(model, self.model_save_path)

        return get_model_summary(model)
evaluate(model)

Evaluate the given model. This method does nothing but print the parameters of the model in a human-readable format.

Parameters:

  • model –

    The model to evaluate.

Source code in fusion_bench/taskpool/dummy.py
def evaluate(self, model):
    """
    Evaluate the given model.
    This method does nothing but print the parameters of the model in a human-readable format.

    Args:
        model: The model to evaluate.
    """
    if rank_zero_only.rank == 0:
        print_parameters(model, is_human_readable=True)

        if self.model_save_path is not None:
            with timeit_context(f"Saving the model to {self.model_save_path}"):
                separate_save(model, self.model_save_path)

    return get_model_summary(model)