aom: Use Arm Neon USMMLA for horiz. 6-tap path for convolve_2d_sr

From 88e4df06ca24c73b79c84a15fc230074062a5584 Mon Sep 17 00:00:00 2001
From: Jonathan Wright <[EMAIL REDACTED]>
Date: Fri, 2 Aug 2024 12:13:34 +0100
Subject: [PATCH] Use Arm Neon USMMLA for horiz. 6-tap path for convolve_2d_sr

By permuting the input samples and the 6-tap filter we can use the
Armv8.6 I8MM USMMLA matrix multiply instructions to accelerate
horizontal 6-tap convolutions. The 2x8 by 8x2 matrix multiply
instruction does twice the work of a USDOT dot product instruction.

Change-Id: I0f6b3969925a4d7190e277c3cd0221f7f4c98018
---
 av1/common/arm/convolve_neon_i8mm.c | 58 +++++++++++++++++++++--------
 1 file changed, 42 insertions(+), 16 deletions(-)

diff --git a/av1/common/arm/convolve_neon_i8mm.c b/av1/common/arm/convolve_neon_i8mm.c
index 796d3f709..c0957aa29 100644
--- a/av1/common/arm/convolve_neon_i8mm.c
+++ b/av1/common/arm/convolve_neon_i8mm.c
@@ -1054,6 +1054,27 @@ static INLINE void convolve_2d_sr_horiz_4tap_neon_i8mm(
   }
 }
 
+static INLINE int16x8_t convolve6_8_2d_h(uint8x16_t samples,
+                                         const int8x16_t filter,
+                                         const uint8x16x2_t permute_tbl,
+                                         const int32x4_t horiz_const) {
+  // Permute samples ready for matrix multiply.
+  // { 0,  1,  2,  3,  4,  5,  6,  7,  2,  3,  4,  5,  6,  7,  8,  9 }
+  // { 4,  5,  6,  7,  8,  9, 10, 11,  6,  7,  8,  9, 10, 11, 12, 13 }
+  uint8x16_t perm_samples[2] = { vqtbl1q_u8(samples, permute_tbl.val[0]),
+                                 vqtbl1q_u8(samples, permute_tbl.val[1]) };
+
+  // These instructions multiply a 2x8 matrix (samples) by an 8x2 matrix
+  // (filter), destructively accumulating into the destination register.
+  int32x4_t sum0123 = vusmmlaq_s32(horiz_const, perm_samples[0], filter);
+  int32x4_t sum4567 = vusmmlaq_s32(horiz_const, perm_samples[1], filter);
+
+  // Narrow and re-pack.
+  // We halved the convolution filter values so -1 from the right shift.
+  return vcombine_s16(vshrn_n_s32(sum0123, ROUND0_BITS - 1),
+                      vshrn_n_s32(sum4567, ROUND0_BITS - 1));
+}
+
 static INLINE void convolve_2d_sr_6tap_neon_i8mm(const uint8_t *src,
                                                  int src_stride, uint8_t *dst,
                                                  int dst_stride, int w, int h,
@@ -1061,16 +1082,21 @@ static INLINE void convolve_2d_sr_6tap_neon_i8mm(const uint8_t *src,
                                                  const int16_t *y_filter_ptr) {
   const int16x8_t y_filter = vld1q_s16(y_filter_ptr);
   // Filter values are even, so halve to reduce intermediate precision reqs.
-  const int8x8_t x_filter = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
+  const int8x8_t x_filter_s8 = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
+  // Stagger the filter for use with the matrix multiply instructions.
+  // { f0, f1, f2, f3, f4, f5,  0,  0,  0, f0, f1, f2, f3, f4, f5,  0 }
+  const int8x16_t x_filter =
+      vcombine_s8(vext_s8(x_filter_s8, x_filter_s8, 1), x_filter_s8);
 
   const int bd = 8;
   // This shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
-  // shifts - which are generally faster than rounding shifts on modern CPUs.
-  // The outermost -1 is needed because we halved the filter values.
+  // shifts in convolution kernels - which are generally faster than rounding
+  // shifts on modern CPUs. The outermost -1 is needed because we halved the
+  // filter values.
   const int32x4_t horiz_const = vdupq_n_s32((1 << (bd + FILTER_BITS - 2)) +
                                             (1 << ((ROUND0_BITS - 1) - 1)));
   const int16x8_t vert_const = vdupq_n_s16(1 << (bd - 1));
-  const uint8x16x3_t permute_tbl = vld1q_u8_x3(kDotProdPermuteTbl);
+  const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMulPermuteTbl);
 
   do {
     const uint8_t *s = src;
@@ -1081,24 +1107,24 @@ static INLINE void convolve_2d_sr_6tap_neon_i8mm(const uint8_t *src,
     load_u8_16x5(s, src_stride, &h_s0, &h_s1, &h_s2, &h_s3, &h_s4);
     s += 5 * src_stride;
 
-    int16x8_t v_s0 = convolve8_8_2d_h(h_s0, x_filter, permute_tbl, horiz_const);
-    int16x8_t v_s1 = convolve8_8_2d_h(h_s1, x_filter, permute_tbl, horiz_const);
-    int16x8_t v_s2 = convolve8_8_2d_h(h_s2, x_filter, permute_tbl, horiz_const);
-    int16x8_t v_s3 = convolve8_8_2d_h(h_s3, x_filter, permute_tbl, horiz_const);
-    int16x8_t v_s4 = convolve8_8_2d_h(h_s4, x_filter, permute_tbl, horiz_const);
+    int16x8_t v_s0 = convolve6_8_2d_h(h_s0, x_filter, permute_tbl, horiz_const);
+    int16x8_t v_s1 = convolve6_8_2d_h(h_s1, x_filter, permute_tbl, horiz_const);
+    int16x8_t v_s2 = convolve6_8_2d_h(h_s2, x_filter, permute_tbl, horiz_const);
+    int16x8_t v_s3 = convolve6_8_2d_h(h_s3, x_filter, permute_tbl, horiz_const);
+    int16x8_t v_s4 = convolve6_8_2d_h(h_s4, x_filter, permute_tbl, horiz_const);
 
     do {
       uint8x16_t h_s5, h_s6, h_s7, h_s8;
       load_u8_16x4(s, src_stride, &h_s5, &h_s6, &h_s7, &h_s8);
 
       int16x8_t v_s5 =
-          convolve8_8_2d_h(h_s5, x_filter, permute_tbl, horiz_const);
+          convolve6_8_2d_h(h_s5, x_filter, permute_tbl, horiz_const);
       int16x8_t v_s6 =
-          convolve8_8_2d_h(h_s6, x_filter, permute_tbl, horiz_const);
+          convolve6_8_2d_h(h_s6, x_filter, permute_tbl, horiz_const);
       int16x8_t v_s7 =
-          convolve8_8_2d_h(h_s7, x_filter, permute_tbl, horiz_const);
+          convolve6_8_2d_h(h_s7, x_filter, permute_tbl, horiz_const);
       int16x8_t v_s8 =
-          convolve8_8_2d_h(h_s8, x_filter, permute_tbl, horiz_const);
+          convolve6_8_2d_h(h_s8, x_filter, permute_tbl, horiz_const);
 
       uint8x8_t d0 = convolve6_8_2d_v(v_s0, v_s1, v_s2, v_s3, v_s4, v_s5,
                                       y_filter, vert_const);
@@ -1294,9 +1320,9 @@ void av1_convolve_2d_sr_neon_i8mm(const uint8_t *src, int src_stride,
     DECLARE_ALIGNED(16, int16_t,
                     im_block[(MAX_SB_SIZE + SUBPEL_TAPS - 1) * MAX_SB_SIZE]);
 
-    if (y_filter_taps == 6 && x_filter_taps >= 6) {
-      convolve_2d_sr_6tap_neon_i8mm(src_ptr, src_stride, dst, dst_stride, w, h,
-                                    x_filter_ptr, y_filter_ptr);
+    if (x_filter_taps == 6 && y_filter_taps == 6) {
+      convolve_2d_sr_6tap_neon_i8mm(src_ptr + 1, src_stride, dst, dst_stride, w,
+                                    h, x_filter_ptr, y_filter_ptr);
       return;
     }