SDL: GPU: Validate shader bytecode

From 6e4ace310cb7130a868574288050a21d09e51a71 Mon Sep 17 00:00:00 2001
From: Lucas Murray <[EMAIL REDACTED]>
Date: Wed, 2 Apr 2025 21:32:15 +1100
Subject: [PATCH] GPU: Validate shader bytecode

---
 src/gpu/d3d12/SDL_gpu_d3d12.c   | 25 +++++++++++++++++++++++++
 src/gpu/metal/SDL_gpu_metal.m   | 16 ++++++++++++++++
 src/gpu/vulkan/SDL_gpu_vulkan.c | 27 +++++++++++++++++++++++++++
 3 files changed, 68 insertions(+)

diff --git a/src/gpu/d3d12/SDL_gpu_d3d12.c b/src/gpu/d3d12/SDL_gpu_d3d12.c
index 628ca16c7d4c3..b848790df0e20 100644
--- a/src/gpu/d3d12/SDL_gpu_d3d12.c
+++ b/src/gpu/d3d12/SDL_gpu_d3d12.c
@@ -2482,12 +2482,32 @@ static D3D12GraphicsRootSignature *D3D12_INTERNAL_CreateGraphicsRootSignature(
     return d3d12GraphicsRootSignature;
 }
 
+static bool D3D12_INTERNAL_IsValidShaderBytecode(
+    const Uint8 *code,
+    size_t codeSize)
+{
+    // Both DXIL and DXBC bytecode have a 4 byte header containing `DXBC`.
+    if (codeSize < 4 || code == NULL) {
+        return false;
+    }
+    return SDL_memcmp(code, "DXBC", 4) == 0;
+}
+
 static bool D3D12_INTERNAL_CreateShaderBytecode(
+    D3D12Renderer *renderer,
     const Uint8 *code,
     size_t codeSize,
+    SDL_GPUShaderFormat format,
     void **pBytecode,
     size_t *pBytecodeSize)
 {
+    if (!D3D12_INTERNAL_IsValidShaderBytecode(code, codeSize)) {
+        if (format == SDL_GPU_SHADERFORMAT_DXBC) {
+            SET_STRING_ERROR_AND_RETURN("The provided shader code is not valid DXBC!", false);
+        }
+        SET_STRING_ERROR_AND_RETURN("The provided shader code is not valid DXIL!", false);
+    }
+
     if (pBytecode != NULL) {
         *pBytecode = SDL_malloc(codeSize);
         if (!*pBytecode) {
@@ -2705,8 +2725,10 @@ static SDL_GPUComputePipeline *D3D12_CreateComputePipeline(
     ID3D12PipelineState *pipelineState;
 
     if (!D3D12_INTERNAL_CreateShaderBytecode(
+            renderer,
             createinfo->code,
             createinfo->code_size,
+            createinfo->format,
             &bytecode,
             &bytecodeSize)) {
         return NULL;
@@ -3113,13 +3135,16 @@ static SDL_GPUShader *D3D12_CreateShader(
     SDL_GPURenderer *driverData,
     const SDL_GPUShaderCreateInfo *createinfo)
 {
+    D3D12Renderer *renderer = (D3D12Renderer *)driverData;
     void *bytecode;
     size_t bytecodeSize;
     D3D12Shader *shader;
 
     if (!D3D12_INTERNAL_CreateShaderBytecode(
+            renderer,
             createinfo->code,
             createinfo->code_size,
+            createinfo->format,
             &bytecode,
             &bytecodeSize)) {
         return NULL;
diff --git a/src/gpu/metal/SDL_gpu_metal.m b/src/gpu/metal/SDL_gpu_metal.m
index d83079f5f581c..c8b894f3ba85c 100644
--- a/src/gpu/metal/SDL_gpu_metal.m
+++ b/src/gpu/metal/SDL_gpu_metal.m
@@ -836,6 +836,17 @@ static void METAL_INTERNAL_TrackUniformBuffer(
     id<MTLFunction> function;
 } MetalLibraryFunction;
 
+static bool METAL_INTERNAL_IsValidMetalLibrary(
+    const Uint8 *code,
+    size_t codeSize)
+{
+    // Metal libraries have a 4 byte header containing `MTLB`.
+    if (codeSize < 4 || code == NULL) {
+        return false;
+    }
+    return SDL_memcmp(code, "MTLB", 4) == 0;
+}
+
 // This function assumes that it's called from within an autorelease pool
 static MetalLibraryFunction METAL_INTERNAL_CompileShader(
     MetalRenderer *renderer,
@@ -864,6 +875,11 @@ static MetalLibraryFunction METAL_INTERNAL_CompileShader(
                          options:nil
                            error:&error];
     } else if (format == SDL_GPU_SHADERFORMAT_METALLIB) {
+        if (!METAL_INTERNAL_IsValidMetalLibrary(code, codeSize)) {
+            SET_STRING_ERROR_AND_RETURN(
+                "The provided shader code is not a valid Metal library!",
+                libraryFunction);
+        }
         data = dispatch_data_create(
             code,
             codeSize,
diff --git a/src/gpu/vulkan/SDL_gpu_vulkan.c b/src/gpu/vulkan/SDL_gpu_vulkan.c
index ebfeb2dd7465a..5de1e62b8cfbe 100644
--- a/src/gpu/vulkan/SDL_gpu_vulkan.c
+++ b/src/gpu/vulkan/SDL_gpu_vulkan.c
@@ -6511,6 +6511,25 @@ static SDL_GPUGraphicsPipeline *VULKAN_CreateGraphicsPipeline(
     return (SDL_GPUGraphicsPipeline *)graphicsPipeline;
 }
 
+static bool VULKAN_INTERNAL_IsValidShaderBytecode(
+    const Uint8 *code,
+    size_t codeSize)
+{
+    // SPIR-V bytecode has a 4 byte header containing 0x07230203. SPIR-V is
+    // defined as a stream of words and not a stream of bytes so both byte
+    // orders need to be considered.
+    //
+    // FIXME: It is uncertain if drivers are able to load both byte orders. If
+    // needed we may need to do an optional swizzle internally so apps can
+    // continue to treat shader code as an opaque blob.
+    if (codeSize < 4 || code == NULL) {
+        return false;
+    }
+    const Uint32 magic = 0x07230203;
+    const Uint32 magicInv = 0x03022307;
+    return SDL_memcmp(code, &magic, 4) == 0 || SDL_memcmp(code, &magicInv, 4) == 0;
+}
+
 static SDL_GPUComputePipeline *VULKAN_CreateComputePipeline(
     SDL_GPURenderer *driverData,
     const SDL_GPUComputePipelineCreateInfo *createinfo)
@@ -6526,6 +6545,10 @@ static SDL_GPUComputePipeline *VULKAN_CreateComputePipeline(
         SET_STRING_ERROR_AND_RETURN("Incompatible shader format for Vulkan!", NULL);
     }
 
+    if (!VULKAN_INTERNAL_IsValidShaderBytecode(createinfo->code, createinfo->code_size)) {
+        SET_STRING_ERROR_AND_RETURN("The provided shader code is not valid SPIR-V!", NULL);
+    }
+
     vulkanComputePipeline = SDL_malloc(sizeof(VulkanComputePipeline));
     shaderModuleCreateInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
     shaderModuleCreateInfo.pNext = NULL;
@@ -6671,6 +6694,10 @@ static SDL_GPUShader *VULKAN_CreateShader(
     VkShaderModuleCreateInfo vkShaderModuleCreateInfo;
     VulkanRenderer *renderer = (VulkanRenderer *)driverData;
 
+    if (!VULKAN_INTERNAL_IsValidShaderBytecode(createinfo->code, createinfo->code_size)) {
+        SET_STRING_ERROR_AND_RETURN("The provided shader code is not valid SPIR-V!", NULL);
+    }
+
     vulkanShader = SDL_malloc(sizeof(VulkanShader));
     vkShaderModuleCreateInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
     vkShaderModuleCreateInfo.pNext = NULL;