In the complex landscape of multi-task learning, AdaMerging has emerged as a potent method for adaptively merging model parameters to optimize performance across tasks. Unlike traditional fixed-coefficient methods, AdaMerging autonomously learns merging coefficients, offering a more refined and responsive approach1.
The cornerstone of AdaMerging lies in its adaptive nature, where it learns the coefficients for merging either on a task-wise or layer-wise basis. This adaptability is driven by an entropy minimization strategy applied to unlabeled test samples as a surrogate objective function, which serves to refine the merging coefficients for optimal performance.
where the merging coefficient \(\lambda^{l}_{i}\) and task vector \(\tau^{l}_{i}\) are specific to each layer \(l\) of the model.
By leveraging this adaptive learning approach, AdaMerging significantly enhances the model's ability to generalize across tasks and layers, resulting in a more robust and finely-tuned performance profile. The method’s reliance on entropy minimization ensures that the merging process continually seeks the most informative and stable configuration, adapting to the specific needs of the dataset and tasks at hand.
Task-wise Coefficients.
The below Figure shows the changes during the iteration process of merging coefficient optimization of each task vector in Task-wise AdaMerging and AdaMerging++, which is shown every ten steps. We consistently observe that the merging coefficients of each task vector are inconsistent. When the number of tasks is relatively large, it is obviously undesirable to grid search the coefficients of each task, but our AdaMerging avoids this manual search process.
Layer-wise Coefficients.
The following Figure shows the merging coefficients learned by Layer-wise AdaMerging and AdaMerging++ on ViT-B/32 respectively. We observed that:
The coefficients learned by each layer of each task vector are different, which shows that the importance of each layer in the model merging process is different.
The coefficients learned by shallow layers are generally smaller than those of deep layers, which indicates that shallow layers rely more on the weights of the pre-trained model rather than the weights provided by task vectors, while the deep layers rely more on the weights provided by the task vectors. This may be since the shallow layer learns general features, which are cross-task, while the deep layer learns task-specific features 2. This finding is also consistent with routing analysis in 3.
classTaskWiseAdaMergingAlgorithm(ModelFusionAlgorithm):_fabric:L.Fabric=Nonedef__init__(self,algorithm_config:DictConfig):super().__init__(algorithm_config)ifself._fabricisNoneandtorch.cuda.is_available():self._fabric=L.Fabric(devices=self.config.get("devices",1))self._fabric.launch()@torch.no_grad()defconstruct_task_wise_merged_model(self,modelpool:ModelPool):ifself.config.weightsisNone:task_wise_weight=get_task_wise_weights(num_models=len(modelpool.model_names),init_values=self.config.init_values,)else:ifisinstance(self.config.weights,str):# self.config.weights is a path to a .np or .pt fileifself.config.weights.endswith(".pt"):task_wise_weight=torch.load(self.config.weights,map_location="cpu").detach_()elifself.config.weights.endswith(".np"):task_wise_weight=torch.from_numpy(np.load(self.config.weights)).detach_()else:raiseValueError(f"Unsupported file format: {self.config.weights}")else:try:task_wise_weight=torch.tensor(list(self.config.weights),dtype=torch.float32)exceptValueError:raiseValueError(f"Unsupported weights format: {self.config.weights}")pretrained_model=modelpool.load_model("_pretrained_")finetuned_models=[modelpool.load_model(name)fornameinmodelpool.model_names]module=TaskWiseMergedModel(task_wise_weight=task_wise_weight,pretrained_model=pretrained_model,finetuned_models=finetuned_models,clamp_weights=self.config.clamp_weights,tie_weights=self.config.tie_weights,strict=self.config.strict,)returnmoduledefrun(self,modelpool:ModelPool):log.info("Fusing models using task-wise adaptive merging.")self.modelpool=modelpoolmodule=self.construct_task_wise_merged_model(modelpool)ifself.config.weightsisnotNone:# skip the test-time adaptationreturnmodule.merge_and_unload()else:module=self.test_time_adaptation(module)ifself.config.get("save_merging_weights",False):torch.save(module.merge_weight,self.config.save_merging_weights)returnmodule.merge_and_unload()defon_test_time_adaptation_start(self):pass@abstractmethoddefget_shuffled_test_loader_iter(self,task:str)->DataLoader:pass@abstractmethoddefcompute_logits(self,module:nn.Module,batch,task:str)->Tensor:""" Compute the logits for the given batch and task. Args: module (nn.Module): The model module. batch (tuple): A batch of input data. task (str): The name of the task. Returns: Tensor: The classification logits for the batch. """passdeftest_time_adaptation(self,module:TaskWiseMergedModel):self.on_test_time_adaptation_start()# configure optimizerifself.config.optimizer=="adam":optimizer=torch.optim.Adam([module.merge_weight],lr=self.config.lr)else:raiseValueError(f"Unsupported optimizer: {self.config.optimizer}")ifself._fabricisnotNone:module,optimizer=self._fabric.setup(module,optimizer)module.train()module.merge_weights()ifself.config.get("fast_dev_run",False):log.info("Running fast_dev_run, only one step")pbar=tqdm(range(1),"AdaMerging Test-time adaptation",dynamic_ncols=True,)else:pbar=tqdm(range(self.config.max_steps),"AdaMerging Test-time adaptation",dynamic_ncols=True,)forstep_idxinpbar:fortaskinself.modelpool.model_names:batch=next(self.get_shuffled_test_loader_iter(task))logits=self.compute_logits(module,batch,task)assert(logits.dim()==2),f"Expected logits to be 2D, got {logits.dim()}"loss=entropy_loss(logits)# .backward() accumulates when .zero_grad() wasn't called# this can save memoryself._fabric.backward(loss,retain_graph=True)optimizer.step()optimizer.zero_grad()module.merge_weights()returnmodule
@abstractmethoddefcompute_logits(self,module:nn.Module,batch,task:str)->Tensor:""" Compute the logits for the given batch and task. Args: module (nn.Module): The model module. batch (tuple): A batch of input data. task (str): The name of the task. Returns: Tensor: The classification logits for the batch. """pass
defentropy_loss(logits:Tensor)->Tensor:""" Compute the entropy loss of a set of logits. Args: logits (Tensor): The logits to compute the entropy loss of. Returns: Tensor: The entropy loss of the logits. """probs=torch.softmax(logits,dim=-1)return-torch.sum(probs*torch.log(probs+1e-8),dim=-1).mean()
A class for task-wise adaptive merging of CLIP models.
This class extends the TaskWiseAdaMergingAlgorithm to provide specific
functionality for CLIP models, including loading datasets, constructing
zero-shot classification heads, and computing logits.
classCLIPTaskWiseAdaMergingAlgorithm(TaskWiseAdaMergingAlgorithm):""" A class for task-wise adaptive merging of CLIP models. This class extends the TaskWiseAdaMergingAlgorithm to provide specific functionality for CLIP models, including loading datasets, constructing zero-shot classification heads, and computing logits. Attributes: modelpool (CLIPVisionModelPool): The model pool containing CLIP models. _clip_processor (CLIPProcessor): The CLIP processor for preparing inputs. zeroshot_weights (dict): A dictionary to store zero-shot weights for each task. """modelpool:CLIPVisionModelPool=None_clip_processor:CLIPProcessor=Nonezeroshot_weights={}def__init__(self,algorithm_config:DictConfig):super().__init__(algorithm_config)@functools.cachedefget_test_dataset(self,task:str):""" Load the test dataset for the task. This method is cached, so the dataset is loaded only once. Args: task (str): The name of the task. Returns: CLIPDataset: The test dataset for the task. """log.info(f"Loading test dataset: {task}")dataset=self.modelpool.load_test_dataset(task)dataset=CLIPDataset(dataset,self._clip_processor)returndataset@functools.cachedefget_shuffled_test_loader_iter(self,task:str):""" Get an iterator over the shuffled test DataLoader for the task. Args: task (str): The name of the task. Returns: iterator: An iterator over the shuffled test DataLoader. """loader=DataLoader(self.get_test_dataset(task),batch_size=self.config.batch_size,shuffle=True,num_workers=self.config.num_workers,pin_memory=True,)ifself._fabricisnotNone:loader=self._fabric.setup_dataloaders(loader)returniter(InfiniteDataLoader(loader))defon_test_time_adaptation_start(self):""" Prepare for test-time adaptation. This method loads the CLIP processor and constructs the zero-shot classification head for each task. """clip_model_config=self.modelpool.get_model_config("_pretrained_")pretrained_path=(clip_model_config.pretrained_model_name_or_pathifhasattr(clip_model_config,"pretrained_model_name_or_path")elseclip_model_config.path)withtimeit_context("Loading CLIP processor and pretrained CLIP model."):self._clip_processor=CLIPProcessor.from_pretrained(pretrained_path)clip_model:CLIPModel=CLIPModel.from_pretrained(pretrained_path)clip_classifier=HFCLIPClassifier(clip_model,self._clip_processor)self.visual_projection=clip_model.visual_projection.requires_grad_(False)self.logit_scale_exp=clip_model.logit_scale.exp()ifself._fabricisnotNone:self.visual_projection=self._fabric.to_device(self.visual_projection)self.logit_scale_exp=self._fabric.to_device(self.logit_scale_exp)fortaskinself.modelpool.model_names:cache_file=os.path.join(self.config.cache_dir,f"{os.path.basename(pretrained_path)}_{task}_zeroshot_weights.pt",)ifos.path.exists(cache_file):log.info(f"Loading cached zeroshot weights for task: {task}")zeroshot_weights=torch.load(cache_file,map_location="cpu")else:log.info(f"Construct zero shot classification head for task: {task}")classnames,templates=get_classnames_and_templates(task)clip_classifier.set_classification_task(classnames,templates)zeroshot_weights=clip_classifier.zeroshot_weightslog.info(f"save zeroshot weights to {cache_file}")torch.save(zeroshot_weights,cache_file)self.zeroshot_weights[task]=zeroshot_weightsifself._fabricisnotNone:self.zeroshot_weights[task]=self._fabric.to_device(self.zeroshot_weights[task])defcompute_logits(self,module,batch,task:str)->Tensor:""" Compute the logits for the given batch and task. This method computes the image embeddings, normalizes them, and calculates the cosine similarity with the text embeddings to produce classification logits. Args: module (nn.Module): The model module. batch (tuple): A batch of input data. task (str): The name of the task. Returns: Tensor: The classification logits for the batch. """images,_=batchtext_embeds=self.zeroshot_weights[task]image_embeds=module(images)[1]image_embeds=self.visual_projection(image_embeds)# normalize embeddingsimage_embeds=image_embeds/image_embeds.norm(p=2,dim=-1,keepdim=True)# cosine similaritylogits_per_text=(torch.matmul(text_embeds,image_embeds.t())*self.logit_scale_exp)logits_per_image=logits_per_text.t()returnlogits_per_image
This method computes the image embeddings, normalizes them, and calculates
the cosine similarity with the text embeddings to produce classification logits.
defcompute_logits(self,module,batch,task:str)->Tensor:""" Compute the logits for the given batch and task. This method computes the image embeddings, normalizes them, and calculates the cosine similarity with the text embeddings to produce classification logits. Args: module (nn.Module): The model module. batch (tuple): A batch of input data. task (str): The name of the task. Returns: Tensor: The classification logits for the batch. """images,_=batchtext_embeds=self.zeroshot_weights[task]image_embeds=module(images)[1]image_embeds=self.visual_projection(image_embeds)# normalize embeddingsimage_embeds=image_embeds/image_embeds.norm(p=2,dim=-1,keepdim=True)# cosine similaritylogits_per_text=(torch.matmul(text_embeds,image_embeds.t())*self.logit_scale_exp)logits_per_image=logits_per_text.t()returnlogits_per_image
@functools.cachedefget_shuffled_test_loader_iter(self,task:str):""" Get an iterator over the shuffled test DataLoader for the task. Args: task (str): The name of the task. Returns: iterator: An iterator over the shuffled test DataLoader. """loader=DataLoader(self.get_test_dataset(task),batch_size=self.config.batch_size,shuffle=True,num_workers=self.config.num_workers,pin_memory=True,)ifself._fabricisnotNone:loader=self._fabric.setup_dataloaders(loader)returniter(InfiniteDataLoader(loader))
@functools.cachedefget_test_dataset(self,task:str):""" Load the test dataset for the task. This method is cached, so the dataset is loaded only once. Args: task (str): The name of the task. Returns: CLIPDataset: The test dataset for the task. """log.info(f"Loading test dataset: {task}")dataset=self.modelpool.load_test_dataset(task)dataset=CLIPDataset(dataset,self._clip_processor)returndataset
defon_test_time_adaptation_start(self):""" Prepare for test-time adaptation. This method loads the CLIP processor and constructs the zero-shot classification head for each task. """clip_model_config=self.modelpool.get_model_config("_pretrained_")pretrained_path=(clip_model_config.pretrained_model_name_or_pathifhasattr(clip_model_config,"pretrained_model_name_or_path")elseclip_model_config.path)withtimeit_context("Loading CLIP processor and pretrained CLIP model."):self._clip_processor=CLIPProcessor.from_pretrained(pretrained_path)clip_model:CLIPModel=CLIPModel.from_pretrained(pretrained_path)clip_classifier=HFCLIPClassifier(clip_model,self._clip_processor)self.visual_projection=clip_model.visual_projection.requires_grad_(False)self.logit_scale_exp=clip_model.logit_scale.exp()ifself._fabricisnotNone:self.visual_projection=self._fabric.to_device(self.visual_projection)self.logit_scale_exp=self._fabric.to_device(self.logit_scale_exp)fortaskinself.modelpool.model_names:cache_file=os.path.join(self.config.cache_dir,f"{os.path.basename(pretrained_path)}_{task}_zeroshot_weights.pt",)ifos.path.exists(cache_file):log.info(f"Loading cached zeroshot weights for task: {task}")zeroshot_weights=torch.load(cache_file,map_location="cpu")else:log.info(f"Construct zero shot classification head for task: {task}")classnames,templates=get_classnames_and_templates(task)clip_classifier.set_classification_task(classnames,templates)zeroshot_weights=clip_classifier.zeroshot_weightslog.info(f"save zeroshot weights to {cache_file}")torch.save(zeroshot_weights,cache_file)self.zeroshot_weights[task]=zeroshot_weightsifself._fabricisnotNone:self.zeroshot_weights[task]=self._fabric.to_device(self.zeroshot_weights[task])
A wrapper class for DataLoader to create an infinite data loader.
This is useful in case we are only interested in the number of steps and not the number of epochs.
This class wraps a DataLoader and provides an iterator that resets
when the end of the dataset is reached, creating an infinite loop.
Attributes:
data_loader
(DataLoader)
–
The DataLoader to wrap.
data_iter
(iterator)
–
An iterator over the DataLoader.
Source code in fusion_bench/method/adamerging/clip_task_wise_adamerging.py
classInfiniteDataLoader:""" A wrapper class for DataLoader to create an infinite data loader. This is useful in case we are only interested in the number of steps and not the number of epochs. This class wraps a DataLoader and provides an iterator that resets when the end of the dataset is reached, creating an infinite loop. Attributes: data_loader (DataLoader): The DataLoader to wrap. data_iter (iterator): An iterator over the DataLoader. """def__init__(self,data_loader):self.data_loader=data_loaderself.data_iter=iter(data_loader)def__iter__(self):returnselfdef__next__(self):try:data=next(self.data_iter)exceptStopIteration:self.data_iter=iter(self.data_loader)# Reset the data loaderdata=next(self.data_iter)returndata
This class merges the layers of a pretrained model with those of several fine-tuned models.
The merging is controlled by layer-wise weights, which can be initialized based on a provided configuration or loaded from a file.
Source code in fusion_bench/method/adamerging/layer_wise_adamerging.py
classLayerWiseAdaMergingAlgorithm(ModelFusionAlgorithm,LightningFabricMixin,SimpleProfilerMixin,):""" Implements the Layer-Wise AdaMerging Algorithm. This class merges the layers of a pretrained model with those of several fine-tuned models. The merging is controlled by layer-wise weights, which can be initialized based on a provided configuration or loaded from a file. """def__init__(self,algorithm_config:DictConfig):""" Initialize the LayerWiseAdaMergingAlgorithm with the given configuration. Args: algorithm_config (DictConfig): The configuration for the algorithm. """super().__init__(algorithm_config)@torch.no_grad()defconstruct_layer_wise_merged_model(self,modelpool:ModelPool):""" Constructs a wrapped layer-wise merged model from model pool. This method creates a new wrapped model by merging the layers of a pretrained model with those of several fine-tuned models. The merging is controlled by layer-wise weights, which is a `torch.Tensor` of the shape `(num_models, num_layers)`. The merging weights can be initialized based on a provided configuration or loaded from a file. Args: modelpool (ModelPool): An object containing the pretrained model and fine-tuned models to be merged. Returns: LayerWiseMergedModel: An instance of the merged model with layer-wise weights applied. """pretrained_model=modelpool.load_model("_pretrained_")finetuned_models=[modelpool.load_model(name)fornameinmodelpool.model_names]# initialize layer-wise weights using the provided configuration `init_values` or load from file if `weights` is providedifself.config.weightsisNone:layer_wise_weight=get_layer_wise_weights(num_models=len(modelpool.model_names),num_layers=len(tuple(filter(lambdap:p.requires_grad,pretrained_model.parameters()))),init_values=self.config.init_values,)else:ifisinstance(self.config.weights,str):# self.config.weights is a path to a saved tensorlayer_wise_weight=load_tensor_from_file(self.config.weights)else:raiseValueError(f"Unsupported weights format: {self.config.weights}")module=LayerWiseMergedModel(layer_wise_weight=layer_wise_weight,pretrained_model=pretrained_model,finetuned_models=finetuned_models,clamp_weights=self.config.clamp_weights,tie_weights=self.config.tie_weights,strict=self.config.strict,)print(f"{layer_wise_weight.size()=}, {layer_wise_weight.numel()=}")returnmodule@rank_zero_onlydefsave_merging_weights(self,file_path:str,merging_weights:torch.Tensor):""" Save the merging weights to a file. Args: file_path (str): The path to save the merging weights. merging_weights (torch.Tensor): The merging weights to save. """ifself.fabric.is_global_zeroandself.config.get("save_merging_weights",False):ifisinstance(file_path,str)andnotfile_path.startswith(("/",".")):# if the file path is not absolute or relative to current working directory, save it in the log directorysave_path=os.path.join(self.log_dir,file_path)else:save_path=file_pathlog.info(f"saving merging weights to {save_path}.")ifos.path.dirname(save_path):os.makedirs(os.path.dirname(save_path),exist_ok=True)torch.save(merging_weights.detach().cpu(),save_path)defrun(self,modelpool:ModelPool,**kwargs):""" Run the Layer-Wise AdaMerging Algorithm. This method constructs the wrapped model and performs test-time adaptation if necessary. Args: modelpool (ModelPool): The model pool containing the pretrained and fine-tuned models. Returns: LayerWiseMergedModel: The merged model after test-time adaptation. """log.info("Fusing models using layer-wise adaptive merging.")self.modelpool=modelpoolself.log_hyperparams(self.config)withself.profile("construct the wrapped model"):module=self.construct_layer_wise_merged_model(modelpool)ifself.config.weightsisnotNone:# skip the test-time adaptationreturnmodule.merge_and_unload()else:withself.profile("test-time adaptation"):module=self.test_time_adaptation(module)ifself.config.get("save_merging_weights",False):self.save_merging_weights(self.config.save_merging_weights,module.merge_weight)returnmodule.merge_and_unload()defon_test_time_adaptation_start(self):""" Something to do before the test-time adaptation starts. Such as setting up the task-specific heads. """pass@abstractmethoddefget_shuffled_test_loader_iter(self,task:str)->DataLoader:""" Loader of test dataset for test-time adaptation. labels are not needed. Args: task (str): The name of the task. Returns: DataLoader: The data loader for the test dataset. """pass@abstractmethoddefcompute_logits(self,module,images:Tensor,task:str)->Tensor:""" Compute the logits for the given images and task. Args: module: The model module. images (Tensor): The input images. task (str): The name of the task. Returns: Tensor: The computed logits. """passdeftest_time_adaptation(self,module:LayerWiseMergedModel):""" Perform test-time adaptation on the merged model. This method adapts the merging weights during test-time to improve performance. Args: module (LayerWiseMergedModel): The merged model. Returns: LayerWiseMergedModel: The adapted merged model. """self.on_test_time_adaptation_start()# configure optimizerifself.config.optimizer=="adam":optimizer=torch.optim.Adam([module.merge_weight],lr=self.config.lr)print(f"{optimizer=}")module,optimizer=self.fabric.setup(module,optimizer)else:raiseValueError(f"Unsupported optimizer: {self.config.optimizer}")module.train()module.merge_weights()forstep_idxin(pbar:=tqdm(range(self.config.max_stepsifnotself.is_debug_modeelse1),("[DEBUG MODE] "ifself.is_debug_modeelse"")+"AdaMerging Test-time adaptation",dynamic_ncols=True,)):# default behavior for first-order optimizersfortaskinself.modelpool.model_names:withself.profile("data loading"):batch=next(self.get_shuffled_test_loader_iter(task))withself.profile("forward pass"):logits=self.compute_logits(module,batch[0],task)loss=entropy_loss(logits)withself.profile("backward pass"):self.fabric.backward(loss,retain_graph=True)withself.profile("optimizer step"):optimizer.step()optimizer.zero_grad()withself.profile("merging weights"):module.merge_weights()metrics={"train/loss":loss.item(),"train/weight_max":module.merge_weight.max().item(),"train/weight_min":module.merge_weight.min().item(),"train/weight_mean":module.merge_weight.mean().item(),}self.fabric.log_dict(metrics,step=step_idx)pbar.set_postfix(metrics)log.info(get_memory_usage(f"after adamerging, the memory usage of GPU is:"))self.print_profile_summary()returnmodule
def__init__(self,algorithm_config:DictConfig):""" Initialize the LayerWiseAdaMergingAlgorithm with the given configuration. Args: algorithm_config (DictConfig): The configuration for the algorithm. """super().__init__(algorithm_config)
@abstractmethoddefcompute_logits(self,module,images:Tensor,task:str)->Tensor:""" Compute the logits for the given images and task. Args: module: The model module. images (Tensor): The input images. task (str): The name of the task. Returns: Tensor: The computed logits. """pass
Constructs a wrapped layer-wise merged model from model pool.
This method creates a new wrapped model by merging the layers of a pretrained model with those of several fine-tuned models.
The merging is controlled by layer-wise weights, which is a torch.Tensor of the shape (num_models, num_layers).
The merging weights can be initialized based on a provided configuration or loaded from a file.
@torch.no_grad()defconstruct_layer_wise_merged_model(self,modelpool:ModelPool):""" Constructs a wrapped layer-wise merged model from model pool. This method creates a new wrapped model by merging the layers of a pretrained model with those of several fine-tuned models. The merging is controlled by layer-wise weights, which is a `torch.Tensor` of the shape `(num_models, num_layers)`. The merging weights can be initialized based on a provided configuration or loaded from a file. Args: modelpool (ModelPool): An object containing the pretrained model and fine-tuned models to be merged. Returns: LayerWiseMergedModel: An instance of the merged model with layer-wise weights applied. """pretrained_model=modelpool.load_model("_pretrained_")finetuned_models=[modelpool.load_model(name)fornameinmodelpool.model_names]# initialize layer-wise weights using the provided configuration `init_values` or load from file if `weights` is providedifself.config.weightsisNone:layer_wise_weight=get_layer_wise_weights(num_models=len(modelpool.model_names),num_layers=len(tuple(filter(lambdap:p.requires_grad,pretrained_model.parameters()))),init_values=self.config.init_values,)else:ifisinstance(self.config.weights,str):# self.config.weights is a path to a saved tensorlayer_wise_weight=load_tensor_from_file(self.config.weights)else:raiseValueError(f"Unsupported weights format: {self.config.weights}")module=LayerWiseMergedModel(layer_wise_weight=layer_wise_weight,pretrained_model=pretrained_model,finetuned_models=finetuned_models,clamp_weights=self.config.clamp_weights,tie_weights=self.config.tie_weights,strict=self.config.strict,)print(f"{layer_wise_weight.size()=}, {layer_wise_weight.numel()=}")returnmodule
@abstractmethoddefget_shuffled_test_loader_iter(self,task:str)->DataLoader:""" Loader of test dataset for test-time adaptation. labels are not needed. Args: task (str): The name of the task. Returns: DataLoader: The data loader for the test dataset. """pass
defrun(self,modelpool:ModelPool,**kwargs):""" Run the Layer-Wise AdaMerging Algorithm. This method constructs the wrapped model and performs test-time adaptation if necessary. Args: modelpool (ModelPool): The model pool containing the pretrained and fine-tuned models. Returns: LayerWiseMergedModel: The merged model after test-time adaptation. """log.info("Fusing models using layer-wise adaptive merging.")self.modelpool=modelpoolself.log_hyperparams(self.config)withself.profile("construct the wrapped model"):module=self.construct_layer_wise_merged_model(modelpool)ifself.config.weightsisnotNone:# skip the test-time adaptationreturnmodule.merge_and_unload()else:withself.profile("test-time adaptation"):module=self.test_time_adaptation(module)ifself.config.get("save_merging_weights",False):self.save_merging_weights(self.config.save_merging_weights,module.merge_weight)returnmodule.merge_and_unload()
@rank_zero_onlydefsave_merging_weights(self,file_path:str,merging_weights:torch.Tensor):""" Save the merging weights to a file. Args: file_path (str): The path to save the merging weights. merging_weights (torch.Tensor): The merging weights to save. """ifself.fabric.is_global_zeroandself.config.get("save_merging_weights",False):ifisinstance(file_path,str)andnotfile_path.startswith(("/",".")):# if the file path is not absolute or relative to current working directory, save it in the log directorysave_path=os.path.join(self.log_dir,file_path)else:save_path=file_pathlog.info(f"saving merging weights to {save_path}.")ifos.path.dirname(save_path):os.makedirs(os.path.dirname(save_path),exist_ok=True)torch.save(merging_weights.detach().cpu(),save_path)
classCLIPLayerWiseAdaMergingAlgorithm(CLIPClassificationMixin,LayerWiseAdaMergingAlgorithm,):defon_test_time_adaptation_start(self):""" Here we load the CLIP processor and construct the zero-shot classification head for each task. """self.setup_zero_shot_classification_head()@functools.cachedefget_shuffled_test_loader_iter(self,task:str):returnsuper().get_shuffled_test_loader_iter(task,batch_size=self.config.batch_size,num_workers=self.config.num_workers,)
defon_test_time_adaptation_start(self):""" Here we load the CLIP processor and construct the zero-shot classification head for each task. """self.setup_zero_shot_classification_head()
(ICLR 2024) AdaMerging: Adaptive Model Merging for Multi-Task Learning. https://openreview.net/pdf?id=nZP6NgD3QY ↩
Jason Yosinski, Jeff Clune, Yoshua Bengio, and Hod Lipson. How transferable are features in deep neural networks? Advances in neural information processing systems, 27, 2014. ↩
A. Tang, L. Shen, Y. Luo, N. Yin, L. Zhang, and D. Tao, “Merging Multi-Task Models via Weight-Ensembling Mixture of Experts,” ICML 2024. doi: 10.48550/arXiv.2402.00433. ↩