Formatting changes.
[radsecproxy.git] / tls.c
1 /* Copyright (c) 2006-2010, UNINETT AS
2  * Copyright (c) 2010-2012, NORDUnet A/S */
3 /* See LICENSE for licensing information. */
4
5 #include <signal.h>
6 #include <sys/socket.h>
7 #include <netinet/in.h>
8 #include <netdb.h>
9 #include <string.h>
10 #include <unistd.h>
11 #include <limits.h>
12 #ifdef SYS_SOLARIS9
13 #include <fcntl.h>
14 #endif
15 #include <sys/time.h>
16 #include <sys/types.h>
17 #include <sys/select.h>
18 #include <ctype.h>
19 #include <sys/wait.h>
20 #include <arpa/inet.h>
21 #include <regex.h>
22 #include <pthread.h>
23 #include <openssl/ssl.h>
24 #include <openssl/err.h>
25 #include "radsecproxy.h"
26 #include "hostport.h"
27
28 #ifdef RADPROT_TLS
29 #include "debug.h"
30 #include "util.h"
31
32 static void setprotoopts(struct commonprotoopts *opts);
33 static char **getlistenerargs();
34 void *tlslistener(void *arg);
35 int tlsconnect(struct server *server, struct timeval *when, int timeout, char *text);
36 void *tlsclientrd(void *arg);
37 int clientradputtls(struct server *server, unsigned char *rad);
38 void tlssetsrcres();
39
40 static const struct protodefs protodefs = {
41     "tls",
42     "radsec", /* secretdefault */
43     SOCK_STREAM, /* socktype */
44     "2083", /* portdefault */
45     0, /* retrycountdefault */
46     0, /* retrycountmax */
47     REQUEST_RETRY_INTERVAL * REQUEST_RETRY_COUNT, /* retryintervaldefault */
48     60, /* retryintervalmax */
49     DUPLICATE_INTERVAL, /* duplicateintervaldefault */
50     setprotoopts, /* setprotoopts */
51     getlistenerargs, /* getlistenerargs */
52     tlslistener, /* listener */
53     tlsconnect, /* connecter */
54     tlsclientrd, /* clientconnreader */
55     clientradputtls, /* clientradput */
56     NULL, /* addclient */
57     NULL, /* addserverextra */
58     tlssetsrcres, /* setsrcres */
59     NULL /* initextra */
60 };
61
62 static struct addrinfo *srcres = NULL;
63 static uint8_t handle;
64 static struct commonprotoopts *protoopts = NULL;
65
66 const struct protodefs *tlsinit(uint8_t h) {
67     handle = h;
68     return &protodefs;
69 }
70
71 static void setprotoopts(struct commonprotoopts *opts) {
72     protoopts = opts;
73 }
74
75 static char **getlistenerargs() {
76     return protoopts ? protoopts->listenargs : NULL;
77 }
78
79 void tlssetsrcres() {
80     if (!srcres)
81         srcres =
82             resolvepassiveaddrinfo(protoopts ? protoopts->sourcearg : NULL,
83                                    AF_UNSPEC, NULL, protodefs.socktype);
84 }
85
86 int tlsconnect(struct server *server, struct timeval *when, int timeout, char *text) {
87     struct timeval now;
88     time_t elapsed;
89     X509 *cert;
90     SSL_CTX *ctx = NULL;
91     unsigned long error;
92
93     debug(DBG_DBG, "tlsconnect: called from %s", text);
94     pthread_mutex_lock(&server->lock);
95     if (when && memcmp(&server->lastconnecttry, when, sizeof(struct timeval))) {
96         /* already reconnected, nothing to do */
97         debug(DBG_DBG, "tlsconnect(%s): seems already reconnected", text);
98         pthread_mutex_unlock(&server->lock);
99         return 1;
100     }
101
102     for (;;) {
103         gettimeofday(&now, NULL);
104         elapsed = now.tv_sec - server->lastconnecttry.tv_sec;
105         if (timeout && server->lastconnecttry.tv_sec && elapsed > timeout) {
106             debug(DBG_DBG, "tlsconnect: timeout");
107             if (server->sock >= 0)
108                 close(server->sock);
109             SSL_free(server->ssl);
110             server->ssl = NULL;
111             pthread_mutex_unlock(&server->lock);
112             return 0;
113         }
114         if (server->connectionok) {
115             server->connectionok = 0;
116             sleep(2);
117         } else if (elapsed < 1)
118             sleep(2);
119         else if (elapsed < 60) {
120             debug(DBG_INFO, "tlsconnect: sleeping %lds", elapsed);
121             sleep(elapsed);
122         } else if (elapsed < 100000) {
123             debug(DBG_INFO, "tlsconnect: sleeping %ds", 60);
124             sleep(60);
125         } else
126             server->lastconnecttry.tv_sec = now.tv_sec;  /* no sleep at startup */
127
128         if (server->sock >= 0)
129             close(server->sock);
130         if ((server->sock = connecttcphostlist(server->conf->hostports, srcres)) < 0)
131             continue;
132
133         SSL_free(server->ssl);
134         server->ssl = NULL;
135         ctx = tlsgetctx(handle, server->conf->tlsconf);
136         if (!ctx)
137             continue;
138         server->ssl = SSL_new(ctx);
139         if (!server->ssl)
140             continue;
141
142         SSL_set_fd(server->ssl, server->sock);
143         if (SSL_connect(server->ssl) <= 0) {
144             while ((error = ERR_get_error()))
145                 debug(DBG_ERR, "tlsconnect: TLS: %s", ERR_error_string(error, NULL));
146             continue;
147         }
148         cert = verifytlscert(server->ssl);
149         if (!cert)
150             continue;
151         if (verifyconfcert(cert, server->conf)) {
152             X509_free(cert);
153             break;
154         }
155         X509_free(cert);
156     }
157     debug(DBG_WARN, "tlsconnect: TLS connection to %s up", server->conf->name);
158     server->connectionok = 1;
159     gettimeofday(&server->lastconnecttry, NULL);
160     pthread_mutex_unlock(&server->lock);
161     return 1;
162 }
163
164 /* timeout in seconds, 0 means no timeout (blocking), returns when num bytes have been read, or timeout */
165 /* returns 0 on timeout, -1 on error and num if ok */
166 int sslreadtimeout(SSL *ssl, unsigned char *buf, int num, int timeout) {
167     int s, ndesc, cnt, len;
168     fd_set readfds, writefds;
169     struct timeval timer;
170
171     s = SSL_get_fd(ssl);
172     if (s < 0)
173         return -1;
174     /* make socket non-blocking? */
175     for (len = 0; len < num; len += cnt) {
176         FD_ZERO(&readfds);
177         FD_SET(s, &readfds);
178         writefds = readfds;
179         if (timeout) {
180             timer.tv_sec = timeout;
181             timer.tv_usec = 0;
182         }
183         ndesc = select(s + 1, &readfds, &writefds, NULL, timeout ? &timer : NULL);
184         if (ndesc < 1)
185             return ndesc;
186
187         cnt = SSL_read(ssl, buf + len, num - len);
188         if (cnt <= 0)
189             switch (SSL_get_error(ssl, cnt)) {
190             case SSL_ERROR_WANT_READ:
191             case SSL_ERROR_WANT_WRITE:
192                 cnt = 0;
193                 continue;
194             case SSL_ERROR_ZERO_RETURN:
195                 /* remote end sent close_notify, send one back */
196                 SSL_shutdown(ssl);
197                 return -1;
198             default:
199                 return -1;
200             }
201     }
202     return num;
203 }
204
205 /* timeout in seconds, 0 means no timeout (blocking) */
206 unsigned char *radtlsget(SSL *ssl, int timeout) {
207     int cnt, len;
208     unsigned char buf[4], *rad;
209
210     for (;;) {
211         cnt = sslreadtimeout(ssl, buf, 4, timeout);
212         if (cnt < 1) {
213             debug(DBG_DBG, cnt ? "radtlsget: connection lost" : "radtlsget: timeout");
214             return NULL;
215         }
216
217         len = RADLEN(buf);
218         rad = malloc(len);
219         if (!rad) {
220             debug(DBG_ERR, "radtlsget: malloc failed");
221             continue;
222         }
223         memcpy(rad, buf, 4);
224
225         cnt = sslreadtimeout(ssl, rad + 4, len - 4, timeout);
226         if (cnt < 1) {
227             debug(DBG_DBG, cnt ? "radtlsget: connection lost" : "radtlsget: timeout");
228             free(rad);
229             return NULL;
230         }
231
232         if (len >= 20)
233             break;
234
235         free(rad);
236         debug(DBG_WARN, "radtlsget: packet smaller than minimum radius size");
237     }
238
239     debug(DBG_DBG, "radtlsget: got %d bytes", len);
240     return rad;
241 }
242
243 int clientradputtls(struct server *server, unsigned char *rad) {
244     int cnt;
245     size_t len;
246     unsigned long error;
247     struct clsrvconf *conf = server->conf;
248
249     if (!server->connectionok)
250         return 0;
251     len = RADLEN(rad);
252     if ((cnt = SSL_write(server->ssl, rad, len)) <= 0) {
253         while ((error = ERR_get_error()))
254             debug(DBG_ERR, "clientradputtls: TLS: %s", ERR_error_string(error, NULL));
255         return 0;
256     }
257
258     debug(DBG_DBG, "clientradputtls: Sent %d bytes, Radius packet of length %d to TLS peer %s", cnt, len, conf->name);
259     return 1;
260 }
261
262 void *tlsclientrd(void *arg) {
263     struct server *server = (struct server *)arg;
264     unsigned char *buf;
265     struct timeval now, lastconnecttry;
266
267     for (;;) {
268         /* yes, lastconnecttry is really necessary */
269         lastconnecttry = server->lastconnecttry;
270         buf = radtlsget(server->ssl, server->dynamiclookuparg ? IDLE_TIMEOUT : 0);
271         if (!buf) {
272             if (server->dynamiclookuparg)
273                 break;
274             tlsconnect(server, &lastconnecttry, 0, "tlsclientrd");
275             continue;
276         }
277
278         replyh(server, buf);
279
280         if (server->dynamiclookuparg) {
281             gettimeofday(&now, NULL);
282             if (now.tv_sec - server->lastreply.tv_sec > IDLE_TIMEOUT) {
283                 debug(DBG_INFO, "tlsclientrd: idle timeout for %s", server->conf->name);
284                 break;
285             }
286         }
287     }
288     ERR_remove_state(0);
289     server->clientrdgone = 1;
290     return NULL;
291 }
292
293 void *tlsserverwr(void *arg) {
294     int cnt;
295     unsigned long error;
296     struct client *client = (struct client *)arg;
297     struct gqueue *replyq;
298     struct request *reply;
299
300     debug(DBG_DBG, "tlsserverwr: starting for %s", addr2string(client->addr));
301     replyq = client->replyq;
302     for (;;) {
303         pthread_mutex_lock(&replyq->mutex);
304         while (!list_first(replyq->entries)) {
305             if (client->ssl) {
306                 debug(DBG_DBG, "tlsserverwr: waiting for signal");
307                 pthread_cond_wait(&replyq->cond, &replyq->mutex);
308                 debug(DBG_DBG, "tlsserverwr: got signal");
309             }
310             if (!client->ssl) {
311                 /* ssl might have changed while waiting */
312                 pthread_mutex_unlock(&replyq->mutex);
313                 debug(DBG_DBG, "tlsserverwr: exiting as requested");
314                 ERR_remove_state(0);
315                 pthread_exit(NULL);
316             }
317         }
318         reply = (struct request *)list_shift(replyq->entries);
319         pthread_mutex_unlock(&replyq->mutex);
320         cnt = SSL_write(client->ssl, reply->replybuf, RADLEN(reply->replybuf));
321         if (cnt > 0)
322             debug(DBG_DBG, "tlsserverwr: sent %d bytes, Radius packet of length %d to %s",
323                   cnt, RADLEN(reply->replybuf), addr2string(client->addr));
324         else
325             while ((error = ERR_get_error()))
326                 debug(DBG_ERR, "tlsserverwr: SSL: %s", ERR_error_string(error, NULL));
327         freerq(reply);
328     }
329 }
330
331 void tlsserverrd(struct client *client) {
332     struct request *rq;
333     uint8_t *buf;
334     pthread_t tlsserverwrth;
335
336     debug(DBG_DBG, "tlsserverrd: starting for %s", addr2string(client->addr));
337
338     if (pthread_create(&tlsserverwrth, NULL, tlsserverwr, (void *)client)) {
339         debug(DBG_ERR, "tlsserverrd: pthread_create failed");
340         return;
341     }
342
343     for (;;) {
344         buf = radtlsget(client->ssl, 0);
345         if (!buf) {
346             debug(DBG_ERR, "tlsserverrd: connection from %s lost", addr2string(client->addr));
347             break;
348         }
349         debug(DBG_DBG, "tlsserverrd: got Radius message from %s", addr2string(client->addr));
350         rq = newrequest();
351         if (!rq) {
352             free(buf);
353             continue;
354         }
355         rq->buf = buf;
356         rq->from = client;
357         if (!radsrv(rq)) {
358             debug(DBG_ERR, "tlsserverrd: message authentication/validation failed, closing connection from %s", addr2string(client->addr));
359             break;
360         }
361     }
362
363     /* stop writer by setting ssl to NULL and give signal in case waiting for data */
364     client->ssl = NULL;
365     pthread_mutex_lock(&client->replyq->mutex);
366     pthread_cond_signal(&client->replyq->cond);
367     pthread_mutex_unlock(&client->replyq->mutex);
368     debug(DBG_DBG, "tlsserverrd: waiting for writer to end");
369     pthread_join(tlsserverwrth, NULL);
370     debug(DBG_DBG, "tlsserverrd: reader for %s exiting", addr2string(client->addr));
371 }
372
373 void *tlsservernew(void *arg) {
374     int s;
375     struct sockaddr_storage from;
376     socklen_t fromlen = sizeof(from);
377     struct clsrvconf *conf;
378     struct list_node *cur = NULL;
379     SSL *ssl = NULL;
380     X509 *cert = NULL;
381     SSL_CTX *ctx = NULL;
382     unsigned long error;
383     struct client *client;
384     struct tls *accepted_tls = NULL;
385
386     s = *(int *)arg;
387     if (getpeername(s, (struct sockaddr *)&from, &fromlen)) {
388         debug(DBG_DBG, "tlsservernew: getpeername failed, exiting");
389         goto exit;
390     }
391     debug(DBG_WARN, "tlsservernew: incoming TLS connection from %s", addr2string((struct sockaddr *)&from));
392
393     conf = find_clconf(handle, (struct sockaddr *)&from, &cur);
394     if (conf) {
395         ctx = tlsgetctx(handle, conf->tlsconf);
396         if (!ctx)
397             goto exit;
398         ssl = SSL_new(ctx);
399         if (!ssl)
400             goto exit;
401         SSL_set_fd(ssl, s);
402
403         if (SSL_accept(ssl) <= 0) {
404             while ((error = ERR_get_error()))
405                 debug(DBG_ERR, "tlsservernew: SSL: %s", ERR_error_string(error, NULL));
406             debug(DBG_ERR, "tlsservernew: SSL_accept failed");
407             goto exit;
408         }
409         cert = verifytlscert(ssl);
410         if (!cert)
411             goto exit;
412         accepted_tls = conf->tlsconf;
413     }
414
415     while (conf) {
416         if (accepted_tls == conf->tlsconf && verifyconfcert(cert, conf)) {
417             X509_free(cert);
418             client = addclient(conf, 1);
419             if (client) {
420                 client->ssl = ssl;
421                 client->addr = addr_copy((struct sockaddr *)&from);
422                 tlsserverrd(client);
423                 removeclient(client);
424             } else
425                 debug(DBG_WARN, "tlsservernew: failed to create new client instance");
426             goto exit;
427         }
428         conf = find_clconf(handle, (struct sockaddr *)&from, &cur);
429     }
430     debug(DBG_WARN, "tlsservernew: ignoring request, no matching TLS client");
431     if (cert)
432         X509_free(cert);
433
434 exit:
435     if (ssl) {
436         SSL_shutdown(ssl);
437         SSL_free(ssl);
438     }
439     ERR_remove_state(0);
440     shutdown(s, SHUT_RDWR);
441     close(s);
442     pthread_exit(NULL);
443 }
444
445 void *tlslistener(void *arg) {
446     pthread_t tlsserverth;
447     int s, *sp = (int *)arg;
448     struct sockaddr_storage from;
449     socklen_t fromlen = sizeof(from);
450
451     listen(*sp, 0);
452
453     for (;;) {
454         s = accept(*sp, (struct sockaddr *)&from, &fromlen);
455         if (s < 0) {
456             debug(DBG_WARN, "accept failed");
457             continue;
458         }
459         if (pthread_create(&tlsserverth, NULL, tlsservernew, (void *)&s)) {
460             debug(DBG_ERR, "tlslistener: pthread_create failed");
461             shutdown(s, SHUT_RDWR);
462             close(s);
463             continue;
464         }
465         pthread_detach(tlsserverth);
466     }
467     free(sp);
468     return NULL;
469 }
470 #else
471 const struct protodefs *tlsinit(uint8_t h) {
472     return NULL;
473 }
474 #endif
475
476 /* Local Variables: */
477 /* c-file-style: "stroustrup" */
478 /* End: */