aom: Add 4-tap specialization of aom_highbd_convolve8_horiz_sve

From 3bdf0fc289783bf14caa9ad724f9fd3ad9b46435 Mon Sep 17 00:00:00 2001
From: Salome Thirot <[EMAIL REDACTED]>
Date: Fri, 23 Feb 2024 16:27:34 +0000
Subject: [PATCH] Add 4-tap specialization of aom_highbd_convolve8_horiz_sve

This function is run mostly with 2-tap and 4-tap filters, so add a
specialised path that allows to use half as many instructions as the
generic 8-tap path for these filters.

Change-Id: I2804cdfcfffc993b33adcfe1a02e40c724dbce08
---
 aom_dsp/arm/aom_filter.h           |  33 +++++++
 aom_dsp/arm/highbd_convolve8_sve.c | 153 ++++++++++++++++++++++++++---
 2 files changed, 170 insertions(+), 16 deletions(-)
 create mode 100644 aom_dsp/arm/aom_filter.h

diff --git a/aom_dsp/arm/aom_filter.h b/aom_dsp/arm/aom_filter.h
new file mode 100644
index 000000000..9972d064f
--- /dev/null
+++ b/aom_dsp/arm/aom_filter.h
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2024, Alliance for Open Media. All rights reserved
+ *
+ * This source code is subject to the terms of the BSD 2 Clause License and
+ * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
+ * was not distributed with this source code in the LICENSE file, you can
+ * obtain it at www.aomedia.org/license/software. If the Alliance for Open
+ * Media Patent License 1.0 was not distributed with this source code in the
+ * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
+ */
+
+#ifndef AOM_AOM_DSP_ARM_AOM_FILTER_H_
+#define AOM_AOM_DSP_ARM_AOM_FILTER_H_
+
+#include <stdint.h>
+
+#include "config/aom_config.h"
+#include "config/aom_dsp_rtcd.h"
+
+static INLINE int get_filter_taps_convolve8(const int16_t *filter) {
+  if (filter[0] | filter[7]) {
+    return 8;
+  }
+  if (filter[1] | filter[6]) {
+    return 6;
+  }
+  if (filter[2] | filter[5]) {
+    return 4;
+  }
+  return 2;
+}
+
+#endif  // AOM_AOM_DSP_ARM_AOM_FILTER_H_
diff --git a/aom_dsp/arm/highbd_convolve8_sve.c b/aom_dsp/arm/highbd_convolve8_sve.c
index 46131b973..189d11b14 100644
--- a/aom_dsp/arm/highbd_convolve8_sve.c
+++ b/aom_dsp/arm/highbd_convolve8_sve.c
@@ -17,6 +17,7 @@
 #include "config/aom_dsp_rtcd.h"
 
 #include "aom_dsp/arm/aom_neon_sve_bridge.h"
+#include "aom_dsp/arm/aom_filter.h"
 #include "aom_dsp/arm/mem_neon.h"
 
 static INLINE uint16x4_t highbd_convolve8_4_h(int16x8_t s[4], int16x8_t filter,
@@ -63,22 +64,10 @@ static INLINE uint16x8_t highbd_convolve8_8_h(int16x8_t s[8], int16x8_t filter,
   return vminq_u16(res, max);
 }
 
-void aom_highbd_convolve8_horiz_sve(const uint8_t *src8, ptrdiff_t src_stride,
-                                    uint8_t *dst8, ptrdiff_t dst_stride,
-                                    const int16_t *filter_x, int x_step_q4,
-                                    const int16_t *filter_y, int y_step_q4,
-                                    int width, int height, int bd) {
-  assert(x_step_q4 == 16);
-  assert(width >= 4 && height >= 4);
-  (void)filter_y;
-  (void)x_step_q4;
-  (void)y_step_q4;
-
-  uint16_t *src = CONVERT_TO_SHORTPTR(src8);
-  uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
-
-  src -= SUBPEL_TAPS / 2 - 1;
-
+static INLINE void highbd_convolve8_horiz_8tap_sve(
+    const uint16_t *src, ptrdiff_t src_stride, uint16_t *dst,
+    ptrdiff_t dst_stride, const int16_t *filter_x, int width, int height,
+    int bd) {
   const int16x8_t filter = vld1q_s16(filter_x);
 
   if (width == 4) {
@@ -140,6 +129,138 @@ void aom_highbd_convolve8_horiz_sve(const uint8_t *src8, ptrdiff_t src_stride,
   }
 }
 
+// clang-format off
+DECLARE_ALIGNED(16, static const uint16_t, kDotProdTbl[16]) = {
+  0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6,
+};
+
+DECLARE_ALIGNED(16, static const uint16_t, kDeinterleaveTbl[8]) = {
+  0, 2, 4, 6, 1, 3, 5, 7,
+};
+// clang-format on
+
+static INLINE uint16x4_t highbd_convolve4_4_h(int16x8_t s, int16x8_t filter,
+                                              uint16x8x2_t permute_tbl,
+                                              uint16x4_t max) {
+  int16x8_t permuted_samples0 = aom_tbl_s16(s, permute_tbl.val[0]);
+  int16x8_t permuted_samples1 = aom_tbl_s16(s, permute_tbl.val[1]);
+
+  int64x2_t sum0 =
+      aom_svdot_lane_s16(vdupq_n_s64(0), permuted_samples0, filter, 0);
+  int64x2_t sum1 =
+      aom_svdot_lane_s16(vdupq_n_s64(0), permuted_samples1, filter, 0);
+
+  int32x4_t res_s32 = vcombine_s32(vmovn_s64(sum0), vmovn_s64(sum1));
+  uint16x4_t res = vqrshrun_n_s32(res_s32, FILTER_BITS);
+
+  return vmin_u16(res, max);
+}
+
+static INLINE uint16x8_t highbd_convolve4_8_h(int16x8_t s[4], int16x8_t filter,
+                                              uint16x8_t idx, uint16x8_t max) {
+  int64x2_t sum04 = aom_svdot_lane_s16(vdupq_n_s64(0), s[0], filter, 0);
+  int64x2_t sum15 = aom_svdot_lane_s16(vdupq_n_s64(0), s[1], filter, 0);
+  int64x2_t sum26 = aom_svdot_lane_s16(vdupq_n_s64(0), s[2], filter, 0);
+  int64x2_t sum37 = aom_svdot_lane_s16(vdupq_n_s64(0), s[3], filter, 0);
+
+  int32x4_t res0 = vcombine_s32(vmovn_s64(sum04), vmovn_s64(sum15));
+  int32x4_t res1 = vcombine_s32(vmovn_s64(sum26), vmovn_s64(sum37));
+
+  uint16x8_t res = vcombine_u16(vqrshrun_n_s32(res0, FILTER_BITS),
+                                vqrshrun_n_s32(res1, FILTER_BITS));
+
+  res = aom_tbl_u16(res, idx);
+
+  return vminq_u16(res, max);
+}
+
+static INLINE void highbd_convolve8_horiz_4tap_sve(
+    const uint16_t *src, ptrdiff_t src_stride, uint16_t *dst,
+    ptrdiff_t dst_stride, const int16_t *filter_x, int width, int height,
+    int bd) {
+  const int16x8_t filter = vcombine_s16(vld1_s16(filter_x + 2), vdup_n_s16(0));
+
+  if (width == 4) {
+    const uint16x4_t max = vdup_n_u16((1 << bd) - 1);
+    uint16x8x2_t permute_tbl = vld1q_u16_x2(kDotProdTbl);
+
+    const int16_t *s = (const int16_t *)src;
+    uint16_t *d = dst;
+
+    do {
+      int16x8_t s0, s1, s2, s3;
+      load_s16_8x4(s, src_stride, &s0, &s1, &s2, &s3);
+
+      uint16x4_t d0 = highbd_convolve4_4_h(s0, filter, permute_tbl, max);
+      uint16x4_t d1 = highbd_convolve4_4_h(s1, filter, permute_tbl, max);
+      uint16x4_t d2 = highbd_convolve4_4_h(s2, filter, permute_tbl, max);
+      uint16x4_t d3 = highbd_convolve4_4_h(s3, filter, permute_tbl, max);
+
+      store_u16_4x4(d, dst_stride, d0, d1, d2, d3);
+
+      s += 4 * src_stride;
+      d += 4 * dst_stride;
+      height -= 4;
+    } while (height > 0);
+  } else {
+    const uint16x8_t max = vdupq_n_u16((1 << bd) - 1);
+    uint16x8_t idx = vld1q_u16(kDeinterleaveTbl);
+
+    do {
+      const int16_t *s = (const int16_t *)src;
+      uint16_t *d = dst;
+      int w = width;
+
+      do {
+        int16x8_t s0[4], s1[4], s2[4], s3[4];
+        load_s16_8x4(s + 0 * src_stride, 1, &s0[0], &s0[1], &s0[2], &s0[3]);
+        load_s16_8x4(s + 1 * src_stride, 1, &s1[0], &s1[1], &s1[2], &s1[3]);
+        load_s16_8x4(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3]);
+        load_s16_8x4(s + 3 * src_stride, 1, &s3[0], &s3[1], &s3[2], &s3[3]);
+
+        uint16x8_t d0 = highbd_convolve4_8_h(s0, filter, idx, max);
+        uint16x8_t d1 = highbd_convolve4_8_h(s1, filter, idx, max);
+        uint16x8_t d2 = highbd_convolve4_8_h(s2, filter, idx, max);
+        uint16x8_t d3 = highbd_convolve4_8_h(s3, filter, idx, max);
+
+        store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
+
+        s += 8;
+        d += 8;
+        w -= 8;
+      } while (w != 0);
+      src += 4 * src_stride;
+      dst += 4 * dst_stride;
+      height -= 4;
+    } while (height > 0);
+  }
+}
+
+void aom_highbd_convolve8_horiz_sve(const uint8_t *src8, ptrdiff_t src_stride,
+                                    uint8_t *dst8, ptrdiff_t dst_stride,
+                                    const int16_t *filter_x, int x_step_q4,
+                                    const int16_t *filter_y, int y_step_q4,
+                                    int width, int height, int bd) {
+  assert(x_step_q4 == 16);
+  assert(width >= 4 && height >= 4);
+  (void)filter_y;
+  (void)x_step_q4;
+  (void)y_step_q4;
+
+  const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
+  uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
+
+  src -= SUBPEL_TAPS / 2 - 1;
+
+  if (get_filter_taps_convolve8(filter_x) <= 4) {
+    highbd_convolve8_horiz_4tap_sve(src + 2, src_stride, dst, dst_stride,
+                                    filter_x, width, height, bd);
+  } else {
+    highbd_convolve8_horiz_8tap_sve(src, src_stride, dst, dst_stride, filter_x,
+                                    width, height, bd);
+  }
+}
+
 DECLARE_ALIGNED(16, static const uint8_t, kDotProdTranConcatTbl[32]) = {
   0, 1, 8,  9,  16, 17, 24, 25, 2, 3, 10, 11, 18, 19, 26, 27,
   4, 5, 12, 13, 20, 21, 28, 29, 6, 7, 14, 15, 22, 23, 30, 31