7
7
8
8
import random
9
9
import warnings
10
+ from typing import Any , Dict , Optional , Union
10
11
11
12
import fire
12
13
import numpy as np
17
18
import torch .utils .data
18
19
from peft import PeftModel , get_peft_model
19
20
from torch .optim .lr_scheduler import StepLR
21
+ from transformers import AutoModel , AutoModelForCausalLM , AutoTokenizer
20
22
21
- from QEfficient .finetune .configs .training import train_config as TRAIN_CONFIG
23
+ from QEfficient .finetune .configs .training import TrainConfig
22
24
from QEfficient .finetune .utils .config_utils import (
23
25
generate_dataset_config ,
24
26
generate_peft_config ,
32
34
from QEfficient .finetune .utils .train_utils import get_longest_seq_length , print_model_size , train
33
35
from QEfficient .utils ._utils import login_and_download_hf_lm
34
36
37
+ # Try importing QAIC-specific module, proceed without it if unavailable
35
38
try :
36
39
import torch_qaic # noqa: F401
37
40
except ImportError as e :
38
- print (f"Warning: { e } . Moving ahead without these qaic modules." )
41
+ print (f"Warning: { e } . Proceeding without QAIC modules." )
39
42
40
43
41
- from transformers import AutoModelForCausalLM , AutoModelForSequenceClassification , AutoTokenizer
44
+ from transformers import AutoModelForSequenceClassification
42
45
43
46
# Suppress all warnings
44
47
warnings .filterwarnings ("ignore" )
45
48
46
49
47
- def main (** kwargs ):
48
- """
49
- Helper function to finetune the model on QAic.
50
+ def setup_distributed_training (train_config : TrainConfig ) -> None :
51
+ """Initialize distributed training environment if enabled.
50
52
51
- .. code-block:: bash
53
+ Args:
54
+ train_config (TrainConfig): Training configuration object.
52
55
53
- python -m QEfficient.cloud.finetune OPTIONS
56
+ Notes:
57
+ - If distributed data parallel (DDP) is disabled, this function does nothing.
58
+ - Ensures the device is not CPU and does not specify an index for DDP compatibility.
59
+ - Initializes the process group using the specified distributed backend.
54
60
61
+ Raises:
62
+ AssertionError: If device is CPU or includes an index with DDP enabled.
55
63
"""
56
- # update the configuration for the training process
57
- train_config = TRAIN_CONFIG ()
58
- update_config (train_config , ** kwargs )
59
- dataset_config = generate_dataset_config (train_config , kwargs )
60
- device = train_config .device
64
+ if not train_config .enable_ddp :
65
+ return
61
66
62
- # dist init
63
- if train_config .enable_ddp :
64
- # TODO: may have to init qccl backend, next try run with torchrun command
65
- torch_device = torch .device (device )
66
- assert torch_device .type != "cpu" , "Host doesn't support single-node DDP"
67
- assert torch_device .index is None , (
68
- f"DDP requires specification of device type only, however provided device index as well: { torch_device } "
69
- )
70
- dist .init_process_group (backend = train_config .dist_backend )
71
- # from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank
72
- getattr (torch , torch_device .type ).set_device (dist .get_rank ())
67
+ torch_device = torch .device (train_config .device )
68
+ assert torch_device .type != "cpu" , "Host doesn't support single-node DDP"
69
+ assert torch_device .index is None , f"DDP requires only device type, got: { torch_device } "
70
+
71
+ dist .init_process_group (backend = train_config .dist_backend )
72
+ # from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank
73
+ getattr (torch , torch_device .type ).set_device (dist .get_rank ())
73
74
74
- # Set the seeds for reproducibility
75
- torch .manual_seed (train_config .seed )
76
- random .seed (train_config .seed )
77
- np .random .seed (train_config .seed )
78
75
79
- # Load the pre-trained model and setup its configuration
80
- # config = AutoConfig.from_pretrained(train_config.model_name)
76
+ def setup_seeds (seed : int ) -> None :
77
+ """Set random seeds across libraries for reproducibility.
78
+
79
+ Args:
80
+ seed (int): Seed value to set for random number generators.
81
+
82
+ Notes:
83
+ - Sets seeds for PyTorch, Python's random module, and NumPy.
84
+ """
85
+ torch .manual_seed (seed )
86
+ random .seed (seed )
87
+ np .random .seed (seed )
88
+
89
+
90
+ def load_model_and_tokenizer (
91
+ train_config : TrainConfig , dataset_config : Any , peft_config_file : str , ** kwargs
92
+ ) -> tuple [AutoModelForCausalLM , AutoTokenizer ]:
93
+ """Load the pre-trained model and tokenizer from Hugging Face.
94
+
95
+ Args:
96
+ config (TrainConfig): Training configuration object containing model and tokenizer names.
97
+ dataset_config (Any): A dataclass object representing dataset configuration.
98
+ peft_config_file (str): Path to PEFT config file used for PEFT finetuning.
99
+ kwargs: Additional arguments to override PEFT config.
100
+
101
+ Returns:
102
+ tuple: A tuple of two values.
103
+ - Model with pretrained weights loaded.
104
+ - Model's tokenizer (AutoTokenizer).
105
+
106
+ Notes:
107
+ - Downloads the model if not already cached using login_and_download_hf_lm.
108
+ - Configures the model with FP16 precision and disables caching for training.
109
+ - Resizes model embeddings if tokenizer vocab size exceeds model embedding size.
110
+ - Sets pad_token_id to eos_token_id if not defined in the tokenizer.
111
+ """
81
112
pretrained_model_path = login_and_download_hf_lm (train_config .model_name )
82
113
if train_config .task_type == "seq_classification" :
83
114
model = AutoModelForSequenceClassification .from_pretrained (
@@ -104,7 +135,6 @@ def main(**kwargs):
104
135
torch_dtype = torch .float16 ,
105
136
)
106
137
107
- # Load the tokenizer and add special tokens
108
138
tokenizer = AutoTokenizer .from_pretrained (
109
139
train_config .model_name if train_config .tokenizer_name is None else train_config .tokenizer_name
110
140
)
@@ -114,14 +144,12 @@ def main(**kwargs):
114
144
# If there is a mismatch between tokenizer vocab size and embedding matrix,
115
145
# throw a warning and then expand the embedding matrix
116
146
if len (tokenizer ) > model .get_input_embeddings ().weight .shape [0 ]:
117
- print ("WARNING: Resizing the embedding matrix to match the tokenizer vocab size." )
147
+ print ("WARNING: Resizing embedding matrix to match tokenizer vocab size." )
118
148
model .resize_token_embeddings (len (tokenizer ))
119
149
150
+ # FIXME (Meet): Cover below line inside the logger once it is implemented.
120
151
print_model_size (model , train_config )
121
152
122
- # print the datatype of the model parameters
123
- # print(get_parameter_dtypes(model))
124
-
125
153
# Note: Need to call this before calling PeftModel.from_pretrained or get_peft_model.
126
154
# Because, both makes model.is_gradient_checkpointing = True which is used in peft library to
127
155
# apply gradient checkpointing related hooks to the input embeddings. Without this we will get
@@ -134,17 +162,70 @@ def main(**kwargs):
134
162
else :
135
163
raise RuntimeError ("Given model doesn't support gradient checkpointing. Please disable it and run it." )
136
164
137
- if train_config .use_peft :
138
- # Load the pre-trained peft model checkpoint and setup its configuration
139
- if train_config .from_peft_checkpoint :
140
- model = PeftModel .from_pretrained (model , train_config .from_peft_checkpoint , is_trainable = True )
141
- peft_config = model .peft_config
142
- # Generate the peft config and start fine-tuning from original model
143
- else :
144
- peft_config = generate_peft_config (train_config , kwargs )
145
- model = get_peft_model (model , peft_config )
146
- model .print_trainable_parameters ()
165
+ model = apply_peft (model , train_config , peft_config_file , ** kwargs )
166
+
167
+ return model , tokenizer
168
+
169
+
170
+ def apply_peft (
171
+ model : AutoModel , train_config : TrainConfig , peft_config_file : Dict , ** kwargs
172
+ ) -> Union [AutoModel , PeftModel ]:
173
+ """Apply Parameter-Efficient Fine-Tuning (PEFT) to the model if enabled.
174
+
175
+ Args:
176
+ model (AutoModel): Huggingface model.
177
+ train_config (TrainConfig): Training configuration object.
178
+ peft_config_file (str, optional): Path to YAML/JSON file containing
179
+ PEFT (LoRA) config. Defaults to None.
180
+ kwargs: Additional arguments to override PEFT config params.
147
181
182
+ Returns:
183
+ Union[AutoModel, PeftModel]: If the use_peft in train_config is True
184
+ then PeftModel object is returned else original model object
185
+ (AutoModel) is returned.
186
+ """
187
+ if not train_config .use_peft :
188
+ return model
189
+
190
+ # Load the pre-trained peft model checkpoint and setup its configuration
191
+ if train_config .from_peft_checkpoint :
192
+ model = PeftModel .from_pretrained (model , train_config .from_peft_checkpoint , is_trainable = True )
193
+ peft_config = model .peft_config
194
+ # Generate the peft config and start fine-tuning from original model
195
+ else :
196
+ peft_config = generate_peft_config (train_config , peft_config_file , ** kwargs )
197
+ model = get_peft_model (model , peft_config )
198
+ model .print_trainable_parameters ()
199
+
200
+ return model
201
+
202
+
203
+ def setup_dataloaders (
204
+ train_config : TrainConfig ,
205
+ dataset_config : Any ,
206
+ tokenizer : AutoTokenizer ,
207
+ ) -> tuple [torch .utils .data .DataLoader , Optional [torch .utils .data .DataLoader ], int ]:
208
+ """Set up training and validation DataLoaders.
209
+
210
+ Args:
211
+ train_config (TrainConfig): Training configuration object.
212
+ dataset_config (Any): Configuration for the dataset (generated from train_config).
213
+ tokenizer (AutoTokenizer): Tokenizer for preprocessing data.
214
+
215
+ Returns:
216
+ tuple: A tuple of three values.
217
+ - First value represents train_dataloader
218
+ - Second value represents eval_dataloader. It is None if
219
+ validation is disabled.
220
+ - Length of longest sequence in the dataset.
221
+
222
+ Raises:
223
+ ValueError: If validation is enabled but the validation set is too small.
224
+
225
+ Notes:
226
+ - Applies a custom data collator if provided by get_custom_data_collator.
227
+ - Configures DataLoader kwargs using get_dataloader_kwargs for train and val splits.
228
+ """
148
229
# Get the dataset utils
149
230
dataset_processer = tokenizer
150
231
@@ -164,6 +245,8 @@ def main(**kwargs):
164
245
##
165
246
train_dl_kwargs = get_dataloader_kwargs (train_config , dataset_train , dataset_processer , "train" )
166
247
print ("length of dataset_train" , len (dataset_train ))
248
+
249
+ # FIXME (Meet): Add custom data collator registration from the outside by the user.
167
250
custom_data_collator = get_custom_data_collator (dataset_processer , dataset_config )
168
251
if custom_data_collator :
169
252
print ("custom_data_collator is used" )
@@ -208,40 +291,66 @@ def main(**kwargs):
208
291
else :
209
292
longest_seq_length , _ = get_longest_seq_length (train_dataloader .dataset )
210
293
294
+ return train_dataloader , eval_dataloader , longest_seq_length
295
+
296
+
297
+ def main (peft_config_file : str = None , ** kwargs ) -> None :
298
+ """
299
+ Fine-tune a model on QAIC hardware with configurable training and LoRA parameters.
300
+
301
+ Args:
302
+ peft_config_file (str, optional): Path to YAML/JSON file containing PEFT (LoRA) config. Defaults to None.
303
+ kwargs: Additional arguments to override TrainConfig.
304
+
305
+ Example:
306
+ .. code-block:: bash
307
+
308
+ # Using a YAML config file for PEFT
309
+ python -m QEfficient.cloud.finetune \\
310
+ --model_name "meta-llama/Llama-3.2-1B" \\
311
+ --lr 5e-4 \\
312
+ --peft_config_file "lora_config.yaml"
313
+
314
+ # Using default LoRA config
315
+ python -m QEfficient.cloud.finetune \\
316
+ --model_name "meta-llama/Llama-3.2-1B" \\
317
+ --lr 5e-4
318
+ """
319
+ train_config = TrainConfig ()
320
+ update_config (train_config , ** kwargs )
321
+ dataset_config = generate_dataset_config (train_config .dataset )
322
+ update_config (dataset_config , ** kwargs )
323
+
324
+ setup_distributed_training (train_config )
325
+ setup_seeds (train_config .seed )
326
+ model , tokenizer = load_model_and_tokenizer (train_config , dataset_config , peft_config_file , ** kwargs )
327
+
328
+ # Create DataLoaders for the training and validation dataset
329
+ train_dataloader , eval_dataloader , longest_seq_length = setup_dataloaders (train_config , dataset_config , tokenizer )
211
330
print (
212
331
f"The longest sequence length in the train data is { longest_seq_length } , "
213
332
f"passed context length is { train_config .context_length } and overall model's context length is "
214
333
f"{ model .config .max_position_embeddings } "
215
334
)
335
+
216
336
model .to (train_config .device )
217
- optimizer = optim .AdamW (
218
- model .parameters (),
219
- lr = train_config .lr ,
220
- weight_decay = train_config .weight_decay ,
221
- )
337
+ optimizer = optim .AdamW (model .parameters (), lr = train_config .lr , weight_decay = train_config .weight_decay )
222
338
scheduler = StepLR (optimizer , step_size = 1 , gamma = train_config .gamma )
223
-
224
- # wrap model with DDP
225
339
if train_config .enable_ddp :
226
340
model = nn .parallel .DistributedDataParallel (model , device_ids = [dist .get_rank ()])
227
-
228
- _ = train (
341
+ results = train (
229
342
model ,
343
+ tokenizer ,
230
344
train_dataloader ,
231
345
eval_dataloader ,
232
- tokenizer ,
233
346
optimizer ,
234
347
scheduler ,
235
- train_config .gradient_accumulation_steps ,
236
348
train_config ,
237
- train_config .device ,
238
349
dist .get_rank () if train_config .enable_ddp else None ,
239
- None ,
240
350
)
241
-
242
- # finalize torch distributed
243
351
if train_config .enable_ddp :
244
352
dist .destroy_process_group ()
353
+ return results
245
354
246
355
247
356
if __name__ == "__main__" :
0 commit comments