@@ -1209,6 +1209,7 @@ def _omni_get_input_positions_tensor(
1209
1209
video_token_id = thinker_config .video_token_index
1210
1210
audio_start_token_id = thinker_config .audio_start_token_id
1211
1211
audio_end_token_id = thinker_config .audio_end_token_id
1212
+ vision_start_token_id = thinker_config .vision_start_token_id
1212
1213
vision_end_token_id = thinker_config .vision_end_token_id
1213
1214
seconds_per_chunk = thinker_config .seconds_per_chunk
1214
1215
spatial_merge_size = thinker_config .vision_config .spatial_merge_size
@@ -1238,8 +1239,15 @@ def _omni_get_input_positions_tensor(
1238
1239
if src_item [idx ] not in [
1239
1240
audio_token_id , video_token_id , image_token_id
1240
1241
]:
1241
- if src_item [idx ] == vision_end_token_id and use_audio_in_video :
1242
- start_idx -= 1
1242
+ if use_audio_in_video and idx > 0 :
1243
+ if src_item [idx ] == vision_end_token_id and \
1244
+ src_item [idx - 1 ] == audio_end_token_id :
1245
+ # processing the <|audio_eos|> before <|vision_eos|>
1246
+ start_idx -= 1
1247
+ elif src_item [idx ] == audio_start_token_id and \
1248
+ src_item [idx - 1 ] == vision_start_token_id :
1249
+ # processing the <|audio_bos|> after <|vision_eos|>
1250
+ start_idx -= 1
1243
1251
new_src_item .append (src_item [idx ])
1244
1252
llm_pos_ids = torch .tensor ([start_idx ],
1245
1253
dtype = torch .long ).expand (3 , - 1 )
@@ -1297,11 +1305,6 @@ def _omni_get_input_positions_tensor(
1297
1305
tokens_per_second ).long ()
1298
1306
t_index_split_chunk = cls ._split_list_into_ranges (
1299
1307
t_index , t_ntoken_per_chunk )
1300
- new_src_item .extend ([audio_start_token_id ])
1301
- start_idx -= 1
1302
- llm_pos_ids_list .extend ([
1303
- torch .tensor ([start_idx ], dtype = torch .long ).expand (3 , - 1 )
1304
- ] * 1 )
1305
1308
place_num = (((audio_seqlen - 1 ) // 2 + 1 - 2 ) // 2 + 1 ) + 2
1306
1309
pure_audio_len = place_num - 2
1307
1310
added_audio_len = 0
@@ -1312,21 +1315,21 @@ def _omni_get_input_positions_tensor(
1312
1315
new_src_item .extend ([video_token_id ] *
1313
1316
vision_ntoken_per_chunk )
1314
1317
vision_llm_pos_ids_list = cls ._get_llm_pos_ids_for_vision (
1315
- start_idx + 1 , video_idx , spatial_merge_size , t_chunk ,
1318
+ start_idx , video_idx , spatial_merge_size , t_chunk ,
1316
1319
grid_hs , grid_ws ).split (1 , dim = 1 )
1317
1320
llm_pos_ids_list .extend (vision_llm_pos_ids_list )
1318
1321
new_src_item .extend (
1319
1322
min (t_ntoken_per_chunk , pure_audio_len -
1320
1323
added_audio_len ) * [audio_token_id ])
1321
1324
audio_start_idx = start_idx if len (
1322
1325
audio_llm_pos_ids_list
1323
- ) == 0 else audio_llm_pos_ids_list [- 1 ][0 ].item ()
1326
+ ) == 0 else audio_llm_pos_ids_list [- 1 ][0 ].item () + 1
1324
1327
if min (t_ntoken_per_chunk ,
1325
1328
pure_audio_len - added_audio_len ) > 0 :
1326
1329
audio_llm_pos_ids_list = (torch .arange (
1327
1330
min (t_ntoken_per_chunk , pure_audio_len -
1328
1331
added_audio_len )).expand (3 , - 1 ) +
1329
- audio_start_idx + 1 ).split (
1332
+ audio_start_idx ).split (
1330
1333
1 , dim = 1 )
1331
1334
else :
1332
1335
audio_llm_pos_ids_list = []
@@ -1341,11 +1344,6 @@ def _omni_get_input_positions_tensor(
1341
1344
3 , - 1 ) + llm_pos_ids_list [- 1 ].max () + 1 ).split (
1342
1345
1 , dim = 1 )
1343
1346
llm_pos_ids_list .extend (audio_llm_pos_ids_list )
1344
- llm_pos_ids_list .extend ([
1345
- torch .tensor (
1346
- [llm_pos_ids_list [- 1 ].max () + 1 ] * 3 ).unsqueeze (1 )
1347
- ] * 1 )
1348
- new_src_item .extend ([audio_end_token_id ])
1349
1347
audio_idx += 1
1350
1348
video_idx += 1
1351
1349
# move to the next token
0 commit comments