aom: {,highbd_}intrapred_neon.c: Avoid over-reads in z1 and z3 preds (d2c3b)

From d2c3b1078aa7f82398c8d533b127833c5e27a62b Mon Sep 17 00:00:00 2001
From: George Steed <[EMAIL REDACTED]>
Date: Thu, 9 May 2024 12:09:57 +0000
Subject: [PATCH] {,highbd_}intrapred_neon.c: Avoid over-reads in z1 and z3
 preds

The existing z1 and z3 predictors already contain checks to see if the
first element of the vector would over-read, however this is not
sufficient since the vector may straddle the end of the input array.

To get around this, add an additional check against the end of the
array. If we would over-read, load a full vector up to the end of the
array and then use TBL to shuffle the data into the correct place. This
also means that we no longer need the compare and BSL at the end of each
loop iteration to select between the computed data or the value of the
last element duplicated.

Bug: aomedia:3571
Change-Id: I03e2313b9bf0b44d64811fff1bedf4eb7381518a
(cherry picked from commit f1b43b5c0d0c98a37713e9939a782ebe014c1d1f)
---
 aom_dsp/arm/highbd_intrapred_neon.c | 108 ++++++++++++++++++++--------
 aom_dsp/arm/intrapred_neon.c        |  65 +++++++++++++----
 2 files changed, 133 insertions(+), 40 deletions(-)

diff --git a/aom_dsp/arm/highbd_intrapred_neon.c b/aom_dsp/arm/highbd_intrapred_neon.c
index dc47974c68..e66f523e37 100644
--- a/aom_dsp/arm/highbd_intrapred_neon.c
+++ b/aom_dsp/arm/highbd_intrapred_neon.c
@@ -1293,6 +1293,33 @@ static AOM_FORCE_INLINE uint16x8_t highbd_dr_z1_apply_shift_x8(uint16x8_t a0,
       highbd_dr_z1_apply_shift_x4(vget_high_u16(a0), vget_high_u16(a1), shift));
 }
 
+// clang-format off
+static const uint8_t kLoadMaxShuffles[] = {
+  14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
+  12, 13, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
+  10, 11, 12, 13, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
+   8,  9, 10, 11, 12, 13, 14, 15, 14, 15, 14, 15, 14, 15, 14, 15,
+   6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 14, 15, 14, 15, 14, 15,
+   4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 14, 15, 14, 15,
+   2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 14, 15,
+   0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15,
+};
+// clang-format on
+
+static INLINE uint16x8_t zn_load_masked_neon(const uint16_t *ptr,
+                                             int shuffle_idx) {
+  uint8x16_t shuffle = vld1q_u8(&kLoadMaxShuffles[16 * shuffle_idx]);
+  uint8x16_t src = vreinterpretq_u8_u16(vld1q_u16(ptr));
+#if AOM_ARCH_AARCH64
+  return vreinterpretq_u16_u8(vqtbl1q_u8(src, shuffle));
+#else
+  uint8x8x2_t src2 = { { vget_low_u8(src), vget_high_u8(src) } };
+  uint8x8_t lo = vtbl2_u8(src2, vget_low_u8(shuffle));
+  uint8x8_t hi = vtbl2_u8(src2, vget_high_u8(shuffle));
+  return vreinterpretq_u16_u8(vcombine_u8(lo, hi));
+#endif
+}
+
 static void highbd_dr_prediction_z1_upsample0_neon(uint16_t *dst,
                                                    ptrdiff_t stride, int bw,
                                                    int bh,
@@ -1336,13 +1363,26 @@ static void highbd_dr_prediction_z1_upsample0_neon(uint16_t *dst,
     } else {
       int c = 0;
       do {
-        const uint16x8_t a0 = vld1q_u16(&above[base + c]);
-        const uint16x8_t a1 = vld1q_u16(&above[base + c + 1]);
-        const uint16x8_t val = highbd_dr_z1_apply_shift_x8(a0, a1, shift);
-        const uint16x8_t cmp =
-            vcgtq_s16(vdupq_n_s16(max_base_x - base - c), iota1x8);
-        const uint16x8_t res = vbslq_u16(cmp, val, vdupq_n_u16(above_max));
-        vst1q_u16(dst + c, res);
+        uint16x8_t a0;
+        uint16x8_t a1;
+        if (base + c >= max_base_x) {
+          a0 = a1 = vdupq_n_u16(above_max);
+        } else {
+          if (base + c + 7 >= max_base_x) {
+            int shuffle_idx = max_base_x - base - c;
+            a0 = zn_load_masked_neon(above + (max_base_x - 7), shuffle_idx);
+          } else {
+            a0 = vld1q_u16(above + base + c);
+          }
+          if (base + c + 8 >= max_base_x) {
+            int shuffle_idx = max_base_x - base - c - 1;
+            a1 = zn_load_masked_neon(above + (max_base_x - 7), shuffle_idx);
+          } else {
+            a1 = vld1q_u16(above + base + c + 1);
+          }
+        }
+
+        vst1q_u16(dst + c, highbd_dr_z1_apply_shift_x8(a0, a1, shift));
         c += 8;
       } while (c < bw);
     }
@@ -2456,13 +2496,29 @@ void av1_highbd_dr_prediction_z2_neon(uint16_t *dst, ptrdiff_t stride, int bw,
     val_lo = vmlal_lane_u16(val_lo, vget_low_u16(in1), (s1), (lane));     \
     uint32x4_t val_hi = vmull_lane_u16(vget_high_u16(in0), (s0), (lane)); \
     val_hi = vmlal_lane_u16(val_hi, vget_high_u16(in1), (s1), (lane));    \
-    const uint16x8_t cmp = vaddq_u16((iota), vdupq_n_u16(base));          \
-    const uint16x8_t res = vcombine_u16(vrshrn_n_u32(val_lo, (shift)),    \
-                                        vrshrn_n_u32(val_hi, (shift)));   \
-    *(out) = vbslq_u16(vcltq_u16(cmp, vdupq_n_u16(max_base_y)), res,      \
-                       vdupq_n_u16(left_max));                            \
+    *(out) = vcombine_u16(vrshrn_n_u32(val_lo, (shift)),                  \
+                          vrshrn_n_u32(val_hi, (shift)));                 \
   } while (0)
 
+static INLINE uint16x8x2_t z3_load_left_neon(const uint16_t *left0, int ofs,
+                                             int max_ofs) {
+  uint16x8_t r0;
+  uint16x8_t r1;
+  if (ofs + 7 >= max_ofs) {
+    int shuffle_idx = max_ofs - ofs;
+    r0 = zn_load_masked_neon(left0 + (max_ofs - 7), shuffle_idx);
+  } else {
+    r0 = vld1q_u16(left0 + ofs);
+  }
+  if (ofs + 8 >= max_ofs) {
+    int shuffle_idx = max_ofs - ofs - 1;
+    r1 = zn_load_masked_neon(left0 + (max_ofs - 7), shuffle_idx);
+  } else {
+    r1 = vld1q_u16(left0 + ofs + 1);
+  }
+  return (uint16x8x2_t){ { r0, r1 } };
+}
+
 static void highbd_dr_prediction_z3_upsample0_neon(uint16_t *dst,
                                                    ptrdiff_t stride, int bw,
                                                    int bh, const uint16_t *left,
@@ -2561,34 +2617,30 @@ static void highbd_dr_prediction_z3_upsample0_neon(uint16_t *dst,
         if (base0 >= max_base_y) {
           out[0] = vdupq_n_u16(left_max);
         } else {
-          const uint16x8_t l00 = vld1q_u16(left + base0);
-          const uint16x8_t l01 = vld1q_u16(left1 + base0);
-          HIGHBD_DR_PREDICTOR_Z3_STEP_X8(&out[0], iota1x8, base0, l00, l01,
-                                         shifts0, shifts1, 0, 6);
+          const uint16x8x2_t l0 = z3_load_left_neon(left, base0, max_base_y);
+          HIGHBD_DR_PREDICTOR_Z3_STEP_X8(&out[0], iota1x8, base0, l0.val[0],
+                                         l0.val[1], shifts0, shifts1, 0, 6);
         }
         if (base1 >= max_base_y) {
           out[1] = vdupq_n_u16(left_max);
         } else {
-          const uint16x8_t l10 = vld1q_u16(left + base1);
-          const uint16x8_t l11 = vld1q_u16(left1 + base1);
-          HIGHBD_DR_PREDICTOR_Z3_STEP_X8(&out[1], iota1x8, base1, l10, l11,
-                                         shifts0, shifts1, 1, 6);
+          const uint16x8x2_t l1 = z3_load_left_neon(left, base1, max_base_y);
+          HIGHBD_DR_PREDICTOR_Z3_STEP_X8(&out[1], iota1x8, base1, l1.val[0],
+                                         l1.val[1], shifts0, shifts1, 1, 6);
         }
         if (base2 >= max_base_y) {
           out[2] = vdupq_n_u16(left_max);
         } else {
-          const uint16x8_t l20 = vld1q_u16(left + base2);
-          const uint16x8_t l21 = vld1q_u16(left1 + base2);
-          HIGHBD_DR_PREDICTOR_Z3_STEP_X8(&out[2], iota1x8, base2, l20, l21,
-                                         shifts0, shifts1, 2, 6);
+          const uint16x8x2_t l2 = z3_load_left_neon(left, base2, max_base_y);
+          HIGHBD_DR_PREDICTOR_Z3_STEP_X8(&out[2], iota1x8, base2, l2.val[0],
+                                         l2.val[1], shifts0, shifts1, 2, 6);
         }
         if (base3 >= max_base_y) {
           out[3] = vdupq_n_u16(left_max);
         } else {
-          const uint16x8_t l30 = vld1q_u16(left + base3);
-          const uint16x8_t l31 = vld1q_u16(left1 + base3);
-          HIGHBD_DR_PREDICTOR_Z3_STEP_X8(&out[3], iota1x8, base3, l30, l31,
-                                         shifts0, shifts1, 3, 6);
+          const uint16x8x2_t l3 = z3_load_left_neon(left, base3, max_base_y);
+          HIGHBD_DR_PREDICTOR_Z3_STEP_X8(&out[3], iota1x8, base3, l3.val[0],
+                                         l3.val[1], shifts0, shifts1, 3, 6);
         }
         transpose_array_inplace_u16_4x8(out);
         for (int r2 = 0; r2 < 4; ++r2) {
diff --git a/aom_dsp/arm/intrapred_neon.c b/aom_dsp/arm/intrapred_neon.c
index c3716b3a78..2c99154fd0 100644
--- a/aom_dsp/arm/intrapred_neon.c
+++ b/aom_dsp/arm/intrapred_neon.c
@@ -1356,6 +1356,41 @@ static void dr_prediction_z1_32xN_neon(int N, uint8_t *dst, ptrdiff_t stride,
   }
 }
 
+// clang-format off
+static const uint8_t kLoadMaxShuffles[] = {
+  15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
+  14, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
+  13, 14, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
+  12, 13, 14, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
+  11, 12, 13, 14, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
+  10, 11, 12, 13, 14, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
+   9, 10, 11, 12, 13, 14, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
+   8,  9, 10, 11, 12, 13, 14, 15, 15, 15, 15, 15, 15, 15, 15, 15,
+   7,  8,  9, 10, 11, 12, 13, 14, 15, 15, 15, 15, 15, 15, 15, 15,
+   6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 15, 15, 15, 15, 15, 15,
+   5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 15, 15, 15, 15, 15,
+   4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 15, 15, 15, 15,
+   3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 15, 15, 15,
+   2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 15, 15,
+   1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 15,
+   0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15,
+};
+// clang-format on
+
+static INLINE uint8x16_t z1_load_masked_neon(const uint8_t *ptr,
+                                             int shuffle_idx) {
+  uint8x16_t shuffle = vld1q_u8(&kLoadMaxShuffles[16 * shuffle_idx]);
+  uint8x16_t src = vld1q_u8(ptr);
+#if AOM_ARCH_AARCH64
+  return vqtbl1q_u8(src, shuffle);
+#else
+  uint8x8x2_t src2 = { { vget_low_u8(src), vget_high_u8(src) } };
+  uint8x8_t lo = vtbl2_u8(src2, vget_low_u8(shuffle));
+  uint8x8_t hi = vtbl2_u8(src2, vget_high_u8(shuffle));
+  return vcombine_u8(lo, hi);
+#endif
+}
+
 static void dr_prediction_z1_64xN_neon(int N, uint8_t *dst, ptrdiff_t stride,
                                        const uint8_t *above, int dx) {
   const int frac_bits = 6;
@@ -1369,7 +1404,6 @@ static void dr_prediction_z1_64xN_neon(int N, uint8_t *dst, ptrdiff_t stride,
   //   (above[x] * 32 + 16 + (above[x+1] - above[x]) * shift) >> 5
 
   const uint8x16_t a_mbase_x = vdupq_n_u8(above[max_base_x]);
-  const uint8x16_t max_base_x128 = vdupq_n_u8(max_base_x);
 
   int x = dx;
   for (int r = 0; r < N; r++, dst += stride) {
@@ -1391,12 +1425,24 @@ static void dr_prediction_z1_64xN_neon(int N, uint8_t *dst, ptrdiff_t stride,
                                                vcreate_u8(0x0F0E0D0C0B0A0908)));
 
     for (int j = 0; j < 64; j += 16) {
-      int mdif = max_base_x - (base + j);
-      if (mdif <= 0) {
+      if (base + j >= max_base_x) {
         vst1q_u8(dst + j, a_mbase_x);
       } else {
-        uint8x16_t a0_128 = vld1q_u8(above + base + j);
-        uint8x16_t a1_128 = vld1q_u8(above + base + 1 + j);
+        uint8x16_t a0_128;
+        uint8x16_t a1_128;
+        if (base + j + 15 >= max_base_x) {
+          int shuffle_idx = max_base_x - base - j;
+          a0_128 = z1_load_masked_neon(above + (max_base_x - 15), shuffle_idx);
+        } else {
+          a0_128 = vld1q_u8(above + base + j);
+        }
+        if (base + j + 16 >= max_base_x) {
+          int shuffle_idx = max_base_x - base - j - 1;
+          a1_128 = z1_load_masked_neon(above + (max_base_x - 15), shuffle_idx);
+        } else {
+          a1_128 = vld1q_u8(above + base + j + 1);
+        }
+
         uint16x8_t diff_lo = vsubl_u8(vget_low_u8(a1_128), vget_low_u8(a0_128));
         uint16x8_t diff_hi =
             vsubl_u8(vget_high_u8(a1_128), vget_high_u8(a0_128));
@@ -1406,13 +1452,8 @@ static void dr_prediction_z1_64xN_neon(int N, uint8_t *dst, ptrdiff_t stride,
             vmlal_u8(vdupq_n_u16(16), vget_high_u8(a0_128), vdup_n_u8(32));
         uint16x8_t res_lo = vmlaq_u16(a32_lo, diff_lo, shift);
         uint16x8_t res_hi = vmlaq_u16(a32_hi, diff_hi, shift);
-        uint8x16_t v_temp =
-            vcombine_u8(vshrn_n_u16(res_lo, 5), vshrn_n_u16(res_hi, 5));
-
-        uint8x16_t mask128 =
-            vcgtq_u8(vqsubq_u8(max_base_x128, base_inc128), vdupq_n_u8(0));
-        uint8x16_t res128 = vbslq_u8(mask128, v_temp, a_mbase_x);
-        vst1q_u8(dst + j, res128);
+        vst1q_u8(dst + j,
+                 vcombine_u8(vshrn_n_u16(res_lo, 5), vshrn_n_u16(res_hi, 5)));
 
         base_inc128 = vaddq_u8(base_inc128, vdupq_n_u8(16));
       }