From 6e8719ab2f646758c945200493d970a19da959bd Mon Sep 17 00:00:00 2001
From: "Ryan C. Gordon" <[EMAIL REDACTED]>
Date: Mon, 11 Sep 2023 01:28:06 -0400
Subject: [PATCH] Changed Wait functions.
Now they take a timeout, and there's a generic "wait for a list of sockets"
option to sleep the calling thread until new incoming data is available.
---
SDL_net.c | 274 ++++++++++++++++++++++++++++++++------------
SDL_net.h | 9 +-
examples/voipchat.c | 2 +-
3 files changed, 208 insertions(+), 77 deletions(-)
diff --git a/SDL_net.c b/SDL_net.c
index 9e43f54..2081efa 100644
--- a/SDL_net.c
+++ b/SDL_net.c
@@ -43,6 +43,14 @@ typedef socklen_t SockLen;
typedef struct sockaddr_storage AddressStorage;
#endif
+typedef enum SDLNet_SocketType
+{
+ SOCKETTYPE_STREAM,
+ SOCKETTYPE_DATAGRAM,
+ SOCKETTYPE_SERVER
+} SDLNet_SocketType;
+
+
const SDL_version *SDLNet_Linked_Version(void)
{
static const SDL_version linked_version = {
@@ -360,18 +368,33 @@ SDLNet_Address *SDLNet_ResolveHostname(const char *host)
return addr;
}
-int SDLNet_WaitForResolution(SDLNet_Address *addr)
+int SDLNet_WaitUntilResolved(SDLNet_Address *addr, Sint32 timeout)
{
if (!addr) {
return SDL_InvalidParamError("address"); // obviously nothing to wait for.
}
// we _could_ use a different lock for this, but this is Good Enough.
- SDL_LockMutex(resolver_lock);
- while (SDL_AtomicGet(&addr->status) == 0) {
- SDL_WaitCondition(resolver_condition, resolver_lock);
+
+ if (timeout) {
+ SDL_LockMutex(resolver_lock);
+ if (timeout < 0) {
+ while (SDL_AtomicGet(&addr->status) == 0) {
+ SDL_WaitCondition(resolver_condition, resolver_lock);
+ }
+ } else {
+ const Uint64 endtime = (SDL_GetTicks() + timeout);
+ SDL_LockMutex(resolver_lock);
+ while (SDL_AtomicGet(&addr->status) == 0) {
+ const Uint64 now = SDL_GetTicks();
+ if (now >= endtime) {
+ break;
+ }
+ SDL_WaitConditionTimeout(resolver_condition, resolver_lock, (endtime - now));
+ }
+ }
+ SDL_UnlockMutex(resolver_lock);
}
- SDL_UnlockMutex(resolver_lock);
return SDLNet_GetAddressStatus(addr); // so we set the error string if necessary.
}
@@ -495,6 +518,7 @@ static struct addrinfo *MakeAddrInfoWithPort(const SDLNet_Address *addr, const i
struct SDLNet_StreamSocket
{
+ SDLNet_SocketType socktype;
SDLNet_Address *addr;
Uint16 port;
Socket handle;
@@ -541,6 +565,7 @@ SDLNet_StreamSocket *SDLNet_CreateClient(SDLNet_Address *addr, Uint16 port)
return NULL;
}
+ sock->socktype = SOCKETTYPE_STREAM;
sock->addr = addr;
sock->port = port;
@@ -591,33 +616,21 @@ static int CheckClientConnection(SDLNet_StreamSocket *sock, int timeoutms)
if (!sock) {
return SDL_InvalidParamError("sock");
} else if (sock->status == 0) { // still pending?
- /*!!! FIXME: add this later? if (sock->simulated_failure_ticks) {
+ /*!!! FIXME: add this later?
+ if (sock->simulated_failure_ticks) {
if (SDL_GetTicks() >= sock->simulated_failure_ticks) {
sock->status = SDL_SetError("simulated failure");
- }
- } else*/ {
- struct pollfd pfd;
- SDL_zero(pfd);
- pfd.fd = sock->handle;
- pfd.events = POLLOUT;
- if (poll(&pfd, 1, timeoutms) == SOCKET_ERROR) {
- sock->status = SDL_SetError("Failed to poll socket: %s", strerror(errno));
- } else if ((pfd.revents & (POLLERR|POLLHUP|POLLNVAL)) != 0) {
- int err = 0;
- SockLen errsize = sizeof (err);
- getsockopt(sock->handle, SOL_SOCKET, SO_ERROR, (char*)&err, &errsize);
- sock->status = SDL_SetError("Socket failed to connect: %s", strerror(err));
- } else if (pfd.revents & POLLOUT) {
- sock->status = 1; // good to go!
- }
+ } else */
+ if (SDLNet_WaitUntilInputAvailable((void **) &sock, 1, timeoutms) == -1) {
+ sock->status = -1; // just abandon the whole enterprise.
}
}
return sock->status;
}
-int SDLNet_WaitForConnection(SDLNet_StreamSocket *sock)
+int SDLNet_WaitUntilConnected(SDLNet_StreamSocket *sock, Sint32 timeout)
{
- return CheckClientConnection(sock, -1); // infinite wait
+ return CheckClientConnection(sock, (int) timeout);
}
int SDLNet_GetConnectionStatus(SDLNet_StreamSocket *sock)
@@ -628,6 +641,7 @@ int SDLNet_GetConnectionStatus(SDLNet_StreamSocket *sock)
struct SDLNet_Server
{
+ SDLNet_SocketType socktype;
SDLNet_Address *addr; // bound to this address (NULL for any).
Uint16 port;
Socket handle;
@@ -646,6 +660,7 @@ SDLNet_Server *SDLNet_CreateServer(SDLNet_Address *addr, Uint16 port)
return NULL;
}
+ server->socktype = SOCKETTYPE_SERVER;
server->addr = addr;
server->port = port;
@@ -697,31 +712,6 @@ SDLNet_Server *SDLNet_CreateServer(SDLNet_Address *addr, Uint16 port)
return server;
}
-int SDLNet_WaitForClientConnection(SDLNet_Server *server)
-{
- if (!server) {
- return SDL_InvalidParamError("server");
- }
-
- struct pollfd pfd;
- SDL_zero(pfd);
- pfd.fd = server->handle;
- pfd.events = POLLIN;
- if (poll(&pfd, 1, -1) == SOCKET_ERROR) {
- return SDL_SetError("Failed to poll listen socket: %s", strerror(errno));
- } else if ((pfd.revents & (POLLERR|POLLHUP|POLLNVAL)) != 0) {
- int err = 0;
- SockLen errsize = sizeof (err);
- getsockopt(server->handle, SOL_SOCKET, SO_ERROR, (char*)&err, &errsize);
- return SDL_SetError("Listen socket has failed: %s", strerror(err));
- } else if (pfd.revents & POLLIN) {
- return 0;
- }
-
- SDL_assert(!"This shouldn't happen, right...?");
- return 0; // just pretend it's time to check again.
-}
-
int SDLNet_AcceptClient(SDLNet_Server *server, SDLNet_StreamSocket **client_stream)
{
if (!client_stream) {
@@ -794,6 +784,7 @@ int SDLNet_AcceptClient(SDLNet_Server *server, SDLNet_StreamSocket **client_stre
return SDL_OutOfMemory();
}
+ sock->socktype = SOCKETTYPE_STREAM;
sock->addr = fromaddr;
sock->port = (Uint16) SDL_atoi(portbuf);
sock->handle = handle;
@@ -919,32 +910,38 @@ int SDLNet_GetStreamSocketPendingWrites(SDLNet_StreamSocket *sock)
return sock->pending_output_len;
}
-int SDLNet_WaitForStreamPendingWrites(SDLNet_StreamSocket *sock)
+int SDLNet_WaitUntilStreamDrained(SDLNet_StreamSocket *sock, int timeoutms)
{
if (!sock) {
return SDL_InvalidParamError("sock");
}
- while (sock->pending_output_len > 0) {
- struct pollfd pfd;
- SDL_zero(pfd);
- pfd.fd = sock->handle;
- pfd.events = POLLOUT;
- if (poll(&pfd, 1, -1) == SOCKET_ERROR) {
- return SDL_SetError("Failed to poll socket: %s", strerror(errno));
- } else if ((pfd.revents & (POLLERR|POLLHUP|POLLNVAL)) != 0) {
- int err = 0;
- SockLen errsize = sizeof (err);
- getsockopt(sock->handle, SOL_SOCKET, SO_ERROR, (char*)&err, &errsize);
- return SDL_SetError("Socket has failed: %s", strerror(err));
- } else if (pfd.revents & POLLOUT) {
- if (PumpStreamSocket(sock) < 0) {
- return -1;
+ if (timeoutms != 0) {
+ const Uint64 endtime = (timeoutms > 0) ? (SDL_GetTicks() + timeoutms) : 0;
+ while (SDLNet_GetStreamSocketPendingWrites(sock) > 0) {
+ struct pollfd pfd;
+ SDL_zero(pfd);
+ pfd.fd = sock->handle;
+ pfd.events = POLLOUT;
+ const int rc = poll(&pfd, 1, timeoutms);
+ if (rc == SOCKET_ERROR) {
+ return SDL_SetError("Socket poll failed: %s", strerror(errno));
+ } else if (rc == 0) {
+ break; // timed out
}
+
+ if (timeoutms > 0) { // We must have woken up for a pending write, etc. Figure out remaining wait time.
+ const Uint64 now = SDL_GetTicks();
+ if (now < endtime) {
+ timeoutms = (int) (endtime - now);
+ } else {
+ break; // time has expired, break out.
+ }
+ } // else timeout is meant to be infinite, but we woke up for a write, etc, so go back to an infinite poll until we fail or buffer is drained.
}
}
- return 0;
+ return SDLNet_GetStreamSocketPendingWrites(sock);
}
int SDLNet_ReadStreamSocket(SDLNet_StreamSocket *sock, void *buf, int buflen)
@@ -1010,6 +1007,7 @@ void SDLNet_DestroyStreamSocket(SDLNet_StreamSocket *sock)
struct SDLNet_DatagramSocket
{
+ SDLNet_SocketType socktype;
SDLNet_Address *addr; // bound to this address (NULL for any).
Uint16 port;
Socket handle;
@@ -1036,6 +1034,7 @@ SDLNet_DatagramSocket *SDLNet_CreateDatagramSocket(SDLNet_Address *addr, Uint16
return NULL;
}
+ sock->socktype = SOCKETTYPE_DATAGRAM;
sock->addr = addr;
sock->port = port;
@@ -1336,6 +1335,137 @@ void SDLNet_DestroyDatagramSocket(SDLNet_DatagramSocket *sock)
}
}
+typedef union SDLNet_GenericSocket
+{
+ SDLNet_SocketType socktype;
+ SDLNet_StreamSocket stream;
+ SDLNet_DatagramSocket dgram;
+ SDLNet_Server server;
+} SDLNet_GenericSocket;
+
+
+int SDLNet_WaitUntilInputAvailable(void **vsockets, int numsockets, int timeoutms)
+{
+ SDLNet_GenericSocket **sockets = (SDLNet_GenericSocket **) vsockets;
+ if (!sockets) {
+ return SDL_InvalidParamError("sockets");
+ } else if (numsockets == 0) {
+ return 0;
+ }
+
+ struct pollfd stack_pfds[32];
+ struct pollfd *pfds = stack_pfds;
+ struct pollfd *malloced_pfds = NULL;
+
+ if (numsockets > SDL_arraysize(stack_pfds)) { // allocate if there's a _ton_ of these.
+ malloced_pfds = (struct pollfd *) SDL_malloc(numsockets * sizeof (*pfds));
+ if (!malloced_pfds) {
+ return SDL_OutOfMemory();
+ }
+ pfds = malloced_pfds;
+ }
+
+ int retval = 0;
+ const Uint64 endtime = (timeoutms > 0) ? (SDL_GetTicks() + timeoutms) : 0;
+
+ while (SDL_TRUE) {
+ SDL_memset(pfds, '\0', sizeof (*pfds) * numsockets);
+
+ for (int i = 0; i < numsockets; i++) {
+ SDLNet_GenericSocket *sock = sockets[i];
+ struct pollfd *pfd = &pfds[i];
+
+ switch (sock->socktype) {
+ case SOCKETTYPE_STREAM:
+ pfd->fd = sock->stream.handle;
+ if (sock->stream.status == 0) {
+ pfd->events = POLLOUT; // marked as writable when connection is complete.
+ } else if (sock->stream.pending_output_len > 0) {
+ pfd->events = POLLIN|POLLOUT; // poll for input or when we can write more of the pending buffer.
+ } else {
+ pfd->events = POLLIN; // poll for input or when we can write more of the pending buffer.
+ }
+ break;
+
+ case SOCKETTYPE_DATAGRAM:
+ pfd->fd = sock->dgram.handle;
+ if (sock->dgram.pending_output_len > 0) {
+ pfd->events = POLLIN|POLLOUT; // poll for input or when we can write more of the pending buffer.
+ } else {
+ pfd->events = POLLIN; // poll for input or when we can write more of the pending buffer.
+ }
+ break;
+
+ case SOCKETTYPE_SERVER:
+ pfd->fd = sock->server.handle;
+ pfd->events = POLLIN; // poll for new connections.
+ break;
+ }
+ }
+
+ const int rc = poll(pfds, numsockets, timeoutms);
+
+ if (rc == SOCKET_ERROR) {
+ SDL_free(malloced_pfds);
+ return SDL_SetError("Socket poll failed: %s", strerror(errno));
+ }
+
+ for (int i = 0; i < numsockets; i++) {
+ SDLNet_GenericSocket *sock = sockets[i];
+ const struct pollfd *pfd = &pfds[i];
+ const SDL_bool failed = ((pfd->revents & (POLLERR|POLLHUP|POLLNVAL)) != 0) ? SDL_TRUE : SDL_FALSE;
+ const SDL_bool writable = (pfd->revents & POLLOUT) ? SDL_TRUE : SDL_FALSE;
+ const SDL_bool readable = (pfd->revents & POLLIN) ? SDL_TRUE : SDL_FALSE;
+
+ if (readable || failed) {
+ retval++;
+ }
+
+ switch (sock->socktype) {
+ case SOCKETTYPE_STREAM:
+ if (sock->stream.status == 0) {
+ if (failed) {
+ int err = 0;
+ SockLen errsize = sizeof (err);
+ getsockopt(pfd->fd, SOL_SOCKET, SO_ERROR, (char*)&err, &errsize);
+ sock->stream.status = SDL_SetError("Socket failed to connect: %s", strerror(err));
+ } else if (writable) {
+ sock->stream.status = 1;
+ }
+ } else if (writable) {
+ PumpStreamSocket(&sock->stream);
+ }
+ break;
+
+ case SOCKETTYPE_DATAGRAM:
+ if (writable) {
+ PumpDatagramSocket(&sock->dgram);
+ }
+ break;
+
+ case SOCKETTYPE_SERVER:
+ // we already checked `readable`.
+ break;
+ }
+ }
+
+ if ((retval > 0) || (endtime == 0)) {
+ break; // something has input available, or we are doing a no-block poll.
+ } else if (timeoutms > 0) { // We must have woken up for a pending write, etc. Figure out remaining wait time.
+ const Uint64 now = SDL_GetTicks();
+ if (now < endtime) {
+ timeoutms = (int) (endtime - now);
+ } else {
+ break; // time has expired, break out.
+ }
+ } // else timeout is meant to be infinite, but we woke up for a write, etc, so go back to an infinite poll.
+ }
+
+ SDL_free(malloced_pfds);
+
+ return retval;
+}
+
#if 0 // some test code.
#include <stdio.h>
@@ -1351,13 +1481,13 @@ int main(int argc, char **argv)
SDLNet_StreamSocket *stream;
if (argc > 1) {
SDLNet_Server *server = SDLNet_CreateServer(NULL, 7997);
- SDLNet_WaitForClientConnection(server);
+ SDLNet_WaitUntilConnected(server);
SDLNet_AcceptClient(server, &stream);
} else {
SDLNet_Address *addr = SDLNet_ResolveHostname("localhost");
- SDLNet_WaitForResolution(addr);
+ SDLNet_WaitUntilResolved(addr, -1);
stream = SDLNet_CreateClient(addr, 7997);
- SDLNet_WaitForConnection(stream);
+ SDLNet_WaitUntilConnected(stream, -1);
}
printf("\n\nConnected!\n\n");
@@ -1382,7 +1512,7 @@ int main(int argc, char **argv)
for (int i = 1; i < argc; i++) {
SDL_Log("Looking up %s ...", argv[i]);
SDLNet_Address *addr = SDLNet_ResolveHostname(argv[i]);
- if (SDLNet_WaitForResolution(addr) == -1) {
+ if (SDLNet_WaitUntilResolved(addr, -1) == -1) {
SDL_Log("Failed to lookup %s: %s", argv[i], SDL_GetError());
} else {
SDL_Log("%s is %s", argv[i], SDLNet_GetAddressString(addr));
@@ -1390,7 +1520,7 @@ int main(int argc, char **argv)
SDLNet_StreamSocket *sock = SDLNet_CreateClient(addr, 80);
if (!sock) {
SDL_Log("Failed to create stream socket to %s: %s\n", argv[i], SDL_GetError());
- } else if (SDLNet_WaitForConnection(sock) < 0) {
+ } else if (SDLNet_WaitUntilConnected(sock, -1) < 0) {
SDL_Log("Failed to connect to %s: %s", argv[i], SDL_GetError());
} else if (SDLNet_WriteToStreamSocket(sock, req, SDL_strlen(req)) < 0) {
SDL_Log("Failed to write to %s: %s", argv[i], SDL_GetError());
@@ -1423,7 +1553,7 @@ int main(int argc, char **argv)
}
for (int i = 1; i < argc; i++) {
- SDLNet_WaitForResolution(addrs[i]);
+ SDLNet_WaitUntilResolved(addrs[i], -1);
}
#endif
diff --git a/SDL_net.h b/SDL_net.h
index dfa4c5f..ed488ce 100644
--- a/SDL_net.h
+++ b/SDL_net.h
@@ -26,7 +26,7 @@ extern DECLSPEC void SDLCALL SDLNet_Quit(void);
typedef struct SDLNet_Address SDLNet_Address;
extern DECLSPEC SDLNet_Address * SDLCALL SDLNet_ResolveHostname(const char *host); /* does not block! */
-extern DECLSPEC int SDLCALL SDLNet_WaitForResolution(SDLNet_Address *address); /* blocks until success or failure. Optional. */
+extern DECLSPEC int SDLCALL SDLNet_WaitUntilResolved(SDLNet_Address *address, Sint32 timeout); /* blocks until success or failure. Optional. timeout: 0: check once and don't block, -1: block until there's a definite answer, else: block for `timeout` milliseconds. */
extern DECLSPEC int SDLCALL SDLNet_GetAddressStatus(SDLNet_Address *address); /* 0: still working, -1: failed (check SDL_GetError), 1: ready */
extern DECLSPEC const char * SDLCALL SDLNet_GetAddressString(SDLNet_Address *address); /* human-readable string, like "127.0.0.1" or "::1" or whatever. NULL if GetAddressStatus != 1. String owned by address! */
extern DECLSPEC SDLNet_Address *SDLCALL SDLNet_RefAddress(SDLNet_Address *address); /* +1 refcount; SDLNet_ResolveHost starts at 1. Returns `address` for convenience. */
@@ -42,12 +42,11 @@ typedef struct SDLNet_StreamSocket SDLNet_StreamSocket; /* a TCP socket. Reliab
/* Clients connect to servers, and then send/receive data on a stream socket. */
extern DECLSPEC SDLNet_StreamSocket * SDLCALL SDLNet_CreateClient(SDLNet_Address *address, Uint16 port); /* Start connection to address:port. does not block! */
-extern DECLSPEC int SDLCALL SDLNet_WaitForConnection(SDLNet_StreamSocket *sock); /* blocks until success or failure. Optional. */
+extern DECLSPEC int SDLCALL SDLNet_WaitUntilConnected(SDLNet_StreamSocket *sock, Sint32 timeout); /* blocks until success or failure. Optional. timeout: 0: check once and don't block, -1: block until there's a definite answer, else: block for `timeout` milliseconds. */
/* Servers listen for and accept connections from clients, and then send/receive data on a stream socket. */
typedef struct SDLNet_Server SDLNet_Server; /* a listen socket internally. Binds to a port, accepts connections. */
extern DECLSPEC SDLNet_Server * SDLCALL SDLNet_CreateServer(SDLNet_Address *addr, Uint16 port); /* Specify NULL for any/all interfaces, or something from GetLocalAddresses */
-extern DECLSPEC int SDLCALL SDLNet_WaitForClientConnection(SDLNet_Server *server); /* blocks until a client is ready for to be accepted or there's a serious error. Optional. */
extern DECLSPEC int SDLCALL SDLNet_AcceptClient(SDLNet_Server *server, SDLNet_StreamSocket **client_stream); /* Accept pending connection. Does not block, returns 0 and sets *client_stream=NULL if none available. -1 on errors, zero otherwise. */
extern DECLSPEC void SDLCALL SDLNet_DestroyServer(SDLNet_Server *server);
@@ -56,7 +55,7 @@ extern DECLSPEC SDLNet_Address * SDLCALL SDLNet_GetStreamSocketAddress(SDLNet_St
extern DECLSPEC int SDLCALL SDLNet_GetConnectionStatus(SDLNet_StreamSocket *sock); /* -1: connecting, 0: failed/dropped (check SDL_GetError), 1: okay */
extern DECLSPEC int SDLCALL SDLNet_WriteToStreamSocket(SDLNet_StreamSocket *sock, const void *buf, int buflen); /* always queues what it can't send immediately. Does not block, -1 on out of memory, dead socket, etc */
extern DECLSPEC int SDLCALL SDLNet_GetStreamSocketPendingWrites(SDLNet_StreamSocket *sock); /* returns number of bytes still pending to write, or -1 on dead socket, etc. 0 if no data pending to send. */
-extern DECLSPEC int SDLCALL SDLNet_WaitForStreamPendingWrites(SDLNet_StreamSocket *sock); /* blocks until all pending data is sent. returns 0 on success, -1 on dead socket, etc. Optional. */
+extern DECLSPEC int SDLCALL SDLNet_WaitUntilStreamSocketDrained(SDLNet_StreamSocket *sock, int timeoutms); /* blocks until all pending data is sent. returns 0 on success, -1 on dead socket, bytes remaining to send on timeout. Optional. timeout: 0: check once and don't block, -1: block until all sent or error, else: block for `timeout` milliseconds. */
extern DECLSPEC int SDLCALL SDLNet_ReadStreamSocket(SDLNet_StreamSocket *sock, void *buf, int buflen); /* read up to buflen bytes. Does not block, -1 on dead socket, etc, 0 if no data available. */
extern DECLSPEC void SDLCALL SDLNet_SimulateStreamPacketLoss(SDLNet_StreamSocket *sock, int percent_loss); /* since streams are reliable, this holds back data and connections for some amount of time, and maybe even drops connections. */
extern DECLSPEC void SDLCALL SDLNet_DestroyStreamSocket(SDLNet_StreamSocket *sock); /* Destroy your sockets when finished with them. Does not block, handles shutdown internally. */
@@ -83,3 +82,5 @@ extern DECLSPEC void SDLCALL SDLNet_SimulateDatagramPacketLoss(SDLNet_DatagramSo
extern DECLSPEC void SDLCALL SDLNet_DestroyDatagramSocket(SDLNet_DatagramSocket *sock); /* Destroy your sockets when finished with them. Does not block. */
+extern DECLSPEC int SDLCALL SDLNet_WaitUntilInputAvailable(void **vsockets, int numsockets, int timeoutms); /* put thread to sleep until one of sockets has new input. Optional. Returns > 0 if something is ready, -1 on error, 0 on timeout. timeout: 0: check once and don't block, -1: block until there's a definite answer, else: block for `timeout` milliseconds. */
+
diff --git a/examples/voipchat.c b/examples/voipchat.c
index 7a2d7fc..3c18855 100644
--- a/examples/voipchat.c
+++ b/examples/voipchat.c
@@ -282,7 +282,7 @@ static void run_voipchat(int argc, char **argv)
SDL_Log("CLIENT: Resolving server hostname '%s' ...", hostname);
server_addr = SDLNet_ResolveHostname(hostname);
if (server_addr) {
- if (SDLNet_WaitForResolution(server_addr) < 0) {
+ if (SDLNet_WaitUntilResolved(server_addr, -1) < 0) {
SDLNet_UnrefAddress(server_addr);
server_addr = NULL;
}