SDL: gpu: Show a debug error when pipelines are not given the right shader stages

https://github.com/libsdl-org/SDL/commit/90aff306c1ce83055d7c0cc98fe4b9c4ac007c35

From 90aff306c1ce83055d7c0cc98fe4b9c4ac007c35 Mon Sep 17 00:00:00 2001
From: Ethan Lee <[EMAIL REDACTED]>
Date: Thu, 9 Jan 2025 14:53:17 -0500
Subject: [PATCH] gpu: Show a debug error when pipelines are not given the
 right shader stages

---
 src/gpu/d3d12/SDL_gpu_d3d12.c   | 11 +++++++++++
 src/gpu/metal/SDL_gpu_metal.m   | 10 ++++++++++
 src/gpu/vulkan/SDL_gpu_vulkan.c | 11 +++++++++++
 3 files changed, 32 insertions(+)

diff --git a/src/gpu/d3d12/SDL_gpu_d3d12.c b/src/gpu/d3d12/SDL_gpu_d3d12.c
index d073b88f196fd..c613e29bc7805 100644
--- a/src/gpu/d3d12/SDL_gpu_d3d12.c
+++ b/src/gpu/d3d12/SDL_gpu_d3d12.c
@@ -962,6 +962,7 @@ struct D3D12Shader
     void *bytecode;
     size_t bytecodeSize;
 
+    SDL_GPUShaderStage stage;
     Uint32 num_samplers;
     Uint32 numUniformBuffers;
     Uint32 numStorageBuffers;
@@ -2858,6 +2859,15 @@ static SDL_GPUGraphicsPipeline *D3D12_CreateGraphicsPipeline(
     D3D12Shader *vertShader = (D3D12Shader *)createinfo->vertex_shader;
     D3D12Shader *fragShader = (D3D12Shader *)createinfo->fragment_shader;
 
+    if (renderer->debug_mode) {
+        if (vertShader->stage != SDL_GPU_SHADERSTAGE_VERTEX) {
+            SDL_assert_release(!"CreateGraphicsPipeline was passed a fragment shader for the vertex stage");
+        }
+        if (fragShader->stage != SDL_GPU_SHADERSTAGE_FRAGMENT) {
+            SDL_assert_release(!"CreateGraphicsPipeline was passed a vertex shader for the fragment stage");
+        }
+    }
+
     D3D12_GRAPHICS_PIPELINE_STATE_DESC psoDesc;
     SDL_zero(psoDesc);
     psoDesc.VS.pShaderBytecode = vertShader->bytecode;
@@ -3028,6 +3038,7 @@ static SDL_GPUShader *D3D12_CreateShader(
         SDL_free(bytecode);
         return NULL;
     }
+    shader->stage = createinfo->stage;
     shader->num_samplers = createinfo->num_samplers;
     shader->numStorageBuffers = createinfo->num_storage_buffers;
     shader->numStorageTextures = createinfo->num_storage_textures;
diff --git a/src/gpu/metal/SDL_gpu_metal.m b/src/gpu/metal/SDL_gpu_metal.m
index 415b38c5bde2c..5aa8ec3e386b1 100644
--- a/src/gpu/metal/SDL_gpu_metal.m
+++ b/src/gpu/metal/SDL_gpu_metal.m
@@ -471,6 +471,7 @@ static MTLDepthClipMode SDLToMetal_DepthClipMode(
     id<MTLLibrary> library;
     id<MTLFunction> function;
 
+    SDL_GPUShaderStage stage;
     Uint32 numSamplers;
     Uint32 numUniformBuffers;
     Uint32 numStorageBuffers;
@@ -1083,6 +1084,14 @@ static void METAL_ReleaseGraphicsPipeline(
         NSError *error = NULL;
         MetalGraphicsPipeline *result = NULL;
 
+        if (renderer->debugMode) {
+            if (vertexShader->stage != SDL_GPU_SHADERSTAGE_VERTEX) {
+                SDL_assert_release(!"CreateGraphicsPipeline was passed a fragment shader for the vertex stage");
+            }
+            if (fragmentShader->stage != SDL_GPU_SHADERSTAGE_FRAGMENT) {
+                SDL_assert_release(!"CreateGraphicsPipeline was passed a vertex shader for the fragment stage");
+            }
+        }
         pipelineDescriptor = [MTLRenderPipelineDescriptor new];
 
         // Blend
@@ -1380,6 +1389,7 @@ static void METAL_PopDebugGroup(
         result = SDL_calloc(1, sizeof(MetalShader));
         result->library = libraryFunction.library;
         result->function = libraryFunction.function;
+        result->stage = createinfo->stage;
         result->numSamplers = createinfo->num_samplers;
         result->numStorageBuffers = createinfo->num_storage_buffers;
         result->numStorageTextures = createinfo->num_storage_textures;
diff --git a/src/gpu/vulkan/SDL_gpu_vulkan.c b/src/gpu/vulkan/SDL_gpu_vulkan.c
index 98217ca4370f4..d04e89ec684f4 100644
--- a/src/gpu/vulkan/SDL_gpu_vulkan.c
+++ b/src/gpu/vulkan/SDL_gpu_vulkan.c
@@ -607,6 +607,7 @@ typedef struct VulkanShader
 {
     VkShaderModule shaderModule;
     const char *entrypointName;
+    SDL_GPUShaderStage stage;
     Uint32 numSamplers;
     Uint32 numStorageTextures;
     Uint32 numStorageBuffers;
@@ -6229,6 +6230,15 @@ static SDL_GPUGraphicsPipeline *VULKAN_CreateGraphicsPipeline(
     shaderStageCreateInfos[1].pName = graphicsPipeline->fragmentShader->entrypointName;
     shaderStageCreateInfos[1].pSpecializationInfo = NULL;
 
+    if (renderer->debugMode) {
+        if (graphicsPipeline->vertexShader->stage != SDL_GPU_SHADERSTAGE_VERTEX) {
+            SDL_assert_release(!"CreateGraphicsPipeline was passed a fragment shader for the vertex stage");
+        }
+        if (graphicsPipeline->fragmentShader->stage != SDL_GPU_SHADERSTAGE_FRAGMENT) {
+            SDL_assert_release(!"CreateGraphicsPipeline was passed a vertex shader for the fragment stage");
+        }
+    }
+
     // Vertex input
 
     for (i = 0; i < createinfo->vertex_input_state.num_vertex_buffers; i += 1) {
@@ -6635,6 +6645,7 @@ static SDL_GPUShader *VULKAN_CreateShader(
     vulkanShader->entrypointName = SDL_malloc(entryPointNameLength);
     SDL_utf8strlcpy((char *)vulkanShader->entrypointName, createinfo->entrypoint, entryPointNameLength);
 
+    vulkanShader->stage = createinfo->stage;
     vulkanShader->numSamplers = createinfo->num_samplers;
     vulkanShader->numStorageTextures = createinfo->num_storage_textures;
     vulkanShader->numStorageBuffers = createinfo->num_storage_buffers;