aom: Add Arm Neon USMMLA impl. for 6-tap aom_convolve8_horiz

From 466e282391f0ff4c53030422b0d088442dd5f472 Mon Sep 17 00:00:00 2001
From: Jonathan Wright <[EMAIL REDACTED]>
Date: Mon, 5 Aug 2024 17:15:12 +0100
Subject: [PATCH] Add Arm Neon USMMLA impl. for 6-tap aom_convolve8_horiz

By permuting the input samples and the 6-tap filter we can use the
Armv8.6 I8MM USMMLA matrix multiply instructions to accelerate
horizontal 6-tap convolutions. The 2x8 by 8x2 matrix multiply
instruction does twice the work of a USDOT dot product instruction.

We use this new USMMLA 6-tap path for 4-tap filters as well since it
uses exactly the same number of instructions as the previous USDOT
implementation.

Change-Id: I36ba48eebb54dba7a7717875b2e83985b3b036d3
---
 aom_dsp/arm/aom_convolve8_neon_i8mm.c | 84 +++++++++++++++------------
 1 file changed, 46 insertions(+), 38 deletions(-)

diff --git a/aom_dsp/arm/aom_convolve8_neon_i8mm.c b/aom_dsp/arm/aom_convolve8_neon_i8mm.c
index 5b9b88e75..121e89213 100644
--- a/aom_dsp/arm/aom_convolve8_neon_i8mm.c
+++ b/aom_dsp/arm/aom_convolve8_neon_i8mm.c
@@ -25,6 +25,13 @@
 #include "aom_dsp/arm/transpose_neon.h"
 #include "aom_ports/mem.h"
 
+DECLARE_ALIGNED(16, static const uint8_t, kMatMulPermuteTbl[32]) = {
+  // clang-format off
+  0,  1,  2,  3,  4,  5,  6,  7,  2,  3,  4,  5,  6,  7,  8,  9,
+  4,  5,  6,  7,  8,  9, 10, 11,  6,  7,  8,  9, 10, 11, 12, 13
+  // clang-format on
+};
+
 DECLARE_ALIGNED(16, static const uint8_t, kDotProdPermuteTbl[48]) = {
   0, 1, 2,  3,  1, 2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6,
   4, 5, 6,  7,  5, 6,  7,  8,  6,  7,  8,  9,  7,  8,  9,  10,
@@ -136,60 +143,61 @@ static inline void convolve8_horiz_8tap_neon_i8mm(
   }
 }
 
-static inline int16x4_t convolve4_4_h(const uint8x16_t samples,
-                                      const int8x8_t filters,
+static inline int16x4_t convolve6_4_h(const uint8x16_t samples,
+                                      const int8x16_t filter,
                                       const uint8x16_t permute_tbl) {
-  // Permute samples ready for dot product.
-  // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
-  uint8x16_t permuted_samples = vqtbl1q_u8(samples, permute_tbl);
+  // Permute samples ready for matrix multiply.
+  // { 0,  1,  2,  3,  4,  5,  6,  7,  2,  3,  4,  5,  6,  7,  8,  9 }
+  uint8x16_t perm_samples = vqtbl1q_u8(samples, permute_tbl);
 
-  int32x4_t sum =
-      vusdotq_lane_s32(vdupq_n_s32(0), permuted_samples, filters, 0);
+  // These instructions multiply a 2x8 matrix (samples) by an 8x2 matrix
+  // (filter), destructively accumulating into the destination register.
+  int32x4_t sum = vusmmlaq_s32(vdupq_n_s32(0), perm_samples, filter);
 
   // Further narrowing and packing is performed by the caller.
   return vmovn_s32(sum);
 }
 
-static inline uint8x8_t convolve4_8_h(const uint8x16_t samples,
-                                      const int8x8_t filters,
+static inline uint8x8_t convolve6_8_h(const uint8x16_t samples,
+                                      const int8x16_t filter,
                                       const uint8x16x2_t permute_tbl) {
-  // Permute samples ready for dot product.
-  // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
-  // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
-  uint8x16_t permuted_samples[2] = { vqtbl1q_u8(samples, permute_tbl.val[0]),
-                                     vqtbl1q_u8(samples, permute_tbl.val[1]) };
+  // Permute samples ready for matrix multiply.
+  // { 0,  1,  2,  3,  4,  5,  6,  7,  2,  3,  4,  5,  6,  7,  8,  9 }
+  // { 4,  5,  6,  7,  8,  9, 10, 11,  6,  7,  8,  9, 10, 11, 12, 13 }
+  uint8x16_t perm_samples[2] = { vqtbl1q_u8(samples, permute_tbl.val[0]),
+                                 vqtbl1q_u8(samples, permute_tbl.val[1]) };
 
-  // First 4 output values.
-  int32x4_t sum0 =
-      vusdotq_lane_s32(vdupq_n_s32(0), permuted_samples[0], filters, 0);
-  // Second 4 output values.
-  int32x4_t sum1 =
-      vusdotq_lane_s32(vdupq_n_s32(0), permuted_samples[1], filters, 0);
+  // These instructions multiply a 2x8 matrix (samples) by an 8x2 matrix
+  // (filter), destructively accumulating into the destination register.
+  int32x4_t sum0123 = vusmmlaq_s32(vdupq_n_s32(0), perm_samples[0], filter);
+  int32x4_t sum4567 = vusmmlaq_s32(vdupq_n_s32(0), perm_samples[1], filter);
 
   // Narrow and re-pack.
-  int16x8_t sum = vcombine_s16(vmovn_s32(sum0), vmovn_s32(sum1));
+  int16x8_t sum = vcombine_s16(vmovn_s32(sum0123), vmovn_s32(sum4567));
   // We halved the filter values so -1 from right shift.
   return vqrshrun_n_s16(sum, FILTER_BITS - 1);
 }
 
-static inline void convolve8_horiz_4tap_neon_i8mm(
+static inline void convolve8_horiz_6tap_neon_i8mm(
     const uint8_t *src, ptrdiff_t src_stride, uint8_t *dst,
     ptrdiff_t dst_stride, const int16_t *filter_x, int width, int height) {
-  const int16x4_t x_filter = vld1_s16(filter_x + 2);
-  // All 4-tap and bilinear filter values are even, so halve them to reduce
-  // intermediate precision requirements.
-  const int8x8_t filter = vshrn_n_s16(vcombine_s16(x_filter, vdup_n_s16(0)), 1);
+  // Filter values are even, so halve to reduce intermediate precision reqs.
+  const int8x8_t x_filter = vshrn_n_s16(vld1q_s16(filter_x), 1);
+  // Stagger the filter for use with the matrix multiply instructions.
+  // { f0, f1, f2, f3, f4, f5,  0,  0,  0, f0, f1, f2, f3, f4, f5,  0 }
+  const int8x16_t filter =
+      vcombine_s8(vext_s8(x_filter, x_filter, 1), x_filter);
 
   if (width == 4) {
-    const uint8x16_t perm_tbl = vld1q_u8(kDotProdPermuteTbl);
+    const uint8x16_t perm_tbl = vld1q_u8(kMatMulPermuteTbl);
     do {
       uint8x16_t s0, s1, s2, s3;
       load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
 
-      int16x4_t t0 = convolve4_4_h(s0, filter, perm_tbl);
-      int16x4_t t1 = convolve4_4_h(s1, filter, perm_tbl);
-      int16x4_t t2 = convolve4_4_h(s2, filter, perm_tbl);
-      int16x4_t t3 = convolve4_4_h(s3, filter, perm_tbl);
+      int16x4_t t0 = convolve6_4_h(s0, filter, perm_tbl);
+      int16x4_t t1 = convolve6_4_h(s1, filter, perm_tbl);
+      int16x4_t t2 = convolve6_4_h(s2, filter, perm_tbl);
+      int16x4_t t3 = convolve6_4_h(s3, filter, perm_tbl);
       // We halved the filter values so -1 from right shift.
       uint8x8_t d01 = vqrshrun_n_s16(vcombine_s16(t0, t1), FILTER_BITS - 1);
       uint8x8_t d23 = vqrshrun_n_s16(vcombine_s16(t2, t3), FILTER_BITS - 1);
@@ -202,7 +210,7 @@ static inline void convolve8_horiz_4tap_neon_i8mm(
       height -= 4;
     } while (height > 0);
   } else {
-    const uint8x16x2_t perm_tbl = vld1q_u8_x2(kDotProdPermuteTbl);
+    const uint8x16x2_t perm_tbl = vld1q_u8_x2(kMatMulPermuteTbl);
 
     do {
       int w = width;
@@ -212,10 +220,10 @@ static inline void convolve8_horiz_4tap_neon_i8mm(
         uint8x16_t s0, s1, s2, s3;
         load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
 
-        uint8x8_t d0 = convolve4_8_h(s0, filter, perm_tbl);
-        uint8x8_t d1 = convolve4_8_h(s1, filter, perm_tbl);
-        uint8x8_t d2 = convolve4_8_h(s2, filter, perm_tbl);
-        uint8x8_t d3 = convolve4_8_h(s3, filter, perm_tbl);
+        uint8x8_t d0 = convolve6_8_h(s0, filter, perm_tbl);
+        uint8x8_t d1 = convolve6_8_h(s1, filter, perm_tbl);
+        uint8x8_t d2 = convolve6_8_h(s2, filter, perm_tbl);
+        uint8x8_t d3 = convolve6_8_h(s3, filter, perm_tbl);
 
         store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
 
@@ -249,8 +257,8 @@ void aom_convolve8_horiz_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride,
   if (filter_taps == 2) {
     convolve8_horiz_2tap_neon(src + 3, src_stride, dst, dst_stride, filter_x, w,
                               h);
-  } else if (filter_taps == 4) {
-    convolve8_horiz_4tap_neon_i8mm(src + 2, src_stride, dst, dst_stride,
+  } else if (filter_taps <= 6) {
+    convolve8_horiz_6tap_neon_i8mm(src + 1, src_stride, dst, dst_stride,
                                    filter_x, w, h);
   } else {
     convolve8_horiz_8tap_neon_i8mm(src, src_stride, dst, dst_stride, filter_x,