aom: Remove unnecessary operations in convolve8_vert_neon

From ecb58ac8ad319849f8183502afd67da8accf76ef Mon Sep 17 00:00:00 2001
From: Salome Thirot <[EMAIL REDACTED]>
Date: Tue, 30 Jan 2024 11:53:49 +0000
Subject: [PATCH] Remove unnecessary operations in convolve8_vert_neon

In the neon_dotprod and neon_i8mm implementations of aom_convolve_vert
some TBL operations had their result immediately overwritten when
entering the loop, so remove them. Also take the opportunity to cleanup
unnecessary forward declarations.

Change-Id: I2f4542355e49f309af8985f82ded8ccf2f430d36
---
 aom_dsp/arm/aom_convolve8_neon_dotprod.c | 116 ++++++++++-------------
 aom_dsp/arm/aom_convolve8_neon_i8mm.c    |  72 ++++++--------
 2 files changed, 75 insertions(+), 113 deletions(-)

diff --git a/aom_dsp/arm/aom_convolve8_neon_dotprod.c b/aom_dsp/arm/aom_convolve8_neon_dotprod.c
index ac0a6efd00..c82125ba17 100644
--- a/aom_dsp/arm/aom_convolve8_neon_dotprod.c
+++ b/aom_dsp/arm/aom_convolve8_neon_dotprod.c
@@ -267,8 +267,6 @@ void aom_convolve8_vert_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride,
   const int32x4_t correction = vdupq_n_s32((int32_t)vaddvq_s16(correct_tmp));
   const uint8x8_t range_limit = vdup_n_u8(128);
   const uint8x16x3_t merge_block_tbl = vld1q_u8_x3(dot_prod_merge_block_tbl);
-  uint8x8_t t0, t1, t2, t3, t4, t5, t6;
-  int8x8_t s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10;
   int8x16x2_t samples_LUT;
 
   assert((intptr_t)dst % 4 == 0);
@@ -282,46 +280,39 @@ void aom_convolve8_vert_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride,
 
   if (w == 4) {
     const uint8x16_t tran_concat_tbl = vld1q_u8(dot_prod_tran_concat_tbl);
-    int8x16_t s0123, s1234, s2345, s3456, s4567, s5678, s6789, s78910;
-    int16x4_t d0, d1, d2, d3;
-    uint8x8_t d01, d23;
 
+    uint8x8_t t0, t1, t2, t3, t4, t5, t6;
     load_u8_8x7(src, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6);
     src += 7 * src_stride;
 
     /* Clamp sample range to [-128, 127] for 8-bit signed dot product. */
-    s0 = vreinterpret_s8_u8(vsub_u8(t0, range_limit));
-    s1 = vreinterpret_s8_u8(vsub_u8(t1, range_limit));
-    s2 = vreinterpret_s8_u8(vsub_u8(t2, range_limit));
-    s3 = vreinterpret_s8_u8(vsub_u8(t3, range_limit));
-    s4 = vreinterpret_s8_u8(vsub_u8(t4, range_limit));
-    s5 = vreinterpret_s8_u8(vsub_u8(t5, range_limit));
-    s6 = vreinterpret_s8_u8(vsub_u8(t6, range_limit));
-    s7 = vdup_n_s8(0);
-    s8 = vdup_n_s8(0);
-    s9 = vdup_n_s8(0);
+    int8x8_t s0 = vreinterpret_s8_u8(vsub_u8(t0, range_limit));
+    int8x8_t s1 = vreinterpret_s8_u8(vsub_u8(t1, range_limit));
+    int8x8_t s2 = vreinterpret_s8_u8(vsub_u8(t2, range_limit));
+    int8x8_t s3 = vreinterpret_s8_u8(vsub_u8(t3, range_limit));
+    int8x8_t s4 = vreinterpret_s8_u8(vsub_u8(t4, range_limit));
+    int8x8_t s5 = vreinterpret_s8_u8(vsub_u8(t5, range_limit));
+    int8x8_t s6 = vreinterpret_s8_u8(vsub_u8(t6, range_limit));
 
     /* This operation combines a conventional transpose and the sample permute
      * (see horizontal case) required before computing the dot product.
      */
+    int8x16_t s0123, s1234, s2345, s3456;
     transpose_concat_4x4(s0, s1, s2, s3, &s0123, tran_concat_tbl);
     transpose_concat_4x4(s1, s2, s3, s4, &s1234, tran_concat_tbl);
     transpose_concat_4x4(s2, s3, s4, s5, &s2345, tran_concat_tbl);
     transpose_concat_4x4(s3, s4, s5, s6, &s3456, tran_concat_tbl);
-    transpose_concat_4x4(s4, s5, s6, s7, &s4567, tran_concat_tbl);
-    transpose_concat_4x4(s5, s6, s7, s8, &s5678, tran_concat_tbl);
-    transpose_concat_4x4(s6, s7, s8, s9, &s6789, tran_concat_tbl);
 
     do {
       uint8x8_t t7, t8, t9, t10;
-
       load_u8_8x4(src, src_stride, &t7, &t8, &t9, &t10);
 
-      s7 = vreinterpret_s8_u8(vsub_u8(t7, range_limit));
-      s8 = vreinterpret_s8_u8(vsub_u8(t8, range_limit));
-      s9 = vreinterpret_s8_u8(vsub_u8(t9, range_limit));
-      s10 = vreinterpret_s8_u8(vsub_u8(t10, range_limit));
+      int8x8_t s7 = vreinterpret_s8_u8(vsub_u8(t7, range_limit));
+      int8x8_t s8 = vreinterpret_s8_u8(vsub_u8(t8, range_limit));
+      int8x8_t s9 = vreinterpret_s8_u8(vsub_u8(t9, range_limit));
+      int8x8_t s10 = vreinterpret_s8_u8(vsub_u8(t10, range_limit));
 
+      int8x16_t s4567, s5678, s6789, s78910;
       transpose_concat_4x4(s7, s8, s9, s10, &s78910, tran_concat_tbl);
 
       /* Merge new data into block from previous iteration. */
@@ -331,12 +322,13 @@ void aom_convolve8_vert_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride,
       s5678 = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[1]);
       s6789 = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[2]);
 
-      d0 = convolve8_4_sdot_partial(s0123, s4567, correction, filter);
-      d1 = convolve8_4_sdot_partial(s1234, s5678, correction, filter);
-      d2 = convolve8_4_sdot_partial(s2345, s6789, correction, filter);
-      d3 = convolve8_4_sdot_partial(s3456, s78910, correction, filter);
-      d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS);
-      d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS);
+      int16x4_t d0 = convolve8_4_sdot_partial(s0123, s4567, correction, filter);
+      int16x4_t d1 = convolve8_4_sdot_partial(s1234, s5678, correction, filter);
+      int16x4_t d2 = convolve8_4_sdot_partial(s2345, s6789, correction, filter);
+      int16x4_t d3 =
+          convolve8_4_sdot_partial(s3456, s78910, correction, filter);
+      uint8x8_t d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS);
+      uint8x8_t d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS);
 
       store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d01);
       store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d23);
@@ -354,37 +346,30 @@ void aom_convolve8_vert_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride,
     } while (h != 0);
   } else {
     const uint8x16x2_t tran_concat_tbl = vld1q_u8_x2(dot_prod_tran_concat_tbl);
-    int8x16_t s0123_lo, s0123_hi, s1234_lo, s1234_hi, s2345_lo, s2345_hi,
-        s3456_lo, s3456_hi, s4567_lo, s4567_hi, s5678_lo, s5678_hi, s6789_lo,
-        s6789_hi, s78910_lo, s78910_hi;
-    uint8x8_t d0, d1, d2, d3;
-    const uint8_t *s;
-    uint8_t *d;
-    int height;
 
     do {
-      height = h;
-      s = src;
-      d = dst;
+      int height = h;
+      const uint8_t *s = src;
+      uint8_t *d = dst;
 
+      uint8x8_t t0, t1, t2, t3, t4, t5, t6;
       load_u8_8x7(s, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6);
       s += 7 * src_stride;
 
       /* Clamp sample range to [-128, 127] for 8-bit signed dot product. */
-      s0 = vreinterpret_s8_u8(vsub_u8(t0, range_limit));
-      s1 = vreinterpret_s8_u8(vsub_u8(t1, range_limit));
-      s2 = vreinterpret_s8_u8(vsub_u8(t2, range_limit));
-      s3 = vreinterpret_s8_u8(vsub_u8(t3, range_limit));
-      s4 = vreinterpret_s8_u8(vsub_u8(t4, range_limit));
-      s5 = vreinterpret_s8_u8(vsub_u8(t5, range_limit));
-      s6 = vreinterpret_s8_u8(vsub_u8(t6, range_limit));
-      s7 = vdup_n_s8(0);
-      s8 = vdup_n_s8(0);
-      s9 = vdup_n_s8(0);
+      int8x8_t s0 = vreinterpret_s8_u8(vsub_u8(t0, range_limit));
+      int8x8_t s1 = vreinterpret_s8_u8(vsub_u8(t1, range_limit));
+      int8x8_t s2 = vreinterpret_s8_u8(vsub_u8(t2, range_limit));
+      int8x8_t s3 = vreinterpret_s8_u8(vsub_u8(t3, range_limit));
+      int8x8_t s4 = vreinterpret_s8_u8(vsub_u8(t4, range_limit));
+      int8x8_t s5 = vreinterpret_s8_u8(vsub_u8(t5, range_limit));
+      int8x8_t s6 = vreinterpret_s8_u8(vsub_u8(t6, range_limit));
 
       /* This operation combines a conventional transpose and the sample permute
        * (see horizontal case) required before computing the dot product.
        */
+      int8x16_t s0123_lo, s0123_hi, s1234_lo, s1234_hi, s2345_lo, s2345_hi,
+          s3456_lo, s3456_hi;
       transpose_concat_8x4(s0, s1, s2, s3, &s0123_lo, &s0123_hi,
                            tran_concat_tbl);
       transpose_concat_8x4(s1, s2, s3, s4, &s1234_lo, &s1234_hi,
@@ -393,23 +378,18 @@ void aom_convolve8_vert_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride,
                            tran_concat_tbl);
       transpose_concat_8x4(s3, s4, s5, s6, &s3456_lo, &s3456_hi,
                            tran_concat_tbl);
-      transpose_concat_8x4(s4, s5, s6, s7, &s4567_lo, &s4567_hi,
-                           tran_concat_tbl);
-      transpose_concat_8x4(s5, s6, s7, s8, &s5678_lo, &s5678_hi,
-                           tran_concat_tbl);
-      transpose_concat_8x4(s6, s7, s8, s9, &s6789_lo, &s6789_hi,
-                           tran_concat_tbl);
 
       do {
         uint8x8_t t7, t8, t9, t10;
-
         load_u8_8x4(s, src_stride, &t7, &t8, &t9, &t10);
 
-        s7 = vreinterpret_s8_u8(vsub_u8(t7, range_limit));
-        s8 = vreinterpret_s8_u8(vsub_u8(t8, range_limit));
-        s9 = vreinterpret_s8_u8(vsub_u8(t9, range_limit));
-        s10 = vreinterpret_s8_u8(vsub_u8(t10, range_limit));
+        int8x8_t s7 = vreinterpret_s8_u8(vsub_u8(t7, range_limit));
+        int8x8_t s8 = vreinterpret_s8_u8(vsub_u8(t8, range_limit));
+        int8x8_t s9 = vreinterpret_s8_u8(vsub_u8(t9, range_limit));
+        int8x8_t s10 = vreinterpret_s8_u8(vsub_u8(t10, range_limit));
 
+        int8x16_t s4567_lo, s4567_hi, s5678_lo, s5678_hi, s6789_lo, s6789_hi,
+            s78910_lo, s78910_hi;
         transpose_concat_8x4(s7, s8, s9, s10, &s78910_lo, &s78910_hi,
                              tran_concat_tbl);
 
@@ -426,14 +406,14 @@ void aom_convolve8_vert_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride,
         s5678_hi = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[1]);
         s6789_hi = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[2]);
 
-        d0 = convolve8_8_sdot_partial(s0123_lo, s4567_lo, s0123_hi, s4567_hi,
-                                      correction, filter);
-        d1 = convolve8_8_sdot_partial(s1234_lo, s5678_lo, s1234_hi, s5678_hi,
-                                      correction, filter);
-        d2 = convolve8_8_sdot_partial(s2345_lo, s6789_lo, s2345_hi, s6789_hi,
-                                      correction, filter);
-        d3 = convolve8_8_sdot_partial(s3456_lo, s78910_lo, s3456_hi, s78910_hi,
-                                      correction, filter);
+        uint8x8_t d0 = convolve8_8_sdot_partial(s0123_lo, s4567_lo, s0123_hi,
+                                                s4567_hi, correction, filter);
+        uint8x8_t d1 = convolve8_8_sdot_partial(s1234_lo, s5678_lo, s1234_hi,
+                                                s5678_hi, correction, filter);
+        uint8x8_t d2 = convolve8_8_sdot_partial(s2345_lo, s6789_lo, s2345_hi,
+                                                s6789_hi, correction, filter);
+        uint8x8_t d3 = convolve8_8_sdot_partial(s3456_lo, s78910_lo, s3456_hi,
+                                                s78910_hi, correction, filter);
 
         store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
 
diff --git a/aom_dsp/arm/aom_convolve8_neon_i8mm.c b/aom_dsp/arm/aom_convolve8_neon_i8mm.c
index c314c0a192..86f06158eb 100644
--- a/aom_dsp/arm/aom_convolve8_neon_i8mm.c
+++ b/aom_dsp/arm/aom_convolve8_neon_i8mm.c
@@ -246,7 +246,6 @@ void aom_convolve8_vert_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride,
                                   int h) {
   const int8x8_t filter = vmovn_s16(vld1q_s16(filter_y));
   const uint8x16x3_t merge_block_tbl = vld1q_u8_x3(dot_prod_merge_block_tbl);
-  uint8x8_t s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10;
   uint8x16x2_t samples_LUT;
 
   assert((intptr_t)dst % 4 == 0);
@@ -260,31 +259,25 @@ void aom_convolve8_vert_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride,
 
   if (w == 4) {
     const uint8x16_t tran_concat_tbl = vld1q_u8(dot_prod_tran_concat_tbl);
-    uint8x16_t s0123, s1234, s2345, s3456, s4567, s5678, s6789, s78910;
-    int16x4_t d0, d1, d2, d3;
-    uint8x8_t d01, d23;
 
+    uint8x8_t s0, s1, s2, s3, s4, s5, s6;
     load_u8_8x7(src, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6);
     src += 7 * src_stride;
 
-    s7 = vdup_n_u8(0);
-    s8 = vdup_n_u8(0);
-    s9 = vdup_n_u8(0);
-
     /* This operation combines a conventional transpose and the sample permute
      * (see horizontal case) required before computing the dot product.
      */
+    uint8x16_t s0123, s1234, s2345, s3456;
     transpose_concat_4x4(s0, s1, s2, s3, &s0123, tran_concat_tbl);
     transpose_concat_4x4(s1, s2, s3, s4, &s1234, tran_concat_tbl);
     transpose_concat_4x4(s2, s3, s4, s5, &s2345, tran_concat_tbl);
     transpose_concat_4x4(s3, s4, s5, s6, &s3456, tran_concat_tbl);
-    transpose_concat_4x4(s4, s5, s6, s7, &s4567, tran_concat_tbl);
-    transpose_concat_4x4(s5, s6, s7, s8, &s5678, tran_concat_tbl);
-    transpose_concat_4x4(s6, s7, s8, s9, &s6789, tran_concat_tbl);
 
     do {
+      uint8x8_t s7, s8, s9, s10;
       load_u8_8x4(src, src_stride, &s7, &s8, &s9, &s10);
 
+      uint8x16_t s4567, s5678, s6789, s78910;
       transpose_concat_4x4(s7, s8, s9, s10, &s78910, tran_concat_tbl);
 
       /* Merge new data into block from previous iteration. */
@@ -294,12 +287,12 @@ void aom_convolve8_vert_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride,
       s5678 = vqtbl2q_u8(samples_LUT, merge_block_tbl.val[1]);
       s6789 = vqtbl2q_u8(samples_LUT, merge_block_tbl.val[2]);
 
-      d0 = convolve8_4_usdot_partial(s0123, s4567, filter);
-      d1 = convolve8_4_usdot_partial(s1234, s5678, filter);
-      d2 = convolve8_4_usdot_partial(s2345, s6789, filter);
-      d3 = convolve8_4_usdot_partial(s3456, s78910, filter);
-      d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS);
-      d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS);
+      int16x4_t d0 = convolve8_4_usdot_partial(s0123, s4567, filter);
+      int16x4_t d1 = convolve8_4_usdot_partial(s1234, s5678, filter);
+      int16x4_t d2 = convolve8_4_usdot_partial(s2345, s6789, filter);
+      int16x4_t d3 = convolve8_4_usdot_partial(s3456, s78910, filter);
+      uint8x8_t d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS);
+      uint8x8_t d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS);
 
       store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d01);
       store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d23);
@@ -317,29 +310,21 @@ void aom_convolve8_vert_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride,
     } while (h != 0);
   } else {
     const uint8x16x2_t tran_concat_tbl = vld1q_u8_x2(dot_prod_tran_concat_tbl);
-    uint8x16_t s0123_lo, s0123_hi, s1234_lo, s1234_hi, s2345_lo, s2345_hi,
-        s3456_lo, s3456_hi, s4567_lo, s4567_hi, s5678_lo, s5678_hi, s6789_lo,
-        s6789_hi, s78910_lo, s78910_hi;
-    uint8x8_t d0, d1, d2, d3;
-    const uint8_t *s;
-    uint8_t *d;
-    int height;
 
     do {
-      height = h;
-      s = src;
-      d = dst;
+      int height = h;
+      const uint8_t *s = src;
+      uint8_t *d = dst;
 
+      uint8x8_t s0, s1, s2, s3, s4, s5, s6;
       load_u8_8x7(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6);
       s += 7 * src_stride;
 
-      s7 = vdup_n_u8(0);
-      s8 = vdup_n_u8(0);
-      s9 = vdup_n_u8(0);
-
       /* This operation combines a conventional transpose and the sample permute
        * (see horizontal case) required before computing the dot product.
        */
+      uint8x16_t s0123_lo, s0123_hi, s1234_lo, s1234_hi, s2345_lo, s2345_hi,
+          s3456_lo, s3456_hi;
       transpose_concat_8x4(s0, s1, s2, s3, &s0123_lo, &s0123_hi,
                            tran_concat_tbl);
       transpose_concat_8x4(s1, s2, s3, s4, &s1234_lo, &s1234_hi,
@@ -348,16 +333,13 @@ void aom_convolve8_vert_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride,
                            tran_concat_tbl);
       transpose_concat_8x4(s3, s4, s5, s6, &s3456_lo, &s3456_hi,
                            tran_concat_tbl);
-      transpose_concat_8x4(s4, s5, s6, s7, &s4567_lo, &s4567_hi,
-                           tran_concat_tbl);
-      transpose_concat_8x4(s5, s6, s7, s8, &s5678_lo, &s5678_hi,
-                           tran_concat_tbl);
-      transpose_concat_8x4(s6, s7, s8, s9, &s6789_lo, &s6789_hi,
-                           tran_concat_tbl);
 
       do {
+        uint8x8_t s7, s8, s9, s10;
         load_u8_8x4(s, src_stride, &s7, &s8, &s9, &s10);
 
+        uint8x16_t s4567_lo, s4567_hi, s5678_lo, s5678_hi, s6789_lo, s6789_hi,
+            s78910_lo, s78910_hi;
         transpose_concat_8x4(s7, s8, s9, s10, &s78910_lo, &s78910_hi,
                              tran_concat_tbl);
 
@@ -374,14 +356,14 @@ void aom_convolve8_vert_neon_i8mm(const uint8_t *src, ptrdiff_t src_stride,
         s5678_hi = vqtbl2q_u8(samples_LUT, merge_block_tbl.val[1]);
         s6789_hi = vqtbl2q_u8(samples_LUT, merge_block_tbl.val[2]);
 
-        d0 = convolve8_8_usdot_partial(s0123_lo, s4567_lo, s0123_hi, s4567_hi,
-                                       filter);
-        d1 = convolve8_8_usdot_partial(s1234_lo, s5678_lo, s1234_hi, s5678_hi,
-                                       filter);
-        d2 = convolve8_8_usdot_partial(s2345_lo, s6789_lo, s2345_hi, s6789_hi,
-                                       filter);
-        d3 = convolve8_8_usdot_partial(s3456_lo, s78910_lo, s3456_hi, s78910_hi,
-                                       filter);
+        uint8x8_t d0 = convolve8_8_usdot_partial(s0123_lo, s4567_lo, s0123_hi,
+                                                 s4567_hi, filter);
+        uint8x8_t d1 = convolve8_8_usdot_partial(s1234_lo, s5678_lo, s1234_hi,
+                                                 s5678_hi, filter);
+        uint8x8_t d2 = convolve8_8_usdot_partial(s2345_lo, s6789_lo, s2345_hi,
+                                                 s6789_hi, filter);
+        uint8x8_t d3 = convolve8_8_usdot_partial(s3456_lo, s78910_lo, s3456_hi,
+                                                 s78910_hi, filter);
 
         store_u8_8x4(d, dst_stride, d0, d1, d2, d3);