Skip to content

Commit b720817

Browse files
committed
[Bugfix][Misc] Add a defensive check before importing triton
Signed-off-by: Mengqing Cao <cmq0113@163.com>
1 parent f690372 commit b720817

25 files changed

+301
-99
lines changed

benchmarks/kernels/benchmark_lora.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,14 @@
1717
from utils import ArgPool, Bench, CudaGraphBenchParams
1818
from weight_shapes import WEIGHT_SHAPES
1919

20-
from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink
21-
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
20+
from vllm.triton_utils import HAS_TRITON
21+
22+
if HAS_TRITON:
23+
from vllm.lora.ops.triton_ops import (LoRAKernelMeta, lora_expand,
24+
lora_shrink)
25+
from vllm.lora.ops.triton_ops.utils import (_LORA_A_PTR_DICT,
26+
_LORA_B_PTR_DICT)
27+
2228
from vllm.utils import FlexibleArgumentParser
2329

2430
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())

benchmarks/kernels/benchmark_moe.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@
1010

1111
import ray
1212
import torch
13-
import triton
13+
14+
from vllm.triton_utils import HAS_TRITON
15+
16+
if HAS_TRITON:
17+
import triton
18+
1419
from ray.experimental.tqdm_ray import tqdm
1520
from transformers import AutoConfig
1621

benchmarks/kernels/benchmark_rmsnorm.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@
44
from typing import Optional, Union
55

66
import torch
7-
import triton
7+
8+
from vllm.triton_utils import HAS_TRITON
9+
10+
if HAS_TRITON:
11+
import triton
12+
813
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
914
from torch import nn
1015

benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@
66
# Import DeepGEMM functions
77
import deep_gemm
88
import torch
9-
import triton
9+
10+
from vllm.triton_utils import HAS_TRITON
11+
12+
if HAS_TRITON:
13+
import triton
14+
1015
from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor
1116

1217
# Import vLLM functions

tests/kernels/test_flashmla.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55

66
import pytest
77
import torch
8-
import triton
8+
9+
from vllm.triton_utils.importing import HAS_TRITON
10+
11+
if HAS_TRITON:
12+
import triton
913

1014
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
1115
get_mla_metadata,

vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import torch
4-
import triton
5-
import triton.language as tl
4+
5+
from vllm.triton_utils import HAS_TRITON
6+
7+
if HAS_TRITON:
8+
import triton
9+
import triton.language as tl
10+
11+
from vllm.triton_utils import triton_heuristics_decorator, triton_jit_decorator
612

713

814
def blocksparse_flash_attn_varlen_fwd(
@@ -122,7 +128,7 @@ def blocksparse_flash_attn_varlen_fwd(
122128
return out
123129

124130

125-
@triton.jit
131+
@triton_jit_decorator
126132
def _fwd_kernel_inner(
127133
acc,
128134
l_i,
@@ -227,11 +233,11 @@ def _fwd_kernel_inner(
227233
return acc, l_i, m_i
228234

229235

230-
@triton.heuristics({
236+
@triton_heuristics_decorator({
231237
"M_LT_N":
232238
lambda kwargs: kwargs["BLOCK_M"] < kwargs["BLOCK_N"],
233239
})
234-
@triton.jit
240+
@triton_jit_decorator
235241
def _fwd_kernel_batch_inference(
236242
Q,
237243
K,

vllm/attention/ops/blocksparse_attention/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88

99
import numpy as np
1010
import torch
11-
import triton
11+
12+
from vllm.triton_utils import HAS_TRITON
13+
14+
if HAS_TRITON:
15+
import triton
1216

1317

1418
class csr_matrix:

vllm/attention/ops/chunked_prefill_paged_decode.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,23 @@
77
# - Thomas Parnell <tpa@zurich.ibm.com>
88

99
import torch
10-
import triton
11-
import triton.language as tl
10+
11+
from vllm.triton_utils import HAS_TRITON
12+
13+
if HAS_TRITON:
14+
import triton
15+
import triton.language as tl
16+
from vllm.triton_utils import triton_jit_decorator
1217

1318
from .prefix_prefill import context_attention_fwd
1419

1520

16-
@triton.jit
21+
@triton_jit_decorator
1722
def cdiv_fn(x, y):
1823
return (x + y - 1) // y
1924

2025

21-
@triton.jit
26+
@triton_jit_decorator
2227
def kernel_paged_attention_2d(
2328
output_ptr, # [num_tokens, num_query_heads, head_size]
2429
query_ptr, # [num_tokens, num_query_heads, head_size]

vllm/attention/ops/prefix_prefill.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,12 @@
44
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
55

66
import torch
7-
import triton
8-
import triton.language as tl
7+
8+
from vllm.triton_utils import HAS_TRITON
9+
10+
if HAS_TRITON:
11+
import triton
12+
import triton.language as tl
913

1014
from vllm.platforms import current_platform
1115

@@ -16,7 +20,7 @@
1620
# To check compatibility
1721
IS_TURING = current_platform.get_device_capability() == (7, 5)
1822

19-
if triton.__version__ >= "2.1.0":
23+
if HAS_TRITON and triton.__version__ >= "2.1.0":
2024

2125
@triton.jit
2226
def _fwd_kernel(

vllm/attention/ops/triton_decode_attention.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,14 @@
3030

3131
import logging
3232

33-
import triton
34-
import triton.language as tl
33+
from vllm.triton_utils import HAS_TRITON
34+
35+
if HAS_TRITON:
36+
import triton
37+
import triton.language as tl
3538

3639
from vllm.platforms import current_platform
40+
from vllm.triton_utils import triton_jit_decorator
3741

3842
is_hip_ = current_platform.is_rocm()
3943

@@ -46,13 +50,13 @@
4650
"can be ignored.")
4751

4852

49-
@triton.jit
53+
@triton_jit_decorator
5054
def tanh(x):
5155
# Tanh is just a scaled sigmoid
5256
return 2 * tl.sigmoid(2 * x) - 1
5357

5458

55-
@triton.jit
59+
@triton_jit_decorator
5660
def _fwd_kernel_stage1(
5761
Q,
5862
K_Buffer,
@@ -228,7 +232,7 @@ def _decode_att_m_fwd(
228232
)
229233

230234

231-
@triton.jit
235+
@triton_jit_decorator
232236
def _fwd_grouped_kernel_stage1(
233237
Q,
234238
K_Buffer,
@@ -468,7 +472,7 @@ def _decode_grouped_att_m_fwd(
468472
)
469473

470474

471-
@triton.jit
475+
@triton_jit_decorator
472476
def _fwd_kernel_stage2(
473477
Mid_O,
474478
o,

vllm/attention/ops/triton_flash_attention.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,46 +22,52 @@
2222
"""
2323

2424
import torch
25-
import triton
26-
import triton.language as tl
25+
26+
from vllm.triton_utils import HAS_TRITON
27+
28+
if HAS_TRITON:
29+
import triton
30+
import triton.language as tl
31+
32+
from vllm.triton_utils import triton_autotune_decorator, triton_jit_decorator
2733

2834
torch_dtype: tl.constexpr = torch.float16
2935

3036

31-
@triton.jit
37+
@triton_jit_decorator
3238
def cdiv_fn(x, y):
3339
return (x + y - 1) // y
3440

3541

36-
@triton.jit
42+
@triton_jit_decorator
3743
def max_fn(x, y):
3844
return tl.math.max(x, y)
3945

4046

41-
@triton.jit
47+
@triton_jit_decorator
4248
def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
4349
ms = tl.arange(0, m)
4450
ns = tl.arange(0, n)
4551
return philox_offset + ms[:, None] * stride + ns[None, :]
4652

4753

48-
@triton.jit
54+
@triton_jit_decorator
4955
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
5056
rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n,
5157
stride).to(tl.uint32)
5258
# TODO: use tl.randint for better performance
5359
return tl.rand(philox_seed, rng_offsets)
5460

5561

56-
@triton.jit
62+
@triton_jit_decorator
5763
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
5864
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n,
5965
stride)
6066
rng_keep = rng_output > dropout_p
6167
return rng_keep
6268

6369

64-
@triton.jit
70+
@triton_jit_decorator
6571
def load_fn(block_ptr, first, second, pad):
6672
if first and second:
6773
tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)
@@ -74,7 +80,7 @@ def load_fn(block_ptr, first, second, pad):
7480
return tensor
7581

7682

77-
@triton.jit
83+
@triton_jit_decorator
7884
def _attn_fwd_inner(
7985
acc,
8086
l_i,
@@ -208,7 +214,7 @@ def _attn_fwd_inner(
208214
return acc, l_i, m_i
209215

210216

211-
@triton.autotune(
217+
@triton_autotune_decorator(
212218
configs=[
213219
triton.Config(
214220
{
@@ -306,7 +312,7 @@ def _attn_fwd_inner(
306312
],
307313
key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'],
308314
)
309-
@triton.jit
315+
@triton_jit_decorator
310316
def attn_fwd(
311317
Q,
312318
K,

vllm/attention/ops/triton_merge_attn_states.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,14 @@
22
from typing import Optional
33

44
import torch
5-
import triton
6-
import triton.language as tl
5+
6+
from vllm.triton_utils import HAS_TRITON
7+
8+
if HAS_TRITON:
9+
import triton
10+
import triton.language as tl
11+
12+
from vllm.triton_utils import triton_jit_decorator
713

814

915
# Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
@@ -35,7 +41,7 @@ def merge_attn_states(
3541
)
3642

3743

38-
@triton.jit
44+
@triton_jit_decorator
3945
def merge_attn_states_kernel(
4046
output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
4147
output_lse, # [NUM_HEADS, NUM_TOKENS]

0 commit comments

Comments
 (0)