SDL_shadercross: Add support for HLSL defines

From 622e56960dc4a2d79c7a0e61f175428817f1be95 Mon Sep 17 00:00:00 2001
From: Evan Hemsley <[EMAIL REDACTED]>
Date: Tue, 19 Nov 2024 18:00:38 -0800
Subject: [PATCH] Add support for HLSL defines

---
 include/SDL3_shadercross/SDL_shadercross.h | 22 +++++-
 src/SDL_shadercross.c                      | 78 ++++++++++++++++++----
 src/cli.c                                  | 24 ++++++-
 3 files changed, 108 insertions(+), 16 deletions(-)

diff --git a/include/SDL3_shadercross/SDL_shadercross.h b/include/SDL3_shadercross/SDL_shadercross.h
index 66d8dcd..9ca3804 100644
--- a/include/SDL3_shadercross/SDL_shadercross.h
+++ b/include/SDL3_shadercross/SDL_shadercross.h
@@ -187,6 +187,8 @@ extern SDL_DECLSPEC SDL_GPUShaderFormat SDLCALL SDL_ShaderCross_GetHLSLShaderFor
  * \param hlslSource the HLSL source code for the shader.
  * \param entrypoint the entry point function name for the shader in UTF-8.
  * \param includeDir the include directory for shader code. Optional, can be NULL.
+ * \param defines an array of define strings. Optional, can be NULL.
+ * \param numDefines the number of strings in the defines array.
  * \param shaderStage the shader stage to compile the shader with.
  * \param size filled in with the bytecode buffer size.
  * \returns an SDL_malloc'd buffer containing DXBC bytecode.
@@ -197,6 +199,8 @@ extern SDL_DECLSPEC void * SDLCALL SDL_ShaderCross_CompileDXBCFromHLSL(
     const char *hlslSource,
     const char *entrypoint,
     const char *includeDir,
+    const char **defines,
+    Uint32 numDefines,
     SDL_ShaderCross_ShaderStage shaderStage,
     size_t *size);
 
@@ -208,6 +212,8 @@ extern SDL_DECLSPEC void * SDLCALL SDL_ShaderCross_CompileDXBCFromHLSL(
  * \param hlslSource the HLSL source code for the shader.
  * \param entrypoint the entry point function name for the shader in UTF-8.
  * \param includeDir the include directory for shader code. Optional, can be NULL.
+ * \param defines an array of define strings. Optional, can be NULL.
+ * \param numDefines the number of strings in the defines array.
  * \param shaderStage the shader stage to compile the shader with.
  * \param size filled in with the bytecode buffer size.
  * \returns an SDL_malloc'd buffer containing DXIL bytecode.
@@ -218,6 +224,8 @@ extern SDL_DECLSPEC void * SDLCALL SDL_ShaderCross_CompileDXILFromHLSL(
     const char *hlslSource,
     const char *entrypoint,
     const char *includeDir,
+    const char **defines,
+    Uint32 numDefines,
     SDL_ShaderCross_ShaderStage shaderStage,
     size_t *size);
 
@@ -229,6 +237,8 @@ extern SDL_DECLSPEC void * SDLCALL SDL_ShaderCross_CompileDXILFromHLSL(
  * \param hlslSource the HLSL source code for the shader.
  * \param entrypoint the entry point function name for the shader in UTF-8.
  * \param includeDir the include directory for shader code. Optional, can be NULL.
+ * \param defines an array of define strings. Optional, can be NULL.
+ * \param numDefines the number of strings in the defines array.
  * \param shaderStage the shader stage to compile the shader with.
  * \param size filled in with the bytecode buffer size.
  * \returns an SDL_malloc'd buffer containing SPIRV bytecode.
@@ -239,6 +249,8 @@ extern SDL_DECLSPEC void * SDLCALL SDL_ShaderCross_CompileSPIRVFromHLSL(
     const char *hlslSource,
     const char *entrypoint,
     const char *includeDir,
+    const char **defines,
+    Uint32 numDefines,
     SDL_ShaderCross_ShaderStage shaderStage,
     size_t *size);
 
@@ -249,6 +261,8 @@ extern SDL_DECLSPEC void * SDLCALL SDL_ShaderCross_CompileSPIRVFromHLSL(
  * \param hlslSource the HLSL source code for the shader.
  * \param entrypoint the entry point function name for the shader in UTF-8.
  * \param includeDir the include directory for shader code. Optional, can be NULL.
+ * \param defines an array of define strings. Optional, can be NULL.
+ * \param numDefines the number of strings in the defines array.
  * \param graphicsShaderStage the shader stage to compile the shader with.
  * \returns a compiled SDL_GPUShader
  *
@@ -259,6 +273,8 @@ extern SDL_DECLSPEC SDL_GPUShader * SDLCALL SDL_ShaderCross_CompileGraphicsShade
     const char *hlslSource,
     const char *entrypoint,
     const char *includeDir,
+    const char **defines,
+    Uint32 numDefines,
     SDL_GPUShaderStage graphicsShaderStage);
 
 /**
@@ -268,6 +284,8 @@ extern SDL_DECLSPEC SDL_GPUShader * SDLCALL SDL_ShaderCross_CompileGraphicsShade
  * \param hlslSource the HLSL source code for the shader.
  * \param entrypoint the entry point function name for the shader in UTF-8.
  * \param includeDir the include directory for shader code. Optional, can be NULL.
+ * \param defines an array of define strings. Optional, can be NULL.
+ * \param numDefines the number of strings in the defines array.
  * \returns a compiled SDL_GPUComputePipeline
  *
  * \threadsafety It is safe to call this function from any thread.
@@ -276,7 +294,9 @@ extern SDL_DECLSPEC SDL_GPUComputePipeline * SDLCALL SDL_ShaderCross_CompileComp
     SDL_GPUDevice *device,
     const char *hlslSource,
     const char *entrypoint,
-    const char *includeDir);
+    const char *includeDir,
+    const char **defines,
+    Uint32 numDefines);
 
 #ifdef __cplusplus
 }
diff --git a/src/SDL_shadercross.c b/src/SDL_shadercross.c
index 079dcca..087db48 100644
--- a/src/SDL_shadercross.c
+++ b/src/SDL_shadercross.c
@@ -332,6 +332,8 @@ static void *SDL_ShaderCross_INTERNAL_CompileUsingDXC(
     const char *hlslSource,
     const char *entrypoint,
     const char *includeDir,
+    const char **defines,
+    Uint32 numDefines,
     SDL_ShaderCross_ShaderStage shaderStage,
     bool spirv,
     size_t *size) // filled in with number of bytes of returned buffer
@@ -405,17 +407,22 @@ static void *SDL_ShaderCross_INTERNAL_CompileUsingDXC(
         return NULL;
     }
 
-    LPCWSTR args[] = {
-        (LPCWSTR)L"-E",
-        (LPCWSTR)entryPointUtf16,
-        NULL,
-        NULL,
-        NULL,
-        NULL,
-        NULL,
-        NULL
-    };
-    Uint32 argCount = 2;
+    LPCWSTR *args = SDL_malloc(sizeof(LPCWSTR) * (numDefines + 8));
+    Uint32 argCount = 0;
+
+    for (Uint32 i = 0; i < numDefines; i += 1) {
+        args[argCount++] = (wchar_t *)SDL_iconv_string("WCHAR_T", "UTF-8", defines[i], SDL_utf8strlen(defines[i]) + 1);
+        if (args[argCount - 1] == NULL) {
+            SDL_SetError("%s", "Failed to convert define argument to WCHAR_T!");
+            SDL_free(args);
+            dxcInstance->lpVtbl->Release(dxcInstance);
+            utils->lpVtbl->Release(utils);
+            return NULL;
+        }
+    }
+
+    args[argCount++] = (LPCWSTR)L"-E";
+    args[argCount++] = (LPCWSTR)entryPointUtf16;
 
     if (includeDir != NULL) {
         includeDirLength = SDL_utf8strlen(includeDir) + 1;
@@ -521,6 +528,11 @@ static void *SDL_ShaderCross_INTERNAL_CompileUsingDXC(
     dxcInstance->lpVtbl->Release(dxcInstance);
     utils->lpVtbl->Release(utils);
 
+    for (Uint32 i = 0; i < numDefines; i += 1) {
+        SDL_free(args[i]);
+    }
+    SDL_free(args);
+
     return buffer;
 #else
     SDL_SetError("%s", "Shadercross was not built with DXC support, cannot compile using DXC!");
@@ -532,6 +544,8 @@ void *SDL_ShaderCross_CompileDXILFromHLSL(
     const char *hlslSource,
     const char *entrypoint,
     const char *includeDir,
+    const char **defines,
+    Uint32 numDefines,
     SDL_ShaderCross_ShaderStage shaderStage,
     size_t *size)
 {
@@ -541,6 +555,8 @@ void *SDL_ShaderCross_CompileDXILFromHLSL(
         hlslSource,
         entrypoint,
         includeDir,
+        defines,
+        numDefines,
         shaderStage,
         &spirvSize);
 
@@ -563,6 +579,8 @@ void *SDL_ShaderCross_CompileDXILFromHLSL(
         translatedSource,
         entrypoint,
         includeDir,
+        defines,
+        numDefines,
         shaderStage,
         false,
         size);
@@ -572,6 +590,8 @@ void *SDL_ShaderCross_CompileSPIRVFromHLSL(
     const char *hlslSource,
     const char *entrypoint,
     const char *includeDir,
+    const char **defines,
+    Uint32 numDefines,
     SDL_ShaderCross_ShaderStage shaderStage,
     size_t *size)
 {
@@ -579,6 +599,8 @@ void *SDL_ShaderCross_CompileSPIRVFromHLSL(
         hlslSource,
         entrypoint,
         includeDir,
+        defines,
+        numDefines,
         shaderStage,
         true,
         size);
@@ -702,6 +724,8 @@ void *SDL_ShaderCross_INTERNAL_CompileDXBCFromHLSL(
     const char *hlslSource,
     const char *entrypoint,
     const char *includeDir,
+    const char **defines,
+    Uint32 numDefines,
     SDL_ShaderCross_ShaderStage shaderStage,
     bool enableRoundtrip,
     size_t *size) // filled in with number of bytes of returned buffer
@@ -715,6 +739,8 @@ void *SDL_ShaderCross_INTERNAL_CompileDXBCFromHLSL(
             hlslSource,
             entrypoint,
             includeDir,
+            defines,
+            numDefines,
             shaderStage,
             &spirv_size);
 
@@ -770,6 +796,8 @@ void *SDL_ShaderCross_CompileDXBCFromHLSL(
     const char *hlslSource,
     const char *entrypoint,
     const char *includeDir,
+    const char **defines,
+    Uint32 numDefines,
     SDL_ShaderCross_ShaderStage shaderStage,
     size_t *size) // filled in with number of bytes of returned buffer
 {
@@ -777,6 +805,8 @@ void *SDL_ShaderCross_CompileDXBCFromHLSL(
         hlslSource,
         entrypoint,
         includeDir,
+        defines,
+        numDefines,
         shaderStage,
         true,
         size);
@@ -787,6 +817,8 @@ static void *SDL_ShaderCross_INTERNAL_CreateShaderFromHLSL(
     const char *hlslSource,
     const char *entrypoint,
     const char *includeDir,
+    const char **defines,
+    Uint32 numDefines,
     SDL_ShaderCross_ShaderStage shaderStage)
 {
     size_t bytecodeSize;
@@ -796,6 +828,8 @@ static void *SDL_ShaderCross_INTERNAL_CreateShaderFromHLSL(
         hlslSource,
         entrypoint,
         includeDir,
+        defines,
+        numDefines,
         shaderStage,
         &bytecodeSize);
 
@@ -828,6 +862,8 @@ SDL_GPUShader *SDL_ShaderCross_CompileGraphicsShaderFromHLSL(
     const char *hlslSource,
     const char *entrypoint,
     const char *includeDir,
+    const char **defines,
+    Uint32 numDefines,
     SDL_GPUShaderStage graphicsShaderStage)
 {
     return (SDL_GPUShader *)SDL_ShaderCross_INTERNAL_CreateShaderFromHLSL(
@@ -835,6 +871,8 @@ SDL_GPUShader *SDL_ShaderCross_CompileGraphicsShaderFromHLSL(
         hlslSource,
         entrypoint,
         includeDir,
+        defines,
+        numDefines,
         (SDL_ShaderCross_ShaderStage)graphicsShaderStage);
 }
 
@@ -842,13 +880,17 @@ SDL_GPUComputePipeline *SDL_ShaderCross_CompileComputePipelineFromHLSL(
     SDL_GPUDevice *device,
     const char *hlslSource,
     const char *entrypoint,
-    const char *includeDir)
+    const char *includeDir,
+    const char **defines,
+    Uint32 numDefines)
 {
     return (SDL_GPUComputePipeline *)SDL_ShaderCross_INTERNAL_CreateShaderFromHLSL(
         device,
         hlslSource,
         entrypoint,
         includeDir,
+        defines,
+        numDefines,
         SDL_SHADERCROSS_SHADERSTAGE_COMPUTE);
 }
 
@@ -1759,6 +1801,8 @@ static void *SDL_ShaderCross_INTERNAL_CompileFromSPIRV(
                 transpileContext->translated_source,
                 transpileContext->cleansed_entrypoint,
                 NULL,
+                NULL,
+                0,
                 shaderStage,
                 false,
                 &createInfo.code_size);
@@ -1767,6 +1811,8 @@ static void *SDL_ShaderCross_INTERNAL_CompileFromSPIRV(
                 transpileContext->translated_source,
                 transpileContext->cleansed_entrypoint,
                 NULL,
+                NULL,
+                0,
                 shaderStage,
                 &createInfo.code_size);
         } else { // MSL
@@ -1791,6 +1837,8 @@ static void *SDL_ShaderCross_INTERNAL_CompileFromSPIRV(
                 transpileContext->translated_source,
                 transpileContext->cleansed_entrypoint,
                 NULL,
+                NULL,
+                0,
                 shaderStage,
                 false,
                 &createInfo.code_size);
@@ -1799,6 +1847,8 @@ static void *SDL_ShaderCross_INTERNAL_CompileFromSPIRV(
                 transpileContext->translated_source,
                 transpileContext->cleansed_entrypoint,
                 NULL,
+                NULL,
+                0,
                 shaderStage,
                 &createInfo.code_size);
         } else { // MSL
@@ -1890,6 +1940,8 @@ void *SDL_ShaderCross_CompileDXBCFromSPIRV(
         context->translated_source,
         context->cleansed_entrypoint,
         NULL,
+        NULL,
+        0,
         shaderStage,
         false,
         size);
@@ -1926,6 +1978,8 @@ void *SDL_ShaderCross_CompileDXILFromSPIRV(
         context->translated_source,
         context->cleansed_entrypoint,
         NULL,
+        NULL,
+        0,
         shaderStage,
         size);
 
diff --git a/src/cli.c b/src/cli.c
index 5a29602..6478515 100644
--- a/src/cli.c
+++ b/src/cli.c
@@ -43,12 +43,12 @@ void print_help(void)
     SDL_Log("  %-*s %s", column_width, "-t | --stage <value>", "Shader stage. May be inferred from the filename. Values: [vertex, fragment, compute]");
     SDL_Log("  %-*s %s", column_width, "-e | --entrypoint <value>", "Entrypoint function name. Default: \"main\".");
     SDL_Log("  %-*s %s", column_width, "-I | --include <value>", "HLSL include directory. Only used with HLSL source. Optional.");
+    SDL_Log("  %-*s %s", column_width, "-D<value>", "HLSL define. Only used with HLSL source. Optional. Can be repeated.");
     SDL_Log("  %-*s %s", column_width, "-o | --output <value>", "Output file.");
 }
 
 int main(int argc, char *argv[])
 {
-
     bool sourceValid = false;
     bool destinationValid = false;
     bool stageValid = false;
@@ -65,6 +65,9 @@ int main(int argc, char *argv[])
     void *fileData = NULL;
     bool accept_optionals = true;
 
+    Uint32 numDefines = 0;
+    const char **defines = NULL;
+
     for (int i = 1; i < argc; i += 1) {
         char *arg = argv[i];
 
@@ -167,6 +170,10 @@ int main(int argc, char *argv[])
                 }
                 i += 1;
                 outputFilename = argv[i];
+            } else if (strncmp(argv[i], "-D", strlen("-D")) == 0) {
+                numDefines += 1;
+                defines = SDL_realloc(defines, sizeof(const char *) * numDefines);
+                defines[numDefines - 1] = argv[i];
             } else if (SDL_strcmp(arg, "--") == 0) {
                 accept_optionals = false;
             } else {
@@ -345,6 +352,8 @@ int main(int argc, char *argv[])
                     fileData,
                     entrypointName,
                     includeDir,
+                    defines,
+                    numDefines,
                     shaderStage,
                     &bytecodeSize);
                 if (buffer == NULL) {
@@ -362,6 +371,8 @@ int main(int argc, char *argv[])
                     fileData,
                     entrypointName,
                     includeDir,
+                    defines,
+                    numDefines,
                     shaderStage,
                     &bytecodeSize);
                 if (buffer == NULL) {
@@ -380,6 +391,8 @@ int main(int argc, char *argv[])
                     fileData,
                     entrypointName,
                     includeDir,
+                    defines,
+                    numDefines,
                     shaderStage,
                     &bytecodeSize);
                 if (spirv == NULL) {
@@ -408,6 +421,8 @@ int main(int argc, char *argv[])
                     fileData,
                     entrypointName,
                     includeDir,
+                    defines,
+                    numDefines,
                     shaderStage,
                     &bytecodeSize);
                 if (buffer == NULL) {
@@ -425,11 +440,13 @@ int main(int argc, char *argv[])
                     fileData,
                     entrypointName,
                     includeDir,
+                    defines,
+                    numDefines,
                     shaderStage,
                     &bytecodeSize);
 
                 if (spirv == NULL) {
-                    SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "%s", "Failed to compile HLSL to SPIRV!");
+                    SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "Failed to compile HLSL to SPIRV: %s", SDL_GetError());
                     result = 1;
                     break;
                 }
@@ -441,7 +458,7 @@ int main(int argc, char *argv[])
                     shaderStage);
 
                 if (buffer == NULL) {
-                    SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "%s", "Failed to transpile HLSL from SPIRV!");
+                    SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "Failed to transpile HLSL from SPIRV: %s", SDL_GetError());
                     result = 1;
                     break;
                 }
@@ -462,6 +479,7 @@ int main(int argc, char *argv[])
 
     SDL_CloseIO(outputIO);
     SDL_free(fileData);
+    SDL_free(defines);
     SDL_ShaderCross_Quit();
     return result;
 }