SDL: The return value of SDL_snprintf is the number of characters that would have been written.

From 79b0aae86c9ea97b4ad8eaf385dfc3b1b804656e Mon Sep 17 00:00:00 2001
From: Sam Lantinga <[EMAIL REDACTED]>
Date: Wed, 22 Sep 2021 11:42:10 -0700
Subject: [PATCH] The return value of SDL_snprintf is the number of characters
 that would have been written.

Fixes https://github.com/libsdl-org/SDL/issues/4762
---
 src/stdlib/SDL_string.c      | 184 +++++++++++++++--------------------
 test/testautomation_stdlib.c |  18 ++++
 2 files changed, 95 insertions(+), 107 deletions(-)

diff --git a/src/stdlib/SDL_string.c b/src/stdlib/SDL_string.c
index ec43b1ecaf..76005e2049 100644
--- a/src/stdlib/SDL_string.c
+++ b/src/stdlib/SDL_string.c
@@ -1472,6 +1472,8 @@ int SDL_vsnprintf(SDL_OUT_Z_CAP(maxlen) char *text, size_t maxlen, const char *f
     return vsnprintf(text, maxlen, fmt, ap);
 }
 #else
+#define TEXT_AND_LEN_ARGS   (length < maxlen) ? &text[length] : NULL, (length < maxlen) ? (maxlen - length) : 0
+
  /* FIXME: implement more of the format specifiers */
 typedef enum
 {
@@ -1514,25 +1516,27 @@ SDL_PrintString(char *text, size_t maxlen, SDL_FormatInfo *info, const char *str
         filllen = SDL_min(width, maxlen);
         SDL_memset(text, fill, filllen);
         text += filllen;
-        length += filllen;
         maxlen -= filllen;
+        length += width;
     }
 
-    slen = SDL_strlcpy(text, string, maxlen);
-    length += SDL_min(slen, maxlen);
+    SDL_strlcpy(text, string, maxlen);
+    length += sz;
 
     if (info) {
         if (info->precision >= 0 && (size_t)info->precision < sz) {
             slen = (size_t)info->precision;
             if (slen < maxlen) {
-                text[slen] = 0;
-                length -= (sz - slen);
+                text[slen] = '\0';
             }
+            length -= (sz - slen);
         }
-        if (info->force_case == SDL_CASE_LOWER) {
-            SDL_strlwr(text);
-        } else if (info->force_case == SDL_CASE_UPPER) {
-            SDL_strupr(text);
+        if (maxlen > 1) {
+            if (info->force_case == SDL_CASE_LOWER) {
+                SDL_strlwr(text);
+            } else if (info->force_case == SDL_CASE_UPPER) {
+                SDL_strupr(text);
+            }
         }
     }
     return length;
@@ -1585,7 +1589,7 @@ SDL_PrintLong(char *text, size_t maxlen, SDL_FormatInfo *info, long value)
     }
 
     SDL_ltoa(value, p, info ? info->radix : 10);
-    SDL_IntPrecisionAdjust(num, maxlen, info);
+    SDL_IntPrecisionAdjust(num, sizeof(num), info);
     return SDL_PrintString(text, maxlen, info, num);
 }
 
@@ -1595,7 +1599,7 @@ SDL_PrintUnsignedLong(char *text, size_t maxlen, SDL_FormatInfo *info, unsigned
     char num[130];
 
     SDL_ultoa(value, num, info ? info->radix : 10);
-    SDL_IntPrecisionAdjust(num, maxlen, info);
+    SDL_IntPrecisionAdjust(num, sizeof(num), info);
     return SDL_PrintString(text, maxlen, info, num);
 }
 
@@ -1609,7 +1613,7 @@ SDL_PrintLongLong(char *text, size_t maxlen, SDL_FormatInfo *info, Sint64 value)
     }
 
     SDL_lltoa(value, p, info ? info->radix : 10);
-    SDL_IntPrecisionAdjust(num, maxlen, info);
+    SDL_IntPrecisionAdjust(num, sizeof(num), info);
     return SDL_PrintString(text, maxlen, info, num);
 }
 
@@ -1619,126 +1623,92 @@ SDL_PrintUnsignedLongLong(char *text, size_t maxlen, SDL_FormatInfo *info, Uint6
     char num[130];
 
     SDL_ulltoa(value, num, info ? info->radix : 10);
-    SDL_IntPrecisionAdjust(num, maxlen, info);
+    SDL_IntPrecisionAdjust(num, sizeof(num), info);
     return SDL_PrintString(text, maxlen, info, num);
 }
 
 static size_t
 SDL_PrintFloat(char *text, size_t maxlen, SDL_FormatInfo *info, double arg)
 {
-    int width;
-    size_t len;
-    size_t left = maxlen;
-    char *textstart = text;
+    size_t length = 0;
 
     if (arg) {
         /* This isn't especially accurate, but hey, it's easy. :) */
         unsigned long value;
 
         if (arg < 0) {
-            if (left > 1) {
-                *text = '-';
-                --left;
+            if (length < maxlen) {
+                text[length] = '-';
             }
-            ++text;
+            ++length;
             arg = -arg;
         } else if (info->force_sign) {
-            if (left > 1) {
-                *text = '+';
-                --left;
+            if (length < maxlen) {
+                text[length] = '+';
             }
-            ++text;
+            ++length;
         }
         value = (unsigned long) arg;
-        len = SDL_PrintUnsignedLong(text, left, NULL, value);
-        if (len >= left) {
-            text += (left > 1) ? left - 1 : 0;
-            left = SDL_min(left, 1);
-        } else {
-            text += len;
-            left -= len;
-        }
+        length += SDL_PrintUnsignedLong(TEXT_AND_LEN_ARGS, NULL, value);
         arg -= value;
         if (info->precision < 0) {
             info->precision = 6;
         }
         if (info->force_type || info->precision > 0) {
             int mult = 10;
-            if (left > 1) {
-                *text = '.';
-                --left;
+            if (length < maxlen) {
+                text[length] = '.';
             }
-            ++text;
+            ++length;
             while (info->precision-- > 0) {
                 value = (unsigned long) (arg * mult);
-                len = SDL_PrintUnsignedLong(text, left, NULL, value);
-                if (len >= left) {
-                    text += (left > 1) ? left - 1 : 0;
-                    left = SDL_min(left, 1);
-                } else {
-                    text += len;
-                    left -= len;
-                }
+                length += SDL_PrintUnsignedLong(TEXT_AND_LEN_ARGS, NULL, value);
                 arg -= (double) value / mult;
                 mult *= 10;
             }
         }
     } else {
-        if (left > 1) {
-            *text = '0';
-            --left;
+        if (length < maxlen) {
+            text[length] = '0';
         }
-        ++text;
+        ++length;
         if (info->force_type) {
-            if (left > 1) {
-                *text = '.';
-                --left;
+            if (length < maxlen) {
+                text[length] = '.';
             }
-            ++text;
+            ++length;
         }
     }
 
-    width = info->width - (int)(text - textstart);
-    if (width > 0) {
+    if (info->width > 0 && (size_t)info->width > length) {
         const char fill = info->pad_zeroes ? '0' : ' ';
-        char *end = text+left-1;
-        len = (text - textstart);
-        for (len = (text - textstart); len--; ) {
-            if ((textstart+len+width) < end) {
-                *(textstart+len+width) = *(textstart+len);
-            }
-        }
-        len = (size_t)width;
-        if (len >= left) {
-            text += (left > 1) ? left - 1 : 0;
-            left = SDL_min(left, 1);
-        } else {
-            text += len;
-            left -= len;
-        }
+        size_t width = info->width - length;
+        size_t filllen, movelen;
 
-        if (end != textstart) {
-            const size_t filllen = SDL_min(len, ((size_t) (end - textstart)) - 1);
-            SDL_memset(textstart, fill, filllen);
-        }
+        filllen = SDL_min(width, maxlen);
+        movelen = SDL_min(length, (maxlen - filllen));
+        SDL_memmove(&text[filllen], text, movelen);
+        SDL_memset(text, fill, filllen);
+        length += width;
     }
 
-    return (text - textstart);
+    return length;
 }
 
 int
 SDL_vsnprintf(SDL_OUT_Z_CAP(maxlen) char *text, size_t maxlen, const char *fmt, va_list ap)
 {
-    size_t left = maxlen;
-    char *textstart = text;
+    size_t length = 0;
 
+    if (!text) {
+        maxlen = 0;
+    }
     if (!fmt) {
         fmt = "";
     }
-    while (*fmt && left > 1) {
+    while (*fmt) {
         if (*fmt == '%') {
             SDL_bool done = SDL_FALSE;
-            size_t len = 0;
             SDL_bool check_flag;
             SDL_FormatInfo info;
             enum
@@ -1800,18 +1770,18 @@ SDL_vsnprintf(SDL_OUT_Z_CAP(maxlen) char *text, size_t maxlen, const char *fmt,
             while (!done) {
                 switch (*fmt) {
                 case '%':
-                    if (left > 1) {
-                        *text = '%';
+                    if (length < maxlen) {
+                        text[length] = '%';
                     }
-                    len = 1;
+                    ++length;
                     done = SDL_TRUE;
                     break;
                 case 'c':
                     /* char is promoted to int when passed through (...) */
-                    if (left > 1) {
-                        *text = (char) va_arg(ap, int);
+                    if (length < maxlen) {
+                        text[length] = (char) va_arg(ap, int);
                     }
-                    len = 1;
+                    ++length;
                     done = SDL_TRUE;
                     break;
                 case 'h':
@@ -1835,15 +1805,15 @@ SDL_vsnprintf(SDL_OUT_Z_CAP(maxlen) char *text, size_t maxlen, const char *fmt,
                     }
                     switch (inttype) {
                     case DO_INT:
-                        len = SDL_PrintLong(text, left, &info,
+                        length += SDL_PrintLong(TEXT_AND_LEN_ARGS, &info,
                                             (long) va_arg(ap, int));
                         break;
                     case DO_LONG:
-                        len = SDL_PrintLong(text, left, &info,
+                        length += SDL_PrintLong(TEXT_AND_LEN_ARGS, &info,
                                             va_arg(ap, long));
                         break;
                     case DO_LONGLONG:
-                        len = SDL_PrintLongLong(text, left, &info,
+                        length += SDL_PrintLongLong(TEXT_AND_LEN_ARGS, &info,
                                                 va_arg(ap, Sint64));
                         break;
                     }
@@ -1876,23 +1846,23 @@ SDL_vsnprintf(SDL_OUT_Z_CAP(maxlen) char *text, size_t maxlen, const char *fmt,
                     }
                     switch (inttype) {
                     case DO_INT:
-                        len = SDL_PrintUnsignedLong(text, left, &info,
+                        length += SDL_PrintUnsignedLong(TEXT_AND_LEN_ARGS, &info,
                                                     (unsigned long)
                                                     va_arg(ap, unsigned int));
                         break;
                     case DO_LONG:
-                        len = SDL_PrintUnsignedLong(text, left, &info,
+                        length += SDL_PrintUnsignedLong(TEXT_AND_LEN_ARGS, &info,
                                                     va_arg(ap, unsigned long));
                         break;
                     case DO_LONGLONG:
-                        len = SDL_PrintUnsignedLongLong(text, left, &info,
+                        length += SDL_PrintUnsignedLongLong(TEXT_AND_LEN_ARGS, &info,
                                                         va_arg(ap, Uint64));
                         break;
                     }
                     done = SDL_TRUE;
                     break;
                 case 'f':
-                    len = SDL_PrintFloat(text, left, &info, va_arg(ap, double));
+                    length += SDL_PrintFloat(TEXT_AND_LEN_ARGS, &info, va_arg(ap, double));
                     done = SDL_TRUE;
                     break;
                 case 'S':
@@ -1902,18 +1872,18 @@ SDL_vsnprintf(SDL_OUT_Z_CAP(maxlen) char *text, size_t maxlen, const char *fmt,
                         if (wide_arg) {
                             char *arg = SDL_iconv_string("UTF-8", "UTF-16LE", (char *)(wide_arg), (SDL_wcslen(wide_arg)+1)*sizeof(*wide_arg));
                             info.pad_zeroes = SDL_FALSE;
-                            len = SDL_PrintString(text, left, &info, arg);
+                            length += SDL_PrintString(TEXT_AND_LEN_ARGS, &info, arg);
                             SDL_free(arg);
                         } else {
                             info.pad_zeroes = SDL_FALSE;
-                            len = SDL_PrintString(text, left, &info, NULL);
+                            length += SDL_PrintString(TEXT_AND_LEN_ARGS, &info, NULL);
                         }
                         done = SDL_TRUE;
                     }
                     break;
                 case 's':
                     info.pad_zeroes = SDL_FALSE;
-                    len = SDL_PrintString(text, left, &info, va_arg(ap, char *));
+                    length += SDL_PrintString(TEXT_AND_LEN_ARGS, &info, va_arg(ap, char *));
                     done = SDL_TRUE;
                     break;
                 default:
@@ -1922,23 +1892,23 @@ SDL_vsnprintf(SDL_OUT_Z_CAP(maxlen) char *text, size_t maxlen, const char *fmt,
                 }
                 ++fmt;
             }
-            if (len >= left) {
-                text += (left > 1) ? left - 1 : 0;
-                left = SDL_min(left, 1);
-            } else {
-                text += len;
-                left -= len;
-            }
         } else {
-            *text++ = *fmt++;
-            --left;
+            if (length < maxlen) {
+                text[length] = *fmt++;
+            }
+            ++length;
         }
     }
-    if (left > 0) {
-        *text = '\0';
+    if (length < maxlen) {
+        text[length] = '\0';
+    } else if (maxlen > 0) {
+        text[maxlen - 1] = '\0';
     }
-    return (int)(text - textstart);
+    return (int)length;
+
 }
+
+#undef TEXT_AND_LEN_ARGS
 #endif /* HAVE_VSNPRINTF */
 
 /* vi: set ts=4 sw=4 expandtab: */
diff --git a/test/testautomation_stdlib.c b/test/testautomation_stdlib.c
index e741d547e9..9ea19649f9 100644
--- a/test/testautomation_stdlib.c
+++ b/test/testautomation_stdlib.c
@@ -44,6 +44,7 @@ int
 stdlib_snprintf(void *arg)
 {
   int result;
+  int predicted;
   char text[1024];
   const char *expected;
 
@@ -60,55 +61,72 @@ stdlib_snprintf(void *arg)
   SDLTest_AssertCheck(result == 3, "Check result value, expected: 3, got: %d", result);
 
   result = SDL_snprintf(NULL, 0, "%s", "foo");
+  SDLTest_AssertPass("Call to SDL_snprintf(NULL, 0, \"%%s\", \"foo\")");
   SDLTest_AssertCheck(result == 3, "Check result value, expected: 3, got: %d", result);
 
   result = SDL_snprintf(text, sizeof(text), "%f", 1.0);
+  predicted = SDL_snprintf(NULL, 0, "%f", 1.0);
   expected = "1.000000";
   SDLTest_AssertPass("Call to SDL_snprintf(\"%%f\", 1.0)");
   SDLTest_AssertCheck(SDL_strcmp(text, expected) == 0, "Check text, expected: %s, got: %s", expected, text);
   SDLTest_AssertCheck(result == SDL_strlen(text), "Check result value, expected: %d, got: %d", (int) SDL_strlen(text), result);
+  SDLTest_AssertCheck(predicted == result, "Check predicted value, expected: %d, got: %d", result, predicted);
 
   result = SDL_snprintf(text, sizeof(text), "%.f", 1.0);
+  predicted = SDL_snprintf(NULL, 0, "%.f", 1.0);
   expected = "1";
   SDLTest_AssertPass("Call to SDL_snprintf(\"%%.f\", 1.0)");
   SDLTest_AssertCheck(SDL_strcmp(text, expected) == 0, "Check text, expected: %s, got: %s", expected, text);
   SDLTest_AssertCheck(result == SDL_strlen(text), "Check result value, expected: %d, got: %d", (int) SDL_strlen(text), result);
+  SDLTest_AssertCheck(predicted == result, "Check predicted value, expected: %d, got: %d", result, predicted);
 
   result = SDL_snprintf(text, sizeof(text), "%#.f", 1.0);
+  predicted = SDL_snprintf(NULL, 0, "%#.f", 1.0);
   expected = "1.";
   SDLTest_AssertPass("Call to SDL_snprintf(\"%%#.f\", 1.0)");
   SDLTest_AssertCheck(SDL_strcmp(text, expected) == 0, "Check text, expected: %s, got: %s", expected, text);
   SDLTest_AssertCheck(result == SDL_strlen(text), "Check result value, expected: %d, got: %d", (int) SDL_strlen(text), result);
+  SDLTest_AssertCheck(predicted == result, "Check predicted value, expected: %d, got: %d", result, predicted);
 
   result = SDL_snprintf(text, sizeof(text), "%f", 1.0 + 1.0 / 3.0);
+  predicted = SDL_snprintf(NULL, 0, "%f", 1.0 + 1.0 / 3.0);
   expected = "1.333333";
   SDLTest_AssertPass("Call to SDL_snprintf(\"%%f\", 1.0 + 1.0 / 3.0)");
   SDLTest_AssertCheck(SDL_strcmp(text, expected) == 0, "Check text, expected: %s, got: %s", expected, text);
   SDLTest_AssertCheck(result == SDL_strlen(text), "Check result value, expected: %d, got: %d", (int) SDL_strlen(text), result);
+  SDLTest_AssertCheck(predicted == result, "Check predicted value, expected: %d, got: %d", result, predicted);
 
   result = SDL_snprintf(text, sizeof(text), "%+f", 1.0 + 1.0 / 3.0);
+  predicted = SDL_snprintf(NULL, 0, "%+f", 1.0 + 1.0 / 3.0);
   expected = "+1.333333";
   SDLTest_AssertPass("Call to SDL_snprintf(\"%%+f\", 1.0 + 1.0 / 3.0)");
   SDLTest_AssertCheck(SDL_strcmp(text, expected) == 0, "Check text, expected: %s, got: %s", expected, text);
   SDLTest_AssertCheck(result == SDL_strlen(text), "Check result value, expected: %d, got: %d", (int) SDL_strlen(text), result);
+  SDLTest_AssertCheck(predicted == result, "Check predicted value, expected: %d, got: %d", result, predicted);
 
   result = SDL_snprintf(text, sizeof(text), "%.2f", 1.0 + 1.0 / 3.0);
+  predicted = SDL_snprintf(NULL, 0, "%.2f", 1.0 + 1.0 / 3.0);
   expected = "1.33";
   SDLTest_AssertPass("Call to SDL_snprintf(\"%%.2f\", 1.0 + 1.0 / 3.0)");
   SDLTest_AssertCheck(SDL_strcmp(text, expected) == 0, "Check text, expected: %s, got: %s", expected, text);
   SDLTest_AssertCheck(result == SDL_strlen(text), "Check result value, expected: %d, got: %d", (int) SDL_strlen(text), result);
+  SDLTest_AssertCheck(predicted == result, "Check predicted value, expected: %d, got: %d", result, predicted);
 
   result = SDL_snprintf(text, sizeof(text), "%6.2f", 1.0 + 1.0 / 3.0);
+  predicted = SDL_snprintf(NULL, 0, "%6.2f", 1.0 + 1.0 / 3.0);
   expected = "  1.33";
   SDLTest_AssertPass("Call to SDL_snprintf(\"%%6.2f\", 1.0 + 1.0 / 3.0)");
   SDLTest_AssertCheck(SDL_strcmp(text, expected) == 0, "Check text, expected: '%s', got: '%s'", expected, text);
   SDLTest_AssertCheck(result == SDL_strlen(text), "Check result value, expected: %d, got: %d", (int) SDL_strlen(text), result);
+  SDLTest_AssertCheck(predicted == result, "Check predicted value, expected: %d, got: %d", result, predicted);
 
   result = SDL_snprintf(text, sizeof(text), "%06.2f", 1.0 + 1.0 / 3.0);
+  predicted = SDL_snprintf(NULL, 0, "%06.2f", 1.0 + 1.0 / 3.0);
   expected = "001.33";
   SDLTest_AssertPass("Call to SDL_snprintf(\"%%06.2f\", 1.0 + 1.0 / 3.0)");
   SDLTest_AssertCheck(SDL_strcmp(text, expected) == 0, "Check text, expected: '%s', got: '%s'", expected, text);
   SDLTest_AssertCheck(result == SDL_strlen(text), "Check result value, expected: %d, got: %d", (int) SDL_strlen(text), result);
+  SDLTest_AssertCheck(predicted == result, "Check predicted value, expected: %d, got: %d", result, predicted);
 
   result = SDL_snprintf(text, 5, "%06.2f", 1.0 + 1.0 / 3.0);
   expected = "001.";