aom: Improve AArch64 performance of highbd warp affine

From 5a44b2f044f22a64685f5e26a492c0b90268c4a3 Mon Sep 17 00:00:00 2001
From: George Steed <[EMAIL REDACTED]>
Date: Thu, 23 May 2024 14:39:59 +0100
Subject: [PATCH] Improve AArch64 performance of highbd warp affine

The common (non-boundary) code path for warp affine currently loads two
vectors of data and then uses EXT repeatedly to set multiple vectors
starting at each possible offset.

We can avoid the need for the EXT instructions by instead loading at the
multiple offsets directly, trading more loads for fewer vector
instructions. In the uncommon boundary code path we simply keep the EXT
instructions as they were before.

On a Neoverse V2 machine with LLVM 17, this reduces the times reported
by the highbd warp affine speed tests by a geomean of ~4.2%.

Change-Id: Ic0e66b76daa44c3d31d7022e57734cbec5da1af4
---
 av1/common/arm/highbd_warp_plane_neon.c |  68 ++-------
 av1/common/arm/highbd_warp_plane_neon.h | 175 +++++++++++++++++-------
 av1/common/arm/highbd_warp_plane_sve.c  |  68 ++-------
 3 files changed, 144 insertions(+), 167 deletions(-)

diff --git a/av1/common/arm/highbd_warp_plane_neon.c b/av1/common/arm/highbd_warp_plane_neon.c
index 89647bc92..51bf142fe 100644
--- a/av1/common/arm/highbd_warp_plane_neon.c
+++ b/av1/common/arm/highbd_warp_plane_neon.c
@@ -24,19 +24,11 @@
 #include "highbd_warp_plane_neon.h"
 
 static AOM_FORCE_INLINE int16x8_t
-highbd_horizontal_filter_4x1_f4(uint16x8x2_t in, int bd, int sx, int alpha) {
+highbd_horizontal_filter_4x1_f4(int16x8_t rv0, int16x8_t rv1, int16x8_t rv2,
+                                int16x8_t rv3, int bd, int sx, int alpha) {
   int16x8_t f[4];
   load_filters_4(f, sx, alpha);
 
-  int16x8_t rv0 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 0);
-  int16x8_t rv1 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 1);
-  int16x8_t rv2 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 2);
-  int16x8_t rv3 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 3);
-
   int32x4_t m0 = vmull_s16(vget_low_s16(f[0]), vget_low_s16(rv0));
   m0 = vmlal_s16(m0, vget_high_s16(f[0]), vget_high_s16(rv0));
   int32x4_t m1 = vmull_s16(vget_low_s16(f[1]), vget_low_s16(rv1));
@@ -57,28 +49,12 @@ highbd_horizontal_filter_4x1_f4(uint16x8x2_t in, int bd, int sx, int alpha) {
   return vcombine_s16(vmovn_s32(res), vdup_n_s16(0));
 }
 
-static AOM_FORCE_INLINE int16x8_t
-highbd_horizontal_filter_8x1_f8(uint16x8x2_t in, int bd, int sx, int alpha) {
+static AOM_FORCE_INLINE int16x8_t highbd_horizontal_filter_8x1_f8(
+    int16x8_t rv0, int16x8_t rv1, int16x8_t rv2, int16x8_t rv3, int16x8_t rv4,
+    int16x8_t rv5, int16x8_t rv6, int16x8_t rv7, int bd, int sx, int alpha) {
   int16x8_t f[8];
   load_filters_8(f, sx, alpha);
 
-  int16x8_t rv0 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 0);
-  int16x8_t rv1 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 1);
-  int16x8_t rv2 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 2);
-  int16x8_t rv3 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 3);
-  int16x8_t rv4 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 4);
-  int16x8_t rv5 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 5);
-  int16x8_t rv6 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 6);
-  int16x8_t rv7 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 7);
-
   int32x4_t m0 = vmull_s16(vget_low_s16(f[0]), vget_low_s16(rv0));
   m0 = vmlal_s16(m0, vget_high_s16(f[0]), vget_high_s16(rv0));
   int32x4_t m1 = vmull_s16(vget_low_s16(f[1]), vget_low_s16(rv1));
@@ -112,18 +88,10 @@ highbd_horizontal_filter_8x1_f8(uint16x8x2_t in, int bd, int sx, int alpha) {
 }
 
 static AOM_FORCE_INLINE int16x8_t
-highbd_horizontal_filter_4x1_f1(uint16x8x2_t in, int bd, int sx) {
+highbd_horizontal_filter_4x1_f1(int16x8_t rv0, int16x8_t rv1, int16x8_t rv2,
+                                int16x8_t rv3, int bd, int sx) {
   int16x8_t f = load_filters_1(sx);
 
-  int16x8_t rv0 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 0);
-  int16x8_t rv1 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 1);
-  int16x8_t rv2 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 2);
-  int16x8_t rv3 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 3);
-
   int32x4_t m0 = vmull_s16(vget_low_s16(f), vget_low_s16(rv0));
   m0 = vmlal_s16(m0, vget_high_s16(f), vget_high_s16(rv0));
   int32x4_t m1 = vmull_s16(vget_low_s16(f), vget_low_s16(rv1));
@@ -144,27 +112,11 @@ highbd_horizontal_filter_4x1_f1(uint16x8x2_t in, int bd, int sx) {
   return vcombine_s16(vmovn_s32(res), vdup_n_s16(0));
 }
 
-static AOM_FORCE_INLINE int16x8_t
-highbd_horizontal_filter_8x1_f1(uint16x8x2_t in, int bd, int sx) {
+static AOM_FORCE_INLINE int16x8_t highbd_horizontal_filter_8x1_f1(
+    int16x8_t rv0, int16x8_t rv1, int16x8_t rv2, int16x8_t rv3, int16x8_t rv4,
+    int16x8_t rv5, int16x8_t rv6, int16x8_t rv7, int bd, int sx) {
   int16x8_t f = load_filters_1(sx);
 
-  int16x8_t rv0 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 0);
-  int16x8_t rv1 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 1);
-  int16x8_t rv2 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 2);
-  int16x8_t rv3 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 3);
-  int16x8_t rv4 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 4);
-  int16x8_t rv5 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 5);
-  int16x8_t rv6 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 6);
-  int16x8_t rv7 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 7);
-
   int32x4_t m0 = vmull_s16(vget_low_s16(f), vget_low_s16(rv0));
   m0 = vmlal_s16(m0, vget_high_s16(f), vget_high_s16(rv0));
   int32x4_t m1 = vmull_s16(vget_low_s16(f), vget_low_s16(rv1));
diff --git a/av1/common/arm/highbd_warp_plane_neon.h b/av1/common/arm/highbd_warp_plane_neon.h
index 48af4a707..2ec45d1e0 100644
--- a/av1/common/arm/highbd_warp_plane_neon.h
+++ b/av1/common/arm/highbd_warp_plane_neon.h
@@ -24,16 +24,19 @@
 #include "config/av1_rtcd.h"
 
 static AOM_FORCE_INLINE int16x8_t
-highbd_horizontal_filter_4x1_f4(uint16x8x2_t in, int bd, int sx, int alpha);
+highbd_horizontal_filter_4x1_f4(int16x8_t rv0, int16x8_t rv1, int16x8_t rv2,
+                                int16x8_t rv3, int bd, int sx, int alpha);
 
-static AOM_FORCE_INLINE int16x8_t
-highbd_horizontal_filter_8x1_f8(uint16x8x2_t in, int bd, int sx, int alpha);
+static AOM_FORCE_INLINE int16x8_t highbd_horizontal_filter_8x1_f8(
+    int16x8_t rv0, int16x8_t rv1, int16x8_t rv2, int16x8_t rv3, int16x8_t rv4,
+    int16x8_t rv5, int16x8_t rv6, int16x8_t rv7, int bd, int sx, int alpha);
 
-static AOM_FORCE_INLINE int16x8_t
-highbd_horizontal_filter_4x1_f1(uint16x8x2_t in, int bd, int sx);
+static AOM_FORCE_INLINE int16x8_t highbd_horizontal_filter_4x1_f1(
+    int16x8_t rv0, int16x8_t rv1, int16x8_t rv2, int16x8_t rv3, int bd, int sx);
 
-static AOM_FORCE_INLINE int16x8_t
-highbd_horizontal_filter_8x1_f1(uint16x8x2_t in, int bd, int sx);
+static AOM_FORCE_INLINE int16x8_t highbd_horizontal_filter_8x1_f1(
+    int16x8_t rv0, int16x8_t rv1, int16x8_t rv2, int16x8_t rv3, int16x8_t rv4,
+    int16x8_t rv5, int16x8_t rv6, int16x8_t rv7, int bd, int sx);
 
 static AOM_FORCE_INLINE int32x4_t vertical_filter_4x1_f1(const int16x8_t *tmp,
                                                          int sy);
@@ -99,6 +102,29 @@ static AOM_FORCE_INLINE uint16x4_t clip_pixel_highbd_vec(int32x4_t val,
   return vqmovun_s32(vminq_s32(val, vdupq_n_s32(limit)));
 }
 
+static AOM_FORCE_INLINE uint16x8x2_t clamp_horizontal(
+    uint16x8x2_t src_1, int out_of_boundary_left, int out_of_boundary_right,
+    const uint16_t *ref, int iy, int stride, int width, const uint16x8_t indx0,
+    const uint16x8_t indx1) {
+  if (out_of_boundary_left >= 0) {
+    uint16x8_t cmp_vec = vdupq_n_u16(out_of_boundary_left);
+    uint16x8_t vec_dup = vdupq_n_u16(ref[iy * stride]);
+    uint16x8_t mask0 = vcleq_u16(indx0, cmp_vec);
+    uint16x8_t mask1 = vcleq_u16(indx1, cmp_vec);
+    src_1.val[0] = vbslq_u16(mask0, vec_dup, src_1.val[0]);
+    src_1.val[1] = vbslq_u16(mask1, vec_dup, src_1.val[1]);
+  }
+  if (out_of_boundary_right >= 0) {
+    uint16x8_t cmp_vec = vdupq_n_u16(15 - out_of_boundary_right);
+    uint16x8_t vec_dup = vdupq_n_u16(ref[iy * stride + width - 1]);
+    uint16x8_t mask0 = vcgeq_u16(indx0, cmp_vec);
+    uint16x8_t mask1 = vcgeq_u16(indx1, cmp_vec);
+    src_1.val[0] = vbslq_u16(mask0, vec_dup, src_1.val[0]);
+    src_1.val[1] = vbslq_u16(mask1, vec_dup, src_1.val[1]);
+  }
+  return src_1;
+}
+
 static AOM_FORCE_INLINE void warp_affine_horizontal(const uint16_t *ref,
                                                     int width, int height,
                                                     int stride, int p_width,
@@ -134,73 +160,120 @@ static AOM_FORCE_INLINE void warp_affine_horizontal(const uint16_t *ref,
   const int out_of_boundary_left = -(ix4 - 6);
   const int out_of_boundary_right = (ix4 + 8) - width;
 
-#define APPLY_HORIZONTAL_SHIFT(fn, ...)                                   \
-  do {                                                                    \
-    if (out_of_boundary_left >= 0 || out_of_boundary_right >= 0) {        \
-      for (int k = 0; k < 15; ++k) {                                      \
-        const int iy = clamp(iy4 + k - 7, 0, height - 1);                 \
-        uint16x8x2_t src_1 = vld1q_u16_x2(ref + iy * stride + ix4 - 7);   \
-                                                                          \
-        if (out_of_boundary_left >= 0) {                                  \
-          uint16x8_t cmp_vec = vdupq_n_u16(out_of_boundary_left);         \
-          uint16x8_t vec_dup = vdupq_n_u16(ref[iy * stride]);             \
-          uint16x8_t mask0 = vcleq_u16(indx0, cmp_vec);                   \
-          uint16x8_t mask1 = vcleq_u16(indx1, cmp_vec);                   \
-          src_1.val[0] = vbslq_u16(mask0, vec_dup, src_1.val[0]);         \
-          src_1.val[1] = vbslq_u16(mask1, vec_dup, src_1.val[1]);         \
-        }                                                                 \
-        if (out_of_boundary_right >= 0) {                                 \
-          uint16x8_t cmp_vec = vdupq_n_u16(15 - out_of_boundary_right);   \
-          uint16x8_t vec_dup = vdupq_n_u16(ref[iy * stride + width - 1]); \
-          uint16x8_t mask0 = vcgeq_u16(indx0, cmp_vec);                   \
-          uint16x8_t mask1 = vcgeq_u16(indx1, cmp_vec);                   \
-          src_1.val[0] = vbslq_u16(mask0, vec_dup, src_1.val[0]);         \
-          src_1.val[1] = vbslq_u16(mask1, vec_dup, src_1.val[1]);         \
-        }                                                                 \
-        tmp[k] = (fn)(src_1, __VA_ARGS__);                                \
-      }                                                                   \
-    } else {                                                              \
-      for (int k = 0; k < 15; ++k) {                                      \
-        const int iy = clamp(iy4 + k - 7, 0, height - 1);                 \
-        uint16x8x2_t src_1 = vld1q_u16_x2(ref + iy * stride + ix4 - 7);   \
-        tmp[k] = (fn)(src_1, __VA_ARGS__);                                \
-      }                                                                   \
-    }                                                                     \
+#define APPLY_HORIZONTAL_SHIFT_4X1(fn, ...)                                \
+  do {                                                                     \
+    if (out_of_boundary_left >= 0 || out_of_boundary_right >= 0) {         \
+      for (int k = 0; k < 15; ++k) {                                       \
+        const int iy = clamp(iy4 + k - 7, 0, height - 1);                  \
+        uint16x8x2_t src_1 = vld1q_u16_x2(ref + iy * stride + ix4 - 7);    \
+        src_1 = clamp_horizontal(src_1, out_of_boundary_left,              \
+                                 out_of_boundary_right, ref, iy, stride,   \
+                                 width, indx0, indx1);                     \
+        int16x8_t rv0 = vextq_s16(vreinterpretq_s16_u16(src_1.val[0]),     \
+                                  vreinterpretq_s16_u16(src_1.val[1]), 0); \
+        int16x8_t rv1 = vextq_s16(vreinterpretq_s16_u16(src_1.val[0]),     \
+                                  vreinterpretq_s16_u16(src_1.val[1]), 1); \
+        int16x8_t rv2 = vextq_s16(vreinterpretq_s16_u16(src_1.val[0]),     \
+                                  vreinterpretq_s16_u16(src_1.val[1]), 2); \
+        int16x8_t rv3 = vextq_s16(vreinterpretq_s16_u16(src_1.val[0]),     \
+                                  vreinterpretq_s16_u16(src_1.val[1]), 3); \
+        tmp[k] = (fn)(rv0, rv1, rv2, rv3, __VA_ARGS__);                    \
+      }                                                                    \
+    } else {                                                               \
+      for (int k = 0; k < 15; ++k) {                                       \
+        const int iy = clamp(iy4 + k - 7, 0, height - 1);                  \
+        const uint16_t *src = ref + iy * stride + ix4;                     \
+        int16x8_t rv0 = vreinterpretq_s16_u16(vld1q_u16(src - 7));         \
+        int16x8_t rv1 = vreinterpretq_s16_u16(vld1q_u16(src - 6));         \
+        int16x8_t rv2 = vreinterpretq_s16_u16(vld1q_u16(src - 5));         \
+        int16x8_t rv3 = vreinterpretq_s16_u16(vld1q_u16(src - 4));         \
+        tmp[k] = (fn)(rv0, rv1, rv2, rv3, __VA_ARGS__);                    \
+      }                                                                    \
+    }                                                                      \
+  } while (0)
+
+#define APPLY_HORIZONTAL_SHIFT_8X1(fn, ...)                                 \
+  do {                                                                      \
+    if (out_of_boundary_left >= 0 || out_of_boundary_right >= 0) {          \
+      for (int k = 0; k < 15; ++k) {                                        \
+        const int iy = clamp(iy4 + k - 7, 0, height - 1);                   \
+        uint16x8x2_t src_1 = vld1q_u16_x2(ref + iy * stride + ix4 - 7);     \
+        src_1 = clamp_horizontal(src_1, out_of_boundary_left,               \
+                                 out_of_boundary_right, ref, iy, stride,    \
+                                 width, indx0, indx1);                      \
+        int16x8_t rv0 = vextq_s16(vreinterpretq_s16_u16(src_1.val[0]),      \
+                                  vreinterpretq_s16_u16(src_1.val[1]), 0);  \
+        int16x8_t rv1 = vextq_s16(vreinterpretq_s16_u16(src_1.val[0]),      \
+                                  vreinterpretq_s16_u16(src_1.val[1]), 1);  \
+        int16x8_t rv2 = vextq_s16(vreinterpretq_s16_u16(src_1.val[0]),      \
+                                  vreinterpretq_s16_u16(src_1.val[1]), 2);  \
+        int16x8_t rv3 = vextq_s16(vreinterpretq_s16_u16(src_1.val[0]),      \
+                                  vreinterpretq_s16_u16(src_1.val[1]), 3);  \
+        int16x8_t rv4 = vextq_s16(vreinterpretq_s16_u16(src_1.val[0]),      \
+                                  vreinterpretq_s16_u16(src_1.val[1]), 4);  \
+        int16x8_t rv5 = vextq_s16(vreinterpretq_s16_u16(src_1.val[0]),      \
+                                  vreinterpretq_s16_u16(src_1.val[1]), 5);  \
+        int16x8_t rv6 = vextq_s16(vreinterpretq_s16_u16(src_1.val[0]),      \
+                                  vreinterpretq_s16_u16(src_1.val[1]), 6);  \
+        int16x8_t rv7 = vextq_s16(vreinterpretq_s16_u16(src_1.val[0]),      \
+                                  vreinterpretq_s16_u16(src_1.val[1]), 7);  \
+        tmp[k] = (fn)(rv0, rv1, rv2, rv3, rv4, rv5, rv6, rv7, __VA_ARGS__); \
+      }                                                                     \
+    } else {                                                                \
+      for (int k = 0; k < 15; ++k) {                                        \
+        const int iy = clamp(iy4 + k - 7, 0, height - 1);                   \
+        const uint16_t *src = ref + iy * stride + ix4;                      \
+        int16x8_t rv0 = vreinterpretq_s16_u16(vld1q_u16(src - 7));          \
+        int16x8_t rv1 = vreinterpretq_s16_u16(vld1q_u16(src - 6));          \
+        int16x8_t rv2 = vreinterpretq_s16_u16(vld1q_u16(src - 5));          \
+        int16x8_t rv3 = vreinterpretq_s16_u16(vld1q_u16(src - 4));          \
+        int16x8_t rv4 = vreinterpretq_s16_u16(vld1q_u16(src - 3));          \
+        int16x8_t rv5 = vreinterpretq_s16_u16(vld1q_u16(src - 2));          \
+        int16x8_t rv6 = vreinterpretq_s16_u16(vld1q_u16(src - 1));          \
+        int16x8_t rv7 = vreinterpretq_s16_u16(vld1q_u16(src - 0));          \
+        tmp[k] = (fn)(rv0, rv1, rv2, rv3, rv4, rv5, rv6, rv7, __VA_ARGS__); \
+      }                                                                     \
+    }                                                                       \
   } while (0)
 
   if (p_width == 4) {
     if (beta == 0) {
       if (alpha == 0) {
-        APPLY_HORIZONTAL_SHIFT(highbd_horizontal_filter_4x1_f1, bd, sx4);
+        APPLY_HORIZONTAL_SHIFT_4X1(highbd_horizontal_filter_4x1_f1, bd, sx4);
       } else {
-        APPLY_HORIZONTAL_SHIFT(highbd_horizontal_filter_4x1_f4, bd, sx4, alpha);
+        APPLY_HORIZONTAL_SHIFT_4X1(highbd_horizontal_filter_4x1_f4, bd, sx4,
+                                   alpha);
       }
     } else {
       if (alpha == 0) {
-        APPLY_HORIZONTAL_SHIFT(highbd_horizontal_filter_4x1_f1, bd,
-                               (sx4 + beta * (k - 3)));
+        APPLY_HORIZONTAL_SHIFT_4X1(highbd_horizontal_filter_4x1_f1, bd,
+                                   (sx4 + beta * (k - 3)));
       } else {
-        APPLY_HORIZONTAL_SHIFT(highbd_horizontal_filter_4x1_f4, bd,
-                               (sx4 + beta * (k - 3)), alpha);
+        APPLY_HORIZONTAL_SHIFT_4X1(highbd_horizontal_filter_4x1_f4, bd,
+                                   (sx4 + beta * (k - 3)), alpha);
       }
     }
   } else {
     if (beta == 0) {
       if (alpha == 0) {
-        APPLY_HORIZONTAL_SHIFT(highbd_horizontal_filter_8x1_f1, bd, sx4);
+        APPLY_HORIZONTAL_SHIFT_8X1(highbd_horizontal_filter_8x1_f1, bd, sx4);
       } else {
-        APPLY_HORIZONTAL_SHIFT(highbd_horizontal_filter_8x1_f8, bd, sx4, alpha);
+        APPLY_HORIZONTAL_SHIFT_8X1(highbd_horizontal_filter_8x1_f8, bd, sx4,
+                                   alpha);
       }
     } else {
       if (alpha == 0) {
-        APPLY_HORIZONTAL_SHIFT(highbd_horizontal_filter_8x1_f1, bd,
-                               (sx4 + beta * (k - 3)));
+        APPLY_HORIZONTAL_SHIFT_8X1(highbd_horizontal_filter_8x1_f1, bd,
+                                   (sx4 + beta * (k - 3)));
       } else {
-        APPLY_HORIZONTAL_SHIFT(highbd_horizontal_filter_8x1_f8, bd,
-                               (sx4 + beta * (k - 3)), alpha);
+        APPLY_HORIZONTAL_SHIFT_8X1(highbd_horizontal_filter_8x1_f8, bd,
+                                   (sx4 + beta * (k - 3)), alpha);
       }
     }
   }
+
+#undef APPLY_HORIZONTAL_SHIFT_4X1
+#undef APPLY_HORIZONTAL_SHIFT_8X1
 }
 
 static AOM_FORCE_INLINE void highbd_vertical_filter_4x1_f4(
diff --git a/av1/common/arm/highbd_warp_plane_sve.c b/av1/common/arm/highbd_warp_plane_sve.c
index 87e033fd0..c2e1e995b 100644
--- a/av1/common/arm/highbd_warp_plane_sve.c
+++ b/av1/common/arm/highbd_warp_plane_sve.c
@@ -25,19 +25,11 @@
 #include "highbd_warp_plane_neon.h"
 
 static AOM_FORCE_INLINE int16x8_t
-highbd_horizontal_filter_4x1_f4(uint16x8x2_t in, int bd, int sx, int alpha) {
+highbd_horizontal_filter_4x1_f4(int16x8_t rv0, int16x8_t rv1, int16x8_t rv2,
+                                int16x8_t rv3, int bd, int sx, int alpha) {
   int16x8_t f[4];
   load_filters_4(f, sx, alpha);
 
-  int16x8_t rv0 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 0);
-  int16x8_t rv1 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 1);
-  int16x8_t rv2 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 2);
-  int16x8_t rv3 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 3);
-
   int64x2_t m0 = aom_sdotq_s16(vdupq_n_s64(0), rv0, f[0]);
   int64x2_t m1 = aom_sdotq_s16(vdupq_n_s64(0), rv1, f[1]);
   int64x2_t m2 = aom_sdotq_s16(vdupq_n_s64(0), rv2, f[2]);
@@ -55,28 +47,12 @@ highbd_horizontal_filter_4x1_f4(uint16x8x2_t in, int bd, int sx, int alpha) {
   return vcombine_s16(vmovn_s32(res), vdup_n_s16(0));
 }
 
-static AOM_FORCE_INLINE int16x8_t
-highbd_horizontal_filter_8x1_f8(uint16x8x2_t in, int bd, int sx, int alpha) {
+static AOM_FORCE_INLINE int16x8_t highbd_horizontal_filter_8x1_f8(
+    int16x8_t rv0, int16x8_t rv1, int16x8_t rv2, int16x8_t rv3, int16x8_t rv4,
+    int16x8_t rv5, int16x8_t rv6, int16x8_t rv7, int bd, int sx, int alpha) {
   int16x8_t f[8];
   load_filters_8(f, sx, alpha);
 
-  int16x8_t rv0 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 0);
-  int16x8_t rv1 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 1);
-  int16x8_t rv2 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 2);
-  int16x8_t rv3 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 3);
-  int16x8_t rv4 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 4);
-  int16x8_t rv5 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 5);
-  int16x8_t rv6 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 6);
-  int16x8_t rv7 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 7);
-
   int64x2_t m0 = aom_sdotq_s16(vdupq_n_s64(0), rv0, f[0]);
   int64x2_t m1 = aom_sdotq_s16(vdupq_n_s64(0), rv1, f[1]);
   int64x2_t m2 = aom_sdotq_s16(vdupq_n_s64(0), rv2, f[2]);
@@ -104,18 +80,10 @@ highbd_horizontal_filter_8x1_f8(uint16x8x2_t in, int bd, int sx, int alpha) {
 }
 
 static AOM_FORCE_INLINE int16x8_t
-highbd_horizontal_filter_4x1_f1(uint16x8x2_t in, int bd, int sx) {
+highbd_horizontal_filter_4x1_f1(int16x8_t rv0, int16x8_t rv1, int16x8_t rv2,
+                                int16x8_t rv3, int bd, int sx) {
   int16x8_t f = load_filters_1(sx);
 
-  int16x8_t rv0 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 0);
-  int16x8_t rv1 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 1);
-  int16x8_t rv2 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 2);
-  int16x8_t rv3 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 3);
-
   int64x2_t m0 = aom_sdotq_s16(vdupq_n_s64(0), rv0, f);
   int64x2_t m1 = aom_sdotq_s16(vdupq_n_s64(0), rv1, f);
   int64x2_t m2 = aom_sdotq_s16(vdupq_n_s64(0), rv2, f);
@@ -133,27 +101,11 @@ highbd_horizontal_filter_4x1_f1(uint16x8x2_t in, int bd, int sx) {
   return vcombine_s16(vmovn_s32(res), vdup_n_s16(0));
 }
 
-static AOM_FORCE_INLINE int16x8_t
-highbd_horizontal_filter_8x1_f1(uint16x8x2_t in, int bd, int sx) {
+static AOM_FORCE_INLINE int16x8_t highbd_horizontal_filter_8x1_f1(
+    int16x8_t rv0, int16x8_t rv1, int16x8_t rv2, int16x8_t rv3, int16x8_t rv4,
+    int16x8_t rv5, int16x8_t rv6, int16x8_t rv7, int bd, int sx) {
   int16x8_t f = load_filters_1(sx);
 
-  int16x8_t rv0 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 0);
-  int16x8_t rv1 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 1);
-  int16x8_t rv2 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 2);
-  int16x8_t rv3 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 3);
-  int16x8_t rv4 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 4);
-  int16x8_t rv5 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 5);
-  int16x8_t rv6 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 6);
-  int16x8_t rv7 = vextq_s16(vreinterpretq_s16_u16(in.val[0]),
-                            vreinterpretq_s16_u16(in.val[1]), 7);
-
   int64x2_t m0 = aom_sdotq_s16(vdupq_n_s64(0), rv0, f);
   int64x2_t m1 = aom_sdotq_s16(vdupq_n_s64(0), rv1, f);
   int64x2_t m2 = aom_sdotq_s16(vdupq_n_s64(0), rv2, f);