|
4 | 4 | # SPDX-License-Identifier: BSD-3-Clause
|
5 | 5 | #
|
6 | 6 | # -----------------------------------------------------------------------------
|
7 |
| -from typing import List, Optional, Tuple, Union |
8 | 7 |
|
9 | 8 | import torch
|
10 | 9 | import torch.utils.checkpoint
|
11 |
| -from torch import nn |
12 | 10 | from transformers.models.llava.modeling_llava import (
|
13 |
| - LlavaCausalLMOutputWithPast, |
14 | 11 | LlavaForConditionalGeneration,
|
15 |
| - logger, |
16 | 12 | )
|
17 | 13 |
|
18 | 14 | BS = 1
|
|
23 | 19 |
|
24 | 20 |
|
25 | 21 | class QEffLlavaForConditionalGeneration(LlavaForConditionalGeneration):
|
26 |
| - def forward( |
27 |
| - self, |
28 |
| - input_ids: torch.LongTensor = None, |
29 |
| - pixel_values: torch.FloatTensor = None, |
30 |
| - attention_mask: Optional[torch.Tensor] = None, |
31 |
| - position_ids: Optional[torch.LongTensor] = None, |
32 |
| - past_key_values: Optional[List[torch.FloatTensor]] = None, |
33 |
| - inputs_embeds: Optional[torch.FloatTensor] = None, |
34 |
| - vision_feature_layer: Optional[int] = None, |
35 |
| - vision_feature_select_strategy: Optional[str] = None, |
36 |
| - labels: Optional[torch.LongTensor] = None, |
37 |
| - use_cache: Optional[bool] = None, |
38 |
| - output_attentions: Optional[bool] = None, |
39 |
| - output_hidden_states: Optional[bool] = None, |
40 |
| - return_dict: Optional[bool] = None, |
41 |
| - cache_position: Optional[torch.LongTensor] = None, |
42 |
| - num_logits_to_keep: int = 0, |
43 |
| - ) -> Union[Tuple, LlavaCausalLMOutputWithPast]: |
44 |
| - r""" |
45 |
| - Args: |
46 |
| - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
47 |
| - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
48 |
| - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
49 |
| - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
50 |
| -
|
51 |
| - num_logits_to_keep (`int`, *optional*): |
52 |
| - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all |
53 |
| - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that |
54 |
| - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. |
55 |
| -
|
56 |
| -
|
57 |
| - Returns: |
58 |
| -
|
59 |
| - Example: |
60 |
| -
|
61 |
| - ```python |
62 |
| - >>> from PIL import Image |
63 |
| - >>> import requests |
64 |
| - >>> from transformers import AutoProcessor, LlavaForConditionalGeneration |
65 |
| -
|
66 |
| - >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf") |
67 |
| - >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") |
68 |
| -
|
69 |
| - >>> prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:" |
70 |
| - >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" |
71 |
| - >>> image = Image.open(requests.get(url, stream=True).raw) |
72 |
| -
|
73 |
| - >>> inputs = processor(images=image, text=prompt, return_tensors="pt") |
74 |
| -
|
75 |
| - >>> # Generate |
76 |
| - >>> generate_ids = model.generate(**inputs, max_new_tokens=15) |
77 |
| - >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
78 |
| - "USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed" |
79 |
| - ```""" |
80 |
| - |
81 |
| - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
82 |
| - output_hidden_states = ( |
83 |
| - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
84 |
| - ) |
85 |
| - return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
86 |
| - vision_feature_layer = ( |
87 |
| - vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer |
88 |
| - ) |
89 |
| - vision_feature_select_strategy = ( |
90 |
| - vision_feature_select_strategy |
91 |
| - if vision_feature_select_strategy is not None |
92 |
| - else self.config.vision_feature_select_strategy |
93 |
| - ) |
94 |
| - |
95 |
| - if (input_ids is None) ^ (inputs_embeds is not None): |
96 |
| - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
97 |
| - |
98 |
| - if pixel_values is not None and inputs_embeds is not None: |
99 |
| - raise ValueError( |
100 |
| - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" |
101 |
| - ) |
102 |
| - |
103 |
| - legacy_processing = False |
104 |
| - if inputs_embeds is None: |
105 |
| - inputs_embeds = self.get_input_embeddings()(input_ids) |
106 |
| - |
107 |
| - # if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing |
108 |
| - # not very reliable, but we don't expect one to actually pass 500+ images for one prompt |
109 |
| - # In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True |
110 |
| - legacy_processing = ( |
111 |
| - (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length |
112 |
| - ) or (input_ids.shape[-1] == 1 and pixel_values is not None) |
113 |
| - |
114 |
| - if pixel_values is not None: |
115 |
| - image_features = self.get_image_features( |
116 |
| - pixel_values=pixel_values, |
117 |
| - vision_feature_layer=vision_feature_layer, |
118 |
| - vision_feature_select_strategy=vision_feature_select_strategy, |
119 |
| - ) |
120 |
| - |
121 |
| - if legacy_processing: |
122 |
| - logger.warning_once( |
123 |
| - "Expanding inputs for image tokens in LLaVa should be done in processing. " |
124 |
| - "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " |
125 |
| - "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " |
126 |
| - "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." |
127 |
| - ) |
128 |
| - # prefill stage vs decoding stage (legacy behavior copied) |
129 |
| - if input_ids.shape[1] != 1: |
130 |
| - inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( |
131 |
| - image_features, inputs_embeds, input_ids, attention_mask, labels |
132 |
| - ) |
133 |
| - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) |
134 |
| - else: |
135 |
| - # Retrieve the first layer to inspect the logits and mask out the hidden states |
136 |
| - # that are set to 0 |
137 |
| - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] |
138 |
| - |
139 |
| - # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 |
140 |
| - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) |
141 |
| - |
142 |
| - # Get the target length |
143 |
| - target_length = input_ids.shape[1] |
144 |
| - past_length = first_layer_past_key_value.shape[-1] |
145 |
| - |
146 |
| - extended_attention_mask = torch.ones( |
147 |
| - (attention_mask.shape[0], past_length), |
148 |
| - dtype=attention_mask.dtype, |
149 |
| - device=attention_mask.device, |
150 |
| - ) |
151 |
| - |
152 |
| - # Filter out only the tokens that can be un-attended, this can happen |
153 |
| - # if one uses Llava + Fused modules where the cache on the |
154 |
| - # first iteration is already big enough, or if one passes custom cache |
155 |
| - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) |
156 |
| - new_batch_index = batch_index[valid_indices] |
157 |
| - new_non_attended_tokens = non_attended_tokens[valid_indices] |
158 |
| - |
159 |
| - # Zero-out the places where we don't need to attend |
160 |
| - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 |
161 |
| - |
162 |
| - attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) |
163 |
| - position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 |
164 |
| - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[ |
165 |
| - -target_length: |
166 |
| - ] |
167 |
| - |
168 |
| - # TODO: @raushan retain only the new behavior after v4.47 |
169 |
| - else: |
170 |
| - n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item() |
171 |
| - n_image_features = image_features.shape[1] |
172 |
| - if n_image_tokens != n_image_features: |
173 |
| - raise ValueError( |
174 |
| - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" |
175 |
| - ) |
176 |
| - |
177 |
| - mask = input_ids == self.config.image_token_index |
178 |
| - indices1 = mask.to(torch.int64).cumsum(1) - 1 |
179 |
| - indices0 = torch.arange(mask.shape[0]).view(-1, 1) |
180 |
| - image_features_expanded = image_features[indices0, indices1] |
181 |
| - image_inputs_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds) |
182 |
| - # *where to skip image encoder for decode* |
183 |
| - inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_inputs_embeds) |
184 |
| - |
| 22 | + def forward(self, input_ids, position_ids, pixel_values, past_key_values): |
| 23 | + inputs_embeds = self.get_input_embeddings()(input_ids) |
| 24 | + # Image features |
| 25 | + image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) |
| 26 | + selected_image_feature = image_outputs.hidden_states[self.config.vision_feature_layer] |
| 27 | + vision_feature_select_strategy = self.config.vision_feature_select_strategy |
| 28 | + if vision_feature_select_strategy == "default": |
| 29 | + selected_image_feature = selected_image_feature[:, 1:] |
| 30 | + elif vision_feature_select_strategy == "full": |
| 31 | + selected_image_feature = selected_image_feature |
| 32 | + else: |
| 33 | + raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}") |
| 34 | + image_features = self.multi_modal_projector(selected_image_feature) |
| 35 | + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) |
| 36 | + |
| 37 | + mask = input_ids == self.config.image_token_index |
| 38 | + indices1 = mask.to(torch.int64).cumsum(1) - 1 |
| 39 | + indices0 = torch.arange(mask.shape[0]).view(-1, 1) |
| 40 | + image_features_expanded = image_features[indices0, indices1] |
| 41 | + image_inputs_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds) |
| 42 | + # *where to skip image encoder for decode* |
| 43 | + inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_inputs_embeds) |
185 | 44 | outputs = self.language_model(
|
186 |
| - attention_mask=attention_mask, |
| 45 | + inputs_embeds=inputs_embeds, |
187 | 46 | position_ids=position_ids,
|
188 | 47 | past_key_values=past_key_values,
|
189 |
| - inputs_embeds=inputs_embeds, |
190 |
| - use_cache=use_cache, |
191 |
| - output_attentions=output_attentions, |
192 |
| - output_hidden_states=output_hidden_states, |
193 |
| - return_dict=return_dict, |
194 |
| - cache_position=cache_position, |
195 |
| - num_logits_to_keep=num_logits_to_keep, |
196 | 48 | )
|
197 |
| - |
198 |
| - logits = outputs[0] |
199 |
| - |
200 |
| - loss = None |
201 |
| - if labels is not None: |
202 |
| - # Shift so that tokens < n predict n |
203 |
| - if attention_mask is not None: |
204 |
| - # we use the input attention mask to shift the logits and labels, because it is 2D. |
205 |
| - # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft |
206 |
| - shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) |
207 |
| - shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() |
208 |
| - shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() |
209 |
| - else: |
210 |
| - shift_logits = logits[..., :-1, :].contiguous() |
211 |
| - shift_labels = labels[..., 1:].contiguous() |
212 |
| - # Flatten the tokens |
213 |
| - loss_fct = nn.CrossEntropyLoss() |
214 |
| - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)) |
215 |
| - |
216 |
| - if not return_dict: |
217 |
| - output = (logits,) + outputs[1:] |
218 |
| - return (loss,) + output if loss is not None else output |
219 |
| - |
220 |
| - return logits, pixel_values, outputs.past_key_values |
| 49 | + return outputs.logits, pixel_values, outputs.past_key_values |
221 | 50 |
|
222 | 51 | def get_dummy_inputs(self, **kwargs):
|
223 | 52 | num_layers = self.config.text_config.num_hidden_layers
|
|
0 commit comments