SDL: Use RoInitialize/RoUninitialize for Windows.Gaming.Input

From 8ebef12d31bff35aec6ac659ae7d2d8fcb6ea5b0 Mon Sep 17 00:00:00 2001
From: Sam Lantinga <[EMAIL REDACTED]>
Date: Fri, 1 Apr 2022 14:58:33 -0700
Subject: [PATCH] Use RoInitialize/RoUninitialize for Windows.Gaming.Input

Thanks @walbourn!

Fixes https://github.com/libsdl-org/SDL/issues/5270
---
 src/core/windows/SDL_windows.c                | 94 ++++++++++---------
 src/core/windows/SDL_windows.h                |  7 ++
 src/joystick/windows/SDL_rawinputjoystick.c   | 14 +--
 .../windows/SDL_windows_gaming_input.c        | 33 ++-----
 4 files changed, 69 insertions(+), 79 deletions(-)

diff --git a/src/core/windows/SDL_windows.c b/src/core/windows/SDL_windows.c
index fcbf8e847b3..767fa13c407 100644
--- a/src/core/windows/SDL_windows.c
+++ b/src/core/windows/SDL_windows.c
@@ -25,7 +25,8 @@
 #include "SDL_windows.h"
 #include "SDL_error.h"
 
-#include <objbase.h>  /* for CoInitialize/CoUninitialize (Win32 only) */
+#include <objbase.h>    /* for CoInitialize/CoUninitialize (Win32 only) */
+#include <roapi.h>      /* For RoInitialize/RoUninitialize (Win32 only) */
 
 #ifndef _WIN32_WINNT_VISTA
 #define _WIN32_WINNT_VISTA  0x0600
@@ -104,51 +105,52 @@ void
 WIN_CoUninitialize(void)
 {
 #ifndef __WINRT__
-    /* Don't uninitialize COM because of what appears to be a bug in Microsoft WGI reference counting.
-     *
-     * If you plug in a non-Xbox controller and let the application run for 30 seconds, then it crashes in CoUninitialize()
-     * with this stack trace:
-
-        Windows.Gaming.Input.dll!GameController::~GameController(void)	Unknown
-        Windows.Gaming.Input.dll!GameController::`vector deleting destructor'(unsigned int)	Unknown
-        Windows.Gaming.Input.dll!Microsoft::WRL::Details::RuntimeClassImpl<struct Microsoft::WRL::RuntimeClassFlags<1>,1,1,0,struct Windows::Gaming::Input::IGameController,struct Windows::Gaming::Input::IGameControllerBatteryInfo,struct Microsoft::WRL::CloakedIid<struct Windows::Gaming::Input::Internal::IGameControllerPrivate>,class Microsoft::WRL::FtmBase>::Release(void)	Unknown
-        Windows.Gaming.Input.dll!Windows::Gaming::Input::Custom::Details::AggregableRuntimeClass<struct Windows::Gaming::Input::IGamepad,struct Windows::Gaming::Input::IGamepad2,struct Microsoft::WRL::CloakedIid<struct Windows::Gaming::Input::Custom::IGameControllerInputSink>,struct Microsoft::WRL::CloakedIid<struct Windows::Gaming::Input::Custom::IGipGameControllerInputSink>,struct Microsoft::WRL::CloakedIid<struct Windows::Gaming::Input::Custom::IHidGameControllerInputSink>,struct Microsoft::WRL::CloakedIid<struct Windows::Gaming::Input::Custom::IXusbGameControllerInputSink>,class Microsoft::WRL::Details::Nil,class Microsoft::WRL::Details::Nil,class Microsoft::WRL::Details::Nil>::Release(void)	Unknown
-        Windows.Gaming.Input.dll!Microsoft::WRL::ComPtr<`WaitForCompletion<Windows::Foundation::IAsyncOperationWithProgressCompletedHandler<Windows::Storage::Streams::IBuffer *,unsigned int>,Windows::Foundation::IAsyncOperationWithProgress<Windows::Storage::Streams::IBuffer *,unsigned int>>'::`2'::FTMEventDelegate>::~ComPtr<`WaitForCompletion<Windows::Foundation::IAsyncOperationWithProgressCompletedHandler<Windows::Storage::Streams::IBuffer *,unsigned int>,Windows::Foundation::IAsyncOperationWithProgress<Windows::Storage::Streams::IBuffer *,unsigned int>>'::`2'::FTMEventDelegate>()	Unknown
-        Windows.Gaming.Input.dll!`eh vector destructor iterator'(void *,unsigned int,int,void (*)(void *))	Unknown
-        Windows.Gaming.Input.dll!Windows::Gaming::Input::Custom::Details::GameControllerCollection<class Windows::Gaming::Input::RawGameController,struct Windows::Gaming::Input::IRawGameController>::~GameControllerCollection<class Windows::Gaming::Input::RawGameController,struct Windows::Gaming::Input::IRawGameController>(void)	Unknown
-        Windows.Gaming.Input.dll!Windows::Gaming::Input::Custom::Details::GameControllerCollection<class Windows::Gaming::Input::RawGameController,struct Windows::Gaming::Input::IRawGameController>::`vector deleting destructor'(unsigned int)	Unknown
-        Windows.Gaming.Input.dll!Microsoft::WRL::Details::RuntimeClassImpl<struct Microsoft::WRL::RuntimeClassFlags<1>,1,1,0,struct Windows::Foundation::Collections::IIterable<class Windows::Gaming::Input::ArcadeStick *>,struct Windows::Foundation::Collections::IVectorView<class Windows::Gaming::Input::ArcadeStick *>,class Microsoft::WRL::FtmBase>::Release(void)	Unknown
-        Windows.Gaming.Input.dll!Windows::Gaming::Input::Custom::Details::CustomGameControllerFactoryBase<class Windows::Gaming::Input::FlightStick,class Windows::Gaming::Input::FlightStick,struct Windows::Gaming::Input::IFlightStick,struct Windows::Gaming::Input::IFlightStickStatics,class Microsoft::WRL::Details::Nil>::~CustomGameControllerFactoryBase<class Windows::Gaming::Input::FlightStick,class Windows::Gaming::Input::FlightStick,struct Windows::Gaming::Input::IFlightStick,struct Windows::Gaming::Input::IFlightStickStatics,class Microsoft::WRL::Details::Nil>(void)	Unknown
-        Windows.Gaming.Input.dll!Windows::Gaming::Input::Custom::Details::CustomGameControllerFactoryBase<class Windows::Gaming::Input::FlightStick,class Windows::Gaming::Input::FlightStick,struct Windows::Gaming::Input::IFlightStick,struct Windows::Gaming::Input::IFlightStickStatics,class Microsoft::WRL::Details::Nil>::`vector deleting destructor'(unsigned int)	Unknown
-        Windows.Gaming.Input.dll!Microsoft::WRL::ActivationFactory<struct Microsoft::WRL::Implements<class Microsoft::WRL::FtmBase,struct Microsoft::WRL::CloakedIid<struct Windows::Gaming::Input::Custom::ICustomGameControllerFactory> >,struct Windows::Gaming::Input::IFlightStickStatics,class Microsoft::WRL::Details::Nil,0>::Release(void)	Unknown
-        Windows.Gaming.Input.dll!Microsoft::WRL::ComPtr<`WaitForCompletion<Windows::Foundation::IAsyncOperationWithProgressCompletedHandler<Windows::Storage::Streams::IBuffer *,unsigned int>,Windows::Foundation::IAsyncOperationWithProgress<Windows::Storage::Streams::IBuffer *,unsigned int>>'::`2'::FTMEventDelegate>::~ComPtr<`WaitForCompletion<Windows::Foundation::IAsyncOperationWithProgressCompletedHandler<Windows::Storage::Streams::IBuffer *,unsigned int>,Windows::Foundation::IAsyncOperationWithProgress<Windows::Storage::Streams::IBuffer *,unsigned int>>'::`2'::FTMEventDelegate>()	Unknown
-        Windows.Gaming.Input.dll!NtList<struct FactoryManager::FactoryListEntry>::~NtList<struct FactoryManager::FactoryListEntry>(void)	Unknown
-        Windows.Gaming.Input.dll!FactoryManager::`vector deleting destructor'(unsigned int)	Unknown
-        Windows.Gaming.Input.dll!Microsoft::WRL::ActivationFactory<struct Microsoft::WRL::Implements<class Microsoft::WRL::FtmBase,struct Windows::Gaming::Input::Custom::IGameControllerFactoryManagerStatics>,struct Windows::Gaming::Input::Custom::IGameControllerFactoryManagerStatics2,struct Microsoft::WRL::CloakedIid<struct Windows::Gaming::Input::Internal::IGameControllerFactoryManagerStaticsPrivate>,0>::Release(void)	Unknown
-        Windows.Gaming.Input.dll!Microsoft::WRL::Details::TerminateMap(class Microsoft::WRL::Details::ModuleBase *,unsigned short const *,bool)	Unknown
-        Windows.Gaming.Input.dll!Microsoft::WRL::Module<1,class Microsoft::WRL::Details::DefaultModule<1> >::~Module<1,class Microsoft::WRL::Details::DefaultModule<1> >(void)	Unknown
-        Windows.Gaming.Input.dll!Microsoft::WRL::Details::DefaultModule<1>::`vector deleting destructor'(unsigned int)	Unknown
-        Windows.Gaming.Input.dll!`dynamic atexit destructor for 'Microsoft::WRL::Details::StaticStorage<Microsoft::WRL::Details::DefaultModule<1>,0,int>::instance_''()	Unknown
-        Windows.Gaming.Input.dll!__CRT_INIT@12()	Unknown
-        Windows.Gaming.Input.dll!__DllMainCRTStartup()	Unknown
-        ntdll.dll!_LdrxCallInitRoutine@16()	Unknown
-        ntdll.dll!LdrpCallInitRoutine()	Unknown
-        ntdll.dll!LdrpProcessDetachNode()	Unknown
-        ntdll.dll!LdrpUnloadNode()	Unknown
-        ntdll.dll!LdrpDecrementModuleLoadCountEx()	Unknown
-        ntdll.dll!LdrUnloadDll()	Unknown
-        KernelBase.dll!FreeLibrary()	Unknown
-        combase.dll!FreeLibraryWithLogging(LoadOrFreeWhy why, HINSTANCE__ * hMod, const wchar_t * pswzOptionalFileName) Line 193	C++
-        combase.dll!CClassCache::CDllPathEntry::CFinishObject::Finish() Line 3311	C++
-        combase.dll!CClassCache::CFinishComposite::Finish() Line 3421	C++
-        combase.dll!CClassCache::CleanUpDllsForProcess() Line 7009	C++
-        [Inline Frame] combase.dll!CCCleanUpDllsForProcess() Line 8773	C++
-        combase.dll!ProcessUninitialize() Line 2243	C++
-        combase.dll!DecrementProcessInitializeCount() Line 993	C++
-        combase.dll!wCoUninitialize(COleTls & Tls, int fHostThread) Line 4126	C++
-        combase.dll!CoUninitialize() Line 3945	C++
-    */
-    /*CoUninitialize();*/
+    CoUninitialize();
+#endif
+}
+
+void *
+WIN_LoadComBaseFunction(const char *name)
+{
+    static SDL_bool s_bLoaded;
+    static HMODULE s_hComBase;
+   
+    if (!s_bLoaded) {
+       s_hComBase = LoadLibraryEx(TEXT("combase.dll"), NULL, LOAD_LIBRARY_SEARCH_SYSTEM32);
+       s_bLoaded = SDL_TRUE;
+    }
+    if (s_hComBase) {
+        return GetProcAddress(s_hComBase, name);
+    } else {
+        return NULL;
+    }
+}
+
+HRESULT
+WIN_RoInitialize(void)
+{
+#ifdef __WINRT__
+    return S_OK;
+#else
+    typedef HRESULT (*RoInitialize_t)(RO_INIT_TYPE initType);
+    RoInitialize_t RoInitializeFunc = (RoInitialize_t)WIN_LoadComBaseFunction("RoInitialize");
+    if (RoInitializeFunc) {
+        return RoInitializeFunc(RO_INIT_MULTITHREADED);
+    } else {
+        return E_NOINTERFACE;
+    }
+#endif
+}
+
+void
+WIN_RoUninitialize(void)
+{
+#ifndef __WINRT__
+    typedef void (*RoUninitialize_t)(void);
+    RoUninitialize_t RoUninitializeFunc = (RoUninitialize_t)WIN_LoadComBaseFunction("RoUninitialize");
+    if (RoUninitializeFunc) {
+        RoUninitializeFunc();
+    }
 #endif
 }
 
diff --git a/src/core/windows/SDL_windows.h b/src/core/windows/SDL_windows.h
index 221c3bd87a7..640769da13c 100644
--- a/src/core/windows/SDL_windows.h
+++ b/src/core/windows/SDL_windows.h
@@ -63,10 +63,17 @@ extern int WIN_SetErrorFromHRESULT(const char *prefix, HRESULT hr);
 /* Sets an error message based on GetLastError(). Always return -1. */
 extern int WIN_SetError(const char *prefix);
 
+/* Load a function from combase.dll */
+void *WIN_LoadComBaseFunction(const char *name);
+
 /* Wrap up the oddities of CoInitialize() into a common function. */
 extern HRESULT WIN_CoInitialize(void);
 extern void WIN_CoUninitialize(void);
 
+/* Wrap up the oddities of RoInitialize() into a common function. */
+extern HRESULT WIN_RoInitialize(void);
+extern void WIN_RoUninitialize(void);
+
 /* Returns SDL_TRUE if we're running on Windows Vista and newer */
 extern BOOL WIN_IsWindowsVistaOrGreater(void);
 
diff --git a/src/joystick/windows/SDL_rawinputjoystick.c b/src/joystick/windows/SDL_rawinputjoystick.c
index 32b63782735..8c96ae63676 100644
--- a/src/joystick/windows/SDL_rawinputjoystick.c
+++ b/src/joystick/windows/SDL_rawinputjoystick.c
@@ -565,22 +565,19 @@ RAWINPUT_InitWindowsGamingInput(RAWINPUT_DeviceContext *ctx)
     if (!wgi_state.initialized) {
         static const IID SDL_IID_IGamepadStatics = { 0x8BBCE529, 0xD49C, 0x39E9, { 0x95, 0x60, 0xE4, 0x7D, 0xDE, 0x96, 0xB7, 0xC8 } };
         HRESULT hr;
-        HMODULE hModule;
 
-        /* I think this takes care of RoInitialize() in a way that is compatible with the rest of SDL */
-        if (FAILED(WIN_CoInitialize())) {
+        if (FAILED(WIN_RoInitialize())) {
             return;
         }
         wgi_state.initialized = SDL_TRUE;
         wgi_state.dirty = SDL_TRUE;
 
-        hModule = LoadLibraryA("combase.dll");
-        if (hModule != NULL) {
+        {
             typedef HRESULT (WINAPI *WindowsCreateStringReference_t)(PCWSTR sourceString, UINT32 length, HSTRING_HEADER *hstringHeader, HSTRING* string);
             typedef HRESULT (WINAPI *RoGetActivationFactory_t)(HSTRING activatableClassId, REFIID iid, void** factory);
 
-            WindowsCreateStringReference_t WindowsCreateStringReferenceFunc = (WindowsCreateStringReference_t)GetProcAddress(hModule, "WindowsCreateStringReference");
-            RoGetActivationFactory_t RoGetActivationFactoryFunc = (RoGetActivationFactory_t)GetProcAddress(hModule, "RoGetActivationFactory");
+            WindowsCreateStringReference_t WindowsCreateStringReferenceFunc = (WindowsCreateStringReference_t)WIN_LoadComBaseFunction("WindowsCreateStringReference");
+            RoGetActivationFactory_t RoGetActivationFactoryFunc = (RoGetActivationFactory_t)WIN_LoadComBaseFunction("RoGetActivationFactory");
             if (WindowsCreateStringReferenceFunc && RoGetActivationFactoryFunc) {
                 PCWSTR pNamespace = L"Windows.Gaming.Input.Gamepad";
                 HSTRING_HEADER hNamespaceStringHeader;
@@ -591,7 +588,6 @@ RAWINPUT_InitWindowsGamingInput(RAWINPUT_DeviceContext *ctx)
                     RoGetActivationFactoryFunc(hNamespaceString, &SDL_IID_IGamepadStatics, (void **)&wgi_state.gamepad_statics);
                 }
             }
-            FreeLibrary(hModule);
         }
     }
 }
@@ -657,7 +653,7 @@ RAWINPUT_QuitWindowsGamingInput(RAWINPUT_DeviceContext *ctx)
             __x_ABI_CWindows_CGaming_CInput_CIGamepadStatics_Release(wgi_state.gamepad_statics);
             wgi_state.gamepad_statics = NULL;
         }
-        WIN_CoUninitialize();
+        WIN_RoUninitialize();
         wgi_state.initialized = SDL_FALSE;
     }
 }
diff --git a/src/joystick/windows/SDL_windows_gaming_input.c b/src/joystick/windows/SDL_windows_gaming_input.c
index a731dad6341..5300cfa5bc0 100644
--- a/src/joystick/windows/SDL_windows_gaming_input.c
+++ b/src/joystick/windows/SDL_windows_gaming_input.c
@@ -260,10 +260,9 @@ static HRESULT STDMETHODCALLTYPE IEventHandler_CRawGameControllerVtbl_InvokeAdde
             WindowsGetStringRawBufferFunc = WindowsGetStringRawBuffer;
             WindowsDeleteStringFunc = WindowsDeleteString;
 #else
-            HMODULE hModule = LoadLibraryA("combase.dll");
-            if (hModule != NULL) {
-                WindowsGetStringRawBufferFunc = (WindowsGetStringRawBuffer_t)GetProcAddress(hModule, "WindowsGetStringRawBuffer");
-                WindowsDeleteStringFunc = (WindowsDeleteString_t)GetProcAddress(hModule, "WindowsDeleteString");
+            {
+                WindowsGetStringRawBufferFunc = (WindowsGetStringRawBuffer_t)WIN_LoadComBaseFunction("WindowsGetStringRawBuffer");
+                WindowsDeleteStringFunc = (WindowsDeleteString_t)WIN_LoadComBaseFunction("WindowsDeleteString");
             }
 #endif /* __WINRT__ */
             if (WindowsGetStringRawBufferFunc && WindowsDeleteStringFunc) {
@@ -277,11 +276,6 @@ static HRESULT STDMETHODCALLTYPE IEventHandler_CRawGameControllerVtbl_InvokeAdde
                     WindowsDeleteStringFunc(hString);
                 }
             }
-#ifndef __WINRT__
-            if (hModule != NULL) {
-                FreeLibrary(hModule);
-            }
-#endif
             __x_ABI_CWindows_CGaming_CInput_CIRawGameController2_Release(controller2);
         }
         if (!name) {
@@ -444,23 +438,19 @@ WGI_JoystickInit(void)
 
     WindowsCreateStringReference_t WindowsCreateStringReferenceFunc = NULL;
     RoGetActivationFactory_t RoGetActivationFactoryFunc = NULL;
-#ifndef __WINRT__
-    HMODULE hModule;
-#endif
     HRESULT hr;
 
-    if (FAILED(WIN_CoInitialize())) {
-        return SDL_SetError("CoInitialize() failed");
+    if (FAILED(WIN_RoInitialize())) {
+        return SDL_SetError("RoInitialize() failed");
     }
 
 #ifdef __WINRT__
     WindowsCreateStringReferenceFunc = WindowsCreateStringReference;
     RoGetActivationFactoryFunc = RoGetActivationFactory;
 #else
-    hModule = LoadLibraryA("combase.dll");
-    if (hModule != NULL) {
-        WindowsCreateStringReferenceFunc = (WindowsCreateStringReference_t)GetProcAddress(hModule, "WindowsCreateStringReference");
-        RoGetActivationFactoryFunc = (RoGetActivationFactory_t)GetProcAddress(hModule, "RoGetActivationFactory");
+    {
+        WindowsCreateStringReferenceFunc = (WindowsCreateStringReference_t)WIN_LoadComBaseFunction("WindowsCreateStringReference");
+        RoGetActivationFactoryFunc = (RoGetActivationFactory_t)WIN_LoadComBaseFunction("RoGetActivationFactory");
     }
 #endif /* __WINRT__ */
     if (WindowsCreateStringReferenceFunc && RoGetActivationFactoryFunc) {
@@ -519,11 +509,6 @@ WGI_JoystickInit(void)
             }
         }
     }
-#ifndef __WINRT__
-    if (hModule != NULL) {
-        FreeLibrary(hModule);
-    }
-#endif
 
     if (wgi.statics) {
         __FIVectorView_1_Windows__CGaming__CInput__CRawGameController *controllers;
@@ -865,7 +850,7 @@ WGI_JoystickQuit(void)
     }
     SDL_zero(wgi);
 
-    WIN_CoUninitialize();
+    WIN_RoUninitialize();
 }
 
 static SDL_bool