IDWT 5x3: generalize SSE2 version for AVX2

Thanks to our macros that abstract SSE use, the functions can use
AVX2 when available (at compile time)

This brings an extra 23% speed improvement on bench_dwt in 64bit builds
with AVX2 compared to SSE2.
This commit is contained in:
Even Rouault 2017-06-21 12:12:58 +02:00
parent f6e3475cc9
commit fd0dc535ad
3 changed files with 155 additions and 67 deletions

View File

@ -52,6 +52,9 @@
#ifdef __SSSE3__
#include <tmmintrin.h>
#endif
#ifdef __AVX2__
#include <immintrin.h>
#endif
#if defined(__GNUC__)
#pragma GCC poison malloc calloc realloc free
@ -63,7 +66,16 @@
#define OPJ_WS(i) v->mem[(i)*2]
#define OPJ_WD(i) v->mem[(1+(i)*2)]
#define PARALLEL_COLS_53 8
#ifdef __AVX2__
/** Number of int32 values in a AVX2 register */
#define VREG_INT_COUNT 8
#else
/** Number of int32 values in a SSE2 register */
#define VREG_INT_COUNT 4
#endif
/** Number of columns that we can process in parallel in the vertical pass */
#define PARALLEL_COLS_53 (2*VREG_INT_COUNT)
/** @name Local data structures */
/*@{*/
@ -553,19 +565,55 @@ static void opj_idwt53_h(const opj_dwt_t *dwt,
#endif
}
#if defined(__SSE2__) && !defined(STANDARD_SLOW_VERSION)
#if (defined(__SSE2__) || defined(__AVX2__)) && !defined(STANDARD_SLOW_VERSION)
/* Conveniency macros to improve the readabilty of the formulas */
#define LOADU(x) _mm_loadu_si128((const __m128i*)(x))
#define STORE(x,y) _mm_store_si128((__m128i*)(x),(y))
#define ADD(x,y) _mm_add_epi32((x),(y))
#if __AVX2__
#define VREG __m256i
#define LOAD_CST(x) _mm256_set1_epi32(x)
#define LOAD(x) _mm256_load_si256((const VREG*)(x))
#define LOADU(x) _mm256_loadu_si256((const VREG*)(x))
#define STORE(x,y) _mm256_store_si256((VREG*)(x),(y))
#define STOREU(x,y) _mm256_storeu_si256((VREG*)(x),(y))
#define ADD(x,y) _mm256_add_epi32((x),(y))
#define SUB(x,y) _mm256_sub_epi32((x),(y))
#define SAR(x,y) _mm256_srai_epi32((x),(y))
#else
#define VREG __m128i
#define LOAD_CST(x) _mm_set1_epi32(x)
#define LOAD(x) _mm_load_si128((const VREG*)(x))
#define LOADU(x) _mm_loadu_si128((const VREG*)(x))
#define STORE(x,y) _mm_store_si128((VREG*)(x),(y))
#define STOREU(x,y) _mm_storeu_si128((VREG*)(x),(y))
#define ADD(x,y) _mm_add_epi32((x),(y))
#define SUB(x,y) _mm_sub_epi32((x),(y))
#define SAR(x,y) _mm_srai_epi32((x),(y))
#endif
#define ADD3(x,y,z) ADD(ADD(x,y),z)
#define SUB(x,y) _mm_sub_epi32((x),(y))
#define SAR(x,y) _mm_srai_epi32((x),(y))
/** Vertical inverse 5x3 wavelet transform for 8 columns, when top-most
* pixel is on even coordinate */
static void opj_idwt53_v_cas0_8cols_SSE2(
static
void opj_idwt53_v_final_memcpy(OPJ_INT32* tiledp_col,
const OPJ_INT32* tmp,
OPJ_INT32 len,
OPJ_INT32 stride)
{
OPJ_INT32 i;
for (i = 0; i < len; ++i) {
/* A memcpy(&tiledp_col[i * stride + 0],
&tmp[PARALLEL_COLS_53 * i + 0],
PARALLEL_COLS_53 * sizeof(OPJ_INT32))
would do but would be a tiny bit slower.
We can take here advantage of our knowledge of alignment */
STOREU(&tiledp_col[i * stride + 0],
LOAD(&tmp[PARALLEL_COLS_53 * i + 0]));
STOREU(&tiledp_col[i * stride + VREG_INT_COUNT],
LOAD(&tmp[PARALLEL_COLS_53 * i + VREG_INT_COUNT]));
}
}
/** Vertical inverse 5x3 wavelet transform for 8 columns in SSE2, or
* 16 in AVX2, when top-most pixel is on even coordinate */
static void opj_idwt53_v_cas0_mcols_SSE2_OR_AVX2(
OPJ_INT32* tmp,
const OPJ_INT32 sn,
const OPJ_INT32 len,
@ -576,17 +624,28 @@ static void opj_idwt53_v_cas0_8cols_SSE2(
const OPJ_INT32* in_odd = &tiledp_col[sn * stride];
OPJ_INT32 i, j;
__m128i d1c_0, d1n_0, s1n_0, s0c_0, s0n_0;
__m128i d1c_1, d1n_1, s1n_1, s0c_1, s0n_1;
const __m128i two = _mm_set1_epi32(2);
VREG d1c_0, d1n_0, s1n_0, s0c_0, s0n_0;
VREG d1c_1, d1n_1, s1n_1, s0c_1, s0n_1;
const VREG two = LOAD_CST(2);
assert(len > 1);
#if __AVX2__
assert(PARALLEL_COLS_53 == 16);
assert(VREG_INT_COUNT == 8);
#else
assert(PARALLEL_COLS_53 == 8);
assert(VREG_INT_COUNT == 4);
#endif
/* Note: loads of input even/odd values must be done in a unaligned */
/* fashion. But stores in tmp can be done with aligned store, since */
/* the temporary buffer is properly aligned */
assert((size_t)tmp % (sizeof(OPJ_INT32) * VREG_INT_COUNT) == 0);
s1n_0 = LOADU(in_even + 0);
s1n_1 = LOADU(in_even + 4);
s1n_1 = LOADU(in_even + VREG_INT_COUNT);
d1n_0 = LOADU(in_odd);
d1n_1 = LOADU(in_odd + 4);
d1n_1 = LOADU(in_odd + VREG_INT_COUNT);
/* s0n = s1n - ((d1n + 1) >> 1); <==> */
/* s0n = s1n - ((d1n + d1n + 2) >> 2); */
@ -600,29 +659,29 @@ static void opj_idwt53_v_cas0_8cols_SSE2(
s0c_1 = s0n_1;
s1n_0 = LOADU(in_even + j * stride);
s1n_1 = LOADU(in_even + j * stride + 4);
s1n_1 = LOADU(in_even + j * stride + VREG_INT_COUNT);
d1n_0 = LOADU(in_odd + j * stride);
d1n_1 = LOADU(in_odd + j * stride + 4);
d1n_1 = LOADU(in_odd + j * stride + VREG_INT_COUNT);
/*s0n = s1n - ((d1c + d1n + 2) >> 2);*/
s0n_0 = SUB(s1n_0, SAR(ADD3(d1c_0, d1n_0, two), 2));
s0n_1 = SUB(s1n_1, SAR(ADD3(d1c_1, d1n_1, two), 2));
STORE(tmp + PARALLEL_COLS_53 * (i + 0), s0c_0);
STORE(tmp + PARALLEL_COLS_53 * (i + 0) + 4, s0c_1);
STORE(tmp + PARALLEL_COLS_53 * (i + 0) + VREG_INT_COUNT, s0c_1);
/* d1c + ((s0c + s0n) >> 1) */
STORE(tmp + PARALLEL_COLS_53 * (i + 1) + 0,
ADD(d1c_0, SAR(ADD(s0c_0, s0n_0), 1)));
STORE(tmp + PARALLEL_COLS_53 * (i + 1) + 4,
STORE(tmp + PARALLEL_COLS_53 * (i + 1) + VREG_INT_COUNT,
ADD(d1c_1, SAR(ADD(s0c_1, s0n_1), 1)));
}
STORE(tmp + PARALLEL_COLS_53 * (i + 0) + 0, s0n_0);
STORE(tmp + PARALLEL_COLS_53 * (i + 0) + 4, s0n_1);
STORE(tmp + PARALLEL_COLS_53 * (i + 0) + VREG_INT_COUNT, s0n_1);
if (len & 1) {
__m128i tmp_len_minus_1;
VREG tmp_len_minus_1;
s1n_0 = LOADU(in_even + ((len - 1) / 2) * stride);
/* tmp_len_minus_1 = s1n - ((d1n + 1) >> 1); */
tmp_len_minus_1 = SUB(s1n_0, SAR(ADD3(d1n_0, d1n_0, two), 2));
@ -631,31 +690,30 @@ static void opj_idwt53_v_cas0_8cols_SSE2(
STORE(tmp + 8 * (len - 2),
ADD(d1n_0, SAR(ADD(s0n_0, tmp_len_minus_1), 1)));
s1n_1 = LOADU(in_even + ((len - 1) / 2) * stride + 4);
s1n_1 = LOADU(in_even + ((len - 1) / 2) * stride + VREG_INT_COUNT);
/* tmp_len_minus_1 = s1n - ((d1n + 1) >> 1); */
tmp_len_minus_1 = SUB(s1n_1, SAR(ADD3(d1n_1, d1n_1, two), 2));
STORE(tmp + PARALLEL_COLS_53 * (len - 1) + 4, tmp_len_minus_1);
STORE(tmp + PARALLEL_COLS_53 * (len - 1) + VREG_INT_COUNT,
tmp_len_minus_1);
/* d1n + ((s0n + tmp_len_minus_1) >> 1) */
STORE(tmp + PARALLEL_COLS_53 * (len - 2) + 4,
STORE(tmp + PARALLEL_COLS_53 * (len - 2) + VREG_INT_COUNT,
ADD(d1n_1, SAR(ADD(s0n_1, tmp_len_minus_1), 1)));
} else {
STORE(tmp + PARALLEL_COLS_53 * (len - 1) + 0, ADD(d1n_0, s0n_0));
STORE(tmp + PARALLEL_COLS_53 * (len - 1) + 4, ADD(d1n_1, s0n_1));
STORE(tmp + PARALLEL_COLS_53 * (len - 1) + 0,
ADD(d1n_0, s0n_0));
STORE(tmp + PARALLEL_COLS_53 * (len - 1) + VREG_INT_COUNT,
ADD(d1n_1, s0n_1));
}
for (i = 0; i < len; ++i) {
memcpy(&tiledp_col[i * stride],
&tmp[PARALLEL_COLS_53 * i],
PARALLEL_COLS_53 * sizeof(OPJ_INT32));
}
opj_idwt53_v_final_memcpy(tiledp_col, tmp, len, stride);
}
/** Vertical inverse 5x3 wavelet transform for 8 columns, when top-most
* pixel is on odd coordinate */
static void opj_idwt53_v_cas1_8cols_SSE2(
/** Vertical inverse 5x3 wavelet transform for 8 columns in SSE2, or
* 16 in AVX2, when top-most pixel is on odd coordinate */
static void opj_idwt53_v_cas1_mcols_SSE2_OR_AVX2(
OPJ_INT32* tmp,
const OPJ_INT32 sn,
const OPJ_INT32 len,
@ -664,15 +722,26 @@ static void opj_idwt53_v_cas1_8cols_SSE2(
{
OPJ_INT32 i, j;
__m128i s1_0, s2_0, dc_0, dn_0;
__m128i s1_1, s2_1, dc_1, dn_1;
const __m128i two = _mm_set1_epi32(2);
VREG s1_0, s2_0, dc_0, dn_0;
VREG s1_1, s2_1, dc_1, dn_1;
const VREG two = LOAD_CST(2);
const OPJ_INT32* in_even = &tiledp_col[sn * stride];
const OPJ_INT32* in_odd = &tiledp_col[0];
assert(len > 2);
#if __AVX2__
assert(PARALLEL_COLS_53 == 16);
assert(VREG_INT_COUNT == 8);
#else
assert(PARALLEL_COLS_53 == 8);
assert(VREG_INT_COUNT == 4);
#endif
/* Note: loads of input even/odd values must be done in a unaligned */
/* fashion. But stores in tmp can be done with aligned store, since */
/* the temporary buffer is properly aligned */
assert((size_t)tmp % (sizeof(OPJ_INT32) * VREG_INT_COUNT) == 0);
s1_0 = LOADU(in_even + stride);
/* in_odd[0] - ((in_even[0] + s1 + 2) >> 2); */
@ -680,30 +749,31 @@ static void opj_idwt53_v_cas1_8cols_SSE2(
SAR(ADD3(LOADU(in_even + 0), s1_0, two), 2));
STORE(tmp + PARALLEL_COLS_53 * 0, ADD(LOADU(in_even + 0), dc_0));
s1_1 = LOADU(in_even + stride + 4);
s1_1 = LOADU(in_even + stride + VREG_INT_COUNT);
/* in_odd[0] - ((in_even[0] + s1 + 2) >> 2); */
dc_1 = SUB(LOADU(in_odd + 4),
SAR(ADD3(LOADU(in_even + 4), s1_1, two), 2));
STORE(tmp + PARALLEL_COLS_53 * 0 + 4, ADD(LOADU(in_even + 4), dc_1));
dc_1 = SUB(LOADU(in_odd + VREG_INT_COUNT),
SAR(ADD3(LOADU(in_even + VREG_INT_COUNT), s1_1, two), 2));
STORE(tmp + PARALLEL_COLS_53 * 0 + VREG_INT_COUNT,
ADD(LOADU(in_even + VREG_INT_COUNT), dc_1));
for (i = 1, j = 1; i < (len - 2 - !(len & 1)); i += 2, j++) {
s2_0 = LOADU(in_even + (j + 1) * stride);
s2_1 = LOADU(in_even + (j + 1) * stride + 4);
s2_1 = LOADU(in_even + (j + 1) * stride + VREG_INT_COUNT);
/* dn = in_odd[j * stride] - ((s1 + s2 + 2) >> 2); */
dn_0 = SUB(LOADU(in_odd + j * stride),
SAR(ADD3(s1_0, s2_0, two), 2));
dn_1 = SUB(LOADU(in_odd + j * stride + 4),
dn_1 = SUB(LOADU(in_odd + j * stride + VREG_INT_COUNT),
SAR(ADD3(s1_1, s2_1, two), 2));
STORE(tmp + PARALLEL_COLS_53 * i, dc_0);
STORE(tmp + PARALLEL_COLS_53 * i + 4, dc_1);
STORE(tmp + PARALLEL_COLS_53 * i + VREG_INT_COUNT, dc_1);
/* tmp[i + 1] = s1 + ((dn + dc) >> 1); */
STORE(tmp + PARALLEL_COLS_53 * (i + 1) + 0,
ADD(s1_0, SAR(ADD(dn_0, dc_0), 1)));
STORE(tmp + PARALLEL_COLS_53 * (i + 1) + 4,
STORE(tmp + PARALLEL_COLS_53 * (i + 1) + VREG_INT_COUNT,
ADD(s1_1, SAR(ADD(dn_1, dc_1), 1)));
dc_0 = dn_0;
@ -712,43 +782,44 @@ static void opj_idwt53_v_cas1_8cols_SSE2(
s1_1 = s2_1;
}
STORE(tmp + PARALLEL_COLS_53 * i, dc_0);
STORE(tmp + PARALLEL_COLS_53 * i + 4, dc_1);
STORE(tmp + PARALLEL_COLS_53 * i + VREG_INT_COUNT, dc_1);
if (!(len & 1)) {
/*dn = in_odd[(len / 2 - 1) * stride] - ((s1 + 1) >> 1); */
dn_0 = SUB(LOADU(in_odd + (len / 2 - 1) * stride),
SAR(ADD3(s1_0, s1_0, two), 2));
dn_1 = SUB(LOADU(in_odd + (len / 2 - 1) * stride + 4),
dn_1 = SUB(LOADU(in_odd + (len / 2 - 1) * stride + VREG_INT_COUNT),
SAR(ADD3(s1_1, s1_1, two), 2));
/* tmp[len - 2] = s1 + ((dn + dc) >> 1); */
STORE(tmp + PARALLEL_COLS_53 * (len - 2) + 0,
ADD(s1_0, SAR(ADD(dn_0, dc_0), 1)));
STORE(tmp + PARALLEL_COLS_53 * (len - 2) + 4,
STORE(tmp + PARALLEL_COLS_53 * (len - 2) + VREG_INT_COUNT,
ADD(s1_1, SAR(ADD(dn_1, dc_1), 1)));
STORE(tmp + PARALLEL_COLS_53 * (len - 1) + 0, dn_0);
STORE(tmp + PARALLEL_COLS_53 * (len - 1) + 4, dn_1);
STORE(tmp + PARALLEL_COLS_53 * (len - 1) + VREG_INT_COUNT, dn_1);
} else {
STORE(tmp + PARALLEL_COLS_53 * (len - 1) + 0, ADD(s1_0, dc_0));
STORE(tmp + PARALLEL_COLS_53 * (len - 1) + 4, ADD(s1_1, dc_1));
STORE(tmp + PARALLEL_COLS_53 * (len - 1) + VREG_INT_COUNT,
ADD(s1_1, dc_1));
}
for (i = 0; i < len; ++i) {
memcpy(&tiledp_col[i * stride],
&tmp[PARALLEL_COLS_53 * i],
PARALLEL_COLS_53 * sizeof(OPJ_INT32));
}
opj_idwt53_v_final_memcpy(tiledp_col, tmp, len, stride);
}
#undef VREG
#undef LOAD_CST
#undef LOADU
#undef LOAD
#undef STORE
#undef STOREU
#undef ADD
#undef ADD3
#undef SUB
#undef SAR
#endif /* defined(__SSE2__) && !defined(STANDARD_SLOW_VERSION) */
#endif /* (defined(__SSE2__) || defined(__AVX2__)) && !defined(STANDARD_SLOW_VERSION) */
#if !defined(STANDARD_SLOW_VERSION)
/** Vertical inverse 5x3 wavelet transform for one column, when top-most
@ -873,11 +944,11 @@ static void opj_idwt53_v(const opj_dwt_t *dwt,
if (dwt->cas == 0) {
/* If len == 1, unmodified value */
#if __SSE2__
#if (defined(__SSE2__) || defined(__AVX2__))
if (len > 1 && nb_cols == PARALLEL_COLS_53) {
/* Same as below general case, except that thanks to SSE2 */
/* we can efficently process 8 columns in parallel */
opj_idwt53_v_cas0_8cols_SSE2(dwt->mem, sn, len, tiledp_col, stride);
/* Same as below general case, except that thanks to SSE2/AVX2 */
/* we can efficently process 8/16 columns in parallel */
opj_idwt53_v_cas0_mcols_SSE2_OR_AVX2(dwt->mem, sn, len, tiledp_col, stride);
return;
}
#endif
@ -916,11 +987,11 @@ static void opj_idwt53_v(const opj_dwt_t *dwt,
return;
}
#ifdef __SSE2__
#if (defined(__SSE2__) || defined(__AVX2__))
if (len > 2 && nb_cols == PARALLEL_COLS_53) {
/* Same as below general case, except that thanks to SSE2 */
/* we can efficently process 8 columns in parallel */
opj_idwt53_v_cas1_8cols_SSE2(dwt->mem, sn, len, tiledp_col, stride);
/* Same as below general case, except that thanks to SSE2/AVX2 */
/* we can efficently process 8/16 columns in parallel */
opj_idwt53_v_cas1_mcols_SSE2_OR_AVX2(dwt->mem, sn, len, tiledp_col, stride);
return;
}
#endif
@ -1291,7 +1362,7 @@ static OPJ_BOOL opj_dwt_decode_tile(opj_thread_pool_t* tp,
/* since for the vertical pass */
/* we process PARALLEL_COLS_53 columns at a time */
h_mem_size *= PARALLEL_COLS_53 * sizeof(OPJ_INT32);
h.mem = (OPJ_INT32*)opj_aligned_malloc(h_mem_size);
h.mem = (OPJ_INT32*)opj_aligned_32_malloc(h_mem_size);
if (! h.mem) {
/* FIXME event manager error callback */
return OPJ_FALSE;
@ -1348,7 +1419,7 @@ static OPJ_BOOL opj_dwt_decode_tile(opj_thread_pool_t* tp,
if (j == (num_jobs - 1U)) { /* this will take care of the overflow */
job->max_j = rh;
}
job->h.mem = (OPJ_INT32*)opj_aligned_malloc(h_mem_size);
job->h.mem = (OPJ_INT32*)opj_aligned_32_malloc(h_mem_size);
if (!job->h.mem) {
/* FIXME event manager error callback */
opj_thread_pool_wait_completion(tp, 0);
@ -1403,7 +1474,7 @@ static OPJ_BOOL opj_dwt_decode_tile(opj_thread_pool_t* tp,
if (j == (num_jobs - 1U)) { /* this will take care of the overflow */
job->max_j = rw;
}
job->v.mem = (OPJ_INT32*)opj_aligned_malloc(h_mem_size);
job->v.mem = (OPJ_INT32*)opj_aligned_32_malloc(h_mem_size);
if (!job->v.mem) {
/* FIXME event manager error callback */
opj_thread_pool_wait_completion(tp, 0);

View File

@ -213,6 +213,15 @@ void * opj_aligned_realloc(void *ptr, size_t size)
return opj_aligned_realloc_n(ptr, 16U, size);
}
void *opj_aligned_32_malloc(size_t size)
{
return opj_aligned_alloc_n(32U, size);
}
void * opj_aligned_32_realloc(void *ptr, size_t size)
{
return opj_aligned_realloc_n(ptr, 32U, size);
}
void opj_aligned_free(void* ptr)
{
#if defined(OPJ_HAVE_POSIX_MEMALIGN) || defined(OPJ_HAVE_MEMALIGN)

View File

@ -71,6 +71,14 @@ void * opj_aligned_malloc(size_t size);
void * opj_aligned_realloc(void *ptr, size_t size);
void opj_aligned_free(void* ptr);
/**
Allocate memory aligned to a 32 byte boundary
@param size Bytes to allocate
@return Returns a void pointer to the allocated space, or NULL if there is insufficient memory available
*/
void * opj_aligned_32_malloc(size_t size);
void * opj_aligned_32_realloc(void *ptr, size_t size);
/**
Reallocate memory blocks.
@param m Pointer to previously allocated memory block