aom: Specialise 4-tap Neon HBD 2D dist_wtd_convolve on bitdepth

From b218391386aded8e82bdb44f34ad316a16639925 Mon Sep 17 00:00:00 2001
From: Salome Thirot <[EMAIL REDACTED]>
Date: Tue, 19 Mar 2024 16:39:54 +0000
Subject: [PATCH] Specialise 4-tap Neon HBD 2D dist_wtd_convolve on bitdepth

Add a 12-bit specialised path for the 4-tap horizontal pass of
av1_highbd_dist_wtd_convolve_2d_neon, giving up to 10% uplift for the
whole 2D convolution over the non-specialized version.

Change-Id: I997cfa3945d3920630311d8409ff32ede4f050e0
---
 .../arm/highbd_compound_convolve_sve2.c       | 172 ++++++++++++------
 1 file changed, 112 insertions(+), 60 deletions(-)

diff --git a/av1/common/arm/highbd_compound_convolve_sve2.c b/av1/common/arm/highbd_compound_convolve_sve2.c
index 8d618fd34..1d6c9b4fa 100644
--- a/av1/common/arm/highbd_compound_convolve_sve2.c
+++ b/av1/common/arm/highbd_compound_convolve_sve2.c
@@ -977,49 +977,105 @@ static INLINE void highbd_dist_wtd_convolve_2d_horiz_8tap_sve2(
   } while (width != 0);
 }
 
-static INLINE uint16x4_t highbd_convolve4_4_2d_h(int16x8_t s0, int16x8_t filter,
-                                                 int64x2_t offset,
-                                                 int32x4_t shift,
-                                                 uint16x8x2_t permute_tbl) {
-  int16x8_t permuted_samples0 = aom_tbl_s16(s0, permute_tbl.val[0]);
-  int16x8_t permuted_samples1 = aom_tbl_s16(s0, permute_tbl.val[1]);
+static INLINE void highbd_12_dist_wtd_convolve_2d_horiz_4tap_sve2(
+    const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride,
+    int width, int height, const int16_t *x_filter_ptr) {
+  const int64x2_t offset = vdupq_n_s64(1 << (12 + FILTER_BITS - 1));
+  const int16x4_t x_filter = vld1_s16(x_filter_ptr + 2);
+  const int16x8_t filter = vcombine_s16(x_filter, vdup_n_s16(0));
 
-  int64x2_t sum01 = aom_svdot_lane_s16(offset, permuted_samples0, filter, 0);
-  int64x2_t sum23 = aom_svdot_lane_s16(offset, permuted_samples1, filter, 0);
+  // We are only doing 8-tap and 4-tap vertical convolutions, therefore we know
+  // that im_h % 4 = 3, so we can do the loop across the whole block 4 rows at
+  // a time and then process the last 3 rows separately.
 
-  int32x4_t sum0123 = vcombine_s32(vmovn_s64(sum01), vmovn_s64(sum23));
-  sum0123 = vshlq_s32(sum0123, shift);
+  if (width == 4) {
+    uint16x8x2_t permute_tbl = vld1q_u16_x2(kDotProdTbl);
 
-  return vqmovun_s32(sum0123);
-}
+    const int16_t *s = (const int16_t *)(src);
 
-static INLINE uint16x8_t highbd_convolve4_8_2d_h(int16x8_t s0[4],
-                                                 int16x8_t filter,
-                                                 int64x2_t offset,
-                                                 int32x4_t shift,
-                                                 uint16x8_t tbl) {
-  int64x2_t sum04 = aom_svdot_lane_s16(offset, s0[0], filter, 0);
-  int64x2_t sum15 = aom_svdot_lane_s16(offset, s0[1], filter, 0);
-  int64x2_t sum26 = aom_svdot_lane_s16(offset, s0[2], filter, 0);
-  int64x2_t sum37 = aom_svdot_lane_s16(offset, s0[3], filter, 0);
+    do {
+      int16x8_t s0, s1, s2, s3;
+      load_s16_8x4(s, src_stride, &s0, &s1, &s2, &s3);
 
-  int32x4_t sum0415 = vcombine_s32(vmovn_s64(sum04), vmovn_s64(sum15));
-  sum0415 = vshlq_s32(sum0415, shift);
+      uint16x4_t d0 = highbd_12_convolve4_4_x(s0, filter, offset, permute_tbl);
+      uint16x4_t d1 = highbd_12_convolve4_4_x(s1, filter, offset, permute_tbl);
+      uint16x4_t d2 = highbd_12_convolve4_4_x(s2, filter, offset, permute_tbl);
+      uint16x4_t d3 = highbd_12_convolve4_4_x(s3, filter, offset, permute_tbl);
 
-  int32x4_t sum2637 = vcombine_s32(vmovn_s64(sum26), vmovn_s64(sum37));
-  sum2637 = vshlq_s32(sum2637, shift);
+      store_u16_4x4(dst, dst_stride, d0, d1, d2, d3);
 
-  uint16x8_t res = vcombine_u16(vqmovun_s32(sum0415), vqmovun_s32(sum2637));
-  return aom_tbl_u16(res, tbl);
+      s += 4 * src_stride;
+      dst += 4 * dst_stride;
+      height -= 4;
+    } while (height > 4);
+
+    // Process final 3 rows.
+    int16x8_t s0, s1, s2;
+    load_s16_8x3(s, src_stride, &s0, &s1, &s2);
+
+    uint16x4_t d0 = highbd_12_convolve4_4_x(s0, filter, offset, permute_tbl);
+    uint16x4_t d1 = highbd_12_convolve4_4_x(s1, filter, offset, permute_tbl);
+    uint16x4_t d2 = highbd_12_convolve4_4_x(s2, filter, offset, permute_tbl);
+
+    store_u16_4x3(dst, dst_stride, d0, d1, d2);
+
+  } else {
+    uint16x8_t idx = vld1q_u16(kDeinterleaveTbl);
+
+    do {
+      const int16_t *s = (const int16_t *)(src);
+      uint16_t *d = dst;
+      int w = width;
+
+      do {
+        int16x8_t s0[4], s1[4], s2[4], s3[4];
+        load_s16_8x4(s + 0 * src_stride, 1, &s0[0], &s0[1], &s0[2], &s0[3]);
+        load_s16_8x4(s + 1 * src_stride, 1, &s1[0], &s1[1], &s1[2], &s1[3]);
+        load_s16_8x4(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3]);
+        load_s16_8x4(s + 3 * src_stride, 1, &s3[0], &s3[1], &s3[2], &s3[3]);
+
+        uint16x8_t d0 = highbd_12_convolve4_8_x(s0, filter, offset, idx);
+        uint16x8_t d1 = highbd_12_convolve4_8_x(s1, filter, offset, idx);
+        uint16x8_t d2 = highbd_12_convolve4_8_x(s2, filter, offset, idx);
+        uint16x8_t d3 = highbd_12_convolve4_8_x(s3, filter, offset, idx);
+
+        store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
+
+        s += 8;
+        d += 8;
+        w -= 8;
+      } while (w != 0);
+      src += 4 * src_stride;
+      dst += 4 * dst_stride;
+      height -= 4;
+    } while (height > 4);
+
+    // Process final 3 rows.
+    const int16_t *s = (const int16_t *)(src);
+
+    do {
+      int16x8_t s0[4], s1[4], s2[4];
+      load_s16_8x4(s + 0 * src_stride, 1, &s0[0], &s0[1], &s0[2], &s0[3]);
+      load_s16_8x4(s + 1 * src_stride, 1, &s1[0], &s1[1], &s1[2], &s1[3]);
+      load_s16_8x4(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3]);
+
+      uint16x8_t d0 = highbd_12_convolve4_8_x(s0, filter, offset, idx);
+      uint16x8_t d1 = highbd_12_convolve4_8_x(s1, filter, offset, idx);
+      uint16x8_t d2 = highbd_12_convolve4_8_x(s2, filter, offset, idx);
+
+      store_u16_8x3(dst, dst_stride, d0, d1, d2);
+
+      s += 8;
+      dst += 8;
+      width -= 8;
+    } while (width != 0);
+  }
 }
 
 static INLINE void highbd_dist_wtd_convolve_2d_horiz_4tap_sve2(
     const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride,
-    int width, int height, const int16_t *x_filter_ptr,
-    ConvolveParams *conv_params, const int bd) {
-  const int64x2_t offset = vdupq_n_s64((1 << (bd + FILTER_BITS - 1)) +
-                                       (1 << (conv_params->round_0 - 1)));
-  const int32x4_t shift = vdupq_n_s32(-conv_params->round_0);
+    int width, int height, const int16_t *x_filter_ptr, const int bd) {
+  const int64x2_t offset = vdupq_n_s64(1 << (bd + FILTER_BITS - 1));
   const int16x4_t x_filter = vld1_s16(x_filter_ptr + 2);
   const int16x8_t filter = vcombine_s16(x_filter, vdup_n_s16(0));
 
@@ -1036,14 +1092,10 @@ static INLINE void highbd_dist_wtd_convolve_2d_horiz_4tap_sve2(
       int16x8_t s0, s1, s2, s3;
       load_s16_8x4(s, src_stride, &s0, &s1, &s2, &s3);
 
-      uint16x4_t d0 =
-          highbd_convolve4_4_2d_h(s0, filter, offset, shift, permute_tbl);
-      uint16x4_t d1 =
-          highbd_convolve4_4_2d_h(s1, filter, offset, shift, permute_tbl);
-      uint16x4_t d2 =
-          highbd_convolve4_4_2d_h(s2, filter, offset, shift, permute_tbl);
-      uint16x4_t d3 =
-          highbd_convolve4_4_2d_h(s3, filter, offset, shift, permute_tbl);
+      uint16x4_t d0 = highbd_convolve4_4_x(s0, filter, offset, permute_tbl);
+      uint16x4_t d1 = highbd_convolve4_4_x(s1, filter, offset, permute_tbl);
+      uint16x4_t d2 = highbd_convolve4_4_x(s2, filter, offset, permute_tbl);
+      uint16x4_t d3 = highbd_convolve4_4_x(s3, filter, offset, permute_tbl);
 
       store_u16_4x4(dst, dst_stride, d0, d1, d2, d3);
 
@@ -1056,15 +1108,11 @@ static INLINE void highbd_dist_wtd_convolve_2d_horiz_4tap_sve2(
     int16x8_t s0, s1, s2;
     load_s16_8x3(s, src_stride, &s0, &s1, &s2);
 
-    uint16x4_t d0 =
-        highbd_convolve4_4_2d_h(s0, filter, offset, shift, permute_tbl);
-    uint16x4_t d1 =
-        highbd_convolve4_4_2d_h(s1, filter, offset, shift, permute_tbl);
-    uint16x4_t d2 =
-        highbd_convolve4_4_2d_h(s2, filter, offset, shift, permute_tbl);
+    uint16x4_t d0 = highbd_convolve4_4_x(s0, filter, offset, permute_tbl);
+    uint16x4_t d1 = highbd_convolve4_4_x(s1, filter, offset, permute_tbl);
+    uint16x4_t d2 = highbd_convolve4_4_x(s2, filter, offset, permute_tbl);
 
     store_u16_4x3(dst, dst_stride, d0, d1, d2);
-
   } else {
     uint16x8_t idx = vld1q_u16(kDeinterleaveTbl);
 
@@ -1080,10 +1128,10 @@ static INLINE void highbd_dist_wtd_convolve_2d_horiz_4tap_sve2(
         load_s16_8x4(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3]);
         load_s16_8x4(s + 3 * src_stride, 1, &s3[0], &s3[1], &s3[2], &s3[3]);
 
-        uint16x8_t d0 = highbd_convolve4_8_2d_h(s0, filter, offset, shift, idx);
-        uint16x8_t d1 = highbd_convolve4_8_2d_h(s1, filter, offset, shift, idx);
-        uint16x8_t d2 = highbd_convolve4_8_2d_h(s2, filter, offset, shift, idx);
-        uint16x8_t d3 = highbd_convolve4_8_2d_h(s3, filter, offset, shift, idx);
+        uint16x8_t d0 = highbd_convolve4_8_x(s0, filter, offset, idx);
+        uint16x8_t d1 = highbd_convolve4_8_x(s1, filter, offset, idx);
+        uint16x8_t d2 = highbd_convolve4_8_x(s2, filter, offset, idx);
+        uint16x8_t d3 = highbd_convolve4_8_x(s3, filter, offset, idx);
 
         store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
 
@@ -1105,9 +1153,9 @@ static INLINE void highbd_dist_wtd_convolve_2d_horiz_4tap_sve2(
       load_s16_8x4(s + 1 * src_stride, 1, &s1[0], &s1[1], &s1[2], &s1[3]);
       load_s16_8x4(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3]);
 
-      uint16x8_t d0 = highbd_convolve4_8_2d_h(s0, filter, offset, shift, idx);
-      uint16x8_t d1 = highbd_convolve4_8_2d_h(s1, filter, offset, shift, idx);
-      uint16x8_t d2 = highbd_convolve4_8_2d_h(s2, filter, offset, shift, idx);
+      uint16x8_t d0 = highbd_convolve4_8_x(s0, filter, offset, idx);
+      uint16x8_t d1 = highbd_convolve4_8_x(s1, filter, offset, idx);
+      uint16x8_t d2 = highbd_convolve4_8_x(s2, filter, offset, idx);
 
       store_u16_8x3(dst, dst_stride, d0, d1, d2);
 
@@ -1446,14 +1494,18 @@ void av1_highbd_dist_wtd_convolve_2d_sve2(
   const int16_t *y_filter_ptr = av1_get_interp_filter_subpel_kernel(
       filter_params_y, subpel_y_qn & SUBPEL_MASK);
 
-  if (x_filter_taps <= 4) {
-    highbd_dist_wtd_convolve_2d_horiz_4tap_sve2(src_ptr, src_stride, im_block,
-                                                im_stride, w, im_h,
-                                                x_filter_ptr, conv_params, bd);
-  } else {
-    if (bd == 12) {
+  if (bd == 12) {
+    if (x_filter_taps <= 4) {
+      highbd_12_dist_wtd_convolve_2d_horiz_4tap_sve2(
+          src_ptr, src_stride, im_block, im_stride, w, im_h, x_filter_ptr);
+    } else {
       highbd_12_dist_wtd_convolve_2d_horiz_8tap_sve2(
           src_ptr, src_stride, im_block, im_stride, w, im_h, x_filter_ptr);
+    }
+  } else {
+    if (x_filter_taps <= 4) {
+      highbd_dist_wtd_convolve_2d_horiz_4tap_sve2(
+          src_ptr, src_stride, im_block, im_stride, w, im_h, x_filter_ptr, bd);
     } else {
       highbd_dist_wtd_convolve_2d_horiz_8tap_sve2(
           src_ptr, src_stride, im_block, im_stride, w, im_h, x_filter_ptr, bd);