Forward DWT 9-7: major speed up by vectorizing vertical pass

`bench_dwt -I -encode` times goes from 8.6s to 2.1s
This commit is contained in:
Even Rouault 2020-05-22 23:57:51 +02:00
parent a38e970fa5
commit 1e931fdb36
No known key found for this signature in database
GPG Key ID: 33EBBFC47B3DD87D
1 changed files with 250 additions and 86 deletions

View File

@ -125,13 +125,6 @@ static void opj_dwt_deinterleave_h(const OPJ_INT32 * OPJ_RESTRICT a,
OPJ_INT32 * OPJ_RESTRICT b,
OPJ_INT32 dn,
OPJ_INT32 sn, OPJ_INT32 cas);
/**
Forward lazy transform (vertical)
*/
static void opj_dwt_deinterleave_v(const OPJ_INT32 * OPJ_RESTRICT a,
OPJ_INT32 * OPJ_RESTRICT b,
OPJ_INT32 dn,
OPJ_INT32 sn, OPJ_UINT32 x, OPJ_INT32 cas);
/**
Forward 9-7 wavelet transform in 1-D
@ -252,35 +245,6 @@ static void opj_dwt_deinterleave_h(const OPJ_INT32 * OPJ_RESTRICT a,
}
}
/* <summary> */
/* Forward lazy transform (vertical). */
/* </summary> */
static void opj_dwt_deinterleave_v(const OPJ_INT32 * OPJ_RESTRICT a,
OPJ_INT32 * OPJ_RESTRICT b,
OPJ_INT32 dn,
OPJ_INT32 sn, OPJ_UINT32 x, OPJ_INT32 cas)
{
OPJ_INT32 i = sn;
OPJ_INT32 * OPJ_RESTRICT l_dest = b;
const OPJ_INT32 * OPJ_RESTRICT l_src = a + cas;
while (i--) {
*l_dest = *l_src;
l_dest += x;
l_src += 2;
} /* b[i*x]=a[2*i+cas]; */
l_dest = b + (OPJ_SIZE_T)sn * (OPJ_SIZE_T)x;
l_src = a + 1 - cas;
i = dn;
while (i--) {
*l_dest = *l_src;
l_dest += x;
l_src += 2;
} /*b[(sn+i)*x]=a[(2*i+1-cas)];*/
}
#ifdef STANDARD_SLOW_VERSION
/* <summary> */
/* Inverse lazy transform (horizontal). */
@ -989,36 +953,85 @@ static void opj_idwt53_v(const opj_dwt_t *dwt,
#endif
}
#if 0
static void opj_dwt_encode_step1(OPJ_FLOAT32* fw,
OPJ_UINT32 start,
OPJ_UINT32 end,
const OPJ_FLOAT32 c)
{
OPJ_UINT32 i;
for (i = start; i < end; ++i) {
fw[i * 2] *= c;
OPJ_UINT32 i = 0;
for (; i < end; ++i) {
fw[0] *= c;
fw += 2;
}
}
#else
static void opj_dwt_encode_step1_combined(OPJ_FLOAT32* fw,
OPJ_UINT32 iters_c1,
OPJ_UINT32 iters_c2,
const OPJ_FLOAT32 c1,
const OPJ_FLOAT32 c2)
{
OPJ_UINT32 i = 0;
const OPJ_UINT32 iters_common = opj_uint_min(iters_c1, iters_c2);
assert((((OPJ_SIZE_T)fw) & 0xf) == 0);
assert(opj_int_abs((OPJ_INT32)iters_c1 - (OPJ_INT32)iters_c2) <= 1);
for (; i + 3 < iters_common; i += 4) {
#ifdef __SSE__
const __m128 vcst = _mm_set_ps(c2, c1, c2, c1);
*(__m128*)fw = _mm_mul_ps(*(__m128*)fw, vcst);
*(__m128*)(fw + 4) = _mm_mul_ps(*(__m128*)(fw + 4), vcst);
#else
fw[0] *= c1;
fw[1] *= c2;
fw[2] *= c1;
fw[3] *= c2;
fw[4] *= c1;
fw[5] *= c2;
fw[6] *= c1;
fw[7] *= c2;
#endif
fw += 8;
}
for (; i < iters_common; i++) {
fw[0] *= c1;
fw[1] *= c2;
fw += 2;
}
if (i < iters_c1) {
fw[0] *= c1;
} else if (i < iters_c2) {
fw[1] *= c2;
}
}
#endif
static void opj_dwt_encode_step2(OPJ_FLOAT32* fl, OPJ_FLOAT32* fw,
OPJ_UINT32 start,
OPJ_UINT32 end,
OPJ_UINT32 m,
OPJ_FLOAT32 c)
{
OPJ_UINT32 i;
OPJ_UINT32 imax = opj_uint_min(end, m);
if (start > 0) {
fw += 2 * start;
fl = fw - 2;
}
for (i = start; i < imax; ++i) {
if (imax > 0) {
fw[-1] += (fl[0] + fw[0]) * c;
fl = fw;
fw += 2;
i = 1;
for (; i + 3 < imax; i += 4) {
fw[-1] += (fw[-2] + fw[0]) * c;
fw[1] += (fw[0] + fw[2]) * c;
fw[3] += (fw[2] + fw[4]) * c;
fw[5] += (fw[4] + fw[6]) * c;
fw += 8;
}
for (; i < imax; ++i) {
fw[-1] += (fw[-2] + fw[0]) * c;
fw += 2;
}
}
if (m < end) {
assert(m + 1 == end);
fw[-1] += (2 * fl[0]) * c;
fw[-1] += (2 * fw[-2]) * c;
}
}
@ -1027,39 +1040,50 @@ static void opj_dwt_encode_1_real(void *aIn, OPJ_INT32 dn, OPJ_INT32 sn,
{
OPJ_FLOAT32* w = (OPJ_FLOAT32*)aIn;
OPJ_INT32 a, b;
assert(dn + sn > 1);
if (cas == 0) {
if (!((dn > 0) || (sn > 1))) {
return;
}
a = 0;
b = 1;
} else {
if (!((sn > 0) || (dn > 1))) {
return;
}
a = 1;
b = 0;
}
opj_dwt_encode_step2(w + a, w + b + 1,
0, (OPJ_UINT32)dn,
(OPJ_UINT32)dn,
(OPJ_UINT32)opj_int_min(dn, sn - b),
opj_dwt_alpha);
opj_dwt_encode_step2(w + b, w + a + 1,
0, (OPJ_UINT32)sn,
(OPJ_UINT32)sn,
(OPJ_UINT32)opj_int_min(sn, dn - a),
opj_dwt_beta);
opj_dwt_encode_step2(w + a, w + b + 1,
0, (OPJ_UINT32)dn,
(OPJ_UINT32)dn,
(OPJ_UINT32)opj_int_min(dn, sn - b),
opj_dwt_gamma);
opj_dwt_encode_step2(w + b, w + a + 1,
0, (OPJ_UINT32)sn,
(OPJ_UINT32)sn,
(OPJ_UINT32)opj_int_min(sn, dn - a),
opj_dwt_delta);
opj_dwt_encode_step1(w + b, 0, (OPJ_UINT32)dn,
#if 0
opj_dwt_encode_step1(w + b, (OPJ_UINT32)dn,
opj_K);
opj_dwt_encode_step1(w + a, 0, (OPJ_UINT32)sn,
opj_dwt_encode_step1(w + a, (OPJ_UINT32)sn,
opj_invK);
#else
if (a == 0) {
opj_dwt_encode_step1_combined(w,
(OPJ_UINT32)sn,
(OPJ_UINT32)dn,
opj_invK,
opj_K);
} else {
opj_dwt_encode_step1_combined(w,
(OPJ_UINT32)dn,
(OPJ_UINT32)sn,
opj_K,
opj_invK);
}
#endif
}
static void opj_dwt_encode_stepsize(OPJ_INT32 stepsize, OPJ_INT32 numbps,
@ -1143,6 +1167,9 @@ void opj_dwt_encode_and_deinterleave_h_one_row_real(void* rowIn,
OPJ_FLOAT32* OPJ_RESTRICT tmp = (OPJ_FLOAT32*)tmpIn;
const OPJ_INT32 sn = (OPJ_INT32)((width + (even ? 1 : 0)) >> 1);
const OPJ_INT32 dn = (OPJ_INT32)(width - (OPJ_UINT32)sn);
if (width == 1) {
return;
}
memcpy(tmp, row, width * sizeof(OPJ_FLOAT32));
opj_dwt_encode_1_real(tmp, dn, sn, even ? 0 : 1);
opj_dwt_deinterleave_h((OPJ_INT32 * OPJ_RESTRICT)tmp,
@ -1258,29 +1285,49 @@ static INLINE void opj_dwt_deinterleave_v_cols(
OPJ_INT32 cas,
OPJ_UINT32 cols)
{
OPJ_INT32 k;
OPJ_INT32 i = sn;
OPJ_INT32 * OPJ_RESTRICT l_dest = dst;
const OPJ_INT32 * OPJ_RESTRICT l_src = src + cas * NB_ELTS_V8;
OPJ_UINT32 c;
while (i--) {
for (c = 0; c < cols; c++) {
l_dest[c] = l_src[c];
for (k = 0; k < 2; k++) {
while (i--) {
if (cols == NB_ELTS_V8) {
memcpy(l_dest, l_src, NB_ELTS_V8 * sizeof(OPJ_INT32));
} else {
c = 0;
switch (cols) {
case 7:
l_dest[c] = l_src[c];
c++; /* fallthru */
case 6:
l_dest[c] = l_src[c];
c++; /* fallthru */
case 5:
l_dest[c] = l_src[c];
c++; /* fallthru */
case 4:
l_dest[c] = l_src[c];
c++; /* fallthru */
case 3:
l_dest[c] = l_src[c];
c++; /* fallthru */
case 2:
l_dest[c] = l_src[c];
c++; /* fallthru */
default:
l_dest[c] = l_src[c];
break;
}
}
l_dest += stride_width;
l_src += 2 * NB_ELTS_V8;
}
l_dest += stride_width;
l_src += 2 * NB_ELTS_V8;
}
l_dest = dst + (OPJ_SIZE_T)sn * (OPJ_SIZE_T)stride_width;
l_src = src + (1 - cas) * NB_ELTS_V8;
i = dn;
while (i--) {
for (c = 0; c < cols; c++) {
l_dest[c] = l_src[c];
}
l_dest += stride_width;
l_src += 2 * NB_ELTS_V8;
l_dest = dst + (OPJ_SIZE_T)sn * (OPJ_SIZE_T)stride_width;
l_src = src + (1 - cas) * NB_ELTS_V8;
i = dn;
}
}
@ -1517,6 +1564,84 @@ static void opj_dwt_encode_and_deinterleave_v(
}
}
static void opj_v8dwt_encode_step1(OPJ_FLOAT32* fw,
OPJ_UINT32 end,
const OPJ_FLOAT32 cst)
{
OPJ_UINT32 i;
#ifdef __SSE__
__m128* vw = (__m128*) fw;
const __m128 vcst = _mm_set1_ps(cst);
for (i = 0; i < end; ++i) {
vw[0] = _mm_mul_ps(vw[0], vcst);
vw[1] = _mm_mul_ps(vw[1], vcst);
vw += 2 * (NB_ELTS_V8 * sizeof(OPJ_FLOAT32) / sizeof(__m128));
}
#else
OPJ_UINT32 c;
for (i = 0; i < end; ++i) {
for (c = 0; c < NB_ELTS_V8; c++) {
fw[i * 2 * NB_ELTS_V8 + c] *= cst;
}
}
#endif
}
static void opj_v8dwt_encode_step2(OPJ_FLOAT32* fl, OPJ_FLOAT32* fw,
OPJ_UINT32 end,
OPJ_UINT32 m,
OPJ_FLOAT32 cst)
{
OPJ_UINT32 i;
OPJ_UINT32 imax = opj_uint_min(end, m);
#ifdef __SSE__
__m128* vw = (__m128*) fw;
__m128 vcst = _mm_set1_ps(cst);
if (imax > 0) {
__m128* vl = (__m128*) fl;
vw[-2] = _mm_add_ps(vw[-2], _mm_mul_ps(_mm_add_ps(vl[0], vw[0]), vcst));
vw[-1] = _mm_add_ps(vw[-1], _mm_mul_ps(_mm_add_ps(vl[1], vw[1]), vcst));
vw += 2 * (NB_ELTS_V8 * sizeof(OPJ_FLOAT32) / sizeof(__m128));
i = 1;
for (; i < imax; ++i) {
vw[-2] = _mm_add_ps(vw[-2], _mm_mul_ps(_mm_add_ps(vw[-4], vw[0]), vcst));
vw[-1] = _mm_add_ps(vw[-1], _mm_mul_ps(_mm_add_ps(vw[-3], vw[1]), vcst));
vw += 2 * (NB_ELTS_V8 * sizeof(OPJ_FLOAT32) / sizeof(__m128));
}
}
if (m < end) {
assert(m + 1 == end);
vcst = _mm_add_ps(vcst, vcst);
vw[-2] = _mm_add_ps(vw[-2], _mm_mul_ps(vw[-4], vcst));
vw[-1] = _mm_add_ps(vw[-1], _mm_mul_ps(vw[-3], vcst));
}
#else
OPJ_INT32 c;
if (imax > 0) {
for (c = 0; c < NB_ELTS_V8; c++) {
fw[-1 * NB_ELTS_V8 + c] += (fl[0 * NB_ELTS_V8 + c] + fw[0 * NB_ELTS_V8 + c]) *
cst;
}
fw += 2 * NB_ELTS_V8;
i = 1;
for (; i < imax; ++i) {
for (c = 0; c < NB_ELTS_V8; c++) {
fw[-1 * NB_ELTS_V8 + c] += (fw[-2 * NB_ELTS_V8 + c] + fw[0 * NB_ELTS_V8 + c]) *
cst;
}
fw += 2 * NB_ELTS_V8;
}
}
if (m < end) {
assert(m + 1 == end);
for (c = 0; c < NB_ELTS_V8; c++) {
fw[-1 * NB_ELTS_V8 + c] += (2 * fw[-2 * NB_ELTS_V8 + c]) * cst;
}
}
#endif
}
/* Forward 9-7 transform, for the vertical pass, processing cols columns */
/* where cols <= NB_ELTS_V8 */
static void opj_dwt_encode_and_deinterleave_v_real(
@ -1529,20 +1654,59 @@ static void opj_dwt_encode_and_deinterleave_v_real(
{
OPJ_FLOAT32* OPJ_RESTRICT array = (OPJ_FLOAT32 * OPJ_RESTRICT)arrayIn;
OPJ_FLOAT32* OPJ_RESTRICT tmp = (OPJ_FLOAT32 * OPJ_RESTRICT)tmpIn;
OPJ_UINT32 c;
const OPJ_INT32 sn = (OPJ_INT32)((height + (even ? 1 : 0)) >> 1);
const OPJ_INT32 dn = (OPJ_INT32)(height - (OPJ_UINT32)sn);
for (c = 0; c < cols; c++) {
OPJ_UINT32 k;
for (k = 0; k < height; ++k) {
tmp[k] = array[c + k * stride_width];
}
OPJ_INT32 a, b;
opj_dwt_encode_1_real(tmp, dn, sn, even ? 0 : 1);
if (height == 1) {
return;
}
opj_dwt_deinterleave_v((OPJ_INT32*)tmpIn,
((OPJ_INT32*)(arrayIn)) + c,
dn, sn, stride_width, even ? 0 : 1);
opj_dwt_fetch_cols_vertical_pass(arrayIn, tmpIn, height, stride_width, cols);
if (even) {
a = 0;
b = 1;
} else {
a = 1;
b = 0;
}
opj_v8dwt_encode_step2(tmp + a * NB_ELTS_V8,
tmp + (b + 1) * NB_ELTS_V8,
(OPJ_UINT32)dn,
(OPJ_UINT32)opj_int_min(dn, sn - b),
opj_dwt_alpha);
opj_v8dwt_encode_step2(tmp + b * NB_ELTS_V8,
tmp + (a + 1) * NB_ELTS_V8,
(OPJ_UINT32)sn,
(OPJ_UINT32)opj_int_min(sn, dn - a),
opj_dwt_beta);
opj_v8dwt_encode_step2(tmp + a * NB_ELTS_V8,
tmp + (b + 1) * NB_ELTS_V8,
(OPJ_UINT32)dn,
(OPJ_UINT32)opj_int_min(dn, sn - b),
opj_dwt_gamma);
opj_v8dwt_encode_step2(tmp + b * NB_ELTS_V8,
tmp + (a + 1) * NB_ELTS_V8,
(OPJ_UINT32)sn,
(OPJ_UINT32)opj_int_min(sn, dn - a),
opj_dwt_delta);
opj_v8dwt_encode_step1(tmp + b * NB_ELTS_V8, (OPJ_UINT32)dn,
opj_K);
opj_v8dwt_encode_step1(tmp + a * NB_ELTS_V8, (OPJ_UINT32)sn,
opj_invK);
if (cols == NB_ELTS_V8) {
opj_dwt_deinterleave_v_cols((OPJ_INT32*)tmp,
(OPJ_INT32*)array,
(OPJ_INT32)dn, (OPJ_INT32)sn,
stride_width, even ? 0 : 1, NB_ELTS_V8);
} else {
opj_dwt_deinterleave_v_cols((OPJ_INT32*)tmp,
(OPJ_INT32*)array,
(OPJ_INT32)dn, (OPJ_INT32)sn,
stride_width, even ? 0 : 1, cols);
}
}