Forward DWT 5-3: major speed up by vectorizing vertical pass

`bench_dwt -encode` times goes from 7.9s to 1.7s
This commit is contained in:
Even Rouault 2020-05-22 17:50:15 +02:00
parent e69fa09f60
commit a38e970fa5
No known key found for this signature in database
GPG Key ID: 33EBBFC47B3DD87D
1 changed files with 286 additions and 61 deletions

View File

@ -132,11 +132,7 @@ static void opj_dwt_deinterleave_v(const OPJ_INT32 * OPJ_RESTRICT a,
OPJ_INT32 * OPJ_RESTRICT b,
OPJ_INT32 dn,
OPJ_INT32 sn, OPJ_UINT32 x, OPJ_INT32 cas);
/**
Forward 5-3 wavelet transform in 1-D
*/
static void opj_dwt_encode_1(void *a, OPJ_INT32 dn, OPJ_INT32 sn,
OPJ_INT32 cas);
/**
Forward 9-7 wavelet transform in 1-D
*/
@ -332,52 +328,6 @@ static void opj_dwt_interleave_v(const opj_dwt_t* v, OPJ_INT32 *a, OPJ_INT32 x)
#endif /* STANDARD_SLOW_VERSION */
/* <summary> */
/* Forward 5-3 wavelet transform in 1-D. */
/* </summary> */
static void opj_dwt_encode_1(void *aIn, OPJ_INT32 dn, OPJ_INT32 sn,
OPJ_INT32 cas)
{
OPJ_INT32 i;
OPJ_INT32* a = (OPJ_INT32*)aIn;
if (!cas) {
if (sn + dn > 1) {
for (i = 0; i < sn - 1; i++) {
OPJ_D(i) -= (OPJ_S(i) + OPJ_S(i + 1)) >> 1;
}
if (((sn + dn) % 2) == 0) {
OPJ_D(i) -= OPJ_S(i);
}
OPJ_S(0) += (OPJ_D(0) + OPJ_D(0) + 2) >> 2;
for (i = 1; i < dn; i++) {
OPJ_S(i) += (OPJ_D(i - 1) + OPJ_D(i) + 2) >> 2;
}
if (((sn + dn) % 2) == 1) {
OPJ_S(i) += (OPJ_D(i - 1) + OPJ_D(i - 1) + 2) >> 2;
}
}
} else {
if (sn + dn == 1) {
a[0] *= 2;
} else {
OPJ_S(0) -= OPJ_D(0);
for (i = 1; i < sn; i++) {
OPJ_S(i) -= (OPJ_D(i) + OPJ_D(i - 1)) >> 1;
}
if (((sn + dn) % 2) == 1) {
OPJ_S(i) -= OPJ_D(i - 1);
}
for (i = 0; i < dn - 1; i++) {
OPJ_D(i) += (OPJ_S(i) + OPJ_S(i + 1) + 2) >> 2;
}
if (((sn + dn) % 2) == 0) {
OPJ_D(i) += (OPJ_S(i) + OPJ_S(i) + 2) >> 2;
}
}
}
}
#ifdef STANDARD_SLOW_VERSION
/* <summary> */
/* Inverse 5-3 wavelet transform in 1-D. */
@ -1265,6 +1215,76 @@ static void opj_dwt_encode_v_func(void* user_data, opj_tls_t* tls)
opj_free(job);
}
/** Fetch up to cols <= NB_ELTS_V8 for each line, and put them in tmpOut */
/* that has a NB_ELTS_V8 interleave factor. */
static void opj_dwt_fetch_cols_vertical_pass(const void *arrayIn,
void *tmpOut,
OPJ_UINT32 height,
OPJ_UINT32 stride_width,
OPJ_UINT32 cols)
{
const OPJ_INT32* OPJ_RESTRICT array = (const OPJ_INT32 * OPJ_RESTRICT)arrayIn;
OPJ_INT32* OPJ_RESTRICT tmp = (OPJ_INT32 * OPJ_RESTRICT)tmpOut;
if (cols == NB_ELTS_V8) {
OPJ_UINT32 k;
for (k = 0; k < height; ++k) {
memcpy(tmp + NB_ELTS_V8 * k,
array + k * stride_width,
NB_ELTS_V8 * sizeof(OPJ_INT32));
}
} else {
OPJ_UINT32 k;
for (k = 0; k < height; ++k) {
OPJ_UINT32 c;
for (c = 0; c < cols; c++) {
tmp[NB_ELTS_V8 * k + c] = array[c + k * stride_width];
}
for (; c < NB_ELTS_V8; c++) {
tmp[NB_ELTS_V8 * k + c] = 0;
}
}
}
}
/* Deinterleave result of forward transform, where cols <= NB_ELTS_V8 */
/* and src contains NB_ELTS_V8 consecutive values for up to NB_ELTS_V8 */
/* columns. */
static INLINE void opj_dwt_deinterleave_v_cols(
const OPJ_INT32 * OPJ_RESTRICT src,
OPJ_INT32 * OPJ_RESTRICT dst,
OPJ_INT32 dn,
OPJ_INT32 sn,
OPJ_UINT32 stride_width,
OPJ_INT32 cas,
OPJ_UINT32 cols)
{
OPJ_INT32 i = sn;
OPJ_INT32 * OPJ_RESTRICT l_dest = dst;
const OPJ_INT32 * OPJ_RESTRICT l_src = src + cas * NB_ELTS_V8;
OPJ_UINT32 c;
while (i--) {
for (c = 0; c < cols; c++) {
l_dest[c] = l_src[c];
}
l_dest += stride_width;
l_src += 2 * NB_ELTS_V8;
}
l_dest = dst + (OPJ_SIZE_T)sn * (OPJ_SIZE_T)stride_width;
l_src = src + (1 - cas) * NB_ELTS_V8;
i = dn;
while (i--) {
for (c = 0; c < cols; c++) {
l_dest[c] = l_src[c];
}
l_dest += stride_width;
l_src += 2 * NB_ELTS_V8;
}
}
/* Forward 5-3 transform, for the vertical pass, processing cols columns */
/* where cols <= NB_ELTS_V8 */
static void opj_dwt_encode_and_deinterleave_v(
@ -1277,18 +1297,223 @@ static void opj_dwt_encode_and_deinterleave_v(
{
OPJ_INT32* OPJ_RESTRICT array = (OPJ_INT32 * OPJ_RESTRICT)arrayIn;
OPJ_INT32* OPJ_RESTRICT tmp = (OPJ_INT32 * OPJ_RESTRICT)tmpIn;
OPJ_UINT32 c;
const OPJ_INT32 sn = (OPJ_INT32)((height + (even ? 1 : 0)) >> 1);
const OPJ_INT32 dn = (OPJ_INT32)(height - (OPJ_UINT32)sn);
for (c = 0; c < cols; c++) {
OPJ_UINT32 k;
for (k = 0; k < height; ++k) {
tmp[k] = array[c + k * stride_width];
const OPJ_UINT32 sn = (height + (even ? 1 : 0)) >> 1;
const OPJ_UINT32 dn = height - sn;
opj_dwt_fetch_cols_vertical_pass(arrayIn, tmpIn, height, stride_width, cols);
#define OPJ_Sc(i) tmp[(i)*2* NB_ELTS_V8 + c]
#define OPJ_Dc(i) tmp[((1+(i)*2))* NB_ELTS_V8 + c]
#ifdef __SSE2__
if (height == 1) {
if (!even) {
OPJ_UINT32 c;
for (c = 0; c < NB_ELTS_V8; c++) {
tmp[c] *= 2;
}
}
} else if (even) {
OPJ_UINT32 c;
OPJ_UINT32 i;
i = 0;
if (i + 1 < sn) {
__m128i xmm_Si_0 = *(const __m128i*)(tmp + 4 * 0);
__m128i xmm_Si_1 = *(const __m128i*)(tmp + 4 * 1);
for (; i + 1 < sn; i++) {
__m128i xmm_Sip1_0 = *(const __m128i*)(tmp +
(i + 1) * 2 * NB_ELTS_V8 + 4 * 0);
__m128i xmm_Sip1_1 = *(const __m128i*)(tmp +
(i + 1) * 2 * NB_ELTS_V8 + 4 * 1);
__m128i xmm_Di_0 = *(const __m128i*)(tmp +
(1 + i * 2) * NB_ELTS_V8 + 4 * 0);
__m128i xmm_Di_1 = *(const __m128i*)(tmp +
(1 + i * 2) * NB_ELTS_V8 + 4 * 1);
xmm_Di_0 = _mm_sub_epi32(xmm_Di_0,
_mm_srai_epi32(_mm_add_epi32(xmm_Si_0, xmm_Sip1_0), 1));
xmm_Di_1 = _mm_sub_epi32(xmm_Di_1,
_mm_srai_epi32(_mm_add_epi32(xmm_Si_1, xmm_Sip1_1), 1));
*(__m128i*)(tmp + (1 + i * 2) * NB_ELTS_V8 + 4 * 0) = xmm_Di_0;
*(__m128i*)(tmp + (1 + i * 2) * NB_ELTS_V8 + 4 * 1) = xmm_Di_1;
xmm_Si_0 = xmm_Sip1_0;
xmm_Si_1 = xmm_Sip1_1;
}
}
if (((height) % 2) == 0) {
for (c = 0; c < NB_ELTS_V8; c++) {
OPJ_Dc(i) -= OPJ_Sc(i);
}
}
for (c = 0; c < NB_ELTS_V8; c++) {
OPJ_Sc(0) += (OPJ_Dc(0) + OPJ_Dc(0) + 2) >> 2;
}
i = 1;
if (i < dn) {
__m128i xmm_Dim1_0 = *(const __m128i*)(tmp + (1 +
(i - 1) * 2) * NB_ELTS_V8 + 4 * 0);
__m128i xmm_Dim1_1 = *(const __m128i*)(tmp + (1 +
(i - 1) * 2) * NB_ELTS_V8 + 4 * 1);
const __m128i xmm_two = _mm_set1_epi32(2);
for (; i < dn; i++) {
__m128i xmm_Di_0 = *(const __m128i*)(tmp +
(1 + i * 2) * NB_ELTS_V8 + 4 * 0);
__m128i xmm_Di_1 = *(const __m128i*)(tmp +
(1 + i * 2) * NB_ELTS_V8 + 4 * 1);
__m128i xmm_Si_0 = *(const __m128i*)(tmp +
(i * 2) * NB_ELTS_V8 + 4 * 0);
__m128i xmm_Si_1 = *(const __m128i*)(tmp +
(i * 2) * NB_ELTS_V8 + 4 * 1);
xmm_Si_0 = _mm_add_epi32(xmm_Si_0,
_mm_srai_epi32(_mm_add_epi32(_mm_add_epi32(xmm_Dim1_0, xmm_Di_0), xmm_two), 2));
xmm_Si_1 = _mm_add_epi32(xmm_Si_1,
_mm_srai_epi32(_mm_add_epi32(_mm_add_epi32(xmm_Dim1_1, xmm_Di_1), xmm_two), 2));
*(__m128i*)(tmp + (i * 2) * NB_ELTS_V8 + 4 * 0) = xmm_Si_0;
*(__m128i*)(tmp + (i * 2) * NB_ELTS_V8 + 4 * 1) = xmm_Si_1;
xmm_Dim1_0 = xmm_Di_0;
xmm_Dim1_1 = xmm_Di_1;
}
}
if (((height) % 2) == 1) {
for (c = 0; c < NB_ELTS_V8; c++) {
OPJ_Sc(i) += (OPJ_Dc(i - 1) + OPJ_Dc(i - 1) + 2) >> 2;
}
}
} else {
OPJ_UINT32 c;
OPJ_UINT32 i;
for (c = 0; c < NB_ELTS_V8; c++) {
OPJ_Sc(0) -= OPJ_Dc(0);
}
i = 1;
if (i < sn) {
__m128i xmm_Dim1_0 = *(const __m128i*)(tmp + (1 +
(i - 1) * 2) * NB_ELTS_V8 + 4 * 0);
__m128i xmm_Dim1_1 = *(const __m128i*)(tmp + (1 +
(i - 1) * 2) * NB_ELTS_V8 + 4 * 1);
for (; i < sn; i++) {
__m128i xmm_Di_0 = *(const __m128i*)(tmp +
(1 + i * 2) * NB_ELTS_V8 + 4 * 0);
__m128i xmm_Di_1 = *(const __m128i*)(tmp +
(1 + i * 2) * NB_ELTS_V8 + 4 * 1);
__m128i xmm_Si_0 = *(const __m128i*)(tmp +
(i * 2) * NB_ELTS_V8 + 4 * 0);
__m128i xmm_Si_1 = *(const __m128i*)(tmp +
(i * 2) * NB_ELTS_V8 + 4 * 1);
xmm_Si_0 = _mm_sub_epi32(xmm_Si_0,
_mm_srai_epi32(_mm_add_epi32(xmm_Di_0, xmm_Dim1_0), 1));
xmm_Si_1 = _mm_sub_epi32(xmm_Si_1,
_mm_srai_epi32(_mm_add_epi32(xmm_Di_1, xmm_Dim1_1), 1));
*(__m128i*)(tmp + (i * 2) * NB_ELTS_V8 + 4 * 0) = xmm_Si_0;
*(__m128i*)(tmp + (i * 2) * NB_ELTS_V8 + 4 * 1) = xmm_Si_1;
xmm_Dim1_0 = xmm_Di_0;
xmm_Dim1_1 = xmm_Di_1;
}
}
if (((height) % 2) == 1) {
for (c = 0; c < NB_ELTS_V8; c++) {
OPJ_Sc(i) -= OPJ_Dc(i - 1);
}
}
i = 0;
if (i + 1 < dn) {
__m128i xmm_Si_0 = *((const __m128i*)(tmp + 4 * 0));
__m128i xmm_Si_1 = *((const __m128i*)(tmp + 4 * 1));
const __m128i xmm_two = _mm_set1_epi32(2);
for (; i + 1 < dn; i++) {
__m128i xmm_Sip1_0 = *(const __m128i*)(tmp +
(i + 1) * 2 * NB_ELTS_V8 + 4 * 0);
__m128i xmm_Sip1_1 = *(const __m128i*)(tmp +
(i + 1) * 2 * NB_ELTS_V8 + 4 * 1);
__m128i xmm_Di_0 = *(const __m128i*)(tmp +
(1 + i * 2) * NB_ELTS_V8 + 4 * 0);
__m128i xmm_Di_1 = *(const __m128i*)(tmp +
(1 + i * 2) * NB_ELTS_V8 + 4 * 1);
xmm_Di_0 = _mm_add_epi32(xmm_Di_0,
_mm_srai_epi32(_mm_add_epi32(_mm_add_epi32(xmm_Si_0, xmm_Sip1_0), xmm_two), 2));
xmm_Di_1 = _mm_add_epi32(xmm_Di_1,
_mm_srai_epi32(_mm_add_epi32(_mm_add_epi32(xmm_Si_1, xmm_Sip1_1), xmm_two), 2));
*(__m128i*)(tmp + (1 + i * 2) * NB_ELTS_V8 + 4 * 0) = xmm_Di_0;
*(__m128i*)(tmp + (1 + i * 2) * NB_ELTS_V8 + 4 * 1) = xmm_Di_1;
xmm_Si_0 = xmm_Sip1_0;
xmm_Si_1 = xmm_Sip1_1;
}
}
if (((height) % 2) == 0) {
for (c = 0; c < NB_ELTS_V8; c++) {
OPJ_Dc(i) += (OPJ_Sc(i) + OPJ_Sc(i) + 2) >> 2;
}
}
}
#else
if (even) {
OPJ_UINT32 c;
if (height > 1) {
OPJ_UINT32 i;
for (i = 0; i + 1 < sn; i++) {
for (c = 0; c < NB_ELTS_V8; c++) {
OPJ_Dc(i) -= (OPJ_Sc(i) + OPJ_Sc(i + 1)) >> 1;
}
}
if (((height) % 2) == 0) {
for (c = 0; c < NB_ELTS_V8; c++) {
OPJ_Dc(i) -= OPJ_Sc(i);
}
}
for (c = 0; c < NB_ELTS_V8; c++) {
OPJ_Sc(0) += (OPJ_Dc(0) + OPJ_Dc(0) + 2) >> 2;
}
for (i = 1; i < dn; i++) {
for (c = 0; c < NB_ELTS_V8; c++) {
OPJ_Sc(i) += (OPJ_Dc(i - 1) + OPJ_Dc(i) + 2) >> 2;
}
}
if (((height) % 2) == 1) {
for (c = 0; c < NB_ELTS_V8; c++) {
OPJ_Sc(i) += (OPJ_Dc(i - 1) + OPJ_Dc(i - 1) + 2) >> 2;
}
}
}
} else {
OPJ_UINT32 c;
if (height == 1) {
for (c = 0; c < NB_ELTS_V8; c++) {
OPJ_Sc(0) *= 2;
}
} else {
OPJ_UINT32 i;
for (c = 0; c < NB_ELTS_V8; c++) {
OPJ_Sc(0) -= OPJ_Dc(0);
}
for (i = 1; i < sn; i++) {
for (c = 0; c < NB_ELTS_V8; c++) {
OPJ_Sc(i) -= (OPJ_Dc(i) + OPJ_Dc(i - 1)) >> 1;
}
}
if (((height) % 2) == 1) {
for (c = 0; c < NB_ELTS_V8; c++) {
OPJ_Sc(i) -= OPJ_Dc(i - 1);
}
}
for (i = 0; i + 1 < dn; i++) {
for (c = 0; c < NB_ELTS_V8; c++) {
OPJ_Dc(i) += (OPJ_Sc(i) + OPJ_Sc(i + 1) + 2) >> 2;
}
}
if (((height) % 2) == 0) {
for (c = 0; c < NB_ELTS_V8; c++) {
OPJ_Dc(i) += (OPJ_Sc(i) + OPJ_Sc(i) + 2) >> 2;
}
}
}
}
#endif
opj_dwt_encode_1(tmp, dn, sn, even ? 0 : 1);
opj_dwt_deinterleave_v(tmp, array + c, dn, sn, stride_width, even ? 0 : 1);
if (cols == NB_ELTS_V8) {
opj_dwt_deinterleave_v_cols(tmp, array, (OPJ_INT32)dn, (OPJ_INT32)sn,
stride_width, even ? 0 : 1, NB_ELTS_V8);
} else {
opj_dwt_deinterleave_v_cols(tmp, array, (OPJ_INT32)dn, (OPJ_INT32)sn,
stride_width, even ? 0 : 1, cols);
}
}