From a492e1744f2433552f3dd2493ba97377b06f1d9b Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Sun, 9 Feb 2025 14:53:08 +0000 Subject: [PATCH] working single qpc, single soc Signed-off-by: Amit Raj --- QEfficient/base/modeling_qeff.py | 1 + .../generation/text_generation_inference.py | 13 + .../transformers/models/llava/__init__.py | 6 + .../models/llava/modeling_llava.py | 294 ++++++++++++++++++ .../models/mllama/modeling_mllama.py | 54 ---- .../transformers/models/modeling_auto.py | 155 ++++----- .../transformers/models/pytorch_transforms.py | 8 + QEfficient/utils/constants.py | 7 +- 8 files changed, 403 insertions(+), 135 deletions(-) create mode 100644 QEfficient/transformers/models/llava/__init__.py create mode 100644 QEfficient/transformers/models/llava/modeling_llava.py diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index e8b388710..b2dab6ae6 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -190,6 +190,7 @@ def _export( except Exception as e: logger.error(f"ONNX export (or) ONNXTransforms failed: {e}") + raise e finally: diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index 54b6f057e..14e781bfb 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -63,6 +63,19 @@ def __repr__(self): \nTotal (E2E) inference time is= {round(self.perf_metrics.total_time, 2)}" +@dataclass +class CloudAI100ExecInfoNew: + batch_size: int + generated_ids: Union[List[np.ndarray], np.ndarray] + perf_metrics: PerfMetrics + + def __repr__(self): + return f"Average Prefill time a.k.a TTFT is= {round(self.perf_metrics.prefill_time, 2)}\ + \nDecode token/sec is= {round(self.perf_metrics.decode_perf * self.batch_size, 2)}\ + \nTotal token/sec is= {round(self.perf_metrics.total_perf * self.batch_size, 2)}\ + \nTotal (E2E) inference time is= {round(self.perf_metrics.total_time, 2)}" + + io_files = [] diff --git a/QEfficient/transformers/models/llava/__init__.py b/QEfficient/transformers/models/llava/__init__.py new file mode 100644 index 000000000..d259e435a --- /dev/null +++ b/QEfficient/transformers/models/llava/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/transformers/models/llava/modeling_llava.py b/QEfficient/transformers/models/llava/modeling_llava.py new file mode 100644 index 000000000..a7998adc0 --- /dev/null +++ b/QEfficient/transformers/models/llava/modeling_llava.py @@ -0,0 +1,294 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from transformers.models.llava.modeling_llava import ( + LlavaCausalLMOutputWithPast, + LlavaForConditionalGeneration, + logger, +) + +BS = 1 +NUM_CHANNEL = 3 +SEQ_LEN = 592 +IMAGE_SIZE = 336 +CTX_LEN = 1024 + + +class QEffLlavaForConditionalGeneration(LlavaForConditionalGeneration): + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[int] = None, + vision_feature_select_strategy: Optional[str] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + ) -> Union[Tuple, LlavaCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, LlavaForConditionalGeneration + + >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf") + >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") + + >>> prompt = "USER: \nWhat's the content of the image? ASSISTANT:" + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_new_tokens=15) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed" + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + legacy_processing = False + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + # if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing + # not very reliable, but we don't expect one to actually pass 500+ images for one prompt + # In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True + legacy_processing = ( + (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length + ) or (input_ids.shape[-1] == 1 and pixel_values is not None) + + if pixel_values is not None: + image_features = self.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + ) + + if legacy_processing: + logger.warning_once( + "Expanding inputs for image tokens in LLaVa should be done in processing. " + "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " + "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " + "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." + ) + # prefill stage vs decoding stage (legacy behavior copied) + if input_ids.shape[1] != 1: + inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids, attention_mask, labels + ) + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) + else: + # Retrieve the first layer to inspect the logits and mask out the hidden states + # that are set to 0 + first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] + + # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 + batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) + + # Get the target length + target_length = input_ids.shape[1] + past_length = first_layer_past_key_value.shape[-1] + + extended_attention_mask = torch.ones( + (attention_mask.shape[0], past_length), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + + # Filter out only the tokens that can be un-attended, this can happen + # if one uses Llava + Fused modules where the cache on the + # first iteration is already big enough, or if one passes custom cache + valid_indices = non_attended_tokens < extended_attention_mask.size(-1) + new_batch_index = batch_index[valid_indices] + new_non_attended_tokens = non_attended_tokens[valid_indices] + + # Zero-out the places where we don't need to attend + extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 + + attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) + position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[ + -target_length: + ] + + # TODO: @raushan retain only the new behavior after v4.47 + else: + n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item() + n_image_features = image_features.shape[1] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + + mask = input_ids == self.config.image_token_index + indices1 = mask.to(torch.int64).cumsum(1) - 1 + indices0 = torch.arange(mask.shape[0]).view(-1, 1) + image_features_expanded = image_features[indices0, indices1] + image_inputs_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds) + # *where to skip image encoder for decode* + inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_inputs_embeds) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, + ) + + logits = outputs[0] + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + if attention_mask is not None: + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) + shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() + shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() + else: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return logits, pixel_values, outputs.past_key_values + + def get_dummy_inputs(self, **kwargs): + num_layers = self.config.text_config.num_hidden_layers + num_key_value_heads = self.config.text_config.num_key_value_heads + head_dim = self.config.text_config.hidden_size // self.config.text_config.num_attention_heads + + inputs = { + "input_ids": torch.ones((BS, SEQ_LEN), dtype=torch.int64), + "attention_mask": torch.ones((BS, SEQ_LEN), dtype=torch.int64), + "pixel_values": torch.zeros((BS, NUM_CHANNEL, IMAGE_SIZE, IMAGE_SIZE), dtype=torch.float32), + } + inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1) + inputs["past_key_values"] = [] + for i in range(num_layers): + inputs["past_key_values"].append( + ( + torch.zeros(BS, num_key_value_heads, CTX_LEN, head_dim), + torch.zeros(BS, num_key_value_heads, CTX_LEN, head_dim), + ) + ) + inputs["position_ids"] = torch.full(inputs["position_ids"].shape, CTX_LEN - 1) + return inputs + + def get_specializations( + self, batch_size: int, prefill_seq_len: int, ctx_len: int, img_size: int, **compiler_options + ): + # TODO: check if this should be named num_crops or something else + max_num_images = compiler_options.get("max_num_images", 1) + prefill_seq_len = prefill_seq_len if prefill_seq_len else SEQ_LEN + ctx_len = ctx_len if ctx_len else CTX_LEN + img_size = img_size if img_size else IMAGE_SIZE + + return [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "max_num_images": max_num_images, + "img_size": img_size, + }, + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "max_num_images": max_num_images, + "img_size": img_size, + }, + ] + + def get_onnx_dynamic_axes( + self, + ): + # Define dynamic axes + num_layers = self.config.text_config.num_hidden_layers + + dynamic_axes = { + "input_ids": {0: "batch_size", 1: "seq_len"}, + "position_ids": {0: "batch_size", 1: "seq_len"}, + "pixel_values": {0: "batch_size", 2: "img_size", 3: "img_size"}, + } + for i in range(num_layers): + dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} + dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + + return dynamic_axes + + def get_output_names( + self, + ): + output_names = ["logits", "pixel_values_RetainedState"] + for i in range(self.language_model.config.num_hidden_layers): + for kv in ["key", "value"]: + output_names.append(f"past_{kv}.{i}_RetainedState") + return output_names diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index 0c8506c2a..4aedf7bfe 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -39,7 +39,6 @@ rotate_half, ) -from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_utils import ( _create_causal_mask, _prepare_aspect_ratio_attention_mask, @@ -1204,56 +1203,3 @@ def generate_dummy_io_info(self, kv_offload=False): output_names = lang_output_names return inputs, output_names, dynamic_axes, inputs_shape - - -class ModelWrapper(nn.Module): - def __init__(self, mllama): - super().__init__() - self.mllama = mllama - self.num_hidden_layers = mllama.config.get_text_config().num_hidden_layers - self.config = self.mllama.config.get_text_config() - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - aspect_ratio_mask: Optional[torch.Tensor] = None, - aspect_ratio_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_mask: Optional[torch.Tensor] = None, - cross_attention_states: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, - ): - if past_key_values is not None: - past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) - outputs = self.mllama( - input_ids=input_ids, - pixel_values=pixel_values, - aspect_ratio_mask=aspect_ratio_mask, - aspect_ratio_ids=aspect_ratio_ids, - attention_mask=attention_mask, - cross_attention_mask=cross_attention_mask, - cross_attention_states=cross_attention_states, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - labels=labels, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, - ) - if "past_key_values" in outputs: - outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache() - return outputs diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 570826e0b..1c251961b 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -29,7 +29,7 @@ from QEfficient.base.modeling_qeff import QEFFBaseModel from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform from QEfficient.generation.cloud_infer import QAICInferenceSession -from QEfficient.generation.text_generation_inference import get_compilation_dims +from QEfficient.generation.text_generation_inference import CloudAI100ExecInfoNew, PerfMetrics, get_compilation_dims from QEfficient.transformers.models.pytorch_transforms import ( CustomOpsTransform, KVCacheTransform, @@ -1151,6 +1151,7 @@ def __init__( super().__init__(model) self.model.config.text_config.use_cache = True self.input_shapes, self.output_names = None, None + self.num_layers = model.config.text_config.num_hidden_layers @classmethod def from_pretrained( @@ -1166,74 +1167,69 @@ def from_pretrained( logger.warning("Updating low_cpu_mem_usage=False") kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + from transformers import AutoConfig - model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) - return cls(model, **kwargs) + config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True) + config._attn_implementation = "eager" + config.vision_config.use_flash_attn = "false" + model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, config, *args, **kwargs) - def set_io_info(self): - if self.output_names is None or self.input_shapes is None: - _, self.output_names, _, self.input_shapes = self.model.generate_dummy_io_info(kv_offload=True) + return cls(model, **kwargs) def export( self, export_dir: Optional[str] = None, **kwargs, ) -> str: - inputs, self.output_names, dynamic_axes, self.input_shapes = self.model.generate_dummy_io_info() - self._export(inputs, self.output_names, dynamic_axes, export_dir=export_dir) + inputs = self.model.get_dummy_inputs() + dynamic_axes = self.model.get_onnx_dynamic_axes() + output_names = self.model.get_output_names() + self._export(inputs, output_names, dynamic_axes, export_dir=export_dir) def compile( self, - img_size: int, + img_size: int = None, onnx_path: Optional[str] = None, compile_dir: Optional[str] = None, *, - prefill_seq_len: int = 32, - ctx_len: int = 128, + prefill_seq_len: int = None, + ctx_len: int = None, batch_size: int = 1, num_devices: int = 1, num_cores: int = 16, # FIXME: Make this mandatory arg mxfp6_matmul: bool = False, - max_num_image: int = 1, + mxint8_kv_cache: bool = False, **compiler_options, ) -> str: - if self.output_names is None: - self.set_io_info() + output_names = self.model.get_output_names() + + # Get specializations from modelling file + specializations = self.model.get_specializations( + batch_size=batch_size, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + img_size=img_size, + **compiler_options, + ) - specializations = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "max_num_images": max_num_image, - "img_size": img_size, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "max_num_images": max_num_image, - "img_size": img_size, - }, - ] - custom_io = {} - kv_cache_dtype = "float16" + kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16" + custom_io = {} # inputs - for input_name in self.output_names: + for input_name in output_names: if input_name.endswith("_RetainedState"): custom_io[input_name[: -len("_RetainedState")]] = kv_cache_dtype # outputs - for output_name in self.output_names: + for output_name in output_names: if output_name.endswith("_RetainedState"): custom_io[output_name] = kv_cache_dtype - compiler_options.update({"retained-state": True}) self._compile( onnx_path, compile_dir, compile_only=True, + retained_state=True, specializations=specializations, convert_to_fp16=True, mxfp6_matmul=mxfp6_matmul, @@ -1243,6 +1239,9 @@ def compile( **compiler_options, ) + def get_onnx_dynamic_axes(self): + return self.model.get_onnx_dynamic_axes() + def generate( self, inputs: torch.Tensor, @@ -1279,12 +1278,15 @@ def cloud_ai_100_generate( batch_size, ctx_len, fbs = get_compilation_dims(self.qpc_path) - eos_token_id = 0 pad_token_id = 1 # Skip inputs/outputs qpc_session.skip_buffers( - [x for x in qpc_session.input_names + qpc_session.output_names if x.startswith("past_")] + [ + x + for x in qpc_session.input_names + qpc_session.output_names + if x.startswith("past_") or x.endswith("_RetainedState") + ] ) # Read prompt and ctx len from session @@ -1299,10 +1301,11 @@ def cloud_ai_100_generate( ) input_len = inputs["attention_mask"].sum(1, keepdims=True) - padded_len = inputs["input_ids"].shape[1] - num_chunks = -(padded_len // -prefill_seq_len) # ceil divide without float + input_ids_length = inputs["input_ids"].shape[1] + + num_chunks = -(input_ids_length // -prefill_seq_len) # ceil divide without float + padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len - generation_len = None if generation_len is None: generation_len = ctx_len - input_len.max() @@ -1310,70 +1313,73 @@ def cloud_ai_100_generate( generated_ids = np.full((batch_size, generation_len + 1), pad_token_id) # Prepare inputs for prefill - start = perf_counter() + prefill_start = perf_counter() + + input_ids = inputs["input_ids"] + input_ids_size = input_ids.shape[1] + inputs["input_ids"] = torch.nn.functional.pad( + inputs["input_ids"], + (0, padded_len - input_ids_size), + "constant", + 1, + ) + inputs["attention_mask"] = torch.nn.functional.pad( + inputs["attention_mask"], (0, padded_len - input_ids_size), "constant", 0 + ) - inputs["position_ids"] = np.where( - inputs.pop("attention_mask"), np.arange(padded_len), -1 - ) # Need to use -1 as position_ids for invalid tokens - inputs = dict(inputs) + for k, v in inputs.items(): + inputs[k] = np.array(v) + + inputs["pixel_values"] = inputs["pixel_values"].astype("float16") + inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) - # vision_session.deactivate() qpc_session.activate() # Run prefill + for i in range(num_chunks): chunk_inputs = inputs.copy() chunk_inputs["input_ids"] = inputs["input_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len] chunk_inputs["position_ids"] = inputs["position_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len] outputs = qpc_session.run(chunk_inputs) - # Skip inputs/outputs again - qpc_session.skip_buffers( - [x for x in qpc_session.input_names + qpc_session.output_names if x.startswith("past_")] - ) - + prefill_time = prefill_start - perf_counter() # Get first token inputs["input_ids"] = outputs["logits"].argmax(2) - inputs["position_ids"] = input_len - inputs["cross_attention_mask"] = inputs["cross_attention_mask"][:, -1:, :, :] + inputs["position_ids"] = input_len.numpy() generated_ids[:, 0] = inputs["input_ids"].squeeze(1) - finished_sequences = inputs["input_ids"] == eos_token_id if streamer: streamer.put(inputs["input_ids"][0]) + qpc_session.skip_buffers(["pixel_values"]) + inputs.pop("pixel_values") + # Decode loop - loop_start = perf_counter() + decode_start = perf_counter() for num_token in range(1, generation_len): outputs = qpc_session.run(inputs) - # Prepare inputs for next iteration inputs["input_ids"] = outputs["logits"].argmax(2) inputs["position_ids"] += 1 generated_ids[:, num_token] = inputs["input_ids"].squeeze(1) - finished_sequences |= inputs["input_ids"] == eos_token_id if streamer: streamer.put(inputs["input_ids"][0]) - if finished_sequences.all(): - break - end = perf_counter() + decode_end = perf_counter() if streamer: streamer.end() - prefill_perf = 1 / (loop_start - start) - decode_perf = (num_token - 1) / (end - loop_start) - total_perf = num_token / (end - start) + decode_perf = (num_token - 1) / (decode_end - decode_start) + total_time = decode_end - prefill_start + total_perf = num_token / total_time - print("TTFT:", round(loop_start - start, 2), "s", file=sys.stderr) - print("E2ET:", round(end - start, 2), "s", file=sys.stderr) - print("Prefill:", round(prefill_perf, 2), "tok/s", file=sys.stderr) - print("Decode:", round(decode_perf, 2), "tok/s", file=sys.stderr) - print("E2E:", round(total_perf, 2), "tok/s", file=sys.stderr) - if batch_size > 1: - print("Prefill (batch):", round(prefill_perf * batch_size, 2), "tok/s", file=sys.stderr) - print("Decode (batch):", round(decode_perf * batch_size, 2), "tok/s", file=sys.stderr) - print("E2E (batch):", round(total_perf * batch_size, 2), "tok/s", file=sys.stderr) - return generated_ids + return CloudAI100ExecInfoNew( + batch_size=batch_size, + generated_ids=generated_ids, + perf_metrics=PerfMetrics( + prefill_time=prefill_time, decode_perf=decode_perf, total_perf=total_perf, total_time=total_time + ), + ) @property def model_hash(self) -> str: @@ -1417,7 +1423,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, kv_offload=False, **kwar kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) - + return cls._get_qeff_class(model, kv_offload, **kwargs) @classmethod @@ -1437,4 +1443,3 @@ def _get_qeff_class(cls, model, kv_offload, **kwargs): return _QEffAutoModelForImageTextToText2QPC(model, **kwargs) else: return _QEFFAutoModelForImageTextToText1QPC(model, **kwargs) - diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 5c87a2847..4ae62da49 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -51,6 +51,9 @@ LlamaModel, LlamaRMSNorm, ) +from transformers.models.llava.modeling_llava import ( + LlavaForConditionalGeneration, +) from transformers.models.mistral.modeling_mistral import ( MistralAttention, MistralDecoderLayer, @@ -152,6 +155,9 @@ QEffLlamaForCausalLM, QEffLlamaModel, ) +from QEfficient.transformers.models.llava.modeling_llava import ( + QEffLlavaForConditionalGeneration, +) from QEfficient.transformers.models.mistral.modeling_mistral import ( QEffMistralAttention, QEffMistralDecoderLayer, @@ -250,6 +256,8 @@ class KVCacheTransform(ModuleMappingTransform): LlamaDecoderLayer: QEffLlamaDecoderLayer, LlamaModel: QEffLlamaModel, LlamaForCausalLM: QEffLlamaForCausalLM, + # Llava + LlavaForConditionalGeneration: QEffLlavaForConditionalGeneration, # Gemma GemmaAttention: QEffGemmaAttention, GemmaDecoderLayer: QEffGemmaDecoderLayer, diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 028dd13b7..ab861a788 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -49,12 +49,6 @@ def get_models_dir(): ONNX_EXPORT_EXAMPLE_FBS = 4 ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep ONNX_EXPORT_OPSET = 13 -ONNX_EXPORT_MAX_NUM_IMAGES = 1 -ONNX_EXPORT_MAX_IMAGE_TILES = 4 -ONNX_EXPORT_IMAGE_WIDTH = 560 -ONNX_EXPORT_IMAGE_LENGHT = 560 -ONNX_EXPORT_IMAGE_DEPTH = 3 -ONNX_EXPORT_CTX_LEN = 1024 COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw", "-aic-hw-version=2.0"] @@ -130,6 +124,7 @@ class QnnConstants: "--float_bitwidth ", "--preserve_io_datatype", "--onnx_skip_simplification", + "--onnx_defer_loading", ] IMMUTABLE_CONTEXT_BIN_GEN_ARGS = [