Skip to content

Commit 29f7efd

Browse files
mgoinYuqi Zhang
authored and
Yuqi Zhang
committed
Support LoRA for Mistral3 (vllm-project#17428)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
1 parent 7290532 commit 29f7efd

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

docs/source/models/supported_models.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -990,7 +990,7 @@ See [this page](#generative-models) for more information on how to use generativ
990990
* Mistral3
991991
* T + I<sup>+</sup>
992992
* `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc.
993-
*
993+
* ✅︎
994994
* ✅︎
995995
* ✅︎
996996
- * `MllamaForConditionalGeneration`

vllm/model_executor/models/mistral3.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
1919
RowParallelLinear)
2020
from vllm.model_executor.layers.quantization import QuantizationConfig
21+
from vllm.model_executor.models.module_mapping import MultiModelKeys
2122
from vllm.model_executor.sampling_metadata import SamplingMetadata
2223
from vllm.multimodal import MULTIMODAL_REGISTRY
2324
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
@@ -31,7 +32,8 @@
3132
from vllm.multimodal.profiling import BaseDummyInputsBuilder
3233
from vllm.sequence import IntermediateTensors
3334

34-
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
35+
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
36+
SupportsMultiModal, SupportsPP)
3537
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
3638
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
3739
maybe_prefix, merge_multimodal_embeddings)
@@ -382,8 +384,8 @@ def init_vision_tower_for_llava(
382384
_build_mistral3_processor,
383385
info=_build_mistral3_info,
384386
dummy_inputs=Mistral3DummyInputsBuilder)
385-
class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
386-
SupportsPP):
387+
class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
388+
SupportsMultiModal, SupportsPP):
387389

388390
packed_modules_mapping = {
389391
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
@@ -594,3 +596,12 @@ def load_weights(self, weights: Iterable[Tuple[str,
594596
torch.Tensor]]) -> Set[str]:
595597
loader = AutoWeightsLoader(self)
596598
return loader.load_weights(weights)
599+
600+
def get_mm_mapping(self) -> MultiModelKeys:
601+
"""
602+
Get the module prefix in multimodal models
603+
"""
604+
return MultiModelKeys.from_string_field(
605+
language_model="language_model",
606+
connector="multi_modal_projector",
607+
tower_model="vision_tower")

0 commit comments

Comments
 (0)