aom: Specialize 4-tap HBD dist_wtd_convolve_x_sve2 on bitdepth

From 8cb23f865eb58e2772c098d8044b4edfdfbdd03d Mon Sep 17 00:00:00 2001
From: Salome Thirot <[EMAIL REDACTED]>
Date: Fri, 8 Mar 2024 16:13:57 +0000
Subject: [PATCH] Specialize 4-tap HBD dist_wtd_convolve_x_sve2 on bitdepth

The rounding value at the end of the convolution depends on the bitdepth
(8/10, or 12). Add 2 specialized versions of the function, so that we
know the rounding value at compile time and therefore use only one
instruction instead of two to perform the final rounding and narrowing
step. This gives up to 20% uplift over the non-specialized version.

Change-Id: Id83a813ddefbb61704adf4b59d93f6910160f22d
---
 .../arm/highbd_compound_convolve_sve2.c       | 212 +++++++++++++-----
 1 file changed, 152 insertions(+), 60 deletions(-)

diff --git a/av1/common/arm/highbd_compound_convolve_sve2.c b/av1/common/arm/highbd_compound_convolve_sve2.c
index 8977538c4..dc983c5f8 100644
--- a/av1/common/arm/highbd_compound_convolve_sve2.c
+++ b/av1/common/arm/highbd_compound_convolve_sve2.c
@@ -164,9 +164,15 @@ static INLINE void highbd_dist_wtd_convolve_x_8tap_sve2(
   } while (height != 0);
 }
 
-static INLINE uint16x4_t convolve4_4_x(int16x8_t s0, int16x8_t filter,
-                                       int64x2_t offset, int32x4_t shift,
-                                       uint16x8x2_t permute_tbl) {
+// clang-format off
+DECLARE_ALIGNED(16, static const uint16_t, kDeinterleaveTbl[8]) = {
+  0, 2, 4, 6, 1, 3, 5, 7,
+};
+// clang-format on
+
+static INLINE uint16x4_t highbd_12_convolve4_4_x(int16x8_t s0, int16x8_t filter,
+                                                 int64x2_t offset,
+                                                 uint16x8x2_t permute_tbl) {
   int16x8_t permuted_samples0 = aom_tbl_s16(s0, permute_tbl.val[0]);
   int16x8_t permuted_samples1 = aom_tbl_s16(s0, permute_tbl.val[1]);
 
@@ -174,44 +180,124 @@ static INLINE uint16x4_t convolve4_4_x(int16x8_t s0, int16x8_t filter,
   int64x2_t sum23 = aom_svdot_lane_s16(offset, permuted_samples1, filter, 0);
 
   int32x4_t sum0123 = vcombine_s32(vmovn_s64(sum01), vmovn_s64(sum23));
-  sum0123 = vshlq_s32(sum0123, shift);
 
-  return vqmovun_s32(sum0123);
+  return vqrshrun_n_s32(sum0123, ROUND0_BITS + 2);
 }
 
-static INLINE uint16x8_t convolve4_8_x(int16x8_t s0[4], int16x8_t filter,
-                                       int64x2_t offset, int32x4_t shift,
-                                       uint16x8_t tbl) {
+static INLINE uint16x8_t highbd_12_convolve4_8_x(int16x8_t s0[4],
+                                                 int16x8_t filter,
+                                                 int64x2_t offset,
+                                                 uint16x8_t tbl) {
   int64x2_t sum04 = aom_svdot_lane_s16(offset, s0[0], filter, 0);
   int64x2_t sum15 = aom_svdot_lane_s16(offset, s0[1], filter, 0);
   int64x2_t sum26 = aom_svdot_lane_s16(offset, s0[2], filter, 0);
   int64x2_t sum37 = aom_svdot_lane_s16(offset, s0[3], filter, 0);
 
   int32x4_t sum0415 = vcombine_s32(vmovn_s64(sum04), vmovn_s64(sum15));
-  sum0415 = vshlq_s32(sum0415, shift);
-
   int32x4_t sum2637 = vcombine_s32(vmovn_s64(sum26), vmovn_s64(sum37));
-  sum2637 = vshlq_s32(sum2637, shift);
 
-  uint16x8_t res = vcombine_u16(vqmovun_s32(sum0415), vqmovun_s32(sum2637));
+  uint16x8_t res = vcombine_u16(vqrshrun_n_s32(sum0415, ROUND0_BITS + 2),
+                                vqrshrun_n_s32(sum2637, ROUND0_BITS + 2));
   return aom_tbl_u16(res, tbl);
 }
 
-// clang-format off
-DECLARE_ALIGNED(16, static const uint16_t, kDeinterleaveTbl[8]) = {
-  0, 2, 4, 6, 1, 3, 5, 7,
-};
-// clang-format on
+static INLINE void highbd_12_dist_wtd_convolve_x_4tap_sve2(
+    const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride,
+    int width, int height, const int16_t *x_filter_ptr) {
+  const int64x2_t offset =
+      vdupq_n_s64((1 << (12 + FILTER_BITS)) + (1 << (12 + FILTER_BITS - 1)));
+
+  const int16x4_t x_filter = vld1_s16(x_filter_ptr + 2);
+  const int16x8_t filter = vcombine_s16(x_filter, vdup_n_s16(0));
+
+  if (width == 4) {
+    uint16x8x2_t permute_tbl = vld1q_u16_x2(kDotProdTbl);
+
+    const int16_t *s = (const int16_t *)(src);
+
+    do {
+      int16x8_t s0, s1, s2, s3;
+      load_s16_8x4(s, src_stride, &s0, &s1, &s2, &s3);
+
+      uint16x4_t d0 = highbd_12_convolve4_4_x(s0, filter, offset, permute_tbl);
+      uint16x4_t d1 = highbd_12_convolve4_4_x(s1, filter, offset, permute_tbl);
+      uint16x4_t d2 = highbd_12_convolve4_4_x(s2, filter, offset, permute_tbl);
+      uint16x4_t d3 = highbd_12_convolve4_4_x(s3, filter, offset, permute_tbl);
+
+      store_u16_4x4(dst, dst_stride, d0, d1, d2, d3);
+
+      s += 4 * src_stride;
+      dst += 4 * dst_stride;
+      height -= 4;
+    } while (height != 0);
+  } else {
+    uint16x8_t idx = vld1q_u16(kDeinterleaveTbl);
+
+    do {
+      const int16_t *s = (const int16_t *)(src);
+      uint16_t *d = dst;
+      int w = width;
+
+      do {
+        int16x8_t s0[4], s1[4], s2[4], s3[4];
+        load_s16_8x4(s + 0 * src_stride, 1, &s0[0], &s0[1], &s0[2], &s0[3]);
+        load_s16_8x4(s + 1 * src_stride, 1, &s1[0], &s1[1], &s1[2], &s1[3]);
+        load_s16_8x4(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3]);
+        load_s16_8x4(s + 3 * src_stride, 1, &s3[0], &s3[1], &s3[2], &s3[3]);
+
+        uint16x8_t d0 = highbd_12_convolve4_8_x(s0, filter, offset, idx);
+        uint16x8_t d1 = highbd_12_convolve4_8_x(s1, filter, offset, idx);
+        uint16x8_t d2 = highbd_12_convolve4_8_x(s2, filter, offset, idx);
+        uint16x8_t d3 = highbd_12_convolve4_8_x(s3, filter, offset, idx);
+
+        store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
+
+        s += 8;
+        d += 8;
+        w -= 8;
+      } while (w != 0);
+      src += 4 * src_stride;
+      dst += 4 * dst_stride;
+      height -= 4;
+    } while (height != 0);
+  }
+}
+
+static INLINE uint16x4_t highbd_convolve4_4_x(int16x8_t s0, int16x8_t filter,
+                                              int64x2_t offset,
+                                              uint16x8x2_t permute_tbl) {
+  int16x8_t permuted_samples0 = aom_tbl_s16(s0, permute_tbl.val[0]);
+  int16x8_t permuted_samples1 = aom_tbl_s16(s0, permute_tbl.val[1]);
+
+  int64x2_t sum01 = aom_svdot_lane_s16(offset, permuted_samples0, filter, 0);
+  int64x2_t sum23 = aom_svdot_lane_s16(offset, permuted_samples1, filter, 0);
+
+  int32x4_t sum0123 = vcombine_s32(vmovn_s64(sum01), vmovn_s64(sum23));
+
+  return vqrshrun_n_s32(sum0123, ROUND0_BITS);
+}
+
+static INLINE uint16x8_t highbd_convolve4_8_x(int16x8_t s0[4], int16x8_t filter,
+                                              int64x2_t offset,
+                                              uint16x8_t tbl) {
+  int64x2_t sum04 = aom_svdot_lane_s16(offset, s0[0], filter, 0);
+  int64x2_t sum15 = aom_svdot_lane_s16(offset, s0[1], filter, 0);
+  int64x2_t sum26 = aom_svdot_lane_s16(offset, s0[2], filter, 0);
+  int64x2_t sum37 = aom_svdot_lane_s16(offset, s0[3], filter, 0);
+
+  int32x4_t sum0415 = vcombine_s32(vmovn_s64(sum04), vmovn_s64(sum15));
+  int32x4_t sum2637 = vcombine_s32(vmovn_s64(sum26), vmovn_s64(sum37));
+
+  uint16x8_t res = vcombine_u16(vqrshrun_n_s32(sum0415, ROUND0_BITS),
+                                vqrshrun_n_s32(sum2637, ROUND0_BITS));
+  return aom_tbl_u16(res, tbl);
+}
 
 static INLINE void highbd_dist_wtd_convolve_x_4tap_sve2(
     const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride,
-    int width, int height, const int16_t *x_filter_ptr,
-    ConvolveParams *conv_params, const int bd) {
-  const int offset = (1 << (conv_params->round_0 - 1)) +
-                     (1 << (bd + FILTER_BITS)) + (1 << (bd + FILTER_BITS - 1));
-
-  const int64x2_t offset_s64 = vdupq_n_s64(offset);
-  const int32x4_t shift = vdupq_n_s32(-conv_params->round_0);
+    int width, int height, const int16_t *x_filter_ptr, const int bd) {
+  const int64x2_t offset =
+      vdupq_n_s64((1 << (bd + FILTER_BITS)) + (1 << (bd + FILTER_BITS - 1)));
 
   const int16x4_t x_filter = vld1_s16(x_filter_ptr + 2);
   const int16x8_t filter = vcombine_s16(x_filter, vdup_n_s16(0));
@@ -225,10 +311,10 @@ static INLINE void highbd_dist_wtd_convolve_x_4tap_sve2(
       int16x8_t s0, s1, s2, s3;
       load_s16_8x4(s, src_stride, &s0, &s1, &s2, &s3);
 
-      uint16x4_t d0 = convolve4_4_x(s0, filter, offset_s64, shift, permute_tbl);
-      uint16x4_t d1 = convolve4_4_x(s1, filter, offset_s64, shift, permute_tbl);
-      uint16x4_t d2 = convolve4_4_x(s2, filter, offset_s64, shift, permute_tbl);
-      uint16x4_t d3 = convolve4_4_x(s3, filter, offset_s64, shift, permute_tbl);
+      uint16x4_t d0 = highbd_convolve4_4_x(s0, filter, offset, permute_tbl);
+      uint16x4_t d1 = highbd_convolve4_4_x(s1, filter, offset, permute_tbl);
+      uint16x4_t d2 = highbd_convolve4_4_x(s2, filter, offset, permute_tbl);
+      uint16x4_t d3 = highbd_convolve4_4_x(s3, filter, offset, permute_tbl);
 
       store_u16_4x4(dst, dst_stride, d0, d1, d2, d3);
 
@@ -251,10 +337,10 @@ static INLINE void highbd_dist_wtd_convolve_x_4tap_sve2(
         load_s16_8x4(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3]);
         load_s16_8x4(s + 3 * src_stride, 1, &s3[0], &s3[1], &s3[2], &s3[3]);
 
-        uint16x8_t d0 = convolve4_8_x(s0, filter, offset_s64, shift, idx);
-        uint16x8_t d1 = convolve4_8_x(s1, filter, offset_s64, shift, idx);
-        uint16x8_t d2 = convolve4_8_x(s2, filter, offset_s64, shift, idx);
-        uint16x8_t d3 = convolve4_8_x(s3, filter, offset_s64, shift, idx);
+        uint16x8_t d0 = highbd_convolve4_8_x(s0, filter, offset, idx);
+        uint16x8_t d1 = highbd_convolve4_8_x(s1, filter, offset, idx);
+        uint16x8_t d2 = highbd_convolve4_8_x(s2, filter, offset, idx);
+        uint16x8_t d3 = highbd_convolve4_8_x(s3, filter, offset, idx);
 
         store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
 
@@ -295,48 +381,54 @@ void av1_highbd_dist_wtd_convolve_x_sve2(
 
   src -= horiz_offset;
 
-  if (conv_params->do_average) {
-    if (x_filter_taps <= 4) {
-      highbd_dist_wtd_convolve_x_4tap_sve2(src + 2, src_stride, im_block,
-                                           im_stride, w, h, x_filter_ptr,
-                                           conv_params, bd);
-    } else {
-      if (bd == 12) {
-        highbd_12_dist_wtd_convolve_x_8tap_sve2(src, src_stride, im_block,
+  if (bd == 12) {
+    if (conv_params->do_average) {
+      if (x_filter_taps <= 4) {
+        highbd_12_dist_wtd_convolve_x_4tap_sve2(src + 2, src_stride, im_block,
                                                 im_stride, w, h, x_filter_ptr);
       } else {
-        highbd_dist_wtd_convolve_x_8tap_sve2(src, src_stride, im_block,
-                                             im_stride, w, h, x_filter_ptr, bd);
+        highbd_12_dist_wtd_convolve_x_8tap_sve2(src, src_stride, im_block,
+                                                im_stride, w, h, x_filter_ptr);
       }
-    }
-    if (conv_params->use_dist_wtd_comp_avg) {
-      if (bd == 12) {
+
+      if (conv_params->use_dist_wtd_comp_avg) {
         highbd_12_dist_wtd_comp_avg_neon(im_block, im_stride, dst, dst_stride,
                                          w, h, conv_params);
 
       } else {
-        highbd_dist_wtd_comp_avg_neon(im_block, im_stride, dst, dst_stride, w,
-                                      h, conv_params, bd);
-      }
-
-    } else {
-      if (bd == 12) {
         highbd_12_comp_avg_neon(im_block, im_stride, dst, dst_stride, w, h,
                                 conv_params);
+      }
+    } else {
+      if (x_filter_taps <= 4) {
+        highbd_12_dist_wtd_convolve_x_4tap_sve2(
+            src + 2, src_stride, dst16, dst16_stride, w, h, x_filter_ptr);
       } else {
-        highbd_comp_avg_neon(im_block, im_stride, dst, dst_stride, w, h,
-                             conv_params, bd);
+        highbd_12_dist_wtd_convolve_x_8tap_sve2(
+            src, src_stride, dst16, dst16_stride, w, h, x_filter_ptr);
       }
     }
   } else {
-    if (x_filter_taps <= 4) {
-      highbd_dist_wtd_convolve_x_4tap_sve2(src + 2, src_stride, dst16,
-                                           dst16_stride, w, h, x_filter_ptr,
-                                           conv_params, bd);
+    if (conv_params->do_average) {
+      if (x_filter_taps <= 4) {
+        highbd_dist_wtd_convolve_x_4tap_sve2(src + 2, src_stride, im_block,
+                                             im_stride, w, h, x_filter_ptr, bd);
+      } else {
+        highbd_dist_wtd_convolve_x_8tap_sve2(src, src_stride, im_block,
+                                             im_stride, w, h, x_filter_ptr, bd);
+      }
+
+      if (conv_params->use_dist_wtd_comp_avg) {
+        highbd_dist_wtd_comp_avg_neon(im_block, im_stride, dst, dst_stride, w,
+                                      h, conv_params, bd);
+      } else {
+        highbd_comp_avg_neon(im_block, im_stride, dst, dst_stride, w, h,
+                             conv_params, bd);
+      }
     } else {
-      if (bd == 12) {
-        highbd_12_dist_wtd_convolve_x_8tap_sve2(
-            src, src_stride, dst16, dst16_stride, w, h, x_filter_ptr);
+      if (x_filter_taps <= 4) {
+        highbd_dist_wtd_convolve_x_4tap_sve2(
+            src + 2, src_stride, dst16, dst16_stride, w, h, x_filter_ptr, bd);
       } else {
         highbd_dist_wtd_convolve_x_8tap_sve2(
             src, src_stride, dst16, dst16_stride, w, h, x_filter_ptr, bd);