aom: Specialise 8-tap Neon HBD 2D dist_wtd_convolve on bitdepth

From 636add45516729ae9d7b86c5f1b78cc9326c0c5f Mon Sep 17 00:00:00 2001
From: Salome Thirot <[EMAIL REDACTED]>
Date: Mon, 18 Mar 2024 17:38:33 +0000
Subject: [PATCH] Specialise 8-tap Neon HBD 2D dist_wtd_convolve on bitdepth

Add a 12-bit specialised path for the 4-tap horizontal pass of
av1_highbd_dist_wtd_convolve_2d_neon, giving up to 10% uplift for the
whole 2D convolution over the non-specialized version.

Change-Id: Id0ac318885c25bcd8d12c32c931d68b2ca595203
---
 .../arm/highbd_compound_convolve_sve2.c       | 160 ++++++++++--------
 1 file changed, 93 insertions(+), 67 deletions(-)

diff --git a/av1/common/arm/highbd_compound_convolve_sve2.c b/av1/common/arm/highbd_compound_convolve_sve2.c
index baffc0edb..8d618fd34 100644
--- a/av1/common/arm/highbd_compound_convolve_sve2.c
+++ b/av1/common/arm/highbd_compound_convolve_sve2.c
@@ -849,42 +849,74 @@ void av1_highbd_dist_wtd_convolve_y_sve2(
   }
 }
 
-static INLINE uint16x8_t highbd_convolve8_8_2d_h(int16x8_t s0[8],
-                                                 int16x8_t filter,
-                                                 int64x2_t offset,
-                                                 int32x4_t shift) {
-  int64x2_t sum[8];
+static INLINE void highbd_12_dist_wtd_convolve_2d_horiz_8tap_sve2(
+    const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride,
+    int width, int height, const int16_t *x_filter_ptr) {
+  const int64x2_t offset = vdupq_n_s64(1 << (12 + FILTER_BITS - 2));
+  const int16x8_t filter = vld1q_s16(x_filter_ptr);
 
-  sum[0] = aom_sdotq_s16(offset, s0[0], filter);
-  sum[1] = aom_sdotq_s16(offset, s0[1], filter);
-  sum[2] = aom_sdotq_s16(offset, s0[2], filter);
-  sum[3] = aom_sdotq_s16(offset, s0[3], filter);
-  sum[4] = aom_sdotq_s16(offset, s0[4], filter);
-  sum[5] = aom_sdotq_s16(offset, s0[5], filter);
-  sum[6] = aom_sdotq_s16(offset, s0[6], filter);
-  sum[7] = aom_sdotq_s16(offset, s0[7], filter);
+  // We are only doing 8-tap and 4-tap vertical convolutions, therefore we know
+  // that im_h % 4 = 3, so we can do the loop across the whole block 4 rows at
+  // a time and then process the last 3 rows separately.
 
-  sum[0] = vpaddq_s64(sum[0], sum[1]);
-  sum[2] = vpaddq_s64(sum[2], sum[3]);
-  sum[4] = vpaddq_s64(sum[4], sum[5]);
-  sum[6] = vpaddq_s64(sum[6], sum[7]);
+  do {
+    const int16_t *s = (const int16_t *)src;
+    uint16_t *d = dst;
+    int w = width;
 
-  int32x4_t sum0123 = vcombine_s32(vmovn_s64(sum[0]), vmovn_s64(sum[2]));
-  int32x4_t sum4567 = vcombine_s32(vmovn_s64(sum[4]), vmovn_s64(sum[6]));
+    do {
+      int16x8_t s0[8], s1[8], s2[8], s3[8];
+      load_s16_8x8(s + 0 * src_stride, 1, &s0[0], &s0[1], &s0[2], &s0[3],
+                   &s0[4], &s0[5], &s0[6], &s0[7]);
+      load_s16_8x8(s + 1 * src_stride, 1, &s1[0], &s1[1], &s1[2], &s1[3],
+                   &s1[4], &s1[5], &s1[6], &s1[7]);
+      load_s16_8x8(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3],
+                   &s2[4], &s2[5], &s2[6], &s2[7]);
+      load_s16_8x8(s + 3 * src_stride, 1, &s3[0], &s3[1], &s3[2], &s3[3],
+                   &s3[4], &s3[5], &s3[6], &s3[7]);
 
-  sum0123 = vshlq_s32(sum0123, shift);
-  sum4567 = vshlq_s32(sum4567, shift);
+      uint16x8_t d0 = highbd_12_convolve8_8_x(s0, filter, offset);
+      uint16x8_t d1 = highbd_12_convolve8_8_x(s1, filter, offset);
+      uint16x8_t d2 = highbd_12_convolve8_8_x(s2, filter, offset);
+      uint16x8_t d3 = highbd_12_convolve8_8_x(s3, filter, offset);
 
-  return vcombine_u16(vqmovun_s32(sum0123), vqmovun_s32(sum4567));
+      store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
+
+      s += 8;
+      d += 8;
+      w -= 8;
+    } while (w != 0);
+    src += 4 * src_stride;
+    dst += 4 * dst_stride;
+    height -= 4;
+  } while (height > 4);
+
+  // Process final 3 rows.
+  const int16_t *s = (const int16_t *)src;
+  do {
+    int16x8_t s0[8], s1[8], s2[8];
+    load_s16_8x8(s + 0 * src_stride, 1, &s0[0], &s0[1], &s0[2], &s0[3], &s0[4],
+                 &s0[5], &s0[6], &s0[7]);
+    load_s16_8x8(s + 1 * src_stride, 1, &s1[0], &s1[1], &s1[2], &s1[3], &s1[4],
+                 &s1[5], &s1[6], &s1[7]);
+    load_s16_8x8(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3], &s2[4],
+                 &s2[5], &s2[6], &s2[7]);
+
+    uint16x8_t d0 = highbd_12_convolve8_8_x(s0, filter, offset);
+    uint16x8_t d1 = highbd_12_convolve8_8_x(s1, filter, offset);
+    uint16x8_t d2 = highbd_12_convolve8_8_x(s2, filter, offset);
+
+    store_u16_8x3(dst, dst_stride, d0, d1, d2);
+    s += 8;
+    dst += 8;
+    width -= 8;
+  } while (width != 0);
 }
 
 static INLINE void highbd_dist_wtd_convolve_2d_horiz_8tap_sve2(
     const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride,
-    int width, int height, const int16_t *x_filter_ptr,
-    ConvolveParams *conv_params, const int offset) {
-  const int32x4_t shift = vdupq_n_s32(-conv_params->round_0);
-
-  const int64x2_t offset_lo = vcombine_s64(vcreate_s64(offset), vdup_n_s64(0));
+    int width, int height, const int16_t *x_filter_ptr, const int bd) {
+  const int64x2_t offset = vdupq_n_s64(1 << (bd + FILTER_BITS - 2));
   const int16x8_t filter = vld1q_s16(x_filter_ptr);
 
   // We are only doing 8-tap and 4-tap vertical convolutions, therefore we know
@@ -907,10 +939,10 @@ static INLINE void highbd_dist_wtd_convolve_2d_horiz_8tap_sve2(
       load_s16_8x8(s + 3 * src_stride, 1, &s3[0], &s3[1], &s3[2], &s3[3],
                    &s3[4], &s3[5], &s3[6], &s3[7]);
 
-      uint16x8_t d0 = highbd_convolve8_8_2d_h(s0, filter, offset_lo, shift);
-      uint16x8_t d1 = highbd_convolve8_8_2d_h(s1, filter, offset_lo, shift);
-      uint16x8_t d2 = highbd_convolve8_8_2d_h(s2, filter, offset_lo, shift);
-      uint16x8_t d3 = highbd_convolve8_8_2d_h(s3, filter, offset_lo, shift);
+      uint16x8_t d0 = highbd_convolve8_8_x(s0, filter, offset);
+      uint16x8_t d1 = highbd_convolve8_8_x(s1, filter, offset);
+      uint16x8_t d2 = highbd_convolve8_8_x(s2, filter, offset);
+      uint16x8_t d3 = highbd_convolve8_8_x(s3, filter, offset);
 
       store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
 
@@ -934,9 +966,9 @@ static INLINE void highbd_dist_wtd_convolve_2d_horiz_8tap_sve2(
     load_s16_8x8(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3], &s2[4],
                  &s2[5], &s2[6], &s2[7]);
 
-    uint16x8_t d0 = highbd_convolve8_8_2d_h(s0, filter, offset_lo, shift);
-    uint16x8_t d1 = highbd_convolve8_8_2d_h(s1, filter, offset_lo, shift);
-    uint16x8_t d2 = highbd_convolve8_8_2d_h(s2, filter, offset_lo, shift);
+    uint16x8_t d0 = highbd_convolve8_8_x(s0, filter, offset);
+    uint16x8_t d1 = highbd_convolve8_8_x(s1, filter, offset);
+    uint16x8_t d2 = highbd_convolve8_8_x(s2, filter, offset);
 
     store_u16_8x3(dst, dst_stride, d0, d1, d2);
     s += 8;
@@ -984,8 +1016,9 @@ static INLINE uint16x8_t highbd_convolve4_8_2d_h(int16x8_t s0[4],
 static INLINE void highbd_dist_wtd_convolve_2d_horiz_4tap_sve2(
     const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride,
     int width, int height, const int16_t *x_filter_ptr,
-    ConvolveParams *conv_params, const int offset) {
-  const int64x2_t offset_vec = vdupq_n_s64(offset);
+    ConvolveParams *conv_params, const int bd) {
+  const int64x2_t offset = vdupq_n_s64((1 << (bd + FILTER_BITS - 1)) +
+                                       (1 << (conv_params->round_0 - 1)));
   const int32x4_t shift = vdupq_n_s32(-conv_params->round_0);
   const int16x4_t x_filter = vld1_s16(x_filter_ptr + 2);
   const int16x8_t filter = vcombine_s16(x_filter, vdup_n_s16(0));
@@ -1004,13 +1037,13 @@ static INLINE void highbd_dist_wtd_convolve_2d_horiz_4tap_sve2(
       load_s16_8x4(s, src_stride, &s0, &s1, &s2, &s3);
 
       uint16x4_t d0 =
-          highbd_convolve4_4_2d_h(s0, filter, offset_vec, shift, permute_tbl);
+          highbd_convolve4_4_2d_h(s0, filter, offset, shift, permute_tbl);
       uint16x4_t d1 =
-          highbd_convolve4_4_2d_h(s1, filter, offset_vec, shift, permute_tbl);
+          highbd_convolve4_4_2d_h(s1, filter, offset, shift, permute_tbl);
       uint16x4_t d2 =
-          highbd_convolve4_4_2d_h(s2, filter, offset_vec, shift, permute_tbl);
+          highbd_convolve4_4_2d_h(s2, filter, offset, shift, permute_tbl);
       uint16x4_t d3 =
-          highbd_convolve4_4_2d_h(s3, filter, offset_vec, shift, permute_tbl);
+          highbd_convolve4_4_2d_h(s3, filter, offset, shift, permute_tbl);
 
       store_u16_4x4(dst, dst_stride, d0, d1, d2, d3);
 
@@ -1024,11 +1057,11 @@ static INLINE void highbd_dist_wtd_convolve_2d_horiz_4tap_sve2(
     load_s16_8x3(s, src_stride, &s0, &s1, &s2);
 
     uint16x4_t d0 =
-        highbd_convolve4_4_2d_h(s0, filter, offset_vec, shift, permute_tbl);
+        highbd_convolve4_4_2d_h(s0, filter, offset, shift, permute_tbl);
     uint16x4_t d1 =
-        highbd_convolve4_4_2d_h(s1, filter, offset_vec, shift, permute_tbl);
+        highbd_convolve4_4_2d_h(s1, filter, offset, shift, permute_tbl);
     uint16x4_t d2 =
-        highbd_convolve4_4_2d_h(s2, filter, offset_vec, shift, permute_tbl);
+        highbd_convolve4_4_2d_h(s2, filter, offset, shift, permute_tbl);
 
     store_u16_4x3(dst, dst_stride, d0, d1, d2);
 
@@ -1047,14 +1080,10 @@ static INLINE void highbd_dist_wtd_convolve_2d_horiz_4tap_sve2(
         load_s16_8x4(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3]);
         load_s16_8x4(s + 3 * src_stride, 1, &s3[0], &s3[1], &s3[2], &s3[3]);
 
-        uint16x8_t d0 =
-            highbd_convolve4_8_2d_h(s0, filter, offset_vec, shift, idx);
-        uint16x8_t d1 =
-            highbd_convolve4_8_2d_h(s1, filter, offset_vec, shift, idx);
-        uint16x8_t d2 =
-            highbd_convolve4_8_2d_h(s2, filter, offset_vec, shift, idx);
-        uint16x8_t d3 =
-            highbd_convolve4_8_2d_h(s3, filter, offset_vec, shift, idx);
+        uint16x8_t d0 = highbd_convolve4_8_2d_h(s0, filter, offset, shift, idx);
+        uint16x8_t d1 = highbd_convolve4_8_2d_h(s1, filter, offset, shift, idx);
+        uint16x8_t d2 = highbd_convolve4_8_2d_h(s2, filter, offset, shift, idx);
+        uint16x8_t d3 = highbd_convolve4_8_2d_h(s3, filter, offset, shift, idx);
 
         store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
 
@@ -1076,12 +1105,9 @@ static INLINE void highbd_dist_wtd_convolve_2d_horiz_4tap_sve2(
       load_s16_8x4(s + 1 * src_stride, 1, &s1[0], &s1[1], &s1[2], &s1[3]);
       load_s16_8x4(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3]);
 
-      uint16x8_t d0 =
-          highbd_convolve4_8_2d_h(s0, filter, offset_vec, shift, idx);
-      uint16x8_t d1 =
-          highbd_convolve4_8_2d_h(s1, filter, offset_vec, shift, idx);
-      uint16x8_t d2 =
-          highbd_convolve4_8_2d_h(s2, filter, offset_vec, shift, idx);
+      uint16x8_t d0 = highbd_convolve4_8_2d_h(s0, filter, offset, shift, idx);
+      uint16x8_t d1 = highbd_convolve4_8_2d_h(s1, filter, offset, shift, idx);
+      uint16x8_t d2 = highbd_convolve4_8_2d_h(s2, filter, offset, shift, idx);
 
       store_u16_8x3(dst, dst_stride, d0, d1, d2);
 
@@ -1410,10 +1436,6 @@ void av1_highbd_dist_wtd_convolve_2d_sve2(
   const int im_stride = MAX_SB_SIZE;
   const int vert_offset = clamped_y_taps / 2 - 1;
   const int horiz_offset = clamped_x_taps / 2 - 1;
-  // The extra shim of (1 << (conv_params->round_0 - 1)) allows us to use a
-  // faster non-rounding non-saturating left shift.
-  const int round_offset_conv_x =
-      (1 << (bd + FILTER_BITS - 1)) + (1 << (conv_params->round_0 - 1));
   const int y_offset_bits = bd + 2 * FILTER_BITS - conv_params->round_0;
   const int round_offset_conv_y = (1 << y_offset_bits);
 
@@ -1425,13 +1447,17 @@ void av1_highbd_dist_wtd_convolve_2d_sve2(
       filter_params_y, subpel_y_qn & SUBPEL_MASK);
 
   if (x_filter_taps <= 4) {
-    highbd_dist_wtd_convolve_2d_horiz_4tap_sve2(
-        src_ptr, src_stride, im_block, im_stride, w, im_h, x_filter_ptr,
-        conv_params, round_offset_conv_x);
+    highbd_dist_wtd_convolve_2d_horiz_4tap_sve2(src_ptr, src_stride, im_block,
+                                                im_stride, w, im_h,
+                                                x_filter_ptr, conv_params, bd);
   } else {
-    highbd_dist_wtd_convolve_2d_horiz_8tap_sve2(
-        src_ptr, src_stride, im_block, im_stride, w, im_h, x_filter_ptr,
-        conv_params, round_offset_conv_x);
+    if (bd == 12) {
+      highbd_12_dist_wtd_convolve_2d_horiz_8tap_sve2(
+          src_ptr, src_stride, im_block, im_stride, w, im_h, x_filter_ptr);
+    } else {
+      highbd_dist_wtd_convolve_2d_horiz_8tap_sve2(
+          src_ptr, src_stride, im_block, im_stride, w, im_h, x_filter_ptr, bd);
+    }
   }
 
   if (conv_params->do_average) {