aom: Arm SVE: move highbd max value clamp into convolution kernel

From a2d599c9750e3027d3104770fe74ff5d5d012c13 Mon Sep 17 00:00:00 2001
From: Jonathan Wright <[EMAIL REDACTED]>
Date: Thu, 15 Feb 2024 15:10:00 +0000
Subject: [PATCH] Arm SVE: move highbd max value clamp into convolution kernel

Move the highbd max value clamp into the horizontal convolution
kernel. This is a purely cosmetic change.

Change-Id: I871edfc494d8788e952fa97cb94d9decb90fbefe
---
 aom_dsp/arm/highbd_convolve8_sve.c | 54 +++++++++++++-----------------
 1 file changed, 24 insertions(+), 30 deletions(-)

diff --git a/aom_dsp/arm/highbd_convolve8_sve.c b/aom_dsp/arm/highbd_convolve8_sve.c
index 8220a4fa1..b00f4d38f 100644
--- a/aom_dsp/arm/highbd_convolve8_sve.c
+++ b/aom_dsp/arm/highbd_convolve8_sve.c
@@ -19,8 +19,8 @@
 #include "aom_dsp/arm/dot_sve.h"
 #include "aom_dsp/arm/mem_neon.h"
 
-static INLINE uint16x4_t highbd_convolve8_4_h(int16x8_t s[4],
-                                              int16x8_t filter) {
+static INLINE uint16x4_t highbd_convolve8_4_h(int16x8_t s[4], int16x8_t filter,
+                                              uint16x4_t max) {
   int64x2_t sum[4];
 
   sum[0] = aom_sdotq_s16(vdupq_n_s64(0), s[0], filter);
@@ -31,13 +31,14 @@ static INLINE uint16x4_t highbd_convolve8_4_h(int16x8_t s[4],
   int64x2_t sum01 = vpaddq_s64(sum[0], sum[1]);
   int64x2_t sum23 = vpaddq_s64(sum[2], sum[3]);
 
-  int32x4_t res = vcombine_s32(vmovn_s64(sum01), vmovn_s64(sum23));
+  int32x4_t sum0123 = vcombine_s32(vmovn_s64(sum01), vmovn_s64(sum23));
 
-  return vqrshrun_n_s32(res, FILTER_BITS);
+  uint16x4_t res = vqrshrun_n_s32(sum0123, FILTER_BITS);
+  return vmin_u16(res, max);
 }
 
-static INLINE uint16x8_t highbd_convolve8_8_h(int16x8_t s[8],
-                                              int16x8_t filter) {
+static INLINE uint16x8_t highbd_convolve8_8_h(int16x8_t s[8], int16x8_t filter,
+                                              uint16x8_t max) {
   int64x2_t sum[8];
 
   sum[0] = aom_sdotq_s16(vdupq_n_s64(0), s[0], filter);
@@ -54,11 +55,12 @@ static INLINE uint16x8_t highbd_convolve8_8_h(int16x8_t s[8],
   int64x2_t sum45 = vpaddq_s64(sum[4], sum[5]);
   int64x2_t sum67 = vpaddq_s64(sum[6], sum[7]);
 
-  int32x4_t res0 = vcombine_s32(vmovn_s64(sum01), vmovn_s64(sum23));
-  int32x4_t res1 = vcombine_s32(vmovn_s64(sum45), vmovn_s64(sum67));
+  int32x4_t sum0123 = vcombine_s32(vmovn_s64(sum01), vmovn_s64(sum23));
+  int32x4_t sum4567 = vcombine_s32(vmovn_s64(sum45), vmovn_s64(sum67));
 
-  return vcombine_u16(vqrshrun_n_s32(res0, FILTER_BITS),
-                      vqrshrun_n_s32(res1, FILTER_BITS));
+  uint16x8_t res = vcombine_u16(vqrshrun_n_s32(sum0123, FILTER_BITS),
+                                vqrshrun_n_s32(sum4567, FILTER_BITS));
+  return vminq_u16(res, max);
 }
 
 void aom_highbd_convolve8_horiz_sve(const uint8_t *src8, ptrdiff_t src_stride,
@@ -76,10 +78,11 @@ void aom_highbd_convolve8_horiz_sve(const uint8_t *src8, ptrdiff_t src_stride,
   uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
 
   src -= SUBPEL_TAPS / 2 - 1;
-  const uint16x8_t max = vdupq_n_u16((1 << bd) - 1);
+
   const int16x8_t filter = vld1q_s16(filter_x);
 
   if (width == 4) {
+    const uint16x4_t max = vdup_n_u16((1 << bd) - 1);
     const int16_t *s = (const int16_t *)src;
     uint16_t *d = dst;
 
@@ -90,15 +93,10 @@ void aom_highbd_convolve8_horiz_sve(const uint8_t *src8, ptrdiff_t src_stride,
       load_s16_8x4(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3]);
       load_s16_8x4(s + 3 * src_stride, 1, &s3[0], &s3[1], &s3[2], &s3[3]);
 
-      uint16x4_t d0 = highbd_convolve8_4_h(s0, filter);
-      uint16x4_t d1 = highbd_convolve8_4_h(s1, filter);
-      uint16x4_t d2 = highbd_convolve8_4_h(s2, filter);
-      uint16x4_t d3 = highbd_convolve8_4_h(s3, filter);
-
-      d0 = vmin_u16(d0, vget_low_u16(max));
-      d1 = vmin_u16(d1, vget_low_u16(max));
-      d2 = vmin_u16(d2, vget_low_u16(max));
-      d3 = vmin_u16(d3, vget_low_u16(max));
+      uint16x4_t d0 = highbd_convolve8_4_h(s0, filter, max);
+      uint16x4_t d1 = highbd_convolve8_4_h(s1, filter, max);
+      uint16x4_t d2 = highbd_convolve8_4_h(s2, filter, max);
+      uint16x4_t d3 = highbd_convolve8_4_h(s3, filter, max);
 
       store_u16_4x4(d, dst_stride, d0, d1, d2, d3);
 
@@ -108,9 +106,10 @@ void aom_highbd_convolve8_horiz_sve(const uint8_t *src8, ptrdiff_t src_stride,
     } while (height > 0);
   } else {
     do {
-      int w = width;
+      const uint16x8_t max = vdupq_n_u16((1 << bd) - 1);
       const int16_t *s = (const int16_t *)src;
       uint16_t *d = dst;
+      int w = width;
 
       do {
         int16x8_t s0[8], s1[8], s2[8], s3[8];
@@ -123,15 +122,10 @@ void aom_highbd_convolve8_horiz_sve(const uint8_t *src8, ptrdiff_t src_stride,
         load_s16_8x8(s + 3 * src_stride, 1, &s3[0], &s3[1], &s3[2], &s3[3],
                      &s3[4], &s3[5], &s3[6], &s3[7]);
 
-        uint16x8_t d0 = highbd_convolve8_8_h(s0, filter);
-        uint16x8_t d1 = highbd_convolve8_8_h(s1, filter);
-        uint16x8_t d2 = highbd_convolve8_8_h(s2, filter);
-        uint16x8_t d3 = highbd_convolve8_8_h(s3, filter);
-
-        d0 = vminq_u16(d0, max);
-        d1 = vminq_u16(d1, max);
-        d2 = vminq_u16(d2, max);
-        d3 = vminq_u16(d3, max);
+        uint16x8_t d0 = highbd_convolve8_8_h(s0, filter, max);
+        uint16x8_t d1 = highbd_convolve8_8_h(s1, filter, max);
+        uint16x8_t d2 = highbd_convolve8_8_h(s2, filter, max);
+        uint16x8_t d3 = highbd_convolve8_8_h(s3, filter, max);
 
         store_u16_8x4(d, dst_stride, d0, d1, d2, d3);