SDL_shadercross: Resource reflection out params

From 6177929166ef235210235443a99fc71d10b27e74 Mon Sep 17 00:00:00 2001
From: Evan Hemsley <[EMAIL REDACTED]>
Date: Tue, 19 Nov 2024 19:58:59 -0800
Subject: [PATCH] Resource reflection out params

---
 include/SDL3_shadercross/SDL_shadercross.h | 37 ++++++++++++++--
 src/SDL_shadercross.c                      | 51 +++++++++++++++++-----
 2 files changed, 72 insertions(+), 16 deletions(-)

diff --git a/include/SDL3_shadercross/SDL_shadercross.h b/include/SDL3_shadercross/SDL_shadercross.h
index 9ca3804..ccd16f0 100644
--- a/include/SDL3_shadercross/SDL_shadercross.h
+++ b/include/SDL3_shadercross/SDL_shadercross.h
@@ -45,6 +45,27 @@ typedef enum SDL_ShaderCross_ShaderStage
    SDL_SHADERCROSS_SHADERSTAGE_COMPUTE
 } SDL_ShaderCross_ShaderStage;
 
+typedef struct SDL_ShaderCross_GraphicsShaderInfo
+{
+    Uint32 numSamplers;         /**< The number of samplers defined in the shader. */
+    Uint32 numStorageTextures;  /**< The number of storage textures defined in the shader. */
+    Uint32 numStorageBuffers;   /**< The number of storage buffers defined in the shader. */
+    Uint32 numUniformBuffers;   /**< The number of uniform buffers defined in the shader. */
+} SDL_ShaderCross_GraphicsShaderInfo;
+
+typedef struct SDL_ShaderCross_ComputePipelineInfo
+{
+    Uint32 numSamplers;                  /**< The number of samplers defined in the shader. */
+    Uint32 numReadOnlyStorageTextures;   /**< The number of readonly storage textures defined in the shader. */
+    Uint32 numReadOnlyStorageBuffers;    /**< The number of readonly storage buffers defined in the shader. */
+    Uint32 numReadWriteStorageTextures;  /**< The number of read-write storage textures defined in the shader. */
+    Uint32 numReadWriteStorageBuffers;   /**< The number of read-write storage buffers defined in the shader. */
+    Uint32 numUniformBuffers;            /**< The number of uniform buffers defined in the shader. */
+    Uint32 threadCountX;                 /**< The number of threads in the X dimension. */
+    Uint32 threadCountY;                 /**< The number of threads in the Y dimension. */
+    Uint32 threadCountZ;                 /**< The number of threads in the Z dimension. */
+} SDL_ShaderCross_ComputePipelineInfo;
+
 /**
  * Initializes SDL_gpu_shadercross
  *
@@ -144,6 +165,7 @@ extern SDL_DECLSPEC void * SDLCALL SDL_ShaderCross_CompileDXILFromSPIRV(
  * \param bytecodeSize the length of the SPIRV bytecode.
  * \param entrypoint the entry point function name for the shader in UTF-8.
  * \param shaderStage the shader stage to compile the shader with.
+ * \param info a pointer filled in with shader metadata.
  * \returns a compiled SDL_GPUShader
  *
  * \threadsafety It is safe to call this function from any thread.
@@ -153,7 +175,8 @@ extern SDL_DECLSPEC SDL_GPUShader * SDLCALL SDL_ShaderCross_CompileGraphicsShade
     const Uint8 *bytecode,
     size_t bytecodeSize,
     const char *entrypoint,
-    SDL_GPUShaderStage shaderStage);
+    SDL_GPUShaderStage shaderStage,
+    SDL_ShaderCross_GraphicsShaderInfo *info);
 
 /**
  * Compile an SDL GPU compute pipeline from SPIRV code.
@@ -162,6 +185,7 @@ extern SDL_DECLSPEC SDL_GPUShader * SDLCALL SDL_ShaderCross_CompileGraphicsShade
  * \param bytecode the SPIRV bytecode.
  * \param bytecodeSize the length of the SPIRV bytecode.
  * \param entrypoint the entry point function name for the shader in UTF-8.
+ * \param info a pointer filled in with compute pipeline metadata.
  * \returns a compiled SDL_GPUComputePipeline
  *
  * \threadsafety It is safe to call this function from any thread.
@@ -170,7 +194,8 @@ extern SDL_DECLSPEC SDL_GPUComputePipeline * SDLCALL SDL_ShaderCross_CompileComp
     SDL_GPUDevice *device,
     const Uint8 *bytecode,
     size_t bytecodeSize,
-    const char *entrypoint);
+    const char *entrypoint,
+    SDL_ShaderCross_ComputePipelineInfo *info);
 
 /**
  * Get the supported shader formats that HLSL cross-compilation can output
@@ -264,6 +289,7 @@ extern SDL_DECLSPEC void * SDLCALL SDL_ShaderCross_CompileSPIRVFromHLSL(
  * \param defines an array of define strings. Optional, can be NULL.
  * \param numDefines the number of strings in the defines array.
  * \param graphicsShaderStage the shader stage to compile the shader with.
+ * \param info a pointer filled in with shader metadata.
  * \returns a compiled SDL_GPUShader
  *
  * \threadsafety It is safe to call this function from any thread.
@@ -275,7 +301,8 @@ extern SDL_DECLSPEC SDL_GPUShader * SDLCALL SDL_ShaderCross_CompileGraphicsShade
     const char *includeDir,
     const char **defines,
     Uint32 numDefines,
-    SDL_GPUShaderStage graphicsShaderStage);
+    SDL_GPUShaderStage graphicsShaderStage,
+    SDL_ShaderCross_GraphicsShaderInfo *info);
 
 /**
  * Compile an SDL GPU compute pipeline from code.
@@ -286,6 +313,7 @@ extern SDL_DECLSPEC SDL_GPUShader * SDLCALL SDL_ShaderCross_CompileGraphicsShade
  * \param includeDir the include directory for shader code. Optional, can be NULL.
  * \param defines an array of define strings. Optional, can be NULL.
  * \param numDefines the number of strings in the defines array.
+ * \param info a pointer filled in with compute pipeline metadata.
  * \returns a compiled SDL_GPUComputePipeline
  *
  * \threadsafety It is safe to call this function from any thread.
@@ -296,7 +324,8 @@ extern SDL_DECLSPEC SDL_GPUComputePipeline * SDLCALL SDL_ShaderCross_CompileComp
     const char *entrypoint,
     const char *includeDir,
     const char **defines,
-    Uint32 numDefines);
+    Uint32 numDefines,
+    SDL_ShaderCross_ComputePipelineInfo *info);
 
 #ifdef __cplusplus
 }
diff --git a/src/SDL_shadercross.c b/src/SDL_shadercross.c
index 087db48..76b0adf 100644
--- a/src/SDL_shadercross.c
+++ b/src/SDL_shadercross.c
@@ -819,7 +819,8 @@ static void *SDL_ShaderCross_INTERNAL_CreateShaderFromHLSL(
     const char *includeDir,
     const char **defines,
     Uint32 numDefines,
-    SDL_ShaderCross_ShaderStage shaderStage)
+    SDL_ShaderCross_ShaderStage shaderStage,
+    SDL_ShaderCross_GraphicsShaderInfo *info)
 {
     size_t bytecodeSize;
 
@@ -844,14 +845,16 @@ static void *SDL_ShaderCross_INTERNAL_CreateShaderFromHLSL(
             device,
             spirv,
             bytecodeSize,
-            entrypoint);
+            entrypoint,
+            (void *)info);
     } else {
         result = SDL_ShaderCross_CompileGraphicsShaderFromSPIRV(
             device,
             spirv,
             bytecodeSize,
             entrypoint,
-            (SDL_GPUShaderStage)shaderStage);
+            (SDL_GPUShaderStage)shaderStage,
+            (void *)info);
     }
     SDL_free(spirv);
     return result;
@@ -864,7 +867,8 @@ SDL_GPUShader *SDL_ShaderCross_CompileGraphicsShaderFromHLSL(
     const char *includeDir,
     const char **defines,
     Uint32 numDefines,
-    SDL_GPUShaderStage graphicsShaderStage)
+    SDL_GPUShaderStage graphicsShaderStage,
+    SDL_ShaderCross_GraphicsShaderInfo *info)
 {
     return (SDL_GPUShader *)SDL_ShaderCross_INTERNAL_CreateShaderFromHLSL(
         device,
@@ -873,7 +877,8 @@ SDL_GPUShader *SDL_ShaderCross_CompileGraphicsShaderFromHLSL(
         includeDir,
         defines,
         numDefines,
-        (SDL_ShaderCross_ShaderStage)graphicsShaderStage);
+        (SDL_ShaderCross_ShaderStage)graphicsShaderStage,
+        (void *)info);
 }
 
 SDL_GPUComputePipeline *SDL_ShaderCross_CompileComputePipelineFromHLSL(
@@ -882,7 +887,8 @@ SDL_GPUComputePipeline *SDL_ShaderCross_CompileComputePipelineFromHLSL(
     const char *entrypoint,
     const char *includeDir,
     const char **defines,
-    Uint32 numDefines)
+    Uint32 numDefines,
+    SDL_ShaderCross_ComputePipelineInfo *info)
 {
     return (SDL_GPUComputePipeline *)SDL_ShaderCross_INTERNAL_CreateShaderFromHLSL(
         device,
@@ -891,7 +897,8 @@ SDL_GPUComputePipeline *SDL_ShaderCross_CompileComputePipelineFromHLSL(
         includeDir,
         defines,
         numDefines,
-        SDL_SHADERCROSS_SHADERSTAGE_COMPUTE);
+        SDL_SHADERCROSS_SHADERSTAGE_COMPUTE,
+        (void *)info);
 }
 
 #include <spirv_cross_c.h>
@@ -1992,7 +1999,8 @@ static void *SDL_ShaderCross_INTERNAL_CreateShaderFromSPIRV(
     const Uint8 *bytecode,
     size_t bytecodeSize,
     const char *entrypoint,
-    SDL_ShaderCross_ShaderStage shaderStage)
+    SDL_ShaderCross_ShaderStage shaderStage,
+    void *info)
 {
     SDL_GPUShaderFormat format;
 
@@ -2001,6 +2009,7 @@ static void *SDL_ShaderCross_INTERNAL_CreateShaderFromSPIRV(
     if (shader_formats & SDL_GPU_SHADERFORMAT_SPIRV) {
         if (shaderStage == SDL_SHADERCROSS_SHADERSTAGE_COMPUTE) {
             SDL_GPUComputePipelineCreateInfo createInfo;
+            SDL_ShaderCross_ComputePipelineInfo *pipelineInfo = (SDL_ShaderCross_ComputePipelineInfo *)info;
             SDL_ShaderCross_INTERNAL_ReflectComputeSPIRV(
                 bytecode,
                 bytecodeSize,
@@ -2010,9 +2019,19 @@ static void *SDL_ShaderCross_INTERNAL_CreateShaderFromSPIRV(
             createInfo.entrypoint = entrypoint;
             createInfo.format = SDL_GPU_SHADERFORMAT_SPIRV;
             createInfo.props = 0;
+            pipelineInfo->numSamplers = createInfo.num_samplers;
+            pipelineInfo->numReadOnlyStorageTextures = createInfo.num_readonly_storage_textures;
+            pipelineInfo->numReadOnlyStorageBuffers = createInfo.num_readonly_storage_buffers;
+            pipelineInfo->numReadWriteStorageTextures = createInfo.num_readwrite_storage_textures;
+            pipelineInfo->numReadWriteStorageBuffers = createInfo.num_readwrite_storage_buffers;
+            pipelineInfo->numUniformBuffers = createInfo.num_uniform_buffers;
+            pipelineInfo->threadCountX = createInfo.threadcount_x;
+            pipelineInfo->threadCountY = createInfo.threadcount_y;
+            pipelineInfo->threadCountZ = createInfo.threadcount_z;
             return SDL_CreateGPUComputePipeline(device, &createInfo);
         } else {
             SDL_GPUShaderCreateInfo createInfo;
+            SDL_ShaderCross_GraphicsShaderInfo *shaderInfo = (SDL_ShaderCross_GraphicsShaderInfo *)info;
             SDL_ShaderCross_INTERNAL_ReflectGraphicsSPIRV(
                 bytecode,
                 bytecodeSize,
@@ -2023,6 +2042,10 @@ static void *SDL_ShaderCross_INTERNAL_CreateShaderFromSPIRV(
             createInfo.format = SDL_GPU_SHADERFORMAT_SPIRV;
             createInfo.stage = (SDL_GPUShaderStage)shaderStage;
             createInfo.props = 0;
+            shaderInfo->numSamplers = createInfo.num_samplers;
+            shaderInfo->numStorageTextures = createInfo.num_storage_textures;
+            shaderInfo->numStorageBuffers = createInfo.num_storage_buffers;
+            shaderInfo->numUniformBuffers = createInfo.num_uniform_buffers;
             return SDL_CreateGPUShader(device, &createInfo);
         }
     } else if (shader_formats & SDL_GPU_SHADERFORMAT_MSL) {
@@ -2056,28 +2079,32 @@ SDL_GPUShader *SDL_ShaderCross_CompileGraphicsShaderFromSPIRV(
     const Uint8 *bytecode,
     size_t bytecodeSize,
     const char *entrypoint,
-    SDL_GPUShaderStage shaderStage)
+    SDL_GPUShaderStage shaderStage,
+    SDL_ShaderCross_GraphicsShaderInfo *info)
 {
     return (SDL_GPUShader *)SDL_ShaderCross_INTERNAL_CreateShaderFromSPIRV(
         device,
         bytecode,
         bytecodeSize,
         entrypoint,
-        (SDL_ShaderCross_ShaderStage)shaderStage);
+        (SDL_ShaderCross_ShaderStage)shaderStage,
+        info);
 }
 
 SDL_GPUComputePipeline *SDL_ShaderCross_CompileComputePipelineFromSPIRV(
     SDL_GPUDevice *device,
     const Uint8 *bytecode,
     size_t bytecodeSize,
-    const char *entrypoint)
+    const char *entrypoint,
+    SDL_ShaderCross_ComputePipelineInfo *info)
 {
     return (SDL_GPUComputePipeline *)SDL_ShaderCross_INTERNAL_CreateShaderFromSPIRV(
         device,
         bytecode,
         bytecodeSize,
         entrypoint,
-        SDL_SHADERCROSS_SHADERSTAGE_COMPUTE);
+        SDL_SHADERCROSS_SHADERSTAGE_COMPUTE,
+        info);
 }
 
 bool SDL_ShaderCross_Init(void)