aom: Add SVE implementation of aom_highbd_convolve8_vert

From 2298f3809ab76a1925ae00478f7c590c0f80de00 Mon Sep 17 00:00:00 2001
From: Salome Thirot <[EMAIL REDACTED]>
Date: Fri, 26 Jan 2024 16:01:38 +0000
Subject: [PATCH] Add SVE implementation of aom_highbd_convolve8_vert

Add SVE implementation of aom_highbd_convolve8_vert as well as the
corresponding tests.

Change-Id: I9ec674c83dc5848fb46df0e5dd9792691a9de2ed
---
 aom_dsp/aom_dsp_rtcd_defs.pl       |   2 +-
 aom_dsp/arm/dot_sve.h              |   5 +
 aom_dsp/arm/highbd_convolve8_sve.c | 299 +++++++++++++++++++++++++++++
 test/convolve_test.cc              |  12 +-
 4 files changed, 311 insertions(+), 7 deletions(-)

diff --git a/aom_dsp/aom_dsp_rtcd_defs.pl b/aom_dsp/aom_dsp_rtcd_defs.pl
index 02081cd3de..62ad9f4ce8 100755
--- a/aom_dsp/aom_dsp_rtcd_defs.pl
+++ b/aom_dsp/aom_dsp_rtcd_defs.pl
@@ -512,7 +512,7 @@ ()
   specialize qw/aom_highbd_convolve8_horiz sse2 avx2 neon sve/;
 
   add_proto qw/void aom_highbd_convolve8_vert/, "const uint8_t *src, ptrdiff_t src_stride, uint8_t *dst, ptrdiff_t dst_stride, const int16_t *filter_x, int x_step_q4, const int16_t *filter_y, int y_step_q4, int w, int h, int bd";
-  specialize qw/aom_highbd_convolve8_vert sse2 avx2 neon/;
+  specialize qw/aom_highbd_convolve8_vert sse2 avx2 neon sve/;
 }
 
 #
diff --git a/aom_dsp/arm/dot_sve.h b/aom_dsp/arm/dot_sve.h
index cf49f23606..a02716933d 100644
--- a/aom_dsp/arm/dot_sve.h
+++ b/aom_dsp/arm/dot_sve.h
@@ -39,4 +39,9 @@ static INLINE int64x2_t aom_sdotq_s16(int64x2_t acc, int16x8_t x, int16x8_t y) {
                                    svset_neonq_s16(svundef_s16(), y)));
 }
 
+#define aom_svdot_lane_s16(sum, s0, f, lane)                          \
+  svget_neonq_s64(svdot_lane_s64(svset_neonq_s64(svundef_s64(), sum), \
+                                 svset_neonq_s16(svundef_s16(), s0),  \
+                                 svset_neonq_s16(svundef_s16(), f), lane))
+
 #endif  // AOM_AOM_DSP_ARM_DOT_SVE_H_
diff --git a/aom_dsp/arm/highbd_convolve8_sve.c b/aom_dsp/arm/highbd_convolve8_sve.c
index ed3d94c212..8220a4fa12 100644
--- a/aom_dsp/arm/highbd_convolve8_sve.c
+++ b/aom_dsp/arm/highbd_convolve8_sve.c
@@ -145,3 +145,302 @@ void aom_highbd_convolve8_horiz_sve(const uint8_t *src8, ptrdiff_t src_stride,
     } while (height > 0);
   }
 }
+
+DECLARE_ALIGNED(16, static const uint8_t, kDotProdTranConcatTbl[32]) = {
+  0, 1, 8,  9,  16, 17, 24, 25, 2, 3, 10, 11, 18, 19, 26, 27,
+  4, 5, 12, 13, 20, 21, 28, 29, 6, 7, 14, 15, 22, 23, 30, 31
+};
+
+DECLARE_ALIGNED(16, static const uint8_t, kDotProdMergeBlockTbl[48]) = {
+  // Shift left and insert new last column in transposed 4x4 block.
+  2, 3, 4, 5, 6, 7, 16, 17, 10, 11, 12, 13, 14, 15, 24, 25,
+  // Shift left and insert two new columns in transposed 4x4 block.
+  4, 5, 6, 7, 16, 17, 18, 19, 12, 13, 14, 15, 24, 25, 26, 27,
+  // Shift left and insert three new columns in transposed 4x4 block.
+  6, 7, 16, 17, 18, 19, 20, 21, 14, 15, 24, 25, 26, 27, 28, 29
+};
+
+static INLINE void transpose_concat_4x4(int16x4_t s0, int16x4_t s1,
+                                        int16x4_t s2, int16x4_t s3,
+                                        int16x8_t res[2],
+                                        uint8x16_t permute_tbl[2]) {
+  // Transpose 16-bit elements and concatenate result rows as follows:
+  // s0: 00, 01, 02, 03
+  // s1: 10, 11, 12, 13
+  // s2: 20, 21, 22, 23
+  // s3: 30, 31, 32, 33
+  //
+  // res[0]: 00 10 20 30 01 11 21 31
+  // res[1]: 02 12 22 32 03 13 23 33
+  //
+  // The 'permute_tbl' is always 'kDotProdTranConcatTbl' above. Passing it
+  // as an argument is preferable to loading it directly from memory as this
+  // inline helper is called many times from the same parent function.
+
+  int8x16x2_t samples = { vreinterpretq_s8_s16(vcombine_s16(s0, s1)),
+                          vreinterpretq_s8_s16(vcombine_s16(s2, s3)) };
+
+  res[0] = vreinterpretq_s16_s8(vqtbl2q_s8(samples, permute_tbl[0]));
+  res[1] = vreinterpretq_s16_s8(vqtbl2q_s8(samples, permute_tbl[1]));
+}
+
+static INLINE void transpose_concat_8x4(int16x8_t s0, int16x8_t s1,
+                                        int16x8_t s2, int16x8_t s3,
+                                        int16x8_t res[4],
+                                        uint8x16_t permute_tbl[2]) {
+  // Transpose 16-bit elements and concatenate result rows as follows:
+  // s0: 00, 01, 02, 03, 04, 05, 06, 07
+  // s1: 10, 11, 12, 13, 14, 15, 16, 17
+  // s2: 20, 21, 22, 23, 24, 25, 26, 27
+  // s3: 30, 31, 32, 33, 34, 35, 36, 37
+  //
+  // res_lo[0]: 00 10 20 30 01 11 21 31
+  // res_lo[1]: 02 12 22 32 03 13 23 33
+  // res_hi[0]: 04 14 24 34 05 15 25 35
+  // res_hi[1]: 06 16 26 36 07 17 27 37
+  //
+  // The 'permute_tbl' is always 'kDotProdTranConcatTbl' above. Passing it
+  // as an argument is preferable to loading it directly from memory as this
+  // inline helper is called many times from the same parent function.
+
+  int8x16x2_t samples_lo = {
+    vreinterpretq_s8_s16(vcombine_s16(vget_low_s16(s0), vget_low_s16(s1))),
+    vreinterpretq_s8_s16(vcombine_s16(vget_low_s16(s2), vget_low_s16(s3)))
+  };
+
+  res[0] = vreinterpretq_s16_s8(vqtbl2q_s8(samples_lo, permute_tbl[0]));
+  res[1] = vreinterpretq_s16_s8(vqtbl2q_s8(samples_lo, permute_tbl[1]));
+
+  int8x16x2_t samples_hi = {
+    vreinterpretq_s8_s16(vcombine_s16(vget_high_s16(s0), vget_high_s16(s1))),
+    vreinterpretq_s8_s16(vcombine_s16(vget_high_s16(s2), vget_high_s16(s3)))
+  };
+
+  res[2] = vreinterpretq_s16_s8(vqtbl2q_s8(samples_hi, permute_tbl[0]));
+  res[3] = vreinterpretq_s16_s8(vqtbl2q_s8(samples_hi, permute_tbl[1]));
+}
+
+static INLINE void aom_tbl2x4_s16(int16x8_t t0[4], int16x8_t t1[4],
+                                  uint8x16_t tbl, int16x8_t res[4]) {
+  int8x16x2_t samples0 = { vreinterpretq_s8_s16(t0[0]),
+                           vreinterpretq_s8_s16(t1[0]) };
+  int8x16x2_t samples1 = { vreinterpretq_s8_s16(t0[1]),
+                           vreinterpretq_s8_s16(t1[1]) };
+  int8x16x2_t samples2 = { vreinterpretq_s8_s16(t0[2]),
+                           vreinterpretq_s8_s16(t1[2]) };
+  int8x16x2_t samples3 = { vreinterpretq_s8_s16(t0[3]),
+                           vreinterpretq_s8_s16(t1[3]) };
+
+  res[0] = vreinterpretq_s16_s8(vqtbl2q_s8(samples0, tbl));
+  res[1] = vreinterpretq_s16_s8(vqtbl2q_s8(samples1, tbl));
+  res[2] = vreinterpretq_s16_s8(vqtbl2q_s8(samples2, tbl));
+  res[3] = vreinterpretq_s16_s8(vqtbl2q_s8(samples3, tbl));
+}
+
+static INLINE void aom_tbl2x2_s16(int16x8_t t0[2], int16x8_t t1[2],
+                                  uint8x16_t tbl, int16x8_t res[2]) {
+  int8x16x2_t samples0 = { vreinterpretq_s8_s16(t0[0]),
+                           vreinterpretq_s8_s16(t1[0]) };
+  int8x16x2_t samples1 = { vreinterpretq_s8_s16(t0[1]),
+                           vreinterpretq_s8_s16(t1[1]) };
+
+  res[0] = vreinterpretq_s16_s8(vqtbl2q_s8(samples0, tbl));
+  res[1] = vreinterpretq_s16_s8(vqtbl2q_s8(samples1, tbl));
+}
+
+static INLINE uint16x4_t highbd_convolve8_4_v(int16x8_t samples_lo[2],
+                                              int16x8_t samples_hi[2],
+                                              int16x8_t filter,
+                                              uint16x4_t max) {
+  int64x2_t sum[2];
+
+  sum[0] = aom_svdot_lane_s16(vdupq_n_s64(0), samples_lo[0], filter, 0);
+  sum[0] = aom_svdot_lane_s16(sum[0], samples_hi[0], filter, 1);
+
+  sum[1] = aom_svdot_lane_s16(vdupq_n_s64(0), samples_lo[1], filter, 0);
+  sum[1] = aom_svdot_lane_s16(sum[1], samples_hi[1], filter, 1);
+
+  int32x4_t res_s32 = vcombine_s32(vmovn_s64(sum[0]), vmovn_s64(sum[1]));
+
+  uint16x4_t res = vqrshrun_n_s32(res_s32, FILTER_BITS);
+
+  return vmin_u16(res, max);
+}
+
+static INLINE uint16x8_t highbd_convolve8_8_v(int16x8_t samples_lo[4],
+                                              int16x8_t samples_hi[4],
+                                              int16x8_t filter,
+                                              uint16x8_t max) {
+  int64x2_t sum[4];
+
+  sum[0] = aom_svdot_lane_s16(vdupq_n_s64(0), samples_lo[0], filter, 0);
+  sum[0] = aom_svdot_lane_s16(sum[0], samples_hi[0], filter, 1);
+
+  sum[1] = aom_svdot_lane_s16(vdupq_n_s64(0), samples_lo[1], filter, 0);
+  sum[1] = aom_svdot_lane_s16(sum[1], samples_hi[1], filter, 1);
+
+  sum[2] = aom_svdot_lane_s16(vdupq_n_s64(0), samples_lo[2], filter, 0);
+  sum[2] = aom_svdot_lane_s16(sum[2], samples_hi[2], filter, 1);
+
+  sum[3] = aom_svdot_lane_s16(vdupq_n_s64(0), samples_lo[3], filter, 0);
+  sum[3] = aom_svdot_lane_s16(sum[3], samples_hi[3], filter, 1);
+
+  int32x4_t res0 = vcombine_s32(vmovn_s64(sum[0]), vmovn_s64(sum[1]));
+  int32x4_t res1 = vcombine_s32(vmovn_s64(sum[2]), vmovn_s64(sum[3]));
+
+  uint16x8_t res = vcombine_u16(vqrshrun_n_s32(res0, FILTER_BITS),
+                                vqrshrun_n_s32(res1, FILTER_BITS));
+
+  return vminq_u16(res, max);
+}
+
+void aom_highbd_convolve8_vert_sve(const uint8_t *src8, ptrdiff_t src_stride,
+                                   uint8_t *dst8, ptrdiff_t dst_stride,
+                                   const int16_t *filter_x, int x_step_q4,
+                                   const int16_t *filter_y, int y_step_q4,
+                                   int width, int height, int bd) {
+  assert(y_step_q4 == 16);
+  assert(w >= 4 && h >= 4);
+  (void)filter_x;
+  (void)y_step_q4;
+  (void)x_step_q4;
+
+  uint16_t *src = CONVERT_TO_SHORTPTR(src8);
+  uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
+
+  src -= (SUBPEL_TAPS / 2 - 1) * src_stride;
+
+  const int16x8_t y_filter = vld1q_s16(filter_y);
+
+  uint8x16_t tran_concat_tbl[2];
+  tran_concat_tbl[0] = vld1q_u8(kDotProdTranConcatTbl);
+  tran_concat_tbl[1] = vld1q_u8(kDotProdTranConcatTbl + 16);
+  uint8x16_t merge_block_tbl[3];
+  merge_block_tbl[0] = vld1q_u8(kDotProdMergeBlockTbl);
+  merge_block_tbl[1] = vld1q_u8(kDotProdMergeBlockTbl + 16);
+  merge_block_tbl[2] = vld1q_u8(kDotProdMergeBlockTbl + 32);
+
+  if (width == 4) {
+    const uint16x4_t max = vdup_n_u16((1 << bd) - 1);
+    int16_t *s = (int16_t *)src;
+
+    int16x4_t s0, s1, s2, s3, s4, s5, s6;
+    load_s16_4x7(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6);
+    s += 7 * src_stride;
+
+    // This operation combines a conventional transpose and the sample permute
+    // required before computing the dot product.
+    int16x8_t s0123[2], s1234[2], s2345[2], s3456[2];
+    transpose_concat_4x4(s0, s1, s2, s3, s0123, tran_concat_tbl);
+    transpose_concat_4x4(s1, s2, s3, s4, s1234, tran_concat_tbl);
+    transpose_concat_4x4(s2, s3, s4, s5, s2345, tran_concat_tbl);
+    transpose_concat_4x4(s3, s4, s5, s6, s3456, tran_concat_tbl);
+
+    do {
+      int16x4_t s7, s8, s9, s10;
+      load_s16_4x4(s, src_stride, &s7, &s8, &s9, &s10);
+
+      int16x8_t s4567[2], s5678[2], s6789[2], s78910[2];
+
+      // Transpose and shuffle the 4 lines that were loaded.
+      transpose_concat_4x4(s7, s8, s9, s10, s78910, tran_concat_tbl);
+
+      // Merge new data into block from previous iteration.
+      aom_tbl2x2_s16(s3456, s78910, merge_block_tbl[0], s4567);
+      aom_tbl2x2_s16(s3456, s78910, merge_block_tbl[1], s5678);
+      aom_tbl2x2_s16(s3456, s78910, merge_block_tbl[2], s6789);
+
+      uint16x4_t d0 = highbd_convolve8_4_v(s0123, s4567, y_filter, max);
+      uint16x4_t d1 = highbd_convolve8_4_v(s1234, s5678, y_filter, max);
+      uint16x4_t d2 = highbd_convolve8_4_v(s2345, s6789, y_filter, max);
+      uint16x4_t d3 = highbd_convolve8_4_v(s3456, s78910, y_filter, max);
+
+      store_u16_4x4(dst, dst_stride, d0, d1, d2, d3);
+
+      // Prepare block for next iteration - re-using as much as possible.
+      // Shuffle everything up four rows.
+      s0123[0] = s4567[0];
+      s0123[1] = s4567[1];
+      s1234[0] = s5678[0];
+      s1234[1] = s5678[1];
+      s2345[0] = s6789[0];
+      s2345[1] = s6789[1];
+      s3456[0] = s78910[0];
+      s3456[1] = s78910[1];
+
+      s += 4 * src_stride;
+      dst += 4 * dst_stride;
+      height -= 4;
+    } while (height != 0);
+  } else {
+    const uint16x8_t max = vdupq_n_u16((1 << bd) - 1);
+    do {
+      int h = height;
+      int16_t *s = (int16_t *)src;
+      uint16_t *d = dst;
+
+      int16x8_t s0, s1, s2, s3, s4, s5, s6;
+      load_s16_8x7(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6);
+      s += 7 * src_stride;
+
+      // This operation combines a conventional transpose and the sample permute
+      // required before computing the dot product.
+      int16x8_t s0123[4], s1234[4], s2345[4], s3456[4];
+      transpose_concat_8x4(s0, s1, s2, s3, s0123, tran_concat_tbl);
+      transpose_concat_8x4(s1, s2, s3, s4, s1234, tran_concat_tbl);
+      transpose_concat_8x4(s2, s3, s4, s5, s2345, tran_concat_tbl);
+      transpose_concat_8x4(s3, s4, s5, s6, s3456, tran_concat_tbl);
+
+      do {
+        int16x8_t s7, s8, s9, s10;
+        load_s16_8x4(s, src_stride, &s7, &s8, &s9, &s10);
+
+        int16x8_t s4567[4], s5678[4], s6789[4], s78910[4];
+
+        // Transpose and shuffle the 4 lines that were loaded.
+        transpose_concat_8x4(s7, s8, s9, s10, s78910, tran_concat_tbl);
+
+        // Merge new data into block from previous iteration.
+        aom_tbl2x4_s16(s3456, s78910, merge_block_tbl[0], s4567);
+        aom_tbl2x4_s16(s3456, s78910, merge_block_tbl[1], s5678);
+        aom_tbl2x4_s16(s3456, s78910, merge_block_tbl[2], s6789);
+
+        uint16x8_t d0 = highbd_convolve8_8_v(s0123, s4567, y_filter, max);
+        uint16x8_t d1 = highbd_convolve8_8_v(s1234, s5678, y_filter, max);
+        uint16x8_t d2 = highbd_convolve8_8_v(s2345, s6789, y_filter, max);
+        uint16x8_t d3 = highbd_convolve8_8_v(s3456, s78910, y_filter, max);
+
+        store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
+
+        // Prepare block for next iteration - re-using as much as possible.
+        // Shuffle everything up four rows.
+        s0123[0] = s4567[0];
+        s0123[1] = s4567[1];
+        s0123[2] = s4567[2];
+        s0123[3] = s4567[3];
+
+        s1234[0] = s5678[0];
+        s1234[1] = s5678[1];
+        s1234[2] = s5678[2];
+        s1234[3] = s5678[3];
+
+        s2345[0] = s6789[0];
+        s2345[1] = s6789[1];
+        s2345[2] = s6789[2];
+        s2345[3] = s6789[3];
+
+        s3456[0] = s78910[0];
+        s3456[1] = s78910[1];
+        s3456[2] = s78910[2];
+        s3456[3] = s78910[3];
+
+        s += 4 * src_stride;
+        d += 4 * dst_stride;
+        h -= 4;
+      } while (h != 0);
+      src += 8;
+      dst += 8;
+      width -= 8;
+    } while (width != 0);
+  }
+}
diff --git a/test/convolve_test.cc b/test/convolve_test.cc
index 1416015188..cab590927b 100644
--- a/test/convolve_test.cc
+++ b/test/convolve_test.cc
@@ -776,10 +776,13 @@ WRAP(convolve8_vert_neon, 12)
 
 #if HAVE_SVE
 WRAP(convolve8_horiz_sve, 8)
+WRAP(convolve8_vert_sve, 8)
 
 WRAP(convolve8_horiz_sve, 10)
+WRAP(convolve8_vert_sve, 10)
 
 WRAP(convolve8_horiz_sve, 12)
+WRAP(convolve8_vert_sve, 12)
 #endif  // HAVE_SVE
 #endif  // CONFIG_AV1_HIGHBITDEPTH
 
@@ -922,16 +925,13 @@ INSTANTIATE_TEST_SUITE_P(NEON_I8MM, LowbdConvolveTest,
 #endif  // HAVE_NEON_I8MM
 
 #if HAVE_SVE
-// The tests don't allow separate testing of the vertical and horizontal pass,
-// so use the Neon implementation of aom_highbd_convolve8_vert until an SVE one
-// is added.
 #if CONFIG_AV1_HIGHBITDEPTH
 const ConvolveFunctions wrap_convolve8_sve(wrap_convolve8_horiz_sve_8,
-                                           wrap_convolve8_vert_neon_8, 8);
+                                           wrap_convolve8_vert_sve_8, 8);
 const ConvolveFunctions wrap_convolve10_sve(wrap_convolve8_horiz_sve_10,
-                                            wrap_convolve8_vert_neon_10, 10);
+                                            wrap_convolve8_vert_sve_10, 10);
 const ConvolveFunctions wrap_convolve12_sve(wrap_convolve8_horiz_sve_12,
-                                            wrap_convolve8_vert_neon_12, 12);
+                                            wrap_convolve8_vert_sve_12, 12);
 const ConvolveParam kArray_HighbdConvolve8_sve[] = {
   ALL_SIZES_64(wrap_convolve8_sve), ALL_SIZES_64(wrap_convolve10_sve),
   ALL_SIZES_64(wrap_convolve12_sve)