SDL_gpu_shadercross: MSL transpilation does not need stage

From 9eed880013a0c156e12ffa3b98a633a116c8382f Mon Sep 17 00:00:00 2001
From: cosmonaut <[EMAIL REDACTED]>
Date: Fri, 25 Oct 2024 11:37:26 -0700
Subject: [PATCH] MSL transpilation does not need stage

---
 include/SDL_gpu_shadercross.h | 3 +--
 src/SDL_gpu_shadercross.c     | 3 +--
 src/cli.c                     | 8 +++-----
 3 files changed, 5 insertions(+), 9 deletions(-)

diff --git a/include/SDL_gpu_shadercross.h b/include/SDL_gpu_shadercross.h
index 27cff7e..bcce37e 100644
--- a/include/SDL_gpu_shadercross.h
+++ b/include/SDL_gpu_shadercross.h
@@ -81,8 +81,7 @@ extern SDL_DECLSPEC SDL_GPUShaderFormat SDLCALL SDL_ShaderCross_GetSPIRVShaderFo
 extern SDL_DECLSPEC void * SDLCALL SDL_ShaderCross_TranspileMSLFromSPIRV(
     const Uint8 *bytecode,
     size_t bytecodeSize,
-    const char *entrypoint,
-    SDL_ShaderCross_ShaderStage shaderStage);
+    const char *entrypoint);
 
 /**
  * Transpile to HLSL code from SPIRV code.
diff --git a/src/SDL_gpu_shadercross.c b/src/SDL_gpu_shadercross.c
index 862da93..e160a98 100644
--- a/src/SDL_gpu_shadercross.c
+++ b/src/SDL_gpu_shadercross.c
@@ -947,8 +947,7 @@ static void *SDL_ShaderCross_INTERNAL_CompileFromSPIRV(
 void *SDL_ShaderCross_TranspileMSLFromSPIRV(
     const Uint8 *bytecode,
     size_t bytecodeSize,
-    const char *entrypoint,
-    SDL_ShaderCross_ShaderStage shaderStage)
+    const char *entrypoint)
 {
     SPIRVTranspileContext *context = SDL_ShaderCross_INTERNAL_TranspileFromSPIRV(
         SPVC_BACKEND_MSL,
diff --git a/src/cli.c b/src/cli.c
index d9be75a..fc692f5 100644
--- a/src/cli.c
+++ b/src/cli.c
@@ -220,7 +220,7 @@ int main(int argc, char *argv[])
         }
     }
 
-    if (!stageValid) {
+    if (!stageValid && destinationFormat != SHADERFORMAT_MSL) {
         if (SDL_strcasestr(filename, ".vert")) {
             shaderStage = SDL_SHADERCROSS_SHADERSTAGE_VERTEX;
         } else if (SDL_strcasestr(filename, ".frag")) {
@@ -277,8 +277,7 @@ int main(int argc, char *argv[])
                 char *buffer = SDL_ShaderCross_TranspileMSLFromSPIRV(
                     fileData,
                     fileSize,
-                    entrypointName,
-                    shaderStage);
+                    entrypointName);
                 SDL_IOprintf(outputIO, "%s", buffer);
                 SDL_free(buffer);
                 break;
@@ -388,8 +387,7 @@ int main(int argc, char *argv[])
                 char *buffer = SDL_ShaderCross_TranspileMSLFromSPIRV(
                     spirv,
                     bytecodeSize,
-                    entrypointName,
-                    shaderStage);
+                    entrypointName);
                 SDL_IOprintf(outputIO, "%s", buffer);
                 SDL_free(spirv);
                 SDL_free(buffer);