Blob


1 /*
2 * Copyright (c) 2022 Stefan Sperling <stsp@openbsd.org>
3 *
4 * Permission to use, copy, modify, and distribute this software for any
5 * purpose with or without fee is hereby granted, provided that the above
6 * copyright notice and this permission notice appear in all copies.
7 *
8 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
9 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
11 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
14 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
15 */
17 #include <sys/types.h>
18 #include <sys/queue.h>
19 #include <sys/socket.h>
20 #include <sys/uio.h>
22 #include <errno.h>
23 #include <event.h>
24 #include <siphash.h>
25 #include <stdint.h>
26 #include <stdio.h>
27 #include <stdlib.h>
28 #include <string.h>
29 #include <imsg.h>
30 #include <limits.h>
31 #include <sha1.h>
32 #include <sha2.h>
33 #include <signal.h>
34 #include <unistd.h>
36 #include "got_error.h"
37 #include "got_path.h"
39 #include "gotd.h"
40 #include "log.h"
41 #include "listen.h"
43 #ifndef nitems
44 #define nitems(_a) (sizeof((_a)) / sizeof((_a)[0]))
45 #endif
47 struct gotd_listen_client {
48 STAILQ_ENTRY(gotd_listen_client) entry;
49 uint32_t id;
50 int fd;
51 uid_t euid;
52 };
53 STAILQ_HEAD(gotd_listen_clients, gotd_listen_client);
55 static struct gotd_listen_clients gotd_listen_clients[GOTD_CLIENT_TABLE_SIZE];
56 static SIPHASH_KEY clients_hash_key;
57 static volatile int listen_client_cnt;
58 static int inflight;
60 struct gotd_uid_connection_counter {
61 STAILQ_ENTRY(gotd_uid_connection_counter) entry;
62 uid_t euid;
63 int nconnections;
64 };
65 STAILQ_HEAD(gotd_client_uids, gotd_uid_connection_counter);
66 static struct gotd_client_uids gotd_client_uids[GOTD_CLIENT_TABLE_SIZE];
67 static SIPHASH_KEY uid_hash_key;
69 static struct {
70 pid_t pid;
71 const char *title;
72 int fd;
73 struct gotd_imsgev iev;
74 struct gotd_imsgev pause;
75 struct gotd_uid_connection_limit *connection_limits;
76 size_t nconnection_limits;
77 } gotd_listen;
79 static int inflight;
81 static void listen_shutdown(void);
83 static void
84 listen_sighdlr(int sig, short event, void *arg)
85 {
86 /*
87 * Normal signal handler rules don't apply because libevent
88 * decouples for us.
89 */
91 switch (sig) {
92 case SIGHUP:
93 break;
94 case SIGUSR1:
95 break;
96 case SIGTERM:
97 case SIGINT:
98 listen_shutdown();
99 /* NOTREACHED */
100 break;
101 default:
102 fatalx("unexpected signal");
106 static uint64_t
107 client_hash(uint32_t client_id)
109 return SipHash24(&clients_hash_key, &client_id, sizeof(client_id));
112 static void
113 add_client(struct gotd_listen_client *client)
115 uint64_t slot = client_hash(client->id) % nitems(gotd_listen_clients);
116 STAILQ_INSERT_HEAD(&gotd_listen_clients[slot], client, entry);
117 listen_client_cnt++;
120 static struct gotd_listen_client *
121 find_client(uint32_t client_id)
123 uint64_t slot;
124 struct gotd_listen_client *c;
126 slot = client_hash(client_id) % nitems(gotd_listen_clients);
127 STAILQ_FOREACH(c, &gotd_listen_clients[slot], entry) {
128 if (c->id == client_id)
129 return c;
132 return NULL;
135 static uint32_t
136 get_client_id(void)
138 int duplicate = 0;
139 uint32_t id;
141 do {
142 id = arc4random();
143 duplicate = (find_client(id) != NULL);
144 } while (duplicate || id == 0);
146 return id;
149 static uint64_t
150 uid_hash(uid_t euid)
152 return SipHash24(&uid_hash_key, &euid, sizeof(euid));
155 static void
156 add_uid_connection_counter(struct gotd_uid_connection_counter *counter)
158 uint64_t slot = uid_hash(counter->euid) % nitems(gotd_client_uids);
159 STAILQ_INSERT_HEAD(&gotd_client_uids[slot], counter, entry);
162 static void
163 remove_uid_connection_counter(struct gotd_uid_connection_counter *counter)
165 uint64_t slot = uid_hash(counter->euid) % nitems(gotd_client_uids);
166 STAILQ_REMOVE(&gotd_client_uids[slot], counter,
167 gotd_uid_connection_counter, entry);
170 static struct gotd_uid_connection_counter *
171 find_uid_connection_counter(uid_t euid)
173 uint64_t slot;
174 struct gotd_uid_connection_counter *c;
176 slot = uid_hash(euid) % nitems(gotd_client_uids);
177 STAILQ_FOREACH(c, &gotd_client_uids[slot], entry) {
178 if (c->euid == euid)
179 return c;
182 return NULL;
185 struct gotd_uid_connection_limit *
186 gotd_find_uid_connection_limit(struct gotd_uid_connection_limit *limits,
187 size_t nlimits, uid_t uid)
189 /* This array is always sorted to allow for binary search. */
190 int i, left = 0, right = nlimits - 1;
192 while (left <= right) {
193 i = ((left + right) / 2);
194 if (limits[i].uid == uid)
195 return &limits[i];
196 if (limits[i].uid > uid)
197 left = i + 1;
198 else
199 right = i - 1;
202 return NULL;
205 static const struct got_error *
206 disconnect(struct gotd_listen_client *client)
208 struct gotd_uid_connection_counter *counter;
209 uint64_t slot;
210 int client_fd;
212 log_debug("client on fd %d disconnecting", client->fd);
214 slot = client_hash(client->id) % nitems(gotd_listen_clients);
215 STAILQ_REMOVE(&gotd_listen_clients[slot], client,
216 gotd_listen_client, entry);
218 counter = find_uid_connection_counter(client->euid);
219 if (counter) {
220 if (counter->nconnections > 0)
221 counter->nconnections--;
222 if (counter->nconnections == 0) {
223 remove_uid_connection_counter(counter);
224 free(counter);
228 client_fd = client->fd;
229 free(client);
230 inflight--;
231 listen_client_cnt--;
232 if (close(client_fd) == -1)
233 return got_error_from_errno("close");
235 return NULL;
238 static int
239 accept_reserve(int fd, struct sockaddr *addr, socklen_t *addrlen,
240 int reserve, volatile int *counter)
242 int ret;
244 if (getdtablecount() + reserve +
245 ((*counter + 1) * GOTD_FD_NEEDED) >= getdtablesize()) {
246 log_debug("inflight fds exceeded");
247 errno = EMFILE;
248 return -1;
251 if ((ret = accept4(fd, addr, addrlen,
252 SOCK_NONBLOCK | SOCK_CLOEXEC)) > -1) {
253 (*counter)++;
256 return ret;
259 static void
260 gotd_accept_paused(int fd, short event, void *arg)
262 event_add(&gotd_listen.iev.ev, NULL);
265 static void
266 gotd_accept(int fd, short event, void *arg)
268 struct gotd_imsgev *iev = arg;
269 struct sockaddr_storage ss;
270 struct timeval backoff;
271 socklen_t len;
272 int s = -1;
273 struct gotd_listen_client *client = NULL;
274 struct gotd_uid_connection_counter *counter = NULL;
275 struct gotd_imsg_connect iconn;
276 uid_t euid;
277 gid_t egid;
279 backoff.tv_sec = 1;
280 backoff.tv_usec = 0;
282 if (event_add(&gotd_listen.iev.ev, NULL) == -1) {
283 log_warn("event_add");
284 return;
286 if (event & EV_TIMEOUT)
287 return;
289 len = sizeof(ss);
291 /* Other backoff conditions apart from EMFILE/ENFILE? */
292 s = accept_reserve(fd, (struct sockaddr *)&ss, &len, GOTD_FD_RESERVE,
293 &inflight);
294 if (s == -1) {
295 switch (errno) {
296 case EINTR:
297 case EWOULDBLOCK:
298 case ECONNABORTED:
299 return;
300 case EMFILE:
301 case ENFILE:
302 event_del(&gotd_listen.iev.ev);
303 evtimer_add(&gotd_listen.pause.ev, &backoff);
304 return;
305 default:
306 log_warn("accept");
307 return;
311 if (listen_client_cnt >= GOTD_MAXCLIENTS)
312 goto err;
314 if (getpeereid(s, &euid, &egid) == -1) {
315 log_warn("getpeerid");
316 goto err;
319 counter = find_uid_connection_counter(euid);
320 if (counter == NULL) {
321 counter = calloc(1, sizeof(*counter));
322 if (counter == NULL) {
323 log_warn("%s: calloc", __func__);
324 goto err;
326 counter->euid = euid;
327 counter->nconnections = 1;
328 add_uid_connection_counter(counter);
329 } else {
330 int max_connections = GOTD_MAX_CONN_PER_UID;
331 struct gotd_uid_connection_limit *limit;
333 limit = gotd_find_uid_connection_limit(
334 gotd_listen.connection_limits,
335 gotd_listen.nconnection_limits, euid);
336 if (limit)
337 max_connections = limit->max_connections;
339 if (counter->nconnections >= max_connections) {
340 log_warnx("maximum connections exceeded for uid %d",
341 euid);
342 goto err;
344 counter->nconnections++;
347 client = calloc(1, sizeof(*client));
348 if (client == NULL) {
349 log_warn("%s: calloc", __func__);
350 goto err;
352 client->id = get_client_id();
353 client->fd = s;
354 client->euid = euid;
355 s = -1;
356 add_client(client);
357 log_debug("%s: new client connected on fd %d uid %d gid %d", __func__,
358 client->fd, euid, egid);
360 memset(&iconn, 0, sizeof(iconn));
361 iconn.client_id = client->id;
362 iconn.euid = euid;
363 iconn.egid = egid;
364 s = dup(client->fd);
365 if (s == -1) {
366 log_warn("%s: dup", __func__);
367 goto err;
369 if (gotd_imsg_compose_event(iev, GOTD_IMSG_CONNECT, PROC_LISTEN, s,
370 &iconn, sizeof(iconn)) == -1) {
371 log_warn("imsg compose CONNECT");
372 goto err;
375 return;
376 err:
377 inflight--;
378 if (client)
379 disconnect(client);
380 if (s != -1)
381 close(s);
384 static const struct got_error *
385 recv_disconnect(struct imsg *imsg)
387 struct gotd_imsg_disconnect idisconnect;
388 size_t datalen;
389 struct gotd_listen_client *client = NULL;
391 datalen = imsg->hdr.len - IMSG_HEADER_SIZE;
392 if (datalen != sizeof(idisconnect))
393 return got_error(GOT_ERR_PRIVSEP_LEN);
394 memcpy(&idisconnect, imsg->data, sizeof(idisconnect));
396 log_debug("client disconnecting");
398 client = find_client(idisconnect.client_id);
399 if (client == NULL)
400 return got_error(GOT_ERR_CLIENT_ID);
402 return disconnect(client);
405 static void
406 listen_dispatch(int fd, short event, void *arg)
408 const struct got_error *err = NULL;
409 struct gotd_imsgev *iev = arg;
410 struct imsgbuf *ibuf = &iev->ibuf;
411 struct imsg imsg;
412 ssize_t n;
413 int shut = 0;
415 if (event & EV_READ) {
416 if ((n = imsg_read(ibuf)) == -1 && errno != EAGAIN)
417 fatal("imsg_read error");
418 if (n == 0) /* Connection closed. */
419 shut = 1;
422 if (event & EV_WRITE) {
423 n = msgbuf_write(&ibuf->w);
424 if (n == -1 && errno != EAGAIN)
425 fatal("msgbuf_write");
426 if (n == 0) /* Connection closed. */
427 shut = 1;
430 for (;;) {
431 if ((n = imsg_get(ibuf, &imsg)) == -1)
432 fatal("%s: imsg_get", __func__);
433 if (n == 0) /* No more messages. */
434 break;
436 switch (imsg.hdr.type) {
437 case GOTD_IMSG_DISCONNECT:
438 err = recv_disconnect(&imsg);
439 if (err)
440 log_warnx("disconnect: %s", err->msg);
441 break;
442 default:
443 log_debug("unexpected imsg %d", imsg.hdr.type);
444 break;
447 imsg_free(&imsg);
450 if (!shut) {
451 gotd_imsg_event_add(iev);
452 } else {
453 /* This pipe is dead. Remove its event handler */
454 event_del(&iev->ev);
455 event_loopexit(NULL);
459 void
460 listen_main(const char *title, int gotd_socket,
461 struct gotd_uid_connection_limit *connection_limits,
462 size_t nconnection_limits)
464 struct gotd_imsgev iev;
465 struct event evsigint, evsigterm, evsighup, evsigusr1;
467 arc4random_buf(&clients_hash_key, sizeof(clients_hash_key));
468 arc4random_buf(&uid_hash_key, sizeof(uid_hash_key));
470 gotd_listen.title = title;
471 gotd_listen.pid = getpid();
472 gotd_listen.fd = gotd_socket;
473 gotd_listen.connection_limits = connection_limits;
474 gotd_listen.nconnection_limits = nconnection_limits;
476 signal_set(&evsigint, SIGINT, listen_sighdlr, NULL);
477 signal_set(&evsigterm, SIGTERM, listen_sighdlr, NULL);
478 signal_set(&evsighup, SIGHUP, listen_sighdlr, NULL);
479 signal_set(&evsigusr1, SIGUSR1, listen_sighdlr, NULL);
480 signal(SIGPIPE, SIG_IGN);
482 signal_add(&evsigint, NULL);
483 signal_add(&evsigterm, NULL);
484 signal_add(&evsighup, NULL);
485 signal_add(&evsigusr1, NULL);
487 imsg_init(&iev.ibuf, GOTD_FILENO_MSG_PIPE);
488 iev.handler = listen_dispatch;
489 iev.events = EV_READ;
490 iev.handler_arg = NULL;
491 event_set(&iev.ev, iev.ibuf.fd, EV_READ, listen_dispatch, &iev);
492 if (event_add(&iev.ev, NULL) == -1)
493 fatalx("event add");
495 event_set(&gotd_listen.iev.ev, gotd_listen.fd, EV_READ | EV_PERSIST,
496 gotd_accept, &iev);
497 if (event_add(&gotd_listen.iev.ev, NULL))
498 fatalx("event add");
499 evtimer_set(&gotd_listen.pause.ev, gotd_accept_paused, NULL);
501 event_dispatch();
503 listen_shutdown();
506 static void
507 listen_shutdown(void)
509 log_debug("shutting down");
511 free(gotd_listen.connection_limits);
512 if (gotd_listen.fd != -1)
513 close(gotd_listen.fd);
515 exit(0);