aom: Add 4-tap path for av1_convolve_2d_vert_sr_neon

From 279722d6fef0a3117d8c4ce0804667c63aabc2ad Mon Sep 17 00:00:00 2001
From: Salome Thirot <[EMAIL REDACTED]>
Date: Wed, 8 May 2024 17:22:20 +0100
Subject: [PATCH] Add 4-tap path for av1_convolve_2d_vert_sr_neon

Add a 4-tap Neon implementation for the vertical pass of
av1_convolve_2d_sr and use it for the neon, neon_dotprod and neon_i8mm
variants of the function. This gives up to 30% uplift over using the
6-tap implementation.

Change-Id: Ia61667cd54a79c352433fd190c9a1f94872c1efe
---
 av1/common/arm/convolve_neon.c         |   7 +-
 av1/common/arm/convolve_neon.h         | 108 +++++++++++++++++++++++++
 av1/common/arm/convolve_neon_dotprod.c |   7 +-
 av1/common/arm/convolve_neon_i8mm.c    |   7 +-
 4 files changed, 123 insertions(+), 6 deletions(-)

diff --git a/av1/common/arm/convolve_neon.c b/av1/common/arm/convolve_neon.c
index 72a85893e8..70cf23be06 100644
--- a/av1/common/arm/convolve_neon.c
+++ b/av1/common/arm/convolve_neon.c
@@ -1588,7 +1588,7 @@ void av1_convolve_2d_sr_neon(const uint8_t *src, int src_stride, uint8_t *dst,
 
   const int y_filter_taps = get_filter_tap(filter_params_y, subpel_y_qn);
   const int x_filter_taps = get_filter_tap(filter_params_x, subpel_x_qn);
-  const int clamped_y_taps = y_filter_taps < 6 ? 6 : y_filter_taps;
+  const int clamped_y_taps = y_filter_taps < 4 ? 4 : y_filter_taps;
   const int im_h = h + clamped_y_taps - 1;
   const int im_stride = MAX_SB_SIZE;
   const int vert_offset = clamped_y_taps / 2 - 1;
@@ -1628,7 +1628,10 @@ void av1_convolve_2d_sr_neon(const uint8_t *src, int src_stride, uint8_t *dst,
 
     const int16x8_t y_filter = vld1q_s16(y_filter_ptr);
 
-    if (clamped_y_taps <= 6) {
+    if (clamped_y_taps <= 4) {
+      convolve_2d_sr_vert_4tap_neon(im_block, im_stride, dst, dst_stride, w, h,
+                                    y_filter_ptr);
+    } else if (clamped_y_taps == 6) {
       convolve_2d_sr_vert_6tap_neon(im_block, im_stride, dst, dst_stride, w, h,
                                     y_filter);
     } else {
diff --git a/av1/common/arm/convolve_neon.h b/av1/common/arm/convolve_neon.h
index 9fbf8aa12f..5a9f8b6d39 100644
--- a/av1/common/arm/convolve_neon.h
+++ b/av1/common/arm/convolve_neon.h
@@ -535,4 +535,112 @@ static INLINE void convolve_2d_sr_vert_6tap_neon(int16_t *src_ptr,
   }
 }
 
+static INLINE int16x4_t convolve4_4_2d_v(const int16x4_t s0, const int16x4_t s1,
+                                         const int16x4_t s2, const int16x4_t s3,
+                                         const int16x4_t y_filter) {
+  int32x4_t sum = vmull_lane_s16(s0, y_filter, 0);
+  sum = vmlal_lane_s16(sum, s1, y_filter, 1);
+  sum = vmlal_lane_s16(sum, s2, y_filter, 2);
+  sum = vmlal_lane_s16(sum, s3, y_filter, 3);
+
+  return vqrshrn_n_s32(sum, 2 * FILTER_BITS - ROUND0_BITS);
+}
+
+static INLINE uint8x8_t convolve4_8_2d_v(const int16x8_t s0, const int16x8_t s1,
+                                         const int16x8_t s2, const int16x8_t s3,
+                                         const int16x4_t y_filter,
+                                         const int16x8_t sub_const) {
+  int32x4_t sum0 = vmull_lane_s16(vget_low_s16(s0), y_filter, 0);
+  sum0 = vmlal_lane_s16(sum0, vget_low_s16(s1), y_filter, 1);
+  sum0 = vmlal_lane_s16(sum0, vget_low_s16(s2), y_filter, 2);
+  sum0 = vmlal_lane_s16(sum0, vget_low_s16(s3), y_filter, 3);
+
+  int32x4_t sum1 = vmull_lane_s16(vget_high_s16(s0), y_filter, 0);
+  sum1 = vmlal_lane_s16(sum1, vget_high_s16(s1), y_filter, 1);
+  sum1 = vmlal_lane_s16(sum1, vget_high_s16(s2), y_filter, 2);
+  sum1 = vmlal_lane_s16(sum1, vget_high_s16(s3), y_filter, 3);
+
+  int16x8_t res =
+      vcombine_s16(vqrshrn_n_s32(sum0, 2 * FILTER_BITS - ROUND0_BITS),
+                   vqrshrn_n_s32(sum1, 2 * FILTER_BITS - ROUND0_BITS));
+  res = vsubq_s16(res, sub_const);
+
+  return vqmovun_s16(res);
+}
+
+static INLINE void convolve_2d_sr_vert_4tap_neon(int16_t *src_ptr,
+                                                 int src_stride,
+                                                 uint8_t *dst_ptr,
+                                                 int dst_stride, int w, int h,
+                                                 const int16_t *y_filter) {
+  const int bd = 8;
+  const int16x8_t sub_const = vdupq_n_s16(1 << (bd - 1));
+
+  const int16x4_t filter = vld1_s16(y_filter + 2);
+
+  if (w == 4) {
+    int16x4_t s0, s1, s2;
+    load_s16_4x3(src_ptr, src_stride, &s0, &s1, &s2);
+    src_ptr += 3 * src_stride;
+
+    do {
+      int16x4_t s3, s4, s5, s6;
+      load_s16_4x4(src_ptr, src_stride, &s3, &s4, &s5, &s6);
+
+      int16x4_t d0 = convolve4_4_2d_v(s0, s1, s2, s3, filter);
+      int16x4_t d1 = convolve4_4_2d_v(s1, s2, s3, s4, filter);
+      int16x4_t d2 = convolve4_4_2d_v(s2, s3, s4, s5, filter);
+      int16x4_t d3 = convolve4_4_2d_v(s3, s4, s5, s6, filter);
+
+      uint8x8_t d01 = vqmovun_s16(vsubq_s16(vcombine_s16(d0, d1), sub_const));
+      uint8x8_t d23 = vqmovun_s16(vsubq_s16(vcombine_s16(d2, d3), sub_const));
+
+      store_u8x4_strided_x2(dst_ptr + 0 * dst_stride, dst_stride, d01);
+      store_u8x4_strided_x2(dst_ptr + 2 * dst_stride, dst_stride, d23);
+
+      s0 = s4;
+      s1 = s5;
+      s2 = s6;
+
+      src_ptr += 4 * src_stride;
+      dst_ptr += 4 * dst_stride;
+      h -= 4;
+    } while (h != 0);
+  } else {
+    // Width is a multiple of 8 and height is a multiple of 4.
+    do {
+      int height = h;
+      int16_t *s = src_ptr;
+      uint8_t *d = dst_ptr;
+
+      int16x8_t s0, s1, s2;
+      load_s16_8x3(s, src_stride, &s0, &s1, &s2);
+      s += 3 * src_stride;
+
+      do {
+        int16x8_t s3, s4, s5, s6;
+        load_s16_8x4(s, src_stride, &s3, &s4, &s5, &s6);
+
+        uint8x8_t d0 = convolve4_8_2d_v(s0, s1, s2, s3, filter, sub_const);
+        uint8x8_t d1 = convolve4_8_2d_v(s1, s2, s3, s4, filter, sub_const);
+        uint8x8_t d2 = convolve4_8_2d_v(s2, s3, s4, s5, filter, sub_const);
+        uint8x8_t d3 = convolve4_8_2d_v(s3, s4, s5, s6, filter, sub_const);
+
+        store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
+
+        s0 = s4;
+        s1 = s5;
+        s2 = s6;
+
+        s += 4 * src_stride;
+        d += 4 * dst_stride;
+        height -= 4;
+      } while (height != 0);
+      src_ptr += 8;
+      dst_ptr += 8;
+      w -= 8;
+    } while (w != 0);
+  }
+}
+
 #endif  // AOM_AV1_COMMON_ARM_CONVOLVE_NEON_H_
diff --git a/av1/common/arm/convolve_neon_dotprod.c b/av1/common/arm/convolve_neon_dotprod.c
index 3c85f3cb4b..b558744731 100644
--- a/av1/common/arm/convolve_neon_dotprod.c
+++ b/av1/common/arm/convolve_neon_dotprod.c
@@ -1419,7 +1419,7 @@ void av1_convolve_2d_sr_neon_dotprod(const uint8_t *src, int src_stride,
 
   const int y_filter_taps = get_filter_tap(filter_params_y, subpel_y_qn);
   const int x_filter_taps = get_filter_tap(filter_params_x, subpel_x_qn);
-  const int clamped_y_taps = y_filter_taps < 6 ? 6 : y_filter_taps;
+  const int clamped_y_taps = y_filter_taps < 4 ? 4 : y_filter_taps;
   const int im_h = h + clamped_y_taps - 1;
   const int im_stride = MAX_SB_SIZE;
   const int vert_offset = clamped_y_taps / 2 - 1;
@@ -1460,7 +1460,10 @@ void av1_convolve_2d_sr_neon_dotprod(const uint8_t *src, int src_stride,
 
     const int16x8_t y_filter = vld1q_s16(y_filter_ptr);
 
-    if (clamped_y_taps <= 6) {
+    if (clamped_y_taps <= 4) {
+      convolve_2d_sr_vert_4tap_neon(im_block, im_stride, dst, dst_stride, w, h,
+                                    y_filter_ptr);
+    } else if (clamped_y_taps == 6) {
       convolve_2d_sr_vert_6tap_neon(im_block, im_stride, dst, dst_stride, w, h,
                                     y_filter);
     } else {
diff --git a/av1/common/arm/convolve_neon_i8mm.c b/av1/common/arm/convolve_neon_i8mm.c
index b6a2a41ba0..b2f489f0d4 100644
--- a/av1/common/arm/convolve_neon_i8mm.c
+++ b/av1/common/arm/convolve_neon_i8mm.c
@@ -1289,7 +1289,7 @@ void av1_convolve_2d_sr_neon_i8mm(const uint8_t *src, int src_stride,
 
   const int y_filter_taps = get_filter_tap(filter_params_y, subpel_y_qn);
   const int x_filter_taps = get_filter_tap(filter_params_x, subpel_x_qn);
-  const int clamped_y_taps = y_filter_taps < 6 ? 6 : y_filter_taps;
+  const int clamped_y_taps = y_filter_taps < 4 ? 4 : y_filter_taps;
   const int im_h = h + clamped_y_taps - 1;
   const int im_stride = MAX_SB_SIZE;
   const int vert_offset = clamped_y_taps / 2 - 1;
@@ -1330,7 +1330,10 @@ void av1_convolve_2d_sr_neon_i8mm(const uint8_t *src, int src_stride,
 
     const int16x8_t y_filter = vld1q_s16(y_filter_ptr);
 
-    if (clamped_y_taps <= 6) {
+    if (clamped_y_taps <= 4) {
+      convolve_2d_sr_vert_4tap_neon(im_block, im_stride, dst, dst_stride, w, h,
+                                    y_filter_ptr);
+    } else if (clamped_y_taps == 6) {
       convolve_2d_sr_vert_6tap_neon(im_block, im_stride, dst, dst_stride, w, h,
                                     y_filter);
     } else {