aom: Add 4-tap merged impl of av1_convolve_2d_sr_neon_dotprod

From 49d02208d85c05eb000f4326a15c1a9c5f4e5e2e Mon Sep 17 00:00:00 2001
From: Salome Thirot <[EMAIL REDACTED]>
Date: Fri, 10 May 2024 15:12:57 +0100
Subject: [PATCH] Add 4-tap merged impl of av1_convolve_2d_sr_neon_dotprod

Merge the vertical and horizontal passes of
av1_convolve_2d_sr_neon_dotprod for 4-tap filters, avoiding the use of
an intermediate buffer. This gives around 10% uplift over the split
implementation.

Change-Id: Id8a1d16a892827109d210b34ba34043c46227e53
---
 aom_dsp/arm/mem_neon.h                 |  10 ++
 av1/common/arm/convolve_neon_dotprod.c | 131 +++++++++++++++++++++++++
 2 files changed, 141 insertions(+)

diff --git a/aom_dsp/arm/mem_neon.h b/aom_dsp/arm/mem_neon.h
index 1aebcf951..b5deb9ca3 100644
--- a/aom_dsp/arm/mem_neon.h
+++ b/aom_dsp/arm/mem_neon.h
@@ -1080,6 +1080,16 @@ static INLINE void load_u8_16x4(const uint8_t *s, ptrdiff_t p,
   *s3 = vld1q_u8(s);
 }
 
+static INLINE void load_u8_16x3(const uint8_t *s, ptrdiff_t p,
+                                uint8x16_t *const s0, uint8x16_t *const s1,
+                                uint8x16_t *const s2) {
+  *s0 = vld1q_u8(s);
+  s += p;
+  *s1 = vld1q_u8(s);
+  s += p;
+  *s2 = vld1q_u8(s);
+}
+
 static INLINE void load_u16_8x8(const uint16_t *s, const ptrdiff_t p,
                                 uint16x8_t *s0, uint16x8_t *s1, uint16x8_t *s2,
                                 uint16x8_t *s3, uint16x8_t *s4, uint16x8_t *s5,
diff --git a/av1/common/arm/convolve_neon_dotprod.c b/av1/common/arm/convolve_neon_dotprod.c
index 964270b36..32b056dc2 100644
--- a/av1/common/arm/convolve_neon_dotprod.c
+++ b/av1/common/arm/convolve_neon_dotprod.c
@@ -1350,6 +1350,131 @@ static INLINE void convolve_2d_sr_6tap_neon_dotprod(
   } while (w != 0);
 }
 
+static INLINE void convolve_2d_sr_4tap_neon_dotprod(
+    const uint8_t *src, int src_stride, uint8_t *dst, int dst_stride, int w,
+    int h, const int16_t *x_filter_ptr, const int16_t *y_filter_ptr) {
+  const int bd = 8;
+  const int16x8_t vert_const = vdupq_n_s16(1 << (bd - 1));
+
+  const int16x4_t y_filter = vld1_s16(y_filter_ptr + 2);
+  const int16x4_t x_filter_s16 = vld1_s16(x_filter_ptr + 2);
+  // All 4-tap and bilinear filter values are even, so halve them to reduce
+  // intermediate precision requirements.
+  const int8x8_t x_filter =
+      vshrn_n_s16(vcombine_s16(x_filter_s16, vdup_n_s16(0)), 1);
+
+  // Adding a shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding
+  // shifts - which are generally faster than rounding shifts on modern CPUs.
+  const int32_t horiz_const =
+      ((1 << (bd + FILTER_BITS - 1)) + (1 << (ROUND0_BITS - 1)));
+  // Accumulate into 128 << FILTER_BITS to account for range transform.
+  // Halve the total because we halved the filter values.
+  const int32x4_t correction =
+      vdupq_n_s32(((128 << FILTER_BITS) + horiz_const) / 2);
+
+  if (w == 4) {
+    const uint8x16_t permute_tbl = vld1q_u8(kDotProdPermuteTbl);
+
+    uint8x16_t h_s0, h_s1, h_s2;
+    load_u8_16x3(src, src_stride, &h_s0, &h_s1, &h_s2);
+
+    int16x4_t v_s0 = convolve4_4_2d_h(h_s0, x_filter, permute_tbl, correction);
+    int16x4_t v_s1 = convolve4_4_2d_h(h_s1, x_filter, permute_tbl, correction);
+    int16x4_t v_s2 = convolve4_4_2d_h(h_s2, x_filter, permute_tbl, correction);
+
+    src += 3 * src_stride;
+
+    do {
+      uint8x16_t h_s3, h_s4, h_s5, h_s6;
+      load_u8_16x4(src, src_stride, &h_s3, &h_s4, &h_s5, &h_s6);
+
+      int16x4_t v_s3 =
+          convolve4_4_2d_h(h_s3, x_filter, permute_tbl, correction);
+      int16x4_t v_s4 =
+          convolve4_4_2d_h(h_s4, x_filter, permute_tbl, correction);
+      int16x4_t v_s5 =
+          convolve4_4_2d_h(h_s5, x_filter, permute_tbl, correction);
+      int16x4_t v_s6 =
+          convolve4_4_2d_h(h_s6, x_filter, permute_tbl, correction);
+
+      int16x4_t d0 = convolve4_4_2d_v(v_s0, v_s1, v_s2, v_s3, y_filter);
+      int16x4_t d1 = convolve4_4_2d_v(v_s1, v_s2, v_s3, v_s4, y_filter);
+      int16x4_t d2 = convolve4_4_2d_v(v_s2, v_s3, v_s4, v_s5, y_filter);
+      int16x4_t d3 = convolve4_4_2d_v(v_s3, v_s4, v_s5, v_s6, y_filter);
+
+      uint8x8_t d01 = vqmovun_s16(vsubq_s16(vcombine_s16(d0, d1), vert_const));
+      uint8x8_t d23 = vqmovun_s16(vsubq_s16(vcombine_s16(d2, d3), vert_const));
+
+      store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d01);
+      store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d23);
+
+      v_s0 = v_s4;
+      v_s1 = v_s5;
+      v_s2 = v_s6;
+
+      src += 4 * src_stride;
+      dst += 4 * dst_stride;
+      h -= 4;
+    } while (h != 0);
+  } else {
+    const uint8x16x2_t permute_tbl = vld1q_u8_x2(kDotProdPermuteTbl);
+
+    do {
+      int height = h;
+      const uint8_t *s = src;
+      uint8_t *d = dst;
+
+      uint8x16_t h_s0, h_s1, h_s2;
+      load_u8_16x3(src, src_stride, &h_s0, &h_s1, &h_s2);
+
+      int16x8_t v_s0 =
+          convolve4_8_2d_h(h_s0, x_filter, permute_tbl, correction);
+      int16x8_t v_s1 =
+          convolve4_8_2d_h(h_s1, x_filter, permute_tbl, correction);
+      int16x8_t v_s2 =
+          convolve4_8_2d_h(h_s2, x_filter, permute_tbl, correction);
+
+      s += 3 * src_stride;
+
+      do {
+        uint8x16_t h_s3, h_s4, h_s5, h_s6;
+        load_u8_16x4(s, src_stride, &h_s3, &h_s4, &h_s5, &h_s6);
+
+        int16x8_t v_s3 =
+            convolve4_8_2d_h(h_s3, x_filter, permute_tbl, correction);
+        int16x8_t v_s4 =
+            convolve4_8_2d_h(h_s4, x_filter, permute_tbl, correction);
+        int16x8_t v_s5 =
+            convolve4_8_2d_h(h_s5, x_filter, permute_tbl, correction);
+        int16x8_t v_s6 =
+            convolve4_8_2d_h(h_s6, x_filter, permute_tbl, correction);
+
+        uint8x8_t d0 =
+            convolve4_8_2d_v(v_s0, v_s1, v_s2, v_s3, y_filter, vert_const);
+        uint8x8_t d1 =
+            convolve4_8_2d_v(v_s1, v_s2, v_s3, v_s4, y_filter, vert_const);
+        uint8x8_t d2 =
+            convolve4_8_2d_v(v_s2, v_s3, v_s4, v_s5, y_filter, vert_const);
+        uint8x8_t d3 =
+            convolve4_8_2d_v(v_s3, v_s4, v_s5, v_s6, y_filter, vert_const);
+
+        store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
+
+        v_s0 = v_s4;
+        v_s1 = v_s5;
+        v_s2 = v_s6;
+
+        s += 4 * src_stride;
+        d += 4 * dst_stride;
+        height -= 4;
+      } while (height != 0);
+      src += 8;
+      dst += 8;
+      w -= 8;
+    } while (w != 0);
+  }
+}
+
 void av1_convolve_2d_sr_neon_dotprod(const uint8_t *src, int src_stride,
                                      uint8_t *dst, int dst_stride, int w, int h,
                                      const InterpFilterParams *filter_params_x,
@@ -1400,6 +1525,12 @@ void av1_convolve_2d_sr_neon_dotprod(const uint8_t *src, int src_stride,
       return;
     }
 
+    if (x_filter_taps <= 4 && y_filter_taps <= 4) {
+      convolve_2d_sr_4tap_neon_dotprod(src_ptr + 2, src_stride, dst, dst_stride,
+                                       w, h, x_filter_ptr, y_filter_ptr);
+      return;
+    }
+
     DECLARE_ALIGNED(16, int16_t,
                     im_block[(MAX_SB_SIZE + SUBPEL_TAPS - 1) * MAX_SB_SIZE]);