Skip to content

Commit b9c0bc1

Browse files
author
asmigosw
committed
final revision VLM
Signed-off-by: Amit Raj <quic_amitraj@quicinc.com>
1 parent e75d07d commit b9c0bc1

File tree

6 files changed

+329
-325
lines changed

6 files changed

+329
-325
lines changed

QEfficient/transformers/models/llava/modeling_llava.py

+24-195
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,11 @@
44
# SPDX-License-Identifier: BSD-3-Clause
55
#
66
# -----------------------------------------------------------------------------
7-
from typing import List, Optional, Tuple, Union
87

98
import torch
109
import torch.utils.checkpoint
11-
from torch import nn
1210
from transformers.models.llava.modeling_llava import (
13-
LlavaCausalLMOutputWithPast,
1411
LlavaForConditionalGeneration,
15-
logger,
1612
)
1713

1814
BS = 1
@@ -23,201 +19,34 @@
2319

2420

2521
class QEffLlavaForConditionalGeneration(LlavaForConditionalGeneration):
26-
def forward(
27-
self,
28-
input_ids: torch.LongTensor = None,
29-
pixel_values: torch.FloatTensor = None,
30-
attention_mask: Optional[torch.Tensor] = None,
31-
position_ids: Optional[torch.LongTensor] = None,
32-
past_key_values: Optional[List[torch.FloatTensor]] = None,
33-
inputs_embeds: Optional[torch.FloatTensor] = None,
34-
vision_feature_layer: Optional[int] = None,
35-
vision_feature_select_strategy: Optional[str] = None,
36-
labels: Optional[torch.LongTensor] = None,
37-
use_cache: Optional[bool] = None,
38-
output_attentions: Optional[bool] = None,
39-
output_hidden_states: Optional[bool] = None,
40-
return_dict: Optional[bool] = None,
41-
cache_position: Optional[torch.LongTensor] = None,
42-
num_logits_to_keep: int = 0,
43-
) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
44-
r"""
45-
Args:
46-
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
47-
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
48-
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
49-
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
50-
51-
num_logits_to_keep (`int`, *optional*):
52-
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
53-
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
54-
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
55-
56-
57-
Returns:
58-
59-
Example:
60-
61-
```python
62-
>>> from PIL import Image
63-
>>> import requests
64-
>>> from transformers import AutoProcessor, LlavaForConditionalGeneration
65-
66-
>>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
67-
>>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
68-
69-
>>> prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:"
70-
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
71-
>>> image = Image.open(requests.get(url, stream=True).raw)
72-
73-
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
74-
75-
>>> # Generate
76-
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
77-
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
78-
"USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
79-
```"""
80-
81-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
82-
output_hidden_states = (
83-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
84-
)
85-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
86-
vision_feature_layer = (
87-
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
88-
)
89-
vision_feature_select_strategy = (
90-
vision_feature_select_strategy
91-
if vision_feature_select_strategy is not None
92-
else self.config.vision_feature_select_strategy
93-
)
94-
95-
if (input_ids is None) ^ (inputs_embeds is not None):
96-
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
97-
98-
if pixel_values is not None and inputs_embeds is not None:
99-
raise ValueError(
100-
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
101-
)
102-
103-
legacy_processing = False
104-
if inputs_embeds is None:
105-
inputs_embeds = self.get_input_embeddings()(input_ids)
106-
107-
# if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing
108-
# not very reliable, but we don't expect one to actually pass 500+ images for one prompt
109-
# In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True
110-
legacy_processing = (
111-
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
112-
) or (input_ids.shape[-1] == 1 and pixel_values is not None)
113-
114-
if pixel_values is not None:
115-
image_features = self.get_image_features(
116-
pixel_values=pixel_values,
117-
vision_feature_layer=vision_feature_layer,
118-
vision_feature_select_strategy=vision_feature_select_strategy,
119-
)
120-
121-
if legacy_processing:
122-
logger.warning_once(
123-
"Expanding inputs for image tokens in LLaVa should be done in processing. "
124-
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
125-
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
126-
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
127-
)
128-
# prefill stage vs decoding stage (legacy behavior copied)
129-
if input_ids.shape[1] != 1:
130-
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
131-
image_features, inputs_embeds, input_ids, attention_mask, labels
132-
)
133-
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
134-
else:
135-
# Retrieve the first layer to inspect the logits and mask out the hidden states
136-
# that are set to 0
137-
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
138-
139-
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
140-
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
141-
142-
# Get the target length
143-
target_length = input_ids.shape[1]
144-
past_length = first_layer_past_key_value.shape[-1]
145-
146-
extended_attention_mask = torch.ones(
147-
(attention_mask.shape[0], past_length),
148-
dtype=attention_mask.dtype,
149-
device=attention_mask.device,
150-
)
151-
152-
# Filter out only the tokens that can be un-attended, this can happen
153-
# if one uses Llava + Fused modules where the cache on the
154-
# first iteration is already big enough, or if one passes custom cache
155-
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
156-
new_batch_index = batch_index[valid_indices]
157-
new_non_attended_tokens = non_attended_tokens[valid_indices]
158-
159-
# Zero-out the places where we don't need to attend
160-
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
161-
162-
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
163-
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
164-
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[
165-
-target_length:
166-
]
167-
168-
# TODO: @raushan retain only the new behavior after v4.47
169-
else:
170-
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
171-
n_image_features = image_features.shape[1]
172-
if n_image_tokens != n_image_features:
173-
raise ValueError(
174-
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
175-
)
176-
177-
mask = input_ids == self.config.image_token_index
178-
indices1 = mask.to(torch.int64).cumsum(1) - 1
179-
indices0 = torch.arange(mask.shape[0]).view(-1, 1)
180-
image_features_expanded = image_features[indices0, indices1]
181-
image_inputs_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds)
182-
# *where to skip image encoder for decode*
183-
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_inputs_embeds)
184-
22+
def forward(self, input_ids, position_ids, pixel_values, past_key_values):
23+
inputs_embeds = self.get_input_embeddings()(input_ids)
24+
# Image features
25+
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
26+
selected_image_feature = image_outputs.hidden_states[self.config.vision_feature_layer]
27+
vision_feature_select_strategy = self.config.vision_feature_select_strategy
28+
if vision_feature_select_strategy == "default":
29+
selected_image_feature = selected_image_feature[:, 1:]
30+
elif vision_feature_select_strategy == "full":
31+
selected_image_feature = selected_image_feature
32+
else:
33+
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
34+
image_features = self.multi_modal_projector(selected_image_feature)
35+
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
36+
37+
mask = input_ids == self.config.image_token_index
38+
indices1 = mask.to(torch.int64).cumsum(1) - 1
39+
indices0 = torch.arange(mask.shape[0]).view(-1, 1)
40+
image_features_expanded = image_features[indices0, indices1]
41+
image_inputs_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds)
42+
# *where to skip image encoder for decode*
43+
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_inputs_embeds)
18544
outputs = self.language_model(
186-
attention_mask=attention_mask,
45+
inputs_embeds=inputs_embeds,
18746
position_ids=position_ids,
18847
past_key_values=past_key_values,
189-
inputs_embeds=inputs_embeds,
190-
use_cache=use_cache,
191-
output_attentions=output_attentions,
192-
output_hidden_states=output_hidden_states,
193-
return_dict=return_dict,
194-
cache_position=cache_position,
195-
num_logits_to_keep=num_logits_to_keep,
19648
)
197-
198-
logits = outputs[0]
199-
200-
loss = None
201-
if labels is not None:
202-
# Shift so that tokens < n predict n
203-
if attention_mask is not None:
204-
# we use the input attention mask to shift the logits and labels, because it is 2D.
205-
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
206-
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
207-
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
208-
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
209-
else:
210-
shift_logits = logits[..., :-1, :].contiguous()
211-
shift_labels = labels[..., 1:].contiguous()
212-
# Flatten the tokens
213-
loss_fct = nn.CrossEntropyLoss()
214-
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device))
215-
216-
if not return_dict:
217-
output = (logits,) + outputs[1:]
218-
return (loss,) + output if loss is not None else output
219-
220-
return logits, pixel_values, outputs.past_key_values
49+
return outputs.logits, pixel_values, outputs.past_key_values
22150

22251
def get_dummy_inputs(self, **kwargs):
22352
num_layers = self.config.text_config.num_hidden_layers

0 commit comments

Comments
 (0)