From 3e5995d6e1ec93b3c2c6b0a1c6782accab3ded59 Mon Sep 17 00:00:00 2001 From: Amit Raj <168538872+quic-amitraj@users.noreply.github.com> Date: Thu, 6 Mar 2025 15:29:57 +0530 Subject: [PATCH 1/5] Dynamic `max_num_tiles` for mllama (#308) Signed-off-by: Amit Raj --- .../models/mllama/modeling_mllama.py | 28 +++++++++++++------ 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index 8d2141240..6f35b9b8f 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -1133,24 +1133,24 @@ def get_dummy_inputs(self, kv_offload: bool = False): if vis_cfg := getattr(self.config, "vision_config", None): img_size = getattr(vis_cfg, "image_size", 448) - max_num_img_tiles = getattr(vis_cfg, "max_num_tiles", 4) + max_num_tiles = getattr(vis_cfg, "max_num_tiles", 4) else: img_size = 448 - max_num_img_tiles = 4 + max_num_tiles = 4 # vision inputs vision_inputs = { "pixel_values": torch.zeros( - (BS, MAX_NUM_IMG, max_num_img_tiles, NUM_CHANNEL, img_size, img_size), dtype=torch.float32 + (BS, MAX_NUM_IMG, max_num_tiles, NUM_CHANNEL, img_size, img_size), dtype=torch.float32 ), "aspect_ratio_ids": torch.ones((BS, MAX_NUM_IMG), dtype=torch.int64), - "aspect_ratio_mask": torch.ones((BS, MAX_NUM_IMG, max_num_img_tiles), dtype=torch.int64), + "aspect_ratio_mask": torch.ones((BS, MAX_NUM_IMG, max_num_tiles), dtype=torch.int64), } # lang_inputs lang_inputs = { "input_ids": torch.zeros((BS, SEQ_LEN), dtype=torch.int64), - "cross_attention_mask": torch.zeros((BS, SEQ_LEN, MAX_NUM_IMG, max_num_img_tiles), dtype=torch.int64), + "cross_attention_mask": torch.zeros((BS, SEQ_LEN, MAX_NUM_IMG, max_num_tiles), dtype=torch.int64), "attention_mask": torch.ones((BS, SEQ_LEN), dtype=torch.int64), } @@ -1201,6 +1201,7 @@ def get_specializations( ): vis_cfg = self.config.vision_config max_num_images = compiler_options.pop("max_num_images", 1) + max_num_tiles = compiler_options.pop("max_num_tiles", 4) prefill_seq_len = prefill_seq_len if prefill_seq_len else 32 ctx_len = ctx_len if ctx_len else 128 if img_size is None and hasattr(vis_cfg, "image_size"): @@ -1209,13 +1210,21 @@ def get_specializations( img_size = 448 logger.warning("Setting `img_size=448` as it was neither passed nor found in vision_config") - vision = [{"batch_size": batch_size, "max_num_images": max_num_images, "img_size": img_size}] + vision = [ + { + "batch_size": batch_size, + "max_num_images": max_num_images, + "max_num_tiles": max_num_tiles, + "img_size": img_size, + } + ] lang = [ { "batch_size": batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len, "max_num_images": max_num_images, + "max_num_tiles": max_num_tiles, "img_size": img_size, }, { @@ -1223,6 +1232,7 @@ def get_specializations( "seq_len": "1", "ctx_len": ctx_len, "max_num_images": max_num_images, + "max_num_tiles": max_num_tiles, "img_size": img_size, }, ] @@ -1241,15 +1251,15 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): cross_attention_layers = txt_cfg.cross_attention_layers vision_dynamic_axes = { - "pixel_values": {0: "batch_size", 1: "max_num_images", 4: "img_size", 5: "img_size"}, + "pixel_values": {0: "batch_size", 1: "max_num_images", 2: "max_num_tiles", 4: "img_size", 5: "img_size"}, "aspect_ratio_ids": {0: "batch_size", 1: "max_num_images"}, - "aspect_ratio_mask": {0: "batch_size", 1: "max_num_images"}, + "aspect_ratio_mask": {0: "batch_size", 1: "max_num_images", 2: "max_num_tiles"}, } lang_dynamic_axes = { "input_ids": {0: "batch_size", 1: "seq_len"}, "position_ids": {0: "batch_size", 1: "seq_len"}, - "cross_attention_mask": {0: "batch_size", 1: "seq_len", 2: "max_num_images"}, + "cross_attention_mask": {0: "batch_size", 1: "seq_len", 2: "max_num_images", 3: "max_num_tiles"}, } for i in range(num_hidden_layers): From c83c2d6127e6f4b679b474b62728dddd39213bf0 Mon Sep 17 00:00:00 2001 From: Rishin Raj Date: Thu, 6 Mar 2025 18:36:23 +0530 Subject: [PATCH 2/5] Mllama bqa (#310) Signed-off-by: Rishin Raj --- .../models/mllama/modeling_mllama.py | 66 +++++++++++++++++++ .../transformers/models/modeling_auto.py | 15 ++++- .../transformers/models/pytorch_transforms.py | 21 ++++++ 3 files changed, 100 insertions(+), 2 deletions(-) diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index 6f35b9b8f..65b2730ca 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -33,6 +33,7 @@ MllamaTextCrossAttention, MllamaTextModel, MllamaTextSelfAttention, + MllamaVisionAttention, MllamaVisionModel, logger, repeat_kv, @@ -1315,3 +1316,68 @@ def get_inputs_info(self): ), IOInfo(name="attention_mask", datatype=torch.int64, shape=("batch_size", "seq_len")), ] + + +class QEffMllamaVisionAttention(MllamaVisionAttention): + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = None, + block_size: int = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + + query = self.q_proj(hidden_state) + key = self.k_proj(hidden_state) + value = self.v_proj(hidden_state) + + batch_size, q_seq_len, _ = query.shape + _, kv_seq_len, _ = key.shape + + query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim).transpose(1, 2) + key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2) + value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2) + if block_size is not None: + num_blocks = q_seq_len // block_size + + attn_output_blocks = [] + attn_weights_blocks = [] + + for i in range(num_blocks): + query_block = query[:, :, i * block_size:(i + 1) * block_size, :] + attn_weights_block = torch.matmul(query_block, key.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: + # causal_mask_block = attention_mask[:, :, :, : key.shape[-2]] + causal_mask_block = attention_mask[:, :, i * block_size:(i + 1) * block_size,:] + attn_weights_block = attn_weights_block + causal_mask_block + + attn_weights_block = nn.functional.softmax(attn_weights_block, dim=-1, dtype=torch.float32).to(query.dtype) + attn_output_block = torch.matmul(attn_weights_block, value) + + attn_output_blocks.append(attn_output_block) + attn_weights_blocks.append(attn_weights_block) + + attn_output = torch.cat(attn_output_blocks, dim=2) + attn_weights = torch.cat(attn_weights_blocks, dim=2) + else: + # Regular attention + attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, :key.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_output = torch.matmul(attn_weights, value) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_seq_len, -1) + + output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return output, attn_weights \ No newline at end of file diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 07aff78ff..b0ffb5eda 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -35,6 +35,7 @@ get_compilation_dims, ) from QEfficient.transformers.models.pytorch_transforms import ( + BlockAttentionTransorm, CustomOpsTransform, KVCacheModuleMethodMapperTransform, KVCacheTransform, @@ -525,6 +526,7 @@ class _QEffAutoModelForImageTextToTextDualQPC: def __init__( self, model: nn.Module, + block_size : int = None, **kwargs, ): if kwargs.pop("full_batch_size", None): @@ -535,6 +537,9 @@ def __init__( self.lang_model = QEffCausalLMForTextImageToTextModel(model) self.input_shapes, self.output_names = None, None + + if block_size: + BlockAttentionTransorm.apply(model, block_size=block_size) @property def model_name(self) -> str: @@ -850,11 +855,15 @@ class _QEFFAutoModelForImageTextToTextSingleQPC(QEFFTransformersBase, Multimodal def __init__( self, model: nn.Module, + block_size : int = None, **kwargs, ): if kwargs.pop("full_batch_size", None): raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") super().__init__(model) + + if block_size: + BlockAttentionTransorm.apply(model, block_size=block_size) # to handle internvl models if hasattr(self.model.config, "llm_config") and hasattr(self.model.config, "vision_config"): @@ -1222,7 +1231,8 @@ def __new__(self, model: nn.Module, kv_offload: Optional[bool] = True, **kwargs) @classmethod @with_replaced_quantizers - def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optional[bool] = None, **kwargs): + def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optional[bool] = None, + block_size : int=None, **kwargs): """Used to load models supported by transformers.AutoModelForImageTextToText for Cloud AI 100. Args: @@ -1238,10 +1248,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona if kwargs.get("low_cpu_mem_usage", None): logger.warning("Updating low_cpu_mem_usage=False") + kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) - return cls(model, kv_offload=kv_offload, **kwargs) + return cls(model, kv_offload=kv_offload, block_size=block_size, **kwargs) MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP = {"InternVLChatModel": QEFFAutoModelForImageTextToText} diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 8152f0676..670dd572c 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -5,6 +5,7 @@ # # ----------------------------------------------------------------------------- +from functools import partial from types import MethodType from typing import Tuple @@ -84,6 +85,7 @@ MllamaTextModel, MllamaTextRMSNorm, MllamaTextSelfAttention, + MllamaVisionAttention, MllamaVisionModel, ) from transformers.models.mpt.modeling_mpt import MptAttention, MptBlock, MptForCausalLM, MptModel @@ -204,6 +206,7 @@ QEffMllamaTextCrossAttentionTwoQPC, QEffMllamaTextModel, QEffMllamaTextSelfAttention, + QEffMllamaVisionAttention, QEffMllamaVisionModel, ) from QEfficient.transformers.models.mpt.modeling_mpt import ( @@ -439,3 +442,21 @@ class KVCacheModuleMethodMapperTransform(ModuleMethodMapperTransform): "InternVisionEmbeddings": {"forward": QEffInternVisionEmbeddings.forward}, } _match_class_replace_method = {} + +class BlockAttentionTransorm(ModuleMappingTransform): + # supported architectures + _module_mapping = { + MllamaVisionAttention: QEffMllamaVisionAttention, + } + + @classmethod + def apply(cls, model: nn.Module, block_size) -> Tuple[nn.Module, bool]: + transformed = False + for module in model.modules(): + if repl_module := cls._module_mapping.get(type(module)): + module.__class__ = repl_module + # Bind the partial function to the instance + module.forward = MethodType(partial(repl_module.forward, block_size=block_size), module) + transformed = True + break + return model, transformed \ No newline at end of file From 01aef5d7dab5068fb8a73def9a166f135b22b0c1 Mon Sep 17 00:00:00 2001 From: Vinayak Baddi Date: Fri, 7 Mar 2025 05:03:25 +0000 Subject: [PATCH 3/5] nit: bug fix forr bqa in vision mllama Signed-off-by: vbaddi --- .../transformers/models/modeling_auto.py | 20 +++++++++---------- .../transformers/models/pytorch_transforms.py | 8 ++++---- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index b0ffb5eda..8e6fc04e1 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -35,7 +35,7 @@ get_compilation_dims, ) from QEfficient.transformers.models.pytorch_transforms import ( - BlockAttentionTransorm, + BlockAttentionTransform, CustomOpsTransform, KVCacheModuleMethodMapperTransform, KVCacheTransform, @@ -526,7 +526,7 @@ class _QEffAutoModelForImageTextToTextDualQPC: def __init__( self, model: nn.Module, - block_size : int = None, + block_size: int = None, **kwargs, ): if kwargs.pop("full_batch_size", None): @@ -537,9 +537,9 @@ def __init__( self.lang_model = QEffCausalLMForTextImageToTextModel(model) self.input_shapes, self.output_names = None, None - + if block_size: - BlockAttentionTransorm.apply(model, block_size=block_size) + BlockAttentionTransform.apply(model, block_size=block_size) @property def model_name(self) -> str: @@ -855,15 +855,15 @@ class _QEFFAutoModelForImageTextToTextSingleQPC(QEFFTransformersBase, Multimodal def __init__( self, model: nn.Module, - block_size : int = None, + block_size: int = None, **kwargs, ): if kwargs.pop("full_batch_size", None): raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") super().__init__(model) - + if block_size: - BlockAttentionTransorm.apply(model, block_size=block_size) + BlockAttentionTransform.apply(model, block_size=block_size) # to handle internvl models if hasattr(self.model.config, "llm_config") and hasattr(self.model.config, "vision_config"): @@ -1231,8 +1231,9 @@ def __new__(self, model: nn.Module, kv_offload: Optional[bool] = True, **kwargs) @classmethod @with_replaced_quantizers - def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optional[bool] = None, - block_size : int=None, **kwargs): + def from_pretrained( + cls, pretrained_model_name_or_path: str, kv_offload: Optional[bool] = None, block_size: int = None, **kwargs + ): """Used to load models supported by transformers.AutoModelForImageTextToText for Cloud AI 100. Args: @@ -1248,7 +1249,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona if kwargs.get("low_cpu_mem_usage", None): logger.warning("Updating low_cpu_mem_usage=False") - kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 670dd572c..0ad1bf309 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -443,7 +443,8 @@ class KVCacheModuleMethodMapperTransform(ModuleMethodMapperTransform): } _match_class_replace_method = {} -class BlockAttentionTransorm(ModuleMappingTransform): + +class BlockAttentionTransform(ModuleMappingTransform): # Fixed typo in class name # supported architectures _module_mapping = { MllamaVisionAttention: QEffMllamaVisionAttention, @@ -457,6 +458,5 @@ def apply(cls, model: nn.Module, block_size) -> Tuple[nn.Module, bool]: module.__class__ = repl_module # Bind the partial function to the instance module.forward = MethodType(partial(repl_module.forward, block_size=block_size), module) - transformed = True - break - return model, transformed \ No newline at end of file + transformed = True # Set to True if at least one transformation occurs + return model, transformed From b7e065518adee308a6058c08867145228872f645 Mon Sep 17 00:00:00 2001 From: Rishin Raj Date: Mon, 17 Mar 2025 10:27:55 +0000 Subject: [PATCH 4/5] Added support Block size not divisible by q_len Signed-off-by: Rishin Raj --- .../models/mllama/modeling_mllama.py | 58 +++++++++++-------- 1 file changed, 35 insertions(+), 23 deletions(-) diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index 65b2730ca..d006c33fa 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -1319,6 +1319,19 @@ def get_inputs_info(self): class QEffMllamaVisionAttention(MllamaVisionAttention): + def compute_block_attention(self, query_states, key_states, value_states, attention_mask, start_idx, end_idx): + curr_attn_weights = torch.matmul( + query_states[:, :, start_idx:end_idx, :], key_states.transpose(2, 3) + ) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask_block = attention_mask[:, :, start_idx:end_idx, : key_states.shape[-2]] + curr_attn_weights += causal_mask_block + # upcast attention to fp32 + curr_attn_weights = nn.functional.softmax(curr_attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + curr_attn_output = torch.matmul(curr_attn_weights, value_states) + + return curr_attn_output def forward( self, @@ -1327,7 +1340,6 @@ def forward( output_attentions: bool = None, block_size: int = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - query = self.q_proj(hidden_state) key = self.k_proj(hidden_state) value = self.v_proj(hidden_state) @@ -1338,35 +1350,35 @@ def forward( query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim).transpose(1, 2) key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2) value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2) - if block_size is not None: - num_blocks = q_seq_len // block_size - - attn_output_blocks = [] - attn_weights_blocks = [] - for i in range(num_blocks): - query_block = query[:, :, i * block_size:(i + 1) * block_size, :] - attn_weights_block = torch.matmul(query_block, key.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: - # causal_mask_block = attention_mask[:, :, :, : key.shape[-2]] - causal_mask_block = attention_mask[:, :, i * block_size:(i + 1) * block_size,:] - attn_weights_block = attn_weights_block + causal_mask_block + if block_size is not None: + runtime_block_size = torch.where( + (q_seq_len // torch.tensor(block_size)) > 0, torch.tensor(block_size), torch.tensor(1) + ) + reminder_block_size = q_seq_len % block_size # calculate the remaining query block + attn_output = torch.zeros(batch_size, self.num_heads, q_seq_len, self.head_dim) + num_iterations = q_seq_len // runtime_block_size + + for iteration in range(num_iterations): + start_idx = iteration * runtime_block_size + end_idx = (iteration + 1) * runtime_block_size + attn_output[:, :, start_idx:end_idx, :] = self.compute_block_attention( + query, key, value, attention_mask, start_idx, end_idx + ) - attn_weights_block = nn.functional.softmax(attn_weights_block, dim=-1, dtype=torch.float32).to(query.dtype) - attn_output_block = torch.matmul(attn_weights_block, value) + if reminder_block_size: + start_idx = num_iterations * runtime_block_size + end_idx = start_idx + reminder_block_size + attn_output[:, :, start_idx:end_idx, :] = self.compute_block_attention( + query, key, value, attention_mask, start_idx, end_idx + ) - attn_output_blocks.append(attn_output_block) - attn_weights_blocks.append(attn_weights_block) - - attn_output = torch.cat(attn_output_blocks, dim=2) - attn_weights = torch.cat(attn_weights_blocks, dim=2) else: # Regular attention attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: - causal_mask = attention_mask[:, :, :, :key.shape[-2]] + causal_mask = attention_mask[:, :, :, : key.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) @@ -1380,4 +1392,4 @@ def forward( if not output_attentions: attn_weights = None - return output, attn_weights \ No newline at end of file + return output, attn_weights From d43d223ffba5546c55218baab7de2b5194a28822 Mon Sep 17 00:00:00 2001 From: Rishin Raj Date: Tue, 25 Mar 2025 14:42:49 +0000 Subject: [PATCH 5/5] pixel value float update Signed-off-by: Rishin Raj --- QEfficient/transformers/models/modeling_auto.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 8e6fc04e1..dd90cb32d 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -632,7 +632,7 @@ def compile( custom_io_vision = {} kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16" - custom_io_vision["pixel_values"] = kv_cache_dtype + custom_io_vision["pixel_values"] = "float16" for output_name in output_names["vision"]: custom_io_vision[output_name] = kv_cache_dtype