Skip to content

Commit d3fd512

Browse files
committed
Rebase on PR116 and make API changes
Signed-off-by: Jou-An Chen <quic_jouachen@quicinc.com>
1 parent c896790 commit d3fd512

File tree

7 files changed

+219
-200
lines changed

7 files changed

+219
-200
lines changed

QEfficient/exporter/export_hf_to_cloud_ai_100.py

+2-14
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from QEfficient.base.common import AUTO_MODEL_MAP_TO_MODEL_TYPE_MAP, QEFF_MODEL_TYPE, QEFFCommonLoader
1717
from QEfficient.base.modeling_qeff import QEFFBaseModel
1818
from QEfficient.exporter.export_utils import export_onnx, fix_onnx_fp16, generate_input_files, run_model_on_ort
19-
from QEfficient.lora.auto import QEffAutoLoraModelForCausalLM
2019
from QEfficient.transformers.modeling_utils import get_lists_of_cb_qeff_models
2120
from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM
2221
from QEfficient.utils import load_hf_tokenizer
@@ -149,7 +148,6 @@ def convert_to_cloud_kvstyle(
149148
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
150149
onnx_dir_path: str,
151150
seq_len: int,
152-
max_num_adapters: int,
153151
) -> str:
154152
"""
155153
API to convert model with kv retention and export to ONNX.
@@ -178,7 +176,7 @@ def convert_to_cloud_kvstyle(
178176

179177
# Decide path for saving exported ONNX files.
180178
model_name = export_kvstyle_transformed_model_to_onnx(
181-
model_name, qeff_model.model, tokenizer, onnx_dir_path, seq_len, max_num_adapters
179+
model_name, qeff_model.model, tokenizer, onnx_dir_path, seq_len
182180
) # type: ignore
183181

184182
# return the model path for automation.
@@ -192,7 +190,6 @@ def export_kvstyle_transformed_model_to_onnx(
192190
onnx_dir_path: str,
193191
seq_len: int,
194192
full_batch_size: Optional[int] = None,
195-
max_num_adapters: Optional[int] = None,
196193
) -> str:
197194
# Disabling requires_grad on all parameters
198195
for _, p in enumerate(transformed_model.parameters()):
@@ -211,7 +208,6 @@ def export_kvstyle_transformed_model_to_onnx(
211208
prompt_len=Constants.PROMPT_LEN,
212209
ctx_len=seq_len,
213210
full_batch_size=full_batch_size,
214-
max_num_adapters=max_num_adapters,
215211
)
216212

217213
inputs = input_handler.prepare_pytorch_inputs()
@@ -319,7 +315,6 @@ def export_for_cloud(
319315
onnx_dir_path: str,
320316
seq_length: int = Constants.SEQ_LEN,
321317
full_batch_size: Optional[int] = None,
322-
max_num_adapters: Optional[int] = None,
323318
) -> str:
324319
# Check if model architecture is supported for continuous batching.
325320
if full_batch_size and qeff_model.model.config.architectures[0].lower() not in {
@@ -330,18 +325,14 @@ def export_for_cloud(
330325
)
331326

332327
# FIXME: move all this to class instead of here, and just call qeff_model.export here.
333-
if (
334-
AUTO_MODEL_MAP_TO_MODEL_TYPE_MAP.get(qeff_model.__class__, None) == QEFF_MODEL_TYPE.CAUSALLM
335-
or qeff_model.__class__ == QEffAutoLoraModelForCausalLM
336-
): # type: ignore
328+
if AUTO_MODEL_MAP_TO_MODEL_TYPE_MAP.get(qeff_model.__class__, None) == QEFF_MODEL_TYPE.CAUSALLM: # type: ignore
337329
return export_lm_model_for_cloud(
338330
model_name=model_name,
339331
qeff_model=qeff_model, # type: ignore
340332
tokenizer=tokenizer,
341333
onnx_dir_path=onnx_dir_path,
342334
seq_length=seq_length,
343335
full_batch_size=full_batch_size,
344-
max_num_adapters=max_num_adapters,
345336
)
346337
else:
347338
raise NotImplementedError(
@@ -356,7 +347,6 @@ def export_lm_model_for_cloud(
356347
onnx_dir_path: str,
357348
seq_length: int,
358349
full_batch_size: Optional[int] = None,
359-
max_num_adapters: Optional[int] = None,
360350
) -> str:
361351
if os.path.exists(onnx_dir_path):
362352
logger.warning(f"Overriding {onnx_dir_path}")
@@ -385,7 +375,6 @@ def qualcomm_efficient_converter(
385375
kv: bool = True,
386376
form_factor: str = "cloud",
387377
full_batch_size: Optional[int] = None,
388-
max_num_adapters: Optional[int] = None,
389378
) -> Tuple[str, str]:
390379
"""
391380
This method is an alias for ``QEfficient.export``.
@@ -461,7 +450,6 @@ def qualcomm_efficient_converter(
461450
onnx_dir_path=onnx_dir_path,
462451
seq_length=seq_length,
463452
full_batch_size=full_batch_size,
464-
max_num_adapters=max_num_adapters,
465453
)
466454
return onnx_dir_path, generated_onnx_model_path
467455
else:

QEfficient/exporter/export_utils.py

-2
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,6 @@ def export_onnx(
8383
dynamic_axes[iname] = {0: dynamic_axis_past_key, 2: "ctx_len"}
8484
elif iname == "batch_index":
8585
dynamic_axes[iname] = {0: "batch_size"}
86-
elif iname == "lora_ids":
87-
dynamic_axes[iname] = {0: "batch_size"}
8886

8987
if "past_key.0" in input_names and "attention_mask" in input_names:
9088
dynamic_axes["attention_mask"] = {0: "batch_size", 1: "ctx_len"}

QEfficient/generation/text_generation_inference.py

+20-8
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,6 @@ def cloud_ai_100_exec_kv(
230230
stream: bool = True,
231231
write_io_dir: Optional[str] = None,
232232
automation=False,
233-
full_batch_size: Optional[int] = None,
234233
prompt_to_lora_id_mapping: Optional[List[int]] = None,
235234
):
236235
"""
@@ -348,7 +347,10 @@ def __init__(
348347

349348
if prompt_to_lora_id_mapping:
350349
self.prompt_to_lora_id_mapping_prefill = deque(prompt_to_lora_id_mapping)
351-
self.prompt_to_lora_id_mapping_decode = prompt_to_lora_id_mapping
350+
if self.full_batch_size:
351+
self.prompt_to_lora_id_mapping_decode = prompt_to_lora_id_mapping
352+
else:
353+
self.prompt_to_lora_id_mapping_decode = deque(prompt_to_lora_id_mapping)
352354
else:
353355
self.prompt_to_lora_id_mapping_prefill = None
354356
self.prompt_to_lora_id_mapping_decode = None
@@ -472,9 +474,15 @@ def prepare_decode_inputs(self):
472474
if self.batch_index is not None:
473475
decode_inputs["batch_index"] = self.batch_index
474476

475-
if self.prompt_to_lora_id_mapping_decode and self.full_batch_size is not None:
476-
first_batch_lora_ids = [self.prompt_to_lora_id_mapping_decode[i] for i in range(self.full_batch_size)]
477-
decode_inputs["lora_ids"] = np.array(first_batch_lora_ids, dtype=np.int64).reshape(self.full_batch_size, 1)
477+
if self.prompt_to_lora_id_mapping_decode:
478+
if self.full_batch_size:
479+
first_batch_lora_ids = [self.prompt_to_lora_id_mapping_decode[i] for i in range(self.full_batch_size)]
480+
decode_inputs["lora_ids"] = np.array(first_batch_lora_ids, dtype=np.int64).reshape(
481+
self.full_batch_size, 1
482+
)
483+
else:
484+
batch_lora_ids = [self.prompt_to_lora_id_mapping_decode.popleft() for i in range(self.batch_size)]
485+
decode_inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1)
478486

479487
return decode_inputs
480488

@@ -565,9 +573,13 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
565573
inputs["batch_index"] = decode_batch_id
566574

567575
if self.prompt_to_lora_id_mapping_prefill:
568-
inputs["lora_ids"] = np.array(self.prompt_to_lora_id_mapping_prefill.popleft(), dtype=np.int64).reshape(
569-
1, 1
570-
)
576+
if self.full_batch_size:
577+
inputs["lora_ids"] = np.array(self.prompt_to_lora_id_mapping_prefill.popleft(), dtype=np.int64).reshape(
578+
1, 1
579+
)
580+
else:
581+
batch_lora_ids = [self.prompt_to_lora_id_mapping_prefill.popleft() for i in range(self.batch_size)]
582+
inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1)
571583

572584
for i in range(num_chunks):
573585
chunk_inputs = inputs.copy()

0 commit comments

Comments
 (0)