|
17 | 17 |
|
18 | 18 | #include "intsimdmatrix.h"
|
19 | 19 |
|
20 |
| -#if !defined(__AVX2__) |
| 20 | +#if !defined(__AVX512VNNI__) || !defined(__AVX512VL__) |
21 | 21 | # if defined(__i686__) || defined(__x86_64__)
|
22 |
| -# error Implementation only for AVX2 capable architectures |
| 22 | +# error Implementation only for AVX512VNNI capable architectures |
23 | 23 | # endif
|
24 | 24 | #else
|
25 | 25 | # include <immintrin.h>
|
@@ -73,16 +73,12 @@ static inline void MultiplyGroup(const __m256i &rep_input, const __m256i &ones,
|
73 | 73 | // Normalize the signs on rep_input, weights, so weights is always +ve.
|
74 | 74 | reps = _mm256_sign_epi8(rep_input, weights);
|
75 | 75 | weights = _mm256_sign_epi8(weights, weights);
|
76 |
| - // Multiply 32x8-bit reps by 32x8-bit weights to make 16x16-bit results, |
77 |
| - // with adjacent pairs added. |
78 |
| - weights = _mm256_maddubs_epi16(weights, reps); |
79 |
| - // Multiply 16x16-bit result by 16x16-bit ones to make 8x32-bit results, |
80 |
| - // with adjacent pairs added. What we really want is a horizontal add of |
81 |
| - // 16+16=32 bit result, but there is no such instruction, so multiply by |
82 |
| - // 16-bit ones instead. It is probably faster than all the sign-extending, |
83 |
| - // permuting and adding that would otherwise be required. |
84 |
| - weights = _mm256_madd_epi16(weights, ones); |
85 |
| - result = _mm256_add_epi32(result, weights); |
| 76 | + |
| 77 | + // VNNI instruction. It replaces 3 AVX2 instructions: |
| 78 | + //weights = _mm256_maddubs_epi16(weights, reps); |
| 79 | + //weights = _mm256_madd_epi16(weights, ones); |
| 80 | + //result = _mm256_add_epi32(result, weights); |
| 81 | + result = _mm256_dpbusd_epi32(result, weights, reps); |
86 | 82 | }
|
87 | 83 |
|
88 | 84 | // Load 64 bits into the bottom of a 128bit register.
|
|
0 commit comments