aom: Add 4-tap specialisation to aom_highbd_convolve8_horiz_neon

From 68a56cc8679dddd79b1902be40c90dcb64197196 Mon Sep 17 00:00:00 2001
From: Salome Thirot <[EMAIL REDACTED]>
Date: Thu, 18 Apr 2024 15:06:37 +0100
Subject: [PATCH] Add 4-tap specialisation to aom_highbd_convolve8_horiz_neon

Add specialised path for 4-tap filters in
aom_highbd_convolve8_horiz_neon. This gives between 30% and 50% uplift
compared to using the 8-tap path.

Change-Id: I721498e71ba7f2dbeebfa68a78b08b0b5bca5a88
---
 aom_dsp/arm/highbd_convolve8_neon.c | 122 ++++++++++++++++++++++++++--
 1 file changed, 115 insertions(+), 7 deletions(-)

diff --git a/aom_dsp/arm/highbd_convolve8_neon.c b/aom_dsp/arm/highbd_convolve8_neon.c
index f84b17f17..6d8ce2961 100644
--- a/aom_dsp/arm/highbd_convolve8_neon.c
+++ b/aom_dsp/arm/highbd_convolve8_neon.c
@@ -19,6 +19,7 @@
 #include "aom/aom_integer.h"
 #include "aom_dsp/aom_dsp_common.h"
 #include "aom_dsp/aom_filter.h"
+#include "aom_dsp/arm/aom_filter.h"
 #include "aom_dsp/arm/mem_neon.h"
 #include "aom_dsp/arm/transpose_neon.h"
 #include "aom_ports/mem.h"
@@ -77,11 +78,9 @@ highbd_convolve8_8(const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,
   return vminq_u16(res, max);
 }
 
-static void highbd_convolve_horiz_neon(const uint16_t *src_ptr,
-                                       ptrdiff_t src_stride, uint16_t *dst_ptr,
-                                       ptrdiff_t dst_stride,
-                                       const int16_t *x_filter_ptr, int w,
-                                       int h, int bd) {
+static void highbd_convolve_horiz_8tap_neon(
+    const uint16_t *src_ptr, ptrdiff_t src_stride, uint16_t *dst_ptr,
+    ptrdiff_t dst_stride, const int16_t *x_filter_ptr, int w, int h, int bd) {
   assert(w >= 4 && h >= 4);
   const int16x8_t x_filter = vld1q_s16(x_filter_ptr);
 
@@ -158,6 +157,109 @@ static void highbd_convolve_horiz_neon(const uint16_t *src_ptr,
   }
 }
 
+static INLINE uint16x4_t highbd_convolve4_4(
+    const int16x4_t s0, const int16x4_t s1, const int16x4_t s2,
+    const int16x4_t s3, const int16x4_t filter, const uint16x4_t max) {
+  int32x4_t sum = vmull_lane_s16(s0, filter, 0);
+  sum = vmlal_lane_s16(sum, s1, filter, 1);
+  sum = vmlal_lane_s16(sum, s2, filter, 2);
+  sum = vmlal_lane_s16(sum, s3, filter, 3);
+
+  uint16x4_t res = vqrshrun_n_s32(sum, FILTER_BITS);
+
+  return vmin_u16(res, max);
+}
+
+static INLINE uint16x8_t highbd_convolve4_8(
+    const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,
+    const int16x8_t s3, const int16x4_t filter, const uint16x8_t max) {
+  int32x4_t sum0 = vmull_lane_s16(vget_low_s16(s0), filter, 0);
+  sum0 = vmlal_lane_s16(sum0, vget_low_s16(s1), filter, 1);
+  sum0 = vmlal_lane_s16(sum0, vget_low_s16(s2), filter, 2);
+  sum0 = vmlal_lane_s16(sum0, vget_low_s16(s3), filter, 3);
+
+  int32x4_t sum1 = vmull_lane_s16(vget_high_s16(s0), filter, 0);
+  sum1 = vmlal_lane_s16(sum1, vget_high_s16(s1), filter, 1);
+  sum1 = vmlal_lane_s16(sum1, vget_high_s16(s2), filter, 2);
+  sum1 = vmlal_lane_s16(sum1, vget_high_s16(s3), filter, 3);
+
+  uint16x8_t res = vcombine_u16(vqrshrun_n_s32(sum0, FILTER_BITS),
+                                vqrshrun_n_s32(sum1, FILTER_BITS));
+
+  return vminq_u16(res, max);
+}
+
+static void highbd_convolve_horiz_4tap_neon(
+    const uint16_t *src_ptr, ptrdiff_t src_stride, uint16_t *dst_ptr,
+    ptrdiff_t dst_stride, const int16_t *x_filter_ptr, int w, int h, int bd) {
+  assert(w >= 4 && h >= 4);
+  const int16x4_t x_filter = vld1_s16(x_filter_ptr + 2);
+
+  if (w == 4) {
+    const uint16x4_t max = vdup_n_u16((1 << bd) - 1);
+    const int16_t *s = (const int16_t *)src_ptr;
+    uint16_t *d = dst_ptr;
+
+    do {
+      int16x4_t s0[4], s1[4], s2[4], s3[4];
+      load_s16_4x4(s + 0 * src_stride, 1, &s0[0], &s0[1], &s0[2], &s0[3]);
+      load_s16_4x4(s + 1 * src_stride, 1, &s1[0], &s1[1], &s1[2], &s1[3]);
+      load_s16_4x4(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3]);
+      load_s16_4x4(s + 3 * src_stride, 1, &s3[0], &s3[1], &s3[2], &s3[3]);
+
+      uint16x4_t d0 =
+          highbd_convolve4_4(s0[0], s0[1], s0[2], s0[3], x_filter, max);
+      uint16x4_t d1 =
+          highbd_convolve4_4(s1[0], s1[1], s1[2], s1[3], x_filter, max);
+      uint16x4_t d2 =
+          highbd_convolve4_4(s2[0], s2[1], s2[2], s2[3], x_filter, max);
+      uint16x4_t d3 =
+          highbd_convolve4_4(s3[0], s3[1], s3[2], s3[3], x_filter, max);
+
+      store_u16_4x4(d, dst_stride, d0, d1, d2, d3);
+
+      s += 4 * src_stride;
+      d += 4 * dst_stride;
+      h -= 4;
+    } while (h > 0);
+  } else {
+    const uint16x8_t max = vdupq_n_u16((1 << bd) - 1);
+    int height = h;
+
+    do {
+      int width = w;
+      const int16_t *s = (const int16_t *)src_ptr;
+      uint16_t *d = dst_ptr;
+
+      do {
+        int16x8_t s0[4], s1[4], s2[4], s3[4];
+        load_s16_8x4(s + 0 * src_stride, 1, &s0[0], &s0[1], &s0[2], &s0[3]);
+        load_s16_8x4(s + 1 * src_stride, 1, &s1[0], &s1[1], &s1[2], &s1[3]);
+        load_s16_8x4(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3]);
+        load_s16_8x4(s + 3 * src_stride, 1, &s3[0], &s3[1], &s3[2], &s3[3]);
+
+        uint16x8_t d0 =
+            highbd_convolve4_8(s0[0], s0[1], s0[2], s0[3], x_filter, max);
+        uint16x8_t d1 =
+            highbd_convolve4_8(s1[0], s1[1], s1[2], s1[3], x_filter, max);
+        uint16x8_t d2 =
+            highbd_convolve4_8(s2[0], s2[1], s2[2], s2[3], x_filter, max);
+        uint16x8_t d3 =
+            highbd_convolve4_8(s3[0], s3[1], s3[2], s3[3], x_filter, max);
+
+        store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
+
+        s += 8;
+        d += 8;
+        width -= 8;
+      } while (width > 0);
+      src_ptr += 4 * src_stride;
+      dst_ptr += 4 * dst_stride;
+      height -= 4;
+    } while (height > 0);
+  }
+}
+
 void aom_highbd_convolve8_horiz_neon(const uint8_t *src8, ptrdiff_t src_stride,
                                      uint8_t *dst8, ptrdiff_t dst_stride,
                                      const int16_t *filter_x, int x_step_q4,
@@ -174,8 +276,14 @@ void aom_highbd_convolve8_horiz_neon(const uint8_t *src8, ptrdiff_t src_stride,
     uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
 
     src -= SUBPEL_TAPS / 2 - 1;
-    highbd_convolve_horiz_neon(src, src_stride, dst, dst_stride, filter_x, w, h,
-                               bd);
+
+    if (get_filter_taps_convolve8(filter_x) <= 4) {
+      highbd_convolve_horiz_4tap_neon(src + 2, src_stride, dst, dst_stride,
+                                      filter_x, w, h, bd);
+    } else {
+      highbd_convolve_horiz_8tap_neon(src, src_stride, dst, dst_stride,
+                                      filter_x, w, h, bd);
+    }
   }
 }