aom: Add merged impl of 6-tap av1_convolve_2d_sr_neon_i8mm

From 069d267d7c4b0738d9a46d4be0c48b8d67e42003 Mon Sep 17 00:00:00 2001
From: Salome Thirot <[EMAIL REDACTED]>
Date: Thu, 9 May 2024 14:52:38 +0100
Subject: [PATCH] Add merged impl of 6-tap av1_convolve_2d_sr_neon_i8mm

Merge the horizontal and vertical passes of av1_convolve_2d_sr_neon_i8mm
for 6-tap filters, avoiding the use of an intermdiate buffer. This gives
around 10% uplift over the split implementation.

Change-Id: I34d5dc819bdc36f04ac172bce349257f8f7887d2
---
 av1/common/arm/convolve_neon_i8mm.c | 79 +++++++++++++++++++++++++++++
 1 file changed, 79 insertions(+)

diff --git a/av1/common/arm/convolve_neon_i8mm.c b/av1/common/arm/convolve_neon_i8mm.c
index c3d4c94c7..8f54b64fa 100644
--- a/av1/common/arm/convolve_neon_i8mm.c
+++ b/av1/common/arm/convolve_neon_i8mm.c
@@ -983,6 +983,79 @@ static INLINE void convolve_2d_sr_horiz_4tap_neon_i8mm(
   }
 }
 
+static INLINE void convolve_2d_sr_6tap_neon_i8mm(const uint8_t *src,
+                                                 int src_stride, uint8_t *dst,
+                                                 int dst_stride, int w, int h,
+                                                 const int16_t *x_filter_ptr,
+                                                 const int16_t *y_filter_ptr) {
+  const int16x8_t y_filter = vld1q_s16(y_filter_ptr);
+  // Filter values are even, so halve to reduce intermediate precision reqs.
+  const int8x8_t x_filter = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
+
+  const int bd = 8;
+  // This shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
+  // shifts - which are generally faster than rounding shifts on modern CPUs.
+  // The outermost -1 is needed because we halved the filter values.
+  const int32x4_t horiz_const = vdupq_n_s32((1 << (bd + FILTER_BITS - 2)) +
+                                            (1 << ((ROUND0_BITS - 1) - 1)));
+  const int16x8_t vert_const = vdupq_n_s16(1 << (bd - 1));
+  const uint8x16x3_t permute_tbl = vld1q_u8_x3(kDotProdPermuteTbl);
+
+  do {
+    const uint8_t *s = src;
+    uint8_t *d = dst;
+    int height = h;
+
+    uint8x16_t h_s0, h_s1, h_s2, h_s3, h_s4;
+    load_u8_16x5(s, src_stride, &h_s0, &h_s1, &h_s2, &h_s3, &h_s4);
+    s += 5 * src_stride;
+
+    int16x8_t v_s0 = convolve8_8_2d_h(h_s0, x_filter, permute_tbl, horiz_const);
+    int16x8_t v_s1 = convolve8_8_2d_h(h_s1, x_filter, permute_tbl, horiz_const);
+    int16x8_t v_s2 = convolve8_8_2d_h(h_s2, x_filter, permute_tbl, horiz_const);
+    int16x8_t v_s3 = convolve8_8_2d_h(h_s3, x_filter, permute_tbl, horiz_const);
+    int16x8_t v_s4 = convolve8_8_2d_h(h_s4, x_filter, permute_tbl, horiz_const);
+
+    do {
+      uint8x16_t h_s5, h_s6, h_s7, h_s8;
+      load_u8_16x4(s, src_stride, &h_s5, &h_s6, &h_s7, &h_s8);
+
+      int16x8_t v_s5 =
+          convolve8_8_2d_h(h_s5, x_filter, permute_tbl, horiz_const);
+      int16x8_t v_s6 =
+          convolve8_8_2d_h(h_s6, x_filter, permute_tbl, horiz_const);
+      int16x8_t v_s7 =
+          convolve8_8_2d_h(h_s7, x_filter, permute_tbl, horiz_const);
+      int16x8_t v_s8 =
+          convolve8_8_2d_h(h_s8, x_filter, permute_tbl, horiz_const);
+
+      uint8x8_t d0 = convolve6_8_2d_v(v_s0, v_s1, v_s2, v_s3, v_s4, v_s5,
+                                      y_filter, vert_const);
+      uint8x8_t d1 = convolve6_8_2d_v(v_s1, v_s2, v_s3, v_s4, v_s5, v_s6,
+                                      y_filter, vert_const);
+      uint8x8_t d2 = convolve6_8_2d_v(v_s2, v_s3, v_s4, v_s5, v_s6, v_s7,
+                                      y_filter, vert_const);
+      uint8x8_t d3 = convolve6_8_2d_v(v_s3, v_s4, v_s5, v_s6, v_s7, v_s8,
+                                      y_filter, vert_const);
+
+      store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
+
+      v_s0 = v_s4;
+      v_s1 = v_s5;
+      v_s2 = v_s6;
+      v_s3 = v_s7;
+      v_s4 = v_s8;
+
+      s += 4 * src_stride;
+      d += 4 * dst_stride;
+      height -= 4;
+    } while (height != 0);
+    src += 8;
+    dst += 8;
+    w -= 8;
+  } while (w != 0);
+}
+
 void av1_convolve_2d_sr_neon_i8mm(const uint8_t *src, int src_stride,
                                   uint8_t *dst, int dst_stride, int w, int h,
                                   const InterpFilterParams *filter_params_x,
@@ -1029,6 +1102,12 @@ void av1_convolve_2d_sr_neon_i8mm(const uint8_t *src, int src_stride,
     DECLARE_ALIGNED(16, int16_t,
                     im_block[(MAX_SB_SIZE + SUBPEL_TAPS - 1) * MAX_SB_SIZE]);
 
+    if (y_filter_taps == 6 && x_filter_taps >= 6) {
+      convolve_2d_sr_6tap_neon_i8mm(src_ptr, src_stride, dst, dst_stride, w, h,
+                                    x_filter_ptr, y_filter_ptr);
+      return;
+    }
+
     if (x_filter_taps <= 4) {
       convolve_2d_sr_horiz_4tap_neon_i8mm(src_ptr + 2, src_stride, im_block,
                                           im_stride, w, im_h, x_filter_ptr);