fusion_bench.modelpool¶
Base Class¶
BaseModelPool
¶
Bases: HydraConfigMixin
, BaseYAMLSerializable
A class for managing and interacting with a pool of models along with their associated datasets or other specifications. For example, a model pool may contain multiple models, each with its own training, validation, and testing datasets. As for the specifications, a vision model pool may contain image preprocessor, and a language model pool may contain a tokenizer.
Attributes:
-
_models
(DictConfig
) –Configuration for all models in the pool.
-
_train_datasets
(Optional[DictConfig]
) –Configuration for training datasets.
-
_val_datasets
(Optional[DictConfig]
) –Configuration for validation datasets.
-
_test_datasets
(Optional[DictConfig]
) –Configuration for testing datasets.
-
_usage_
(Optional[str]
) –Optional usage information.
-
_version_
(Optional[str]
) –Optional version information.
Source code in fusion_bench/modelpool/base_pool.py
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 |
|
all_model_names
property
¶
Get the names of all models in the pool, including special models.
Returns:
-
List[str]
–List[str]: A list of all model names.
has_pretrained
property
¶
Check if the model pool contains a pretrained model.
Returns:
-
bool
(bool
) –True if a pretrained model is available, False otherwise.
has_test_dataset
property
¶
Check if the model pool contains testing datasets.
Returns:
-
bool
(bool
) –True if testing datasets are available, False otherwise.
has_train_dataset
property
¶
Check if the model pool contains training datasets.
Returns:
-
bool
(bool
) –True if training datasets are available, False otherwise.
has_val_dataset
property
¶
Check if the model pool contains validation datasets.
Returns:
-
bool
(bool
) –True if validation datasets are available, False otherwise.
model_names
property
¶
Get the names of regular models, excluding special models.
Returns:
-
List[str]
–List[str]: A list of regular model names.
test_dataset_names
property
¶
Get the names of testing datasets.
Returns:
-
List[str]
–List[str]: A list of testing dataset names.
train_dataset_names
property
¶
Get the names of training datasets.
Returns:
-
List[str]
–List[str]: A list of training dataset names.
val_dataset_names
property
¶
Get the names of validation datasets.
Returns:
-
List[str]
–List[str]: A list of validation dataset names.
get_model_config(model_name, return_copy=True)
¶
Get the configuration for the specified model.
Parameters:
-
model_name
(str
) –The name of the model.
Returns:
-
DictConfig
(DictConfig
) –The configuration for the specified model.
Source code in fusion_bench/modelpool/base_pool.py
get_model_path(model_name)
¶
Get the path for the specified model.
Parameters:
-
model_name
(str
) –The name of the model.
Returns:
-
str
(str
) –The path for the specified model.
Source code in fusion_bench/modelpool/base_pool.py
is_special_model(model_name)
staticmethod
¶
Determine if a model is special based on its name.
Parameters:
-
model_name
(str
) –The name of the model.
Returns:
-
bool
(bool
) –True if the model name indicates a special model, False otherwise.
Source code in fusion_bench/modelpool/base_pool.py
load_model(model_name_or_config, *args, **kwargs)
¶
Load a model from the pool based on the provided configuration.
Parameters:
-
model_name_or_config
(Union[str, DictConfig]
) –The model name or configuration. - If str: should be a key in self._models - If DictConfig: should be a configuration dict for instantiation
-
*args
–Additional positional arguments passed to model instantiation.
-
**kwargs
–Additional keyword arguments passed to model instantiation.
Returns:
-
Module
–nn.Module: The instantiated or retrieved model.
Source code in fusion_bench/modelpool/base_pool.py
load_pretrained_or_first_model(*args, **kwargs)
¶
Load the pretrained model if available, otherwise load the first available model.
Returns:
-
–
nn.Module: The loaded model.
Source code in fusion_bench/modelpool/base_pool.py
load_test_dataset(dataset_name, *args, **kwargs)
¶
Load the testing dataset for the specified model.
Parameters:
-
dataset_name
(str
) –The name of the model.
Returns:
-
Dataset
(Dataset
) –The instantiated testing dataset.
Source code in fusion_bench/modelpool/base_pool.py
load_train_dataset(dataset_name, *args, **kwargs)
¶
Load the training dataset for the specified model.
Parameters:
-
dataset_name
(str
) –The name of the model.
Returns:
-
Dataset
(Dataset
) –The instantiated training dataset.
Source code in fusion_bench/modelpool/base_pool.py
load_val_dataset(dataset_name, *args, **kwargs)
¶
Load the validation dataset for the specified model.
Parameters:
-
dataset_name
(str
) –The name of the model.
Returns:
-
Dataset
(Dataset
) –The instantiated validation dataset.
Source code in fusion_bench/modelpool/base_pool.py
save_model(model, path, *args, **kwargs)
¶
Save the state dictionary of the model to the specified path.
Parameters:
-
model
(Module
) –The model whose state dictionary is to be saved.
-
path
(str
) –The path where the state dictionary will be saved.
Source code in fusion_bench/modelpool/base_pool.py
Vision Model Pool¶
NYUv2 Tasks (ResNet)¶
NYUv2ModelPool
¶
Bases: ModelPool
Source code in fusion_bench/modelpool/nyuv2_modelpool.py
CLIP Vision Encoder¶
CLIPVisionModelPool
¶
Bases: BaseModelPool
A model pool for managing Hugging Face's CLIP Vision models.
This class extends the base ModelPool
class and overrides its methods to handle
the specifics of the CLIP Vision models provided by the Hugging Face Transformers library.
Source code in fusion_bench/modelpool/clip_vision/modelpool.py
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
|
load_model(model_name_or_config, *args, **kwargs)
¶
Load a CLIPVisionModel from the model pool with support for various configuration formats.
This method provides flexible model loading capabilities, handling different types of model configurations including string paths, pre-instantiated models, and complex configurations.
Supported configuration formats: 1. String model paths (e.g., Hugging Face model IDs) 2. Pre-instantiated nn.Module objects 3. DictConfig objects for complex configurations
Example configuration:
models:
# Simple string paths to Hugging Face models
cifar10: tanganke/clip-vit-base-patch32_cifar10
sun397: tanganke/clip-vit-base-patch32_sun397
stanford-cars: tanganke/clip-vit-base-patch32_stanford-cars
# Complex configuration with additional parameters
custom_model:
_target_: transformers.CLIPVisionModel.from_pretrained
pretrained_model_name_or_path: openai/clip-vit-base-patch32
torch_dtype: float16
Parameters:
-
model_name_or_config
(Union[str, DictConfig]
) –Either a model name from the pool or a configuration dictionary for instantiating the model.
-
*args
–Additional positional arguments passed to model loading/instantiation.
-
**kwargs
–Additional keyword arguments passed to model loading/instantiation.
Returns:
-
CLIPVisionModel
(CLIPVisionModel
) –The loaded CLIPVisionModel instance.
Source code in fusion_bench/modelpool/clip_vision/modelpool.py
save_model(model, path)
¶
Save a CLIP Vision model to the given path.
Parameters:
-
model
(CLIPVisionModel
) –The model to save.
-
path
(str
) –The path to save the model to.
Source code in fusion_bench/modelpool/clip_vision/modelpool.py
OpenCLIP Vision Encoder¶
OpenCLIPVisionModelPool
¶
Bases: BaseModelPool
A model pool for managing OpenCLIP Vision models (models from task vector paper).
Source code in fusion_bench/modelpool/openclip_vision/modelpool.py
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 |
|
load_classification_head(model_name_or_config, *args, **kwargs)
¶
The model config can be:
- A string, which is the path to the model checkpoint in pickle format. Load directly using
torch.load
. - Default, load the model using
instantiate
from hydra.
Source code in fusion_bench/modelpool/openclip_vision/modelpool.py
load_model(model_name_or_config, *args, **kwargs)
¶
The model config can be:
- A string, which is the path to the model checkpoint in pickle format. Load directly using
torch.load
. - {"model_name": str, "pickle_path": str}, load the model from the binary file (pickle format). This will first construct the model using
ImageEncoder(model_name)
, and then load the state dict from model located in the pickle file. - {"model_name": str, "state_dict_path": str}, load the model from the state dict file. This will first construct the model using
ImageEncoder(model_name)
, and then load the state dict from the file. - Default, load the model using
instantiate
from hydra.
Source code in fusion_bench/modelpool/openclip_vision/modelpool.py
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
|
NLP Model Pool¶
GPT-2¶
HuggingFaceGPT2ClassificationPool = GPT2ForSequenceClassificationPool
module-attribute
¶
GPT2ForSequenceClassificationPool
¶
Bases: BaseModelPool
Source code in fusion_bench/modelpool/huggingface_gpt2_classification.py
Seq2Seq Language Models (Flan-T5)¶
Seq2SeqLMPool
¶
Bases: BaseModelPool
A model pool specialized for sequence-to-sequence language models.
This model pool provides management and loading capabilities for sequence-to-sequence (seq2seq) language models such as T5, BART, and mT5. It extends the base model pool functionality with seq2seq-specific features including tokenizer management and model configuration handling.
Seq2seq models are particularly useful for tasks that require generating output sequences from input sequences, such as translation, summarization, question answering, and text generation. This pool streamlines the process of loading and configuring multiple seq2seq models for fusion and ensemble scenarios.
Key Features
- Specialized loading for AutoModelForSeq2SeqLM models
- Integrated tokenizer management
- Support for model-specific keyword arguments
- Automatic dtype parsing and configuration
- Compatible with PEFT (Parameter-Efficient Fine-Tuning) adapters
Attributes:
-
_tokenizer
–Configuration for the tokenizer associated with the models
-
_model_kwargs
–Default keyword arguments applied to all model loading operations
Example
pool = Seq2SeqLMPool(
models={
"t5_base": "t5-base",
"t5_large": "t5-large",
"custom_model": "/path/to/local/model"
},
tokenizer={"_target_": "transformers.T5Tokenizer",
"pretrained_model_name_or_path": "t5-base"},
model_kwargs={"torch_dtype": "float16", "device_map": "auto"}
)
model = pool.load_model("t5_base")
tokenizer = pool.load_tokenizer()
Source code in fusion_bench/modelpool/seq2seq_lm/modelpool.py
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
|
__init__(models, *, tokenizer, model_kwargs=None, **kwargs)
¶
Initialize the sequence-to-sequence language model pool.
Sets up the model pool with configurations for models, tokenizer, and default model loading parameters. Automatically processes model kwargs to handle special configurations like torch_dtype parsing.
Parameters:
-
models
(DictConfig
) –Configuration dictionary specifying the seq2seq models to manage. Keys are model names, values can be model paths/names or detailed configs.
-
tokenizer
(Optional[DictConfig]
) –Configuration for the tokenizer to use with the models. Can be a simple path/name or detailed configuration with target.
-
model_kwargs
(Optional[DictConfig]
, default:None
) –Default keyword arguments applied to all model loading operations. Common options include torch_dtype, device_map, etc. The torch_dtype field is automatically parsed from string to dtype.
-
**kwargs
–Additional arguments passed to the parent BaseModelPool.
Example
Source code in fusion_bench/modelpool/seq2seq_lm/modelpool.py
load_model(model_name_or_config, *args, **kwargs)
¶
Load a sequence-to-sequence language model from the pool.
Loads a seq2seq model using the parent class loading mechanism while automatically applying the pool's default model kwargs. The method merges the pool's model_kwargs with any additional kwargs provided, giving priority to the explicitly provided kwargs.
Parameters:
-
model_name_or_config
(str | DictConfig
) –Either a string model name from the pool configuration or a DictConfig containing model loading parameters.
-
*args
–Additional positional arguments passed to the parent load_model method.
-
**kwargs
–Additional keyword arguments that override the pool's default model_kwargs. Common options include device, torch_dtype, etc.
Returns:
-
AutoModelForSeq2SeqLM
–The loaded sequence-to-sequence language model.
Source code in fusion_bench/modelpool/seq2seq_lm/modelpool.py
load_tokenizer(*args, **kwargs)
¶
Load the tokenizer associated with the sequence-to-sequence models.
Loads a tokenizer based on the tokenizer configuration provided during pool initialization. The tokenizer should be compatible with the seq2seq models in the pool and is typically used for preprocessing input text and postprocessing generated output.
Parameters:
-
*args
–Additional positional arguments passed to the tokenizer constructor.
-
**kwargs
–Additional keyword arguments passed to the tokenizer constructor.
Returns:
-
PreTrainedTokenizer
–The loaded tokenizer instance compatible with the seq2seq models in this pool.
Raises:
-
AssertionError
–If no tokenizer configuration is provided.
Source code in fusion_bench/modelpool/seq2seq_lm/modelpool.py
SequenceClassificationModelPool
¶
Bases: BaseModelPool
Source code in fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py
save_model(model, path, push_to_hub=False, model_dtype=None, save_tokenizer=False, tokenizer_kwargs=None, **kwargs)
¶
Save the model to the specified path.
Parameters:
-
model
(PreTrainedModel
) –The model to be saved.
-
path
(str
) –The path where the model will be saved.
-
push_to_hub
(bool
, default:False
) –Whether to push the model to the Hugging Face Hub. Defaults to False.
-
save_tokenizer
(bool
, default:False
) –Whether to save the tokenizer along with the model. Defaults to False.
-
**kwargs
–Additional keyword arguments passed to the
save_pretrained
method.
Source code in fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py
PeftModelForSeq2SeqLMPool
¶
Bases: ModelPool
Source code in fusion_bench/modelpool/PeftModelForSeq2SeqLM.py
load_model(model_config)
¶
Load a model based on the provided configuration.
The configuration options of model_config
are:
- name: The name of the model. If it is "pretrained", a pretrained Seq2Seq language model is returned.
- path: The path where the model is stored.
- is_trainable: A boolean indicating whether the model parameters should be trainable. Default is
True
. - merge_and_unload: A boolean indicating whether to merge and unload the PEFT model after loading. Default is
True
.
Parameters:
-
model_config
(str | DictConfig
) –The configuration for the model. This can be either a string (name of the model) or a DictConfig object containing the model configuration.
Returns:
-
model
–The loaded model. If the model name is "pretrained", it returns a pretrained Seq2Seq language model. Otherwise, it returns a PEFT model.
Source code in fusion_bench/modelpool/PeftModelForSeq2SeqLM.py
Causal Language Models (Llama, Mistral, Qwen...)¶
CausalLMPool
¶
Bases: BaseModelPool
A model pool for managing and loading causal language models.
This class provides a unified interface for loading and managing multiple causal language models, typically used in model fusion and ensemble scenarios. It supports both eager and lazy loading strategies, and handles model configuration through YAML configs or direct instantiation.
The pool can manage models from Hugging Face Hub, local paths, or custom configurations. It also provides tokenizer management and model saving capabilities with optional Hugging Face Hub integration.
Parameters:
-
models
–Dictionary or configuration specifying the models to be managed. Can contain model names mapped to paths or detailed configurations.
-
tokenizer
(Optional[DictConfig | str]
) –Tokenizer configuration, either a string path/name or a DictConfig with detailed tokenizer settings.
-
model_kwargs
(Optional[DictConfig]
, default:None
) –Additional keyword arguments passed to model loading. Common options include torch_dtype, device_map, etc.
-
enable_lazy_loading
(bool
, default:False
) –Whether to use lazy loading for models. When True, models are loaded as LazyStateDict objects instead of actual models, which can save memory for large model collections.
-
**kwargs
–Additional arguments passed to the parent BaseModelPool.
Example
Source code in fusion_bench/modelpool/causal_lm/causal_lm.py
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 |
|
get_model_kwargs()
¶
Get processed model keyword arguments for model loading.
Converts the stored model_kwargs
from DictConfig to a regular dictionary
and processes special arguments like torch_dtype for proper model loading.
Returns:
-
dict
–Processed keyword arguments ready to be passed to model loading functions. The torch_dtype field, if present, is converted from string to the appropriate torch dtype object.
Source code in fusion_bench/modelpool/causal_lm/causal_lm.py
get_model_path(model_name)
¶
Extract the model path from the model configuration.
Parameters:
-
model_name
(str
) –The name of the model as defined in the models configuration.
Returns:
-
str
–The path or identifier for the model. For string configurations, returns the string directly. For dict configurations, extracts the 'pretrained_model_name_or_path' field.
Raises:
-
RuntimeError
–If the model configuration is invalid or the model name is not found in the configuration.
Source code in fusion_bench/modelpool/causal_lm/causal_lm.py
load_model(model_name_or_config, *args, **kwargs)
¶
Load a causal language model from the model pool.
This method supports multiple loading strategies: 1. Loading by model name from the configured model pool 2. Loading from a direct configuration dictionary 3. Lazy loading using LazyStateDict for memory efficiency
The method automatically handles different model configuration formats and applies the appropriate loading strategy based on the enable_lazy_loading flag.
Parameters:
-
model_name_or_config
(str | DictConfig
) –Either a string model name that exists in the model pool configuration, or a DictConfig/dict containing the model configuration directly.
-
*args
–Additional positional arguments passed to the model constructor.
-
**kwargs
–Additional keyword arguments passed to the model constructor. These will be merged with the pool's model_kwargs.
Returns:
-
Union[PreTrainedModel, LazyStateDict]
–Union[PreTrainedModel, LazyStateDict]: The loaded model. Returns a PreTrainedModel for normal loading or a LazyStateDict for lazy loading.
Raises:
-
RuntimeError
–If the model configuration is invalid.
-
KeyError
–If the model name is not found in the model pool.
Example YAML configurations
Simple string configuration:
Detailed configuration:
Source code in fusion_bench/modelpool/causal_lm/causal_lm.py
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
|
load_tokenizer(*args, **kwargs)
¶
Load the tokenizer associated with this model pool.
Loads a tokenizer based on the tokenizer configuration provided during pool initialization. Supports both simple string paths and detailed configuration dictionaries.
Parameters:
-
*args
–Additional positional arguments passed to the tokenizer constructor.
-
**kwargs
–Additional keyword arguments passed to the tokenizer constructor.
Returns:
-
PreTrainedTokenizer
(PreTrainedTokenizer
) –The loaded tokenizer instance.
Raises:
-
AssertionError
–If no tokenizer is defined in the configuration.
Example YAML configurations
Simple string configuration:
Detailed configuration:
Source code in fusion_bench/modelpool/causal_lm/causal_lm.py
save_model(model, path, push_to_hub=False, model_dtype=None, save_tokenizer=False, tokenizer_kwargs=None, tokenizer=None, algorithm_config=None, description=None, base_model_in_modelcard=True, **kwargs)
¶
Save a model to the specified path with optional tokenizer and Hub upload.
This method provides comprehensive model saving capabilities including optional tokenizer saving, dtype conversion, model card creation, and Hugging Face Hub upload. The model is saved in the standard Hugging Face format.
Parameters:
-
model
(PreTrainedModel
) –The PreTrainedModel instance to be saved.
-
path
(str
) –The local path where the model will be saved. Supports tilde expansion for home directory paths.
-
push_to_hub
(bool
, default:False
) –Whether to push the saved model to the Hugging Face Hub. Requires proper authentication and repository permissions.
-
model_dtype
(Optional[str]
, default:None
) –Optional string specifying the target dtype for the model before saving (e.g., "float16", "bfloat16"). The model will be converted to this dtype before saving.
-
save_tokenizer
(bool
, default:False
) –Whether to save the tokenizer alongside the model. If True, the tokenizer will be loaded using the pool's tokenizer configuration and saved to the same path.
-
tokenizer_kwargs
–Additional keyword arguments for tokenizer loading when save_tokenizer is True.
-
tokenizer
(Optional[PreTrainedTokenizer]
, default:None
) –Optional pre-loaded tokenizer instance. If provided, this tokenizer will be saved regardless of the save_tokenizer flag.
-
algorithm_config
(Optional[DictConfig]
, default:None
) –Optional DictConfig containing algorithm configuration. If provided, a model card will be created with algorithm details.
-
description
(Optional[str]
, default:None
) –Optional description for the model card. If not provided and algorithm_config is given, a default description will be generated.
-
**kwargs
–Additional keyword arguments passed to the model's save_pretrained method.
Example
>>> pool = CausalLMPool(models=..., tokenizer=...)
>>> model = pool.load_model("my_model")
>>> pool.save_model(
... model,
... "/path/to/save",
... save_tokenizer=True,
... model_dtype="float16",
... push_to_hub=True,
... algorithm_config=algorithm_config,
... description="Custom merged model"
... )
Source code in fusion_bench/modelpool/causal_lm/causal_lm.py
265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 |
|
CausalLMBackbonePool
¶
Bases: CausalLMPool
A specialized model pool that loads only the transformer backbone layers.
This class extends CausalLMPool to provide access to just the transformer layers (backbone) of causal language models, excluding the language modeling head and embeddings. This is useful for model fusion scenarios where only the core transformer layers are needed.
The class automatically extracts the model.layers
component from loaded
AutoModelForCausalLM instances, providing direct access to the transformer
blocks. Lazy loading is not supported for this pool type.
Note
This pool automatically disables lazy loading as it needs to access the internal structure of the model to extract the backbone layers.
Example
Source code in fusion_bench/modelpool/causal_lm/causal_lm.py
load_model(model_name_or_config, *args, **kwargs)
¶
Load only the transformer backbone layers from a causal language model.
This method loads a complete causal language model and then extracts only the transformer layers (backbone), discarding the embedding layers and language modeling head. This is useful for model fusion scenarios where only the core transformer computation is needed.
Parameters:
-
model_name_or_config
(str | DictConfig
) –Either a string model name from the pool configuration or a DictConfig with model loading parameters.
-
*args
–Additional positional arguments passed to the parent load_model method.
-
**kwargs
–Additional keyword arguments passed to the parent load_model method.
Returns:
-
Module
(Module
) –The transformer layers (typically a nn.ModuleList) containing the core transformer blocks without embeddings or output heads.
Note
Lazy loading is automatically disabled for this method as it needs to access the internal model structure to extract the layers.
Source code in fusion_bench/modelpool/causal_lm/causal_lm.py
load_peft_causal_lm(base_model_path, peft_model_path, torch_dtype='bfloat16', is_trainable=True, merge_and_unload=False)
¶
Load a causal language model with PEFT (Parameter-Efficient Fine-Tuning) adapters.
This function loads a base causal language model and applies PEFT adapters (such as LoRA, AdaLoRA, or other parameter-efficient fine-tuning methods) to create a fine-tuned model. It supports both keeping the adapters separate or merging them into the base model.
Parameters:
-
base_model_path
(str
) –Path or identifier for the base causal language model. Can be a Hugging Face model name or local path.
-
peft_model_path
(str
) –Path to the PEFT adapter configuration and weights. This should contain the adapter_config.json and adapter weights.
-
torch_dtype
(str
, default:'bfloat16'
) –The torch data type to use for the model. Common options include "float16", "bfloat16", "float32". Defaults to "bfloat16".
-
is_trainable
(bool
, default:True
) –Whether the loaded PEFT model should be trainable. Set to False for inference-only usage to save memory.
-
merge_and_unload
(bool
, default:False
) –Whether to merge the PEFT adapters into the base model and unload the adapter weights. When True, returns a standard PreTrainedModel instead of a PeftModel.
Returns:
-
–
Union[PeftModel, PreTrainedModel]: The loaded model with PEFT adapters. Returns a PeftModel if merge_and_unload is False, or a PreTrainedModel if the adapters are merged and unloaded.
Example
>>> # Load model with adapters for training
>>> model = load_peft_causal_lm(
... "microsoft/DialoGPT-medium",
... "/path/to/lora/adapters",
... is_trainable=True
... )
>>> # Load and merge adapters for inference
>>> merged_model = load_peft_causal_lm(
... "microsoft/DialoGPT-medium",
... "/path/to/lora/adapters",
... merge_and_unload=True,
... is_trainable=False
... )