aom: Specialize 8-tap HBD dist_wtd_convolve_x_sve2 on bitdepth

From 416c1a9ff3198f9f4833ba6809ca65eba649273f Mon Sep 17 00:00:00 2001
From: Salome Thirot <[EMAIL REDACTED]>
Date: Thu, 7 Mar 2024 12:14:27 +0000
Subject: [PATCH] Specialize 8-tap HBD dist_wtd_convolve_x_sve2 on bitdepth

The rounding value at the end of the convolution depends on the bitdepth
(8/10, or 12). Add 2 specialized versions of the function, so that we
know the rounding value at compile time and therefore use only one
instruction instead of two to perform the final rounding and narrowing
step. This gives up to 10% uplift over the non-specialized version.

Change-Id: Iba3ef98eb82e44d0f67860721a957fb73589f71a
---
 .../arm/highbd_compound_convolve_sve2.c       | 131 +++++++++++++-----
 1 file changed, 100 insertions(+), 31 deletions(-)

diff --git a/av1/common/arm/highbd_compound_convolve_sve2.c b/av1/common/arm/highbd_compound_convolve_sve2.c
index b36e01f2f..8977538c4 100644
--- a/av1/common/arm/highbd_compound_convolve_sve2.c
+++ b/av1/common/arm/highbd_compound_convolve_sve2.c
@@ -31,8 +31,9 @@ DECLARE_ALIGNED(16, static const uint16_t, kDotProdTbl[32]) = {
   4, 5, 6, 7, 5, 6, 7, 0, 6, 7, 0, 1, 7, 0, 1, 2,
 };
 
-static INLINE uint16x8_t convolve8_8_x(int16x8_t s0[8], int16x8_t filter,
-                                       int64x2_t offset, int32x4_t shift) {
+static INLINE uint16x8_t highbd_12_convolve8_8_x(int16x8_t s0[8],
+                                                 int16x8_t filter,
+                                                 int64x2_t offset) {
   int64x2_t sum[8];
   sum[0] = aom_sdotq_s16(offset, s0[0], filter);
   sum[1] = aom_sdotq_s16(offset, s0[1], filter);
@@ -51,22 +52,85 @@ static INLINE uint16x8_t convolve8_8_x(int16x8_t s0[8], int16x8_t filter,
   int32x4_t sum0123 = vcombine_s32(vmovn_s64(sum[0]), vmovn_s64(sum[2]));
   int32x4_t sum4567 = vcombine_s32(vmovn_s64(sum[4]), vmovn_s64(sum[6]));
 
-  sum0123 = vshlq_s32(sum0123, shift);
-  sum4567 = vshlq_s32(sum4567, shift);
+  return vcombine_u16(vqrshrun_n_s32(sum0123, ROUND0_BITS + 2),
+                      vqrshrun_n_s32(sum4567, ROUND0_BITS + 2));
+}
 
-  return vcombine_u16(vqmovun_s32(sum0123), vqmovun_s32(sum4567));
+static INLINE void highbd_12_dist_wtd_convolve_x_8tap_sve2(
+    const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride,
+    int width, int height, const int16_t *x_filter_ptr) {
+  const int64x1_t offset_vec =
+      vcreate_s64((1 << (12 + FILTER_BITS)) + (1 << (12 + FILTER_BITS - 1)));
+  const int64x2_t offset_lo = vcombine_s64(offset_vec, vdup_n_s64(0));
+
+  const int16x8_t filter = vld1q_s16(x_filter_ptr);
+
+  do {
+    const int16_t *s = (const int16_t *)src;
+    uint16_t *d = dst;
+    int w = width;
+
+    do {
+      int16x8_t s0[8], s1[8], s2[8], s3[8];
+      load_s16_8x8(s + 0 * src_stride, 1, &s0[0], &s0[1], &s0[2], &s0[3],
+                   &s0[4], &s0[5], &s0[6], &s0[7]);
+      load_s16_8x8(s + 1 * src_stride, 1, &s1[0], &s1[1], &s1[2], &s1[3],
+                   &s1[4], &s1[5], &s1[6], &s1[7]);
+      load_s16_8x8(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3],
+                   &s2[4], &s2[5], &s2[6], &s2[7]);
+      load_s16_8x8(s + 3 * src_stride, 1, &s3[0], &s3[1], &s3[2], &s3[3],
+                   &s3[4], &s3[5], &s3[6], &s3[7]);
+
+      uint16x8_t d0 = highbd_12_convolve8_8_x(s0, filter, offset_lo);
+      uint16x8_t d1 = highbd_12_convolve8_8_x(s1, filter, offset_lo);
+      uint16x8_t d2 = highbd_12_convolve8_8_x(s2, filter, offset_lo);
+      uint16x8_t d3 = highbd_12_convolve8_8_x(s3, filter, offset_lo);
+
+      store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
+
+      s += 8;
+      d += 8;
+      w -= 8;
+    } while (w != 0);
+    src += 4 * src_stride;
+    dst += 4 * dst_stride;
+    height -= 4;
+  } while (height != 0);
 }
 
-static INLINE void highbd_dist_wtd_convolve_x_sve2(
+static INLINE uint16x8_t highbd_convolve8_8_x(int16x8_t s0[8], int16x8_t filter,
+                                              int64x2_t offset) {
+  int64x2_t sum[8];
+  sum[0] = aom_sdotq_s16(offset, s0[0], filter);
+  sum[1] = aom_sdotq_s16(offset, s0[1], filter);
+  sum[2] = aom_sdotq_s16(offset, s0[2], filter);
+  sum[3] = aom_sdotq_s16(offset, s0[3], filter);
+  sum[4] = aom_sdotq_s16(offset, s0[4], filter);
+  sum[5] = aom_sdotq_s16(offset, s0[5], filter);
+  sum[6] = aom_sdotq_s16(offset, s0[6], filter);
+  sum[7] = aom_sdotq_s16(offset, s0[7], filter);
+
+  sum[0] = vpaddq_s64(sum[0], sum[1]);
+  sum[2] = vpaddq_s64(sum[2], sum[3]);
+  sum[4] = vpaddq_s64(sum[4], sum[5]);
+  sum[6] = vpaddq_s64(sum[6], sum[7]);
+
+  int32x4_t sum0123 = vcombine_s32(vmovn_s64(sum[0]), vmovn_s64(sum[2]));
+  int32x4_t sum4567 = vcombine_s32(vmovn_s64(sum[4]), vmovn_s64(sum[6]));
+
+  return vcombine_u16(vqrshrun_n_s32(sum0123, ROUND0_BITS),
+                      vqrshrun_n_s32(sum4567, ROUND0_BITS));
+}
+
+static INLINE void highbd_dist_wtd_convolve_x_8tap_sve2(
     const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride,
-    int width, int height, const int16_t *x_filter_ptr,
-    ConvolveParams *conv_params, const int offset) {
-  const int32x4_t shift = vdupq_n_s32(-conv_params->round_0);
-  const int64x2_t offset_vec = vdupq_n_s64(offset);
+    int width, int height, const int16_t *x_filter_ptr, const int bd) {
+  const int64x1_t offset_vec =
+      vcreate_s64((1 << (bd + FILTER_BITS)) + (1 << (bd + FILTER_BITS - 1)));
+  const int64x2_t offset_lo = vcombine_s64(offset_vec, vdup_n_s64(0));
 
-  const int64x2_t offset_lo =
-      vcombine_s64(vget_low_s64(offset_vec), vdup_n_s64(0));
   const int16x8_t filter = vld1q_s16(x_filter_ptr);
+
   do {
     const int16_t *s = (const int16_t *)src;
     uint16_t *d = dst;
@@ -83,10 +147,10 @@ static INLINE void highbd_dist_wtd_convolve_x_sve2(
       load_s16_8x8(s + 3 * src_stride, 1, &s3[0], &s3[1], &s3[2], &s3[3],
                    &s3[4], &s3[5], &s3[6], &s3[7]);
 
-      uint16x8_t d0 = convolve8_8_x(s0, filter, offset_lo, shift);
-      uint16x8_t d1 = convolve8_8_x(s1, filter, offset_lo, shift);
-      uint16x8_t d2 = convolve8_8_x(s2, filter, offset_lo, shift);
-      uint16x8_t d3 = convolve8_8_x(s3, filter, offset_lo, shift);
+      uint16x8_t d0 = highbd_convolve8_8_x(s0, filter, offset_lo);
+      uint16x8_t d1 = highbd_convolve8_8_x(s1, filter, offset_lo);
+      uint16x8_t d2 = highbd_convolve8_8_x(s2, filter, offset_lo);
+      uint16x8_t d3 = highbd_convolve8_8_x(s3, filter, offset_lo);
 
       store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
 
@@ -142,8 +206,10 @@ DECLARE_ALIGNED(16, static const uint16_t, kDeinterleaveTbl[8]) = {
 static INLINE void highbd_dist_wtd_convolve_x_4tap_sve2(
     const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride,
     int width, int height, const int16_t *x_filter_ptr,
-    ConvolveParams *conv_params, const int offset) {
-  // This shim allows to do only one rounding shift instead of two.
+    ConvolveParams *conv_params, const int bd) {
+  const int offset = (1 << (conv_params->round_0 - 1)) +
+                     (1 << (bd + FILTER_BITS)) + (1 << (bd + FILTER_BITS - 1));
+
   const int64x2_t offset_s64 = vdupq_n_s64(offset);
   const int32x4_t shift = vdupq_n_s32(-conv_params->round_0);
 
@@ -223,9 +289,6 @@ void av1_highbd_dist_wtd_convolve_x_sve2(
   const int im_stride = MAX_SB_SIZE;
   const int horiz_offset = filter_params_x->taps / 2 - 1;
   assert(FILTER_BITS == COMPOUND_ROUND1_BITS);
-  const int offset_convolve = (1 << (conv_params->round_0 - 1)) +
-                              (1 << (bd + FILTER_BITS)) +
-                              (1 << (bd + FILTER_BITS - 1));
 
   const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
       filter_params_x, subpel_x_qn & SUBPEL_MASK);
@@ -236,13 +299,16 @@ void av1_highbd_dist_wtd_convolve_x_sve2(
     if (x_filter_taps <= 4) {
       highbd_dist_wtd_convolve_x_4tap_sve2(src + 2, src_stride, im_block,
                                            im_stride, w, h, x_filter_ptr,
-                                           conv_params, offset_convolve);
+                                           conv_params, bd);
     } else {
-      highbd_dist_wtd_convolve_x_sve2(src, src_stride, im_block, im_stride, w,
-                                      h, x_filter_ptr, conv_params,
-                                      offset_convolve);
+      if (bd == 12) {
+        highbd_12_dist_wtd_convolve_x_8tap_sve2(src, src_stride, im_block,
+                                                im_stride, w, h, x_filter_ptr);
+      } else {
+        highbd_dist_wtd_convolve_x_8tap_sve2(src, src_stride, im_block,
+                                             im_stride, w, h, x_filter_ptr, bd);
+      }
     }
-
     if (conv_params->use_dist_wtd_comp_avg) {
       if (bd == 12) {
         highbd_12_dist_wtd_comp_avg_neon(im_block, im_stride, dst, dst_stride,
@@ -257,7 +323,6 @@ void av1_highbd_dist_wtd_convolve_x_sve2(
       if (bd == 12) {
         highbd_12_comp_avg_neon(im_block, im_stride, dst, dst_stride, w, h,
                                 conv_params);
-
       } else {
         highbd_comp_avg_neon(im_block, im_stride, dst, dst_stride, w, h,
                              conv_params, bd);
@@ -267,11 +332,15 @@ void av1_highbd_dist_wtd_convolve_x_sve2(
     if (x_filter_taps <= 4) {
       highbd_dist_wtd_convolve_x_4tap_sve2(src + 2, src_stride, dst16,
                                            dst16_stride, w, h, x_filter_ptr,
-                                           conv_params, offset_convolve);
+                                           conv_params, bd);
     } else {
-      highbd_dist_wtd_convolve_x_sve2(src, src_stride, dst16, dst16_stride, w,
-                                      h, x_filter_ptr, conv_params,
-                                      offset_convolve);
+      if (bd == 12) {
+        highbd_12_dist_wtd_convolve_x_8tap_sve2(
+            src, src_stride, dst16, dst16_stride, w, h, x_filter_ptr);
+      } else {
+        highbd_dist_wtd_convolve_x_8tap_sve2(
+            src, src_stride, dst16, dst16_stride, w, h, x_filter_ptr, bd);
+      }
     }
   }
 }