SDL_gpu_shadercross: MSL: Remap MSL resources based on descriptor set input

From 955d27e6115fb58b3c4f529fe12dde4fe2668f51 Mon Sep 17 00:00:00 2001
From: cosmonaut <[EMAIL REDACTED]>
Date: Sun, 27 Oct 2024 12:37:33 -0700
Subject: [PATCH] MSL: Remap MSL resources based on descriptor set input

---
 .../SDL_gpu_shadercross.h                     |   3 +-
 src/SDL_gpu_shadercross.c                     | 356 +++++++++++++++++-
 src/cli.c                                     |   6 +-
 3 files changed, 353 insertions(+), 12 deletions(-)

diff --git a/include/SDL3_gpu_shadercross/SDL_gpu_shadercross.h b/include/SDL3_gpu_shadercross/SDL_gpu_shadercross.h
index dd857e5..5d29bd5 100644
--- a/include/SDL3_gpu_shadercross/SDL_gpu_shadercross.h
+++ b/include/SDL3_gpu_shadercross/SDL_gpu_shadercross.h
@@ -88,7 +88,8 @@ extern SDL_DECLSPEC SDL_GPUShaderFormat SDLCALL SDL_ShaderCross_GetSPIRVShaderFo
 extern SDL_DECLSPEC void * SDLCALL SDL_ShaderCross_TranspileMSLFromSPIRV(
     const Uint8 *bytecode,
     size_t bytecodeSize,
-    const char *entrypoint);
+    const char *entrypoint,
+    SDL_ShaderCross_ShaderStage shaderStage);
 
 /**
  * Transpile to HLSL code from SPIRV code.
diff --git a/src/SDL_gpu_shadercross.c b/src/SDL_gpu_shadercross.c
index 4d3e25e..275d432 100644
--- a/src/SDL_gpu_shadercross.c
+++ b/src/SDL_gpu_shadercross.c
@@ -712,8 +712,13 @@ typedef spvc_result (*pfn_spvc_context_parse_spirv)(spvc_context, const SpvId *,
 typedef spvc_result (*pfn_spvc_context_create_compiler)(spvc_context, spvc_backend, spvc_parsed_ir, spvc_capture_mode, spvc_compiler *);
 typedef spvc_result (*pfn_spvc_compiler_create_compiler_options)(spvc_compiler, spvc_compiler_options *);
 typedef spvc_result (*pfn_spvc_compiler_options_set_uint)(spvc_compiler_options, spvc_compiler_option, unsigned);
+typedef spvc_result (*pfn_spvc_compiler_create_shader_resources)(spvc_compiler, spvc_resources *);
+typedef spvc_result (*pfn_spvc_compiler_msl_add_resource_binding)(spvc_compiler, const spvc_msl_resource_binding *);
+typedef spvc_result (*pfn_spvc_compiler_has_decoration)(spvc_compiler, SpvId, SpvDecoration);
+typedef spvc_result (*pfn_spvc_compiler_get_decoration)(spvc_compiler, SpvId, SpvDecoration);
 typedef spvc_result (*pfn_spvc_compiler_install_compiler_options)(spvc_compiler, spvc_compiler_options);
 typedef spvc_result (*pfn_spvc_compiler_compile)(spvc_compiler, const char **);
+typedef spvc_result (*pfn_spvc_resources_get_resource_list_for_type)(spvc_resources, spvc_resource_type, const spvc_reflected_resource **, size_t *);
 typedef const char *(*pfn_spvc_context_get_last_error_string)(spvc_context);
 typedef SpvExecutionModel (*pfn_spvc_compiler_get_execution_model)(spvc_compiler compiler);
 typedef const char *(*pfn_spvc_compiler_get_cleansed_entry_point_name)(spvc_compiler compiler, const char *name, SpvExecutionModel model);
@@ -724,8 +729,13 @@ static pfn_spvc_context_parse_spirv SDL_spvc_context_parse_spirv = NULL;
 static pfn_spvc_context_create_compiler SDL_spvc_context_create_compiler = NULL;
 static pfn_spvc_compiler_create_compiler_options SDL_spvc_compiler_create_compiler_options = NULL;
 static pfn_spvc_compiler_options_set_uint SDL_spvc_compiler_options_set_uint = NULL;
+static pfn_spvc_compiler_create_shader_resources SDL_spvc_compiler_create_shader_resources = NULL;
+static pfn_spvc_compiler_msl_add_resource_binding SDL_spvc_compiler_msl_add_resource_binding = NULL;
+static pfn_spvc_compiler_has_decoration SDL_spvc_compiler_has_decoration = NULL;
+static pfn_spvc_compiler_get_decoration SDL_spvc_compiler_get_decoration = NULL;
 static pfn_spvc_compiler_install_compiler_options SDL_spvc_compiler_install_compiler_options = NULL;
 static pfn_spvc_compiler_compile SDL_spvc_compiler_compile = NULL;
+static pfn_spvc_resources_get_resource_list_for_type SDL_spvc_resources_get_resource_list_for_type = NULL;
 static pfn_spvc_context_get_last_error_string SDL_spvc_context_get_last_error_string = NULL;
 static pfn_spvc_compiler_get_execution_model SDL_spvc_compiler_get_execution_model = NULL;
 static pfn_spvc_compiler_get_cleansed_entry_point_name SDL_spvc_compiler_get_cleansed_entry_point_name = NULL;
@@ -738,8 +748,13 @@ static pfn_spvc_compiler_get_cleansed_entry_point_name SDL_spvc_compiler_get_cle
 #define SDL_spvc_context_create_compiler                spvc_context_create_compiler
 #define SDL_spvc_compiler_create_compiler_options       spvc_compiler_create_compiler_options
 #define SDL_spvc_compiler_options_set_uint              spvc_compiler_options_set_uint
+#define SDL_spvc_compiler_create_shader_resources       spvc_compiler_create_shader_resources
+#define SDL_spvc_compiler_msl_add_resource_binding      spvc_compiler_msl_add_resource_binding
+#define SDL_spvc_compiler_has_decoration                spvc_compiler_has_decoration
+#define SDL_spvc_compiler_get_decoration                spvc_compiler_get_decoration
 #define SDL_spvc_compiler_install_compiler_options      spvc_compiler_install_compiler_options
 #define SDL_spvc_compiler_compile                       spvc_compiler_compile
+#define SDL_spvc_resources_get_resource_list_for_type   spvc_resources_get_resource_list_for_type
 #define SDL_spvc_context_get_last_error_string          spvc_context_get_last_error_string
 #define SDL_spvc_compiler_get_execution_model           spvc_compiler_get_execution_model
 #define SDL_spvc_compiler_get_cleansed_entry_point_name spvc_compiler_get_cleansed_entry_point_name
@@ -764,7 +779,8 @@ static void SDL_ShaderCross_INTERNAL_DestroyTranspileContext(
 
 static SPIRVTranspileContext *SDL_ShaderCross_INTERNAL_TranspileFromSPIRV(
     spvc_backend backend,
-    unsigned shadermodel,
+    unsigned shadermodel, // only used for HLSL
+    SDL_ShaderCross_ShaderStage shaderStage, // only used for MSL
     const Uint8 *code,
     size_t codeSize,
     const char *entrypoint
@@ -815,6 +831,319 @@ static SPIRVTranspileContext *SDL_ShaderCross_INTERNAL_TranspileFromSPIRV(
         SDL_spvc_compiler_options_set_uint(options, SPVC_COMPILER_OPTION_HLSL_FLATTEN_MATRIX_VERTEX_INPUT_SEMANTICS, 1);
     }
 
+    if (backend == SPVC_BACKEND_MSL && shaderStage != SDL_SHADERCROSS_SHADERSTAGE_COMPUTE) {
+        spvc_resources resources;
+        spvc_reflected_resource *reflected_resources;
+        size_t num_texture_samplers;
+        size_t num_storage_textures;
+        size_t num_storage_buffers;
+        size_t num_uniform_buffers;
+
+        result = SDL_spvc_compiler_create_shader_resources(compiler, &resources);
+        if (result < 0) {
+            SPVC_ERROR(spvc_compiler_create_shader_resources);
+            SDL_spvc_context_destroy(context);
+            return NULL;
+        }
+
+        // Combined texture-samplers
+        result = SDL_spvc_resources_get_resource_list_for_type(
+            resources,
+            SPVC_RESOURCE_TYPE_SAMPLED_IMAGE,
+            (const spvc_reflected_resource **)&reflected_resources,
+            &num_texture_samplers);
+        if (result < 0) {
+            SPVC_ERROR(spvc_resources_get_resource_list_for_type);
+            SDL_spvc_context_destroy(context);
+            return NULL;
+        }
+
+        // Storage textures
+        result = SDL_spvc_resources_get_resource_list_for_type(
+            resources,
+            SPVC_RESOURCE_TYPE_STORAGE_IMAGE,
+            (const spvc_reflected_resource **)&reflected_resources,
+            &num_storage_textures);
+        if (result < 0) {
+            SPVC_ERROR(spvc_resources_get_resource_list_for_type);
+            SDL_spvc_context_destroy(context);
+            return NULL;
+        }
+
+        // Storage buffers
+        result = SDL_spvc_resources_get_resource_list_for_type(
+            resources,
+            SPVC_RESOURCE_TYPE_STORAGE_BUFFER,
+            (const spvc_reflected_resource **)&reflected_resources,
+            &num_storage_buffers);
+        if (result < 0) {
+            SPVC_ERROR(spvc_resources_get_resource_list_for_type);
+            SDL_spvc_context_destroy(context);
+            return NULL;
+        }
+
+        // Uniform buffers
+        result = SDL_spvc_resources_get_resource_list_for_type(
+            resources,
+            SPVC_RESOURCE_TYPE_UNIFORM_BUFFER,
+            (const spvc_reflected_resource **)&reflected_resources,
+            &num_uniform_buffers);
+        if (result < 0) {
+            SPVC_ERROR(spvc_resources_get_resource_list_for_type);
+            SDL_spvc_context_destroy(context);
+            return NULL;
+        }
+
+        // Create mappings
+        // TODO: validate descriptor sets and binding order?
+        spvc_msl_resource_binding binding;
+        for (int i = 0; i < num_texture_samplers; i += 1) {
+            binding.stage = SpvExecutionModelMax; // stage is arbitrary, we only are compiling one shader at a time
+            binding.desc_set = shaderStage == SDL_SHADERCROSS_SHADERSTAGE_VERTEX ? 0 : 2;
+            binding.binding = i;
+            binding.msl_texture = i;
+            binding.msl_sampler = i;
+            result = SDL_spvc_compiler_msl_add_resource_binding(compiler, &binding);
+            if (result < 0) {
+                SPVC_ERROR(spvc_compiler_msl_add_resource_binding);
+                SDL_spvc_context_destroy(context);
+                return NULL;
+            }
+        }
+
+        for (int i = 0; i < num_storage_textures; i += 1) {
+            binding.stage = SpvExecutionModelMax;
+            binding.desc_set = shaderStage == SDL_SHADERCROSS_SHADERSTAGE_VERTEX ? 0 : 2;
+            binding.binding = num_texture_samplers + i;
+            binding.msl_texture = num_texture_samplers + i;
+            SDL_spvc_compiler_msl_add_resource_binding(compiler, &binding);
+            if (result < 0) {
+                SPVC_ERROR(spvc_compiler_msl_add_resource_binding);
+                SDL_spvc_context_destroy(context);
+                return NULL;
+            }
+        }
+
+        for (int i = 0; i < num_storage_buffers; i += 1) {
+            binding.stage = SpvExecutionModelMax;
+            binding.desc_set = shaderStage == SDL_SHADERCROSS_SHADERSTAGE_VERTEX ? 0 : 2;
+            binding.binding = num_texture_samplers + num_storage_textures + i;
+            binding.msl_buffer = i;
+            SDL_spvc_compiler_msl_add_resource_binding(compiler, &binding);
+            if (result < 0) {
+                SPVC_ERROR(spvc_compiler_msl_add_resource_binding);
+                SDL_spvc_context_destroy(context);
+                return NULL;
+            }
+        }
+
+        for (int i = 0; i< num_uniform_buffers; i += 1) {
+            binding.stage = SpvExecutionModelMax;
+            binding.desc_set = shaderStage == SDL_SHADERCROSS_SHADERSTAGE_VERTEX ? 1 : 3;
+            binding.binding = i;
+            binding.msl_buffer = num_storage_buffers + i;
+            SDL_spvc_compiler_msl_add_resource_binding(compiler, &binding);
+            if (result < 0) {
+                SPVC_ERROR(spvc_compiler_msl_add_resource_binding);
+                SDL_spvc_context_destroy(context);
+                return NULL;
+            }
+        }
+    }
+
+    if (backend == SPVC_BACKEND_MSL && shaderStage == SDL_SHADERCROSS_SHADERSTAGE_COMPUTE) {
+        spvc_resources resources;
+        spvc_reflected_resource *reflected_resources;
+        size_t num_texture_samplers;
+        size_t num_storage_textures; // total storage textures
+        size_t num_storage_buffers; // total storage buffers
+        size_t num_uniform_buffers;
+
+        size_t num_readonly_storage_textures = 0;
+        size_t num_readonly_storage_buffers = 0;
+        size_t num_readwrite_storage_textures = 0;
+        size_t num_readwrite_storage_buffers = 0;
+
+        result = SDL_spvc_compiler_create_shader_resources(compiler, &resources);
+        if (result < 0) {
+            SPVC_ERROR(spvc_compiler_create_shader_resources);
+            SDL_spvc_context_destroy(context);
+            return NULL;
+        }
+
+        // Combined texture-samplers
+        result = SDL_spvc_resources_get_resource_list_for_type(
+            resources,
+            SPVC_RESOURCE_TYPE_SAMPLED_IMAGE,
+            (const spvc_reflected_resource **)&reflected_resources,
+            &num_texture_samplers);
+        if (result < 0) {
+            SPVC_ERROR(spvc_resources_get_resource_list_for_type);
+            SDL_spvc_context_destroy(context);
+            return NULL;
+        }
+
+        // Storage textures
+        result = SDL_spvc_resources_get_resource_list_for_type(
+            resources,
+            SPVC_RESOURCE_TYPE_STORAGE_IMAGE,
+            (const spvc_reflected_resource **)&reflected_resources,
+            &num_storage_textures);
+        if (result < 0) {
+            SPVC_ERROR(spvc_resources_get_resource_list_for_type);
+            SDL_spvc_context_destroy(context);
+            return NULL;
+        }
+
+        // Determine readonly vs writeonly resources
+        for (int i = 0; i < num_storage_textures; i += 1) {
+            if (!SDL_spvc_compiler_has_decoration(compiler, reflected_resources[i].id, SpvDecorationDescriptorSet) || !SDL_spvc_compiler_has_decoration(compiler, reflected_resources[i].id, SpvDecorationBinding)) {
+                SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "%s", "Shader resources must have descriptor set and binding index!");
+                SDL_spvc_context_destroy(context);
+                return NULL;
+            }
+
+            unsigned int descriptor_set_index = SDL_spvc_compiler_get_decoration(compiler, reflected_resources[i].id, SpvDecorationDescriptorSet);
+
+            if (descriptor_set_index == 0) {
+                num_readonly_storage_textures += 1;
+            } else if (descriptor_set_index == 1) {
+                num_readwrite_storage_textures += 1;
+            } else {
+                SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "%s", "Storage texture descriptor set index for compute shader must be 0 or 1!");
+                SDL_spvc_context_destroy(context);
+                return NULL;
+            }
+        }
+
+        // Storage buffers
+        result = SDL_spvc_resources_get_resource_list_for_type(
+            resources,
+            SPVC_RESOURCE_TYPE_STORAGE_BUFFER,
+            (const spvc_reflected_resource **)&reflected_resources,
+            &num_storage_buffers);
+        if (result < 0) {
+            SPVC_ERROR(spvc_resources_get_resource_list_for_type);
+            SDL_spvc_context_destroy(context);
+            return NULL;
+        }
+
+        // Determine readonly vs writeonly resources
+        for (int i = 0; i < num_storage_buffers; i += 1) {
+            if (!SDL_spvc_compiler_has_decoration(compiler, reflected_resources[i].id, SpvDecorationDescriptorSet) || !SDL_spvc_compiler_has_decoration(compiler, reflected_resources[i].id, SpvDecorationBinding)) {
+                SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "%s", "Shader resources must have descriptor set and binding index!");
+                SDL_spvc_context_destroy(context);
+                return NULL;
+            }
+
+            unsigned int descriptor_set_index = SDL_spvc_compiler_get_decoration(compiler, reflected_resources[i].id, SpvDecorationDescriptorSet);
+
+            if (descriptor_set_index == 0) {
+                num_readonly_storage_buffers += 1;
+            } else if (descriptor_set_index == 1) {
+                num_readwrite_storage_buffers += 1;
+            } else {
+                SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "%s", "Storage buffer descriptor set index for compute shader must be 0 or 1!");
+                SDL_spvc_context_destroy(context);
+                return NULL;
+            }
+        }
+
+        // Uniform buffers
+        result = SDL_spvc_resources_get_resource_list_for_type(
+            resources,
+            SPVC_RESOURCE_TYPE_UNIFORM_BUFFER,
+            (const spvc_reflected_resource **)&reflected_resources,
+            &num_uniform_buffers);
+        if (result < 0) {
+            SPVC_ERROR(spvc_resources_get_resource_list_for_type);
+            SDL_spvc_context_destroy(context);
+            return NULL;
+        }
+
+        // Create mappings
+        // TODO: validate descriptor sets and binding order?
+        spvc_msl_resource_binding binding;
+        for (int i = 0; i < num_texture_samplers; i += 1) {
+            binding.stage = SpvExecutionModelMax; // stage is arbitrary, we only are compiling one shader at a time
+            binding.desc_set = 0;
+            binding.binding = i;
+            binding.msl_texture = i;
+            binding.msl_sampler = i;
+            result = SDL_spvc_compiler_msl_add_resource_binding(compiler, &binding);
+            if (result < 0) {
+                SPVC_ERROR(spvc_compiler_msl_add_resource_binding);
+                SDL_spvc_context_destroy(context);
+                return NULL;
+            }
+        }
+
+        for (int i = 0; i < num_readonly_storage_textures; i += 1) {
+            binding.stage = SpvExecutionModelMax;
+            binding.desc_set = 0;
+            binding.binding = num_texture_samplers + i;
+            binding.msl_texture = num_texture_samplers + i;
+            SDL_spvc_compiler_msl_add_resource_binding(compiler, &binding);
+            if (result < 0) {
+                SPVC_ERROR(spvc_compiler_msl_add_resource_binding);
+                SDL_spvc_context_destroy(context);
+                return NULL;
+            }
+        }
+
+        for (int i = 0; i < num_readonly_storage_buffers; i += 1) {
+            binding.stage = SpvExecutionModelMax;
+            binding.desc_set = 0;
+            binding.binding = num_texture_samplers + num_readonly_storage_textures + i;
+            binding.msl_buffer = i;
+            SDL_spvc_compiler_msl_add_resource_binding(compiler, &binding);
+            if (result < 0) {
+                SPVC_ERROR(spvc_compiler_msl_add_resource_binding);
+                SDL_spvc_context_destroy(context);
+                return NULL;
+            }
+        }
+
+        for (int i = 0; i < num_readwrite_storage_textures; i += 1) {
+            binding.stage = SpvExecutionModelMax;
+            binding.desc_set = 1;
+            binding.binding = i;
+            binding.msl_texture = num_texture_samplers + num_readonly_storage_textures + i;
+            SDL_spvc_compiler_msl_add_resource_binding(compiler, &binding);
+            if (result < 0) {
+                SPVC_ERROR(spvc_compiler_msl_add_resource_binding);
+                SDL_spvc_context_destroy(context);
+                return NULL;
+            }
+        }
+
+        for (int i = 0; i < num_readwrite_storage_buffers; i += 1) {
+            binding.stage = SpvExecutionModelMax;
+            binding.desc_set = 1;
+            binding.binding = num_readwrite_storage_textures + i;
+            binding.msl_buffer = num_readonly_storage_buffers + i;
+            SDL_spvc_compiler_msl_add_resource_binding(compiler, &binding);
+            if (result < 0) {
+                SPVC_ERROR(spvc_compiler_msl_add_resource_binding);
+                SDL_spvc_context_destroy(context);
+                return NULL;
+            }
+        }
+
+        for (int i = 0; i < num_uniform_buffers; i += 1) {
+            binding.stage = SpvExecutionModelMax;
+            binding.desc_set = 2;
+            binding.binding = i;
+            binding.msl_buffer = num_readonly_storage_buffers + num_readwrite_storage_buffers + i;
+            SDL_spvc_compiler_msl_add_resource_binding(compiler, &binding);
+            if (result < 0) {
+                SPVC_ERROR(spvc_compiler_msl_add_resource_binding);
+                SDL_spvc_context_destroy(context);
+                return NULL;
+            }
+        }
+    }
+
     result = SDL_spvc_compiler_install_compiler_options(compiler, options);
     if (result < 0) {
         SPVC_ERROR(spvc_compiler_install_compiler_options);
@@ -847,7 +1176,7 @@ static void *SDL_ShaderCross_INTERNAL_CompileFromSPIRV(
     SDL_GPUDevice *device,
     SDL_GPUShaderFormat shaderFormat,
     const void *originalCreateInfo,
-    bool isCompute
+    SDL_ShaderCross_ShaderStage shaderStage
 ) {
     const SDL_GPUShaderCreateInfo *createInfo;
     spvc_backend backend;
@@ -875,6 +1204,7 @@ static void *SDL_ShaderCross_INTERNAL_CompileFromSPIRV(
     SPIRVTranspileContext *transpileContext = SDL_ShaderCross_INTERNAL_TranspileFromSPIRV(
         backend,
         shadermodel,
+        shaderStage,
         createInfo->code,
         createInfo->code_size,
         createInfo->entrypoint);
@@ -882,7 +1212,7 @@ static void *SDL_ShaderCross_INTERNAL_CompileFromSPIRV(
     void *shaderObject = NULL;
 
     /* Copy the original create info, but with the new source code */
-    if (isCompute) {
+    if (shaderStage == SDL_SHADERCROSS_SHADERSTAGE_COMPUTE) {
         SDL_GPUComputePipelineCreateInfo newCreateInfo;
         newCreateInfo = *(const SDL_GPUComputePipelineCreateInfo *)createInfo;
 
@@ -948,11 +1278,13 @@ static void *SDL_ShaderCross_INTERNAL_CompileFromSPIRV(
 void *SDL_ShaderCross_TranspileMSLFromSPIRV(
     const Uint8 *bytecode,
     size_t bytecodeSize,
-    const char *entrypoint)
+    const char *entrypoint,
+    SDL_ShaderCross_ShaderStage shaderStage)
 {
     SPIRVTranspileContext *context = SDL_ShaderCross_INTERNAL_TranspileFromSPIRV(
         SPVC_BACKEND_MSL,
         0,
+        shaderStage,
         bytecode,
         bytecodeSize,
         entrypoint
@@ -986,6 +1318,7 @@ void *SDL_ShaderCross_TranspileHLSLFromSPIRV(
     SPIRVTranspileContext *context = SDL_ShaderCross_INTERNAL_TranspileFromSPIRV(
         SPVC_BACKEND_HLSL,
         shadermodel,
+        0, // unused for HLSL
         bytecode,
         bytecodeSize,
         entrypoint
@@ -1009,6 +1342,7 @@ void *SDL_ShaderCross_CompileDXBCFromSPIRV(
     SPIRVTranspileContext *context = SDL_ShaderCross_INTERNAL_TranspileFromSPIRV(
         SPVC_BACKEND_HLSL,
         50,
+        0, // unused for HLSL
         bytecode,
         bytecodeSize,
         entrypoint);
@@ -1042,6 +1376,7 @@ void *SDL_ShaderCross_CompileDXILFromSPIRV(
     SPIRVTranspileContext *context = SDL_ShaderCross_INTERNAL_TranspileFromSPIRV(
         SPVC_BACKEND_HLSL,
         60,
+        0, // unused for HLSL
         bytecode,
         bytecodeSize,
         entrypoint);
@@ -1068,14 +1403,14 @@ void *SDL_ShaderCross_CompileDXILFromSPIRV(
 static void *SDL_ShaderCross_INTERNAL_CreateShaderFromSPIRV(
     SDL_GPUDevice *device,
     const void *originalCreateInfo,
-    bool isCompute)
+    SDL_ShaderCross_ShaderStage shaderStage)
 {
     SDL_GPUShaderFormat format;
 
     SDL_GPUShaderFormat shader_formats = SDL_GetGPUShaderFormats(device);
 
     if (shader_formats & SDL_GPU_SHADERFORMAT_SPIRV) {
-        if (isCompute) {
+        if (shaderStage == SDL_SHADERCROSS_SHADERSTAGE_COMPUTE) {
             return SDL_CreateGPUComputePipeline(device, (const SDL_GPUComputePipelineCreateInfo *)originalCreateInfo);
         } else {
             return SDL_CreateGPUShader(device, (const SDL_GPUShaderCreateInfo *)originalCreateInfo);
@@ -1095,21 +1430,21 @@ static void *SDL_ShaderCross_INTERNAL_CreateShaderFromSPIRV(
         device,
         format,
         originalCreateInfo,
-        isCompute);
+        shaderStage);
 }
 
 SDL_GPUShader *SDL_ShaderCross_CompileGraphicsShaderFromSPIRV(
     SDL_GPUDevice *device,
     const SDL_GPUShaderCreateInfo *createInfo)
 {
-    return (SDL_GPUShader *)SDL_ShaderCross_INTERNAL_CreateShaderFromSPIRV(device, createInfo, false);
+    return (SDL_GPUShader *)SDL_ShaderCross_INTERNAL_CreateShaderFromSPIRV(device, createInfo, (SDL_ShaderCross_ShaderStage)createInfo->stage);
 }
 
 SDL_GPUComputePipeline *SDL_ShaderCross_CompileComputePipelineFromSPIRV(
     SDL_GPUDevice *device,
     const SDL_GPUComputePipelineCreateInfo *createInfo)
 {
-    return (SDL_GPUComputePipeline *)SDL_ShaderCross_INTERNAL_CreateShaderFromSPIRV(device, createInfo, true);
+    return (SDL_GPUComputePipeline *)SDL_ShaderCross_INTERNAL_CreateShaderFromSPIRV(device, createInfo, SDL_SHADERCROSS_SHADERSTAGE_COMPUTE);
 }
 
 #endif /* SDL_GPU_SHADERCROSS_SPIRVCROSS */
@@ -1179,6 +1514,9 @@ bool SDL_ShaderCross_Init(void)
         CHECK_FUNC(spvc_context_create_compiler)
         CHECK_FUNC(spvc_compiler_create_compiler_options)
         CHECK_FUNC(spvc_compiler_options_set_uint)
+        CHECK_FUNC(spvc_compiler_create_shader_resources)
+        CHECK_FUNC(spvc_resources_get_resource_list_for_type)
+        CHECK_FUNC(spvc_compiler_msl_add_resource_binding)
         CHECK_FUNC(spvc_compiler_install_compiler_options)
         CHECK_FUNC(spvc_compiler_compile)
         CHECK_FUNC(spvc_context_get_last_error_string)
diff --git a/src/cli.c b/src/cli.c
index 396acfc..51210f2 100644
--- a/src/cli.c
+++ b/src/cli.c
@@ -283,7 +283,8 @@ int main(int argc, char *argv[])
                 char *buffer = SDL_ShaderCross_TranspileMSLFromSPIRV(
                     fileData,
                     fileSize,
-                    entrypointName);
+                    entrypointName,
+                    shaderStage);
                 SDL_IOprintf(outputIO, "%s", buffer);
                 SDL_free(buffer);
                 break;
@@ -393,7 +394,8 @@ int main(int argc, char *argv[])
                 char *buffer = SDL_ShaderCross_TranspileMSLFromSPIRV(
                     spirv,
                     bytecodeSize,
-                    entrypointName);
+                    entrypointName,
+                    shaderStage);
                 SDL_IOprintf(outputIO, "%s", buffer);
                 SDL_free(spirv);
                 SDL_free(buffer);