@@ -560,16 +560,9 @@ def __init__(self, model):
560
560
self .model = model
561
561
self .model .vision_model = self .model .vision_tower
562
562
563
- def forward (self , input_ids , pixel_values ):
564
- inputs_embeds = self .model .get_input_embeddings ()(input_ids )
565
- B , N , C = inputs_embeds .shape
563
+ def forward (self , pixel_values ):
566
564
image_features = self .model .get_image_features (pixel_values = pixel_values )
567
- selected = input_ids == self .model .config .image_token_index
568
- indices1 = selected .to (torch .int64 ).cumsum (1 ) - 1
569
- indices0 = torch .arange (selected .unsqueeze (0 ).shape [0 ]).view (- 1 , 1 )
570
- image_features_expanded = image_features .reshape (- 1 , C ).unsqueeze (0 )[indices0 , indices1 ]
571
- image_input_embeds = torch .where (selected .unsqueeze (- 1 ), image_features_expanded , inputs_embeds )
572
- return image_input_embeds
565
+ return image_features
573
566
574
567
575
568
class QEffGemma3DecoderWrapper (nn .Module ):
@@ -579,14 +572,21 @@ def __init__(self, model):
579
572
self .language_model = self .model .language_model
580
573
self .config = self .model .config
581
574
582
- def forward (self , input_ids , vision_embeds , position_ids , past_key_values ):
583
- image_embeds = vision_embeds [:, : input_ids .shape [1 ], :]
584
- inputs_embeds = self .model .language_model .get_input_embeddings ()(input_ids )
585
- inputs_embeds = torch .where (input_ids .shape [1 ] == torch .tensor (1 ), inputs_embeds , image_embeds )
575
+ def forward (self , input_ids , vision_embeds , position_ids , index , past_key_values ):
576
+ inputs_embeds = self .model .get_input_embeddings ()(input_ids )
577
+ B , N , C = inputs_embeds .shape
578
+ selected = input_ids == self .model .config .image_token_index
579
+ indices1 = selected .to (torch .int64 ).cumsum (1 ) - 1
580
+ indices1 = torch .where (indices1 != - 1 , indices1 + index , indices1 )
581
+ indices0 = torch .arange (selected .unsqueeze (0 ).shape [0 ]).view (- 1 , 1 )
582
+ image_features_expanded = vision_embeds .reshape (- 1 , C ).unsqueeze (0 )[indices0 , indices1 ]
583
+ image_input_embeds = torch .where (selected .unsqueeze (- 1 ), image_features_expanded , inputs_embeds )
584
+ inputs_embeds = torch .where (input_ids .shape [1 ] == torch .tensor (1 ), inputs_embeds , image_input_embeds )
586
585
outputs = self .model .language_model (
587
586
inputs_embeds = inputs_embeds , position_ids = position_ids , past_key_values = past_key_values , use_cache = True
588
587
)
589
- return outputs .logits , vision_embeds , outputs .past_key_values
588
+ index = (indices1 .max () + 1 ).unsqueeze (0 ).unsqueeze (0 )
589
+ return outputs .logits , vision_embeds , index , outputs .past_key_values
590
590
591
591
592
592
class QEffGemma3ForConditionalGeneration (Gemma3ForConditionalGeneration ):
@@ -605,24 +605,20 @@ def get_specializations(
605
605
kv_offload : bool = False ,
606
606
** compiler_options ,
607
607
):
608
- vision_seq_len = compiler_options .pop ("vision_seq_len" , None )
609
- if vision_seq_len is None :
610
- # TODO: Check properly for Gemma3, Not verified yet.
611
- vision_seq_len = 512 # for Gemma3 Vision feature shape is (1, 4096, 1152) --> 1152 is hidden size)
612
-
613
608
prefill_seq_len = prefill_seq_len if prefill_seq_len else 32
614
609
ctx_len = ctx_len if ctx_len else constants .INTERN_CTX_LEN
615
610
if img_size is None and hasattr (self .config .vision_config , "image_size" ):
616
611
img_size = getattr (self .config .vision_config , "image_size" )
617
612
elif img_size is None :
618
613
img_size = 896 # FIXME based on gemma3 Image size
619
614
logger .warning ("Setting img_size to be 336, as it was neither passed nor found in vision_config" )
615
+ mm_tokens_per_image = getattr (self .config , "mm_tokens_per_image" , 256 )
620
616
621
617
vision = [
622
618
{
623
619
"batch_size" : batch_size ,
624
620
"img_size" : img_size ,
625
- "seq_len" : vision_seq_len ,
621
+ "seq_len" : prefill_seq_len ,
626
622
"ctx_len" : ctx_len ,
627
623
}
628
624
]
@@ -632,14 +628,14 @@ def get_specializations(
632
628
"seq_len" : prefill_seq_len ,
633
629
"ctx_len" : ctx_len ,
634
630
"img_size" : img_size ,
635
- "chunk_length " : prefill_seq_len ,
631
+ "mm_tokens_per_image " : mm_tokens_per_image ,
636
632
},
637
633
{
638
634
"batch_size" : batch_size ,
639
635
"seq_len" : "1" ,
640
636
"ctx_len" : ctx_len ,
641
637
"img_size" : img_size ,
642
- "chunk_length " : prefill_seq_len ,
638
+ "mm_tokens_per_image " : mm_tokens_per_image ,
643
639
},
644
640
]
645
641
@@ -658,9 +654,8 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False):
658
654
lang_dynamic_axes = {}
659
655
lang_dynamic_axes ["input_ids" ] = {0 : "batch_size" , 1 : "seq_len" }
660
656
lang_dynamic_axes ["position_ids" ] = {0 : "batch_size" , 1 : "seq_len" }
661
- lang_dynamic_axes ["vision_embeds" ] = {0 : "batch_size" , 1 : "chunk_length " }
657
+ lang_dynamic_axes ["vision_embeds" ] = {0 : "batch_size" , 1 : "mm_tokens_per_image " }
662
658
vision_dynamic_axes ["pixel_values" ] = {0 : "batch_size" , 2 : "img_size" , 3 : "img_size" }
663
- vision_dynamic_axes ["input_ids" ] = {0 : "batch_size" , 1 : "seq_len" }
664
659
665
660
pkv_dynamic_axes = {0 : "batch_size" , 2 : "ctx_len" }
666
661
for i in range (self .language_model .config .num_hidden_layers ):
@@ -685,6 +680,7 @@ def get_output_names(self, kv_offload: bool = False):
685
680
output_names = {}
686
681
if kv_offload :
687
682
lang_output_names .insert (1 , "vision_embeds_RetainedState" )
683
+ lang_output_names .insert (2 , "index_output" )
688
684
output_names ["vision" ] = vision_output_names
689
685
output_names ["lang" ] = lang_output_names
690
686
else :
@@ -698,12 +694,13 @@ def get_dummy_inputs(self, kv_offload: bool = False):
698
694
else :
699
695
img_size = 896
700
696
697
+ mm_tokens_per_image = getattr (self .config , "mm_tokens_per_image" , 256 )
701
698
# Define shapes
702
699
inputs_shapes = {}
703
700
inputs_shapes ["input_ids" ] = (constants .ONNX_EXPORT_EXAMPLE_BATCH_SIZE , constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN )
704
701
inputs_shapes ["vision_embeds" ] = (
705
702
1 , # constants.INTERN_NUM_PATCHES,
706
- constants . ONNX_EXPORT_EXAMPLE_SEQ_LEN , # constants.INTERN_FEATURE_SIZE,
703
+ mm_tokens_per_image , # constants.INTERN_FEATURE_SIZE,
707
704
self .language_model .config .hidden_size , # 5120
708
705
)
709
706
inputs_shapes ["position_ids" ] = (
@@ -716,20 +713,20 @@ def get_dummy_inputs(self, kv_offload: bool = False):
716
713
img_size ,
717
714
img_size ,
718
715
)
716
+ inputs_shapes ["index" ] = (1 , 1 )
719
717
720
718
# Define inputs
721
719
vision_inputs = {}
722
720
lang_inputs = {}
723
721
vision_inputs ["pixel_values" ] = torch .zeros ((inputs_shapes ["pixel_values" ]), dtype = torch .float32 )
724
- vision_inputs ["input_ids" ] = torch .zeros ((inputs_shapes ["input_ids" ]), dtype = torch .int64 )
725
722
lang_inputs ["input_ids" ] = torch .zeros ((inputs_shapes ["input_ids" ]), dtype = torch .int64 )
726
723
lang_inputs ["vision_embeds" ] = torch .zeros ((inputs_shapes ["vision_embeds" ]), dtype = torch .float32 )
727
724
lang_inputs ["position_ids" ] = (
728
725
torch .arange (constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN , dtype = torch .int64 )
729
726
.view (1 , constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN )
730
727
.repeat (constants .ONNX_EXPORT_EXAMPLE_BATCH_SIZE , 1 )
731
728
)
732
-
729
+ lang_inputs [ "index" ] = torch . zeros (( inputs_shapes [ "index" ]), dtype = torch . int64 )
733
730
# Add data for KV
734
731
kv_cache_shape = get_padding_shape_from_config (
735
732
config = self .language_model .config ,
0 commit comments