@@ -87,3 +87,63 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
87
87
ref_out = (as_float32_tensor (x ) * ref_iscale ).clamp (
88
88
fp8_traits_min , fp8_traits_max ).to (FP8_DTYPE )
89
89
return ref_out , ref_scale .view ((1 , ))
90
+
91
+
92
+ def native_w8a8_block_matmul (A : torch .Tensor , B : torch .Tensor ,
93
+ As : torch .Tensor , Bs : torch .Tensor , block_size ,
94
+ output_dtype ):
95
+ """This function performs matrix multiplication with block-wise
96
+ quantization using native torch.
97
+ It is agnostic to the input data type and can be used for both int8 and
98
+ fp8 data types.
99
+
100
+ It takes two input tensors `A` and `B` (int8) with scales `As` and
101
+ `Bs` (float32).
102
+ The output is returned in the specified `output_dtype`.
103
+ """
104
+ A = A .to (torch .float32 )
105
+ B = B .to (torch .float32 )
106
+ assert A .shape [- 1 ] == B .shape [- 1 ]
107
+ assert B .ndim == 2 and B .is_contiguous () and Bs .ndim == 2
108
+ assert len (block_size ) == 2
109
+ block_n , block_k = block_size [0 ], block_size [1 ]
110
+ assert (A .shape [- 1 ] + block_k - 1 ) // block_k == As .shape [- 1 ]
111
+ assert A .shape [:- 1 ] == As .shape [:- 1 ]
112
+
113
+ M = A .numel () // A .shape [- 1 ]
114
+ N , K = B .shape
115
+ origin_C_shape = A .shape [:- 1 ] + (N , )
116
+ A = A .reshape (M , A .shape [- 1 ])
117
+ As = As .reshape (M , As .shape [- 1 ])
118
+ n_tiles = (N + block_n - 1 ) // block_n
119
+ k_tiles = (K + block_k - 1 ) // block_k
120
+ assert n_tiles == Bs .shape [0 ]
121
+ assert k_tiles == Bs .shape [1 ]
122
+
123
+ C_shape = (M , N )
124
+ C = torch .zeros (C_shape , dtype = torch .float32 , device = A .device )
125
+
126
+ A_tiles = [
127
+ A [:, i * block_k :min ((i + 1 ) * block_k , K )] for i in range (k_tiles )
128
+ ]
129
+ B_tiles = [[
130
+ B [
131
+ j * block_n :min ((j + 1 ) * block_n , N ),
132
+ i * block_k :min ((i + 1 ) * block_k , K ),
133
+ ] for i in range (k_tiles )
134
+ ] for j in range (n_tiles )]
135
+ C_tiles = [
136
+ C [:, j * block_n :min ((j + 1 ) * block_n , N )] for j in range (n_tiles )
137
+ ]
138
+ As_tiles = [As [:, i :i + 1 ] for i in range (k_tiles )]
139
+
140
+ for i in range (k_tiles ):
141
+ for j in range (n_tiles ):
142
+ a = A_tiles [i ]
143
+ b = B_tiles [j ][i ]
144
+ c = C_tiles [j ]
145
+ s = As_tiles [i ] * Bs [j ][i ]
146
+ c [:, :] += torch .matmul (a , b .t ()) * s
147
+
148
+ C = C .reshape (origin_C_shape ).to (output_dtype )
149
+ return C
0 commit comments