diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index 8d2141240..6f35b9b8f 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -1133,24 +1133,24 @@ def get_dummy_inputs(self, kv_offload: bool = False): if vis_cfg := getattr(self.config, "vision_config", None): img_size = getattr(vis_cfg, "image_size", 448) - max_num_img_tiles = getattr(vis_cfg, "max_num_tiles", 4) + max_num_tiles = getattr(vis_cfg, "max_num_tiles", 4) else: img_size = 448 - max_num_img_tiles = 4 + max_num_tiles = 4 # vision inputs vision_inputs = { "pixel_values": torch.zeros( - (BS, MAX_NUM_IMG, max_num_img_tiles, NUM_CHANNEL, img_size, img_size), dtype=torch.float32 + (BS, MAX_NUM_IMG, max_num_tiles, NUM_CHANNEL, img_size, img_size), dtype=torch.float32 ), "aspect_ratio_ids": torch.ones((BS, MAX_NUM_IMG), dtype=torch.int64), - "aspect_ratio_mask": torch.ones((BS, MAX_NUM_IMG, max_num_img_tiles), dtype=torch.int64), + "aspect_ratio_mask": torch.ones((BS, MAX_NUM_IMG, max_num_tiles), dtype=torch.int64), } # lang_inputs lang_inputs = { "input_ids": torch.zeros((BS, SEQ_LEN), dtype=torch.int64), - "cross_attention_mask": torch.zeros((BS, SEQ_LEN, MAX_NUM_IMG, max_num_img_tiles), dtype=torch.int64), + "cross_attention_mask": torch.zeros((BS, SEQ_LEN, MAX_NUM_IMG, max_num_tiles), dtype=torch.int64), "attention_mask": torch.ones((BS, SEQ_LEN), dtype=torch.int64), } @@ -1201,6 +1201,7 @@ def get_specializations( ): vis_cfg = self.config.vision_config max_num_images = compiler_options.pop("max_num_images", 1) + max_num_tiles = compiler_options.pop("max_num_tiles", 4) prefill_seq_len = prefill_seq_len if prefill_seq_len else 32 ctx_len = ctx_len if ctx_len else 128 if img_size is None and hasattr(vis_cfg, "image_size"): @@ -1209,13 +1210,21 @@ def get_specializations( img_size = 448 logger.warning("Setting `img_size=448` as it was neither passed nor found in vision_config") - vision = [{"batch_size": batch_size, "max_num_images": max_num_images, "img_size": img_size}] + vision = [ + { + "batch_size": batch_size, + "max_num_images": max_num_images, + "max_num_tiles": max_num_tiles, + "img_size": img_size, + } + ] lang = [ { "batch_size": batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len, "max_num_images": max_num_images, + "max_num_tiles": max_num_tiles, "img_size": img_size, }, { @@ -1223,6 +1232,7 @@ def get_specializations( "seq_len": "1", "ctx_len": ctx_len, "max_num_images": max_num_images, + "max_num_tiles": max_num_tiles, "img_size": img_size, }, ] @@ -1241,15 +1251,15 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): cross_attention_layers = txt_cfg.cross_attention_layers vision_dynamic_axes = { - "pixel_values": {0: "batch_size", 1: "max_num_images", 4: "img_size", 5: "img_size"}, + "pixel_values": {0: "batch_size", 1: "max_num_images", 2: "max_num_tiles", 4: "img_size", 5: "img_size"}, "aspect_ratio_ids": {0: "batch_size", 1: "max_num_images"}, - "aspect_ratio_mask": {0: "batch_size", 1: "max_num_images"}, + "aspect_ratio_mask": {0: "batch_size", 1: "max_num_images", 2: "max_num_tiles"}, } lang_dynamic_axes = { "input_ids": {0: "batch_size", 1: "seq_len"}, "position_ids": {0: "batch_size", 1: "seq_len"}, - "cross_attention_mask": {0: "batch_size", 1: "seq_len", 2: "max_num_images"}, + "cross_attention_mask": {0: "batch_size", 1: "seq_len", 2: "max_num_images", 3: "max_num_tiles"}, } for i in range(num_hidden_layers):