aom: Add 4-tap specialisation to aom_convolve8_horiz_neon_i8mm

From e7dcb10a44ee19b5a9d0654aad9358f49e8ec962 Mon Sep 17 00:00:00 2001
From: Salome Thirot <[EMAIL REDACTED]>
Date: Fri, 12 Apr 2024 16:36:54 +0100
Subject: [PATCH] Add 4-tap specialisation to aom_convolve8_horiz_neon_i8mm

Add specialised path for 4-tap filters in aom_convolve8_horiz_neon_i8mm.
This gives between 30% and 50% uplift compared to using the 8-tap path.

This is a port from libpvx, using hash
9d8d71b41bd5b5be78c765bb099cbde84b99f25c.

Change-Id: I7913edd7c5d8546bc5ec0800dac2a1866c9f3739
---
 aom_dsp/arm/aom_convolve8_neon_i8mm.c | 135 +++++++++++++++++++++++---
 1 file changed, 121 insertions(+), 14 deletions(-)

diff --git a/aom_dsp/arm/aom_convolve8_neon_i8mm.c b/aom_dsp/arm/aom_convolve8_neon_i8mm.c
index 68e031461..da0210ac3 100644
--- a/aom_dsp/arm/aom_convolve8_neon_i8mm.c
+++ b/aom_dsp/arm/aom_convolve8_neon_i8mm.c
@@ -19,6 +19,7 @@
 #include "aom/aom_integer.h"
 #include "aom_dsp/aom_dsp_common.h"
 #include "aom_dsp/aom_filter.h"
+#include "aom_dsp/arm/aom_filter.h"
 #include "aom_dsp/arm/mem_neon.h"
 #include "aom_dsp/arm/transpose_neon.h"
 #include "aom_ports/mem.h"
@@ -80,22 +81,11 @@ static INLINE uint8x8_t convolve8_8_h(const uint8x16_t samples,
   return vqrshrun_n_s16(sum, FILTER_BITS);
 }
 
-void aom_convolve8_horiz_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride,
-                                   uint8_t *dst, ptrdiff_t dst_stride,
-                                   const int16_t *filter_x, int x_step_q4,
-                                   const int16_t *filter_y, int y_step_q4,
-                                   int w, int h) {
+static INLINE void convolve8_horiz_8tap_neon_i8mm(
+    const uint8_t *src, ptrdiff_t src_stride, uint8_t *dst,
+    ptrdiff_t dst_stride, const int16_t *filter_x, int w, int h) {
   const int8x8_t filter = vmovn_s16(vld1q_s16(filter_x));
 
-  assert((intptr_t)dst % 4 == 0);
-  assert(dst_stride % 4 == 0);
-
-  (void)x_step_q4;
-  (void)filter_y;
-  (void)y_step_q4;
-
-  src -= ((SUBPEL_TAPS / 2) - 1);
-
   if (w == 4) {
     const uint8x16x2_t perm_tbl = vld1q_u8_x2(kDotProdPermuteTbl);
     do {
@@ -145,6 +135,123 @@ void aom_convolve8_horiz_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride,
   }
 }
 
+static INLINE int16x4_t convolve4_4_h(const uint8x16_t samples,
+                                      const int8x8_t filters,
+                                      const uint8x16_t permute_tbl) {
+  // Permute samples ready for dot product.
+  // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
+  uint8x16_t permuted_samples = vqtbl1q_u8(samples, permute_tbl);
+
+  int32x4_t sum =
+      vusdotq_lane_s32(vdupq_n_s32(0), permuted_samples, filters, 0);
+
+  // Further narrowing and packing is performed by the caller.
+  return vmovn_s32(sum);
+}
+
+static INLINE uint8x8_t convolve4_8_h(const uint8x16_t samples,
+                                      const int8x8_t filters,
+                                      const uint8x16x2_t permute_tbl) {
+  // Permute samples ready for dot product.
+  // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
+  // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
+  uint8x16_t permuted_samples[2] = { vqtbl1q_u8(samples, permute_tbl.val[0]),
+                                     vqtbl1q_u8(samples, permute_tbl.val[1]) };
+
+  // First 4 output values.
+  int32x4_t sum0 =
+      vusdotq_lane_s32(vdupq_n_s32(0), permuted_samples[0], filters, 0);
+  // Second 4 output values.
+  int32x4_t sum1 =
+      vusdotq_lane_s32(vdupq_n_s32(0), permuted_samples[1], filters, 0);
+
+  // Narrow and re-pack.
+  int16x8_t sum = vcombine_s16(vmovn_s32(sum0), vmovn_s32(sum1));
+  // We halved the filter values so -1 from right shift.
+  return vqrshrun_n_s16(sum, FILTER_BITS - 1);
+}
+
+static INLINE void convolve8_horiz_4tap_neon_i8mm(
+    const uint8_t *src, ptrdiff_t src_stride, uint8_t *dst,
+    ptrdiff_t dst_stride, const int16_t *filter_x, int width, int height) {
+  const int16x4_t x_filter = vld1_s16(filter_x + 2);
+  // All 4-tap and bilinear filter values are even, so halve them to reduce
+  // intermediate precision requirements.
+  const int8x8_t filter = vshrn_n_s16(vcombine_s16(x_filter, vdup_n_s16(0)), 1);
+
+  if (width == 4) {
+    const uint8x16_t perm_tbl = vld1q_u8(kDotProdPermuteTbl);
+    do {
+      uint8x16_t s0, s1, s2, s3;
+      load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
+
+      int16x4_t t0 = convolve4_4_h(s0, filter, perm_tbl);
+      int16x4_t t1 = convolve4_4_h(s1, filter, perm_tbl);
+      int16x4_t t2 = convolve4_4_h(s2, filter, perm_tbl);
+      int16x4_t t3 = convolve4_4_h(s3, filter, perm_tbl);
+      // We halved the filter values so -1 from right shift.
+      uint8x8_t d01 = vqrshrun_n_s16(vcombine_s16(t0, t1), FILTER_BITS - 1);
+      uint8x8_t d23 = vqrshrun_n_s16(vcombine_s16(t2, t3), FILTER_BITS - 1);
+
+      store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d01);
+      store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d23);
+
+      src += 4 * src_stride;
+      dst += 4 * dst_stride;
+      height -= 4;
+    } while (height > 0);
+  } else {
+    const uint8x16x2_t perm_tbl = vld1q_u8_x2(kDotProdPermuteTbl);
+
+    do {
+      int w = width;
+      const uint8_t *s = src;
+      uint8_t *d = dst;
+      do {
+        uint8x16_t s0, s1, s2, s3;
+        load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
+
+        uint8x8_t d0 = convolve4_8_h(s0, filter, perm_tbl);
+        uint8x8_t d1 = convolve4_8_h(s1, filter, perm_tbl);
+        uint8x8_t d2 = convolve4_8_h(s2, filter, perm_tbl);
+        uint8x8_t d3 = convolve4_8_h(s3, filter, perm_tbl);
+
+        store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
+
+        s += 8;
+        d += 8;
+        w -= 8;
+      } while (w != 0);
+      src += 4 * src_stride;
+      dst += 4 * dst_stride;
+      height -= 4;
+    } while (height > 0);
+  }
+}
+
+void aom_convolve8_horiz_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride,
+                                   uint8_t *dst, ptrdiff_t dst_stride,
+                                   const int16_t *filter_x, int x_step_q4,
+                                   const int16_t *filter_y, int y_step_q4,
+                                   int w, int h) {
+  assert((intptr_t)dst % 4 == 0);
+  assert(dst_stride % 4 == 0);
+
+  (void)x_step_q4;
+  (void)filter_y;
+  (void)y_step_q4;
+
+  src -= ((SUBPEL_TAPS / 2) - 1);
+
+  if (get_filter_taps_convolve8(filter_x) <= 4) {
+    convolve8_horiz_4tap_neon_i8mm(src + 2, src_stride, dst, dst_stride,
+                                   filter_x, w, h);
+  } else {
+    convolve8_horiz_8tap_neon_i8mm(src, src_stride, dst, dst_stride, filter_x,
+                                   w, h);
+  }
+}
+
 static INLINE void transpose_concat_4x4(uint8x8_t a0, uint8x8_t a1,
                                         uint8x8_t a2, uint8x8_t a3,
                                         uint8x16_t *b) {