aom: Optimize transpose functions in aom_highbd_convolve8_vert_sve

From 8398313499003afa5129e0e367ccc36fc33cab7f Mon Sep 17 00:00:00 2001
From: Salome Thirot <[EMAIL REDACTED]>
Date: Mon, 4 Mar 2024 11:41:27 +0000
Subject: [PATCH] Optimize transpose functions in aom_highbd_convolve8_vert_sve

Use ZIP instructions instead of TBL to transpose and concatenate
elements in the SVE implementation of aom_highbd_convolve8_vert. This
removed the need to load a table and gives up to 10% uplift.

Change-Id: I92ad082512f263393cb5def409f8dccbb5278016
---
 aom_dsp/arm/highbd_convolve8_sve.c | 80 ++++++++++++------------------
 1 file changed, 33 insertions(+), 47 deletions(-)

diff --git a/aom_dsp/arm/highbd_convolve8_sve.c b/aom_dsp/arm/highbd_convolve8_sve.c
index 189d11b14..9830b7e5d 100644
--- a/aom_dsp/arm/highbd_convolve8_sve.c
+++ b/aom_dsp/arm/highbd_convolve8_sve.c
@@ -261,11 +261,6 @@ void aom_highbd_convolve8_horiz_sve(const uint8_t *src8, ptrdiff_t src_stride,
   }
 }
 
-DECLARE_ALIGNED(16, static const uint8_t, kDotProdTranConcatTbl[32]) = {
-  0, 1, 8,  9,  16, 17, 24, 25, 2, 3, 10, 11, 18, 19, 26, 27,
-  4, 5, 12, 13, 20, 21, 28, 29, 6, 7, 14, 15, 22, 23, 30, 31
-};
-
 DECLARE_ALIGNED(16, static const uint8_t, kDotProdMergeBlockTbl[48]) = {
   // Shift left and insert new last column in transposed 4x4 block.
   2, 3, 4, 5, 6, 7, 16, 17, 10, 11, 12, 13, 14, 15, 24, 25,
@@ -277,8 +272,7 @@ DECLARE_ALIGNED(16, static const uint8_t, kDotProdMergeBlockTbl[48]) = {
 
 static INLINE void transpose_concat_4x4(int16x4_t s0, int16x4_t s1,
                                         int16x4_t s2, int16x4_t s3,
-                                        int16x8_t res[2],
-                                        uint8x16_t permute_tbl[2]) {
+                                        int16x8_t res[2]) {
   // Transpose 16-bit elements and concatenate result rows as follows:
   // s0: 00, 01, 02, 03
   // s1: 10, 11, 12, 13
@@ -287,22 +281,24 @@ static INLINE void transpose_concat_4x4(int16x4_t s0, int16x4_t s1,
   //
   // res[0]: 00 10 20 30 01 11 21 31
   // res[1]: 02 12 22 32 03 13 23 33
-  //
-  // The 'permute_tbl' is always 'kDotProdTranConcatTbl' above. Passing it
-  // as an argument is preferable to loading it directly from memory as this
-  // inline helper is called many times from the same parent function.
 
-  int8x16x2_t samples = { vreinterpretq_s8_s16(vcombine_s16(s0, s1)),
-                          vreinterpretq_s8_s16(vcombine_s16(s2, s3)) };
+  int16x8_t s0q = vcombine_s16(s0, vdup_n_s16(0));
+  int16x8_t s1q = vcombine_s16(s1, vdup_n_s16(0));
+  int16x8_t s2q = vcombine_s16(s2, vdup_n_s16(0));
+  int16x8_t s3q = vcombine_s16(s3, vdup_n_s16(0));
+
+  int32x4_t s01 = vreinterpretq_s32_s16(vzip1q_s16(s0q, s1q));
+  int32x4_t s23 = vreinterpretq_s32_s16(vzip1q_s16(s2q, s3q));
 
-  res[0] = vreinterpretq_s16_s8(vqtbl2q_s8(samples, permute_tbl[0]));
-  res[1] = vreinterpretq_s16_s8(vqtbl2q_s8(samples, permute_tbl[1]));
+  int32x4x2_t s0123 = vzipq_s32(s01, s23);
+
+  res[0] = vreinterpretq_s16_s32(s0123.val[0]);
+  res[1] = vreinterpretq_s16_s32(s0123.val[1]);
 }
 
 static INLINE void transpose_concat_8x4(int16x8_t s0, int16x8_t s1,
                                         int16x8_t s2, int16x8_t s3,
-                                        int16x8_t res[4],
-                                        uint8x16_t permute_tbl[2]) {
+                                        int16x8_t res[4]) {
   // Transpose 16-bit elements and concatenate result rows as follows:
   // s0: 00, 01, 02, 03, 04, 05, 06, 07
   // s1: 10, 11, 12, 13, 14, 15, 16, 17
@@ -313,26 +309,19 @@ static INLINE void transpose_concat_8x4(int16x8_t s0, int16x8_t s1,
   // res_lo[1]: 02 12 22 32 03 13 23 33
   // res_hi[0]: 04 14 24 34 05 15 25 35
   // res_hi[1]: 06 16 26 36 07 17 27 37
-  //
-  // The 'permute_tbl' is always 'kDotProdTranConcatTbl' above. Passing it
-  // as an argument is preferable to loading it directly from memory as this
-  // inline helper is called many times from the same parent function.
-
-  int8x16x2_t samples_lo = {
-    vreinterpretq_s8_s16(vcombine_s16(vget_low_s16(s0), vget_low_s16(s1))),
-    vreinterpretq_s8_s16(vcombine_s16(vget_low_s16(s2), vget_low_s16(s3)))
-  };
 
-  res[0] = vreinterpretq_s16_s8(vqtbl2q_s8(samples_lo, permute_tbl[0]));
-  res[1] = vreinterpretq_s16_s8(vqtbl2q_s8(samples_lo, permute_tbl[1]));
+  int16x8x2_t tr01_16 = vzipq_s16(s0, s1);
+  int16x8x2_t tr23_16 = vzipq_s16(s2, s3);
 
-  int8x16x2_t samples_hi = {
-    vreinterpretq_s8_s16(vcombine_s16(vget_high_s16(s0), vget_high_s16(s1))),
-    vreinterpretq_s8_s16(vcombine_s16(vget_high_s16(s2), vget_high_s16(s3)))
-  };
+  int32x4x2_t tr01_32 = vzipq_s32(vreinterpretq_s32_s16(tr01_16.val[0]),
+                                  vreinterpretq_s32_s16(tr23_16.val[0]));
+  int32x4x2_t tr23_32 = vzipq_s32(vreinterpretq_s32_s16(tr01_16.val[1]),
+                                  vreinterpretq_s32_s16(tr23_16.val[1]));
 
-  res[2] = vreinterpretq_s16_s8(vqtbl2q_s8(samples_hi, permute_tbl[0]));
-  res[3] = vreinterpretq_s16_s8(vqtbl2q_s8(samples_hi, permute_tbl[1]));
+  res[0] = vreinterpretq_s16_s32(tr01_32.val[0]);
+  res[1] = vreinterpretq_s16_s32(tr01_32.val[1]);
+  res[2] = vreinterpretq_s16_s32(tr23_32.val[0]);
+  res[3] = vreinterpretq_s16_s32(tr23_32.val[1]);
 }
 
 static INLINE void aom_tbl2x4_s16(int16x8_t t0[4], int16x8_t t1[4],
@@ -427,9 +416,6 @@ void aom_highbd_convolve8_vert_sve(const uint8_t *src8, ptrdiff_t src_stride,
 
   const int16x8_t y_filter = vld1q_s16(filter_y);
 
-  uint8x16_t tran_concat_tbl[2];
-  tran_concat_tbl[0] = vld1q_u8(kDotProdTranConcatTbl);
-  tran_concat_tbl[1] = vld1q_u8(kDotProdTranConcatTbl + 16);
   uint8x16_t merge_block_tbl[3];
   merge_block_tbl[0] = vld1q_u8(kDotProdMergeBlockTbl);
   merge_block_tbl[1] = vld1q_u8(kDotProdMergeBlockTbl + 16);
@@ -446,10 +432,10 @@ void aom_highbd_convolve8_vert_sve(const uint8_t *src8, ptrdiff_t src_stride,
     // This operation combines a conventional transpose and the sample permute
     // required before computing the dot product.
     int16x8_t s0123[2], s1234[2], s2345[2], s3456[2];
-    transpose_concat_4x4(s0, s1, s2, s3, s0123, tran_concat_tbl);
-    transpose_concat_4x4(s1, s2, s3, s4, s1234, tran_concat_tbl);
-    transpose_concat_4x4(s2, s3, s4, s5, s2345, tran_concat_tbl);
-    transpose_concat_4x4(s3, s4, s5, s6, s3456, tran_concat_tbl);
+    transpose_concat_4x4(s0, s1, s2, s3, s0123);
+    transpose_concat_4x4(s1, s2, s3, s4, s1234);
+    transpose_concat_4x4(s2, s3, s4, s5, s2345);
+    transpose_concat_4x4(s3, s4, s5, s6, s3456);
 
     do {
       int16x4_t s7, s8, s9, s10;
@@ -458,7 +444,7 @@ void aom_highbd_convolve8_vert_sve(const uint8_t *src8, ptrdiff_t src_stride,
       int16x8_t s4567[2], s5678[2], s6789[2], s78910[2];
 
       // Transpose and shuffle the 4 lines that were loaded.
-      transpose_concat_4x4(s7, s8, s9, s10, s78910, tran_concat_tbl);
+      transpose_concat_4x4(s7, s8, s9, s10, s78910);
 
       // Merge new data into block from previous iteration.
       aom_tbl2x2_s16(s3456, s78910, merge_block_tbl[0], s4567);
@@ -501,10 +487,10 @@ void aom_highbd_convolve8_vert_sve(const uint8_t *src8, ptrdiff_t src_stride,
       // This operation combines a conventional transpose and the sample permute
       // required before computing the dot product.
       int16x8_t s0123[4], s1234[4], s2345[4], s3456[4];
-      transpose_concat_8x4(s0, s1, s2, s3, s0123, tran_concat_tbl);
-      transpose_concat_8x4(s1, s2, s3, s4, s1234, tran_concat_tbl);
-      transpose_concat_8x4(s2, s3, s4, s5, s2345, tran_concat_tbl);
-      transpose_concat_8x4(s3, s4, s5, s6, s3456, tran_concat_tbl);
+      transpose_concat_8x4(s0, s1, s2, s3, s0123);
+      transpose_concat_8x4(s1, s2, s3, s4, s1234);
+      transpose_concat_8x4(s2, s3, s4, s5, s2345);
+      transpose_concat_8x4(s3, s4, s5, s6, s3456);
 
       do {
         int16x8_t s7, s8, s9, s10;
@@ -513,7 +499,7 @@ void aom_highbd_convolve8_vert_sve(const uint8_t *src8, ptrdiff_t src_stride,
         int16x8_t s4567[4], s5678[4], s6789[4], s78910[4];
 
         // Transpose and shuffle the 4 lines that were loaded.
-        transpose_concat_8x4(s7, s8, s9, s10, s78910, tran_concat_tbl);
+        transpose_concat_8x4(s7, s8, s9, s10, s78910);
 
         // Merge new data into block from previous iteration.
         aom_tbl2x4_s16(s3456, s78910, merge_block_tbl[0], s4567);