Skip to content

Commit eefc008

Browse files
DarkLight1337lk-chen
authored andcommitted
[Optim] Compute multimodal hash only once per item (vllm-project#17314)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent 1c6d739 commit eefc008

File tree

6 files changed

+233
-128
lines changed

6 files changed

+233
-128
lines changed

vllm/model_executor/models/deepseek_vl2.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
2323
ImageSize, MultiModalDataItems)
2424
from vllm.multimodal.processing import (BaseMultiModalProcessor,
25-
BaseProcessingInfo, PromptReplacement,
26-
PromptUpdate)
25+
BaseProcessingInfo, MultiModalHashes,
26+
PromptReplacement, PromptUpdate)
2727
from vllm.multimodal.profiling import BaseDummyInputsBuilder
2828
from vllm.sequence import IntermediateTensors
2929
from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config,
@@ -279,24 +279,26 @@ def _cached_apply_hf_processor(
279279
prompt: Union[str, list[int]],
280280
mm_data_items: MultiModalDataItems,
281281
hf_processor_mm_kwargs: Mapping[str, object],
282-
) -> tuple[list[int], MultiModalKwargs, bool]:
282+
*,
283+
return_mm_hashes: bool,
284+
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
283285
# The processor logic is different for len(images) <= 2 vs > 2
284286
# Since the processing cache assumes that the processor output is
285287
# invariant of how many images are passed per prompt, we only
286288
# perform caching for the most common case
287289
if mm_data_items.get_count("image", strict=False) > 2:
288-
# This code path corresponds to the cache being disabled
289-
return self._apply_hf_processor_main(
290+
return self._apply_hf_processor(
290291
prompt=prompt,
291-
mm_items=mm_data_items,
292+
mm_data_items=mm_data_items,
292293
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
293-
enable_hf_prompt_update=True,
294+
return_mm_hashes=return_mm_hashes,
294295
)
295296

296297
return super()._cached_apply_hf_processor(
297298
prompt=prompt,
298299
mm_data_items=mm_data_items,
299300
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
301+
return_mm_hashes=return_mm_hashes,
300302
)
301303

302304

vllm/model_executor/models/h2ovl.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from vllm.multimodal.inputs import MultiModalKwargs
2020
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
2121
MultiModalDataItems)
22-
from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
23-
PromptUpdateDetails)
22+
from vllm.multimodal.processing import (MultiModalHashes, PromptReplacement,
23+
PromptUpdate, PromptUpdateDetails)
2424
from vllm.transformers_utils.tokenizer import AnyTokenizer
2525

2626
from .intern_vit import InternVisionModel
@@ -488,24 +488,26 @@ def _cached_apply_hf_processor(
488488
prompt: Union[str, list[int]],
489489
mm_data_items: MultiModalDataItems,
490490
hf_processor_mm_kwargs: Mapping[str, object],
491-
) -> tuple[list[int], MultiModalKwargs, bool]:
491+
*,
492+
return_mm_hashes: bool,
493+
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
492494
# The processor logic is different for len(images) <= 1 vs > 1
493495
# Since the processing cache assumes that the processor output is
494496
# invariant of how many images are passed per prompt, we only
495497
# perform caching for the most common case
496498
if mm_data_items.get_count("image", strict=False) > 1:
497-
# This code path corresponds to the cache being disabled
498-
return self._apply_hf_processor_main(
499+
return self._apply_hf_processor(
499500
prompt=prompt,
500-
mm_items=mm_data_items,
501+
mm_data_items=mm_data_items,
501502
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
502-
enable_hf_prompt_update=True,
503+
return_mm_hashes=return_mm_hashes,
503504
)
504505

505506
return super()._cached_apply_hf_processor(
506507
prompt=prompt,
507508
mm_data_items=mm_data_items,
508509
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
510+
return_mm_hashes=return_mm_hashes,
509511
)
510512

511513

vllm/model_executor/models/llava.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -396,22 +396,19 @@ def _build_llava_or_pixtral_hf_processor(
396396
dummy_inputs: BaseDummyInputsBuilder[_I],
397397
*,
398398
cache: Optional[ProcessingCache] = None,
399-
enable_sanity_checks: bool = True,
400399
) -> BaseMultiModalProcessor:
401400
if isinstance(info, PixtralHFProcessingInfo):
402401
return PixtralHFMultiModalProcessor(
403402
info,
404403
dummy_inputs, # type: ignore
405404
cache=cache,
406-
enable_sanity_checks=enable_sanity_checks,
407405
)
408406

409407
if isinstance(info, LlavaProcessingInfo):
410408
return LlavaMultiModalProcessor(
411409
info,
412410
dummy_inputs, # type: ignore
413411
cache=cache,
414-
enable_sanity_checks=enable_sanity_checks,
415412
)
416413

417414
raise NotImplementedError(type(info))

vllm/model_executor/models/mistral3.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -312,14 +312,12 @@ def _build_mistral3_processor(
312312
dummy_inputs: BaseDummyInputsBuilder[_I],
313313
*,
314314
cache: Optional[ProcessingCache] = None,
315-
enable_sanity_checks: bool = True,
316315
) -> BaseMultiModalProcessor:
317316
assert isinstance(info, Mistral3ProcessingInfo)
318317
return Mistral3MultiModalProcessor(
319318
info,
320319
dummy_inputs, # type: ignore
321320
cache=cache,
322-
enable_sanity_checks=enable_sanity_checks,
323321
)
324322

325323

vllm/model_executor/models/pixtral.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,9 @@
3636
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
3737
MultiModalDataItems)
3838
from vllm.multimodal.processing import (BaseMultiModalProcessor,
39-
BaseProcessingInfo, PromptReplacement,
40-
PromptUpdate, PromptUpdateDetails)
39+
BaseProcessingInfo, MultiModalHashes,
40+
PromptReplacement, PromptUpdate,
41+
PromptUpdateDetails)
4142
from vllm.multimodal.profiling import BaseDummyInputsBuilder
4243
from vllm.sequence import IntermediateTensors
4344
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
@@ -271,15 +272,19 @@ def _cached_apply_hf_processor(
271272
prompt: Union[str, list[int]],
272273
mm_data_items: MultiModalDataItems,
273274
hf_processor_mm_kwargs: Mapping[str, object],
274-
) -> tuple[list[int], MultiModalKwargs, bool]:
275-
prompt_ids, mm_kwargs, _ = super()._cached_apply_hf_processor(
275+
*,
276+
return_mm_hashes: bool,
277+
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
278+
prompt_ids, mm_kwargs, mm_hashes, _ = super(
279+
)._cached_apply_hf_processor(
276280
prompt=prompt,
277281
mm_data_items=mm_data_items,
278282
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
283+
return_mm_hashes=return_mm_hashes,
279284
)
280285

281286
# NOTE: The tokens are already inserted by the chat template
282-
return prompt_ids, mm_kwargs, True
287+
return prompt_ids, mm_kwargs, mm_hashes, True
283288

284289

285290
@MULTIMODAL_REGISTRY.register_processor(PixtralMultiModalProcessor,

0 commit comments

Comments
 (0)