SDL_image: Added support for loading the PNG palette when using STB_IMAGE

From 164a13a8fb07da9a642c3a1814876555ae6795e4 Mon Sep 17 00:00:00 2001
From: Sam Lantinga <[EMAIL REDACTED]>
Date: Mon, 16 Oct 2023 10:19:29 -0700
Subject: [PATCH] Added support for loading the PNG palette when using
 STB_IMAGE

Based on https://github.com/nothings/stb/pull/788

Fixes https://github.com/libsdl-org/SDL_image/issues/298
Fixes https://github.com/libsdl-org/SDL_image/issues/300
---
 src/IMG_stb.c   |  78 +++++++++++++++++++++---
 src/stb_image.h | 154 +++++++++++++++++++++++++++++++++++++++---------
 2 files changed, 195 insertions(+), 37 deletions(-)

diff --git a/src/IMG_stb.c b/src/IMG_stb.c
index acfc3fa3..be3479a7 100644
--- a/src/IMG_stb.c
+++ b/src/IMG_stb.c
@@ -77,10 +77,13 @@ static int IMG_LoadSTB_RW_eof(void *user)
 SDL_Surface *IMG_LoadSTB_RW(SDL_RWops *src)
 {
     Sint64 start;
+    Uint8 magic[26];
     int w, h, format;
     stbi_uc *pixels;
     stbi_io_callbacks rw_callbacks;
     SDL_Surface *surface = NULL;
+    SDL_bool use_palette = SDL_FALSE;
+    unsigned int palette_colors[256];
 
     if (!src) {
         /* The error message has been set in SDL_RWFromFile */
@@ -88,25 +91,82 @@ SDL_Surface *IMG_LoadSTB_RW(SDL_RWops *src)
     }
     start = SDL_RWtell(src);
 
+    if (SDL_RWread(src, magic, sizeof(magic)) == sizeof(magic)) {
+        const Uint8 PNG_COLOR_INDEXED = 3;
+        if (magic[0] == 0x89 &&
+            magic[1] == 'P' &&
+            magic[2] == 'N' &&
+            magic[3] == 'G' &&
+            magic[12] == 'I' &&
+            magic[13] == 'H' &&
+            magic[14] == 'D' &&
+            magic[15] == 'R' &&
+            magic[25] == PNG_COLOR_INDEXED) {
+            use_palette = SDL_TRUE;
+        }
+    }
+    SDL_RWseek(src, start, SDL_RW_SEEK_SET);
+
     /* Load the image data */
     rw_callbacks.read = IMG_LoadSTB_RW_read;
     rw_callbacks.skip = IMG_LoadSTB_RW_skip;
     rw_callbacks.eof = IMG_LoadSTB_RW_eof;
     w = h = format = 0; /* silence warning */
-    pixels = stbi_load_from_callbacks(
-        &rw_callbacks,
-        src,
-        &w,
-        &h,
-        &format,
-        STBI_default
-    );
+    if (use_palette) {
+        pixels = stbi_load_from_callbacks_with_palette(
+            &rw_callbacks,
+            src,
+            &w,
+            &h,
+            palette_colors,
+            SDL_arraysize(palette_colors)
+        );
+    } else {
+        pixels = stbi_load_from_callbacks(
+            &rw_callbacks,
+            src,
+            &w,
+            &h,
+            &format,
+            STBI_default
+        );
+    }
     if (!pixels) {
         SDL_RWseek(src, start, SDL_RW_SEEK_SET);
         return NULL;
     }
 
-    if (format == STBI_grey || format == STBI_rgb || format == STBI_rgb_alpha) {
+    if (use_palette) {
+        surface = SDL_CreateSurfaceFrom(
+            pixels,
+            w,
+            h,
+            w,
+            SDL_PIXELFORMAT_INDEX8
+        );
+        if (surface) {
+            SDL_Palette *palette = surface->format->palette;
+            if (palette) {
+                int i;
+                Uint8 *palette_bytes = (Uint8 *)palette_colors;
+
+                for (i = 0; i < palette->ncolors; i++) {
+                    palette->colors[i].r = *palette_bytes++;
+                    palette->colors[i].g = *palette_bytes++;
+                    palette->colors[i].b = *palette_bytes++;
+                    palette->colors[i].a = *palette_bytes++;
+                }
+            }
+
+            /* FIXME: This sucks. It'd be better to allocate the surface first, then
+             * write directly to the pixel buffer:
+             * https://github.com/nothings/stb/issues/58
+             * -flibit
+             */
+            surface->flags &= ~SDL_PREALLOC;
+        }
+
+    } else if (format == STBI_grey || format == STBI_rgb || format == STBI_rgb_alpha) {
         surface = SDL_CreateSurfaceFrom(
             pixels,
             w,
diff --git a/src/stb_image.h b/src/stb_image.h
index 6c559fb7..eaeff8a3 100644
--- a/src/stb_image.h
+++ b/src/stb_image.h
@@ -92,7 +92,7 @@ RECENT REVISION HISTORY:
  Optimizations & bugfixes                  Mikhail Morozov (1-bit BMP)
     Fabian "ryg" Giesen                    Anael Seghezzi (is-16-bit query)
     Arseny Kapoulkine                      Simon Breuss (16-bit PNM)
-    John-Mark Allen
+    John-Mark Allen                        Katelyn Gadd (indexed color loading)
     Carmelo J Fdez-Aguera
 
  Bug & warning fixes
@@ -442,6 +442,18 @@ STBIDEF stbi_uc *stbi_load_gif_from_memory(stbi_uc const *buffer, int len, int *
 STBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input);
 #endif
 
+////////////////////////////////////
+//
+// 8-bits-per-channel indexed color
+// Will fail if image is not an 8-bit PNG or TGA with a palette. 
+// Palette buffer needs to be at least 256 entries for PNG.
+//
+
+#if 0 /* not used in SDL_image */
+STBIDEF stbi_uc *stbi_load_from_memory_with_palette   (stbi_uc           const *buffer, int len , int *x, int *y, unsigned int *palette_buffer, int palette_buffer_len);
+#endif
+STBIDEF stbi_uc *stbi_load_from_callbacks_with_palette(stbi_io_callbacks const *clbk, void *user, int *x, int *y, unsigned int *palette_buffer, int palette_buffer_len);
+
 ////////////////////////////////////
 //
 // 16-bits-per-channel interface
@@ -694,6 +706,10 @@ typedef Uint32 stbi__uint32;
 typedef Sint32 stbi__int32;
 #endif
 
+#ifndef STBI_BUFFER_SIZE
+#define STBI_BUFFER_SIZE 128
+#endif
+
 // should produce compiler error if size is wrong
 typedef unsigned char validate_uint32[sizeof(stbi__uint32)==4 ? 1 : -1];
 
@@ -960,7 +976,7 @@ static int      stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp);
 
 #ifndef STBI_NO_PNG
 static int      stbi__png_test(stbi__context *s);
-static void    *stbi__png_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri);
+static void    *stbi__png_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, unsigned int *palette_buffer, int palette_buffer_len, stbi__result_info *ri);
 #if 0 /* not used in SDL_image */
 static int      stbi__png_info(stbi__context *s, int *x, int *y, int *comp);
 static int      stbi__png_is16(stbi__context *s);
@@ -975,7 +991,7 @@ static int      stbi__bmp_info(stbi__context *s, int *x, int *y, int *comp);
 
 #ifndef STBI_NO_TGA
 static int      stbi__tga_test(stbi__context *s);
-static void    *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri);
+static void    *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, unsigned int *palette_buffer, int palette_buffer_len, stbi__result_info *ri);
 static int      stbi__tga_info(stbi__context *s, int *x, int *y, int *comp);
 #endif
 
@@ -1193,7 +1209,7 @@ STBIDEF void stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_fli
                                          : stbi__vertically_flip_on_load_global)
 #endif // STBI_THREAD_LOCAL
 
-static void *stbi__load_main(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri, int bpc)
+static void *stbi__load_main(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri, int bpc, unsigned int *palette_buffer, int palette_buffer_len)
 {
    memset(ri, 0, sizeof(*ri)); // make sure it's initialized if we add new fields
    ri->bits_per_channel = 8; // default is 8 so most paths don't have to be changed
@@ -1203,7 +1219,7 @@ static void *stbi__load_main(stbi__context *s, int *x, int *y, int *comp, int re
    // test the formats with a very explicit header first (at least a FOURCC
    // or distinctive magic number first)
    #ifndef STBI_NO_PNG
-   if (stbi__png_test(s))  return stbi__png_load(s,x,y,comp,req_comp, ri);
+   if (stbi__png_test(s))  return stbi__png_load(s,x,y,comp,req_comp, palette_buffer, palette_buffer_len, ri);
    #endif
    #ifndef STBI_NO_BMP
    if (stbi__bmp_test(s))  return stbi__bmp_load(s,x,y,comp,req_comp, ri);
@@ -1240,7 +1256,7 @@ static void *stbi__load_main(stbi__context *s, int *x, int *y, int *comp, int re
    #ifndef STBI_NO_TGA
    // test tga last because it's a crappy test!
    if (stbi__tga_test(s))
-      return stbi__tga_load(s,x,y,comp,req_comp, ri);
+      return stbi__tga_load(s,x,y,comp,req_comp, palette_buffer, palette_buffer_len, ri);
    #endif
 
    return stbi__errpuc("unknown image type", "Image not of any known type, or corrupt");
@@ -1318,10 +1334,42 @@ static void stbi__vertical_flip_slices(void *image, int w, int h, int z, int byt
 }
 #endif
 
+static unsigned char *stbi__load_indexed(stbi__context *s, int *x, int *y, unsigned int *palette_buffer, int palette_buffer_len)
+{
+   if (!palette_buffer)
+       return NULL;
+
+   stbi__result_info ri;
+   int comp;
+   void *result = stbi__load_main(s, x, y, &comp, 1, &ri, 8, palette_buffer, palette_buffer_len);
+
+   if (result == NULL)
+      return NULL;
+
+   if (comp != 1) {
+       stbi_image_free(result);
+       return NULL;
+   }
+
+   if (ri.bits_per_channel != 8) {
+      stbi_image_free(result);
+      return NULL;
+   }
+
+   // @TODO: move stbi__convert_format to here
+
+   if (stbi__vertically_flip_on_load) {
+      int channels = 1;
+      stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi_uc));
+   }
+
+   return (unsigned char *) result;
+}
+
 static unsigned char *stbi__load_and_postprocess_8bit(stbi__context *s, int *x, int *y, int *comp, int req_comp)
 {
    stbi__result_info ri;
-   void *result = stbi__load_main(s, x, y, comp, req_comp, &ri, 8);
+   void *result = stbi__load_main(s, x, y, comp, req_comp, &ri, 8, NULL, 0);
 
    if (result == NULL)
       return NULL;
@@ -1348,7 +1396,7 @@ static unsigned char *stbi__load_and_postprocess_8bit(stbi__context *s, int *x,
 static stbi__uint16 *stbi__load_and_postprocess_16bit(stbi__context *s, int *x, int *y, int *comp, int req_comp)
 {
    stbi__result_info ri;
-   void *result = stbi__load_main(s, x, y, comp, req_comp, &ri, 16);
+   void *result = stbi__load_main(s, x, y, comp, req_comp, &ri, 16, NULL, 0);
 
    if (result == NULL)
       return NULL;
@@ -1505,6 +1553,22 @@ STBIDEF stbi_uc *stbi_load_from_callbacks(stbi_io_callbacks const *clbk, void *u
    return stbi__load_and_postprocess_8bit(&s,x,y,comp,req_comp);
 }
 
+#if 0 /* not used in SDL_image */
+STBIDEF stbi_uc *stbi_load_from_memory_with_palette(stbi_uc const *buffer, int len, int *x, int *y, unsigned int *palette_buffer, int palette_buffer_len)
+{
+    stbi__context s;
+    stbi__start_mem(&s, buffer, len);
+    return stbi__load_indexed(&s, x, y, palette_buffer, palette_buffer_len);
+}
+#endif
+
+STBIDEF stbi_uc *stbi_load_from_callbacks_with_palette(stbi_io_callbacks const *clbk, void *user, int *x, int *y, unsigned int *palette_buffer, int palette_buffer_len)
+{
+    stbi__context s;
+    stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user);
+    return stbi__load_indexed(&s, x, y, palette_buffer, palette_buffer_len);
+}
+
 #ifndef STBI_NO_GIF
 STBIDEF stbi_uc *stbi_load_gif_from_memory(stbi_uc const *buffer, int len, int **delays, int *x, int *y, int *z, int *comp, int req_comp)
 {
@@ -5180,15 +5244,25 @@ static void stbi__de_iphone(stbi__png *z)
 
 #define STBI__PNG_TYPE(a,b,c,d)  (((unsigned) (a) << 24) + ((unsigned) (b) << 16) + ((unsigned) (c) << 8) + (unsigned) (d))
 
-static int stbi__parse_png_file(stbi__png *z, int scan, int req_comp)
+static int stbi__parse_png_file(stbi__png *z, int scan, int req_comp, unsigned int *palette_buffer, int palette_buffer_len)
 {
-   stbi_uc palette[1024], pal_img_n=0;
+   stbi_uc _palette[1024], pal_img_n=0;
+   stbi_uc *palette = _palette;
    stbi_uc has_trans=0, tc[3]={0};
    stbi__uint16 tc16[3];
    stbi__uint32 ioff=0, idata_limit=0, i, pal_len=0;
    int first=1,k,interlace=0, color=0, is_iphone=0;
    stbi__context *s = z->s;
 
+   if (palette_buffer) {
+       if (palette_buffer_len < 256)
+           return stbi__err("palette buffer too small", "palette buffer len must be 256");
+       else if (req_comp != 1)
+           return stbi__err("invalid req_comp", "req_comp must be 1 when loading paletted");
+       else
+           palette = (stbi_uc *)(void *)palette_buffer;
+   }
+
    z->expanded = NULL;
    z->idata = NULL;
    z->out = NULL;
@@ -5329,8 +5403,9 @@ static int stbi__parse_png_file(stbi__png *z, int scan, int req_comp)
                s->img_n = pal_img_n; // record the actual colors we had
                s->img_out_n = pal_img_n;
                if (req_comp >= 3) s->img_out_n = req_comp;
-               if (!stbi__expand_png_palette(z, palette, pal_len, s->img_out_n))
-                  return 0;
+               if (!palette_buffer)
+                   if (!stbi__expand_png_palette(z, palette, pal_len, s->img_out_n))
+                       return 0;
             } else if (has_trans) {
                // non-paletted image with tRNS -> source image has (constant) alpha
                ++s->img_n;
@@ -5364,11 +5439,18 @@ static int stbi__parse_png_file(stbi__png *z, int scan, int req_comp)
    }
 }
 
-static void *stbi__do_png(stbi__png *p, int *x, int *y, int *n, int req_comp, stbi__result_info *ri)
+static void *stbi__do_png(stbi__png *p, int *x, int *y, int *n, int req_comp, unsigned int *palette_buffer, int palette_buffer_len, stbi__result_info *ri)
 {
    void *result=NULL;
-   if (req_comp < 0 || req_comp > 4) return stbi__errpuc("bad req_comp", "Internal error");
-   if (stbi__parse_png_file(p, STBI__SCAN_load, req_comp)) {
+   if (palette_buffer && req_comp != 1) {
+      stbi__err("bad req_comp", "req_comp must be 1 if loading paletted image without expansion");
+      return NULL;
+   }
+   if (req_comp < 0 || req_comp > 4) {
+      stbi__err("bad req_comp", "Internal error");
+      return NULL;
+   }
+   if (stbi__parse_png_file(p, STBI__SCAN_load, req_comp, palette_buffer, palette_buffer_len)) {
       if (p->depth <= 8)
          ri->bits_per_channel = 8;
       else if (p->depth == 16)
@@ -5378,7 +5460,9 @@ static void *stbi__do_png(stbi__png *p, int *x, int *y, int *n, int req_comp, st
       result = p->out;
       p->out = NULL;
       if (req_comp && req_comp != p->s->img_out_n) {
-         if (ri->bits_per_channel == 8)
+         if (palette_buffer)
+            ;
+         else if (ri->bits_per_channel == 8)
             result = stbi__convert_format((unsigned char *) result, p->s->img_out_n, req_comp, p->s->img_x, p->s->img_y);
          else
             result = stbi__convert_format16((stbi__uint16 *) result, p->s->img_out_n, req_comp, p->s->img_x, p->s->img_y);
@@ -5387,7 +5471,12 @@ static void *stbi__do_png(stbi__png *p, int *x, int *y, int *n, int req_comp, st
       }
       *x = p->s->img_x;
       *y = p->s->img_y;
-      if (n) *n = p->s->img_n;
+      if (n) {
+         if (palette_buffer)
+            *n = 1;
+         else
+            *n = p->s->img_n;
+      }
    }
    STBI_FREE(p->out);      p->out      = NULL;
    STBI_FREE(p->expanded); p->expanded = NULL;
@@ -5396,11 +5485,11 @@ static void *stbi__do_png(stbi__png *p, int *x, int *y, int *n, int req_comp, st
    return result;
 }
 
-static void *stbi__png_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri)
+static void *stbi__png_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, unsigned int *palette_buffer, int palette_buffer_len, stbi__result_info *ri)
 {
    stbi__png p;
    p.s = s;
-   return stbi__do_png(&p, x,y,comp,req_comp, ri);
+   return stbi__do_png(&p, x,y,comp,req_comp, palette_buffer, palette_buffer_len, ri);
 }
 
 static int stbi__png_test(stbi__context *s)
@@ -5414,7 +5503,7 @@ static int stbi__png_test(stbi__context *s)
 #if 0 /* not used in SDL_image */
 static int stbi__png_info_raw(stbi__png *p, int *x, int *y, int *comp)
 {
-   if (!stbi__parse_png_file(p, STBI__SCAN_header, 0)) {
+   if (!stbi__parse_png_file(p, STBI__SCAN_header, NULL, 0, NULL)) {
       stbi__rewind( p->s );
       return 0;
    }
@@ -5974,7 +6063,7 @@ static void stbi__tga_read_rgb16(stbi__context *s, stbi_uc* out)
    // so let's treat all 15 and 16bit TGAs as RGB with no alpha.
 }
 
-static void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri)
+static void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, unsigned int *palette_buffer, int palette_buffer_len, stbi__result_info *ri)
 {
    //   read in the TGA header stuff
    int tga_offset = stbi__get8(s);
@@ -6054,10 +6143,18 @@ static void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, int req
          //   any data to skip? (offset usually = 0)
          stbi__skip(s, tga_palette_start );
          //   load the palette
-         tga_palette = (unsigned char*)stbi__malloc_mad2(tga_palette_len, tga_comp, 0);
-         if (!tga_palette) {
-            STBI_FREE(tga_data);
-            return stbi__errpuc("outofmem", "Out of memory");
+         if (palette_buffer) {
+            if (palette_buffer_len < tga_palette_len * tga_comp) {
+               STBI_FREE(tga_data);
+               return stbi__errpuc("buffer too small", "Palette buffer too small");
+            }
+            tga_palette = (unsigned char*)(void*)palette_buffer;
+         } else {
+            tga_palette = (unsigned char*)stbi__malloc_mad2(tga_palette_len, tga_comp, 0);
+            if (!tga_palette) {
+               STBI_FREE(tga_data);
+               return stbi__errpuc("outofmem", "Out of memory");
+            }
          }
          if (tga_rgb16) {
             stbi_uc *pal_entry = tga_palette;
@@ -6068,7 +6165,8 @@ static void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, int req
             }
          } else if (!stbi__getn(s, tga_palette, tga_palette_len * tga_comp)) {
                STBI_FREE(tga_data);
-               STBI_FREE(tga_palette);
+               if (!palette_buffer)
+                  STBI_FREE(tga_palette);
                return stbi__errpuc("bad palette", "Corrupt TGA");
          }
       }
@@ -6097,7 +6195,7 @@ static void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, int req
          if ( read_next_pixel )
          {
             //   load however much data we did have
-            if ( tga_indexed )
+            if ( tga_indexed && !palette_buffer )
             {
                // read in index, then perform the lookup
                int pal_idx = (tga_bits_per_pixel == 8) ? stbi__get8(s) : stbi__get16le(s);
@@ -6147,7 +6245,7 @@ static void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, int req
          }
       }
       //   clear my palette, if I had one
-      if ( tga_palette != NULL )
+      if ( tga_palette != NULL && !palette_buffer )
       {
          STBI_FREE( tga_palette );
       }