@@ -219,6 +219,11 @@ def __init__(
219
219
self .arange_np = np .arange (self .max_num_tokens , dtype = np .int32 )
220
220
self .num_reqs_paddings = _get_req_paddings (
221
221
min_req_size = MIN_NUM_SEQS , max_req_size = self .max_num_reqs )
222
+
223
+ max_token_idx = self .max_num_reqs * self .max_model_len - 1
224
+ # if max token idx exceeds int32 max, use int64 to avoid overflow
225
+ self .token_indices_dtype = np .int32 \
226
+ if max_token_idx <= np .iinfo (np .int32 ).max else np .int64
222
227
223
228
def _update_num_xla_graphs (self , case_str ):
224
229
check_comp = self .check_recompilation and not self .enforce_eager
@@ -457,8 +462,10 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
457
462
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
458
463
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
459
464
# where M is the max_model_len.
465
+ # For long context, may need to cast to int64 to avoid overflow
460
466
token_indices = (positions_np +
461
- req_indices * self .input_batch .token_ids_cpu .shape [1 ])
467
+ req_indices .astype (self .token_indices_dtype ) *
468
+ self .input_batch .token_ids_cpu .shape [1 ])
462
469
463
470
# NOTE(woosuk): We use torch.index_select instead of np.take here
464
471
# because torch.index_select is much faster than np.take for large
0 commit comments