allowing build with only specific transports
[libradsec.git] / tls.c
1 /*
2  * Copyright (C) 2006-2008 Stig Venaas <venaas@uninett.no>
3  *
4  * Permission to use, copy, modify, and distribute this software for any
5  * purpose with or without fee is hereby granted, provided that the above
6  * copyright notice and this permission notice appear in all copies.
7  */
8
9 #ifdef RADPROT_TLS
10 #include <signal.h>
11 #include <sys/socket.h>
12 #include <netinet/in.h>
13 #include <netdb.h>
14 #include <string.h>
15 #include <unistd.h>
16 #include <limits.h>
17 #ifdef SYS_SOLARIS9
18 #include <fcntl.h>
19 #endif
20 #include <sys/time.h>
21 #include <sys/types.h>
22 #include <sys/select.h>
23 #include <ctype.h>
24 #include <sys/wait.h>
25 #include <arpa/inet.h>
26 #include <regex.h>
27 #include <pthread.h>
28 #include <openssl/ssl.h>
29 #include <openssl/err.h>
30 #include "debug.h"
31 #include "list.h"
32 #include "util.h"
33 #include "radsecproxy.h"
34
35 static void setprotoopts(struct commonprotoopts *opts);
36 static char **getlistenerargs();
37 void *tlslistener(void *arg);
38 int tlsconnect(struct server *server, struct timeval *when, int timeout, char *text);
39 void *tlsclientrd(void *arg);
40 int clientradputtls(struct server *server, unsigned char *rad);
41 void tlssetsrcres();
42
43 static const struct protodefs protodefs = {
44     "tls",
45     "mysecret", /* secretdefault */
46     SOCK_STREAM, /* socktype */
47     "2083", /* portdefault */
48     0, /* retrycountdefault */
49     0, /* retrycountmax */
50     REQUEST_RETRY_INTERVAL * REQUEST_RETRY_COUNT, /* retryintervaldefault */
51     60, /* retryintervalmax */
52     DUPLICATE_INTERVAL, /* duplicateintervaldefault */
53     setprotoopts, /* setprotoopts */
54     getlistenerargs, /* getlistenerargs */
55     tlslistener, /* listener */
56     tlsconnect, /* connecter */
57     tlsclientrd, /* clientconnreader */
58     clientradputtls, /* clientradput */
59     NULL, /* addclient */
60     NULL, /* addserverextra */
61     tlssetsrcres, /* setsrcres */
62     NULL /* initextra */
63 };
64
65 static struct addrinfo *srcres = NULL;
66 static uint8_t handle;
67 static struct commonprotoopts *protoopts = NULL;
68
69 const struct protodefs *tlsinit(uint8_t h) {
70     handle = h;
71     return &protodefs;
72 }
73
74 static void setprotoopts(struct commonprotoopts *opts) {
75     protoopts = opts;
76 }
77
78 static char **getlistenerargs() {
79     return protoopts ? protoopts->listenargs : NULL;
80 }
81
82 void tlssetsrcres() {
83     if (!srcres)
84         srcres = resolve_hostport_addrinfo(handle, protoopts ? protoopts->sourcearg : NULL);
85     
86 }
87
88 int tlsconnect(struct server *server, struct timeval *when, int timeout, char *text) {
89     struct timeval now;
90     time_t elapsed;
91     X509 *cert;
92     SSL_CTX *ctx = NULL;
93     unsigned long error;
94     
95     debug(DBG_DBG, "tlsconnect: called from %s", text);
96     pthread_mutex_lock(&server->lock);
97     if (when && memcmp(&server->lastconnecttry, when, sizeof(struct timeval))) {
98         /* already reconnected, nothing to do */
99         debug(DBG_DBG, "tlsconnect(%s): seems already reconnected", text);
100         pthread_mutex_unlock(&server->lock);
101         return 1;
102     }
103
104     for (;;) {
105         gettimeofday(&now, NULL);
106         elapsed = now.tv_sec - server->lastconnecttry.tv_sec;
107         if (timeout && server->lastconnecttry.tv_sec && elapsed > timeout) {
108             debug(DBG_DBG, "tlsconnect: timeout");
109             if (server->sock >= 0)
110                 close(server->sock);
111             SSL_free(server->ssl);
112             server->ssl = NULL;
113             pthread_mutex_unlock(&server->lock);
114             return 0;
115         }
116         if (server->connectionok) {
117             server->connectionok = 0;
118             sleep(2);
119         } else if (elapsed < 1)
120             sleep(2);
121         else if (elapsed < 60) {
122             debug(DBG_INFO, "tlsconnect: sleeping %lds", elapsed);
123             sleep(elapsed);
124         } else if (elapsed < 100000) {
125             debug(DBG_INFO, "tlsconnect: sleeping %ds", 60);
126             sleep(60);
127         } else
128             server->lastconnecttry.tv_sec = now.tv_sec;  /* no sleep at startup */
129         debug(DBG_WARN, "tlsconnect: trying to open TLS connection to %s port %s", server->conf->host, server->conf->port);
130         if (server->sock >= 0)
131             close(server->sock);
132         if ((server->sock = connecttcp(server->conf->addrinfo, srcres)) < 0) {
133             debug(DBG_ERR, "tlsconnect: connecttcp failed");
134             continue;
135         }
136         
137         SSL_free(server->ssl);
138         server->ssl = NULL;
139         ctx = tlsgetctx(handle, server->conf->tlsconf);
140         if (!ctx)
141             continue;
142         server->ssl = SSL_new(ctx);
143         if (!server->ssl)
144             continue;
145
146         SSL_set_fd(server->ssl, server->sock);
147         if (SSL_connect(server->ssl) <= 0) {
148             while ((error = ERR_get_error()))
149                 debug(DBG_ERR, "tlsconnect: TLS: %s", ERR_error_string(error, NULL));
150             continue;
151         }
152         cert = verifytlscert(server->ssl);
153         if (!cert)
154             continue;
155         if (verifyconfcert(cert, server->conf)) {
156             X509_free(cert);
157             break;
158         }
159         X509_free(cert);
160     }
161     debug(DBG_WARN, "tlsconnect: TLS connection to %s port %s up", server->conf->host, server->conf->port);
162     server->connectionok = 1;
163     gettimeofday(&server->lastconnecttry, NULL);
164     pthread_mutex_unlock(&server->lock);
165     return 1;
166 }
167
168 /* timeout in seconds, 0 means no timeout (blocking), returns when num bytes have been read, or timeout */
169 /* returns 0 on timeout, -1 on error and num if ok */
170 int sslreadtimeout(SSL *ssl, unsigned char *buf, int num, int timeout) {
171     int s, ndesc, cnt, len;
172     fd_set readfds, writefds;
173     struct timeval timer;
174     
175     s = SSL_get_fd(ssl);
176     if (s < 0)
177         return -1;
178     /* make socket non-blocking? */
179     for (len = 0; len < num; len += cnt) {
180         FD_ZERO(&readfds);
181         FD_SET(s, &readfds);
182         writefds = readfds;
183         if (timeout) {
184             timer.tv_sec = timeout;
185             timer.tv_usec = 0;
186         }
187         ndesc = select(s + 1, &readfds, &writefds, NULL, timeout ? &timer : NULL);
188         if (ndesc < 1)
189             return ndesc;
190
191         cnt = SSL_read(ssl, buf + len, num - len);
192         if (cnt <= 0)
193             switch (SSL_get_error(ssl, cnt)) {
194             case SSL_ERROR_WANT_READ:
195             case SSL_ERROR_WANT_WRITE:
196                 cnt = 0;
197                 continue;
198             case SSL_ERROR_ZERO_RETURN:
199                 /* remote end sent close_notify, send one back */
200                 SSL_shutdown(ssl);
201                 return -1;
202             default:
203                 return -1;
204             }
205     }
206     return num;
207 }
208
209 /* timeout in seconds, 0 means no timeout (blocking) */
210 unsigned char *radtlsget(SSL *ssl, int timeout) {
211     int cnt, len;
212     unsigned char buf[4], *rad;
213
214     for (;;) {
215         cnt = sslreadtimeout(ssl, buf, 4, timeout);
216         if (cnt < 1) {
217             debug(DBG_DBG, cnt ? "radtlsget: connection lost" : "radtlsget: timeout");
218             return NULL;
219         }
220
221         len = RADLEN(buf);
222         rad = malloc(len);
223         if (!rad) {
224             debug(DBG_ERR, "radtlsget: malloc failed");
225             continue;
226         }
227         memcpy(rad, buf, 4);
228         
229         cnt = sslreadtimeout(ssl, rad + 4, len - 4, timeout);
230         if (cnt < 1) {
231             debug(DBG_DBG, cnt ? "radtlsget: connection lost" : "radtlsget: timeout");
232             free(rad);
233             return NULL;
234         }
235         
236         if (len >= 20)
237             break;
238         
239         free(rad);
240         debug(DBG_WARN, "radtlsget: packet smaller than minimum radius size");
241     }
242     
243     debug(DBG_DBG, "radtlsget: got %d bytes", len);
244     return rad;
245 }
246
247 int clientradputtls(struct server *server, unsigned char *rad) {
248     int cnt;
249     size_t len;
250     unsigned long error;
251     struct clsrvconf *conf = server->conf;
252
253     if (!server->connectionok)
254         return 0;
255     len = RADLEN(rad);
256     if ((cnt = SSL_write(server->ssl, rad, len)) <= 0) {
257         while ((error = ERR_get_error()))
258             debug(DBG_ERR, "clientradputtls: TLS: %s", ERR_error_string(error, NULL));
259         return 0;
260     }
261
262     debug(DBG_DBG, "clientradputtls: Sent %d bytes, Radius packet of length %d to TLS peer %s", cnt, len, conf->host);
263     return 1;
264 }
265
266 void *tlsclientrd(void *arg) {
267     struct server *server = (struct server *)arg;
268     unsigned char *buf;
269     struct timeval now, lastconnecttry;
270     
271     for (;;) {
272         /* yes, lastconnecttry is really necessary */
273         lastconnecttry = server->lastconnecttry;
274         buf = radtlsget(server->ssl, server->dynamiclookuparg ? IDLE_TIMEOUT : 0);
275         if (!buf) {
276             if (server->dynamiclookuparg)
277                 break;
278             tlsconnect(server, &lastconnecttry, 0, "tlsclientrd");
279             continue;
280         }
281
282         replyh(server, buf);
283
284         if (server->dynamiclookuparg) {
285             gettimeofday(&now, NULL);
286             if (now.tv_sec - server->lastreply.tv_sec > IDLE_TIMEOUT) {
287                 debug(DBG_INFO, "tlsclientrd: idle timeout for %s", server->conf->name);
288                 break;
289             }
290         }
291     }
292     ERR_remove_state(0);
293     server->clientrdgone = 1;
294     return NULL;
295 }
296
297 void *tlsserverwr(void *arg) {
298     int cnt;
299     unsigned long error;
300     struct client *client = (struct client *)arg;
301     struct queue *replyq;
302     struct request *reply;
303     
304     debug(DBG_DBG, "tlsserverwr: starting for %s", addr2string(client->addr));
305     replyq = client->replyq;
306     for (;;) {
307         pthread_mutex_lock(&replyq->mutex);
308         while (!list_first(replyq->entries)) {
309             if (client->ssl) {      
310                 debug(DBG_DBG, "tlsserverwr: waiting for signal");
311                 pthread_cond_wait(&replyq->cond, &replyq->mutex);
312                 debug(DBG_DBG, "tlsserverwr: got signal");
313             }
314             if (!client->ssl) {
315                 /* ssl might have changed while waiting */
316                 pthread_mutex_unlock(&replyq->mutex);
317                 debug(DBG_DBG, "tlsserverwr: exiting as requested");
318                 ERR_remove_state(0);
319                 pthread_exit(NULL);
320             }
321         }
322         reply = (struct request *)list_shift(replyq->entries);
323         pthread_mutex_unlock(&replyq->mutex);
324         cnt = SSL_write(client->ssl, reply->replybuf, RADLEN(reply->replybuf));
325         if (cnt > 0)
326             debug(DBG_DBG, "tlsserverwr: sent %d bytes, Radius packet of length %d to %s",
327                   cnt, RADLEN(reply->replybuf), addr2string(client->addr));
328         else
329             while ((error = ERR_get_error()))
330                 debug(DBG_ERR, "tlsserverwr: SSL: %s", ERR_error_string(error, NULL));
331         freerq(reply);
332     }
333 }
334
335 void tlsserverrd(struct client *client) {
336     struct request *rq;
337     uint8_t *buf;
338     pthread_t tlsserverwrth;
339     
340     debug(DBG_DBG, "tlsserverrd: starting for %s", addr2string(client->addr));
341     
342     if (pthread_create(&tlsserverwrth, NULL, tlsserverwr, (void *)client)) {
343         debug(DBG_ERR, "tlsserverrd: pthread_create failed");
344         return;
345     }
346
347     for (;;) {
348         buf = radtlsget(client->ssl, 0);
349         if (!buf) {
350             debug(DBG_ERR, "tlsserverrd: connection from %s lost", addr2string(client->addr));
351             break;
352         }
353         debug(DBG_DBG, "tlsserverrd: got Radius message from %s", addr2string(client->addr));
354         rq = newrequest();
355         if (!rq) {
356             free(buf);
357             continue;
358         }
359         rq->buf = buf;
360         rq->from = client;
361         if (!radsrv(rq)) {
362             debug(DBG_ERR, "tlsserverrd: message authentication/validation failed, closing connection from %s", addr2string(client->addr));
363             break;
364         }
365     }
366     
367     /* stop writer by setting ssl to NULL and give signal in case waiting for data */
368     client->ssl = NULL;
369     pthread_mutex_lock(&client->replyq->mutex);
370     pthread_cond_signal(&client->replyq->cond);
371     pthread_mutex_unlock(&client->replyq->mutex);
372     debug(DBG_DBG, "tlsserverrd: waiting for writer to end");
373     pthread_join(tlsserverwrth, NULL);
374     debug(DBG_DBG, "tlsserverrd: reader for %s exiting", addr2string(client->addr));
375 }
376
377 void *tlsservernew(void *arg) {
378     int s;
379     struct sockaddr_storage from;
380     socklen_t fromlen = sizeof(from);
381     struct clsrvconf *conf;
382     struct list_node *cur = NULL;
383     SSL *ssl = NULL;
384     X509 *cert = NULL;
385     SSL_CTX *ctx = NULL;
386     unsigned long error;
387     struct client *client;
388
389     s = *(int *)arg;
390     if (getpeername(s, (struct sockaddr *)&from, &fromlen)) {
391         debug(DBG_DBG, "tlsservernew: getpeername failed, exiting");
392         goto exit;
393     }
394     debug(DBG_WARN, "tlsservernew: incoming TLS connection from %s", addr2string((struct sockaddr *)&from));
395
396     conf = find_clconf(handle, (struct sockaddr *)&from, &cur);
397     if (conf) {
398         ctx = tlsgetctx(handle, conf->tlsconf);
399         if (!ctx)
400             goto exit;
401         ssl = SSL_new(ctx);
402         if (!ssl)
403             goto exit;
404         SSL_set_fd(ssl, s);
405
406         if (SSL_accept(ssl) <= 0) {
407             while ((error = ERR_get_error()))
408                 debug(DBG_ERR, "tlsservernew: SSL: %s", ERR_error_string(error, NULL));
409             debug(DBG_ERR, "tlsservernew: SSL_accept failed");
410             goto exit;
411         }
412         cert = verifytlscert(ssl);
413         if (!cert)
414             goto exit;
415     }
416     
417     while (conf) {
418         if (verifyconfcert(cert, conf)) {
419             X509_free(cert);
420             client = addclient(conf, 1);
421             if (client) {
422                 client->ssl = ssl;
423                 client->addr = addr_copy((struct sockaddr *)&from);
424                 tlsserverrd(client);
425                 removeclient(client);
426             } else
427                 debug(DBG_WARN, "tlsservernew: failed to create new client instance");
428             goto exit;
429         }
430         conf = find_clconf(handle, (struct sockaddr *)&from, &cur);
431     }
432     debug(DBG_WARN, "tlsservernew: ignoring request, no matching TLS client");
433     if (cert)
434         X509_free(cert);
435
436  exit:
437     if (ssl) {
438         SSL_shutdown(ssl);
439         SSL_free(ssl);
440     }
441     ERR_remove_state(0);
442     shutdown(s, SHUT_RDWR);
443     close(s);
444     pthread_exit(NULL);
445 }
446
447 void *tlslistener(void *arg) {
448     pthread_t tlsserverth;
449     int s, *sp = (int *)arg;
450     struct sockaddr_storage from;
451     socklen_t fromlen = sizeof(from);
452
453     listen(*sp, 0);
454
455     for (;;) {
456         s = accept(*sp, (struct sockaddr *)&from, &fromlen);
457         if (s < 0) {
458             debug(DBG_WARN, "accept failed");
459             continue;
460         }
461         if (pthread_create(&tlsserverth, NULL, tlsservernew, (void *)&s)) {
462             debug(DBG_ERR, "tlslistener: pthread_create failed");
463             shutdown(s, SHUT_RDWR);
464             close(s);
465             continue;
466         }
467         pthread_detach(tlsserverth);
468     }
469     free(sp);
470     return NULL;
471 }
472 #else
473 const struct protodefs *tlsinit(uint8_t h) {
474     return NULL;
475 }
476 #endif