From f1f1bf450f8995249ee393605d977c781e1037f8 Mon Sep 17 00:00:00 2001
From: Jonathan Wright <[EMAIL REDACTED]>
Date: Mon, 8 Apr 2024 17:23:31 +0100
Subject: [PATCH] Simplify Armv8.4 DotProd correction constant computation
Simplify the computation of the Armv8.4 DotProd convolution
correction constant. Summing 128 * filter_tap[0,7] is always the same
as 128 * 128 since the filter taps always sum to 128.
Change-Id: Ie0191b764809963c2be8f5032e6196725e20f0d9
---
aom_dsp/arm/aom_convolve8_neon_dotprod.c | 6 +-
.../arm/compound_convolve_neon_dotprod.c | 55 +++++++++----------
av1/common/arm/convolve_neon_dotprod.c | 49 +++++++----------
3 files changed, 47 insertions(+), 63 deletions(-)
diff --git a/aom_dsp/arm/aom_convolve8_neon_dotprod.c b/aom_dsp/arm/aom_convolve8_neon_dotprod.c
index c82125ba17..9fd94cd21d 100644
--- a/aom_dsp/arm/aom_convolve8_neon_dotprod.c
+++ b/aom_dsp/arm/aom_convolve8_neon_dotprod.c
@@ -108,8 +108,7 @@ void aom_convolve8_horiz_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride,
const int16_t *filter_y, int y_step_q4,
int w, int h) {
const int8x8_t filter = vmovn_s16(vld1q_s16(filter_x));
- const int16x8_t correct_tmp = vmulq_n_s16(vld1q_s16(filter_x), 128);
- const int32x4_t correction = vdupq_n_s32((int32_t)vaddvq_s16(correct_tmp));
+ const int32x4_t correction = vdupq_n_s32(128 << FILTER_BITS);
const uint8x16_t range_limit = vdupq_n_u8(128);
uint8x16_t s0, s1, s2, s3;
@@ -263,8 +262,7 @@ void aom_convolve8_vert_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride,
const int16_t *filter_y, int y_step_q4,
int w, int h) {
const int8x8_t filter = vmovn_s16(vld1q_s16(filter_y));
- const int16x8_t correct_tmp = vmulq_n_s16(vld1q_s16(filter_y), 128);
- const int32x4_t correction = vdupq_n_s32((int32_t)vaddvq_s16(correct_tmp));
+ const int32x4_t correction = vdupq_n_s32(128 << FILTER_BITS);
const uint8x8_t range_limit = vdup_n_u8(128);
const uint8x16x3_t merge_block_tbl = vld1q_u8_x3(dot_prod_merge_block_tbl);
int8x16x2_t samples_LUT;
diff --git a/av1/common/arm/compound_convolve_neon_dotprod.c b/av1/common/arm/compound_convolve_neon_dotprod.c
index 3aeffbb0e6..40befdf44e 100644
--- a/av1/common/arm/compound_convolve_neon_dotprod.c
+++ b/av1/common/arm/compound_convolve_neon_dotprod.c
@@ -80,17 +80,15 @@ static INLINE void dist_wtd_convolve_2d_horiz_neon_dotprod(
const uint8_t *src, int src_stride, int16_t *im_block, const int im_stride,
const int16_t *x_filter_ptr, const int im_h, int w) {
const int bd = 8;
- const int32_t horiz_const = (1 << (bd + FILTER_BITS - 2));
// Dot product constants and other shims.
const int16x8_t x_filter_s16 = vld1q_s16(x_filter_ptr);
- const int32_t correction_s32 =
- vaddlvq_s16(vshlq_n_s16(x_filter_s16, FILTER_BITS - 1));
- // Fold horiz_const into the dot-product filter correction constant. The
- // additional shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-
- // rounding shifts - which are generally faster than rounding shifts on
- // modern CPUs. (The extra -1 is needed because we halved the filter values.)
- const int32x4_t correction = vdupq_n_s32(correction_s32 + horiz_const +
- (1 << ((ROUND0_BITS - 1) - 1)));
+ // This shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding shifts
+ // - which are generally faster than rounding shifts on modern CPUs.
+ const int32_t horiz_const =
+ ((1 << (bd + FILTER_BITS - 1)) + (1 << (ROUND0_BITS - 1)));
+ // Halve the total because we will halve the filter values.
+ const int32x4_t correction =
+ vdupq_n_s32(((128 << FILTER_BITS) + horiz_const) / 2);
const uint8x16_t range_limit = vdupq_n_u8(128);
const uint8_t *src_ptr = src;
@@ -334,15 +332,14 @@ static INLINE void dist_wtd_convolve_x_dist_wtd_avg_neon_dotprod(
// Dot-product constants and other shims.
const uint8x16_t range_limit = vdupq_n_u8(128);
- const int32_t correction_s32 =
- vaddlvq_s16(vshlq_n_s16(x_filter_s16, FILTER_BITS - 1));
// Fold round_offset into the dot-product filter correction constant. The
- // additional shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-
- // rounding shifts - which are generally faster than rounding shifts on
- // modern CPUs. (The extra -1 is needed because we halved the filter values.)
+ // additional shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding
+ // shifts - which are generally faster than rounding shifts on modern CPUs.
+ // Halve the total because we will halve the filter values.
int32x4_t correction =
- vdupq_n_s32(correction_s32 + (round_offset << (ROUND0_BITS - 1)) +
- (1 << ((ROUND0_BITS - 1) - 1)));
+ vdupq_n_s32(((128 << FILTER_BITS) + (round_offset << ROUND0_BITS) +
+ (1 << (ROUND0_BITS - 1))) /
+ 2);
const int horiz_offset = filter_params_x->taps / 2 - 1;
const uint8_t *src_ptr = src - horiz_offset;
@@ -455,15 +452,14 @@ static INLINE void dist_wtd_convolve_x_avg_neon_dotprod(
// Dot-product constants and other shims.
const uint8x16_t range_limit = vdupq_n_u8(128);
- const int32_t correction_s32 =
- vaddlvq_s16(vshlq_n_s16(x_filter_s16, FILTER_BITS - 1));
// Fold round_offset into the dot-product filter correction constant. The
- // additional shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-
- // rounding shifts - which are generally faster than rounding shifts on
- // modern CPUs. (The extra -1 is needed because we halved the filter values.)
+ // additional shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding
+ // shifts - which are generally faster than rounding shifts on modern CPUs.
+ // Halve the total because we will halve the filter values.
int32x4_t correction =
- vdupq_n_s32(correction_s32 + (round_offset << (ROUND0_BITS - 1)) +
- (1 << ((ROUND0_BITS - 1) - 1)));
+ vdupq_n_s32(((128 << FILTER_BITS) + (round_offset << ROUND0_BITS) +
+ (1 << (ROUND0_BITS - 1))) /
+ 2);
const int horiz_offset = filter_params_x->taps / 2 - 1;
const uint8_t *src_ptr = src - horiz_offset;
@@ -574,15 +570,14 @@ static INLINE void dist_wtd_convolve_x_neon_dotprod(
// Dot-product constants and other shims.
const uint8x16_t range_limit = vdupq_n_u8(128);
- const int32_t correction_s32 =
- vaddlvq_s16(vshlq_n_s16(x_filter_s16, FILTER_BITS - 1));
// Fold round_offset into the dot-product filter correction constant. The
- // additional shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-
- // rounding shifts - which are generally faster than rounding shifts on
- // modern CPUs. (The extra -1 is needed because we halved the filter values.)
+ // additional shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding
+ // shifts - which are generally faster than rounding shifts on modern CPUs.
+ // Halve the total because we will halve the vilter values.
int32x4_t correction =
- vdupq_n_s32(correction_s32 + (round_offset << (ROUND0_BITS - 1)) +
- (1 << ((ROUND0_BITS - 1) - 1)));
+ vdupq_n_s32(((128 << FILTER_BITS) + (round_offset << ROUND0_BITS) +
+ (1 << (ROUND0_BITS - 1))) /
+ 2);
const int horiz_offset = filter_params_x->taps / 2 - 1;
const uint8_t *src_ptr = src - horiz_offset;
diff --git a/av1/common/arm/convolve_neon_dotprod.c b/av1/common/arm/convolve_neon_dotprod.c
index c29229eb09..132da2442b 100644
--- a/av1/common/arm/convolve_neon_dotprod.c
+++ b/av1/common/arm/convolve_neon_dotprod.c
@@ -102,14 +102,12 @@ static INLINE void convolve_x_sr_12tap_neon_dotprod(
const int8x16_t filter =
vcombine_s8(vmovn_s16(filter_0_7), vmovn_s16(filter_8_15));
- const int32_t correction_s32 =
- vaddvq_s32(vaddq_s32(vpaddlq_s16(vshlq_n_s16(filter_0_7, FILTER_BITS)),
- vpaddlq_s16(vshlq_n_s16(filter_8_15, FILTER_BITS))));
- // A shim of 1 << (ROUND0_BITS - 1) enables us to use a single rounding right
- // shift by FILTER_BITS - instead of a first rounding right shift by
+ // Adding a shim of 1 << (ROUND0_BITS - 1) enables us to use a single rounding
+ // right shift by FILTER_BITS - instead of a first rounding right shift by
// ROUND0_BITS, followed by second rounding right shift by FILTER_BITS -
// ROUND0_BITS.
- int32x4_t correction = vdupq_n_s32(correction_s32 + (1 << (ROUND0_BITS - 1)));
+ int32x4_t correction =
+ vdupq_n_s32((128 << FILTER_BITS) + (1 << (ROUND0_BITS - 1)));
const uint8x16_t range_limit = vdupq_n_u8(128);
const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
@@ -274,16 +272,13 @@ void av1_convolve_x_sr_neon_dotprod(const uint8_t *src, int src_stride,
}
const int16x8_t x_filter_s16 = vld1q_s16(x_filter_ptr);
- // Dot product constants.
- const int32_t correction_s32 =
- vaddlvq_s16(vshlq_n_s16(x_filter_s16, FILTER_BITS - 1));
- // This shim of (1 << ((ROUND0_BITS - 1) - 1) enables us to use a single
- // rounding right shift by FILTER_BITS - instead of a first rounding right
- // shift by ROUND0_BITS, followed by second rounding right shift by
- // FILTER_BITS - ROUND0_BITS.
- // The outermost -1 is needed because we will halve the filter values.
+ // Dot product constants:
+ // Adding a shim of 1 << (ROUND0_BITS - 1) enables us to use a single rounding
+ // right shift by FILTER_BITS - instead of a first rounding right shift by
+ // ROUND0_BITS, followed by second rounding right shift by FILTER_BITS -
+ // ROUND0_BITS. Halve the total because we will halve the filter values.
const int32x4_t correction =
- vdupq_n_s32(correction_s32 + (1 << ((ROUND0_BITS - 1) - 1)));
+ vdupq_n_s32(((128 << FILTER_BITS) + (1 << ((ROUND0_BITS - 1)))) / 2);
const uint8x16_t range_limit = vdupq_n_u8(128);
if (w <= 4) {
@@ -465,16 +460,13 @@ static INLINE void convolve_2d_sr_horiz_12tap_neon_dotprod(
const int8x16_t x_filter = vcombine_s8(vmovn_s16(x_filter_s16.val[0]),
vmovn_s16(x_filter_s16.val[1]));
- // This shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding shifts
- // - which are generally faster than rounding shifts on modern CPUs.
+ // Adding a shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding
+ // shifts - which are generally faster than rounding shifts on modern CPUs.
const int32_t horiz_const =
((1 << (bd + FILTER_BITS - 1)) + (1 << (ROUND0_BITS - 1)));
// Dot product constants.
- const int32x4_t correct_tmp =
- vaddq_s32(vpaddlq_s16(vshlq_n_s16(x_filter_s16.val[0], 7)),
- vpaddlq_s16(vshlq_n_s16(x_filter_s16.val[1], 7)));
const int32x4_t correction =
- vdupq_n_s32(vaddvq_s32(correct_tmp) + horiz_const);
+ vdupq_n_s32((128 << FILTER_BITS) + horiz_const);
const uint8x16_t range_limit = vdupq_n_u8(128);
const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
@@ -621,16 +613,15 @@ static INLINE void convolve_2d_sr_horiz_neon_dotprod(
const uint8_t *src, int src_stride, int16_t *im_block, int im_stride, int w,
int im_h, const int16_t *x_filter_ptr) {
const int bd = 8;
- // This shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
- // shifts - which are generally faster than rounding shifts on modern CPUs.
- // The outermost -1 is needed because we halved the filter values.
- const int32_t horiz_const =
- ((1 << (bd + FILTER_BITS - 2)) + (1 << ((ROUND0_BITS - 1) - 1)));
// Dot product constants.
const int16x8_t x_filter_s16 = vld1q_s16(x_filter_ptr);
- const int32_t correction_s32 =
- vaddlvq_s16(vshlq_n_s16(x_filter_s16, FILTER_BITS - 1));
- const int32x4_t correction = vdupq_n_s32(correction_s32 + horiz_const);
+ // Adding a shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding
+ // shifts - which are generally faster than rounding shifts on modern CPUs.
+ const int32_t horiz_const =
+ ((1 << (bd + FILTER_BITS - 1)) + (1 << (ROUND0_BITS - 1)));
+ // Halve the total because we will halve the filter values.
+ const int32x4_t correction =
+ vdupq_n_s32(((128 << FILTER_BITS) + horiz_const) / 2);
const uint8x16_t range_limit = vdupq_n_u8(128);
const uint8_t *src_ptr = src;