Skip to content

Commit d02c4a1

Browse files
committed
Add int64 casting for TPU worker
1 parent 7c1a09c commit d02c4a1

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

vllm/v1/worker/tpu_model_runner.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,11 @@ def __init__(
219219
self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32)
220220
self.num_reqs_paddings = _get_req_paddings(
221221
min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs)
222+
223+
max_token_idx = self.max_num_reqs * self.max_model_len - 1
224+
# if max token idx exceeds int32 max, use int64 to avoid overflow
225+
self.token_indices_dtype = np.int32 \
226+
if max_token_idx <= np.iinfo(np.int32).max else np.int64
222227

223228
def _update_num_xla_graphs(self, case_str):
224229
check_comp = self.check_recompilation and not self.enforce_eager
@@ -457,8 +462,10 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
457462
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
458463
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
459464
# where M is the max_model_len.
465+
# For long context, may need to cast to int64 to avoid overflow
460466
token_indices = (positions_np +
461-
req_indices * self.input_batch.token_ids_cpu.shape[1])
467+
req_indices.astype(self.token_indices_dtype) *
468+
self.input_batch.token_ids_cpu.shape[1])
462469

463470
# NOTE(woosuk): We use torch.index_select instead of np.take here
464471
# because torch.index_select is much faster than np.take for large

0 commit comments

Comments
 (0)