Skip to content

Commit dc56f11

Browse files
imkerolk-chen
authored andcommitted
[Bugfix] Fix Qwen2.5-Omni M-RoPE position ids generation (vllm-project#16878)
Signed-off-by: imkero <kerorek@outlook.com>
1 parent d12ce9c commit dc56f11

File tree

1 file changed

+13
-15
lines changed

1 file changed

+13
-15
lines changed

vllm/model_executor/layers/rotary_embedding.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,6 +1209,7 @@ def _omni_get_input_positions_tensor(
12091209
video_token_id = thinker_config.video_token_index
12101210
audio_start_token_id = thinker_config.audio_start_token_id
12111211
audio_end_token_id = thinker_config.audio_end_token_id
1212+
vision_start_token_id = thinker_config.vision_start_token_id
12121213
vision_end_token_id = thinker_config.vision_end_token_id
12131214
seconds_per_chunk = thinker_config.seconds_per_chunk
12141215
spatial_merge_size = thinker_config.vision_config.spatial_merge_size
@@ -1238,8 +1239,15 @@ def _omni_get_input_positions_tensor(
12381239
if src_item[idx] not in [
12391240
audio_token_id, video_token_id, image_token_id
12401241
]:
1241-
if src_item[idx] == vision_end_token_id and use_audio_in_video:
1242-
start_idx -= 1
1242+
if use_audio_in_video and idx > 0:
1243+
if src_item[idx] == vision_end_token_id and \
1244+
src_item[idx - 1] == audio_end_token_id:
1245+
# processing the <|audio_eos|> before <|vision_eos|>
1246+
start_idx -= 1
1247+
elif src_item[idx] == audio_start_token_id and \
1248+
src_item[idx - 1] == vision_start_token_id:
1249+
# processing the <|audio_bos|> after <|vision_eos|>
1250+
start_idx -= 1
12431251
new_src_item.append(src_item[idx])
12441252
llm_pos_ids = torch.tensor([start_idx],
12451253
dtype=torch.long).expand(3, -1)
@@ -1297,11 +1305,6 @@ def _omni_get_input_positions_tensor(
12971305
tokens_per_second).long()
12981306
t_index_split_chunk = cls._split_list_into_ranges(
12991307
t_index, t_ntoken_per_chunk)
1300-
new_src_item.extend([audio_start_token_id])
1301-
start_idx -= 1
1302-
llm_pos_ids_list.extend([
1303-
torch.tensor([start_idx], dtype=torch.long).expand(3, -1)
1304-
] * 1)
13051308
place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2
13061309
pure_audio_len = place_num - 2
13071310
added_audio_len = 0
@@ -1312,21 +1315,21 @@ def _omni_get_input_positions_tensor(
13121315
new_src_item.extend([video_token_id] *
13131316
vision_ntoken_per_chunk)
13141317
vision_llm_pos_ids_list = cls._get_llm_pos_ids_for_vision(
1315-
start_idx + 1, video_idx, spatial_merge_size, t_chunk,
1318+
start_idx, video_idx, spatial_merge_size, t_chunk,
13161319
grid_hs, grid_ws).split(1, dim=1)
13171320
llm_pos_ids_list.extend(vision_llm_pos_ids_list)
13181321
new_src_item.extend(
13191322
min(t_ntoken_per_chunk, pure_audio_len -
13201323
added_audio_len) * [audio_token_id])
13211324
audio_start_idx = start_idx if len(
13221325
audio_llm_pos_ids_list
1323-
) == 0 else audio_llm_pos_ids_list[-1][0].item()
1326+
) == 0 else audio_llm_pos_ids_list[-1][0].item() + 1
13241327
if min(t_ntoken_per_chunk,
13251328
pure_audio_len - added_audio_len) > 0:
13261329
audio_llm_pos_ids_list = (torch.arange(
13271330
min(t_ntoken_per_chunk, pure_audio_len -
13281331
added_audio_len)).expand(3, -1) +
1329-
audio_start_idx + 1).split(
1332+
audio_start_idx).split(
13301333
1, dim=1)
13311334
else:
13321335
audio_llm_pos_ids_list = []
@@ -1341,11 +1344,6 @@ def _omni_get_input_positions_tensor(
13411344
3, -1) + llm_pos_ids_list[-1].max() + 1).split(
13421345
1, dim=1)
13431346
llm_pos_ids_list.extend(audio_llm_pos_ids_list)
1344-
llm_pos_ids_list.extend([
1345-
torch.tensor(
1346-
[llm_pos_ids_list[-1].max() + 1] * 3).unsqueeze(1)
1347-
] * 1)
1348-
new_src_item.extend([audio_end_token_id])
13491347
audio_idx += 1
13501348
video_idx += 1
13511349
# move to the next token

0 commit comments

Comments
 (0)