Skip to content

Commit 070d1f4

Browse files
mgoindbyoung18
authored andcommitted
Add missing rocm_skinny_gemms kernel test to CI (vllm-project#17060)
Signed-off-by: mgoin <mgoin64@gmail.com>
1 parent 5d3e793 commit 070d1f4

File tree

5 files changed

+62
-65
lines changed

5 files changed

+62
-65
lines changed

tests/kernels/quant_utils.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,63 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
8787
ref_out = (as_float32_tensor(x) * ref_iscale).clamp(
8888
fp8_traits_min, fp8_traits_max).to(FP8_DTYPE)
8989
return ref_out, ref_scale.view((1, ))
90+
91+
92+
def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor,
93+
As: torch.Tensor, Bs: torch.Tensor, block_size,
94+
output_dtype):
95+
"""This function performs matrix multiplication with block-wise
96+
quantization using native torch.
97+
It is agnostic to the input data type and can be used for both int8 and
98+
fp8 data types.
99+
100+
It takes two input tensors `A` and `B` (int8) with scales `As` and
101+
`Bs` (float32).
102+
The output is returned in the specified `output_dtype`.
103+
"""
104+
A = A.to(torch.float32)
105+
B = B.to(torch.float32)
106+
assert A.shape[-1] == B.shape[-1]
107+
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
108+
assert len(block_size) == 2
109+
block_n, block_k = block_size[0], block_size[1]
110+
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1]
111+
assert A.shape[:-1] == As.shape[:-1]
112+
113+
M = A.numel() // A.shape[-1]
114+
N, K = B.shape
115+
origin_C_shape = A.shape[:-1] + (N, )
116+
A = A.reshape(M, A.shape[-1])
117+
As = As.reshape(M, As.shape[-1])
118+
n_tiles = (N + block_n - 1) // block_n
119+
k_tiles = (K + block_k - 1) // block_k
120+
assert n_tiles == Bs.shape[0]
121+
assert k_tiles == Bs.shape[1]
122+
123+
C_shape = (M, N)
124+
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
125+
126+
A_tiles = [
127+
A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles)
128+
]
129+
B_tiles = [[
130+
B[
131+
j * block_n:min((j + 1) * block_n, N),
132+
i * block_k:min((i + 1) * block_k, K),
133+
] for i in range(k_tiles)
134+
] for j in range(n_tiles)]
135+
C_tiles = [
136+
C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles)
137+
]
138+
As_tiles = [As[:, i:i + 1] for i in range(k_tiles)]
139+
140+
for i in range(k_tiles):
141+
for j in range(n_tiles):
142+
a = A_tiles[i]
143+
b = B_tiles[j][i]
144+
c = C_tiles[j]
145+
s = As_tiles[i] * Bs[j][i]
146+
c[:, :] += torch.matmul(a, b.t()) * s
147+
148+
C = C.reshape(origin_C_shape).to(output_dtype)
149+
return C

tests/kernels/quantization/test_block_fp8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytest
77
import torch
88

9-
from tests.kernels.utils_block import native_w8a8_block_matmul
9+
from tests.kernels.quant_utils import native_w8a8_block_matmul
1010
from vllm.config import VllmConfig, set_current_vllm_config
1111
from vllm.model_executor.layers.activation import SiluAndMul
1212
from vllm.model_executor.layers.fused_moe import fused_moe

tests/kernels/quantization/test_block_int8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytest
77
import torch
88

9-
from tests.kernels.utils_block import native_w8a8_block_matmul
9+
from tests.kernels.quant_utils import native_w8a8_block_matmul
1010
from vllm.config import VllmConfig, set_current_vllm_config
1111
from vllm.model_executor.layers.activation import SiluAndMul
1212
from vllm.model_executor.layers.fused_moe import fused_moe

tests/kernels/utils_block.py

Lines changed: 0 additions & 63 deletions
This file was deleted.

0 commit comments

Comments
 (0)