Skip to content

Model Training/Fine-Tuning

CLIP vision model fine-tuning

ImageClassificationFineTuning

Bases: BaseAlgorithm

Fine-tuning algorithm for image classification models.

This class implements end-to-end fine-tuning for image classification tasks using PyTorch Lightning. It supports both epoch-based and step-based training with configurable optimizers, learning rate schedulers, and data loaders.

Parameters:

  • max_epochs (Optional[int]) –

    Maximum number of training epochs. Mutually exclusive with max_steps.

  • max_steps (Optional[int]) –

    Maximum number of training steps. Mutually exclusive with max_epochs.

  • label_smoothing (float) –

    Label smoothing factor for cross-entropy loss (0.0 = no smoothing).

  • optimizer (DictConfig) –

    Configuration for the optimizer (e.g., Adam, SGD).

  • lr_scheduler (DictConfig) –

    Configuration for the learning rate scheduler.

  • dataloader_kwargs (DictConfig) –

    Additional keyword arguments for DataLoader construction.

  • **kwargs

    Additional arguments passed to the base class.

Raises:

  • AssertionError

    If both max_epochs and max_steps are provided.

Example
>>> config = {
...     'max_epochs': 10,
...     'max_steps': None,
...     'label_smoothing': 0.1,
...     'optimizer': {'_target_': 'torch.optim.Adam', 'lr': 0.001},
...     'lr_scheduler': {'_target_': 'torch.optim.lr_scheduler.StepLR', 'step_size': 5},
...     'dataloader_kwargs': {'batch_size': 32, 'num_workers': 4}
... }
>>> algorithm = ImageClassificationFineTuning(**config)
Source code in fusion_bench/method/classification/image_classification_finetune.py
@auto_register_config
class ImageClassificationFineTuning(BaseAlgorithm):
    """Fine-tuning algorithm for image classification models.

    This class implements end-to-end fine-tuning for image classification tasks using PyTorch Lightning.
    It supports both epoch-based and step-based training with configurable optimizers, learning rate
    schedulers, and data loaders.

    Args:
        max_epochs (Optional[int]): Maximum number of training epochs. Mutually exclusive with max_steps.
        max_steps (Optional[int]): Maximum number of training steps. Mutually exclusive with max_epochs.
        label_smoothing (float): Label smoothing factor for cross-entropy loss (0.0 = no smoothing).
        optimizer (DictConfig): Configuration for the optimizer (e.g., Adam, SGD).
        lr_scheduler (DictConfig): Configuration for the learning rate scheduler.
        dataloader_kwargs (DictConfig): Additional keyword arguments for DataLoader construction.
        **kwargs: Additional arguments passed to the base class.

    Raises:
        AssertionError: If both max_epochs and max_steps are provided.

    Example:
        ```python
        >>> config = {
        ...     'max_epochs': 10,
        ...     'max_steps': None,
        ...     'label_smoothing': 0.1,
        ...     'optimizer': {'_target_': 'torch.optim.Adam', 'lr': 0.001},
        ...     'lr_scheduler': {'_target_': 'torch.optim.lr_scheduler.StepLR', 'step_size': 5},
        ...     'dataloader_kwargs': {'batch_size': 32, 'num_workers': 4}
        ... }
        >>> algorithm = ImageClassificationFineTuning(**config)
        ```
    """

    def __init__(
        self,
        max_epochs: Optional[int],
        max_steps: Optional[int],
        training_data_ratio: Optional[float],
        label_smoothing: float,
        optimizer: DictConfig,
        lr_scheduler: DictConfig,
        dataloader_kwargs: DictConfig,
        save_top_k: int,
        save_interval: int,
        save_on_train_epoch_end: bool,
        **kwargs,
    ):
        super().__init__(**kwargs)
        assert (max_epochs is None or max_epochs < 0) or (
            max_steps is None or max_steps < 0
        ), "Only one of max_epochs or max_steps should be set."
        self.training_interval = (
            "epoch" if max_epochs is not None and max_epochs > 0 else "step"
        )
        if self.training_interval == "epoch":
            self.max_steps = -1
        log.info(f"Training interval: {self.training_interval}")
        log.info(f"Max epochs: {max_epochs}, max steps: {max_steps}")

    def run(self, modelpool: ResNetForImageClassificationPool):
        """Execute the fine-tuning process on the provided model pool.

        This method performs the complete fine-tuning workflow:
        1. Loads the pretrained model from the model pool
        2. Prepares training and validation datasets
        3. Configures optimizer and learning rate scheduler
        4. Sets up Lightning trainer with appropriate callbacks
        5. Executes the training process
        6. Saves the final fine-tuned model
        """
        # load model and dataset
        model = modelpool.load_pretrained_or_first_model()
        base_model_name = _get_base_model_name(model)

        assert isinstance(model, nn.Module), "Loaded model is not a nn.Module."

        assert (
            len(modelpool.train_dataset_names) == 1
        ), "Exactly one training dataset is required."
        self.dataset_name = dataset_name = modelpool.train_dataset_names[0]
        num_classes = get_num_classes(dataset_name)
        log.info(f"Number of classes for dataset {dataset_name}: {num_classes}")
        train_dataset = modelpool.load_train_dataset(dataset_name)
        log.info(f"Training dataset size: {len(train_dataset)}")
        if self.training_data_ratio is not None and 0 < self.training_data_ratio < 1:
            train_dataset, _ = random_split(
                train_dataset,
                lengths=[self.training_data_ratio, 1 - self.training_data_ratio],
            )
            log.info(
                f"Using {len(train_dataset)} samples for training after applying training_data_ratio={self.training_data_ratio}."
            )
        train_dataset = CLIPDataset(
            train_dataset, processor=modelpool.load_processor(stage="train")
        )
        train_loader = self.get_dataloader(train_dataset, stage="train")
        if modelpool.has_val_dataset:
            val_dataset = modelpool.load_val_dataset(dataset_name)
            val_dataset = CLIPDataset(
                val_dataset, processor=modelpool.load_processor(stage="val")
            )
            val_loader = self.get_dataloader(val_dataset, stage="val")
        else:
            val_loader = None

        # configure optimizer
        optimizer = instantiate(self.optimizer, params=model.parameters())
        if self.lr_scheduler is not None:
            lr_scheduler = instantiate(self.lr_scheduler, optimizer=optimizer)
            optimizer = {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": lr_scheduler,
                    "interval": self.training_interval,
                    "frequency": 1,
                },
            }
        log.info(f"optimizer:\n{optimizer}")

        lit_module = ERM_LitModule(
            model,
            optimizer,
            objective=nn.CrossEntropyLoss(label_smoothing=self.label_smoothing),
            metrics={
                "acc@1": Accuracy(task="multiclass", num_classes=num_classes),
                f"acc@{min(5,num_classes)}": Accuracy(
                    task="multiclass",
                    num_classes=num_classes,
                    top_k=min(5, num_classes),
                ),
            },
        )
        lit_module.train()

        log_dir = (
            self._program.path.log_dir
            if self._program is not None
            else "outputs/lightning_logs"
        )
        trainer = L.Trainer(
            max_epochs=self.max_epochs,
            max_steps=self.max_steps,
            accelerator="auto",
            devices="auto",
            callbacks=[
                pl_callbacks.LearningRateMonitor(logging_interval="step"),
                pl_callbacks.DeviceStatsMonitor(),
                pl_callbacks.ModelCheckpoint(
                    save_top_k=self.save_top_k,
                    every_n_train_steps=(
                        self.save_interval if self.training_interval == "step" else None
                    ),
                    every_n_epochs=(
                        self.save_interval
                        if self.training_interval == "epoch"
                        else None
                    ),
                    save_on_train_epoch_end=self.save_on_train_epoch_end,
                    save_last=True,
                ),
            ],
            logger=TensorBoardLogger(save_dir=log_dir, name="", version=""),
            fast_dev_run=RuntimeConstants.debug,
        )

        trainer.fit(
            lit_module, train_dataloaders=train_loader, val_dataloaders=val_loader
        )
        model = lit_module.model
        if rank_zero_only.rank == 0:
            log.info(f"Saving the final model to {log_dir}/raw_checkpoints/final")
            modelpool.save_model(
                model,
                path=os.path.join(
                    trainer.log_dir if trainer.log_dir is not None else log_dir,
                    "raw_checkpoints",
                    "final",
                ),
                algorithm_config=self.config,
                description=f"Fine-tuned ResNet model on dataset {dataset_name}.",
                base_model=base_model_name,
            )
        return model

    def get_dataloader(self, dataset, stage: str):
        """Create a DataLoader for the specified dataset and training stage.

        Constructs a PyTorch DataLoader with stage-appropriate configurations:
        - Training stage: shuffling enabled by default
        - Validation/test stages: shuffling disabled by default

        Args:
            dataset: The dataset to wrap in a DataLoader.
            stage (str): Training stage, must be one of "train", "val", or "test".
                Determines default shuffling behavior.

        Returns:
            DataLoader: Configured DataLoader for the given dataset and stage.
        """
        assert stage in ["train", "val", "test"], f"Invalid stage: {stage}"
        dataloader_kwargs = dict(self.dataloader_kwargs)
        if "shuffle" not in dataloader_kwargs:
            dataloader_kwargs["shuffle"] = stage == "train"
        return DataLoader(dataset, **dataloader_kwargs)
get_dataloader(dataset, stage)

Create a DataLoader for the specified dataset and training stage.

Constructs a PyTorch DataLoader with stage-appropriate configurations: - Training stage: shuffling enabled by default - Validation/test stages: shuffling disabled by default

Parameters:

  • dataset

    The dataset to wrap in a DataLoader.

  • stage (str) –

    Training stage, must be one of "train", "val", or "test". Determines default shuffling behavior.

Returns:

  • DataLoader

    Configured DataLoader for the given dataset and stage.

Source code in fusion_bench/method/classification/image_classification_finetune.py
def get_dataloader(self, dataset, stage: str):
    """Create a DataLoader for the specified dataset and training stage.

    Constructs a PyTorch DataLoader with stage-appropriate configurations:
    - Training stage: shuffling enabled by default
    - Validation/test stages: shuffling disabled by default

    Args:
        dataset: The dataset to wrap in a DataLoader.
        stage (str): Training stage, must be one of "train", "val", or "test".
            Determines default shuffling behavior.

    Returns:
        DataLoader: Configured DataLoader for the given dataset and stage.
    """
    assert stage in ["train", "val", "test"], f"Invalid stage: {stage}"
    dataloader_kwargs = dict(self.dataloader_kwargs)
    if "shuffle" not in dataloader_kwargs:
        dataloader_kwargs["shuffle"] = stage == "train"
    return DataLoader(dataset, **dataloader_kwargs)
run(modelpool)

Execute the fine-tuning process on the provided model pool.

This method performs the complete fine-tuning workflow: 1. Loads the pretrained model from the model pool 2. Prepares training and validation datasets 3. Configures optimizer and learning rate scheduler 4. Sets up Lightning trainer with appropriate callbacks 5. Executes the training process 6. Saves the final fine-tuned model

Source code in fusion_bench/method/classification/image_classification_finetune.py
def run(self, modelpool: ResNetForImageClassificationPool):
    """Execute the fine-tuning process on the provided model pool.

    This method performs the complete fine-tuning workflow:
    1. Loads the pretrained model from the model pool
    2. Prepares training and validation datasets
    3. Configures optimizer and learning rate scheduler
    4. Sets up Lightning trainer with appropriate callbacks
    5. Executes the training process
    6. Saves the final fine-tuned model
    """
    # load model and dataset
    model = modelpool.load_pretrained_or_first_model()
    base_model_name = _get_base_model_name(model)

    assert isinstance(model, nn.Module), "Loaded model is not a nn.Module."

    assert (
        len(modelpool.train_dataset_names) == 1
    ), "Exactly one training dataset is required."
    self.dataset_name = dataset_name = modelpool.train_dataset_names[0]
    num_classes = get_num_classes(dataset_name)
    log.info(f"Number of classes for dataset {dataset_name}: {num_classes}")
    train_dataset = modelpool.load_train_dataset(dataset_name)
    log.info(f"Training dataset size: {len(train_dataset)}")
    if self.training_data_ratio is not None and 0 < self.training_data_ratio < 1:
        train_dataset, _ = random_split(
            train_dataset,
            lengths=[self.training_data_ratio, 1 - self.training_data_ratio],
        )
        log.info(
            f"Using {len(train_dataset)} samples for training after applying training_data_ratio={self.training_data_ratio}."
        )
    train_dataset = CLIPDataset(
        train_dataset, processor=modelpool.load_processor(stage="train")
    )
    train_loader = self.get_dataloader(train_dataset, stage="train")
    if modelpool.has_val_dataset:
        val_dataset = modelpool.load_val_dataset(dataset_name)
        val_dataset = CLIPDataset(
            val_dataset, processor=modelpool.load_processor(stage="val")
        )
        val_loader = self.get_dataloader(val_dataset, stage="val")
    else:
        val_loader = None

    # configure optimizer
    optimizer = instantiate(self.optimizer, params=model.parameters())
    if self.lr_scheduler is not None:
        lr_scheduler = instantiate(self.lr_scheduler, optimizer=optimizer)
        optimizer = {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": lr_scheduler,
                "interval": self.training_interval,
                "frequency": 1,
            },
        }
    log.info(f"optimizer:\n{optimizer}")

    lit_module = ERM_LitModule(
        model,
        optimizer,
        objective=nn.CrossEntropyLoss(label_smoothing=self.label_smoothing),
        metrics={
            "acc@1": Accuracy(task="multiclass", num_classes=num_classes),
            f"acc@{min(5,num_classes)}": Accuracy(
                task="multiclass",
                num_classes=num_classes,
                top_k=min(5, num_classes),
            ),
        },
    )
    lit_module.train()

    log_dir = (
        self._program.path.log_dir
        if self._program is not None
        else "outputs/lightning_logs"
    )
    trainer = L.Trainer(
        max_epochs=self.max_epochs,
        max_steps=self.max_steps,
        accelerator="auto",
        devices="auto",
        callbacks=[
            pl_callbacks.LearningRateMonitor(logging_interval="step"),
            pl_callbacks.DeviceStatsMonitor(),
            pl_callbacks.ModelCheckpoint(
                save_top_k=self.save_top_k,
                every_n_train_steps=(
                    self.save_interval if self.training_interval == "step" else None
                ),
                every_n_epochs=(
                    self.save_interval
                    if self.training_interval == "epoch"
                    else None
                ),
                save_on_train_epoch_end=self.save_on_train_epoch_end,
                save_last=True,
            ),
        ],
        logger=TensorBoardLogger(save_dir=log_dir, name="", version=""),
        fast_dev_run=RuntimeConstants.debug,
    )

    trainer.fit(
        lit_module, train_dataloaders=train_loader, val_dataloaders=val_loader
    )
    model = lit_module.model
    if rank_zero_only.rank == 0:
        log.info(f"Saving the final model to {log_dir}/raw_checkpoints/final")
        modelpool.save_model(
            model,
            path=os.path.join(
                trainer.log_dir if trainer.log_dir is not None else log_dir,
                "raw_checkpoints",
                "final",
            ),
            algorithm_config=self.config,
            description=f"Fine-tuned ResNet model on dataset {dataset_name}.",
            base_model=base_model_name,
        )
    return model

ImageClassificationFineTuning_Test

Bases: BaseAlgorithm

Test/evaluation algorithm for fine-tuned image classification models.

This class implements model evaluation on test or validation datasets using PyTorch Lightning. It can either evaluate a model directly or load a model from a checkpoint before evaluation. The evaluation computes standard classification metrics including top-1 and top-5 accuracy.

Parameters:

  • checkpoint_path (str) –

    Path to the model checkpoint file. If None, uses the model directly from the model pool without loading from checkpoint.

  • dataloader_kwargs (DictConfig) –

    Additional keyword arguments for DataLoader construction.

  • **kwargs

    Additional arguments passed to the base class.

Example
>>> config = {
...     'checkpoint_path': '/path/to/model/checkpoint.ckpt',
...     'dataloader_kwargs': {'batch_size': 64, 'num_workers': 4}
... }
>>> test_algorithm = ImageClassificationFineTuning_Test(**config)
Source code in fusion_bench/method/classification/image_classification_finetune.py
@auto_register_config
class ImageClassificationFineTuning_Test(BaseAlgorithm):
    """Test/evaluation algorithm for fine-tuned image classification models.

    This class implements model evaluation on test or validation datasets using PyTorch Lightning.
    It can either evaluate a model directly or load a model from a checkpoint before evaluation.
    The evaluation computes standard classification metrics including top-1 and top-5 accuracy.

    Args:
        checkpoint_path (str): Path to the model checkpoint file. If None, uses the model
            directly from the model pool without loading from checkpoint.
        dataloader_kwargs (DictConfig): Additional keyword arguments for DataLoader construction.
        **kwargs: Additional arguments passed to the base class.

    Example:
        ```python
        >>> config = {
        ...     'checkpoint_path': '/path/to/model/checkpoint.ckpt',
        ...     'dataloader_kwargs': {'batch_size': 64, 'num_workers': 4}
        ... }
        >>> test_algorithm = ImageClassificationFineTuning_Test(**config)
        ```
    """

    def __init__(self, checkpoint_path: str, dataloader_kwargs: DictConfig, **kwargs):
        super().__init__(**kwargs)

    def run(self, modelpool: ResNetForImageClassificationPool):
        """Execute model evaluation on the provided model pool's test/validation dataset.

        This method performs the complete evaluation workflow:
        1. Loads the model from the model pool (pretrained or first available)
        2. Prepares the test or validation dataset (prioritizes test if both available)
        3. Sets up the Lightning module with appropriate metrics (top-1 and top-5 accuracy)
        4. Loads from checkpoint if specified, otherwise uses the model directly
        5. Executes the evaluation using Lightning trainer
        6. Logs and returns the test metrics
        """
        assert (
            modelpool.has_val_dataset or modelpool.has_test_dataset
        ), "No validation or test dataset found in the model pool."

        # load model and dataset
        model = modelpool.load_pretrained_or_first_model()
        assert isinstance(model, nn.Module), "Loaded model is not a nn.Module."

        if modelpool.has_test_dataset:
            assert (
                len(modelpool.test_dataset_names) == 1
            ), "Exactly one test dataset is required."
            self.dataset_name = dataset_name = modelpool.test_dataset_names[0]
            dataset = modelpool.load_test_dataset(dataset_name)
            dataset = CLIPDataset(
                dataset, processor=modelpool.load_processor(stage="test")
            )
        else:
            assert (
                len(modelpool.val_dataset_names) == 1
            ), "Exactly one validation dataset is required."
            self.dataset_name = dataset_name = modelpool.val_dataset_names[0]
            dataset = modelpool.load_val_dataset(dataset_name)
            dataset = CLIPDataset(
                dataset, processor=modelpool.load_processor(stage="test")
            )
        num_classes = get_num_classes(dataset_name)

        test_loader = self.get_dataloader(dataset, stage="test")

        if self.checkpoint_path is None:
            lit_module = ERM_LitModule(
                model,
                metrics={
                    "acc@1": Accuracy(task="multiclass", num_classes=num_classes),
                    f"acc@{min(5,num_classes)}": Accuracy(
                        task="multiclass",
                        num_classes=num_classes,
                        top_k=min(5, num_classes),
                    ),
                },
            )
        else:
            lit_module = ERM_LitModule.load_from_checkpoint(
                checkpoint_path=self.checkpoint_path,
                model=model,
                metrics={
                    "acc@1": Accuracy(task="multiclass", num_classes=num_classes),
                    f"acc@{min(5,num_classes)}": Accuracy(
                        task="multiclass",
                        num_classes=num_classes,
                        top_k=min(5, num_classes),
                    ),
                },
            )

        trainer = L.Trainer(
            devices=1, num_nodes=1, logger=False, fast_dev_run=RuntimeConstants.debug
        )

        test_metrics = trainer.test(lit_module, dataloaders=test_loader)
        log.info(f"Test metrics: {test_metrics}")
        return model

    def get_dataloader(self, dataset, stage: str):
        """Create a DataLoader for the specified dataset and evaluation stage.

        Constructs a PyTorch DataLoader with stage-appropriate configurations for evaluation.
        Similar to the training version but typically used for test/validation datasets.

        Args:
            dataset: The dataset to wrap in a DataLoader.
            stage (str): Evaluation stage, must be one of "train", "val", or "test".
                Determines default shuffling behavior (disabled for non-train stages).

        Returns:
            DataLoader: Configured DataLoader for the given dataset and stage.
        """
        assert stage in ["train", "val", "test"], f"Invalid stage: {stage}"
        dataloader_kwargs = dict(self.dataloader_kwargs)
        if "shuffle" not in dataloader_kwargs:
            dataloader_kwargs["shuffle"] = stage == "train"
        return DataLoader(dataset, **dataloader_kwargs)
get_dataloader(dataset, stage)

Create a DataLoader for the specified dataset and evaluation stage.

Constructs a PyTorch DataLoader with stage-appropriate configurations for evaluation. Similar to the training version but typically used for test/validation datasets.

Parameters:

  • dataset

    The dataset to wrap in a DataLoader.

  • stage (str) –

    Evaluation stage, must be one of "train", "val", or "test". Determines default shuffling behavior (disabled for non-train stages).

Returns:

  • DataLoader

    Configured DataLoader for the given dataset and stage.

Source code in fusion_bench/method/classification/image_classification_finetune.py
def get_dataloader(self, dataset, stage: str):
    """Create a DataLoader for the specified dataset and evaluation stage.

    Constructs a PyTorch DataLoader with stage-appropriate configurations for evaluation.
    Similar to the training version but typically used for test/validation datasets.

    Args:
        dataset: The dataset to wrap in a DataLoader.
        stage (str): Evaluation stage, must be one of "train", "val", or "test".
            Determines default shuffling behavior (disabled for non-train stages).

    Returns:
        DataLoader: Configured DataLoader for the given dataset and stage.
    """
    assert stage in ["train", "val", "test"], f"Invalid stage: {stage}"
    dataloader_kwargs = dict(self.dataloader_kwargs)
    if "shuffle" not in dataloader_kwargs:
        dataloader_kwargs["shuffle"] = stage == "train"
    return DataLoader(dataset, **dataloader_kwargs)
run(modelpool)

Execute model evaluation on the provided model pool's test/validation dataset.

This method performs the complete evaluation workflow: 1. Loads the model from the model pool (pretrained or first available) 2. Prepares the test or validation dataset (prioritizes test if both available) 3. Sets up the Lightning module with appropriate metrics (top-1 and top-5 accuracy) 4. Loads from checkpoint if specified, otherwise uses the model directly 5. Executes the evaluation using Lightning trainer 6. Logs and returns the test metrics

Source code in fusion_bench/method/classification/image_classification_finetune.py
def run(self, modelpool: ResNetForImageClassificationPool):
    """Execute model evaluation on the provided model pool's test/validation dataset.

    This method performs the complete evaluation workflow:
    1. Loads the model from the model pool (pretrained or first available)
    2. Prepares the test or validation dataset (prioritizes test if both available)
    3. Sets up the Lightning module with appropriate metrics (top-1 and top-5 accuracy)
    4. Loads from checkpoint if specified, otherwise uses the model directly
    5. Executes the evaluation using Lightning trainer
    6. Logs and returns the test metrics
    """
    assert (
        modelpool.has_val_dataset or modelpool.has_test_dataset
    ), "No validation or test dataset found in the model pool."

    # load model and dataset
    model = modelpool.load_pretrained_or_first_model()
    assert isinstance(model, nn.Module), "Loaded model is not a nn.Module."

    if modelpool.has_test_dataset:
        assert (
            len(modelpool.test_dataset_names) == 1
        ), "Exactly one test dataset is required."
        self.dataset_name = dataset_name = modelpool.test_dataset_names[0]
        dataset = modelpool.load_test_dataset(dataset_name)
        dataset = CLIPDataset(
            dataset, processor=modelpool.load_processor(stage="test")
        )
    else:
        assert (
            len(modelpool.val_dataset_names) == 1
        ), "Exactly one validation dataset is required."
        self.dataset_name = dataset_name = modelpool.val_dataset_names[0]
        dataset = modelpool.load_val_dataset(dataset_name)
        dataset = CLIPDataset(
            dataset, processor=modelpool.load_processor(stage="test")
        )
    num_classes = get_num_classes(dataset_name)

    test_loader = self.get_dataloader(dataset, stage="test")

    if self.checkpoint_path is None:
        lit_module = ERM_LitModule(
            model,
            metrics={
                "acc@1": Accuracy(task="multiclass", num_classes=num_classes),
                f"acc@{min(5,num_classes)}": Accuracy(
                    task="multiclass",
                    num_classes=num_classes,
                    top_k=min(5, num_classes),
                ),
            },
        )
    else:
        lit_module = ERM_LitModule.load_from_checkpoint(
            checkpoint_path=self.checkpoint_path,
            model=model,
            metrics={
                "acc@1": Accuracy(task="multiclass", num_classes=num_classes),
                f"acc@{min(5,num_classes)}": Accuracy(
                    task="multiclass",
                    num_classes=num_classes,
                    top_k=min(5, num_classes),
                ),
            },
        )

    trainer = L.Trainer(
        devices=1, num_nodes=1, logger=False, fast_dev_run=RuntimeConstants.debug
    )

    test_metrics = trainer.test(lit_module, dataloaders=test_loader)
    log.info(f"Test metrics: {test_metrics}")
    return model

ImageClassificationFineTuningForCLIP

Bases: CLIPClassificationMixin, SimpleProfilerMixin, ModelFusionAlgorithm

A class for fine-tuning CLIP models for image classification tasks.

Source code in fusion_bench/method/classification/clip_finetune.py
class ImageClassificationFineTuningForCLIP(
    CLIPClassificationMixin,
    SimpleProfilerMixin,
    ModelFusionAlgorithm,
):
    """
    A class for fine-tuning CLIP models for image classification tasks.
    """

    def run(self, modelpool: CLIPVisionModelPool):
        """
        Executes the fine-tuning process.

        Args:
            modelpool (CLIPVisionModelPool): The modelpool is responsible for loading the pre-trained model and training datasets.

        Returns:
            VisionModel: The fine-tuned vision model.
        """
        self.modelpool = to_modelpool(modelpool)
        config = self.config
        self.log_hyperparams(config, filename="method_config.yaml")
        self.finetune_method = "fine-tune"

        L.seed_everything(config.seed)

        task_names = modelpool.train_dataset_names
        with self.profile("setup model and optimizer"):
            processor, classifier, optimizer, lr_scheduler = self.setup_model()

            if config.state_dict_load_path is not None:
                self.fabric.load(
                    config.state_dict_load_path,
                    {"vision_model": classifier.clip_model.vision_model},
                )
                if config.skip_training:
                    return classifier.clip_model.vision_model

            self.setup_zero_shot_classification_head(
                clip_processor=processor,
                clip_model=classifier.clip_model,
                task_names=task_names,
            )

            self.fabric.setup(classifier, optimizer)

        with self.profile("setup data"):
            train_datasets = [
                CLIPDataset(modelpool.load_train_dataset(task_name), processor)
                for task_name in task_names
            ]
            train_dataloaders = [
                DataLoader(
                    dataset,
                    shuffle=True,
                    batch_size=config.batch_size,
                    num_workers=config.num_workers,
                )
                for dataset in train_datasets
            ]
            train_dataloaders = self.fabric.setup_dataloaders(*train_dataloaders)
            if not isinstance(train_dataloaders, (list, tuple)):
                train_dataloaders = [train_dataloaders]
            train_dataloader_iters = [
                iter(InfiniteDataLoader(loader)) for loader in train_dataloaders
            ]

        # train
        for step_idx in tqdm(
            range(config.num_steps),
            desc=self.finetune_method,
            disable=not self.fabric.is_global_zero,
            dynamic_ncols=True,
        ):
            optimizer.zero_grad()
            loss = 0
            for task, loader in zip(task_names, train_dataloader_iters):
                with self.profile("data loading"):
                    batch = next(loader)
                    images, labels = batch
                with self.profile("forward"):
                    classifier.zeroshot_weights = self.zeroshot_weights[task]
                    logits = classifier(images)
                loss = loss + nn.functional.cross_entropy(logits, labels)

            with self.profile("backward"):
                self.fabric.backward(loss)
            with self.profile("optimizer step"):
                optimizer.step()
                lr_scheduler.step()

            metrics = {"train/loss": loss}

            self.fabric.log_dict(metrics, step=step_idx)

            if (step_idx + 1) % config.save_interval == 0:
                save_path = os.path.join(
                    self.log_dir, "checkpoints", f"step={step_idx}.ckpt"
                )
                self.save_model(classifier, save_path)

        if config.state_dict_save_path is not None:
            self.save_model(classifier, config.state_dict_save_path)
        self.print_profile_summary()
        return classifier.clip_model.vision_model

    def save_model(
        self,
        model: HFCLIPClassifier | CLIPModel | CLIPVisionModel | CLIPVisionTransformer,
        save_path: str,
    ):
        """
        Save the vision model to the specified path.

        Args:
            model (Union[HFCLIPClassifier, CLIPModel, CLIPVisionModel, CLIPVisionTransformer]): The model to save.
            save_path (str): The path to save the model.
        """
        if isinstance(model, HFCLIPClassifier):
            vision_model = model.clip_model.vision_model
        elif isinstance(model, CLIPModel):
            vision_model = model.vision_model
        elif isinstance(model, CLIPVisionModel):
            vision_model = model.vision_model
        elif isinstance(model, CLIPVisionTransformer):
            vision_model = model
        else:
            raise ValueError(f"Unsupported model type: {type(model)}")

        save_dir = os.path.dirname(save_path)
        if save_dir and not os.path.exists(save_dir):
            os.makedirs(save_dir, exist_ok=True)
        self.fabric.save(save_path, {"vision_model": vision_model})

    def setup_model(self):
        """
        Sets up the model, optimizer, and learning rate scheduler.

        This method initializes the CLIP model, applies LoRA if specified, and configures the optimizer and learning rate scheduler.

        Returns:
            Tuple: A tuple containing the processor, classifier, optimizer, and learning rate scheduler.
        """
        config = self.config
        modelpool = self.modelpool

        clip_model: CLIPModel = modelpool.load_clip_model("_pretrained_")
        processor = modelpool.load_processor()

        self.finetune_method = "full fine-tune"
        if config.use_lora or config.use_l_lora:
            self.finetune_method = "lora fine-tune"
            lora_config = LoraConfig(
                **OmegaConf.to_container(
                    config.lora_config, resolve=True, enum_to_str=True
                )
            )
            clip_model.vision_model = get_peft_model(
                clip_model.vision_model, lora_config
            )

            if config.use_l_lora:
                # http://arxiv.org/abs/2310.04742
                # Anke Tang et al. Parameter Efficient Multi-task Model Fusion with Partial Linearization. ICLR 2024.
                self.finetune_method = "l-lora fine-tune"
                print("Linearizing Lora Layers")
                linearize_lora_model_(clip_model.vision_model)

        classifier = HFCLIPClassifier(clip_model, processor=processor)

        if self.fabric.is_global_zero:
            print("=== Model Summary (For Vision Model Only) ===")
            print_parameters(classifier.clip_model.vision_model)
        # configure optimizers
        optimizer = torch.optim.Adam(
            [
                p
                for p in classifier.clip_model.vision_model.parameters()
                if p.requires_grad
            ],
            lr=config.learning_rate,
            weight_decay=config.weight_decay,
        )
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer=optimizer, T_max=config.num_steps
        )

        return processor, classifier, optimizer, lr_scheduler
run(modelpool)

Executes the fine-tuning process.

Parameters:

  • modelpool (CLIPVisionModelPool) –

    The modelpool is responsible for loading the pre-trained model and training datasets.

Returns:

  • VisionModel

    The fine-tuned vision model.

Source code in fusion_bench/method/classification/clip_finetune.py
def run(self, modelpool: CLIPVisionModelPool):
    """
    Executes the fine-tuning process.

    Args:
        modelpool (CLIPVisionModelPool): The modelpool is responsible for loading the pre-trained model and training datasets.

    Returns:
        VisionModel: The fine-tuned vision model.
    """
    self.modelpool = to_modelpool(modelpool)
    config = self.config
    self.log_hyperparams(config, filename="method_config.yaml")
    self.finetune_method = "fine-tune"

    L.seed_everything(config.seed)

    task_names = modelpool.train_dataset_names
    with self.profile("setup model and optimizer"):
        processor, classifier, optimizer, lr_scheduler = self.setup_model()

        if config.state_dict_load_path is not None:
            self.fabric.load(
                config.state_dict_load_path,
                {"vision_model": classifier.clip_model.vision_model},
            )
            if config.skip_training:
                return classifier.clip_model.vision_model

        self.setup_zero_shot_classification_head(
            clip_processor=processor,
            clip_model=classifier.clip_model,
            task_names=task_names,
        )

        self.fabric.setup(classifier, optimizer)

    with self.profile("setup data"):
        train_datasets = [
            CLIPDataset(modelpool.load_train_dataset(task_name), processor)
            for task_name in task_names
        ]
        train_dataloaders = [
            DataLoader(
                dataset,
                shuffle=True,
                batch_size=config.batch_size,
                num_workers=config.num_workers,
            )
            for dataset in train_datasets
        ]
        train_dataloaders = self.fabric.setup_dataloaders(*train_dataloaders)
        if not isinstance(train_dataloaders, (list, tuple)):
            train_dataloaders = [train_dataloaders]
        train_dataloader_iters = [
            iter(InfiniteDataLoader(loader)) for loader in train_dataloaders
        ]

    # train
    for step_idx in tqdm(
        range(config.num_steps),
        desc=self.finetune_method,
        disable=not self.fabric.is_global_zero,
        dynamic_ncols=True,
    ):
        optimizer.zero_grad()
        loss = 0
        for task, loader in zip(task_names, train_dataloader_iters):
            with self.profile("data loading"):
                batch = next(loader)
                images, labels = batch
            with self.profile("forward"):
                classifier.zeroshot_weights = self.zeroshot_weights[task]
                logits = classifier(images)
            loss = loss + nn.functional.cross_entropy(logits, labels)

        with self.profile("backward"):
            self.fabric.backward(loss)
        with self.profile("optimizer step"):
            optimizer.step()
            lr_scheduler.step()

        metrics = {"train/loss": loss}

        self.fabric.log_dict(metrics, step=step_idx)

        if (step_idx + 1) % config.save_interval == 0:
            save_path = os.path.join(
                self.log_dir, "checkpoints", f"step={step_idx}.ckpt"
            )
            self.save_model(classifier, save_path)

    if config.state_dict_save_path is not None:
        self.save_model(classifier, config.state_dict_save_path)
    self.print_profile_summary()
    return classifier.clip_model.vision_model
save_model(model, save_path)

Save the vision model to the specified path.

Parameters:

  • model (Union[HFCLIPClassifier, CLIPModel, CLIPVisionModel, CLIPVisionTransformer]) –

    The model to save.

  • save_path (str) –

    The path to save the model.

Source code in fusion_bench/method/classification/clip_finetune.py
def save_model(
    self,
    model: HFCLIPClassifier | CLIPModel | CLIPVisionModel | CLIPVisionTransformer,
    save_path: str,
):
    """
    Save the vision model to the specified path.

    Args:
        model (Union[HFCLIPClassifier, CLIPModel, CLIPVisionModel, CLIPVisionTransformer]): The model to save.
        save_path (str): The path to save the model.
    """
    if isinstance(model, HFCLIPClassifier):
        vision_model = model.clip_model.vision_model
    elif isinstance(model, CLIPModel):
        vision_model = model.vision_model
    elif isinstance(model, CLIPVisionModel):
        vision_model = model.vision_model
    elif isinstance(model, CLIPVisionTransformer):
        vision_model = model
    else:
        raise ValueError(f"Unsupported model type: {type(model)}")

    save_dir = os.path.dirname(save_path)
    if save_dir and not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    self.fabric.save(save_path, {"vision_model": vision_model})
setup_model()

Sets up the model, optimizer, and learning rate scheduler.

This method initializes the CLIP model, applies LoRA if specified, and configures the optimizer and learning rate scheduler.

Returns:

  • Tuple

    A tuple containing the processor, classifier, optimizer, and learning rate scheduler.

Source code in fusion_bench/method/classification/clip_finetune.py
def setup_model(self):
    """
    Sets up the model, optimizer, and learning rate scheduler.

    This method initializes the CLIP model, applies LoRA if specified, and configures the optimizer and learning rate scheduler.

    Returns:
        Tuple: A tuple containing the processor, classifier, optimizer, and learning rate scheduler.
    """
    config = self.config
    modelpool = self.modelpool

    clip_model: CLIPModel = modelpool.load_clip_model("_pretrained_")
    processor = modelpool.load_processor()

    self.finetune_method = "full fine-tune"
    if config.use_lora or config.use_l_lora:
        self.finetune_method = "lora fine-tune"
        lora_config = LoraConfig(
            **OmegaConf.to_container(
                config.lora_config, resolve=True, enum_to_str=True
            )
        )
        clip_model.vision_model = get_peft_model(
            clip_model.vision_model, lora_config
        )

        if config.use_l_lora:
            # http://arxiv.org/abs/2310.04742
            # Anke Tang et al. Parameter Efficient Multi-task Model Fusion with Partial Linearization. ICLR 2024.
            self.finetune_method = "l-lora fine-tune"
            print("Linearizing Lora Layers")
            linearize_lora_model_(clip_model.vision_model)

    classifier = HFCLIPClassifier(clip_model, processor=processor)

    if self.fabric.is_global_zero:
        print("=== Model Summary (For Vision Model Only) ===")
        print_parameters(classifier.clip_model.vision_model)
    # configure optimizers
    optimizer = torch.optim.Adam(
        [
            p
            for p in classifier.clip_model.vision_model.parameters()
            if p.requires_grad
        ],
        lr=config.learning_rate,
        weight_decay=config.weight_decay,
    )
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer=optimizer, T_max=config.num_steps
    )

    return processor, classifier, optimizer, lr_scheduler

ContinualImageClassificationFineTuningForCLIP

Bases: CLIPClassificationMixin, SimpleProfilerMixin, BaseAlgorithm

Source code in fusion_bench/method/classification/continual_clip_finetune.py
class ContinualImageClassificationFineTuningForCLIP(
    CLIPClassificationMixin,
    SimpleProfilerMixin,
    BaseAlgorithm,
):
    # attributes to configuration keys mapping
    _config_mapping = BaseAlgorithm._config_mapping | {
        "seed": "seed",
        "shuffle_order": "shuffle_order",
        "learning_rate": "learning_rate",
        "weight_decay": "weight_decay",
        "num_steps": "num_steps",
        "batch_size": "batch_size",
        "num_workers": "num_workers",
        "save_interval": "save_interval",
        "state_dict_load_path": "state_dict_load_path",
        "state_dict_save_path": "state_dict_save_path",
        "skip_training": "skip_training",
        "use_lora": "use_lora",
        "lora_config": "lora_config",
    }

    def __init__(
        self,
        seed: int = 42,
        shuffle_order: bool = True,
        learning_rate: float = 1e-5,
        weight_decay: float = 0,
        num_steps: int = 4000,
        batch_size: int = 128,
        num_workers: int = 16,
        save_interval: int = 500,
        state_dict_load_path: Optional[str] = None,
        state_dict_save_path: Optional[str] = None,
        skip_training: bool = False,
        use_lora: bool = False,
        lora_config: Optional[LoraConfig] = None,
    ):
        self.seed = seed
        self.shuffle_order = shuffle_order
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.num_steps = num_steps
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.save_interval = save_interval
        self.state_dict_load_path = state_dict_load_path
        self.state_dict_save_path = state_dict_save_path
        self.skip_training = skip_training
        self.use_lora = use_lora
        self.lora_config = lora_config

    def run(self, modelpool: CLIPVisionModelPool):
        self.modelpool = to_modelpool(modelpool)
        config = self.config
        self.log_hyperparams(config, filename="method_config.yaml")
        self.finetune_method = "fine-tune"

        if self.seed is not None:
            L.seed_everything(self.seed)
        else:
            seed_everything_by_time(self.fabric)

        task_names = list(modelpool.train_dataset_names)
        if self.shuffle_order:
            random.shuffle(task_names)
        if self.fabric.is_global_zero:
            save_to_json(task_names, os.path.join(self.log_dir, "task_names.json"))

        if self._program.taskpool is not None and isinstance(
            self._program.taskpool, CLIPVisionModelTaskPool
        ):
            has_taskpool = True
            taskpool = cast(CLIPVisionModelTaskPool, self._program.taskpool)
            test_datasets = taskpool._test_datasets
        else:
            has_taskpool = False

        with self.profile("setup model and optimizer"):
            processor, classifier, optimizer, lr_scheduler = self.setup_model()

            if self.state_dict_load_path is not None:
                self.fabric.load(
                    self.state_dict_load_path,
                    {"vision_model": classifier.clip_model.vision_model},
                )
                if self.skip_training:
                    return classifier.clip_model.vision_model

            self.setup_zero_shot_classification_head(
                clip_processor=processor,
                clip_model=classifier.clip_model,
                task_names=task_names,
            )

            init_optimizer_state_dict = optimizer.state_dict()
            init_lr_scheduler_state_dict = lr_scheduler.state_dict()
            self.fabric.setup(classifier, optimizer)

        with self.profile("setup data"):
            train_datasets = [
                CLIPDataset(modelpool.load_train_dataset(task_name), processor)
                for task_name in task_names
            ]
            train_dataloaders = [
                DataLoader(
                    dataset,
                    shuffle=True,
                    batch_size=self.batch_size,
                    num_workers=self.num_workers,
                )
                for dataset in train_datasets
            ]
            train_dataloaders = self.fabric.setup_dataloaders(*train_dataloaders)
            if not isinstance(train_dataloaders, (list, tuple)):
                train_dataloaders = [train_dataloaders]
            train_dataloader_iters = [
                iter(InfiniteDataLoader(loader)) for loader in train_dataloaders
            ]

        # continual train
        for task_idx, task_name in tqdm(
            enumerate(task_names),
            dynamic_ncols=True,
            disable=not self.fabric.is_global_zero,
        ):
            train_dataloader_iter = train_dataloader_iters[task_idx]

            # reset optimizer and lr scheduler
            print("reset optimizer and lr scheduler")
            optimizer.load_state_dict(init_optimizer_state_dict)
            lr_scheduler.load_state_dict(init_lr_scheduler_state_dict)

            for step_idx in tqdm(
                range(self.num_steps),
                desc=f"continual fine-tune on {task_name}",
                disable=not self.fabric.is_global_zero,
                dynamic_ncols=True,
                leave=False,
            ):
                optimizer.zero_grad()
                loss = 0
                with self.profile("data loading"):
                    batch = next(train_dataloader_iter)
                    images, labels = batch
                with self.profile("forward"):
                    classifier.zeroshot_weights = self.zeroshot_weights[task_name]
                    logits = classifier(images)
                    assert (
                        labels.max() < logits.shape[1]
                    ), f"for task {task_name}, labels.max() = {labels.max()}, logits.shape[1] = {logits.shape[1]}"
                loss = loss + nn.functional.cross_entropy(logits, labels)

                with self.profile("backward"):
                    self.fabric.backward(loss)
                with self.profile("optimizer step"):
                    optimizer.step()
                    lr_scheduler.step()

                metrics = {"train/loss": loss}
                self.fabric.log_dict(metrics, step=step_idx)

                if (step_idx + 1) % self.save_interval == 0:
                    save_path = os.path.join(
                        self.log_dir,
                        "checkpoints",
                        f"task={task_idx}_step={step_idx}.ckpt",
                    )
                    self.save_model(classifier, save_path)

            if has_taskpool:
                taskpool._is_setup = False
                taskpool._test_datasets = DictConfig(
                    {t: test_datasets[t] for t in task_names[: task_idx + 1]}
                )
                eval_report = taskpool.evaluate(
                    deepcopy(classifier.clip_model.vision_model),
                    name=task_name,
                )
                if self.fabric.is_global_zero:
                    save_to_json(
                        eval_report,
                        os.path.join(self.log_dir, f"results_{task_idx}.json"),
                    )

        if self.state_dict_save_path is not None:
            self.save_model(classifier, self.state_dict_save_path)
        self.print_profile_summary()
        return classifier.clip_model.vision_model

    def save_model(
        self,
        model: HFCLIPClassifier | CLIPModel | CLIPVisionModel | CLIPVisionTransformer,
        save_path: str,
    ):
        """
        Save the vision model to the specified path.

        Args:
            model (Union[HFCLIPClassifier, CLIPModel, CLIPVisionModel, CLIPVisionTransformer]): The model to save.
            save_path (str): The path to save the model.
        """
        if isinstance(model, HFCLIPClassifier):
            vision_model = model.clip_model.vision_model
        elif isinstance(model, CLIPModel):
            vision_model = model.vision_model
        elif isinstance(model, CLIPVisionModel):
            vision_model = model.vision_model
        elif isinstance(model, CLIPVisionTransformer):
            vision_model = model
        else:
            raise ValueError(f"Unsupported model type: {type(model)}")

        save_dir = os.path.dirname(save_path)
        if save_dir and not os.path.exists(save_dir):
            os.makedirs(save_dir, exist_ok=True)
        self.fabric.save(save_path, {"vision_model": vision_model})

    def setup_model(self):
        """
        Sets up the model, optimizer, and learning rate scheduler.

        This method initializes the CLIP model, applies LoRA if specified, and configures the optimizer and learning rate scheduler.

        Returns:
            Tuple: A tuple containing the processor, classifier, optimizer, and learning rate scheduler.
        """
        config = self.config
        modelpool = self.modelpool

        clip_model: CLIPModel = modelpool.load_clip_model("_pretrained_")
        processor = modelpool.load_processor()

        self.finetune_method = "full fine-tune"
        if self.use_lora:
            self.finetune_method = "lora fine-tune"
            lora_config = LoraConfig(
                **OmegaConf.to_container(
                    self.lora_config, resolve=True, enum_to_str=True
                )
            )
            clip_model.vision_model = get_peft_model(
                clip_model.vision_model, lora_config
            )

        classifier = HFCLIPClassifier(clip_model, processor=processor)

        if self.fabric.is_global_zero:
            print("=== Model Summary (For Vision Model Only) ===")
            print_parameters(classifier.clip_model.vision_model)
        # configure optimizers
        optimizer = torch.optim.Adam(
            [
                p
                for p in classifier.clip_model.vision_model.parameters()
                if p.requires_grad
            ],
            lr=self.learning_rate,
            weight_decay=self.weight_decay,
        )
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer=optimizer, T_max=self.num_steps
        )

        return processor, classifier, optimizer, lr_scheduler
save_model(model, save_path)

Save the vision model to the specified path.

Parameters:

  • model (Union[HFCLIPClassifier, CLIPModel, CLIPVisionModel, CLIPVisionTransformer]) –

    The model to save.

  • save_path (str) –

    The path to save the model.

Source code in fusion_bench/method/classification/continual_clip_finetune.py
def save_model(
    self,
    model: HFCLIPClassifier | CLIPModel | CLIPVisionModel | CLIPVisionTransformer,
    save_path: str,
):
    """
    Save the vision model to the specified path.

    Args:
        model (Union[HFCLIPClassifier, CLIPModel, CLIPVisionModel, CLIPVisionTransformer]): The model to save.
        save_path (str): The path to save the model.
    """
    if isinstance(model, HFCLIPClassifier):
        vision_model = model.clip_model.vision_model
    elif isinstance(model, CLIPModel):
        vision_model = model.vision_model
    elif isinstance(model, CLIPVisionModel):
        vision_model = model.vision_model
    elif isinstance(model, CLIPVisionTransformer):
        vision_model = model
    else:
        raise ValueError(f"Unsupported model type: {type(model)}")

    save_dir = os.path.dirname(save_path)
    if save_dir and not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    self.fabric.save(save_path, {"vision_model": vision_model})
setup_model()

Sets up the model, optimizer, and learning rate scheduler.

This method initializes the CLIP model, applies LoRA if specified, and configures the optimizer and learning rate scheduler.

Returns:

  • Tuple

    A tuple containing the processor, classifier, optimizer, and learning rate scheduler.

Source code in fusion_bench/method/classification/continual_clip_finetune.py
def setup_model(self):
    """
    Sets up the model, optimizer, and learning rate scheduler.

    This method initializes the CLIP model, applies LoRA if specified, and configures the optimizer and learning rate scheduler.

    Returns:
        Tuple: A tuple containing the processor, classifier, optimizer, and learning rate scheduler.
    """
    config = self.config
    modelpool = self.modelpool

    clip_model: CLIPModel = modelpool.load_clip_model("_pretrained_")
    processor = modelpool.load_processor()

    self.finetune_method = "full fine-tune"
    if self.use_lora:
        self.finetune_method = "lora fine-tune"
        lora_config = LoraConfig(
            **OmegaConf.to_container(
                self.lora_config, resolve=True, enum_to_str=True
            )
        )
        clip_model.vision_model = get_peft_model(
            clip_model.vision_model, lora_config
        )

    classifier = HFCLIPClassifier(clip_model, processor=processor)

    if self.fabric.is_global_zero:
        print("=== Model Summary (For Vision Model Only) ===")
        print_parameters(classifier.clip_model.vision_model)
    # configure optimizers
    optimizer = torch.optim.Adam(
        [
            p
            for p in classifier.clip_model.vision_model.parameters()
            if p.requires_grad
        ],
        lr=self.learning_rate,
        weight_decay=self.weight_decay,
    )
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer=optimizer, T_max=self.num_steps
    )

    return processor, classifier, optimizer, lr_scheduler

LLM Fine-tuning

FullFinetuneSFT

Bases: BaseAlgorithm, FabricTrainingMixin

Source code in fusion_bench/method/lm_finetune/fullfinetune_sft.py
class FullFinetuneSFT(BaseAlgorithm, FabricTrainingMixin):

    model: Union[nn.Module, "_FabricModule", "LlamaForCausalLM"]
    optimizer: Union[torch.optim.Optimizer, "_FabricOptimizer"]
    train_dataloader: Union[DataLoader, "_FabricDataLoader"]
    lr_scheduler: torch.optim.lr_scheduler.LRScheduler
    _latest_saved_checkpoint_global_step: int = -1

    def __init__(
        self,
        optimizer: DictConfig,
        lr_scheduler: Optional[DictConfig],
        dataloader_kwargs: DictConfig,
        max_epochs: int,
        max_steps: int = -1,
        max_steps_per_epoch: int = -1,
        lr_scheduler_interval: Literal["epoch", "step"] = "step",
        lr_scheduler_frequency: int = 1,
        checkpoint_save_interval: Literal["epoch", "step"] = "epoch",
        checkpoint_save_frequency: int = 1,
        accumulate_grad_batches: int = 1,
        gradient_clip_val: Optional[float] = None,
        gradient_clip_algorithm: Literal["value", "norm"] = "norm",
        save_optimizer_state: bool = False,
        save_full_model: bool = False,
        save_ckpt_type: Literal["lightning", "hf"] = "lightning",
        ckpt_path: Optional[str] = None,
        max_length: int = 6144,
        fix_token_embedding: bool = True,
        **kwargs,
    ):
        """
        Class for full finetuning of a language model on given SFT datasets.

        Args:
            optimizer(DictConfig): Configuration for the optimizer.
            lr_scheduler(DictConfig): Configuration for the learning rate scheduler.
            dataloader_kwargs(DictConfig): Configuration for the dataloader, such as batch size, num_workers, etc.
            max_epochs(int): Maximum number of epochs to train the model. If set to -1, the training will continue indefinitely or until max_steps is reached.
            max_steps(int): Maximum number of steps to train the model. If set to -1, the training will continue indefinitely or until max_epochs is reached.
            max_steps_per_epoch(int): Maximum number of steps to train the model in each epoch. If set to -1, the training will continue until the end of the epoch.
            lr_scheduler_interval(str): Interval at which to run the learning rate scheduler. Available options: 'epoch', 'step'. If set to 'epoch', the scheduler will run at the end of each epoch. If set to 'step', the scheduler will run at the end of each step.
            lr_scheduler_frequency(int): Frequency at which to run the learning rate scheduler. The scheduler will run every `lr_scheduler_frequency` epochs or steps, depending on the value of `lr_scheduler_interval`.
            checkpoint_save_interval(str): Interval at which to save the model checkpoint. Available options: 'epoch', 'step'. If set to 'epoch', the model will be saved at the end of each epoch. If set to 'step', the model will be saved at the end of each step.
            checkpoint_save_frequency(int): Frequency at which to save the model checkpoint. The model will be saved every `checkpoint_save_frequency` epochs or steps, depending on the value of `checkpoint_save_interval`.
            accumulate_grad_batches(int): Number of batches to accumulate gradients across before updating the model parameters.
            gradient_clip_val(float): Value to clip the gradients. If set to None, no gradient clipping will be applied.
            gradient_clip_algorithm(str): Algorithm to use for gradient clipping. Available options: 'value', 'norm'. If set to 'value', the gradients will be clipped to the specified value. If set to 'norm', the gradients will be clipped to the specified norm.
            save_optimizer_state(bool): Whether to save the optimizer and lr_scheduler state along with the model checkpoint.
            save_full_model(bool): Whether to save the full model or only the trainable parameters in the model checkpoint.
            save_ckpt_type (str): Type of checkpoint to save. Available options: 'lightning', 'hf'. If set to 'lightning', the checkpoint will be saved in the lightning format. If set to 'hf', the checkpoint will be saved in the huggingface format.
            ckpt_path(str): Path to the checkpoint to load before training. If set to None, no checkpoint will be loaded.
            max_length(int): Maximum input length to consider. If the input length exceeds this value, it will be truncated.
            fix_token_embedding(bool): Whether to fix the token embeddings during training. If set to True, the token embeddings will not be updated during training.
        """
        self._optimizer = optimizer
        self._lr_scheduler = lr_scheduler
        self.dataloader_kwargs = dataloader_kwargs
        self.max_epochs = max_epochs
        self.max_steps = max_steps
        self.max_steps_per_epoch = max_steps_per_epoch
        self.lr_scheduler_interval = lr_scheduler_interval
        self.lr_scheduler_frequency = lr_scheduler_frequency
        self.checkpoint_save_interval = checkpoint_save_interval
        self.checkpoint_save_frequency = checkpoint_save_frequency
        self.accumulate_grad_batches = accumulate_grad_batches
        self.gradient_clip_val = gradient_clip_val
        self.gradient_clip_algorithm = gradient_clip_algorithm
        self.save_optimizer_state = save_optimizer_state
        self.save_full_model = save_full_model
        self.save_ckpt_type = save_ckpt_type
        self.ckpt_path = ckpt_path
        self.max_length = max_length
        self.fix_token_embedding = fix_token_embedding
        super().__init__(**kwargs)

    def run(self, modelpool: CausalLMPool):
        self.modelpool = modelpool
        self.setup()
        self.train(self.model, self.optimizer, self.lr_scheduler)
        return self.model

    def setup_model(self):
        self.tokenizer = self.modelpool.load_tokenizer()
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

        model = self.modelpool.load_pretrained_model()
        self.model: "LlamaForCausalLM" = model

        if self.fix_token_embedding:
            self.model.model.embed_tokens.requires_grad_(False)

        if self.fabric.strategy == "fsdp" or isinstance(
            self.fabric.strategy, FSDPStrategy
        ):
            # https://github.com/Lightning-AI/pytorch-lightning/issues/19267
            self.model.gradient_checkpointing_enable(
                gradient_checkpointing_kwargs={"use_reentrant": True}
            )
            self.use_cache = False
        else:
            self.use_cache = True
        self.model_dtype = get_dtype(self.model)

    def configure_optimizer(self):
        # compute expected total steps
        self.compute_expected_total_steps(self.train_dataloader)

        optimizer = instantiate(self._optimizer, self.model.parameters())
        if self._lr_scheduler is not None:
            for key, arg in self._lr_scheduler.items():
                if arg == "_T_max_":
                    log.info(
                        f"Setting key `{key}` of lr_scheduler configuration to {self.expected_total_steps}"
                    )
                    self._lr_scheduler[key] = self.expected_total_steps
            lr_scheduler: torch.optim.lr_scheduler.LRScheduler = instantiate(
                self._lr_scheduler,
                optimizer=optimizer,
            )
        else:
            lr_scheduler = None
        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}

    def setup_data(self):
        fabric = self.fabric
        modelpool = self.modelpool
        assert (
            len(modelpool.train_dataset_names) > 0
        ), "No training datasets found in modelpool."

        train_datasets = [
            modelpool.load_train_dataset(dataset_name)
            for dataset_name in modelpool.train_dataset_names
        ]
        if len(train_datasets) > 1:
            train_dataset = ConcatDataset(train_datasets)
        else:
            train_dataset = train_datasets[0]

        self.train_dataset = train_dataset
        self.train_dataloader = DataLoader(
            train_dataset,
            **self.dataloader_kwargs,
            shuffle=True,
            collate_fn=functools.partial(
                padded_collate_sft, pad_token_id=self.tokenizer.pad_token_id
            ),
        )
        self.train_dataloader = fabric.setup_dataloaders(self.train_dataloader)

    def setup(self):
        fabric = self.fabric

        self.setup_model()
        self.setup_data()

        optimizer = self.configure_optimizer()
        optimizer, lr_scheduler = optimizer["optimizer"], optimizer["lr_scheduler"]

        self.model, self.optimizer = fabric.setup(self.model, optimizer)
        self.lr_scheduler = lr_scheduler

    @override
    def train_epoch(self, *args, **kwargs):
        fabric = self.fabric

        accumulated_loss = 0
        for step_idx, batch in enumerate(
            pbar := tqdm(
                self.train_dataloader,
                desc="Training Batches",
                dynamic_ncols=True,
                leave=False,
                disable=not fabric.is_global_zero,
            )
        ):
            is_accumulating = (step_idx + 1) % self.accumulate_grad_batches != 0

            if self.max_length > 0 and batch["input_ids"].shape[1] > self.max_length:
                log.warning(
                    f"Input length exceeds max_length: {batch['input_ids'].shape[1]} > {self.max_length}. Truncating input."
                )
                batch["input_ids"] = batch["input_ids"][:, : self.max_length]
                batch["attention_mask"] = batch["attention_mask"][:, : self.max_length]
                batch["labels"] = batch["labels"][:, : self.max_length]

            # disable gradient synchronization if accumulating gradients across steps for improved performance
            with fabric.no_backward_sync(self.model, enabled=is_accumulating):
                # use_cache=True is not compatible with gradient checkpointing, so we disable it here
                output = self.model(
                    input_ids=batch["input_ids"],
                    attention_mask=batch["attention_mask"],
                    labels=batch["labels"],
                    use_cache=self.use_cache,
                )
                loss = output["loss"] / self.accumulate_grad_batches

                fabric.backward(loss)
                accumulated_loss += loss.item()

            if not is_accumulating:
                self.clip_gradients_if_needed(self.model, self.optimizer)

                # run lr_scheduler at the end of the step if interval is set to "step"
                if (
                    self.lr_scheduler_interval == "step"
                    and (self.global_step_idx + 1) % self.lr_scheduler_frequency == 0
                ):
                    self.lr_scheduler.step()

                # update the model parameters and zero the gradients
                self.optimizer.step()
                self.optimizer.zero_grad()

                metrics = {
                    "train/loss": accumulated_loss,
                    "train/epoch_idx": self.epoch_idx,
                    "train/lr": self.optimizer.param_groups[0]["lr"],
                }
                fabric.log_dict(metrics, step=self.global_step_idx)
                pbar.set_postfix(metrics)

                # save the model at the end of the step if interval is set to "step" and frequency is met
                self.conditional_checkpoint_save(stage="end_of_step")

                # break if max_steps_per_epoch is set, and exit epoch
                if (
                    self.max_steps_per_epoch > 0
                    and step_idx + 1 >= self.max_steps_per_epoch
                ):
                    break
                # break if max_steps is set, and exit training
                if self.max_steps > 0 and self.global_step_idx >= self.max_steps - 1:
                    self.is_training = False
                    break

                self.global_step_idx += 1
                accumulated_loss = 0

    def save_checkpoint(
        self,
        path: Union[str, Path],
        save_optimizer_state: Optional[bool] = None,
        overwrite: bool = False,
    ):
        if not overwrite and os.path.exists(path):
            return log.warning(f"Checkpoint already exists at {path}. Skipping save.")

        fabric = self.fabric

        if self.save_ckpt_type == "lightning":
            state = {"model": self.model}

            # save the optimizer and lr_scheduler state if needed
            if self.save_optimizer_state and save_optimizer_state is not False:
                state.update(
                    {
                        "optimizer": self.optimizer,
                        "lr_scheduler": self.lr_scheduler,
                        "global_step_idx": self.global_step_idx,
                        "epoch_idx": self.epoch_idx,
                    }
                )

            trainable_param_names = set(
                name
                for name, param in self.model.state_dict(keep_vars=True).items()
                if param.requires_grad
            )
            filter = (
                None
                if self.save_full_model
                else {"model": lambda k, p: k in trainable_param_names}
            )

            fabric.save(path, state=state, filter=filter)
        else:
            self.model.save_pretrained(path, is_main_process=fabric.is_global_zero)

        self._latest_saved_checkpoint_global_step = self.global_step_idx

    def load_checkpoint(self, path: Union[str, Path]):
        fabric = self.fabric

        state = {"model": self.model}

        # save the optimizer and lr_scheduler state if needed
        if self.save_optimizer_state:
            state.update(
                {
                    "optimizer": self.optimizer,
                    "lr_scheduler": self.lr_scheduler,
                }
            )

        fabric.load(path, state)

__init__(optimizer, lr_scheduler, dataloader_kwargs, max_epochs, max_steps=-1, max_steps_per_epoch=-1, lr_scheduler_interval='step', lr_scheduler_frequency=1, checkpoint_save_interval='epoch', checkpoint_save_frequency=1, accumulate_grad_batches=1, gradient_clip_val=None, gradient_clip_algorithm='norm', save_optimizer_state=False, save_full_model=False, save_ckpt_type='lightning', ckpt_path=None, max_length=6144, fix_token_embedding=True, **kwargs)

Class for full finetuning of a language model on given SFT datasets.

Parameters:

  • optimizer (DictConfig) –

    Configuration for the optimizer.

  • lr_scheduler (DictConfig) –

    Configuration for the learning rate scheduler.

  • dataloader_kwargs (DictConfig) –

    Configuration for the dataloader, such as batch size, num_workers, etc.

  • max_epochs (int) –

    Maximum number of epochs to train the model. If set to -1, the training will continue indefinitely or until max_steps is reached.

  • max_steps (int, default: -1 ) –

    Maximum number of steps to train the model. If set to -1, the training will continue indefinitely or until max_epochs is reached.

  • max_steps_per_epoch (int, default: -1 ) –

    Maximum number of steps to train the model in each epoch. If set to -1, the training will continue until the end of the epoch.

  • lr_scheduler_interval (str, default: 'step' ) –

    Interval at which to run the learning rate scheduler. Available options: 'epoch', 'step'. If set to 'epoch', the scheduler will run at the end of each epoch. If set to 'step', the scheduler will run at the end of each step.

  • lr_scheduler_frequency (int, default: 1 ) –

    Frequency at which to run the learning rate scheduler. The scheduler will run every lr_scheduler_frequency epochs or steps, depending on the value of lr_scheduler_interval.

  • checkpoint_save_interval (str, default: 'epoch' ) –

    Interval at which to save the model checkpoint. Available options: 'epoch', 'step'. If set to 'epoch', the model will be saved at the end of each epoch. If set to 'step', the model will be saved at the end of each step.

  • checkpoint_save_frequency (int, default: 1 ) –

    Frequency at which to save the model checkpoint. The model will be saved every checkpoint_save_frequency epochs or steps, depending on the value of checkpoint_save_interval.

  • accumulate_grad_batches (int, default: 1 ) –

    Number of batches to accumulate gradients across before updating the model parameters.

  • gradient_clip_val (float, default: None ) –

    Value to clip the gradients. If set to None, no gradient clipping will be applied.

  • gradient_clip_algorithm (str, default: 'norm' ) –

    Algorithm to use for gradient clipping. Available options: 'value', 'norm'. If set to 'value', the gradients will be clipped to the specified value. If set to 'norm', the gradients will be clipped to the specified norm.

  • save_optimizer_state (bool, default: False ) –

    Whether to save the optimizer and lr_scheduler state along with the model checkpoint.

  • save_full_model (bool, default: False ) –

    Whether to save the full model or only the trainable parameters in the model checkpoint.

  • save_ckpt_type (str, default: 'lightning' ) –

    Type of checkpoint to save. Available options: 'lightning', 'hf'. If set to 'lightning', the checkpoint will be saved in the lightning format. If set to 'hf', the checkpoint will be saved in the huggingface format.

  • ckpt_path (str, default: None ) –

    Path to the checkpoint to load before training. If set to None, no checkpoint will be loaded.

  • max_length (int, default: 6144 ) –

    Maximum input length to consider. If the input length exceeds this value, it will be truncated.

  • fix_token_embedding (bool, default: True ) –

    Whether to fix the token embeddings during training. If set to True, the token embeddings will not be updated during training.

Source code in fusion_bench/method/lm_finetune/fullfinetune_sft.py
def __init__(
    self,
    optimizer: DictConfig,
    lr_scheduler: Optional[DictConfig],
    dataloader_kwargs: DictConfig,
    max_epochs: int,
    max_steps: int = -1,
    max_steps_per_epoch: int = -1,
    lr_scheduler_interval: Literal["epoch", "step"] = "step",
    lr_scheduler_frequency: int = 1,
    checkpoint_save_interval: Literal["epoch", "step"] = "epoch",
    checkpoint_save_frequency: int = 1,
    accumulate_grad_batches: int = 1,
    gradient_clip_val: Optional[float] = None,
    gradient_clip_algorithm: Literal["value", "norm"] = "norm",
    save_optimizer_state: bool = False,
    save_full_model: bool = False,
    save_ckpt_type: Literal["lightning", "hf"] = "lightning",
    ckpt_path: Optional[str] = None,
    max_length: int = 6144,
    fix_token_embedding: bool = True,
    **kwargs,
):
    """
    Class for full finetuning of a language model on given SFT datasets.

    Args:
        optimizer(DictConfig): Configuration for the optimizer.
        lr_scheduler(DictConfig): Configuration for the learning rate scheduler.
        dataloader_kwargs(DictConfig): Configuration for the dataloader, such as batch size, num_workers, etc.
        max_epochs(int): Maximum number of epochs to train the model. If set to -1, the training will continue indefinitely or until max_steps is reached.
        max_steps(int): Maximum number of steps to train the model. If set to -1, the training will continue indefinitely or until max_epochs is reached.
        max_steps_per_epoch(int): Maximum number of steps to train the model in each epoch. If set to -1, the training will continue until the end of the epoch.
        lr_scheduler_interval(str): Interval at which to run the learning rate scheduler. Available options: 'epoch', 'step'. If set to 'epoch', the scheduler will run at the end of each epoch. If set to 'step', the scheduler will run at the end of each step.
        lr_scheduler_frequency(int): Frequency at which to run the learning rate scheduler. The scheduler will run every `lr_scheduler_frequency` epochs or steps, depending on the value of `lr_scheduler_interval`.
        checkpoint_save_interval(str): Interval at which to save the model checkpoint. Available options: 'epoch', 'step'. If set to 'epoch', the model will be saved at the end of each epoch. If set to 'step', the model will be saved at the end of each step.
        checkpoint_save_frequency(int): Frequency at which to save the model checkpoint. The model will be saved every `checkpoint_save_frequency` epochs or steps, depending on the value of `checkpoint_save_interval`.
        accumulate_grad_batches(int): Number of batches to accumulate gradients across before updating the model parameters.
        gradient_clip_val(float): Value to clip the gradients. If set to None, no gradient clipping will be applied.
        gradient_clip_algorithm(str): Algorithm to use for gradient clipping. Available options: 'value', 'norm'. If set to 'value', the gradients will be clipped to the specified value. If set to 'norm', the gradients will be clipped to the specified norm.
        save_optimizer_state(bool): Whether to save the optimizer and lr_scheduler state along with the model checkpoint.
        save_full_model(bool): Whether to save the full model or only the trainable parameters in the model checkpoint.
        save_ckpt_type (str): Type of checkpoint to save. Available options: 'lightning', 'hf'. If set to 'lightning', the checkpoint will be saved in the lightning format. If set to 'hf', the checkpoint will be saved in the huggingface format.
        ckpt_path(str): Path to the checkpoint to load before training. If set to None, no checkpoint will be loaded.
        max_length(int): Maximum input length to consider. If the input length exceeds this value, it will be truncated.
        fix_token_embedding(bool): Whether to fix the token embeddings during training. If set to True, the token embeddings will not be updated during training.
    """
    self._optimizer = optimizer
    self._lr_scheduler = lr_scheduler
    self.dataloader_kwargs = dataloader_kwargs
    self.max_epochs = max_epochs
    self.max_steps = max_steps
    self.max_steps_per_epoch = max_steps_per_epoch
    self.lr_scheduler_interval = lr_scheduler_interval
    self.lr_scheduler_frequency = lr_scheduler_frequency
    self.checkpoint_save_interval = checkpoint_save_interval
    self.checkpoint_save_frequency = checkpoint_save_frequency
    self.accumulate_grad_batches = accumulate_grad_batches
    self.gradient_clip_val = gradient_clip_val
    self.gradient_clip_algorithm = gradient_clip_algorithm
    self.save_optimizer_state = save_optimizer_state
    self.save_full_model = save_full_model
    self.save_ckpt_type = save_ckpt_type
    self.ckpt_path = ckpt_path
    self.max_length = max_length
    self.fix_token_embedding = fix_token_embedding
    super().__init__(**kwargs)

PeftFinetuneSFT

Bases: BaseAlgorithm, FabricTrainingMixin

Source code in fusion_bench/method/lm_finetune/peftfinetune_sft.py
class PeftFinetuneSFT(BaseAlgorithm, FabricTrainingMixin):

    model: Union[
        nn.Module, "_FabricModule", "LlamaForCausalLM", PeftModel, peft.LoraModel
    ]
    optimizer: Union[torch.optim.Optimizer, "_FabricOptimizer"]
    train_dataloader: Union[DataLoader, "_FabricDataLoader"]
    lr_scheduler: torch.optim.lr_scheduler.LRScheduler
    _latest_saved_checkpoint_global_step: int = -1

    def __init__(
        self,
        optimizer: DictConfig,
        lr_scheduler: Optional[DictConfig],
        peft_config: DictConfig,
        dataloader_kwargs: DictConfig,
        adapter_name: str = "default",
        merge_and_unload: bool = False,
        max_epochs: int = 1,
        max_steps: int = -1,
        max_steps_per_epoch: int = -1,
        lr_scheduler_interval: Literal["epoch", "step"] = "step",
        lr_scheduler_frequency: int = 1,
        checkpoint_save_interval: Literal["epoch", "step"] = "epoch",
        checkpoint_save_frequency: int = 1,
        accumulate_grad_batches: int = 1,
        gradient_clip_val: Optional[float] = None,
        gradient_clip_algorithm: Literal["value", "norm"] = "norm",
        save_optimizer_state: bool = False,
        save_full_model: bool = False,
        save_ckpt_type: Literal["lightning", "peft"] = "peft",
        ckpt_path: Optional[str] = None,
        max_length: int = 6144,
        **kwargs,
    ):
        """
        Class for full finetuning of a language model on given SFT datasets.

        Args:
            optimizer(DictConfig): Configuration for the optimizer.
            lr_scheduler(DictConfig): Configuration for the learning rate scheduler.
            peft_config(DictConfig): Configuration for the PEFT model.
            dataloader_kwargs(DictConfig): Configuration for the dataloader, such as batch size, num_workers, etc.
            adapter_name(str): Name of the adapter to use for the PEFT model.
            merge_and_unload(bool): Whether to merge and unload the model after training.
            max_epochs(int): Maximum number of epochs to train the model. If set to -1, the training will continue indefinitely or until max_steps is reached.
            max_steps(int): Maximum number of steps to train the model. If set to -1, the training will continue indefinitely or until max_epochs is reached.
            max_steps_per_epoch(int): Maximum number of steps to train the model in each epoch. If set to -1, the training will continue until the end of the epoch.
            lr_scheduler_interval(str): Interval at which to run the learning rate scheduler. Available options: 'epoch', 'step'. If set to 'epoch', the scheduler will run at the end of each epoch. If set to 'step', the scheduler will run at the end of each step.
            lr_scheduler_frequency(int): Frequency at which to run the learning rate scheduler. The scheduler will run every `lr_scheduler_frequency` epochs or steps, depending on the value of `lr_scheduler_interval`.
            checkpoint_save_interval(str): Interval at which to save the model checkpoint. Available options: 'epoch', 'step'. If set to 'epoch', the model will be saved at the end of each epoch. If set to 'step', the model will be saved at the end of each step.
            checkpoint_save_frequency(int): Frequency at which to save the model checkpoint. The model will be saved every `checkpoint_save_frequency` epochs or steps, depending on the value of `checkpoint_save_interval`.
            accumulate_grad_batches(int): Number of batches to accumulate gradients across before updating the model parameters.
            gradient_clip_val(float): Value to clip the gradients. If set to None, no gradient clipping will be applied.
            gradient_clip_algorithm(str): Algorithm to use for gradient clipping. Available options: 'value', 'norm'. If set to 'value', the gradients will be clipped to the specified value. If set to 'norm', the gradients will be clipped to the specified norm.
            save_optimizer_state(bool): Whether to save the optimizer and lr_scheduler state along with the model checkpoint.
            save_full_model(bool): Whether to save the full model or only the trainable parameters in the model checkpoint.
            save_ckpt_type(str): Type of checkpoint to save. Available options: 'lightning', 'peft'. If set to 'lightning', the model will be saved using the Lightning checkpointing mechanism. If set to 'peft', the model will be saved using the PEFT checkpointing mechanism.
            ckpt_path(str): Path to the checkpoint to load before training. If set to None, no checkpoint will be loaded.
        """
        self._optimizer = optimizer
        self._lr_scheduler = lr_scheduler
        self._peft_config = peft_config
        self.dataloader_kwargs = dataloader_kwargs
        self.adapter_name = adapter_name
        self.merge_and_unload = merge_and_unload
        self.max_epochs = max_epochs
        self.max_steps = max_steps
        self.max_steps_per_epoch = max_steps_per_epoch
        self.lr_scheduler_interval = lr_scheduler_interval
        self.lr_scheduler_frequency = lr_scheduler_frequency
        self.checkpoint_save_interval = checkpoint_save_interval
        self.checkpoint_save_frequency = checkpoint_save_frequency
        self.accumulate_grad_batches = accumulate_grad_batches
        self.gradient_clip_val = gradient_clip_val
        self.gradient_clip_algorithm = gradient_clip_algorithm
        self.save_optimizer_state = save_optimizer_state
        self.save_full_model = save_full_model
        self.save_ckpt_type = save_ckpt_type
        self.ckpt_path = ckpt_path
        self.max_length = max_length
        super().__init__(**kwargs)

    def run(self, modelpool: CausalLMPool):
        self.modelpool = modelpool
        self.setup()
        self.train(self.model, self.optimizer, self.lr_scheduler)

        if self.merge_and_unload:
            self.model = self.model.merge_and_unload()
        return self.model

    def setup_model(self):
        # https://github.com/Lightning-AI/litgpt/blob/main/litgpt/finetune/lora.py
        self.tokenizer = self.modelpool.load_tokenizer()
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

        model = self.modelpool.load_pretrained_model()

        # get the PEFT model
        peft_config = instantiate(self._peft_config, _convert_="all")
        peft_config.save_pretrained(os.path.join(self.log_dir, "peft_config"))
        peft_model = get_peft_model(model, peft_config, self.adapter_name)
        peft_model.print_trainable_parameters()

        self.model = peft_model

        if self.fabric.strategy == "fsdp" or isinstance(
            self.fabric.strategy, FSDPStrategy
        ):
            # https://github.com/Lightning-AI/pytorch-lightning/issues/19267
            self.model.gradient_checkpointing_enable(
                gradient_checkpointing_kwargs={"use_reentrant": True}
            )
            self.use_cache = False
        else:
            self.use_cache = True

        self.model_dtype = get_dtype(self.model)
        self.model = self.model.to(dtype=self.model_dtype)

    def configure_optimizer(self):
        # compute expected total steps
        self.compute_expected_total_steps(self.train_dataloader)

        optimizer = instantiate(self._optimizer, self.model.parameters())
        if self._lr_scheduler is not None:
            for key, arg in self._lr_scheduler.items():
                if arg == "_T_max_":
                    log.info(
                        f"Setting key `{key}` of lr_scheduler configuration to {self.expected_total_steps}"
                    )
                    self._lr_scheduler[key] = self.expected_total_steps
            lr_scheduler: torch.optim.lr_scheduler.LRScheduler = instantiate(
                self._lr_scheduler,
                optimizer=optimizer,
            )
        else:
            lr_scheduler = None
        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}

    def setup_data(self):
        fabric = self.fabric
        modelpool = self.modelpool
        assert (
            len(modelpool.train_dataset_names) > 0
        ), "No training datasets found in modelpool."

        train_datasets = [
            modelpool.load_train_dataset(dataset_name)
            for dataset_name in modelpool.train_dataset_names
        ]
        if len(train_datasets) > 1:
            train_dataset = ConcatDataset(train_datasets)
        else:
            train_dataset = train_datasets[0]

        self.train_dataset = train_dataset
        self.train_dataloader = DataLoader(
            train_dataset,
            **self.dataloader_kwargs,
            shuffle=True,
            collate_fn=functools.partial(
                padded_collate_sft, pad_token_id=self.tokenizer.pad_token_id
            ),
        )
        self.train_dataloader = fabric.setup_dataloaders(self.train_dataloader)

    def setup(self):
        fabric = self.fabric

        self.setup_model()
        self.setup_data()

        optimizer = self.configure_optimizer()
        optimizer, lr_scheduler = optimizer["optimizer"], optimizer["lr_scheduler"]

        self.model = self.fabric.setup_module(self.model)
        self.optimizer = self.fabric.setup_optimizers(optimizer)
        self.lr_scheduler = lr_scheduler

    @override
    def train_epoch(self, *args, **kwargs):
        fabric = self.fabric

        accumulated_loss = 0
        for step_idx, batch in enumerate(
            pbar := tqdm(
                self.train_dataloader,
                desc="Training Batches",
                dynamic_ncols=True,
                leave=False,
                disable=not fabric.is_global_zero,
            )
        ):
            is_accumulating = (step_idx + 1) % self.accumulate_grad_batches != 0

            if self.max_length > 0 and batch["input_ids"].shape[1] > self.max_length:
                log.warning(
                    f"Input length exceeds max_length: {batch['input_ids'].shape[1]} > {self.max_length}. Truncating input."
                )
                batch["input_ids"] = batch["input_ids"][:, : self.max_length]
                batch["attention_mask"] = batch["attention_mask"][:, : self.max_length]
                batch["labels"] = batch["labels"][:, : self.max_length]

            # disable gradient synchronization if accumulating gradients across steps for improved performance
            with fabric.no_backward_sync(self.model, enabled=is_accumulating):
                # use_cache=True is not compatible with gradient checkpointing, so we disable it here
                output = self.model(
                    input_ids=batch["input_ids"],
                    attention_mask=batch["attention_mask"],
                    labels=batch["labels"],
                    use_cache=self.use_cache,
                )
                loss = output["loss"] / self.accumulate_grad_batches

                fabric.backward(loss)
                accumulated_loss += loss.item()

            if not is_accumulating:
                self.clip_gradients_if_needed(self.model, self.optimizer)

                # run lr_scheduler at the end of the step if interval is set to "step"
                if (
                    self.lr_scheduler_interval == "step"
                    and (self.global_step_idx + 1) % self.lr_scheduler_frequency == 0
                ):
                    self.lr_scheduler.step()

                # update the model parameters and zero the gradients
                self.optimizer.step()
                self.optimizer.zero_grad()

                metrics = {
                    "train/loss": accumulated_loss,
                    "train/epoch_idx": self.epoch_idx,
                    "train/lr": self.optimizer.param_groups[0]["lr"],
                }
                fabric.log_dict(metrics, step=self.global_step_idx)
                pbar.set_postfix(metrics)

                # save the model at the end of the step if interval is set to "step" and frequency is met
                self.conditional_checkpoint_save(stage="end_of_step")

                # break if max_steps_per_epoch is set, and exit epoch
                if (
                    self.max_steps_per_epoch > 0
                    and step_idx + 1 >= self.max_steps_per_epoch
                ):
                    break
                # break if max_steps is set, and exit training
                if self.max_steps > 0 and self.global_step_idx >= self.max_steps - 1:
                    self.is_training = False
                    break

                self.global_step_idx += 1
                accumulated_loss = 0

    def save_checkpoint(
        self,
        path: Union[str, Path],
        save_optimizer_state: Optional[bool] = None,
        overwrite: bool = False,
    ):
        if not overwrite and os.path.exists(path):
            return log.warning(f"Checkpoint already exists at {path}. Skipping save.")

        fabric = self.fabric
        if self.save_ckpt_type == "lightning":
            state = {"model": self.model}

            # save the optimizer and lr_scheduler state if needed
            if self.save_optimizer_state and save_optimizer_state is not False:
                state.update(
                    {
                        "optimizer": self.optimizer,
                        "lr_scheduler": self.lr_scheduler,
                        "global_step_idx": self.global_step_idx,
                        "epoch_idx": self.epoch_idx,
                    }
                )
            trainable_param_names = set(
                name
                for name, param in self.model.state_dict(keep_vars=True).items()
                if param.requires_grad
            )
            filter = (
                None
                if self.save_full_model
                else {"model": lambda k, p: k in trainable_param_names}
            )
            os.makedirs(os.path.dirname(path), exist_ok=True)
            fabric.save(path, state=state, filter=filter)
        elif self.save_ckpt_type == "peft":
            self.model.save_pretrained(path, is_main_process=fabric.is_global_zero)
        else:
            raise ValueError(
                f"Unknown save_ckpt_type: {self.save_ckpt_type}. Available options: 'lightning', 'peft'"
            )
        self._latest_saved_checkpoint_global_step = self.global_step_idx

    def load_checkpoint(self, path: Union[str, Path]):
        fabric = self.fabric

        state = {"model": self.model}

        # save the optimizer and lr_scheduler state if needed
        if self.save_optimizer_state:
            state.update(
                {
                    "optimizer": self.optimizer,
                    "lr_scheduler": self.lr_scheduler,
                }
            )

        fabric.load(path, state)

__init__(optimizer, lr_scheduler, peft_config, dataloader_kwargs, adapter_name='default', merge_and_unload=False, max_epochs=1, max_steps=-1, max_steps_per_epoch=-1, lr_scheduler_interval='step', lr_scheduler_frequency=1, checkpoint_save_interval='epoch', checkpoint_save_frequency=1, accumulate_grad_batches=1, gradient_clip_val=None, gradient_clip_algorithm='norm', save_optimizer_state=False, save_full_model=False, save_ckpt_type='peft', ckpt_path=None, max_length=6144, **kwargs)

Class for full finetuning of a language model on given SFT datasets.

Parameters:

  • optimizer (DictConfig) –

    Configuration for the optimizer.

  • lr_scheduler (DictConfig) –

    Configuration for the learning rate scheduler.

  • peft_config (DictConfig) –

    Configuration for the PEFT model.

  • dataloader_kwargs (DictConfig) –

    Configuration for the dataloader, such as batch size, num_workers, etc.

  • adapter_name (str, default: 'default' ) –

    Name of the adapter to use for the PEFT model.

  • merge_and_unload (bool, default: False ) –

    Whether to merge and unload the model after training.

  • max_epochs (int, default: 1 ) –

    Maximum number of epochs to train the model. If set to -1, the training will continue indefinitely or until max_steps is reached.

  • max_steps (int, default: -1 ) –

    Maximum number of steps to train the model. If set to -1, the training will continue indefinitely or until max_epochs is reached.

  • max_steps_per_epoch (int, default: -1 ) –

    Maximum number of steps to train the model in each epoch. If set to -1, the training will continue until the end of the epoch.

  • lr_scheduler_interval (str, default: 'step' ) –

    Interval at which to run the learning rate scheduler. Available options: 'epoch', 'step'. If set to 'epoch', the scheduler will run at the end of each epoch. If set to 'step', the scheduler will run at the end of each step.

  • lr_scheduler_frequency (int, default: 1 ) –

    Frequency at which to run the learning rate scheduler. The scheduler will run every lr_scheduler_frequency epochs or steps, depending on the value of lr_scheduler_interval.

  • checkpoint_save_interval (str, default: 'epoch' ) –

    Interval at which to save the model checkpoint. Available options: 'epoch', 'step'. If set to 'epoch', the model will be saved at the end of each epoch. If set to 'step', the model will be saved at the end of each step.

  • checkpoint_save_frequency (int, default: 1 ) –

    Frequency at which to save the model checkpoint. The model will be saved every checkpoint_save_frequency epochs or steps, depending on the value of checkpoint_save_interval.

  • accumulate_grad_batches (int, default: 1 ) –

    Number of batches to accumulate gradients across before updating the model parameters.

  • gradient_clip_val (float, default: None ) –

    Value to clip the gradients. If set to None, no gradient clipping will be applied.

  • gradient_clip_algorithm (str, default: 'norm' ) –

    Algorithm to use for gradient clipping. Available options: 'value', 'norm'. If set to 'value', the gradients will be clipped to the specified value. If set to 'norm', the gradients will be clipped to the specified norm.

  • save_optimizer_state (bool, default: False ) –

    Whether to save the optimizer and lr_scheduler state along with the model checkpoint.

  • save_full_model (bool, default: False ) –

    Whether to save the full model or only the trainable parameters in the model checkpoint.

  • save_ckpt_type (str, default: 'peft' ) –

    Type of checkpoint to save. Available options: 'lightning', 'peft'. If set to 'lightning', the model will be saved using the Lightning checkpointing mechanism. If set to 'peft', the model will be saved using the PEFT checkpointing mechanism.

  • ckpt_path (str, default: None ) –

    Path to the checkpoint to load before training. If set to None, no checkpoint will be loaded.

Source code in fusion_bench/method/lm_finetune/peftfinetune_sft.py
def __init__(
    self,
    optimizer: DictConfig,
    lr_scheduler: Optional[DictConfig],
    peft_config: DictConfig,
    dataloader_kwargs: DictConfig,
    adapter_name: str = "default",
    merge_and_unload: bool = False,
    max_epochs: int = 1,
    max_steps: int = -1,
    max_steps_per_epoch: int = -1,
    lr_scheduler_interval: Literal["epoch", "step"] = "step",
    lr_scheduler_frequency: int = 1,
    checkpoint_save_interval: Literal["epoch", "step"] = "epoch",
    checkpoint_save_frequency: int = 1,
    accumulate_grad_batches: int = 1,
    gradient_clip_val: Optional[float] = None,
    gradient_clip_algorithm: Literal["value", "norm"] = "norm",
    save_optimizer_state: bool = False,
    save_full_model: bool = False,
    save_ckpt_type: Literal["lightning", "peft"] = "peft",
    ckpt_path: Optional[str] = None,
    max_length: int = 6144,
    **kwargs,
):
    """
    Class for full finetuning of a language model on given SFT datasets.

    Args:
        optimizer(DictConfig): Configuration for the optimizer.
        lr_scheduler(DictConfig): Configuration for the learning rate scheduler.
        peft_config(DictConfig): Configuration for the PEFT model.
        dataloader_kwargs(DictConfig): Configuration for the dataloader, such as batch size, num_workers, etc.
        adapter_name(str): Name of the adapter to use for the PEFT model.
        merge_and_unload(bool): Whether to merge and unload the model after training.
        max_epochs(int): Maximum number of epochs to train the model. If set to -1, the training will continue indefinitely or until max_steps is reached.
        max_steps(int): Maximum number of steps to train the model. If set to -1, the training will continue indefinitely or until max_epochs is reached.
        max_steps_per_epoch(int): Maximum number of steps to train the model in each epoch. If set to -1, the training will continue until the end of the epoch.
        lr_scheduler_interval(str): Interval at which to run the learning rate scheduler. Available options: 'epoch', 'step'. If set to 'epoch', the scheduler will run at the end of each epoch. If set to 'step', the scheduler will run at the end of each step.
        lr_scheduler_frequency(int): Frequency at which to run the learning rate scheduler. The scheduler will run every `lr_scheduler_frequency` epochs or steps, depending on the value of `lr_scheduler_interval`.
        checkpoint_save_interval(str): Interval at which to save the model checkpoint. Available options: 'epoch', 'step'. If set to 'epoch', the model will be saved at the end of each epoch. If set to 'step', the model will be saved at the end of each step.
        checkpoint_save_frequency(int): Frequency at which to save the model checkpoint. The model will be saved every `checkpoint_save_frequency` epochs or steps, depending on the value of `checkpoint_save_interval`.
        accumulate_grad_batches(int): Number of batches to accumulate gradients across before updating the model parameters.
        gradient_clip_val(float): Value to clip the gradients. If set to None, no gradient clipping will be applied.
        gradient_clip_algorithm(str): Algorithm to use for gradient clipping. Available options: 'value', 'norm'. If set to 'value', the gradients will be clipped to the specified value. If set to 'norm', the gradients will be clipped to the specified norm.
        save_optimizer_state(bool): Whether to save the optimizer and lr_scheduler state along with the model checkpoint.
        save_full_model(bool): Whether to save the full model or only the trainable parameters in the model checkpoint.
        save_ckpt_type(str): Type of checkpoint to save. Available options: 'lightning', 'peft'. If set to 'lightning', the model will be saved using the Lightning checkpointing mechanism. If set to 'peft', the model will be saved using the PEFT checkpointing mechanism.
        ckpt_path(str): Path to the checkpoint to load before training. If set to None, no checkpoint will be loaded.
    """
    self._optimizer = optimizer
    self._lr_scheduler = lr_scheduler
    self._peft_config = peft_config
    self.dataloader_kwargs = dataloader_kwargs
    self.adapter_name = adapter_name
    self.merge_and_unload = merge_and_unload
    self.max_epochs = max_epochs
    self.max_steps = max_steps
    self.max_steps_per_epoch = max_steps_per_epoch
    self.lr_scheduler_interval = lr_scheduler_interval
    self.lr_scheduler_frequency = lr_scheduler_frequency
    self.checkpoint_save_interval = checkpoint_save_interval
    self.checkpoint_save_frequency = checkpoint_save_frequency
    self.accumulate_grad_batches = accumulate_grad_batches
    self.gradient_clip_val = gradient_clip_val
    self.gradient_clip_algorithm = gradient_clip_algorithm
    self.save_optimizer_state = save_optimizer_state
    self.save_full_model = save_full_model
    self.save_ckpt_type = save_ckpt_type
    self.ckpt_path = ckpt_path
    self.max_length = max_length
    super().__init__(**kwargs)

Reward Modeling

BradleyTerryRewardModeling

Bases: BaseAlgorithm, FabricTrainingMixin

Source code in fusion_bench/method/lm_finetune/bradley_terry_rm.py
class BradleyTerryRewardModeling(BaseAlgorithm, FabricTrainingMixin):

    model: Union[nn.Module, "_FabricModule", "LlamaForSequenceClassification"]
    optimizer: Union[torch.optim.Optimizer, "_FabricOptimizer"]
    train_dataloader: Union[DataLoader, "_FabricDataLoader"]
    lr_scheduler: torch.optim.lr_scheduler.LRScheduler

    def __init__(
        self,
        optimizer: DictConfig,
        lr_scheduler: Optional[DictConfig],
        dataloader_kwargs: DictConfig,
        max_epochs: int,
        max_steps: int = -1,
        max_steps_per_epoch: int = -1,
        lr_scheduler_interval: Literal["epoch", "step"] = "step",
        lr_scheduler_frequency: int = 1,
        checkpoint_save_interval: Literal["epoch", "step"] = "epoch",
        checkpoint_save_frequency: int = 1,
        accumulate_grad_batches: int = 1,
        gradient_clip_val: Optional[float] = None,
        gradient_clip_algorithm: Literal["value", "norm"] = "norm",
        save_optimizer_state: bool = False,
        save_full_model: bool = False,
        save_ckpt_type: Literal["lightning", "hf"] = "lightning",
        ckpt_path: Optional[str] = None,
        max_length: int = 6144,
        fix_token_embedding: bool = True,
        **kwargs,
    ):
        """
        Class for reward modeling using Bradley-Terry model.

        Args:
            optimizer(DictConfig): Configuration for the optimizer.
            lr_scheduler(DictConfig): Configuration for the learning rate scheduler.
            dataloader_kwargs(DictConfig): Configuration for the dataloader, such as batch size, num_workers, etc.
            max_epochs(int): Maximum number of epochs to train the model. If set to -1, the training will continue indefinitely or until max_steps is reached.
            max_steps(int): Maximum number of steps to train the model. If set to -1, the training will continue indefinitely or until max_epochs is reached.
            max_steps_per_epoch(int): Maximum number of steps to train the model in each epoch. If set to -1, the training will continue until the end of the epoch.
            lr_scheduler_interval(str): Interval at which to run the learning rate scheduler. Available options: 'epoch', 'step'. If set to 'epoch', the scheduler will run at the end of each epoch. If set to 'step', the scheduler will run at the end of each step.
            lr_scheduler_frequency(int): Frequency at which to run the learning rate scheduler. The scheduler will run every `lr_scheduler_frequency` epochs or steps, depending on the value of `lr_scheduler_interval`.
            checkpoint_save_interval(str): Interval at which to save the model checkpoint. Available options: 'epoch', 'step'. If set to 'epoch', the model will be saved at the end of each epoch. If set to 'step', the model will be saved at the end of each step.
            checkpoint_save_frequency(int): Frequency at which to save the model checkpoint. The model will be saved every `checkpoint_save_frequency` epochs or steps, depending on the value of `checkpoint_save_interval`.
            accumulate_grad_batches(int): Number of batches to accumulate gradients across before updating the model parameters.
            gradient_clip_val(float): Value to clip the gradients. If set to None, no gradient clipping will be applied.
            gradient_clip_algorithm(str): Algorithm to use for gradient clipping. Available options: 'value', 'norm'. If set to 'value', the gradients will be clipped to the specified value. If set to 'norm', the gradients will be clipped to the specified norm.
            save_optimizer_state(bool): Whether to save the optimizer and lr_scheduler state along with the model checkpoint.
            save_full_model(bool): Whether to save the full model or only the trainable parameters in the model checkpoint.
            save_ckpt_type (str): Type of checkpoint to save. Available options: 'lightning', 'hf'. If set to 'lightning', the checkpoint will be saved in the lightning format. If set to 'hf', the checkpoint will be saved in the huggingface format.
            ckpt_path(str): Path to the checkpoint to load before training. If set to None, no checkpoint will be loaded.
            max_length(int): Maximum input length to consider. If the input length exceeds this value, it will be truncated.
            fix_token_embedding(bool): Whether to fix the token embeddings during training. If set to True, the token embeddings will not be updated during training.
        """
        self._optimizer = optimizer
        self._lr_scheduler = lr_scheduler
        self.dataloader_kwargs = dataloader_kwargs
        self.max_epochs = max_epochs
        self.max_steps = max_steps
        self.max_steps_per_epoch = max_steps_per_epoch
        self.lr_scheduler_interval = lr_scheduler_interval
        self.lr_scheduler_frequency = lr_scheduler_frequency
        self.checkpoint_save_interval = checkpoint_save_interval
        self.checkpoint_save_frequency = checkpoint_save_frequency
        self.accumulate_grad_batches = accumulate_grad_batches
        self.gradient_clip_val = gradient_clip_val
        self.gradient_clip_algorithm = gradient_clip_algorithm
        self.save_optimizer_state = save_optimizer_state
        self.save_full_model = save_full_model
        self.save_ckpt_type = save_ckpt_type
        self.ckpt_path = ckpt_path
        self.max_length = max_length
        self.fix_token_embedding = fix_token_embedding
        super().__init__(**kwargs)

    def run(self, modelpool: SequenceClassificationModelPool):
        self.modelpool = modelpool
        self.setup()
        self.train(self.model, self.optimizer, self.lr_scheduler)
        return self.model

    def setup_model(self):
        self.tokenizer = self.modelpool.load_tokenizer()
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = (
                self.tokenizer.eos_token_id
            )  #! make sure eos_token_id only show up at the end of the sequence

        model = self.modelpool.load_pretrained_model()
        self.model: "LlamaForSequenceClassification" = model

        if model.config.pad_token_id is None:
            model.config.pad_token_id = self.tokenizer.pad_token_id

        if self.fix_token_embedding:
            self.model.model.embed_tokens.requires_grad_(False)

        if self.fabric.strategy == "fsdp" or isinstance(
            self.fabric.strategy, FSDPStrategy
        ):
            # https://github.com/Lightning-AI/pytorch-lightning/issues/19267
            self.model.gradient_checkpointing_enable(
                gradient_checkpointing_kwargs={"use_reentrant": True}
            )
            self.use_cache = False
        else:
            self.use_cache = True
        self.model_dtype = get_dtype(self.model)

    def setup_data(self):
        fabric = self.fabric
        modelpool = self.modelpool
        assert (
            len(modelpool.train_dataset_names) > 0
        ), "No training datasets found in modelpool."

        train_datasets = [
            modelpool.load_train_dataset(dataset_name)
            for dataset_name in modelpool.train_dataset_names
        ]
        if len(train_datasets) > 1:
            train_dataset = ConcatDataset(train_datasets)
        else:
            train_dataset = train_datasets[0]

        self.train_dataset = train_dataset
        self.train_dataloader = DataLoader(
            train_dataset,
            **self.dataloader_kwargs,
            shuffle=True,
            collate_fn=functools.partial(
                bradley_terry_rm_collate,
                pad_token_id=self.tokenizer.pad_token_id,
            ),  # NOTE: different from SFT, uses bradley_terry_rm_collate
        )
        self.train_dataloader = fabric.setup_dataloaders(self.train_dataloader)

    def configure_optimizer(self):
        # compute expected total steps
        self.compute_expected_total_steps(self.train_dataloader)

        optimizer = instantiate(self._optimizer, self.model.parameters())
        if self._lr_scheduler is not None:
            for key, arg in self._lr_scheduler.items():
                if arg == "_T_max_":
                    log.info(
                        f"Setting key `{key}` of lr_scheduler configuration to {self.expected_total_steps}"
                    )
                    self._lr_scheduler[key] = self.expected_total_steps
            lr_scheduler: torch.optim.lr_scheduler.LRScheduler = instantiate(
                self._lr_scheduler,
                optimizer=optimizer,
            )
        else:
            lr_scheduler = None
        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}

    def setup(self):
        fabric = self.fabric

        self.setup_model()
        self.setup_data()

        optimizer = self.configure_optimizer()
        optimizer, lr_scheduler = optimizer["optimizer"], optimizer["lr_scheduler"]

        self.model, self.optimizer = fabric.setup(self.model, optimizer)
        self.lr_scheduler = lr_scheduler

    def compute_loss(self, batch: Dict[str, Union[Tensor, Any]]) -> Dict[str, Tensor]:
        """
        Maximize the likelihood of the winner over the loser using the Bradley-Terry model.

        Args:
            batch (Dict[str, Union[Tensor, Any]]): A dictionary containing the input token ids and attention masks for the winner and loser.
        """
        batch_size = batch["input_ids"].size(0)
        assert batch_size % 2 == 0, "Batch size must be even."

        outputs = self.model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            use_cache=self.use_cache,
        )

        rewards = outputs[0]
        chosen_reward = rewards[: batch_size // 2]
        rejected_rewards = rewards[batch_size // 2 :]
        loss = -torch.log(torch.sigmoid(chosen_reward - rejected_rewards)).mean()

        return {
            "chosen_reward": chosen_reward,
            "rejected_reward": rejected_rewards,
            "loss": loss,
        }

    @override
    def train_epoch(self, *args, **kwargs):
        fabric = self.fabric

        accumulated_loss = 0
        accumulated_chosen_reward = 0
        accumulated_rejected_reward = 0
        for step_idx, batch in enumerate(
            pbar := tqdm(
                self.train_dataloader,
                desc="Training Batches",
                dynamic_ncols=True,
                leave=False,
                disable=not fabric.is_global_zero,
            )
        ):
            is_accumulating = (step_idx + 1) % self.accumulate_grad_batches != 0

            if self.max_length > 0 and batch["input_ids"].shape[1] > self.max_length:
                log.warning(
                    f"Input length exceeds max_length: {batch['input_ids'].shape[1]} > {self.max_length}. Truncating input."
                )
                batch["input_ids"] = batch["input_ids"][:, -self.max_length :]
                batch["attention_mask"] = batch["attention_mask"][:, -self.max_length :]

            # disable gradient synchronization if accumulating gradients across steps for improved performance
            with fabric.no_backward_sync(self.model, enabled=is_accumulating):
                # use_cache=True is not compatible with gradient checkpointing, so we disable it here
                output = self.compute_loss(batch)
                loss = output["loss"] / self.accumulate_grad_batches

                fabric.backward(loss)

                accumulated_loss += loss.item()
                accumulated_chosen_reward += output["chosen_reward"].mean().item()
                accumulated_rejected_reward += output["rejected_reward"].mean().item()

            # 1. update the model parameters if not accumulating gradients
            # 2. step the lr_scheduler if interval is set to "step" and frequency is met
            # 3. save the model if interval is set to "step" and frequency is met
            # 4. log metrics
            # 5. increase the global step index
            if not is_accumulating:
                self.clip_gradients_if_needed(self.model, self.optimizer)

                # run lr_scheduler at the end of the step if interval is set to "step"
                if (
                    self.lr_scheduler_interval == "step"
                    and (self.global_step_idx + 1) % self.lr_scheduler_frequency == 0
                ):
                    self.lr_scheduler.step()

                # update the model parameters and zero the gradients
                self.optimizer.step()
                self.optimizer.zero_grad()

                metrics = {
                    "train/loss": accumulated_loss,
                    "train/chosen_reward": accumulated_chosen_reward
                    / self.accumulate_grad_batches,
                    "train/rejected_reward": accumulated_rejected_reward
                    / self.accumulate_grad_batches,
                    "train/epoch_idx": self.epoch_idx,
                    "train/lr": self.optimizer.param_groups[0]["lr"],
                }
                metrics["train/chosen_reward-rejected_reward"] = (
                    metrics["train/chosen_reward"] - metrics["train/rejected_reward"]
                )

                fabric.log_dict(metrics, step=self.global_step_idx)
                pbar.set_postfix(metrics)

                # save the model at the end of the step if interval is set to "step" and frequency is met
                self.conditional_checkpoint_save(stage="end_of_step")

                # break if max_steps_per_epoch is set, and exit epoch
                if (
                    self.max_steps_per_epoch > 0
                    and step_idx + 1 >= self.max_steps_per_epoch
                ):
                    break
                # break if max_steps is set, and exit training
                if self.max_steps > 0 and self.global_step_idx >= self.max_steps - 1:
                    self.is_training = False
                    break

                self.global_step_idx += 1
                accumulated_loss = 0
                accumulated_chosen_reward = 0
                accumulated_rejected_reward = 0

    def save_checkpoint(
        self,
        path: Union[str, Path],
        save_optimizer_state: Optional[bool] = None,
        overwrite: bool = False,
    ):
        if not overwrite and os.path.exists(path):
            return log.warning(f"Checkpoint already exists at {path}. Skipping save.")

        fabric = self.fabric

        if self.save_ckpt_type == "lightning":
            state = {"model": self.model}

            # save the optimizer and lr_scheduler state if needed
            if self.save_optimizer_state and save_optimizer_state is not False:
                state.update(
                    {
                        "optimizer": self.optimizer,
                        "lr_scheduler": self.lr_scheduler,
                        "global_step_idx": self.global_step_idx,
                        "epoch_idx": self.epoch_idx,
                    }
                )

            trainable_param_names = set(
                name
                for name, param in self.model.state_dict(keep_vars=True).items()
                if param.requires_grad
            )
            filter = (
                None
                if self.save_full_model
                else {"model": lambda k, p: k in trainable_param_names}
            )

            fabric.save(path, state=state, filter=filter)
        else:
            self.model.save_pretrained(path, is_main_process=fabric.is_global_zero)

        self._latest_saved_checkpoint_global_step = self.global_step_idx

    def load_checkpoint(self, path: Union[str, Path]):
        fabric = self.fabric

        state = {"model": self.model}

        # save the optimizer and lr_scheduler state if needed
        if self.save_optimizer_state:
            state.update(
                {
                    "optimizer": self.optimizer,
                    "lr_scheduler": self.lr_scheduler,
                }
            )

        fabric.load(path, state)

__init__(optimizer, lr_scheduler, dataloader_kwargs, max_epochs, max_steps=-1, max_steps_per_epoch=-1, lr_scheduler_interval='step', lr_scheduler_frequency=1, checkpoint_save_interval='epoch', checkpoint_save_frequency=1, accumulate_grad_batches=1, gradient_clip_val=None, gradient_clip_algorithm='norm', save_optimizer_state=False, save_full_model=False, save_ckpt_type='lightning', ckpt_path=None, max_length=6144, fix_token_embedding=True, **kwargs)

Class for reward modeling using Bradley-Terry model.

Parameters:

  • optimizer (DictConfig) –

    Configuration for the optimizer.

  • lr_scheduler (DictConfig) –

    Configuration for the learning rate scheduler.

  • dataloader_kwargs (DictConfig) –

    Configuration for the dataloader, such as batch size, num_workers, etc.

  • max_epochs (int) –

    Maximum number of epochs to train the model. If set to -1, the training will continue indefinitely or until max_steps is reached.

  • max_steps (int, default: -1 ) –

    Maximum number of steps to train the model. If set to -1, the training will continue indefinitely or until max_epochs is reached.

  • max_steps_per_epoch (int, default: -1 ) –

    Maximum number of steps to train the model in each epoch. If set to -1, the training will continue until the end of the epoch.

  • lr_scheduler_interval (str, default: 'step' ) –

    Interval at which to run the learning rate scheduler. Available options: 'epoch', 'step'. If set to 'epoch', the scheduler will run at the end of each epoch. If set to 'step', the scheduler will run at the end of each step.

  • lr_scheduler_frequency (int, default: 1 ) –

    Frequency at which to run the learning rate scheduler. The scheduler will run every lr_scheduler_frequency epochs or steps, depending on the value of lr_scheduler_interval.

  • checkpoint_save_interval (str, default: 'epoch' ) –

    Interval at which to save the model checkpoint. Available options: 'epoch', 'step'. If set to 'epoch', the model will be saved at the end of each epoch. If set to 'step', the model will be saved at the end of each step.

  • checkpoint_save_frequency (int, default: 1 ) –

    Frequency at which to save the model checkpoint. The model will be saved every checkpoint_save_frequency epochs or steps, depending on the value of checkpoint_save_interval.

  • accumulate_grad_batches (int, default: 1 ) –

    Number of batches to accumulate gradients across before updating the model parameters.

  • gradient_clip_val (float, default: None ) –

    Value to clip the gradients. If set to None, no gradient clipping will be applied.

  • gradient_clip_algorithm (str, default: 'norm' ) –

    Algorithm to use for gradient clipping. Available options: 'value', 'norm'. If set to 'value', the gradients will be clipped to the specified value. If set to 'norm', the gradients will be clipped to the specified norm.

  • save_optimizer_state (bool, default: False ) –

    Whether to save the optimizer and lr_scheduler state along with the model checkpoint.

  • save_full_model (bool, default: False ) –

    Whether to save the full model or only the trainable parameters in the model checkpoint.

  • save_ckpt_type (str, default: 'lightning' ) –

    Type of checkpoint to save. Available options: 'lightning', 'hf'. If set to 'lightning', the checkpoint will be saved in the lightning format. If set to 'hf', the checkpoint will be saved in the huggingface format.

  • ckpt_path (str, default: None ) –

    Path to the checkpoint to load before training. If set to None, no checkpoint will be loaded.

  • max_length (int, default: 6144 ) –

    Maximum input length to consider. If the input length exceeds this value, it will be truncated.

  • fix_token_embedding (bool, default: True ) –

    Whether to fix the token embeddings during training. If set to True, the token embeddings will not be updated during training.

Source code in fusion_bench/method/lm_finetune/bradley_terry_rm.py
def __init__(
    self,
    optimizer: DictConfig,
    lr_scheduler: Optional[DictConfig],
    dataloader_kwargs: DictConfig,
    max_epochs: int,
    max_steps: int = -1,
    max_steps_per_epoch: int = -1,
    lr_scheduler_interval: Literal["epoch", "step"] = "step",
    lr_scheduler_frequency: int = 1,
    checkpoint_save_interval: Literal["epoch", "step"] = "epoch",
    checkpoint_save_frequency: int = 1,
    accumulate_grad_batches: int = 1,
    gradient_clip_val: Optional[float] = None,
    gradient_clip_algorithm: Literal["value", "norm"] = "norm",
    save_optimizer_state: bool = False,
    save_full_model: bool = False,
    save_ckpt_type: Literal["lightning", "hf"] = "lightning",
    ckpt_path: Optional[str] = None,
    max_length: int = 6144,
    fix_token_embedding: bool = True,
    **kwargs,
):
    """
    Class for reward modeling using Bradley-Terry model.

    Args:
        optimizer(DictConfig): Configuration for the optimizer.
        lr_scheduler(DictConfig): Configuration for the learning rate scheduler.
        dataloader_kwargs(DictConfig): Configuration for the dataloader, such as batch size, num_workers, etc.
        max_epochs(int): Maximum number of epochs to train the model. If set to -1, the training will continue indefinitely or until max_steps is reached.
        max_steps(int): Maximum number of steps to train the model. If set to -1, the training will continue indefinitely or until max_epochs is reached.
        max_steps_per_epoch(int): Maximum number of steps to train the model in each epoch. If set to -1, the training will continue until the end of the epoch.
        lr_scheduler_interval(str): Interval at which to run the learning rate scheduler. Available options: 'epoch', 'step'. If set to 'epoch', the scheduler will run at the end of each epoch. If set to 'step', the scheduler will run at the end of each step.
        lr_scheduler_frequency(int): Frequency at which to run the learning rate scheduler. The scheduler will run every `lr_scheduler_frequency` epochs or steps, depending on the value of `lr_scheduler_interval`.
        checkpoint_save_interval(str): Interval at which to save the model checkpoint. Available options: 'epoch', 'step'. If set to 'epoch', the model will be saved at the end of each epoch. If set to 'step', the model will be saved at the end of each step.
        checkpoint_save_frequency(int): Frequency at which to save the model checkpoint. The model will be saved every `checkpoint_save_frequency` epochs or steps, depending on the value of `checkpoint_save_interval`.
        accumulate_grad_batches(int): Number of batches to accumulate gradients across before updating the model parameters.
        gradient_clip_val(float): Value to clip the gradients. If set to None, no gradient clipping will be applied.
        gradient_clip_algorithm(str): Algorithm to use for gradient clipping. Available options: 'value', 'norm'. If set to 'value', the gradients will be clipped to the specified value. If set to 'norm', the gradients will be clipped to the specified norm.
        save_optimizer_state(bool): Whether to save the optimizer and lr_scheduler state along with the model checkpoint.
        save_full_model(bool): Whether to save the full model or only the trainable parameters in the model checkpoint.
        save_ckpt_type (str): Type of checkpoint to save. Available options: 'lightning', 'hf'. If set to 'lightning', the checkpoint will be saved in the lightning format. If set to 'hf', the checkpoint will be saved in the huggingface format.
        ckpt_path(str): Path to the checkpoint to load before training. If set to None, no checkpoint will be loaded.
        max_length(int): Maximum input length to consider. If the input length exceeds this value, it will be truncated.
        fix_token_embedding(bool): Whether to fix the token embeddings during training. If set to True, the token embeddings will not be updated during training.
    """
    self._optimizer = optimizer
    self._lr_scheduler = lr_scheduler
    self.dataloader_kwargs = dataloader_kwargs
    self.max_epochs = max_epochs
    self.max_steps = max_steps
    self.max_steps_per_epoch = max_steps_per_epoch
    self.lr_scheduler_interval = lr_scheduler_interval
    self.lr_scheduler_frequency = lr_scheduler_frequency
    self.checkpoint_save_interval = checkpoint_save_interval
    self.checkpoint_save_frequency = checkpoint_save_frequency
    self.accumulate_grad_batches = accumulate_grad_batches
    self.gradient_clip_val = gradient_clip_val
    self.gradient_clip_algorithm = gradient_clip_algorithm
    self.save_optimizer_state = save_optimizer_state
    self.save_full_model = save_full_model
    self.save_ckpt_type = save_ckpt_type
    self.ckpt_path = ckpt_path
    self.max_length = max_length
    self.fix_token_embedding = fix_token_embedding
    super().__init__(**kwargs)

compute_loss(batch)

Maximize the likelihood of the winner over the loser using the Bradley-Terry model.

Parameters:

  • batch (Dict[str, Union[Tensor, Any]]) –

    A dictionary containing the input token ids and attention masks for the winner and loser.

Source code in fusion_bench/method/lm_finetune/bradley_terry_rm.py
def compute_loss(self, batch: Dict[str, Union[Tensor, Any]]) -> Dict[str, Tensor]:
    """
    Maximize the likelihood of the winner over the loser using the Bradley-Terry model.

    Args:
        batch (Dict[str, Union[Tensor, Any]]): A dictionary containing the input token ids and attention masks for the winner and loser.
    """
    batch_size = batch["input_ids"].size(0)
    assert batch_size % 2 == 0, "Batch size must be even."

    outputs = self.model(
        input_ids=batch["input_ids"],
        attention_mask=batch["attention_mask"],
        use_cache=self.use_cache,
    )

    rewards = outputs[0]
    chosen_reward = rewards[: batch_size // 2]
    rejected_rewards = rewards[batch_size // 2 :]
    loss = -torch.log(torch.sigmoid(chosen_reward - rejected_rewards)).mean()

    return {
        "chosen_reward": chosen_reward,
        "rejected_reward": rejected_rewards,
        "loss": loss,
    }

LLM Fine-tuning with AdaMerging

LayerWiseAdaMergingForLlamaSFT

Bases: BaseAlgorithm, LightningFabricMixin, SimpleProfilerMixin

Source code in fusion_bench/method/adamerging/llama_adamerging.py
class LayerWiseAdaMergingForLlamaSFT(
    BaseAlgorithm,
    LightningFabricMixin,
    SimpleProfilerMixin,
):

    modelpool: CausalLMPool

    def __init__(
        self,
        seed: int,
        output_dir: str,
        optimizer: str,
        lr: float,
        sparsity_ratio: Optional[float],
        average_attntion: bool,
        start_layer_idx: Optional[Union[float, int]],
        init_values: float,
        init_weights_path: str,
        clamp_weights: bool,
        normalized_merging_weights: bool,
        max_steps: int,
        tie_weights: bool,
        strict: bool,
        dataloader_kwargs: bool,
        skip_training: bool = False,
        save_interval: int = None,
        save_merged_model: bool = True,
        **kwargs,
    ):
        R"""
        Layer-wise AdaMerging algorithm for Llama models.
        Unlike the original AdaMerging algorithm that uses test-time adaptation training to optimize the entropy loss. This algorithm optimize the cross entropy loss.

        Args:
            seed (int): random seed to set at the begining of running.
            output_dir (str): directory to save the merged model. If `None`, the log directory will be used.
            optimizer (str): optimizer to use for training.
            lr (float): learning rate for training.
            sparsity_ratio (Optional[float]): ratio of zero weights in the task vectors. If `None`, no sparsity is enforced.
            average_attntion (bool): whether to average attention weights.
            start_layer_idx (Optional[Union[float, int]]): index of the layer to start merging.
            init_values (float): initial value for the merging weights.
            init_weights_path (str): path to the initial merging weights.
            clamp_weights (bool): whether to clamp the merging weights.
            normalized_merging_weights (bool): whether to normalize the merging weights.
            max_steps (int): maximum number of training steps.
            tie_weights (bool): whether to tie the weights of the same layer.
            strict (bool): whether to enforce strict merging.
            dataloader_kwargs (bool): keyword arguments for dataloaders.
            skip_training (bool): whether to skip training.
            save_interval (int): interval to save the merging weights. If `None`, no intermediate weights are saved. The weights are saved to `{output_dir}/checkpoints/merging-weights_{step_idx}.ckpt`.
            save_merged_model (bool): whether to save the merged model. This will save the model to `{output_dir}/checkpoints/merged_model`.
        """
        self.seed = seed
        self.output_dir = output_dir
        self.optimizer = optimizer
        self.lr = lr
        self.sparsity_ratio = sparsity_ratio
        self.average_attntion = average_attntion
        self.start_layer_idx = start_layer_idx
        self.init_values = init_values
        self.init_weights_path = init_weights_path
        self.clamp_weights = clamp_weights
        self.max_steps = max_steps
        self.tie_weights = tie_weights
        self.strict = strict
        self.normalized_merging_weights = normalized_merging_weights
        self.dataloader_kwargs = dataloader_kwargs
        self.skip_training = skip_training
        self.save_interval = save_interval
        self.save_merged_model = save_merged_model
        super().__init__(**kwargs)

    def run(self, modelpool: CausalLMPool):
        """
        Run the algorithm.

        Args:
            modelpool (CausalLMPool): The pool of models to be merged.

        Returns:
            The merged model.
        """
        self.modelpool = modelpool
        fabric = self.fabric

        assert (
            modelpool.has_pretrained
        ), "Must be a pre-tarined model with name `_pretrained_` in the model pool."
        log.info(f"There are {len(modelpool)} expert models in the model pool.")

        fabric.seed_everything(self.seed)

        if self.output_dir is None:
            log.warning(
                f"`output_dir` is not specified, set to log directory {self.log_dir}."
            )
            self.output_dir = fabric.logger.log_dir
        if fabric.global_rank == 0:
            os.makedirs(self.output_dir, exist_ok=True)

        with self.profile("construct_layer_wise_merged_model"):
            module = self.construct_layer_wise_merged_model(modelpool)
            if fabric.is_global_zero:
                print_parameters(module)

        if not self.skip_training:
            module = self.train(module)

        model = merge_and_unload(module)
        if self.save_merged_model:
            merged_model_path = os.path.join(
                self.output_dir, "checkpoints", "merged_model"
            )
            if self.fabric.global_rank == 0:
                modelpool.load_tokenizer().save_pretrained(merged_model_path)
                model.save_pretrained(merged_model_path)
                print_parameters(model)
        return model

    @torch.no_grad()
    def construct_layer_wise_merged_model(self, modelpool: CausalLMPool):
        """
        Constructs a wrapped layer-wise merged model from model pool.

        This method creates a new wrapped model by merging the layers of a pretrained model with those of several fine-tuned models.
        The merging is controlled by layer-wise weights, which is a `torch.Tensor` of the shape `(num_models, num_layers)`.
        The merging weights can be initialized based on a provided configuration or loaded from a file.

        Args:
            modelpool (ModelPool): An object containing the pretrained model and fine-tuned models to be merged.

        Returns:
            LayerWiseMergedModel: An instance of the merged model with layer-wise weights applied.
        """
        pretrained_causal_lm = modelpool.load_model("_pretrained_")

        # we only merge the backbone
        pretrained_model = pretrained_causal_lm.model.layers
        finetuned_models = [
            modelpool.load_model(name).model.layers for name in modelpool.model_names
        ]

        if self.start_layer_idx is not None and isinstance(self.start_layer_idx, float):
            self.start_layer_idx = int(self.start_layer_idx * len(pretrained_model))

        if self.start_layer_idx is not None:
            for layer_idx, layer in enumerate(pretrained_model[: self.start_layer_idx]):
                pretrained_model[layer_idx] = simple_average(
                    [m[layer_idx] for m in finetuned_models],
                    base_module=pretrained_model[layer_idx],
                )
                pretrained_model[layer_idx].requires_grad_(False)

        if self.average_attntion:
            for layer_idx, layer in enumerate(pretrained_model):
                if layer_idx < self.start_layer_idx:
                    continue
                layer.self_attn = simple_average(
                    [m[layer_idx].self_attn for m in finetuned_models],
                    base_module=layer.self_attn,
                )
                layer.self_attn.requires_grad_(False)

        # initialize layer-wise weights using the provided configuration `init_values` or load from file if `weights` is provided
        for layer_idx, layer in enumerate(pretrained_model):
            if layer_idx < self.start_layer_idx:
                continue
            layer_wise_weight = get_layer_wise_weights(
                num_models=len(modelpool.model_names),
                num_layers=len(
                    tuple(filter(lambda p: p.requires_grad, layer.parameters()))
                ),
                init_values=self.init_values,
                dtype=get_dtype(layer),
            )

            module = LayerWiseMergedModel(
                layer_wise_weight=layer_wise_weight,
                pretrained_model=pretrained_model[layer_idx],
                finetuned_models=[m[layer_idx] for m in finetuned_models],
                clamp_weights=self.clamp_weights,
                tie_weights=self.tie_weights,
                strict=self.strict,
                sparsity_ratio=self.sparsity_ratio,
                normalized_merging_weights=self.normalized_merging_weights,
            )

            pretrained_causal_lm.model.layers[layer_idx] = module

        fix_other_parts(pretrained_causal_lm)
        return pretrained_causal_lm

    def configure_optimizer(self, module: nn.Module):
        if self.optimizer == "adam":
            optimizer = torch.optim.Adam(
                [p for p in module.parameters() if p.requires_grad], lr=self.lr
            )
            return {"optimizer": optimizer}
        else:
            raise ValueError(f"Unknown optmizer type {self.optimizer}")

    def train(self, causal_lm):
        fabric = self.fabric
        modelpool = self.modelpool

        with self.profile("load datasets and setup dataloaders"):
            train_datasets = {
                dataset_name: modelpool.load_train_dataset(dataset_name)
                for dataset_name in modelpool.train_dataset_names
            }
            train_loaders = {
                dataset_name: fabric.setup_dataloaders(
                    DataLoader(
                        dataset,
                        **self.dataloader_kwargs,
                        collate_fn=default_data_collator,
                    )
                )
                for dataset_name, dataset in train_datasets.items()
            }
            train_loader_iters = {
                dataset_name: iter(InfiniteDataLoader(loader))
                for dataset_name, loader in train_loaders.items()
            }

        optimizer = self.configure_optimizer(causal_lm)["optimizer"]
        causal_lm, optimizer = cast(
            Tuple[nn.Module, torch.optim.Optimizer],
            fabric.setup(causal_lm, optimizer),
        )

        causal_lm.train()
        merge_weights(causal_lm)

        self.save_state("init", causal_lm)

        assert len(train_datasets) > 0, "No training datasets are provided."
        for step_idx in tqdm(range(self.max_steps)):
            log_metrics = {}

            losses = []
            for dataset_name, dataloader in train_loader_iters.items():
                # compute loss
                inputs = next(dataloader)
                outputs = causal_lm(**inputs)

                losses.append(outputs.loss)

            if len(losses) > 1:
                total_loss = sum(losses)
            else:
                total_loss = losses[0]

            log_metrics["train/loss"] = total_loss.item()

            fabric.backward(total_loss)
            optimizer.step()
            optimizer.zero_grad()

            if (
                self.save_interval is not None
                and (step_idx + 1) % self.save_interval == 0
            ):
                self.save_state(step_idx=step_idx, causal_lm=causal_lm)

            merge_weights(causal_lm)

            self.fabric.log_dict(log_metrics, step=step_idx)

        self.save_state("latest", causal_lm)

        return causal_lm

    def save_state(self, step_idx: Union[int, str], causal_lm):
        """
        Save merging weights of each layers. This method must be called at all processes.

        Args:
            step_idx (Union[int, str]): step index of the training.
            causal_lm (nn.Module): the model to save.
        """
        state = {}
        for layer_idx, layer in enumerate(causal_lm.model.layers):
            if isinstance(layer, LayerWiseMergedModel):
                state[f"layer_{layer_idx}"] = layer.merge_weight

        if self.fabric.is_global_zero:
            os.makedirs(os.path.join(self.output_dir, "checkpoints"), exist_ok=True)
        save_path = os.path.join(
            self.output_dir, "checkpoints", f"merging-weights_{step_idx}.ckpt"
        )
        if self.fabric.is_global_zero:
            log.info(f"Saving merging weights to {save_path}")
        self.fabric.save(save_path, state)

__init__(seed, output_dir, optimizer, lr, sparsity_ratio, average_attntion, start_layer_idx, init_values, init_weights_path, clamp_weights, normalized_merging_weights, max_steps, tie_weights, strict, dataloader_kwargs, skip_training=False, save_interval=None, save_merged_model=True, **kwargs)

Layer-wise AdaMerging algorithm for Llama models. Unlike the original AdaMerging algorithm that uses test-time adaptation training to optimize the entropy loss. This algorithm optimize the cross entropy loss.

Parameters:

  • seed (int) –

    random seed to set at the begining of running.

  • output_dir (str) –

    directory to save the merged model. If None, the log directory will be used.

  • optimizer (str) –

    optimizer to use for training.

  • lr (float) –

    learning rate for training.

  • sparsity_ratio (Optional[float]) –

    ratio of zero weights in the task vectors. If None, no sparsity is enforced.

  • average_attntion (bool) –

    whether to average attention weights.

  • start_layer_idx (Optional[Union[float, int]]) –

    index of the layer to start merging.

  • init_values (float) –

    initial value for the merging weights.

  • init_weights_path (str) –

    path to the initial merging weights.

  • clamp_weights (bool) –

    whether to clamp the merging weights.

  • normalized_merging_weights (bool) –

    whether to normalize the merging weights.

  • max_steps (int) –

    maximum number of training steps.

  • tie_weights (bool) –

    whether to tie the weights of the same layer.

  • strict (bool) –

    whether to enforce strict merging.

  • dataloader_kwargs (bool) –

    keyword arguments for dataloaders.

  • skip_training (bool, default: False ) –

    whether to skip training.

  • save_interval (int, default: None ) –

    interval to save the merging weights. If None, no intermediate weights are saved. The weights are saved to {output_dir}/checkpoints/merging-weights_{step_idx}.ckpt.

  • save_merged_model (bool, default: True ) –

    whether to save the merged model. This will save the model to {output_dir}/checkpoints/merged_model.

Source code in fusion_bench/method/adamerging/llama_adamerging.py
def __init__(
    self,
    seed: int,
    output_dir: str,
    optimizer: str,
    lr: float,
    sparsity_ratio: Optional[float],
    average_attntion: bool,
    start_layer_idx: Optional[Union[float, int]],
    init_values: float,
    init_weights_path: str,
    clamp_weights: bool,
    normalized_merging_weights: bool,
    max_steps: int,
    tie_weights: bool,
    strict: bool,
    dataloader_kwargs: bool,
    skip_training: bool = False,
    save_interval: int = None,
    save_merged_model: bool = True,
    **kwargs,
):
    R"""
    Layer-wise AdaMerging algorithm for Llama models.
    Unlike the original AdaMerging algorithm that uses test-time adaptation training to optimize the entropy loss. This algorithm optimize the cross entropy loss.

    Args:
        seed (int): random seed to set at the begining of running.
        output_dir (str): directory to save the merged model. If `None`, the log directory will be used.
        optimizer (str): optimizer to use for training.
        lr (float): learning rate for training.
        sparsity_ratio (Optional[float]): ratio of zero weights in the task vectors. If `None`, no sparsity is enforced.
        average_attntion (bool): whether to average attention weights.
        start_layer_idx (Optional[Union[float, int]]): index of the layer to start merging.
        init_values (float): initial value for the merging weights.
        init_weights_path (str): path to the initial merging weights.
        clamp_weights (bool): whether to clamp the merging weights.
        normalized_merging_weights (bool): whether to normalize the merging weights.
        max_steps (int): maximum number of training steps.
        tie_weights (bool): whether to tie the weights of the same layer.
        strict (bool): whether to enforce strict merging.
        dataloader_kwargs (bool): keyword arguments for dataloaders.
        skip_training (bool): whether to skip training.
        save_interval (int): interval to save the merging weights. If `None`, no intermediate weights are saved. The weights are saved to `{output_dir}/checkpoints/merging-weights_{step_idx}.ckpt`.
        save_merged_model (bool): whether to save the merged model. This will save the model to `{output_dir}/checkpoints/merged_model`.
    """
    self.seed = seed
    self.output_dir = output_dir
    self.optimizer = optimizer
    self.lr = lr
    self.sparsity_ratio = sparsity_ratio
    self.average_attntion = average_attntion
    self.start_layer_idx = start_layer_idx
    self.init_values = init_values
    self.init_weights_path = init_weights_path
    self.clamp_weights = clamp_weights
    self.max_steps = max_steps
    self.tie_weights = tie_weights
    self.strict = strict
    self.normalized_merging_weights = normalized_merging_weights
    self.dataloader_kwargs = dataloader_kwargs
    self.skip_training = skip_training
    self.save_interval = save_interval
    self.save_merged_model = save_merged_model
    super().__init__(**kwargs)

construct_layer_wise_merged_model(modelpool)

Constructs a wrapped layer-wise merged model from model pool.

This method creates a new wrapped model by merging the layers of a pretrained model with those of several fine-tuned models. The merging is controlled by layer-wise weights, which is a torch.Tensor of the shape (num_models, num_layers). The merging weights can be initialized based on a provided configuration or loaded from a file.

Parameters:

  • modelpool (ModelPool) –

    An object containing the pretrained model and fine-tuned models to be merged.

Returns:

  • LayerWiseMergedModel

    An instance of the merged model with layer-wise weights applied.

Source code in fusion_bench/method/adamerging/llama_adamerging.py
@torch.no_grad()
def construct_layer_wise_merged_model(self, modelpool: CausalLMPool):
    """
    Constructs a wrapped layer-wise merged model from model pool.

    This method creates a new wrapped model by merging the layers of a pretrained model with those of several fine-tuned models.
    The merging is controlled by layer-wise weights, which is a `torch.Tensor` of the shape `(num_models, num_layers)`.
    The merging weights can be initialized based on a provided configuration or loaded from a file.

    Args:
        modelpool (ModelPool): An object containing the pretrained model and fine-tuned models to be merged.

    Returns:
        LayerWiseMergedModel: An instance of the merged model with layer-wise weights applied.
    """
    pretrained_causal_lm = modelpool.load_model("_pretrained_")

    # we only merge the backbone
    pretrained_model = pretrained_causal_lm.model.layers
    finetuned_models = [
        modelpool.load_model(name).model.layers for name in modelpool.model_names
    ]

    if self.start_layer_idx is not None and isinstance(self.start_layer_idx, float):
        self.start_layer_idx = int(self.start_layer_idx * len(pretrained_model))

    if self.start_layer_idx is not None:
        for layer_idx, layer in enumerate(pretrained_model[: self.start_layer_idx]):
            pretrained_model[layer_idx] = simple_average(
                [m[layer_idx] for m in finetuned_models],
                base_module=pretrained_model[layer_idx],
            )
            pretrained_model[layer_idx].requires_grad_(False)

    if self.average_attntion:
        for layer_idx, layer in enumerate(pretrained_model):
            if layer_idx < self.start_layer_idx:
                continue
            layer.self_attn = simple_average(
                [m[layer_idx].self_attn for m in finetuned_models],
                base_module=layer.self_attn,
            )
            layer.self_attn.requires_grad_(False)

    # initialize layer-wise weights using the provided configuration `init_values` or load from file if `weights` is provided
    for layer_idx, layer in enumerate(pretrained_model):
        if layer_idx < self.start_layer_idx:
            continue
        layer_wise_weight = get_layer_wise_weights(
            num_models=len(modelpool.model_names),
            num_layers=len(
                tuple(filter(lambda p: p.requires_grad, layer.parameters()))
            ),
            init_values=self.init_values,
            dtype=get_dtype(layer),
        )

        module = LayerWiseMergedModel(
            layer_wise_weight=layer_wise_weight,
            pretrained_model=pretrained_model[layer_idx],
            finetuned_models=[m[layer_idx] for m in finetuned_models],
            clamp_weights=self.clamp_weights,
            tie_weights=self.tie_weights,
            strict=self.strict,
            sparsity_ratio=self.sparsity_ratio,
            normalized_merging_weights=self.normalized_merging_weights,
        )

        pretrained_causal_lm.model.layers[layer_idx] = module

    fix_other_parts(pretrained_causal_lm)
    return pretrained_causal_lm

run(modelpool)

Run the algorithm.

Parameters:

  • modelpool (CausalLMPool) –

    The pool of models to be merged.

Returns:

  • The merged model.

Source code in fusion_bench/method/adamerging/llama_adamerging.py
def run(self, modelpool: CausalLMPool):
    """
    Run the algorithm.

    Args:
        modelpool (CausalLMPool): The pool of models to be merged.

    Returns:
        The merged model.
    """
    self.modelpool = modelpool
    fabric = self.fabric

    assert (
        modelpool.has_pretrained
    ), "Must be a pre-tarined model with name `_pretrained_` in the model pool."
    log.info(f"There are {len(modelpool)} expert models in the model pool.")

    fabric.seed_everything(self.seed)

    if self.output_dir is None:
        log.warning(
            f"`output_dir` is not specified, set to log directory {self.log_dir}."
        )
        self.output_dir = fabric.logger.log_dir
    if fabric.global_rank == 0:
        os.makedirs(self.output_dir, exist_ok=True)

    with self.profile("construct_layer_wise_merged_model"):
        module = self.construct_layer_wise_merged_model(modelpool)
        if fabric.is_global_zero:
            print_parameters(module)

    if not self.skip_training:
        module = self.train(module)

    model = merge_and_unload(module)
    if self.save_merged_model:
        merged_model_path = os.path.join(
            self.output_dir, "checkpoints", "merged_model"
        )
        if self.fabric.global_rank == 0:
            modelpool.load_tokenizer().save_pretrained(merged_model_path)
            model.save_pretrained(merged_model_path)
            print_parameters(model)
    return model

save_state(step_idx, causal_lm)

Save merging weights of each layers. This method must be called at all processes.

Parameters:

  • step_idx (Union[int, str]) –

    step index of the training.

  • causal_lm (Module) –

    the model to save.

Source code in fusion_bench/method/adamerging/llama_adamerging.py
def save_state(self, step_idx: Union[int, str], causal_lm):
    """
    Save merging weights of each layers. This method must be called at all processes.

    Args:
        step_idx (Union[int, str]): step index of the training.
        causal_lm (nn.Module): the model to save.
    """
    state = {}
    for layer_idx, layer in enumerate(causal_lm.model.layers):
        if isinstance(layer, LayerWiseMergedModel):
            state[f"layer_{layer_idx}"] = layer.merge_weight

    if self.fabric.is_global_zero:
        os.makedirs(os.path.join(self.output_dir, "checkpoints"), exist_ok=True)
    save_path = os.path.join(
        self.output_dir, "checkpoints", f"merging-weights_{step_idx}.ckpt"
    )
    if self.fabric.is_global_zero:
        log.info(f"Saving merging weights to {save_path}")
    self.fabric.save(save_path, state)