Skip to content

Commit f434ea3

Browse files
mohiso22Mohit Soni
authored and
Mohit Soni
committed
Updating Wrappers for Merging and Chunking in DecoderWrapper (#404)
Signed-off-by: Mohit Soni <quic_mohisoni@quicinc.com> Signed-off-by: Mohit Soni <mohisoni@qti.qualcomm.com>
1 parent e532d74 commit f434ea3

File tree

2 files changed

+35
-38
lines changed

2 files changed

+35
-38
lines changed

QEfficient/transformers/models/gemma3/modeling_gemma3.py

+24-27
Original file line numberDiff line numberDiff line change
@@ -560,16 +560,9 @@ def __init__(self, model):
560560
self.model = model
561561
self.model.vision_model = self.model.vision_tower
562562

563-
def forward(self, input_ids, pixel_values):
564-
inputs_embeds = self.model.get_input_embeddings()(input_ids)
565-
B, N, C = inputs_embeds.shape
563+
def forward(self, pixel_values):
566564
image_features = self.model.get_image_features(pixel_values=pixel_values)
567-
selected = input_ids == self.model.config.image_token_index
568-
indices1 = selected.to(torch.int64).cumsum(1) - 1
569-
indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1)
570-
image_features_expanded = image_features.reshape(-1, C).unsqueeze(0)[indices0, indices1]
571-
image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds)
572-
return image_input_embeds
565+
return image_features
573566

574567

575568
class QEffGemma3DecoderWrapper(nn.Module):
@@ -579,14 +572,21 @@ def __init__(self, model):
579572
self.language_model = self.model.language_model
580573
self.config = self.model.config
581574

582-
def forward(self, input_ids, vision_embeds, position_ids, past_key_values):
583-
image_embeds = vision_embeds[:, : input_ids.shape[1], :]
584-
inputs_embeds = self.model.language_model.get_input_embeddings()(input_ids)
585-
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_embeds)
575+
def forward(self, input_ids, vision_embeds, position_ids, index, past_key_values):
576+
inputs_embeds = self.model.get_input_embeddings()(input_ids)
577+
B, N, C = inputs_embeds.shape
578+
selected = input_ids == self.model.config.image_token_index
579+
indices1 = selected.to(torch.int64).cumsum(1) - 1
580+
indices1 = torch.where(indices1 != -1, indices1 + index, indices1)
581+
indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1)
582+
image_features_expanded = vision_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1]
583+
image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds)
584+
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds)
586585
outputs = self.model.language_model(
587586
inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True
588587
)
589-
return outputs.logits, vision_embeds, outputs.past_key_values
588+
index = (indices1.max() + 1).unsqueeze(0).unsqueeze(0)
589+
return outputs.logits, vision_embeds, index, outputs.past_key_values
590590

591591

592592
class QEffGemma3ForConditionalGeneration(Gemma3ForConditionalGeneration):
@@ -605,24 +605,20 @@ def get_specializations(
605605
kv_offload: bool = False,
606606
**compiler_options,
607607
):
608-
vision_seq_len = compiler_options.pop("vision_seq_len", None)
609-
if vision_seq_len is None:
610-
# TODO: Check properly for Gemma3, Not verified yet.
611-
vision_seq_len = 512 # for Gemma3 Vision feature shape is (1, 4096, 1152) --> 1152 is hidden size)
612-
613608
prefill_seq_len = prefill_seq_len if prefill_seq_len else 32
614609
ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN
615610
if img_size is None and hasattr(self.config.vision_config, "image_size"):
616611
img_size = getattr(self.config.vision_config, "image_size")
617612
elif img_size is None:
618613
img_size = 896 # FIXME based on gemma3 Image size
619614
logger.warning("Setting img_size to be 336, as it was neither passed nor found in vision_config")
615+
mm_tokens_per_image = getattr(self.config, "mm_tokens_per_image", 256)
620616

621617
vision = [
622618
{
623619
"batch_size": batch_size,
624620
"img_size": img_size,
625-
"seq_len": vision_seq_len,
621+
"seq_len": prefill_seq_len,
626622
"ctx_len": ctx_len,
627623
}
628624
]
@@ -632,14 +628,14 @@ def get_specializations(
632628
"seq_len": prefill_seq_len,
633629
"ctx_len": ctx_len,
634630
"img_size": img_size,
635-
"chunk_length": prefill_seq_len,
631+
"mm_tokens_per_image": mm_tokens_per_image,
636632
},
637633
{
638634
"batch_size": batch_size,
639635
"seq_len": "1",
640636
"ctx_len": ctx_len,
641637
"img_size": img_size,
642-
"chunk_length": prefill_seq_len,
638+
"mm_tokens_per_image": mm_tokens_per_image,
643639
},
644640
]
645641

@@ -658,9 +654,8 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
658654
lang_dynamic_axes = {}
659655
lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"}
660656
lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"}
661-
lang_dynamic_axes["vision_embeds"] = {0: "batch_size", 1: "chunk_length"}
657+
lang_dynamic_axes["vision_embeds"] = {0: "batch_size", 1: "mm_tokens_per_image"}
662658
vision_dynamic_axes["pixel_values"] = {0: "batch_size", 2: "img_size", 3: "img_size"}
663-
vision_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"}
664659

665660
pkv_dynamic_axes = {0: "batch_size", 2: "ctx_len"}
666661
for i in range(self.language_model.config.num_hidden_layers):
@@ -685,6 +680,7 @@ def get_output_names(self, kv_offload: bool = False):
685680
output_names = {}
686681
if kv_offload:
687682
lang_output_names.insert(1, "vision_embeds_RetainedState")
683+
lang_output_names.insert(2, "index_output")
688684
output_names["vision"] = vision_output_names
689685
output_names["lang"] = lang_output_names
690686
else:
@@ -698,12 +694,13 @@ def get_dummy_inputs(self, kv_offload: bool = False):
698694
else:
699695
img_size = 896
700696

697+
mm_tokens_per_image = getattr(self.config, "mm_tokens_per_image", 256)
701698
# Define shapes
702699
inputs_shapes = {}
703700
inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
704701
inputs_shapes["vision_embeds"] = (
705702
1, # constants.INTERN_NUM_PATCHES,
706-
constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, # constants.INTERN_FEATURE_SIZE,
703+
mm_tokens_per_image, # constants.INTERN_FEATURE_SIZE,
707704
self.language_model.config.hidden_size, # 5120
708705
)
709706
inputs_shapes["position_ids"] = (
@@ -716,20 +713,20 @@ def get_dummy_inputs(self, kv_offload: bool = False):
716713
img_size,
717714
img_size,
718715
)
716+
inputs_shapes["index"] = (1, 1)
719717

720718
# Define inputs
721719
vision_inputs = {}
722720
lang_inputs = {}
723721
vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=torch.float32)
724-
vision_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64)
725722
lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64)
726723
lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=torch.float32)
727724
lang_inputs["position_ids"] = (
728725
torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64)
729726
.view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
730727
.repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1)
731728
)
732-
729+
lang_inputs["index"] = torch.zeros((inputs_shapes["index"]), dtype=torch.int64)
733730
# Add data for KV
734731
kv_cache_shape = get_padding_shape_from_config(
735732
config=self.language_model.config,

QEfficient/transformers/models/modeling_auto.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -751,8 +751,8 @@ def kv_offload_generate(
751751
input_len = inputs["attention_mask"].sum(1, keepdims=True)
752752
input_ids_length = inputs["input_ids"].shape[1]
753753
num_chunks = -(input_ids_length // -prefill_seq_len) # ceil divide without float
754-
# padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len
755-
padded_len = vision_session.bindings[vision_session.binding_index_map["input_ids"]].dims[1]
754+
padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len
755+
756756
if generation_len is None:
757757
generation_len = ctx_len - input_len.max()
758758
assert generation_len > 0, "generation length should be greater than zero"
@@ -783,39 +783,39 @@ def kv_offload_generate(
783783
}
784784

785785
vision_inputs["pixel_values"] = vision_inputs["pixel_values"].astype("float16")
786-
vision_inputs["input_ids"] = inputs["input_ids"]
787786
vision_start = perf_counter()
788787
vision_outputs = vision_session.run(vision_inputs)
789788
vision_end = perf_counter()
790789

791790
lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs}
792-
lang_inputs["input_ids"] = inputs["input_ids"]
793791
lang_inputs["position_ids"] = np.where(
794792
lang_inputs.pop("attention_mask"), np.arange(padded_len), -1
795793
) # Need to use -1 as position_ids for invalid tokens
796794

797795
vision_session.deactivate()
798796
lang_session.activate()
799797
lang_inputs["vision_embeds"] = vision_outputs["vision_embeds"]
800-
# lang_session.set_buffers(vision_outputs)
798+
lang_session.set_buffers(vision_outputs)
801799
prefill_start = perf_counter()
802800
# Run prefill
801+
chunk_inputs = lang_inputs.copy()
802+
chunk_inputs["index"] = np.array([[0]])
803803
for i in range(num_chunks):
804-
chunk_inputs = lang_inputs.copy()
805804
chunk_inputs["input_ids"] = lang_inputs["input_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len]
806805
chunk_inputs["position_ids"] = lang_inputs["position_ids"][
807806
:, i * prefill_seq_len : (i + 1) * prefill_seq_len
808807
]
809-
chunk_inputs["vision_embeds"] = lang_inputs["vision_embeds"][
810-
:, i * prefill_seq_len : (i + 1) * prefill_seq_len
811-
]
812808
outputs = lang_session.run(chunk_inputs)
809+
chunk_inputs["index"] = outputs["index_output"]
813810

814811
prefill_time = perf_counter() - prefill_start + vision_end - vision_start
815-
lang_inputs["vision_embeds"] = lang_inputs["vision_embeds"][:, :prefill_seq_len]
816812
# Skip inputs/outputs again
817813
lang_session.skip_buffers(
818-
[x for x in lang_session.input_names + lang_session.output_names if x.startswith("past_")]
814+
[
815+
x
816+
for x in lang_session.input_names + lang_session.output_names
817+
if x.startswith("past_") or x.endswith("_RetainedState")
818+
]
819819
)
820820

821821
# Get first token

0 commit comments

Comments
 (0)