aom: Add 4-tap specialisation to aom_convolve8_vert_neon

From 32f8079ce4d9a5b9a7553100782ab124dfc0be64 Mon Sep 17 00:00:00 2001
From: Salome Thirot <[EMAIL REDACTED]>
Date: Tue, 16 Apr 2024 15:25:39 +0100
Subject: [PATCH] Add 4-tap specialisation to aom_convolve8_vert_neon

Add a specialized Neon path for 4-tap filters in
aom_convolve8_vert_neon, and use it for the neon_dotprod and neon_i8mm
versions as well. This gives between 40% and 50% uplift compared to
using the 8-tap path.

Change-Id: I93e48a62d851af5ff0c6f8015cb91f687c970802
---
 aom_dsp/arm/aom_convolve8_neon.c         | 53 ++++++-------
 aom_dsp/arm/aom_convolve8_neon.h         | 98 ++++++++++++++++++++++++
 aom_dsp/arm/aom_convolve8_neon_dotprod.c | 40 ++++++----
 aom_dsp/arm/aom_convolve8_neon_i8mm.c    | 40 ++++++----
 aom_dsp/arm/mem_neon.h                   | 10 +++
 5 files changed, 187 insertions(+), 54 deletions(-)

diff --git a/aom_dsp/arm/aom_convolve8_neon.c b/aom_dsp/arm/aom_convolve8_neon.c
index 5665b5e12..43aef5428 100644
--- a/aom_dsp/arm/aom_convolve8_neon.c
+++ b/aom_dsp/arm/aom_convolve8_neon.c
@@ -243,18 +243,6 @@ static INLINE int16x4_t convolve4_4(const int16x4_t s0, const int16x4_t s1,
   return sum;
 }
 
-static INLINE uint8x8_t convolve4_8(const int16x8_t s0, const int16x8_t s1,
-                                    const int16x8_t s2, const int16x8_t s3,
-                                    const int16x4_t filter) {
-  int16x8_t sum = vmulq_lane_s16(s0, filter, 0);
-  sum = vmlaq_lane_s16(sum, s1, filter, 1);
-  sum = vmlaq_lane_s16(sum, s2, filter, 2);
-  sum = vmlaq_lane_s16(sum, s3, filter, 3);
-
-  // We halved the filter values so -1 from right shift.
-  return vqrshrun_n_s16(sum, FILTER_BITS - 1);
-}
-
 static INLINE void convolve8_horiz_4tap_neon(const uint8_t *src,
                                              ptrdiff_t src_stride, uint8_t *dst,
                                              ptrdiff_t dst_stride,
@@ -368,22 +356,13 @@ void aom_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
   }
 }
 
-void aom_convolve8_vert_neon(const uint8_t *src, ptrdiff_t src_stride,
-                             uint8_t *dst, ptrdiff_t dst_stride,
-                             const int16_t *filter_x, int x_step_q4,
-                             const int16_t *filter_y, int y_step_q4, int w,
-                             int h) {
+static INLINE void convolve8_vert_8tap_neon(const uint8_t *src,
+                                            ptrdiff_t src_stride, uint8_t *dst,
+                                            ptrdiff_t dst_stride,
+                                            const int16_t *filter_y, int w,
+                                            int h) {
   const int16x8_t filter = vld1q_s16(filter_y);
 
-  assert((intptr_t)dst % 4 == 0);
-  assert(dst_stride % 4 == 0);
-
-  (void)filter_x;
-  (void)x_step_q4;
-  (void)y_step_q4;
-
-  src -= ((SUBPEL_TAPS / 2) - 1) * src_stride;
-
   if (w == 4) {
     uint8x8_t t0, t1, t2, t3, t4, t5, t6;
     load_u8_8x7(src, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6);
@@ -478,3 +457,25 @@ void aom_convolve8_vert_neon(const uint8_t *src, ptrdiff_t src_stride,
     } while (w != 0);
   }
 }
+
+void aom_convolve8_vert_neon(const uint8_t *src, ptrdiff_t src_stride,
+                             uint8_t *dst, ptrdiff_t dst_stride,
+                             const int16_t *filter_x, int x_step_q4,
+                             const int16_t *filter_y, int y_step_q4, int w,
+                             int h) {
+  assert((intptr_t)dst % 4 == 0);
+  assert(dst_stride % 4 == 0);
+
+  (void)filter_x;
+  (void)x_step_q4;
+  (void)y_step_q4;
+
+  src -= ((SUBPEL_TAPS / 2) - 1) * src_stride;
+
+  if (get_filter_taps_convolve8(filter_y) <= 4) {
+    convolve8_vert_4tap_neon(src + 2 * src_stride, src_stride, dst, dst_stride,
+                             filter_y, w, h);
+  } else {
+    convolve8_vert_8tap_neon(src, src_stride, dst, dst_stride, filter_y, w, h);
+  }
+}
diff --git a/aom_dsp/arm/aom_convolve8_neon.h b/aom_dsp/arm/aom_convolve8_neon.h
index 0aebc6d12..83fbd0afc 100644
--- a/aom_dsp/arm/aom_convolve8_neon.h
+++ b/aom_dsp/arm/aom_convolve8_neon.h
@@ -103,4 +103,102 @@ static INLINE void convolve8_horiz_2tap_neon(const uint8_t *src,
   }
 }
 
+static INLINE uint8x8_t convolve4_8(const int16x8_t s0, const int16x8_t s1,
+                                    const int16x8_t s2, const int16x8_t s3,
+                                    const int16x4_t filter) {
+  int16x8_t sum = vmulq_lane_s16(s0, filter, 0);
+  sum = vmlaq_lane_s16(sum, s1, filter, 1);
+  sum = vmlaq_lane_s16(sum, s2, filter, 2);
+  sum = vmlaq_lane_s16(sum, s3, filter, 3);
+
+  // We halved the filter values so -1 from right shift.
+  return vqrshrun_n_s16(sum, FILTER_BITS - 1);
+}
+
+static INLINE void convolve8_vert_4tap_neon(const uint8_t *src,
+                                            ptrdiff_t src_stride, uint8_t *dst,
+                                            ptrdiff_t dst_stride,
+                                            const int16_t *filter_y, int w,
+                                            int h) {
+  // All filter values are even, halve to reduce intermediate precision
+  // requirements.
+  const int16x4_t filter = vshr_n_s16(vld1_s16(filter_y + 2), 1);
+
+  if (w == 4) {
+    uint8x8_t t01 = load_unaligned_u8(src + 0 * src_stride, (int)src_stride);
+    uint8x8_t t12 = load_unaligned_u8(src + 1 * src_stride, (int)src_stride);
+
+    int16x8_t s01 = vreinterpretq_s16_u16(vmovl_u8(t01));
+    int16x8_t s12 = vreinterpretq_s16_u16(vmovl_u8(t12));
+
+    src += 2 * src_stride;
+
+    do {
+      uint8x8_t t23 = load_unaligned_u8(src + 0 * src_stride, (int)src_stride);
+      uint8x8_t t34 = load_unaligned_u8(src + 1 * src_stride, (int)src_stride);
+      uint8x8_t t45 = load_unaligned_u8(src + 2 * src_stride, (int)src_stride);
+      uint8x8_t t56 = load_unaligned_u8(src + 3 * src_stride, (int)src_stride);
+
+      int16x8_t s23 = vreinterpretq_s16_u16(vmovl_u8(t23));
+      int16x8_t s34 = vreinterpretq_s16_u16(vmovl_u8(t34));
+      int16x8_t s45 = vreinterpretq_s16_u16(vmovl_u8(t45));
+      int16x8_t s56 = vreinterpretq_s16_u16(vmovl_u8(t56));
+
+      uint8x8_t d01 = convolve4_8(s01, s12, s23, s34, filter);
+      uint8x8_t d23 = convolve4_8(s23, s34, s45, s56, filter);
+
+      store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d01);
+      store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d23);
+
+      s01 = s45;
+      s12 = s56;
+
+      src += 4 * src_stride;
+      dst += 4 * dst_stride;
+      h -= 4;
+    } while (h != 0);
+  } else {
+    do {
+      uint8x8_t t0, t1, t2;
+      load_u8_8x3(src, src_stride, &t0, &t1, &t2);
+
+      int16x8_t s0 = vreinterpretq_s16_u16(vmovl_u8(t0));
+      int16x8_t s1 = vreinterpretq_s16_u16(vmovl_u8(t1));
+      int16x8_t s2 = vreinterpretq_s16_u16(vmovl_u8(t2));
+
+      int height = h;
+      const uint8_t *s = src + 3 * src_stride;
+      uint8_t *d = dst;
+
+      do {
+        uint8x8_t t3;
+        load_u8_8x4(s, src_stride, &t0, &t1, &t2, &t3);
+
+        int16x8_t s3 = vreinterpretq_s16_u16(vmovl_u8(t0));
+        int16x8_t s4 = vreinterpretq_s16_u16(vmovl_u8(t1));
+        int16x8_t s5 = vreinterpretq_s16_u16(vmovl_u8(t2));
+        int16x8_t s6 = vreinterpretq_s16_u16(vmovl_u8(t3));
+
+        uint8x8_t d0 = convolve4_8(s0, s1, s2, s3, filter);
+        uint8x8_t d1 = convolve4_8(s1, s2, s3, s4, filter);
+        uint8x8_t d2 = convolve4_8(s2, s3, s4, s5, filter);
+        uint8x8_t d3 = convolve4_8(s3, s4, s5, s6, filter);
+
+        store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
+
+        s0 = s4;
+        s1 = s5;
+        s2 = s6;
+
+        s += 4 * src_stride;
+        d += 4 * dst_stride;
+        height -= 4;
+      } while (height != 0);
+      src += 8;
+      dst += 8;
+      w -= 8;
+    } while (w != 0);
+  }
+}
+
 #endif  // AOM_AOM_DSP_ARM_AOM_CONVOLVE8_NEON_H_
diff --git a/aom_dsp/arm/aom_convolve8_neon_dotprod.c b/aom_dsp/arm/aom_convolve8_neon_dotprod.c
index f49d33ff3..4d47d86ef 100644
--- a/aom_dsp/arm/aom_convolve8_neon_dotprod.c
+++ b/aom_dsp/arm/aom_convolve8_neon_dotprod.c
@@ -370,24 +370,13 @@ static INLINE uint8x8_t convolve8_8_v(const int8x16_t samples0_lo,
   return vqrshrun_n_s16(sum, FILTER_BITS);
 }
 
-void aom_convolve8_vert_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride,
-                                     uint8_t *dst, ptrdiff_t dst_stride,
-                                     const int16_t *filter_x, int x_step_q4,
-                                     const int16_t *filter_y, int y_step_q4,
-                                     int w, int h) {
+static INLINE void convolve8_vert_8tap_neon_dotprod(
+    const uint8_t *src, ptrdiff_t src_stride, uint8_t *dst,
+    ptrdiff_t dst_stride, const int16_t *filter_y, int w, int h) {
   const int8x8_t filter = vmovn_s16(vld1q_s16(filter_y));
   const uint8x16x3_t merge_block_tbl = vld1q_u8_x3(kDotProdMergeBlockTbl);
   int8x16x2_t samples_LUT;
 
-  assert((intptr_t)dst % 4 == 0);
-  assert(dst_stride % 4 == 0);
-
-  (void)filter_x;
-  (void)x_step_q4;
-  (void)y_step_q4;
-
-  src -= ((SUBPEL_TAPS / 2) - 1) * src_stride;
-
   if (w == 4) {
     uint8x8_t t0, t1, t2, t3, t4, t5, t6;
     load_u8_8x7(src, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6);
@@ -536,3 +525,26 @@ void aom_convolve8_vert_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride,
     } while (w != 0);
   }
 }
+
+void aom_convolve8_vert_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride,
+                                     uint8_t *dst, ptrdiff_t dst_stride,
+                                     const int16_t *filter_x, int x_step_q4,
+                                     const int16_t *filter_y, int y_step_q4,
+                                     int w, int h) {
+  assert((intptr_t)dst % 4 == 0);
+  assert(dst_stride % 4 == 0);
+
+  (void)filter_x;
+  (void)x_step_q4;
+  (void)y_step_q4;
+
+  src -= ((SUBPEL_TAPS / 2) - 1) * src_stride;
+
+  if (get_filter_taps_convolve8(filter_y) <= 4) {
+    convolve8_vert_4tap_neon(src + 2 * src_stride, src_stride, dst, dst_stride,
+                             filter_y, w, h);
+  } else {
+    convolve8_vert_8tap_neon_dotprod(src, src_stride, dst, dst_stride, filter_y,
+                                     w, h);
+  }
+}
diff --git a/aom_dsp/arm/aom_convolve8_neon_i8mm.c b/aom_dsp/arm/aom_convolve8_neon_i8mm.c
index 972763999..21a4551a3 100644
--- a/aom_dsp/arm/aom_convolve8_neon_i8mm.c
+++ b/aom_dsp/arm/aom_convolve8_neon_i8mm.c
@@ -340,24 +340,13 @@ static INLINE uint8x8_t convolve8_8_v(const uint8x16_t samples0_lo,
   return vqrshrun_n_s16(sum, FILTER_BITS);
 }
 
-void aom_convolve8_vert_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride,
-                                  uint8_t *dst, ptrdiff_t dst_stride,
-                                  const int16_t *filter_x, int x_step_q4,
-                                  const int16_t *filter_y, int y_step_q4, int w,
-                                  int h) {
+static INLINE void convolve8_vert_8tap_neon_i8mm(
+    const uint8_t *src, ptrdiff_t src_stride, uint8_t *dst,
+    ptrdiff_t dst_stride, const int16_t *filter_y, int w, int h) {
   const int8x8_t filter = vmovn_s16(vld1q_s16(filter_y));
   const uint8x16x3_t merge_block_tbl = vld1q_u8_x3(kDotProdMergeBlockTbl);
   uint8x16x2_t samples_LUT;
 
-  assert((intptr_t)dst % 4 == 0);
-  assert(dst_stride % 4 == 0);
-
-  (void)filter_x;
-  (void)x_step_q4;
-  (void)y_step_q4;
-
-  src -= ((SUBPEL_TAPS / 2) - 1) * src_stride;
-
   if (w == 4) {
     uint8x8_t s0, s1, s2, s3, s4, s5, s6;
     load_u8_8x7(src, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6);
@@ -478,3 +467,26 @@ void aom_convolve8_vert_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride,
     } while (w != 0);
   }
 }
+
+void aom_convolve8_vert_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride,
+                                  uint8_t *dst, ptrdiff_t dst_stride,
+                                  const int16_t *filter_x, int x_step_q4,
+                                  const int16_t *filter_y, int y_step_q4, int w,
+                                  int h) {
+  assert((intptr_t)dst % 4 == 0);
+  assert(dst_stride % 4 == 0);
+
+  (void)filter_x;
+  (void)x_step_q4;
+  (void)y_step_q4;
+
+  src -= ((SUBPEL_TAPS / 2) - 1) * src_stride;
+
+  if (get_filter_taps_convolve8(filter_y) <= 4) {
+    convolve8_vert_4tap_neon(src + 2 * src_stride, src_stride, dst, dst_stride,
+                             filter_y, w, h);
+  } else {
+    convolve8_vert_8tap_neon_i8mm(src, src_stride, dst, dst_stride, filter_y, w,
+                                  h);
+  }
+}
diff --git a/aom_dsp/arm/mem_neon.h b/aom_dsp/arm/mem_neon.h
index 32a462a18..ba187007c 100644
--- a/aom_dsp/arm/mem_neon.h
+++ b/aom_dsp/arm/mem_neon.h
@@ -174,6 +174,16 @@ static INLINE void load_u8_8x4(const uint8_t *s, const ptrdiff_t p,
   *s3 = vld1_u8(s);
 }
 
+static INLINE void load_u8_8x3(const uint8_t *s, const ptrdiff_t p,
+                               uint8x8_t *const s0, uint8x8_t *const s1,
+                               uint8x8_t *const s2) {
+  *s0 = vld1_u8(s);
+  s += p;
+  *s1 = vld1_u8(s);
+  s += p;
+  *s2 = vld1_u8(s);
+}
+
 static INLINE void load_u16_4x4(const uint16_t *s, const ptrdiff_t p,
                                 uint16x4_t *const s0, uint16x4_t *const s1,
                                 uint16x4_t *const s2, uint16x4_t *const s3) {