aom: Use Arm Neon USMMLA 6-tap impl. for 4-tap convolve_x_sr

From 6e14f9069e58c9abc7ec4277d6e312116ac65b64 Mon Sep 17 00:00:00 2001
From: Jonathan Wright <[EMAIL REDACTED]>
Date: Wed, 7 Aug 2024 11:57:21 +0100
Subject: [PATCH] Use Arm Neon USMMLA 6-tap impl. for 4-tap convolve_x_sr

The 6-tap USMMLA implementation of convolve_x_sr uses the same number
of instructions as the 4-tap USDOT implementation, so delete the
4-tap USDOT path and use the 6-tap USMMLA implementation in both
cases.

Change-Id: Ic390abea63047af623a2ab232532b6b36360293e
---
 av1/common/arm/convolve_neon_i8mm.c | 118 +++++++---------------------
 1 file changed, 29 insertions(+), 89 deletions(-)

diff --git a/av1/common/arm/convolve_neon_i8mm.c b/av1/common/arm/convolve_neon_i8mm.c
index dd4a34e0b..acd912e57 100644
--- a/av1/common/arm/convolve_neon_i8mm.c
+++ b/av1/common/arm/convolve_neon_i8mm.c
@@ -213,6 +213,22 @@ static inline void convolve_x_sr_8tap_neon_i8mm(
   } while (height != 0);
 }
 
+static inline int16x4_t convolve6_4_x(uint8x16_t samples,
+                                      const int8x16_t filter,
+                                      const uint8x16_t permute_tbl,
+                                      const int32x4_t horiz_const) {
+  // Permute samples ready for matrix multiply.
+  // { 0,  1,  2,  3,  4,  5,  6,  7,  2,  3,  4,  5,  6,  7,  8,  9 }
+  uint8x16_t perm_samples = vqtbl1q_u8(samples, permute_tbl);
+
+  // These instructions multiply a 2x8 matrix (samples) by an 8x2 matrix
+  // (filter), destructively accumulating into the destination register.
+  int32x4_t sum = vusmmlaq_s32(horiz_const, perm_samples, filter);
+
+  // Further narrowing and packing is performed by the caller.
+  return vmovn_s32(sum);
+}
+
 static inline uint8x8_t convolve6_8_x(uint8x16_t samples,
                                       const int8x16_t filter,
                                       const uint8x16x2_t permute_tbl,
@@ -244,86 +260,16 @@ static inline void convolve_x_sr_6tap_neon_i8mm(
   const int8x16_t x_filter =
       vcombine_s8(vext_s8(x_filter_s8, x_filter_s8, 1), x_filter_s8);
 
-  const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMulPermuteTbl);
-  do {
-    const uint8_t *s = src;
-    uint8_t *d = dst;
-    int w = width;
-
-    do {
-      uint8x16_t s0, s1, s2, s3;
-      load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
-
-      uint8x8_t d0 = convolve6_8_x(s0, x_filter, permute_tbl, horiz_const);
-      uint8x8_t d1 = convolve6_8_x(s1, x_filter, permute_tbl, horiz_const);
-      uint8x8_t d2 = convolve6_8_x(s2, x_filter, permute_tbl, horiz_const);
-      uint8x8_t d3 = convolve6_8_x(s3, x_filter, permute_tbl, horiz_const);
-
-      store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
-
-      s += 8;
-      d += 8;
-      w -= 8;
-    } while (w != 0);
-    src += 4 * src_stride;
-    dst += 4 * dst_stride;
-    height -= 4;
-  } while (height != 0);
-}
-
-static inline int16x4_t convolve4_4_x(const uint8x16_t samples,
-                                      const int8x8_t filters,
-                                      const uint8x16_t permute_tbl,
-                                      const int32x4_t horiz_const) {
-  // Permute samples ready for dot product.
-  // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
-  uint8x16_t perm_samples = vqtbl1q_u8(samples, permute_tbl);
-
-  int32x4_t sum = vusdotq_lane_s32(horiz_const, perm_samples, filters, 0);
-
-  // Further narrowing and packing is performed by the caller.
-  return vmovn_s32(sum);
-}
-
-static inline uint8x8_t convolve4_8_x(const uint8x16_t samples,
-                                      const int8x8_t filters,
-                                      const uint8x16x2_t permute_tbl,
-                                      const int32x4_t horiz_const) {
-  // Permute samples ready for dot product.
-  // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
-  // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
-  uint8x16_t perm_samples[2] = { vqtbl1q_u8(samples, permute_tbl.val[0]),
-                                 vqtbl1q_u8(samples, permute_tbl.val[1]) };
-
-  int32x4_t acc = horiz_const;
-  int32x4_t sum0123 = vusdotq_lane_s32(acc, perm_samples[0], filters, 0);
-  int32x4_t sum4567 = vusdotq_lane_s32(acc, perm_samples[1], filters, 0);
-
-  // Narrow and re-pack.
-  int16x8_t sum = vcombine_s16(vmovn_s32(sum0123), vmovn_s32(sum4567));
-  // We halved the filter values so -1 from right shift.
-  return vqrshrun_n_s16(sum, FILTER_BITS - 1);
-}
-
-static inline void convolve_x_sr_4tap_neon_i8mm(
-    const uint8_t *src, ptrdiff_t src_stride, uint8_t *dst,
-    ptrdiff_t dst_stride, int width, int height, const int16_t *filter_x,
-    const int32x4_t horiz_const) {
-  const int16x4_t x_filter = vld1_s16(filter_x + 2);
-  // All 4-tap and bilinear filter values are even, so halve them to reduce
-  // intermediate precision requirements.
-  const int8x8_t filter = vshrn_n_s16(vcombine_s16(x_filter, vdup_n_s16(0)), 1);
-
   if (width == 4) {
-    const uint8x16_t perm_tbl = vld1q_u8(kDotProdPermuteTbl);
+    const uint8x16_t permute_tbl = vld1q_u8(kMatMulPermuteTbl);
     do {
       uint8x16_t s0, s1, s2, s3;
       load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
 
-      int16x4_t t0 = convolve4_4_x(s0, filter, perm_tbl, horiz_const);
-      int16x4_t t1 = convolve4_4_x(s1, filter, perm_tbl, horiz_const);
-      int16x4_t t2 = convolve4_4_x(s2, filter, perm_tbl, horiz_const);
-      int16x4_t t3 = convolve4_4_x(s3, filter, perm_tbl, horiz_const);
+      int16x4_t t0 = convolve6_4_x(s0, x_filter, permute_tbl, horiz_const);
+      int16x4_t t1 = convolve6_4_x(s1, x_filter, permute_tbl, horiz_const);
+      int16x4_t t2 = convolve6_4_x(s2, x_filter, permute_tbl, horiz_const);
+      int16x4_t t3 = convolve6_4_x(s3, x_filter, permute_tbl, horiz_const);
       // We halved the filter values so -1 from right shift.
       uint8x8_t d01 = vqrshrun_n_s16(vcombine_s16(t0, t1), FILTER_BITS - 1);
       uint8x8_t d23 = vqrshrun_n_s16(vcombine_s16(t2, t3), FILTER_BITS - 1);
@@ -336,20 +282,20 @@ static inline void convolve_x_sr_4tap_neon_i8mm(
       height -= 4;
     } while (height != 0);
   } else {
-    const uint8x16x2_t perm_tbl = vld1q_u8_x2(kDotProdPermuteTbl);
-
+    const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMulPermuteTbl);
     do {
-      int w = width;
       const uint8_t *s = src;
       uint8_t *d = dst;
+      int w = width;
+
       do {
         uint8x16_t s0, s1, s2, s3;
         load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
 
-        uint8x8_t d0 = convolve4_8_x(s0, filter, perm_tbl, horiz_const);
-        uint8x8_t d1 = convolve4_8_x(s1, filter, perm_tbl, horiz_const);
-        uint8x8_t d2 = convolve4_8_x(s2, filter, perm_tbl, horiz_const);
-        uint8x8_t d3 = convolve4_8_x(s3, filter, perm_tbl, horiz_const);
+        uint8x8_t d0 = convolve6_8_x(s0, x_filter, permute_tbl, horiz_const);
+        uint8x8_t d1 = convolve6_8_x(s1, x_filter, permute_tbl, horiz_const);
+        uint8x8_t d2 = convolve6_8_x(s2, x_filter, permute_tbl, horiz_const);
+        uint8x8_t d3 = convolve6_8_x(s3, x_filter, permute_tbl, horiz_const);
 
         store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
 
@@ -390,7 +336,7 @@ void av1_convolve_x_sr_neon_i8mm(const uint8_t *src, int src_stride,
   // Halve the total because we will halve the filter values.
   const int32x4_t horiz_const = vdupq_n_s32((1 << ((ROUND0_BITS - 1)) / 2));
 
-  if (filter_taps == 6) {
+  if (filter_taps <= 6) {
     convolve_x_sr_6tap_neon_i8mm(src + 1, src_stride, dst, dst_stride, w, h,
                                  x_filter_ptr, horiz_const);
     return;
@@ -402,12 +348,6 @@ void av1_convolve_x_sr_neon_i8mm(const uint8_t *src, int src_stride,
     return;
   }
 
-  if (filter_taps <= 4) {
-    convolve_x_sr_4tap_neon_i8mm(src + 2, src_stride, dst, dst_stride, w, h,
-                                 x_filter_ptr, horiz_const);
-    return;
-  }
-
   convolve_x_sr_8tap_neon_i8mm(src, src_stride, dst, dst_stride, w, h,
                                x_filter_ptr, horiz_const);
 }