Skip to content

Model Training/Fine-Tuning

CLIP vision model fine-tuning

ImageClassificationFineTuningForCLIP

Bases: CLIPClassificationMixin, SimpleProfilerMixin, ModelFusionAlgorithm

A class for fine-tuning CLIP models for image classification tasks.

Source code in fusion_bench/method/classification/clip_finetune.py
class ImageClassificationFineTuningForCLIP(
    CLIPClassificationMixin,
    SimpleProfilerMixin,
    ModelFusionAlgorithm,
):
    """
    A class for fine-tuning CLIP models for image classification tasks.
    """

    def run(self, modelpool: CLIPVisionModelPool):
        """
        Executes the fine-tuning process.

        Args:
            modelpool (CLIPVisionModelPool): The modelpool is responsible for loading the pre-trained model and training datasets.

        Returns:
            VisionModel: The fine-tuned vision model.
        """
        self.modelpool = to_modelpool(modelpool)
        config = self.config
        self.log_hyperparams(config, filename="method_config.yaml")
        self.finetune_method = "fine-tune"

        L.seed_everything(config.seed)

        task_names = modelpool.train_dataset_names
        with self.profile("setup model and optimizer"):
            processor, classifier, optimizer, lr_scheduler = self.setup_model()

            if config.state_dict_load_path is not None:
                self.fabric.load(
                    config.state_dict_load_path,
                    {"vision_model": classifier.clip_model.vision_model},
                )
                if config.skip_training:
                    return classifier.clip_model.vision_model

            self.setup_zero_shot_classification_head(
                clip_processor=processor,
                clip_model=classifier.clip_model,
                task_names=task_names,
            )

            self.fabric.setup(classifier, optimizer)

        with self.profile("setup data"):
            train_datasets = [
                CLIPDataset(modelpool.load_train_dataset(task_name), processor)
                for task_name in task_names
            ]
            train_dataloaders = [
                DataLoader(
                    dataset,
                    shuffle=True,
                    batch_size=config.batch_size,
                    num_workers=config.num_workers,
                )
                for dataset in train_datasets
            ]
            train_dataloaders = self.fabric.setup_dataloaders(*train_dataloaders)
            if not isinstance(train_dataloaders, (list, tuple)):
                train_dataloaders = [train_dataloaders]
            train_dataloader_iters = [
                iter(InfiniteDataLoader(loader)) for loader in train_dataloaders
            ]

        # train
        for step_idx in tqdm(
            range(config.num_steps),
            desc=self.finetune_method,
            disable=not self.fabric.is_global_zero,
            dynamic_ncols=True,
        ):
            optimizer.zero_grad()
            loss = 0
            for task, loader in zip(task_names, train_dataloader_iters):
                with self.profile("data loading"):
                    batch = next(loader)
                    images, labels = batch
                with self.profile("forward"):
                    classifier.zeroshot_weights = self.zeroshot_weights[task]
                    logits = classifier(images)
                loss = loss + nn.functional.cross_entropy(logits, labels)

            with self.profile("backward"):
                self.fabric.backward(loss)
            with self.profile("optimizer step"):
                optimizer.step()
                lr_scheduler.step()

            metrics = {"train/loss": loss}

            self.fabric.log_dict(metrics, step=step_idx)

            if (step_idx + 1) % config.save_interval == 0:
                save_path = os.path.join(
                    self.log_dir, "checkpoints", f"step={step_idx}.ckpt"
                )
                self.save_model(classifier, save_path)

        if config.state_dict_save_path is not None:
            self.save_model(classifier, config.state_dict_save_path)
        self.print_profile_summary()
        return classifier.clip_model.vision_model

    def save_model(
        self,
        model: HFCLIPClassifier | CLIPModel | CLIPVisionModel | CLIPVisionTransformer,
        save_path: str,
    ):
        """
        Save the vision model to the specified path.

        Args:
            model (Union[HFCLIPClassifier, CLIPModel, CLIPVisionModel, CLIPVisionTransformer]): The model to save.
            save_path (str): The path to save the model.
        """
        if isinstance(model, HFCLIPClassifier):
            vision_model = model.clip_model.vision_model
        elif isinstance(model, CLIPModel):
            vision_model = model.vision_model
        elif isinstance(model, CLIPVisionModel):
            vision_model = model.vision_model
        elif isinstance(model, CLIPVisionTransformer):
            vision_model = model
        else:
            raise ValueError(f"Unsupported model type: {type(model)}")

        save_dir = os.path.dirname(save_path)
        if save_dir and not os.path.exists(save_dir):
            os.makedirs(save_dir, exist_ok=True)
        self.fabric.save(save_path, {"vision_model": vision_model})

    def setup_model(self):
        """
        Sets up the model, optimizer, and learning rate scheduler.

        This method initializes the CLIP model, applies LoRA if specified, and configures the optimizer and learning rate scheduler.

        Returns:
            Tuple: A tuple containing the processor, classifier, optimizer, and learning rate scheduler.
        """
        config = self.config
        modelpool = self.modelpool

        clip_model: CLIPModel = modelpool.load_clip_model("_pretrained_")
        processor = modelpool.load_processor()

        self.finetune_method = "full fine-tune"
        if config.use_lora or config.use_l_lora:
            self.finetune_method = "lora fine-tune"
            lora_config = LoraConfig(
                **OmegaConf.to_container(
                    config.lora_config, resolve=True, enum_to_str=True
                )
            )
            clip_model.vision_model = get_peft_model(
                clip_model.vision_model, lora_config
            )

            if config.use_l_lora:
                # http://arxiv.org/abs/2310.04742
                # Anke Tang et al. Parameter Efficient Multi-task Model Fusion with Partial Linearization. ICLR 2024.
                self.finetune_method = "l-lora fine-tune"
                print("Linearizing Lora Layers")
                linearize_lora_model_(clip_model.vision_model)

        classifier = HFCLIPClassifier(clip_model, processor=processor)

        if self.fabric.is_global_zero:
            print("=== Model Summary (For Vision Model Only) ===")
            print_parameters(classifier.clip_model.vision_model)
        # configure optimizers
        optimizer = torch.optim.Adam(
            [
                p
                for p in classifier.clip_model.vision_model.parameters()
                if p.requires_grad
            ],
            lr=config.learning_rate,
            weight_decay=config.weight_decay,
        )
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer=optimizer, T_max=config.num_steps
        )

        return processor, classifier, optimizer, lr_scheduler

run(modelpool)

Executes the fine-tuning process.

Parameters:

  • modelpool (CLIPVisionModelPool) –

    The modelpool is responsible for loading the pre-trained model and training datasets.

Returns:

  • VisionModel

    The fine-tuned vision model.

Source code in fusion_bench/method/classification/clip_finetune.py
def run(self, modelpool: CLIPVisionModelPool):
    """
    Executes the fine-tuning process.

    Args:
        modelpool (CLIPVisionModelPool): The modelpool is responsible for loading the pre-trained model and training datasets.

    Returns:
        VisionModel: The fine-tuned vision model.
    """
    self.modelpool = to_modelpool(modelpool)
    config = self.config
    self.log_hyperparams(config, filename="method_config.yaml")
    self.finetune_method = "fine-tune"

    L.seed_everything(config.seed)

    task_names = modelpool.train_dataset_names
    with self.profile("setup model and optimizer"):
        processor, classifier, optimizer, lr_scheduler = self.setup_model()

        if config.state_dict_load_path is not None:
            self.fabric.load(
                config.state_dict_load_path,
                {"vision_model": classifier.clip_model.vision_model},
            )
            if config.skip_training:
                return classifier.clip_model.vision_model

        self.setup_zero_shot_classification_head(
            clip_processor=processor,
            clip_model=classifier.clip_model,
            task_names=task_names,
        )

        self.fabric.setup(classifier, optimizer)

    with self.profile("setup data"):
        train_datasets = [
            CLIPDataset(modelpool.load_train_dataset(task_name), processor)
            for task_name in task_names
        ]
        train_dataloaders = [
            DataLoader(
                dataset,
                shuffle=True,
                batch_size=config.batch_size,
                num_workers=config.num_workers,
            )
            for dataset in train_datasets
        ]
        train_dataloaders = self.fabric.setup_dataloaders(*train_dataloaders)
        if not isinstance(train_dataloaders, (list, tuple)):
            train_dataloaders = [train_dataloaders]
        train_dataloader_iters = [
            iter(InfiniteDataLoader(loader)) for loader in train_dataloaders
        ]

    # train
    for step_idx in tqdm(
        range(config.num_steps),
        desc=self.finetune_method,
        disable=not self.fabric.is_global_zero,
        dynamic_ncols=True,
    ):
        optimizer.zero_grad()
        loss = 0
        for task, loader in zip(task_names, train_dataloader_iters):
            with self.profile("data loading"):
                batch = next(loader)
                images, labels = batch
            with self.profile("forward"):
                classifier.zeroshot_weights = self.zeroshot_weights[task]
                logits = classifier(images)
            loss = loss + nn.functional.cross_entropy(logits, labels)

        with self.profile("backward"):
            self.fabric.backward(loss)
        with self.profile("optimizer step"):
            optimizer.step()
            lr_scheduler.step()

        metrics = {"train/loss": loss}

        self.fabric.log_dict(metrics, step=step_idx)

        if (step_idx + 1) % config.save_interval == 0:
            save_path = os.path.join(
                self.log_dir, "checkpoints", f"step={step_idx}.ckpt"
            )
            self.save_model(classifier, save_path)

    if config.state_dict_save_path is not None:
        self.save_model(classifier, config.state_dict_save_path)
    self.print_profile_summary()
    return classifier.clip_model.vision_model

save_model(model, save_path)

Save the vision model to the specified path.

Parameters:

  • model (Union[HFCLIPClassifier, CLIPModel, CLIPVisionModel, CLIPVisionTransformer]) –

    The model to save.

  • save_path (str) –

    The path to save the model.

Source code in fusion_bench/method/classification/clip_finetune.py
def save_model(
    self,
    model: HFCLIPClassifier | CLIPModel | CLIPVisionModel | CLIPVisionTransformer,
    save_path: str,
):
    """
    Save the vision model to the specified path.

    Args:
        model (Union[HFCLIPClassifier, CLIPModel, CLIPVisionModel, CLIPVisionTransformer]): The model to save.
        save_path (str): The path to save the model.
    """
    if isinstance(model, HFCLIPClassifier):
        vision_model = model.clip_model.vision_model
    elif isinstance(model, CLIPModel):
        vision_model = model.vision_model
    elif isinstance(model, CLIPVisionModel):
        vision_model = model.vision_model
    elif isinstance(model, CLIPVisionTransformer):
        vision_model = model
    else:
        raise ValueError(f"Unsupported model type: {type(model)}")

    save_dir = os.path.dirname(save_path)
    if save_dir and not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    self.fabric.save(save_path, {"vision_model": vision_model})

setup_model()

Sets up the model, optimizer, and learning rate scheduler.

This method initializes the CLIP model, applies LoRA if specified, and configures the optimizer and learning rate scheduler.

Returns:

  • Tuple

    A tuple containing the processor, classifier, optimizer, and learning rate scheduler.

Source code in fusion_bench/method/classification/clip_finetune.py
def setup_model(self):
    """
    Sets up the model, optimizer, and learning rate scheduler.

    This method initializes the CLIP model, applies LoRA if specified, and configures the optimizer and learning rate scheduler.

    Returns:
        Tuple: A tuple containing the processor, classifier, optimizer, and learning rate scheduler.
    """
    config = self.config
    modelpool = self.modelpool

    clip_model: CLIPModel = modelpool.load_clip_model("_pretrained_")
    processor = modelpool.load_processor()

    self.finetune_method = "full fine-tune"
    if config.use_lora or config.use_l_lora:
        self.finetune_method = "lora fine-tune"
        lora_config = LoraConfig(
            **OmegaConf.to_container(
                config.lora_config, resolve=True, enum_to_str=True
            )
        )
        clip_model.vision_model = get_peft_model(
            clip_model.vision_model, lora_config
        )

        if config.use_l_lora:
            # http://arxiv.org/abs/2310.04742
            # Anke Tang et al. Parameter Efficient Multi-task Model Fusion with Partial Linearization. ICLR 2024.
            self.finetune_method = "l-lora fine-tune"
            print("Linearizing Lora Layers")
            linearize_lora_model_(clip_model.vision_model)

    classifier = HFCLIPClassifier(clip_model, processor=processor)

    if self.fabric.is_global_zero:
        print("=== Model Summary (For Vision Model Only) ===")
        print_parameters(classifier.clip_model.vision_model)
    # configure optimizers
    optimizer = torch.optim.Adam(
        [
            p
            for p in classifier.clip_model.vision_model.parameters()
            if p.requires_grad
        ],
        lr=config.learning_rate,
        weight_decay=config.weight_decay,
    )
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer=optimizer, T_max=config.num_steps
    )

    return processor, classifier, optimizer, lr_scheduler

ContinualImageClassificationFineTuningForCLIP

Bases: CLIPClassificationMixin, SimpleProfilerMixin, BaseAlgorithm

Source code in fusion_bench/method/classification/continual_clip_finetune.py
class ContinualImageClassificationFineTuningForCLIP(
    CLIPClassificationMixin,
    SimpleProfilerMixin,
    BaseAlgorithm,
):
    # attributes to configuration keys mapping
    _config_mapping = BaseAlgorithm._config_mapping | {
        "seed": "seed",
        "shuffle_order": "shuffle_order",
        "learning_rate": "learning_rate",
        "weight_decay": "weight_decay",
        "num_steps": "num_steps",
        "batch_size": "batch_size",
        "num_workers": "num_workers",
        "save_interval": "save_interval",
        "state_dict_load_path": "state_dict_load_path",
        "state_dict_save_path": "state_dict_save_path",
        "skip_training": "skip_training",
        "use_lora": "use_lora",
        "lora_config": "lora_config",
    }

    def __init__(
        self,
        seed: int = 42,
        shuffle_order: bool = True,
        learning_rate: float = 1e-5,
        weight_decay: float = 0,
        num_steps: int = 4000,
        batch_size: int = 128,
        num_workers: int = 16,
        save_interval: int = 500,
        state_dict_load_path: Optional[str] = None,
        state_dict_save_path: Optional[str] = None,
        skip_training: bool = False,
        use_lora: bool = False,
        lora_config: Optional[LoraConfig] = None,
    ):
        self.seed = seed
        self.shuffle_order = shuffle_order
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.num_steps = num_steps
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.save_interval = save_interval
        self.state_dict_load_path = state_dict_load_path
        self.state_dict_save_path = state_dict_save_path
        self.skip_training = skip_training
        self.use_lora = use_lora
        self.lora_config = lora_config

    def run(self, modelpool: CLIPVisionModelPool):
        self.modelpool = to_modelpool(modelpool)
        config = self.config
        self.log_hyperparams(config, filename="method_config.yaml")
        self.finetune_method = "fine-tune"

        if self.seed is not None:
            L.seed_everything(self.seed)
        else:
            seed_everything_by_time(self.fabric)

        task_names = list(modelpool.train_dataset_names)
        if self.shuffle_order:
            random.shuffle(task_names)
        if self.fabric.is_global_zero:
            save_to_json(task_names, os.path.join(self.log_dir, "task_names.json"))

        if self._program.taskpool is not None and isinstance(
            self._program.taskpool, CLIPVisionModelTaskPool
        ):
            has_taskpool = True
            taskpool = cast(CLIPVisionModelTaskPool, self._program.taskpool)
            test_datasets = taskpool._test_datasets
        else:
            has_taskpool = False

        with self.profile("setup model and optimizer"):
            processor, classifier, optimizer, lr_scheduler = self.setup_model()

            if self.state_dict_load_path is not None:
                self.fabric.load(
                    self.state_dict_load_path,
                    {"vision_model": classifier.clip_model.vision_model},
                )
                if self.skip_training:
                    return classifier.clip_model.vision_model

            self.setup_zero_shot_classification_head(
                clip_processor=processor,
                clip_model=classifier.clip_model,
                task_names=task_names,
            )

            init_optimizer_state_dict = optimizer.state_dict()
            init_lr_scheduler_state_dict = lr_scheduler.state_dict()
            self.fabric.setup(classifier, optimizer)

        with self.profile("setup data"):
            train_datasets = [
                CLIPDataset(modelpool.load_train_dataset(task_name), processor)
                for task_name in task_names
            ]
            train_dataloaders = [
                DataLoader(
                    dataset,
                    shuffle=True,
                    batch_size=self.batch_size,
                    num_workers=self.num_workers,
                )
                for dataset in train_datasets
            ]
            train_dataloaders = self.fabric.setup_dataloaders(*train_dataloaders)
            if not isinstance(train_dataloaders, (list, tuple)):
                train_dataloaders = [train_dataloaders]
            train_dataloader_iters = [
                iter(InfiniteDataLoader(loader)) for loader in train_dataloaders
            ]

        # continual train
        for task_idx, task_name in tqdm(
            enumerate(task_names),
            dynamic_ncols=True,
            disable=not self.fabric.is_global_zero,
        ):
            train_dataloader_iter = train_dataloader_iters[task_idx]

            # reset optimizer and lr scheduler
            print("reset optimizer and lr scheduler")
            optimizer.load_state_dict(init_optimizer_state_dict)
            lr_scheduler.load_state_dict(init_lr_scheduler_state_dict)

            for step_idx in tqdm(
                range(self.num_steps),
                desc=f"continual fine-tune on {task_name}",
                disable=not self.fabric.is_global_zero,
                dynamic_ncols=True,
                leave=False,
            ):
                optimizer.zero_grad()
                loss = 0
                with self.profile("data loading"):
                    batch = next(train_dataloader_iter)
                    images, labels = batch
                with self.profile("forward"):
                    classifier.zeroshot_weights = self.zeroshot_weights[task_name]
                    logits = classifier(images)
                    assert (
                        labels.max() < logits.shape[1]
                    ), f"for task {task_name}, labels.max() = {labels.max()}, logits.shape[1] = {logits.shape[1]}"
                loss = loss + nn.functional.cross_entropy(logits, labels)

                with self.profile("backward"):
                    self.fabric.backward(loss)
                with self.profile("optimizer step"):
                    optimizer.step()
                    lr_scheduler.step()

                metrics = {"train/loss": loss}
                self.fabric.log_dict(metrics, step=step_idx)

                if (step_idx + 1) % self.save_interval == 0:
                    save_path = os.path.join(
                        self.log_dir,
                        "checkpoints",
                        f"task={task_idx}_step={step_idx}.ckpt",
                    )
                    self.save_model(classifier, save_path)

            if has_taskpool:
                taskpool._is_setup = False
                taskpool._test_datasets = DictConfig(
                    {t: test_datasets[t] for t in task_names[: task_idx + 1]}
                )
                eval_report = taskpool.evaluate(
                    deepcopy(classifier.clip_model.vision_model),
                    name=task_name,
                )
                if self.fabric.is_global_zero:
                    save_to_json(
                        eval_report,
                        os.path.join(self.log_dir, f"results_{task_idx}.json"),
                    )

        if self.state_dict_save_path is not None:
            self.save_model(classifier, self.state_dict_save_path)
        self.print_profile_summary()
        return classifier.clip_model.vision_model

    def save_model(
        self,
        model: HFCLIPClassifier | CLIPModel | CLIPVisionModel | CLIPVisionTransformer,
        save_path: str,
    ):
        """
        Save the vision model to the specified path.

        Args:
            model (Union[HFCLIPClassifier, CLIPModel, CLIPVisionModel, CLIPVisionTransformer]): The model to save.
            save_path (str): The path to save the model.
        """
        if isinstance(model, HFCLIPClassifier):
            vision_model = model.clip_model.vision_model
        elif isinstance(model, CLIPModel):
            vision_model = model.vision_model
        elif isinstance(model, CLIPVisionModel):
            vision_model = model.vision_model
        elif isinstance(model, CLIPVisionTransformer):
            vision_model = model
        else:
            raise ValueError(f"Unsupported model type: {type(model)}")

        save_dir = os.path.dirname(save_path)
        if save_dir and not os.path.exists(save_dir):
            os.makedirs(save_dir, exist_ok=True)
        self.fabric.save(save_path, {"vision_model": vision_model})

    def setup_model(self):
        """
        Sets up the model, optimizer, and learning rate scheduler.

        This method initializes the CLIP model, applies LoRA if specified, and configures the optimizer and learning rate scheduler.

        Returns:
            Tuple: A tuple containing the processor, classifier, optimizer, and learning rate scheduler.
        """
        config = self.config
        modelpool = self.modelpool

        clip_model: CLIPModel = modelpool.load_clip_model("_pretrained_")
        processor = modelpool.load_processor()

        self.finetune_method = "full fine-tune"
        if self.use_lora:
            self.finetune_method = "lora fine-tune"
            lora_config = LoraConfig(
                **OmegaConf.to_container(
                    self.lora_config, resolve=True, enum_to_str=True
                )
            )
            clip_model.vision_model = get_peft_model(
                clip_model.vision_model, lora_config
            )

        classifier = HFCLIPClassifier(clip_model, processor=processor)

        if self.fabric.is_global_zero:
            print("=== Model Summary (For Vision Model Only) ===")
            print_parameters(classifier.clip_model.vision_model)
        # configure optimizers
        optimizer = torch.optim.Adam(
            [
                p
                for p in classifier.clip_model.vision_model.parameters()
                if p.requires_grad
            ],
            lr=self.learning_rate,
            weight_decay=self.weight_decay,
        )
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer=optimizer, T_max=self.num_steps
        )

        return processor, classifier, optimizer, lr_scheduler

save_model(model, save_path)

Save the vision model to the specified path.

Parameters:

  • model (Union[HFCLIPClassifier, CLIPModel, CLIPVisionModel, CLIPVisionTransformer]) –

    The model to save.

  • save_path (str) –

    The path to save the model.

Source code in fusion_bench/method/classification/continual_clip_finetune.py
def save_model(
    self,
    model: HFCLIPClassifier | CLIPModel | CLIPVisionModel | CLIPVisionTransformer,
    save_path: str,
):
    """
    Save the vision model to the specified path.

    Args:
        model (Union[HFCLIPClassifier, CLIPModel, CLIPVisionModel, CLIPVisionTransformer]): The model to save.
        save_path (str): The path to save the model.
    """
    if isinstance(model, HFCLIPClassifier):
        vision_model = model.clip_model.vision_model
    elif isinstance(model, CLIPModel):
        vision_model = model.vision_model
    elif isinstance(model, CLIPVisionModel):
        vision_model = model.vision_model
    elif isinstance(model, CLIPVisionTransformer):
        vision_model = model
    else:
        raise ValueError(f"Unsupported model type: {type(model)}")

    save_dir = os.path.dirname(save_path)
    if save_dir and not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    self.fabric.save(save_path, {"vision_model": vision_model})

setup_model()

Sets up the model, optimizer, and learning rate scheduler.

This method initializes the CLIP model, applies LoRA if specified, and configures the optimizer and learning rate scheduler.

Returns:

  • Tuple

    A tuple containing the processor, classifier, optimizer, and learning rate scheduler.

Source code in fusion_bench/method/classification/continual_clip_finetune.py
def setup_model(self):
    """
    Sets up the model, optimizer, and learning rate scheduler.

    This method initializes the CLIP model, applies LoRA if specified, and configures the optimizer and learning rate scheduler.

    Returns:
        Tuple: A tuple containing the processor, classifier, optimizer, and learning rate scheduler.
    """
    config = self.config
    modelpool = self.modelpool

    clip_model: CLIPModel = modelpool.load_clip_model("_pretrained_")
    processor = modelpool.load_processor()

    self.finetune_method = "full fine-tune"
    if self.use_lora:
        self.finetune_method = "lora fine-tune"
        lora_config = LoraConfig(
            **OmegaConf.to_container(
                self.lora_config, resolve=True, enum_to_str=True
            )
        )
        clip_model.vision_model = get_peft_model(
            clip_model.vision_model, lora_config
        )

    classifier = HFCLIPClassifier(clip_model, processor=processor)

    if self.fabric.is_global_zero:
        print("=== Model Summary (For Vision Model Only) ===")
        print_parameters(classifier.clip_model.vision_model)
    # configure optimizers
    optimizer = torch.optim.Adam(
        [
            p
            for p in classifier.clip_model.vision_model.parameters()
            if p.requires_grad
        ],
        lr=self.learning_rate,
        weight_decay=self.weight_decay,
    )
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer=optimizer, T_max=self.num_steps
    )

    return processor, classifier, optimizer, lr_scheduler

LLM Fine-tuning

FullFinetuneSFT

Bases: BaseAlgorithm, FabricTrainingMixin

Source code in fusion_bench/method/lm_finetune/fullfinetune_sft.py
class FullFinetuneSFT(BaseAlgorithm, FabricTrainingMixin):

    model: Union[nn.Module, "_FabricModule", "LlamaForCausalLM"]
    optimizer: Union[torch.optim.Optimizer, "_FabricOptimizer"]
    train_dataloader: Union[DataLoader, "_FabricDataLoader"]
    lr_scheduler: torch.optim.lr_scheduler.LRScheduler
    _latest_saved_checkpoint_global_step: int = -1

    def __init__(
        self,
        optimizer: DictConfig,
        lr_scheduler: Optional[DictConfig],
        dataloader_kwargs: DictConfig,
        max_epochs: int,
        max_steps: int = -1,
        max_steps_per_epoch: int = -1,
        lr_scheduler_interval: Literal["epoch", "step"] = "step",
        lr_scheduler_frequency: int = 1,
        checkpoint_save_interval: Literal["epoch", "step"] = "epoch",
        checkpoint_save_frequency: int = 1,
        accumulate_grad_batches: int = 1,
        gradient_clip_val: Optional[float] = None,
        gradient_clip_algorithm: Literal["value", "norm"] = "norm",
        save_optimizer_state: bool = False,
        save_full_model: bool = False,
        save_ckpt_type: Literal["lightning", "hf"] = "lightning",
        ckpt_path: Optional[str] = None,
        max_length: int = 6144,
        fix_token_embedding: bool = True,
        **kwargs,
    ):
        """
        Class for full finetuning of a language model on given SFT datasets.

        Args:
            optimizer(DictConfig): Configuration for the optimizer.
            lr_scheduler(DictConfig): Configuration for the learning rate scheduler.
            dataloader_kwargs(DictConfig): Configuration for the dataloader, such as batch size, num_workers, etc.
            max_epochs(int): Maximum number of epochs to train the model. If set to -1, the training will continue indefinitely or until max_steps is reached.
            max_steps(int): Maximum number of steps to train the model. If set to -1, the training will continue indefinitely or until max_epochs is reached.
            max_steps_per_epoch(int): Maximum number of steps to train the model in each epoch. If set to -1, the training will continue until the end of the epoch.
            lr_scheduler_interval(str): Interval at which to run the learning rate scheduler. Available options: 'epoch', 'step'. If set to 'epoch', the scheduler will run at the end of each epoch. If set to 'step', the scheduler will run at the end of each step.
            lr_scheduler_frequency(int): Frequency at which to run the learning rate scheduler. The scheduler will run every `lr_scheduler_frequency` epochs or steps, depending on the value of `lr_scheduler_interval`.
            checkpoint_save_interval(str): Interval at which to save the model checkpoint. Available options: 'epoch', 'step'. If set to 'epoch', the model will be saved at the end of each epoch. If set to 'step', the model will be saved at the end of each step.
            checkpoint_save_frequency(int): Frequency at which to save the model checkpoint. The model will be saved every `checkpoint_save_frequency` epochs or steps, depending on the value of `checkpoint_save_interval`.
            accumulate_grad_batches(int): Number of batches to accumulate gradients across before updating the model parameters.
            gradient_clip_val(float): Value to clip the gradients. If set to None, no gradient clipping will be applied.
            gradient_clip_algorithm(str): Algorithm to use for gradient clipping. Available options: 'value', 'norm'. If set to 'value', the gradients will be clipped to the specified value. If set to 'norm', the gradients will be clipped to the specified norm.
            save_optimizer_state(bool): Whether to save the optimizer and lr_scheduler state along with the model checkpoint.
            save_full_model(bool): Whether to save the full model or only the trainable parameters in the model checkpoint.
            save_ckpt_type (str): Type of checkpoint to save. Available options: 'lightning', 'hf'. If set to 'lightning', the checkpoint will be saved in the lightning format. If set to 'hf', the checkpoint will be saved in the huggingface format.
            ckpt_path(str): Path to the checkpoint to load before training. If set to None, no checkpoint will be loaded.
            max_length(int): Maximum input length to consider. If the input length exceeds this value, it will be truncated.
            fix_token_embedding(bool): Whether to fix the token embeddings during training. If set to True, the token embeddings will not be updated during training.
        """
        self._optimizer = optimizer
        self._lr_scheduler = lr_scheduler
        self.dataloader_kwargs = dataloader_kwargs
        self.max_epochs = max_epochs
        self.max_steps = max_steps
        self.max_steps_per_epoch = max_steps_per_epoch
        self.lr_scheduler_interval = lr_scheduler_interval
        self.lr_scheduler_frequency = lr_scheduler_frequency
        self.checkpoint_save_interval = checkpoint_save_interval
        self.checkpoint_save_frequency = checkpoint_save_frequency
        self.accumulate_grad_batches = accumulate_grad_batches
        self.gradient_clip_val = gradient_clip_val
        self.gradient_clip_algorithm = gradient_clip_algorithm
        self.save_optimizer_state = save_optimizer_state
        self.save_full_model = save_full_model
        self.save_ckpt_type = save_ckpt_type
        self.ckpt_path = ckpt_path
        self.max_length = max_length
        self.fix_token_embedding = fix_token_embedding
        super().__init__(**kwargs)

    def run(self, modelpool: CausalLMPool):
        self.modelpool = modelpool
        self.setup()
        self.train(self.model, self.optimizer, self.lr_scheduler)
        return self.model

    def setup_model(self):
        self.tokenizer = self.modelpool.load_tokenizer()
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

        model = self.modelpool.load_pretrained_model()
        self.model: "LlamaForCausalLM" = model

        if self.fix_token_embedding:
            self.model.model.embed_tokens.requires_grad_(False)

        if self.fabric.strategy == "fsdp" or isinstance(
            self.fabric.strategy, FSDPStrategy
        ):
            # https://github.com/Lightning-AI/pytorch-lightning/issues/19267
            self.model.gradient_checkpointing_enable(
                gradient_checkpointing_kwargs={"use_reentrant": True}
            )
            self.use_cache = False
        else:
            self.use_cache = True
        self.model_dtype = get_dtype(self.model)

    def configure_optimizer(self):
        # compute expected total steps
        self.compute_expected_total_steps(self.train_dataloader)

        optimizer = instantiate(self._optimizer, self.model.parameters())
        if self._lr_scheduler is not None:
            for key, arg in self._lr_scheduler.items():
                if arg == "_T_max_":
                    log.info(
                        f"Setting key `{key}` of lr_scheduler configuration to {self.expected_total_steps}"
                    )
                    self._lr_scheduler[key] = self.expected_total_steps
            lr_scheduler: torch.optim.lr_scheduler.LRScheduler = instantiate(
                self._lr_scheduler,
                optimizer=optimizer,
            )
        else:
            lr_scheduler = None
        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}

    def setup_data(self):
        fabric = self.fabric
        modelpool = self.modelpool
        assert (
            len(modelpool.train_dataset_names) > 0
        ), "No training datasets found in modelpool."

        train_datasets = [
            modelpool.load_train_dataset(dataset_name)
            for dataset_name in modelpool.train_dataset_names
        ]
        if len(train_datasets) > 1:
            train_dataset = ConcatDataset(train_datasets)
        else:
            train_dataset = train_datasets[0]

        self.train_dataset = train_dataset
        self.train_dataloader = DataLoader(
            train_dataset,
            **self.dataloader_kwargs,
            shuffle=True,
            collate_fn=functools.partial(
                padded_collate_sft, pad_token_id=self.tokenizer.pad_token_id
            ),
        )
        self.train_dataloader = fabric.setup_dataloaders(self.train_dataloader)

    def setup(self):
        fabric = self.fabric

        self.setup_model()
        self.setup_data()

        optimizer = self.configure_optimizer()
        optimizer, lr_scheduler = optimizer["optimizer"], optimizer["lr_scheduler"]

        self.model, self.optimizer = fabric.setup(self.model, optimizer)
        self.lr_scheduler = lr_scheduler

    @override
    def train_epoch(self, *args, **kwargs):
        fabric = self.fabric

        accumulated_loss = 0
        for step_idx, batch in enumerate(
            pbar := tqdm(
                self.train_dataloader,
                desc="Training Batches",
                dynamic_ncols=True,
                leave=False,
                disable=not fabric.is_global_zero,
            )
        ):
            is_accumulating = (step_idx + 1) % self.accumulate_grad_batches != 0

            if self.max_length > 0 and batch["input_ids"].shape[1] > self.max_length:
                log.warning(
                    f"Input length exceeds max_length: {batch['input_ids'].shape[1]} > {self.max_length}. Truncating input."
                )
                batch["input_ids"] = batch["input_ids"][:, : self.max_length]
                batch["attention_mask"] = batch["attention_mask"][:, : self.max_length]
                batch["labels"] = batch["labels"][:, : self.max_length]

            # disable gradient synchronization if accumulating gradients across steps for improved performance
            with fabric.no_backward_sync(self.model, enabled=is_accumulating):
                # use_cache=True is not compatible with gradient checkpointing, so we disable it here
                output = self.model(
                    input_ids=batch["input_ids"],
                    attention_mask=batch["attention_mask"],
                    labels=batch["labels"],
                    use_cache=self.use_cache,
                )
                loss = output["loss"] / self.accumulate_grad_batches

                fabric.backward(loss)
                accumulated_loss += loss.item()

            if not is_accumulating:
                self.clip_gradients_if_needed(self.model, self.optimizer)

                # run lr_scheduler at the end of the step if interval is set to "step"
                if (
                    self.lr_scheduler_interval == "step"
                    and (self.global_step_idx + 1) % self.lr_scheduler_frequency == 0
                ):
                    self.lr_scheduler.step()

                # update the model parameters and zero the gradients
                self.optimizer.step()
                self.optimizer.zero_grad()

                metrics = {
                    "train/loss": accumulated_loss,
                    "train/epoch_idx": self.epoch_idx,
                    "train/lr": self.optimizer.param_groups[0]["lr"],
                }
                fabric.log_dict(metrics, step=self.global_step_idx)
                pbar.set_postfix(metrics)

                # save the model at the end of the step if interval is set to "step" and frequency is met
                self.conditional_checkpoint_save(stage="end_of_step")

                # break if max_steps_per_epoch is set, and exit epoch
                if (
                    self.max_steps_per_epoch > 0
                    and step_idx + 1 >= self.max_steps_per_epoch
                ):
                    break
                # break if max_steps is set, and exit training
                if self.max_steps > 0 and self.global_step_idx >= self.max_steps - 1:
                    self.is_training = False
                    break

                self.global_step_idx += 1
                accumulated_loss = 0

    def save_checkpoint(
        self,
        path: Union[str, Path],
        save_optimizer_state: Optional[bool] = None,
        overwrite: bool = False,
    ):
        if not overwrite and os.path.exists(path):
            return log.warning(f"Checkpoint already exists at {path}. Skipping save.")

        fabric = self.fabric

        if self.save_ckpt_type == "lightning":
            state = {"model": self.model}

            # save the optimizer and lr_scheduler state if needed
            if self.save_optimizer_state and save_optimizer_state is not False:
                state.update(
                    {
                        "optimizer": self.optimizer,
                        "lr_scheduler": self.lr_scheduler,
                        "global_step_idx": self.global_step_idx,
                        "epoch_idx": self.epoch_idx,
                    }
                )

            trainable_param_names = set(
                name
                for name, param in self.model.state_dict(keep_vars=True).items()
                if param.requires_grad
            )
            filter = (
                None
                if self.save_full_model
                else {"model": lambda k, p: k in trainable_param_names}
            )

            fabric.save(path, state=state, filter=filter)
        else:
            self.model.save_pretrained(path, is_main_process=fabric.is_global_zero)

        self._latest_saved_checkpoint_global_step = self.global_step_idx

    def load_checkpoint(self, path: Union[str, Path]):
        fabric = self.fabric

        state = {"model": self.model}

        # save the optimizer and lr_scheduler state if needed
        if self.save_optimizer_state:
            state.update(
                {
                    "optimizer": self.optimizer,
                    "lr_scheduler": self.lr_scheduler,
                }
            )

        fabric.load(path, state)

__init__(optimizer, lr_scheduler, dataloader_kwargs, max_epochs, max_steps=-1, max_steps_per_epoch=-1, lr_scheduler_interval='step', lr_scheduler_frequency=1, checkpoint_save_interval='epoch', checkpoint_save_frequency=1, accumulate_grad_batches=1, gradient_clip_val=None, gradient_clip_algorithm='norm', save_optimizer_state=False, save_full_model=False, save_ckpt_type='lightning', ckpt_path=None, max_length=6144, fix_token_embedding=True, **kwargs)

Class for full finetuning of a language model on given SFT datasets.

Parameters:

  • optimizer(DictConfig)

    Configuration for the optimizer.

  • lr_scheduler(DictConfig)

    Configuration for the learning rate scheduler.

  • dataloader_kwargs(DictConfig)

    Configuration for the dataloader, such as batch size, num_workers, etc.

  • max_epochs(int)

    Maximum number of epochs to train the model. If set to -1, the training will continue indefinitely or until max_steps is reached.

  • max_steps(int)

    Maximum number of steps to train the model. If set to -1, the training will continue indefinitely or until max_epochs is reached.

  • max_steps_per_epoch(int)

    Maximum number of steps to train the model in each epoch. If set to -1, the training will continue until the end of the epoch.

  • lr_scheduler_interval(str)

    Interval at which to run the learning rate scheduler. Available options: 'epoch', 'step'. If set to 'epoch', the scheduler will run at the end of each epoch. If set to 'step', the scheduler will run at the end of each step.

  • lr_scheduler_frequency(int)

    Frequency at which to run the learning rate scheduler. The scheduler will run every lr_scheduler_frequency epochs or steps, depending on the value of lr_scheduler_interval.

  • checkpoint_save_interval(str)

    Interval at which to save the model checkpoint. Available options: 'epoch', 'step'. If set to 'epoch', the model will be saved at the end of each epoch. If set to 'step', the model will be saved at the end of each step.

  • checkpoint_save_frequency(int)

    Frequency at which to save the model checkpoint. The model will be saved every checkpoint_save_frequency epochs or steps, depending on the value of checkpoint_save_interval.

  • accumulate_grad_batches(int)

    Number of batches to accumulate gradients across before updating the model parameters.

  • gradient_clip_val(float)

    Value to clip the gradients. If set to None, no gradient clipping will be applied.

  • gradient_clip_algorithm(str)

    Algorithm to use for gradient clipping. Available options: 'value', 'norm'. If set to 'value', the gradients will be clipped to the specified value. If set to 'norm', the gradients will be clipped to the specified norm.

  • save_optimizer_state(bool)

    Whether to save the optimizer and lr_scheduler state along with the model checkpoint.

  • save_full_model(bool)

    Whether to save the full model or only the trainable parameters in the model checkpoint.

  • save_ckpt_type (str, default: 'lightning' ) –

    Type of checkpoint to save. Available options: 'lightning', 'hf'. If set to 'lightning', the checkpoint will be saved in the lightning format. If set to 'hf', the checkpoint will be saved in the huggingface format.

  • ckpt_path(str)

    Path to the checkpoint to load before training. If set to None, no checkpoint will be loaded.

  • max_length(int)

    Maximum input length to consider. If the input length exceeds this value, it will be truncated.

  • fix_token_embedding(bool)

    Whether to fix the token embeddings during training. If set to True, the token embeddings will not be updated during training.

Source code in fusion_bench/method/lm_finetune/fullfinetune_sft.py
def __init__(
    self,
    optimizer: DictConfig,
    lr_scheduler: Optional[DictConfig],
    dataloader_kwargs: DictConfig,
    max_epochs: int,
    max_steps: int = -1,
    max_steps_per_epoch: int = -1,
    lr_scheduler_interval: Literal["epoch", "step"] = "step",
    lr_scheduler_frequency: int = 1,
    checkpoint_save_interval: Literal["epoch", "step"] = "epoch",
    checkpoint_save_frequency: int = 1,
    accumulate_grad_batches: int = 1,
    gradient_clip_val: Optional[float] = None,
    gradient_clip_algorithm: Literal["value", "norm"] = "norm",
    save_optimizer_state: bool = False,
    save_full_model: bool = False,
    save_ckpt_type: Literal["lightning", "hf"] = "lightning",
    ckpt_path: Optional[str] = None,
    max_length: int = 6144,
    fix_token_embedding: bool = True,
    **kwargs,
):
    """
    Class for full finetuning of a language model on given SFT datasets.

    Args:
        optimizer(DictConfig): Configuration for the optimizer.
        lr_scheduler(DictConfig): Configuration for the learning rate scheduler.
        dataloader_kwargs(DictConfig): Configuration for the dataloader, such as batch size, num_workers, etc.
        max_epochs(int): Maximum number of epochs to train the model. If set to -1, the training will continue indefinitely or until max_steps is reached.
        max_steps(int): Maximum number of steps to train the model. If set to -1, the training will continue indefinitely or until max_epochs is reached.
        max_steps_per_epoch(int): Maximum number of steps to train the model in each epoch. If set to -1, the training will continue until the end of the epoch.
        lr_scheduler_interval(str): Interval at which to run the learning rate scheduler. Available options: 'epoch', 'step'. If set to 'epoch', the scheduler will run at the end of each epoch. If set to 'step', the scheduler will run at the end of each step.
        lr_scheduler_frequency(int): Frequency at which to run the learning rate scheduler. The scheduler will run every `lr_scheduler_frequency` epochs or steps, depending on the value of `lr_scheduler_interval`.
        checkpoint_save_interval(str): Interval at which to save the model checkpoint. Available options: 'epoch', 'step'. If set to 'epoch', the model will be saved at the end of each epoch. If set to 'step', the model will be saved at the end of each step.
        checkpoint_save_frequency(int): Frequency at which to save the model checkpoint. The model will be saved every `checkpoint_save_frequency` epochs or steps, depending on the value of `checkpoint_save_interval`.
        accumulate_grad_batches(int): Number of batches to accumulate gradients across before updating the model parameters.
        gradient_clip_val(float): Value to clip the gradients. If set to None, no gradient clipping will be applied.
        gradient_clip_algorithm(str): Algorithm to use for gradient clipping. Available options: 'value', 'norm'. If set to 'value', the gradients will be clipped to the specified value. If set to 'norm', the gradients will be clipped to the specified norm.
        save_optimizer_state(bool): Whether to save the optimizer and lr_scheduler state along with the model checkpoint.
        save_full_model(bool): Whether to save the full model or only the trainable parameters in the model checkpoint.
        save_ckpt_type (str): Type of checkpoint to save. Available options: 'lightning', 'hf'. If set to 'lightning', the checkpoint will be saved in the lightning format. If set to 'hf', the checkpoint will be saved in the huggingface format.
        ckpt_path(str): Path to the checkpoint to load before training. If set to None, no checkpoint will be loaded.
        max_length(int): Maximum input length to consider. If the input length exceeds this value, it will be truncated.
        fix_token_embedding(bool): Whether to fix the token embeddings during training. If set to True, the token embeddings will not be updated during training.
    """
    self._optimizer = optimizer
    self._lr_scheduler = lr_scheduler
    self.dataloader_kwargs = dataloader_kwargs
    self.max_epochs = max_epochs
    self.max_steps = max_steps
    self.max_steps_per_epoch = max_steps_per_epoch
    self.lr_scheduler_interval = lr_scheduler_interval
    self.lr_scheduler_frequency = lr_scheduler_frequency
    self.checkpoint_save_interval = checkpoint_save_interval
    self.checkpoint_save_frequency = checkpoint_save_frequency
    self.accumulate_grad_batches = accumulate_grad_batches
    self.gradient_clip_val = gradient_clip_val
    self.gradient_clip_algorithm = gradient_clip_algorithm
    self.save_optimizer_state = save_optimizer_state
    self.save_full_model = save_full_model
    self.save_ckpt_type = save_ckpt_type
    self.ckpt_path = ckpt_path
    self.max_length = max_length
    self.fix_token_embedding = fix_token_embedding
    super().__init__(**kwargs)

PeftFinetuneSFT

Bases: BaseAlgorithm, FabricTrainingMixin

Source code in fusion_bench/method/lm_finetune/peftfinetune_sft.py
class PeftFinetuneSFT(BaseAlgorithm, FabricTrainingMixin):

    model: Union[
        nn.Module, "_FabricModule", "LlamaForCausalLM", PeftModel, peft.LoraModel
    ]
    optimizer: Union[torch.optim.Optimizer, "_FabricOptimizer"]
    train_dataloader: Union[DataLoader, "_FabricDataLoader"]
    lr_scheduler: torch.optim.lr_scheduler.LRScheduler
    _latest_saved_checkpoint_global_step: int = -1

    def __init__(
        self,
        optimizer: DictConfig,
        lr_scheduler: Optional[DictConfig],
        peft_config: DictConfig,
        dataloader_kwargs: DictConfig,
        adapter_name: str = "default",
        merge_and_unload: bool = False,
        max_epochs: int = 1,
        max_steps: int = -1,
        max_steps_per_epoch: int = -1,
        lr_scheduler_interval: Literal["epoch", "step"] = "step",
        lr_scheduler_frequency: int = 1,
        checkpoint_save_interval: Literal["epoch", "step"] = "epoch",
        checkpoint_save_frequency: int = 1,
        accumulate_grad_batches: int = 1,
        gradient_clip_val: Optional[float] = None,
        gradient_clip_algorithm: Literal["value", "norm"] = "norm",
        save_optimizer_state: bool = False,
        save_full_model: bool = False,
        save_ckpt_type: Literal["lightning", "peft"] = "peft",
        ckpt_path: Optional[str] = None,
        max_length: int = 6144,
        **kwargs,
    ):
        """
        Class for full finetuning of a language model on given SFT datasets.

        Args:
            optimizer(DictConfig): Configuration for the optimizer.
            lr_scheduler(DictConfig): Configuration for the learning rate scheduler.
            peft_config(DictConfig): Configuration for the PEFT model.
            dataloader_kwargs(DictConfig): Configuration for the dataloader, such as batch size, num_workers, etc.
            adapter_name(str): Name of the adapter to use for the PEFT model.
            merge_and_unload(bool): Whether to merge and unload the model after training.
            max_epochs(int): Maximum number of epochs to train the model. If set to -1, the training will continue indefinitely or until max_steps is reached.
            max_steps(int): Maximum number of steps to train the model. If set to -1, the training will continue indefinitely or until max_epochs is reached.
            max_steps_per_epoch(int): Maximum number of steps to train the model in each epoch. If set to -1, the training will continue until the end of the epoch.
            lr_scheduler_interval(str): Interval at which to run the learning rate scheduler. Available options: 'epoch', 'step'. If set to 'epoch', the scheduler will run at the end of each epoch. If set to 'step', the scheduler will run at the end of each step.
            lr_scheduler_frequency(int): Frequency at which to run the learning rate scheduler. The scheduler will run every `lr_scheduler_frequency` epochs or steps, depending on the value of `lr_scheduler_interval`.
            checkpoint_save_interval(str): Interval at which to save the model checkpoint. Available options: 'epoch', 'step'. If set to 'epoch', the model will be saved at the end of each epoch. If set to 'step', the model will be saved at the end of each step.
            checkpoint_save_frequency(int): Frequency at which to save the model checkpoint. The model will be saved every `checkpoint_save_frequency` epochs or steps, depending on the value of `checkpoint_save_interval`.
            accumulate_grad_batches(int): Number of batches to accumulate gradients across before updating the model parameters.
            gradient_clip_val(float): Value to clip the gradients. If set to None, no gradient clipping will be applied.
            gradient_clip_algorithm(str): Algorithm to use for gradient clipping. Available options: 'value', 'norm'. If set to 'value', the gradients will be clipped to the specified value. If set to 'norm', the gradients will be clipped to the specified norm.
            save_optimizer_state(bool): Whether to save the optimizer and lr_scheduler state along with the model checkpoint.
            save_full_model(bool): Whether to save the full model or only the trainable parameters in the model checkpoint.
            save_ckpt_type(str): Type of checkpoint to save. Available options: 'lightning', 'peft'. If set to 'lightning', the model will be saved using the Lightning checkpointing mechanism. If set to 'peft', the model will be saved using the PEFT checkpointing mechanism.
            ckpt_path(str): Path to the checkpoint to load before training. If set to None, no checkpoint will be loaded.
        """
        self._optimizer = optimizer
        self._lr_scheduler = lr_scheduler
        self._peft_config = peft_config
        self.dataloader_kwargs = dataloader_kwargs
        self.adapter_name = adapter_name
        self.merge_and_unload = merge_and_unload
        self.max_epochs = max_epochs
        self.max_steps = max_steps
        self.max_steps_per_epoch = max_steps_per_epoch
        self.lr_scheduler_interval = lr_scheduler_interval
        self.lr_scheduler_frequency = lr_scheduler_frequency
        self.checkpoint_save_interval = checkpoint_save_interval
        self.checkpoint_save_frequency = checkpoint_save_frequency
        self.accumulate_grad_batches = accumulate_grad_batches
        self.gradient_clip_val = gradient_clip_val
        self.gradient_clip_algorithm = gradient_clip_algorithm
        self.save_optimizer_state = save_optimizer_state
        self.save_full_model = save_full_model
        self.save_ckpt_type = save_ckpt_type
        self.ckpt_path = ckpt_path
        self.max_length = max_length
        super().__init__(**kwargs)

    def run(self, modelpool: CausalLMPool):
        self.modelpool = modelpool
        self.setup()
        self.train(self.model, self.optimizer, self.lr_scheduler)

        if self.merge_and_unload:
            self.model = self.model.merge_and_unload()
        return self.model

    def setup_model(self):
        # https://github.com/Lightning-AI/litgpt/blob/main/litgpt/finetune/lora.py
        self.tokenizer = self.modelpool.load_tokenizer()
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

        model = self.modelpool.load_pretrained_model()

        # get the PEFT model
        peft_config = instantiate(self._peft_config, _convert_="all")
        peft_config.save_pretrained(os.path.join(self.log_dir, "peft_config"))
        peft_model = get_peft_model(model, peft_config, self.adapter_name)
        peft_model.print_trainable_parameters()

        self.model = peft_model

        if self.fabric.strategy == "fsdp" or isinstance(
            self.fabric.strategy, FSDPStrategy
        ):
            # https://github.com/Lightning-AI/pytorch-lightning/issues/19267
            self.model.gradient_checkpointing_enable(
                gradient_checkpointing_kwargs={"use_reentrant": True}
            )
            self.use_cache = False
        else:
            self.use_cache = True

        self.model_dtype = get_dtype(self.model)
        self.model = self.model.to(dtype=self.model_dtype)

    def configure_optimizer(self):
        # compute expected total steps
        self.compute_expected_total_steps(self.train_dataloader)

        optimizer = instantiate(self._optimizer, self.model.parameters())
        if self._lr_scheduler is not None:
            for key, arg in self._lr_scheduler.items():
                if arg == "_T_max_":
                    log.info(
                        f"Setting key `{key}` of lr_scheduler configuration to {self.expected_total_steps}"
                    )
                    self._lr_scheduler[key] = self.expected_total_steps
            lr_scheduler: torch.optim.lr_scheduler.LRScheduler = instantiate(
                self._lr_scheduler,
                optimizer=optimizer,
            )
        else:
            lr_scheduler = None
        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}

    def setup_data(self):
        fabric = self.fabric
        modelpool = self.modelpool
        assert (
            len(modelpool.train_dataset_names) > 0
        ), "No training datasets found in modelpool."

        train_datasets = [
            modelpool.load_train_dataset(dataset_name)
            for dataset_name in modelpool.train_dataset_names
        ]
        if len(train_datasets) > 1:
            train_dataset = ConcatDataset(train_datasets)
        else:
            train_dataset = train_datasets[0]

        self.train_dataset = train_dataset
        self.train_dataloader = DataLoader(
            train_dataset,
            **self.dataloader_kwargs,
            shuffle=True,
            collate_fn=functools.partial(
                padded_collate_sft, pad_token_id=self.tokenizer.pad_token_id
            ),
        )
        self.train_dataloader = fabric.setup_dataloaders(self.train_dataloader)

    def setup(self):
        fabric = self.fabric

        self.setup_model()
        self.setup_data()

        optimizer = self.configure_optimizer()
        optimizer, lr_scheduler = optimizer["optimizer"], optimizer["lr_scheduler"]

        self.model = self.fabric.setup_module(self.model)
        self.optimizer = self.fabric.setup_optimizers(optimizer)
        self.lr_scheduler = lr_scheduler

    @override
    def train_epoch(self, *args, **kwargs):
        fabric = self.fabric

        accumulated_loss = 0
        for step_idx, batch in enumerate(
            pbar := tqdm(
                self.train_dataloader,
                desc="Training Batches",
                dynamic_ncols=True,
                leave=False,
                disable=not fabric.is_global_zero,
            )
        ):
            is_accumulating = (step_idx + 1) % self.accumulate_grad_batches != 0

            if self.max_length > 0 and batch["input_ids"].shape[1] > self.max_length:
                log.warning(
                    f"Input length exceeds max_length: {batch['input_ids'].shape[1]} > {self.max_length}. Truncating input."
                )
                batch["input_ids"] = batch["input_ids"][:, : self.max_length]
                batch["attention_mask"] = batch["attention_mask"][:, : self.max_length]
                batch["labels"] = batch["labels"][:, : self.max_length]

            # disable gradient synchronization if accumulating gradients across steps for improved performance
            with fabric.no_backward_sync(self.model, enabled=is_accumulating):
                # use_cache=True is not compatible with gradient checkpointing, so we disable it here
                output = self.model(
                    input_ids=batch["input_ids"],
                    attention_mask=batch["attention_mask"],
                    labels=batch["labels"],
                    use_cache=self.use_cache,
                )
                loss = output["loss"] / self.accumulate_grad_batches

                fabric.backward(loss)
                accumulated_loss += loss.item()

            if not is_accumulating:
                self.clip_gradients_if_needed(self.model, self.optimizer)

                # run lr_scheduler at the end of the step if interval is set to "step"
                if (
                    self.lr_scheduler_interval == "step"
                    and (self.global_step_idx + 1) % self.lr_scheduler_frequency == 0
                ):
                    self.lr_scheduler.step()

                # update the model parameters and zero the gradients
                self.optimizer.step()
                self.optimizer.zero_grad()

                metrics = {
                    "train/loss": accumulated_loss,
                    "train/epoch_idx": self.epoch_idx,
                    "train/lr": self.optimizer.param_groups[0]["lr"],
                }
                fabric.log_dict(metrics, step=self.global_step_idx)
                pbar.set_postfix(metrics)

                # save the model at the end of the step if interval is set to "step" and frequency is met
                self.conditional_checkpoint_save(stage="end_of_step")

                # break if max_steps_per_epoch is set, and exit epoch
                if (
                    self.max_steps_per_epoch > 0
                    and step_idx + 1 >= self.max_steps_per_epoch
                ):
                    break
                # break if max_steps is set, and exit training
                if self.max_steps > 0 and self.global_step_idx >= self.max_steps - 1:
                    self.is_training = False
                    break

                self.global_step_idx += 1
                accumulated_loss = 0

    def save_checkpoint(
        self,
        path: Union[str, Path],
        save_optimizer_state: Optional[bool] = None,
        overwrite: bool = False,
    ):
        if not overwrite and os.path.exists(path):
            return log.warning(f"Checkpoint already exists at {path}. Skipping save.")

        fabric = self.fabric
        if self.save_ckpt_type == "lightning":
            state = {"model": self.model}

            # save the optimizer and lr_scheduler state if needed
            if self.save_optimizer_state and save_optimizer_state is not False:
                state.update(
                    {
                        "optimizer": self.optimizer,
                        "lr_scheduler": self.lr_scheduler,
                        "global_step_idx": self.global_step_idx,
                        "epoch_idx": self.epoch_idx,
                    }
                )
            trainable_param_names = set(
                name
                for name, param in self.model.state_dict(keep_vars=True).items()
                if param.requires_grad
            )
            filter = (
                None
                if self.save_full_model
                else {"model": lambda k, p: k in trainable_param_names}
            )
            os.makedirs(os.path.dirname(path), exist_ok=True)
            fabric.save(path, state=state, filter=filter)
        elif self.save_ckpt_type == "peft":
            self.model.save_pretrained(path, is_main_process=fabric.is_global_zero)
        else:
            raise ValueError(
                f"Unknown save_ckpt_type: {self.save_ckpt_type}. Available options: 'lightning', 'peft'"
            )
        self._latest_saved_checkpoint_global_step = self.global_step_idx

    def load_checkpoint(self, path: Union[str, Path]):
        fabric = self.fabric

        state = {"model": self.model}

        # save the optimizer and lr_scheduler state if needed
        if self.save_optimizer_state:
            state.update(
                {
                    "optimizer": self.optimizer,
                    "lr_scheduler": self.lr_scheduler,
                }
            )

        fabric.load(path, state)

__init__(optimizer, lr_scheduler, peft_config, dataloader_kwargs, adapter_name='default', merge_and_unload=False, max_epochs=1, max_steps=-1, max_steps_per_epoch=-1, lr_scheduler_interval='step', lr_scheduler_frequency=1, checkpoint_save_interval='epoch', checkpoint_save_frequency=1, accumulate_grad_batches=1, gradient_clip_val=None, gradient_clip_algorithm='norm', save_optimizer_state=False, save_full_model=False, save_ckpt_type='peft', ckpt_path=None, max_length=6144, **kwargs)

Class for full finetuning of a language model on given SFT datasets.

Parameters:

  • optimizer(DictConfig)

    Configuration for the optimizer.

  • lr_scheduler(DictConfig)

    Configuration for the learning rate scheduler.

  • peft_config(DictConfig)

    Configuration for the PEFT model.

  • dataloader_kwargs(DictConfig)

    Configuration for the dataloader, such as batch size, num_workers, etc.

  • adapter_name(str)

    Name of the adapter to use for the PEFT model.

  • merge_and_unload(bool)

    Whether to merge and unload the model after training.

  • max_epochs(int)

    Maximum number of epochs to train the model. If set to -1, the training will continue indefinitely or until max_steps is reached.

  • max_steps(int)

    Maximum number of steps to train the model. If set to -1, the training will continue indefinitely or until max_epochs is reached.

  • max_steps_per_epoch(int)

    Maximum number of steps to train the model in each epoch. If set to -1, the training will continue until the end of the epoch.

  • lr_scheduler_interval(str)

    Interval at which to run the learning rate scheduler. Available options: 'epoch', 'step'. If set to 'epoch', the scheduler will run at the end of each epoch. If set to 'step', the scheduler will run at the end of each step.

  • lr_scheduler_frequency(int)

    Frequency at which to run the learning rate scheduler. The scheduler will run every lr_scheduler_frequency epochs or steps, depending on the value of lr_scheduler_interval.

  • checkpoint_save_interval(str)

    Interval at which to save the model checkpoint. Available options: 'epoch', 'step'. If set to 'epoch', the model will be saved at the end of each epoch. If set to 'step', the model will be saved at the end of each step.

  • checkpoint_save_frequency(int)

    Frequency at which to save the model checkpoint. The model will be saved every checkpoint_save_frequency epochs or steps, depending on the value of checkpoint_save_interval.

  • accumulate_grad_batches(int)

    Number of batches to accumulate gradients across before updating the model parameters.

  • gradient_clip_val(float)

    Value to clip the gradients. If set to None, no gradient clipping will be applied.

  • gradient_clip_algorithm(str)

    Algorithm to use for gradient clipping. Available options: 'value', 'norm'. If set to 'value', the gradients will be clipped to the specified value. If set to 'norm', the gradients will be clipped to the specified norm.

  • save_optimizer_state(bool)

    Whether to save the optimizer and lr_scheduler state along with the model checkpoint.

  • save_full_model(bool)

    Whether to save the full model or only the trainable parameters in the model checkpoint.

  • save_ckpt_type(str)

    Type of checkpoint to save. Available options: 'lightning', 'peft'. If set to 'lightning', the model will be saved using the Lightning checkpointing mechanism. If set to 'peft', the model will be saved using the PEFT checkpointing mechanism.

  • ckpt_path(str)

    Path to the checkpoint to load before training. If set to None, no checkpoint will be loaded.

Source code in fusion_bench/method/lm_finetune/peftfinetune_sft.py
def __init__(
    self,
    optimizer: DictConfig,
    lr_scheduler: Optional[DictConfig],
    peft_config: DictConfig,
    dataloader_kwargs: DictConfig,
    adapter_name: str = "default",
    merge_and_unload: bool = False,
    max_epochs: int = 1,
    max_steps: int = -1,
    max_steps_per_epoch: int = -1,
    lr_scheduler_interval: Literal["epoch", "step"] = "step",
    lr_scheduler_frequency: int = 1,
    checkpoint_save_interval: Literal["epoch", "step"] = "epoch",
    checkpoint_save_frequency: int = 1,
    accumulate_grad_batches: int = 1,
    gradient_clip_val: Optional[float] = None,
    gradient_clip_algorithm: Literal["value", "norm"] = "norm",
    save_optimizer_state: bool = False,
    save_full_model: bool = False,
    save_ckpt_type: Literal["lightning", "peft"] = "peft",
    ckpt_path: Optional[str] = None,
    max_length: int = 6144,
    **kwargs,
):
    """
    Class for full finetuning of a language model on given SFT datasets.

    Args:
        optimizer(DictConfig): Configuration for the optimizer.
        lr_scheduler(DictConfig): Configuration for the learning rate scheduler.
        peft_config(DictConfig): Configuration for the PEFT model.
        dataloader_kwargs(DictConfig): Configuration for the dataloader, such as batch size, num_workers, etc.
        adapter_name(str): Name of the adapter to use for the PEFT model.
        merge_and_unload(bool): Whether to merge and unload the model after training.
        max_epochs(int): Maximum number of epochs to train the model. If set to -1, the training will continue indefinitely or until max_steps is reached.
        max_steps(int): Maximum number of steps to train the model. If set to -1, the training will continue indefinitely or until max_epochs is reached.
        max_steps_per_epoch(int): Maximum number of steps to train the model in each epoch. If set to -1, the training will continue until the end of the epoch.
        lr_scheduler_interval(str): Interval at which to run the learning rate scheduler. Available options: 'epoch', 'step'. If set to 'epoch', the scheduler will run at the end of each epoch. If set to 'step', the scheduler will run at the end of each step.
        lr_scheduler_frequency(int): Frequency at which to run the learning rate scheduler. The scheduler will run every `lr_scheduler_frequency` epochs or steps, depending on the value of `lr_scheduler_interval`.
        checkpoint_save_interval(str): Interval at which to save the model checkpoint. Available options: 'epoch', 'step'. If set to 'epoch', the model will be saved at the end of each epoch. If set to 'step', the model will be saved at the end of each step.
        checkpoint_save_frequency(int): Frequency at which to save the model checkpoint. The model will be saved every `checkpoint_save_frequency` epochs or steps, depending on the value of `checkpoint_save_interval`.
        accumulate_grad_batches(int): Number of batches to accumulate gradients across before updating the model parameters.
        gradient_clip_val(float): Value to clip the gradients. If set to None, no gradient clipping will be applied.
        gradient_clip_algorithm(str): Algorithm to use for gradient clipping. Available options: 'value', 'norm'. If set to 'value', the gradients will be clipped to the specified value. If set to 'norm', the gradients will be clipped to the specified norm.
        save_optimizer_state(bool): Whether to save the optimizer and lr_scheduler state along with the model checkpoint.
        save_full_model(bool): Whether to save the full model or only the trainable parameters in the model checkpoint.
        save_ckpt_type(str): Type of checkpoint to save. Available options: 'lightning', 'peft'. If set to 'lightning', the model will be saved using the Lightning checkpointing mechanism. If set to 'peft', the model will be saved using the PEFT checkpointing mechanism.
        ckpt_path(str): Path to the checkpoint to load before training. If set to None, no checkpoint will be loaded.
    """
    self._optimizer = optimizer
    self._lr_scheduler = lr_scheduler
    self._peft_config = peft_config
    self.dataloader_kwargs = dataloader_kwargs
    self.adapter_name = adapter_name
    self.merge_and_unload = merge_and_unload
    self.max_epochs = max_epochs
    self.max_steps = max_steps
    self.max_steps_per_epoch = max_steps_per_epoch
    self.lr_scheduler_interval = lr_scheduler_interval
    self.lr_scheduler_frequency = lr_scheduler_frequency
    self.checkpoint_save_interval = checkpoint_save_interval
    self.checkpoint_save_frequency = checkpoint_save_frequency
    self.accumulate_grad_batches = accumulate_grad_batches
    self.gradient_clip_val = gradient_clip_val
    self.gradient_clip_algorithm = gradient_clip_algorithm
    self.save_optimizer_state = save_optimizer_state
    self.save_full_model = save_full_model
    self.save_ckpt_type = save_ckpt_type
    self.ckpt_path = ckpt_path
    self.max_length = max_length
    super().__init__(**kwargs)

Reward Modeling

BradleyTerryRewardModeling

Bases: BaseAlgorithm, FabricTrainingMixin

Source code in fusion_bench/method/lm_finetune/bradley_terry_rm.py
class BradleyTerryRewardModeling(BaseAlgorithm, FabricTrainingMixin):

    model: Union[nn.Module, "_FabricModule", "LlamaForSequenceClassification"]
    optimizer: Union[torch.optim.Optimizer, "_FabricOptimizer"]
    train_dataloader: Union[DataLoader, "_FabricDataLoader"]
    lr_scheduler: torch.optim.lr_scheduler.LRScheduler

    def __init__(
        self,
        optimizer: DictConfig,
        lr_scheduler: Optional[DictConfig],
        dataloader_kwargs: DictConfig,
        max_epochs: int,
        max_steps: int = -1,
        max_steps_per_epoch: int = -1,
        lr_scheduler_interval: Literal["epoch", "step"] = "step",
        lr_scheduler_frequency: int = 1,
        checkpoint_save_interval: Literal["epoch", "step"] = "epoch",
        checkpoint_save_frequency: int = 1,
        accumulate_grad_batches: int = 1,
        gradient_clip_val: Optional[float] = None,
        gradient_clip_algorithm: Literal["value", "norm"] = "norm",
        save_optimizer_state: bool = False,
        save_full_model: bool = False,
        save_ckpt_type: Literal["lightning", "hf"] = "lightning",
        ckpt_path: Optional[str] = None,
        max_length: int = 6144,
        fix_token_embedding: bool = True,
        **kwargs,
    ):
        """
        Class for reward modeling using Bradley-Terry model.

        Args:
            optimizer(DictConfig): Configuration for the optimizer.
            lr_scheduler(DictConfig): Configuration for the learning rate scheduler.
            dataloader_kwargs(DictConfig): Configuration for the dataloader, such as batch size, num_workers, etc.
            max_epochs(int): Maximum number of epochs to train the model. If set to -1, the training will continue indefinitely or until max_steps is reached.
            max_steps(int): Maximum number of steps to train the model. If set to -1, the training will continue indefinitely or until max_epochs is reached.
            max_steps_per_epoch(int): Maximum number of steps to train the model in each epoch. If set to -1, the training will continue until the end of the epoch.
            lr_scheduler_interval(str): Interval at which to run the learning rate scheduler. Available options: 'epoch', 'step'. If set to 'epoch', the scheduler will run at the end of each epoch. If set to 'step', the scheduler will run at the end of each step.
            lr_scheduler_frequency(int): Frequency at which to run the learning rate scheduler. The scheduler will run every `lr_scheduler_frequency` epochs or steps, depending on the value of `lr_scheduler_interval`.
            checkpoint_save_interval(str): Interval at which to save the model checkpoint. Available options: 'epoch', 'step'. If set to 'epoch', the model will be saved at the end of each epoch. If set to 'step', the model will be saved at the end of each step.
            checkpoint_save_frequency(int): Frequency at which to save the model checkpoint. The model will be saved every `checkpoint_save_frequency` epochs or steps, depending on the value of `checkpoint_save_interval`.
            accumulate_grad_batches(int): Number of batches to accumulate gradients across before updating the model parameters.
            gradient_clip_val(float): Value to clip the gradients. If set to None, no gradient clipping will be applied.
            gradient_clip_algorithm(str): Algorithm to use for gradient clipping. Available options: 'value', 'norm'. If set to 'value', the gradients will be clipped to the specified value. If set to 'norm', the gradients will be clipped to the specified norm.
            save_optimizer_state(bool): Whether to save the optimizer and lr_scheduler state along with the model checkpoint.
            save_full_model(bool): Whether to save the full model or only the trainable parameters in the model checkpoint.
            save_ckpt_type (str): Type of checkpoint to save. Available options: 'lightning', 'hf'. If set to 'lightning', the checkpoint will be saved in the lightning format. If set to 'hf', the checkpoint will be saved in the huggingface format.
            ckpt_path(str): Path to the checkpoint to load before training. If set to None, no checkpoint will be loaded.
            max_length(int): Maximum input length to consider. If the input length exceeds this value, it will be truncated.
            fix_token_embedding(bool): Whether to fix the token embeddings during training. If set to True, the token embeddings will not be updated during training.
        """
        self._optimizer = optimizer
        self._lr_scheduler = lr_scheduler
        self.dataloader_kwargs = dataloader_kwargs
        self.max_epochs = max_epochs
        self.max_steps = max_steps
        self.max_steps_per_epoch = max_steps_per_epoch
        self.lr_scheduler_interval = lr_scheduler_interval
        self.lr_scheduler_frequency = lr_scheduler_frequency
        self.checkpoint_save_interval = checkpoint_save_interval
        self.checkpoint_save_frequency = checkpoint_save_frequency
        self.accumulate_grad_batches = accumulate_grad_batches
        self.gradient_clip_val = gradient_clip_val
        self.gradient_clip_algorithm = gradient_clip_algorithm
        self.save_optimizer_state = save_optimizer_state
        self.save_full_model = save_full_model
        self.save_ckpt_type = save_ckpt_type
        self.ckpt_path = ckpt_path
        self.max_length = max_length
        self.fix_token_embedding = fix_token_embedding
        super().__init__(**kwargs)

    def run(self, modelpool: SequenceClassificationModelPool):
        self.modelpool = modelpool
        self.setup()
        self.train(self.model, self.optimizer, self.lr_scheduler)
        return self.model

    def setup_model(self):
        self.tokenizer = self.modelpool.load_tokenizer()
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = (
                self.tokenizer.eos_token_id
            )  #! make sure eos_token_id only show up at the end of the sequence

        model = self.modelpool.load_pretrained_model()
        self.model: "LlamaForSequenceClassification" = model

        if model.config.pad_token_id is None:
            model.config.pad_token_id = self.tokenizer.pad_token_id

        if self.fix_token_embedding:
            self.model.model.embed_tokens.requires_grad_(False)

        if self.fabric.strategy == "fsdp" or isinstance(
            self.fabric.strategy, FSDPStrategy
        ):
            # https://github.com/Lightning-AI/pytorch-lightning/issues/19267
            self.model.gradient_checkpointing_enable(
                gradient_checkpointing_kwargs={"use_reentrant": True}
            )
            self.use_cache = False
        else:
            self.use_cache = True
        self.model_dtype = get_dtype(self.model)

    def setup_data(self):
        fabric = self.fabric
        modelpool = self.modelpool
        assert (
            len(modelpool.train_dataset_names) > 0
        ), "No training datasets found in modelpool."

        train_datasets = [
            modelpool.load_train_dataset(dataset_name)
            for dataset_name in modelpool.train_dataset_names
        ]
        if len(train_datasets) > 1:
            train_dataset = ConcatDataset(train_datasets)
        else:
            train_dataset = train_datasets[0]

        self.train_dataset = train_dataset
        self.train_dataloader = DataLoader(
            train_dataset,
            **self.dataloader_kwargs,
            shuffle=True,
            collate_fn=functools.partial(
                bradley_terry_rm_collate,
                pad_token_id=self.tokenizer.pad_token_id,
            ),  # NOTE: different from SFT, uses bradley_terry_rm_collate
        )
        self.train_dataloader = fabric.setup_dataloaders(self.train_dataloader)

    def configure_optimizer(self):
        # compute expected total steps
        self.compute_expected_total_steps(self.train_dataloader)

        optimizer = instantiate(self._optimizer, self.model.parameters())
        if self._lr_scheduler is not None:
            for key, arg in self._lr_scheduler.items():
                if arg == "_T_max_":
                    log.info(
                        f"Setting key `{key}` of lr_scheduler configuration to {self.expected_total_steps}"
                    )
                    self._lr_scheduler[key] = self.expected_total_steps
            lr_scheduler: torch.optim.lr_scheduler.LRScheduler = instantiate(
                self._lr_scheduler,
                optimizer=optimizer,
            )
        else:
            lr_scheduler = None
        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}

    def setup(self):
        fabric = self.fabric

        self.setup_model()
        self.setup_data()

        optimizer = self.configure_optimizer()
        optimizer, lr_scheduler = optimizer["optimizer"], optimizer["lr_scheduler"]

        self.model, self.optimizer = fabric.setup(self.model, optimizer)
        self.lr_scheduler = lr_scheduler

    def compute_loss(self, batch: Dict[str, Union[Tensor, Any]]) -> Dict[str, Tensor]:
        """
        Maximize the likelihood of the winner over the loser using the Bradley-Terry model.

        Args:
            batch (Dict[str, Union[Tensor, Any]]): A dictionary containing the input token ids and attention masks for the winner and loser.
        """
        batch_size = batch["input_ids"].size(0)
        assert batch_size % 2 == 0, "Batch size must be even."

        outputs = self.model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            use_cache=self.use_cache,
        )

        rewards = outputs[0]
        chosen_reward = rewards[: batch_size // 2]
        rejected_rewards = rewards[batch_size // 2 :]
        loss = -torch.log(torch.sigmoid(chosen_reward - rejected_rewards)).mean()

        return {
            "chosen_reward": chosen_reward,
            "rejected_reward": rejected_rewards,
            "loss": loss,
        }

    @override
    def train_epoch(self, *args, **kwargs):
        fabric = self.fabric

        accumulated_loss = 0
        accumulated_chosen_reward = 0
        accumulated_rejected_reward = 0
        for step_idx, batch in enumerate(
            pbar := tqdm(
                self.train_dataloader,
                desc="Training Batches",
                dynamic_ncols=True,
                leave=False,
                disable=not fabric.is_global_zero,
            )
        ):
            is_accumulating = (step_idx + 1) % self.accumulate_grad_batches != 0

            if self.max_length > 0 and batch["input_ids"].shape[1] > self.max_length:
                log.warning(
                    f"Input length exceeds max_length: {batch['input_ids'].shape[1]} > {self.max_length}. Truncating input."
                )
                batch["input_ids"] = batch["input_ids"][:, -self.max_length :]
                batch["attention_mask"] = batch["attention_mask"][:, -self.max_length :]

            # disable gradient synchronization if accumulating gradients across steps for improved performance
            with fabric.no_backward_sync(self.model, enabled=is_accumulating):
                # use_cache=True is not compatible with gradient checkpointing, so we disable it here
                output = self.compute_loss(batch)
                loss = output["loss"] / self.accumulate_grad_batches

                fabric.backward(loss)

                accumulated_loss += loss.item()
                accumulated_chosen_reward += output["chosen_reward"].mean().item()
                accumulated_rejected_reward += output["rejected_reward"].mean().item()

            # 1. update the model parameters if not accumulating gradients
            # 2. step the lr_scheduler if interval is set to "step" and frequency is met
            # 3. save the model if interval is set to "step" and frequency is met
            # 4. log metrics
            # 5. increase the global step index
            if not is_accumulating:
                self.clip_gradients_if_needed(self.model, self.optimizer)

                # run lr_scheduler at the end of the step if interval is set to "step"
                if (
                    self.lr_scheduler_interval == "step"
                    and (self.global_step_idx + 1) % self.lr_scheduler_frequency == 0
                ):
                    self.lr_scheduler.step()

                # update the model parameters and zero the gradients
                self.optimizer.step()
                self.optimizer.zero_grad()

                metrics = {
                    "train/loss": accumulated_loss,
                    "train/chosen_reward": accumulated_chosen_reward
                    / self.accumulate_grad_batches,
                    "train/rejected_reward": accumulated_rejected_reward
                    / self.accumulate_grad_batches,
                    "train/epoch_idx": self.epoch_idx,
                    "train/lr": self.optimizer.param_groups[0]["lr"],
                }
                metrics["train/chosen_reward-rejected_reward"] = (
                    metrics["train/chosen_reward"] - metrics["train/rejected_reward"]
                )

                fabric.log_dict(metrics, step=self.global_step_idx)
                pbar.set_postfix(metrics)

                # save the model at the end of the step if interval is set to "step" and frequency is met
                self.conditional_checkpoint_save(stage="end_of_step")

                # break if max_steps_per_epoch is set, and exit epoch
                if (
                    self.max_steps_per_epoch > 0
                    and step_idx + 1 >= self.max_steps_per_epoch
                ):
                    break
                # break if max_steps is set, and exit training
                if self.max_steps > 0 and self.global_step_idx >= self.max_steps - 1:
                    self.is_training = False
                    break

                self.global_step_idx += 1
                accumulated_loss = 0
                accumulated_chosen_reward = 0
                accumulated_rejected_reward = 0

    def save_checkpoint(
        self,
        path: Union[str, Path],
        save_optimizer_state: Optional[bool] = None,
        overwrite: bool = False,
    ):
        if not overwrite and os.path.exists(path):
            return log.warning(f"Checkpoint already exists at {path}. Skipping save.")

        fabric = self.fabric

        if self.save_ckpt_type == "lightning":
            state = {"model": self.model}

            # save the optimizer and lr_scheduler state if needed
            if self.save_optimizer_state and save_optimizer_state is not False:
                state.update(
                    {
                        "optimizer": self.optimizer,
                        "lr_scheduler": self.lr_scheduler,
                        "global_step_idx": self.global_step_idx,
                        "epoch_idx": self.epoch_idx,
                    }
                )

            trainable_param_names = set(
                name
                for name, param in self.model.state_dict(keep_vars=True).items()
                if param.requires_grad
            )
            filter = (
                None
                if self.save_full_model
                else {"model": lambda k, p: k in trainable_param_names}
            )

            fabric.save(path, state=state, filter=filter)
        else:
            self.model.save_pretrained(path, is_main_process=fabric.is_global_zero)

        self._latest_saved_checkpoint_global_step = self.global_step_idx

    def load_checkpoint(self, path: Union[str, Path]):
        fabric = self.fabric

        state = {"model": self.model}

        # save the optimizer and lr_scheduler state if needed
        if self.save_optimizer_state:
            state.update(
                {
                    "optimizer": self.optimizer,
                    "lr_scheduler": self.lr_scheduler,
                }
            )

        fabric.load(path, state)

__init__(optimizer, lr_scheduler, dataloader_kwargs, max_epochs, max_steps=-1, max_steps_per_epoch=-1, lr_scheduler_interval='step', lr_scheduler_frequency=1, checkpoint_save_interval='epoch', checkpoint_save_frequency=1, accumulate_grad_batches=1, gradient_clip_val=None, gradient_clip_algorithm='norm', save_optimizer_state=False, save_full_model=False, save_ckpt_type='lightning', ckpt_path=None, max_length=6144, fix_token_embedding=True, **kwargs)

Class for reward modeling using Bradley-Terry model.

Parameters:

  • optimizer(DictConfig)

    Configuration for the optimizer.

  • lr_scheduler(DictConfig)

    Configuration for the learning rate scheduler.

  • dataloader_kwargs(DictConfig)

    Configuration for the dataloader, such as batch size, num_workers, etc.

  • max_epochs(int)

    Maximum number of epochs to train the model. If set to -1, the training will continue indefinitely or until max_steps is reached.

  • max_steps(int)

    Maximum number of steps to train the model. If set to -1, the training will continue indefinitely or until max_epochs is reached.

  • max_steps_per_epoch(int)

    Maximum number of steps to train the model in each epoch. If set to -1, the training will continue until the end of the epoch.

  • lr_scheduler_interval(str)

    Interval at which to run the learning rate scheduler. Available options: 'epoch', 'step'. If set to 'epoch', the scheduler will run at the end of each epoch. If set to 'step', the scheduler will run at the end of each step.

  • lr_scheduler_frequency(int)

    Frequency at which to run the learning rate scheduler. The scheduler will run every lr_scheduler_frequency epochs or steps, depending on the value of lr_scheduler_interval.

  • checkpoint_save_interval(str)

    Interval at which to save the model checkpoint. Available options: 'epoch', 'step'. If set to 'epoch', the model will be saved at the end of each epoch. If set to 'step', the model will be saved at the end of each step.

  • checkpoint_save_frequency(int)

    Frequency at which to save the model checkpoint. The model will be saved every checkpoint_save_frequency epochs or steps, depending on the value of checkpoint_save_interval.

  • accumulate_grad_batches(int)

    Number of batches to accumulate gradients across before updating the model parameters.

  • gradient_clip_val(float)

    Value to clip the gradients. If set to None, no gradient clipping will be applied.

  • gradient_clip_algorithm(str)

    Algorithm to use for gradient clipping. Available options: 'value', 'norm'. If set to 'value', the gradients will be clipped to the specified value. If set to 'norm', the gradients will be clipped to the specified norm.

  • save_optimizer_state(bool)

    Whether to save the optimizer and lr_scheduler state along with the model checkpoint.

  • save_full_model(bool)

    Whether to save the full model or only the trainable parameters in the model checkpoint.

  • save_ckpt_type (str, default: 'lightning' ) –

    Type of checkpoint to save. Available options: 'lightning', 'hf'. If set to 'lightning', the checkpoint will be saved in the lightning format. If set to 'hf', the checkpoint will be saved in the huggingface format.

  • ckpt_path(str)

    Path to the checkpoint to load before training. If set to None, no checkpoint will be loaded.

  • max_length(int)

    Maximum input length to consider. If the input length exceeds this value, it will be truncated.

  • fix_token_embedding(bool)

    Whether to fix the token embeddings during training. If set to True, the token embeddings will not be updated during training.

Source code in fusion_bench/method/lm_finetune/bradley_terry_rm.py
def __init__(
    self,
    optimizer: DictConfig,
    lr_scheduler: Optional[DictConfig],
    dataloader_kwargs: DictConfig,
    max_epochs: int,
    max_steps: int = -1,
    max_steps_per_epoch: int = -1,
    lr_scheduler_interval: Literal["epoch", "step"] = "step",
    lr_scheduler_frequency: int = 1,
    checkpoint_save_interval: Literal["epoch", "step"] = "epoch",
    checkpoint_save_frequency: int = 1,
    accumulate_grad_batches: int = 1,
    gradient_clip_val: Optional[float] = None,
    gradient_clip_algorithm: Literal["value", "norm"] = "norm",
    save_optimizer_state: bool = False,
    save_full_model: bool = False,
    save_ckpt_type: Literal["lightning", "hf"] = "lightning",
    ckpt_path: Optional[str] = None,
    max_length: int = 6144,
    fix_token_embedding: bool = True,
    **kwargs,
):
    """
    Class for reward modeling using Bradley-Terry model.

    Args:
        optimizer(DictConfig): Configuration for the optimizer.
        lr_scheduler(DictConfig): Configuration for the learning rate scheduler.
        dataloader_kwargs(DictConfig): Configuration for the dataloader, such as batch size, num_workers, etc.
        max_epochs(int): Maximum number of epochs to train the model. If set to -1, the training will continue indefinitely or until max_steps is reached.
        max_steps(int): Maximum number of steps to train the model. If set to -1, the training will continue indefinitely or until max_epochs is reached.
        max_steps_per_epoch(int): Maximum number of steps to train the model in each epoch. If set to -1, the training will continue until the end of the epoch.
        lr_scheduler_interval(str): Interval at which to run the learning rate scheduler. Available options: 'epoch', 'step'. If set to 'epoch', the scheduler will run at the end of each epoch. If set to 'step', the scheduler will run at the end of each step.
        lr_scheduler_frequency(int): Frequency at which to run the learning rate scheduler. The scheduler will run every `lr_scheduler_frequency` epochs or steps, depending on the value of `lr_scheduler_interval`.
        checkpoint_save_interval(str): Interval at which to save the model checkpoint. Available options: 'epoch', 'step'. If set to 'epoch', the model will be saved at the end of each epoch. If set to 'step', the model will be saved at the end of each step.
        checkpoint_save_frequency(int): Frequency at which to save the model checkpoint. The model will be saved every `checkpoint_save_frequency` epochs or steps, depending on the value of `checkpoint_save_interval`.
        accumulate_grad_batches(int): Number of batches to accumulate gradients across before updating the model parameters.
        gradient_clip_val(float): Value to clip the gradients. If set to None, no gradient clipping will be applied.
        gradient_clip_algorithm(str): Algorithm to use for gradient clipping. Available options: 'value', 'norm'. If set to 'value', the gradients will be clipped to the specified value. If set to 'norm', the gradients will be clipped to the specified norm.
        save_optimizer_state(bool): Whether to save the optimizer and lr_scheduler state along with the model checkpoint.
        save_full_model(bool): Whether to save the full model or only the trainable parameters in the model checkpoint.
        save_ckpt_type (str): Type of checkpoint to save. Available options: 'lightning', 'hf'. If set to 'lightning', the checkpoint will be saved in the lightning format. If set to 'hf', the checkpoint will be saved in the huggingface format.
        ckpt_path(str): Path to the checkpoint to load before training. If set to None, no checkpoint will be loaded.
        max_length(int): Maximum input length to consider. If the input length exceeds this value, it will be truncated.
        fix_token_embedding(bool): Whether to fix the token embeddings during training. If set to True, the token embeddings will not be updated during training.
    """
    self._optimizer = optimizer
    self._lr_scheduler = lr_scheduler
    self.dataloader_kwargs = dataloader_kwargs
    self.max_epochs = max_epochs
    self.max_steps = max_steps
    self.max_steps_per_epoch = max_steps_per_epoch
    self.lr_scheduler_interval = lr_scheduler_interval
    self.lr_scheduler_frequency = lr_scheduler_frequency
    self.checkpoint_save_interval = checkpoint_save_interval
    self.checkpoint_save_frequency = checkpoint_save_frequency
    self.accumulate_grad_batches = accumulate_grad_batches
    self.gradient_clip_val = gradient_clip_val
    self.gradient_clip_algorithm = gradient_clip_algorithm
    self.save_optimizer_state = save_optimizer_state
    self.save_full_model = save_full_model
    self.save_ckpt_type = save_ckpt_type
    self.ckpt_path = ckpt_path
    self.max_length = max_length
    self.fix_token_embedding = fix_token_embedding
    super().__init__(**kwargs)

compute_loss(batch)

Maximize the likelihood of the winner over the loser using the Bradley-Terry model.

Parameters:

  • batch (Dict[str, Union[Tensor, Any]]) –

    A dictionary containing the input token ids and attention masks for the winner and loser.

Source code in fusion_bench/method/lm_finetune/bradley_terry_rm.py
def compute_loss(self, batch: Dict[str, Union[Tensor, Any]]) -> Dict[str, Tensor]:
    """
    Maximize the likelihood of the winner over the loser using the Bradley-Terry model.

    Args:
        batch (Dict[str, Union[Tensor, Any]]): A dictionary containing the input token ids and attention masks for the winner and loser.
    """
    batch_size = batch["input_ids"].size(0)
    assert batch_size % 2 == 0, "Batch size must be even."

    outputs = self.model(
        input_ids=batch["input_ids"],
        attention_mask=batch["attention_mask"],
        use_cache=self.use_cache,
    )

    rewards = outputs[0]
    chosen_reward = rewards[: batch_size // 2]
    rejected_rewards = rewards[batch_size // 2 :]
    loss = -torch.log(torch.sigmoid(chosen_reward - rejected_rewards)).mean()

    return {
        "chosen_reward": chosen_reward,
        "rejected_reward": rejected_rewards,
        "loss": loss,
    }