aom: Add 2-tap path for aom_convolve8_horiz_neon

From 42e3156b5343e725aa79cda40a464a891e6c8735 Mon Sep 17 00:00:00 2001
From: Salome Thirot <[EMAIL REDACTED]>
Date: Tue, 16 Apr 2024 11:39:22 +0100
Subject: [PATCH] Add 2-tap path for aom_convolve8_horiz_neon

Add a specialized Neon implementation for 2-tap filters and use it
instead of the 4-tap implementation in all 3 Neon versions of
aom_convolve8_horiz. This provides between 20% and 60% uplift depending
on the architecture extension.

Change-Id: I48da9553cd391dd801affdc1c62995d1f1a48f15
---
 aom_dsp/arm/aom_convolve8_neon.c         |   8 +-
 aom_dsp/arm/aom_convolve8_neon.h         | 106 +++++++++++++++++++++++
 aom_dsp/arm/aom_convolve8_neon_dotprod.c |   8 +-
 aom_dsp/arm/aom_convolve8_neon_i8mm.c    |   8 +-
 4 files changed, 127 insertions(+), 3 deletions(-)
 create mode 100644 aom_dsp/arm/aom_convolve8_neon.h

diff --git a/aom_dsp/arm/aom_convolve8_neon.c b/aom_dsp/arm/aom_convolve8_neon.c
index 6a177b2e6..5665b5e12 100644
--- a/aom_dsp/arm/aom_convolve8_neon.c
+++ b/aom_dsp/arm/aom_convolve8_neon.c
@@ -20,6 +20,7 @@
 #include "aom/aom_integer.h"
 #include "aom_dsp/aom_dsp_common.h"
 #include "aom_dsp/aom_filter.h"
+#include "aom_dsp/arm/aom_convolve8_neon.h"
 #include "aom_dsp/arm/aom_filter.h"
 #include "aom_dsp/arm/mem_neon.h"
 #include "aom_dsp/arm/transpose_neon.h"
@@ -354,7 +355,12 @@ void aom_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride,
 
   src -= ((SUBPEL_TAPS / 2) - 1);
 
-  if (get_filter_taps_convolve8(filter_x) <= 4) {
+  int filter_taps = get_filter_taps_convolve8(filter_x);
+
+  if (filter_taps == 2) {
+    convolve8_horiz_2tap_neon(src + 3, src_stride, dst, dst_stride, filter_x, w,
+                              h);
+  } else if (filter_taps == 4) {
     convolve8_horiz_4tap_neon(src + 2, src_stride, dst, dst_stride, filter_x, w,
                               h);
   } else {
diff --git a/aom_dsp/arm/aom_convolve8_neon.h b/aom_dsp/arm/aom_convolve8_neon.h
new file mode 100644
index 000000000..0aebc6d12
--- /dev/null
+++ b/aom_dsp/arm/aom_convolve8_neon.h
@@ -0,0 +1,106 @@
+/*
+ *  Copyright (c) 2024, Alliance for Open Media. All Rights Reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license
+ *  that can be found in the LICENSE file in the root of the source
+ *  tree. An additional intellectual property rights grant can be found
+ *  in the file PATENTS.  All contributing project authors may
+ *  be found in the AUTHORS file in the root of the source tree.
+ */
+
+#ifndef AOM_AOM_DSP_ARM_AOM_CONVOLVE8_NEON_H_
+#define AOM_AOM_DSP_ARM_AOM_CONVOLVE8_NEON_H_
+
+#include <arm_neon.h>
+
+#include "config/aom_config.h"
+#include "aom_dsp/arm/mem_neon.h"
+
+static INLINE void convolve8_horiz_2tap_neon(const uint8_t *src,
+                                             ptrdiff_t src_stride, uint8_t *dst,
+                                             ptrdiff_t dst_stride,
+                                             const int16_t *filter_x, int w,
+                                             int h) {
+  // Bilinear filter values are all positive.
+  const uint8x8_t f0 = vdup_n_u8((uint8_t)filter_x[3]);
+  const uint8x8_t f1 = vdup_n_u8((uint8_t)filter_x[4]);
+
+  if (w == 4) {
+    do {
+      uint8x8_t s0 =
+          load_unaligned_u8(src + 0 * src_stride + 0, (int)src_stride);
+      uint8x8_t s1 =
+          load_unaligned_u8(src + 0 * src_stride + 1, (int)src_stride);
+      uint8x8_t s2 =
+          load_unaligned_u8(src + 2 * src_stride + 0, (int)src_stride);
+      uint8x8_t s3 =
+          load_unaligned_u8(src + 2 * src_stride + 1, (int)src_stride);
+
+      uint16x8_t sum0 = vmull_u8(s0, f0);
+      sum0 = vmlal_u8(sum0, s1, f1);
+      uint16x8_t sum1 = vmull_u8(s2, f0);
+      sum1 = vmlal_u8(sum1, s3, f1);
+
+      uint8x8_t d0 = vqrshrn_n_u16(sum0, FILTER_BITS);
+      uint8x8_t d1 = vqrshrn_n_u16(sum1, FILTER_BITS);
+
+      store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d0);
+      store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d1);
+
+      src += 4 * src_stride;
+      dst += 4 * dst_stride;
+      h -= 4;
+    } while (h > 0);
+  } else if (w == 8) {
+    do {
+      uint8x8_t s0 = vld1_u8(src + 0 * src_stride + 0);
+      uint8x8_t s1 = vld1_u8(src + 0 * src_stride + 1);
+      uint8x8_t s2 = vld1_u8(src + 1 * src_stride + 0);
+      uint8x8_t s3 = vld1_u8(src + 1 * src_stride + 1);
+
+      uint16x8_t sum0 = vmull_u8(s0, f0);
+      sum0 = vmlal_u8(sum0, s1, f1);
+      uint16x8_t sum1 = vmull_u8(s2, f0);
+      sum1 = vmlal_u8(sum1, s3, f1);
+
+      uint8x8_t d0 = vqrshrn_n_u16(sum0, FILTER_BITS);
+      uint8x8_t d1 = vqrshrn_n_u16(sum1, FILTER_BITS);
+
+      vst1_u8(dst + 0 * dst_stride, d0);
+      vst1_u8(dst + 1 * dst_stride, d1);
+
+      src += 2 * src_stride;
+      dst += 2 * dst_stride;
+      h -= 2;
+    } while (h > 0);
+  } else {
+    do {
+      int width = w;
+      const uint8_t *s = src;
+      uint8_t *d = dst;
+
+      do {
+        uint8x16_t s0 = vld1q_u8(s + 0);
+        uint8x16_t s1 = vld1q_u8(s + 1);
+
+        uint16x8_t sum0 = vmull_u8(vget_low_u8(s0), f0);
+        sum0 = vmlal_u8(sum0, vget_low_u8(s1), f1);
+        uint16x8_t sum1 = vmull_u8(vget_high_u8(s0), f0);
+        sum1 = vmlal_u8(sum1, vget_high_u8(s1), f1);
+
+        uint8x8_t d0 = vqrshrn_n_u16(sum0, FILTER_BITS);
+        uint8x8_t d1 = vqrshrn_n_u16(sum1, FILTER_BITS);
+
+        vst1q_u8(d, vcombine_u8(d0, d1));
+
+        s += 16;
+        d += 16;
+        width -= 16;
+      } while (width != 0);
+      src += src_stride;
+      dst += dst_stride;
+    } while (--h > 0);
+  }
+}
+
+#endif  // AOM_AOM_DSP_ARM_AOM_CONVOLVE8_NEON_H_
diff --git a/aom_dsp/arm/aom_convolve8_neon_dotprod.c b/aom_dsp/arm/aom_convolve8_neon_dotprod.c
index 576db8e4f..f49d33ff3 100644
--- a/aom_dsp/arm/aom_convolve8_neon_dotprod.c
+++ b/aom_dsp/arm/aom_convolve8_neon_dotprod.c
@@ -20,6 +20,7 @@
 #include "aom/aom_integer.h"
 #include "aom_dsp/aom_dsp_common.h"
 #include "aom_dsp/aom_filter.h"
+#include "aom_dsp/arm/aom_convolve8_neon.h"
 #include "aom_dsp/arm/aom_filter.h"
 #include "aom_dsp/arm/mem_neon.h"
 #include "aom_dsp/arm/transpose_neon.h"
@@ -269,7 +270,12 @@ void aom_convolve8_horiz_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride,
 
   src -= ((SUBPEL_TAPS / 2) - 1);
 
-  if (get_filter_taps_convolve8(filter_x) <= 4) {
+  int filter_taps = get_filter_taps_convolve8(filter_x);
+
+  if (filter_taps == 2) {
+    convolve8_horiz_2tap_neon(src + 3, src_stride, dst, dst_stride, filter_x, w,
+                              h);
+  } else if (filter_taps == 4) {
     convolve8_horiz_4tap_neon_dotprod(src + 2, src_stride, dst, dst_stride,
                                       filter_x, w, h);
   } else {
diff --git a/aom_dsp/arm/aom_convolve8_neon_i8mm.c b/aom_dsp/arm/aom_convolve8_neon_i8mm.c
index da0210ac3..972763999 100644
--- a/aom_dsp/arm/aom_convolve8_neon_i8mm.c
+++ b/aom_dsp/arm/aom_convolve8_neon_i8mm.c
@@ -19,6 +19,7 @@
 #include "aom/aom_integer.h"
 #include "aom_dsp/aom_dsp_common.h"
 #include "aom_dsp/aom_filter.h"
+#include "aom_dsp/arm/aom_convolve8_neon.h"
 #include "aom_dsp/arm/aom_filter.h"
 #include "aom_dsp/arm/mem_neon.h"
 #include "aom_dsp/arm/transpose_neon.h"
@@ -243,7 +244,12 @@ void aom_convolve8_horiz_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride,
 
   src -= ((SUBPEL_TAPS / 2) - 1);
 
-  if (get_filter_taps_convolve8(filter_x) <= 4) {
+  int filter_taps = get_filter_taps_convolve8(filter_x);
+
+  if (filter_taps == 2) {
+    convolve8_horiz_2tap_neon(src + 3, src_stride, dst, dst_stride, filter_x, w,
+                              h);
+  } else if (filter_taps == 4) {
     convolve8_horiz_4tap_neon_i8mm(src + 2, src_stride, dst, dst_stride,
                                    filter_x, w, h);
   } else {