aom: Add Neon Dotprod implementation of av1_convolve_y_sr for 12-tap

From dadb003877190556961f36e67d52495cc1742a54 Mon Sep 17 00:00:00 2001
From: Salome Thirot <[EMAIL REDACTED]>
Date: Wed, 24 Apr 2024 17:18:24 +0100
Subject: [PATCH] Add Neon Dotprod implementation of av1_convolve_y_sr for
 12-tap

Add an Armv8.4 implementation of av1_convolve_y_sr for 12-tap filters.
This gives between 20 and 50% upflift over the Armv8.0 implementation.

Change-Id: Icd163f4f4c5c56b899d268b91d79b5733724eab4
---
 av1/common/arm/convolve_neon_dotprod.c | 345 +++++++++++++++++++++++++
 av1/common/av1_rtcd_defs.pl            |   2 +-
 test/av1_convolve_test.cc              |   5 +
 3 files changed, 351 insertions(+), 1 deletion(-)

diff --git a/av1/common/arm/convolve_neon_dotprod.c b/av1/common/arm/convolve_neon_dotprod.c
index d670657f8..bf945c6fa 100644
--- a/av1/common/arm/convolve_neon_dotprod.c
+++ b/av1/common/arm/convolve_neon_dotprod.c
@@ -27,6 +27,15 @@ DECLARE_ALIGNED(16, static const uint8_t, kDotProdPermuteTbl[48]) = {
   8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14
 };
 
+DECLARE_ALIGNED(16, static const uint8_t, kDotProdMergeBlockTbl[48]) = {
+  // Shift left and insert new last column in transposed 4x4 block.
+  1, 2, 3, 16, 5, 6, 7, 20, 9, 10, 11, 24, 13, 14, 15, 28,
+  // Shift left and insert two new columns in transposed 4x4 block.
+  2, 3, 16, 17, 6, 7, 20, 21, 10, 11, 24, 25, 14, 15, 28, 29,
+  // Shift left and insert three new columns in transposed 4x4 block.
+  3, 16, 17, 18, 7, 20, 21, 22, 11, 24, 25, 26, 15, 28, 29, 30
+};
+
 static INLINE int16x4_t convolve12_4_x(uint8x16_t samples,
                                        const int8x16_t filter,
                                        const int32x4_t correction,
@@ -421,6 +430,342 @@ void av1_convolve_x_sr_neon_dotprod(const uint8_t *src, int src_stride,
   } while (h != 0);
 }
 
+static INLINE void transpose_concat_4x4(int8x8_t a0, int8x8_t a1, int8x8_t a2,
+                                        int8x8_t a3, int8x16_t *b) {
+  // Transpose 8-bit elements and concatenate result rows as follows:
+  // a0: 00, 01, 02, 03, XX, XX, XX, XX
+  // a1: 10, 11, 12, 13, XX, XX, XX, XX
+  // a2: 20, 21, 22, 23, XX, XX, XX, XX
+  // a3: 30, 31, 32, 33, XX, XX, XX, XX
+  //
+  // b: 00, 10, 20, 30, 01, 11, 21, 31, 02, 12, 22, 32, 03, 13, 23, 33
+
+  int8x16_t a0q = vcombine_s8(a0, vdup_n_s8(0));
+  int8x16_t a1q = vcombine_s8(a1, vdup_n_s8(0));
+  int8x16_t a2q = vcombine_s8(a2, vdup_n_s8(0));
+  int8x16_t a3q = vcombine_s8(a3, vdup_n_s8(0));
+
+  int8x16_t a01 = vzipq_s8(a0q, a1q).val[0];
+  int8x16_t a23 = vzipq_s8(a2q, a3q).val[0];
+
+  int16x8_t a0123 =
+      vzipq_s16(vreinterpretq_s16_s8(a01), vreinterpretq_s16_s8(a23)).val[0];
+
+  *b = vreinterpretq_s8_s16(a0123);
+}
+
+static INLINE void transpose_concat_8x4(int8x8_t a0, int8x8_t a1, int8x8_t a2,
+                                        int8x8_t a3, int8x16_t *b0,
+                                        int8x16_t *b1) {
+  // Transpose 8-bit elements and concatenate result rows as follows:
+  // a0: 00, 01, 02, 03, 04, 05, 06, 07
+  // a1: 10, 11, 12, 13, 14, 15, 16, 17
+  // a2: 20, 21, 22, 23, 24, 25, 26, 27
+  // a3: 30, 31, 32, 33, 34, 35, 36, 37
+  //
+  // b0: 00, 10, 20, 30, 01, 11, 21, 31, 02, 12, 22, 32, 03, 13, 23, 33
+  // b1: 04, 14, 24, 34, 05, 15, 25, 35, 06, 16, 26, 36, 07, 17, 27, 37
+
+  int8x16_t a0q = vcombine_s8(a0, vdup_n_s8(0));
+  int8x16_t a1q = vcombine_s8(a1, vdup_n_s8(0));
+  int8x16_t a2q = vcombine_s8(a2, vdup_n_s8(0));
+  int8x16_t a3q = vcombine_s8(a3, vdup_n_s8(0));
+
+  int8x16_t a01 = vzipq_s8(a0q, a1q).val[0];
+  int8x16_t a23 = vzipq_s8(a2q, a3q).val[0];
+
+  int16x8x2_t a0123 =
+      vzipq_s16(vreinterpretq_s16_s8(a01), vreinterpretq_s16_s8(a23));
+
+  *b0 = vreinterpretq_s8_s16(a0123.val[0]);
+  *b1 = vreinterpretq_s8_s16(a0123.val[1]);
+}
+
+static INLINE int16x4_t convolve12_4_y(const int8x16_t s0, const int8x16_t s1,
+                                       const int8x16_t s2,
+                                       const int8x8_t filters_0_7,
+                                       const int8x8_t filters_4_11) {
+  // The sample range transform and permutation are performed by the caller.
+  // Accumulate into 128 << FILTER_BITS to account for range transform.
+  const int32x4_t acc = vdupq_n_s32(128 << FILTER_BITS);
+  int32x4_t sum = vdotq_lane_s32(acc, s0, filters_0_7, 0);
+  sum = vdotq_lane_s32(sum, s1, filters_0_7, 1);
+  sum = vdotq_lane_s32(sum, s2, filters_4_11, 1);
+
+  // Further narrowing and packing is performed by the caller.
+  return vqmovn_s32(sum);
+}
+
+static INLINE uint8x8_t convolve12_8_y(
+    const int8x16_t s0_lo, const int8x16_t s0_hi, const int8x16_t s1_lo,
+    const int8x16_t s1_hi, const int8x16_t s2_lo, const int8x16_t s2_hi,
+    const int8x8_t filters_0_7, const int8x8_t filters_4_11) {
+  // The sample range transform and permutation are performed by the caller.
+  // Accumulate into 128 << FILTER_BITS to account for range transform.
+  const int32x4_t acc = vdupq_n_s32(128 << FILTER_BITS);
+
+  int32x4_t sum0123 = vdotq_lane_s32(acc, s0_lo, filters_0_7, 0);
+  sum0123 = vdotq_lane_s32(sum0123, s1_lo, filters_0_7, 1);
+  sum0123 = vdotq_lane_s32(sum0123, s2_lo, filters_4_11, 1);
+
+  int32x4_t sum4567 = vdotq_lane_s32(acc, s0_hi, filters_0_7, 0);
+  sum4567 = vdotq_lane_s32(sum4567, s1_hi, filters_0_7, 1);
+  sum4567 = vdotq_lane_s32(sum4567, s2_hi, filters_4_11, 1);
+
+  // Narrow and re-pack.
+  int16x8_t sum = vcombine_s16(vqmovn_s32(sum0123), vqmovn_s32(sum4567));
+  return vqrshrun_n_s16(sum, FILTER_BITS);
+}
+
+static INLINE void convolve_y_sr_12tap_neon_dotprod(
+    const uint8_t *src_ptr, int src_stride, uint8_t *dst_ptr, int dst_stride,
+    int w, int h, const int16_t *y_filter_ptr) {
+  // Special case the following no-op filter as 128 won't fit into the
+  // 8-bit signed dot-product instruction:
+  // { 0, 0, 0, 0, 0, 128, 0, 0, 0, 0, 0, 0 }
+  if (y_filter_ptr[5] == 128) {
+    // Undo the vertical offset in the calling function.
+    src_ptr += 5 * src_stride;
+
+    do {
+      const uint8_t *s = src_ptr;
+      uint8_t *d = dst_ptr;
+      int width = w;
+
+      do {
+        uint8x8_t d0 = vld1_u8(s);
+        if (w == 4) {
+          store_u8_4x1(d, d0);
+        } else {
+          vst1_u8(d, d0);
+        }
+
+        s += 8;
+        d += 8;
+        width -= 8;
+      } while (width > 0);
+      src_ptr += src_stride;
+      dst_ptr += dst_stride;
+    } while (--h != 0);
+  } else {
+    const int8x8_t filter_0_7 = vmovn_s16(vld1q_s16(y_filter_ptr));
+    const int8x8_t filter_4_11 = vmovn_s16(vld1q_s16(y_filter_ptr + 4));
+
+    const uint8x16x3_t merge_block_tbl = vld1q_u8_x3(kDotProdMergeBlockTbl);
+
+    if (w == 4) {
+      uint8x8_t t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, tA;
+      load_u8_8x11(src_ptr, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7,
+                   &t8, &t9, &tA);
+      src_ptr += 11 * src_stride;
+
+      // Transform sample range to [-128, 127] for 8-bit signed dot product.
+      int8x8_t s0 = vreinterpret_s8_u8(vsub_u8(t0, vdup_n_u8(128)));
+      int8x8_t s1 = vreinterpret_s8_u8(vsub_u8(t1, vdup_n_u8(128)));
+      int8x8_t s2 = vreinterpret_s8_u8(vsub_u8(t2, vdup_n_u8(128)));
+      int8x8_t s3 = vreinterpret_s8_u8(vsub_u8(t3, vdup_n_u8(128)));
+      int8x8_t s4 = vreinterpret_s8_u8(vsub_u8(t4, vdup_n_u8(128)));
+      int8x8_t s5 = vreinterpret_s8_u8(vsub_u8(t5, vdup_n_u8(128)));
+      int8x8_t s6 = vreinterpret_s8_u8(vsub_u8(t6, vdup_n_u8(128)));
+      int8x8_t s7 = vreinterpret_s8_u8(vsub_u8(t7, vdup_n_u8(128)));
+      int8x8_t s8 = vreinterpret_s8_u8(vsub_u8(t8, vdup_n_u8(128)));
+      int8x8_t s9 = vreinterpret_s8_u8(vsub_u8(t9, vdup_n_u8(128)));
+      int8x8_t sA = vreinterpret_s8_u8(vsub_u8(tA, vdup_n_u8(128)));
+
+      int8x16_t s0123, s1234, s2345, s3456, s4567, s5678, s6789, s789A;
+      transpose_concat_4x4(s0, s1, s2, s3, &s0123);
+      transpose_concat_4x4(s1, s2, s3, s4, &s1234);
+      transpose_concat_4x4(s2, s3, s4, s5, &s2345);
+      transpose_concat_4x4(s3, s4, s5, s6, &s3456);
+      transpose_concat_4x4(s4, s5, s6, s7, &s4567);
+      transpose_concat_4x4(s5, s6, s7, s8, &s5678);
+      transpose_concat_4x4(s6, s7, s8, s9, &s6789);
+      transpose_concat_4x4(s7, s8, s9, sA, &s789A);
+
+      do {
+        uint8x8_t tB, tC, tD, tE;
+        load_u8_8x4(src_ptr, src_stride, &tB, &tC, &tD, &tE);
+
+        int8x8_t sB = vreinterpret_s8_u8(vsub_u8(tB, vdup_n_u8(128)));
+        int8x8_t sC = vreinterpret_s8_u8(vsub_u8(tC, vdup_n_u8(128)));
+        int8x8_t sD = vreinterpret_s8_u8(vsub_u8(tD, vdup_n_u8(128)));
+        int8x8_t sE = vreinterpret_s8_u8(vsub_u8(tE, vdup_n_u8(128)));
+
+        int8x16_t s89AB, s9ABC, sABCD, sBCDE;
+        transpose_concat_4x4(sB, sC, sD, sE, &sBCDE);
+
+        // Merge new data into block from previous iteration.
+        int8x16x2_t samples_LUT = { { s789A, sBCDE } };
+        s89AB = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[0]);
+        s9ABC = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[1]);
+        sABCD = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[2]);
+
+        int16x4_t d0 =
+            convolve12_4_y(s0123, s4567, s89AB, filter_0_7, filter_4_11);
+        int16x4_t d1 =
+            convolve12_4_y(s1234, s5678, s9ABC, filter_0_7, filter_4_11);
+        int16x4_t d2 =
+            convolve12_4_y(s2345, s6789, sABCD, filter_0_7, filter_4_11);
+        int16x4_t d3 =
+            convolve12_4_y(s3456, s789A, sBCDE, filter_0_7, filter_4_11);
+        uint8x8_t d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS);
+        uint8x8_t d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS);
+
+        store_u8x4_strided_x2(dst_ptr + 0 * dst_stride, dst_stride, d01);
+        store_u8x4_strided_x2(dst_ptr + 2 * dst_stride, dst_stride, d23);
+
+        // Prepare block for next iteration - re-using as much as possible.
+        // Shuffle everything up four rows.
+        s0123 = s4567;
+        s1234 = s5678;
+        s2345 = s6789;
+        s3456 = s789A;
+        s4567 = s89AB;
+        s5678 = s9ABC;
+        s6789 = sABCD;
+        s789A = sBCDE;
+
+        src_ptr += 4 * src_stride;
+        dst_ptr += 4 * dst_stride;
+        h -= 4;
+      } while (h != 0);
+    } else {
+      do {
+        int height = h;
+        const uint8_t *s = src_ptr;
+        uint8_t *d = dst_ptr;
+
+        uint8x8_t t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, tA;
+        load_u8_8x11(s, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7, &t8,
+                     &t9, &tA);
+        s += 11 * src_stride;
+
+        // Transform sample range to [-128, 127] for 8-bit signed dot product.
+        int8x8_t s0 = vreinterpret_s8_u8(vsub_u8(t0, vdup_n_u8(128)));
+        int8x8_t s1 = vreinterpret_s8_u8(vsub_u8(t1, vdup_n_u8(128)));
+        int8x8_t s2 = vreinterpret_s8_u8(vsub_u8(t2, vdup_n_u8(128)));
+        int8x8_t s3 = vreinterpret_s8_u8(vsub_u8(t3, vdup_n_u8(128)));
+        int8x8_t s4 = vreinterpret_s8_u8(vsub_u8(t4, vdup_n_u8(128)));
+        int8x8_t s5 = vreinterpret_s8_u8(vsub_u8(t5, vdup_n_u8(128)));
+        int8x8_t s6 = vreinterpret_s8_u8(vsub_u8(t6, vdup_n_u8(128)));
+        int8x8_t s7 = vreinterpret_s8_u8(vsub_u8(t7, vdup_n_u8(128)));
+        int8x8_t s8 = vreinterpret_s8_u8(vsub_u8(t8, vdup_n_u8(128)));
+        int8x8_t s9 = vreinterpret_s8_u8(vsub_u8(t9, vdup_n_u8(128)));
+        int8x8_t sA = vreinterpret_s8_u8(vsub_u8(tA, vdup_n_u8(128)));
+
+        // This operation combines a conventional transpose and the sample
+        // permute (see horizontal case) required before computing the dot
+        // product.
+        int8x16_t s0123_lo, s0123_hi, s1234_lo, s1234_hi, s2345_lo, s2345_hi,
+            s3456_lo, s3456_hi, s4567_lo, s4567_hi, s5678_lo, s5678_hi,
+            s6789_lo, s6789_hi, s789A_lo, s789A_hi;
+        transpose_concat_8x4(s0, s1, s2, s3, &s0123_lo, &s0123_hi);
+        transpose_concat_8x4(s1, s2, s3, s4, &s1234_lo, &s1234_hi);
+        transpose_concat_8x4(s2, s3, s4, s5, &s2345_lo, &s2345_hi);
+        transpose_concat_8x4(s3, s4, s5, s6, &s3456_lo, &s3456_hi);
+        transpose_concat_8x4(s4, s5, s6, s7, &s4567_lo, &s4567_hi);
+        transpose_concat_8x4(s5, s6, s7, s8, &s5678_lo, &s5678_hi);
+        transpose_concat_8x4(s6, s7, s8, s9, &s6789_lo, &s6789_hi);
+        transpose_concat_8x4(s7, s8, s9, sA, &s789A_lo, &s789A_hi);
+
+        do {
+          uint8x8_t tB, tC, tD, tE;
+          load_u8_8x4(s, src_stride, &tB, &tC, &tD, &tE);
+
+          int8x8_t sB = vreinterpret_s8_u8(vsub_u8(tB, vdup_n_u8(128)));
+          int8x8_t sC = vreinterpret_s8_u8(vsub_u8(tC, vdup_n_u8(128)));
+          int8x8_t sD = vreinterpret_s8_u8(vsub_u8(tD, vdup_n_u8(128)));
+          int8x8_t sE = vreinterpret_s8_u8(vsub_u8(tE, vdup_n_u8(128)));
+
+          int8x16_t s89AB_lo, s89AB_hi, s9ABC_lo, s9ABC_hi, sABCD_lo, sABCD_hi,
+              sBCDE_lo, sBCDE_hi;
+          transpose_concat_8x4(sB, sC, sD, sE, &sBCDE_lo, &sBCDE_hi);
+
+          // Merge new data into block from previous iteration.
+          int8x16x2_t samples_LUT_lo = { { s789A_lo, sBCDE_lo } };
+          s89AB_lo = vqtbl2q_s8(samples_LUT_lo, merge_block_tbl.val[0]);
+          s9ABC_lo = vqtbl2q_s8(samples_LUT_lo, merge_block_tbl.val[1]);
+          sABCD_lo = vqtbl2q_s8(samples_LUT_lo, merge_block_tbl.val[2]);
+
+          int8x16x2_t samples_LUT_hi = { { s789A_hi, sBCDE_hi } };
+          s89AB_hi = vqtbl2q_s8(samples_LUT_hi, merge_block_tbl.val[0]);
+          s9ABC_hi = vqtbl2q_s8(samples_LUT_hi, merge_block_tbl.val[1]);
+          sABCD_hi = vqtbl2q_s8(samples_LUT_hi, merge_block_tbl.val[2]);
+
+          uint8x8_t d0 =
+              convolve12_8_y(s0123_lo, s0123_hi, s4567_lo, s4567_hi, s89AB_lo,
+                             s89AB_hi, filter_0_7, filter_4_11);
+          uint8x8_t d1 =
+              convolve12_8_y(s1234_lo, s1234_hi, s5678_lo, s5678_hi, s9ABC_lo,
+                             s9ABC_hi, filter_0_7, filter_4_11);
+          uint8x8_t d2 =
+              convolve12_8_y(s2345_lo, s2345_hi, s6789_lo, s6789_hi, sABCD_lo,
+                             sABCD_hi, filter_0_7, filter_4_11);
+          uint8x8_t d3 =
+              convolve12_8_y(s3456_lo, s3456_hi, s789A_lo, s789A_hi, sBCDE_lo,
+                             sBCDE_hi, filter_0_7, filter_4_11);
+
+          store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
+
+          // Prepare block for next iteration - re-using as much as possible.
+          // Shuffle everything up four rows.
+          s0123_lo = s4567_lo;
+          s0123_hi = s4567_hi;
+          s1234_lo = s5678_lo;
+          s1234_hi = s5678_hi;
+          s2345_lo = s6789_lo;
+          s2345_hi = s6789_hi;
+          s3456_lo = s789A_lo;
+          s3456_hi = s789A_hi;
+          s4567_lo = s89AB_lo;
+          s4567_hi = s89AB_hi;
+          s5678_lo = s9ABC_lo;
+          s5678_hi = s9ABC_hi;
+          s6789_lo = sABCD_lo;
+          s6789_hi = sABCD_hi;
+          s789A_lo = sBCDE_lo;
+          s789A_hi = sBCDE_hi;
+
+          s += 4 * src_stride;
+          d += 4 * dst_stride;
+          height -= 4;
+        } while (height != 0);
+        src_ptr += 8;
+        dst_ptr += 8;
+        w -= 8;
+      } while (w != 0);
+    }
+  }
+}
+
+void av1_convolve_y_sr_neon_dotprod(const uint8_t *src, int src_stride,
+                                    uint8_t *dst, int dst_stride, int w, int h,
+                                    const InterpFilterParams *filter_params_y,
+                                    const int subpel_y_qn) {
+  if (w == 2 || h == 2) {
+    av1_convolve_y_sr_c(src, src_stride, dst, dst_stride, w, h, filter_params_y,
+                        subpel_y_qn);
+    return;
+  }
+
+  const int y_filter_taps = get_filter_tap(filter_params_y, subpel_y_qn);
+
+  if (y_filter_taps <= 8) {
+    av1_convolve_y_sr_neon(src, src_stride, dst, dst_stride, w, h,
+                           filter_params_y, subpel_y_qn);
+    return;
+  }
+
+  const int vert_offset = y_filter_taps / 2 - 1;
+  src -= vert_offset * src_stride;
+
+  const int16_t *y_filter_ptr = av1_get_interp_filter_subpel_kernel(
+      filter_params_y, subpel_y_qn & SUBPEL_MASK);
+
+  convolve_y_sr_12tap_neon_dotprod(src, src_stride, dst, dst_stride, w, h,
+                                   y_filter_ptr);
+}
+
 static INLINE int16x4_t convolve12_4_2d_h(uint8x16_t samples,
                                           const int8x16_t filters,
                                           const int32x4_t correction,
diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index 7035fb3bd..59d70f0e8 100644
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -604,7 +604,7 @@ ()
   specialize qw/av1_convolve_2d_sr_intrabc neon/;
   specialize qw/av1_convolve_x_sr sse2 avx2 neon neon_dotprod neon_i8mm/;
   specialize qw/av1_convolve_x_sr_intrabc neon/;
-  specialize qw/av1_convolve_y_sr sse2 avx2 neon/;
+  specialize qw/av1_convolve_y_sr sse2 avx2 neon neon_dotprod/;
   specialize qw/av1_convolve_y_sr_intrabc neon/;
   specialize qw/av1_convolve_2d_scale sse4_1/;
   specialize qw/av1_dist_wtd_convolve_2d ssse3 avx2 neon neon_dotprod neon_i8mm/;
diff --git a/test/av1_convolve_test.cc b/test/av1_convolve_test.cc
index b2392276c..96c060349 100644
--- a/test/av1_convolve_test.cc
+++ b/test/av1_convolve_test.cc
@@ -827,6 +827,11 @@ INSTANTIATE_TEST_SUITE_P(NEON, AV1ConvolveYTest,
                          BuildLowbdParams(av1_convolve_y_sr_neon));
 #endif
 
+#if HAVE_NEON_DOTPROD
+INSTANTIATE_TEST_SUITE_P(NEON_DOTPROD, AV1ConvolveYTest,
+                         BuildLowbdParams(av1_convolve_y_sr_neon_dotprod));
+#endif
+
 ////////////////////////////////////////////////////////////////
 // Single reference convolve-y IntraBC functions (low bit-depth)
 ////////////////////////////////////////////////////////////////