SDL: Add neon SIMD alpha blending blitter

From d9b3b9ad91f0883b7840a589ecbc7ed27f8a7852 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?M=C4=81rti=C5=86=C5=A1=20Mo=C5=BEeiko?=
 <martins.mozeiko@gmail.com>
Date: Wed, 28 Aug 2024 20:29:05 -0700
Subject: [PATCH] Add neon SIMD alpha blending blitter

---
 src/video/SDL_blit_A.c | 115 +++++++++++++++++++++++++++++++++++++++++
 1 file changed, 115 insertions(+)

diff --git a/src/video/SDL_blit_A.c b/src/video/SDL_blit_A.c
index db13e74cc42df..1211bb015b3fa 100644
--- a/src/video/SDL_blit_A.c
+++ b/src/video/SDL_blit_A.c
@@ -1176,6 +1176,114 @@ static void SDL_TARGETING("avx2") Blit8888to8888PixelAlphaSwizzleAVX2(SDL_BlitIn
 
 #endif
 
+#if defined(SDL_NEON_INTRINSICS) && (__ARM_ARCH >= 8)
+
+static void Blit8888to8888PixelAlphaSwizzleNEON(SDL_BlitInfo *info)
+{
+    int width = info->dst_w;
+    int height = info->dst_h;
+    Uint8 *src = info->src;
+    int srcskip = info->src_skip;
+    Uint8 *dst = info->dst;
+    int dstskip = info->dst_skip;
+    const SDL_PixelFormatDetails *srcfmt = info->src_fmt;
+    const SDL_PixelFormatDetails *dstfmt = info->dst_fmt;
+
+    // The byte offsets for the start of each pixel
+    const uint8x16_t mask_offsets = vreinterpretq_u8_u64(vcombine_u64(
+        vcreate_u64(0x0404040400000000), vcreate_u64(0x0c0c0c0c08080808)));
+
+    const uint8x16_t convert_mask = vreinterpretq_u8_u32(vaddq_u32(
+        vreinterpretq_u32_u8(mask_offsets),
+        vdupq_n_u32(
+            ((srcfmt->Rshift >> 3) << dstfmt->Rshift) |
+            ((srcfmt->Gshift >> 3) << dstfmt->Gshift) |
+            ((srcfmt->Bshift >> 3) << dstfmt->Bshift))));
+
+    const uint8x16_t alpha_splat_mask = vaddq_u8(vdupq_n_u8(srcfmt->Ashift >> 3), mask_offsets);
+    const uint8x16_t alpha_fill_mask = vreinterpretq_u8_u32(vdupq_n_u32(dstfmt->Amask));
+
+    while (height--) {
+        int i = 0;
+
+        for (; i + 4 <= width; i += 4) {
+            // Load 4 src pixels
+            uint8x16_t src128 = vld1q_u8(src);
+
+            // Load 4 dst pixels
+            uint8x16_t dst128 = vld1q_u8(dst);
+
+            // Extract the alpha from each pixel and splat it into all the channels
+            uint8x16_t srcA = vqtbl1q_u8(src128, alpha_splat_mask);
+
+            // Convert to dst format
+            src128 = vqtbl1q_u8(src128, convert_mask);
+
+            // Set the alpha channels of src to 255
+            src128 = vorrq_u8(src128, alpha_fill_mask);
+
+            // 255 - srcA = ~srcA
+            uint8x16_t srcInvA = vmvnq_u8(srcA);
+
+            // Result initialized with 1, this is for truncated divide later
+            uint16x8_t res_lo = vdupq_n_u16(1);
+            uint16x8_t res_hi = vdupq_n_u16(1);
+
+            // res = alpha * src + (255 - alpha) * dst
+            res_lo = vmlal_u8(res_lo, vget_low_u8(srcA),    vget_low_u8(src128));
+            res_lo = vmlal_u8(res_lo, vget_low_u8(srcInvA), vget_low_u8(dst128));
+            res_hi = vmlal_high_u8(res_hi, srcA,    src128);
+            res_hi = vmlal_high_u8(res_hi, srcInvA, dst128);
+
+            // Now result has +1 already added for truncated division
+            // dst = (res + (res >> 8)) >> 8
+            uint8x8_t temp;
+            temp   = vaddhn_u16(res_lo, vshrq_n_u16(res_lo, 8));
+            dst128 = vaddhn_high_u16(temp, res_hi, vshrq_n_u16(res_hi, 8));
+
+            // For rounded division remove the constant 1 and change first two vmlal_u8 to vmull_u8
+            // Then replace two previous lines with following code:
+            // temp   = vraddhn_u16(res_lo, vrshrq_n_u16(res_lo, 8));
+            // dst128 = vraddhn_high_u16(temp, res_hi, vrshrq_n_u16(res_hi, 8));
+
+            // Save the result
+            vst1q_u8(dst, dst128);
+
+            src += 16;
+            dst += 16;
+        }
+
+        // Process 1 pixel per iteration, max 3 iterations, same calculations as above
+        for (; i < width; ++i) {
+            // Top 32-bits will be not used in src32 & dst32
+            uint8x8_t src32 = vreinterpret_u8_u32(vld1_dup_u32((Uint32*)src));
+            uint8x8_t dst32 = vreinterpret_u8_u32(vld1_dup_u32((Uint32*)dst));
+
+            uint8x8_t srcA = vtbl1_u8(src32, vget_low_u8(alpha_splat_mask));
+            src32 = vtbl1_u8(src32, vget_low_u8(convert_mask));
+            src32 = vorr_u8(src32, vget_low_u8(alpha_fill_mask));
+            uint8x8_t srcInvA = vmvn_u8(srcA);
+
+            uint16x8_t res = vdupq_n_u16(1);
+            res = vmlal_u8(res, srcA,    src32);
+            res = vmlal_u8(res, srcInvA, dst32);
+
+            dst32 = vaddhn_u16(res, vshrq_n_u16(res, 8));
+
+            // Save the result, only low 32-bits
+            vst1_lane_u32((Uint32*)dst, vreinterpret_u32_u8(dst32), 0);
+
+            src += 4;
+            dst += 4;
+        }
+
+        src += srcskip;
+        dst += dstskip;
+    }
+}
+
+#endif
+
 // General (slow) N->N blending with pixel alpha
 static void BlitNtoNPixelAlpha(SDL_BlitInfo *info)
 {
@@ -1256,11 +1364,18 @@ SDL_BlitFunc SDL_CalculateBlitA(SDL_Surface *surface)
                     return Blit8888to8888PixelAlphaSwizzleSSE41;
                 }
 #endif
+#if defined(SDL_NEON_INTRINSICS) && (__ARM_ARCH >= 8)
+                // To prevent "unused function" compiler warnings/errors
+                (void)Blit8888to8888PixelAlpha;
+                (void)Blit8888to8888PixelAlphaSwizzle;
+                return Blit8888to8888PixelAlphaSwizzleNEON;
+#else
                 if (sf->format == df->format) {
                     return Blit8888to8888PixelAlpha;
                 } else {
                     return Blit8888to8888PixelAlphaSwizzle;
                 }
+#endif
             }
             return BlitNtoNPixelAlpha;