SDL: gpu: Rework driver name queries, add GetGPUShaderFormats

From 96e147b2b9528d26a5095fb405e259bbb6484173 Mon Sep 17 00:00:00 2001
From: Ethan Lee <[EMAIL REDACTED]>
Date: Fri, 13 Sep 2024 11:16:43 -0400
Subject: [PATCH] gpu: Rework driver name queries, add GetGPUShaderFormats

---
 include/SDL3/SDL_gpu.h            | 70 +++++++++++++++++++++----------
 src/dynapi/SDL_dynapi.sym         |  3 ++
 src/dynapi/SDL_dynapi_overrides.h |  3 ++
 src/dynapi/SDL_dynapi_procs.h     |  5 ++-
 src/gpu/SDL_gpu.c                 | 57 +++++++++++++++----------
 src/gpu/SDL_sysgpu.h              |  9 ++--
 src/gpu/d3d11/SDL_gpu_d3d11.c     |  1 -
 src/gpu/d3d12/SDL_gpu_d3d12.c     |  1 -
 src/gpu/metal/SDL_gpu_metal.m     |  1 -
 src/gpu/vulkan/SDL_gpu_vulkan.c   |  1 -
 src/render/gpu/SDL_shaders_gpu.c  | 37 +++++++++++-----
 test/testgpu_spinning_cube.c      |  8 ++--
 12 files changed, 129 insertions(+), 67 deletions(-)

diff --git a/include/SDL3/SDL_gpu.h b/include/SDL3/SDL_gpu.h
index 645174a0bcd01..e6d8c9b749791 100644
--- a/include/SDL3/SDL_gpu.h
+++ b/include/SDL3/SDL_gpu.h
@@ -932,23 +932,6 @@ typedef enum SDL_GPUSwapchainComposition
     SDL_GPU_SWAPCHAINCOMPOSITION_HDR10_ST2048
 } SDL_GPUSwapchainComposition;
 
-/**
- * Specifies a backend API supported by SDL_GPU.
- *
- * Only one of these will be in use at a time.
- *
- * \since This enum is available since SDL 3.0.0
- */
-typedef enum SDL_GPUDriver
-{
-    SDL_GPU_DRIVER_INVALID,
-    SDL_GPU_DRIVER_PRIVATE, /* NDA'd platforms */
-    SDL_GPU_DRIVER_VULKAN,
-    SDL_GPU_DRIVER_D3D11,
-    SDL_GPU_DRIVER_D3D12,
-    SDL_GPU_DRIVER_METAL
-} SDL_GPUDriver;
-
 /* Structures */
 
 /**
@@ -1704,7 +1687,8 @@ typedef struct SDL_GPUStorageTextureWriteOnlyBinding
  *
  * \since This function is available since SDL 3.0.0.
  *
- * \sa SDL_GetGPUDriver
+ * \sa SDL_GetGPUShaderFormats
+ * \sa SDL_GetGPUDeviceDriver
  * \sa SDL_DestroyGPUDevice
  */
 extern SDL_DECLSPEC SDL_GPUDevice *SDLCALL SDL_CreateGPUDevice(
@@ -1749,7 +1733,8 @@ extern SDL_DECLSPEC SDL_GPUDevice *SDLCALL SDL_CreateGPUDevice(
  *
  * \since This function is available since SDL 3.0.0.
  *
- * \sa SDL_GetGPUDriver
+ * \sa SDL_GetGPUShaderFormats
+ * \sa SDL_GetGPUDeviceDriver
  * \sa SDL_DestroyGPUDevice
  */
 extern SDL_DECLSPEC SDL_GPUDevice *SDLCALL SDL_CreateGPUDeviceWithProperties(
@@ -1778,14 +1763,55 @@ extern SDL_DECLSPEC SDL_GPUDevice *SDLCALL SDL_CreateGPUDeviceWithProperties(
 extern SDL_DECLSPEC void SDLCALL SDL_DestroyGPUDevice(SDL_GPUDevice *device);
 
 /**
- * Returns the backend used to create this GPU context.
+ * Get the number of GPU drivers compiled into SDL.
+ *
+ * \returns the number of built in GPU drivers.
+ *
+ * \since This function is available since SDL 3.0.0.
+ *
+ * \sa SDL_GetGPUDriver
+ */
+extern SDL_DECLSPEC int SDLCALL SDL_GetNumGPUDrivers(void);
+
+/**
+ * Get the name of a built in GPU driver.
+ *
+ * The GPU drivers are presented in the order in which they are normally
+ * checked during initialization.
+ *
+ * The names of drivers are all simple, low-ASCII identifiers, like "vulkan",
+ * "metal" or "direct3d12". These never have Unicode characters, and are not
+ * meant to be proper names.
+ *
+ * \param index the index of a GPU driver.
+ * \returns the name of the GPU driver with the given **index**.
+ *
+ * \since This function is available since SDL 3.0.0.
+ *
+ * \sa SDL_GetNumGPUDrivers
+ */
+extern SDL_DECLSPEC const char * SDLCALL SDL_GetGPUDriver(int index);
+
+/**
+ * Returns the name of the backend used to create this GPU context.
+ *
+ * \param device a GPU context to query.
+ * \returns the name of the device's driver, or NULL on error.
+ *
+ * \since This function is available since SDL 3.0.0.
+ */
+extern SDL_DECLSPEC const char * SDLCALL SDL_GetGPUDeviceDriver(SDL_GPUDevice *device);
+
+/**
+ * Returns the supported shader formats for this GPU context.
  *
  * \param device a GPU context to query.
- * \returns an SDL_GPUDriver value, or SDL_GPU_DRIVER_INVALID on error.
+ * \returns a bitflag indicating which shader formats the driver is
+ *                     able to consume.
  *
  * \since This function is available since SDL 3.0.0.
  */
-extern SDL_DECLSPEC SDL_GPUDriver SDLCALL SDL_GetGPUDriver(SDL_GPUDevice *device);
+extern SDL_DECLSPEC SDL_GPUShaderFormat SDLCALL SDL_GetGPUShaderFormats(SDL_GPUDevice *device);
 
 /* State Creation */
 
diff --git a/src/dynapi/SDL_dynapi.sym b/src/dynapi/SDL_dynapi.sym
index 8b60ab7b75afb..89e3992addd02 100644
--- a/src/dynapi/SDL_dynapi.sym
+++ b/src/dynapi/SDL_dynapi.sym
@@ -276,7 +276,9 @@ SDL3_0.0.0 {
     SDL_GetFullscreenDisplayModes;
     SDL_GetGDKDefaultUser;
     SDL_GetGDKTaskQueue;
+    SDL_GetGPUDeviceDriver;
     SDL_GetGPUDriver;
+    SDL_GetGPUShaderFormats;
     SDL_GetGPUSwapchainTextureFormat;
     SDL_GetGamepadAppleSFSymbolsNameForAxis;
     SDL_GetGamepadAppleSFSymbolsNameForButton;
@@ -392,6 +394,7 @@ SDL3_0.0.0 {
     SDL_GetNumAllocations;
     SDL_GetNumAudioDrivers;
     SDL_GetNumCameraDrivers;
+    SDL_GetNumGPUDrivers;
     SDL_GetNumGamepadTouchpadFingers;
     SDL_GetNumGamepadTouchpads;
     SDL_GetNumHapticAxes;
diff --git a/src/dynapi/SDL_dynapi_overrides.h b/src/dynapi/SDL_dynapi_overrides.h
index 2a25777781f42..9bd82e2026be8 100644
--- a/src/dynapi/SDL_dynapi_overrides.h
+++ b/src/dynapi/SDL_dynapi_overrides.h
@@ -301,7 +301,9 @@
 #define SDL_GetFullscreenDisplayModes SDL_GetFullscreenDisplayModes_REAL
 #define SDL_GetGDKDefaultUser SDL_GetGDKDefaultUser_REAL
 #define SDL_GetGDKTaskQueue SDL_GetGDKTaskQueue_REAL
+#define SDL_GetGPUDeviceDriver SDL_GetGPUDeviceDriver_REAL
 #define SDL_GetGPUDriver SDL_GetGPUDriver_REAL
+#define SDL_GetGPUShaderFormats SDL_GetGPUShaderFormats_REAL
 #define SDL_GetGPUSwapchainTextureFormat SDL_GetGPUSwapchainTextureFormat_REAL
 #define SDL_GetGamepadAppleSFSymbolsNameForAxis SDL_GetGamepadAppleSFSymbolsNameForAxis_REAL
 #define SDL_GetGamepadAppleSFSymbolsNameForButton SDL_GetGamepadAppleSFSymbolsNameForButton_REAL
@@ -417,6 +419,7 @@
 #define SDL_GetNumAllocations SDL_GetNumAllocations_REAL
 #define SDL_GetNumAudioDrivers SDL_GetNumAudioDrivers_REAL
 #define SDL_GetNumCameraDrivers SDL_GetNumCameraDrivers_REAL
+#define SDL_GetNumGPUDrivers SDL_GetNumGPUDrivers_REAL
 #define SDL_GetNumGamepadTouchpadFingers SDL_GetNumGamepadTouchpadFingers_REAL
 #define SDL_GetNumGamepadTouchpads SDL_GetNumGamepadTouchpads_REAL
 #define SDL_GetNumHapticAxes SDL_GetNumHapticAxes_REAL
diff --git a/src/dynapi/SDL_dynapi_procs.h b/src/dynapi/SDL_dynapi_procs.h
index b30fe83f70722..0ad90cfb85d16 100644
--- a/src/dynapi/SDL_dynapi_procs.h
+++ b/src/dynapi/SDL_dynapi_procs.h
@@ -321,7 +321,9 @@ SDL_DYNAPI_PROC(float,SDL_GetFloatProperty,(SDL_PropertiesID a, const char *b, f
 SDL_DYNAPI_PROC(SDL_DisplayMode**,SDL_GetFullscreenDisplayModes,(SDL_DisplayID a, int *b),(a,b),return)
 SDL_DYNAPI_PROC(SDL_bool,SDL_GetGDKDefaultUser,(XUserHandle *a),(a),return)
 SDL_DYNAPI_PROC(SDL_bool,SDL_GetGDKTaskQueue,(XTaskQueueHandle *a),(a),return)
-SDL_DYNAPI_PROC(SDL_GPUDriver,SDL_GetGPUDriver,(SDL_GPUDevice *a),(a),return)
+SDL_DYNAPI_PROC(const char*,SDL_GetGPUDeviceDriver,(SDL_GPUDevice *a),(a),return)
+SDL_DYNAPI_PROC(const char*,SDL_GetGPUDriver,(int a),(a),return)
+SDL_DYNAPI_PROC(SDL_GPUShaderFormat,SDL_GetGPUShaderFormats,(SDL_GPUDevice *a),(a),return)
 SDL_DYNAPI_PROC(SDL_GPUTextureFormat,SDL_GetGPUSwapchainTextureFormat,(SDL_GPUDevice *a, SDL_Window *b),(a,b),return)
 SDL_DYNAPI_PROC(const char*,SDL_GetGamepadAppleSFSymbolsNameForAxis,(SDL_Gamepad *a, SDL_GamepadAxis b),(a,b),return)
 SDL_DYNAPI_PROC(const char*,SDL_GetGamepadAppleSFSymbolsNameForButton,(SDL_Gamepad *a, SDL_GamepadButton b),(a,b),return)
@@ -437,6 +439,7 @@ SDL_DYNAPI_PROC(SDL_DisplayOrientation,SDL_GetNaturalDisplayOrientation,(SDL_Dis
 SDL_DYNAPI_PROC(int,SDL_GetNumAllocations,(void),(),return)
 SDL_DYNAPI_PROC(int,SDL_GetNumAudioDrivers,(void),(),return)
 SDL_DYNAPI_PROC(int,SDL_GetNumCameraDrivers,(void),(),return)
+SDL_DYNAPI_PROC(int,SDL_GetNumGPUDrivers,(void),(),return)
 SDL_DYNAPI_PROC(int,SDL_GetNumGamepadTouchpadFingers,(SDL_Gamepad *a, int b),(a,b),return)
 SDL_DYNAPI_PROC(int,SDL_GetNumGamepadTouchpads,(SDL_Gamepad *a),(a),return)
 SDL_DYNAPI_PROC(int,SDL_GetNumHapticAxes,(SDL_Haptic *a),(a),return)
diff --git a/src/gpu/SDL_gpu.c b/src/gpu/SDL_gpu.c
index 97aeb81f7edd6..7a40303449afe 100644
--- a/src/gpu/SDL_gpu.c
+++ b/src/gpu/SDL_gpu.c
@@ -371,7 +371,7 @@ void SDL_GPU_BlitCommon(
 // Driver Functions
 
 #ifndef SDL_GPU_DISABLED
-static SDL_GPUDriver SDL_GPUSelectBackend(
+static const SDL_GPUBootstrap * SDL_GPUSelectBackend(
     SDL_VideoDevice *_this,
     const char *gpudriver,
     SDL_GPUShaderFormat format_flags)
@@ -384,16 +384,16 @@ static SDL_GPUDriver SDL_GPUSelectBackend(
             if (SDL_strcasecmp(gpudriver, backends[i]->name) == 0) {
                 if (!(backends[i]->shader_formats & format_flags)) {
                     SDL_LogError(SDL_LOG_CATEGORY_GPU, "Required shader format for backend %s not provided!", gpudriver);
-                    return SDL_GPU_DRIVER_INVALID;
+                    return NULL;
                 }
                 if (backends[i]->PrepareDriver(_this)) {
-                    return backends[i]->backendflag;
+                    return backends[i];
                 }
             }
         }
 
         SDL_LogError(SDL_LOG_CATEGORY_GPU, "SDL_HINT_GPU_DRIVER %s unsupported!", gpudriver);
-        return SDL_GPU_DRIVER_INVALID;
+        return NULL;
     }
 
     for (i = 0; backends[i]; i += 1) {
@@ -402,12 +402,12 @@ static SDL_GPUDriver SDL_GPUSelectBackend(
             continue;
         }
         if (backends[i]->PrepareDriver(_this)) {
-            return backends[i]->backendflag;
+            return backends[i];
         }
     }
 
     SDL_LogError(SDL_LOG_CATEGORY_GPU, "No supported SDL_GPU backend found!");
-    return SDL_GPU_DRIVER_INVALID;
+    return NULL;
 }
 #endif // SDL_GPU_DISABLED
 
@@ -455,10 +455,9 @@ SDL_GPUDevice *SDL_CreateGPUDeviceWithProperties(SDL_PropertiesID props)
     bool debug_mode;
     bool preferLowPower;
 
-    int i;
     const char *gpudriver;
     SDL_GPUDevice *result = NULL;
-    SDL_GPUDriver selectedBackend;
+    const SDL_GPUBootstrap *selectedBackend;
     SDL_VideoDevice *_this = SDL_GetVideoDevice();
 
     if (_this == NULL) {
@@ -494,17 +493,12 @@ SDL_GPUDevice *SDL_CreateGPUDeviceWithProperties(SDL_PropertiesID props)
     }
 
     selectedBackend = SDL_GPUSelectBackend(_this, gpudriver, format_flags);
-    if (selectedBackend != SDL_GPU_DRIVER_INVALID) {
-        for (i = 0; backends[i]; i += 1) {
-            if (backends[i]->backendflag == selectedBackend) {
-                result = backends[i]->CreateDevice(debug_mode, preferLowPower, props);
-                if (result != NULL) {
-                    result->backend = backends[i]->backendflag;
-                    result->shader_formats = backends[i]->shader_formats;
-                    result->debug_mode = debug_mode;
-                    break;
-                }
-            }
+    if (selectedBackend != NULL) {
+        result = selectedBackend->CreateDevice(debug_mode, preferLowPower, props);
+        if (result != NULL) {
+            result->backend = selectedBackend->name;
+            result->shader_formats = selectedBackend->shader_formats;
+            result->debug_mode = debug_mode;
         }
     }
     return result;
@@ -521,13 +515,34 @@ void SDL_DestroyGPUDevice(SDL_GPUDevice *device)
     device->DestroyDevice(device);
 }
 
-SDL_GPUDriver SDL_GetGPUDriver(SDL_GPUDevice *device)
+int SDL_GetNumGPUDrivers(void)
+{
+    return SDL_arraysize(backends) - 1;
+}
+
+const char * SDL_GetGPUDriver(int index)
+{
+    if (index < 0 || index >= SDL_GetNumGPUDrivers()) {
+        SDL_InvalidParamError("index");
+        return NULL;
+    }
+    return backends[index]->name;
+}
+
+const char * SDL_GetGPUDeviceDriver(SDL_GPUDevice *device)
 {
-    CHECK_DEVICE_MAGIC(device, SDL_GPU_DRIVER_INVALID);
+    CHECK_DEVICE_MAGIC(device, NULL);
 
     return device->backend;
 }
 
+SDL_GPUShaderFormat SDL_GetGPUShaderFormats(SDL_GPUDevice *device)
+{
+    CHECK_DEVICE_MAGIC(device, SDL_GPU_SHADERFORMAT_INVALID);
+
+    return device->shader_formats;
+}
+
 Uint32 SDL_GPUTextureFormatTexelBlockSize(
     SDL_GPUTextureFormat format)
 {
diff --git a/src/gpu/SDL_sysgpu.h b/src/gpu/SDL_sysgpu.h
index 79eef1964c584..5ae9b95a84767 100644
--- a/src/gpu/SDL_sysgpu.h
+++ b/src/gpu/SDL_sysgpu.h
@@ -693,12 +693,14 @@ struct SDL_GPUDevice
     // Opaque pointer for the Driver
     SDL_GPURenderer *driverData;
 
-    // Store this for SDL_GetGPUDriver()
-    SDL_GPUDriver backend;
+    // Store this for SDL_GetGPUDeviceDriver()
+    const char *backend;
+
+    // Store this for SDL_GetGPUShaderFormats()
+    SDL_GPUShaderFormat shader_formats;
 
     // Store this for SDL_gpu.c's debug layer
     bool debug_mode;
-    SDL_GPUShaderFormat shader_formats;
 };
 
 #define ASSIGN_DRIVER_FUNC(func, name) \
@@ -786,7 +788,6 @@ struct SDL_GPUDevice
 typedef struct SDL_GPUBootstrap
 {
     const char *name;
-    const SDL_GPUDriver backendflag;
     const SDL_GPUShaderFormat shader_formats;
     bool (*PrepareDriver)(SDL_VideoDevice *_this);
     SDL_GPUDevice *(*CreateDevice)(bool debug_mode, bool prefer_low_power, SDL_PropertiesID props);
diff --git a/src/gpu/d3d11/SDL_gpu_d3d11.c b/src/gpu/d3d11/SDL_gpu_d3d11.c
index dac8a20e4a105..ef18ce6aebc1c 100644
--- a/src/gpu/d3d11/SDL_gpu_d3d11.c
+++ b/src/gpu/d3d11/SDL_gpu_d3d11.c
@@ -6432,7 +6432,6 @@ static SDL_GPUDevice *D3D11_CreateDevice(bool debugMode, bool preferLowPower, SD
 
 SDL_GPUBootstrap D3D11Driver = {
     "direct3d11",
-    SDL_GPU_DRIVER_D3D11,
     SDL_GPU_SHADERFORMAT_DXBC,
     D3D11_PrepareDriver,
     D3D11_CreateDevice
diff --git a/src/gpu/d3d12/SDL_gpu_d3d12.c b/src/gpu/d3d12/SDL_gpu_d3d12.c
index 2f71bac227e9f..f3c9e18e9920f 100644
--- a/src/gpu/d3d12/SDL_gpu_d3d12.c
+++ b/src/gpu/d3d12/SDL_gpu_d3d12.c
@@ -8348,7 +8348,6 @@ static SDL_GPUDevice *D3D12_CreateDevice(bool debugMode, bool preferLowPower, SD
 
 SDL_GPUBootstrap D3D12Driver = {
     "direct3d12",
-    SDL_GPU_DRIVER_D3D12,
     SDL_GPU_SHADERFORMAT_DXIL,
     D3D12_PrepareDriver,
     D3D12_CreateDevice
diff --git a/src/gpu/metal/SDL_gpu_metal.m b/src/gpu/metal/SDL_gpu_metal.m
index f2361c82dd3ba..7dd33a21be69d 100644
--- a/src/gpu/metal/SDL_gpu_metal.m
+++ b/src/gpu/metal/SDL_gpu_metal.m
@@ -4120,7 +4120,6 @@ static void METAL_INTERNAL_DestroyBlitResources(
 
 SDL_GPUBootstrap MetalDriver = {
     "metal",
-    SDL_GPU_DRIVER_METAL,
     SDL_GPU_SHADERFORMAT_MSL | SDL_GPU_SHADERFORMAT_METALLIB,
     METAL_PrepareDriver,
     METAL_CreateDevice
diff --git a/src/gpu/vulkan/SDL_gpu_vulkan.c b/src/gpu/vulkan/SDL_gpu_vulkan.c
index 2763ec8b44db9..156c02f241656 100644
--- a/src/gpu/vulkan/SDL_gpu_vulkan.c
+++ b/src/gpu/vulkan/SDL_gpu_vulkan.c
@@ -11889,7 +11889,6 @@ static SDL_GPUDevice *VULKAN_CreateDevice(bool debugMode, bool preferLowPower, S
 
 SDL_GPUBootstrap VulkanDriver = {
     "vulkan",
-    SDL_GPU_DRIVER_VULKAN,
     SDL_GPU_SHADERFORMAT_SPIRV,
     VULKAN_PrepareDriver,
     VULKAN_CreateDevice
diff --git a/src/render/gpu/SDL_shaders_gpu.c b/src/render/gpu/SDL_shaders_gpu.c
index e71b49c2668eb..49be1e47d30b7 100644
--- a/src/render/gpu/SDL_shaders_gpu.c
+++ b/src/render/gpu/SDL_shaders_gpu.c
@@ -150,17 +150,28 @@ static const GPU_ShaderSources frag_shader_sources[NUM_FRAG_SHADERS] = {
 static SDL_GPUShader *CompileShader(const GPU_ShaderSources *sources, SDL_GPUDevice *device, SDL_GPUShaderStage stage)
 {
     const GPU_ShaderModuleSource *sms = NULL;
-    SDL_GPUDriver driver = SDL_GetGPUDriver(device);
+    SDL_GPUShaderFormat formats = SDL_GetGPUShaderFormats(device);
 
-    switch (driver) {
-        // clang-format off
-        IF_VULKAN(  case SDL_GPU_DRIVER_VULKAN: sms = &sources->spirv;  break;)
-        IF_D3D11(   case SDL_GPU_DRIVER_D3D11:  sms = &sources->dxbc50; break;)
-        IF_D3D12(   case SDL_GPU_DRIVER_D3D12:  sms = &sources->dxil60; break;)
-        IF_METAL(   case SDL_GPU_DRIVER_METAL:  sms = &sources->msl;    break;)
-        // clang-format on
-
-    default:
+    if (formats == SDL_GPU_SHADERFORMAT_INVALID) {
+        // SDL_GetGPUShaderFormats already set the error
+        return NULL;
+#if HAVE_SPIRV_SHADERS
+    } else if (formats & SDL_GPU_SHADERFORMAT_SPIRV) {
+        sms = &sources->spirv;
+#endif // HAVE_SPIRV_SHADERS
+#if HAVE_DXBC50_SHADERS
+    } else if (formats & SDL_GPU_SHADERFORMAT_DXBC) {
+        sms = &sources->dxbc50;
+#endif // HAVE_DXBC50_SHADERS
+#if HAVE_DXIL60_SHADERS
+    } else if (formats & SDL_GPU_SHADERFORMAT_DXIL) {
+        sms = &sources->dxil60;
+#endif // HAVE_DXIL60_SHADERS
+#if HAVE_METAL_SHADERS
+    } else if (formats & SDL_GPU_SHADERFORMAT_MSL) {
+        sms = &sources->msl;
+#endif // HAVE_METAL_SHADERS
+    } else {
         SDL_SetError("Unsupported GPU backend");
         return NULL;
     }
@@ -170,7 +181,11 @@ static SDL_GPUShader *CompileShader(const GPU_ShaderSources *sources, SDL_GPUDev
     sci.code_size = sms->code_len;
     sci.format = sms->format;
     // FIXME not sure if this is correct
-    sci.entrypoint = driver == SDL_GPU_DRIVER_METAL ? "main0" : "main";
+    sci.entrypoint =
+#if HAVE_METAL_SHADERS
+        (sms == &sources->msl) ? "main0" :
+#endif // HAVE_METAL_SHADERS
+        "main";
     sci.num_samplers = sources->num_samplers;
     sci.num_uniform_buffers = sources->num_uniform_buffers;
     sci.stage = stage;
diff --git a/test/testgpu_spinning_cube.c b/test/testgpu_spinning_cube.c
index 14e8c631f3240..7c28d19302f09 100644
--- a/test/testgpu_spinning_cube.c
+++ b/test/testgpu_spinning_cube.c
@@ -423,18 +423,18 @@ load_shader(SDL_bool is_vertex)
     createinfo.num_uniform_buffers = is_vertex ? 1 : 0;
     createinfo.props = 0;
 
-    SDL_GPUDriver backend = SDL_GetGPUDriver(gpu_device);
-    if (backend == SDL_GPU_DRIVER_D3D11) {
+    SDL_GPUShaderFormat format = SDL_GetGPUShaderFormats(gpu_device);
+    if (format & SDL_GPU_SHADERFORMAT_DXBC) {
         createinfo.format = SDL_GPU_SHADERFORMAT_DXBC;
         createinfo.code = is_vertex ? D3D11_CubeVert : D3D11_CubeFrag;
         createinfo.code_size = is_vertex ? SDL_arraysize(D3D11_CubeVert) : SDL_arraysize(D3D11_CubeFrag);
         createinfo.entrypoint = is_vertex ? "VSMain" : "PSMain";
-    } else if (backend == SDL_GPU_DRIVER_D3D12) {
+    } else if (format & SDL_GPU_SHADERFORMAT_DXIL) {
         createinfo.format = SDL_GPU_SHADERFORMAT_DXIL;
         createinfo.code = is_vertex ? D3D12_CubeVert : D3D12_CubeFrag;
         createinfo.code_size = is_vertex ? SDL_arraysize(D3D12_CubeVert) : SDL_arraysize(D3D12_CubeFrag);
         createinfo.entrypoint = is_vertex ? "VSMain" : "PSMain";
-    } else if (backend == SDL_GPU_DRIVER_METAL) {
+    } else if (format & SDL_GPU_SHADERFORMAT_METALLIB) {
         createinfo.format = SDL_GPU_SHADERFORMAT_METALLIB;
         createinfo.code = is_vertex ? cube_vert_metallib : cube_frag_metallib;
         createinfo.code_size = is_vertex ? cube_vert_metallib_len : cube_frag_metallib_len;