SDL: Do more robust validation of devices passed to the SDL HIDAPI functions

From f61b10dcf127645c2ce6cab12d2c71d28d1192d6 Mon Sep 17 00:00:00 2001
From: Sam Lantinga <[EMAIL REDACTED]>
Date: Mon, 8 Nov 2021 06:34:32 -0800
Subject: [PATCH] Do more robust validation of devices passed to the SDL HIDAPI
 functions

---
 src/hidapi/SDL_hidapi.c | 79 +++++++++++++++++++++--------------------
 1 file changed, 40 insertions(+), 39 deletions(-)

diff --git a/src/hidapi/SDL_hidapi.c b/src/hidapi/SDL_hidapi.c
index 80a0ee931b..70a53f2998 100644
--- a/src/hidapi/SDL_hidapi.c
+++ b/src/hidapi/SDL_hidapi.c
@@ -423,19 +423,22 @@ static const struct hidapi_backend LIBUSB_Backend = {
 typedef struct _HIDDeviceWrapper HIDDeviceWrapper;
 struct _HIDDeviceWrapper
 {
-    SDL_hid_device *device; /* must be first field */
+    const void *magic;
+    SDL_hid_device *device;
     const struct hidapi_backend *backend;
 };
+static char device_magic;
 
 #if HAVE_PLATFORM_BACKEND || HAVE_DRIVER_BACKEND || defined(SDL_LIBUSB_DYNAMIC)
 
 static HIDDeviceWrapper *
 CreateHIDDeviceWrapper(SDL_hid_device *device, const struct hidapi_backend *backend)
 {
-    HIDDeviceWrapper *ret = (HIDDeviceWrapper *)SDL_malloc(sizeof(*ret));
-    ret->device = device;
-    ret->backend = backend;
-    return ret;
+    HIDDeviceWrapper *wrapper = (HIDDeviceWrapper *)SDL_malloc(sizeof(*wrapper));
+    wrapper->magic = &device_magic;
+    wrapper->device = device;
+    wrapper->backend = backend;
+    return wrapper;
 }
 
 static SDL_hid_device *
@@ -455,9 +458,17 @@ UnwrapHIDDevice(SDL_hid_device *device)
 static void
 DeleteHIDDeviceWrapper(HIDDeviceWrapper *device)
 {
+    device->magic = NULL;
     SDL_free(device);
 }
 
+#define CHECK_DEVICE_MAGIC(device, retval) \
+    SDL_assert(device && device->magic == &device_magic); \
+    if (!device || device->magic != &device_magic) { \
+        SDL_SetError("Invalid device"); \
+        return retval; \
+    }
+
 #ifndef SDL_DISABLE_HIDAPI
 
 #define COPY_IF_EXISTS(var) \
@@ -848,9 +859,8 @@ int SDL_hid_write(SDL_hid_device *device, const unsigned char *data, size_t leng
     HIDDeviceWrapper *wrapper = UnwrapHIDDevice(device);
     int result;
 
-    if (!wrapper) {
-        return -1;
-    }
+    CHECK_DEVICE_MAGIC(wrapper, -1);
+
     result = wrapper->backend->hid_write(wrapper->device, data, length);
     if (result < 0) {
         SDL_SetHIDAPIError(wrapper->backend->hid_error(wrapper->device));
@@ -863,9 +873,8 @@ int SDL_hid_read_timeout(SDL_hid_device *device, unsigned char *data, size_t len
     HIDDeviceWrapper *wrapper = UnwrapHIDDevice(device);
     int result;
 
-    if (!wrapper) {
-        return -1;
-    }
+    CHECK_DEVICE_MAGIC(wrapper, -1);
+
     result = wrapper->backend->hid_read_timeout(wrapper->device, data, length, milliseconds);
     if (result < 0) {
         SDL_SetHIDAPIError(wrapper->backend->hid_error(wrapper->device));
@@ -878,9 +887,8 @@ int SDL_hid_read(SDL_hid_device *device, unsigned char *data, size_t length)
     HIDDeviceWrapper *wrapper = UnwrapHIDDevice(device);
     int result;
 
-    if (!wrapper) {
-        return -1;
-    }
+    CHECK_DEVICE_MAGIC(wrapper, -1);
+
     result = wrapper->backend->hid_read(wrapper->device, data, length);
     if (result < 0) {
         SDL_SetHIDAPIError(wrapper->backend->hid_error(wrapper->device));
@@ -893,9 +901,8 @@ int SDL_hid_set_nonblocking(SDL_hid_device *device, int nonblock)
     HIDDeviceWrapper *wrapper = UnwrapHIDDevice(device);
     int result;
 
-    if (!wrapper) {
-        return -1;
-    }
+    CHECK_DEVICE_MAGIC(wrapper, -1);
+
     result = wrapper->backend->hid_set_nonblocking(wrapper->device, nonblock);
     if (result < 0) {
         SDL_SetHIDAPIError(wrapper->backend->hid_error(wrapper->device));
@@ -908,9 +915,8 @@ int SDL_hid_send_feature_report(SDL_hid_device *device, const unsigned char *dat
     HIDDeviceWrapper *wrapper = UnwrapHIDDevice(device);
     int result;
 
-    if (!wrapper) {
-        return -1;
-    }
+    CHECK_DEVICE_MAGIC(wrapper, -1);
+
     result = wrapper->backend->hid_send_feature_report(wrapper->device, data, length);
     if (result < 0) {
         SDL_SetHIDAPIError(wrapper->backend->hid_error(wrapper->device));
@@ -923,9 +929,8 @@ int SDL_hid_get_feature_report(SDL_hid_device *device, unsigned char *data, size
     HIDDeviceWrapper *wrapper = UnwrapHIDDevice(device);
     int result;
 
-    if (!wrapper) {
-        return -1;
-    }
+    CHECK_DEVICE_MAGIC(wrapper, -1);
+
     result = wrapper->backend->hid_get_feature_report(wrapper->device, data, length);
     if (result < 0) {
         SDL_SetHIDAPIError(wrapper->backend->hid_error(wrapper->device));
@@ -937,10 +942,10 @@ void SDL_hid_close(SDL_hid_device *device)
 {
     HIDDeviceWrapper *wrapper = UnwrapHIDDevice(device);
 
-    if (wrapper) {
-        wrapper->backend->hid_close(wrapper->device);
-        DeleteHIDDeviceWrapper(wrapper);
-    }
+    CHECK_DEVICE_MAGIC(wrapper,);
+
+    wrapper->backend->hid_close(wrapper->device);
+    DeleteHIDDeviceWrapper(wrapper);
 }
 
 int SDL_hid_get_manufacturer_string(SDL_hid_device *device, wchar_t *string, size_t maxlen)
@@ -948,9 +953,8 @@ int SDL_hid_get_manufacturer_string(SDL_hid_device *device, wchar_t *string, siz
     HIDDeviceWrapper *wrapper = UnwrapHIDDevice(device);
     int result;
 
-    if (!wrapper) {
-        return -1;
-    }
+    CHECK_DEVICE_MAGIC(wrapper, -1);
+
     result = wrapper->backend->hid_get_manufacturer_string(wrapper->device, string, maxlen);
     if (result < 0) {
         SDL_SetHIDAPIError(wrapper->backend->hid_error(wrapper->device));
@@ -963,9 +967,8 @@ int SDL_hid_get_product_string(SDL_hid_device *device, wchar_t *string, size_t m
     HIDDeviceWrapper *wrapper = UnwrapHIDDevice(device);
     int result;
 
-    if (!wrapper) {
-        return -1;
-    }
+    CHECK_DEVICE_MAGIC(wrapper, -1);
+
     result = wrapper->backend->hid_get_product_string(wrapper->device, string, maxlen);
     if (result < 0) {
         SDL_SetHIDAPIError(wrapper->backend->hid_error(wrapper->device));
@@ -978,9 +981,8 @@ int SDL_hid_get_serial_number_string(SDL_hid_device *device, wchar_t *string, si
     HIDDeviceWrapper *wrapper = UnwrapHIDDevice(device);
     int result;
 
-    if (!wrapper) {
-        return -1;
-    }
+    CHECK_DEVICE_MAGIC(wrapper, -1);
+
     result = wrapper->backend->hid_get_serial_number_string(wrapper->device, string, maxlen);
     if (result < 0) {
         SDL_SetHIDAPIError(wrapper->backend->hid_error(wrapper->device));
@@ -993,9 +995,8 @@ int SDL_hid_get_indexed_string(SDL_hid_device *device, int string_index, wchar_t
     HIDDeviceWrapper *wrapper = UnwrapHIDDevice(device);
     int result;
 
-    if (!wrapper) {
-        return -1;
-    }
+    CHECK_DEVICE_MAGIC(wrapper, -1);
+
     result = wrapper->backend->hid_get_indexed_string(wrapper->device, string_index, string, maxlen);
     if (result < 0) {
         SDL_SetHIDAPIError(wrapper->backend->hid_error(wrapper->device));