aom: Add 2-tap path for aom_highbd_convolve8_horiz_neon

From 276f8f8011388d70a41bfaac12a2725d660ad69a Mon Sep 17 00:00:00 2001
From: Salome Thirot <[EMAIL REDACTED]>
Date: Thu, 18 Apr 2024 16:00:10 +0100
Subject: [PATCH] Add 2-tap path for aom_highbd_convolve8_horiz_neon

Add a specialized Neon implementation for 2-tap filters and use it
instead of the 4-tap implementation in both Neon and SVE Neon versions
of aom_highbd_convolve8_horiz. This provides between 40% and 80% uplift
over the 4-tap implementation.

Change-Id: Ie24189770a066e1155d0239ed6e80a9e8a7938ce
---
 aom_dsp/arm/highbd_convolve8_neon.c |  8 ++-
 aom_dsp/arm/highbd_convolve8_neon.h | 98 +++++++++++++++++++++++++++++
 aom_dsp/arm/highbd_convolve8_sve.c  |  8 ++-
 3 files changed, 112 insertions(+), 2 deletions(-)
 create mode 100644 aom_dsp/arm/highbd_convolve8_neon.h

diff --git a/aom_dsp/arm/highbd_convolve8_neon.c b/aom_dsp/arm/highbd_convolve8_neon.c
index 6d8ce2961..75d1cd413 100644
--- a/aom_dsp/arm/highbd_convolve8_neon.c
+++ b/aom_dsp/arm/highbd_convolve8_neon.c
@@ -20,6 +20,7 @@
 #include "aom_dsp/aom_dsp_common.h"
 #include "aom_dsp/aom_filter.h"
 #include "aom_dsp/arm/aom_filter.h"
+#include "aom_dsp/arm/highbd_convolve8_neon.h"
 #include "aom_dsp/arm/mem_neon.h"
 #include "aom_dsp/arm/transpose_neon.h"
 #include "aom_ports/mem.h"
@@ -277,7 +278,12 @@ void aom_highbd_convolve8_horiz_neon(const uint8_t *src8, ptrdiff_t src_stride,
 
     src -= SUBPEL_TAPS / 2 - 1;
 
-    if (get_filter_taps_convolve8(filter_x) <= 4) {
+    const int filter_taps = get_filter_taps_convolve8(filter_x);
+
+    if (filter_taps == 2) {
+      highbd_convolve8_horiz_2tap_neon(src + 3, src_stride, dst, dst_stride,
+                                       filter_x, w, h, bd);
+    } else if (filter_taps == 4) {
       highbd_convolve_horiz_4tap_neon(src + 2, src_stride, dst, dst_stride,
                                       filter_x, w, h, bd);
     } else {
diff --git a/aom_dsp/arm/highbd_convolve8_neon.h b/aom_dsp/arm/highbd_convolve8_neon.h
new file mode 100644
index 000000000..05cff79a9
--- /dev/null
+++ b/aom_dsp/arm/highbd_convolve8_neon.h
@@ -0,0 +1,98 @@
+/*
+ *  Copyright (c) 2024, Alliance for Open Media. All Rights Reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license
+ *  that can be found in the LICENSE file in the root of the source
+ *  tree. An additional intellectual property rights grant can be found
+ *  in the file PATENTS.  All contributing project authors may
+ *  be found in the AUTHORS file in the root of the source tree.
+ */
+
+#ifndef AOM_AOM_DSP_ARM_HIGHBD_CONVOLVE8_NEON_H_
+#define AOM_AOM_DSP_ARM_HIGHBD_CONVOLVE8_NEON_H_
+
+#include <arm_neon.h>
+
+#include "config/aom_config.h"
+#include "aom_dsp/arm/mem_neon.h"
+
+static INLINE void highbd_convolve8_horiz_2tap_neon(
+    const uint16_t *src_ptr, ptrdiff_t src_stride, uint16_t *dst_ptr,
+    ptrdiff_t dst_stride, const int16_t *x_filter_ptr, int w, int h, int bd) {
+  // Bilinear filter values are all positive and multiples of 8. Divide by 8 to
+  // reduce intermediate precision requirements and allow the use of non
+  // widening multiply.
+  const uint16x8_t f0 = vdupq_n_u16((uint16_t)x_filter_ptr[3] / 8);
+  const uint16x8_t f1 = vdupq_n_u16((uint16_t)x_filter_ptr[4] / 8);
+
+  const uint16x8_t max = vdupq_n_u16((1 << bd) - 1);
+
+  if (w == 4) {
+    do {
+      uint16x8_t s0 =
+          load_unaligned_u16_4x2(src_ptr + 0 * src_stride + 0, (int)src_stride);
+      uint16x8_t s1 =
+          load_unaligned_u16_4x2(src_ptr + 0 * src_stride + 1, (int)src_stride);
+      uint16x8_t s2 =
+          load_unaligned_u16_4x2(src_ptr + 2 * src_stride + 0, (int)src_stride);
+      uint16x8_t s3 =
+          load_unaligned_u16_4x2(src_ptr + 2 * src_stride + 1, (int)src_stride);
+
+      uint16x8_t sum01 = vmulq_u16(s0, f0);
+      sum01 = vmlaq_u16(sum01, s1, f1);
+      uint16x8_t sum23 = vmulq_u16(s2, f0);
+      sum23 = vmlaq_u16(sum23, s3, f1);
+
+      // We divided filter taps by 8 so subtract 3 from right shift.
+      sum01 = vrshrq_n_u16(sum01, FILTER_BITS - 3);
+      sum23 = vrshrq_n_u16(sum23, FILTER_BITS - 3);
+
+      sum01 = vminq_u16(sum01, max);
+      sum23 = vminq_u16(sum23, max);
+
+      store_u16x4_strided_x2(dst_ptr + 0 * dst_stride, (int)dst_stride, sum01);
+      store_u16x4_strided_x2(dst_ptr + 2 * dst_stride, (int)dst_stride, sum23);
+
+      src_ptr += 4 * src_stride;
+      dst_ptr += 4 * dst_stride;
+      h -= 4;
+    } while (h > 0);
+  } else {
+    do {
+      int width = w;
+      const uint16_t *s = src_ptr;
+      uint16_t *d = dst_ptr;
+
+      do {
+        uint16x8_t s0 = vld1q_u16(s + 0 * src_stride + 0);
+        uint16x8_t s1 = vld1q_u16(s + 0 * src_stride + 1);
+        uint16x8_t s2 = vld1q_u16(s + 1 * src_stride + 0);
+        uint16x8_t s3 = vld1q_u16(s + 1 * src_stride + 1);
+
+        uint16x8_t sum01 = vmulq_u16(s0, f0);
+        sum01 = vmlaq_u16(sum01, s1, f1);
+        uint16x8_t sum23 = vmulq_u16(s2, f0);
+        sum23 = vmlaq_u16(sum23, s3, f1);
+
+        // We divided filter taps by 8 so subtract 3 from right shift.
+        sum01 = vrshrq_n_u16(sum01, FILTER_BITS - 3);
+        sum23 = vrshrq_n_u16(sum23, FILTER_BITS - 3);
+
+        sum01 = vminq_u16(sum01, max);
+        sum23 = vminq_u16(sum23, max);
+
+        vst1q_u16(d + 0 * dst_stride, sum01);
+        vst1q_u16(d + 1 * dst_stride, sum23);
+
+        s += 8;
+        d += 8;
+        width -= 8;
+      } while (width != 0);
+      src_ptr += 2 * src_stride;
+      dst_ptr += 2 * dst_stride;
+      h -= 2;
+    } while (h > 0);
+  }
+}
+
+#endif  // AOM_AOM_DSP_ARM_HIGHBD_CONVOLVE8_NEON_H_
diff --git a/aom_dsp/arm/highbd_convolve8_sve.c b/aom_dsp/arm/highbd_convolve8_sve.c
index e57c41a0b..ef977181b 100644
--- a/aom_dsp/arm/highbd_convolve8_sve.c
+++ b/aom_dsp/arm/highbd_convolve8_sve.c
@@ -18,6 +18,7 @@
 
 #include "aom_dsp/arm/aom_neon_sve_bridge.h"
 #include "aom_dsp/arm/aom_filter.h"
+#include "aom_dsp/arm/highbd_convolve8_neon.h"
 #include "aom_dsp/arm/mem_neon.h"
 
 static INLINE uint16x4_t highbd_convolve8_4_h(int16x8_t s[4], int16x8_t filter,
@@ -252,7 +253,12 @@ void aom_highbd_convolve8_horiz_sve(const uint8_t *src8, ptrdiff_t src_stride,
 
   src -= SUBPEL_TAPS / 2 - 1;
 
-  if (get_filter_taps_convolve8(filter_x) <= 4) {
+  const int filter_taps = get_filter_taps_convolve8(filter_x);
+
+  if (filter_taps == 2) {
+    highbd_convolve8_horiz_2tap_neon(src + 3, src_stride, dst, dst_stride,
+                                     filter_x, width, height, bd);
+  } else if (filter_taps == 4) {
     highbd_convolve8_horiz_4tap_sve(src + 2, src_stride, dst, dst_stride,
                                     filter_x, width, height, bd);
   } else {