aom: Simplify Armv8.4 DotProd correction constant computation

From f1f1bf450f8995249ee393605d977c781e1037f8 Mon Sep 17 00:00:00 2001
From: Jonathan Wright <[EMAIL REDACTED]>
Date: Mon, 8 Apr 2024 17:23:31 +0100
Subject: [PATCH] Simplify Armv8.4 DotProd correction constant computation

Simplify the computation of the Armv8.4 DotProd convolution
correction constant. Summing 128 * filter_tap[0,7] is always the same
as 128 * 128 since the filter taps always sum to 128.

Change-Id: Ie0191b764809963c2be8f5032e6196725e20f0d9
---
 aom_dsp/arm/aom_convolve8_neon_dotprod.c      |  6 +-
 .../arm/compound_convolve_neon_dotprod.c      | 55 +++++++++----------
 av1/common/arm/convolve_neon_dotprod.c        | 49 +++++++----------
 3 files changed, 47 insertions(+), 63 deletions(-)

diff --git a/aom_dsp/arm/aom_convolve8_neon_dotprod.c b/aom_dsp/arm/aom_convolve8_neon_dotprod.c
index c82125ba17..9fd94cd21d 100644
--- a/aom_dsp/arm/aom_convolve8_neon_dotprod.c
+++ b/aom_dsp/arm/aom_convolve8_neon_dotprod.c
@@ -108,8 +108,7 @@ void aom_convolve8_horiz_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride,
                                       const int16_t *filter_y, int y_step_q4,
                                       int w, int h) {
   const int8x8_t filter = vmovn_s16(vld1q_s16(filter_x));
-  const int16x8_t correct_tmp = vmulq_n_s16(vld1q_s16(filter_x), 128);
-  const int32x4_t correction = vdupq_n_s32((int32_t)vaddvq_s16(correct_tmp));
+  const int32x4_t correction = vdupq_n_s32(128 << FILTER_BITS);
   const uint8x16_t range_limit = vdupq_n_u8(128);
   uint8x16_t s0, s1, s2, s3;
 
@@ -263,8 +262,7 @@ void aom_convolve8_vert_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride,
                                      const int16_t *filter_y, int y_step_q4,
                                      int w, int h) {
   const int8x8_t filter = vmovn_s16(vld1q_s16(filter_y));
-  const int16x8_t correct_tmp = vmulq_n_s16(vld1q_s16(filter_y), 128);
-  const int32x4_t correction = vdupq_n_s32((int32_t)vaddvq_s16(correct_tmp));
+  const int32x4_t correction = vdupq_n_s32(128 << FILTER_BITS);
   const uint8x8_t range_limit = vdup_n_u8(128);
   const uint8x16x3_t merge_block_tbl = vld1q_u8_x3(dot_prod_merge_block_tbl);
   int8x16x2_t samples_LUT;
diff --git a/av1/common/arm/compound_convolve_neon_dotprod.c b/av1/common/arm/compound_convolve_neon_dotprod.c
index 3aeffbb0e6..40befdf44e 100644
--- a/av1/common/arm/compound_convolve_neon_dotprod.c
+++ b/av1/common/arm/compound_convolve_neon_dotprod.c
@@ -80,17 +80,15 @@ static INLINE void dist_wtd_convolve_2d_horiz_neon_dotprod(
     const uint8_t *src, int src_stride, int16_t *im_block, const int im_stride,
     const int16_t *x_filter_ptr, const int im_h, int w) {
   const int bd = 8;
-  const int32_t horiz_const = (1 << (bd + FILTER_BITS - 2));
   // Dot product constants and other shims.
   const int16x8_t x_filter_s16 = vld1q_s16(x_filter_ptr);
-  const int32_t correction_s32 =
-      vaddlvq_s16(vshlq_n_s16(x_filter_s16, FILTER_BITS - 1));
-  // Fold horiz_const into the dot-product filter correction constant. The
-  // additional shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-
-  // rounding shifts - which are generally faster than rounding shifts on
-  // modern CPUs. (The extra -1 is needed because we halved the filter values.)
-  const int32x4_t correction = vdupq_n_s32(correction_s32 + horiz_const +
-                                           (1 << ((ROUND0_BITS - 1) - 1)));
+  // This shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding shifts
+  // - which are generally faster than rounding shifts on modern CPUs.
+  const int32_t horiz_const =
+      ((1 << (bd + FILTER_BITS - 1)) + (1 << (ROUND0_BITS - 1)));
+  // Halve the total because we will halve the filter values.
+  const int32x4_t correction =
+      vdupq_n_s32(((128 << FILTER_BITS) + horiz_const) / 2);
   const uint8x16_t range_limit = vdupq_n_u8(128);
 
   const uint8_t *src_ptr = src;
@@ -334,15 +332,14 @@ static INLINE void dist_wtd_convolve_x_dist_wtd_avg_neon_dotprod(
 
   // Dot-product constants and other shims.
   const uint8x16_t range_limit = vdupq_n_u8(128);
-  const int32_t correction_s32 =
-      vaddlvq_s16(vshlq_n_s16(x_filter_s16, FILTER_BITS - 1));
   // Fold round_offset into the dot-product filter correction constant. The
-  // additional shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-
-  // rounding shifts - which are generally faster than rounding shifts on
-  // modern CPUs. (The extra -1 is needed because we halved the filter values.)
+  // additional shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding
+  // shifts - which are generally faster than rounding shifts on modern CPUs.
+  // Halve the total because we will halve the filter values.
   int32x4_t correction =
-      vdupq_n_s32(correction_s32 + (round_offset << (ROUND0_BITS - 1)) +
-                  (1 << ((ROUND0_BITS - 1) - 1)));
+      vdupq_n_s32(((128 << FILTER_BITS) + (round_offset << ROUND0_BITS) +
+                   (1 << (ROUND0_BITS - 1))) /
+                  2);
 
   const int horiz_offset = filter_params_x->taps / 2 - 1;
   const uint8_t *src_ptr = src - horiz_offset;
@@ -455,15 +452,14 @@ static INLINE void dist_wtd_convolve_x_avg_neon_dotprod(
 
   // Dot-product constants and other shims.
   const uint8x16_t range_limit = vdupq_n_u8(128);
-  const int32_t correction_s32 =
-      vaddlvq_s16(vshlq_n_s16(x_filter_s16, FILTER_BITS - 1));
   // Fold round_offset into the dot-product filter correction constant. The
-  // additional shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-
-  // rounding shifts - which are generally faster than rounding shifts on
-  // modern CPUs. (The extra -1 is needed because we halved the filter values.)
+  // additional shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding
+  // shifts - which are generally faster than rounding shifts on modern CPUs.
+  // Halve the total because we will halve the filter values.
   int32x4_t correction =
-      vdupq_n_s32(correction_s32 + (round_offset << (ROUND0_BITS - 1)) +
-                  (1 << ((ROUND0_BITS - 1) - 1)));
+      vdupq_n_s32(((128 << FILTER_BITS) + (round_offset << ROUND0_BITS) +
+                   (1 << (ROUND0_BITS - 1))) /
+                  2);
 
   const int horiz_offset = filter_params_x->taps / 2 - 1;
   const uint8_t *src_ptr = src - horiz_offset;
@@ -574,15 +570,14 @@ static INLINE void dist_wtd_convolve_x_neon_dotprod(
 
   // Dot-product constants and other shims.
   const uint8x16_t range_limit = vdupq_n_u8(128);
-  const int32_t correction_s32 =
-      vaddlvq_s16(vshlq_n_s16(x_filter_s16, FILTER_BITS - 1));
   // Fold round_offset into the dot-product filter correction constant. The
-  // additional shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-
-  // rounding shifts - which are generally faster than rounding shifts on
-  // modern CPUs. (The extra -1 is needed because we halved the filter values.)
+  // additional shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding
+  // shifts - which are generally faster than rounding shifts on modern CPUs.
+  // Halve the total because we will halve the vilter values.
   int32x4_t correction =
-      vdupq_n_s32(correction_s32 + (round_offset << (ROUND0_BITS - 1)) +
-                  (1 << ((ROUND0_BITS - 1) - 1)));
+      vdupq_n_s32(((128 << FILTER_BITS) + (round_offset << ROUND0_BITS) +
+                   (1 << (ROUND0_BITS - 1))) /
+                  2);
 
   const int horiz_offset = filter_params_x->taps / 2 - 1;
   const uint8_t *src_ptr = src - horiz_offset;
diff --git a/av1/common/arm/convolve_neon_dotprod.c b/av1/common/arm/convolve_neon_dotprod.c
index c29229eb09..132da2442b 100644
--- a/av1/common/arm/convolve_neon_dotprod.c
+++ b/av1/common/arm/convolve_neon_dotprod.c
@@ -102,14 +102,12 @@ static INLINE void convolve_x_sr_12tap_neon_dotprod(
   const int8x16_t filter =
       vcombine_s8(vmovn_s16(filter_0_7), vmovn_s16(filter_8_15));
 
-  const int32_t correction_s32 =
-      vaddvq_s32(vaddq_s32(vpaddlq_s16(vshlq_n_s16(filter_0_7, FILTER_BITS)),
-                           vpaddlq_s16(vshlq_n_s16(filter_8_15, FILTER_BITS))));
-  // A shim of 1 << (ROUND0_BITS - 1) enables us to use a single rounding right
-  // shift by FILTER_BITS - instead of a first rounding right shift by
+  // Adding a shim of 1 << (ROUND0_BITS - 1) enables us to use a single rounding
+  // right shift by FILTER_BITS - instead of a first rounding right shift by
   // ROUND0_BITS, followed by second rounding right shift by FILTER_BITS -
   // ROUND0_BITS.
-  int32x4_t correction = vdupq_n_s32(correction_s32 + (1 << (ROUND0_BITS - 1)));
+  int32x4_t correction =
+      vdupq_n_s32((128 << FILTER_BITS) + (1 << (ROUND0_BITS - 1)));
   const uint8x16_t range_limit = vdupq_n_u8(128);
   const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
 
@@ -274,16 +272,13 @@ void av1_convolve_x_sr_neon_dotprod(const uint8_t *src, int src_stride,
   }
 
   const int16x8_t x_filter_s16 = vld1q_s16(x_filter_ptr);
-  // Dot product constants.
-  const int32_t correction_s32 =
-      vaddlvq_s16(vshlq_n_s16(x_filter_s16, FILTER_BITS - 1));
-  // This shim of (1 << ((ROUND0_BITS - 1) - 1) enables us to use a single
-  // rounding right shift by FILTER_BITS - instead of a first rounding right
-  // shift by ROUND0_BITS, followed by second rounding right shift by
-  // FILTER_BITS - ROUND0_BITS.
-  // The outermost -1 is needed because we will halve the filter values.
+  // Dot product constants:
+  // Adding a shim of 1 << (ROUND0_BITS - 1) enables us to use a single rounding
+  // right shift by FILTER_BITS - instead of a first rounding right shift by
+  // ROUND0_BITS, followed by second rounding right shift by FILTER_BITS -
+  // ROUND0_BITS. Halve the total because we will halve the filter values.
   const int32x4_t correction =
-      vdupq_n_s32(correction_s32 + (1 << ((ROUND0_BITS - 1) - 1)));
+      vdupq_n_s32(((128 << FILTER_BITS) + (1 << ((ROUND0_BITS - 1)))) / 2);
   const uint8x16_t range_limit = vdupq_n_u8(128);
 
   if (w <= 4) {
@@ -465,16 +460,13 @@ static INLINE void convolve_2d_sr_horiz_12tap_neon_dotprod(
     const int8x16_t x_filter = vcombine_s8(vmovn_s16(x_filter_s16.val[0]),
                                            vmovn_s16(x_filter_s16.val[1]));
 
-    // This shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding shifts
-    // - which are generally faster than rounding shifts on modern CPUs.
+    // Adding a shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding
+    // shifts - which are generally faster than rounding shifts on modern CPUs.
     const int32_t horiz_const =
         ((1 << (bd + FILTER_BITS - 1)) + (1 << (ROUND0_BITS - 1)));
     // Dot product constants.
-    const int32x4_t correct_tmp =
-        vaddq_s32(vpaddlq_s16(vshlq_n_s16(x_filter_s16.val[0], 7)),
-                  vpaddlq_s16(vshlq_n_s16(x_filter_s16.val[1], 7)));
     const int32x4_t correction =
-        vdupq_n_s32(vaddvq_s32(correct_tmp) + horiz_const);
+        vdupq_n_s32((128 << FILTER_BITS) + horiz_const);
     const uint8x16_t range_limit = vdupq_n_u8(128);
     const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
 
@@ -621,16 +613,15 @@ static INLINE void convolve_2d_sr_horiz_neon_dotprod(
     const uint8_t *src, int src_stride, int16_t *im_block, int im_stride, int w,
     int im_h, const int16_t *x_filter_ptr) {
   const int bd = 8;
-  // This shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
-  // shifts - which are generally faster than rounding shifts on modern CPUs.
-  // The outermost -1 is needed because we halved the filter values.
-  const int32_t horiz_const =
-      ((1 << (bd + FILTER_BITS - 2)) + (1 << ((ROUND0_BITS - 1) - 1)));
   // Dot product constants.
   const int16x8_t x_filter_s16 = vld1q_s16(x_filter_ptr);
-  const int32_t correction_s32 =
-      vaddlvq_s16(vshlq_n_s16(x_filter_s16, FILTER_BITS - 1));
-  const int32x4_t correction = vdupq_n_s32(correction_s32 + horiz_const);
+  // Adding a shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding
+  // shifts - which are generally faster than rounding shifts on modern CPUs.
+  const int32_t horiz_const =
+      ((1 << (bd + FILTER_BITS - 1)) + (1 << (ROUND0_BITS - 1)));
+  // Halve the total because we will halve the filter values.
+  const int32x4_t correction =
+      vdupq_n_s32(((128 << FILTER_BITS) + horiz_const) / 2);
   const uint8x16_t range_limit = vdupq_n_u8(128);
 
   const uint8_t *src_ptr = src;