SDL_gpu_shadercross: Add support for HLSL include dirs (#55)

From 80626699abd04430a537685ede8c4b53ac628dcc Mon Sep 17 00:00:00 2001
From: Evan Hemsley <[EMAIL REDACTED]>
Date: Tue, 12 Nov 2024 10:27:01 -0800
Subject: [PATCH] Add support for HLSL include dirs (#55)

---------

Co-authored-by: Anonymous Maarten <madebr@users.noreply.github.com>
---
 .../SDL_gpu_shadercross.h                     |  10 +
 src/SDL_gpu_shadercross.c                     | 172 ++++++++++++++++--
 src/cli.c                                     |  20 ++
 3 files changed, 191 insertions(+), 11 deletions(-)

diff --git a/include/SDL3_gpu_shadercross/SDL_gpu_shadercross.h b/include/SDL3_gpu_shadercross/SDL_gpu_shadercross.h
index c5dc456..f19dfaf 100644
--- a/include/SDL3_gpu_shadercross/SDL_gpu_shadercross.h
+++ b/include/SDL3_gpu_shadercross/SDL_gpu_shadercross.h
@@ -229,6 +229,7 @@ extern SDL_DECLSPEC SDL_GPUShaderFormat SDLCALL SDL_ShaderCross_GetHLSLShaderFor
  *
  * \param hlslSource the HLSL source code for the shader.
  * \param entrypoint the entry point function name for the shader in UTF-8.
+ * \param includeDir the include directory for shader code. Optional, can be NULL.
  * \param shaderStage the shader stage to compile the shader with.
  * \param size filled in with the bytecode buffer size.
  * \returns an SDL_malloc'd buffer containing DXBC bytecode.
@@ -238,6 +239,7 @@ extern SDL_DECLSPEC SDL_GPUShaderFormat SDLCALL SDL_ShaderCross_GetHLSLShaderFor
 extern SDL_DECLSPEC void * SDLCALL SDL_ShaderCross_CompileDXBCFromHLSL(
     const char *hlslSource,
     const char *entrypoint,
+    const char *includeDir,
     SDL_ShaderCross_ShaderStage shaderStage,
     size_t *size);
 
@@ -248,6 +250,7 @@ extern SDL_DECLSPEC void * SDLCALL SDL_ShaderCross_CompileDXBCFromHLSL(
  *
  * \param hlslSource the HLSL source code for the shader.
  * \param entrypoint the entry point function name for the shader in UTF-8.
+ * \param includeDir the include directory for shader code. Optional, can be NULL.
  * \param shaderStage the shader stage to compile the shader with.
  * \param size filled in with the bytecode buffer size.
  * \returns an SDL_malloc'd buffer containing DXIL bytecode.
@@ -257,6 +260,7 @@ extern SDL_DECLSPEC void * SDLCALL SDL_ShaderCross_CompileDXBCFromHLSL(
 extern SDL_DECLSPEC void * SDLCALL SDL_ShaderCross_CompileDXILFromHLSL(
     const char *hlslSource,
     const char *entrypoint,
+    const char *includeDir,
     SDL_ShaderCross_ShaderStage shaderStage,
     size_t *size);
 
@@ -267,6 +271,7 @@ extern SDL_DECLSPEC void * SDLCALL SDL_ShaderCross_CompileDXILFromHLSL(
  *
  * \param hlslSource the HLSL source code for the shader.
  * \param entrypoint the entry point function name for the shader in UTF-8.
+ * \param includeDir the include directory for shader code. Optional, can be NULL.
  * \param shaderStage the shader stage to compile the shader with.
  * \param size filled in with the bytecode buffer size.
  * \returns an SDL_malloc'd buffer containing SPIRV bytecode.
@@ -276,6 +281,7 @@ extern SDL_DECLSPEC void * SDLCALL SDL_ShaderCross_CompileDXILFromHLSL(
 extern SDL_DECLSPEC void * SDLCALL SDL_ShaderCross_CompileSPIRVFromHLSL(
     const char *hlslSource,
     const char *entrypoint,
+    const char *includeDir,
     SDL_ShaderCross_ShaderStage shaderStage,
     size_t *size);
 
@@ -285,6 +291,7 @@ extern SDL_DECLSPEC void * SDLCALL SDL_ShaderCross_CompileSPIRVFromHLSL(
  * \param device the SDL GPU device.
  * \param hlslSource the HLSL source code for the shader.
  * \param entrypoint the entry point function name for the shader in UTF-8.
+ * \param includeDir the include directory for shader code. Optional, can be NULL.
  * \param graphicsShaderStage the shader stage to compile the shader with.
  * \param resourceInfo a pointer to an SDL_ShaderCross_ShaderResourceInfo.
  * \returns a compiled SDL_GPUShader
@@ -295,6 +302,7 @@ extern SDL_DECLSPEC SDL_GPUShader * SDLCALL SDL_ShaderCross_CompileGraphicsShade
     SDL_GPUDevice *device,
     const char *hlslSource,
     const char *entrypoint,
+    const char *includeDir,
     SDL_GPUShaderStage graphicsShaderStage,
     const SDL_ShaderCross_ShaderResourceInfo *resourceInfo);
 
@@ -304,6 +312,7 @@ extern SDL_DECLSPEC SDL_GPUShader * SDLCALL SDL_ShaderCross_CompileGraphicsShade
  * \param device the SDL GPU device.
  * \param hlslSource the HLSL source code for the shader.
  * \param entrypoint the entry point function name for the shader in UTF-8.
+ * \param includeDir the include directory for shader code. Optional, can be NULL.
  * \param resourceInfo a pointer to an SDL_ShaderCross_ComputeResourceInfo.
  * \returns a compiled SDL_GPUComputePipeline
  *
@@ -313,6 +322,7 @@ extern SDL_DECLSPEC SDL_GPUComputePipeline * SDLCALL SDL_ShaderCross_CompileComp
     SDL_GPUDevice *device,
     const char *hlslSource,
     const char *entrypoint,
+    const char *includeDir,
     const SDL_ShaderCross_ComputeResourceInfo *resourceInfo);
 
 #endif /* SDL_GPU_SHADERCROSS_HLSL */
diff --git a/src/SDL_gpu_shadercross.c b/src/SDL_gpu_shadercross.c
index d0e7c25..43d6ec7 100644
--- a/src/SDL_gpu_shadercross.c
+++ b/src/SDL_gpu_shadercross.c
@@ -256,6 +256,63 @@ struct IDxcCompiler3
     const IDxcCompiler3Vtbl *lpVtbl;
 };
 
+// We need all this DxcUtils garbage for DXC include dir support. Thanks Microsoft!
+typedef struct IMalloc IMalloc;
+typedef struct IStream IStream;
+typedef struct DxcDefine DxcDefine;
+typedef struct IDxcCompilerArgs IDxcCompilerArgs;
+
+static struct
+{
+    Uint32 Data1;
+    Uint16 Data2;
+    Uint16 Data3;
+    Uint8 Data4[8];
+} CLSID_DxcUtils = {
+    .Data1 = 0x6245d6af,
+    .Data2 = 0x66e0,
+    .Data3 = 0x48fd,
+    .Data4 = {0x80, 0xb4, 0x4d, 0x27, 0x17, 0x96, 0x74, 0x8c}};
+static Uint8 IID_IDxcUtils[] = {
+    0xcb, 0xc4, 0x05, 0x46,
+    0x19, 0x20,
+    0x2a, 0x49,
+    0xad,
+    0xa4,
+    0x65,
+    0xf2,
+    0x0b,
+    0xb7,
+    0xd6,
+    0x7f
+};
+typedef struct IDxcUtilsVtbl
+{
+    HRESULT (__stdcall *QueryInterface)(void *pSelf, REFIID riid, void **ppvObject);
+    ULONG (__stdcall *AddRef)(void *pSelf);
+    ULONG (__stdcall *Release)(void *pSelf);
+
+    HRESULT (__stdcall *CreateBlobFromBlob)(void *pSelf, IDxcBlob *pBlob, UINT offset, UINT length, IDxcBlob **ppResult);
+    HRESULT (__stdcall *CreateBlobFromPinned)(void *pSelf, LPCVOID pData, UINT size, UINT codePage, IDxcBlobEncoding **pBlobEncoding);
+    HRESULT (__stdcall *MoveToBlob)(void *pSelf, LPCVOID pData, IMalloc *pIMalloc, UINT size, UINT codePage, IDxcBlobEncoding **pBlobEncoding);
+    HRESULT (__stdcall *CreateBlob)(void *pSelf, LPCVOID pData, UINT size, UINT codePage, IDxcBlobEncoding **pBlobEncoding);
+    HRESULT (__stdcall *LoadFile)(void *pSelf, LPCWSTR pFileName, UINT *pCodePage, IDxcBlobEncoding **pBlobEncoding);
+    HRESULT (__stdcall *CreateReadOnlyStreamFromBlob)(void *pSelf, IDxcBlob *pBlob, IStream **ppStream);
+    HRESULT (__stdcall *CreateDefaultIncludeHandler)(void *pSelf, IDxcIncludeHandler **ppResult);
+    HRESULT (__stdcall *GetBlobAsUtf8)(void *pSelf, IDxcBlob *pBlob, IDxcBlobUtf8 **pBlobEncoding);
+    HRESULT (__stdcall *GetBlobAsWide)(void *pSelf, IDxcBlob *pBlob, IDxcBlobWide **pBlobEncoding);
+    HRESULT (__stdcall *GetDxilContainerPart)(void *pSelf, const DxcBuffer *pShader, UINT DxcPart, void **ppPartData, UINT *pPartSizeInBytes);
+    HRESULT (__stdcall *CreateReflection)(void *pSelf, const DxcBuffer *pData, REFIID iid, void **ppvReflection);
+    HRESULT (__stdcall *BuildArguments)(void *pSelf, LPCWSTR pSourceName, LPCWSTR pEntryPoint, LPCWSTR pTargetProfile, LPCWSTR *pArguments, UINT argCount, const DxcDefine *pDefines, UINT defineCount, IDxcCompilerArgs **ppArgs);
+    HRESULT (__stdcall *GetPDBContents)(void *pSelf, IDxcBlob *pPDBBlob, IDxcBlob **ppHash, IDxcBlob **ppContainer);
+} IDxcUtilsVtbl;
+
+typedef struct IDxcUtils IDxcUtils;
+struct IDxcUtils
+{
+    const IDxcUtilsVtbl *lpVtbl;
+};
+
 /* *INDENT-ON* */ // clang-format on
 
 /* DXCompiler */
@@ -273,6 +330,7 @@ HRESULT DxcCreateInstance(REFCLSID rclsid, REFIID riid, LPVOID *ppv);
 static void *SDL_ShaderCross_INTERNAL_CompileUsingDXC(
     const char *hlslSource,
     const char *entrypoint,
+    const char *includeDir,
     SDL_ShaderCross_ShaderStage shaderStage,
     bool spirv,
     size_t *size) // filled in with number of bytes of returned buffer
@@ -282,20 +340,15 @@ static void *SDL_ShaderCross_INTERNAL_CompileUsingDXC(
     IDxcBlob *blob;
     IDxcBlobUtf8 *errors;
     size_t entryPointLength = SDL_utf8strlen(entrypoint) + 1;
-    wchar_t *entryPointUtf16 = (wchar_t *)SDL_iconv_string("WCHAR_T", "UTF-8", entrypoint, entryPointLength);
-    LPCWSTR args[] = {
-        (LPCWSTR)L"-E",
-        (LPCWSTR)entryPointUtf16,
-        NULL,
-        NULL,
-        NULL,
-        NULL
-    };
-    Uint32 argCount = 2;
+    wchar_t *entryPointUtf16 = NULL;
+    size_t includeDirLength = 0;
+    wchar_t *includeDirUtf16 = NULL;
     HRESULT ret;
 
     /* Non-static DxcInstance, since the functions we call on it are not thread-safe */
     IDxcCompiler3 *dxcInstance = NULL;
+    IDxcUtils *utils = NULL;
+    IDxcIncludeHandler *includeHandler = NULL;
 
     #if defined(SDL_PLATFORM_XBOXONE) || defined(SDL_PLATFORM_XBOXSERIES)
     if (SDL_DxcCreateInstance == NULL) {
@@ -306,17 +359,78 @@ static void *SDL_ShaderCross_INTERNAL_CompileUsingDXC(
         &CLSID_DxcCompiler,
         IID_IDxcCompiler3,
         (void **)&dxcInstance);
+
+    SDL_DxcCreateInstance(
+        &CLSID_DxcUtils,
+        &IID_IDxcUtils,
+        (void **)(&utils));
     #else
     DxcCreateInstance(
         &CLSID_DxcCompiler,
         IID_IDxcCompiler3,
         (void **)&dxcInstance);
+
+    DxcCreateInstance(
+        &CLSID_DxcUtils,
+        &IID_IDxcUtils,
+        (void **)(&utils));
     #endif
 
     if (dxcInstance == NULL) {
+        SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "%s", "Could not create DXC instance!");
         return NULL;
     }
 
+    if (utils == NULL) {
+        SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "%s", "Could not create DXC utils instance!");
+        dxcInstance->lpVtbl->Release(dxcInstance);
+        return NULL;
+    }
+
+    utils->lpVtbl->CreateDefaultIncludeHandler(utils, &includeHandler);
+    if (includeHandler == NULL) {
+        SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "%s", "Failed to create a default include handler!");
+        dxcInstance->lpVtbl->Release(dxcInstance);
+        utils->lpVtbl->Release(utils);
+        return NULL;
+    }
+
+    entryPointUtf16 = (wchar_t *)SDL_iconv_string("WCHAR_T", "UTF-8", entrypoint, entryPointLength);
+    if (entryPointUtf16 == NULL) {
+        SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "%s", "Failed to convert entrypoint to WCHAR_T!");
+        dxcInstance->lpVtbl->Release(dxcInstance);
+        utils->lpVtbl->Release(utils);
+        return NULL;
+    }
+
+    LPCWSTR args[] = {
+        (LPCWSTR)L"-E",
+        (LPCWSTR)entryPointUtf16,
+        NULL,
+        NULL,
+        NULL,
+        NULL,
+        NULL,
+        NULL
+    };
+    Uint32 argCount = 2;
+
+    if (includeDir != NULL) {
+        includeDirLength = SDL_utf8strlen(includeDir) + 1;
+        includeDirUtf16 = (wchar_t *)SDL_iconv_string("WCHAR_T", "UTF-8", includeDir, includeDirLength);
+
+        if (includeDirUtf16 == NULL) {
+            SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "%s", "Failed to convert include dir to WCHAR_T!");
+            dxcInstance->lpVtbl->Release(dxcInstance);
+            utils->lpVtbl->Release(utils);
+            SDL_free(entryPointUtf16);
+            return NULL;
+        }
+        args[2] = (LPCWSTR)L"-I";
+        args[3] = includeDirUtf16;
+        argCount += 2;
+    }
+
     source.Ptr = hlslSource;
     source.Size = SDL_strlen(hlslSource) + 1;
     source.Encoding = DXC_CP_ACP;
@@ -345,22 +459,27 @@ static void *SDL_ShaderCross_INTERNAL_CompileUsingDXC(
         &source,
         args,
         argCount,
-        NULL,
+        includeHandler,
         IID_IDxcResult,
         (void **)&dxcResult);
 
     SDL_free(entryPointUtf16);
+    if (includeDirUtf16 != NULL) {
+        SDL_free(includeDirUtf16);
+    }
 
     if (ret < 0) {
         SDL_LogError(SDL_LOG_CATEGORY_GPU,
                      "IDxcShaderCompiler3::Compile failed: %X",
                      ret);
         dxcInstance->lpVtbl->Release(dxcInstance);
+        utils->lpVtbl->Release(utils);
         return NULL;
     } else if (dxcResult == NULL) {
         SDL_LogError(SDL_LOG_CATEGORY_GPU,
                      "HLSL compilation failed with no IDxcResult");
         dxcInstance->lpVtbl->Release(dxcInstance);
+        utils->lpVtbl->Release(utils);
         return NULL;
     }
 
@@ -375,6 +494,7 @@ static void *SDL_ShaderCross_INTERNAL_CompileUsingDXC(
                      (char *)errors->lpVtbl->GetBufferPointer(errors));
         dxcResult->lpVtbl->Release(dxcResult);
         dxcInstance->lpVtbl->Release(dxcInstance);
+        utils->lpVtbl->Release(utils);
         return NULL;
     }
 
@@ -387,6 +507,7 @@ static void *SDL_ShaderCross_INTERNAL_CompileUsingDXC(
         SDL_LogError(SDL_LOG_CATEGORY_GPU, "IDxcBlob fetch failed");
         dxcResult->lpVtbl->Release(dxcResult);
         dxcInstance->lpVtbl->Release(dxcInstance);
+        utils->lpVtbl->Release(utils);
         return NULL;
     }
 
@@ -394,8 +515,10 @@ static void *SDL_ShaderCross_INTERNAL_CompileUsingDXC(
     void *buffer = SDL_malloc(*size);
     SDL_memcpy(buffer, blob->lpVtbl->GetBufferPointer(blob), *size);
 
+    blob->lpVtbl->Release(blob);
     dxcResult->lpVtbl->Release(dxcResult);
     dxcInstance->lpVtbl->Release(dxcInstance);
+    utils->lpVtbl->Release(utils);
 
     return buffer;
 }
@@ -403,6 +526,7 @@ static void *SDL_ShaderCross_INTERNAL_CompileUsingDXC(
 void *SDL_ShaderCross_CompileDXILFromHLSL(
     const char *hlslSource,
     const char *entrypoint,
+    const char *includeDir,
     SDL_ShaderCross_ShaderStage shaderStage,
     size_t *size)
 {
@@ -411,6 +535,7 @@ void *SDL_ShaderCross_CompileDXILFromHLSL(
     void *spirv = SDL_ShaderCross_CompileSPIRVFromHLSL(
         hlslSource,
         entrypoint,
+        includeDir,
         shaderStage,
         &spirvSize);
 
@@ -433,6 +558,7 @@ void *SDL_ShaderCross_CompileDXILFromHLSL(
     return SDL_ShaderCross_INTERNAL_CompileUsingDXC(
         translatedSource,
         entrypoint,
+        includeDir,
         shaderStage,
         false,
         size);
@@ -441,12 +567,14 @@ void *SDL_ShaderCross_CompileDXILFromHLSL(
 void *SDL_ShaderCross_CompileSPIRVFromHLSL(
     const char *hlslSource,
     const char *entrypoint,
+    const char *includeDir,
     SDL_ShaderCross_ShaderStage shaderStage,
     size_t *size)
 {
     return SDL_ShaderCross_INTERNAL_CompileUsingDXC(
         hlslSource,
         entrypoint,
+        includeDir,
         shaderStage,
         true,
         size);
@@ -456,6 +584,7 @@ static void *SDL_ShaderCross_INTERNAL_CreateShaderFromDXC(
     SDL_GPUDevice *device,
     const char *hlslSource,
     const char *entrypoint,
+    const char *includeDir,
     SDL_ShaderCross_ShaderStage shaderStage,
     const void *resourceInfo,
     bool spirv)
@@ -469,6 +598,7 @@ static void *SDL_ShaderCross_INTERNAL_CreateShaderFromDXC(
         bytecode = SDL_ShaderCross_CompileDXILFromHLSL(
             hlslSource,
             entrypoint,
+            includeDir,
             shaderStage,
             &bytecodeSize);
     } else {
@@ -476,6 +606,7 @@ static void *SDL_ShaderCross_INTERNAL_CreateShaderFromDXC(
         bytecode = SDL_ShaderCross_INTERNAL_CompileUsingDXC(
             hlslSource,
             entrypoint,
+            includeDir,
             shaderStage,
             spirv,
             &bytecodeSize);
@@ -640,6 +771,7 @@ static ID3DBlob *SDL_ShaderCross_INTERNAL_CompileDXBC(
 void *SDL_ShaderCross_CompileDXBCFromHLSL(
     const char *hlslSource,
     const char *entrypoint,
+    const char *includeDir,
     SDL_ShaderCross_ShaderStage shaderStage,
     size_t *size) // filled in with number of bytes of returned buffer
 {
@@ -648,6 +780,7 @@ void *SDL_ShaderCross_CompileDXBCFromHLSL(
     void *spirv = SDL_ShaderCross_CompileSPIRVFromHLSL(
         hlslSource,
         entrypoint,
+        includeDir,
         shaderStage,
         &spirv_size);
 
@@ -701,6 +834,7 @@ static void *SDL_ShaderCross_INTERNAL_CreateShaderFromDXBC(
     SDL_GPUDevice *device,
     const char *hlslSource,
     const char *entrypoint,
+    const char *includeDir,
     SDL_ShaderCross_ShaderStage shaderStage,
     const void *resourceInfo)
 {
@@ -710,6 +844,7 @@ static void *SDL_ShaderCross_INTERNAL_CreateShaderFromDXBC(
     void *bytecode = SDL_ShaderCross_CompileDXBCFromHLSL(
         hlslSource,
         entrypoint,
+        includeDir,
         shaderStage,
         &bytecodeSize);
 
@@ -761,6 +896,7 @@ static void *SDL_ShaderCross_INTERNAL_CreateShaderFromHLSL(
     SDL_GPUDevice *device,
     const char *hlslSource,
     const char *entrypoint,
+    const char *includeDir,
     SDL_ShaderCross_ShaderStage shaderStage,
     const void *resourceInfo)
 {
@@ -770,6 +906,7 @@ static void *SDL_ShaderCross_INTERNAL_CreateShaderFromHLSL(
             device,
             hlslSource,
             entrypoint,
+            includeDir,
             shaderStage,
             resourceInfo);
     }
@@ -778,6 +915,7 @@ static void *SDL_ShaderCross_INTERNAL_CreateShaderFromHLSL(
             device,
             hlslSource,
             entrypoint,
+            includeDir,
             shaderStage,
             resourceInfo,
             false);
@@ -787,6 +925,7 @@ static void *SDL_ShaderCross_INTERNAL_CreateShaderFromHLSL(
             device,
             hlslSource,
             entrypoint,
+            includeDir,
             shaderStage,
             resourceInfo,
             true);
@@ -796,6 +935,7 @@ static void *SDL_ShaderCross_INTERNAL_CreateShaderFromHLSL(
         void *spirv = SDL_ShaderCross_CompileSPIRVFromHLSL(
             hlslSource,
             entrypoint,
+            includeDir,
             shaderStage,
             &bytecodeSize);
         if (spirv == NULL) {
@@ -831,6 +971,7 @@ SDL_GPUShader *SDL_ShaderCross_CompileGraphicsShaderFromHLSL(
     SDL_GPUDevice *device,
     const char *hlslSource,
     const char *entrypoint,
+    const char *includeDir,
     SDL_GPUShaderStage graphicsShaderStage,
     const SDL_ShaderCross_ShaderResourceInfo *resourceInfo)
 {
@@ -838,6 +979,7 @@ SDL_GPUShader *SDL_ShaderCross_CompileGraphicsShaderFromHLSL(
         device,
         hlslSource,
         entrypoint,
+        includeDir,
         (SDL_ShaderCross_ShaderStage)graphicsShaderStage,
         (const void *)resourceInfo);
 }
@@ -846,12 +988,14 @@ SDL_GPUComputePipeline *SDL_ShaderCross_CompileComputePipelineFromHLSL(
     SDL_GPUDevice *device,
     const char *hlslSource,
     const char *entrypoint,
+    const char *includeDir,
     const SDL_ShaderCross_ComputeResourceInfo *resourceInfo)
 {
     return (SDL_GPUComputePipeline *)SDL_ShaderCross_INTERNAL_CreateShaderFromHLSL(
         device,
         hlslSource,
         entrypoint,
+        includeDir,
         SDL_SHADERCROSS_SHADERSTAGE_COMPUTE,
         (const void *)resourceInfo);
 }
@@ -1449,12 +1593,14 @@ static void *SDL_ShaderCross_INTERNAL_CompileFromSPIRV(
             createInfo.code = SDL_ShaderCross_CompileDXBCFromHLSL(
                 transpileContext->translated_source,
                 transpileContext->cleansed_entrypoint,
+                NULL,
                 shaderStage,
                 &createInfo.code_size);
         } else if (targetFormat == SDL_GPU_SHADERFORMAT_DXIL) {
             createInfo.code = SDL_ShaderCross_CompileDXILFromHLSL(
                 transpileContext->translated_source,
                 transpileContext->cleansed_entrypoint,
+                NULL,
                 shaderStage,
                 &createInfo.code_size);
         } else { // MSL
@@ -1479,12 +1625,14 @@ static void *SDL_ShaderCross_INTERNAL_CompileFromSPIRV(
             createInfo.code = SDL_ShaderCross_CompileDXBCFromHLSL(
                 transpileContext->translated_source,
                 transpileContext->cleansed_entrypoint,
+                NULL,
                 shaderStage,
                 &createInfo.code_size);
         } else if (targetFormat == SDL_GPU_SHADERFORMAT_DXIL) {
             createInfo.code = SDL_ShaderCross_CompileDXILFromHLSL(
                 transpileContext->translated_source,
                 transpileContext->cleansed_entrypoint,
+                NULL,
                 shaderStage,
                 &createInfo.code_size);
         } else { // MSL
@@ -1575,6 +1723,7 @@ void *SDL_ShaderCross_CompileDXBCFromSPIRV(
     void *result = SDL_ShaderCross_CompileDXBCFromHLSL(
         context->translated_source,
         context->cleansed_entrypoint,
+        NULL,
         shaderStage,
         size);
 
@@ -1600,6 +1749,7 @@ void *SDL_ShaderCross_CompileDXILFromSPIRV(
     void *result = SDL_ShaderCross_CompileDXILFromHLSL(
         context->translated_source,
         context->cleansed_entrypoint,
+        NULL,
         shaderStage,
         size);
 
diff --git a/src/cli.c b/src/cli.c
index 0ea23dc..91d7a13 100644
--- a/src/cli.c
+++ b/src/cli.c
@@ -43,6 +43,7 @@ void print_help(void)
     SDL_Log("  %-*s %s", column_width, "-t | --stage <value>", "Shader stage. May be inferred from the filename. Values: [vertex, fragment, compute]");
     SDL_Log("  %-*s %s", column_width, "-e | --entrypoint <value>", "Entrypoint function name. Default: \"main\".");
     SDL_Log("  %-*s %s", column_width, "-m | --shadermodel <value>", "HLSL Shader Model. Only used with HLSL destination. Values: [5.0, 6.0]");
+    SDL_Log("  %-*s %s", column_width, "-I | --include <value>", "HLSL include directory. Only used with HLSL source. Optional.");
     SDL_Log("  %-*s %s", column_width, "-o | --output <value>", "Output file.");
 }
 
@@ -60,6 +61,7 @@ int main(int argc, char *argv[])
     SDL_ShaderCross_ShaderModel shaderModel = SDL_SHADERCROSS_SHADERMODEL_INVALID;
     char *outputFilename = NULL;
     char *entrypointName = "main";
+    char *includeDir = NULL;
 
     char *filename = NULL;
     size_t fileSize = 0;
@@ -165,6 +167,19 @@ int main(int argc, char *argv[])
                     print_help();
                     return 1;
                 }
+            } else if (SDL_strcmp(arg, "-I") == 0 || SDL_strcmp(arg, "--include") == 0) {
+                if (includeDir) {
+                    SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "'%s' can only be used once", arg);
+                    print_help();
+                    return 1;
+                }
+                if (i + 1 >= argc) {
+                    SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "%s requires an argument", arg);
+                    print_help();
+                    return 1;
+                }
+                i += 1;
+                includeDir = argv[i];
             } else if (SDL_strcmp(arg, "-o") == 0 || SDL_strcmp(arg, "--output") == 0) {
                 if (i + 1 >= argc) {
                     SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "%s requires an argument", arg);
@@ -339,6 +354,7 @@ int main(int argc, char *argv[])
                 Uint8 *buffer = SDL_ShaderCross_CompileDXBCFromHLSL(
                     fileData,
                     entrypointName,
+                    includeDir,
                     shaderStage,
                     &bytecodeSize);
                 SDL_WriteIO(outputIO, buffer, bytecodeSize);
@@ -350,6 +366,7 @@ int main(int argc, char *argv[])
                 Uint8 *buffer = SDL_ShaderCross_CompileDXILFromHLSL(
                     fileData,
                     entrypointName,
+                    includeDir,
                     shaderStage,
                     &bytecodeSize);
                 SDL_WriteIO(outputIO, buffer, bytecodeSize);
@@ -361,6 +378,7 @@ int main(int argc, char *argv[])
                 void *spirv = SDL_ShaderCross_CompileSPIRVFromHLSL(
                     fileData,
                     entrypointName,
+                    includeDir,
                     shaderStage,
                     &bytecodeSize);
                 if (spirv == NULL) {
@@ -382,6 +400,7 @@ int main(int argc, char *argv[])
                 Uint8 *buffer = SDL_ShaderCross_CompileSPIRVFromHLSL(
                     fileData,
                     entrypointName,
+                    includeDir,
                     shaderStage,
                     &bytecodeSize);
                 SDL_WriteIO(outputIO, buffer, bytecodeSize);
@@ -399,6 +418,7 @@ int main(int argc, char *argv[])
                 void *spirv = SDL_ShaderCross_CompileSPIRVFromHLSL(
                     fileData,
                     entrypointName,
+                    includeDir,
                     shaderStage,
                     &bytecodeSize);