Commit Diff


commit - b1b2091b92cf99c8f0fe87488f2757f4d712e094
commit + 7a0564e3ba8d55d4f066d3ba0f35ff64fd6a8d60
blob - fe0c0107c18a0d38f7565091ea1d3badb537969e
blob + d17cc8a6c43063e47af1fb6b1ce517a12eb909c5
--- gotd/gotd.h
+++ gotd/gotd.h
@@ -23,6 +23,7 @@
 #define GOTD_EMPTY_PATH	"/var/empty"
 
 #define GOTD_MAXCLIENTS		1024
+#define GOTD_MAX_CONN_PER_UID	4
 #define GOTD_FD_RESERVE		5
 #define GOTD_FD_NEEDED		6
 #define GOTD_FILENO_MSG_PIPE	3
blob - 075110e34e7f07475aa0f1df07f74fac98521044
blob + bcc891f386e858648339da2b211be2df91530d4b
--- gotd/listen.c
+++ gotd/listen.c
@@ -46,6 +46,7 @@ struct gotd_listen_client {
 	STAILQ_ENTRY(gotd_listen_client)	 entry;
 	uint32_t			 id;
 	int				 fd;
+	uid_t				 euid;
 };
 STAILQ_HEAD(gotd_listen_clients, gotd_listen_client);
 
@@ -54,6 +55,15 @@ static SIPHASH_KEY clients_hash_key;
 static volatile int listen_client_cnt;
 static int inflight;
 
+struct gotd_uid_connection_counter {
+	STAILQ_ENTRY(gotd_uid_connection_counter) entry;
+	uid_t euid;
+	int nconnections;
+};
+STAILQ_HEAD(gotd_client_uids, gotd_uid_connection_counter);
+static struct gotd_client_uids gotd_client_uids[GOTD_CLIENT_TABLE_SIZE];
+static SIPHASH_KEY uid_hash_key;
+
 static struct {
 	pid_t pid;
 	const char *title;
@@ -130,11 +140,48 @@ get_client_id(void)
 	} while (duplicate || id == 0);
 
 	return id;
+}
+
+static uint64_t
+uid_hash(uid_t euid)
+{
+	return SipHash24(&uid_hash_key, &euid, sizeof(euid));
+}
+
+static void
+add_uid_connection_counter(struct gotd_uid_connection_counter *counter)
+{
+	uint64_t slot = uid_hash(counter->euid) % nitems(gotd_client_uids);
+	STAILQ_INSERT_HEAD(&gotd_client_uids[slot], counter, entry);
+}
+
+static void
+remove_uid_connection_counter(struct gotd_uid_connection_counter *counter)
+{
+	uint64_t slot = uid_hash(counter->euid) % nitems(gotd_client_uids);
+	STAILQ_REMOVE(&gotd_client_uids[slot], counter,
+	    gotd_uid_connection_counter, entry);
+}
+
+static struct gotd_uid_connection_counter *
+find_uid_connection_counter(uid_t euid)
+{
+	uint64_t slot;
+	struct gotd_uid_connection_counter *c;
+
+	slot = uid_hash(euid) % nitems(gotd_client_uids);
+	STAILQ_FOREACH(c, &gotd_client_uids[slot], entry) {
+		if (c->euid == euid)
+			return c;
+	}
+
+	return NULL;
 }
 
 static const struct got_error *
 disconnect(struct gotd_listen_client *client)
 {
+	struct gotd_uid_connection_counter *counter;
 	uint64_t slot;
 	int client_fd;
 
@@ -143,6 +190,17 @@ disconnect(struct gotd_listen_client *client)
 	slot = client_hash(client->id) % nitems(gotd_listen_clients);
 	STAILQ_REMOVE(&gotd_listen_clients[slot], client,
 	    gotd_listen_client, entry);
+
+	counter = find_uid_connection_counter(client->euid);
+	if (counter) {
+		if (counter->nconnections > 0)
+			counter->nconnections--;
+		if (counter->nconnections == 0) {
+			remove_uid_connection_counter(counter);
+			free(counter);
+		}
+	}
+
 	client_fd = client->fd;
 	free(client);
 	inflight--;
@@ -189,6 +247,7 @@ gotd_accept(int fd, short event, void *arg)
 	socklen_t len;
 	int s = -1;
 	struct gotd_listen_client *client = NULL;
+	struct gotd_uid_connection_counter *counter = NULL;
 	struct gotd_imsg_connect iconn;
 	uid_t euid;
 	gid_t egid;
@@ -233,6 +292,25 @@ gotd_accept(int fd, short event, void *arg)
 		goto err;
 	}
 
+	counter = find_uid_connection_counter(euid);
+	if (counter == NULL) {
+		counter = calloc(1, sizeof(*counter));
+		if (counter == NULL) {
+			log_warn("%s: calloc", __func__);
+			goto err;
+		}
+		counter->euid = euid;
+		counter->nconnections = 1;
+		add_uid_connection_counter(counter);
+	} else {
+		if (counter->nconnections >= GOTD_MAX_CONN_PER_UID) {
+			log_warnx("maximum connections exceeded for uid %d",
+			    euid);
+			goto err;
+		}
+		counter->nconnections++;
+	}
+
 	client = calloc(1, sizeof(*client));
 	if (client == NULL) {
 		log_warn("%s: calloc", __func__);
@@ -240,6 +318,7 @@ gotd_accept(int fd, short event, void *arg)
 	}
 	client->id = get_client_id();
 	client->fd = s;
+	client->euid = euid;
 	s = -1;
 	add_client(client);
 	log_debug("%s: new client connected on fd %d uid %d gid %d", __func__,
@@ -353,6 +432,7 @@ listen_main(const char *title, int gotd_socket)
 	struct event evsigint, evsigterm, evsighup, evsigusr1;
 
 	arc4random_buf(&clients_hash_key, sizeof(clients_hash_key));
+	arc4random_buf(&uid_hash_key, sizeof(uid_hash_key));
 
 	gotd_listen.title = title;
 	gotd_listen.pid = getpid();