SDL: Implement visually accurate SIMD blitters

From 9590a47629be6fa52d1dfb2104298ff2306776e9 Mon Sep 17 00:00:00 2001
From: Isaac Aronson <[EMAIL REDACTED]>
Date: Sun, 10 Sep 2023 17:45:53 -0500
Subject: [PATCH] Implement visually accurate SIMD blitters

---
 src/video/SDL_blit_A_avx2.c   | 166 ++++++++++++++++++++++++----------
 src/video/SDL_blit_A_sse4_1.c | 159 +++++++++++++++++---------------
 src/video/SDL_blit_A_sse4_1.h |   8 +-
 3 files changed, 211 insertions(+), 122 deletions(-)

diff --git a/src/video/SDL_blit_A_avx2.c b/src/video/SDL_blit_A_avx2.c
index 78bdf9ecc71b8..ed2bfc1bfe8eb 100644
--- a/src/video/SDL_blit_A_avx2.c
+++ b/src/video/SDL_blit_A_avx2.c
@@ -9,45 +9,92 @@
 #include "SDL_blit.h"
 #include "SDL_blit_A_sse4_1.h"
 
-__m256i SDL_TARGETING("avx2") GetSDL_PixelFormatAlphaMask_AVX2(const SDL_PixelFormat* dstfmt) {
-    Uint8 index = dstfmt->Ashift / 4;
-    /* Handle case where bad input sent */
-    if (dstfmt->Ashift == dstfmt->Bshift && dstfmt->Ashift == 0) {
-        index = 6;
+__m256i SDL_TARGETING("avx2") GetSDL_PixelFormatAlphaSplatMask_AVX2(const SDL_PixelFormat* dstfmt) {
+    Uint8 index = dstfmt->Ashift / 8;
+    return _mm256_set_epi8(
+            index + 28, index + 28, index + 28, index + 28, index + 24, index + 24, index + 24, index + 24,
+            index + 20, index + 20, index + 20, index + 20, index + 16, index + 16, index + 16, index + 16,
+            index + 12, index + 12, index + 12, index + 12, index + 8, index + 8, index + 8, index + 8,
+            index + 4, index + 4, index + 4, index + 4, index, index, index, index);
+}
+
+__m256i SDL_TARGETING("avx2") GetSDL_PixelFormatAlphaSaturateMask_AVX2(const SDL_PixelFormat* dstfmt) {
+    const Uint8 bin = dstfmt->Ashift / 8;
+    return _mm256_set_epi8(
+            bin == 3 ? 0xFF : 0, bin == 2 ? 0xFF : 0, bin == 1 ? 0xFF : 0, bin == 0 ? 0xFF : 0,
+            bin == 3 ? 0xFF : 0, bin == 2 ? 0xFF : 0, bin == 1 ? 0xFF : 0, bin == 0 ? 0xFF : 0,
+            bin == 3 ? 0xFF : 0, bin == 2 ? 0xFF : 0, bin == 1 ? 0xFF : 0, bin == 0 ? 0xFF : 0,
+            bin == 3 ? 0xFF : 0, bin == 2 ? 0xFF : 0, bin == 1 ? 0xFF : 0, bin == 0 ? 0xFF : 0,
+            bin == 3 ? 0xFF : 0, bin == 2 ? 0xFF : 0, bin == 1 ? 0xFF : 0, bin == 0 ? 0xFF : 0,
+            bin == 3 ? 0xFF : 0, bin == 2 ? 0xFF : 0, bin == 1 ? 0xFF : 0, bin == 0 ? 0xFF : 0,
+            bin == 3 ? 0xFF : 0, bin == 2 ? 0xFF : 0, bin == 1 ? 0xFF : 0, bin == 0 ? 0xFF : 0,
+            bin == 3 ? 0xFF : 0, bin == 2 ? 0xFF : 0, bin == 1 ? 0xFF : 0, bin == 0 ? 0xFF : 0);
+}
+
+__m256i SDL_TARGETING("avx2") GetSDL_PixelFormatShuffleMask_AVX2(const SDL_PixelFormat* srcfmt,
+                                                              const SDL_PixelFormat* dstfmt) {
+    /* Calculate shuffle indices based on the source and destination SDL_PixelFormat */
+    Uint8 shuffleIndices[32];
+    Uint8 dstAshift = dstfmt->Ashift / 8;
+    Uint8 dstRshift = dstfmt->Rshift / 8;
+    Uint8 dstGshift = dstfmt->Gshift / 8;
+    Uint8 dstBshift = dstfmt->Bshift / 8;
+    for (int i = 0; i < 8; ++i) {
+        shuffleIndices[dstAshift + i * 4] = srcfmt->Ashift / 8 + i * 4;
+        shuffleIndices[dstRshift + i * 4] = srcfmt->Rshift / 8 + i * 4;
+        shuffleIndices[dstGshift + i * 4] = srcfmt->Gshift / 8 + i * 4;
+        shuffleIndices[dstBshift + i * 4] = srcfmt->Bshift / 8 + i * 4;
     }
+
+    /* Create shuffle mask based on the calculated indices */
     return _mm256_set_epi8(
-            -1, index + 24, -1, index + 24, -1, index + 24, -1, index + 24,
-            -1, index + 16, -1, index + 16, -1, index + 16, -1, index + 16,
-            -1, index + 8, -1, index + 8, -1, index + 8, -1, index + 8,
-            -1, index, -1, index, -1, index, -1, index);
+            shuffleIndices[31], shuffleIndices[30], shuffleIndices[29], shuffleIndices[28],
+            shuffleIndices[27], shuffleIndices[26], shuffleIndices[25], shuffleIndices[24],
+            shuffleIndices[23], shuffleIndices[22], shuffleIndices[21], shuffleIndices[20],
+            shuffleIndices[19], shuffleIndices[18], shuffleIndices[17], shuffleIndices[16],
+            shuffleIndices[15], shuffleIndices[14], shuffleIndices[13], shuffleIndices[12],
+            shuffleIndices[11], shuffleIndices[10], shuffleIndices[9], shuffleIndices[8],
+            shuffleIndices[7], shuffleIndices[6], shuffleIndices[5], shuffleIndices[4],
+            shuffleIndices[3], shuffleIndices[2], shuffleIndices[1], shuffleIndices[0]
+    );
 }
 
 /**
- * Using the AVX2 instruction set, blit eight pixels with alpha blending
- * @param src A pointer to four 32-bit pixels of ARGB format to blit into dst
- * @param dst A pointer to four 32-bit pixels of ARGB format to retain visual data for while alpha blending
- * @return A 128-bit wide vector of four alpha-blended pixels in ARGB format
+ * Using the AVX2 instruction set, blit sixteen pixels into eight with alpha blending
  */
-__m128i SDL_TARGETING("avx2") MixRGBA_AVX2(const __m128i src, const __m128i dst, const __m256i alphaMask) {
-    __m256i src_color = _mm256_cvtepu8_epi16(src);
-    __m256i dst_color = _mm256_cvtepu8_epi16(dst);
-    __m256i alpha = _mm256_shuffle_epi8(src_color, alphaMask);
-    __m256i sub = _mm256_sub_epi16(src_color, dst_color);
-    __m256i mul = _mm256_mullo_epi16(sub, alpha);
-    /**
-     * With an 8-bit shuffle, one can only move integers within a lane. The 256-bit AVX2 lane is actually 4 64-bit
-     * lanes. We pack the integers into the start of each lane. The second shuffle operates on these 64-bit integers to
-     * put them into the correct order for transport back to the surface in the correct format.
-     */
-    const __m256i SHUFFLE_REDUCE = _mm256_set_epi8(
-            -1, -1, -1, -1, -1, -1, -1, -1,
-            31, 29, 27, 25, 23, 21, 19, 17,
-            -1, -1, -1, -1, -1, -1, -1, -1,
-            15, 13, 11, 9, 7, 5, 3, 1);
-    __m256i reduced = _mm256_shuffle_epi8(mul, SHUFFLE_REDUCE);
-    __m256i packed = _mm256_permute4x64_epi64(reduced, _MM_SHUFFLE(3, 1, 2, 0));
-    __m128i mix = _mm256_castsi256_si128(packed);
-    return _mm_add_epi8(mix, dst);
+__m256i SDL_TARGETING("avx2") MixRGBA_AVX2(__m256i src, __m256i dst, const __m256i alpha_shuffle,
+                                           const __m256i alpha_saturate) {
+    // SIMD implementation of blend_mul2.
+    // dstRGB                            = (srcRGB * srcA) + (dstRGB * (1-srcA))
+    // dstA   = srcA + (dstA * (1-srcA)) = (1      * srcA) + (dstA   * (1-srcA))
+
+    // Splat the alpha into all channels for each pixel
+    __m256i srca = _mm256_shuffle_epi8(src, alpha_shuffle);
+
+    // Set the alpha channels of src to 255
+    src = _mm256_or_si256(src, alpha_saturate);
+
+    __m256i src_lo = _mm256_unpacklo_epi8(src, _mm256_setzero_si256());
+    __m256i src_hi = _mm256_unpackhi_epi8(src, _mm256_setzero_si256());
+
+    __m256i dst_lo = _mm256_unpacklo_epi8(dst, _mm256_setzero_si256());
+    __m256i dst_hi = _mm256_unpackhi_epi8(dst, _mm256_setzero_si256());
+
+    __m256i srca_lo = _mm256_unpacklo_epi8(srca, _mm256_setzero_si256());
+    __m256i srca_hi = _mm256_unpackhi_epi8(srca, _mm256_setzero_si256());
+
+    // dst = ((src - dst) * srcA) + ((dst << 8) - dst)
+    dst_lo = _mm256_add_epi16(_mm256_mullo_epi16(_mm256_sub_epi16(src_lo, dst_lo), srca_lo),
+                              _mm256_sub_epi16(_mm256_slli_epi16(dst_lo, 8), dst_lo));
+    dst_hi = _mm256_add_epi16(_mm256_mullo_epi16(_mm256_sub_epi16(src_hi, dst_hi), srca_hi),
+                              _mm256_sub_epi16(_mm256_slli_epi16(dst_hi, 8), dst_hi));
+
+    // dst = (dst * 0x8081) >> 23
+    dst_lo = _mm256_srli_epi16(_mm256_mulhi_epu16(dst_lo, _mm256_set1_epi16(-0x7F7F)), 7);
+    dst_hi = _mm256_srli_epi16(_mm256_mulhi_epu16(dst_hi, _mm256_set1_epi16(-0x7F7F)), 7);
+
+    dst = _mm256_packus_epi16(dst_lo, dst_hi);
+    return dst;
 }
 
 void SDL_TARGETING("avx2") BlitNtoNPixelAlpha_AVX2(SDL_BlitInfo *info)
@@ -61,32 +108,52 @@ void SDL_TARGETING("avx2") BlitNtoNPixelAlpha_AVX2(SDL_BlitInfo *info)
     SDL_PixelFormat *srcfmt = info->src_fmt;
     SDL_PixelFormat *dstfmt = info->dst_fmt;
 
-    int chunks = width / 4;
-    const __m128i colorShiftMask = GetSDL_PixelFormatShuffleMask(srcfmt, dstfmt);
-    const __m256i alphaMask = GetSDL_PixelFormatAlphaMask_AVX2(dstfmt);
-    const __m128i sse4_1AlphaMask = GetSDL_PixelFormatAlphaMask_SSE4_1(dstfmt);
+    int chunks = width / 8;
+    int free_format = 0;
+    /* Handle case when passed invalid format, assume ARGB destination */
+    if (dstfmt->Ashift == 0 && dstfmt->Ashift == dstfmt->Bshift) {
+        dstfmt = SDL_CreatePixelFormat(SDL_PIXELFORMAT_ARGB8888);
+        free_format = 1;
+    }
+    const __m256i shift_mask = GetSDL_PixelFormatShuffleMask_AVX2(srcfmt, dstfmt);
+    const __m256i splat_mask = GetSDL_PixelFormatAlphaSplatMask_AVX2(dstfmt);
+    const __m256i saturate_mask = GetSDL_PixelFormatAlphaSaturateMask_AVX2(dstfmt);
+    const __m128i sse4_1_shift_mask = GetSDL_PixelFormatShuffleMask_SSE4_1(srcfmt, dstfmt);
+    const __m128i sse4_1_splat_mask = GetSDL_PixelFormatAlphaSplatMask_SSE4_1(dstfmt);
+    const __m128i sse4_1_saturate_mask = GetSDL_PixelFormatAlphaSaturateMask_SSE4_1(dstfmt);
 
     while (height--) {
-        /* Process 4-wide chunks of source color data that may be in wrong format */
+        /* Process 8-wide chunks of source color data that may be in wrong format */
         for (int i = 0; i < chunks; i += 1) {
-            __m128i c_src = _mm_shuffle_epi8(_mm_loadu_si128((__m128i *) (src + i * 16)), colorShiftMask);
-            /* Alpha-blend in 4-wide chunk from src into destination */
-            __m128i c_dst = _mm_loadu_si128((__m128i*) (dst + i * 16));
-            __m128i c_mix = MixRGBA_AVX2(c_src, c_dst, alphaMask);
-            _mm_storeu_si128((__m128i*) (dst + i * 16), c_mix);
+            __m256i c_src = _mm256_shuffle_epi8(_mm256_loadu_si256((__m256i *) (src + i * 32)), shift_mask);
+            /* Alpha-blend in 8-wide chunk from src into destination */
+            __m256i c_dst = _mm256_loadu_si256((__m256i*) (dst + i * 32));
+            __m256i c_mix = MixRGBA_AVX2(c_src, c_dst, splat_mask, saturate_mask);
+            _mm256_storeu_si256((__m256i*) (dst + i * 32), c_mix);
         }
 
         /* Handle remaining pixels when width is not a multiple of 4 */
-        if (width % 4 != 0) {
-            int remaining_pixels = width % 4;
+        if (width % 8 != 0) {
+            int remaining_pixels = width % 8;
             int offset = width - remaining_pixels;
+            if (remaining_pixels >= 4) {
+                Uint32 *src_ptr = ((Uint32*)(src + (offset * 4)));
+                Uint32 *dst_ptr = ((Uint32*)(dst + (offset * 4)));
+                __m128i c_src = _mm_loadu_si128((__m128i*)src_ptr);
+                c_src = _mm_shuffle_epi8(c_src, sse4_1_shift_mask);
+                __m128i c_dst = _mm_loadu_si128((__m128i*)dst_ptr);
+                __m128i c_mix = MixRGBA_SSE4_1(c_src, c_dst, sse4_1_splat_mask, sse4_1_saturate_mask);
+                _mm_storeu_si128((__m128i*)dst_ptr, c_mix);
+                remaining_pixels -= 4;
+                offset += 4;
+            }
             if (remaining_pixels >= 2) {
                 Uint32 *src_ptr = ((Uint32*)(src + (offset * 4)));
                 Uint32 *dst_ptr = ((Uint32*)(dst + (offset * 4)));
                 __m128i c_src = _mm_loadu_si64(src_ptr);
-                c_src = _mm_shuffle_epi8(c_src, colorShiftMask);
+                c_src = _mm_shuffle_epi8(c_src, sse4_1_shift_mask);
                 __m128i c_dst = _mm_loadu_si64(dst_ptr);
-                __m128i c_mix = MixRGBA_SSE4_1(c_src, c_dst, sse4_1AlphaMask);
+                __m128i c_mix = MixRGBA_SSE4_1(c_src, c_dst, sse4_1_splat_mask, sse4_1_saturate_mask);
                 _mm_storeu_si64(dst_ptr, c_mix);
                 remaining_pixels -= 2;
                 offset += 2;
@@ -103,7 +170,7 @@ void SDL_TARGETING("avx2") BlitNtoNPixelAlpha_AVX2(SDL_BlitInfo *info)
                 __m128i c_src = _mm_loadu_si32(&pixel);
                 __m128i c_dst = _mm_loadu_si32(dst_ptr);
                 #endif
-                __m128i mixed_pixel = MixRGBA_SSE4_1(c_src, c_dst, sse4_1AlphaMask);
+                __m128i mixed_pixel = MixRGBA_SSE4_1(c_src, c_dst, sse4_1_splat_mask, sse4_1_saturate_mask);
                 /* Old GCC has bad or no _mm_storeu_si32 */
                 #if defined(__GNUC__) && (__GNUC__ < 11)
                 *dst_ptr = _mm_extract_epi32(mixed_pixel, 0);
@@ -119,6 +186,9 @@ void SDL_TARGETING("avx2") BlitNtoNPixelAlpha_AVX2(SDL_BlitInfo *info)
         src += srcskip;
         dst += dstskip;
     }
+    if (free_format) {
+        SDL_DestroyPixelFormat(dstfmt);
+    }
 }
 
 #endif
diff --git a/src/video/SDL_blit_A_sse4_1.c b/src/video/SDL_blit_A_sse4_1.c
index 5348879277630..34355e8c950a3 100644
--- a/src/video/SDL_blit_A_sse4_1.c
+++ b/src/video/SDL_blit_A_sse4_1.c
@@ -10,75 +10,40 @@
 #include "SDL_blit_A_sse4_1.h"
 
 /**
- * A helper function to create an alpha mask for use with MixRGBA_SSE4_1 based on pixel format
+ * A helper function to create an alpha splat mask for use with MixRGBA_SSE4_1 based on pixel format
  */
-__m128i SDL_TARGETING("sse4.1") GetSDL_PixelFormatAlphaMask_SSE4_1(const SDL_PixelFormat* dstfmt) {
-    Uint8 index = dstfmt->Ashift / 8;
-    /* Handle case where bad input sent */
-    if (dstfmt->Ashift == dstfmt->Bshift && dstfmt->Ashift == 0) {
-        index = 3;
-    }
+__m128i SDL_TARGETING("sse4.1") GetSDL_PixelFormatAlphaSplatMask_SSE4_1(const SDL_PixelFormat* dstfmt) {
+    const Uint8 index = dstfmt->Ashift / 8;
     return _mm_set_epi8(
-            -1, index + 4, -1, index + 4, -1, index + 4, -1, index + 4,
-            -1, index, -1, index, -1, index, -1, index);
+            index + 12, index + 12, index + 12, index + 12,
+            index + 8, index + 8, index + 8, index + 8,
+            index + 4, index + 4, index + 4, index + 4,
+            index, index, index, index);
 }
 
 /**
- * Using the SSE4.1 instruction set, blit four pixels with alpha blending
- * @param src A pointer to two 32-bit pixels of ARGB format to blit into dst
- * @param dst A pointer to two 32-bit pixels of ARGB format to retain visual data for while alpha blending
- * @return A 128-bit wide vector of two alpha-blended pixels in ARGB format
+ * A helper function to create an alpha saturate mask for use with MixRGBA_SSE4_1 based on pixel format
  */
-__m128i SDL_TARGETING("sse4.1") MixRGBA_SSE4_1(const __m128i src, const __m128i dst, const __m128i alphaMask) {
-    __m128i src_color = _mm_cvtepu8_epi16(src);
-    __m128i dst_color = _mm_cvtepu8_epi16(dst);
-    /**
-     * Combines a shuffle and an _mm_cvtepu8_epi16 operation into one operation by moving the lower 8 bits of the alpha
-     * channel around to create 16-bit integers.
-     */
-    __m128i alpha = _mm_shuffle_epi8(src, alphaMask);
-    __m128i sub = _mm_sub_epi16(src_color, dst_color);
-    __m128i mul = _mm_mullo_epi16(sub, alpha);
-    const __m128i SHUFFLE_REDUCE = _mm_set_epi8(
-        -1, -1, -1, -1, -1, -1, -1, -1,
-        15, 13, 11, 9, 7, 5, 3, 1);
-    __m128i reduced = _mm_shuffle_epi8(mul, SHUFFLE_REDUCE);
-
-    return _mm_add_epi8(reduced, dst);
-}
-
-Uint32 AlignPixelToSDL_PixelFormat(Uint32 color, const SDL_PixelFormat* srcfmt, const SDL_PixelFormat* dstfmt) {
-    Uint8 a = (color >> srcfmt->Ashift) & 0xFF;
-    Uint8 r = (color >> srcfmt->Rshift) & 0xFF;
-    Uint8 g = (color >> srcfmt->Gshift) & 0xFF;
-    Uint8 b = (color >> srcfmt->Bshift) & 0xFF;
-
-    /* Handle case where bad input sent */
-    Uint8 aShift = dstfmt->Ashift;
-    if (aShift == dstfmt->Bshift && aShift == 0) {
-        aShift = 24;
-    }
-    return (a << aShift) |
-           (r << dstfmt->Rshift) |
-           (g << dstfmt->Gshift) |
-           (b << dstfmt->Bshift);
+__m128i SDL_TARGETING("sse4.1") GetSDL_PixelFormatAlphaSaturateMask_SSE4_1(const SDL_PixelFormat* dstfmt) {
+    const Uint8 bin = dstfmt->Ashift / 8;
+    return _mm_set_epi8(
+            bin == 3 ? 0xFF : 0, bin == 2 ? 0xFF : 0, bin == 1 ? 0xFF : 0, bin == 0 ? 0xFF : 0,
+            bin == 3 ? 0xFF : 0, bin == 2 ? 0xFF : 0, bin == 1 ? 0xFF : 0, bin == 0 ? 0xFF : 0,
+            bin == 3 ? 0xFF : 0, bin == 2 ? 0xFF : 0, bin == 1 ? 0xFF : 0, bin == 0 ? 0xFF : 0,
+            bin == 3 ? 0xFF : 0, bin == 2 ? 0xFF : 0, bin == 1 ? 0xFF : 0, bin == 0 ? 0xFF : 0);
 }
 
-/*
+/**
  * This helper function converts arbitrary pixel formats into a shuffle mask for _mm_shuffle_epi8
  */
-__m128i SDL_TARGETING("sse4.1") GetSDL_PixelFormatShuffleMask(const SDL_PixelFormat* srcfmt,
-                                                              const SDL_PixelFormat* dstfmt) {
+__m128i SDL_TARGETING("sse4.1") GetSDL_PixelFormatShuffleMask_SSE4_1(const SDL_PixelFormat* srcfmt,
+                                                                     const SDL_PixelFormat* dstfmt) {
     /* Calculate shuffle indices based on the source and destination SDL_PixelFormat */
     Uint8 shuffleIndices[16];
     Uint8 dstAshift = dstfmt->Ashift / 8;
     Uint8 dstRshift = dstfmt->Rshift / 8;
     Uint8 dstGshift = dstfmt->Gshift / 8;
     Uint8 dstBshift = dstfmt->Bshift / 8;
-    /* Handle case where bad input sent */
-    if (dstAshift == dstBshift && dstAshift == 0) {
-        dstAshift = 3;
-    }
     for (int i = 0; i < 4; ++i) {
         shuffleIndices[dstAshift + i * 4] = srcfmt->Ashift / 8 + i * 4;
         shuffleIndices[dstRshift + i * 4] = srcfmt->Rshift / 8 + i * 4;
@@ -95,6 +60,56 @@ __m128i SDL_TARGETING("sse4.1") GetSDL_PixelFormatShuffleMask(const SDL_PixelFor
     );
 }
 
+/**
+ * Using the SSE4.1 instruction set, blit eight pixels into four with alpha blending
+ */
+__m128i SDL_TARGETING("sse4.1") MixRGBA_SSE4_1(__m128i src, __m128i dst,
+                                               const __m128i alpha_splat, const __m128i alpha_saturate) {
+    // SIMD implementation of blend_mul2.
+    // dstRGB                            = (srcRGB * srcA) + (dstRGB * (1-srcA))
+    // dstA   = srcA + (dstA * (1-srcA)) = (1      * srcA) + (dstA   * (1-srcA))
+
+    // Splat the alpha into all channels for each pixel
+    __m128i srca = _mm_shuffle_epi8(src, alpha_splat);
+
+    // Set the alpha channels of src to 255
+    src = _mm_or_si128(src, alpha_saturate);
+
+    __m128i src_lo = _mm_unpacklo_epi8(src, _mm_setzero_si128());
+    __m128i src_hi = _mm_unpackhi_epi8(src, _mm_setzero_si128());
+
+    __m128i dst_lo = _mm_unpacklo_epi8(dst, _mm_setzero_si128());
+    __m128i dst_hi = _mm_unpackhi_epi8(dst, _mm_setzero_si128());
+
+    __m128i srca_lo = _mm_unpacklo_epi8(srca, _mm_setzero_si128());
+    __m128i srca_hi = _mm_unpackhi_epi8(srca, _mm_setzero_si128());
+
+    // dst = ((src - dst) * srcA) + ((dst << 8) - dst)
+    dst_lo = _mm_add_epi16(_mm_mullo_epi16(_mm_sub_epi16(src_lo, dst_lo), srca_lo),
+                           _mm_sub_epi16(_mm_slli_epi16(dst_lo, 8), dst_lo));
+    dst_hi = _mm_add_epi16(_mm_mullo_epi16(_mm_sub_epi16(src_hi, dst_hi), srca_hi),
+                           _mm_sub_epi16(_mm_slli_epi16(dst_hi, 8), dst_hi));
+
+    // dst = (dst * 0x8081) >> 23
+    dst_lo = _mm_srli_epi16(_mm_mulhi_epu16(dst_lo, _mm_set1_epi16(-0x7F7F)), 7);
+    dst_hi = _mm_srli_epi16(_mm_mulhi_epu16(dst_hi, _mm_set1_epi16(-0x7F7F)), 7);
+
+    dst = _mm_packus_epi16(dst_lo, dst_hi);
+    return dst;
+}
+
+Uint32 AlignPixelToSDL_PixelFormat(Uint32 color, const SDL_PixelFormat* srcfmt, const SDL_PixelFormat* dstfmt) {
+    Uint8 a = (color >> srcfmt->Ashift) & 0xFF;
+    Uint8 r = (color >> srcfmt->Rshift) & 0xFF;
+    Uint8 g = (color >> srcfmt->Gshift) & 0xFF;
+    Uint8 b = (color >> srcfmt->Bshift) & 0xFF;
+
+    return (a << dstfmt->Ashift) |
+           (r << dstfmt->Rshift) |
+           (g << dstfmt->Gshift) |
+           (b << dstfmt->Bshift);
+}
+
 
 void SDL_TARGETING("sse4.1") BlitNtoNPixelAlpha_SSE4_1(SDL_BlitInfo* info) {
     int width = info->dst_w;
@@ -106,24 +121,24 @@ void SDL_TARGETING("sse4.1") BlitNtoNPixelAlpha_SSE4_1(SDL_BlitInfo* info) {
     SDL_PixelFormat *srcfmt = info->src_fmt;
     SDL_PixelFormat *dstfmt = info->dst_fmt;
 
-    int chunks = width / 4;
-    Uint8 *buffer = (Uint8*)SDL_malloc(chunks * 16 * sizeof(Uint8));
-    const __m128i colorShiftMask = GetSDL_PixelFormatShuffleMask(srcfmt, dstfmt);
-    const __m128i alphaMask = GetSDL_PixelFormatAlphaMask_SSE4_1(dstfmt);
+    const int chunks = width / 4;
+    int free_format = 0;
+    /* Handle case when passed invalid format, assume ARGB destination */
+    if (dstfmt->Ashift == 0 && dstfmt->Ashift == dstfmt->Bshift) {
+        dstfmt = SDL_CreatePixelFormat(SDL_PIXELFORMAT_ARGB8888);
+        free_format = 1;
+    }
+    const __m128i shift_mask = GetSDL_PixelFormatShuffleMask_SSE4_1(srcfmt, dstfmt);
+    const __m128i splat_mask = GetSDL_PixelFormatAlphaSplatMask_SSE4_1(dstfmt);
+    const __m128i saturate_mask = GetSDL_PixelFormatAlphaSaturateMask_SSE4_1(dstfmt);
 
     while (height--) {
-        /* Process 4-wide chunks of source color data that may be in wrong format into buffer */
         for (int i = 0; i < chunks; i += 1) {
             __m128i colors = _mm_loadu_si128((__m128i*)(src + i * 16));
-            _mm_storeu_si128((__m128i*)(buffer + i * 16), _mm_shuffle_epi8(colors, colorShiftMask));
-        }
-
-        /* Alpha-blend in 2-wide chunks from buffer into destination */
-        for (int i = 0; i < chunks * 2; i += 1) {
-            __m128i c_src = _mm_loadu_si64((buffer + (i * 8)));
-            __m128i c_dst = _mm_loadu_si64((dst + i * 8));
-            __m128i c_mix = MixRGBA_SSE4_1(c_src, c_dst, alphaMask);
-            _mm_storeu_si64(dst + i * 8, c_mix);
+            colors = _mm_shuffle_epi8(colors, shift_mask);
+            colors = MixRGBA_SSE4_1(colors, _mm_loadu_si128((__m128i*)(dst + i * 16)),
+                                    splat_mask, saturate_mask);
+            _mm_storeu_si128((__m128i*)(dst + i * 16), colors);
         }
 
         /* Handle remaining pixels when width is not a multiple of 4 */
@@ -134,9 +149,9 @@ void SDL_TARGETING("sse4.1") BlitNtoNPixelAlpha_SSE4_1(SDL_BlitInfo* info) {
                 Uint32 *src_ptr = ((Uint32*)(src + (offset * 4)));
                 Uint32 *dst_ptr = ((Uint32*)(dst + (offset * 4)));
                 __m128i c_src = _mm_loadu_si64(src_ptr);
-                c_src = _mm_shuffle_epi8(c_src, colorShiftMask);
+                c_src = _mm_shuffle_epi8(c_src, shift_mask);
                 __m128i c_dst = _mm_loadu_si64(dst_ptr);
-                __m128i c_mix = MixRGBA_SSE4_1(c_src, c_dst, alphaMask);
+                __m128i c_mix = MixRGBA_SSE4_1(c_src, c_dst, splat_mask, saturate_mask);
                 _mm_storeu_si64(dst_ptr, c_mix);
                 remaining_pixels -= 2;
                 offset += 2;
@@ -153,7 +168,7 @@ void SDL_TARGETING("sse4.1") BlitNtoNPixelAlpha_SSE4_1(SDL_BlitInfo* info) {
                 __m128i c_src = _mm_loadu_si32(&pixel);
                 __m128i c_dst = _mm_loadu_si32(dst_ptr);
                 #endif
-                __m128i mixed_pixel = MixRGBA_SSE4_1(c_src, c_dst, alphaMask);
+                __m128i mixed_pixel = MixRGBA_SSE4_1(c_src, c_dst, splat_mask, saturate_mask);
                 /* Old GCC has bad or no _mm_storeu_si32 */
                 #if defined(__GNUC__) && (__GNUC__ < 11)
                 *dst_ptr = _mm_extract_epi32(mixed_pixel, 0);
@@ -169,7 +184,9 @@ void SDL_TARGETING("sse4.1") BlitNtoNPixelAlpha_SSE4_1(SDL_BlitInfo* info) {
         src += srcskip;
         dst += dstskip;
     }
-    SDL_free(buffer);
+    if (free_format) {
+        SDL_DestroyPixelFormat(dstfmt);
+    }
 }
 
 #endif
diff --git a/src/video/SDL_blit_A_sse4_1.h b/src/video/SDL_blit_A_sse4_1.h
index 132120d051769..c6c8dec729a8a 100644
--- a/src/video/SDL_blit_A_sse4_1.h
+++ b/src/video/SDL_blit_A_sse4_1.h
@@ -4,11 +4,13 @@
 #ifdef SDL_SSE4_1_INTRINSICS
 Uint32 AlignPixelToSDL_PixelFormat(Uint32 color, const SDL_PixelFormat* srcfmt, const SDL_PixelFormat* dstfmt);
 
-__m128i SDL_TARGETING("sse4.1") GetSDL_PixelFormatAlphaMask_SSE4_1(const SDL_PixelFormat* dstfmt);
+__m128i SDL_TARGETING("sse4.1") GetSDL_PixelFormatAlphaSplatMask_SSE4_1(const SDL_PixelFormat* dstfmt);
 
-__m128i SDL_TARGETING("sse4.1") GetSDL_PixelFormatShuffleMask(const SDL_PixelFormat* srcfmt, const SDL_PixelFormat* dstfmt);
+__m128i SDL_TARGETING("sse4.1") GetSDL_PixelFormatAlphaSaturateMask_SSE4_1(const SDL_PixelFormat* dstfmt);
 
-__m128i SDL_TARGETING("sse4.1") MixRGBA_SSE4_1(__m128i src, __m128i dst, __m128i alphaMask);
+__m128i SDL_TARGETING("sse4.1") GetSDL_PixelFormatShuffleMask_SSE4_1(const SDL_PixelFormat* srcfmt, const SDL_PixelFormat* dstfmt);
+
+__m128i SDL_TARGETING("sse4.1") MixRGBA_SSE4_1(__m128i src, __m128i dst, __m128i alpha_splat, __m128i alpha_saturate);
 
 void SDL_TARGETING("sse4.1") BlitNtoNPixelAlpha_SSE4_1(SDL_BlitInfo *info);