From 466e282391f0ff4c53030422b0d088442dd5f472 Mon Sep 17 00:00:00 2001
From: Jonathan Wright <[EMAIL REDACTED]>
Date: Mon, 5 Aug 2024 17:15:12 +0100
Subject: [PATCH] Add Arm Neon USMMLA impl. for 6-tap aom_convolve8_horiz
By permuting the input samples and the 6-tap filter we can use the
Armv8.6 I8MM USMMLA matrix multiply instructions to accelerate
horizontal 6-tap convolutions. The 2x8 by 8x2 matrix multiply
instruction does twice the work of a USDOT dot product instruction.
We use this new USMMLA 6-tap path for 4-tap filters as well since it
uses exactly the same number of instructions as the previous USDOT
implementation.
Change-Id: I36ba48eebb54dba7a7717875b2e83985b3b036d3
---
aom_dsp/arm/aom_convolve8_neon_i8mm.c | 84 +++++++++++++++------------
1 file changed, 46 insertions(+), 38 deletions(-)
diff --git a/aom_dsp/arm/aom_convolve8_neon_i8mm.c b/aom_dsp/arm/aom_convolve8_neon_i8mm.c
index 5b9b88e75..121e89213 100644
--- a/aom_dsp/arm/aom_convolve8_neon_i8mm.c
+++ b/aom_dsp/arm/aom_convolve8_neon_i8mm.c
@@ -25,6 +25,13 @@
#include "aom_dsp/arm/transpose_neon.h"
#include "aom_ports/mem.h"
+DECLARE_ALIGNED(16, static const uint8_t, kMatMulPermuteTbl[32]) = {
+ // clang-format off
+ 0, 1, 2, 3, 4, 5, 6, 7, 2, 3, 4, 5, 6, 7, 8, 9,
+ 4, 5, 6, 7, 8, 9, 10, 11, 6, 7, 8, 9, 10, 11, 12, 13
+ // clang-format on
+};
+
DECLARE_ALIGNED(16, static const uint8_t, kDotProdPermuteTbl[48]) = {
0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6,
4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10,
@@ -136,60 +143,61 @@ static inline void convolve8_horiz_8tap_neon_i8mm(
}
}
-static inline int16x4_t convolve4_4_h(const uint8x16_t samples,
- const int8x8_t filters,
+static inline int16x4_t convolve6_4_h(const uint8x16_t samples,
+ const int8x16_t filter,
const uint8x16_t permute_tbl) {
- // Permute samples ready for dot product.
- // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 }
- uint8x16_t permuted_samples = vqtbl1q_u8(samples, permute_tbl);
+ // Permute samples ready for matrix multiply.
+ // { 0, 1, 2, 3, 4, 5, 6, 7, 2, 3, 4, 5, 6, 7, 8, 9 }
+ uint8x16_t perm_samples = vqtbl1q_u8(samples, permute_tbl);
- int32x4_t sum =
- vusdotq_lane_s32(vdupq_n_s32(0), permuted_samples, filters, 0);
+ // These instructions multiply a 2x8 matrix (samples) by an 8x2 matrix
+ // (filter), destructively accumulating into the destination register.
+ int32x4_t sum = vusmmlaq_s32(vdupq_n_s32(0), perm_samples, filter);
// Further narrowing and packing is performed by the caller.
return vmovn_s32(sum);
}
-static inline uint8x8_t convolve4_8_h(const uint8x16_t samples,
- const int8x8_t filters,
+static inline uint8x8_t convolve6_8_h(const uint8x16_t samples,
+ const int8x16_t filter,
const uint8x16x2_t permute_tbl) {
- // Permute samples ready for dot product.
- // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 }
- // { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 }
- uint8x16_t permuted_samples[2] = { vqtbl1q_u8(samples, permute_tbl.val[0]),
- vqtbl1q_u8(samples, permute_tbl.val[1]) };
+ // Permute samples ready for matrix multiply.
+ // { 0, 1, 2, 3, 4, 5, 6, 7, 2, 3, 4, 5, 6, 7, 8, 9 }
+ // { 4, 5, 6, 7, 8, 9, 10, 11, 6, 7, 8, 9, 10, 11, 12, 13 }
+ uint8x16_t perm_samples[2] = { vqtbl1q_u8(samples, permute_tbl.val[0]),
+ vqtbl1q_u8(samples, permute_tbl.val[1]) };
- // First 4 output values.
- int32x4_t sum0 =
- vusdotq_lane_s32(vdupq_n_s32(0), permuted_samples[0], filters, 0);
- // Second 4 output values.
- int32x4_t sum1 =
- vusdotq_lane_s32(vdupq_n_s32(0), permuted_samples[1], filters, 0);
+ // These instructions multiply a 2x8 matrix (samples) by an 8x2 matrix
+ // (filter), destructively accumulating into the destination register.
+ int32x4_t sum0123 = vusmmlaq_s32(vdupq_n_s32(0), perm_samples[0], filter);
+ int32x4_t sum4567 = vusmmlaq_s32(vdupq_n_s32(0), perm_samples[1], filter);
// Narrow and re-pack.
- int16x8_t sum = vcombine_s16(vmovn_s32(sum0), vmovn_s32(sum1));
+ int16x8_t sum = vcombine_s16(vmovn_s32(sum0123), vmovn_s32(sum4567));
// We halved the filter values so -1 from right shift.
return vqrshrun_n_s16(sum, FILTER_BITS - 1);
}
-static inline void convolve8_horiz_4tap_neon_i8mm(
+static inline void convolve8_horiz_6tap_neon_i8mm(
const uint8_t *src, ptrdiff_t src_stride, uint8_t *dst,
ptrdiff_t dst_stride, const int16_t *filter_x, int width, int height) {
- const int16x4_t x_filter = vld1_s16(filter_x + 2);
- // All 4-tap and bilinear filter values are even, so halve them to reduce
- // intermediate precision requirements.
- const int8x8_t filter = vshrn_n_s16(vcombine_s16(x_filter, vdup_n_s16(0)), 1);
+ // Filter values are even, so halve to reduce intermediate precision reqs.
+ const int8x8_t x_filter = vshrn_n_s16(vld1q_s16(filter_x), 1);
+ // Stagger the filter for use with the matrix multiply instructions.
+ // { f0, f1, f2, f3, f4, f5, 0, 0, 0, f0, f1, f2, f3, f4, f5, 0 }
+ const int8x16_t filter =
+ vcombine_s8(vext_s8(x_filter, x_filter, 1), x_filter);
if (width == 4) {
- const uint8x16_t perm_tbl = vld1q_u8(kDotProdPermuteTbl);
+ const uint8x16_t perm_tbl = vld1q_u8(kMatMulPermuteTbl);
do {
uint8x16_t s0, s1, s2, s3;
load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
- int16x4_t t0 = convolve4_4_h(s0, filter, perm_tbl);
- int16x4_t t1 = convolve4_4_h(s1, filter, perm_tbl);
- int16x4_t t2 = convolve4_4_h(s2, filter, perm_tbl);
- int16x4_t t3 = convolve4_4_h(s3, filter, perm_tbl);
+ int16x4_t t0 = convolve6_4_h(s0, filter, perm_tbl);
+ int16x4_t t1 = convolve6_4_h(s1, filter, perm_tbl);
+ int16x4_t t2 = convolve6_4_h(s2, filter, perm_tbl);
+ int16x4_t t3 = convolve6_4_h(s3, filter, perm_tbl);
// We halved the filter values so -1 from right shift.
uint8x8_t d01 = vqrshrun_n_s16(vcombine_s16(t0, t1), FILTER_BITS - 1);
uint8x8_t d23 = vqrshrun_n_s16(vcombine_s16(t2, t3), FILTER_BITS - 1);
@@ -202,7 +210,7 @@ static inline void convolve8_horiz_4tap_neon_i8mm(
height -= 4;
} while (height > 0);
} else {
- const uint8x16x2_t perm_tbl = vld1q_u8_x2(kDotProdPermuteTbl);
+ const uint8x16x2_t perm_tbl = vld1q_u8_x2(kMatMulPermuteTbl);
do {
int w = width;
@@ -212,10 +220,10 @@ static inline void convolve8_horiz_4tap_neon_i8mm(
uint8x16_t s0, s1, s2, s3;
load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
- uint8x8_t d0 = convolve4_8_h(s0, filter, perm_tbl);
- uint8x8_t d1 = convolve4_8_h(s1, filter, perm_tbl);
- uint8x8_t d2 = convolve4_8_h(s2, filter, perm_tbl);
- uint8x8_t d3 = convolve4_8_h(s3, filter, perm_tbl);
+ uint8x8_t d0 = convolve6_8_h(s0, filter, perm_tbl);
+ uint8x8_t d1 = convolve6_8_h(s1, filter, perm_tbl);
+ uint8x8_t d2 = convolve6_8_h(s2, filter, perm_tbl);
+ uint8x8_t d3 = convolve6_8_h(s3, filter, perm_tbl);
store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
@@ -249,8 +257,8 @@ void aom_convolve8_horiz_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride,
if (filter_taps == 2) {
convolve8_horiz_2tap_neon(src + 3, src_stride, dst, dst_stride, filter_x, w,
h);
- } else if (filter_taps == 4) {
- convolve8_horiz_4tap_neon_i8mm(src + 2, src_stride, dst, dst_stride,
+ } else if (filter_taps <= 6) {
+ convolve8_horiz_6tap_neon_i8mm(src + 1, src_stride, dst, dst_stride,
filter_x, w, h);
} else {
convolve8_horiz_8tap_neon_i8mm(src, src_stride, dst, dst_stride, filter_x,