SDL_shadercross: Refactor MSL resource index calculation (#92)

From d3e304a9279c921bbba5b36c1f8b105825ca6333 Mon Sep 17 00:00:00 2001
From: Evan Hemsley <[EMAIL REDACTED]>
Date: Mon, 20 Jan 2025 08:43:33 -0800
Subject: [PATCH] Refactor MSL resource index calculation (#92)

---
 src/SDL_shadercross.c | 351 ++++++++++++++++++++++++------------------
 1 file changed, 200 insertions(+), 151 deletions(-)

diff --git a/src/SDL_shadercross.c b/src/SDL_shadercross.c
index 3197880..e2bd5ae 100644
--- a/src/SDL_shadercross.c
+++ b/src/SDL_shadercross.c
@@ -967,9 +967,14 @@ static SPIRVTranspileContext *SDL_ShaderCross_INTERNAL_TranspileFromSPIRV(
         size_t num_uniform_buffers;
         size_t num_separate_samplers = 0;
         size_t num_separate_images = 0;
-        spvc_msl_resource_binding binding;
-        unsigned int num_textures = 0;
-        unsigned int num_buffers = 0;
+
+        spvc_msl_resource_binding_2 bufferBindings[32];
+        SDL_zeroa(bufferBindings);
+        Uint32 numBufferBindings = 0;
+
+        spvc_msl_resource_binding_2 textureBindings[32];
+        SDL_zeroa(textureBindings);
+        Uint32 numTextureBindings = 0;
 
         result = spvc_compiler_create_shader_resources(compiler, &resources);
         if (result < 0) {
@@ -1021,20 +1026,14 @@ static SPIRVTranspileContext *SDL_ShaderCross_INTERNAL_TranspileFromSPIRV(
 
             unsigned int binding_index = spvc_compiler_get_decoration(compiler, reflected_resources[i].id, SpvDecorationBinding);
 
-            binding.stage = executionModel;
-            binding.desc_set = descriptor_set_index;
-            binding.binding = binding_index;
-            binding.msl_texture = binding_index;
-            binding.msl_sampler = binding_index;
-            result = spvc_compiler_msl_add_resource_binding(compiler, &binding);
-            if (result < 0) {
-                SPVC_ERROR(spvc_compiler_msl_add_resource_binding);
-                spvc_context_destroy(context);
-                return NULL;
-            }
-            num_textures += 1;
+            textureBindings[numTextureBindings].stage = executionModel;
+            textureBindings[numTextureBindings].desc_set = descriptor_set_index;
+            textureBindings[numTextureBindings].binding = binding_index;
+            textureBindings[numTextureBindings].count = 1;
+            // assign binding index after we have collected all resources
+
+            numTextureBindings += 1;
         }
-        num_textures += num_texture_samplers;
 
         // Storage textures
         result = spvc_resources_get_resource_list_for_type(
@@ -1064,18 +1063,14 @@ static SPIRVTranspileContext *SDL_ShaderCross_INTERNAL_TranspileFromSPIRV(
 
             unsigned int binding_index = spvc_compiler_get_decoration(compiler, reflected_resources[i].id, SpvDecorationBinding);
 
-            binding.stage = executionModel;
-            binding.desc_set = descriptor_set_index;
-            binding.binding = binding_index;
-            binding.msl_texture = num_textures + binding_index;
-            spvc_compiler_msl_add_resource_binding(compiler, &binding);
-            if (result < 0) {
-                SPVC_ERROR(spvc_compiler_msl_add_resource_binding);
-                spvc_context_destroy(context);
-                return NULL;
-            }
+            textureBindings[numTextureBindings].stage = executionModel;
+            textureBindings[numTextureBindings].desc_set = descriptor_set_index;
+            textureBindings[numTextureBindings].binding = binding_index;
+            textureBindings[numTextureBindings].count = 1;
+            // assign binding index after we have collected all resources
+
+            numTextureBindings += 1;
         }
-        num_textures += num_storage_textures;
 
         // If source is HLSL, storage images might be marked as separate images
         result = spvc_resources_get_resource_list_for_type(
@@ -1106,18 +1101,14 @@ static SPIRVTranspileContext *SDL_ShaderCross_INTERNAL_TranspileFromSPIRV(
 
             unsigned int binding_index = spvc_compiler_get_decoration(compiler, reflected_resources[i].id, SpvDecorationBinding);
 
-            binding.stage = executionModel;
-            binding.desc_set = descriptor_set_index;
-            binding.binding = binding_index;
-            binding.msl_texture = num_textures + binding_index;
-            spvc_compiler_msl_add_resource_binding(compiler, &binding);
-            if (result < 0) {
-                SPVC_ERROR(spvc_compiler_msl_add_resource_binding);
-                spvc_context_destroy(context);
-                return NULL;
-            }
+            textureBindings[numTextureBindings].stage = executionModel;
+            textureBindings[numTextureBindings].desc_set = descriptor_set_index;
+            textureBindings[numTextureBindings].binding = binding_index;
+            textureBindings[numTextureBindings].count = 1;
+            // assign binding index after we have collected all resources
+
+            numTextureBindings += 1;
         }
-        num_textures += (num_separate_images - num_separate_samplers);
 
         // Uniform buffers
         result = spvc_resources_get_resource_list_for_type(
@@ -1147,18 +1138,14 @@ static SPIRVTranspileContext *SDL_ShaderCross_INTERNAL_TranspileFromSPIRV(
 
             unsigned int binding_index = spvc_compiler_get_decoration(compiler, reflected_resources[i].id, SpvDecorationBinding);
 
-            binding.stage = executionModel;
-            binding.desc_set = descriptor_set_index;
-            binding.binding = binding_index;
-            binding.msl_buffer = binding_index;
-            spvc_compiler_msl_add_resource_binding(compiler, &binding);
-            if (result < 0) {
-                SPVC_ERROR(spvc_compiler_msl_add_resource_binding);
-                spvc_context_destroy(context);
-                return NULL;
-            }
+            bufferBindings[numBufferBindings].stage = executionModel;
+            bufferBindings[numBufferBindings].desc_set = descriptor_set_index;
+            bufferBindings[numBufferBindings].binding = binding_index;
+            bufferBindings[numBufferBindings].count = 1;
+            // assign binding index after we have collected all resources
+
+            numBufferBindings += 1;
         }
-        num_buffers += num_uniform_buffers;
 
         // Storage buffers
         result = spvc_resources_get_resource_list_for_type(
@@ -1188,18 +1175,55 @@ static SPIRVTranspileContext *SDL_ShaderCross_INTERNAL_TranspileFromSPIRV(
 
             unsigned int binding_index = spvc_compiler_get_decoration(compiler, reflected_resources[i].id, SpvDecorationBinding);
 
-            binding.stage = executionModel;
-            binding.desc_set = descriptor_set_index;
-            binding.binding = binding_index;
-            binding.msl_buffer = num_buffers + binding_index;
-            spvc_compiler_msl_add_resource_binding(compiler, &binding);
+            bufferBindings[numBufferBindings].stage = executionModel;
+            bufferBindings[numBufferBindings].desc_set = descriptor_set_index;
+            bufferBindings[numBufferBindings].binding = binding_index;
+            bufferBindings[numBufferBindings].count = 1;
+            // assign binding index after we have collected all resources
+
+            numBufferBindings += 1;
+        }
+
+        // Textures come first so we can just use the binding slot
+        for (Uint32 i = 0; i < numTextureBindings; i += 1) {
+            textureBindings[i].msl_texture = textureBindings[i].binding;
+            textureBindings[i].msl_sampler = textureBindings[i].binding;
+            result = spvc_compiler_msl_add_resource_binding_2(compiler, &textureBindings[i]);
+        }
+
+        if (result < 0) {
+            SPVC_ERROR(spvc_compiler_msl_add_resource_binding_2);
+            spvc_context_destroy(context);
+            return NULL;
+        }
+
+        // Calculate number of uniform buffers
+        Uint32 uniformBufferCount = 0;
+
+        for (Uint32 i = 0; i < numBufferBindings; i += 1) {
+            if (bufferBindings[i].desc_set == 1 || bufferBindings[i].desc_set == 3) {
+                uniformBufferCount += 1;
+            }
+        }
+
+        // Calculate resource indices
+        for (Uint32 i = 0; i < numBufferBindings; i += 1) {
+            if (bufferBindings[i].desc_set == 1 || bufferBindings[i].desc_set == 3) {
+                // Uniform buffers are alone in the descriptor set
+                bufferBindings[i].msl_buffer = bufferBindings[i].binding;
+            } else {
+                // Subtract by the texture count because the textures precede the storage buffers in the descriptor set
+                bufferBindings[i].msl_buffer = uniformBufferCount + (bufferBindings[i].binding - numTextureBindings);
+            }
+
+            result = spvc_compiler_msl_add_resource_binding_2(compiler, &bufferBindings[i]);
+
             if (result < 0) {
-                SPVC_ERROR(spvc_compiler_msl_add_resource_binding);
+                SPVC_ERROR(spvc_compiler_msl_add_resource_binding_2);
                 spvc_context_destroy(context);
                 return NULL;
             }
         }
-        num_buffers += num_storage_buffers;
     }
 
     if (backend == SPVC_BACKEND_MSL && shaderStage == SDL_SHADERCROSS_SHADERSTAGE_COMPUTE) {
@@ -1211,9 +1235,12 @@ static SPIRVTranspileContext *SDL_ShaderCross_INTERNAL_TranspileFromSPIRV(
         size_t num_uniform_buffers;
         size_t num_separate_samplers = 0;
         size_t num_separate_images = 0;
-        spvc_msl_resource_binding binding;
-        unsigned int num_textures = 0;
-        unsigned int num_buffers = 0;
+
+        spvc_msl_resource_binding_2 bufferBindings[32];
+        Uint32 numBufferBindings = 0;
+
+        spvc_msl_resource_binding_2 textureBindings[32];
+        Uint32 numTextureBindings = 0;
 
         result = spvc_compiler_create_shader_resources(compiler, &resources);
         if (result < 0) {
@@ -1265,19 +1292,14 @@ static SPIRVTranspileContext *SDL_ShaderCross_INTERNAL_TranspileFromSPIRV(
 
             unsigned int binding_index = spvc_compiler_get_decoration(compiler, reflected_resources[i].id, SpvDecorationBinding);
 
-            binding.stage = executionModel;
-            binding.desc_set = descriptor_set_index;
-            binding.binding = binding_index;
-            binding.msl_texture = binding_index;
-            binding.msl_sampler = binding_index;
-            result = spvc_compiler_msl_add_resource_binding(compiler, &binding);
-            if (result < 0) {
-                SPVC_ERROR(spvc_compiler_msl_add_resource_binding);
-                spvc_context_destroy(context);
-                return NULL;
-            }
+            textureBindings[numTextureBindings].stage = executionModel;
+            textureBindings[numTextureBindings].desc_set = descriptor_set_index;
+            textureBindings[numTextureBindings].binding = binding_index;
+            textureBindings[numTextureBindings].count = 1;
+            // assign binding index after we have collected all resources
+
+            numTextureBindings += 1;
         }
-        num_textures += num_texture_samplers;
 
         // Readonly storage textures
         result = spvc_resources_get_resource_list_for_type(
@@ -1291,7 +1313,6 @@ static SPIRVTranspileContext *SDL_ShaderCross_INTERNAL_TranspileFromSPIRV(
             return NULL;
         }
 
-        size_t current_num_textures = num_textures;
         for (size_t i = 0; i < num_storage_textures; i += 1) {
             if (!spvc_compiler_has_decoration(compiler, reflected_resources[i].id, SpvDecorationDescriptorSet) || !spvc_compiler_has_decoration(compiler, reflected_resources[i].id, SpvDecorationBinding)) {
                 SDL_SetError("%s", "Shader resources must have descriptor set and binding index!");
@@ -1311,17 +1332,13 @@ static SPIRVTranspileContext *SDL_ShaderCross_INTERNAL_TranspileFromSPIRV(
 
             unsigned int binding_index = spvc_compiler_get_decoration(compiler, reflected_resources[i].id, SpvDecorationBinding);
 
-            binding.stage = executionModel;
-            binding.desc_set = descriptor_set_index;
-            binding.binding = binding_index;
-            binding.msl_texture = current_num_textures + binding_index;
-            spvc_compiler_msl_add_resource_binding(compiler, &binding);
-            if (result < 0) {
-                SPVC_ERROR(spvc_compiler_msl_add_resource_binding);
-                spvc_context_destroy(context);
-                return NULL;
-            }
-            num_textures += 1;
+            textureBindings[numTextureBindings].stage = executionModel;
+            textureBindings[numTextureBindings].desc_set = descriptor_set_index;
+            textureBindings[numTextureBindings].binding = binding_index;
+            textureBindings[numTextureBindings].count = 1;
+            // assign binding index after we have collected all resources
+
+            numTextureBindings += 1;
         }
 
         // If source is HLSL, storage images might be marked as separate images
@@ -1337,7 +1354,6 @@ static SPIRVTranspileContext *SDL_ShaderCross_INTERNAL_TranspileFromSPIRV(
         }
 
         // We only want to iterate the images that don't have an associated sampler
-        current_num_textures = num_textures;
         for (size_t i = num_separate_samplers; i < num_separate_images; i += 1) {
             if (!spvc_compiler_has_decoration(compiler, reflected_resources[i].id, SpvDecorationDescriptorSet) || !spvc_compiler_has_decoration(compiler, reflected_resources[i].id, SpvDecorationBinding)) {
                 SDL_SetError("%s", "Shader resources must have descriptor set and binding index!");
@@ -1357,17 +1373,13 @@ static SPIRVTranspileContext *SDL_ShaderCross_INTERNAL_TranspileFromSPIRV(
 
             unsigned int binding_index = spvc_compiler_get_decoration(compiler, reflected_resources[i].id, SpvDecorationBinding);
 
-            binding.stage = executionModel;
-            binding.desc_set = descriptor_set_index;
-            binding.binding = binding_index;
-            binding.msl_texture = current_num_textures + binding_index;
-            spvc_compiler_msl_add_resource_binding(compiler, &binding);
-            if (result < 0) {
-                SPVC_ERROR(spvc_compiler_msl_add_resource_binding);
-                spvc_context_destroy(context);
-                return NULL;
-            }
-            num_textures += 1;
+            textureBindings[numTextureBindings].stage = executionModel;
+            textureBindings[numTextureBindings].desc_set = descriptor_set_index;
+            textureBindings[numTextureBindings].binding = binding_index;
+            textureBindings[numTextureBindings].count = 1;
+            // assign binding index after we have collected all resources
+
+            numTextureBindings += 1;
         }
 
         // Readwrite storage textures
@@ -1382,7 +1394,6 @@ static SPIRVTranspileContext *SDL_ShaderCross_INTERNAL_TranspileFromSPIRV(
             return NULL;
         }
 
-        current_num_textures = num_textures;
         for (size_t i = 0; i < num_storage_textures; i += 1) {
             unsigned int descriptor_set_index = spvc_compiler_get_decoration(compiler, reflected_resources[i].id, SpvDecorationDescriptorSet);
 
@@ -1391,17 +1402,13 @@ static SPIRVTranspileContext *SDL_ShaderCross_INTERNAL_TranspileFromSPIRV(
 
             unsigned int binding_index = spvc_compiler_get_decoration(compiler, reflected_resources[i].id, SpvDecorationBinding);
 
-            binding.stage = executionModel;
-            binding.desc_set = descriptor_set_index;
-            binding.binding = binding_index;
-            binding.msl_texture = current_num_textures + binding_index;
-            spvc_compiler_msl_add_resource_binding(compiler, &binding);
-            if (result < 0) {
-                SPVC_ERROR(spvc_compiler_msl_add_resource_binding);
-                spvc_context_destroy(context);
-                return NULL;
-            }
-            num_textures += 1;
+            textureBindings[numTextureBindings].stage = executionModel;
+            textureBindings[numTextureBindings].desc_set = descriptor_set_index;
+            textureBindings[numTextureBindings].binding = binding_index;
+            textureBindings[numTextureBindings].count = 1;
+            // assign binding index after we have collected all resources
+
+            numTextureBindings += 1;
         }
 
         // If source is HLSL, storage images might be marked as separate images
@@ -1417,7 +1424,6 @@ static SPIRVTranspileContext *SDL_ShaderCross_INTERNAL_TranspileFromSPIRV(
         }
 
         // We only want to iterate the images that don't have an associated sampler
-        current_num_textures = num_textures;
         for (size_t i = num_separate_samplers; i < num_separate_images; i += 1) {
             unsigned int descriptor_set_index = spvc_compiler_get_decoration(compiler, reflected_resources[i].id, SpvDecorationDescriptorSet);
 
@@ -1426,17 +1432,13 @@ static SPIRVTranspileContext *SDL_ShaderCross_INTERNAL_TranspileFromSPIRV(
 
             unsigned int binding_index = spvc_compiler_get_decoration(compiler, reflected_resources[i].id, SpvDecorationBinding);
 
-            binding.stage = executionModel;
-            binding.desc_set = descriptor_set_index;
-            binding.binding = binding_index;
-            binding.msl_texture = current_num_textures + binding_index;
-            spvc_compiler_msl_add_resource_binding(compiler, &binding);
-            if (result < 0) {
-                SPVC_ERROR(spvc_compiler_msl_add_resource_binding);
-                spvc_context_destroy(context);
-                return NULL;
-            }
-            num_textures += 1;
+            textureBindings[numTextureBindings].stage = executionModel;
+            textureBindings[numTextureBindings].desc_set = descriptor_set_index;
+            textureBindings[numTextureBindings].binding = binding_index;
+            textureBindings[numTextureBindings].count = 1;
+            // assign binding index after we have collected all resources
+
+            numTextureBindings += 1;
         }
 
         // Uniform buffers
@@ -1467,18 +1469,14 @@ static SPIRVTranspileContext *SDL_ShaderCross_INTERNAL_TranspileFromSPIRV(
 
             unsigned int binding_index = spvc_compiler_get_decoration(compiler, reflected_resources[i].id, SpvDecorationBinding);
 
-            binding.stage = executionModel;
-            binding.desc_set = descriptor_set_index;
-            binding.binding = binding_index;
-            binding.msl_buffer = binding_index;
-            result = spvc_compiler_msl_add_resource_binding(compiler, &binding);
-            if (result < 0) {
-                SPVC_ERROR(spvc_compiler_msl_add_resource_binding);
-                spvc_context_destroy(context);
-                return NULL;
-            }
+            bufferBindings[numBufferBindings].stage = executionModel;
+            bufferBindings[numBufferBindings].desc_set = descriptor_set_index;
+            bufferBindings[numBufferBindings].binding = binding_index;
+            bufferBindings[numBufferBindings].count = 1;
+            // assign binding index after we have collected all resources
+
+            numBufferBindings += 1;
         }
-        num_buffers += num_uniform_buffers;
 
         // Storage buffers
         result = spvc_resources_get_resource_list_for_type(
@@ -1493,7 +1491,6 @@ static SPIRVTranspileContext *SDL_ShaderCross_INTERNAL_TranspileFromSPIRV(
         }
 
         // Readonly storage buffers
-        size_t current_num_buffers = num_buffers;
         for (size_t i = 0; i < num_storage_buffers; i += 1) {
             if (!spvc_compiler_has_decoration(compiler, reflected_resources[i].id, SpvDecorationDescriptorSet) || !spvc_compiler_has_decoration(compiler, reflected_resources[i].id, SpvDecorationBinding)) {
                 SDL_SetError("%s", "Shader resources must have descriptor set and binding index!");
@@ -1513,22 +1510,16 @@ static SPIRVTranspileContext *SDL_ShaderCross_INTERNAL_TranspileFromSPIRV(
 
             unsigned int binding_index = spvc_compiler_get_decoration(compiler, reflected_resources[i].id, SpvDecorationBinding);
 
-            binding.stage = executionModel;
-            binding.desc_set = descriptor_set_index;
-            binding.binding = binding_index;
-            binding.msl_buffer = current_num_buffers + binding_index;
-            result = spvc_compiler_msl_add_resource_binding(compiler, &binding);
-            if (result < 0) {
-                SPVC_ERROR(spvc_compiler_msl_add_resource_binding);
-                spvc_context_destroy(context);
-                return NULL;
-            }
+            bufferBindings[numBufferBindings].stage = executionModel;
+            bufferBindings[numBufferBindings].desc_set = descriptor_set_index;
+            bufferBindings[numBufferBindings].binding = binding_index;
+            bufferBindings[numBufferBindings].count = 1;
+            // assign binding index after we have collected all resources
 
-            num_buffers += 1;
+            numBufferBindings += 1;
         }
 
         // Readwrite storage buffers
-        current_num_buffers = num_buffers;
         for (size_t i = 0; i < num_storage_buffers; i += 1) {
             unsigned int descriptor_set_index = spvc_compiler_get_decoration(compiler, reflected_resources[i].id, SpvDecorationDescriptorSet);
 
@@ -1537,20 +1528,78 @@ static SPIRVTranspileContext *SDL_ShaderCross_INTERNAL_TranspileFromSPIRV(
 
             unsigned int binding_index = spvc_compiler_get_decoration(compiler, reflected_resources[i].id, SpvDecorationBinding);
 
-            binding.stage = executionModel;
-            binding.desc_set = descriptor_set_index;
-            binding.binding = binding_index;
-            binding.msl_buffer = current_num_buffers + binding_index;
-            result = spvc_compiler_msl_add_resource_binding(compiler, &binding);
+            bufferBindings[numBufferBindings].stage = executionModel;
+            bufferBindings[numBufferBindings].desc_set = descriptor_set_index;
+            bufferBindings[numBufferBindings].binding = binding_index;
+            bufferBindings[numBufferBindings].count = 1;
+            // assign binding index after we have collected all resources
+
+            numBufferBindings += 1;
+        }
+
+        // Calculate binding offsets
+
+        Uint32 readonlyTextureCount = 0;
+        Uint32 readwriteTextureCount = 0;
+
+        for (Uint32 i = 0; i < numTextureBindings; i += 1) {
+            if (textureBindings[i].desc_set == 0) {
+                readonlyTextureCount += 1;
+            } else if (textureBindings[i].desc_set == 1) {
+                readwriteTextureCount += 1;
+            }
+        }
+
+        Uint32 uniformBufferCount = 0;
+        Uint32 readonlyBufferCount = 0;
+
+        for (Uint32 i = 0; i < numBufferBindings; i += 1) {
+            if (bufferBindings[i].desc_set == 0) {
+                readonlyBufferCount += 1;
+            } else if (bufferBindings[i].desc_set == 2) {
+                uniformBufferCount += 1;
+            }
+        }
+
+        // Calculate resource indices
+
+        for (Uint32 i = 0; i < numTextureBindings; i += 1) {
+            if (textureBindings[i].desc_set == 0) {
+                // readonly textures
+                textureBindings[i].msl_texture = textureBindings[i].binding;
+                textureBindings[i].msl_sampler = textureBindings[i].binding;
+            } else {
+                // readwrite textures
+                textureBindings[i].msl_texture = readonlyTextureCount + textureBindings[i].binding;
+                textureBindings[i].msl_sampler = readonlyTextureCount + textureBindings[i].binding;
+            }
+            result = spvc_compiler_msl_add_resource_binding_2(compiler, &textureBindings[i]);
             if (result < 0) {
-                SPVC_ERROR(spvc_compiler_msl_add_resource_binding);
+                SPVC_ERROR(spvc_compiler_msl_add_resource_binding_2);
                 spvc_context_destroy(context);
                 return NULL;
             }
-
-            num_buffers += 1;
         }
 
+        for (Uint32 i = 0; i < numBufferBindings; i += 1) {
+            if (bufferBindings[i].desc_set == 0) {
+                // Subtract by the readonly texture count because they precede readonly buffers in the descriptor set
+                bufferBindings[i].msl_buffer = uniformBufferCount + (bufferBindings[i].binding - readonlyTextureCount);
+            } else if (bufferBindings[i].desc_set == 1) {
+                // Subtract by the readwrite texture count because they precede readwrite buffers in the descriptor set
+                bufferBindings[i].msl_buffer = uniformBufferCount + readonlyBufferCount + (bufferBindings[i].binding - readwriteTextureCount);
+            } else {
+                // Uniform buffers are alone in the descriptor set
+                bufferBindings[i].msl_buffer = bufferBindings[i].binding;
+            }
+            result = spvc_compiler_msl_add_resource_binding_2(compiler, &bufferBindings[i]);
+
+            if (result < 0) {
+                SPVC_ERROR(spvc_compiler_msl_add_resource_binding_2);
+                spvc_context_destroy(context);
+                return NULL;
+            }
+        }
     }
 
     result = spvc_compiler_install_compiler_options(compiler, options);