In the complex landscape of multi-task learning, AdaMerging has emerged as a potent method for adaptively merging model parameters to optimize performance across tasks. Unlike traditional fixed-coefficient methods, AdaMerging autonomously learns merging coefficients, offering a more refined and responsive approach1.
The cornerstone of AdaMerging lies in its adaptive nature, where it learns the coefficients for merging either on a task-wise or layer-wise basis. This adaptability is driven by an entropy minimization strategy applied to unlabeled test samples as a surrogate objective function, which serves to refine the merging coefficients for optimal performance.
where the merging coefficient \(\lambda^{l}_{i}\) and task vector \(\tau^{l}_{i}\) are specific to each layer \(l\) of the model.
By leveraging this adaptive learning approach, AdaMerging significantly enhances the model's ability to generalize across tasks and layers, resulting in a more robust and finely-tuned performance profile. The method’s reliance on entropy minimization ensures that the merging process continually seeks the most informative and stable configuration, adapting to the specific needs of the dataset and tasks at hand.
Task-wise Coefficients.
The below Figure shows the changes during the iteration process of merging coefficient optimization of each task vector in Task-wise AdaMerging and AdaMerging++, which is shown every ten steps. We consistently observe that the merging coefficients of each task vector are inconsistent. When the number of tasks is relatively large, it is obviously undesirable to grid search the coefficients of each task, but our AdaMerging avoids this manual search process.
Layer-wise Coefficients.
The following Figure shows the merging coefficients learned by Layer-wise AdaMerging and AdaMerging++ on ViT-B/32 respectively. We observed that:
The coefficients learned by each layer of each task vector are different, which shows that the importance of each layer in the model merging process is different.
The coefficients learned by shallow layers are generally smaller than those of deep layers, which indicates that shallow layers rely more on the weights of the pre-trained model rather than the weights provided by task vectors, while the deep layers rely more on the weights provided by the task vectors. This may be since the shallow layer learns general features, which are cross-task, while the deep layer learns task-specific features 2. This finding is also consistent with routing analysis in 3.
defentropy_loss(logits:Tensor)->Tensor:""" Compute the entropy loss of a set of logits. Args: logits (Tensor): The logits to compute the entropy loss of. Returns: Tensor: The entropy loss of the logits. """probs=torch.softmax(logits,dim=-1)return-torch.sum(probs*torch.log(probs+1e-8),dim=-1).mean()
classCLIPTaskWiseAdaMergingAlgorithm(TaskWiseAdaMergingAlgorithm):modelpool:HuggingFaceClipVisionPool=None_clip_processor:CLIPProcessor=Nonezeroshot_weights={}def__init__(self,algorithm_config:DictConfig):super().__init__(algorithm_config)defget_task_config(self,task):fortask_configinself.modelpool.config.tta_datasets:iftask_config.name==task:returntask_configraiseValueError(f"Task {task} not found in config")defprepare_dataset_config(self,dataset_config:DictConfig):ifnothasattr(dataset_config,"type"):withopen_dict(dataset_config):dataset_config["type"]=self.modelpool.config.dataset_typereturndataset_config@functools.cachedefget_test_dataset(self,task:str):""" Load the test dataset for the task. This method is cached, so the dataset is loaded only once. """dataset_config=self.get_task_config(task)["dataset"]dataset_config=self.prepare_dataset_config(dataset_config)log.info(f"Loading test dataset: {dataset_config.name}")dataset=load_dataset_from_config(dataset_config)dataset=CLIPDataset(dataset,self._clip_processor)returndataset@functools.cachedefget_shuffled_test_loader_iter(self,task:str):loader=DataLoader(self.get_test_dataset(task),batch_size=self.config.batch_size,shuffle=True,num_workers=self.config.num_workers,pin_memory=True,)ifself._fabricisnotNone:loader=self._fabric.setup_dataloaders(loader)returniter(InfiniteDataLoader(loader))defon_test_time_adaptation_start(self):""" Here we load the CLIP processor and construct the zero-shot classification head for each task. """clip_model_config=self.modelpool.get_model_config("_pretrained_")withtimeit_context("Loading CLIP processor and pretrained CLIP model."):self._clip_processor=CLIPProcessor.from_pretrained(clip_model_config.path)clip_model=CLIPModel.from_pretrained(clip_model_config.path)clip_classifier=HFCLIPClassifier(clip_model,self._clip_processor)self.visual_projection=clip_model.visual_projection.requires_grad_(False)self.logit_scale=clip_model.logit_scale.exp()ifself._fabricisnotNone:self.visual_projection=self._fabric.to_device(self.visual_projection)self.logit_scale=self._fabric.to_device(self.logit_scale)fortaskinself.modelpool.model_names:cache_file=os.path.join(self.config.cache_dir,f"{os.path.basename(clip_model_config.path)}_{task}_zeroshot_weights.pt",)ifos.path.exists(cache_file):log.info(f"Loading cached zeroshot weights for task: {task}")zeroshot_weights=torch.load(cache_file,map_location="cpu")else:log.info(f"Construct zero shot classification head for task: {task}")classnames,templates=get_classnames_and_templates(self.get_task_config(task)["dataset"].name)clip_classifier.set_classification_task(classnames,templates)zeroshot_weights=clip_classifier.zeroshot_weightslog.info(f"save zeroshot weights to {cache_file}")torch.save(zeroshot_weights,cache_file)self.zeroshot_weights[task]=zeroshot_weightsifself._fabricisnotNone:self.zeroshot_weights[task]=self._fabric.to_device(self.zeroshot_weights[task])defcompute_logits(self,module,batch,task:str)->Tensor:images,_=batchtext_embeds=self.zeroshot_weights[task]image_embeds=module(images)[1]image_embeds=self.visual_projection(image_embeds)# normalize embeddingsimage_embeds=image_embeds/image_embeds.norm(p=2,dim=-1,keepdim=True)# cosine similaritylogits_per_text=torch.matmul(text_embeds,image_embeds.t())*self.logit_scalelogits_per_image=logits_per_text.t()returnlogits_per_image
@functools.cachedefget_test_dataset(self,task:str):""" Load the test dataset for the task. This method is cached, so the dataset is loaded only once. """dataset_config=self.get_task_config(task)["dataset"]dataset_config=self.prepare_dataset_config(dataset_config)log.info(f"Loading test dataset: {dataset_config.name}")dataset=load_dataset_from_config(dataset_config)dataset=CLIPDataset(dataset,self._clip_processor)returndataset
defon_test_time_adaptation_start(self):""" Here we load the CLIP processor and construct the zero-shot classification head for each task. """clip_model_config=self.modelpool.get_model_config("_pretrained_")withtimeit_context("Loading CLIP processor and pretrained CLIP model."):self._clip_processor=CLIPProcessor.from_pretrained(clip_model_config.path)clip_model=CLIPModel.from_pretrained(clip_model_config.path)clip_classifier=HFCLIPClassifier(clip_model,self._clip_processor)self.visual_projection=clip_model.visual_projection.requires_grad_(False)self.logit_scale=clip_model.logit_scale.exp()ifself._fabricisnotNone:self.visual_projection=self._fabric.to_device(self.visual_projection)self.logit_scale=self._fabric.to_device(self.logit_scale)fortaskinself.modelpool.model_names:cache_file=os.path.join(self.config.cache_dir,f"{os.path.basename(clip_model_config.path)}_{task}_zeroshot_weights.pt",)ifos.path.exists(cache_file):log.info(f"Loading cached zeroshot weights for task: {task}")zeroshot_weights=torch.load(cache_file,map_location="cpu")else:log.info(f"Construct zero shot classification head for task: {task}")classnames,templates=get_classnames_and_templates(self.get_task_config(task)["dataset"].name)clip_classifier.set_classification_task(classnames,templates)zeroshot_weights=clip_classifier.zeroshot_weightslog.info(f"save zeroshot weights to {cache_file}")torch.save(zeroshot_weights,cache_file)self.zeroshot_weights[task]=zeroshot_weightsifself._fabricisnotNone:self.zeroshot_weights[task]=self._fabric.to_device(self.zeroshot_weights[task])
classLayerWiseAdaMergingAlgorithm(ModelFusionAlgorithm,LightningFabricMixin,SimpleProfilerMixin,):def__init__(self,algorithm_config:DictConfig):super().__init__(algorithm_config)@torch.no_grad()defconstruct_layer_wise_merged_model(self,modelpool:ModelPool):""" Constructs a wrapped layer-wise merged model from model pool. This method creates a new wrapped model by merging the layers of a pretrained model with those of several fine-tuned models. The merging is controlled by layer-wise weights, which is a `torch.Tensor` of the shape `(num_models, num_layers)`. The merging weights can be initialized based on a provided configuration or loaded from a file. Args: modelpool (ModelPool): An object containing the pretrained model and fine-tuned models to be merged. Returns: LayerWiseMergedModel: An instance of the merged model with layer-wise weights applied. """pretrained_model=modelpool.load_model("_pretrained_")finetuned_models=[modelpool.load_model(name)fornameinmodelpool.model_names]# initialize layer-wise weights using the provided configuration `init_values` or load from file if `weights` is providedifself.config.weightsisNone:layer_wise_weight=get_layer_wise_weights(num_models=len(modelpool.model_names),num_layers=len(tuple(filter(lambdap:p.requires_grad,pretrained_model.parameters()))),init_values=self.config.init_values,)else:ifisinstance(self.config.weights,str):# self.config.weights is a path to a saved tensorlayer_wise_weight=load_tensor_from_file(self.config.weights)else:raiseValueError(f"Unsupported weights format: {self.config.weights}")module=LayerWiseMergedModel(layer_wise_weight=layer_wise_weight,pretrained_model=pretrained_model,finetuned_models=finetuned_models,clamp_weights=self.config.clamp_weights,tie_weights=self.config.tie_weights,strict=self.config.strict,)print(f"{layer_wise_weight.size()=}, {layer_wise_weight.numel()=}")returnmodule@rank_zero_onlydefsave_merging_weights(self,file_path:str,merging_weights:torch.Tensor):ifself.fabric.is_global_zeroandself.config.get("save_merging_weights",False):ifisinstance(file_path,str)andnotfile_path.startswith(("/",".")):# if the file path is not absolute or relative to current working directory, save it in the log directorysave_path=os.path.join(self.log_dir,file_path)else:save_path=file_pathlog.info(f"saving merging weights to {save_path}.")ifos.path.dirname(save_path):os.makedirs(os.path.dirname(save_path),exist_ok=True)torch.save(merging_weights.detach().cpu(),save_path)defrun(self,modelpool:ModelPool):log.info("Fusing models using layer-wise adaptive merging.")self.modelpool=modelpoolself.log_hyperparams(self.config)withself.profile("construct the wrapped model"):module=self.construct_layer_wise_merged_model(modelpool)ifself.config.weightsisnotNone:# skip the test-time adaptationreturnmodule.merge_and_unload()else:withself.profile("test-time adaptation"):module=self.test_time_adaptation(module)ifself.config.get("save_merging_weights",False):self.save_merging_weights(self.config.save_merging_weights,module.merge_weight)returnmodule.merge_and_unload()defon_test_time_adaptation_start(self):""" Something to do before the test-time adaptation starts. Such as setting up the task-specific heads. """pass@abstractmethoddefget_shuffled_test_loader_iter(self,task:str)->DataLoader:""" Loader of test dataset for test-time adaptation. labels are not needed. """pass@abstractmethoddefcompute_logits(self,module,images:Tensor,task:str)->Tensor:passdeftest_time_adaptation(self,module:LayerWiseMergedModel):self.on_test_time_adaptation_start()config=self.config# configure optimizerifself.config.optimizer=="adam":optimizer=torch.optim.Adam([module.merge_weight],lr=self.config.lr)print(f"{optimizer=}")module,optimizer=self.fabric.setup(module,optimizer)else:raiseValueError(f"Unsupported optimizer: {self.config.optimizer}")module.train()module.merge_weights()forstep_idxin(pbar:=tqdm(range(self.config.max_stepsifnotself.is_debug_modeelse1),("[DEBUG MODE] "ifself.is_debug_modeelse"")+"AdaMerging Test-time adaptation",dynamic_ncols=True,)):# default behavior for first-order optimizersfortaskinself.modelpool.model_names:withself.profile("data loading"):batch=next(self.get_shuffled_test_loader_iter(task))withself.profile("forward pass"):logits=self.compute_logits(module,batch[0],task)loss=entropy_loss(logits)withself.profile("backward pass"):self.fabric.backward(loss,retain_graph=True)withself.profile("optimizer step"):optimizer.step()optimizer.zero_grad()withself.profile("merging weights"):module.merge_weights()metrics={"train/loss":loss.item(),"train/weight_max":module.merge_weight.max().item(),"train/weight_min":module.merge_weight.min().item(),"train/weight_mean":module.merge_weight.mean().item(),}self.fabric.log_dict(metrics,step=step_idx)pbar.set_postfix(metrics)self.print_profile_summary()returnmodule
Constructs a wrapped layer-wise merged model from model pool.
This method creates a new wrapped model by merging the layers of a pretrained model with those of several fine-tuned models.
The merging is controlled by layer-wise weights, which is a torch.Tensor of the shape (num_models, num_layers).
The merging weights can be initialized based on a provided configuration or loaded from a file.
@torch.no_grad()defconstruct_layer_wise_merged_model(self,modelpool:ModelPool):""" Constructs a wrapped layer-wise merged model from model pool. This method creates a new wrapped model by merging the layers of a pretrained model with those of several fine-tuned models. The merging is controlled by layer-wise weights, which is a `torch.Tensor` of the shape `(num_models, num_layers)`. The merging weights can be initialized based on a provided configuration or loaded from a file. Args: modelpool (ModelPool): An object containing the pretrained model and fine-tuned models to be merged. Returns: LayerWiseMergedModel: An instance of the merged model with layer-wise weights applied. """pretrained_model=modelpool.load_model("_pretrained_")finetuned_models=[modelpool.load_model(name)fornameinmodelpool.model_names]# initialize layer-wise weights using the provided configuration `init_values` or load from file if `weights` is providedifself.config.weightsisNone:layer_wise_weight=get_layer_wise_weights(num_models=len(modelpool.model_names),num_layers=len(tuple(filter(lambdap:p.requires_grad,pretrained_model.parameters()))),init_values=self.config.init_values,)else:ifisinstance(self.config.weights,str):# self.config.weights is a path to a saved tensorlayer_wise_weight=load_tensor_from_file(self.config.weights)else:raiseValueError(f"Unsupported weights format: {self.config.weights}")module=LayerWiseMergedModel(layer_wise_weight=layer_wise_weight,pretrained_model=pretrained_model,finetuned_models=finetuned_models,clamp_weights=self.config.clamp_weights,tie_weights=self.config.tie_weights,strict=self.config.strict,)print(f"{layer_wise_weight.size()=}, {layer_wise_weight.numel()=}")returnmodule
@abstractmethoddefget_shuffled_test_loader_iter(self,task:str)->DataLoader:""" Loader of test dataset for test-time adaptation. labels are not needed. """pass
classCLIPLayerWiseAdaMergingAlgorithm(CLIPClassificationMixin,LayerWiseAdaMergingAlgorithm,):defon_test_time_adaptation_start(self):""" Here we load the CLIP processor and construct the zero-shot classification head for each task. """self.setup_zero_shot_classification_head()
defon_test_time_adaptation_start(self):""" Here we load the CLIP processor and construct the zero-shot classification head for each task. """self.setup_zero_shot_classification_head()
(ICLR 2024) AdaMerging: Adaptive Model Merging for Multi-Task Learning. https://openreview.net/pdf?id=nZP6NgD3QY ↩
Jason Yosinski, Jeff Clune, Yoshua Bengio, and Hod Lipson. How transferable are features in deep neural networks? Advances in neural information processing systems, 27, 2014. ↩
A. Tang, L. Shen, Y. Luo, N. Yin, L. Zhang, and D. Tao, “Merging Multi-Task Models via Weight-Ensembling Mixture of Experts,” ICML 2024. doi: 10.48550/arXiv.2402.00433. ↩