Skip to content

Added support for InternVL single QPC #264

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion QEfficient/base/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
# SPDX-License-Identifier: BSD-3-Clause
#
# ----------------------------------------------------------------------------
from typing import Dict, Tuple, Type
from types import MethodType
from typing import Callable, Dict, Tuple, Type

from torch import nn

Expand Down Expand Up @@ -87,3 +88,25 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
@classmethod
def mutate(cls, original_module: nn.Module, parent_module: nn.Module):
raise NotImplementedError("Please implement your own method by inheriting this class")


class ModuleMethodMapperTransform(PytorchTransform):
"""
Serves as base class for any transform that want to map a particular method of a class to a new method implementation.
"""

_match_class_replace_method: Dict[nn.Module, Dict[str, Callable]]
_match_string_replace_method: Dict[str, Dict[str, Callable]]

@classmethod
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
transformed = False
for module in model.modules():
if (repl_method_map := cls._match_class_replace_method.get(type(module))) or (
repl_method_map := cls._match_string_replace_method.get(module.__class__.__name__)
):
for orig_method_name, mapped_method in repl_method_map.items():
setattr(module, orig_method_name, MethodType(mapped_method, module))
transformed = True

return model, transformed
6 changes: 6 additions & 0 deletions QEfficient/transformers/models/internvl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
154 changes: 154 additions & 0 deletions QEfficient/transformers/models/internvl/modeling_internvl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

import torch
import torch.nn as nn
import torch.nn.functional as F

from QEfficient.utils import constants
from QEfficient.utils._utils import get_padding_shape_from_config


class QEffInternVLModel(nn.Module):
def get_specializations(
self, batch_size: int, prefill_seq_len: int, ctx_len: int, img_size: int, **compiler_options
):
# TODO: check if this should be named num_crops or something else
num_crops = compiler_options.get("num_crops", 13)
prefill_seq_len = prefill_seq_len if prefill_seq_len else 3840 # 4096-256
ctx_len = ctx_len if ctx_len else 4096
img_size = img_size if img_size else 448

return [
{
"batch_size": batch_size,
"seq_len": prefill_seq_len,
"ctx_len": ctx_len,
"num_crops": num_crops,
"img_size": img_size,
},
{
"batch_size": batch_size,
"seq_len": "1",
"ctx_len": ctx_len,
"num_crops": num_crops,
"img_size": img_size,
},
]

def get_onnx_dynamic_axes(
self,
):
# Define dynamic axes
dynamic_axes = {}
dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"}
dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"}
dynamic_axes["pixel_values"] = {0: "num_crops", 2: "img_size", 3: "img_size"}

pkv_dynamic_axes = {0: "batch_size", 2: "ctx_len"}
for i in range(self.language_model.config.num_hidden_layers):
for kv in ["key", "value"]:
dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes

return dynamic_axes

def get_output_names(
self,
):
output_names = ["logits", "pixel_values_RetainedState"]
for i in range(self.language_model.config.num_hidden_layers):
for kv in ["key", "value"]:
output_names.append(f"past_{kv}.{i}_RetainedState")
return output_names

def get_dummy_inputs(self, kv_offload: bool = False):
if kv_offload:
raise ValueError("kv_offload method not supported for InternVL yet!")
NUM_CROPS = 13
C, H, W = 3, 448, 448

# Define shapes
inputs_shapes = {}
inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
inputs_shapes["position_ids"] = (
constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN,
)
inputs_shapes["pixel_values"] = (NUM_CROPS, C, H, W)

# Define inputs
inputs = {}
inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64)
inputs["position_ids"] = (
torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64)
.view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
.repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1)
)
inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=torch.float32)

# Add data for KV
kv_cache_shape = get_padding_shape_from_config(
config=self.language_model.config,
batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN,
)

inputs["past_key_values"] = [[] for _ in range(self.language_model.config.num_hidden_layers)]
for i in range(self.language_model.config.num_hidden_layers):
for kv in ["key", "value"]:
inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32))

return inputs

def forward(self, input_ids, pixel_values, position_ids, past_key_values):
# TODO: Check if Hardcoding this is okay, i.e. check if this value is common for all intern models
IMG_CONTEXT_TOKEN = 151667

input_embeds = self.language_model.get_input_embeddings()(input_ids)
vit_embeds = self.extract_feature(pixel_values)
B, N, C = input_embeds.shape
image_input_embeds = input_embeds.reshape(B * N, C)
image_input_ids = input_ids.reshape(B * N)
selected = image_input_ids == IMG_CONTEXT_TOKEN
indices1 = selected.unsqueeze(0).to(torch.int64).cumsum(1) - 1
indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1)
image_features_expanded = vit_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1]
image_input_embeds = torch.where(selected.unsqueeze(0).unsqueeze(-1), image_features_expanded, input_embeds)
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), input_embeds, image_input_embeds)
outputs = self.language_model(
inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True
)
return outputs.logits, pixel_values, outputs.past_key_values


class QEffInternVisionEmbeddings(nn.Module):
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, channel, width, height]
batch_size, _, height, width = patch_embeds.shape
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)

pos_embed = self.position_embedding[:, 1:, :]
target_dtype = pos_embed.dtype
pos_embed = (
pos_embed.float()
.reshape(1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1)
.permute(0, 3, 1, 2)
)
pos_embed = (
F.interpolate(pos_embed, size=(height, width), mode="bilinear", align_corners=False)
.reshape(1, -1, height * width)
.permute(0, 2, 1)
.to(target_dtype)
)

position_embedding = torch.cat([self.position_embedding[:, :1, :], pos_embed], dim=1)

embeddings = embeddings + position_embedding.to(target_dtype)
return embeddings
47 changes: 43 additions & 4 deletions QEfficient/transformers/models/mllama/modeling_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@
from QEfficient.utils.constants import Constants

bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
max_num_images = constants.ONNX_EXPORT_MAX_NUM_IMAGES
max_image_tiles = constants.ONNX_EXPORT_MAX_IMAGE_TILES
image_size = constants.ONNX_EXPORT_IMAGE_WIDTH
num_channel = constants.ONNX_EXPORT_IMAGE_DEPTH
max_num_images = 1
max_image_tiles = 4
image_size = 560
num_channel = 3
seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN


Expand Down Expand Up @@ -998,7 +998,46 @@ def forward(
)


class QEffMllamaVisionEncoder(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.cross_attention_layers = self.model.config.get_text_config().cross_attention_layers

def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
aspect_ratio_mask: Optional[torch.Tensor] = None,
aspect_ratio_ids: Optional[torch.Tensor] = None,
) -> List[Tuple[torch.Tensor]]:
vision_outputs = self.model.vision_model(
pixel_values=pixel_values,
aspect_ratio_ids=aspect_ratio_ids,
aspect_ratio_mask=aspect_ratio_mask,
)
cross_attention_states = vision_outputs[0]
cross_attention_states = self.model.multi_modal_projector(cross_attention_states).reshape(
-1, cross_attention_states.shape[-2], self.model.hidden_size
)

bsz = pixel_values.shape[0]
outputs = []
for i in self.cross_attention_layers:
cross_attn = self.model.language_model.model.layers[i].cross_attn
key_states = cross_attn.k_proj(cross_attention_states)
value_states = cross_attn.v_proj(cross_attention_states)
key_states = key_states.view(bsz, -1, cross_attn.num_key_value_heads, cross_attn.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, -1, cross_attn.num_key_value_heads, cross_attn.head_dim).transpose(
1, 2
)
outputs.append((key_states, value_states))
return outputs


class QEffMllamaForConditionalGeneration(MllamaForConditionalGeneration):
def get_qeff_vision_encoder(self):
return QEffMllamaVisionEncoder(self)

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand Down
Loading