4953e98f769bedfd1d1d12e4a5b175988af58a6f
[radsecproxy.git] / tls.c
1 /* Copyright (c) 2006-2010, UNINETT AS.
2  * Copyright (c) 2010, UNINETT AS, NORDUnet A/S.
3  * Copyright (c) 2010-2012, NORDUnet A/S. */
4 /* See LICENSE for licensing information. */
5
6 #include <signal.h>
7 #include <sys/socket.h>
8 #include <netinet/in.h>
9 #include <netdb.h>
10 #include <string.h>
11 #include <unistd.h>
12 #include <limits.h>
13 #ifdef SYS_SOLARIS9
14 #include <fcntl.h>
15 #endif
16 #include <sys/time.h>
17 #include <sys/types.h>
18 #include <sys/select.h>
19 #include <ctype.h>
20 #include <sys/wait.h>
21 #include <arpa/inet.h>
22 #include <regex.h>
23 #include <pthread.h>
24 #include <openssl/ssl.h>
25 #include <openssl/err.h>
26 #include "radsecproxy.h"
27 #include "hostport.h"
28
29 #ifdef RADPROT_TLS
30 #include "debug.h"
31 #include "util.h"
32
33 static void setprotoopts(struct commonprotoopts *opts);
34 static char **getlistenerargs();
35 void *tlslistener(void *arg);
36 int tlsconnect(struct server *server, struct timeval *when, int timeout, char *text);
37 void *tlsclientrd(void *arg);
38 int clientradputtls(struct server *server, unsigned char *rad);
39 void tlssetsrcres();
40
41 static const struct protodefs protodefs = {
42     "tls",
43     "radsec", /* secretdefault */
44     SOCK_STREAM, /* socktype */
45     "2083", /* portdefault */
46     0, /* retrycountdefault */
47     0, /* retrycountmax */
48     REQUEST_RETRY_INTERVAL * REQUEST_RETRY_COUNT, /* retryintervaldefault */
49     60, /* retryintervalmax */
50     DUPLICATE_INTERVAL, /* duplicateintervaldefault */
51     setprotoopts, /* setprotoopts */
52     getlistenerargs, /* getlistenerargs */
53     tlslistener, /* listener */
54     tlsconnect, /* connecter */
55     tlsclientrd, /* clientconnreader */
56     clientradputtls, /* clientradput */
57     NULL, /* addclient */
58     NULL, /* addserverextra */
59     tlssetsrcres, /* setsrcres */
60     NULL /* initextra */
61 };
62
63 static struct addrinfo *srcres = NULL;
64 static uint8_t handle;
65 static struct commonprotoopts *protoopts = NULL;
66
67 const struct protodefs *tlsinit(uint8_t h) {
68     handle = h;
69     return &protodefs;
70 }
71
72 static void setprotoopts(struct commonprotoopts *opts) {
73     protoopts = opts;
74 }
75
76 static char **getlistenerargs() {
77     return protoopts ? protoopts->listenargs : NULL;
78 }
79
80 void tlssetsrcres() {
81     if (!srcres)
82         srcres =
83             resolvepassiveaddrinfo(protoopts ? protoopts->sourcearg : NULL,
84                                    AF_UNSPEC, NULL, protodefs.socktype);
85 }
86
87 int tlsconnect(struct server *server, struct timeval *when, int timeout, char *text) {
88     struct timeval now;
89     time_t elapsed;
90     X509 *cert;
91     SSL_CTX *ctx = NULL;
92     unsigned long error;
93
94     debug(DBG_DBG, "tlsconnect: called from %s", text);
95     pthread_mutex_lock(&server->lock);
96     if (when && memcmp(&server->lastconnecttry, when, sizeof(struct timeval))) {
97         /* already reconnected, nothing to do */
98         debug(DBG_DBG, "tlsconnect(%s): seems already reconnected", text);
99         pthread_mutex_unlock(&server->lock);
100         return 1;
101     }
102
103     for (;;) {
104         gettimeofday(&now, NULL);
105         elapsed = now.tv_sec - server->lastconnecttry.tv_sec;
106         if (timeout && server->lastconnecttry.tv_sec && elapsed > timeout) {
107             debug(DBG_DBG, "tlsconnect: timeout");
108             if (server->sock >= 0)
109                 close(server->sock);
110             SSL_free(server->ssl);
111             server->ssl = NULL;
112             pthread_mutex_unlock(&server->lock);
113             return 0;
114         }
115         if (server->connectionok) {
116             server->connectionok = 0;
117             sleep(2);
118         } else if (elapsed < 1)
119             sleep(2);
120         else if (elapsed < 60) {
121             debug(DBG_INFO, "tlsconnect: sleeping %lds", elapsed);
122             sleep(elapsed);
123         } else if (elapsed < 100000) {
124             debug(DBG_INFO, "tlsconnect: sleeping %ds", 60);
125             sleep(60);
126         } else
127             server->lastconnecttry.tv_sec = now.tv_sec;  /* no sleep at startup */
128
129         if (server->sock >= 0)
130             close(server->sock);
131         if ((server->sock = connecttcphostlist(server->conf->hostports, srcres)) < 0)
132             continue;
133
134         SSL_free(server->ssl);
135         server->ssl = NULL;
136         ctx = tlsgetctx(handle, server->conf->tlsconf);
137         if (!ctx)
138             continue;
139         server->ssl = SSL_new(ctx);
140         if (!server->ssl)
141             continue;
142
143         SSL_set_fd(server->ssl, server->sock);
144         if (SSL_connect(server->ssl) <= 0) {
145             while ((error = ERR_get_error()))
146                 debug(DBG_ERR, "tlsconnect: TLS: %s", ERR_error_string(error, NULL));
147             continue;
148         }
149         cert = verifytlscert(server->ssl);
150         if (!cert)
151             continue;
152         if (verifyconfcert(cert, server->conf)) {
153             X509_free(cert);
154             break;
155         }
156         X509_free(cert);
157     }
158     debug(DBG_WARN, "tlsconnect: TLS connection to %s up", server->conf->name);
159     server->connectionok = 1;
160     gettimeofday(&server->lastconnecttry, NULL);
161     pthread_mutex_unlock(&server->lock);
162     return 1;
163 }
164
165 /* timeout in seconds, 0 means no timeout (blocking), returns when num bytes have been read, or timeout */
166 /* returns 0 on timeout, -1 on error and num if ok */
167 int sslreadtimeout(SSL *ssl, unsigned char *buf, int num, int timeout) {
168     int s, ndesc, cnt, len;
169     fd_set readfds, writefds;
170     struct timeval timer;
171
172     s = SSL_get_fd(ssl);
173     if (s < 0)
174         return -1;
175     /* make socket non-blocking? */
176     for (len = 0; len < num; len += cnt) {
177         FD_ZERO(&readfds);
178         FD_SET(s, &readfds);
179         writefds = readfds;
180         if (timeout) {
181             timer.tv_sec = timeout;
182             timer.tv_usec = 0;
183         }
184         ndesc = select(s + 1, &readfds, &writefds, NULL, timeout ? &timer : NULL);
185         if (ndesc < 1)
186             return ndesc;
187
188         cnt = SSL_read(ssl, buf + len, num - len);
189         if (cnt <= 0)
190             switch (SSL_get_error(ssl, cnt)) {
191             case SSL_ERROR_WANT_READ:
192             case SSL_ERROR_WANT_WRITE:
193                 cnt = 0;
194                 continue;
195             case SSL_ERROR_ZERO_RETURN:
196                 /* remote end sent close_notify, send one back */
197                 SSL_shutdown(ssl);
198                 return -1;
199             default:
200                 return -1;
201             }
202     }
203     return num;
204 }
205
206 /* timeout in seconds, 0 means no timeout (blocking) */
207 unsigned char *radtlsget(SSL *ssl, int timeout) {
208     int cnt, len;
209     unsigned char buf[4], *rad;
210
211     for (;;) {
212         cnt = sslreadtimeout(ssl, buf, 4, timeout);
213         if (cnt < 1) {
214             debug(DBG_DBG, cnt ? "radtlsget: connection lost" : "radtlsget: timeout");
215             return NULL;
216         }
217
218         len = RADLEN(buf);
219         rad = malloc(len);
220         if (!rad) {
221             debug(DBG_ERR, "radtlsget: malloc failed");
222             continue;
223         }
224         memcpy(rad, buf, 4);
225
226         cnt = sslreadtimeout(ssl, rad + 4, len - 4, timeout);
227         if (cnt < 1) {
228             debug(DBG_DBG, cnt ? "radtlsget: connection lost" : "radtlsget: timeout");
229             free(rad);
230             return NULL;
231         }
232
233         if (len >= 20)
234             break;
235
236         free(rad);
237         debug(DBG_WARN, "radtlsget: packet smaller than minimum radius size");
238     }
239
240     debug(DBG_DBG, "radtlsget: got %d bytes", len);
241     return rad;
242 }
243
244 int clientradputtls(struct server *server, unsigned char *rad) {
245     int cnt;
246     size_t len;
247     unsigned long error;
248     struct clsrvconf *conf = server->conf;
249
250     if (!server->connectionok)
251         return 0;
252     len = RADLEN(rad);
253     if ((cnt = SSL_write(server->ssl, rad, len)) <= 0) {
254         while ((error = ERR_get_error()))
255             debug(DBG_ERR, "clientradputtls: TLS: %s", ERR_error_string(error, NULL));
256         return 0;
257     }
258
259     debug(DBG_DBG, "clientradputtls: Sent %d bytes, Radius packet of length %d to TLS peer %s", cnt, len, conf->name);
260     return 1;
261 }
262
263 void *tlsclientrd(void *arg) {
264     struct server *server = (struct server *)arg;
265     unsigned char *buf;
266     struct timeval now, lastconnecttry;
267
268     for (;;) {
269         /* yes, lastconnecttry is really necessary */
270         lastconnecttry = server->lastconnecttry;
271         buf = radtlsget(server->ssl, server->dynamiclookuparg ? IDLE_TIMEOUT : 0);
272         if (!buf) {
273             if (server->dynamiclookuparg)
274                 break;
275             tlsconnect(server, &lastconnecttry, 0, "tlsclientrd");
276             continue;
277         }
278
279         replyh(server, buf);
280
281         if (server->dynamiclookuparg) {
282             gettimeofday(&now, NULL);
283             if (now.tv_sec - server->lastreply.tv_sec > IDLE_TIMEOUT) {
284                 debug(DBG_INFO, "tlsclientrd: idle timeout for %s", server->conf->name);
285                 break;
286             }
287         }
288     }
289     ERR_remove_state(0);
290     server->clientrdgone = 1;
291     return NULL;
292 }
293
294 void *tlsserverwr(void *arg) {
295     int cnt;
296     unsigned long error;
297     struct client *client = (struct client *)arg;
298     struct gqueue *replyq;
299     struct request *reply;
300
301     debug(DBG_DBG, "tlsserverwr: starting for %s", addr2string(client->addr));
302     replyq = client->replyq;
303     for (;;) {
304         pthread_mutex_lock(&replyq->mutex);
305         while (!list_first(replyq->entries)) {
306             if (client->ssl) {
307                 debug(DBG_DBG, "tlsserverwr: waiting for signal");
308                 pthread_cond_wait(&replyq->cond, &replyq->mutex);
309                 debug(DBG_DBG, "tlsserverwr: got signal");
310             }
311             if (!client->ssl) {
312                 /* ssl might have changed while waiting */
313                 pthread_mutex_unlock(&replyq->mutex);
314                 debug(DBG_DBG, "tlsserverwr: exiting as requested");
315                 ERR_remove_state(0);
316                 pthread_exit(NULL);
317             }
318         }
319         reply = (struct request *)list_shift(replyq->entries);
320         pthread_mutex_unlock(&replyq->mutex);
321         cnt = SSL_write(client->ssl, reply->replybuf, RADLEN(reply->replybuf));
322         if (cnt > 0)
323             debug(DBG_DBG, "tlsserverwr: sent %d bytes, Radius packet of length %d to %s",
324                   cnt, RADLEN(reply->replybuf), addr2string(client->addr));
325         else
326             while ((error = ERR_get_error()))
327                 debug(DBG_ERR, "tlsserverwr: SSL: %s", ERR_error_string(error, NULL));
328         freerq(reply);
329     }
330 }
331
332 void tlsserverrd(struct client *client) {
333     struct request *rq;
334     uint8_t *buf;
335     pthread_t tlsserverwrth;
336
337     debug(DBG_DBG, "tlsserverrd: starting for %s", addr2string(client->addr));
338
339     if (pthread_create(&tlsserverwrth, NULL, tlsserverwr, (void *)client)) {
340         debug(DBG_ERR, "tlsserverrd: pthread_create failed");
341         return;
342     }
343
344     for (;;) {
345         buf = radtlsget(client->ssl, 0);
346         if (!buf) {
347             debug(DBG_ERR, "tlsserverrd: connection from %s lost", addr2string(client->addr));
348             break;
349         }
350         debug(DBG_DBG, "tlsserverrd: got Radius message from %s", addr2string(client->addr));
351         rq = newrequest();
352         if (!rq) {
353             free(buf);
354             continue;
355         }
356         rq->buf = buf;
357         rq->from = client;
358         if (!radsrv(rq)) {
359             debug(DBG_ERR, "tlsserverrd: message authentication/validation failed, closing connection from %s", addr2string(client->addr));
360             break;
361         }
362     }
363
364     /* stop writer by setting ssl to NULL and give signal in case waiting for data */
365     client->ssl = NULL;
366     pthread_mutex_lock(&client->replyq->mutex);
367     pthread_cond_signal(&client->replyq->cond);
368     pthread_mutex_unlock(&client->replyq->mutex);
369     debug(DBG_DBG, "tlsserverrd: waiting for writer to end");
370     pthread_join(tlsserverwrth, NULL);
371     debug(DBG_DBG, "tlsserverrd: reader for %s exiting", addr2string(client->addr));
372 }
373
374 void *tlsservernew(void *arg) {
375     int s;
376     struct sockaddr_storage from;
377     socklen_t fromlen = sizeof(from);
378     struct clsrvconf *conf;
379     struct list_node *cur = NULL;
380     SSL *ssl = NULL;
381     X509 *cert = NULL;
382     SSL_CTX *ctx = NULL;
383     unsigned long error;
384     struct client *client;
385     struct tls *accepted_tls = NULL;
386
387     s = *(int *)arg;
388     if (getpeername(s, (struct sockaddr *)&from, &fromlen)) {
389         debug(DBG_DBG, "tlsservernew: getpeername failed, exiting");
390         goto exit;
391     }
392     debug(DBG_WARN, "tlsservernew: incoming TLS connection from %s", addr2string((struct sockaddr *)&from));
393
394     conf = find_clconf(handle, (struct sockaddr *)&from, &cur);
395     if (conf) {
396         ctx = tlsgetctx(handle, conf->tlsconf);
397         if (!ctx)
398             goto exit;
399         ssl = SSL_new(ctx);
400         if (!ssl)
401             goto exit;
402         SSL_set_fd(ssl, s);
403
404         if (SSL_accept(ssl) <= 0) {
405             while ((error = ERR_get_error()))
406                 debug(DBG_ERR, "tlsservernew: SSL: %s", ERR_error_string(error, NULL));
407             debug(DBG_ERR, "tlsservernew: SSL_accept failed");
408             goto exit;
409         }
410         cert = verifytlscert(ssl);
411         if (!cert)
412             goto exit;
413         accepted_tls = conf->tlsconf;
414     }
415
416     while (conf) {
417         if (accepted_tls == conf->tlsconf && verifyconfcert(cert, conf)) {
418             X509_free(cert);
419             client = addclient(conf, 1);
420             if (client) {
421                 client->ssl = ssl;
422                 client->addr = addr_copy((struct sockaddr *)&from);
423                 tlsserverrd(client);
424                 removeclient(client);
425             } else
426                 debug(DBG_WARN, "tlsservernew: failed to create new client instance");
427             goto exit;
428         }
429         conf = find_clconf(handle, (struct sockaddr *)&from, &cur);
430     }
431     debug(DBG_WARN, "tlsservernew: ignoring request, no matching TLS client");
432     if (cert)
433         X509_free(cert);
434
435 exit:
436     if (ssl) {
437         SSL_shutdown(ssl);
438         SSL_free(ssl);
439     }
440     ERR_remove_state(0);
441     shutdown(s, SHUT_RDWR);
442     close(s);
443     pthread_exit(NULL);
444 }
445
446 void *tlslistener(void *arg) {
447     pthread_t tlsserverth;
448     int s, *sp = (int *)arg;
449     struct sockaddr_storage from;
450     socklen_t fromlen = sizeof(from);
451
452     listen(*sp, 0);
453
454     for (;;) {
455         s = accept(*sp, (struct sockaddr *)&from, &fromlen);
456         if (s < 0) {
457             debug(DBG_WARN, "accept failed");
458             continue;
459         }
460         if (pthread_create(&tlsserverth, NULL, tlsservernew, (void *)&s)) {
461             debug(DBG_ERR, "tlslistener: pthread_create failed");
462             shutdown(s, SHUT_RDWR);
463             close(s);
464             continue;
465         }
466         pthread_detach(tlsserverth);
467     }
468     free(sp);
469     return NULL;
470 }
471 #else
472 const struct protodefs *tlsinit(uint8_t h) {
473     return NULL;
474 }
475 #endif
476
477 /* Local Variables: */
478 /* c-file-style: "stroustrup" */
479 /* End: */