From 6e14f9069e58c9abc7ec4277d6e312116ac65b64 Mon Sep 17 00:00:00 2001
From: Jonathan Wright <[EMAIL REDACTED]>
Date: Wed, 7 Aug 2024 11:57:21 +0100
Subject: [PATCH] Use Arm Neon USMMLA 6-tap impl. for 4-tap convolve_x_sr
The 6-tap USMMLA implementation of convolve_x_sr uses the same number
of instructions as the 4-tap USDOT implementation, so delete the
4-tap USDOT path and use the 6-tap USMMLA implementation in both
cases.
Change-Id: Ic390abea63047af623a2ab232532b6b36360293e
---
av1/common/arm/convolve_neon_i8mm.c | 118 +++++++---------------------
1 file changed, 29 insertions(+), 89 deletions(-)
diff --git a/av1/common/arm/convolve_neon_i8mm.c b/av1/common/arm/convolve_neon_i8mm.c
index dd4a34e0b..acd912e57 100644
--- a/av1/common/arm/convolve_neon_i8mm.c
+++ b/av1/common/arm/convolve_neon_i8mm.c
@@ -213,6 +213,22 @@ static inline void convolve_x_sr_8tap_neon_i8mm(
} while (height != 0);
}
+static inline int16x4_t convolve6_4_x(uint8x16_t samples,
+ const int8x16_t filter,
+ const uint8x16_t permute_tbl,
+ const int32x4_t horiz_const) {
+ // Permute samples ready for matrix multiply.
+ // { 0, 1, 2, 3, 4, 5, 6, 7, 2, 3, 4, 5, 6, 7, 8, 9 }
+ uint8x16_t perm_samples = vqtbl1q_u8(samples, permute_tbl);
+
+ // These instructions multiply a 2x8 matrix (samples) by an 8x2 matrix
+ // (filter), destructively accumulating into the destination register.
+ int32x4_t sum = vusmmlaq_s32(horiz_const, perm_samples, filter);
+
+ // Further narrowing and packing is performed by the caller.
+ return vmovn_s32(sum);
+}
+
static inline uint8x8_t convolve6_8_x(uint8x16_t samples,
const int8x16_t filter,
const uint8x16x2_t permute_tbl,
@@ -244,86 +260,16 @@ static inline void convolve_x_sr_6tap_neon_i8mm(
const int8x16_t x_filter =
vcombine_s8(vext_s8(x_filter_s8, x_filter_s8, 1), x_filter_s8);
- const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMulPermuteTbl);
- do {
- const uint8_t *s = src;
- uint8_t *d = dst;
- int w = width;
-
- do {
- uint8x16_t s0, s1, s2, s3;
- load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
-
- uint8x8_t d0 = convolve6_8_x(s0, x_filter, permute_tbl, horiz_const);
- uint8x8_t d1 = convolve6_8_x(s1, x_filter, permute_tbl, horiz_const);
- uint8x8_t d2 = convolve6_8_x(s2, x_filter, permute_tbl, horiz_const);
- uint8x8_t d3 = convolve6_8_x(s3, x_filter, permute_tbl, horiz_const);
-
- store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
-
- s += 8;
- d += 8;
- w -= 8;
- } while (w != 0);
- src += 4 * src_stride;
- dst += 4 * dst_stride;
- height -= 4;
- } while (height != 0);
-}
-
-static inline int16x4_t convolve4_4_x(const uint8x16_t samples,
- const int8x8_t filters,
- const uint8x16_t permute_tbl,
- const int32x4_t horiz_const) {
- // Permute samples ready for dot product.
- // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 }
- uint8x16_t perm_samples = vqtbl1q_u8(samples, permute_tbl);
-
- int32x4_t sum = vusdotq_lane_s32(horiz_const, perm_samples, filters, 0);
-
- // Further narrowing and packing is performed by the caller.
- return vmovn_s32(sum);
-}
-
-static inline uint8x8_t convolve4_8_x(const uint8x16_t samples,
- const int8x8_t filters,
- const uint8x16x2_t permute_tbl,
- const int32x4_t horiz_const) {
- // Permute samples ready for dot product.
- // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 }
- // { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 }
- uint8x16_t perm_samples[2] = { vqtbl1q_u8(samples, permute_tbl.val[0]),
- vqtbl1q_u8(samples, permute_tbl.val[1]) };
-
- int32x4_t acc = horiz_const;
- int32x4_t sum0123 = vusdotq_lane_s32(acc, perm_samples[0], filters, 0);
- int32x4_t sum4567 = vusdotq_lane_s32(acc, perm_samples[1], filters, 0);
-
- // Narrow and re-pack.
- int16x8_t sum = vcombine_s16(vmovn_s32(sum0123), vmovn_s32(sum4567));
- // We halved the filter values so -1 from right shift.
- return vqrshrun_n_s16(sum, FILTER_BITS - 1);
-}
-
-static inline void convolve_x_sr_4tap_neon_i8mm(
- const uint8_t *src, ptrdiff_t src_stride, uint8_t *dst,
- ptrdiff_t dst_stride, int width, int height, const int16_t *filter_x,
- const int32x4_t horiz_const) {
- const int16x4_t x_filter = vld1_s16(filter_x + 2);
- // All 4-tap and bilinear filter values are even, so halve them to reduce
- // intermediate precision requirements.
- const int8x8_t filter = vshrn_n_s16(vcombine_s16(x_filter, vdup_n_s16(0)), 1);
-
if (width == 4) {
- const uint8x16_t perm_tbl = vld1q_u8(kDotProdPermuteTbl);
+ const uint8x16_t permute_tbl = vld1q_u8(kMatMulPermuteTbl);
do {
uint8x16_t s0, s1, s2, s3;
load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
- int16x4_t t0 = convolve4_4_x(s0, filter, perm_tbl, horiz_const);
- int16x4_t t1 = convolve4_4_x(s1, filter, perm_tbl, horiz_const);
- int16x4_t t2 = convolve4_4_x(s2, filter, perm_tbl, horiz_const);
- int16x4_t t3 = convolve4_4_x(s3, filter, perm_tbl, horiz_const);
+ int16x4_t t0 = convolve6_4_x(s0, x_filter, permute_tbl, horiz_const);
+ int16x4_t t1 = convolve6_4_x(s1, x_filter, permute_tbl, horiz_const);
+ int16x4_t t2 = convolve6_4_x(s2, x_filter, permute_tbl, horiz_const);
+ int16x4_t t3 = convolve6_4_x(s3, x_filter, permute_tbl, horiz_const);
// We halved the filter values so -1 from right shift.
uint8x8_t d01 = vqrshrun_n_s16(vcombine_s16(t0, t1), FILTER_BITS - 1);
uint8x8_t d23 = vqrshrun_n_s16(vcombine_s16(t2, t3), FILTER_BITS - 1);
@@ -336,20 +282,20 @@ static inline void convolve_x_sr_4tap_neon_i8mm(
height -= 4;
} while (height != 0);
} else {
- const uint8x16x2_t perm_tbl = vld1q_u8_x2(kDotProdPermuteTbl);
-
+ const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMulPermuteTbl);
do {
- int w = width;
const uint8_t *s = src;
uint8_t *d = dst;
+ int w = width;
+
do {
uint8x16_t s0, s1, s2, s3;
load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
- uint8x8_t d0 = convolve4_8_x(s0, filter, perm_tbl, horiz_const);
- uint8x8_t d1 = convolve4_8_x(s1, filter, perm_tbl, horiz_const);
- uint8x8_t d2 = convolve4_8_x(s2, filter, perm_tbl, horiz_const);
- uint8x8_t d3 = convolve4_8_x(s3, filter, perm_tbl, horiz_const);
+ uint8x8_t d0 = convolve6_8_x(s0, x_filter, permute_tbl, horiz_const);
+ uint8x8_t d1 = convolve6_8_x(s1, x_filter, permute_tbl, horiz_const);
+ uint8x8_t d2 = convolve6_8_x(s2, x_filter, permute_tbl, horiz_const);
+ uint8x8_t d3 = convolve6_8_x(s3, x_filter, permute_tbl, horiz_const);
store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
@@ -390,7 +336,7 @@ void av1_convolve_x_sr_neon_i8mm(const uint8_t *src, int src_stride,
// Halve the total because we will halve the filter values.
const int32x4_t horiz_const = vdupq_n_s32((1 << ((ROUND0_BITS - 1)) / 2));
- if (filter_taps == 6) {
+ if (filter_taps <= 6) {
convolve_x_sr_6tap_neon_i8mm(src + 1, src_stride, dst, dst_stride, w, h,
x_filter_ptr, horiz_const);
return;
@@ -402,12 +348,6 @@ void av1_convolve_x_sr_neon_i8mm(const uint8_t *src, int src_stride,
return;
}
- if (filter_taps <= 4) {
- convolve_x_sr_4tap_neon_i8mm(src + 2, src_stride, dst, dst_stride, w, h,
- x_filter_ptr, horiz_const);
- return;
- }
-
convolve_x_sr_8tap_neon_i8mm(src, src_stride, dst, dst_stride, w, h,
x_filter_ptr, horiz_const);
}