aom: Propagate offset constant into Neon averaging helper functions

From aaac42ee3a62831f2f00190e9578f0840b0001bd Mon Sep 17 00:00:00 2001
From: Salome Thirot <[EMAIL REDACTED]>
Date: Thu, 14 Mar 2024 11:26:34 +0000
Subject: [PATCH] Propagate offset constant into Neon averaging helper
 functions

Averaging helper functions for high bitdepth compound convolutions
already have a specialized implementation on the bitdepth, but the
offset was still computed in the calling convolution functions. Move the
computation inside the averaging function, allowing for some terms to
become known at compile time.

Change-Id: I8ba8cb1bb0c68ea9321d56e4da9027a21243efda
---
 .../arm/highbd_compound_convolve_neon.c       | 43 +++++++------------
 .../arm/highbd_compound_convolve_neon.h       | 33 ++++++++++----
 .../arm/highbd_compound_convolve_sve2.c       | 11 ++---
 3 files changed, 44 insertions(+), 43 deletions(-)

diff --git a/av1/common/arm/highbd_compound_convolve_neon.c b/av1/common/arm/highbd_compound_convolve_neon.c
index 05773393d..c93a1d4e2 100644
--- a/av1/common/arm/highbd_compound_convolve_neon.c
+++ b/av1/common/arm/highbd_compound_convolve_neon.c
@@ -486,9 +486,6 @@ void av1_highbd_dist_wtd_convolve_x_neon(
   const int im_stride = MAX_SB_SIZE;
   const int horiz_offset = filter_params_x->taps / 2 - 1;
   assert(FILTER_BITS == COMPOUND_ROUND1_BITS);
-  const int offset_bits = bd + 2 * FILTER_BITS - conv_params->round_0;
-  const int offset_avg = (1 << (offset_bits - conv_params->round_1)) +
-                         (1 << (offset_bits - conv_params->round_1 - 1));
   const int offset_convolve = (1 << (conv_params->round_0 - 1)) +
                               (1 << (bd + FILTER_BITS)) +
                               (1 << (bd + FILTER_BITS - 1));
@@ -511,10 +508,10 @@ void av1_highbd_dist_wtd_convolve_x_neon(
       }
       if (conv_params->use_dist_wtd_comp_avg) {
         highbd_12_dist_wtd_comp_avg_neon(im_block, im_stride, dst, dst_stride,
-                                         w, h, conv_params, offset_avg, bd);
+                                         w, h, conv_params);
       } else {
         highbd_12_comp_avg_neon(im_block, im_stride, dst, dst_stride, w, h,
-                                conv_params, offset_avg, bd);
+                                conv_params);
       }
     } else {
       if (x_filter_taps <= 6 && w != 4) {
@@ -538,10 +535,10 @@ void av1_highbd_dist_wtd_convolve_x_neon(
       }
       if (conv_params->use_dist_wtd_comp_avg) {
         highbd_dist_wtd_comp_avg_neon(im_block, im_stride, dst, dst_stride, w,
-                                      h, conv_params, offset_avg, bd);
+                                      h, conv_params, bd);
       } else {
         highbd_comp_avg_neon(im_block, im_stride, dst, dst_stride, w, h,
-                             conv_params, offset_avg, bd);
+                             conv_params, bd);
       }
     } else {
       if (x_filter_taps <= 6 && w != 4) {
@@ -891,9 +888,6 @@ void av1_highbd_dist_wtd_convolve_y_neon(
   const int im_stride = MAX_SB_SIZE;
   const int vert_offset = filter_params_y->taps / 2 - 1;
   assert(FILTER_BITS == COMPOUND_ROUND1_BITS);
-  const int offset_bits = bd + 2 * FILTER_BITS - conv_params->round_0;
-  const int round_offset_avg = (1 << (offset_bits - conv_params->round_1)) +
-                               (1 << (offset_bits - conv_params->round_1 - 1));
   const int round_offset_conv = (1 << (conv_params->round_0 - 1)) +
                                 (1 << (bd + FILTER_BITS)) +
                                 (1 << (bd + FILTER_BITS - 1));
@@ -916,11 +910,10 @@ void av1_highbd_dist_wtd_convolve_y_neon(
       }
       if (conv_params->use_dist_wtd_comp_avg) {
         highbd_12_dist_wtd_comp_avg_neon(im_block, im_stride, dst, dst_stride,
-                                         w, h, conv_params, round_offset_avg,
-                                         bd);
+                                         w, h, conv_params);
       } else {
         highbd_12_comp_avg_neon(im_block, im_stride, dst, dst_stride, w, h,
-                                conv_params, round_offset_avg, bd);
+                                conv_params);
       }
     } else {
       if (y_filter_taps <= 6) {
@@ -946,10 +939,10 @@ void av1_highbd_dist_wtd_convolve_y_neon(
       }
       if (conv_params->use_dist_wtd_comp_avg) {
         highbd_dist_wtd_comp_avg_neon(im_block, im_stride, dst, dst_stride, w,
-                                      h, conv_params, round_offset_avg, bd);
+                                      h, conv_params, bd);
       } else {
         highbd_comp_avg_neon(im_block, im_stride, dst, dst_stride, w, h,
-                             conv_params, round_offset_avg, bd);
+                             conv_params, bd);
       }
     } else {
       if (y_filter_taps <= 6) {
@@ -1028,18 +1021,18 @@ void av1_highbd_dist_wtd_convolve_2d_copy_neon(const uint16_t *src,
     if (conv_params->use_dist_wtd_comp_avg) {
       if (bd == 12) {
         highbd_12_dist_wtd_comp_avg_neon(im_block, im_stride, dst, dst_stride,
-                                         w, h, conv_params, round_offset, bd);
+                                         w, h, conv_params);
       } else {
         highbd_dist_wtd_comp_avg_neon(im_block, im_stride, dst, dst_stride, w,
-                                      h, conv_params, round_offset, bd);
+                                      h, conv_params, bd);
       }
     } else {
       if (bd == 12) {
         highbd_12_comp_avg_neon(im_block, im_stride, dst, dst_stride, w, h,
-                                conv_params, round_offset, bd);
+                                conv_params);
       } else {
         highbd_comp_avg_neon(im_block, im_stride, dst, dst_stride, w, h,
-                             conv_params, round_offset, bd);
+                             conv_params, bd);
       }
     }
   }
@@ -1692,9 +1685,6 @@ void av1_highbd_dist_wtd_convolve_2d_neon(
       (1 << (bd + FILTER_BITS - 1)) + (1 << (conv_params->round_0 - 1));
   const int y_offset_bits = bd + 2 * FILTER_BITS - conv_params->round_0;
   const int round_offset_conv_y = (1 << y_offset_bits);
-  const int round_offset_avg =
-      ((1 << (y_offset_bits - conv_params->round_1)) +
-       (1 << (y_offset_bits - conv_params->round_1 - 1)));
 
   const uint16_t *src_ptr = src - vert_offset * src_stride - horiz_offset;
 
@@ -1755,19 +1745,18 @@ void av1_highbd_dist_wtd_convolve_2d_neon(
     if (conv_params->use_dist_wtd_comp_avg) {
       if (bd == 12) {
         highbd_12_dist_wtd_comp_avg_neon(im_block2, im_stride, dst, dst_stride,
-                                         w, h, conv_params, round_offset_avg,
-                                         bd);
+                                         w, h, conv_params);
       } else {
         highbd_dist_wtd_comp_avg_neon(im_block2, im_stride, dst, dst_stride, w,
-                                      h, conv_params, round_offset_avg, bd);
+                                      h, conv_params, bd);
       }
     } else {
       if (bd == 12) {
         highbd_12_comp_avg_neon(im_block2, im_stride, dst, dst_stride, w, h,
-                                conv_params, round_offset_avg, bd);
+                                conv_params);
       } else {
         highbd_comp_avg_neon(im_block2, im_stride, dst, dst_stride, w, h,
-                             conv_params, round_offset_avg, bd);
+                             conv_params, bd);
       }
     }
   }
diff --git a/av1/common/arm/highbd_compound_convolve_neon.h b/av1/common/arm/highbd_compound_convolve_neon.h
index efe70440f..c9344f3ad 100644
--- a/av1/common/arm/highbd_compound_convolve_neon.h
+++ b/av1/common/arm/highbd_compound_convolve_neon.h
@@ -24,12 +24,15 @@
 static INLINE void highbd_12_comp_avg_neon(const uint16_t *src_ptr,
                                            int src_stride, uint16_t *dst_ptr,
                                            int dst_stride, int w, int h,
-                                           ConvolveParams *conv_params,
-                                           const int offset, const int bd) {
+                                           ConvolveParams *conv_params) {
+  const int offset_bits = 12 + 2 * FILTER_BITS - ROUND0_BITS - 2;
+  const int offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
+                     (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
+
   CONV_BUF_TYPE *ref_ptr = conv_params->dst;
   const int ref_stride = conv_params->dst_stride;
-  const uint16x4_t offset_vec = vdup_n_u16(offset);
-  const uint16x8_t max = vdupq_n_u16((1 << bd) - 1);
+  const uint16x4_t offset_vec = vdup_n_u16((uint16_t)offset);
+  const uint16x8_t max = vdupq_n_u16((1 << 12) - 1);
 
   if (w == 4) {
     do {
@@ -86,10 +89,14 @@ static INLINE void highbd_comp_avg_neon(const uint16_t *src_ptr, int src_stride,
                                         uint16_t *dst_ptr, int dst_stride,
                                         int w, int h,
                                         ConvolveParams *conv_params,
-                                        const int offset, const int bd) {
+                                        const int bd) {
+  const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
+  const int offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
+                     (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
+
   CONV_BUF_TYPE *ref_ptr = conv_params->dst;
   const int ref_stride = conv_params->dst_stride;
-  const uint16x4_t offset_vec = vdup_n_u16(offset);
+  const uint16x4_t offset_vec = vdup_n_u16((uint16_t)offset);
   const uint16x8_t max = vdupq_n_u16((1 << bd) - 1);
 
   if (w == 4) {
@@ -145,11 +152,15 @@ static INLINE void highbd_comp_avg_neon(const uint16_t *src_ptr, int src_stride,
 
 static INLINE void highbd_12_dist_wtd_comp_avg_neon(
     const uint16_t *src_ptr, int src_stride, uint16_t *dst_ptr, int dst_stride,
-    int w, int h, ConvolveParams *conv_params, const int offset, const int bd) {
+    int w, int h, ConvolveParams *conv_params) {
+  const int offset_bits = 12 + 2 * FILTER_BITS - ROUND0_BITS - 2;
+  const int offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
+                     (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
+
   CONV_BUF_TYPE *ref_ptr = conv_params->dst;
   const int ref_stride = conv_params->dst_stride;
   const uint32x4_t offset_vec = vdupq_n_u32(offset);
-  const uint16x8_t max = vdupq_n_u16((1 << bd) - 1);
+  const uint16x8_t max = vdupq_n_u16((1 << 12) - 1);
   uint16x4_t fwd_offset = vdup_n_u16(conv_params->fwd_offset);
   uint16x4_t bck_offset = vdup_n_u16(conv_params->bck_offset);
 
@@ -212,7 +223,11 @@ static INLINE void highbd_12_dist_wtd_comp_avg_neon(
 
 static INLINE void highbd_dist_wtd_comp_avg_neon(
     const uint16_t *src_ptr, int src_stride, uint16_t *dst_ptr, int dst_stride,
-    int w, int h, ConvolveParams *conv_params, const int offset, const int bd) {
+    int w, int h, ConvolveParams *conv_params, const int bd) {
+  const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
+  const int offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
+                     (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
+
   CONV_BUF_TYPE *ref_ptr = conv_params->dst;
   const int ref_stride = conv_params->dst_stride;
   const uint32x4_t offset_vec = vdupq_n_u32(offset);
diff --git a/av1/common/arm/highbd_compound_convolve_sve2.c b/av1/common/arm/highbd_compound_convolve_sve2.c
index f7eda226e..b36e01f2f 100644
--- a/av1/common/arm/highbd_compound_convolve_sve2.c
+++ b/av1/common/arm/highbd_compound_convolve_sve2.c
@@ -223,9 +223,6 @@ void av1_highbd_dist_wtd_convolve_x_sve2(
   const int im_stride = MAX_SB_SIZE;
   const int horiz_offset = filter_params_x->taps / 2 - 1;
   assert(FILTER_BITS == COMPOUND_ROUND1_BITS);
-  const int offset_bits = bd + 2 * FILTER_BITS - conv_params->round_0;
-  const int offset_avg = (1 << (offset_bits - conv_params->round_1)) +
-                         (1 << (offset_bits - conv_params->round_1 - 1));
   const int offset_convolve = (1 << (conv_params->round_0 - 1)) +
                               (1 << (bd + FILTER_BITS)) +
                               (1 << (bd + FILTER_BITS - 1));
@@ -249,21 +246,21 @@ void av1_highbd_dist_wtd_convolve_x_sve2(
     if (conv_params->use_dist_wtd_comp_avg) {
       if (bd == 12) {
         highbd_12_dist_wtd_comp_avg_neon(im_block, im_stride, dst, dst_stride,
-                                         w, h, conv_params, offset_avg, bd);
+                                         w, h, conv_params);
 
       } else {
         highbd_dist_wtd_comp_avg_neon(im_block, im_stride, dst, dst_stride, w,
-                                      h, conv_params, offset_avg, bd);
+                                      h, conv_params, bd);
       }
 
     } else {
       if (bd == 12) {
         highbd_12_comp_avg_neon(im_block, im_stride, dst, dst_stride, w, h,
-                                conv_params, offset_avg, bd);
+                                conv_params);
 
       } else {
         highbd_comp_avg_neon(im_block, im_stride, dst, dst_stride, w, h,
-                             conv_params, offset_avg, bd);
+                             conv_params, bd);
       }
     }
   } else {