From 2cee1940c698fe1799eadd42d6314f3bdc526b40 Mon Sep 17 00:00:00 2001 From: Dipankar Sarkar Date: Mon, 20 Jan 2025 14:52:57 +0000 Subject: [PATCH 1/2] Adding VLM pipeline Signed-off-by: Dipankar Sarkar --- QEfficient/base/pytorch_transforms.py | 2 +- QEfficient/transformers/modeling_utils.py | 17 + .../transformers/models/modeling_auto.py | 493 ++++- .../models/phi3_vision/__init__.py | 6 + .../phi3_vision/configuration_phi3_v.py | 215 ++ .../models/phi3_vision/modeling_phi3_v.py | 1938 +++++++++++++++++ .../phi3_vision/modeling_phi3_vision.py | 1771 +++++++++++++++ .../transformers/models/pytorch_transforms.py | 15 + QEfficient/utils/constants.py | 8 +- tests/transformers/models/test_vlm_models.py | 183 ++ 10 files changed, 4642 insertions(+), 6 deletions(-) create mode 100644 QEfficient/transformers/models/phi3_vision/__init__.py create mode 100644 QEfficient/transformers/models/phi3_vision/configuration_phi3_v.py create mode 100644 QEfficient/transformers/models/phi3_vision/modeling_phi3_v.py create mode 100644 QEfficient/transformers/models/phi3_vision/modeling_phi3_vision.py create mode 100755 tests/transformers/models/test_vlm_models.py diff --git a/QEfficient/base/pytorch_transforms.py b/QEfficient/base/pytorch_transforms.py index 6e21d11b2..d6b396fe1 100644 --- a/QEfficient/base/pytorch_transforms.py +++ b/QEfficient/base/pytorch_transforms.py @@ -40,7 +40,7 @@ class ModuleMappingTransform(PytorchTransform): def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: transformed = False for module in model.modules(): - if repl_module := cls._module_mapping.get(type(module)): + if repl_module := cls._module_mapping.get(module.__class__.__name__): module.__class__ = repl_module # Handling the __init__ calls in the models if hasattr(module, "__qeff_init__"): diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index f749cc0c3..a53719192 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -77,6 +77,15 @@ ) from QEfficient.customop import CustomRMSNormAIC +from QEfficient.transformers.models.phi3_vision.modeling_phi3_v import Phi3VForCausalLM +from QEfficient.transformers.models.phi3_vision.modeling_phi3_vision import ( + QEffPhi3ImageEmbedding, + QEffPhi3RotaryEmbedding, + QEffPhi3VAttention, + QEffPhi3VDecoderLayer, + QEffPhi3VForCausalLM, + QEffPhi3VModel, +) from .models.codegen.modeling_codegen import ( QEffCodeGenAttention, @@ -157,6 +166,7 @@ Starcoder2ForCausalLM.__name__, GPTBigCodeForCausalLM.__name__, MllamaForCausalLM.__name__, + Phi3VForCausalLM.__name__, ] ) @@ -241,4 +251,11 @@ GPTBigCodeAttention: QEffGPTBigCodeAttention, GPTBigCodeBlock: QEffGPTBigCodeBlock, GPTBigCodeModel: QEffGPTBigCodeModel, + # Phi3-vision + "Phi3VModel": QEffPhi3VModel, + "Phi3Attention": QEffPhi3VAttention, + "Phi3RotaryEmbedding": QEffPhi3RotaryEmbedding, + "Phi3VForCausalLM": QEffPhi3VForCausalLM, + "Phi3ImageEmbedding": QEffPhi3ImageEmbedding, + "Phi3DecoderLayer": QEffPhi3VDecoderLayer, } diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index c2e3777bc..31362aa36 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -12,9 +12,19 @@ from typing import List, Optional, Union import numpy as np +import requests import torch import torch.nn as nn -from transformers import AutoModel, AutoModelForCausalLM, PreTrainedTokenizer, PreTrainedTokenizerFast +from PIL import Image +from transformers import ( + AutoModel, + AutoModelForCausalLM, + AutoModelForImageTextToText, + AutoProcessor, + PreTrainedTokenizer, + PreTrainedTokenizerFast, + TextStreamer, +) import QEfficient from QEfficient.base.modeling_qeff import QEFFBaseModel @@ -24,6 +34,8 @@ from QEfficient.transformers.quantizers.auto import QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING, with_replaced_quantizers from QEfficient.transformers.quantizers.quant_transforms import AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform from QEfficient.utils import constants, get_padding_shape_from_config + +# from QEfficient.transformers.models.phi3_vision.modeling_phi3_vision import Phi3VModelWrapper from QEfficient.utils.cache import to_hashable logger = logging.getLogger(__file__) @@ -421,6 +433,485 @@ def generate( raise NotImplementedError("Only AI_100 runtime is supported right now via generate API") +class QEFFAutoModelForImageTextToText(QEFFTransformersBase): + """ + The QEFF class is designed for manipulating any causal language model from the HuggingFace hub. + Although it is possible to initialize the class directly, we highly recommend using the ``from_pretrained`` method for initialization. + + ``Mandatory`` Args: + :model (nn.Module): PyTorch model + :continuous_batching (bool): Weather this model will be used for continuous batching in future. If this is not set True here, the model can not be exported/compiled for continuous batching later. + :is_tlm (bool): Whether this is a Speculative Decoding Target Language Model. If set to True, `num_logits_to_keep` input array will have to be fed to control the number of returned logits during prefill/decode. + + + .. code-block:: python + + from QEfficient import QEFFAutoModelForImageTextToText + from transformers import AutoTokenizer + + model_name = "llava" + model = QEFFAutoModelForCausalLM.from_pretrained(model_name, num_hidden_layers=2) + model.compile(prefill_seq_len=1024, ctx_len=1280, num_cores=16, num_devices=1) + + processor = AutoProcessor.from_pretrained(model_name) + model.generate(inputs, streamer, device_ids, is) + """ + + _hf_auto_class = AutoModelForImageTextToText + _pytorch_transforms = [AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform, CustomOpsTransform, KVCacheTransform] + _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + + def __init__( + self, + model: nn.Module, + processor: AutoProcessor, + continuous_batching: bool = False, + is_tlm: bool = False, + **kwargs, + ): + model_class_name = model.__class__.__name__ + if not (model_class_name.endswith("ForCausalLM") or model_class_name.endswith("ForConditionalGeneration")): + raise TypeError(f"Required pytorch module for CausalLM or ConditionalGeneration, got {model_class_name}") + + if continuous_batching: + raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") + if is_tlm: + raise NotImplementedError("Speculative Decoding is not supported for image-text-to-text models yet.") + + if kwargs.pop("full_batch_size", None): + raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") + # warnings.warn( + # "full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2 + # ) + # breakpoint() + # phi3vwrapper_model=Phi3VModelWrapper(model) + super().__init__(model) + + # Set use_cache=True to get KV values as output during ONNX export + self.model.config.use_cache = True + self.processor = processor + self.num_layers = model.config.num_hidden_layers + self.num_key_value_heads = model.config.num_key_value_heads + self.head_dim = model.config.hidden_size // model.config.num_attention_heads + self.continuous_batching = continuous_batching + self.is_tlm = is_tlm + self.pad_token_id = model.config.pad_token_id + self.ctx_len = 1280 + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path, continuous_batching: bool = False, is_tlm: bool = False, *args, **kwargs + ): + """ + This method serves as the easiest entry point into using QEfficient. The interface is designed to be similar to transformers.AutoModelForCausalLM. + Once the model is initialized, you can use other methods such as export, compile, and generate on the same object. + + Args: + :pretrained_name_or_path (str): Model card name from HuggingFace or local path to model directory. + :continuous_batching (bool): Whether this model will be used for continuous batching in future. If this is not set True here, the model can not be exported/compiled for continuous batching later. + :is_tlm (bool): Whether this is a Speculative Decoding Target Language Model. If set to True, `num_logits_to_keep` input array will have to be fed to control the number of returned logits during prefill/decode. + :args, kwargs: Additional arguments to pass to transformers.AutoModelForCausalLM. + + .. code-block:: python + + from QEfficient import QEFFAutoModelForCausalLM + from transformers import AutoTokenizer + + # Initialize the model using from_pretrained similar to transformers.AutoModelForCausalLM + model_name = "gpt2" + model = QEFFAutoModelForCausalLM.from_pretrained(model_name) + + # Now you can directly compile the model for Cloud AI 100 + model.compile(num_cores=16) # Considering you have a Cloud AI 100 Standard SKU + + # You can now execute the model + tokenizer = AutoTokenizer.from_pretrained(model_name) + model.generate(prompts=["Hi there!!"], tokenizer=tokenizer) + """ + + if kwargs.pop("full_batch_size", None): + raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") + + self = super().from_pretrained(pretrained_model_name_or_path, is_tlm=is_tlm, *args, **kwargs) + self.continuous_batching = continuous_batching + + return self + + @property + def model_hash(self) -> str: + # Compute the hash with: model_config, continuous_batching, transforms + mhash = hashlib.sha256() + mhash.update(to_hashable(self.model.config.to_diff_dict())) + mhash.update(to_hashable({"continuous_batching": self.continuous_batching})) + mhash.update(to_hashable({"is_tlm": self.is_tlm})) + mhash.update(to_hashable(self._transform_names())) + mhash = mhash.hexdigest()[:16] + return mhash + + def _generate_inputs(self, **kwargs): + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + # seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + # fbs = constants.ONNX_EXPORT_EXAMPLE_FBS + + self.ctx_len = kwargs["ctx_len"] if "ctx_len" in kwargs else self.ctx_len + + ## PREPROCESSING THE MULTI-MODAL INPUTS for Phi-3.5 for now + # TODO: Create a map for the other models to have their own inputs accordingly + images = [] + placeholder = "" + + # Note: if OOM, you might consider reduce number of frames in this example. + for i in range(1, 2): + url = f"https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-{i}-2048.jpg" + images.append(Image.open(requests.get(url, stream=True).raw)) + placeholder += f"<|image_{1}|>\n" + + messages = [ + {"role": "user", "content": placeholder + "Summarize the deck of slides."}, + ] + + prompt = self.processor.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + inputs = dict(self.processor(images=images, text=prompt, return_tensors="pt")) + inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1) + inputs["past_key_values"] = [] + for i in range(self.num_layers): + inputs["past_key_values"].append( + ( + torch.zeros(bs, self.num_key_value_heads, self.ctx_len, self.head_dim), + torch.zeros(bs, self.num_key_value_heads, self.ctx_len, self.head_dim), + ) + ) + output_names = [ + "logits", + "pixel_values_RetainedState", + "image_sizes_RetainedState", + *[f"past_{kv}.{i}_RetainedState" for i in range(self.num_layers) for kv in ["key", "value"]], + ] + dynamic_axes = { + "input_ids": {0: "batch_size", 1: "seq_len"}, + "position_ids": {0: "batch_size", 1: "seq_len"}, + # "pixel_values": {0: "img_batch_size"}, + } + for i in range(self.num_layers): + dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} + dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + + # Avoid issues due to index out of range + inputs["position_ids"] = torch.full(inputs["position_ids"].shape, self.ctx_len - 1) + + return inputs, dynamic_axes, output_names + + def export( + self, + export_dir: Optional[str] = None, + **kwargs, + ) -> str: + """ + Exports the model to ``ONNX`` format using ``torch.onnx.export``. + + ``Optional`` Args: + :export_dir (str, optional): The directory path to store ONNX-graph. + :**kwargs: Keyword arguments for ``_generate_inputs``. If "ctx_len" is passed, it will be used as the context length. Otherwise, it will be set to 1280. + + Returns: + :str: Path of the generated ``ONNX`` graph. + """ + + example_inputs, dynamic_axes, output_names = self._generate_inputs(**kwargs) + # breakpoint() + return self._export( + example_inputs, + output_names, + dynamic_axes, + export_dir=export_dir, + ) + + def compile( + self, + onnx_path: Optional[str] = None, + compile_dir: Optional[str] = None, + *, + prefill_seq_len: int = 1024, + ctx_len: int = 1280, + batch_size: int = 1, + full_batch_size: Optional[int] = None, + kv_cache_batch_size: Optional[int] = None, + num_devices: int = 1, + num_cores: int = 16, # FIXME: Make this mandatory arg + mxfp6_matmul: bool = False, + mxint8_kv_cache: bool = False, + num_speculative_tokens: Optional[int] = None, + enable_qnn: bool = False, + qnn_config: Optional[str] = None, + **compiler_options, + ) -> str: + """ + This method compiles the exported ``ONNX`` model using the Cloud AI 100 Platform SDK compiler binary found at ``/opt/qti-aic/exec/qaic-exec`` and generates a ``qpc`` package. + If the model has not been exported yet, this method will handle the export process. + You can pass any other arguments that the `qaic-exec` takes as extra kwargs. + + ``Optional`` Args: + :onnx_path (str, optional): Path to pre-exported onnx model. + :compile_dir (str, optional): Path for saving the qpc generated. + :num_cores (int): Number of cores used to compile the model. + :num_devices (int): Number of devices the model needs to be compiled for. Defaults to 1. + :batch_size (int, optional): Batch size. ``Defaults to 1``. + :prefill_seq_len (int, optional): The length of the Prefill prompt should be less that ``prefill_seq_len``. ``Defaults to 32``. + :ctx_len (int, optional): Maximum ``ctx`` that the compiled model can remember. ``Defaults to 128``. + :full_batch_size (int, optional): Continuous batching batch size. + :mxfp6_matmul (bool, optional): Whether to use ``mxfp6`` compression for weights. ``Defaults to False``. + :mxint8_kv_cache (bool, optional): Whether to use ``mxint8`` compression for KV cache. ``Defaults to False``. + :num_speculative_tokens (int, optional): Number of speculative tokens to take as input for Speculative Decoding Target Language Model. + :mos (int, optional): Effort level to reduce on-chip memory. Defaults to -1, meaning no effort. ``Defaults to -1``. + :aic_enable_depth_first (bool, optional): Enables DFS with default memory size. ``Defaults to False``. + :enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.`` + :qnn_config (str): Path of QNN Config parameters file. ``Defaults to None.`` + + Returns: + :str: Path of the compiled ``qpc`` package. + """ + # if self.is_tlm: + # # assert num_speculative_tokens cfg is acceptable if defined + # if num_speculative_tokens is None: + # raise TypeError("missing required argument `num_speculative_tokens` as `is_tlm` is True.") + # if not isinstance(num_speculative_tokens, int) and num_speculative_tokens < 2: + # ValueError( + # f"`num_speculative_tokens` arg should be an integer greater than 1, got {num_speculative_tokens}" + # ) + # num_logits_to_keep = num_speculative_tokens + 1 + # if prefill_seq_len < num_logits_to_keep: + # raise ValueError( + # f"sequence length ({prefill_seq_len}) must be at least `num_speculative_tokens+1` ({num_logits_to_keep})" + # ) + + # if self.continuous_batching and full_batch_size is None: + # raise TypeError("missing required argument: 'full_batch_size'") + + # if kv_cache_batch_size and not full_batch_size: + # raise ValueError( + # "Prefix caching is enabled only for continuous batching. Please pass `full_batch_size` argument and make sure you pass `continuous_batching=True` in the `from_pretrained` call" + # ) + + kv_cache_batch_size = ( + kv_cache_batch_size if kv_cache_batch_size else (full_batch_size if full_batch_size else batch_size) + ) + # Define prefill specialization + prefill_specialization = { + # Prefill is always run with single BS for continuous batching. + "batch_size": 1 if self.continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + # TODO: should be renamed to kv_cache_batch_size in specialzation too + } + prefill_specialization.update({"num_logits_to_keep": 1}) if self.is_tlm else ... + if self.continuous_batching: + prefill_specialization.update({"full_batch_size": kv_cache_batch_size}) + else: + prefill_specialization.update({"batch_size": kv_cache_batch_size}) + prefill_specialization.update({"full_batch_exec_size": full_batch_size}) if full_batch_size else ... + specializations = [ + prefill_specialization, + ] + + # Skip decode specialization if we are not in continuous batching and prefill_seq_len=1 as this repeats prefill specialization + if prefill_seq_len != 1 or self.continuous_batching: + decode_specialization = { + "batch_size": full_batch_size if self.continuous_batching else batch_size, + "seq_len": num_speculative_tokens + 1 if self.is_tlm else 1, + "ctx_len": ctx_len, + } + if self.continuous_batching: + decode_specialization.update({"full_batch_size": kv_cache_batch_size}) + else: + decode_specialization.update({"batch_size": kv_cache_batch_size}) + decode_specialization.update({"num_logits_to_keep": num_speculative_tokens + 1}) if self.is_tlm else ... + specializations.append(decode_specialization) + + if enable_qnn: + if compiler_options: + logger.warning("Extra arguments to QNN compilation are supported via qnn_config.json only") + + qpc_path = self._qnn_compile( + onnx_path, + compile_dir, + specializations=specializations, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + batch_size=batch_size, + full_batch_size=full_batch_size, + mdp_ts_num_devices=num_devices, + num_cores=num_cores, + mxfp6_matmul=mxfp6_matmul, + mxint8_kv_cache=mxint8_kv_cache, + qnn_config=qnn_config, + ) + else: + # Custom IO + custom_io = {} + kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16" + custom_io["pixel_values"] = kv_cache_dtype + custom_io["pixel_values_RetainedState"] = kv_cache_dtype + for suffix in ["", "_RetainedState"]: + for i in range(self.num_layers): + for kv in ["key", "value"]: + custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype + + breakpoint() + qpc_path = self._compile( + onnx_path, + compile_dir, + compile_only=True, + retained_state=True, + specializations=specializations, + convert_to_fp16=True, + mxfp6_matmul=mxfp6_matmul, + custom_io=custom_io, + mdp_ts_num_devices=num_devices, + num_speculative_tokens=num_speculative_tokens, + aic_num_cores=num_cores, + **compiler_options, + ) + return qpc_path + + def generate( + self, + inputs: torch.Tensor, + streamer: Optional[TextStreamer], + device_ids: List[int] = None, + runtime_ai100: bool = True, + ) -> Union[torch.Tensor, np.ndarray]: + """ + This method generates output by executing PyTorch runtime or the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards. + ``Mandatory`` Args: + :inputs (Union[torch.Tensor, np.ndarray]): inputs to run the execution. + ``optional`` Args: + :device_id (List[int]): Ids of devices for running the qpc pass as [0] in case of normal model / [0, 1, 2, 3] in case of tensor slicing model + :runtime_ai100 (bool, optional): ``AI_100`` and ``PyTorch`` runtime is supported as of now. Defaults to ``True`` for ``AI_100`` runtime. + Returns: + :dict: Output from the ``AI_100`` or ``PyTorch`` runtime. + """ + # AI_100 runtime + if runtime_ai100: + if not isinstance(self.qpc_path, Path): + raise TypeError("Please run compile API first!") + + return self.cloud_ai_100_vlm_generate(inputs=inputs, device_ids=device_ids) + # PyTorch runtime + else: + return self.pytorch_vlm_generate(model=self.model, inputs=inputs, streamer=streamer) + + # TODO: Add the code based on how we did in single inference script + def cloud_ai_100_vlm_generate( + self, + inputs: torch.Tensor, + device_ids: List[int] = [0], + ) -> np.ndarray: + """ + Generates features with list of prompts using AI 100 runtime. + + ``Mandatory`` Args: + :inputs (Union[torch.Tensor, np.ndarray]): inputs to run the execution. + ``Optional`` Args: + device_ids (List[int], optional): A list of device IDs to use for the session. Defaults to [0]. + + Returns: + np.ndarray: A list of dictionaries containing the generated output features. + """ + + if self.qpc_session is None: + self.qpc_session = QAICInferenceSession(str(self.qpc_path), device_ids) + self.batch_size = self.qpc_session.bindings[0].dims[0] + self.seq_len = self.qpc_session.bindings[0].dims[1] + # Skip inputs/outputs + self.qpc_session.skip_buffers( + [x for x in self.qpc_session.input_names + self.qpc_session.output_names if x.startswith("past_")] + + ["pixel_values_RetainedState", "image_sizes_RetainedState"] + ) + + # Read prompt and ctx len from session + # batch_size = max( + # [x[self.qpc_session.binding_index_map["input_ids"]][1][0] for x in self.qpc_session.allowed_shapes] + # + [self.qpc_session.bindings[self.qpc_session.binding_index_map["input_ids"]].dims[0]] + # ) + + # prefill_seq_len = max( + # [x[self.qpc_session.binding_index_map["input_ids"]][1][1] for x in self.qpc_session.allowed_shapes] + # + [self.qpc_session.bindings[self.qpc_session.binding_index_map["input_ids"]].dims[1]] + # ) + # Prepare input + input_ids_len = inputs["input_ids"].shape[1] + input_ids = np.array( + torch.nn.functional.pad(inputs["input_ids"], (0, self.seq_len - inputs["input_ids"].size(1)), "constant", 0) + ) + attention_mask = np.array( + torch.nn.functional.pad( + inputs["attention_mask"], (0, self.seq_len - inputs["attention_mask"].size(1)), "constant", 0 + ) + ) + + inputs = dict(input_ids=input_ids, attention_mask=attention_mask) + + outputs = { + "output": np.random.randn(self.batch_size, self.seq_len, self.qpc_session.bindings[2].dims[2]).astype( + np.float32 + ), + } + self.qpc_session.set_buffers(outputs) + outputs = self.qpc_session.run(inputs) + outputs = outputs["output"][:, :input_ids_len, :] + return outputs + + def pytorch_vlm_generate( + self, + model, + inputs: Union[torch.Tensor, np.ndarray], + streamer: TextStreamer, + ) -> List[torch.Tensor]: + """ + Generates features from a list of text prompts using a PyTorch model. + + ``Mandatory`` Args: + :model: The transformed PyTorch model used for generating features. + :inputs (Union[torch.Tensor, np.ndarray]): inputs to run the execution. + :streamer (TextStreamer): A TextStreamer object used for streaming the generated text. + + Returns: + torch.Tensor: A list of output features generated by the model for each prompt. + """ + # inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1) + # inputs["past_key_values"] = [] + # for _ in range(model.config.num_hidden_layers): + # inputs["past_key_values"].append(( + # torch.zeros(1, model.config.num_key_value_heads, self.ctx_len,self.head_dim), + # torch.zeros(1, model.config.num_key_value_heads, self.ctx_len, self.head_dim), + # )) + self.batch_size = inputs["input_ids"].shape[0] + generation_len = self.ctx_len - inputs["input_ids"].shape[1] + generated_ids = torch.full((self.batch_size, generation_len + 1), self.processor.tokenizer.pad_token_id) + + outputs = model(**inputs) + + inputs["input_ids"] = outputs[0].argmax(2) + inputs["position_ids"] = inputs["position_ids"].max(1, keepdim=True).values + 1 + streamer.put(inputs["input_ids"]) + + for _ in range(generation_len): + outputs = model(**inputs) + inputs["input_ids"] = outputs[0].argmax(2) + inputs["position_ids"] += 1 + streamer.put(inputs["input_ids"]) + generated_ids[:, _] = inputs["input_ids"].squeeze(1) + generated_texts = self.processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + for i in range(self.batch_size): + print(i, generated_texts[i]) + + return generated_ids + + class QEFFAutoModel(QEFFTransformersBase): """ The QEFFAutoModel class is designed for manipulating any transformer model from the HuggingFace hub. diff --git a/QEfficient/transformers/models/phi3_vision/__init__.py b/QEfficient/transformers/models/phi3_vision/__init__.py new file mode 100644 index 000000000..d259e435a --- /dev/null +++ b/QEfficient/transformers/models/phi3_vision/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/transformers/models/phi3_vision/configuration_phi3_v.py b/QEfficient/transformers/models/phi3_vision/configuration_phi3_v.py new file mode 100644 index 000000000..4bba14069 --- /dev/null +++ b/QEfficient/transformers/models/phi3_vision/configuration_phi3_v.py @@ -0,0 +1,215 @@ +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Phi-3-V model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +PHI3V_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "microsoft/Phi-3-vision-128k-instruct": "https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/resolve/main/config.json", + "microsoft/Phi-3.5-vision-instruct": "https://huggingface.co/microsoft/Phi-3.5-vision-instruct/resolve/main/config.json", +} + + +class Phi3VConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Phi3VModel`]. It is used to instantiate a Phi-3 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the + [microsoft/Phi-3-vision-128k-instruct](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 32064): + Vocabulary size of the Phi-3-V model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Phi3VModel`]. + hidden_size (`int`, *optional*, defaults to 3072): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 8192): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + resid_pdrop (`float`, *optional*, defaults to 0.0): + Dropout probability for mlp outputs. + embd_pdrop (`int`, *optional*, defaults to 0.0): + The dropout ratio for the embeddings. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio after computing the attention scores. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. + original_max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model was trained with. This is used to determine the size of the + original RoPE embeddings when using long scaling. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon value used for the RMSNorm. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`dict`, *optional*): + The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must + contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be either `su` or `yarn` and + the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size + divided by the number of attention heads divided by 2. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 32000): + The id of the "end-of-sequence" token. + pad_token_id (`int`, *optional*, defaults to 32000): + The id of the padding token. + sliding_window (`int`, *optional*): + Sliding window attention window size. If `None`, no sliding window is applied. + embd_layer (`str`, *optional*, defaults to `"default"`): + The embedding layer to use. Can be either `"default"` or `"image"`. "default" uses the standard embedding for text. + + Example: + + ```python + >>> from transformers import Phi3VModel, Phi3VConfig + + >>> # Initializing a Phi-3-V style configuration + >>> configuration = Phi3Config.from_pretrained("microsoft/Phi-3-vision-128k-instruct") + + >>> # Initializing a model from the configuration + >>> model = Phi3VModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "phi3_v" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32064, + hidden_size=3072, + intermediate_size=8192, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + resid_pdrop=0.0, + embd_pdrop=0.0, + attention_dropout=0.0, + hidden_act="silu", + max_position_embeddings=4096, + original_max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + bos_token_id=1, + eos_token_id=32000, + pad_token_id=32000, + sliding_window=None, + embd_layer: str = "default", + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attention_dropout = attention_dropout + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.original_max_position_embeddings = original_max_position_embeddings + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + self.sliding_window = sliding_window + self.embd_layer = embd_layer + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3: + raise ValueError( + "`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_short_factor = self.rope_scaling.get("short_factor", None) + rope_scaling_long_factor = self.rope_scaling.get("long_factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["su", "yarn"]: + raise ValueError(f"`rope_scaling`'s type field must be one of ['su', 'yarn'], got {rope_scaling_type}") + if not ( + isinstance(rope_scaling_short_factor, list) + and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor) + ): + raise ValueError( + f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}" + ) + if not len(rope_scaling_short_factor) == self.hidden_size // self.num_attention_heads // 2: + raise ValueError( + f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}" + ) + if not ( + isinstance(rope_scaling_long_factor, list) + and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor) + ): + raise ValueError( + f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}" + ) + if not len(rope_scaling_long_factor) == self.hidden_size // self.num_attention_heads // 2: + raise ValueError( + f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}" + ) diff --git a/QEfficient/transformers/models/phi3_vision/modeling_phi3_v.py b/QEfficient/transformers/models/phi3_vision/modeling_phi3_v.py new file mode 100644 index 000000000..2f70ca861 --- /dev/null +++ b/QEfficient/transformers/models/phi3_vision/modeling_phi3_v.py @@ -0,0 +1,1938 @@ +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PyTorch Phi-3-V model.""" + +import inspect +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) + +from .configuration_phi3_v import Phi3VConfig + +try: + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) +except ImportError: + pass + +import torch +from transformers import CLIPVisionConfig, CLIPVisionModel, PretrainedConfig +from transformers.models.clip.modeling_clip import CLIPAttention + +logger = logging.get_logger(__name__) + + +MAX_INPUT_ID = int(1e9) + +CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig( + attention_dropout=0.0, + dropout=0.0, + hidden_act="quick_gelu", + hidden_size=1024, + image_size=336, + initializer_factor=1.0, + initializer_range=0.02, + intermediate_size=4096, + layer_norm_eps=1e-05, + num_attention_heads=16, + num_channels=3, + num_hidden_layers=24, + patch_size=14, + projection_dim=768, + _attn_implementation="eager", +) + + +class CLIPAttentionFA2(CLIPAttention): + """Add flash attention 2 to CLIPAttention. (This is only used in the vision encoder)""" + + def forward( + self, + hidden_states, + attention_mask=None, + causal_attention_mask=None, + output_attentions=False, + ): + """Input shape: Batch x Time x Channel""" + + assert attention_mask is None, "CLIPAttentionFA2 does not support attention_mask" + assert causal_attention_mask is None, "CLIPAttentionFA2 does not support causal_attention_mask" + assert output_attentions is False, "CLIPAttentionFA2 does not support output_attentions" + + bsz, tgt_len, embed_dim = hidden_states.size() + query_states = self.q_proj(hidden_states).reshape(bsz, tgt_len, self.num_heads, self.head_dim) + key_states = self.k_proj(hidden_states).reshape(bsz, tgt_len, self.num_heads, self.head_dim) + value_states = self.v_proj(hidden_states).reshape(bsz, tgt_len, self.num_heads, self.head_dim) + + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout_p=self.dropout if self.training else 0.0, + softmax_scale=self.scale, + causal=False, + ).reshape(bsz, tgt_len, embed_dim) + + attn_output = self.out_proj(attn_output) + return attn_output, None + + +class Phi3ImageEmbedding(nn.Module): + """Phi3 Image embedding.""" + + def __init__(self, config: PretrainedConfig, wte=None, **kwargs) -> None: + super().__init__() + + # n_embed or hidden_size + hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size + if hasattr(config, "embd_pdrop") or hasattr(config, "embed_pdrop"): + embd_drop = config.embd_pdrop if hasattr(config, "embd_pdrop") else config.embed_pdrop + self.drop = nn.Dropout(embd_drop) + else: + self.drop = None + + self.wte = wte + + if isinstance(config.img_processor, dict) and config.img_processor.get("name", None) == "clip_vision_model": + assert "model_name" in config.img_processor, "model_name must be provided for CLIPVisionModel" + assert "image_dim_out" in config.img_processor, "image_dim_out must be provided for CLIPVisionModel" + assert "num_img_tokens" in config.img_processor, "num_img_tokens must be provided for CLIPVisionModel" + assert config.img_processor["model_name"] == "openai/clip-vit-large-patch14-336" + clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG + self.img_processor = CLIPVisionModel(clip_config) + image_dim_out = config.img_processor["image_dim_out"] + self.num_img_tokens = config.img_processor["num_img_tokens"] + + # FA2 in CLIP + if config._attn_implementation == "flash_attention_2": + for layer in self.img_processor.vision_model.encoder.layers: + clip_fa2 = CLIPAttentionFA2(clip_config) + del layer.self_attn + layer.self_attn = clip_fa2 + else: + raise NotImplementedError(f"img_processor = {config.img_processor}, not implemented") + + self.image_dim_out = image_dim_out + self.img_sizes = None + + # global_gn and sub_gn for hd transform, serves as line separator + self.use_hd_transform = kwargs.get("use_hd_transform", False) + self.with_learnable_separator = kwargs.get("with_learnable_separator", False) + self.hd_transform_order = kwargs.get("hd_transform_order", "glb_sub") + # with_hd_transform and with_learnable_separator should have same value + assert ( + self.use_hd_transform == self.with_learnable_separator + ), "use_hd_transform and with_learnable_separator should have same value" + if self.with_learnable_separator: + assert self.use_hd_transform, "learnable separator is only for hd transform" + # 1024 * 4, merge spatial to channel dimension + self.glb_GN = nn.Parameter(torch.zeros([1, 1, self.image_dim_out * 4])) + self.sub_GN = nn.Parameter(torch.zeros([1, 1, 1, self.image_dim_out * 4])) + logger.info(f"learnable separator enabled for hd transform, hd_transform_order = {self.hd_transform_order}") + + projection_cls = kwargs.get("projection_cls", "linear") + if projection_cls == "linear": + self.img_projection = nn.Linear(image_dim_out, hidden_size) + elif projection_cls == "mlp" and self.use_hd_transform: + dim_projection = hidden_size + depth = 2 + layers = [nn.Linear(image_dim_out * 4, dim_projection)] + for _ in range(1, depth): + layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)]) + self.img_projection = nn.Sequential(*layers) + elif projection_cls == "mlp": + dim_projection = hidden_size + depth = 2 + layers = [nn.Linear(image_dim_out, dim_projection)] + for _ in range(1, depth): + layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)]) + self.img_projection = nn.Sequential(*layers) + else: + raise NotImplementedError(f"projection_cls = {projection_cls}, not implemented") + + self.vocab_size = config.vocab_size + self.img_features = None + + if isinstance(config.img_processor, dict): + self.layer_idx = config.img_processor.get("layer_idx", -2) + self.type_feature = config.img_processor.get("type_feature", "patch") + else: + self.layer_idx = -2 + self.type_feature = "patch" + + def set_img_features(self, img_features: torch.FloatTensor) -> None: + self.img_features = img_features + + def set_img_sizes(self, img_sizes: torch.LongTensor) -> None: + self.img_sizes = img_sizes + + def get_img_features(self, img_embeds: torch.FloatTensor) -> torch.FloatTensor: + LAYER_IDX = self.layer_idx + TYPE_FEATURE = self.type_feature + + img_processor_output = self.img_processor(img_embeds, output_hidden_states=True) + img_feature = img_processor_output.hidden_states[LAYER_IDX] + + if TYPE_FEATURE == "patch": + patch_feature = img_feature[:, 1:] + return patch_feature + + raise NotImplementedError + + def forward( + self, input_ids: torch.LongTensor, pixel_values: torch.FloatTensor, image_sizes=None + ) -> torch.FloatTensor: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + # positions for image tokens + positions = torch.nonzero((input_ids < 0) & (input_ids > -MAX_INPUT_ID), as_tuple=True) + has_image = len(positions[0].tolist()) > 0 + input_ids = input_ids.clamp_min(0).clamp_max(self.vocab_size).detach() + hidden_states = self.wte(input_ids) + + if has_image: + assert self.use_hd_transform + num_images, num_crops, c, h, w = pixel_values.shape + assert c == 3 and h == w == 336 + img_features = self.get_img_features(pixel_values.flatten(0, 1)).reshape( + num_images, num_crops, -1, self.image_dim_out + ) + image_features_proj = self.hd_feature_transform(img_features, image_sizes) + hidden_states = hidden_states.index_put(positions, image_features_proj, accumulate=False) + + if self.drop is not None: + hidden_states = self.drop(hidden_states) + + return hidden_states + + def hd_feature_transform(self, image_features, image_sizes): + """ + image_features: (num_images, num_crops+1, 24*24, 1024) + """ + assert self.hd_transform_order == "sub_glb", f"hd_transform_order `{self.hd_transform_order}` not implemented" + if isinstance(self.img_projection, nn.Sequential): + target_device = self.img_projection[0].bias.device + target_dtype = self.img_projection[0].bias.dtype + else: # It's a single nn.Linear layer + target_device = self.img_projection.bias.device + target_dtype = self.img_projection.bias.dtype + + global_image_features = image_features[:, 0] # (num_images, 24*24, 1024) + # global feature can be viewed as a special HD case with num_crops 1x1 + global_image_features_hd = self.reshape_hd_patches_2x2merge(global_image_features, 1, 1) + global_image_features_hd_newline = self.add_image_newline(global_image_features_hd) + + all_image_embeddings = [] + # need a for loop to process each image because of different image sizes + # (patch arrangement is different for each image) + for i, img_size in enumerate(image_sizes): + h, w = img_size + h_crop = h // 336 + w_crop = w // 336 + num_crops = h_crop * w_crop + + # NOTE: real num_crops is padded + # (num_crops, 24*24, 1024) + sub_image_features = image_features[i, 1 : 1 + num_crops] + sub_image_features_hd = self.reshape_hd_patches_2x2merge(sub_image_features, h_crop, w_crop) + sub_image_features_hd_newline = self.add_image_newline(sub_image_features_hd) + + # [sub features, separator, global features] + all_image_embeddings.extend( + [ + sub_image_features_hd_newline.squeeze(0), # (h_crop*12*(w_crop*12+1), 4096) + self.glb_GN.squeeze(0), + global_image_features_hd_newline[i], + ] + ) + + image_features_proj = self.img_projection( + torch.cat(all_image_embeddings, dim=0).to(target_device).to(target_dtype) + ) + + return image_features_proj + + def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop): + """ + image_features: (num_images*num_crops, 24*24, 1024) + output: (num_images, h_crop*12, w_crop*12, 4096), h_crop*w_crop == num_crops + """ + N, L, C = image_features.shape + assert L == 24 * 24 and C == 1024 and N % (h_crop * w_crop) == 0 + num_images = N // (h_crop * w_crop) + H = int(L**0.5) + image_features_hd = ( + image_features.reshape(N, H, H, C) # N, 24, 24, 1024 + .reshape(N, H // 2, 2, H // 2, 2, C) # N, 12, 2, 12, 2, 1024 + .permute(0, 1, 3, 2, 4, 5) # N, 12, 12, 2, 2, 1024 + .reshape(N, -1, 4 * C) # N, 144, 4096 + .reshape(num_images, h_crop, w_crop, H // 2, H // 2, -1) # n_img, h_crop, w_crop, 12, 12, 4096 + .permute(0, 1, 3, 2, 4, 5) # n_img, h_crop, 12, w_crop, 12, 4096 + .reshape(num_images, h_crop * H // 2, w_crop * H // 2, 4 * C) # n_img, h_crop*12, w_crop*12, 4096 + ) + + # alternative implementation using einops + # from einops import rearrange + # image_features_nhwc = rearrange( + # image_features, + # 'N (H W) c -> N H W c', + # H=H, + # W=H, + # ) + # image_features_2x2merge = rearrange( + # image_features_nhwc, + # 'N (h h_pool) (w w_pool) c -> N h w (h_pool w_pool c)', + # h_pool=2, + # w_pool=2, + # ) + # image_features_hd = rearrange( + # image_features_2x2merge, + # '(n_img h_crop w_crop) h w C -> n_img (h_crop h) (w_crop w) C', + # h_crop=h_crop, + # w_crop=w_crop, + # ) + + return image_features_hd + + def add_image_newline(self, image_features_hd): + """ + image_features_hd: (num_images, h_crop*12, w_crop*12, 4096) + output: (num_images, (h_crop*12) * (w_crop*12+1), 4096) + """ + num_images, h, w, hid_dim = image_features_hd.shape + # add the newline token to the HD image feature patches + newline_embeddings = self.sub_GN.expand(num_images, h, -1, -1) # (n_img, h, 1, hid_dim) + image_features_hd_newline = torch.cat([image_features_hd, newline_embeddings], dim=2).reshape( + num_images, -1, hid_dim + ) + return image_features_hd_newline + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "microsoft/Phi-3-vision-128k-instruct" +_CONFIG_FOR_DOC = "Phi3VConfig" + +PHI3V_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "microsoft/Phi-3-vision-128k-instruct", + # See all Phi-3 models at https://huggingface.co/models?filter=Phi-3 +] + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Phi3 +class Phi3RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Phi3RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3 +class Phi3RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.register_buffer("inv_freq", None, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if self.inv_freq is None: + self.inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim) + ) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding): + def __init__(self, dim, config, device=None): + super().__init__(dim, config.max_position_embeddings, config.rope_theta, device) + + self.short_factor = config.rope_scaling["short_factor"] + self.long_factor = config.rope_scaling["long_factor"] + self.original_max_position_embeddings = config.original_max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + seq_len = seq_len or torch.max(position_ids) + 1 + if seq_len > self.original_max_position_embeddings: + ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device) + else: + ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device) + + inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim + self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape) + + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + + scale = self.max_position_embeddings / self.original_max_position_embeddings + if scale <= 1.0: + scaling_factor = 1.0 + else: + scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings)) + + cos = emb.cos() * scaling_factor + sin = emb.sin() * scaling_factor + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding): + def __init__(self, dim, config, device=None): + super().__init__(dim, config.max_position_embeddings, config.rope_theta, device) + + self.short_factor = config.rope_scaling["short_factor"] + self.long_factor = config.rope_scaling["long_factor"] + self.original_max_position_embeddings = config.original_max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + seq_len = torch.max(position_ids) + 1 + if seq_len > self.original_max_position_embeddings: + ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device) + else: + ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device) + + inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim + self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape) + + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + + scale = self.max_position_embeddings / self.original_max_position_embeddings + if scale <= 1.0: + scaling_factor = 1.0 + else: + scaling_factor = 0.1 * math.log(scale) + 1.0 + + cos = emb.cos() * scaling_factor + sin = emb.sin() * scaling_factor + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Phi3MLP(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + + self.activation_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + up_states = self.gate_up_proj(hidden_states) + + gate, up_states = up_states.chunk(2, dim=-1) + up_states = up_states * self.activation_fn(gate) + + return self.down_proj(up_states) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Phi3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Phi3VConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.original_max_position_embeddings = config.original_max_position_embeddings + self.rope_theta = config.rope_theta + self.rope_scaling = config.rope_scaling + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + op_size = self.num_heads * self.head_dim + 2 * (self.num_key_value_heads * self.head_dim) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.qkv_proj = nn.Linear(self.hidden_size, op_size, bias=False) + self._init_rope() + + def _init_rope(self): + if self.rope_scaling is None: + self.rotary_emb = Phi3RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + if scaling_type == "su": + self.rotary_emb = Phi3SuScaledRotaryEmbedding(self.head_dim, self.config) + elif scaling_type == "yarn": + self.rotary_emb = Phi3YarnScaledRotaryEmbedding(self.head_dim, self.config) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + logger.warning_once("You are not running the flash-attention implementation, expect numerical differences.") + + bsz, q_len, _ = hidden_states.size() + + qkv = self.qkv_proj(hidden_states) + query_pos = self.num_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Phi3FlashAttention2(Phi3Attention): + """ + Phi-3 flash attention module. This module inherits from `Phi3Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # Phi3FlashAttention2 attention does not support output_attentions + + if not _flash_supports_window_size: + logger.warning_once( + "The current flash attention version does not support sliding window attention. Please use `attn_implementation='eager'` or upgrade flash-attn library." + ) + raise ValueError("The current flash attention version does not support sliding window attention.") + + output_attentions = False + + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + + bsz, q_len, _ = hidden_states.size() + + qkv = self.qkv_proj(hidden_states) + query_pos = self.num_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + use_sliding_windows = ( + _flash_supports_window_size + and getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + ) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_dropout = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. + + if query_states.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.qkv_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=attn_dropout, + use_sliding_windows=use_sliding_windows, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + use_sliding_windows=False, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + use_sliding_windows (`bool`, *optional*): + Whether to activate sliding window attention. + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + if not use_sliding_windows: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + if not use_sliding_windows: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + return attn_output + + # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape + + # On the first iteration we need to properly re-create the padding mask + # by slicing it on the proper place + if kv_seq_len != attention_mask.shape[-1]: + attention_mask_num_tokens = attention_mask.shape[-1] + attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] + + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + + if query_length == kv_seq_len: + query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Phi3 +# TODO @Arthur no longer copied from LLama after static cache +class Phi3SdpaAttention(Phi3Attention): + """ + Phi3 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Phi3Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from Phi3Attention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Phi3Model is using Phi3SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + qkv = self.qkv_proj(hidden_states) + query_pos = self.num_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +PHI3_ATTENTION_CLASSES = { + "eager": Phi3Attention, + "flash_attention_2": Phi3FlashAttention2, + "sdpa": Phi3SdpaAttention, +} + + +class Phi3DecoderLayer(nn.Module): + def __init__(self, config: Phi3VConfig, layer_idx: int): + super().__init__() + + self.config = config + self.self_attn = PHI3_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) + + self.mlp = Phi3MLP(config) + self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) + self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) + self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + """ + Args: + hidden_states (`torch.FloatTensor`): + input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range + `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + attn_outputs, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = residual + self.resid_attn_dropout(attn_outputs) + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.resid_mlp_dropout(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +PHI3V_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Phi3VConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Phi-3-V model outputting raw hidden-states without any specific head on top.", + PHI3V_START_DOCSTRING, +) +class Phi3VPreTrainedModel(PreTrainedModel): + config_class = Phi3VConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Phi3DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = False + _supports_cache_class = True + + _version = "0.0.5" + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +PHI3V_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using [`AutoImageProcessor`]. + See [`Phi3ImageProcessor.__call__`] for details. + image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`, *optional*): + The sizes of the images in the batch, being (height, width) for each image. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Phi-3-V model outputting raw hidden-states without any specific head on top.", + PHI3V_START_DOCSTRING, +) +class Phi3VModel(Phi3VPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`] + + Args: + config: Phi3Config + """ + + def __init__(self, config: Phi3VConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.embed_dropout = nn.Dropout(config.embd_pdrop) + + self.vision_embed_tokens = None + if isinstance(config.embd_layer, dict): + # vision embedding layer + embedding_config = {"embedding_cls": config.embd_layer["embedding_cls"], **config.embd_layer} + self.vision_embed_tokens = Phi3ImageEmbedding(config, wte=self.embed_tokens, **embedding_config) + # # set wte the same for vision embedding + # self.vision_embed_tokens.wte.weight = self.embed_tokens.weight + + self.layers = nn.ModuleList( + [Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(PHI3V_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + past_key_values_length = 0 + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + if pixel_values is not None and image_sizes is not None: + assert self.vision_embed_tokens is not None, "Vision embedding layer is not defined" + inputs_embeds = self.vision_embed_tokens(input_ids, pixel_values=pixel_values, image_sizes=image_sizes) + else: + inputs_embeds = self.embed_tokens(input_ids) + + if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + + if self._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class Phi3VForCausalLM(Phi3VPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi3 + def __init__(self, config): + super().__init__(config) + self.model = Phi3VModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings + def get_input_embeddings(self): + return self.model.embed_tokens + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings + def get_output_embeddings(self): + return self.lm_head + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder + def set_decoder(self, decoder): + self.model = decoder + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder + def get_decoder(self): + return self.model + + # Ignore copy + @add_start_docstrings_to_model_forward(PHI3V_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Phi3ForCausalLM + + >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct") + + >>> prompt = "This is an example script ." + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum' + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + image_sizes=image_sizes, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + # Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM.prepare_inputs_for_generation + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + pixel_values=None, + image_sizes=None, + **kwargs, + ): + # When the first time input length reached long and short factor switching point, enforce re-compute cache + # It will cause downside of slower at this single token position, however, better than current failure. + if ( + past_key_values + and self.config.rope_scaling + and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1 + ): + past_length = ( + past_key_values.seen_tokens if isinstance(past_key_values, Cache) else past_key_values[0][0].shape[2] + ) + if past_length <= self.config.original_max_position_embeddings: + past_key_values = None + + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "image_sizes": image_sizes, + } + ) + return model_inputs + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings( + """ + The [`Phi3VModel`] with a sequence classification head on top (linear layer). + + [`Phi3VForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + PHI3V_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Phi3, LLAMA->PHI3, self.transformer->self.model, transformer_outputs->model_outputs +class Phi3VForSequenceClassification(Phi3VPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Phi3VModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(PHI3V_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + model_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + image_sizes=image_sizes, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = model_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + model_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=model_outputs.past_key_values, + hidden_states=model_outputs.hidden_states, + attentions=model_outputs.attentions, + ) + + +@add_start_docstrings( + """ + [`Phi3VModel`] with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + PHI3V_START_DOCSTRING, +) +# Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with Mpt->Phi3,MPT->PHI3,self.transformer->self.model,transformer_outputs->model_outputs +class Phi3VForTokenClassification(Phi3VPreTrainedModel): + def __init__(self, config: Phi3VConfig): + super().__init__(config) + self.num_labels = config.num_labels + + self.model = Phi3VModel(config) + if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: + classifier_dropout = config.classifier_dropout + elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(PHI3V_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.LongTensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **deprecated_arguments, + ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + model_outputs = self.model( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + image_sizes=image_sizes, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = model_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + batch_size, seq_length = labels.shape + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)) + + if not return_dict: + output = (logits,) + model_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=model_outputs.hidden_states, + attentions=model_outputs.attentions, + ) diff --git a/QEfficient/transformers/models/phi3_vision/modeling_phi3_vision.py b/QEfficient/transformers/models/phi3_vision/modeling_phi3_vision.py new file mode 100644 index 000000000..46e8cd61e --- /dev/null +++ b/QEfficient/transformers/models/phi3_vision/modeling_phi3_vision.py @@ -0,0 +1,1771 @@ +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PyTorch Phi-3-V model.""" + +import math +from typing import List, Optional, Tuple, Union + +# from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +# try: +# from flash_attn import flash_attn_func, flash_attn_varlen_func +# from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa +# _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) +# except ImportError: +# pass +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers import CLIPVisionConfig, CLIPVisionModel, PretrainedConfig +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.models.clip.modeling_clip import CLIPAttention +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) + +from .configuration_phi3_v import Phi3VConfig + +logger = logging.get_logger(__name__) + + +def _create_causal_mask( + position_ids, + target_length, + sliding_window: Optional[int] = None, +): + """ + A utility attention mask class that allows one to: + - Create a causal 4d mask + - Create a causal 4d mask with slided window + """ + if sliding_window is not None: + query_indices = position_ids.unsqueeze(-1) + kv_indices = torch.arange(target_length).view(1, -1) + # --- Rolling buffer --- + pos_max = position_ids.max(1, keepdim=True).values + kv_start = (pos_max // target_length) * target_length + kv_indices_high = kv_indices + kv_start + kv_indices_low = torch.where(kv_indices_high < target_length, kv_indices, kv_indices_high - target_length) + kv_indices = torch.where(kv_indices_high > pos_max, kv_indices_low, kv_indices_high) + kv_indices = kv_indices.unsqueeze(1) + # ------ + causal_mask = kv_indices > query_indices + attention_mask = causal_mask + + window_indices = query_indices - sliding_window + 1 + window_mask = kv_indices < window_indices + attention_mask = attention_mask | window_mask + attention_mask = attention_mask.unsqueeze(1) + else: + query_indices = position_ids.unsqueeze(-1) + kv_indices = torch.arange(target_length).view(1, 1, -1) + attention_mask = kv_indices > query_indices + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + + +MAX_INPUT_ID = int(1e9) + +CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig( + attention_dropout=0.0, + dropout=0.0, + hidden_act="quick_gelu", + hidden_size=1024, + image_size=336, + initializer_factor=1.0, + initializer_range=0.02, + intermediate_size=4096, + layer_norm_eps=1e-05, + num_attention_heads=16, + num_channels=3, + num_hidden_layers=24, + patch_size=14, + projection_dim=768, + _attn_implementation="eager", +) + + +def reshape_hd_patches_2x2merge(image_features, h_crop, w_crop): + """ + image_features: (num_images*num_crops, 24*24, 1024) + output: (num_images, h_crop*12, w_crop*12, 4096), h_crop*w_crop == num_crops + """ + N, L, C = image_features.shape + assert L == 24 * 24 and C == 1024 and N % (h_crop * w_crop) == 0 + num_images = torch.tensor(N // (h_crop * w_crop), dtype=torch.int64) + H = torch.tensor(int(L**0.5), dtype=torch.int64) + H_div_2 = torch.tensor(H // 2, dtype=torch.int64) + + image_features_hd = ( + image_features.reshape(N, H, H, C) # N, 24, 24, 1024 + .reshape(N, H_div_2, 2, H_div_2, 2, C) # N, 12, 2, 12, 2, 1024 + .permute(0, 1, 3, 2, 4, 5) # N, 12, 12, 2, 2, 1024 + .reshape(N, -1, 4 * C) # N, 144, 4096 + .reshape(num_images, h_crop, w_crop, H_div_2, H_div_2, -1) # n_img, h_crop, w_crop, 12, 12, 4096 + .permute(0, 1, 3, 2, 4, 5) # n_img, h_crop, 12, w_crop, 12, 4096 + .reshape(num_images, h_crop * H_div_2, w_crop * H_div_2, 4 * C) # n_img, h_crop*12, w_crop*12, 4096 + ) + + return image_features_hd + + +def add_image_newline(image_features_hd, sub_GN): + """ + image_features_hd: (num_images, h_crop*12, w_crop*12, 4096) + output: (num_images, (h_crop*12) * (w_crop*12+1), 4096) + """ + num_images, h, w, hid_dim = image_features_hd.shape + # add the newline token to the HD image feature patches + newline_embeddings = sub_GN.expand(num_images, h, -1, -1) # (n_img, h, 1, hid_dim) + image_features_hd_newline = torch.cat([image_features_hd, newline_embeddings], dim=2).reshape( + num_images, -1, hid_dim + ) + return image_features_hd_newline + + +# TODO: vbaddi: comment the tracing part for single image +# @torch.jit.script_if_tracing +def get_image_embeddings(image_dim_out, image_sizes, image_features, global_image_features_hd_newline): + """ + Get image embeddings for all images. + Need a for loop to process each image because of different image sizes + (patch arrangement is different for each image) + """ + + glb_GN = torch.zeros(1, 1, image_dim_out * 4).to(image_features.device) + sub_GN = torch.zeros(1, 1, 1, image_dim_out * 4).to(image_features.device) + + all_image_embeddings = torch.empty(0, 4096).to(image_features.device) + + # Assuming img_size is (672, 672) + h, w = image_sizes.squeeze().tolist() + # print(f"h is {h} and w is {w} and type of image_size is {type(image_sizes)} and image_size0 is {image_sizes[0]}") + h_crop = torch.tensor(h // 336, dtype=torch.int64) + w_crop = torch.tensor(w // 336, dtype=torch.int64) + num_crops = h_crop * w_crop + + # NOTE: real num_crops is padded + # (num_crops, 24*24, 1024) + sub_image_features = image_features[0, 1 : 1 + num_crops] + sub_image_features_hd = reshape_hd_patches_2x2merge(sub_image_features, h_crop, w_crop) + sub_image_features_hd_newline = add_image_newline(sub_image_features_hd, sub_GN) + + # [sub features, separator, global features] + all_image_embeddings = torch.cat( + [ + sub_image_features_hd_newline.view(-1, 4096), # (h_crop*12*(w_crop*12+1), 4096) + glb_GN.view(-1, 4096), + global_image_features_hd_newline[0], + ] + ) + + return all_image_embeddings + + +# class Phi3Embedding(nn.Module): +# """Phi3 embedding for text-only and vision + text.""" +# def __init__(self, wte, vocab_size): +# super().__init__() +# self.wte = wte +# self.vocab_size = vocab_size + +# # def forward(self, input_ids: torch.LongTensor, image_features: torch.FloatTensor, positions: torch.LongTensor) -> torch.FloatTensor: +# def forward(self, input_ids: torch.LongTensor, image_features: torch.FloatTensor) -> torch.FloatTensor: +# # TODO: vbaddi: comment the script part, precompute and send positions to avoid nonzero +# # breakpoint() +# # input_ids, positions = clamp_input_ids(input_ids, image_features, self.vocab_size) +# # hidden_states = self.wte(input_ids) +# # hidden_states = select_logic(hidden_states, image_features, positions) +# # return hidden_states +# # # breakpoint() +# mask = input_ids < 0 +# indices1 = mask.to(torch.int64).cumsum(1) - 1 +# indices0 = torch.arange(mask.shape[0]).view(-1, 1) +# input_ids = input_ids.clamp_min(0).clamp_max(self.vocab_size).detach() +# inputs_embeds = self.wte(input_ids) +# image_features = image_features.unsqueeze(0) +# image_features_expanded = image_features[indices0, indices1] +# # breakpoint() +# image_inputs_embeds = torch.where( +# mask.unsqueeze(-1), image_features_expanded, inputs_embeds +# ) + +# # *where to skip image encoder for decode* +# inputs_embeds = torch.where( +# input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_inputs_embeds +# ) +# return inputs_embeds + + +class QEffPhi3ImageEmbedding(nn.Module): + """Phi3 Image embedding.""" + + def __init__(self, config: PretrainedConfig, wte=None, **kwargs) -> None: + super().__init__() + # n_embed or hidden_size + hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size + if hasattr(config, "embd_pdrop") or hasattr(config, "embed_pdrop"): + embd_drop = config.embd_pdrop if hasattr(config, "embd_pdrop") else config.embed_pdrop + self.drop = nn.Dropout(embd_drop) + else: + self.drop = None + + self.wte = wte + + if isinstance(config.img_processor, dict) and config.img_processor.get("name", None) == "clip_vision_model": + assert "model_name" in config.img_processor, "model_name must be provided for CLIPVisionModel" + assert "image_dim_out" in config.img_processor, "image_dim_out must be provided for CLIPVisionModel" + assert "num_img_tokens" in config.img_processor, "num_img_tokens must be provided for CLIPVisionModel" + assert config.img_processor["model_name"] == "openai/clip-vit-large-patch14-336" + clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG + self.img_processor = CLIPVisionModel(clip_config) + image_dim_out = config.img_processor["image_dim_out"] + self.num_img_tokens = config.img_processor["num_img_tokens"] + + # FA2 in CLIP + # if config._attn_implementation == "flash_attention_2": + # for layer in self.img_processor.vision_model.encoder.layers: + # clip_fa2 = CLIPAttentionFA2(clip_config) + # del layer.self_attn + # layer.self_attn = clip_fa2 + # Testing to manually add Eager Attention to clipVisionModel + + else: + raise NotImplementedError(f"img_processor = {config.img_processor}, not implemented") + + self.image_dim_out = image_dim_out + self.img_sizes = None + + # global_gn and sub_gn for hd transform, serves as line separator + self.use_hd_transform = kwargs.get("use_hd_transform", False) + self.with_learnable_separator = kwargs.get("with_learnable_separator", False) + self.hd_transform_order = kwargs.get("hd_transform_order", "glb_sub") + # with_hd_transform and with_learnable_separator should have same value + assert ( + self.use_hd_transform == self.with_learnable_separator + ), "use_hd_transform and with_learnable_separator should have same value" + if self.with_learnable_separator: + assert self.use_hd_transform, "learnable separator is only for hd transform" + # 1024 * 4, merge spatial to channel dimension + self.glb_GN = nn.Parameter(torch.zeros([1, 1, self.image_dim_out * 4])) + self.sub_GN = nn.Parameter(torch.zeros([1, 1, 1, self.image_dim_out * 4])) + logger.info(f"learnable separator enabled for hd transform, hd_transform_order = {self.hd_transform_order}") + + projection_cls = kwargs.get("projection_cls", "linear") + if projection_cls == "linear": + self.img_projection = nn.Linear(image_dim_out, hidden_size) + elif projection_cls == "mlp" and self.use_hd_transform: + dim_projection = hidden_size + depth = 2 + layers = [nn.Linear(image_dim_out * 4, dim_projection)] + for _ in range(1, depth): + layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)]) + self.img_projection = nn.Sequential(*layers) + elif projection_cls == "mlp": + dim_projection = hidden_size + depth = 2 + layers = [nn.Linear(image_dim_out, dim_projection)] + for _ in range(1, depth): + layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)]) + self.img_projection = nn.Sequential(*layers) + else: + raise NotImplementedError(f"projection_cls = {projection_cls}, not implemented") + + self.vocab_size = config.vocab_size + self.img_features = None + + if isinstance(config.img_processor, dict): + self.layer_idx = config.img_processor.get("layer_idx", -2) + self.type_feature = config.img_processor.get("type_feature", "patch") + else: + self.layer_idx = -2 + self.type_feature = "patch" + self.__qeff_init__() + + def __qeff_init__(self): + # breakpoint() + clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG + for layer in self.img_processor.vision_model.encoder.layers: + clip_eager = CLIPAttention(clip_config) + del layer.self_attn + layer.self_attn = clip_eager + + def set_img_features(self, img_features: torch.FloatTensor) -> None: + self.img_features = img_features + + def set_img_sizes(self, img_sizes: torch.LongTensor) -> None: + self.img_sizes = img_sizes + + def get_img_features(self, img_embeds: torch.FloatTensor) -> torch.FloatTensor: + LAYER_IDX = self.layer_idx + TYPE_FEATURE = self.type_feature + + img_processor_output = self.img_processor(img_embeds, output_hidden_states=True) + img_feature = img_processor_output.hidden_states[LAYER_IDX] + + if TYPE_FEATURE == "patch": + patch_feature = img_feature[:, 1:] + return patch_feature + + raise NotImplementedError + + def forward(self, pixel_values: torch.FloatTensor, image_sizes=None) -> torch.FloatTensor: + assert self.use_hd_transform + # breakpoint() + + num_images, num_crops, c, h, w = pixel_values.shape + assert c == 3 and h == w == 336 + img_features = self.get_img_features(pixel_values.flatten(0, 1)).reshape( + num_images, num_crops, -1, self.image_dim_out + ) + image_features_proj = self.hd_feature_transform(img_features, image_sizes) + + return image_features_proj + + def hd_feature_transform(self, image_features, image_sizes): + """ + image_features: (num_images, num_crops+1, 24*24, 1024) + """ + assert self.hd_transform_order == "sub_glb", f"hd_transform_order `{self.hd_transform_order}` not implemented" + if isinstance(self.img_projection, nn.Sequential): + target_device = self.img_projection[0].bias.device + target_dtype = self.img_projection[0].bias.dtype + else: # It's a single nn.Linear layer + target_device = self.img_projection.bias.device + target_dtype = self.img_projection.bias.dtype + + global_image_features = image_features[:, 0] # (num_images, 24*24, 1024) + # global feature can be viewed as a special HD case with num_crops 1x1 + global_image_features_hd = self.reshape_hd_patches_2x2merge(global_image_features, 1, 1) + global_image_features_hd_newline = self.add_image_newline(global_image_features_hd) + + all_image_embeddings = get_image_embeddings( + torch.tensor(self.image_dim_out), image_sizes, image_features, global_image_features_hd_newline + ) + image_features_proj = self.img_projection(all_image_embeddings.unsqueeze(0).to(target_device).to(target_dtype)) + return image_features_proj.squeeze() + + def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop): + """ + image_features: (num_images*num_crops, 24*24, 1024) + output: (num_images, h_crop*12, w_crop*12, 4096), h_crop*w_crop == num_crops + """ + N, L, C = image_features.shape + assert L == 24 * 24 and C == 1024 and N % (h_crop * w_crop) == 0 + num_images = N // (h_crop * w_crop) + H = int(L**0.5) + image_features_hd = ( + image_features.reshape(N, H, H, C) # N, 24, 24, 1024 + .reshape(N, H // 2, 2, H // 2, 2, C) # N, 12, 2, 12, 2, 1024 + .permute(0, 1, 3, 2, 4, 5) # N, 12, 12, 2, 2, 1024 + .reshape(N, -1, 4 * C) # N, 144, 4096 + .reshape(num_images, h_crop, w_crop, H // 2, H // 2, -1) # n_img, h_crop, w_crop, 12, 12, 4096 + .permute(0, 1, 3, 2, 4, 5) # n_img, h_crop, 12, w_crop, 12, 4096 + .reshape(num_images, h_crop * H // 2, w_crop * H // 2, 4 * C) # n_img, h_crop*12, w_crop*12, 4096 + ) + + # alternative implementation using einops + # from einops import rearrange + # image_features_nhwc = rearrange( + # image_features, + # 'N (H W) c -> N H W c', + # H=H, + # W=H, + # ) + # image_features_2x2merge = rearrange( + # image_features_nhwc, + # 'N (h h_pool) (w w_pool) c -> N h w (h_pool w_pool c)', + # h_pool=2, + # w_pool=2, + # ) + # image_features_hd = rearrange( + # image_features_2x2merge, + # '(n_img h_crop w_crop) h w C -> n_img (h_crop h) (w_crop w) C', + # h_crop=h_crop, + # w_crop=w_crop, + # ) + + return image_features_hd + + def add_image_newline(self, image_features_hd): + """ + image_features_hd: (num_images, h_crop*12, w_crop*12, 4096) + output: (num_images, (h_crop*12) * (w_crop*12+1), 4096) + """ + num_images, h, w, hid_dim = image_features_hd.shape + # add the newline token to the HD image feature patches + newline_embeddings = self.sub_GN.expand(num_images, h, -1, -1) # (n_img, h, 1, hid_dim) + image_features_hd_newline = torch.cat([image_features_hd, newline_embeddings], dim=2).reshape( + num_images, -1, hid_dim + ) + return image_features_hd_newline + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "microsoft/Phi-3-vision-128k-instruct" +_CONFIG_FOR_DOC = "Phi3VConfig" + +PHI3V_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "microsoft/Phi-3-vision-128k-instruct", + # See all Phi-3 models at https://huggingface.co/models?filter=Phi-3 +] + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Phi3 +class Phi3RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Phi3RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # import pdb; pdb.set_trace() + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3 +class Phi3RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.register_buffer("inv_freq", None, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if self.inv_freq is None: + self.inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim) + ) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class QEffPhi3RotaryEmbedding(Phi3RotaryEmbedding): + """ + Copied from Phi3ForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + The only differences are: + - Add static sin/cos computations. + """ + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super(Phi3RotaryEmbedding, self).__init__() # Initialize nn.Module + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + # Apply rotation + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + # Cast back to original dtype + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + +class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding): + def __init__(self, dim, config, device=None): + super().__init__(dim, config.max_position_embeddings, config.rope_theta, device) + + self.short_factor = config.rope_scaling["short_factor"] + self.long_factor = config.rope_scaling["long_factor"] + self.original_max_position_embeddings = config.original_max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + seq_len = seq_len or torch.max(position_ids) + 1 + if seq_len > self.original_max_position_embeddings: + ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device) + else: + ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device) + + inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim + self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape) + + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + + scale = self.max_position_embeddings / self.original_max_position_embeddings + if scale <= 1.0: + scaling_factor = 1.0 + else: + scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings)) + + cos = emb.cos() * scaling_factor + sin = emb.sin() * scaling_factor + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + # breakpoint() + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Phi3MLP(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + + self.activation_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + up_states = self.gate_up_proj(hidden_states) + + gate, up_states = up_states.chunk(2, dim=-1) + up_states = up_states * self.activation_fn(gate) + + return self.down_proj(up_states) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Phi3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Phi3VConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.original_max_position_embeddings = config.original_max_position_embeddings + self.rope_theta = config.rope_theta + self.rope_scaling = config.rope_scaling + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + op_size = self.num_heads * self.head_dim + 2 * (self.num_key_value_heads * self.head_dim) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.qkv_proj = nn.Linear(self.hidden_size, op_size, bias=False) + self._init_rope() + + def _init_rope(self): + if self.rope_scaling is None: + self.rotary_emb = Phi3RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + if scaling_type == "su": + self.rotary_emb = Phi3SuScaledRotaryEmbedding(self.head_dim, self.config) + # elif scaling_type == "yarn": + # self.rotary_emb = Phi3YarnScaledRotaryEmbedding(self.head_dim, self.config) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + logger.warning_once("You are not running the flash-attention implementation, expect numerical differences.") + + bsz, q_len, _ = hidden_states.size() + + qkv = self.qkv_proj(hidden_states) + query_pos = self.num_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + # breakpoint() + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # Update the cache_kwargs with position_ids for Cloud AI 100 + cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class QEffPhi3VAttention(Phi3Attention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + """ + Copied from Phi3Attention: https://github.com/huggingface/transformers/blob/main/src/transformers/models/phi3/modeling_phi3.py + The only differences are: + - add new args position idx for the cache_kwargs for kv retention + """ + + def __init__(self, config: Phi3VConfig, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + # Define the general __qeff_init__() for any changes in the init calls + # Set the init in the module mapping pytorch transforms + self.__qeff_init__() + + def __qeff_init__(self): + # breakpoint() + self.rotary_emb = QEffPhi3RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + logger.warning_once("You are not running the flash-attention implementation, expect numerical differences.") + + bsz, q_len, _ = hidden_states.size() + + qkv = self.qkv_proj(hidden_states) + query_pos = self.num_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + # breakpoint() + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # Update the cache_kwargs with position_ids for Cloud AI 100 + cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + # breakpoint() + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + # Added by Illango for the kv_cache thing + kv_seq_len = key_states.shape[-2] + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + # breakpoint() + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +PHI3_ATTENTION_CLASSES = { + "eager": QEffPhi3VAttention, +} + + +class QEffPhi3VDecoderLayer(nn.Module): + def __init__(self, config: Phi3VConfig, layer_idx: int): + super().__init__() + + self.config = config + self.self_attn = PHI3_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) + + self.mlp = Phi3MLP(config) + self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) + self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) + self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): + input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range + `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + # breakpoint() + # Self Attention + attn_outputs, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + batch_index=batch_index, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = residual + self.resid_attn_dropout(attn_outputs) + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.resid_mlp_dropout(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +PHI3V_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Phi3VConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Phi-3-V model outputting raw hidden-states without any specific head on top.", + PHI3V_START_DOCSTRING, +) +class Phi3VPreTrainedModel(PreTrainedModel): + config_class = Phi3VConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + # _no_split_modules = ["QEffPhi3VDecoderLayer"] + _no_split_modules = ["QEffPhi3VDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = False + _supports_cache_class = True + + _version = "0.0.5" + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +PHI3V_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using [`AutoImageProcessor`]. + See [`Phi3ImageProcessor.__call__`] for details. + image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`, *optional*): + The sizes of the images in the batch, being (height, width) for each image. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Phi-3-V model outputting raw hidden-states without any specific head on top.", + PHI3V_START_DOCSTRING, +) +class Phi3VModel(Phi3VPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`QEffPhi3VDecoderLayer`] + + Args: + config: Phi3Config + """ + + def __init__(self, config: Phi3VConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.embed_dropout = nn.Dropout(config.embd_pdrop) + # self.combined_embed = Phi3Embedding(self.embed_tokens, config.vocab_size) + + self.vision_embed_tokens = None + if isinstance(config.embd_layer, dict): + # vision embedding layer + embedding_config = {"embedding_cls": config.embd_layer["embedding_cls"], **config.embd_layer} + # breakpoint() + self.vision_embed_tokens = QEffPhi3ImageEmbedding(config, wte=self.embed_tokens, **embedding_config) + # # set wte the same for vision embedding + # self.vision_embed_tokens.wte.weight = self.embed_tokens.weight + + self.layers = nn.ModuleList( + [QEffPhi3VDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + # [QEffPhi3VDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(PHI3V_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + # cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + past_key_values_length = 0 + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + if pixel_values is not None and image_sizes is not None: + assert self.vision_embed_tokens is not None, "Vision embedding layer is not defined" + # inputs_embeds = self.vision_embed_tokens(input_ids, pixel_values=pixel_values, image_sizes=image_sizes) + img_feature_proj = self.vision_embed_tokens(pixel_values=pixel_values, image_sizes=image_sizes) + # # TODO : call combined_embedding here, pass `input_embeds`, `input_ids`, `positions` to the object and get mixed embeddings. + inputs_embeds = self.combined_embed(input_ids, img_feature_proj) + else: + inputs_embeds = self.embed_tokens(input_ids) + + if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + + if self._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class QEffPhi3VModel(Phi3VModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`QEffPhi3VDecoderLayer`] + + Args: + config: Phi3Config + """ + + def __init__(self, config: Phi3VConfig): + super().__init__(config) + + @add_start_docstrings_to_model_forward(PHI3V_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # breakpoint() + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + # past_key_values_length = 0 + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + # past_key_values_length = past_key_values.get_usable_length(seq_length) + past_key_values.get_usable_length(seq_length) + + if inputs_embeds is None: + if pixel_values is not None and image_sizes is not None: + assert self.vision_embed_tokens is not None, "Vision embedding layer is not defined" + # inputs_embeds = self.vision_embed_tokens(input_ids, pixel_values=pixel_values, image_sizes=image_sizes) + img_feature_proj = self.vision_embed_tokens(pixel_values=pixel_values, image_sizes=image_sizes) + # # TODO : call combined_embedding here, pass `input_embeds`, `input_ids`, `positions` to the object and get mixed embeddings. + # inputs_embeds = self.combined_embed(input_ids,img_feature_proj) + mask = input_ids < 0 + indices1 = mask.to(torch.int64).cumsum(1) - 1 + indices0 = torch.arange(mask.shape[0]).view(-1, 1) + input_ids = input_ids.clamp_min(0).clamp_max(self.vocab_size).detach() + inputs_embeds = self.embed_tokens(input_ids) + image_features = img_feature_proj.unsqueeze(0) + image_features_expanded = image_features[indices0, indices1] + image_inputs_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds) + # *where to skip image encoder for decode* + inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_inputs_embeds) + else: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + + if self._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + + # Added from QEff + target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + causal_mask = _create_causal_mask( + position_ids=position_ids, target_length=target_length, sliding_window=self.config.sliding_window + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + elif batch_index is not None: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + batch_index=batch_index, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class QEffPhi3VForCausalLM(Phi3VPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi3 + def __init__(self, config): + super().__init__(config) + + self.model = QEffPhi3VModel(config) + + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings + def get_input_embeddings(self): + return self.model.embed_tokens + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings + def get_output_embeddings(self): + return self.lm_head + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder + def set_decoder(self, decoder): + self.model = decoder + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder + def get_decoder(self): + return self.model + + # Ignore copy + @add_start_docstrings_to_model_forward(PHI3V_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Phi3ForCausalLM + + >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct") + + >>> prompt = "This is an example script ." + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum' + ```""" + # breakpoint() + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + batch_index=batch_index, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + image_sizes=image_sizes, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + # Cast to INT32 to avoid issue while running in ONNXRT + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + # return CausalLMOutputWithPast( + # loss=loss, + # logits=logits, + # past_key_values=outputs.past_key_values, + # hidden_states=outputs.hidden_states, + # attentions=outputs.attentions, + # ) + outputs = CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + return outputs.logits, pixel_values, image_sizes, outputs.past_key_values + + # Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM.prepare_inputs_for_generation + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + pixel_values=None, + image_sizes=None, + **kwargs, + ): + # When the first time input length reached long and short factor switching point, enforce re-compute cache + # It will cause downside of slower at this single token position, however, better than current failure. + if ( + past_key_values + and self.config.rope_scaling + and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1 + ): + past_length = ( + past_key_values.seen_tokens if isinstance(past_key_values, Cache) else past_key_values[0][0].shape[2] + ) + if past_length <= self.config.original_max_position_embeddings: + past_key_values = None + + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "image_sizes": image_sizes, + } + ) + return model_inputs + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +# class Phi3VModelWrapper(torch.nn.Module): +# def __init__(self, model): +# super().__init__() +# self.model = model + +# def forward(self, input_ids, position_ids, pixel_values, image_sizes, past_key_values): +# #breakpoint() +# outputs = self.model( +# input_ids=input_ids, +# position_ids=position_ids, +# past_key_values=past_key_values, +# pixel_values=pixel_values, +# image_sizes=image_sizes, +# ) +# # breakpoint() +# return outputs.logits, pixel_values, image_sizes, outputs.past_key_values diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 6b8d00689..8325b8ed1 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -188,6 +188,14 @@ QEffPhi3ForCausalLM, QEffPhi3Model, ) +from QEfficient.transformers.models.phi3_vision.modeling_phi3_vision import ( + QEffPhi3ImageEmbedding, + QEffPhi3RotaryEmbedding, + QEffPhi3VAttention, + QEffPhi3VDecoderLayer, + QEffPhi3VForCausalLM, + QEffPhi3VModel, +) from QEfficient.transformers.models.qwen2.modeling_qwen2 import ( QEffQwen2Attention, QEffQwen2DecoderLayer, @@ -301,6 +309,13 @@ class KVCacheTransform(ModuleMappingTransform): GPTBigCodeBlock: QEffGPTBigCodeBlock, GPTBigCodeModel: QEffGPTBigCodeModel, GPTBigCodeForCausalLM: QEffGPTBigCodeForCausalLM, + # Phi3-vision + "Phi3VModel": QEffPhi3VModel, + "Phi3Attention": QEffPhi3VAttention, + "Phi3RotaryEmbedding": QEffPhi3RotaryEmbedding, + "Phi3VForCausalLM": QEffPhi3VForCausalLM, + "Phi3ImageEmbedding": QEffPhi3ImageEmbedding, + "Phi3DecoderLayer": QEffPhi3VDecoderLayer, } @classmethod diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index ab861a788..ea8de3722 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -45,18 +45,18 @@ def get_models_dir(): QEFF_MODELS_DIR = get_models_dir() ONNX_EXPORT_EXAMPLE_BATCH_SIZE = 1 -ONNX_EXPORT_EXAMPLE_SEQ_LEN = 32 +ONNX_EXPORT_EXAMPLE_SEQ_LEN = 1024 ONNX_EXPORT_EXAMPLE_FBS = 4 ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep -ONNX_EXPORT_OPSET = 13 +ONNX_EXPORT_OPSET = 17 COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw", "-aic-hw-version=2.0"] class Constants: # Export Constants. - SEQ_LEN = 32 - CTX_LEN = 32 + SEQ_LEN = 1024 + CTX_LEN = 1280 PROMPT_LEN = 8 INPUT_STR = ["My name is"] GB = 2**30 diff --git a/tests/transformers/models/test_vlm_models.py b/tests/transformers/models/test_vlm_models.py new file mode 100755 index 000000000..f8b496d0f --- /dev/null +++ b/tests/transformers/models/test_vlm_models.py @@ -0,0 +1,183 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + + +import pytest +import requests +import torch +from PIL import Image +from transformers import AutoModelForCausalLM, AutoProcessor, TextStreamer + +# from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter +from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForImageTextToText + +# from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers +from QEfficient.utils import hf_download + +# from QEfficient.utils._utils import load_hf_processor +from QEfficient.utils.constants import Constants +from QEfficient.utils.device_utils import get_available_device_id + +# from QEfficient.utils.run_utils import ApiRunner + +test_models = [ + # "liuhaotian/llava-v1.5-7b", + "microsoft/Phi-3.5-vision-instruct", +] + + +def load_vlm_model(model_config): + """ + Function to load model from huggingface and transform to KV model + -------- + + :model_config: Dict + + :return model_hf, params + """ + model_path = hf_download( + repo_id=model_config["model_name"], + ignore_patterns=["*.onnx", "*.ot", "*.md", "*.tflite", "*.pdf", "*.h5", "*.msgpack"], + ) + model_hf = AutoModelForCausalLM.from_pretrained( + model_path, + use_cache=True, + num_hidden_layers=model_config["n_layer"], + _attn_implementation="eager", + low_cpu_mem_usage=False, + ) # Run models for single layers only + params = sum(p.numel() for p in model_hf.parameters()) + model_hf.eval() + return model_hf, params + + +def _generate_inputs(model, processor): + ## PREPROCESSING THE MULTI-MODAL INPUTS + images = [] + placeholder = "" + + # Note: if OOM, you might consider reduce number of frames in this example. + for i in range(1, 2): + url = f"https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-{i}-2048.jpg" + images.append(Image.open(requests.get(url, stream=True).raw)) + placeholder += f"<|image_{1}|>\n" + + messages = [ + {"role": "user", "content": placeholder + "Summarize the deck of slides."}, + ] + + prompt = processor.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + # padding="max_length", + # max_length=seq_len + ) + head_dim = model.config.hidden_size // model.config.num_attention_heads + ctx_len = 1280 # FIXME: Pass it a ssome arguement later on + inputs = dict(processor(images=images, text=prompt, return_tensors="pt")) + inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1) + inputs["past_key_values"] = [] + for i in range(model.config.num_hidden_layers): + inputs["past_key_values"].append( + ( + torch.zeros(1, model.config.num_key_value_heads, ctx_len, head_dim), + torch.zeros(1, model.config.num_key_value_heads, ctx_len, head_dim), + ) + ) + return inputs + + +def check_vlm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name: str, + prompt_len: int = Constants.PROMPT_LEN, + ctx_len: int = Constants.CTX_LEN, + n_layer: int = 1, + # num_speculative_tokens: Optional[int] = None, +): + """ + Validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``Phi-3.5-vision-instruct`` + :prompt_len (int): Prompt length for the model to compile. + :ctx_len (int): Maximum context length to compile the model. + :n_layers (int): Number of layers for the Model. + """ + # replace_transformers_quantizers() + model_config = {"model_name": model_name} + model_config["n_layer"] = n_layer + + model_hf, _ = load_vlm_model(model_config) + # Load processor instead + processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) + # config = model_hf.config + # batch_size = len(Constants.INPUT_STR) + streamer = TextStreamer(processor) + # Testing for Phi-3.5 only atm + inputs = _generate_inputs(model_hf, processor) + breakpoint() + # Original PyTorch model + pt_model = AutoModelForCausalLM.from_pretrained( + model_name, + num_hidden_layers=n_layer, + _attn_implementation="eager", + trust_remote_code=True, + # Check if this works + rope_scaling=None, + ) + # TODO: Don't use API RUNNER CLASS BUT directly define those functions here + # pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf) + + # is_tlm = False if num_speculative_tokens is None else True + qeff_model = QEFFAutoModelForImageTextToText(pt_model, processor, is_tlm=False) + + # pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model) + + # assert ( + # pytorch_hf_tokens == pytorch_kv_tokens + # ).all(), "Tokens don't match for HF PyTorch model output and KV PyTorch model output" + + # onnx_model_path = qeff_model.export() + qeff_model.export() + # ort_tokens = api_runner.run_kv_model_on_ort(onnx_model_path, is_tlm=is_tlm) + + # assert (pytorch_kv_tokens == ort_tokens).all(), "Tokens don't match for ONNXRT output and PyTorch output." + + if not get_available_device_id(): + pytest.skip("No available devices to run model on Cloud AI 100") + + _ = qeff_model.compile( + prefill_seq_len=prompt_len, + ctx_len=ctx_len, + num_cores=14, + mxfp6=False, + aic_enable_depth_first=False, + # num_speculative_tokens=num_speculative_tokens, + ) + # exec_info = qeff_model.generate(inputs, streamer, device_ids=[0], runtime_ai100=True) + qeff_model.generate(inputs, streamer, device_ids=[0], runtime_ai100=False) + # cloud_ai_100_tokens = exec_info[0] # Because we always run for single input and single batch size + # gen_len = ort_tokens.shape[-1] + # assert ( + # ort_tokens == cloud_ai_100_tokens[:, :gen_len] + # ).all(), "Tokens don't match for ONNXRT output and Cloud AI 100 output." + + +@pytest.mark.on_qaic +@pytest.mark.parametrize("model_name", test_models) +def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): + """ + Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + if model_name == "microsoft/Phi-3-mini-4k-instruct": + n_layer = 2 # test only 2 layer models + else: + n_layer = 32 + + check_vlm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer) From b97cd22783a8710126fd105e7a1654bdf66c570d Mon Sep 17 00:00:00 2001 From: Dipankar Sarkar Date: Wed, 29 Jan 2025 04:46:11 +0000 Subject: [PATCH 2/2] Adding VLM Pipeline to Qeff Signed-off-by: Dipankar Sarkar --- .../transformers/models/modeling_auto.py | 163 +++++++++--------- .../phi3_vision/modeling_phi3_vision.py | 71 ++++---- .../transformers/models/test_vlm_single.py | 142 +++++++++++++++ QEfficient/utils/constants.py | 13 +- tests/transformers/models/test_vlm_models.py | 57 +----- 5 files changed, 278 insertions(+), 168 deletions(-) create mode 100755 QEfficient/transformers/models/test_vlm_single.py diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 31362aa36..e716bfb1a 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -30,12 +30,11 @@ from QEfficient.base.modeling_qeff import QEFFBaseModel from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.generation.text_generation_inference import write_io_files from QEfficient.transformers.models.pytorch_transforms import CustomOpsTransform, KVCacheTransform, SpDTransform from QEfficient.transformers.quantizers.auto import QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING, with_replaced_quantizers from QEfficient.transformers.quantizers.quant_transforms import AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform from QEfficient.utils import constants, get_padding_shape_from_config - -# from QEfficient.transformers.models.phi3_vision.modeling_phi3_vision import Phi3VModelWrapper from QEfficient.utils.cache import to_hashable logger = logging.getLogger(__file__) @@ -435,7 +434,7 @@ def generate( class QEFFAutoModelForImageTextToText(QEFFTransformersBase): """ - The QEFF class is designed for manipulating any causal language model from the HuggingFace hub. + The QEFF class is designed for manipulating any multimodal language model from the HuggingFace hub. Although it is possible to initialize the class directly, we highly recommend using the ``from_pretrained`` method for initialization. ``Mandatory`` Args: @@ -466,6 +465,7 @@ def __init__( model: nn.Module, processor: AutoProcessor, continuous_batching: bool = False, + qpc_session=None, is_tlm: bool = False, **kwargs, ): @@ -480,11 +480,7 @@ def __init__( if kwargs.pop("full_batch_size", None): raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") - # warnings.warn( - # "full_batch_size argument is deprecated. Use continuous_batching=True instead.", DeprecationWarning, 2 - # ) - # breakpoint() - # phi3vwrapper_model=Phi3VModelWrapper(model) + super().__init__(model) # Set use_cache=True to get KV values as output during ONNX export @@ -495,8 +491,9 @@ def __init__( self.head_dim = model.config.hidden_size // model.config.num_attention_heads self.continuous_batching = continuous_batching self.is_tlm = is_tlm + self.qpc_session = qpc_session self.pad_token_id = model.config.pad_token_id - self.ctx_len = 1280 + self.ctx_len = constants.VlmConstants.CTX_LEN @classmethod def from_pretrained( @@ -550,11 +547,7 @@ def model_hash(self) -> str: def _generate_inputs(self, **kwargs): bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE - # seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN - # fbs = constants.ONNX_EXPORT_EXAMPLE_FBS - self.ctx_len = kwargs["ctx_len"] if "ctx_len" in kwargs else self.ctx_len - ## PREPROCESSING THE MULTI-MODAL INPUTS for Phi-3.5 for now # TODO: Create a map for the other models to have their own inputs accordingly images = [] @@ -594,7 +587,6 @@ def _generate_inputs(self, **kwargs): dynamic_axes = { "input_ids": {0: "batch_size", 1: "seq_len"}, "position_ids": {0: "batch_size", 1: "seq_len"}, - # "pixel_values": {0: "img_batch_size"}, } for i in range(self.num_layers): dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} @@ -602,7 +594,6 @@ def _generate_inputs(self, **kwargs): # Avoid issues due to index out of range inputs["position_ids"] = torch.full(inputs["position_ids"].shape, self.ctx_len - 1) - return inputs, dynamic_axes, output_names def export( @@ -620,9 +611,7 @@ def export( Returns: :str: Path of the generated ``ONNX`` graph. """ - example_inputs, dynamic_axes, output_names = self._generate_inputs(**kwargs) - # breakpoint() return self._export( example_inputs, output_names, @@ -674,28 +663,7 @@ def compile( Returns: :str: Path of the compiled ``qpc`` package. """ - # if self.is_tlm: - # # assert num_speculative_tokens cfg is acceptable if defined - # if num_speculative_tokens is None: - # raise TypeError("missing required argument `num_speculative_tokens` as `is_tlm` is True.") - # if not isinstance(num_speculative_tokens, int) and num_speculative_tokens < 2: - # ValueError( - # f"`num_speculative_tokens` arg should be an integer greater than 1, got {num_speculative_tokens}" - # ) - # num_logits_to_keep = num_speculative_tokens + 1 - # if prefill_seq_len < num_logits_to_keep: - # raise ValueError( - # f"sequence length ({prefill_seq_len}) must be at least `num_speculative_tokens+1` ({num_logits_to_keep})" - # ) - - # if self.continuous_batching and full_batch_size is None: - # raise TypeError("missing required argument: 'full_batch_size'") - - # if kv_cache_batch_size and not full_batch_size: - # raise ValueError( - # "Prefix caching is enabled only for continuous batching. Please pass `full_batch_size` argument and make sure you pass `continuous_batching=True` in the `from_pretrained` call" - # ) - + self.seq_len_constant = prefill_seq_len kv_cache_batch_size = ( kv_cache_batch_size if kv_cache_batch_size else (full_batch_size if full_batch_size else batch_size) ) @@ -759,8 +727,6 @@ def compile( for i in range(self.num_layers): for kv in ["key", "value"]: custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype - - breakpoint() qpc_path = self._compile( onnx_path, compile_dir, @@ -799,16 +765,17 @@ def generate( if not isinstance(self.qpc_path, Path): raise TypeError("Please run compile API first!") - return self.cloud_ai_100_vlm_generate(inputs=inputs, device_ids=device_ids) + return self.cloud_ai_100_vlm_generate(inputs=inputs, device_ids=device_ids, streamer=streamer) # PyTorch runtime else: return self.pytorch_vlm_generate(model=self.model, inputs=inputs, streamer=streamer) - # TODO: Add the code based on how we did in single inference script def cloud_ai_100_vlm_generate( self, inputs: torch.Tensor, device_ids: List[int] = [0], + streamer: TextStreamer = None, + write_io_dir: Optional[str] = None, ) -> np.ndarray: """ Generates features with list of prompts using AI 100 runtime. @@ -821,7 +788,6 @@ def cloud_ai_100_vlm_generate( Returns: np.ndarray: A list of dictionaries containing the generated output features. """ - if self.qpc_session is None: self.qpc_session = QAICInferenceSession(str(self.qpc_path), device_ids) self.batch_size = self.qpc_session.bindings[0].dims[0] @@ -833,37 +799,83 @@ def cloud_ai_100_vlm_generate( ) # Read prompt and ctx len from session - # batch_size = max( - # [x[self.qpc_session.binding_index_map["input_ids"]][1][0] for x in self.qpc_session.allowed_shapes] - # + [self.qpc_session.bindings[self.qpc_session.binding_index_map["input_ids"]].dims[0]] - # ) - - # prefill_seq_len = max( - # [x[self.qpc_session.binding_index_map["input_ids"]][1][1] for x in self.qpc_session.allowed_shapes] - # + [self.qpc_session.bindings[self.qpc_session.binding_index_map["input_ids"]].dims[1]] - # ) - # Prepare input - input_ids_len = inputs["input_ids"].shape[1] - input_ids = np.array( - torch.nn.functional.pad(inputs["input_ids"], (0, self.seq_len - inputs["input_ids"].size(1)), "constant", 0) + batch_size = max( + [x[self.qpc_session.binding_index_map["input_ids"]][1][0] for x in self.qpc_session.allowed_shapes] + + [self.qpc_session.bindings[self.qpc_session.binding_index_map["input_ids"]].dims[0]] ) - attention_mask = np.array( - torch.nn.functional.pad( - inputs["attention_mask"], (0, self.seq_len - inputs["attention_mask"].size(1)), "constant", 0 - ) + prefill_seq_len = max( + [x[self.qpc_session.binding_index_map["input_ids"]][1][1] for x in self.qpc_session.allowed_shapes] + + [self.qpc_session.bindings[self.qpc_session.binding_index_map["input_ids"]].dims[1]] + ) + input_len = inputs["attention_mask"].sum(1, keepdims=True) + padded_len = inputs["input_ids"].shape[1] + # input_len=padded_len + num_chunks = -(padded_len // -prefill_seq_len) # ceil divide without float + padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len + generation_len = self.ctx_len - input_len.max() # in standalone this is tensor + assert generation_len > 0, "generation length should be greater than zero" + generated_ids = np.full((batch_size, generation_len + 1), self.processor.tokenizer.pad_token_id) + # inputs["input_ids"]=torch.nn.functional.pad(inputs["input_ids"],(0,self.seq_len_constant-inputs["input_ids"].size(1)),"constant",self.pad_token_id) + # inputs["position_ids"]=torch.nn.functional.pad(inputs["position_ids"],(0,self.seq_len_constant-inputs["position_ids"].size(1)),"constant",-1) + input_ids = inputs["input_ids"] + input_ids_size = input_ids.shape[1] + # attention_mask = inputs["attention_mask"] + inputs["input_ids"] = torch.nn.functional.pad( + inputs["input_ids"], (0, 1024 - input_ids_size), "constant", self.processor.tokenizer.pad_token_id + ) + inputs["attention_mask"] = torch.nn.functional.pad( + inputs["attention_mask"], (0, 1024 - input_ids_size), "constant", 0 ) + # for k,v in inputs.items(): + # inputs[k]=np.array(v) + # input_len=np.array(input_len).reshape(1,1) + # inputs["pixel_values"]=inputs["pixel_values"].astype("float16") + inputs["input_ids"] = inputs["input_ids"].numpy() + inputs["attention_mask"] = inputs["attention_mask"].numpy() + inputs["pixel_values"] = inputs["pixel_values"].numpy().astype("float16") + inputs["image_sizes"] = inputs["image_sizes"].numpy() + inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) + for i in range(num_chunks): + chunk_inputs = inputs.copy() + chunk_inputs["input_ids"] = inputs["input_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len] + chunk_inputs["position_ids"] = inputs["position_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len] + outputs = self.qpc_session.run(chunk_inputs) + if write_io_dir: + write_io_files(inputs, outputs, write_io_dir, "prefill", "aic_batch_io", True, False) + # Get first token + inputs["input_ids"] = outputs["logits"].argmax(2) + inputs["position_ids"] = input_len.numpy() + generated_ids[:, 0] = inputs["input_ids"].squeeze(1) + finished_sequences = inputs["input_ids"] == self.processor.tokenizer.eos_token_id + if streamer is not None: + streamer.put(inputs["input_ids"][0]) + self.qpc_session.skip_buffers(["pixel_values"]) + self.qpc_session.skip_buffers(["image_sizes"]) + inputs.pop("pixel_values") + inputs.pop("image_sizes") + + # Decode loop + generation_len = torch.tensor(generation_len) + for num_token in range(1, generation_len): + outputs = self.qpc_session.run(inputs) + if write_io_dir: + write_io_files(inputs, outputs, write_io_dir, "decode", "aic_batch_io", True, False) + write_io_dir = None + + # Prepare inputs for next iteration + inputs["input_ids"] = outputs["logits"].argmax(2) - inputs = dict(input_ids=input_ids, attention_mask=attention_mask) + inputs["position_ids"] += 1 + generated_ids[:, num_token] = inputs["input_ids"].squeeze(1) - outputs = { - "output": np.random.randn(self.batch_size, self.seq_len, self.qpc_session.bindings[2].dims[2]).astype( - np.float32 - ), - } - self.qpc_session.set_buffers(outputs) - outputs = self.qpc_session.run(inputs) - outputs = outputs["output"][:, :input_ids_len, :] - return outputs + finished_sequences |= inputs["input_ids"] == self.processor.tokenizer.eos_token_id + if streamer is not None: + streamer.put(inputs["input_ids"][0]) + if finished_sequences.all(): + break + generated_texts = self.processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + breakpoint() + return generated_texts def pytorch_vlm_generate( self, @@ -882,13 +894,6 @@ def pytorch_vlm_generate( Returns: torch.Tensor: A list of output features generated by the model for each prompt. """ - # inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1) - # inputs["past_key_values"] = [] - # for _ in range(model.config.num_hidden_layers): - # inputs["past_key_values"].append(( - # torch.zeros(1, model.config.num_key_value_heads, self.ctx_len,self.head_dim), - # torch.zeros(1, model.config.num_key_value_heads, self.ctx_len, self.head_dim), - # )) self.batch_size = inputs["input_ids"].shape[0] generation_len = self.ctx_len - inputs["input_ids"].shape[1] generated_ids = torch.full((self.batch_size, generation_len + 1), self.processor.tokenizer.pad_token_id) @@ -897,7 +902,7 @@ def pytorch_vlm_generate( inputs["input_ids"] = outputs[0].argmax(2) inputs["position_ids"] = inputs["position_ids"].max(1, keepdim=True).values + 1 - streamer.put(inputs["input_ids"]) + streamer.put(inputs["input_ids"][0]) for _ in range(generation_len): outputs = model(**inputs) diff --git a/QEfficient/transformers/models/phi3_vision/modeling_phi3_vision.py b/QEfficient/transformers/models/phi3_vision/modeling_phi3_vision.py index 46e8cd61e..14b7c880b 100644 --- a/QEfficient/transformers/models/phi3_vision/modeling_phi3_vision.py +++ b/QEfficient/transformers/models/phi3_vision/modeling_phi3_vision.py @@ -15,22 +15,15 @@ """PyTorch Phi-3-V model.""" +import inspect import math from typing import List, Optional, Tuple, Union -# from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask -# try: -# from flash_attn import flash_attn_func, flash_attn_varlen_func -# from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa -# _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) -# except ImportError: -# pass import torch import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss -from transformers import CLIPVisionConfig, CLIPVisionModel, PretrainedConfig from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask @@ -39,7 +32,6 @@ CausalLMOutputWithPast, ) from transformers.modeling_utils import PreTrainedModel -from transformers.models.clip.modeling_clip import CLIPAttention from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -49,6 +41,19 @@ from .configuration_phi3_v import Phi3VConfig +# from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask + +try: + from flash_attn import flash_attn_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) +except ImportError: + pass + +import torch +from transformers import CLIPVisionConfig, CLIPVisionModel, PretrainedConfig + logger = logging.get_logger(__name__) @@ -226,7 +231,8 @@ class QEffPhi3ImageEmbedding(nn.Module): """Phi3 Image embedding.""" def __init__(self, config: PretrainedConfig, wte=None, **kwargs) -> None: - super().__init__() + super().__init__(config) + # n_embed or hidden_size hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size if hasattr(config, "embd_pdrop") or hasattr(config, "embed_pdrop"): @@ -236,6 +242,8 @@ def __init__(self, config: PretrainedConfig, wte=None, **kwargs) -> None: self.drop = None self.wte = wte + breakpoint() + self.config = config if isinstance(config.img_processor, dict) and config.img_processor.get("name", None) == "clip_vision_model": assert "model_name" in config.img_processor, "model_name must be provided for CLIPVisionModel" @@ -248,7 +256,7 @@ def __init__(self, config: PretrainedConfig, wte=None, **kwargs) -> None: self.num_img_tokens = config.img_processor["num_img_tokens"] # FA2 in CLIP - # if config._attn_implementation == "flash_attention_2": + # if config._attn_implementation == 'flash_attention_2': # for layer in self.img_processor.vision_model.encoder.layers: # clip_fa2 = CLIPAttentionFA2(clip_config) # del layer.self_attn @@ -305,15 +313,6 @@ def __init__(self, config: PretrainedConfig, wte=None, **kwargs) -> None: else: self.layer_idx = -2 self.type_feature = "patch" - self.__qeff_init__() - - def __qeff_init__(self): - # breakpoint() - clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG - for layer in self.img_processor.vision_model.encoder.layers: - clip_eager = CLIPAttention(clip_config) - del layer.self_attn - layer.self_attn = clip_eager def set_img_features(self, img_features: torch.FloatTensor) -> None: self.img_features = img_features @@ -1222,15 +1221,13 @@ def forward( else: raise ValueError("You have to specify either input_ids or inputs_embeds") - past_key_values_length = 0 - if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - + past_key_values_length = 0 if use_cache: use_legacy_cache = not isinstance(past_key_values, Cache) if use_legacy_cache: @@ -1245,7 +1242,6 @@ def forward( position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() - if inputs_embeds is None: if pixel_values is not None and image_sizes is not None: assert self.vision_embed_tokens is not None, "Vision embedding layer is not defined" @@ -1265,18 +1261,18 @@ def forward( " call `tokenizer.padding_side = 'left'` before tokenizing the input. " ) - if self._attn_implementation == "flash_attention_2": - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, - ) + # if self._attn_implementation == "flash_attention_2": + # # 2d mask is passed through the layers + # attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + # else: + # # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) hidden_states = inputs_embeds @@ -1382,8 +1378,6 @@ def forward( else: raise ValueError("You have to specify either input_ids or inputs_embeds") - # past_key_values_length = 0 - if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( @@ -1395,7 +1389,6 @@ def forward( use_legacy_cache = not isinstance(past_key_values, Cache) if use_legacy_cache: past_key_values = DynamicCache.from_legacy_cache(past_key_values) - # past_key_values_length = past_key_values.get_usable_length(seq_length) past_key_values.get_usable_length(seq_length) if inputs_embeds is None: diff --git a/QEfficient/transformers/models/test_vlm_single.py b/QEfficient/transformers/models/test_vlm_single.py new file mode 100755 index 000000000..3fc9f5c8d --- /dev/null +++ b/QEfficient/transformers/models/test_vlm_single.py @@ -0,0 +1,142 @@ +import requests +from PIL import Image +from transformers import AutoModelForCausalLM, AutoProcessor, TextStreamer + +from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForImageTextToText +from QEfficient.utils import hf_download + + +def load_vlm_model(model_config): + """ + Function to load model from huggingface and transform to KV model + -------- + + :model_config: Dict + + :return model_hf, params + """ + model_path = hf_download( + repo_id=model_config["model_name"], + ignore_patterns=["*.onnx", "*.ot", "*.md", "*.tflite", "*.pdf", "*.h5", "*.msgpack"], + ) + model_hf = AutoModelForCausalLM.from_pretrained( + model_path, + use_cache=True, + num_hidden_layers=model_config["n_layer"], + _attn_implementation="eager", + low_cpu_mem_usage=False, + ) # Run models for single layers only + params = sum(p.numel() for p in model_hf.parameters()) + model_hf.eval() + return model_hf, params + + +def _generate_inputs(model, processor): + ## PREPROCESSING THE MULTI-MODAL INPUTS + images = [] + placeholder = "" + + # Note: if OOM, you might consider reduce number of frames in this example. + for i in range(1, 2): + url = f"https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-{i}-2048.jpg" + images.append(Image.open(requests.get(url, stream=True).raw)) + placeholder += f"<|image_{1}|>\n" + + messages = [ + {"role": "user", "content": placeholder + "Summarize the deck of slides."}, + ] + + prompt = processor.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + # padding="max_length", + # max_length=seq_len + ) + model.config.hidden_size // model.config.num_attention_heads + # ctx_len = 1280 # FIXME: Pass it a ssome arguement later on + inputs = dict(processor(images=images, text=prompt, return_tensors="pt")) + inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1) + # inputs["past_key_values"] = [] + # for i in range(model.config.num_hidden_layers): + # inputs["past_key_values"].append(( + # torch.zeros(1, model.config.num_key_value_heads, ctx_len, head_dim), + # torch.zeros(1, model.config.num_key_value_heads, ctx_len, head_dim), + # )) + return inputs + + +def check_vlm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name: str, + prompt_len: int = 1024, + ctx_len: int = 1280, + n_layer: int = 32, + # num_speculative_tokens: Optional[int] = None, +): + """ + Validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``Phi-3.5-vision-instruct`` + :prompt_len (int): Prompt length for the model to compile. + :ctx_len (int): Maximum context length to compile the model. + :n_layers (int): Number of layers for the Model. + """ + # replace_transformers_quantizers() + model_config = {"model_name": model_name} + model_config["n_layer"] = n_layer + + model_hf, _ = load_vlm_model(model_config) + # Load processor instead + processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) + # config = model_hf.config + # batch_size = len(Constants.INPUT_STR) + streamer = TextStreamer(processor) + # Testing for Phi-3.5 only atm + inputs = _generate_inputs(model_hf, processor) + # Original PyTorch model + pt_model = AutoModelForCausalLM.from_pretrained( + model_name, + num_hidden_layers=n_layer, + _attn_implementation="eager", + trust_remote_code=True, + # Check if this works + rope_scaling=None, + ) + # TODO: Don't use API RUNNER CLASS BUT directly define those functions here + # pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf) + + # is_tlm = False if num_speculative_tokens is None else True + qeff_model = QEFFAutoModelForImageTextToText(pt_model, processor, is_tlm=False) + breakpoint() + # pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model) + + # assert ( + # pytorch_hf_tokens == pytorch_kv_tokens + # ).all(), "Tokens don't match for HF PyTorch model output and KV PyTorch model output" + + qeff_model.export() + + # ort_tokens = api_runner.run_kv_model_on_ort(onnx_model_path, is_tlm=is_tlm) + + # assert (pytorch_kv_tokens == ort_tokens).all(), "Tokens don't match for ONNXRT output and PyTorch output." + + # if not get_available_device_id(): + # pytest.skip("No available devices to run model on Cloud AI 100") + breakpoint() + _ = qeff_model.compile( + prefill_seq_len=prompt_len, + ctx_len=ctx_len, + num_cores=14, + mxfp6=False, + aic_enable_depth_first=False, + # num_speculative_tokens=num_speculative_tokens, + ) + exec_info = qeff_model.generate(inputs, streamer, device_ids=None, runtime_ai100=True) + exec_info[0] # Because we always run for single input and single batch size + # gen_len = ort_tokens.shape[-1] + # assert ( + # ort_tokens == cloud_ai_100_tokens[:, :gen_len] + # ).all(), "Tokens don't match for ONNXRT output and Cloud AI 100 output." + + +check_vlm_pytorch_vs_kv_vs_ort_vs_ai100("microsoft/Phi-3.5-vision-instruct", 1024) diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index ea8de3722..d93313b33 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -55,10 +55,21 @@ def get_models_dir(): class Constants: # Export Constants. + SEQ_LEN = 32 + CTX_LEN = 32 + PROMPT_LEN = 8 + INPUT_STR = ["My name is"] + GB = 2**30 + MAX_QPC_LIMIT = 30 + MAX_RETRIES = 5 # This constant will be used set the maximum number of retry attempts for downloading a model using huggingface_hub snapshot_download + NUM_SPECULATIVE_TOKENS = 2 + + +class VlmConstants: SEQ_LEN = 1024 CTX_LEN = 1280 PROMPT_LEN = 8 - INPUT_STR = ["My name is"] + INPUT_STR = ["Summarize the deck of slides"] GB = 2**30 MAX_QPC_LIMIT = 30 MAX_RETRIES = 5 # This constant will be used set the maximum number of retry attempts for downloading a model using huggingface_hub snapshot_download diff --git a/tests/transformers/models/test_vlm_models.py b/tests/transformers/models/test_vlm_models.py index f8b496d0f..1f55daaf3 100755 --- a/tests/transformers/models/test_vlm_models.py +++ b/tests/transformers/models/test_vlm_models.py @@ -8,7 +8,6 @@ import pytest import requests -import torch from PIL import Image from transformers import AutoModelForCausalLM, AutoProcessor, TextStreamer @@ -19,13 +18,10 @@ from QEfficient.utils import hf_download # from QEfficient.utils._utils import load_hf_processor -from QEfficient.utils.constants import Constants +from QEfficient.utils.constants import VlmConstants from QEfficient.utils.device_utils import get_available_device_id -# from QEfficient.utils.run_utils import ApiRunner - test_models = [ - # "liuhaotian/llava-v1.5-7b", "microsoft/Phi-3.5-vision-instruct", ] @@ -74,30 +70,18 @@ def _generate_inputs(model, processor): messages, tokenize=False, add_generation_prompt=True, - # padding="max_length", - # max_length=seq_len ) - head_dim = model.config.hidden_size // model.config.num_attention_heads - ctx_len = 1280 # FIXME: Pass it a ssome arguement later on + model.config.hidden_size // model.config.num_attention_heads + # ctx_len = 1280 # FIXME: Pass it a ssome arguement later on inputs = dict(processor(images=images, text=prompt, return_tensors="pt")) - inputs["position_ids"] = inputs.pop("attention_mask").cumsum(1) - inputs["past_key_values"] = [] - for i in range(model.config.num_hidden_layers): - inputs["past_key_values"].append( - ( - torch.zeros(1, model.config.num_key_value_heads, ctx_len, head_dim), - torch.zeros(1, model.config.num_key_value_heads, ctx_len, head_dim), - ) - ) return inputs def check_vlm_pytorch_vs_kv_vs_ort_vs_ai100( model_name: str, - prompt_len: int = Constants.PROMPT_LEN, - ctx_len: int = Constants.CTX_LEN, + prompt_len: int = VlmConstants.SEQ_LEN, + ctx_len: int = VlmConstants.CTX_LEN, n_layer: int = 1, - # num_speculative_tokens: Optional[int] = None, ): """ Validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. @@ -114,39 +98,20 @@ def check_vlm_pytorch_vs_kv_vs_ort_vs_ai100( model_hf, _ = load_vlm_model(model_config) # Load processor instead processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) - # config = model_hf.config - # batch_size = len(Constants.INPUT_STR) streamer = TextStreamer(processor) - # Testing for Phi-3.5 only atm + streamer.on_finalized_text("< ") inputs = _generate_inputs(model_hf, processor) - breakpoint() # Original PyTorch model pt_model = AutoModelForCausalLM.from_pretrained( model_name, num_hidden_layers=n_layer, _attn_implementation="eager", trust_remote_code=True, - # Check if this works rope_scaling=None, ) - # TODO: Don't use API RUNNER CLASS BUT directly define those functions here - # pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf) - # is_tlm = False if num_speculative_tokens is None else True qeff_model = QEFFAutoModelForImageTextToText(pt_model, processor, is_tlm=False) - - # pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model) - - # assert ( - # pytorch_hf_tokens == pytorch_kv_tokens - # ).all(), "Tokens don't match for HF PyTorch model output and KV PyTorch model output" - - # onnx_model_path = qeff_model.export() qeff_model.export() - # ort_tokens = api_runner.run_kv_model_on_ort(onnx_model_path, is_tlm=is_tlm) - - # assert (pytorch_kv_tokens == ort_tokens).all(), "Tokens don't match for ONNXRT output and PyTorch output." - if not get_available_device_id(): pytest.skip("No available devices to run model on Cloud AI 100") @@ -156,15 +121,9 @@ def check_vlm_pytorch_vs_kv_vs_ort_vs_ai100( num_cores=14, mxfp6=False, aic_enable_depth_first=False, - # num_speculative_tokens=num_speculative_tokens, ) - # exec_info = qeff_model.generate(inputs, streamer, device_ids=[0], runtime_ai100=True) - qeff_model.generate(inputs, streamer, device_ids=[0], runtime_ai100=False) - # cloud_ai_100_tokens = exec_info[0] # Because we always run for single input and single batch size - # gen_len = ort_tokens.shape[-1] - # assert ( - # ort_tokens == cloud_ai_100_tokens[:, :gen_len] - # ).all(), "Tokens don't match for ONNXRT output and Cloud AI 100 output." + exec_info = qeff_model.generate(inputs, streamer, device_ids=None, runtime_ai100=True) + exec_info[0] # Because we always run for single input and single batch size @pytest.mark.on_qaic