aom: Specialize 8-tap HBD dist_wtd_convolve_y_sve2 on bitdepth

From 5da47ec716e9c320afab896b6f41bfd94da7914c Mon Sep 17 00:00:00 2001
From: Salome Thirot <[EMAIL REDACTED]>
Date: Mon, 11 Mar 2024 14:10:18 +0000
Subject: [PATCH] Specialize 8-tap HBD dist_wtd_convolve_y_sve2 on bitdepth

The rounding value at the end of the convolution depends on the bitdepth
(8/10, or 12). Add 2 specialized versions of the function, so that we
know the rounding value at compile time and therefore use only one
instruction instead of two to perform the final rounding and narrowing
step. This gives up to 10% uplift over the non-specialized version.

Change-Id: Id30c3ba63567ff92ff0a79aaf94c091e0d807f02
---
 .../arm/highbd_compound_convolve_sve2.c       | 261 ++++++++++++++----
 1 file changed, 212 insertions(+), 49 deletions(-)

diff --git a/av1/common/arm/highbd_compound_convolve_sve2.c b/av1/common/arm/highbd_compound_convolve_sve2.c
index f500110f1..d9ea83d55 100644
--- a/av1/common/arm/highbd_compound_convolve_sve2.c
+++ b/av1/common/arm/highbd_compound_convolve_sve2.c
@@ -438,11 +438,186 @@ void av1_highbd_dist_wtd_convolve_x_sve2(
   }
 }
 
+static INLINE uint16x4_t highbd_12_convolve8_4_y(int16x8_t samples_lo[2],
+                                                 int16x8_t samples_hi[2],
+                                                 int16x8_t filter,
+                                                 int64x2_t offset) {
+  int64x2_t sum01 = aom_svdot_lane_s16(offset, samples_lo[0], filter, 0);
+  sum01 = aom_svdot_lane_s16(sum01, samples_hi[0], filter, 1);
+
+  int64x2_t sum23 = aom_svdot_lane_s16(offset, samples_lo[1], filter, 0);
+  sum23 = aom_svdot_lane_s16(sum23, samples_hi[1], filter, 1);
+
+  int32x4_t sum0123 = vcombine_s32(vmovn_s64(sum01), vmovn_s64(sum23));
+
+  return vqrshrun_n_s32(sum0123, ROUND0_BITS + 2);
+}
+
+static INLINE uint16x8_t highbd_12_convolve8_8_y(int16x8_t samples_lo[4],
+                                                 int16x8_t samples_hi[4],
+                                                 int16x8_t filter,
+                                                 int64x2_t offset) {
+  int64x2_t sum01 = aom_svdot_lane_s16(offset, samples_lo[0], filter, 0);
+  sum01 = aom_svdot_lane_s16(sum01, samples_hi[0], filter, 1);
+
+  int64x2_t sum23 = aom_svdot_lane_s16(offset, samples_lo[1], filter, 0);
+  sum23 = aom_svdot_lane_s16(sum23, samples_hi[1], filter, 1);
+
+  int64x2_t sum45 = aom_svdot_lane_s16(offset, samples_lo[2], filter, 0);
+  sum45 = aom_svdot_lane_s16(sum45, samples_hi[2], filter, 1);
+
+  int64x2_t sum67 = aom_svdot_lane_s16(offset, samples_lo[3], filter, 0);
+  sum67 = aom_svdot_lane_s16(sum67, samples_hi[3], filter, 1);
+
+  int32x4_t sum0123 = vcombine_s32(vmovn_s64(sum01), vmovn_s64(sum23));
+  int32x4_t sum4567 = vcombine_s32(vmovn_s64(sum45), vmovn_s64(sum67));
+
+  return vcombine_u16(vqrshrun_n_s32(sum0123, ROUND0_BITS + 2),
+                      vqrshrun_n_s32(sum4567, ROUND0_BITS + 2));
+}
+
+static INLINE void highbd_12_dist_wtd_convolve_y_8tap_sve2(
+    const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride,
+    int width, int height, const int16_t *y_filter_ptr) {
+  const int64x2_t offset =
+      vdupq_n_s64((1 << (12 + FILTER_BITS)) + (1 << (12 + FILTER_BITS - 1)));
+  const int16x8_t y_filter = vld1q_s16(y_filter_ptr);
+
+  uint16x8x3_t merge_block_tbl = vld1q_u16_x3(kDotProdMergeBlockTbl);
+  // Scale indices by size of the true vector length to avoid reading from an
+  // 'undefined' portion of a vector on a system with SVE vectors > 128-bit.
+  uint16x8_t correction0 =
+      vreinterpretq_u16_u64(vdupq_n_u64(svcnth() * 0x0001000000000000ULL));
+  merge_block_tbl.val[0] = vaddq_u16(merge_block_tbl.val[0], correction0);
+  uint16x8_t correction1 =
+      vreinterpretq_u16_u64(vdupq_n_u64(svcnth() * 0x0001000100000000ULL));
+  merge_block_tbl.val[1] = vaddq_u16(merge_block_tbl.val[1], correction1);
+
+  uint16x8_t correction2 =
+      vreinterpretq_u16_u64(vdupq_n_u64(svcnth() * 0x0001000100010000ULL));
+  merge_block_tbl.val[2] = vaddq_u16(merge_block_tbl.val[2], correction2);
+
+  if (width == 4) {
+    int16_t *s = (int16_t *)src;
+    int16x4_t s0, s1, s2, s3, s4, s5, s6;
+    load_s16_4x7(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6);
+    s += 7 * src_stride;
+
+    // This operation combines a conventional transpose and the sample permute
+    // required before computing the dot product.
+    int16x8_t s0123[2], s1234[2], s2345[2], s3456[2];
+    transpose_concat_4x4(s0, s1, s2, s3, s0123);
+    transpose_concat_4x4(s1, s2, s3, s4, s1234);
+    transpose_concat_4x4(s2, s3, s4, s5, s2345);
+    transpose_concat_4x4(s3, s4, s5, s6, s3456);
+
+    do {
+      int16x4_t s7, s8, s9, s10;
+      load_s16_4x4(s, src_stride, &s7, &s8, &s9, &s10);
+
+      int16x8_t s4567[2], s5678[2], s6789[2], s789A[2];
+      // Transpose and shuffle the 4 lines that were loaded.
+      transpose_concat_4x4(s7, s8, s9, s10, s789A);
+
+      // Merge new data into block from previous iteration.
+      aom_tbl2x2_s16(s3456, s789A, merge_block_tbl.val[0], s4567);
+      aom_tbl2x2_s16(s3456, s789A, merge_block_tbl.val[1], s5678);
+      aom_tbl2x2_s16(s3456, s789A, merge_block_tbl.val[2], s6789);
+
+      uint16x4_t d0 = highbd_12_convolve8_4_y(s0123, s4567, y_filter, offset);
+      uint16x4_t d1 = highbd_12_convolve8_4_y(s1234, s5678, y_filter, offset);
+      uint16x4_t d2 = highbd_12_convolve8_4_y(s2345, s6789, y_filter, offset);
+      uint16x4_t d3 = highbd_12_convolve8_4_y(s3456, s789A, y_filter, offset);
+
+      store_u16_4x4(dst, dst_stride, d0, d1, d2, d3);
+
+      // Prepare block for next iteration - re-using as much as possible.
+      // Shuffle everything up four rows.
+      s0123[0] = s4567[0];
+      s0123[1] = s4567[1];
+      s1234[0] = s5678[0];
+      s1234[1] = s5678[1];
+      s2345[0] = s6789[0];
+      s2345[1] = s6789[1];
+      s3456[0] = s789A[0];
+      s3456[1] = s789A[1];
+
+      s += 4 * src_stride;
+      dst += 4 * dst_stride;
+      height -= 4;
+    } while (height != 0);
+  } else {
+    do {
+      int h = height;
+      int16_t *s = (int16_t *)src;
+      uint16_t *d = dst;
+
+      int16x8_t s0, s1, s2, s3, s4, s5, s6;
+      load_s16_8x7(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6);
+      s += 7 * src_stride;
+
+      // This operation combines a conventional transpose and the sample permute
+      // required before computing the dot product.
+      int16x8_t s0123[4], s1234[4], s2345[4], s3456[4];
+      transpose_concat_8x4(s0, s1, s2, s3, s0123);
+      transpose_concat_8x4(s1, s2, s3, s4, s1234);
+      transpose_concat_8x4(s2, s3, s4, s5, s2345);
+      transpose_concat_8x4(s3, s4, s5, s6, s3456);
+
+      do {
+        int16x8_t s7, s8, s9, s10;
+        load_s16_8x4(s, src_stride, &s7, &s8, &s9, &s10);
+        int16x8_t s4567[4], s5678[4], s6789[4], s789A[4];
+
+        // Transpose and shuffle the 4 lines that were loaded.
+        transpose_concat_8x4(s7, s8, s9, s10, s789A);
+
+        // Merge new data into block from previous iteration.
+        aom_tbl2x4_s16(s3456, s789A, merge_block_tbl.val[0], s4567);
+        aom_tbl2x4_s16(s3456, s789A, merge_block_tbl.val[1], s5678);
+        aom_tbl2x4_s16(s3456, s789A, merge_block_tbl.val[2], s6789);
+
+        uint16x8_t d0 = highbd_12_convolve8_8_y(s0123, s4567, y_filter, offset);
+        uint16x8_t d1 = highbd_12_convolve8_8_y(s1234, s5678, y_filter, offset);
+        uint16x8_t d2 = highbd_12_convolve8_8_y(s2345, s6789, y_filter, offset);
+        uint16x8_t d3 = highbd_12_convolve8_8_y(s3456, s789A, y_filter, offset);
+
+        store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
+
+        // Prepare block for next iteration - re-using as much as possible.
+        // Shuffle everything up four rows.
+        s0123[0] = s4567[0];
+        s0123[1] = s4567[1];
+        s0123[2] = s4567[2];
+        s0123[3] = s4567[3];
+        s1234[0] = s5678[0];
+        s1234[1] = s5678[1];
+        s1234[2] = s5678[2];
+        s1234[3] = s5678[3];
+        s2345[0] = s6789[0];
+        s2345[1] = s6789[1];
+        s2345[2] = s6789[2];
+        s2345[3] = s6789[3];
+        s3456[0] = s789A[0];
+        s3456[1] = s789A[1];
+        s3456[2] = s789A[2];
+        s3456[3] = s789A[3];
+
+        s += 4 * src_stride;
+        d += 4 * dst_stride;
+        h -= 4;
+      } while (h != 0);
+      src += 8;
+      dst += 8;
+      width -= 8;
+    } while (width != 0);
+  }
+}
+
 static INLINE uint16x4_t highbd_convolve8_4_y(int16x8_t samples_lo[2],
                                               int16x8_t samples_hi[2],
                                               int16x8_t filter,
-                                              int64x2_t offset,
-                                              int32x4_t shift) {
+                                              int64x2_t offset) {
   int64x2_t sum01 = aom_svdot_lane_s16(offset, samples_lo[0], filter, 0);
   sum01 = aom_svdot_lane_s16(sum01, samples_hi[0], filter, 1);
 
@@ -450,16 +625,14 @@ static INLINE uint16x4_t highbd_convolve8_4_y(int16x8_t samples_lo[2],
   sum23 = aom_svdot_lane_s16(sum23, samples_hi[1], filter, 1);
 
   int32x4_t sum0123 = vcombine_s32(vmovn_s64(sum01), vmovn_s64(sum23));
-  sum0123 = vshlq_s32(sum0123, shift);
 
-  return vqmovun_s32(sum0123);
+  return vqrshrun_n_s32(sum0123, ROUND0_BITS);
 }
 
 static INLINE uint16x8_t highbd_convolve8_8_y(int16x8_t samples_lo[4],
                                               int16x8_t samples_hi[4],
                                               int16x8_t filter,
-                                              int64x2_t offset,
-                                              int32x4_t shift) {
+                                              int64x2_t offset) {
   int64x2_t sum01 = aom_svdot_lane_s16(offset, samples_lo[0], filter, 0);
   sum01 = aom_svdot_lane_s16(sum01, samples_hi[0], filter, 1);
 
@@ -473,21 +646,18 @@ static INLINE uint16x8_t highbd_convolve8_8_y(int16x8_t samples_lo[4],
   sum67 = aom_svdot_lane_s16(sum67, samples_hi[3], filter, 1);
 
   int32x4_t sum0123 = vcombine_s32(vmovn_s64(sum01), vmovn_s64(sum23));
-  sum0123 = vshlq_s32(sum0123, shift);
-
   int32x4_t sum4567 = vcombine_s32(vmovn_s64(sum45), vmovn_s64(sum67));
-  sum4567 = vshlq_s32(sum4567, shift);
 
-  return vcombine_u16(vqmovun_s32(sum0123), vqmovun_s32(sum4567));
+  return vcombine_u16(vqrshrun_n_s32(sum0123, ROUND0_BITS),
+                      vqrshrun_n_s32(sum4567, ROUND0_BITS));
 }
 
 static INLINE void highbd_dist_wtd_convolve_y_8tap_sve2(
     const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride,
-    int width, int height, const int16_t *y_filter_ptr,
-    ConvolveParams *conv_params, int offset) {
+    int width, int height, const int16_t *y_filter_ptr, const int bd) {
+  const int64x2_t offset =
+      vdupq_n_s64((1 << (bd + FILTER_BITS)) + (1 << (bd + FILTER_BITS - 1)));
   const int16x8_t y_filter = vld1q_s16(y_filter_ptr);
-  const int64x2_t offset_s64 = vdupq_n_s64(offset);
-  const int32x4_t shift = vdupq_n_s32(-conv_params->round_0);
 
   uint16x8x3_t merge_block_tbl = vld1q_u16_x3(kDotProdMergeBlockTbl);
   // Scale indices by size of the true vector length to avoid reading from an
@@ -530,14 +700,10 @@ static INLINE void highbd_dist_wtd_convolve_y_8tap_sve2(
       aom_tbl2x2_s16(s3456, s789A, merge_block_tbl.val[1], s5678);
       aom_tbl2x2_s16(s3456, s789A, merge_block_tbl.val[2], s6789);
 
-      uint16x4_t d0 =
-          highbd_convolve8_4_y(s0123, s4567, y_filter, offset_s64, shift);
-      uint16x4_t d1 =
-          highbd_convolve8_4_y(s1234, s5678, y_filter, offset_s64, shift);
-      uint16x4_t d2 =
-          highbd_convolve8_4_y(s2345, s6789, y_filter, offset_s64, shift);
-      uint16x4_t d3 =
-          highbd_convolve8_4_y(s3456, s789A, y_filter, offset_s64, shift);
+      uint16x4_t d0 = highbd_convolve8_4_y(s0123, s4567, y_filter, offset);
+      uint16x4_t d1 = highbd_convolve8_4_y(s1234, s5678, y_filter, offset);
+      uint16x4_t d2 = highbd_convolve8_4_y(s2345, s6789, y_filter, offset);
+      uint16x4_t d3 = highbd_convolve8_4_y(s3456, s789A, y_filter, offset);
 
       store_u16_4x4(dst, dst_stride, d0, d1, d2, d3);
 
@@ -587,14 +753,10 @@ static INLINE void highbd_dist_wtd_convolve_y_8tap_sve2(
         aom_tbl2x4_s16(s3456, s789A, merge_block_tbl.val[1], s5678);
         aom_tbl2x4_s16(s3456, s789A, merge_block_tbl.val[2], s6789);
 
-        uint16x8_t d0 =
-            highbd_convolve8_8_y(s0123, s4567, y_filter, offset_s64, shift);
-        uint16x8_t d1 =
-            highbd_convolve8_8_y(s1234, s5678, y_filter, offset_s64, shift);
-        uint16x8_t d2 =
-            highbd_convolve8_8_y(s2345, s6789, y_filter, offset_s64, shift);
-        uint16x8_t d3 =
-            highbd_convolve8_8_y(s3456, s789A, y_filter, offset_s64, shift);
+        uint16x8_t d0 = highbd_convolve8_8_y(s0123, s4567, y_filter, offset);
+        uint16x8_t d1 = highbd_convolve8_8_y(s1234, s5678, y_filter, offset);
+        uint16x8_t d2 = highbd_convolve8_8_y(s2345, s6789, y_filter, offset);
+        uint16x8_t d3 = highbd_convolve8_8_y(s3456, s789A, y_filter, offset);
 
         store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
 
@@ -648,40 +810,41 @@ void av1_highbd_dist_wtd_convolve_y_sve2(
   const int im_stride = MAX_SB_SIZE;
   const int vert_offset = filter_params_y->taps / 2 - 1;
   assert(FILTER_BITS == COMPOUND_ROUND1_BITS);
-  const int round_offset_conv = (1 << (conv_params->round_0 - 1)) +
-                                (1 << (bd + FILTER_BITS)) +
-                                (1 << (bd + FILTER_BITS - 1));
 
   const int16_t *y_filter_ptr = av1_get_interp_filter_subpel_kernel(
       filter_params_y, subpel_y_qn & SUBPEL_MASK);
 
   src -= vert_offset * src_stride;
 
-  if (conv_params->do_average) {
-    highbd_dist_wtd_convolve_y_8tap_sve2(src, src_stride, im_block, im_stride,
-                                         w, h, y_filter_ptr, conv_params,
-                                         round_offset_conv);
-    if (conv_params->use_dist_wtd_comp_avg) {
-      if (bd == 12) {
+  if (bd == 12) {
+    if (conv_params->do_average) {
+      highbd_12_dist_wtd_convolve_y_8tap_sve2(src, src_stride, im_block,
+                                              im_stride, w, h, y_filter_ptr);
+      if (conv_params->use_dist_wtd_comp_avg) {
         highbd_12_dist_wtd_comp_avg_neon(im_block, im_stride, dst, dst_stride,
                                          w, h, conv_params);
       } else {
-        highbd_dist_wtd_comp_avg_neon(im_block, im_stride, dst, dst_stride, w,
-                                      h, conv_params, bd);
-      }
-    } else {
-      if (bd == 12) {
         highbd_12_comp_avg_neon(im_block, im_stride, dst, dst_stride, w, h,
                                 conv_params);
-
+      }
+    } else {
+      highbd_12_dist_wtd_convolve_y_8tap_sve2(src, src_stride, dst16,
+                                              dst16_stride, w, h, y_filter_ptr);
+    }
+  } else {
+    if (conv_params->do_average) {
+      highbd_dist_wtd_convolve_y_8tap_sve2(src, src_stride, im_block, im_stride,
+                                           w, h, y_filter_ptr, bd);
+      if (conv_params->use_dist_wtd_comp_avg) {
+        highbd_dist_wtd_comp_avg_neon(im_block, im_stride, dst, dst_stride, w,
+                                      h, conv_params, bd);
       } else {
         highbd_comp_avg_neon(im_block, im_stride, dst, dst_stride, w, h,
                              conv_params, bd);
       }
+    } else {
+      highbd_dist_wtd_convolve_y_8tap_sve2(src, src_stride, dst16, dst16_stride,
+                                           w, h, y_filter_ptr, bd);
     }
-  } else {
-    highbd_dist_wtd_convolve_y_8tap_sve2(src, src_stride, dst16, dst16_stride,
-                                         w, h, y_filter_ptr, conv_params,
-                                         round_offset_conv);
   }
 }