Blob


1 /*
2 * Copyright (c) 2022 Stefan Sperling <stsp@openbsd.org>
3 *
4 * Permission to use, copy, modify, and distribute this software for any
5 * purpose with or without fee is hereby granted, provided that the above
6 * copyright notice and this permission notice appear in all copies.
7 *
8 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
9 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
11 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
14 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
15 */
17 #include <sys/types.h>
18 #include <sys/queue.h>
19 #include <sys/socket.h>
20 #include <sys/uio.h>
22 #include <errno.h>
23 #include <event.h>
24 #include <siphash.h>
25 #include <stdint.h>
26 #include <stdio.h>
27 #include <stdlib.h>
28 #include <string.h>
29 #include <imsg.h>
30 #include <limits.h>
31 #include <signal.h>
32 #include <unistd.h>
34 #include "got_error.h"
35 #include "got_path.h"
37 #include "gotd.h"
38 #include "log.h"
39 #include "listen.h"
41 #ifndef nitems
42 #define nitems(_a) (sizeof((_a)) / sizeof((_a)[0]))
43 #endif
45 struct gotd_listen_client {
46 STAILQ_ENTRY(gotd_listen_client) entry;
47 uint32_t id;
48 int fd;
49 uid_t euid;
50 };
51 STAILQ_HEAD(gotd_listen_clients, gotd_listen_client);
53 static struct gotd_listen_clients gotd_listen_clients[GOTD_CLIENT_TABLE_SIZE];
54 static SIPHASH_KEY clients_hash_key;
55 static volatile int listen_client_cnt;
56 static int inflight;
58 struct gotd_uid_connection_counter {
59 STAILQ_ENTRY(gotd_uid_connection_counter) entry;
60 uid_t euid;
61 int nconnections;
62 };
63 STAILQ_HEAD(gotd_client_uids, gotd_uid_connection_counter);
64 static struct gotd_client_uids gotd_client_uids[GOTD_CLIENT_TABLE_SIZE];
65 static SIPHASH_KEY uid_hash_key;
67 static struct {
68 pid_t pid;
69 const char *title;
70 int fd;
71 struct gotd_imsgev iev;
72 struct gotd_imsgev pause;
73 struct gotd_uid_connection_limit *connection_limits;
74 size_t nconnection_limits;
75 } gotd_listen;
77 static int inflight;
79 static void listen_shutdown(void);
81 static void
82 listen_sighdlr(int sig, short event, void *arg)
83 {
84 /*
85 * Normal signal handler rules don't apply because libevent
86 * decouples for us.
87 */
89 switch (sig) {
90 case SIGHUP:
91 break;
92 case SIGUSR1:
93 break;
94 case SIGTERM:
95 case SIGINT:
96 listen_shutdown();
97 /* NOTREACHED */
98 break;
99 default:
100 fatalx("unexpected signal");
104 static uint64_t
105 client_hash(uint32_t client_id)
107 return SipHash24(&clients_hash_key, &client_id, sizeof(client_id));
110 static void
111 add_client(struct gotd_listen_client *client)
113 uint64_t slot = client_hash(client->id) % nitems(gotd_listen_clients);
114 STAILQ_INSERT_HEAD(&gotd_listen_clients[slot], client, entry);
115 listen_client_cnt++;
118 static struct gotd_listen_client *
119 find_client(uint32_t client_id)
121 uint64_t slot;
122 struct gotd_listen_client *c;
124 slot = client_hash(client_id) % nitems(gotd_listen_clients);
125 STAILQ_FOREACH(c, &gotd_listen_clients[slot], entry) {
126 if (c->id == client_id)
127 return c;
130 return NULL;
133 static uint32_t
134 get_client_id(void)
136 int duplicate = 0;
137 uint32_t id;
139 do {
140 id = arc4random();
141 duplicate = (find_client(id) != NULL);
142 } while (duplicate || id == 0);
144 return id;
147 static uint64_t
148 uid_hash(uid_t euid)
150 return SipHash24(&uid_hash_key, &euid, sizeof(euid));
153 static void
154 add_uid_connection_counter(struct gotd_uid_connection_counter *counter)
156 uint64_t slot = uid_hash(counter->euid) % nitems(gotd_client_uids);
157 STAILQ_INSERT_HEAD(&gotd_client_uids[slot], counter, entry);
160 static void
161 remove_uid_connection_counter(struct gotd_uid_connection_counter *counter)
163 uint64_t slot = uid_hash(counter->euid) % nitems(gotd_client_uids);
164 STAILQ_REMOVE(&gotd_client_uids[slot], counter,
165 gotd_uid_connection_counter, entry);
168 static struct gotd_uid_connection_counter *
169 find_uid_connection_counter(uid_t euid)
171 uint64_t slot;
172 struct gotd_uid_connection_counter *c;
174 slot = uid_hash(euid) % nitems(gotd_client_uids);
175 STAILQ_FOREACH(c, &gotd_client_uids[slot], entry) {
176 if (c->euid == euid)
177 return c;
180 return NULL;
183 static const struct got_error *
184 disconnect(struct gotd_listen_client *client)
186 struct gotd_uid_connection_counter *counter;
187 uint64_t slot;
188 int client_fd;
190 log_debug("client on fd %d disconnecting", client->fd);
192 slot = client_hash(client->id) % nitems(gotd_listen_clients);
193 STAILQ_REMOVE(&gotd_listen_clients[slot], client,
194 gotd_listen_client, entry);
196 counter = find_uid_connection_counter(client->euid);
197 if (counter) {
198 if (counter->nconnections > 0)
199 counter->nconnections--;
200 if (counter->nconnections == 0) {
201 remove_uid_connection_counter(counter);
202 free(counter);
206 client_fd = client->fd;
207 free(client);
208 inflight--;
209 listen_client_cnt--;
210 if (close(client_fd) == -1)
211 return got_error_from_errno("close");
213 return NULL;
216 static int
217 accept_reserve(int fd, struct sockaddr *addr, socklen_t *addrlen,
218 int reserve, volatile int *counter)
220 int ret;
222 if (getdtablecount() + reserve +
223 ((*counter + 1) * GOTD_FD_NEEDED) >= getdtablesize()) {
224 log_debug("inflight fds exceeded");
225 errno = EMFILE;
226 return -1;
229 if ((ret = accept4(fd, addr, addrlen,
230 SOCK_NONBLOCK | SOCK_CLOEXEC)) > -1) {
231 (*counter)++;
234 return ret;
237 static void
238 gotd_accept_paused(int fd, short event, void *arg)
240 event_add(&gotd_listen.iev.ev, NULL);
243 static void
244 gotd_accept(int fd, short event, void *arg)
246 struct gotd_imsgev *iev = arg;
247 struct sockaddr_storage ss;
248 struct timeval backoff;
249 socklen_t len;
250 int s = -1;
251 struct gotd_listen_client *client = NULL;
252 struct gotd_uid_connection_counter *counter = NULL;
253 struct gotd_imsg_connect iconn;
254 uid_t euid;
255 gid_t egid;
257 backoff.tv_sec = 1;
258 backoff.tv_usec = 0;
260 if (event_add(&gotd_listen.iev.ev, NULL) == -1) {
261 log_warn("event_add");
262 return;
264 if (event & EV_TIMEOUT)
265 return;
267 len = sizeof(ss);
269 /* Other backoff conditions apart from EMFILE/ENFILE? */
270 s = accept_reserve(fd, (struct sockaddr *)&ss, &len, GOTD_FD_RESERVE,
271 &inflight);
272 if (s == -1) {
273 switch (errno) {
274 case EINTR:
275 case EWOULDBLOCK:
276 case ECONNABORTED:
277 return;
278 case EMFILE:
279 case ENFILE:
280 event_del(&gotd_listen.iev.ev);
281 evtimer_add(&gotd_listen.pause.ev, &backoff);
282 return;
283 default:
284 log_warn("accept");
285 return;
289 if (listen_client_cnt >= GOTD_MAXCLIENTS)
290 goto err;
292 if (getpeereid(s, &euid, &egid) == -1) {
293 log_warn("getpeerid");
294 goto err;
297 counter = find_uid_connection_counter(euid);
298 if (counter == NULL) {
299 counter = calloc(1, sizeof(*counter));
300 if (counter == NULL) {
301 log_warn("%s: calloc", __func__);
302 goto err;
304 counter->euid = euid;
305 counter->nconnections = 1;
306 add_uid_connection_counter(counter);
307 } else {
308 int max_connections = GOTD_MAX_CONN_PER_UID;
309 struct gotd_uid_connection_limit *limit;
311 limit = gotd_find_uid_connection_limit(
312 gotd_listen.connection_limits,
313 gotd_listen.nconnection_limits, euid);
314 if (limit)
315 max_connections = limit->max_connections;
317 if (counter->nconnections >= max_connections) {
318 log_warnx("maximum connections exceeded for uid %d",
319 euid);
320 goto err;
322 counter->nconnections++;
325 client = calloc(1, sizeof(*client));
326 if (client == NULL) {
327 log_warn("%s: calloc", __func__);
328 goto err;
330 client->id = get_client_id();
331 client->fd = s;
332 client->euid = euid;
333 s = -1;
334 add_client(client);
335 log_debug("%s: new client connected on fd %d uid %d gid %d", __func__,
336 client->fd, euid, egid);
338 memset(&iconn, 0, sizeof(iconn));
339 iconn.client_id = client->id;
340 iconn.euid = euid;
341 iconn.egid = egid;
342 s = dup(client->fd);
343 if (s == -1) {
344 log_warn("%s: dup", __func__);
345 goto err;
347 if (gotd_imsg_compose_event(iev, GOTD_IMSG_CONNECT, PROC_LISTEN, s,
348 &iconn, sizeof(iconn)) == -1) {
349 log_warn("imsg compose CONNECT");
350 goto err;
353 return;
354 err:
355 inflight--;
356 if (client)
357 disconnect(client);
358 if (s != -1)
359 close(s);
362 static const struct got_error *
363 recv_disconnect(struct imsg *imsg)
365 struct gotd_imsg_disconnect idisconnect;
366 size_t datalen;
367 struct gotd_listen_client *client = NULL;
369 datalen = imsg->hdr.len - IMSG_HEADER_SIZE;
370 if (datalen != sizeof(idisconnect))
371 return got_error(GOT_ERR_PRIVSEP_LEN);
372 memcpy(&idisconnect, imsg->data, sizeof(idisconnect));
374 log_debug("client disconnecting");
376 client = find_client(idisconnect.client_id);
377 if (client == NULL)
378 return got_error(GOT_ERR_CLIENT_ID);
380 return disconnect(client);
383 static void
384 listen_dispatch(int fd, short event, void *arg)
386 const struct got_error *err = NULL;
387 struct gotd_imsgev *iev = arg;
388 struct imsgbuf *ibuf = &iev->ibuf;
389 struct imsg imsg;
390 ssize_t n;
391 int shut = 0;
393 if (event & EV_READ) {
394 if ((n = imsg_read(ibuf)) == -1 && errno != EAGAIN)
395 fatal("imsg_read error");
396 if (n == 0) /* Connection closed. */
397 shut = 1;
400 if (event & EV_WRITE) {
401 n = msgbuf_write(&ibuf->w);
402 if (n == -1 && errno != EAGAIN)
403 fatal("msgbuf_write");
404 if (n == 0) /* Connection closed. */
405 shut = 1;
408 for (;;) {
409 if ((n = imsg_get(ibuf, &imsg)) == -1)
410 fatal("%s: imsg_get", __func__);
411 if (n == 0) /* No more messages. */
412 break;
414 switch (imsg.hdr.type) {
415 case GOTD_IMSG_DISCONNECT:
416 err = recv_disconnect(&imsg);
417 if (err)
418 log_warnx("disconnect: %s", err->msg);
419 break;
420 default:
421 log_debug("unexpected imsg %d", imsg.hdr.type);
422 break;
425 imsg_free(&imsg);
428 if (!shut) {
429 gotd_imsg_event_add(iev);
430 } else {
431 /* This pipe is dead. Remove its event handler */
432 event_del(&iev->ev);
433 event_loopexit(NULL);
437 void
438 listen_main(const char *title, int gotd_socket,
439 struct gotd_uid_connection_limit *connection_limits,
440 size_t nconnection_limits)
442 struct gotd_imsgev iev;
443 struct event evsigint, evsigterm, evsighup, evsigusr1;
445 arc4random_buf(&clients_hash_key, sizeof(clients_hash_key));
446 arc4random_buf(&uid_hash_key, sizeof(uid_hash_key));
448 gotd_listen.title = title;
449 gotd_listen.pid = getpid();
450 gotd_listen.fd = gotd_socket;
451 gotd_listen.connection_limits = connection_limits;
452 gotd_listen.nconnection_limits = nconnection_limits;
454 signal_set(&evsigint, SIGINT, listen_sighdlr, NULL);
455 signal_set(&evsigterm, SIGTERM, listen_sighdlr, NULL);
456 signal_set(&evsighup, SIGHUP, listen_sighdlr, NULL);
457 signal_set(&evsigusr1, SIGUSR1, listen_sighdlr, NULL);
458 signal(SIGPIPE, SIG_IGN);
460 signal_add(&evsigint, NULL);
461 signal_add(&evsigterm, NULL);
462 signal_add(&evsighup, NULL);
463 signal_add(&evsigusr1, NULL);
465 imsg_init(&iev.ibuf, GOTD_FILENO_MSG_PIPE);
466 iev.handler = listen_dispatch;
467 iev.events = EV_READ;
468 iev.handler_arg = NULL;
469 event_set(&iev.ev, iev.ibuf.fd, EV_READ, listen_dispatch, &iev);
470 if (event_add(&iev.ev, NULL) == -1)
471 fatalx("event add");
473 event_set(&gotd_listen.iev.ev, gotd_listen.fd, EV_READ | EV_PERSIST,
474 gotd_accept, &iev);
475 if (event_add(&gotd_listen.iev.ev, NULL))
476 fatalx("event add");
477 evtimer_set(&gotd_listen.pause.ev, gotd_accept_paused, NULL);
479 event_dispatch();
481 listen_shutdown();
484 static void
485 listen_shutdown(void)
487 log_debug("shutting down");
489 free(gotd_listen.connection_limits);
490 if (gotd_listen.fd != -1)
491 close(gotd_listen.fd);
493 exit(0);