Skip to content

Added support for Llava model single QPC #265

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def _export(

except Exception as e:
logger.error(f"ONNX export (or) ONNXTransforms failed: {e}")

raise e

finally:
Expand Down
13 changes: 13 additions & 0 deletions QEfficient/generation/text_generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,19 @@ def __repr__(self):
\nTotal (E2E) inference time is= {round(self.perf_metrics.total_time, 2)}"


@dataclass
class CloudAI100ExecInfoNew:
batch_size: int
generated_ids: Union[List[np.ndarray], np.ndarray]
perf_metrics: PerfMetrics

def __repr__(self):
return f"Average Prefill time a.k.a TTFT is= {round(self.perf_metrics.prefill_time, 2)}\
\nDecode token/sec is= {round(self.perf_metrics.decode_perf * self.batch_size, 2)}\
\nTotal token/sec is= {round(self.perf_metrics.total_perf * self.batch_size, 2)}\
\nTotal (E2E) inference time is= {round(self.perf_metrics.total_time, 2)}"


io_files = []


Expand Down
6 changes: 6 additions & 0 deletions QEfficient/transformers/models/llava/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
294 changes: 294 additions & 0 deletions QEfficient/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
from typing import List, Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from transformers.models.llava.modeling_llava import (
LlavaCausalLMOutputWithPast,
LlavaForConditionalGeneration,
logger,
)

BS = 1
NUM_CHANNEL = 3
SEQ_LEN = 592
IMAGE_SIZE = 336
CTX_LEN = 1024


class QEffLlavaForConditionalGeneration(LlavaForConditionalGeneration):
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[int] = None,
vision_feature_select_strategy: Optional[str] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.


Returns:

Example:

```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, LlavaForConditionalGeneration

>>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
>>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

>>> prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:"
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> inputs = processor(images=image, text=prompt, return_tensors="pt")

>>> # Generate
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
```"""

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)

if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)

legacy_processing = False
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)

# if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing
# not very reliable, but we don't expect one to actually pass 500+ images for one prompt
# In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True
legacy_processing = (
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
) or (input_ids.shape[-1] == 1 and pixel_values is not None)

if pixel_values is not None:
image_features = self.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
)

if legacy_processing:
logger.warning_once(
"Expanding inputs for image tokens in LLaVa should be done in processing. "
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
)
# prefill stage vs decoding stage (legacy behavior copied)
if input_ids.shape[1] != 1:
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids, attention_mask, labels
)
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
else:
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]

# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)

# Get the target length
target_length = input_ids.shape[1]
past_length = first_layer_past_key_value.shape[-1]

extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)

# Filter out only the tokens that can be un-attended, this can happen
# if one uses Llava + Fused modules where the cache on the
# first iteration is already big enough, or if one passes custom cache
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
new_batch_index = batch_index[valid_indices]
new_non_attended_tokens = non_attended_tokens[valid_indices]

# Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0

attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[
-target_length:
]

# TODO: @raushan retain only the new behavior after v4.47
else:
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
n_image_features = image_features.shape[1]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)

mask = input_ids == self.config.image_token_index
indices1 = mask.to(torch.int64).cumsum(1) - 1
indices0 = torch.arange(mask.shape[0]).view(-1, 1)
image_features_expanded = image_features[indices0, indices1]
image_inputs_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds)
# *where to skip image encoder for decode*
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_inputs_embeds)

outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
num_logits_to_keep=num_logits_to_keep,
)

logits = outputs[0]

loss = None
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device))

if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output

return logits, pixel_values, outputs.past_key_values

def get_dummy_inputs(self, **kwargs):
num_layers = self.config.text_config.num_hidden_layers
num_key_value_heads = self.config.text_config.num_key_value_heads
head_dim = self.config.text_config.hidden_size // self.config.text_config.num_attention_heads

inputs = {
"input_ids": torch.ones((BS, SEQ_LEN), dtype=torch.int64),
"attention_mask": torch.ones((BS, SEQ_LEN), dtype=torch.int64),
"pixel_values": torch.zeros((BS, NUM_CHANNEL, IMAGE_SIZE, IMAGE_SIZE), dtype=torch.float32),
}
inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1)
inputs["past_key_values"] = []
for i in range(num_layers):
inputs["past_key_values"].append(
(
torch.zeros(BS, num_key_value_heads, CTX_LEN, head_dim),
torch.zeros(BS, num_key_value_heads, CTX_LEN, head_dim),
)
)
inputs["position_ids"] = torch.full(inputs["position_ids"].shape, CTX_LEN - 1)
return inputs

def get_specializations(
self, batch_size: int, prefill_seq_len: int, ctx_len: int, img_size: int, **compiler_options
):
# TODO: check if this should be named num_crops or something else
max_num_images = compiler_options.get("max_num_images", 1)
prefill_seq_len = prefill_seq_len if prefill_seq_len else SEQ_LEN
ctx_len = ctx_len if ctx_len else CTX_LEN
img_size = img_size if img_size else IMAGE_SIZE

return [
{
"batch_size": batch_size,
"seq_len": prefill_seq_len,
"ctx_len": ctx_len,
"max_num_images": max_num_images,
"img_size": img_size,
},
{
"batch_size": batch_size,
"seq_len": "1",
"ctx_len": ctx_len,
"max_num_images": max_num_images,
"img_size": img_size,
},
]

def get_onnx_dynamic_axes(
self,
):
# Define dynamic axes
num_layers = self.config.text_config.num_hidden_layers

dynamic_axes = {
"input_ids": {0: "batch_size", 1: "seq_len"},
"position_ids": {0: "batch_size", 1: "seq_len"},
"pixel_values": {0: "batch_size", 2: "img_size", 3: "img_size"},
}
for i in range(num_layers):
dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"}
dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"}

return dynamic_axes

def get_output_names(
self,
):
output_names = ["logits", "pixel_values_RetainedState"]
for i in range(self.language_model.config.num_hidden_layers):
for kv in ["key", "value"]:
output_names.append(f"past_{kv}.{i}_RetainedState")
return output_names
54 changes: 0 additions & 54 deletions QEfficient/transformers/models/mllama/modeling_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
rotate_half,
)

from QEfficient.transformers.cache_utils import QEffDynamicCache
from QEfficient.transformers.modeling_utils import (
_create_causal_mask,
_prepare_aspect_ratio_attention_mask,
Expand Down Expand Up @@ -1204,56 +1203,3 @@ def generate_dummy_io_info(self, kv_offload=False):
output_names = lang_output_names

return inputs, output_names, dynamic_axes, inputs_shape


class ModelWrapper(nn.Module):
def __init__(self, mllama):
super().__init__()
self.mllama = mllama
self.num_hidden_layers = mllama.config.get_text_config().num_hidden_layers
self.config = self.mllama.config.get_text_config()

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
aspect_ratio_mask: Optional[torch.Tensor] = None,
aspect_ratio_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_mask: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
):
if past_key_values is not None:
past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values)
outputs = self.mllama(
input_ids=input_ids,
pixel_values=pixel_values,
aspect_ratio_mask=aspect_ratio_mask,
aspect_ratio_ids=aspect_ratio_ids,
attention_mask=attention_mask,
cross_attention_mask=cross_attention_mask,
cross_attention_states=cross_attention_states,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
num_logits_to_keep=num_logits_to_keep,
)
if "past_key_values" in outputs:
outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache()
return outputs
Loading