Skip to content

Commit 7e057cc

Browse files
committed
Add the needed changes to support AVX512VNNI
1 parent 91b5f68 commit 7e057cc

File tree

3 files changed

+10
-14
lines changed

3 files changed

+10
-14
lines changed

Makefile.am

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ noinst_LTLIBRARIES += libtesseract_avx512.la
171171
endif
172172

173173
if HAVE_AVX512VNNI
174-
libtesseract_avx512vnni_la_CXXFLAGS = -march=icelake-client
174+
libtesseract_avx512vnni_la_CXXFLAGS = -mavx512vnni -mavx512vl
175175
libtesseract_avx512vnni_la_CXXFLAGS += -I$(top_srcdir)/src/ccutil
176176
libtesseract_avx512vnni_la_SOURCES = src/arch/intsimdmatrixavx512vnni.cpp
177177
libtesseract_la_LIBADD += libtesseract_avx512vnni.la

configure.ac

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ case "${host_cpu}" in
157157
AC_DEFINE([HAVE_AVX512F], [1], [Enable AVX512F instructions])
158158
fi
159159

160-
AX_CHECK_COMPILE_FLAG([-march=icelake-client], [avx512vnni=true], [avx512vnni=false], [$WERROR])
160+
AX_CHECK_COMPILE_FLAG([-mavx512vnni], [avx512vnni=true], [avx512vnni=false], [$WERROR])
161161
AM_CONDITIONAL([HAVE_AVX512VNNI], $avx512vnni)
162162
if $avx512vnni; then
163163
AC_DEFINE([HAVE_AVX512VNNI], [1], [Enable AVX512VNNI instructions])

src/arch/intsimdmatrixavx512vnni.cpp

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717

1818
#include "intsimdmatrix.h"
1919

20-
#if !defined(__AVX2__)
20+
#if !defined(__AVX512VNNI__) || !defined(__AVX512VL__)
2121
# if defined(__i686__) || defined(__x86_64__)
22-
# error Implementation only for AVX2 capable architectures
22+
# error Implementation only for AVX512VNNI capable architectures
2323
# endif
2424
#else
2525
# include <immintrin.h>
@@ -73,16 +73,12 @@ static inline void MultiplyGroup(const __m256i &rep_input, const __m256i &ones,
7373
// Normalize the signs on rep_input, weights, so weights is always +ve.
7474
reps = _mm256_sign_epi8(rep_input, weights);
7575
weights = _mm256_sign_epi8(weights, weights);
76-
// Multiply 32x8-bit reps by 32x8-bit weights to make 16x16-bit results,
77-
// with adjacent pairs added.
78-
weights = _mm256_maddubs_epi16(weights, reps);
79-
// Multiply 16x16-bit result by 16x16-bit ones to make 8x32-bit results,
80-
// with adjacent pairs added. What we really want is a horizontal add of
81-
// 16+16=32 bit result, but there is no such instruction, so multiply by
82-
// 16-bit ones instead. It is probably faster than all the sign-extending,
83-
// permuting and adding that would otherwise be required.
84-
weights = _mm256_madd_epi16(weights, ones);
85-
result = _mm256_add_epi32(result, weights);
76+
77+
// VNNI instruction. It replaces 3 AVX2 instructions:
78+
//weights = _mm256_maddubs_epi16(weights, reps);
79+
//weights = _mm256_madd_epi16(weights, ones);
80+
//result = _mm256_add_epi32(result, weights);
81+
result = _mm256_dpbusd_epi32(result, weights, reps);
8682
}
8783

8884
// Load 64 bits into the bottom of a 128bit register.

0 commit comments

Comments
 (0)