aom: Refactor and optimize aom_scaled_2d_neon

From 98cbdb3e0cc28ce66bf5e332e7d1c57d4dbca281 Mon Sep 17 00:00:00 2001
From: Jonathan Wright <[EMAIL REDACTED]>
Date: Tue, 11 Jun 2024 00:15:12 +0100
Subject: [PATCH] Refactor and optimize aom_scaled_2d_neon

Tidy up the standard bitdepth Armv8.0 Neon implementation of
aom_scaled_2d. Also halve the filter values (since they're all even)
to avoid saturating arithmetic in convolution kernels.

Change-Id: I6485e609bf667f4517dc480470eca8b1025ac278
---
 aom_dsp/arm/aom_convolve8_neon.c        |  41 --
 aom_dsp/arm/aom_convolve8_neon.h        |  44 +-
 aom_dsp/arm/aom_scaled_convolve8_neon.c | 516 +++++++++++-------------
 3 files changed, 289 insertions(+), 312 deletions(-)

diff --git a/aom_dsp/arm/aom_convolve8_neon.c b/aom_dsp/arm/aom_convolve8_neon.c
index 0928b9327..d2f13ff13 100644
--- a/aom_dsp/arm/aom_convolve8_neon.c
+++ b/aom_dsp/arm/aom_convolve8_neon.c
@@ -26,47 +26,6 @@
 #include "aom_dsp/arm/transpose_neon.h"
 #include "aom_ports/mem.h"
 
-static INLINE int16x4_t convolve8_4(const int16x4_t s0, const int16x4_t s1,
-                                    const int16x4_t s2, const int16x4_t s3,
-                                    const int16x4_t s4, const int16x4_t s5,
-                                    const int16x4_t s6, const int16x4_t s7,
-                                    const int16x8_t filter) {
-  const int16x4_t filter_lo = vget_low_s16(filter);
-  const int16x4_t filter_hi = vget_high_s16(filter);
-
-  int16x4_t sum = vmul_lane_s16(s0, filter_lo, 0);
-  sum = vmla_lane_s16(sum, s1, filter_lo, 1);
-  sum = vmla_lane_s16(sum, s2, filter_lo, 2);
-  sum = vmla_lane_s16(sum, s3, filter_lo, 3);
-  sum = vmla_lane_s16(sum, s4, filter_hi, 0);
-  sum = vmla_lane_s16(sum, s5, filter_hi, 1);
-  sum = vmla_lane_s16(sum, s6, filter_hi, 2);
-  sum = vmla_lane_s16(sum, s7, filter_hi, 3);
-
-  return sum;
-}
-
-static INLINE uint8x8_t convolve8_8(const int16x8_t s0, const int16x8_t s1,
-                                    const int16x8_t s2, const int16x8_t s3,
-                                    const int16x8_t s4, const int16x8_t s5,
-                                    const int16x8_t s6, const int16x8_t s7,
-                                    const int16x8_t filter) {
-  const int16x4_t filter_lo = vget_low_s16(filter);
-  const int16x4_t filter_hi = vget_high_s16(filter);
-
-  int16x8_t sum = vmulq_lane_s16(s0, filter_lo, 0);
-  sum = vmlaq_lane_s16(sum, s1, filter_lo, 1);
-  sum = vmlaq_lane_s16(sum, s2, filter_lo, 2);
-  sum = vmlaq_lane_s16(sum, s3, filter_lo, 3);
-  sum = vmlaq_lane_s16(sum, s4, filter_hi, 0);
-  sum = vmlaq_lane_s16(sum, s5, filter_hi, 1);
-  sum = vmlaq_lane_s16(sum, s6, filter_hi, 2);
-  sum = vmlaq_lane_s16(sum, s7, filter_hi, 3);
-
-  // We halved the filter values so -1 from right shift.
-  return vqrshrun_n_s16(sum, FILTER_BITS - 1);
-}
-
 static INLINE void convolve8_horiz_8tap_neon(const uint8_t *src,
                                              ptrdiff_t src_stride, uint8_t *dst,
                                              ptrdiff_t dst_stride,
diff --git a/aom_dsp/arm/aom_convolve8_neon.h b/aom_dsp/arm/aom_convolve8_neon.h
index 0b6e5245a..d1384a76e 100644
--- a/aom_dsp/arm/aom_convolve8_neon.h
+++ b/aom_dsp/arm/aom_convolve8_neon.h
@@ -14,8 +14,50 @@
 
 #include <arm_neon.h>
 
-#include "config/aom_config.h"
+#include "aom_dsp/aom_filter.h"
 #include "aom_dsp/arm/mem_neon.h"
+#include "config/aom_config.h"
+
+static INLINE int16x4_t convolve8_4(const int16x4_t s0, const int16x4_t s1,
+                                    const int16x4_t s2, const int16x4_t s3,
+                                    const int16x4_t s4, const int16x4_t s5,
+                                    const int16x4_t s6, const int16x4_t s7,
+                                    const int16x8_t filter) {
+  const int16x4_t filter_lo = vget_low_s16(filter);
+  const int16x4_t filter_hi = vget_high_s16(filter);
+
+  int16x4_t sum = vmul_lane_s16(s0, filter_lo, 0);
+  sum = vmla_lane_s16(sum, s1, filter_lo, 1);
+  sum = vmla_lane_s16(sum, s2, filter_lo, 2);
+  sum = vmla_lane_s16(sum, s3, filter_lo, 3);
+  sum = vmla_lane_s16(sum, s4, filter_hi, 0);
+  sum = vmla_lane_s16(sum, s5, filter_hi, 1);
+  sum = vmla_lane_s16(sum, s6, filter_hi, 2);
+  sum = vmla_lane_s16(sum, s7, filter_hi, 3);
+
+  return sum;
+}
+
+static INLINE uint8x8_t convolve8_8(const int16x8_t s0, const int16x8_t s1,
+                                    const int16x8_t s2, const int16x8_t s3,
+                                    const int16x8_t s4, const int16x8_t s5,
+                                    const int16x8_t s6, const int16x8_t s7,
+                                    const int16x8_t filter) {
+  const int16x4_t filter_lo = vget_low_s16(filter);
+  const int16x4_t filter_hi = vget_high_s16(filter);
+
+  int16x8_t sum = vmulq_lane_s16(s0, filter_lo, 0);
+  sum = vmlaq_lane_s16(sum, s1, filter_lo, 1);
+  sum = vmlaq_lane_s16(sum, s2, filter_lo, 2);
+  sum = vmlaq_lane_s16(sum, s3, filter_lo, 3);
+  sum = vmlaq_lane_s16(sum, s4, filter_hi, 0);
+  sum = vmlaq_lane_s16(sum, s5, filter_hi, 1);
+  sum = vmlaq_lane_s16(sum, s6, filter_hi, 2);
+  sum = vmlaq_lane_s16(sum, s7, filter_hi, 3);
+
+  // We halved the filter values so -1 from right shift.
+  return vqrshrun_n_s16(sum, FILTER_BITS - 1);
+}
 
 static INLINE void convolve8_horiz_2tap_neon(const uint8_t *src,
                                              ptrdiff_t src_stride, uint8_t *dst,
diff --git a/aom_dsp/arm/aom_scaled_convolve8_neon.c b/aom_dsp/arm/aom_scaled_convolve8_neon.c
index f81a06be9..3c11133b8 100644
--- a/aom_dsp/arm/aom_scaled_convolve8_neon.c
+++ b/aom_dsp/arm/aom_scaled_convolve8_neon.c
@@ -12,310 +12,294 @@
 #include <arm_neon.h>
 #include <assert.h>
 
+#include "aom_dsp/arm/aom_convolve8_neon.h"
 #include "aom_dsp/arm/mem_neon.h"
 #include "aom_dsp/arm/transpose_neon.h"
 #include "config/aom_dsp_rtcd.h"
 
-static INLINE int16x4_t convolve8_4(const int16x4_t s0, const int16x4_t s1,
-                                    const int16x4_t s2, const int16x4_t s3,
-                                    const int16x4_t s4, const int16x4_t s5,
-                                    const int16x4_t s6, const int16x4_t s7,
-                                    const int16x8_t filter) {
-  const int16x4_t filter_lo = vget_low_s16(filter);
-  const int16x4_t filter_hi = vget_high_s16(filter);
-
-  int16x4_t sum = vmul_lane_s16(s0, filter_lo, 0);
-  sum = vmla_lane_s16(sum, s1, filter_lo, 1);
-  sum = vmla_lane_s16(sum, s2, filter_lo, 2);
-  sum = vmla_lane_s16(sum, s5, filter_hi, 1);
-  sum = vmla_lane_s16(sum, s6, filter_hi, 2);
-  sum = vmla_lane_s16(sum, s7, filter_hi, 3);
-  sum = vqadd_s16(sum, vmul_lane_s16(s3, filter_lo, 3));
-  sum = vqadd_s16(sum, vmul_lane_s16(s4, filter_hi, 0));
-  return sum;
-}
-
-static INLINE uint8x8_t convolve8_8(const int16x8_t s0, const int16x8_t s1,
-                                    const int16x8_t s2, const int16x8_t s3,
-                                    const int16x8_t s4, const int16x8_t s5,
-                                    const int16x8_t s6, const int16x8_t s7,
-                                    const int16x8_t filter) {
-  const int16x4_t filter_lo = vget_low_s16(filter);
-  const int16x4_t filter_hi = vget_high_s16(filter);
-
-  int16x8_t sum = vmulq_lane_s16(s0, filter_lo, 0);
-  sum = vmlaq_lane_s16(sum, s1, filter_lo, 1);
-  sum = vmlaq_lane_s16(sum, s2, filter_lo, 2);
-  sum = vmlaq_lane_s16(sum, s5, filter_hi, 1);
-  sum = vmlaq_lane_s16(sum, s6, filter_hi, 2);
-  sum = vmlaq_lane_s16(sum, s7, filter_hi, 3);
-  sum = vqaddq_s16(sum, vmulq_lane_s16(s3, filter_lo, 3));
-  sum = vqaddq_s16(sum, vmulq_lane_s16(s4, filter_hi, 0));
-  return vqrshrun_n_s16(sum, 7);
-}
-
-static INLINE uint8x8_t scale_filter_8(const uint8x8_t *const s,
-                                       const int16x8_t filter) {
-  int16x8_t ss0 = vreinterpretq_s16_u16(vmovl_u8(s[0]));
-  int16x8_t ss1 = vreinterpretq_s16_u16(vmovl_u8(s[1]));
-  int16x8_t ss2 = vreinterpretq_s16_u16(vmovl_u8(s[2]));
-  int16x8_t ss3 = vreinterpretq_s16_u16(vmovl_u8(s[3]));
-  int16x8_t ss4 = vreinterpretq_s16_u16(vmovl_u8(s[4]));
-  int16x8_t ss5 = vreinterpretq_s16_u16(vmovl_u8(s[5]));
-  int16x8_t ss6 = vreinterpretq_s16_u16(vmovl_u8(s[6]));
-  int16x8_t ss7 = vreinterpretq_s16_u16(vmovl_u8(s[7]));
-
-  return convolve8_8(ss0, ss1, ss2, ss3, ss4, ss5, ss6, ss7, filter);
-}
-
-static INLINE void scaledconvolve_horiz_w4(
+static INLINE void scaled_convolve_horiz_neon(
     const uint8_t *src, const ptrdiff_t src_stride, uint8_t *dst,
-    const ptrdiff_t dst_stride, const InterpKernel *const x_filters,
-    const int x0_q4, const int x_step_q4, const int w, const int h) {
-  DECLARE_ALIGNED(16, uint8_t, temp[4 * 4]);
-  int x, y, z;
-
-  src -= SUBPEL_TAPS / 2 - 1;
+    const ptrdiff_t dst_stride, const InterpKernel *const x_filter,
+    const int x0_q4, const int x_step_q4, int w, int h) {
+  DECLARE_ALIGNED(16, uint8_t, temp[8 * 8]);
 
-  y = h;
-  do {
-    int x_q4 = x0_q4;
-    x = 0;
+  if (w == 4) {
     do {
-      // process 4 src_x steps
-      for (z = 0; z < 4; ++z) {
-        const uint8_t *const src_x = &src[x_q4 >> SUBPEL_BITS];
+      int x_q4 = x0_q4;
+
+      // Process a 4x4 tile.
+      for (int r = 0; r < 4; ++r) {
+        const uint8_t *s = &src[x_q4 >> SUBPEL_BITS];
+
         if (x_q4 & SUBPEL_MASK) {
-          const int16x8_t filters = vld1q_s16(x_filters[x_q4 & SUBPEL_MASK]);
-          uint8x8_t s[8], d;
-          int16x8_t ss[4];
-          int16x4_t t[8], tt;
-
-          load_u8_8x4(src_x, src_stride, &s[0], &s[1], &s[2], &s[3]);
-          transpose_elems_inplace_u8_8x4(&s[0], &s[1], &s[2], &s[3]);
-
-          ss[0] = vreinterpretq_s16_u16(vmovl_u8(s[0]));
-          ss[1] = vreinterpretq_s16_u16(vmovl_u8(s[1]));
-          ss[2] = vreinterpretq_s16_u16(vmovl_u8(s[2]));
-          ss[3] = vreinterpretq_s16_u16(vmovl_u8(s[3]));
-          t[0] = vget_low_s16(ss[0]);
-          t[1] = vget_low_s16(ss[1]);
-          t[2] = vget_low_s16(ss[2]);
-          t[3] = vget_low_s16(ss[3]);
-          t[4] = vget_high_s16(ss[0]);
-          t[5] = vget_high_s16(ss[1]);
-          t[6] = vget_high_s16(ss[2]);
-          t[7] = vget_high_s16(ss[3]);
-
-          tt = convolve8_4(t[0], t[1], t[2], t[3], t[4], t[5], t[6], t[7],
-                           filters);
-          d = vqrshrun_n_s16(vcombine_s16(tt, tt), 7);
-          store_u8_4x1(&temp[4 * z], d);
+          // Halve filter values (all even) to avoid the need for saturating
+          // arithmetic in convolution kernels.
+          const int16x8_t filter =
+              vshrq_n_s16(vld1q_s16(x_filter[x_q4 & SUBPEL_MASK]), 1);
+
+          uint8x8_t t0, t1, t2, t3;
+          load_u8_8x4(s, src_stride, &t0, &t1, &t2, &t3);
+          transpose_elems_inplace_u8_8x4(&t0, &t1, &t2, &t3);
+
+          int16x4_t s0 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t0)));
+          int16x4_t s1 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t1)));
+          int16x4_t s2 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t2)));
+          int16x4_t s3 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t3)));
+          int16x4_t s4 = vget_high_s16(vreinterpretq_s16_u16(vmovl_u8(t0)));
+          int16x4_t s5 = vget_high_s16(vreinterpretq_s16_u16(vmovl_u8(t1)));
+          int16x4_t s6 = vget_high_s16(vreinterpretq_s16_u16(vmovl_u8(t2)));
+          int16x4_t s7 = vget_high_s16(vreinterpretq_s16_u16(vmovl_u8(t3)));
+
+          int16x4_t dd0 = convolve8_4(s0, s1, s2, s3, s4, s5, s6, s7, filter);
+          // We halved the filter values so -1 from right shift.
+          uint8x8_t d0 =
+              vqrshrun_n_s16(vcombine_s16(dd0, vdup_n_s16(0)), FILTER_BITS - 1);
+
+          store_u8_4x1(&temp[4 * r], d0);
         } else {
-          int i;
-          for (i = 0; i < 4; ++i) {
-            temp[z * 4 + i] = src_x[i * src_stride + 3];
+          // Memcpy for non-subpel locations.
+          s += SUBPEL_TAPS / 2 - 1;
+
+          for (int c = 0; c < 4; ++c) {
+            temp[r * 4 + c] = s[c * src_stride];
           }
         }
         x_q4 += x_step_q4;
       }
 
-      // transpose the 4x4 filters values back to dst
-      {
-        const uint8x8x4_t d4 = vld4_u8(temp);
-        store_u8_4x1(&dst[x + 0 * dst_stride], d4.val[0]);
-        store_u8_4x1(&dst[x + 1 * dst_stride], d4.val[1]);
-        store_u8_4x1(&dst[x + 2 * dst_stride], d4.val[2]);
-        store_u8_4x1(&dst[x + 3 * dst_stride], d4.val[3]);
-      }
-      x += 4;
-    } while (x < w);
+      // Transpose the 4x4 result tile and store.
+      uint8x8_t d01 = vld1_u8(temp + 0);
+      uint8x8_t d23 = vld1_u8(temp + 8);
 
-    src += src_stride * 4;
-    dst += dst_stride * 4;
-    y -= 4;
-  } while (y > 0);
-}
+      transpose_elems_inplace_u8_4x4(&d01, &d23);
 
-static INLINE void scaledconvolve_horiz_w8(
-    const uint8_t *src, const ptrdiff_t src_stride, uint8_t *dst,
-    const ptrdiff_t dst_stride, const InterpKernel *const x_filters,
-    const int x0_q4, const int x_step_q4, const int w, const int h) {
-  DECLARE_ALIGNED(16, uint8_t, temp[8 * 8]);
-  int x, y, z;
-  src -= SUBPEL_TAPS / 2 - 1;
+      store_u8x4_strided_x2(dst + 0 * dst_stride, 2 * dst_stride, d01);
+      store_u8x4_strided_x2(dst + 1 * dst_stride, 2 * dst_stride, d23);
 
-  // This function processes 8x8 areas. The intermediate height is not always
-  // a multiple of 8, so force it to be a multiple of 8 here.
-  y = (h + 7) & ~7;
+      src += 4 * src_stride;
+      dst += 4 * dst_stride;
+      h -= 4;
+    } while (h > 0);
+    return;
+  }
 
+  // w >= 8
   do {
     int x_q4 = x0_q4;
-    x = 0;
+    uint8_t *d = dst;
+    int width = w;
+
     do {
-      uint8x8_t d[8];
-      // process 8 src_x steps
-      for (z = 0; z < 8; ++z) {
-        const uint8_t *const src_x = &src[x_q4 >> SUBPEL_BITS];
+      // Process an 8x8 tile.
+      for (int r = 0; r < 8; ++r) {
+        const uint8_t *s = &src[x_q4 >> SUBPEL_BITS];
 
         if (x_q4 & SUBPEL_MASK) {
-          const int16x8_t filters = vld1q_s16(x_filters[x_q4 & SUBPEL_MASK]);
-          uint8x8_t s[8];
-          load_u8_8x8(src_x, src_stride, &s[0], &s[1], &s[2], &s[3], &s[4],
-                      &s[5], &s[6], &s[7]);
-          transpose_elems_inplace_u8_8x8(&s[0], &s[1], &s[2], &s[3], &s[4],
-                                         &s[5], &s[6], &s[7]);
-          d[0] = scale_filter_8(s, filters);
-          vst1_u8(&temp[8 * z], d[0]);
+          // Halve filter values (all even) to avoid the need for saturating
+          // arithmetic in convolution kernels.
+          const int16x8_t filter =
+              vshrq_n_s16(vld1q_s16(x_filter[x_q4 & SUBPEL_MASK]), 1);
+
+          uint8x8_t t0, t1, t2, t3, t4, t5, t6, t7;
+          load_u8_8x8(s, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7);
+          transpose_elems_inplace_u8_8x8(&t0, &t1, &t2, &t3, &t4, &t5, &t6,
+                                         &t7);
+
+          int16x8_t s0 = vreinterpretq_s16_u16(vmovl_u8(t0));
+          int16x8_t s1 = vreinterpretq_s16_u16(vmovl_u8(t1));
+          int16x8_t s2 = vreinterpretq_s16_u16(vmovl_u8(t2));
+          int16x8_t s3 = vreinterpretq_s16_u16(vmovl_u8(t3));
+          int16x8_t s4 = vreinterpretq_s16_u16(vmovl_u8(t4));
+          int16x8_t s5 = vreinterpretq_s16_u16(vmovl_u8(t5));
+          int16x8_t s6 = vreinterpretq_s16_u16(vmovl_u8(t6));
+          int16x8_t s7 = vreinterpretq_s16_u16(vmovl_u8(t7));
+
+          uint8x8_t d0 = convolve8_8(s0, s1, s2, s3, s4, s5, s6, s7, filter);
+
+          vst1_u8(&temp[r * 8], d0);
         } else {
-          int i;
-          for (i = 0; i < 8; ++i) {
-            temp[z * 8 + i] = src_x[i * src_stride + 3];
+          // Memcpy for non-subpel locations.
+          s += SUBPEL_TAPS / 2 - 1;
+
+          for (int c = 0; c < 8; ++c) {
+            temp[r * 8 + c] = s[c * src_stride];
           }
         }
         x_q4 += x_step_q4;
       }
 
-      // transpose the 8x8 filters values back to dst
-      load_u8_8x8(temp, 8, &d[0], &d[1], &d[2], &d[3], &d[4], &d[5], &d[6],
-                  &d[7]);
-      transpose_elems_inplace_u8_8x8(&d[0], &d[1], &d[2], &d[3], &d[4], &d[5],
-                                     &d[6], &d[7]);
-      store_u8_8x8(dst + x, dst_stride, d[0], d[1], d[2], d[3], d[4], d[5],
-                   d[6], d[7]);
-      x += 8;
-    } while (x < w);
-
-    src += src_stride * 8;
-    dst += dst_stride * 8;
-  } while (y -= 8);
-}
+      // Transpose the 8x8 result tile and store.
+      uint8x8_t d0, d1, d2, d3, d4, d5, d6, d7;
+      load_u8_8x8(temp, 8, &d0, &d1, &d2, &d3, &d4, &d5, &d6, &d7);
 
-static INLINE void scaledconvolve_vert_w4(
-    const uint8_t *src, const ptrdiff_t src_stride, uint8_t *dst,
-    const ptrdiff_t dst_stride, const InterpKernel *const y_filters,
-    const int y0_q4, const int y_step_q4, const int w, const int h) {
-  int y;
-  int y_q4 = y0_q4;
+      transpose_elems_inplace_u8_8x8(&d0, &d1, &d2, &d3, &d4, &d5, &d6, &d7);
 
-  src -= src_stride * (SUBPEL_TAPS / 2 - 1);
-  y = h;
-  do {
-    const unsigned char *src_y = &src[(y_q4 >> SUBPEL_BITS) * src_stride];
+      store_u8_8x8(d, dst_stride, d0, d1, d2, d3, d4, d5, d6, d7);
 
-    if (y_q4 & SUBPEL_MASK) {
-      const int16x8_t filters = vld1q_s16(y_filters[y_q4 & SUBPEL_MASK]);
-      uint8x8_t s[8], d;
-      int16x4_t t[8], tt;
-
-      load_u8_8x8(src_y, src_stride, &s[0], &s[1], &s[2], &s[3], &s[4], &s[5],
-                  &s[6], &s[7]);
-      t[0] = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(s[0])));
-      t[1] = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(s[1])));
-      t[2] = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(s[2])));
-      t[3] = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(s[3])));
-      t[4] = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(s[4])));
-      t[5] = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(s[5])));
-      t[6] = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(s[6])));
-      t[7] = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(s[7])));
-
-      tt = convolve8_4(t[0], t[1], t[2], t[3], t[4], t[5], t[6], t[7], filters);
-      d = vqrshrun_n_s16(vcombine_s16(tt, tt), 7);
-      store_u8_4x1(dst, d);
-    } else {
-      memcpy(dst, &src_y[3 * src_stride], w);
-    }
+      d += 8;
+      width -= 8;
+    } while (width != 0);
 
-    dst += dst_stride;
-    y_q4 += y_step_q4;
-  } while (--y);
+    src += 8 * src_stride;
+    dst += 8 * dst_stride;
+    h -= 8;
+  } while (h > 0);
 }
 
-static INLINE void scaledconvolve_vert_w8(
+static INLINE void scaled_convolve_vert_neon(
     const uint8_t *src, const ptrdiff_t src_stride, uint8_t *dst,
-    const ptrdiff_t dst_stride, const InterpKernel *const y_filters,
-    const int y0_q4, const int y_step_q4, const int w, const int h) {
-  int y;
+    const ptrdiff_t dst_stride, const InterpKernel *const y_filter,
+    const int y0_q4, const int y_step_q4, int w, int h) {
   int y_q4 = y0_q4;
 
-  src -= src_stride * (SUBPEL_TAPS / 2 - 1);
-  y = h;
-  do {
-    const unsigned char *src_y = &src[(y_q4 >> SUBPEL_BITS) * src_stride];
-    if (y_q4 & SUBPEL_MASK) {
-      const int16x8_t filters = vld1q_s16(y_filters[y_q4 & SUBPEL_MASK]);
-      uint8x8_t s[8], d;
-      load_u8_8x8(src_y, src_stride, &s[0], &s[1], &s[2], &s[3], &s[4], &s[5],
-                  &s[6], &s[7]);
-      d = scale_filter_8(s, filters);
-      vst1_u8(dst, d);
-    } else {
-      memcpy(dst, &src_y[3 * src_stride], w);
-    }
-    dst += dst_stride;
-    y_q4 += y_step_q4;
-  } while (--y);
-}
+  if (w == 4) {
+    do {
+      const uint8_t *s = &src[(y_q4 >> SUBPEL_BITS) * src_stride];
+
+      if (y_q4 & SUBPEL_MASK) {
+        // Halve filter values (all even) to avoid the need for saturating
+        // arithmetic in convolution kernels.
+        const int16x8_t filter =
+            vshrq_n_s16(vld1q_s16(y_filter[y_q4 & SUBPEL_MASK]), 1);
+
+        uint8x8_t t0, t1, t2, t3, t4, t5, t6, t7;
+        load_u8_8x8(s, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7);
+
+        int16x4_t s0 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t0)));
+        int16x4_t s1 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t1)));
+        int16x4_t s2 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t2)));
+        int16x4_t s3 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t3)));
+        int16x4_t s4 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t4)));
+        int16x4_t s5 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t5)));
+        int16x4_t s6 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t6)));
+        int16x4_t s7 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(t7)));
+
+        int16x4_t dd0 = convolve8_4(s0, s1, s2, s3, s4, s5, s6, s7, filter);
+        // We halved the filter values so -1 from right shift.
+        uint8x8_t d0 =
+            vqrshrun_n_s16(vcombine_s16(dd0, vdup_n_s16(0)), FILTER_BITS - 1);
+
+        store_u8_4x1(dst, d0);
+      } else {
+        // Memcpy for non-subpel locations.
+        memcpy(dst, &s[(SUBPEL_TAPS / 2 - 1) * src_stride], 4);
+      }
 
-static INLINE void scaledconvolve_vert_w16(
-    const uint8_t *src, const ptrdiff_t src_stride, uint8_t *dst,
-    const ptrdiff_t dst_stride, const InterpKernel *const y_filters,
-    const int y0_q4, const int y_step_q4, const int w, const int h) {
-  int x, y;
-  int y_q4 = y0_q4;
+      y_q4 += y_step_q4;
+      dst += dst_stride;
+    } while (--h != 0);
+    return;
+  }
 
-  src -= src_stride * (SUBPEL_TAPS / 2 - 1);
-  y = h;
+  if (w == 8) {
+    do {
+      const uint8_t *s = &src[(y_q4 >> SUBPEL_BITS) * src_stride];
+
+      if (y_q4 & SUBPEL_MASK) {
+        // Halve filter values (all even) to avoid the need for saturating
+        // arithmetic in convolution kernels.
+        const int16x8_t filter =
+            vshrq_n_s16(vld1q_s16(y_filter[y_q4 & SUBPEL_MASK]), 1);
+
+        uint8x8_t t0, t1, t2, t3, t4, t5, t6, t7;
+        load_u8_8x8(s, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7);
+
+        int16x8_t s0 = vreinterpretq_s16_u16(vmovl_u8(t0));
+        int16x8_t s1 = vreinterpretq_s16_u16(vmovl_u8(t1));
+        int16x8_t s2 = vreinterpretq_s16_u16(vmovl_u8(t2));
+        int16x8_t s3 = vreinterpretq_s16_u16(vmovl_u8(t3));
+        int16x8_t s4 = vreinterpretq_s16_u16(vmovl_u8(t4));
+        int16x8_t s5 = vreinterpretq_s16_u16(vmovl_u8(t5));
+        int16x8_t s6 = vreinterpretq_s16_u16(vmovl_u8(t6));
+        int16x8_t s7 = vreinterpretq_s16_u16(vmovl_u8(t7));
+
+        uint8x8_t d0 = convolve8_8(s0, s1, s2, s3, s4, s5, s6, s7, filter);
+
+        vst1_u8(dst, d0);
+      } else {
+        // Memcpy for non-subpel locations.
+        memcpy(dst, &s[(SUBPEL_TAPS / 2 - 1) * src_stride], 8);
+      }
+
+      y_q4 += y_step_q4;
+      dst += dst_stride;
+    } while (--h != 0);
+    return;
+  }
+
+  // w >= 16
   do {
-    const unsigned char *src_y = &src[(y_q4 >> SUBPEL_BITS) * src_stride];
+    const uint8_t *s = &src[(y_q4 >> SUBPEL_BITS) * src_stride];
+    uint8_t *d = dst;
+    int width = w;
+
     if (y_q4 & SUBPEL_MASK) {
-      x = 0;
       do {
-        const int16x8_t filters = vld1q_s16(y_filters[y_q4 & SUBPEL_MASK]);
-        uint8x16_t ss[8];
-        uint8x8_t s[8], d[2];
-        load_u8_16x8(src_y, src_stride, &ss[0], &ss[1], &ss[2], &ss[3], &ss[4],
-                     &ss[5], &ss[6], &ss[7]);
-        s[0] = vget_low_u8(ss[0]);
-        s[1] = vget_low_u8(ss[1]);
-        s[2] = vget_low_u8(ss[2]);
-        s[3] = vget_low_u8(ss[3]);
-        s[4] = vget_low_u8(ss[4]);
-        s[5] = vget_low_u8(ss[5]);
-        s[6] = vget_low_u8(ss[6]);
-        s[7] = vget_low_u8(ss[7]);
-        d[0] = scale_filter_8(s, filters);
-
-        s[0] = vget_high_u8(ss[0]);
-        s[1] = vget_high_u8(ss[1]);
-        s[2] = vget_high_u8(ss[2]);
-        s[3] = vget_high_u8(ss[3]);
-        s[4] = vget_high_u8(ss[4]);
-        s[5] = vget_high_u8(ss[5]);
-        s[6] = vget_high_u8(ss[6]);
-        s[7] = vget_high_u8(ss[7]);
-        d[1] = scale_filter_8(s, filters);
-        vst1q_u8(&dst[x], vcombine_u8(d[0], d[1]));
-        src_y += 16;
-        x += 16;
-      } while (x < w);
+        // Halve filter values (all even) to avoid the need for saturating
+        // arithmetic in convolution kernels.
+        const int16x8_t filter =
+            vshrq_n_s16(vld1q_s16(y_filter[y_q4 & SUBPEL_MASK]), 1);
+
+        uint8x16_t t0, t1, t2, t3, t4, t5, t6, t7;
+        load_u8_16x8(s, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6, &t7);
+
+        int16x8_t s0[2], s1[2], s2[2], s3[2], s4[2], s5[2], s6[2], s7[2];
+        s0[0] = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(t0)));
+        s1[0] = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(t1)));
+        s2[0] = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(t2)));
+        s3[0] = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(t3)));
+        s4[0] = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(t4)));
+        s5[0] = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(t5)));
+        s6[0] = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(t6)));
+        s7[0] = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(t7)));
+
+        s0[1] = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(t0)));
+        s1[1] = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(t1)));
+        s2[1] = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(t2)));
+        s3[1] = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(t3)));
+        s4[1] = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(t4)));
+        s5[1] = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(t5)));
+        s6[1] = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(t6)));
+        s7[1] = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(t7)));
+
+        uint8x8_t d0 = convolve8_8(s0[0], s1[0], s2[0], s3[0], s4[0], s5[0],
+                                   s6[0], s7[0], filter);
+        uint8x8_t d1 = convolve8_8(s0[1], s1[1], s2[1], s3[1], s4[1], s5[1],
+                                   s6[1], s7[1], filter);
+
+        vst1q_u8(d, vcombine_u8(d0, d1));
+
+        s += 16;
+        d += 16;
+        width -= 16;
+      } while (width != 0);
     } else {
-      memcpy(dst, &src_y[3 * src_stride], w);
+      // Memcpy for non-subpel locations.
+      s += (SUBPEL_TAPS / 2 - 1) * src_stride;
+
+      do {
+        uint8x16_t s0 = vld1q_u8(s);
+        vst1q_u8(d, s0);
+        s += 16;
+        d += 16;
+        width -= 16;
+      } while (width != 0);
     }
-    dst += dst_stride;
+
     y_q4 += y_step_q4;
-  } while (--y);
+    dst += dst_stride;
+  } while (--h != 0);
 }
 
 void aom_scaled_2d_neon(const uint8_t *src, ptrdiff_t src_stride, uint8_t *dst,
                         ptrdiff_t dst_stride, const InterpKernel *filter,
                         int x0_q4, int x_step_q4, int y0_q4, int y_step_q4,
                         int w, int h) {
-  // Note: Fixed size intermediate buffer, temp, places limits on parameters.
+  // Fixed size intermediate buffer, im_block, places limits on parameters.
   // 2d filtering proceeds in 2 steps:
   //   (1) Interpolate horizontally into an intermediate buffer, temp.
   //   (2) Interpolate temp vertically to derive the sub-pixel result.
-  // Deriving the maximum number of rows in the temp buffer (135):
+  // Deriving the maximum number of rows in the im_block buffer (135):
   // --Smallest scaling factor is x1/2 ==> y_step_q4 = 32 (Normative).
   // --Largest block size is 64x64 pixels.
   // --64 rows in the downscaled frame span a distance of (64 - 1) * 32 in the
@@ -327,33 +311,25 @@ void aom_scaled_2d_neon(const uint8_t *src, ptrdiff_t src_stride, uint8_t *dst,
   // When calling in frame scaling function, the smallest scaling factor is x1/4
   // ==> y_step_q4 = 64. Since w and h are at most 16, the temp buffer is still
   // big enough.
-  DECLARE_ALIGNED(16, uint8_t, temp[(135 + 8) * 64]);
-  const int intermediate_height =
+  DECLARE_ALIGNED(16, uint8_t, im_block[(135 + 8) * 64]);
+  const int im_height =
       (((h - 1) * y_step_q4 + y0_q4) >> SUBPEL_BITS) + SUBPEL_TAPS;
+  const ptrdiff_t im_stride = 64;
 
   assert(w <= 64);
   assert(h <= 64);
   assert(y_step_q4 <= 32 || (y_step_q4 <= 64 && h <= 32));
   assert(x_step_q4 <= 64);
 
-  if (w >= 8) {
-    scaledconvolve_horiz_w8(src - src_stride * (SUBPEL_TAPS / 2 - 1),
-                            src_stride, temp, 64, filter, x0_q4, x_step_q4, w,
-                            intermediate_height);
-  } else {
-    scaledconvolve_horiz_w4(src - src_stride * (SUBPEL_TAPS / 2 - 1),
-                            src_stride, temp, 64, filter, x0_q4, x_step_q4, w,
-                            intermediate_height);
-  }
+  // Account for needing SUBPEL_TAPS / 2 - 1 lines prior and SUBPEL_TAPS / 2
+  // lines post both horizontally and vertically.
+  const ptrdiff_t horiz_offset = SUBPEL_TAPS / 2 - 1;
+  const ptrdiff_t vert_offset = (SUBPEL_TAPS / 2 - 1) * src_stride;
 
-  if (w >= 16) {
-    scaledconvolve_vert_w16(temp + 64 * (SUBPEL_TAPS / 2 - 1), 64, dst,
-                            dst_stride, filter, y0_q4, y_step_q4, w, h);
-  } else if (w == 8) {
-    scaledconvolve_vert_w8(temp + 64 * (SUBPEL_TAPS / 2 - 1), 64, dst,
-                           dst_stride, filter, y0_q4, y_step_q4, w, h);
-  } else {
-    scaledconvolve_vert_w4(temp + 64 * (SUBPEL_TAPS / 2 - 1), 64, dst,
-                           dst_stride, filter, y0_q4, y_step_q4, w, h);
-  }
+  scaled_convolve_horiz_neon(src - horiz_offset - vert_offset, src_stride,
+                             im_block, im_stride, filter, x0_q4, x_step_q4, w,
+                             im_height);
+
+  scaled_convolve_vert_neon(im_block, im_stride, dst, dst_stride, filter, y0_q4,
+                            y_step_q4, w, h);
 }