aom: Add SVE2 implementation of HBD convolve_x_sr for 4-tap filters

From 2a3c66fa3d6e288ad73a837e370e9fead4056d0f Mon Sep 17 00:00:00 2001
From: Salome Thirot <[EMAIL REDACTED]>
Date: Wed, 21 Feb 2024 14:42:20 +0000
Subject: [PATCH] Add SVE2 implementation of HBD convolve_x_sr for 4-tap
 filters

Add SVE2 implementation of av1_highbd_convolve_x_sr for 4-tap filters as
well as the corresponding tests.

This implementation is not beneficial on cores that only support SVE
(like Neoverse V1) due to them having 256-bit SVE pipes instead of
128-bit pipes on more modern cores (that also support SVE2). It
therefore makes sense to restrict this implementation to SVE2 only. This
also allows us to use extra instructions not available in SVE.

Change-Id: I06590e03b56a262f01a25324fc2d2ebb88369495
---
 av1/common/arm/highbd_convolve_sve2.c | 128 +++++++++++++++++++++++++-
 1 file changed, 124 insertions(+), 4 deletions(-)

diff --git a/av1/common/arm/highbd_convolve_sve2.c b/av1/common/arm/highbd_convolve_sve2.c
index d34bbbdb3..5b6cb45c1 100644
--- a/av1/common/arm/highbd_convolve_sve2.c
+++ b/av1/common/arm/highbd_convolve_sve2.c
@@ -22,7 +22,7 @@
 #include "av1/common/convolve.h"
 #include "av1/common/filter.h"
 
-DECLARE_ALIGNED(16, static const uint16_t, kDotProd12TapTbl[32]) = {
+DECLARE_ALIGNED(16, static const uint16_t, kDotProdTbl[32]) = {
   0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6,
   4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10,
 };
@@ -121,7 +121,7 @@ static INLINE void highbd_convolve_x_sr_12tap_sve2(
   const int16x8_t y_filter_0_7 = vld1q_s16(y_filter_ptr);
   const int16x8_t y_filter_4_11 = vld1q_s16(y_filter_ptr + 4);
 
-  uint16x8x4_t permute_tbl = vld1q_u16_x4(kDotProd12TapTbl);
+  uint16x8x4_t permute_tbl = vld1q_u16_x4(kDotProdTbl);
 
   if (width == 4) {
     const uint16x4_t max = vdup_n_u16((1 << bd) - 1);
@@ -253,6 +253,120 @@ static INLINE void highbd_convolve_x_sr_8tap_sve2(
   } while (height != 0);
 }
 
+// clang-format off
+DECLARE_ALIGNED(16, static const uint16_t, kDeinterleaveTbl[8]) = {
+  0, 2, 4, 6, 1, 3, 5, 7,
+};
+// clang-format on
+
+static INLINE uint16x8_t aom_tbl_u16(uint16x8_t src, uint16x8_t table) {
+  return svget_neonq_u16(svtbl_u16(svset_neonq_u16(svundef_u16(), src),
+                                   svset_neonq_u16(svundef_u16(), table)));
+}
+
+static INLINE uint16x4_t convolve4_4_x(int16x8_t s0, int16x8_t filter,
+                                       int64x2_t offset,
+                                       uint16x8x2_t permute_tbl,
+                                       uint16x4_t max) {
+  int16x8_t permuted_samples0 = aom_tbl_s16(s0, permute_tbl.val[0]);
+  int16x8_t permuted_samples1 = aom_tbl_s16(s0, permute_tbl.val[1]);
+
+  int64x2_t sum01 = aom_svdot_lane_s16(offset, permuted_samples0, filter, 0);
+  int64x2_t sum23 = aom_svdot_lane_s16(offset, permuted_samples1, filter, 0);
+
+  int32x4_t sum0123 = vcombine_s32(vmovn_s64(sum01), vmovn_s64(sum23));
+  uint16x4_t res = vqrshrun_n_s32(sum0123, FILTER_BITS);
+
+  return vmin_u16(res, max);
+}
+
+static INLINE uint16x8_t convolve4_8_x(int16x8_t s0[8], int16x8_t filter,
+                                       int64x2_t offset, uint16x8_t tbl,
+                                       uint16x8_t max) {
+  int64x2_t sum04 = aom_svdot_lane_s16(offset, s0[0], filter, 0);
+  int64x2_t sum15 = aom_svdot_lane_s16(offset, s0[1], filter, 0);
+  int64x2_t sum26 = aom_svdot_lane_s16(offset, s0[2], filter, 0);
+  int64x2_t sum37 = aom_svdot_lane_s16(offset, s0[3], filter, 0);
+
+  int32x4_t sum0415 = vcombine_s32(vmovn_s64(sum04), vmovn_s64(sum15));
+  int32x4_t sum2637 = vcombine_s32(vmovn_s64(sum26), vmovn_s64(sum37));
+
+  uint16x8_t res = vcombine_u16(vqrshrun_n_s32(sum0415, FILTER_BITS),
+                                vqrshrun_n_s32(sum2637, FILTER_BITS));
+  res = aom_tbl_u16(res, tbl);
+
+  return vminq_u16(res, max);
+}
+
+static INLINE void highbd_convolve_x_sr_4tap_sve2(
+    const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride,
+    int width, int height, const int16_t *x_filter_ptr,
+    ConvolveParams *conv_params, int bd) {
+  // This shim allows to do only one rounding shift instead of two.
+  const int64x2_t offset = vdupq_n_s64(1 << (conv_params->round_0 - 1));
+
+  const int16x4_t x_filter = vld1_s16(x_filter_ptr + 2);
+  const int16x8_t filter = vcombine_s16(x_filter, vdup_n_s16(0));
+
+  if (width == 4) {
+    const uint16x4_t max = vdup_n_u16((1 << bd) - 1);
+    uint16x8x2_t permute_tbl = vld1q_u16_x2(kDotProdTbl);
+
+    const int16_t *s = (const int16_t *)(src);
+
+    do {
+      int16x8_t s0, s1, s2, s3;
+      load_s16_8x4(s, src_stride, &s0, &s1, &s2, &s3);
+
+      uint16x4_t d0 = convolve4_4_x(s0, filter, offset, permute_tbl, max);
+      uint16x4_t d1 = convolve4_4_x(s1, filter, offset, permute_tbl, max);
+      uint16x4_t d2 = convolve4_4_x(s2, filter, offset, permute_tbl, max);
+      uint16x4_t d3 = convolve4_4_x(s3, filter, offset, permute_tbl, max);
+
+      store_u16_4x4(dst, dst_stride, d0, d1, d2, d3);
+
+      s += 4 * src_stride;
+      dst += 4 * dst_stride;
+      height -= 4;
+    } while (height != 0);
+  } else {
+    const uint16x8_t max = vdupq_n_u16((1 << bd) - 1);
+    uint16x8_t idx = vld1q_u16(kDeinterleaveTbl);
+
+    do {
+      const int16_t *s = (const int16_t *)(src);
+      uint16_t *d = dst;
+      int w = width;
+
+      do {
+        int16x8_t s0[8], s1[8], s2[8], s3[8];
+        load_s16_8x8(s + 0 * src_stride, 1, &s0[0], &s0[1], &s0[2], &s0[3],
+                     &s0[4], &s0[5], &s0[6], &s0[7]);
+        load_s16_8x8(s + 1 * src_stride, 1, &s1[0], &s1[1], &s1[2], &s1[3],
+                     &s1[4], &s1[5], &s1[6], &s1[7]);
+        load_s16_8x8(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3],
+                     &s2[4], &s2[5], &s2[6], &s2[7]);
+        load_s16_8x8(s + 3 * src_stride, 1, &s3[0], &s3[1], &s3[2], &s3[3],
+                     &s3[4], &s3[5], &s3[6], &s3[7]);
+
+        uint16x8_t d0 = convolve4_8_x(s0, filter, offset, idx, max);
+        uint16x8_t d1 = convolve4_8_x(s1, filter, offset, idx, max);
+        uint16x8_t d2 = convolve4_8_x(s2, filter, offset, idx, max);
+        uint16x8_t d3 = convolve4_8_x(s3, filter, offset, idx, max);
+
+        store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
+
+        s += 8;
+        d += 8;
+        w -= 8;
+      } while (w != 0);
+      src += 4 * src_stride;
+      dst += 4 * dst_stride;
+      height -= 4;
+    } while (height != 0);
+  }
+}
+
 void av1_highbd_convolve_x_sr_sve2(const uint16_t *src, int src_stride,
                                    uint16_t *dst, int dst_stride, int w, int h,
                                    const InterpFilterParams *filter_params_x,
@@ -266,7 +380,7 @@ void av1_highbd_convolve_x_sr_sve2(const uint16_t *src, int src_stride,
 
   const int x_filter_taps = get_filter_tap(filter_params_x, subpel_x_qn);
 
-  if (x_filter_taps != 12 && x_filter_taps != 8) {
+  if (x_filter_taps == 6) {
     av1_highbd_convolve_x_sr_neon(src, src_stride, dst, dst_stride, w, h,
                                   filter_params_x, subpel_x_qn, conv_params,
                                   bd);
@@ -285,6 +399,12 @@ void av1_highbd_convolve_x_sr_sve2(const uint16_t *src, int src_stride,
     return;
   }
 
-  highbd_convolve_x_sr_8tap_sve2(src, src_stride, dst, dst_stride, w, h,
+  if (x_filter_taps == 8) {
+    highbd_convolve_x_sr_8tap_sve2(src, src_stride, dst, dst_stride, w, h,
+                                   x_filter_ptr, conv_params, bd);
+    return;
+  }
+
+  highbd_convolve_x_sr_4tap_sve2(src + 2, src_stride, dst, dst_stride, w, h,
                                  x_filter_ptr, conv_params, bd);
 }