aom: Add 2-tap path for aom_convolve8_vert_neon

From 5eb58adeeeeb9ec735d091d31a6d51c8a3c52f2c Mon Sep 17 00:00:00 2001
From: Salome Thirot <[EMAIL REDACTED]>
Date: Wed, 17 Apr 2024 11:06:58 +0100
Subject: [PATCH] Add 2-tap path for aom_convolve8_vert_neon

Add a specialized Neon implementation for 2-tap filters and use it
instead of the 4-tap implementation in all 3 Neon versions of
aom_convolve8_vert. This provides around 50% uplift over using the
4-tap implementation.

Change-Id: I411b23905d5f93ea1a8f2cab7d5ece8006b94032
---
 aom_dsp/arm/aom_convolve8_neon.c         |  7 +-
 aom_dsp/arm/aom_convolve8_neon.h         | 81 ++++++++++++++++++++++++
 aom_dsp/arm/aom_convolve8_neon_dotprod.c |  7 +-
 aom_dsp/arm/aom_convolve8_neon_i8mm.c    |  7 +-
 4 files changed, 99 insertions(+), 3 deletions(-)

diff --git a/aom_dsp/arm/aom_convolve8_neon.c b/aom_dsp/arm/aom_convolve8_neon.c
index 43aef5428..9a3ff8079 100644
--- a/aom_dsp/arm/aom_convolve8_neon.c
+++ b/aom_dsp/arm/aom_convolve8_neon.c
@@ -472,7 +472,12 @@ void aom_convolve8_vert_neon(const uint8_t *src, ptrdiff_t src_stride,
 
   src -= ((SUBPEL_TAPS / 2) - 1) * src_stride;
 
-  if (get_filter_taps_convolve8(filter_y) <= 4) {
+  int filter_taps = get_filter_taps_convolve8(filter_y);
+
+  if (filter_taps == 2) {
+    convolve8_vert_2tap_neon(src + 3 * src_stride, src_stride, dst, dst_stride,
+                             filter_y, w, h);
+  } else if (filter_taps == 4) {
     convolve8_vert_4tap_neon(src + 2 * src_stride, src_stride, dst, dst_stride,
                              filter_y, w, h);
   } else {
diff --git a/aom_dsp/arm/aom_convolve8_neon.h b/aom_dsp/arm/aom_convolve8_neon.h
index 83fbd0afc..b523c41bc 100644
--- a/aom_dsp/arm/aom_convolve8_neon.h
+++ b/aom_dsp/arm/aom_convolve8_neon.h
@@ -201,4 +201,85 @@ static INLINE void convolve8_vert_4tap_neon(const uint8_t *src,
   }
 }
 
+static INLINE void convolve8_vert_2tap_neon(const uint8_t *src,
+                                            ptrdiff_t src_stride, uint8_t *dst,
+                                            ptrdiff_t dst_stride,
+                                            const int16_t *filter_y, int w,
+                                            int h) {
+  // Bilinear filter values are all positive.
+  uint8x8_t f0 = vdup_n_u8((uint8_t)filter_y[3]);
+  uint8x8_t f1 = vdup_n_u8((uint8_t)filter_y[4]);
+
+  if (w == 4) {
+    do {
+      uint8x8_t s0 = load_unaligned_u8(src + 0 * src_stride, (int)src_stride);
+      uint8x8_t s1 = load_unaligned_u8(src + 1 * src_stride, (int)src_stride);
+      uint8x8_t s2 = load_unaligned_u8(src + 2 * src_stride, (int)src_stride);
+      uint8x8_t s3 = load_unaligned_u8(src + 3 * src_stride, (int)src_stride);
+
+      uint16x8_t sum0 = vmull_u8(s0, f0);
+      sum0 = vmlal_u8(sum0, s1, f1);
+      uint16x8_t sum1 = vmull_u8(s2, f0);
+      sum1 = vmlal_u8(sum1, s3, f1);
+
+      uint8x8_t d0 = vqrshrn_n_u16(sum0, FILTER_BITS);
+      uint8x8_t d1 = vqrshrn_n_u16(sum1, FILTER_BITS);
+
+      store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d0);
+      store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d1);
+
+      src += 4 * src_stride;
+      dst += 4 * dst_stride;
+      h -= 4;
+    } while (h > 0);
+  } else if (w == 8) {
+    do {
+      uint8x8_t s0, s1, s2;
+      load_u8_8x3(src, src_stride, &s0, &s1, &s2);
+
+      uint16x8_t sum0 = vmull_u8(s0, f0);
+      sum0 = vmlal_u8(sum0, s1, f1);
+      uint16x8_t sum1 = vmull_u8(s1, f0);
+      sum1 = vmlal_u8(sum1, s2, f1);
+
+      uint8x8_t d0 = vqrshrn_n_u16(sum0, FILTER_BITS);
+      uint8x8_t d1 = vqrshrn_n_u16(sum1, FILTER_BITS);
+
+      vst1_u8(dst + 0 * dst_stride, d0);
+      vst1_u8(dst + 1 * dst_stride, d1);
+
+      src += 2 * src_stride;
+      dst += 2 * dst_stride;
+      h -= 2;
+    } while (h > 0);
+  } else {
+    do {
+      int width = w;
+      const uint8_t *s = src;
+      uint8_t *d = dst;
+
+      do {
+        uint8x16_t s0 = vld1q_u8(s + 0 * src_stride);
+        uint8x16_t s1 = vld1q_u8(s + 1 * src_stride);
+
+        uint16x8_t sum0 = vmull_u8(vget_low_u8(s0), f0);
+        sum0 = vmlal_u8(sum0, vget_low_u8(s1), f1);
+        uint16x8_t sum1 = vmull_u8(vget_high_u8(s0), f0);
+        sum1 = vmlal_u8(sum1, vget_high_u8(s1), f1);
+
+        uint8x8_t d0 = vqrshrn_n_u16(sum0, FILTER_BITS);
+        uint8x8_t d1 = vqrshrn_n_u16(sum1, FILTER_BITS);
+
+        vst1q_u8(d, vcombine_u8(d0, d1));
+
+        s += 16;
+        d += 16;
+        width -= 16;
+      } while (width != 0);
+      src += src_stride;
+      dst += dst_stride;
+    } while (--h > 0);
+  }
+}
+
 #endif  // AOM_AOM_DSP_ARM_AOM_CONVOLVE8_NEON_H_
diff --git a/aom_dsp/arm/aom_convolve8_neon_dotprod.c b/aom_dsp/arm/aom_convolve8_neon_dotprod.c
index 4d47d86ef..721957086 100644
--- a/aom_dsp/arm/aom_convolve8_neon_dotprod.c
+++ b/aom_dsp/arm/aom_convolve8_neon_dotprod.c
@@ -540,7 +540,12 @@ void aom_convolve8_vert_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride,
 
   src -= ((SUBPEL_TAPS / 2) - 1) * src_stride;
 
-  if (get_filter_taps_convolve8(filter_y) <= 4) {
+  int filter_taps = get_filter_taps_convolve8(filter_y);
+
+  if (filter_taps == 2) {
+    convolve8_vert_2tap_neon(src + 3 * src_stride, src_stride, dst, dst_stride,
+                             filter_y, w, h);
+  } else if (filter_taps == 4) {
     convolve8_vert_4tap_neon(src + 2 * src_stride, src_stride, dst, dst_stride,
                              filter_y, w, h);
   } else {
diff --git a/aom_dsp/arm/aom_convolve8_neon_i8mm.c b/aom_dsp/arm/aom_convolve8_neon_i8mm.c
index 21a4551a3..34bfe0166 100644
--- a/aom_dsp/arm/aom_convolve8_neon_i8mm.c
+++ b/aom_dsp/arm/aom_convolve8_neon_i8mm.c
@@ -482,7 +482,12 @@ void aom_convolve8_vert_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride,
 
   src -= ((SUBPEL_TAPS / 2) - 1) * src_stride;
 
-  if (get_filter_taps_convolve8(filter_y) <= 4) {
+  int filter_taps = get_filter_taps_convolve8(filter_y);
+
+  if (filter_taps == 2) {
+    convolve8_vert_2tap_neon(src + 3 * src_stride, src_stride, dst, dst_stride,
+                             filter_y, w, h);
+  } else if (filter_taps == 4) {
     convolve8_vert_4tap_neon(src + 2 * src_stride, src_stride, dst, dst_stride,
                              filter_y, w, h);
   } else {