SDL_gpu_shadercross: Improve type-safety of the public API

From b29a8e5f886d0c99355b43280580b83ea5d7b010 Mon Sep 17 00:00:00 2001
From: Beyley Thomas <[EMAIL REDACTED]>
Date: Mon, 7 Oct 2024 00:54:27 -0700
Subject: [PATCH] Improve type-safety of the public API

---
 SDL_gpu_shadercross.h | 98 +++++++++++++++++++++++++++++++++----------
 1 file changed, 76 insertions(+), 22 deletions(-)

diff --git a/SDL_gpu_shadercross.h b/SDL_gpu_shadercross.h
index e4fe585..8ec49f0 100644
--- a/SDL_gpu_shadercross.h
+++ b/SDL_gpu_shadercross.h
@@ -54,20 +54,28 @@ extern void SDL_ShaderCross_Quit(void);
 extern SDL_GPUShaderFormat SDL_ShaderCross_GetSPIRVShaderFormats(void);
 
 /**
- * Compile an SDL shader from SPIRV code.
+ * Compile an SDL GPU shader from SPIRV code.
  *
  * \param device the SDL GPU device.
- * \param createInfo a pointer to an SDL_GPUShaderCreateInfo or SDL_GPUComputePipelineCreateInfo structure
- *                   depending on whether the shader profile is a compute profile.
- * \param isCompute a flag for whether the shader is a compute shader or not.
- * \returns a compiled SDL_GPUShader or SDL_GPUComputePipeline depending
- *          on whether isCompute is set
+ * \param createInfo a pointer to an SDL_GPUShaderCreateInfo.
+ * \returns a compiled SDL_GPUShader
  *
  * \threadsafety It is safe to call this function from any thread.
  */
-extern void *SDL_ShaderCross_CompileFromSPIRV(SDL_GPUDevice *device,
-                                              const void *createInfo,
-                                              bool isCompute);
+extern SDL_GPUShader *SDL_ShaderCross_CompileGraphicsShaderFromSPIRV(SDL_GPUDevice *device,
+                                              const SDL_GPUShaderCreateInfo *createInfo);
+
+/**
+ * Compile an SDL GPU compute pipeline from SPIRV code.
+ *
+ * \param device the SDL GPU device.
+ * \param createInfo a pointer to an SDL_GPUComputePipelineCreateInfo.
+ * \returns a compiled SDL_GPUComputePipeline
+ *
+ * \threadsafety It is safe to call this function from any thread.
+ */
+extern SDL_GPUComputePipeline *SDL_ShaderCross_CompileComputePipelineFromSPIRV(SDL_GPUDevice *device,
+                                              const SDL_GPUComputePipelineCreateInfo *createInfo);
 #endif /* SDL_GPU_SHADERCROSS_SPIRVCROSS */
 
 #if SDL_GPU_SHADERCROSS_HLSL
@@ -79,20 +87,34 @@ extern void *SDL_ShaderCross_CompileFromSPIRV(SDL_GPUDevice *device,
 extern SDL_GPUShaderFormat SDL_ShaderCross_GetHLSLShaderFormats(void);
 
 /**
- * Compile an SDL shader from HLSL code.
+ * Compile an SDL GPU shader from HLSL code.
  *
  * \param device the SDL GPU device.
- * \param createInfo a pointer to an SDL_GPUShaderCreateInfo or SDL_GPUComputePipelineCreateInfo structure
- *                   depending on whether the shader profile is a compute profile.
+ * \param createInfo a pointer to an SDL_GPUShaderCreateInfo.
  * \param hlslSource the HLSL source code for the shader.
  * \param shaderProfile the shader profile to compile the shader with.
- * \returns a compiled SDL_GPUShader or SDL_GPUComputePipeline depending
- *          on whether the shader profile is a compute profile
+ * \returns a compiled SDL_GPUShader
  *
  * \threadsafety It is safe to call this function from any thread.
  */
-extern void *SDL_ShaderCross_CompileFromHLSL(SDL_GPUDevice *device,
-                                             const void *createInfo,
+extern SDL_GPUShader *SDL_ShaderCross_CompileGraphicsShaderFromHLSL(SDL_GPUDevice *device,
+                                             const SDL_GPUShaderCreateInfo *createInfo,
+                                             const char *hlslSource,
+                                             const char *shaderProfile);
+
+/**
+ * Compile an SDL GPU compute pipeline from HLSL code.
+ *
+ * \param device the SDL GPU device.
+ * \param createInfo a pointer to an SDL_GPUComputePipelineCreateInfo.
+ * \param hlslSource the HLSL source code for the shader.
+ * \param shaderProfile the shader profile to compile the shader with.
+ * \returns a compiled SDL_GPUComputePipeline
+ *
+ * \threadsafety It is safe to call this function from any thread.
+ */
+extern SDL_GPUComputePipeline *SDL_ShaderCross_CompileComputePipelineFromHLSL(SDL_GPUDevice *device,
+                                             const SDL_GPUComputePipelineCreateInfo *createInfo,
                                              const char *hlslSource,
                                              const char *shaderProfile);
 #endif /* SDL_GPU_SHADERCROSS_HLSL */
@@ -620,7 +642,7 @@ static void *SDL_ShaderCross_INTERNAL_CompileFXC(
     return result;
 }
 
-extern void *SDL_ShaderCross_CompileFromHLSL(SDL_GPUDevice *device,
+static void *SDL_ShaderCross_INTERNAL_CompileFromHLSL(SDL_GPUDevice *device,
                                              const void *createInfo,
                                              const char *hlslSource,
                                              const char *shaderProfile)
@@ -636,10 +658,28 @@ extern void *SDL_ShaderCross_CompileFromHLSL(SDL_GPUDevice *device,
         return SDL_ShaderCross_INTERNAL_CompileDXC(device, createInfo, hlslSource, shaderProfile, DXC_CP_ACP, true);
     }
 
-    SDL_SetError("SDL_ShaderCross_CompileFromHLSL: Unexpected SDL_GPUShaderFormat");
+    SDL_SetError("SDL_ShaderCross_INTERNAL_CompileFromHLSL: Unexpected SDL_GPUShaderFormat");
     return NULL;
 }
 
+SDL_GPUShader *SDL_ShaderCross_CompileGraphicsShaderFromHLSL(
+    SDL_GPUDevice *device,
+    const SDL_GPUShaderCreateInfo *createInfo,
+    const char *hlslSource,
+    const char *shaderProfile)
+{
+    return (SDL_GPUShader *)SDL_ShaderCross_INTERNAL_CompileFromHLSL(device, createInfo, hlslSource, shaderProfile);
+}
+
+SDL_GPUComputePipeline *SDL_ShaderCross_CompileComputePipelineFromHLSL(
+    SDL_GPUDevice *device,
+    const SDL_GPUComputePipelineCreateInfo *createInfo,
+    const char *hlslSource,
+    const char *shaderProfile)
+{
+    return (SDL_GPUComputePipeline *)SDL_ShaderCross_INTERNAL_CompileFromHLSL(device, createInfo, hlslSource, shaderProfile);
+}
+
 #endif /* SDL_GPU_SHADERCROSS_HLSL */
 
 #if SDL_GPU_SHADERCROSS_SPIRVCROSS
@@ -707,7 +747,7 @@ static pfn_spvc_compiler_get_cleansed_entry_point_name SDL_spvc_compiler_get_cle
 #define SPVC_ERROR(func) \
     SDL_SetError(#func " failed: %s", SDL_spvc_context_get_last_error_string(context))
 
-void *SDL_ShaderCross_CompileFromSPIRV(
+static void *SDL_ShaderCross_INTERNAL_CompileFromSPIRV(
     SDL_GPUDevice *device,
     const void *originalCreateInfo,
     bool isCompute)
@@ -745,7 +785,7 @@ void *SDL_ShaderCross_CompileFromSPIRV(
         backend = SPVC_BACKEND_MSL;
         format = SDL_GPU_SHADERFORMAT_MSL;
     } else {
-        SDL_SetError("SDL_ShaderCross_CompileFromSPIRV: Unexpected SDL_GPUBackend");
+        SDL_SetError("SDL_ShaderCross_INTERNAL_CompileFromSPIRV: Unexpected SDL_GPUBackend");
         return NULL;
     }
 
@@ -823,7 +863,7 @@ void *SDL_ShaderCross_CompileFromSPIRV(
         newCreateInfo.entrypoint = cleansed_entrypoint;
 
         if (backend == SPVC_BACKEND_HLSL) {
-            compiledResult = SDL_ShaderCross_CompileFromHLSL(
+            compiledResult = SDL_ShaderCross_INTERNAL_CompileFromHLSL(
                 device,
                 &newCreateInfo,
                 translated_source,
@@ -846,7 +886,7 @@ void *SDL_ShaderCross_CompileFromSPIRV(
             } else {
                 profile = (shadermodel == 50) ? "ps_5_0" : "ps_6_0";
             }
-            compiledResult = SDL_ShaderCross_CompileFromHLSL(
+            compiledResult = SDL_ShaderCross_INTERNAL_CompileFromHLSL(
                 device,
                 &newCreateInfo,
                 translated_source,
@@ -864,6 +904,20 @@ void *SDL_ShaderCross_CompileFromSPIRV(
     return compiledResult;
 }
 
+SDL_GPUShader *SDL_ShaderCross_CompileGraphicsShaderFromSPIRV(
+    SDL_GPUDevice *device,
+    const SDL_GPUShaderCreateInfo *createInfo)
+{
+    return (SDL_GPUShader *)SDL_ShaderCross_INTERNAL_CompileFromSPIRV(device, createInfo, false);
+}
+
+SDL_GPUComputePipeline *SDL_ShaderCross_CompileComputePipelineFromSPIRV(
+    SDL_GPUDevice *device,
+    const SDL_GPUComputePipelineCreateInfo *createInfo)
+{
+    return (SDL_GPUComputePipeline *)SDL_ShaderCross_INTERNAL_CompileFromSPIRV(device, createInfo, true);
+}
+
 #endif /* SDL_GPU_SHADERCROSS_SPIRVCROSS */
 
 bool SDL_ShaderCross_Init(void)