aom: Add SVE2 implementation of HBD convolve_x_sr for 8-tap filters

From f909b1012a8b9a58cb7fc44eda54f5ea79e39848 Mon Sep 17 00:00:00 2001
From: Salome Thirot <[EMAIL REDACTED]>
Date: Wed, 21 Feb 2024 14:38:35 +0000
Subject: [PATCH] Add SVE2 implementation of HBD convolve_x_sr for 8-tap
 filters

Add SVE2 implementation of av1_highbd_convolve_x_sr for 8-tap filters as
well as the corresponding tests.

This implementation is not beneficial on cores that only support SVE
(like Neoverse V1) due to them having 256-bit SVE pipes instead of
128-bit pipes on more modern cores (that also support SVE2). It
therefore makes sense to restrict this implementation to SVE2 only. This
also allows us to use extra instructions not available in SVE.

Change-Id: I7a383e3d559a7ae609e136e20cae7b2b9484ad65
---
 av1/common/arm/highbd_convolve_sve2.c | 83 +++++++++++++++++++++++++--
 1 file changed, 79 insertions(+), 4 deletions(-)

diff --git a/av1/common/arm/highbd_convolve_sve2.c b/av1/common/arm/highbd_convolve_sve2.c
index 6465b28dd..d34bbbdb3 100644
--- a/av1/common/arm/highbd_convolve_sve2.c
+++ b/av1/common/arm/highbd_convolve_sve2.c
@@ -183,6 +183,76 @@ static INLINE void highbd_convolve_x_sr_12tap_sve2(
   }
 }
 
+static INLINE uint16x8_t convolve8_8_x(int16x8_t s0[8], int16x8_t filter,
+                                       int64x2_t offset, uint16x8_t max) {
+  int64x2_t sum[8];
+  sum[0] = aom_sdotq_s16(offset, s0[0], filter);
+  sum[1] = aom_sdotq_s16(offset, s0[1], filter);
+  sum[2] = aom_sdotq_s16(offset, s0[2], filter);
+  sum[3] = aom_sdotq_s16(offset, s0[3], filter);
+  sum[4] = aom_sdotq_s16(offset, s0[4], filter);
+  sum[5] = aom_sdotq_s16(offset, s0[5], filter);
+  sum[6] = aom_sdotq_s16(offset, s0[6], filter);
+  sum[7] = aom_sdotq_s16(offset, s0[7], filter);
+
+  sum[0] = vpaddq_s64(sum[0], sum[1]);
+  sum[2] = vpaddq_s64(sum[2], sum[3]);
+  sum[4] = vpaddq_s64(sum[4], sum[5]);
+  sum[6] = vpaddq_s64(sum[6], sum[7]);
+
+  int32x4_t sum0123 = vcombine_s32(vmovn_s64(sum[0]), vmovn_s64(sum[2]));
+  int32x4_t sum4567 = vcombine_s32(vmovn_s64(sum[4]), vmovn_s64(sum[6]));
+
+  uint16x8_t res = vcombine_u16(vqrshrun_n_s32(sum0123, FILTER_BITS),
+                                vqrshrun_n_s32(sum4567, FILTER_BITS));
+
+  return vminq_u16(res, max);
+}
+
+static INLINE void highbd_convolve_x_sr_8tap_sve2(
+    const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride,
+    int width, int height, const int16_t *y_filter_ptr,
+    ConvolveParams *conv_params, int bd) {
+  const uint16x8_t max = vdupq_n_u16((1 << bd) - 1);
+  // This shim allows to do only one rounding shift instead of two.
+  const int64_t offset = 1 << (conv_params->round_0 - 1);
+  const int64x2_t offset_lo = vcombine_s64((int64x1_t)(offset), vdup_n_s64(0));
+
+  const int16x8_t filter = vld1q_s16(y_filter_ptr);
+
+  do {
+    const int16_t *s = (const int16_t *)src;
+    uint16_t *d = dst;
+    int w = width;
+
+    do {
+      int16x8_t s0[8], s1[8], s2[8], s3[8];
+      load_s16_8x8(s + 0 * src_stride, 1, &s0[0], &s0[1], &s0[2], &s0[3],
+                   &s0[4], &s0[5], &s0[6], &s0[7]);
+      load_s16_8x8(s + 1 * src_stride, 1, &s1[0], &s1[1], &s1[2], &s1[3],
+                   &s1[4], &s1[5], &s1[6], &s1[7]);
+      load_s16_8x8(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3],
+                   &s2[4], &s2[5], &s2[6], &s2[7]);
+      load_s16_8x8(s + 3 * src_stride, 1, &s3[0], &s3[1], &s3[2], &s3[3],
+                   &s3[4], &s3[5], &s3[6], &s3[7]);
+
+      uint16x8_t d0 = convolve8_8_x(s0, filter, offset_lo, max);
+      uint16x8_t d1 = convolve8_8_x(s1, filter, offset_lo, max);
+      uint16x8_t d2 = convolve8_8_x(s2, filter, offset_lo, max);
+      uint16x8_t d3 = convolve8_8_x(s3, filter, offset_lo, max);
+
+      store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
+
+      s += 8;
+      d += 8;
+      w -= 8;
+    } while (w != 0);
+    src += 4 * src_stride;
+    dst += 4 * dst_stride;
+    height -= 4;
+  } while (height != 0);
+}
+
 void av1_highbd_convolve_x_sr_sve2(const uint16_t *src, int src_stride,
                                    uint16_t *dst, int dst_stride, int w, int h,
                                    const InterpFilterParams *filter_params_x,
@@ -196,7 +266,7 @@ void av1_highbd_convolve_x_sr_sve2(const uint16_t *src, int src_stride,
 
   const int x_filter_taps = get_filter_tap(filter_params_x, subpel_x_qn);
 
-  if (x_filter_taps != 12) {
+  if (x_filter_taps != 12 && x_filter_taps != 8) {
     av1_highbd_convolve_x_sr_neon(src, src_stride, dst, dst_stride, w, h,
                                   filter_params_x, subpel_x_qn, conv_params,
                                   bd);
@@ -209,7 +279,12 @@ void av1_highbd_convolve_x_sr_sve2(const uint16_t *src, int src_stride,
 
   src -= horiz_offset;
 
-  highbd_convolve_x_sr_12tap_sve2(src, src_stride, dst, dst_stride, w, h,
-                                  x_filter_ptr, conv_params, bd);
-  return;
+  if (x_filter_taps == 12) {
+    highbd_convolve_x_sr_12tap_sve2(src, src_stride, dst, dst_stride, w, h,
+                                    x_filter_ptr, conv_params, bd);
+    return;
+  }
+
+  highbd_convolve_x_sr_8tap_sve2(src, src_stride, dst, dst_stride, w, h,
+                                 x_filter_ptr, conv_params, bd);
 }