aom: Refactor 4-tap convolve_2d_sr Neon I8MM path

From a1e3c8c721c7523f12b57307d9b5fed4b35ae221 Mon Sep 17 00:00:00 2001
From: Jonathan Wright <[EMAIL REDACTED]>
Date: Mon, 5 Aug 2024 09:24:24 +0100
Subject: [PATCH] Refactor 4-tap convolve_2d_sr Neon I8MM path

A 6-tap USMMLA (horizontal) convolution kernel requires the same
number of instructions as a 4-tap USDOT convolution kernel; therefore
we can use the USMMLA 6-tap path for both 6- and 4-tap cases.

This patch uses the above information to expand the utility of the
Neon I8MM 4-tap combined 2D convolution to support up to 6-tap
horizontal filters. This is useful because 6-tap horiz, 4-tap vert
filter combinations are the third most common type for convolve_2d_sr
in --rt encodings after 6, 6 and 4, 4.

Change-Id: Iefaf3a4f759bbcfd61ff60c829225d4aff93556c
---
 av1/common/arm/convolve_neon_i8mm.c | 81 +++++++++++++++++------------
 1 file changed, 48 insertions(+), 33 deletions(-)

diff --git a/av1/common/arm/convolve_neon_i8mm.c b/av1/common/arm/convolve_neon_i8mm.c
index c0957aa29..2d31054f8 100644
--- a/av1/common/arm/convolve_neon_i8mm.c
+++ b/av1/common/arm/convolve_neon_i8mm.c
@@ -1054,6 +1054,22 @@ static INLINE void convolve_2d_sr_horiz_4tap_neon_i8mm(
   }
 }
 
+static INLINE int16x4_t convolve6_4_2d_h(uint8x16_t samples,
+                                         const int8x16_t filter,
+                                         const uint8x16_t permute_tbl,
+                                         const int32x4_t horiz_const) {
+  // Permute samples ready for matrix multiply.
+  // { 0,  1,  2,  3,  4,  5,  6,  7,  2,  3,  4,  5,  6,  7,  8,  9 }
+  uint8x16_t perm_samples = vqtbl1q_u8(samples, permute_tbl);
+
+  // These instructions multiply a 2x8 matrix (samples) by an 8x2 matrix
+  // (filter), destructively accumulating into the destination register.
+  int32x4_t sum = vusmmlaq_s32(horiz_const, perm_samples, filter);
+
+  // We halved the convolution filter values so -1 from the right shift.
+  return vshrn_n_s32(sum, ROUND0_BITS - 1);
+}
+
 static INLINE int16x8_t convolve6_8_2d_h(uint8x16_t samples,
                                          const int8x16_t filter,
                                          const uint8x16x2_t permute_tbl,
@@ -1153,36 +1169,33 @@ static INLINE void convolve_2d_sr_6tap_neon_i8mm(const uint8_t *src,
   } while (w != 0);
 }
 
-static INLINE void convolve_2d_sr_4tap_neon_i8mm(const uint8_t *src,
-                                                 int src_stride, uint8_t *dst,
-                                                 int dst_stride, int w, int h,
-                                                 const int16_t *x_filter_ptr,
-                                                 const int16_t *y_filter_ptr) {
-  const int bd = 8;
-  const int16x8_t vert_const = vdupq_n_s16(1 << (bd - 1));
-
+static INLINE void convolve_2d_sr_6tap_4tap_neon_i8mm(
+    const uint8_t *src, int src_stride, uint8_t *dst, int dst_stride, int w,
+    int h, const int16_t *x_filter_ptr, const int16_t *y_filter_ptr) {
   const int16x4_t y_filter = vld1_s16(y_filter_ptr + 2);
-  const int16x4_t x_filter_s16 = vld1_s16(x_filter_ptr + 2);
-  // All 4-tap and bilinear filter values are even, so halve them to reduce
-  // intermediate precision requirements.
-  const int8x8_t x_filter =
-      vshrn_n_s16(vcombine_s16(x_filter_s16, vdup_n_s16(0)), 1);
+  // Filter values are even, so halve to reduce intermediate precision reqs.
+  const int8x8_t x_filter_s8 = vshrn_n_s16(vld1q_s16(x_filter_ptr), 1);
+  // Stagger the filter for use with the matrix multiply instructions.
+  // { f0, f1, f2, f3, f4, f5,  0,  0,  0, f0, f1, f2, f3, f4, f5,  0 }
+  const int8x16_t x_filter =
+      vcombine_s8(vext_s8(x_filter_s8, x_filter_s8, 1), x_filter_s8);
 
+  const int bd = 8;
   // Adding a shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding
   // shifts - which are generally faster than rounding shifts on modern CPUs.
   // Halve the total because we halved the filter values.
   const int32x4_t horiz_const = vdupq_n_s32(
       ((1 << (bd + FILTER_BITS - 1)) + (1 << (ROUND0_BITS - 1))) / 2);
+  const int16x8_t vert_const = vdupq_n_s16(1 << (bd - 1));
 
   if (w == 4) {
-    const uint8x16_t permute_tbl = vld1q_u8(kDotProdPermuteTbl);
-
+    const uint8x16_t permute_tbl = vld1q_u8(kMatMulPermuteTbl);
     uint8x16_t h_s0, h_s1, h_s2;
     load_u8_16x3(src, src_stride, &h_s0, &h_s1, &h_s2);
 
-    int16x4_t v_s0 = convolve4_4_2d_h(h_s0, x_filter, permute_tbl, horiz_const);
-    int16x4_t v_s1 = convolve4_4_2d_h(h_s1, x_filter, permute_tbl, horiz_const);
-    int16x4_t v_s2 = convolve4_4_2d_h(h_s2, x_filter, permute_tbl, horiz_const);
+    int16x4_t v_s0 = convolve6_4_2d_h(h_s0, x_filter, permute_tbl, horiz_const);
+    int16x4_t v_s1 = convolve6_4_2d_h(h_s1, x_filter, permute_tbl, horiz_const);
+    int16x4_t v_s2 = convolve6_4_2d_h(h_s2, x_filter, permute_tbl, horiz_const);
 
     src += 3 * src_stride;
 
@@ -1191,13 +1204,13 @@ static INLINE void convolve_2d_sr_4tap_neon_i8mm(const uint8_t *src,
       load_u8_16x4(src, src_stride, &h_s3, &h_s4, &h_s5, &h_s6);
 
       int16x4_t v_s3 =
-          convolve4_4_2d_h(h_s3, x_filter, permute_tbl, horiz_const);
+          convolve6_4_2d_h(h_s3, x_filter, permute_tbl, horiz_const);
       int16x4_t v_s4 =
-          convolve4_4_2d_h(h_s4, x_filter, permute_tbl, horiz_const);
+          convolve6_4_2d_h(h_s4, x_filter, permute_tbl, horiz_const);
       int16x4_t v_s5 =
-          convolve4_4_2d_h(h_s5, x_filter, permute_tbl, horiz_const);
+          convolve6_4_2d_h(h_s5, x_filter, permute_tbl, horiz_const);
       int16x4_t v_s6 =
-          convolve4_4_2d_h(h_s6, x_filter, permute_tbl, horiz_const);
+          convolve6_4_2d_h(h_s6, x_filter, permute_tbl, horiz_const);
 
       int16x4_t d0 = convolve4_4_2d_v(v_s0, v_s1, v_s2, v_s3, y_filter);
       int16x4_t d1 = convolve4_4_2d_v(v_s1, v_s2, v_s3, v_s4, y_filter);
@@ -1219,7 +1232,7 @@ static INLINE void convolve_2d_sr_4tap_neon_i8mm(const uint8_t *src,
       h -= 4;
     } while (h != 0);
   } else {
-    const uint8x16x2_t permute_tbl = vld1q_u8_x2(kDotProdPermuteTbl);
+    const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMulPermuteTbl);
 
     do {
       int height = h;
@@ -1230,11 +1243,11 @@ static INLINE void convolve_2d_sr_4tap_neon_i8mm(const uint8_t *src,
       load_u8_16x3(src, src_stride, &h_s0, &h_s1, &h_s2);
 
       int16x8_t v_s0 =
-          convolve4_8_2d_h(h_s0, x_filter, permute_tbl, horiz_const);
+          convolve6_8_2d_h(h_s0, x_filter, permute_tbl, horiz_const);
       int16x8_t v_s1 =
-          convolve4_8_2d_h(h_s1, x_filter, permute_tbl, horiz_const);
+          convolve6_8_2d_h(h_s1, x_filter, permute_tbl, horiz_const);
       int16x8_t v_s2 =
-          convolve4_8_2d_h(h_s2, x_filter, permute_tbl, horiz_const);
+          convolve6_8_2d_h(h_s2, x_filter, permute_tbl, horiz_const);
 
       s += 3 * src_stride;
 
@@ -1243,13 +1256,13 @@ static INLINE void convolve_2d_sr_4tap_neon_i8mm(const uint8_t *src,
         load_u8_16x4(s, src_stride, &h_s3, &h_s4, &h_s5, &h_s6);
 
         int16x8_t v_s3 =
-            convolve4_8_2d_h(h_s3, x_filter, permute_tbl, horiz_const);
+            convolve6_8_2d_h(h_s3, x_filter, permute_tbl, horiz_const);
         int16x8_t v_s4 =
-            convolve4_8_2d_h(h_s4, x_filter, permute_tbl, horiz_const);
+            convolve6_8_2d_h(h_s4, x_filter, permute_tbl, horiz_const);
         int16x8_t v_s5 =
-            convolve4_8_2d_h(h_s5, x_filter, permute_tbl, horiz_const);
+            convolve6_8_2d_h(h_s5, x_filter, permute_tbl, horiz_const);
         int16x8_t v_s6 =
-            convolve4_8_2d_h(h_s6, x_filter, permute_tbl, horiz_const);
+            convolve6_8_2d_h(h_s6, x_filter, permute_tbl, horiz_const);
 
         uint8x8_t d0 =
             convolve4_8_2d_v(v_s0, v_s1, v_s2, v_s3, y_filter, vert_const);
@@ -1326,9 +1339,11 @@ void av1_convolve_2d_sr_neon_i8mm(const uint8_t *src, int src_stride,
       return;
     }
 
-    if (y_filter_taps <= 4 && x_filter_taps <= 4) {
-      convolve_2d_sr_4tap_neon_i8mm(src_ptr + 2, src_stride, dst, dst_stride, w,
-                                    h, x_filter_ptr, y_filter_ptr);
+    // Used for both 6, 4 and 4, 4 horiz, vert filter tap combinations.
+    if (x_filter_taps <= 6 && y_filter_taps <= 4) {
+      convolve_2d_sr_6tap_4tap_neon_i8mm(src_ptr + 1, src_stride, dst,
+                                         dst_stride, w, h, x_filter_ptr,
+                                         y_filter_ptr);
       return;
     }