Skip to content

Commit 7c0b51c

Browse files
mohiso22quic-rishinrabukhoyasmigoswvbaddi
authored and
Mohit Soni
committed
Gemma3 Adding Merging and Chunking in DecoderWrapper (#402)
Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com> Signed-off-by: Mohit Soni <quic_mohisoni@quicinc.com> Signed-off-by: Abukhoyer Shaik <quic_abukhoye@quicinc.com> Signed-off-by: Asmita Goswami <quic_asmigosw@quicinc.com> Signed-off-by: vbaddi <quic_vbaddi@quicinc.com> Signed-off-by: Meet Patel <quic_meetkuma@quicinc.com> Co-authored-by: Rishin Raj <quic_rishinr@quicinc.com> Co-authored-by: Abukhoyer Shaik <quic_abukhoye@quicinc.com> Co-authored-by: asmigosw <quic_asmigosw@quicinc.com> Co-authored-by: Vinayak Baddi <68580231+vbaddi@users.noreply.github.com> Co-authored-by: Meet Patel <quic_meetkuma@quicinc.com> Signed-off-by: Mohit Soni <mohisoni@qti.qualcomm.com>
1 parent 80ef2ca commit 7c0b51c

File tree

18 files changed

+596
-228
lines changed

18 files changed

+596
-228
lines changed

QEfficient/cloud/compile.py

+21-9
Original file line numberDiff line numberDiff line change
@@ -85,17 +85,29 @@
8585
parser.add_argument(
8686
"--enable_qnn",
8787
"--enable-qnn",
88-
action="store_true",
88+
nargs="?",
89+
const=True,
90+
type=str,
8991
default=False,
9092
help="Enables QNN. Optionally, a configuration file can be provided with [--enable_qnn CONFIG_FILE].\
9193
If not provided, the default configuration will be used.\
9294
Sample Config: QEfficient/compile/qnn_config.json",
9395
)
94-
parser.add_argument(
95-
"qnn_config",
96-
nargs="?",
97-
type=str,
98-
)
99-
# FIXME(ochougul): Allow extra compilation arguments
100-
args = parser.parse_args()
101-
QEfficient.compile(**vars(args))
96+
97+
args, compiler_options = parser.parse_known_args()
98+
99+
if isinstance(args.enable_qnn, str):
100+
args.qnn_config = args.enable_qnn
101+
args.enable_qnn = True
102+
103+
compiler_options_dict = {}
104+
for i in range(0, len(compiler_options)):
105+
if compiler_options[i].startswith("--"):
106+
key = compiler_options[i].lstrip("-").replace("-", "_")
107+
value = (
108+
compiler_options[i + 1]
109+
if i + 1 < len(compiler_options) and not compiler_options[i + 1].startswith("-")
110+
else True
111+
)
112+
compiler_options_dict[key] = value
113+
QEfficient.compile(**args.__dict__, **compiler_options_dict)

QEfficient/cloud/finetune.py

+169-60
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import random
99
import warnings
10+
from typing import Any, Dict, Optional, Union
1011

1112
import fire
1213
import numpy as np
@@ -17,8 +18,9 @@
1718
import torch.utils.data
1819
from peft import PeftModel, get_peft_model
1920
from torch.optim.lr_scheduler import StepLR
21+
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
2022

21-
from QEfficient.finetune.configs.training import train_config as TRAIN_CONFIG
23+
from QEfficient.finetune.configs.training import TrainConfig
2224
from QEfficient.finetune.utils.config_utils import (
2325
generate_dataset_config,
2426
generate_peft_config,
@@ -32,52 +34,81 @@
3234
from QEfficient.finetune.utils.train_utils import get_longest_seq_length, print_model_size, train
3335
from QEfficient.utils._utils import login_and_download_hf_lm
3436

37+
# Try importing QAIC-specific module, proceed without it if unavailable
3538
try:
3639
import torch_qaic # noqa: F401
3740
except ImportError as e:
38-
print(f"Warning: {e}. Moving ahead without these qaic modules.")
41+
print(f"Warning: {e}. Proceeding without QAIC modules.")
3942

4043

41-
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
44+
from transformers import AutoModelForSequenceClassification
4245

4346
# Suppress all warnings
4447
warnings.filterwarnings("ignore")
4548

4649

47-
def main(**kwargs):
48-
"""
49-
Helper function to finetune the model on QAic.
50+
def setup_distributed_training(train_config: TrainConfig) -> None:
51+
"""Initialize distributed training environment if enabled.
5052
51-
.. code-block:: bash
53+
Args:
54+
train_config (TrainConfig): Training configuration object.
5255
53-
python -m QEfficient.cloud.finetune OPTIONS
56+
Notes:
57+
- If distributed data parallel (DDP) is disabled, this function does nothing.
58+
- Ensures the device is not CPU and does not specify an index for DDP compatibility.
59+
- Initializes the process group using the specified distributed backend.
5460
61+
Raises:
62+
AssertionError: If device is CPU or includes an index with DDP enabled.
5563
"""
56-
# update the configuration for the training process
57-
train_config = TRAIN_CONFIG()
58-
update_config(train_config, **kwargs)
59-
dataset_config = generate_dataset_config(train_config, kwargs)
60-
device = train_config.device
64+
if not train_config.enable_ddp:
65+
return
6166

62-
# dist init
63-
if train_config.enable_ddp:
64-
# TODO: may have to init qccl backend, next try run with torchrun command
65-
torch_device = torch.device(device)
66-
assert torch_device.type != "cpu", "Host doesn't support single-node DDP"
67-
assert torch_device.index is None, (
68-
f"DDP requires specification of device type only, however provided device index as well: {torch_device}"
69-
)
70-
dist.init_process_group(backend=train_config.dist_backend)
71-
# from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank
72-
getattr(torch, torch_device.type).set_device(dist.get_rank())
67+
torch_device = torch.device(train_config.device)
68+
assert torch_device.type != "cpu", "Host doesn't support single-node DDP"
69+
assert torch_device.index is None, f"DDP requires only device type, got: {torch_device}"
70+
71+
dist.init_process_group(backend=train_config.dist_backend)
72+
# from here onward "qaic/cuda" will automatically map to "qaic:i/cuda:i", where i = process rank
73+
getattr(torch, torch_device.type).set_device(dist.get_rank())
7374

74-
# Set the seeds for reproducibility
75-
torch.manual_seed(train_config.seed)
76-
random.seed(train_config.seed)
77-
np.random.seed(train_config.seed)
7875

79-
# Load the pre-trained model and setup its configuration
80-
# config = AutoConfig.from_pretrained(train_config.model_name)
76+
def setup_seeds(seed: int) -> None:
77+
"""Set random seeds across libraries for reproducibility.
78+
79+
Args:
80+
seed (int): Seed value to set for random number generators.
81+
82+
Notes:
83+
- Sets seeds for PyTorch, Python's random module, and NumPy.
84+
"""
85+
torch.manual_seed(seed)
86+
random.seed(seed)
87+
np.random.seed(seed)
88+
89+
90+
def load_model_and_tokenizer(
91+
train_config: TrainConfig, dataset_config: Any, peft_config_file: str, **kwargs
92+
) -> tuple[AutoModelForCausalLM, AutoTokenizer]:
93+
"""Load the pre-trained model and tokenizer from Hugging Face.
94+
95+
Args:
96+
config (TrainConfig): Training configuration object containing model and tokenizer names.
97+
dataset_config (Any): A dataclass object representing dataset configuration.
98+
peft_config_file (str): Path to PEFT config file used for PEFT finetuning.
99+
kwargs: Additional arguments to override PEFT config.
100+
101+
Returns:
102+
tuple: A tuple of two values.
103+
- Model with pretrained weights loaded.
104+
- Model's tokenizer (AutoTokenizer).
105+
106+
Notes:
107+
- Downloads the model if not already cached using login_and_download_hf_lm.
108+
- Configures the model with FP16 precision and disables caching for training.
109+
- Resizes model embeddings if tokenizer vocab size exceeds model embedding size.
110+
- Sets pad_token_id to eos_token_id if not defined in the tokenizer.
111+
"""
81112
pretrained_model_path = login_and_download_hf_lm(train_config.model_name)
82113
if train_config.task_type == "seq_classification":
83114
model = AutoModelForSequenceClassification.from_pretrained(
@@ -104,7 +135,6 @@ def main(**kwargs):
104135
torch_dtype=torch.float16,
105136
)
106137

107-
# Load the tokenizer and add special tokens
108138
tokenizer = AutoTokenizer.from_pretrained(
109139
train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name
110140
)
@@ -114,14 +144,12 @@ def main(**kwargs):
114144
# If there is a mismatch between tokenizer vocab size and embedding matrix,
115145
# throw a warning and then expand the embedding matrix
116146
if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
117-
print("WARNING: Resizing the embedding matrix to match the tokenizer vocab size.")
147+
print("WARNING: Resizing embedding matrix to match tokenizer vocab size.")
118148
model.resize_token_embeddings(len(tokenizer))
119149

150+
# FIXME (Meet): Cover below line inside the logger once it is implemented.
120151
print_model_size(model, train_config)
121152

122-
# print the datatype of the model parameters
123-
# print(get_parameter_dtypes(model))
124-
125153
# Note: Need to call this before calling PeftModel.from_pretrained or get_peft_model.
126154
# Because, both makes model.is_gradient_checkpointing = True which is used in peft library to
127155
# apply gradient checkpointing related hooks to the input embeddings. Without this we will get
@@ -134,17 +162,70 @@ def main(**kwargs):
134162
else:
135163
raise RuntimeError("Given model doesn't support gradient checkpointing. Please disable it and run it.")
136164

137-
if train_config.use_peft:
138-
# Load the pre-trained peft model checkpoint and setup its configuration
139-
if train_config.from_peft_checkpoint:
140-
model = PeftModel.from_pretrained(model, train_config.from_peft_checkpoint, is_trainable=True)
141-
peft_config = model.peft_config
142-
# Generate the peft config and start fine-tuning from original model
143-
else:
144-
peft_config = generate_peft_config(train_config, kwargs)
145-
model = get_peft_model(model, peft_config)
146-
model.print_trainable_parameters()
165+
model = apply_peft(model, train_config, peft_config_file, **kwargs)
166+
167+
return model, tokenizer
168+
169+
170+
def apply_peft(
171+
model: AutoModel, train_config: TrainConfig, peft_config_file: Dict, **kwargs
172+
) -> Union[AutoModel, PeftModel]:
173+
"""Apply Parameter-Efficient Fine-Tuning (PEFT) to the model if enabled.
174+
175+
Args:
176+
model (AutoModel): Huggingface model.
177+
train_config (TrainConfig): Training configuration object.
178+
peft_config_file (str, optional): Path to YAML/JSON file containing
179+
PEFT (LoRA) config. Defaults to None.
180+
kwargs: Additional arguments to override PEFT config params.
147181
182+
Returns:
183+
Union[AutoModel, PeftModel]: If the use_peft in train_config is True
184+
then PeftModel object is returned else original model object
185+
(AutoModel) is returned.
186+
"""
187+
if not train_config.use_peft:
188+
return model
189+
190+
# Load the pre-trained peft model checkpoint and setup its configuration
191+
if train_config.from_peft_checkpoint:
192+
model = PeftModel.from_pretrained(model, train_config.from_peft_checkpoint, is_trainable=True)
193+
peft_config = model.peft_config
194+
# Generate the peft config and start fine-tuning from original model
195+
else:
196+
peft_config = generate_peft_config(train_config, peft_config_file, **kwargs)
197+
model = get_peft_model(model, peft_config)
198+
model.print_trainable_parameters()
199+
200+
return model
201+
202+
203+
def setup_dataloaders(
204+
train_config: TrainConfig,
205+
dataset_config: Any,
206+
tokenizer: AutoTokenizer,
207+
) -> tuple[torch.utils.data.DataLoader, Optional[torch.utils.data.DataLoader], int]:
208+
"""Set up training and validation DataLoaders.
209+
210+
Args:
211+
train_config (TrainConfig): Training configuration object.
212+
dataset_config (Any): Configuration for the dataset (generated from train_config).
213+
tokenizer (AutoTokenizer): Tokenizer for preprocessing data.
214+
215+
Returns:
216+
tuple: A tuple of three values.
217+
- First value represents train_dataloader
218+
- Second value represents eval_dataloader. It is None if
219+
validation is disabled.
220+
- Length of longest sequence in the dataset.
221+
222+
Raises:
223+
ValueError: If validation is enabled but the validation set is too small.
224+
225+
Notes:
226+
- Applies a custom data collator if provided by get_custom_data_collator.
227+
- Configures DataLoader kwargs using get_dataloader_kwargs for train and val splits.
228+
"""
148229
# Get the dataset utils
149230
dataset_processer = tokenizer
150231

@@ -164,6 +245,8 @@ def main(**kwargs):
164245
##
165246
train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, dataset_processer, "train")
166247
print("length of dataset_train", len(dataset_train))
248+
249+
# FIXME (Meet): Add custom data collator registration from the outside by the user.
167250
custom_data_collator = get_custom_data_collator(dataset_processer, dataset_config)
168251
if custom_data_collator:
169252
print("custom_data_collator is used")
@@ -208,40 +291,66 @@ def main(**kwargs):
208291
else:
209292
longest_seq_length, _ = get_longest_seq_length(train_dataloader.dataset)
210293

294+
return train_dataloader, eval_dataloader, longest_seq_length
295+
296+
297+
def main(peft_config_file: str = None, **kwargs) -> None:
298+
"""
299+
Fine-tune a model on QAIC hardware with configurable training and LoRA parameters.
300+
301+
Args:
302+
peft_config_file (str, optional): Path to YAML/JSON file containing PEFT (LoRA) config. Defaults to None.
303+
kwargs: Additional arguments to override TrainConfig.
304+
305+
Example:
306+
.. code-block:: bash
307+
308+
# Using a YAML config file for PEFT
309+
python -m QEfficient.cloud.finetune \\
310+
--model_name "meta-llama/Llama-3.2-1B" \\
311+
--lr 5e-4 \\
312+
--peft_config_file "lora_config.yaml"
313+
314+
# Using default LoRA config
315+
python -m QEfficient.cloud.finetune \\
316+
--model_name "meta-llama/Llama-3.2-1B" \\
317+
--lr 5e-4
318+
"""
319+
train_config = TrainConfig()
320+
update_config(train_config, **kwargs)
321+
dataset_config = generate_dataset_config(train_config.dataset)
322+
update_config(dataset_config, **kwargs)
323+
324+
setup_distributed_training(train_config)
325+
setup_seeds(train_config.seed)
326+
model, tokenizer = load_model_and_tokenizer(train_config, dataset_config, peft_config_file, **kwargs)
327+
328+
# Create DataLoaders for the training and validation dataset
329+
train_dataloader, eval_dataloader, longest_seq_length = setup_dataloaders(train_config, dataset_config, tokenizer)
211330
print(
212331
f"The longest sequence length in the train data is {longest_seq_length}, "
213332
f"passed context length is {train_config.context_length} and overall model's context length is "
214333
f"{model.config.max_position_embeddings}"
215334
)
335+
216336
model.to(train_config.device)
217-
optimizer = optim.AdamW(
218-
model.parameters(),
219-
lr=train_config.lr,
220-
weight_decay=train_config.weight_decay,
221-
)
337+
optimizer = optim.AdamW(model.parameters(), lr=train_config.lr, weight_decay=train_config.weight_decay)
222338
scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
223-
224-
# wrap model with DDP
225339
if train_config.enable_ddp:
226340
model = nn.parallel.DistributedDataParallel(model, device_ids=[dist.get_rank()])
227-
228-
_ = train(
341+
results = train(
229342
model,
343+
tokenizer,
230344
train_dataloader,
231345
eval_dataloader,
232-
tokenizer,
233346
optimizer,
234347
scheduler,
235-
train_config.gradient_accumulation_steps,
236348
train_config,
237-
train_config.device,
238349
dist.get_rank() if train_config.enable_ddp else None,
239-
None,
240350
)
241-
242-
# finalize torch distributed
243351
if train_config.enable_ddp:
244352
dist.destroy_process_group()
353+
return results
245354

246355

247356
if __name__ == "__main__":

QEfficient/cloud/infer.py

+4
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,10 @@ def main(
197197
**kwargs,
198198
)
199199

200+
# If the io-encrypt flag is passed we will exit after QPC generation.
201+
if kwargs.get("io_encrypt", None):
202+
exit()
203+
200204
#########
201205
# Execute
202206
#########

0 commit comments

Comments
 (0)