aom: Add 4-tap path for av1_convolve_2d_horiz_sr_neon

From c349b1ebd23dce55cf23c5f8fa56d5a9416971b0 Mon Sep 17 00:00:00 2001
From: Salome Thirot <[EMAIL REDACTED]>
Date: Thu, 25 Apr 2024 15:16:17 +0100
Subject: [PATCH] Add 4-tap path for av1_convolve_2d_horiz_sr_neon

Add 4-tap specialization for the horizontal pass of
av1_convolve_2d_sr_neon. This gives up to 30% uplift over using the
8-tap path.

Change-Id: I6c9b7be0e90661a36cef95db1daedc2fabd6a31e
---
 aom_dsp/arm/mem_neon.h         |  22 ++
 av1/common/arm/convolve_neon.c | 363 ++++++++++++++++++++-------------
 2 files changed, 243 insertions(+), 142 deletions(-)

diff --git a/aom_dsp/arm/mem_neon.h b/aom_dsp/arm/mem_neon.h
index b1f6ebeb1..46aa16e61 100644
--- a/aom_dsp/arm/mem_neon.h
+++ b/aom_dsp/arm/mem_neon.h
@@ -654,6 +654,13 @@ static INLINE void store_s16_8x4(int16_t *s, ptrdiff_t dst_stride,
   vst1q_s16(s, s3);
 }
 
+static INLINE void store_s16_8x2(int16_t *s, ptrdiff_t dst_stride,
+                                 const int16x8_t s0, const int16x8_t s1) {
+  vst1q_s16(s, s0);
+  s += dst_stride;
+  vst1q_s16(s, s1);
+}
+
 static INLINE void load_u8_8x11(const uint8_t *s, ptrdiff_t p,
                                 uint8x8_t *const s0, uint8x8_t *const s1,
                                 uint8x8_t *const s2, uint8x8_t *const s3,
@@ -1248,6 +1255,12 @@ static INLINE uint8x8_t load_u8_gather_s16_x8(const uint8_t *src,
     memcpy(dst, &a, 8);                                            \
   } while (0)
 
+#define store_s16_4x1_lane(dst, src, lane)                        \
+  do {                                                            \
+    int64_t a = vgetq_lane_s64(vreinterpretq_s64_s16(src), lane); \
+    memcpy(dst, &a, 8);                                           \
+  } while (0)
+
 // Store the low 16-bits from a single vector.
 static INLINE void store_u8_2x1(uint8_t *dst, const uint8x8_t src) {
   store_u8_2x1_lane(dst, src, 0);
@@ -1307,9 +1320,18 @@ static INLINE void store_u16x4_strided_x2(uint16_t *dst, uint32_t dst_stride,
   store_u16_4x1_lane(dst, src, 1);
 }
 
+// Store two blocks of 64-bits from a single vector.
+static INLINE void store_s16x4_strided_x2(int16_t *dst, int32_t dst_stride,
+                                          int16x8_t src) {
+  store_s16_4x1_lane(dst, src, 0);
+  dst += dst_stride;
+  store_s16_4x1_lane(dst, src, 1);
+}
+
 #undef store_u8_2x1_lane
 #undef store_u8_4x1_lane
 #undef store_u16_2x1_lane
 #undef store_u16_4x1_lane
+#undef store_s16_4x1_lane
 
 #endif  // AOM_AOM_DSP_ARM_MEM_NEON_H_
diff --git a/av1/common/arm/convolve_neon.c b/av1/common/arm/convolve_neon.c
index bd11b7cf2..72a85893e 100644
--- a/av1/common/arm/convolve_neon.c
+++ b/av1/common/arm/convolve_neon.c
@@ -1307,18 +1307,122 @@ static INLINE void convolve_2d_sr_horiz_12tap_neon(
   } while (--h != 0);
 }
 
-static INLINE int16x4_t convolve4_4_2d_h(const int16x4_t s0, const int16x4_t s1,
-                                         const int16x4_t s2, const int16x4_t s3,
+static INLINE int16x8_t convolve4_8_2d_h(const int16x8_t s0, const int16x8_t s1,
+                                         const int16x8_t s2, const int16x8_t s3,
                                          const int16x4_t filter,
-                                         const int16x4_t horiz_const) {
-  int16x4_t sum = horiz_const;
-  sum = vmla_lane_s16(sum, s0, filter, 0);
-  sum = vmla_lane_s16(sum, s1, filter, 1);
-  sum = vmla_lane_s16(sum, s2, filter, 2);
-  sum = vmla_lane_s16(sum, s3, filter, 3);
+                                         const int16x8_t horiz_const) {
+  int16x8_t sum = vmlaq_lane_s16(horiz_const, s0, filter, 0);
+  sum = vmlaq_lane_s16(sum, s1, filter, 1);
+  sum = vmlaq_lane_s16(sum, s2, filter, 2);
+  sum = vmlaq_lane_s16(sum, s3, filter, 3);
+  // We halved the filter values so -1 from right shift.
+  return vshrq_n_s16(sum, ROUND0_BITS - 1);
+}
 
-  // We halved the convolution filter values so -1 from the right shift.
-  return vshr_n_s16(sum, ROUND0_BITS - 1);
+static INLINE void convolve_2d_sr_horiz_4tap_neon(
+    const uint8_t *src, ptrdiff_t src_stride, int16_t *dst,
+    ptrdiff_t dst_stride, int w, int h, const int16_t *filter_x) {
+  const int bd = 8;
+  // All filter values are even, halve to reduce intermediate precision
+  // requirements.
+  const int16x4_t filter = vshr_n_s16(vld1_s16(filter_x + 2), 1);
+
+  // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
+  // shifts - which are generally faster than rounding shifts on modern CPUs.
+  // (The extra -1 is needed because we halved the filter values.)
+  const int16x8_t horiz_const = vdupq_n_s16((1 << (bd + FILTER_BITS - 2)) +
+                                            (1 << ((ROUND0_BITS - 1) - 1)));
+
+  if (w == 4) {
+    do {
+      uint8x8_t t01[4];
+      t01[0] = load_unaligned_u8(src + 0, (int)src_stride);
+      t01[1] = load_unaligned_u8(src + 1, (int)src_stride);
+      t01[2] = load_unaligned_u8(src + 2, (int)src_stride);
+      t01[3] = load_unaligned_u8(src + 3, (int)src_stride);
+
+      int16x8_t s01[4];
+      s01[0] = vreinterpretq_s16_u16(vmovl_u8(t01[0]));
+      s01[1] = vreinterpretq_s16_u16(vmovl_u8(t01[1]));
+      s01[2] = vreinterpretq_s16_u16(vmovl_u8(t01[2]));
+      s01[3] = vreinterpretq_s16_u16(vmovl_u8(t01[3]));
+
+      int16x8_t d01 =
+          convolve4_8_2d_h(s01[0], s01[1], s01[2], s01[3], filter, horiz_const);
+
+      store_s16x4_strided_x2(dst, (int)dst_stride, d01);
+
+      src += 2 * src_stride;
+      dst += 2 * dst_stride;
+      h -= 2;
+    } while (h > 0);
+  } else {
+    do {
+      int width = w;
+      const uint8_t *s = src;
+      int16_t *d = dst;
+
+      do {
+        uint8x8_t t0[4], t1[4];
+        load_u8_8x4(s + 0 * src_stride, 1, &t0[0], &t0[1], &t0[2], &t0[3]);
+        load_u8_8x4(s + 1 * src_stride, 1, &t1[0], &t1[1], &t1[2], &t1[3]);
+
+        int16x8_t s0[4];
+        s0[0] = vreinterpretq_s16_u16(vmovl_u8(t0[0]));
+        s0[1] = vreinterpretq_s16_u16(vmovl_u8(t0[1]));
+        s0[2] = vreinterpretq_s16_u16(vmovl_u8(t0[2]));
+        s0[3] = vreinterpretq_s16_u16(vmovl_u8(t0[3]));
+
+        int16x8_t s1[4];
+        s1[0] = vreinterpretq_s16_u16(vmovl_u8(t1[0]));
+        s1[1] = vreinterpretq_s16_u16(vmovl_u8(t1[1]));
+        s1[2] = vreinterpretq_s16_u16(vmovl_u8(t1[2]));
+        s1[3] = vreinterpretq_s16_u16(vmovl_u8(t1[3]));
+
+        int16x8_t d0 =
+            convolve4_8_2d_h(s0[0], s0[1], s0[2], s0[3], filter, horiz_const);
+        int16x8_t d1 =
+            convolve4_8_2d_h(s1[0], s1[1], s1[2], s1[3], filter, horiz_const);
+
+        store_s16_8x2(d, dst_stride, d0, d1);
+
+        s += 8;
+        d += 8;
+        width -= 8;
+      } while (width != 0);
+      src += 2 * src_stride;
+      dst += 2 * dst_stride;
+      h -= 2;
+    } while (h > 2);
+
+    do {
+      const uint8_t *s = src;
+      int16_t *d = dst;
+      int width = w;
+
+      do {
+        uint8x8_t t0[4];
+        load_u8_8x4(s, 1, &t0[0], &t0[1], &t0[2], &t0[3]);
+
+        int16x8_t s0[4];
+        s0[0] = vreinterpretq_s16_u16(vmovl_u8(t0[0]));
+        s0[1] = vreinterpretq_s16_u16(vmovl_u8(t0[1]));
+        s0[2] = vreinterpretq_s16_u16(vmovl_u8(t0[2]));
+        s0[3] = vreinterpretq_s16_u16(vmovl_u8(t0[3]));
+
+        int16x8_t d0 =
+            convolve4_8_2d_h(s0[0], s0[1], s0[2], s0[3], filter, horiz_const);
+
+        vst1q_s16(d, d0);
+
+        s += 8;
+        d += 8;
+        width -= 8;
+      } while (width != 0);
+      src += src_stride;
+      dst += dst_stride;
+    } while (--h != 0);
+  }
 }
 
 static INLINE int16x8_t convolve8_8_2d_h(const int16x8_t s0, const int16x8_t s1,
@@ -1344,10 +1448,9 @@ static INLINE int16x8_t convolve8_8_2d_h(const int16x8_t s0, const int16x8_t s1,
   return vshrq_n_s16(sum, ROUND0_BITS - 1);
 }
 
-static INLINE void convolve_2d_sr_horiz_neon(const uint8_t *src, int src_stride,
-                                             int16_t *im_block, int im_stride,
-                                             int w, int im_h,
-                                             const int16_t *x_filter_ptr) {
+static INLINE void convolve_2d_sr_horiz_8tap_neon(
+    const uint8_t *src, int src_stride, int16_t *im_block, int im_stride, int w,
+    int im_h, const int16_t *x_filter_ptr) {
   const int bd = 8;
 
   const uint8_t *src_ptr = src;
@@ -1355,149 +1458,119 @@ static INLINE void convolve_2d_sr_horiz_neon(const uint8_t *src, int src_stride,
   int dst_stride = im_stride;
   int height = im_h;
 
-  if (w <= 4) {
-    // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
-    // shifts - which are generally faster than rounding shifts on modern CPUs.
-    // (The extra -1 is needed because we halved the filter values.)
-    const int16x4_t horiz_const = vdup_n_s16((1 << (bd + FILTER_BITS - 2)) +
-                                             (1 << ((ROUND0_BITS - 1) - 1)));
-    // 4-tap filters are used for blocks having width <= 4.
-    // Filter values are even, so halve to reduce intermediate precision reqs.
-    const int16x4_t x_filter = vshr_n_s16(vld1_s16(x_filter_ptr + 2), 1);
-
-    src_ptr += 2;
-
-    do {
-      uint8x8_t t0 = vld1_u8(src_ptr);  // a0 a1 a2 a3 a4 a5 a6 a7
-      int16x4_t s0 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t0)));
-      int16x4_t s4 = vget_high_s16(vreinterpretq_s16_u16(vmovl_u8(t0)));
-
-      int16x4_t s1 = vext_s16(s0, s4, 1);  // a1 a2 a3 a4
-      int16x4_t s2 = vext_s16(s0, s4, 2);  // a2 a3 a4 a5
-      int16x4_t s3 = vext_s16(s0, s4, 3);  // a3 a4 a5 a6
+  // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
+  // shifts - which are generally faster than rounding shifts on modern CPUs.
+  // (The extra -1 is needed because we halved the filter values.)
+  const int16x8_t horiz_const = vdupq_n_s16((1 << (bd + FILTER_BITS - 2)) +
+                                            (1 << ((ROUND0_BITS - 1) - 1)));
+  // Filter values are even, so halve to reduce intermediate precision reqs.
+  const int16x8_t x_filter = vshrq_n_s16(vld1q_s16(x_filter_ptr), 1);
 
-      int16x4_t d0 = convolve4_4_2d_h(s0, s1, s2, s3, x_filter, horiz_const);
+#if AOM_ARCH_AARCH64
+  while (height > 8) {
+    const uint8_t *s = src_ptr;
+    int16_t *d = dst_ptr;
+    int width = w;
 
-      vst1_s16(dst_ptr, d0);
+    uint8x8_t t0, t1, t2, t3, t4, t5, t6, t7;
+    load_u8_8x8(s, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7);
+    transpose_elems_inplace_u8_8x8(&t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7);
 
-      src_ptr += src_stride;
-      dst_ptr += dst_stride;
-    } while (--height != 0);
-  } else {
-    // A shim of 1 << ((ROUND0_BITS - 1) - 1) enables us to use non-rounding
-    // shifts - which are generally faster than rounding shifts on modern CPUs.
-    // (The extra -1 is needed because we halved the filter values.)
-    const int16x8_t horiz_const = vdupq_n_s16((1 << (bd + FILTER_BITS - 2)) +
-                                              (1 << ((ROUND0_BITS - 1) - 1)));
-    // Filter values are even, so halve to reduce intermediate precision reqs.
-    const int16x8_t x_filter = vshrq_n_s16(vld1q_s16(x_filter_ptr), 1);
+    int16x8_t s0 = vreinterpretq_s16_u16(vmovl_u8(t0));
+    int16x8_t s1 = vreinterpretq_s16_u16(vmovl_u8(t1));
+    int16x8_t s2 = vreinterpretq_s16_u16(vmovl_u8(t2));
+    int16x8_t s3 = vreinterpretq_s16_u16(vmovl_u8(t3));
+    int16x8_t s4 = vreinterpretq_s16_u16(vmovl_u8(t4));
+    int16x8_t s5 = vreinterpretq_s16_u16(vmovl_u8(t5));
+    int16x8_t s6 = vreinterpretq_s16_u16(vmovl_u8(t6));
 
-#if AOM_ARCH_AARCH64
-    while (height > 8) {
-      const uint8_t *s = src_ptr;
-      int16_t *d = dst_ptr;
-      int width = w;
+    s += 7;
 
-      uint8x8_t t0, t1, t2, t3, t4, t5, t6, t7;
+    do {
       load_u8_8x8(s, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7);
-      transpose_elems_inplace_u8_8x8(&t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7);
 
-      int16x8_t s0 = vreinterpretq_s16_u16(vmovl_u8(t0));
-      int16x8_t s1 = vreinterpretq_s16_u16(vmovl_u8(t1));
-      int16x8_t s2 = vreinterpretq_s16_u16(vmovl_u8(t2));
-      int16x8_t s3 = vreinterpretq_s16_u16(vmovl_u8(t3));
-      int16x8_t s4 = vreinterpretq_s16_u16(vmovl_u8(t4));
-      int16x8_t s5 = vreinterpretq_s16_u16(vmovl_u8(t5));
-      int16x8_t s6 = vreinterpretq_s16_u16(vmovl_u8(t6));
+      transpose_elems_inplace_u8_8x8(&t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7);
 
-      s += 7;
+      int16x8_t s7 = vreinterpretq_s16_u16(vmovl_u8(t0));
+      int16x8_t s8 = vreinterpretq_s16_u16(vmovl_u8(t1));
+      int16x8_t s9 = vreinterpretq_s16_u16(vmovl_u8(t2));
+      int16x8_t s10 = vreinterpretq_s16_u16(vmovl_u8(t3));
+      int16x8_t s11 = vreinterpretq_s16_u16(vmovl_u8(t4));
+      int16x8_t s12 = vreinterpretq_s16_u16(vmovl_u8(t5));
+      int16x8_t s13 = vreinterpretq_s16_u16(vmovl_u8(t6));
+      int16x8_t s14 = vreinterpretq_s16_u16(vmovl_u8(t7));
+
+      int16x8_t d0 = convolve8_8_2d_h(s0, s1, s2, s3, s4, s5, s6, s7, x_filter,
+                                      horiz_const);
+      int16x8_t d1 = convolve8_8_2d_h(s1, s2, s3, s4, s5, s6, s7, s8, x_filter,
+                                      horiz_const);
+      int16x8_t d2 = convolve8_8_2d_h(s2, s3, s4, s5, s6, s7, s8, s9, x_filter,
+                                      horiz_const);
+      int16x8_t d3 = convolve8_8_2d_h(s3, s4, s5, s6, s7, s8, s9, s10, x_filter,
+                                      horiz_const);
+      int16x8_t d4 = convolve8_8_2d_h(s4, s5, s6, s7, s8, s9, s10, s11,
+                                      x_filter, horiz_const);
+      int16x8_t d5 = convolve8_8_2d_h(s5, s6, s7, s8, s9, s10, s11, s12,
+                                      x_filter, horiz_const);
+      int16x8_t d6 = convolve8_8_2d_h(s6, s7, s8, s9, s10, s11, s12, s13,
+                                      x_filter, horiz_const);
+      int16x8_t d7 = convolve8_8_2d_h(s7, s8, s9, s10, s11, s12, s13, s14,
+                                      x_filter, horiz_const);
+
+      transpose_elems_inplace_s16_8x8(&d0, &d1, &d2, &d3, &d4, &d5, &d6, &d7);
+
+      store_s16_8x8(d, dst_stride, d0, d1, d2, d3, d4, d5, d6, d7);
 
-      do {
-        load_u8_8x8(s, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7);
-
-        transpose_elems_inplace_u8_8x8(&t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7);
-
-        int16x8_t s7 = vreinterpretq_s16_u16(vmovl_u8(t0));
-        int16x8_t s8 = vreinterpretq_s16_u16(vmovl_u8(t1));
-        int16x8_t s9 = vreinterpretq_s16_u16(vmovl_u8(t2));
-        int16x8_t s10 = vreinterpretq_s16_u16(vmovl_u8(t3));
-        int16x8_t s11 = vreinterpretq_s16_u16(vmovl_u8(t4));
-        int16x8_t s12 = vreinterpretq_s16_u16(vmovl_u8(t5));
-        int16x8_t s13 = vreinterpretq_s16_u16(vmovl_u8(t6));
-        int16x8_t s14 = vreinterpretq_s16_u16(vmovl_u8(t7));
-
-        int16x8_t d0 = convolve8_8_2d_h(s0, s1, s2, s3, s4, s5, s6, s7,
-                                        x_filter, horiz_const);
-        int16x8_t d1 = convolve8_8_2d_h(s1, s2, s3, s4, s5, s6, s7, s8,
-                                        x_filter, horiz_const);
-        int16x8_t d2 = convolve8_8_2d_h(s2, s3, s4, s5, s6, s7, s8, s9,
-                                        x_filter, horiz_const);
-        int16x8_t d3 = convolve8_8_2d_h(s3, s4, s5, s6, s7, s8, s9, s10,
-                                        x_filter, horiz_const);
-        int16x8_t d4 = convolve8_8_2d_h(s4, s5, s6, s7, s8, s9, s10, s11,
-                                        x_filter, horiz_const);
-        int16x8_t d5 = convolve8_8_2d_h(s5, s6, s7, s8, s9, s10, s11, s12,
-                                        x_filter, horiz_const);
-        int16x8_t d6 = convolve8_8_2d_h(s6, s7, s8, s9, s10, s11, s12, s13,
-                                        x_filter, horiz_const);
-        int16x8_t d7 = convolve8_8_2d_h(s7, s8, s9, s10, s11, s12, s13, s14,
-                                        x_filter, horiz_const);
-
-        transpose_elems_inplace_s16_8x8(&d0, &d1, &d2, &d3, &d4, &d5, &d6, &d7);
-
-        store_s16_8x8(d, dst_stride, d0, d1, d2, d3, d4, d5, d6, d7);
-
-        s0 = s8;
-        s1 = s9;
-        s2 = s10;
-        s3 = s11;
-        s4 = s12;
-        s5 = s13;
-        s6 = s14;
-        s += 8;
-        d += 8;
-        width -= 8;
-      } while (width != 0);
-      src_ptr += 8 * src_stride;
-      dst_ptr += 8 * dst_stride;
-      height -= 8;
-    }
+      s0 = s8;
+      s1 = s9;
+      s2 = s10;
+      s3 = s11;
+      s4 = s12;
+      s5 = s13;
+      s6 = s14;
+      s += 8;
+      d += 8;
+      width -= 8;
+    } while (width != 0);
+    src_ptr += 8 * src_stride;
+    dst_ptr += 8 * dst_stride;
+    height -= 8;
+  }
 #endif  // AOM_ARCH_AARCH64
 
-    do {
-      const uint8_t *s = src_ptr;
-      int16_t *d = dst_ptr;
-      int width = w;
+  do {
+    const uint8_t *s = src_ptr;
+    int16_t *d = dst_ptr;
+    int width = w;
 
-      uint8x8_t t0 = vld1_u8(s);  // a0 a1 a2 a3 a4 a5 a6 a7
-      int16x8_t s0 = vreinterpretq_s16_u16(vmovl_u8(t0));
+    uint8x8_t t0 = vld1_u8(s);  // a0 a1 a2 a3 a4 a5 a6 a7
+    int16x8_t s0 = vreinterpretq_s16_u16(vmovl_u8(t0));
 
-      do {
-        uint8x8_t t1 = vld1_u8(s + 8);  // a8 a9 a10 a11 a12 a13 a14 a15
-        int16x8_t s8 = vreinterpretq_s16_u16(vmovl_u8(t1));
+    do {
+      uint8x8_t t1 = vld1_u8(s + 8);  // a8 a9 a10 a11 a12 a13 a14 a15
+      int16x8_t s8 = vreinterpretq_s16_u16(vmovl_u8(t1));
 
-        int16x8_t s1 = vextq_s16(s0, s8, 1);  // a1 a2 a3 a4 a5 a6 a7 a8
-        int16x8_t s2 = vextq_s16(s0, s8, 2);  // a2 a3 a4 a5 a6 a7 a8 a9
-        int16x8_t s3 = vextq_s16(s0, s8, 3);  // a3 a4 a5 a6 a7 a8 a9 a10
-        int16x8_t s4 = vextq_s16(s0, s8, 4);  // a4 a5 a6 a7 a8 a9 a10 a11
-        int16x8_t s5 = vextq_s16(s0, s8, 5);  // a5 a6 a7 a8 a9 a10 a11 a12
-        int16x8_t s6 = vextq_s16(s0, s8, 6);  // a6 a7 a8 a9 a10 a11 a12 a13
-        int16x8_t s7 = vextq_s16(s0, s8, 7);  // a7 a8 a9 a10 a11 a12 a13 a14
+      int16x8_t s1 = vextq_s16(s0, s8, 1);  // a1 a2 a3 a4 a5 a6 a7 a8
+      int16x8_t s2 = vextq_s16(s0, s8, 2);  // a2 a3 a4 a5 a6 a7 a8 a9
+      int16x8_t s3 = vextq_s16(s0, s8, 3);  // a3 a4 a5 a6 a7 a8 a9 a10
+      int16x8_t s4 = vextq_s16(s0, s8, 4);  // a4 a5 a6 a7 a8 a9 a10 a11
+      int16x8_t s5 = vextq_s16(s0, s8, 5);  // a5 a6 a7 a8 a9 a10 a11 a12
+      int16x8_t s6 = vextq_s16(s0, s8, 6);  // a6 a7 a8 a9 a10 a11 a12 a13
+      int16x8_t s7 = vextq_s16(s0, s8, 7);  // a7 a8 a9 a10 a11 a12 a13 a14
 
-        int16x8_t d0 = convolve8_8_2d_h(s0, s1, s2, s3, s4, s5, s6, s7,
-                                        x_filter, horiz_const);
+      int16x8_t d0 = convolve8_8_2d_h(s0, s1, s2, s3, s4, s5, s6, s7, x_filter,
+                                      horiz_const);
 
-        vst1q_s16(d, d0);
+      vst1q_s16(d, d0);
 
-        s0 = s8;
-        s += 8;
-        d += 8;
-        width -= 8;
-      } while (width != 0);
-      src_ptr += src_stride;
-      dst_ptr += dst_stride;
-    } while (--height != 0);
-  }
+      s0 = s8;
+      s += 8;
+      d += 8;
+      width -= 8;
+    } while (width != 0);
+    src_ptr += src_stride;
+    dst_ptr += dst_stride;
+  } while (--height != 0);
 }
 
 void av1_convolve_2d_sr_neon(const uint8_t *src, int src_stride, uint8_t *dst,
@@ -1514,6 +1587,7 @@ void av1_convolve_2d_sr_neon(const uint8_t *src, int src_stride, uint8_t *dst,
   }
 
   const int y_filter_taps = get_filter_tap(filter_params_y, subpel_y_qn);
+  const int x_filter_taps = get_filter_tap(filter_params_x, subpel_x_qn);
   const int clamped_y_taps = y_filter_taps < 6 ? 6 : y_filter_taps;
   const int im_h = h + clamped_y_taps - 1;
   const int im_stride = MAX_SB_SIZE;
@@ -1544,8 +1618,13 @@ void av1_convolve_2d_sr_neon(const uint8_t *src, int src_stride, uint8_t *dst,
     DECLARE_ALIGNED(16, int16_t,
                     im_block[(MAX_SB_SIZE + SUBPEL_TAPS - 1) * MAX_SB_SIZE]);
 
-    convolve_2d_sr_horiz_neon(src_ptr, src_stride, im_block, im_stride, w, im_h,
-                              x_filter_ptr);
+    if (x_filter_taps <= 4) {
+      convolve_2d_sr_horiz_4tap_neon(src_ptr + 2, src_stride, im_block,
+                                     im_stride, w, im_h, x_filter_ptr);
+    } else {
+      convolve_2d_sr_horiz_8tap_neon(src_ptr, src_stride, im_block, im_stride,
+                                     w, im_h, x_filter_ptr);
+    }
 
     const int16x8_t y_filter = vld1q_s16(y_filter_ptr);