SDL: wasapi: Handle disconnect notifications from the management thread, too.

From 468c386686010c127ba2f06ef251682a294614e8 Mon Sep 17 00:00:00 2001
From: "Ryan C. Gordon" <[EMAIL REDACTED]>
Date: Fri, 27 Oct 2023 01:28:51 -0400
Subject: [PATCH] wasapi: Handle disconnect notifications from the management
 thread, too.

These are also pretty heavyweight, don't do them from the notification
thread, which can deadlock everything.
---
 src/audio/wasapi/SDL_wasapi_win32.c | 15 ++++++++++++++-
 src/core/windows/SDL_immdevice.c    | 25 +++++++++++++++++--------
 src/core/windows/SDL_immdevice.h    |  8 ++++++--
 3 files changed, 37 insertions(+), 11 deletions(-)

diff --git a/src/audio/wasapi/SDL_wasapi_win32.c b/src/audio/wasapi/SDL_wasapi_win32.c
index 3021a6fb042c..db8397c41667 100644
--- a/src/audio/wasapi/SDL_wasapi_win32.c
+++ b/src/audio/wasapi/SDL_wasapi_win32.c
@@ -48,6 +48,18 @@ static SDL_bool immdevice_initialized = SDL_FALSE;
 // Some GUIDs we need to know without linking to libraries that aren't available before Vista.
 static const IID SDL_IID_IAudioClient = { 0x1cb9ad4c, 0xdbfa, 0x4c32, { 0xb1, 0x78, 0xc2, 0xf5, 0x68, 0xa7, 0x03, 0xb2 } };
 
+static int mgmtthrtask_AudioDeviceDisconnected(void *userdata)
+{
+    SDL_AudioDeviceDisconnected((SDL_AudioDevice *)userdata);
+    return 0;
+}
+
+static void WASAPI_AudioDeviceDisconnected(SDL_AudioDevice *device)
+{
+    // don't wait on this, IMMDevice's own thread needs to return or everything will deadlock.
+    WASAPI_ProxyToManagementThread(mgmtthrtask_AudioDeviceDisconnected, device, NULL);
+}
+
 static int mgmtthrtask_DefaultAudioDeviceChanged(void *userdata)
 {
     SDL_DefaultAudioDeviceChanged((SDL_AudioDevice *) userdata);
@@ -62,9 +74,10 @@ static void WASAPI_DefaultAudioDeviceChanged(SDL_AudioDevice *new_default_device
 
 int WASAPI_PlatformInit(void)
 {
+    const SDL_IMMDevice_callbacks callbacks = { WASAPI_AudioDeviceDisconnected, WASAPI_DefaultAudioDeviceChanged };
     if (FAILED(WIN_CoInitialize())) {
         return SDL_SetError("CoInitialize() failed");
-    } else if (SDL_IMMDevice_Init(WASAPI_DefaultAudioDeviceChanged) < 0) {
+    } else if (SDL_IMMDevice_Init(&callbacks) < 0) {
         return -1; // Error string is set by SDL_IMMDevice_Init
     }
 
diff --git a/src/core/windows/SDL_immdevice.c b/src/core/windows/SDL_immdevice.c
index 583c71ff2ef4..3ed2d5f2dc88 100644
--- a/src/core/windows/SDL_immdevice.c
+++ b/src/core/windows/SDL_immdevice.c
@@ -37,7 +37,7 @@ static const ERole SDL_IMMDevice_role = eConsole; /* !!! FIXME: should this be e
 
 /* This is global to the WASAPI target, to handle hotplug and default device lookup. */
 static IMMDeviceEnumerator *enumerator = NULL;
-static SDL_IMMDevice_DefaultAudioDeviceChanged devchangecallback = NULL;
+static SDL_IMMDevice_callbacks immcallbacks;
 
 /* PropVariantInit() is an inline function/macro in PropIdl.h that calls the C runtime's memset() directly. Use ours instead, to avoid dependency. */
 #ifdef PropVariantInit
@@ -205,9 +205,7 @@ static ULONG STDMETHODCALLTYPE SDLMMNotificationClient_Release(IMMNotificationCl
 static HRESULT STDMETHODCALLTYPE SDLMMNotificationClient_OnDefaultDeviceChanged(IMMNotificationClient *iclient, EDataFlow flow, ERole role, LPCWSTR pwstrDeviceId)
 {
     if (role == SDL_IMMDevice_role) {
-        if (devchangecallback) {
-            devchangecallback(SDL_IMMDevice_FindByDevID(pwstrDeviceId));
-        }
+        immcallbacks.default_audio_device_changed(SDL_IMMDevice_FindByDevID(pwstrDeviceId));
     }
     return S_OK;
 }
@@ -247,7 +245,7 @@ static HRESULT STDMETHODCALLTYPE SDLMMNotificationClient_OnDeviceStateChanged(IM
                         SDL_free(utf8dev);
                     }
                 } else {
-                    SDL_AudioDeviceDisconnected(SDL_IMMDevice_FindByDevID(pwstrDeviceId));
+                    immcallbacks.audio_device_disconnected(SDL_IMMDevice_FindByDevID(pwstrDeviceId));
                 }
             }
             IMMEndpoint_Release(endpoint);
@@ -276,7 +274,7 @@ static const IMMNotificationClientVtbl notification_client_vtbl = {
 
 static SDLMMNotificationClient notification_client = { &notification_client_vtbl, { 1 } };
 
-int SDL_IMMDevice_Init(SDL_IMMDevice_DefaultAudioDeviceChanged devchanged)
+int SDL_IMMDevice_Init(const SDL_IMMDevice_callbacks *callbacks)
 {
     HRESULT ret;
 
@@ -295,7 +293,18 @@ int SDL_IMMDevice_Init(SDL_IMMDevice_DefaultAudioDeviceChanged devchanged)
         return WIN_SetErrorFromHRESULT("IMMDevice CoCreateInstance(MMDeviceEnumerator)", ret);
     }
 
-    devchangecallback = devchanged ? devchanged : SDL_DefaultAudioDeviceChanged;
+    if (callbacks) {
+        SDL_copyp(&immcallbacks, callbacks);
+    } else {
+        SDL_zero(immcallbacks);
+    }
+
+    if (!immcallbacks.audio_device_disconnected) {
+        immcallbacks.audio_device_disconnected = SDL_AudioDeviceDisconnected;
+    }
+    if (!immcallbacks.default_audio_device_changed) {
+        immcallbacks.default_audio_device_changed = SDL_DefaultAudioDeviceChanged;
+    }
 
     return 0;
 }
@@ -308,7 +317,7 @@ void SDL_IMMDevice_Quit(void)
         enumerator = NULL;
     }
 
-    devchangecallback = NULL;
+    SDL_zero(immcallbacks);
 
     WIN_CoUninitialize();
 }
diff --git a/src/core/windows/SDL_immdevice.h b/src/core/windows/SDL_immdevice.h
index 0b64d2f5566d..5ed0c5ddc8b6 100644
--- a/src/core/windows/SDL_immdevice.h
+++ b/src/core/windows/SDL_immdevice.h
@@ -28,9 +28,13 @@
 
 typedef struct SDL_AudioDevice SDL_AudioDevice; // this is defined in src/audio/SDL_sysaudio.h
 
-typedef void (*SDL_IMMDevice_DefaultAudioDeviceChanged)(SDL_AudioDevice *new_default_device);
+typedef struct SDL_IMMDevice_callbacks
+{
+    void (*audio_device_disconnected)(SDL_AudioDevice *device);
+    void (*default_audio_device_changed)(SDL_AudioDevice *new_default_device);
+} SDL_IMMDevice_callbacks;
 
-int SDL_IMMDevice_Init(SDL_IMMDevice_DefaultAudioDeviceChanged devchanged);
+int SDL_IMMDevice_Init(const SDL_IMMDevice_callbacks *callbacks);
 void SDL_IMMDevice_Quit(void);
 int SDL_IMMDevice_Get(SDL_AudioDevice *device, IMMDevice **immdevice, SDL_bool iscapture);
 void SDL_IMMDevice_EnumerateEndpoints(SDL_AudioDevice **default_output, SDL_AudioDevice **default_capture);