diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index 4bd18f311..58b837e9c 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -608,7 +608,7 @@ def get_specializations( vision_seq_len = compiler_options.pop("vision_seq_len", None) if vision_seq_len is None: # TODO: Check properly for Gemma3, Not verified yet. - vision_seq_len = 2560 # for Gemma3 Vision feature shape is (1, 4096, 1152) --> 1152 is hidden size) + vision_seq_len = 512 # for Gemma3 Vision feature shape is (1, 4096, 1152) --> 1152 is hidden size) prefill_seq_len = prefill_seq_len if prefill_seq_len else 32 ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 6b5deb8db..1a9610187 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -751,8 +751,8 @@ def kv_offload_generate( input_len = inputs["attention_mask"].sum(1, keepdims=True) input_ids_length = inputs["input_ids"].shape[1] num_chunks = -(input_ids_length // -prefill_seq_len) # ceil divide without float - padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len - + # padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len + padded_len = vision_session.bindings[vision_session.binding_index_map["input_ids"]].dims[1] if generation_len is None: generation_len = ctx_len - input_len.max() assert generation_len > 0, "generation length should be greater than zero" @@ -783,18 +783,22 @@ def kv_offload_generate( } vision_inputs["pixel_values"] = vision_inputs["pixel_values"].astype("float16") + vision_inputs["input_ids"] = inputs["input_ids"] + vision_start = perf_counter() vision_outputs = vision_session.run(vision_inputs) + vision_end = perf_counter() lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} + lang_inputs["input_ids"] = inputs["input_ids"] lang_inputs["position_ids"] = np.where( lang_inputs.pop("attention_mask"), np.arange(padded_len), -1 ) # Need to use -1 as position_ids for invalid tokens vision_session.deactivate() lang_session.activate() - - lang_session.set_buffers(vision_outputs) - + lang_inputs["vision_embeds"] = vision_outputs["vision_embeds"] + # lang_session.set_buffers(vision_outputs) + prefill_start = perf_counter() # Run prefill for i in range(num_chunks): chunk_inputs = lang_inputs.copy() @@ -802,9 +806,13 @@ def kv_offload_generate( chunk_inputs["position_ids"] = lang_inputs["position_ids"][ :, i * prefill_seq_len : (i + 1) * prefill_seq_len ] + chunk_inputs["vision_embeds"] = lang_inputs["vision_embeds"][ + :, i * prefill_seq_len : (i + 1) * prefill_seq_len + ] outputs = lang_session.run(chunk_inputs) - prefill_time = perf_counter() - prefill_start + prefill_time = perf_counter() - prefill_start + vision_end - vision_start + lang_inputs["vision_embeds"] = lang_inputs["vision_embeds"][:, :prefill_seq_len] # Skip inputs/outputs again lang_session.skip_buffers( [x for x in lang_session.input_names + lang_session.output_names if x.startswith("past_")] @@ -838,7 +846,7 @@ def kv_offload_generate( streamer.end() decode_perf = (num_token - 1) / (decode_end - decode_start) - total_time = decode_end - prefill_start + total_time = decode_end - decode_start + prefill_time total_perf = num_token / total_time return CloudAI100ExecInfoNew(