commit 7a0564e3ba8d55d4f066d3ba0f35ff64fd6a8d60 from: Stefan Sperling date: Fri Dec 30 22:30:18 2022 UTC enforce a per-uid connection limit in the gotd listen process For now the limit is set at compile-time. It will become configurable via gotd.conf soon. ok op@ commit - b1b2091b92cf99c8f0fe87488f2757f4d712e094 commit + 7a0564e3ba8d55d4f066d3ba0f35ff64fd6a8d60 blob - fe0c0107c18a0d38f7565091ea1d3badb537969e blob + d17cc8a6c43063e47af1fb6b1ce517a12eb909c5 --- gotd/gotd.h +++ gotd/gotd.h @@ -23,6 +23,7 @@ #define GOTD_EMPTY_PATH "/var/empty" #define GOTD_MAXCLIENTS 1024 +#define GOTD_MAX_CONN_PER_UID 4 #define GOTD_FD_RESERVE 5 #define GOTD_FD_NEEDED 6 #define GOTD_FILENO_MSG_PIPE 3 blob - 075110e34e7f07475aa0f1df07f74fac98521044 blob + bcc891f386e858648339da2b211be2df91530d4b --- gotd/listen.c +++ gotd/listen.c @@ -46,6 +46,7 @@ struct gotd_listen_client { STAILQ_ENTRY(gotd_listen_client) entry; uint32_t id; int fd; + uid_t euid; }; STAILQ_HEAD(gotd_listen_clients, gotd_listen_client); @@ -54,6 +55,15 @@ static SIPHASH_KEY clients_hash_key; static volatile int listen_client_cnt; static int inflight; +struct gotd_uid_connection_counter { + STAILQ_ENTRY(gotd_uid_connection_counter) entry; + uid_t euid; + int nconnections; +}; +STAILQ_HEAD(gotd_client_uids, gotd_uid_connection_counter); +static struct gotd_client_uids gotd_client_uids[GOTD_CLIENT_TABLE_SIZE]; +static SIPHASH_KEY uid_hash_key; + static struct { pid_t pid; const char *title; @@ -130,11 +140,48 @@ get_client_id(void) } while (duplicate || id == 0); return id; +} + +static uint64_t +uid_hash(uid_t euid) +{ + return SipHash24(&uid_hash_key, &euid, sizeof(euid)); +} + +static void +add_uid_connection_counter(struct gotd_uid_connection_counter *counter) +{ + uint64_t slot = uid_hash(counter->euid) % nitems(gotd_client_uids); + STAILQ_INSERT_HEAD(&gotd_client_uids[slot], counter, entry); +} + +static void +remove_uid_connection_counter(struct gotd_uid_connection_counter *counter) +{ + uint64_t slot = uid_hash(counter->euid) % nitems(gotd_client_uids); + STAILQ_REMOVE(&gotd_client_uids[slot], counter, + gotd_uid_connection_counter, entry); +} + +static struct gotd_uid_connection_counter * +find_uid_connection_counter(uid_t euid) +{ + uint64_t slot; + struct gotd_uid_connection_counter *c; + + slot = uid_hash(euid) % nitems(gotd_client_uids); + STAILQ_FOREACH(c, &gotd_client_uids[slot], entry) { + if (c->euid == euid) + return c; + } + + return NULL; } static const struct got_error * disconnect(struct gotd_listen_client *client) { + struct gotd_uid_connection_counter *counter; uint64_t slot; int client_fd; @@ -143,6 +190,17 @@ disconnect(struct gotd_listen_client *client) slot = client_hash(client->id) % nitems(gotd_listen_clients); STAILQ_REMOVE(&gotd_listen_clients[slot], client, gotd_listen_client, entry); + + counter = find_uid_connection_counter(client->euid); + if (counter) { + if (counter->nconnections > 0) + counter->nconnections--; + if (counter->nconnections == 0) { + remove_uid_connection_counter(counter); + free(counter); + } + } + client_fd = client->fd; free(client); inflight--; @@ -189,6 +247,7 @@ gotd_accept(int fd, short event, void *arg) socklen_t len; int s = -1; struct gotd_listen_client *client = NULL; + struct gotd_uid_connection_counter *counter = NULL; struct gotd_imsg_connect iconn; uid_t euid; gid_t egid; @@ -233,6 +292,25 @@ gotd_accept(int fd, short event, void *arg) goto err; } + counter = find_uid_connection_counter(euid); + if (counter == NULL) { + counter = calloc(1, sizeof(*counter)); + if (counter == NULL) { + log_warn("%s: calloc", __func__); + goto err; + } + counter->euid = euid; + counter->nconnections = 1; + add_uid_connection_counter(counter); + } else { + if (counter->nconnections >= GOTD_MAX_CONN_PER_UID) { + log_warnx("maximum connections exceeded for uid %d", + euid); + goto err; + } + counter->nconnections++; + } + client = calloc(1, sizeof(*client)); if (client == NULL) { log_warn("%s: calloc", __func__); @@ -240,6 +318,7 @@ gotd_accept(int fd, short event, void *arg) } client->id = get_client_id(); client->fd = s; + client->euid = euid; s = -1; add_client(client); log_debug("%s: new client connected on fd %d uid %d gid %d", __func__, @@ -353,6 +432,7 @@ listen_main(const char *title, int gotd_socket) struct event evsigint, evsigterm, evsighup, evsigusr1; arc4random_buf(&clients_hash_key, sizeof(clients_hash_key)); + arc4random_buf(&uid_hash_key, sizeof(uid_hash_key)); gotd_listen.title = title; gotd_listen.pid = getpid();