aom: Add intermediate buffers for av1_compute_stats_highbd

From 860bdd3a127ec9d26b0dff804e06dcbbc2dc299c Mon Sep 17 00:00:00 2001
From: Salome Thirot <[EMAIL REDACTED]>
Date: Tue, 9 Apr 2024 17:01:23 +0100
Subject: [PATCH] Add intermediate buffers for av1_compute_stats_highbd

The standard bitdepth version of av1_compute_stats makes use of
intermediate buffers dgd_avg and src_avg to precompute the subtraction
of the average. Port this optimization to the highbd version in
preparation for those buffers being used in a subsequent commit. For now
these buffers are never allocated for high bidepth.

Also remove the allocation for Neon, as the implementation doesn't make
use of them.

Change-Id: I3b646bafaf43013d508f343b318897a83c519410
---
 av1/common/av1_rtcd_defs.pl                |  2 +-
 av1/encoder/arm/neon/highbd_pickrst_neon.c |  5 +-
 av1/encoder/pickrst.c                      | 21 +++++----
 av1/encoder/x86/pickrst_avx2.c             | 12 +++--
 av1/encoder/x86/pickrst_sse4.c             | 18 ++++---
 test/wiener_test.cc                        | 55 ++++++++++++++--------
 6 files changed, 74 insertions(+), 39 deletions(-)

diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl
index e97f39145d..6a0043c761 100644
--- a/av1/common/av1_rtcd_defs.pl
+++ b/av1/common/av1_rtcd_defs.pl
@@ -469,7 +469,7 @@ ()
       specialize qw/av1_calc_proj_params_high_bd sse4_1 avx2 neon/;
       add_proto qw/int64_t av1_highbd_pixel_proj_error/, "const uint8_t *src8, int width, int height, int src_stride, const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride, int32_t *flt1, int flt1_stride, int xq[2], const sgr_params_type *params";
       specialize qw/av1_highbd_pixel_proj_error sse4_1 avx2 neon/;
-      add_proto qw/void av1_compute_stats_highbd/,  "int wiener_win, const uint8_t *dgd8, const uint8_t *src8, int h_start, int h_end, int v_start, int v_end, int dgd_stride, int src_stride, int64_t *M, int64_t *H, aom_bit_depth_t bit_depth";
+      add_proto qw/void av1_compute_stats_highbd/,  "int wiener_win, const uint8_t *dgd8, const uint8_t *src8, int16_t *dgd_avg, int16_t *src_avg, int h_start, int h_end, int v_start, int v_end, int dgd_stride, int src_stride, int64_t *M, int64_t *H, aom_bit_depth_t bit_depth";
       specialize qw/av1_compute_stats_highbd sse4_1 avx2 neon/;
     }
   }
diff --git a/av1/encoder/arm/neon/highbd_pickrst_neon.c b/av1/encoder/arm/neon/highbd_pickrst_neon.c
index 47b5f5cfb7..8b0d3bcc7e 100644
--- a/av1/encoder/arm/neon/highbd_pickrst_neon.c
+++ b/av1/encoder/arm/neon/highbd_pickrst_neon.c
@@ -1008,10 +1008,13 @@ static uint16_t highbd_find_average_neon(const uint16_t *src, int src_stride,
 }
 
 void av1_compute_stats_highbd_neon(int wiener_win, const uint8_t *dgd8,
-                                   const uint8_t *src8, int h_start, int h_end,
+                                   const uint8_t *src8, int16_t *dgd_avg,
+                                   int16_t *src_avg, int h_start, int h_end,
                                    int v_start, int v_end, int dgd_stride,
                                    int src_stride, int64_t *M, int64_t *H,
                                    aom_bit_depth_t bit_depth) {
+  (void)dgd_avg;
+  (void)src_avg;
   assert(wiener_win == WIENER_WIN || wiener_win == WIENER_WIN_REDUCED);
 
   const int wiener_halfwin = wiener_win >> 1;
diff --git a/av1/encoder/pickrst.c b/av1/encoder/pickrst.c
index b0d0d0bb78..a431c4dada 100644
--- a/av1/encoder/pickrst.c
+++ b/av1/encoder/pickrst.c
@@ -1044,10 +1044,13 @@ void av1_compute_stats_c(int wiener_win, const uint8_t *dgd, const uint8_t *src,
 
 #if CONFIG_AV1_HIGHBITDEPTH
 void av1_compute_stats_highbd_c(int wiener_win, const uint8_t *dgd8,
-                                const uint8_t *src8, int h_start, int h_end,
+                                const uint8_t *src8, int16_t *dgd_avg,
+                                int16_t *src_avg, int h_start, int h_end,
                                 int v_start, int v_end, int dgd_stride,
                                 int src_stride, int64_t *M, int64_t *H,
                                 aom_bit_depth_t bit_depth) {
+  (void)dgd_avg;
+  (void)src_avg;
   int i, j, k, l;
   int32_t Y[WIENER_WIN2];
   const int wiener_win2 = wiener_win * wiener_win;
@@ -1659,9 +1662,10 @@ static AOM_INLINE void search_wiener(
     // functions. Optimize intrinsics of HBD design similar to LBD (i.e.,
     // pre-calculate d and s buffers and avoid most of the C operations).
     av1_compute_stats_highbd(reduced_wiener_win, rsc->dgd_buffer,
-                             rsc->src_buffer, limits->h_start, limits->h_end,
-                             limits->v_start, limits->v_end, rsc->dgd_stride,
-                             rsc->src_stride, M, H, cm->seq_params->bit_depth);
+                             rsc->src_buffer, rsc->dgd_avg, rsc->src_avg,
+                             limits->h_start, limits->h_end, limits->v_start,
+                             limits->v_end, rsc->dgd_stride, rsc->src_stride, M,
+                             H, cm->seq_params->bit_depth);
   } else {
     av1_compute_stats(reduced_wiener_win, rsc->dgd_buffer, rsc->src_buffer,
                       rsc->dgd_avg, rsc->src_avg, limits->h_start,
@@ -2081,10 +2085,9 @@ void av1_pick_filter_restoration(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi) {
   // and height aligned to multiple of 16 is considered for intrinsic purpose.
   rsc.dgd_avg = NULL;
   rsc.src_avg = NULL;
-#if HAVE_AVX2 || HAVE_NEON
-  // The buffers allocated below are used during Wiener filter processing of low
-  // bitdepth path. Hence, allocate the same when Wiener filter is enabled in
-  // low bitdepth path.
+#if HAVE_AVX2
+  // The buffers allocated below are used during Wiener filter processing.
+  // Hence, allocate the same when Wiener filter is enabled.
   if (!cpi->sf.lpf_sf.disable_wiener_filter && !highbd) {
     const int buf_size = sizeof(*cpi->pick_lr_ctxt.dgd_avg) * 6 *
                          RESTORATION_UNITSIZE_MAX * RESTORATION_UNITSIZE_MAX;
@@ -2221,7 +2224,7 @@ void av1_pick_filter_restoration(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi) {
                               best_luma_unit_size);
   }
 
-#if HAVE_AVX || HAVE_NEON
+#if HAVE_AVX2
   if (!cpi->sf.lpf_sf.disable_wiener_filter && !highbd) {
     aom_free(cpi->pick_lr_ctxt.dgd_avg);
     cpi->pick_lr_ctxt.dgd_avg = NULL;
diff --git a/av1/encoder/x86/pickrst_avx2.c b/av1/encoder/x86/pickrst_avx2.c
index 6658ed39a8..1f76576c9e 100644
--- a/av1/encoder/x86/pickrst_avx2.c
+++ b/av1/encoder/x86/pickrst_avx2.c
@@ -345,21 +345,27 @@ static INLINE void compute_stats_highbd_win5_opt_avx2(
 }
 
 void av1_compute_stats_highbd_avx2(int wiener_win, const uint8_t *dgd8,
-                                   const uint8_t *src8, int h_start, int h_end,
+                                   const uint8_t *src8, int16_t *dgd_avg,
+                                   int16_t *src_avg, int h_start, int h_end,
                                    int v_start, int v_end, int dgd_stride,
                                    int src_stride, int64_t *M, int64_t *H,
                                    aom_bit_depth_t bit_depth) {
   if (wiener_win == WIENER_WIN) {
+    (void)dgd_avg;
+    (void)src_avg;
     compute_stats_highbd_win7_opt_avx2(dgd8, src8, h_start, h_end, v_start,
                                        v_end, dgd_stride, src_stride, M, H,
                                        bit_depth);
   } else if (wiener_win == WIENER_WIN_CHROMA) {
+    (void)dgd_avg;
+    (void)src_avg;
     compute_stats_highbd_win5_opt_avx2(dgd8, src8, h_start, h_end, v_start,
                                        v_end, dgd_stride, src_stride, M, H,
                                        bit_depth);
   } else {
-    av1_compute_stats_highbd_c(wiener_win, dgd8, src8, h_start, h_end, v_start,
-                               v_end, dgd_stride, src_stride, M, H, bit_depth);
+    av1_compute_stats_highbd_c(wiener_win, dgd8, src8, dgd_avg, src_avg,
+                               h_start, h_end, v_start, v_end, dgd_stride,
+                               src_stride, M, H, bit_depth);
   }
 }
 #endif  // CONFIG_AV1_HIGHBITDEPTH
diff --git a/av1/encoder/x86/pickrst_sse4.c b/av1/encoder/x86/pickrst_sse4.c
index 50db305802..3617d33fef 100644
--- a/av1/encoder/x86/pickrst_sse4.c
+++ b/av1/encoder/x86/pickrst_sse4.c
@@ -524,21 +524,27 @@ static INLINE void compute_stats_highbd_win5_opt_sse4_1(
 }
 
 void av1_compute_stats_highbd_sse4_1(int wiener_win, const uint8_t *dgd8,
-                                     const uint8_t *src8, int h_start,
-                                     int h_end, int v_start, int v_end,
-                                     int dgd_stride, int src_stride, int64_t *M,
-                                     int64_t *H, aom_bit_depth_t bit_depth) {
+                                     const uint8_t *src8, int16_t *dgd_avg,
+                                     int16_t *src_avg, int h_start, int h_end,
+                                     int v_start, int v_end, int dgd_stride,
+                                     int src_stride, int64_t *M, int64_t *H,
+                                     aom_bit_depth_t bit_depth) {
   if (wiener_win == WIENER_WIN) {
+    (void)dgd_avg;
+    (void)src_avg;
     compute_stats_highbd_win7_opt_sse4_1(dgd8, src8, h_start, h_end, v_start,
                                          v_end, dgd_stride, src_stride, M, H,
                                          bit_depth);
   } else if (wiener_win == WIENER_WIN_CHROMA) {
+    (void)dgd_avg;
+    (void)src_avg;
     compute_stats_highbd_win5_opt_sse4_1(dgd8, src8, h_start, h_end, v_start,
                                          v_end, dgd_stride, src_stride, M, H,
                                          bit_depth);
   } else {
-    av1_compute_stats_highbd_c(wiener_win, dgd8, src8, h_start, h_end, v_start,
-                               v_end, dgd_stride, src_stride, M, H, bit_depth);
+    av1_compute_stats_highbd_c(wiener_win, dgd8, src8, dgd_avg, src_avg,
+                               h_start, h_end, v_start, v_end, dgd_stride,
+                               src_stride, M, H, bit_depth);
   }
 }
 #endif  // CONFIG_AV1_HIGHBITDEPTH
diff --git a/test/wiener_test.cc b/test/wiener_test.cc
index 2886ed77df..c38e10e3c2 100644
--- a/test/wiener_test.cc
+++ b/test/wiener_test.cc
@@ -520,25 +520,27 @@ static void compute_stats_highbd_win_opt_c(int wiener_win, const uint8_t *dgd8,
 }
 
 void compute_stats_highbd_opt_c(int wiener_win, const uint8_t *dgd,
-                                const uint8_t *src, int h_start, int h_end,
-                                int v_start, int v_end, int dgd_stride,
-                                int src_stride, int64_t *M, int64_t *H,
-                                aom_bit_depth_t bit_depth) {
+                                const uint8_t *src, int16_t *d, int16_t *s,
+                                int h_start, int h_end, int v_start, int v_end,
+                                int dgd_stride, int src_stride, int64_t *M,
+                                int64_t *H, aom_bit_depth_t bit_depth) {
   if (wiener_win == WIENER_WIN || wiener_win == WIENER_WIN_CHROMA) {
     compute_stats_highbd_win_opt_c(wiener_win, dgd, src, h_start, h_end,
                                    v_start, v_end, dgd_stride, src_stride, M, H,
                                    bit_depth);
   } else {
-    av1_compute_stats_highbd_c(wiener_win, dgd, src, h_start, h_end, v_start,
-                               v_end, dgd_stride, src_stride, M, H, bit_depth);
+    av1_compute_stats_highbd_c(wiener_win, dgd, src, d, s, h_start, h_end,
+                               v_start, v_end, dgd_stride, src_stride, M, H,
+                               bit_depth);
   }
 }
 
 static const int kIterations = 100;
 typedef void (*compute_stats_Func)(int wiener_win, const uint8_t *dgd,
-                                   const uint8_t *src, int h_start, int h_end,
-                                   int v_start, int v_end, int dgd_stride,
-                                   int src_stride, int64_t *M, int64_t *H,
+                                   const uint8_t *src, int16_t *d, int16_t *s,
+                                   int h_start, int h_end, int v_start,
+                                   int v_end, int dgd_stride, int src_stride,
+                                   int64_t *M, int64_t *H,
                                    aom_bit_depth_t bit_depth);
 
 typedef std::tuple<const compute_stats_Func> WienerTestParam;
@@ -552,11 +554,17 @@ class WienerTestHighbd : public ::testing::TestWithParam<WienerTestParam> {
     dgd_buf = (uint16_t *)aom_memalign(
         32, MAX_DATA_BLOCK * MAX_DATA_BLOCK * sizeof(*dgd_buf));
     ASSERT_NE(dgd_buf, nullptr);
+    const size_t buf_size =
+        sizeof(*buf) * 6 * RESTORATION_UNITSIZE_MAX * RESTORATION_UNITSIZE_MAX;
+    buf = (int16_t *)aom_memalign(32, buf_size);
+    ASSERT_NE(buf, nullptr);
+    memset(buf, 0, buf_size);
     target_func_ = GET_PARAM(0);
   }
   void TearDown() override {
     aom_free(src_buf);
     aom_free(dgd_buf);
+    aom_free(buf);
   }
   void RunWienerTest(const int32_t wiener_win, int32_t run_times,
                      aom_bit_depth_t bit_depth);
@@ -568,6 +576,7 @@ class WienerTestHighbd : public ::testing::TestWithParam<WienerTestParam> {
   libaom_test::ACMRandom rng_;
   uint16_t *src_buf;
   uint16_t *dgd_buf;
+  int16_t *buf;
 };
 
 void WienerTestHighbd::RunWienerTest(const int32_t wiener_win,
@@ -595,6 +604,9 @@ void WienerTestHighbd::RunWienerTest(const int32_t wiener_win,
   const int dgd_stride = h_end;
   const int src_stride = MAX_DATA_BLOCK;
   const int iters = run_times == 1 ? kIterations : 2;
+  int16_t *dgd_avg = buf;
+  int16_t *src_avg =
+      buf + (3 * RESTORATION_UNITSIZE_MAX * RESTORATION_UNITSIZE_MAX);
   for (int iter = 0; iter < iters && !HasFatalFailure(); ++iter) {
     for (int i = 0; i < MAX_DATA_BLOCK * MAX_DATA_BLOCK; ++i) {
       dgd_buf[i] = rng_.Rand16() % (1 << bit_depth);
@@ -607,16 +619,17 @@ void WienerTestHighbd::RunWienerTest(const int32_t wiener_win,
     aom_usec_timer timer;
     aom_usec_timer_start(&timer);
     for (int i = 0; i < run_times; ++i) {
-      av1_compute_stats_highbd_c(wiener_win, dgd8, src8, h_start, h_end,
-                                 v_start, v_end, dgd_stride, src_stride, M_ref,
-                                 H_ref, bit_depth);
+      av1_compute_stats_highbd_c(wiener_win, dgd8, src8, dgd_avg, src_avg,
+                                 h_start, h_end, v_start, v_end, dgd_stride,
+                                 src_stride, M_ref, H_ref, bit_depth);
     }
     aom_usec_timer_mark(&timer);
     const double time1 = static_cast<double>(aom_usec_timer_elapsed(&timer));
     aom_usec_timer_start(&timer);
     for (int i = 0; i < run_times; ++i) {
-      target_func_(wiener_win, dgd8, src8, h_start, h_end, v_start, v_end,
-                   dgd_stride, src_stride, M_test, H_test, bit_depth);
+      target_func_(wiener_win, dgd8, src8, dgd_avg, src_avg, h_start, h_end,
+                   v_start, v_end, dgd_stride, src_stride, M_test, H_test,
+                   bit_depth);
     }
     aom_usec_timer_mark(&timer);
     const double time2 = static_cast<double>(aom_usec_timer_elapsed(&timer));
@@ -663,6 +676,9 @@ void WienerTestHighbd::RunWienerTest_ExtremeValues(const int32_t wiener_win,
   const int dgd_stride = h_end;
   const int src_stride = MAX_DATA_BLOCK;
   const int iters = 1;
+  int16_t *dgd_avg = buf;
+  int16_t *src_avg =
+      buf + (3 * RESTORATION_UNITSIZE_MAX * RESTORATION_UNITSIZE_MAX);
   for (int iter = 0; iter < iters && !HasFatalFailure(); ++iter) {
     // Fill with alternating extreme values to maximize difference with
     // the average.
@@ -674,12 +690,13 @@ void WienerTestHighbd::RunWienerTest_ExtremeValues(const int32_t wiener_win,
         dgd_buf + wiener_halfwin * MAX_DATA_BLOCK + wiener_halfwin);
     const uint8_t *src8 = CONVERT_TO_BYTEPTR(src_buf);
 
-    av1_compute_stats_highbd_c(wiener_win, dgd8, src8, h_start, h_end, v_start,
-                               v_end, dgd_stride, src_stride, M_ref, H_ref,
-                               bit_depth);
+    av1_compute_stats_highbd_c(wiener_win, dgd8, src8, dgd_avg, src_avg,
+                               h_start, h_end, v_start, v_end, dgd_stride,
+                               src_stride, M_ref, H_ref, bit_depth);
 
-    target_func_(wiener_win, dgd8, src8, h_start, h_end, v_start, v_end,
-                 dgd_stride, src_stride, M_test, H_test, bit_depth);
+    target_func_(wiener_win, dgd8, src8, dgd_avg, src_avg, h_start, h_end,
+                 v_start, v_end, dgd_stride, src_stride, M_test, H_test,
+                 bit_depth);
 
     int failed = 0;
     for (int i = 0; i < wiener_win2; ++i) {