Skip to content

[Model] support MiniMax-VL-01 model #16328

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 291 commits into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
291 commits
Select commit Hold shift + click to select a range
166efb2
fix code
qscqesze Apr 8, 2025
ec91d9a
fix code
qscqesze Apr 8, 2025
91551ca
fix code
qscqesze Apr 8, 2025
4aa950d
fix code
qscqesze Apr 8, 2025
5fb9ffc
fix code
qscqesze Apr 8, 2025
173c9a0
fix code
qscqesze Apr 8, 2025
9c2b6e2
fix bug
qscqesze Apr 8, 2025
ac1b7e2
fix code
qscqesze Apr 8, 2025
789dd0e
fix code
qscqesze Apr 8, 2025
1cdd2cf
fix code
qscqesze Apr 8, 2025
d616131
fix code
qscqesze Apr 8, 2025
06fce4a
fix
qscqesze Apr 8, 2025
7a5c1f5
fix code
qscqesze Apr 8, 2025
509d5fc
fix code
qscqesze Apr 8, 2025
40d9075
fix code
qscqesze Apr 8, 2025
0f3ab18
fix code
qscqesze Apr 8, 2025
999a3ab
fix code
qscqesze Apr 8, 2025
4db1ba8
fix code
qscqesze Apr 8, 2025
37fc348
fix code
qscqesze Apr 8, 2025
0fd255b
fix code
qscqesze Apr 8, 2025
415926e
fix code
qscqesze Apr 8, 2025
fa150a1
fix code
qscqesze Apr 8, 2025
18d8ca3
fix code
qscqesze Apr 8, 2025
3245ca8
fix code
qscqesze Apr 8, 2025
e701618
fix code
qscqesze Apr 8, 2025
86e7656
fix code
qscqesze Apr 8, 2025
4954958
fix code
qscqesze Apr 8, 2025
4efc192
fix code
qscqesze Apr 8, 2025
a855bbe
fix code
qscqesze Apr 8, 2025
6976391
fix code
qscqesze Apr 8, 2025
69db01a
fix code
qscqesze Apr 8, 2025
9421581
fix code
qscqesze Apr 9, 2025
27ccae4
fix code
qscqesze Apr 9, 2025
540cd96
fix code
qscqesze Apr 9, 2025
ac84a09
fix code
qscqesze Apr 9, 2025
df2e97b
fix code
qscqesze Apr 9, 2025
0d7424f
fix code
qscqesze Apr 9, 2025
a1aa44c
fix code
qscqesze Apr 9, 2025
b64e156
fix code
qscqesze Apr 9, 2025
ee7f68d
fix code
qscqesze Apr 9, 2025
dfa4361
fix code
qscqesze Apr 9, 2025
87bbf9b
fix code
qscqesze Apr 9, 2025
153c34a
fix code
qscqesze Apr 9, 2025
79251db
fix code
qscqesze Apr 9, 2025
3e3d16e
fix code
qscqesze Apr 9, 2025
fe6271c
fix code
qscqesze Apr 9, 2025
c8b0c18
fix code
qscqesze Apr 9, 2025
9c2cbb6
fix code
qscqesze Apr 9, 2025
cd882eb
fix code
qscqesze Apr 9, 2025
bc753a0
fix code
qscqesze Apr 9, 2025
2f20b64
fix code
qscqesze Apr 9, 2025
9c82188
fix code
qscqesze Apr 9, 2025
e7514d5
fix code
qscqesze Apr 9, 2025
3746251
fix code
qscqesze Apr 10, 2025
1fd388c
fix code
qscqesze Apr 10, 2025
641ca08
fix code
qscqesze Apr 10, 2025
d27bc8d
fix code
qscqesze Apr 10, 2025
69739f4
fix code
qscqesze Apr 10, 2025
0961a29
fix code
qscqesze Apr 10, 2025
d95adbc
fix code
qscqesze Apr 10, 2025
857e78e
fix code
qscqesze Apr 10, 2025
a4b4ab1
fix code
qscqesze Apr 10, 2025
4dd8722
fix code
qscqesze Apr 10, 2025
be64010
fix code
qscqesze Apr 10, 2025
cc59771
fix code
qscqesze Apr 10, 2025
41eede7
fix image
qscqesze Apr 10, 2025
28e5f83
fix code
qscqesze Apr 10, 2025
2ae6a30
fix code
qscqesze Apr 10, 2025
cb12273
fix code
qscqesze Apr 10, 2025
c35403f
fix code
qscqesze Apr 10, 2025
d8cf84c
add image token
qscqesze Apr 10, 2025
e45cbf8
fix code
qscqesze Apr 10, 2025
47464d1
fix code
qscqesze Apr 10, 2025
6086190
fix code
qscqesze Apr 10, 2025
8183e1b
fix code
qscqesze Apr 10, 2025
0e08d41
fix code
qscqesze Apr 10, 2025
17b31f9
fix code
qscqesze Apr 10, 2025
0c38058
fix code
qscqesze Apr 10, 2025
87a1410
fix code
qscqesze Apr 10, 2025
e141853
fix code
qscqesze Apr 10, 2025
a38b04d
fix code
qscqesze Apr 10, 2025
1c21295
fix code
qscqesze Apr 10, 2025
48232ec
fix code
qscqesze Apr 10, 2025
ca3390e
fix code
qscqesze Apr 10, 2025
9a13af6
fix code
qscqesze Apr 11, 2025
7397a4a
add
qscqesze Apr 11, 2025
1d21b15
fix code
qscqesze Apr 11, 2025
59a98a0
fix code
qscqesze Apr 11, 2025
48a975d
fix code
qscqesze Apr 11, 2025
4b4c81f
fix code
qscqesze Apr 11, 2025
4f8bc81
fix code
qscqesze Apr 11, 2025
6425531
fix code
qscqesze Apr 11, 2025
69cf8d4
fix code
qscqesze Apr 11, 2025
c6f3fcb
fix code
qscqesze Apr 11, 2025
9751164
fix code
qscqesze Apr 11, 2025
e2bab67
fix code
qscqesze Apr 11, 2025
c32a767
fix ce
qscqesze Apr 11, 2025
dad312c
ix code
qscqesze Apr 11, 2025
0a65921
fix code
qscqesze Apr 11, 2025
971753c
fix code
qscqesze Apr 11, 2025
019626d
fix code
qscqesze Apr 14, 2025
0310224
fix code
qscqesze Apr 14, 2025
b464cb9
fix pixel number
qscqesze Apr 14, 2025
fc25ee5
fix code
qscqesze Apr 14, 2025
ee494ab
fix ce
qscqesze Apr 14, 2025
ab2aef2
fix code
qscqesze Apr 14, 2025
9e0eabc
fix code
qscqesze Apr 14, 2025
75d5fff
fix code
qscqesze Apr 14, 2025
04afbc5
fix code
qscqesze Apr 14, 2025
07d4675
fix code
qscqesze Apr 14, 2025
5341027
fix code
qscqesze Apr 14, 2025
bdfa8ef
fix code
qscqesze Apr 14, 2025
0925313
fix code
qscqesze Apr 14, 2025
4d5948a
fix code
qscqesze Apr 14, 2025
b5f1125
fix code
qscqesze Apr 14, 2025
198f3aa
fix code
qscqesze Apr 14, 2025
f56c799
fix code
qscqesze Apr 14, 2025
dc1ab92
fix code
qscqesze Apr 14, 2025
2092aff
fix code
qscqesze Apr 14, 2025
e2be301
fix code
qscqesze Apr 14, 2025
2785b3c
fix code
qscqesze Apr 14, 2025
7f42d0c
fix code
qscqesze Apr 14, 2025
2a41bae
fix code
qscqesze Apr 14, 2025
a9c8bd5
fix code
qscqesze Apr 14, 2025
1cb8bc3
fix code
qscqesze Apr 14, 2025
dcc3a58
fix code
qscqesze Apr 14, 2025
65ff912
fix code
qscqesze Apr 14, 2025
e5d7f4b
fix code
qscqesze Apr 14, 2025
9e416ad
fix code
qscqesze Apr 14, 2025
4bc93e4
fix code
qscqesze Apr 14, 2025
27c1835
fix up processor
qscqesze Apr 15, 2025
54b9b82
fix code
qscqesze Apr 15, 2025
2128b37
fix code
qscqesze Apr 15, 2025
f9bce2f
fix code
qscqesze Apr 15, 2025
de30b06
fix code
qscqesze Apr 15, 2025
f98a6f0
fix code
qscqesze Apr 15, 2025
9bed82d
add code
qscqesze Apr 15, 2025
2af16ea
fix code
qscqesze Apr 15, 2025
9c11ace
fix code
qscqesze Apr 15, 2025
d242762
fix code
qscqesze Apr 15, 2025
77db27d
fix code
qscqesze Apr 15, 2025
705e8b2
change model
qscqesze Apr 15, 2025
5a2677c
fix code
qscqesze Apr 15, 2025
8f2ef44
fix code
qscqesze Apr 15, 2025
54efe71
fix code
qscqesze Apr 15, 2025
f71e00f
fix code
qscqesze Apr 15, 2025
589bc89
fix code
qscqesze Apr 15, 2025
610e6fd
fix code
qscqesze Apr 15, 2025
f705a71
fix code
qscqesze Apr 15, 2025
74fe462
fix code
qscqesze Apr 15, 2025
2449d22
fix code
qscqesze Apr 15, 2025
9776926
fix code
qscqesze Apr 15, 2025
b743790
fix llava
qscqesze Apr 15, 2025
da7ff3f
fix code
qscqesze Apr 15, 2025
069efed
fix code
qscqesze Apr 15, 2025
104f92e
fix code
qscqesze Apr 16, 2025
09842c0
fix import
qscqesze Apr 16, 2025
399c373
fix code
qscqesze Apr 16, 2025
5c701bf
fix code
qscqesze Apr 16, 2025
14b246f
fix code
qscqesze Apr 16, 2025
71a7bf0
fix code
qscqesze Apr 16, 2025
f8c8f54
fix code
qscqesze Apr 16, 2025
d404d25
fix code
qscqesze Apr 16, 2025
088c8de
fix code
qscqesze Apr 16, 2025
89efb39
fix code
qscqesze Apr 16, 2025
84ff3e2
fix code
qscqesze Apr 16, 2025
fc0ccf8
fix code
qscqesze Apr 16, 2025
667f7a9
fix code
qscqesze Apr 16, 2025
bcd0744
fix code
qscqesze Apr 16, 2025
8016ff9
fix code
qscqesze Apr 16, 2025
bef469a
fix code
qscqesze Apr 16, 2025
6b396ff
fix code
qscqesze Apr 16, 2025
9f3daab
fix code
qscqesze Apr 16, 2025
9b9e82b
Refactored the code, optimized import order, fixed comments and forma…
qscqesze Apr 16, 2025
6a2661e
Merge branch 'main' into qingjun/vl
qscqesze Apr 16, 2025
d6cd217
fix import bug
qscqesze Apr 16, 2025
f98c8e2
fix import bug #2
qscqesze Apr 16, 2025
f3837e9
fix import
qscqesze Apr 16, 2025
5b28dd8
fix code
qscqesze Apr 16, 2025
75d9714
fix config
qscqesze Apr 17, 2025
0fa7914
format files
qscqesze Apr 17, 2025
b5f20e1
Removed unused files, including minimax_image_processer.py and minima…
qscqesze Apr 17, 2025
d25e9f7
Added support for the MiniMax-VL-01 model in the test files, and remo…
qscqesze Apr 19, 2025
8d5d9bf
Added SPDX license identifiers to minimax_image_processer.py and mini…
qscqesze Apr 19, 2025
f44ea9a
Deleted the minimax_image_processer.py file and cleaned up redundant …
qscqesze Apr 21, 2025
653c349
Updated the docstrings in the test files to ensure that MiniMaxVL01Mu…
qscqesze Apr 21, 2025
454e015
fix code update
qscqesze Apr 21, 2025
5bede72
format code
qscqesze Apr 21, 2025
d3431f8
fix code
qscqesze Apr 22, 2025
76606f7
fix code
qscqesze Apr 22, 2025
5dda227
Merge branch 'main' into qingjun/vl
qscqesze Apr 22, 2025
467e062
fix code
qscqesze Apr 22, 2025
cc7ee27
fix code
qscqesze Apr 22, 2025
1e7eaae
fix code
qscqesze Apr 22, 2025
85ae307
fix comments
qscqesze Apr 22, 2025
fdc8e32
fix code
qscqesze Apr 22, 2025
07ebdc6
fix code
qscqesze Apr 22, 2025
53d11e1
fix code&fix up
qscqesze Apr 23, 2025
193f35e
fix code
qscqesze Apr 23, 2025
e909429
fix test config
qscqesze Apr 23, 2025
e26ca62
fix dummy run
qscqesze Apr 23, 2025
55419cb
fix code
qscqesze Apr 23, 2025
2b768d3
fix code
qscqesze Apr 23, 2025
cbe89ab
fix code
qscqesze Apr 23, 2025
98c9403
fix code
qscqesze Apr 24, 2025
72647c6
fix code
qscqesze Apr 24, 2025
7c1244c
fix code
qscqesze Apr 24, 2025
1797d48
fix special case
qscqesze Apr 24, 2025
5b7e382
add log
qscqesze Apr 24, 2025
4ce41f2
fix code
qscqesze Apr 24, 2025
4d9b4ac
fix code
qscqesze Apr 24, 2025
31c4b4b
fix code
qscqesze Apr 24, 2025
85e4d11
fix code
qscqesze Apr 24, 2025
0424185
fix code
qscqesze Apr 24, 2025
dbb0e9d
fix code
qscqesze Apr 24, 2025
e831284
fix code
qscqesze Apr 24, 2025
7ae14db
fix code
qscqesze Apr 24, 2025
76f26fb
fix code
qscqesze Apr 24, 2025
6456971
fix code
qscqesze Apr 24, 2025
6663b4e
fix code
qscqesze Apr 24, 2025
9da4960
Merge branch 'main' into qingjun/vl
qscqesze Apr 25, 2025
c1aafbb
fix code
qscqesze Apr 25, 2025
cec44db
fix code
qscqesze Apr 25, 2025
4dd3781
fix code
qscqesze Apr 25, 2025
c5e910e
fix function
qscqesze Apr 25, 2025
79b0baf
fix bug
qscqesze Apr 25, 2025
ea79083
fix bug
qscqesze Apr 25, 2025
4048fb0
fix code
qscqesze Apr 25, 2025
686c56c
fix code
qscqesze Apr 25, 2025
3c18bf5
fix code
qscqesze Apr 25, 2025
656b516
fix code
qscqesze Apr 25, 2025
7b5ff17
fix code
qscqesze Apr 25, 2025
e57d009
fix code
qscqesze Apr 25, 2025
c4f3cef
fix code
qscqesze Apr 25, 2025
6ba30ea
fix code
qscqesze Apr 28, 2025
273278e
add for test
qscqesze Apr 28, 2025
77e2824
add for test
qscqesze Apr 28, 2025
2494e16
fix code
qscqesze Apr 28, 2025
3990d65
fix bug
qscqesze Apr 28, 2025
c4f1225
fix code
qscqesze Apr 28, 2025
80060cb
fix code
qscqesze Apr 28, 2025
2c39391
fix code
qscqesze Apr 28, 2025
b6f26cf
fix code
qscqesze Apr 28, 2025
25d5d66
add test
qscqesze Apr 28, 2025
9ba0a93
fix code
qscqesze Apr 28, 2025
7170f4a
fix code
qscqesze Apr 28, 2025
baae72e
fix code
qscqesze Apr 28, 2025
a1c79ab
fix test
qscqesze Apr 28, 2025
72d7c02
fix code
qscqesze Apr 28, 2025
8f0aee4
fix code
qscqesze Apr 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions tests/models/decoder_only/vision_language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,19 @@
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
patch_hf_runner=model_utils.minicpmv_26_patch_hf_runner,
),
"minimax_vl_01": VLMTestInfo(
models=["MiniMaxAI/MiniMax-VL-01"],
prompt_formatter=lambda img_prompt: f"<beginning_of_sentence>user: {img_prompt} assistant:<end_of_sentence>", # noqa: E501
img_idx_to_prompt=lambda _: "<image>",
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
max_model_len=8192,
max_num_seqs=4,
dtype="bfloat16",
hf_output_post_proc=model_utils.minimax_vl_01_hf_output,
patch_hf_runner=model_utils.minimax_vl_01_patch_hf_runner,
auto_cls=AutoModelForImageTextToText,
marks=[large_gpu_mark(min_gb=80)],
),
"molmo": VLMTestInfo(
models=["allenai/Molmo-7B-D-0924"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,14 @@ def minicpmv_trunc_hf_output(hf_output: RunnerOutput,
return output_ids, output_str, out_logprobs


def minimax_vl_01_hf_output(hf_output: RunnerOutput,
model: str) -> RunnerOutput:
output_ids, output_str, out_logprobs = hf_output
if output_str.endswith("<end_of_sentence>"):
output_str = output_str.split("<end_of_sentence>")[0]
return output_ids, output_str, out_logprobs


####### Functions for converting image assets to embeddings
def get_llava_embeddings(image_assets: _ImageAssets):
return [asset.image_embeds for asset in image_assets]
Expand Down Expand Up @@ -627,6 +635,17 @@ def _generate(self, *args, image_sizes=None, **kwargs):
return hf_model


def minimax_vl_01_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
orig_generate = hf_model.model.generate

def _generate(self, *args, image_sizes=None, **kwargs):
return orig_generate(*args, decode_text=False, **kwargs)

hf_model.model.generate = types.MethodType(_generate, hf_model.model)

return hf_model


def molmo_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for Molmo."""
hf_processor = hf_model.processor
Expand Down
99 changes: 99 additions & 0 deletions tests/models/multimodal/processing/test_minimax_vl_01.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# SPDX-License-Identifier: Apache-2.0

import pytest
from PIL import Image

from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.parse import ImageSize
from vllm.multimodal.processing import BaseMultiModalProcessor

from ....conftest import _ImageAssets
from ...utils import build_model_context


@pytest.mark.parametrize("model_id", ["MiniMaxAI/MiniMax-VL-01"])
# yapf: enable
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_override(
image_assets: _ImageAssets,
model_id: str,
num_imgs: int,
):
ctx = build_model_context(
model_id,
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
prompt = "<image>" * num_imgs
image = Image.new("RGB", size=(364, 364))
mm_data = {"image": [image] * num_imgs}

processed_inputs = processor.apply(prompt, mm_data, {})
image_placeholders = processed_inputs["mm_placeholders"]["image"]

assert len(image_placeholders) == num_imgs


def _validate_image_prompt_replacements_one(
processor: BaseMultiModalProcessor,
num_imgs: int,
failed_size_excs: list[tuple[ImageSize, Exception]],
image_size: ImageSize,
) -> None:
prompt = "<image>" * num_imgs
image = Image.new("RGB", size=image_size)
mm_data = {"image": [image] * num_imgs}

try:
processed_inputs = processor.apply(prompt, mm_data, {})

image_placeholders = processed_inputs["mm_placeholders"]["image"]
assert len(image_placeholders) == num_imgs

except Exception as exc:
failed_size_excs.append((image_size, exc))


def _test_image_prompt_replacements(
processor,
*,
num_imgs: int,
image_sizes: list[ImageSize],
) -> None:

failed_size_excs = list[tuple[ImageSize, Exception]]()

for size in image_sizes:
_validate_image_prompt_replacements_one(processor, num_imgs,
failed_size_excs, size)

if failed_size_excs:
msg = "Found failing image sizes:" \
+ "\n========\n".join(f"[{size}]\n{exc}"
for size, exc in failed_size_excs)
raise AssertionError(msg)


@pytest.mark.parametrize("model_id", ["MiniMaxAI/MiniMax-VL-01"])
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_prompt_replacements_regression(model_id, num_imgs):
ctx = build_model_context(
model_id,
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)

image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328),
(488, 183), (2560, 1669)]
image_sizes = [
size for w, h in image_ratios
for size in [ImageSize(w, h), ImageSize(h, w)]
]

_test_image_prompt_replacements(
processor,
num_imgs=num_imgs,
image_sizes=image_sizes,
)
2 changes: 2 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,8 @@ def check_available_online(
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5",
extras={"2.6": "openbmb/MiniCPM-V-2_6"}, # noqa: E501
trust_remote_code=True),
"MiniMaxVL01ForConditionalGeneration": _HfExamplesInfo("MiniMaxAI/MiniMax-VL-01", # noqa: E501
trust_remote_code=True),
"Mistral3ForConditionalGeneration": _HfExamplesInfo("mistralai/Mistral-Small-3.1-24B-Instruct-2503", # noqa: E501
min_transformers_version="4.50", # noqa: E501
extras={"fp8": "nm-testing/Mistral-Small-3.1-24B-Instruct-2503-FP8-dynamic"}), # noqa: E501
Expand Down
67 changes: 53 additions & 14 deletions vllm/model_executor/models/minimax_text_01.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import copy
import math
import re
from typing import Dict, Iterable, List, Optional, Tuple, Union
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union

import torch
import torch.distributed
Expand Down Expand Up @@ -110,7 +110,17 @@ def _forward(
variance = tensor_model_parallel_all_reduce(
variance) / self.tp_world
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype) * self.weight

weight = self.weight
if x.size(-1) != self.weight.size(0):
if self.weight.size(0) < x.size(-1):
repeat_count = (x.size(-1) + self.weight.size(0)) // x.size(-1)
full_weight = self.weight.repeat(repeat_count)
weight = full_weight[:x.size(-1)]
else:
weight = self.weight[:x.size(-1)]

x = x.to(orig_dtype) * weight
return x

def forward(
Expand Down Expand Up @@ -421,6 +431,10 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
attn_metadata):
hidden = []
for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
if _prefill_idx >= len(attn_metadata.query_start_loc):
break
if _prefill_idx >= len(state_indices_tensor):
break
_start = attn_metadata.query_start_loc[_prefill_idx]
_end = attn_metadata.query_start_loc[_prefill_idx + 1]
slot_id = state_indices_tensor[_prefill_idx]
Expand All @@ -443,6 +457,10 @@ def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
hidden.append(
self._decode_infer(q, k, v, kv_cache, state_indices_tensor,
attn_metadata))

if not hidden:
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)

hidden = torch.concat(hidden, dim=0).contiguous()
return hidden

Expand Down Expand Up @@ -663,6 +681,9 @@ def __init__(
self.shared_moe = False

shared_intermediate = getattr(config, 'shared_intermediate_size', 0)
if isinstance(shared_intermediate, list):
shared_intermediate = shared_intermediate[
layer_id] if layer_id < len(shared_intermediate) else 0
if shared_intermediate > 0:
self.shared_moe = True
self.shared_mlp = MiniMaxText01MLP(
Expand Down Expand Up @@ -875,6 +896,8 @@ def _clear_prefill_cache(self, attn_metadata,

slots_to_clear = []
for _prefill_id in range(getattr(attn_metadata, "num_prefills", 0)):
if _prefill_id >= len(seq_id_map):
break
seq_id = seq_id_map[_prefill_id]
if attn_metadata.context_lens_tensor[
_prefill_id] == 0 and seq_id in seq_to_slot_maps:
Expand All @@ -886,13 +909,18 @@ def _clear_prefill_cache(self, attn_metadata,
dtype=torch.long)
minimax_cache_tensors[:, slots_tensor, ...] = 0

def get_input_embeddings(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
return self.embed_tokens(input_ids)

def forward(self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
intermediate_tensors=None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs) -> torch.Tensor:
**kwargs) -> Union[torch.Tensor, IntermediateTensors]:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
Expand All @@ -901,6 +929,7 @@ def forward(self,
kwargs["request_ids_to_seq_ids"] = {}
if "finished_requests_ids" not in kwargs:
kwargs["finished_requests_ids"] = []

(
minimax_cache_tensors,
state_indices_tensor,
Expand All @@ -922,15 +951,11 @@ def forward(self,
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]

kv_cache_index = 0
minimax_cache_index = 0
attn_metadata.rotary_emb = self.rotary_emb
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
_caches = None
if isinstance(layer.self_attn, MiniMaxText01Attention):
_caches = kv_caches[kv_cache_index]
kv_cache_index += 1
if isinstance(layer.self_attn, MiniMaxText01LinearAttention):
current_state_layer = minimax_cache_index
_caches = minimax_cache_params.at_layer_idx(
Expand Down Expand Up @@ -1009,15 +1034,20 @@ def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs(
batch_size)

def get_input_embeddings(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)

def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, self.kv_cache,
intermediate_tensors, inputs_embeds,
**kwargs)
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds, **kwargs)

return hidden_states

Expand All @@ -1043,8 +1073,9 @@ def make_empty_intermediate_tensors(
})

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> None:
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()

def which_layer(name: str) -> int:
if "layers" in name:
Expand Down Expand Up @@ -1108,6 +1139,7 @@ def load_sparse_moe_weight(name: str, loaded_weight: torch.Tensor,
weight_name,
expert_id=expert_id,
shard_id=shard_id)
loaded_params.add(name)
break
else:
if is_pp_missing_parameter(name, self):
Expand All @@ -1117,6 +1149,7 @@ def load_sparse_moe_weight(name: str, loaded_weight: torch.Tensor,
default_weight_loader)
weight_loader = weight_loader_with_alias(name)(weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return

def is_shared_mlp_weight(name: str) -> bool:
Expand Down Expand Up @@ -1154,6 +1187,7 @@ def load_shared_mlp_weight(name: str, loaded_weight: torch.Tensor,
else:
raise AssertionError(
"MLP weight not in [gate_up_proj, down_proj]")
loaded_params.add(name)
return

def is_mha_weight(name: str) -> bool:
Expand All @@ -1170,6 +1204,7 @@ def load_linear_attn_weight(name: str, loaded_weight: torch.Tensor,
MiniMaxText01LinearAttention.weight_direct_load)
weight_loader = weight_loader_with_alias(name)(weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return

def load_flash_attn_weight(name: str, loaded_weight: torch.Tensor,
Expand All @@ -1194,6 +1229,7 @@ def load_flash_attn_weight(name: str, loaded_weight: torch.Tensor,
default_weight_loader)
weight_loader = weight_loader_with_alias(name)(weight_loader)
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(name)
break
else:
if is_pp_missing_parameter(name, self):
Expand All @@ -1204,6 +1240,7 @@ def load_flash_attn_weight(name: str, loaded_weight: torch.Tensor,
default_weight_loader)
weight_loader = weight_loader_with_alias(name)(weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return

def is_layer_norm_weight(name: str) -> bool:
Expand All @@ -1219,6 +1256,7 @@ def load_layer_norm_weight(name: str, loaded_weight: torch.Tensor,
default_weight_loader)
weight_loader = weight_loader_with_alias(name)(weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return

def load_basic_weight(name: str, loaded_weight: torch.Tensor,
Expand All @@ -1230,6 +1268,7 @@ def load_basic_weight(name: str, loaded_weight: torch.Tensor,
default_weight_loader)
weight_loader = weight_loader_with_alias(name)(weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return

for name, loaded_weight in weights:
Expand Down Expand Up @@ -1258,4 +1297,4 @@ def load_basic_weight(name: str, loaded_weight: torch.Tensor,
continue

load_basic_weight(name, loaded_weight, self)
return
return loaded_params
Loading