aom: Add 4-tap filter path for Neon HBD vert compound convolution

From 80123cb35215ab8a775d75ac4817090803a92d02 Mon Sep 17 00:00:00 2001
From: Salome Thirot <[EMAIL REDACTED]>
Date: Tue, 19 Mar 2024 11:22:47 +0000
Subject: [PATCH] Add 4-tap filter path for Neon HBD vert compound convolution

Add 4-tap filter specialization path for
av1_highbd_dist_wtd_convolve_y_neon, for 12 and 8/10 bitdepth. This
gives up to 30% uplift over using the 6-tap path.

Change-Id: Ic7fc8bc12c184a94d41c799b1d54e1c2befffcea
---
 .../arm/highbd_compound_convolve_neon.c       | 230 +++++++++++++++++-
 1 file changed, 226 insertions(+), 4 deletions(-)

diff --git a/av1/common/arm/highbd_compound_convolve_neon.c b/av1/common/arm/highbd_compound_convolve_neon.c
index c93a1d4e2..9247ded6b 100644
--- a/av1/common/arm/highbd_compound_convolve_neon.c
+++ b/av1/common/arm/highbd_compound_convolve_neon.c
@@ -711,6 +711,212 @@ static INLINE void highbd_dist_wtd_convolve_y_6tap_neon(
   }
 }
 
+static INLINE uint16x4_t highbd_12_convolve4_4(
+    const int16x4_t s0, const int16x4_t s1, const int16x4_t s2,
+    const int16x4_t s3, const int16x4_t filter, const int32x4_t offset) {
+  int32x4_t sum = vmlal_lane_s16(offset, s0, filter, 0);
+  sum = vmlal_lane_s16(sum, s1, filter, 1);
+  sum = vmlal_lane_s16(sum, s2, filter, 2);
+  sum = vmlal_lane_s16(sum, s3, filter, 3);
+
+  return vqshrun_n_s32(sum, ROUND0_BITS + 2);
+}
+
+static INLINE uint16x8_t highbd_12_convolve4_8(
+    const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,
+    const int16x8_t s3, const int16x4_t filter, const int32x4_t offset) {
+  int32x4_t sum0 = vmlal_lane_s16(offset, vget_low_s16(s0), filter, 0);
+  sum0 = vmlal_lane_s16(sum0, vget_low_s16(s1), filter, 1);
+  sum0 = vmlal_lane_s16(sum0, vget_low_s16(s2), filter, 2);
+  sum0 = vmlal_lane_s16(sum0, vget_low_s16(s3), filter, 3);
+
+  int32x4_t sum1 = vmlal_lane_s16(offset, vget_high_s16(s0), filter, 0);
+  sum1 = vmlal_lane_s16(sum1, vget_high_s16(s1), filter, 1);
+  sum1 = vmlal_lane_s16(sum1, vget_high_s16(s2), filter, 2);
+  sum1 = vmlal_lane_s16(sum1, vget_high_s16(s3), filter, 3);
+
+  return vcombine_u16(vqshrun_n_s32(sum0, ROUND0_BITS + 2),
+                      vqshrun_n_s32(sum1, ROUND0_BITS + 2));
+}
+
+static INLINE void highbd_12_dist_wtd_convolve_y_4tap_neon(
+    const uint16_t *src_ptr, int src_stride, uint16_t *dst_ptr, int dst_stride,
+    int w, int h, const int16_t *y_filter_ptr, const int offset) {
+  const int16x4_t y_filter = vld1_s16(y_filter_ptr + 2);
+  const int32x4_t offset_vec = vdupq_n_s32(offset);
+
+  if (w == 4) {
+    const int16_t *s = (const int16_t *)src_ptr;
+    uint16_t *d = dst_ptr;
+
+    int16x4_t s0, s1, s2;
+    load_s16_4x3(s, src_stride, &s0, &s1, &s2);
+    s += 3 * src_stride;
+
+    do {
+      int16x4_t s3, s4, s5, s6;
+      load_s16_4x4(s, src_stride, &s3, &s4, &s5, &s6);
+
+      uint16x4_t d0 =
+          highbd_12_convolve4_4(s0, s1, s2, s3, y_filter, offset_vec);
+      uint16x4_t d1 =
+          highbd_12_convolve4_4(s1, s2, s3, s4, y_filter, offset_vec);
+      uint16x4_t d2 =
+          highbd_12_convolve4_4(s2, s3, s4, s5, y_filter, offset_vec);
+      uint16x4_t d3 =
+          highbd_12_convolve4_4(s3, s4, s5, s6, y_filter, offset_vec);
+
+      store_u16_4x4(d, dst_stride, d0, d1, d2, d3);
+
+      s0 = s4;
+      s1 = s5;
+      s2 = s6;
+
+      s += 4 * src_stride;
+      d += 4 * dst_stride;
+      h -= 4;
+    } while (h != 0);
+  } else {
+    do {
+      int height = h;
+      const int16_t *s = (const int16_t *)src_ptr;
+      uint16_t *d = dst_ptr;
+
+      int16x8_t s0, s1, s2;
+      load_s16_8x3(s, src_stride, &s0, &s1, &s2);
+      s += 3 * src_stride;
+
+      do {
+        int16x8_t s3, s4, s5, s6;
+        load_s16_8x4(s, src_stride, &s3, &s4, &s5, &s6);
+
+        uint16x8_t d0 =
+            highbd_12_convolve4_8(s0, s1, s2, s3, y_filter, offset_vec);
+        uint16x8_t d1 =
+            highbd_12_convolve4_8(s1, s2, s3, s4, y_filter, offset_vec);
+        uint16x8_t d2 =
+            highbd_12_convolve4_8(s2, s3, s4, s5, y_filter, offset_vec);
+        uint16x8_t d3 =
+            highbd_12_convolve4_8(s3, s4, s5, s6, y_filter, offset_vec);
+
+        store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
+
+        s0 = s4;
+        s1 = s5;
+        s2 = s6;
+
+        s += 4 * src_stride;
+        d += 4 * dst_stride;
+        height -= 4;
+      } while (height != 0);
+      src_ptr += 8;
+      dst_ptr += 8;
+      w -= 8;
+    } while (w != 0);
+  }
+}
+
+static INLINE uint16x4_t highbd_convolve4_4(
+    const int16x4_t s0, const int16x4_t s1, const int16x4_t s2,
+    const int16x4_t s3, const int16x4_t filter, const int32x4_t offset) {
+  int32x4_t sum = vmlal_lane_s16(offset, s0, filter, 0);
+  sum = vmlal_lane_s16(sum, s1, filter, 1);
+  sum = vmlal_lane_s16(sum, s2, filter, 2);
+  sum = vmlal_lane_s16(sum, s3, filter, 3);
+
+  return vqshrun_n_s32(sum, ROUND0_BITS);
+}
+
+static INLINE uint16x8_t highbd_convolve4_8(
+    const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,
+    const int16x8_t s3, const int16x4_t filter, const int32x4_t offset) {
+  int32x4_t sum0 = vmlal_lane_s16(offset, vget_low_s16(s0), filter, 0);
+  sum0 = vmlal_lane_s16(sum0, vget_low_s16(s1), filter, 1);
+  sum0 = vmlal_lane_s16(sum0, vget_low_s16(s2), filter, 2);
+  sum0 = vmlal_lane_s16(sum0, vget_low_s16(s3), filter, 3);
+
+  int32x4_t sum1 = vmlal_lane_s16(offset, vget_high_s16(s0), filter, 0);
+  sum1 = vmlal_lane_s16(sum1, vget_high_s16(s1), filter, 1);
+  sum1 = vmlal_lane_s16(sum1, vget_high_s16(s2), filter, 2);
+  sum1 = vmlal_lane_s16(sum1, vget_high_s16(s3), filter, 3);
+
+  return vcombine_u16(vqshrun_n_s32(sum0, ROUND0_BITS),
+                      vqshrun_n_s32(sum1, ROUND0_BITS));
+}
+
+static INLINE void highbd_dist_wtd_convolve_y_4tap_neon(
+    const uint16_t *src_ptr, int src_stride, uint16_t *dst_ptr, int dst_stride,
+    int w, int h, const int16_t *y_filter_ptr, const int offset) {
+  const int16x4_t y_filter = vld1_s16(y_filter_ptr + 2);
+  const int32x4_t offset_vec = vdupq_n_s32(offset);
+
+  if (w == 4) {
+    const int16_t *s = (const int16_t *)src_ptr;
+    uint16_t *d = dst_ptr;
+
+    int16x4_t s0, s1, s2;
+    load_s16_4x3(s, src_stride, &s0, &s1, &s2);
+    s += 3 * src_stride;
+
+    do {
+      int16x4_t s3, s4, s5, s6;
+      load_s16_4x4(s, src_stride, &s3, &s4, &s5, &s6);
+
+      uint16x4_t d0 = highbd_convolve4_4(s0, s1, s2, s3, y_filter, offset_vec);
+      uint16x4_t d1 = highbd_convolve4_4(s1, s2, s3, s4, y_filter, offset_vec);
+      uint16x4_t d2 = highbd_convolve4_4(s2, s3, s4, s5, y_filter, offset_vec);
+      uint16x4_t d3 = highbd_convolve4_4(s3, s4, s5, s6, y_filter, offset_vec);
+
+      store_u16_4x4(d, dst_stride, d0, d1, d2, d3);
+
+      s0 = s4;
+      s1 = s5;
+      s2 = s6;
+
+      s += 4 * src_stride;
+      d += 4 * dst_stride;
+      h -= 4;
+    } while (h != 0);
+  } else {
+    do {
+      int height = h;
+      const int16_t *s = (const int16_t *)src_ptr;
+      uint16_t *d = dst_ptr;
+
+      int16x8_t s0, s1, s2;
+      load_s16_8x3(s, src_stride, &s0, &s1, &s2);
+      s += 3 * src_stride;
+
+      do {
+        int16x8_t s3, s4, s5, s6;
+        load_s16_8x4(s, src_stride, &s3, &s4, &s5, &s6);
+
+        uint16x8_t d0 =
+            highbd_convolve4_8(s0, s1, s2, s3, y_filter, offset_vec);
+        uint16x8_t d1 =
+            highbd_convolve4_8(s1, s2, s3, s4, y_filter, offset_vec);
+        uint16x8_t d2 =
+            highbd_convolve4_8(s2, s3, s4, s5, y_filter, offset_vec);
+        uint16x8_t d3 =
+            highbd_convolve4_8(s3, s4, s5, s6, y_filter, offset_vec);
+
+        store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
+
+        s0 = s4;
+        s1 = s5;
+        s2 = s6;
+
+        s += 4 * src_stride;
+        d += 4 * dst_stride;
+        height -= 4;
+      } while (height != 0);
+      src_ptr += 8;
+      dst_ptr += 8;
+      w -= 8;
+    } while (w != 0);
+  }
+}
+
 static INLINE void highbd_12_dist_wtd_convolve_y_8tap_neon(
     const uint16_t *src_ptr, int src_stride, uint16_t *dst_ptr, int dst_stride,
     int w, int h, const int16_t *y_filter_ptr, const int offset) {
@@ -899,7 +1105,11 @@ void av1_highbd_dist_wtd_convolve_y_neon(
 
   if (bd == 12) {
     if (conv_params->do_average) {
-      if (y_filter_taps <= 6) {
+      if (y_filter_taps <= 4) {
+        highbd_12_dist_wtd_convolve_y_4tap_neon(
+            src + 2 * src_stride, src_stride, im_block, im_stride, w, h,
+            y_filter_ptr, round_offset_conv);
+      } else if (y_filter_taps == 6) {
         highbd_12_dist_wtd_convolve_y_6tap_neon(
             src + src_stride, src_stride, im_block, im_stride, w, h,
             y_filter_ptr, round_offset_conv);
@@ -916,7 +1126,11 @@ void av1_highbd_dist_wtd_convolve_y_neon(
                                 conv_params);
       }
     } else {
-      if (y_filter_taps <= 6) {
+      if (y_filter_taps <= 4) {
+        highbd_12_dist_wtd_convolve_y_4tap_neon(
+            src + 2 * src_stride, src_stride, dst16, dst16_stride, w, h,
+            y_filter_ptr, round_offset_conv);
+      } else if (y_filter_taps == 6) {
         highbd_12_dist_wtd_convolve_y_6tap_neon(
             src + src_stride, src_stride, dst16, dst16_stride, w, h,
             y_filter_ptr, round_offset_conv);
@@ -928,7 +1142,11 @@ void av1_highbd_dist_wtd_convolve_y_neon(
     }
   } else {
     if (conv_params->do_average) {
-      if (y_filter_taps <= 6) {
+      if (y_filter_taps <= 4) {
+        highbd_dist_wtd_convolve_y_4tap_neon(src + 2 * src_stride, src_stride,
+                                             im_block, im_stride, w, h,
+                                             y_filter_ptr, round_offset_conv);
+      } else if (y_filter_taps == 6) {
         highbd_dist_wtd_convolve_y_6tap_neon(src + src_stride, src_stride,
                                              im_block, im_stride, w, h,
                                              y_filter_ptr, round_offset_conv);
@@ -945,7 +1163,11 @@ void av1_highbd_dist_wtd_convolve_y_neon(
                              conv_params, bd);
       }
     } else {
-      if (y_filter_taps <= 6) {
+      if (y_filter_taps <= 4) {
+        highbd_dist_wtd_convolve_y_4tap_neon(src + 2 * src_stride, src_stride,
+                                             dst16, dst16_stride, w, h,
+                                             y_filter_ptr, round_offset_conv);
+      } else if (y_filter_taps == 6) {
         highbd_dist_wtd_convolve_y_6tap_neon(src + src_stride, src_stride,
                                              dst16, dst16_stride, w, h,
                                              y_filter_ptr, round_offset_conv);