aom: Add 8-tap path for av1_convolve_y_sr_neon_dotprod

From 766e37a42fff20511a81f7a0300eaf968e078ce6 Mon Sep 17 00:00:00 2001
From: Salome Thirot <[EMAIL REDACTED]>
Date: Thu, 25 Apr 2024 10:41:28 +0100
Subject: [PATCH] Add 8-tap path for av1_convolve_y_sr_neon_dotprod

Add 8-tap specialisation for av1_convolve_y_sr_neon_dotprod. This gives
around 10-20% uplift over the Neon implementation.

Change-Id: I913df3892ab47dc0ee1d0f28ad62de861bb4320d
---
 av1/common/arm/convolve_neon_dotprod.c | 196 ++++++++++++++++++++++++-
 1 file changed, 193 insertions(+), 3 deletions(-)

diff --git a/av1/common/arm/convolve_neon_dotprod.c b/av1/common/arm/convolve_neon_dotprod.c
index bf945c6fa..393f2e81f 100644
--- a/av1/common/arm/convolve_neon_dotprod.c
+++ b/av1/common/arm/convolve_neon_dotprod.c
@@ -738,6 +738,190 @@ static INLINE void convolve_y_sr_12tap_neon_dotprod(
   }
 }
 
+static INLINE int16x4_t convolve8_4_y(const int8x16_t s0, const int8x16_t s1,
+                                      const int8x8_t filters) {
+  // The sample range transform and permutation are performed by the caller.
+  // Accumulate into 128 << FILTER_BITS to account for range transform.
+  const int32x4_t acc = vdupq_n_s32(128 << FILTER_BITS);
+  int32x4_t sum = vdotq_lane_s32(acc, s0, filters, 0);
+  sum = vdotq_lane_s32(sum, s1, filters, 1);
+
+  // Further narrowing and packing is performed by the caller.
+  return vqmovn_s32(sum);
+}
+
+static INLINE uint8x8_t convolve8_8_y(const int8x16_t s0_lo,
+                                      const int8x16_t s0_hi,
+                                      const int8x16_t s1_lo,
+                                      const int8x16_t s1_hi,
+                                      const int8x8_t filters) {
+  // The sample range transform and permutation are performed by the caller.
+  // Accumulate into 128 << FILTER_BITS to account for range transform.
+  const int32x4_t acc = vdupq_n_s32(128 << FILTER_BITS);
+
+  int32x4_t sum0123 = vdotq_lane_s32(acc, s0_lo, filters, 0);
+  sum0123 = vdotq_lane_s32(sum0123, s1_lo, filters, 1);
+
+  int32x4_t sum4567 = vdotq_lane_s32(acc, s0_hi, filters, 0);
+  sum4567 = vdotq_lane_s32(sum4567, s1_hi, filters, 1);
+
+  // Narrow and re-pack.
+  int16x8_t sum = vcombine_s16(vqmovn_s32(sum0123), vqmovn_s32(sum4567));
+  return vqrshrun_n_s16(sum, FILTER_BITS);
+}
+
+static INLINE void convolve_y_sr_8tap_neon_dotprod(
+    const uint8_t *src_ptr, int src_stride, uint8_t *dst_ptr, int dst_stride,
+    int w, int h, const int16_t *y_filter_ptr) {
+  const int8x8_t filter = vmovn_s16(vld1q_s16(y_filter_ptr));
+
+  const uint8x16x3_t merge_block_tbl = vld1q_u8_x3(kDotProdMergeBlockTbl);
+
+  if (w == 4) {
+    uint8x8_t t0, t1, t2, t3, t4, t5, t6;
+    load_u8_8x7(src_ptr, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6);
+    src_ptr += 7 * src_stride;
+
+    // Transform sample range to [-128, 127] for 8-bit signed dot product.
+    int8x8_t s0 = vreinterpret_s8_u8(vsub_u8(t0, vdup_n_u8(128)));
+    int8x8_t s1 = vreinterpret_s8_u8(vsub_u8(t1, vdup_n_u8(128)));
+    int8x8_t s2 = vreinterpret_s8_u8(vsub_u8(t2, vdup_n_u8(128)));
+    int8x8_t s3 = vreinterpret_s8_u8(vsub_u8(t3, vdup_n_u8(128)));
+    int8x8_t s4 = vreinterpret_s8_u8(vsub_u8(t4, vdup_n_u8(128)));
+    int8x8_t s5 = vreinterpret_s8_u8(vsub_u8(t5, vdup_n_u8(128)));
+    int8x8_t s6 = vreinterpret_s8_u8(vsub_u8(t6, vdup_n_u8(128)));
+
+    int8x16_t s0123, s1234, s2345, s3456;
+    transpose_concat_4x4(s0, s1, s2, s3, &s0123);
+    transpose_concat_4x4(s1, s2, s3, s4, &s1234);
+    transpose_concat_4x4(s2, s3, s4, s5, &s2345);
+    transpose_concat_4x4(s3, s4, s5, s6, &s3456);
+
+    do {
+      uint8x8_t t7, t8, t9, t10;
+      load_u8_8x4(src_ptr, src_stride, &t7, &t8, &t9, &t10);
+
+      int8x8_t s7 = vreinterpret_s8_u8(vsub_u8(t7, vdup_n_u8(128)));
+      int8x8_t s8 = vreinterpret_s8_u8(vsub_u8(t8, vdup_n_u8(128)));
+      int8x8_t s9 = vreinterpret_s8_u8(vsub_u8(t9, vdup_n_u8(128)));
+      int8x8_t s10 = vreinterpret_s8_u8(vsub_u8(t10, vdup_n_u8(128)));
+
+      int8x16_t s4567, s5678, s6789, s78910;
+      transpose_concat_4x4(s7, s8, s9, s10, &s78910);
+
+      // Merge new data into block from previous iteration.
+      int8x16x2_t samples_LUT = { { s3456, s78910 } };
+      s4567 = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[0]);
+      s5678 = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[1]);
+      s6789 = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[2]);
+
+      int16x4_t d0 = convolve8_4_y(s0123, s4567, filter);
+      int16x4_t d1 = convolve8_4_y(s1234, s5678, filter);
+      int16x4_t d2 = convolve8_4_y(s2345, s6789, filter);
+      int16x4_t d3 = convolve8_4_y(s3456, s78910, filter);
+      uint8x8_t d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS);
+      uint8x8_t d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS);
+
+      store_u8x4_strided_x2(dst_ptr + 0 * dst_stride, dst_stride, d01);
+      store_u8x4_strided_x2(dst_ptr + 2 * dst_stride, dst_stride, d23);
+
+      // Prepare block for next iteration - re-using as much as possible.
+      // Shuffle everything up four rows.
+      s0123 = s4567;
+      s1234 = s5678;
+      s2345 = s6789;
+      s3456 = s78910;
+
+      src_ptr += 4 * src_stride;
+      dst_ptr += 4 * dst_stride;
+      h -= 4;
+    } while (h != 0);
+  } else {
+    do {
+      int height = h;
+      const uint8_t *s = src_ptr;
+      uint8_t *d = dst_ptr;
+
+      uint8x8_t t0, t1, t2, t3, t4, t5, t6;
+      load_u8_8x7(s, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6);
+      s += 7 * src_stride;
+
+      // Transform sample range to [-128, 127] for 8-bit signed dot product.
+      int8x8_t s0 = vreinterpret_s8_u8(vsub_u8(t0, vdup_n_u8(128)));
+      int8x8_t s1 = vreinterpret_s8_u8(vsub_u8(t1, vdup_n_u8(128)));
+      int8x8_t s2 = vreinterpret_s8_u8(vsub_u8(t2, vdup_n_u8(128)));
+      int8x8_t s3 = vreinterpret_s8_u8(vsub_u8(t3, vdup_n_u8(128)));
+      int8x8_t s4 = vreinterpret_s8_u8(vsub_u8(t4, vdup_n_u8(128)));
+      int8x8_t s5 = vreinterpret_s8_u8(vsub_u8(t5, vdup_n_u8(128)));
+      int8x8_t s6 = vreinterpret_s8_u8(vsub_u8(t6, vdup_n_u8(128)));
+
+      // This operation combines a conventional transpose and the sample
+      // permute (see horizontal case) required before computing the dot
+      // product.
+      int8x16_t s0123_lo, s0123_hi, s1234_lo, s1234_hi, s2345_lo, s2345_hi,
+          s3456_lo, s3456_hi;
+      transpose_concat_8x4(s0, s1, s2, s3, &s0123_lo, &s0123_hi);
+      transpose_concat_8x4(s1, s2, s3, s4, &s1234_lo, &s1234_hi);
+      transpose_concat_8x4(s2, s3, s4, s5, &s2345_lo, &s2345_hi);
+      transpose_concat_8x4(s3, s4, s5, s6, &s3456_lo, &s3456_hi);
+
+      do {
+        uint8x8_t t7, t8, t9, t10;
+        load_u8_8x4(s, src_stride, &t7, &t8, &t9, &t10);
+
+        int8x8_t s7 = vreinterpret_s8_u8(vsub_u8(t7, vdup_n_u8(128)));
+        int8x8_t s8 = vreinterpret_s8_u8(vsub_u8(t8, vdup_n_u8(128)));
+        int8x8_t s9 = vreinterpret_s8_u8(vsub_u8(t9, vdup_n_u8(128)));
+        int8x8_t s10 = vreinterpret_s8_u8(vsub_u8(t10, vdup_n_u8(128)));
+
+        int8x16_t s4567_lo, s4567_hi, s5678_lo, s5678_hi, s6789_lo, s6789_hi,
+            s78910_lo, s78910_hi;
+        transpose_concat_8x4(s7, s8, s9, s10, &s78910_lo, &s78910_hi);
+
+        // Merge new data into block from previous iteration.
+        int8x16x2_t samples_LUT_lo = { { s3456_lo, s78910_lo } };
+        s4567_lo = vqtbl2q_s8(samples_LUT_lo, merge_block_tbl.val[0]);
+        s5678_lo = vqtbl2q_s8(samples_LUT_lo, merge_block_tbl.val[1]);
+        s6789_lo = vqtbl2q_s8(samples_LUT_lo, merge_block_tbl.val[2]);
+
+        int8x16x2_t samples_LUT_hi = { { s3456_hi, s78910_hi } };
+        s4567_hi = vqtbl2q_s8(samples_LUT_hi, merge_block_tbl.val[0]);
+        s5678_hi = vqtbl2q_s8(samples_LUT_hi, merge_block_tbl.val[1]);
+        s6789_hi = vqtbl2q_s8(samples_LUT_hi, merge_block_tbl.val[2]);
+
+        uint8x8_t d0 =
+            convolve8_8_y(s0123_lo, s0123_hi, s4567_lo, s4567_hi, filter);
+        uint8x8_t d1 =
+            convolve8_8_y(s1234_lo, s1234_hi, s5678_lo, s5678_hi, filter);
+        uint8x8_t d2 =
+            convolve8_8_y(s2345_lo, s2345_hi, s6789_lo, s6789_hi, filter);
+        uint8x8_t d3 =
+            convolve8_8_y(s3456_lo, s3456_hi, s78910_lo, s78910_hi, filter);
+
+        store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
+
+        // Prepare block for next iteration - re-using as much as possible.
+        // Shuffle everything up four rows.
+        s0123_lo = s4567_lo;
+        s0123_hi = s4567_hi;
+        s1234_lo = s5678_lo;
+        s1234_hi = s5678_hi;
+        s2345_lo = s6789_lo;
+        s2345_hi = s6789_hi;
+        s3456_lo = s78910_lo;
+        s3456_hi = s78910_hi;
+
+        s += 4 * src_stride;
+        d += 4 * dst_stride;
+        height -= 4;
+      } while (height != 0);
+      src_ptr += 8;
+      dst_ptr += 8;
+      w -= 8;
+    } while (w != 0);
+  }
+}
+
 void av1_convolve_y_sr_neon_dotprod(const uint8_t *src, int src_stride,
                                     uint8_t *dst, int dst_stride, int w, int h,
                                     const InterpFilterParams *filter_params_y,
@@ -750,7 +934,7 @@ void av1_convolve_y_sr_neon_dotprod(const uint8_t *src, int src_stride,
 
   const int y_filter_taps = get_filter_tap(filter_params_y, subpel_y_qn);
 
-  if (y_filter_taps <= 8) {
+  if (y_filter_taps <= 6) {
     av1_convolve_y_sr_neon(src, src_stride, dst, dst_stride, w, h,
                            filter_params_y, subpel_y_qn);
     return;
@@ -762,8 +946,14 @@ void av1_convolve_y_sr_neon_dotprod(const uint8_t *src, int src_stride,
   const int16_t *y_filter_ptr = av1_get_interp_filter_subpel_kernel(
       filter_params_y, subpel_y_qn & SUBPEL_MASK);
 
-  convolve_y_sr_12tap_neon_dotprod(src, src_stride, dst, dst_stride, w, h,
-                                   y_filter_ptr);
+  if (y_filter_taps > 8) {
+    convolve_y_sr_12tap_neon_dotprod(src, src_stride, dst, dst_stride, w, h,
+                                     y_filter_ptr);
+    return;
+  }
+
+  convolve_y_sr_8tap_neon_dotprod(src, src_stride, dst, dst_stride, w, h,
+                                  y_filter_ptr);
 }
 
 static INLINE int16x4_t convolve12_4_2d_h(uint8x16_t samples,