SDL: hashtable: reimplement as open-addressed robin hood hashtable

From ba7b346e522d90d2baae0047bcdda053df136fa1 Mon Sep 17 00:00:00 2001
From: Andrei Alexeyev <[EMAIL REDACTED]>
Date: Wed, 18 Sep 2024 16:27:20 +0200
Subject: [PATCH] hashtable: reimplement as open-addressed robin hood hashtable

This is mostly ported from Taisei Project
---
 src/SDL_hashtable.c | 428 ++++++++++++++++++++++++++++++++------------
 1 file changed, 313 insertions(+), 115 deletions(-)

diff --git a/src/SDL_hashtable.c b/src/SDL_hashtable.c
index feaa08f52af2b..da76b03eb60ea 100644
--- a/src/SDL_hashtable.c
+++ b/src/SDL_hashtable.c
@@ -18,26 +18,42 @@
      misrepresented as being the original software.
   3. This notice may not be removed or altered from any source distribution.
 */
+
 #include "SDL_internal.h"
 #include "SDL_hashtable.h"
 
+// XXX: We can't use SDL_assert here because it's going to call into hashtable code
+#include <assert.h>
+#define HT_ASSERT(x) assert(x)
+
 typedef struct SDL_HashItem
 {
+    // TODO: Splitting off values into a separate array might be more cache-friendly
     const void *key;
     const void *value;
-    struct SDL_HashItem *next;
+    Uint32 hash;
+    Uint32 probe_len : 31;
+    Uint32 live : 1;
 } SDL_HashItem;
 
+// Must be a power of 2 >= sizeof(SDL_HashItem)
+#define MAX_HASHITEM_SIZEOF 32u
+SDL_COMPILE_TIME_ASSERT(sizeof_SDL_HashItem, sizeof(SDL_HashItem) <= MAX_HASHITEM_SIZEOF);
+
+// Anything larger than this will cause integer overflows
+#define MAX_HASHTABLE_SIZE (0x80000000u / (MAX_HASHITEM_SIZEOF))
+
 struct SDL_HashTable
 {
-    SDL_HashItem **table;
-    Uint32 table_len;
-    int hash_shift;
-    bool stackable;
-    void *data;
+    SDL_HashItem *table;
     SDL_HashTable_HashFn hash;
     SDL_HashTable_KeyMatchFn keymatch;
     SDL_HashTable_NukeFn nuke;
+    void *data;
+    Uint32 hash_mask;
+    Uint32 max_probe_len;
+    Uint32 num_occupied_slots;
+    bool stackable;
 };
 
 SDL_HashTable *SDL_CreateHashTable(void *data, const Uint32 num_buckets, const SDL_HashTable_HashFn hashfn,
@@ -47,26 +63,29 @@ SDL_HashTable *SDL_CreateHashTable(void *data, const Uint32 num_buckets, const S
 {
     SDL_HashTable *table;
 
-    // num_buckets must be a power of two so we can derive the bucket index with just a bitshift.
-    // Need at least two buckets, otherwise hash_shift would be 32, which is UB!
-    if ((num_buckets < 2) || !SDL_HasExactlyOneBitSet32(num_buckets)) {
+    // num_buckets must be a power of two so we can derive the bucket index with just a bit-and.
+    if ((num_buckets < 1) || !SDL_HasExactlyOneBitSet32(num_buckets)) {
         SDL_SetError("num_buckets must be a power of two");
         return NULL;
     }
 
-    table = (SDL_HashTable *) SDL_calloc(1, sizeof (SDL_HashTable));
+    if (num_buckets > MAX_HASHTABLE_SIZE) {
+        SDL_SetError("num_buckets is too large");
+        return NULL;
+    }
+
+    table = (SDL_HashTable *)SDL_calloc(1, sizeof(SDL_HashTable));
     if (!table) {
         return NULL;
     }
 
-    table->table = (SDL_HashItem **) SDL_calloc(num_buckets, sizeof (SDL_HashItem *));
+    table->table = (SDL_HashItem *)SDL_calloc(num_buckets, sizeof(SDL_HashItem));
     if (!table->table) {
         SDL_free(table);
         return NULL;
     }
 
-    table->table_len = num_buckets;
-    table->hash_shift = 32 - SDL_MostSignificantBitIndex32(num_buckets);
+    table->hash_mask = num_buckets - 1;
     table->stackable = stackable;
     table->data = data;
     table->hash = hashfn;
@@ -75,47 +94,232 @@ SDL_HashTable *SDL_CreateHashTable(void *data, const Uint32 num_buckets, const S
     return table;
 }
 
-static SDL_INLINE Uint32 calc_hash(const SDL_HashTable *table, const void *key)
+static SDL_INLINE Uint32 calc_hash(const SDL_HashTable *restrict table, const void *key)
 {
-    // Mix the bits together, and use the highest bits as the bucket index.
     const Uint32 BitMixer = 0x9E3779B1u;
-    return (table->hash(key, table->data) * BitMixer) >> table->hash_shift;
+    return table->hash(key, table->data) * BitMixer;
 }
 
+static SDL_INLINE Uint32 get_probe_length(Uint32 zero_idx, Uint32 actual_idx, Uint32 num_buckets)
+{
+    // returns the probe sequence length from zero_idx to actual_idx
+
+    if (actual_idx < zero_idx) {
+        return num_buckets - zero_idx + actual_idx;
+    }
 
-bool SDL_InsertIntoHashTable(SDL_HashTable *table, const void *key, const void *value)
+    return actual_idx - zero_idx;
+}
+
+static SDL_HashItem *find_item(const SDL_HashTable *restrict ht, const void *key, Uint32 hash, Uint32 *restrict i, Uint32 *restrict probe_len)
 {
-    SDL_HashItem *item;
-    Uint32 hash;
+    Uint32 hash_mask = ht->hash_mask;
+    Uint32 max_probe_len = ht->max_probe_len;
 
-    if (!table) {
+    SDL_HashItem *table = ht->table;
+
+    for (;;) {
+        SDL_HashItem *item = table + *i;
+        Uint32 item_hash = item->hash;
+
+        if (!item->live) {
+            return NULL;
+        }
+
+        if (item_hash == hash && ht->keymatch(item->key, key, ht->data)) {
+            return item;
+        }
+
+        Uint32 item_probe_len = item->probe_len;
+        HT_ASSERT(item_probe_len == get_probe_length(item_hash & hash_mask, (Uint32)(item - table), hash_mask + 1));
+
+        if (*probe_len > item_probe_len) {
+            return NULL;
+        }
+
+        if (++*probe_len > max_probe_len) {
+            return NULL;
+        }
+
+        *i = (*i + 1) & hash_mask;
+    }
+}
+
+static SDL_HashItem *find_first_item(const SDL_HashTable *restrict ht, const void *key, Uint32 hash)
+{
+    Uint32 i = hash & ht->hash_mask;
+    Uint32 probe_len = 0;
+    return find_item(ht, key, hash, &i, &probe_len);
+}
+
+static SDL_HashItem *insert_item(SDL_HashItem *restrict item_to_insert, SDL_HashItem *restrict table, Uint32 hash_mask, Uint32 *max_probe_len_ptr)
+{
+    Uint32 idx = item_to_insert->hash & hash_mask;
+    SDL_HashItem temp_item, *target = NULL;
+    Uint32 num_buckets = hash_mask + 1;
+
+    for (;;) {
+        SDL_HashItem *candidate = table + idx;
+
+        if (!candidate->live) {
+            // Found an empty slot. Put it here and we're done.
+
+            *candidate = *item_to_insert;
+
+            if (target == NULL) {
+                target = candidate;
+            }
+
+            Uint32 probe_len = get_probe_length(candidate->hash & hash_mask, idx, num_buckets);
+            candidate->probe_len = probe_len;
+
+            if (*max_probe_len_ptr < probe_len) {
+                *max_probe_len_ptr = probe_len;
+            }
+
+            break;
+        }
+
+        Uint32 candidate_probe_len = candidate->probe_len;
+        HT_ASSERT(candidate_probe_len == get_probe_length(candidate->hash & hash_mask, idx, num_buckets));
+        Uint32 new_probe_len = get_probe_length(item_to_insert->hash & hash_mask, idx, num_buckets);
+
+        if (candidate_probe_len < new_probe_len) {
+            // Robin Hood hashing: the item at idx has a better probe length than our item would at this position.
+            // Evict it and put our item in its place, then continue looking for a new spot for the displaced item.
+            // This algorithm significantly reduces clustering in the table, making lookups take very few probes.
+
+            temp_item = *candidate;
+            *candidate = *item_to_insert;
+
+            if (target == NULL) {
+                target = candidate;
+            }
+
+            *item_to_insert = temp_item;
+
+            HT_ASSERT(new_probe_len == get_probe_length(candidate->hash & hash_mask, idx, num_buckets));
+            candidate->probe_len = new_probe_len;
+
+            if (*max_probe_len_ptr < new_probe_len) {
+                *max_probe_len_ptr = new_probe_len;
+            }
+        }
+
+        idx = (idx + 1) & hash_mask;
+    }
+
+    return target;
+}
+
+static void delete_item(SDL_HashTable *restrict ht, SDL_HashItem *item)
+{
+    Uint32 hash_mask = ht->hash_mask;
+    SDL_HashItem *table = ht->table;
+
+    if (ht->nuke) {
+        ht->nuke(item->key, item->value, ht->data);
+    }
+    ht->num_occupied_slots--;
+
+    Uint32 idx = (Uint32)(item - ht->table);
+
+    for (;;) {
+        idx = (idx + 1) & hash_mask;
+        SDL_HashItem *next_item = table + idx;
+
+        if (next_item->probe_len < 1) {
+            SDL_zerop(item);
+            return;
+        }
+
+        *item = *next_item;
+        item->probe_len -= 1;
+        HT_ASSERT(item->probe_len < ht->max_probe_len);
+        item = next_item;
+    }
+}
+
+static bool resize(SDL_HashTable *restrict ht, Uint32 new_size)
+{
+    SDL_HashItem *old_table = ht->table;
+    Uint32 old_size = ht->hash_mask + 1;
+    Uint32 new_hash_mask = new_size - 1;
+    SDL_HashItem *new_table = SDL_calloc(new_size, sizeof(*new_table));
+
+    if (!new_table) {
         return false;
     }
 
-    if ( (!table->stackable) && (SDL_FindInHashTable(table, key, NULL)) ) {
+    ht->max_probe_len = 0;
+    ht->hash_mask = new_hash_mask;
+    ht->table = new_table;
+
+    for (Uint32 i = 0; i < old_size; ++i) {
+        SDL_HashItem *item = old_table + i;
+        if (item->live) {
+            insert_item(item, new_table, new_hash_mask, &ht->max_probe_len);
+        }
+    }
+
+    SDL_free(old_table);
+    return true;
+}
+
+static bool maybe_resize(SDL_HashTable *restrict ht)
+{
+    Uint32 capacity = ht->hash_mask + 1;
+
+    if (capacity >= MAX_HASHTABLE_SIZE) {
         return false;
     }
 
-    // !!! FIXME: grow and rehash table if it gets too saturated.
-    item = (SDL_HashItem *) SDL_malloc(sizeof (SDL_HashItem));
-    if (!item) {
+    Uint32 max_load_factor = 217; // range: 0-255; 217 is roughly 85%
+    Uint32 resize_threshold = (max_load_factor * (Uint64)capacity) >> 8;
+
+    if (ht->num_occupied_slots > resize_threshold) {
+        return resize(ht, capacity * 2);
+    }
+
+    return true;
+}
+
+bool SDL_InsertIntoHashTable(SDL_HashTable *restrict table, const void *key, const void *value)
+{
+    SDL_HashItem *item;
+    Uint32 hash;
+
+    if (!table) {
         return false;
     }
 
     hash = calc_hash(table, key);
+    item = find_first_item(table, key, hash);
 
-    item->key = key;
-    item->value = value;
-    item->next = table->table[hash];
-    table->table[hash] = item;
+    if (item && !table->stackable) {
+        // TODO: Maybe allow overwrites? We could do it more efficiently here than unset followed by insert.
+        return false;
+    }
 
-    return true;
+    SDL_HashItem new_item;
+    new_item.key = key;
+    new_item.value = value;
+    new_item.hash = hash;
+    new_item.live = true;
+
+    table->num_occupied_slots++;
+
+    if (!maybe_resize(table)) {
+        table->num_occupied_slots--;
+        return false;
+    }
+
+    return insert_item(&new_item, table->table, table->hash_mask, &table->max_probe_len);
 }
 
 bool SDL_FindInHashTable(const SDL_HashTable *table, const void *key, const void **_value)
 {
     Uint32 hash;
-    void *data;
     SDL_HashItem *i;
 
     if (!table) {
@@ -123,104 +327,101 @@ bool SDL_FindInHashTable(const SDL_HashTable *table, const void *key, const void
     }
 
     hash = calc_hash(table, key);
-    data = table->data;
-
-    for (i = table->table[hash]; i; i = i->next) {
-        if (table->keymatch(key, i->key, data)) {
-            if (_value) {
-                *_value = i->value;
-            }
-            return true;
-        }
-    }
+    i = find_first_item(table, key, hash);
+    *_value = i ? i->value : NULL;
 
-    return false;
+    return i;
 }
 
 bool SDL_RemoveFromHashTable(SDL_HashTable *table, const void *key)
 {
     Uint32 hash;
-    SDL_HashItem *item = NULL;
-    SDL_HashItem *prev = NULL;
-    void *data;
+    SDL_HashItem *item;
 
     if (!table) {
         return false;
     }
 
-    hash = calc_hash(table, key);
-    data = table->data;
-
-    for (item = table->table[hash]; item; item = item->next) {
-        if (table->keymatch(key, item->key, data)) {
-            if (prev) {
-                prev->next = item->next;
-            } else {
-                table->table[hash] = item->next;
-            }
+    // FIXME: what to do for stacking hashtables?
+    // The original code removes just one item.
+    // This hashtable happens to preserve the insertion order of multi-value keys,
+    // so deleting the first one will always delete the least-recently inserted one.
+    // But maybe it makes more sense to remove all matching items?
 
-            if (table->nuke) {
-                table->nuke(item->key, item->value, data);
-            }
-            SDL_free(item);
-            return true;
-        }
+    hash = calc_hash(table, key);
+    item = find_first_item(table, key, hash);
 
-        prev = item;
+    if (!item) {
+        return false;
     }
 
-    return false;
+    delete_item(table, item);
+    return true;
 }
 
 bool SDL_IterateHashTableKey(const SDL_HashTable *table, const void *key, const void **_value, void **iter)
 {
-    SDL_HashItem *item;
+    SDL_HashItem *item = (SDL_HashItem *)*iter;
 
     if (!table) {
         return false;
     }
 
-    item = *iter ? ((SDL_HashItem *)*iter)->next : table->table[calc_hash(table, key)];
+    Uint32 i, probe_len, hash;
 
-    while (item) {
-        if (table->keymatch(key, item->key, table->data)) {
-            *_value = item->value;
-            *iter = item;
-            return true;
-        }
-        item = item->next;
+    if (item) {
+        HT_ASSERT(item >= table->table);
+        HT_ASSERT(item < table->table + (table->hash_mask + 1));
+
+        hash = item->hash;
+        probe_len = item->probe_len + 1;
+        i = ((Uint32)(item - table->table) + 1) & table->hash_mask;
+        item = table->table + i;
+    } else {
+        hash = calc_hash(table, key);
+        i = hash & table->hash_mask;
+        probe_len = 0;
     }
 
-    // no more matches.
-    *_value = NULL;
-    *iter = NULL;
-    return false;
+    item = find_item(table, key, hash, &i, &probe_len);
+
+    if (!item) {
+        *_value = NULL;
+        return false;
+    }
+
+    *_value = item->value;
+    *iter = item;
+
+    return true;
 }
 
 bool SDL_IterateHashTable(const SDL_HashTable *table, const void **_key, const void **_value, void **iter)
 {
-    SDL_HashItem *item = (SDL_HashItem *) *iter;
-    Uint32 idx = 0;
+    SDL_HashItem *item = (SDL_HashItem *)*iter;
 
     if (!table) {
         return false;
     }
 
-    if (item) {
-        const SDL_HashItem *orig = item;
-        item = item->next;
-        if (!item) {
-            idx = calc_hash(table, orig->key) + 1;  // !!! FIXME: we probably shouldn't rehash each time.
-        }
+    if (!item) {
+        item = table->table;
+    } else {
+        item++;
     }
 
-    while (!item && (idx < table->table_len)) {
-        item = table->table[idx++];  // skip empty buckets...
+    HT_ASSERT(item >= table->table);
+    SDL_HashItem *end = table->table + (table->hash_mask + 1);
+
+    while (item < end && !item->live) {
+        ++item;
     }
 
-    if (!item) {  // no more matches?
+    HT_ASSERT(item <= end);
+
+    if (item == end) {
         *_key = NULL;
-        *iter = NULL;
+        *_value = NULL;
         return false;
     }
 
@@ -233,44 +434,41 @@ bool SDL_IterateHashTable(const SDL_HashTable *table, const void **_key, const v
 
 bool SDL_HashTableEmpty(SDL_HashTable *table)
 {
-    if (table) {
-        Uint32 i;
+    return !(table && table->num_occupied_slots);
+}
 
-        for (i = 0; i < table->table_len; i++) {
-            SDL_HashItem *item = table->table[i];
-            if (item) {
-                return false;
-            }
+static void nuke_all(SDL_HashTable *restrict table)
+{
+    void *data = table->data;
+    SDL_HashItem *end = table->table + (table->hash_mask + 1);
+    SDL_HashItem *i;
+
+    for (i = table->table; i < end; ++i) {
+        if (i->live) {
+            table->nuke(i->key, i->value, data);
         }
     }
-    return true;
 }
 
-void SDL_EmptyHashTable(SDL_HashTable *table)
+void SDL_EmptyHashTable(SDL_HashTable *restrict table)
 {
     if (table) {
-        void *data = table->data;
-        Uint32 i;
-
-        for (i = 0; i < table->table_len; i++) {
-            SDL_HashItem *item = table->table[i];
-            while (item) {
-                SDL_HashItem *next = item->next;
-                if (table->nuke) {
-                    table->nuke(item->key, item->value, data);
-                }
-                SDL_free(item);
-                item = next;
-            }
-            table->table[i] = NULL;
+        if (table->nuke) {
+            nuke_all(table);
         }
+
+        SDL_memset(table->table, 0, sizeof(*table->table) * (table->hash_mask + 1));
+        table->num_occupied_slots = 0;
     }
 }
 
 void SDL_DestroyHashTable(SDL_HashTable *table)
 {
     if (table) {
-        SDL_EmptyHashTable(table);
+        if (table->nuke) {
+            nuke_all(table);
+        }
+
         SDL_free(table->table);
         SDL_free(table);
     }
@@ -298,13 +496,13 @@ bool SDL_KeyMatchString(const void *a, const void *b, void *data)
     const char *b_string = (const char *)b;
 
     if (a == b) {
-        return true;  // same pointer, must match.
+        return true; // same pointer, must match.
     } else if (!a || !b) {
-        return false;  // one pointer is NULL (and first test shows they aren't the same pointer), must not match.
+        return false; // one pointer is NULL (and first test shows they aren't the same pointer), must not match.
     } else if (a_string[0] != b_string[0]) {
-        return false;  // we know they don't match
+        return false; // we know they don't match
     }
-    return (SDL_strcmp(a_string, b_string) == 0);  // Check against actual string contents.
+    return (SDL_strcmp(a_string, b_string) == 0); // Check against actual string contents.
 }
 
 // We assume we can fit the ID in the key directly