cleaning up code
[libradsec.git] / tls.c
1 /*
2  * Copyright (C) 2006-2008 Stig Venaas <venaas@uninett.no>
3  *
4  * Permission to use, copy, modify, and distribute this software for any
5  * purpose with or without fee is hereby granted, provided that the above
6  * copyright notice and this permission notice appear in all copies.
7  */
8
9 #include <signal.h>
10 #include <sys/socket.h>
11 #include <netinet/in.h>
12 #include <netdb.h>
13 #include <string.h>
14 #include <unistd.h>
15 #include <limits.h>
16 #ifdef SYS_SOLARIS9
17 #include <fcntl.h>
18 #endif
19 #include <sys/time.h>
20 #include <sys/types.h>
21 #include <sys/select.h>
22 #include <ctype.h>
23 #include <sys/wait.h>
24 #include <arpa/inet.h>
25 #include <regex.h>
26 #include <pthread.h>
27 #include <openssl/ssl.h>
28 #include <openssl/err.h>
29 #include "debug.h"
30 #include "list.h"
31 #include "util.h"
32 #include "radsecproxy.h"
33
34 void *tlslistener(void *arg);
35 int tlsconnect(struct server *server, struct timeval *when, int timeout, char *text);
36 void *tlsclientrd(void *arg);
37 int clientradputtls(struct server *server, unsigned char *rad);
38 void tlssetsrcres(char *source);
39
40 static const struct protodefs protodefs = {
41     "tls",
42     "mysecret", /* secretdefault */
43     SOCK_STREAM, /* socktype */
44     "2083", /* portdefault */
45     0, /* retrycountdefault */
46     0, /* retrycountmax */
47     REQUEST_RETRY_INTERVAL * REQUEST_RETRY_COUNT, /* retryintervaldefault */
48     60, /* retryintervalmax */
49     DUPLICATE_INTERVAL, /* duplicateintervaldefault */
50     tlslistener, /* listener */
51     tlsconnect, /* connecter */
52     tlsclientrd, /* clientconnreader */
53     clientradputtls, /* clientradput */
54     NULL, /* addclient */
55     NULL, /* addserverextra */
56     tlssetsrcres, /* setsrcres */
57     NULL /* initextra */
58 };
59
60 static struct addrinfo *srcres = NULL;
61 static uint8_t handle;
62
63 const struct protodefs *tlsinit(uint8_t h) {
64     handle = h;
65     return &protodefs;
66 }
67
68 void tlssetsrcres(char *source) {
69     if (!srcres)
70         srcres = resolve_hostport_addrinfo(handle, source);
71 }
72
73 int tlsconnect(struct server *server, struct timeval *when, int timeout, char *text) {
74     struct timeval now;
75     time_t elapsed;
76     X509 *cert;
77     SSL_CTX *ctx = NULL;
78     unsigned long error;
79     
80     debug(DBG_DBG, "tlsconnect: called from %s", text);
81     pthread_mutex_lock(&server->lock);
82     if (when && memcmp(&server->lastconnecttry, when, sizeof(struct timeval))) {
83         /* already reconnected, nothing to do */
84         debug(DBG_DBG, "tlsconnect(%s): seems already reconnected", text);
85         pthread_mutex_unlock(&server->lock);
86         return 1;
87     }
88
89     for (;;) {
90         gettimeofday(&now, NULL);
91         elapsed = now.tv_sec - server->lastconnecttry.tv_sec;
92         if (timeout && server->lastconnecttry.tv_sec && elapsed > timeout) {
93             debug(DBG_DBG, "tlsconnect: timeout");
94             if (server->sock >= 0)
95                 close(server->sock);
96             SSL_free(server->ssl);
97             server->ssl = NULL;
98             pthread_mutex_unlock(&server->lock);
99             return 0;
100         }
101         if (server->connectionok) {
102             server->connectionok = 0;
103             sleep(2);
104         } else if (elapsed < 1)
105             sleep(2);
106         else if (elapsed < 60) {
107             debug(DBG_INFO, "tlsconnect: sleeping %lds", elapsed);
108             sleep(elapsed);
109         } else if (elapsed < 100000) {
110             debug(DBG_INFO, "tlsconnect: sleeping %ds", 60);
111             sleep(60);
112         } else
113             server->lastconnecttry.tv_sec = now.tv_sec;  /* no sleep at startup */
114         debug(DBG_WARN, "tlsconnect: trying to open TLS connection to %s port %s", server->conf->host, server->conf->port);
115         if (server->sock >= 0)
116             close(server->sock);
117         if ((server->sock = connecttcp(server->conf->addrinfo, srcres)) < 0) {
118             debug(DBG_ERR, "tlsconnect: connecttcp failed");
119             continue;
120         }
121         
122         SSL_free(server->ssl);
123         server->ssl = NULL;
124         ctx = tlsgetctx(handle, server->conf->tlsconf);
125         if (!ctx)
126             continue;
127         server->ssl = SSL_new(ctx);
128         if (!server->ssl)
129             continue;
130
131         SSL_set_fd(server->ssl, server->sock);
132         if (SSL_connect(server->ssl) <= 0) {
133             while ((error = ERR_get_error()))
134                 debug(DBG_ERR, "tlsconnect: TLS: %s", ERR_error_string(error, NULL));
135             continue;
136         }
137         cert = verifytlscert(server->ssl);
138         if (!cert)
139             continue;
140         if (verifyconfcert(cert, server->conf)) {
141             X509_free(cert);
142             break;
143         }
144         X509_free(cert);
145     }
146     debug(DBG_WARN, "tlsconnect: TLS connection to %s port %s up", server->conf->host, server->conf->port);
147     server->connectionok = 1;
148     gettimeofday(&server->lastconnecttry, NULL);
149     pthread_mutex_unlock(&server->lock);
150     return 1;
151 }
152
153 /* timeout in seconds, 0 means no timeout (blocking), returns when num bytes have been read, or timeout */
154 /* returns 0 on timeout, -1 on error and num if ok */
155 int sslreadtimeout(SSL *ssl, unsigned char *buf, int num, int timeout) {
156     int s, ndesc, cnt, len;
157     fd_set readfds, writefds;
158     struct timeval timer;
159     
160     s = SSL_get_fd(ssl);
161     if (s < 0)
162         return -1;
163     /* make socket non-blocking? */
164     for (len = 0; len < num; len += cnt) {
165         FD_ZERO(&readfds);
166         FD_SET(s, &readfds);
167         writefds = readfds;
168         if (timeout) {
169             timer.tv_sec = timeout;
170             timer.tv_usec = 0;
171         }
172         ndesc = select(s + 1, &readfds, &writefds, NULL, timeout ? &timer : NULL);
173         if (ndesc < 1)
174             return ndesc;
175
176         cnt = SSL_read(ssl, buf + len, num - len);
177         if (cnt <= 0)
178             switch (SSL_get_error(ssl, cnt)) {
179             case SSL_ERROR_WANT_READ:
180             case SSL_ERROR_WANT_WRITE:
181                 cnt = 0;
182                 continue;
183             case SSL_ERROR_ZERO_RETURN:
184                 /* remote end sent close_notify, send one back */
185                 SSL_shutdown(ssl);
186                 return -1;
187             default:
188                 return -1;
189             }
190     }
191     return num;
192 }
193
194 /* timeout in seconds, 0 means no timeout (blocking) */
195 unsigned char *radtlsget(SSL *ssl, int timeout) {
196     int cnt, len;
197     unsigned char buf[4], *rad;
198
199     for (;;) {
200         cnt = sslreadtimeout(ssl, buf, 4, timeout);
201         if (cnt < 1) {
202             debug(DBG_DBG, cnt ? "radtlsget: connection lost" : "radtlsget: timeout");
203             return NULL;
204         }
205
206         len = RADLEN(buf);
207         rad = malloc(len);
208         if (!rad) {
209             debug(DBG_ERR, "radtlsget: malloc failed");
210             continue;
211         }
212         memcpy(rad, buf, 4);
213         
214         cnt = sslreadtimeout(ssl, rad + 4, len - 4, timeout);
215         if (cnt < 1) {
216             debug(DBG_DBG, cnt ? "radtlsget: connection lost" : "radtlsget: timeout");
217             free(rad);
218             return NULL;
219         }
220         
221         if (len >= 20)
222             break;
223         
224         free(rad);
225         debug(DBG_WARN, "radtlsget: packet smaller than minimum radius size");
226     }
227     
228     debug(DBG_DBG, "radtlsget: got %d bytes", len);
229     return rad;
230 }
231
232 int clientradputtls(struct server *server, unsigned char *rad) {
233     int cnt;
234     size_t len;
235     unsigned long error;
236     struct clsrvconf *conf = server->conf;
237
238     if (!server->connectionok)
239         return 0;
240     len = RADLEN(rad);
241     if ((cnt = SSL_write(server->ssl, rad, len)) <= 0) {
242         while ((error = ERR_get_error()))
243             debug(DBG_ERR, "clientradputtls: TLS: %s", ERR_error_string(error, NULL));
244         return 0;
245     }
246
247     debug(DBG_DBG, "clientradputtls: Sent %d bytes, Radius packet of length %d to TLS peer %s", cnt, len, conf->host);
248     return 1;
249 }
250
251 void *tlsclientrd(void *arg) {
252     struct server *server = (struct server *)arg;
253     unsigned char *buf;
254     struct timeval now, lastconnecttry;
255     
256     for (;;) {
257         /* yes, lastconnecttry is really necessary */
258         lastconnecttry = server->lastconnecttry;
259         buf = radtlsget(server->ssl, server->dynamiclookuparg ? IDLE_TIMEOUT : 0);
260         if (!buf) {
261             if (server->dynamiclookuparg)
262                 break;
263             tlsconnect(server, &lastconnecttry, 0, "tlsclientrd");
264             continue;
265         }
266
267         replyh(server, buf);
268
269         if (server->dynamiclookuparg) {
270             gettimeofday(&now, NULL);
271             if (now.tv_sec - server->lastreply.tv_sec > IDLE_TIMEOUT) {
272                 debug(DBG_INFO, "tlsclientrd: idle timeout for %s", server->conf->name);
273                 break;
274             }
275         }
276     }
277     ERR_remove_state(0);
278     server->clientrdgone = 1;
279     return NULL;
280 }
281
282 void *tlsserverwr(void *arg) {
283     int cnt;
284     unsigned long error;
285     struct client *client = (struct client *)arg;
286     struct queue *replyq;
287     struct request *reply;
288     
289     debug(DBG_DBG, "tlsserverwr: starting for %s", addr2string(client->addr));
290     replyq = client->replyq;
291     for (;;) {
292         pthread_mutex_lock(&replyq->mutex);
293         while (!list_first(replyq->entries)) {
294             if (client->ssl) {      
295                 debug(DBG_DBG, "tlsserverwr: waiting for signal");
296                 pthread_cond_wait(&replyq->cond, &replyq->mutex);
297                 debug(DBG_DBG, "tlsserverwr: got signal");
298             }
299             if (!client->ssl) {
300                 /* ssl might have changed while waiting */
301                 pthread_mutex_unlock(&replyq->mutex);
302                 debug(DBG_DBG, "tlsserverwr: exiting as requested");
303                 ERR_remove_state(0);
304                 pthread_exit(NULL);
305             }
306         }
307         reply = (struct request *)list_shift(replyq->entries);
308         pthread_mutex_unlock(&replyq->mutex);
309         cnt = SSL_write(client->ssl, reply->replybuf, RADLEN(reply->replybuf));
310         if (cnt > 0)
311             debug(DBG_DBG, "tlsserverwr: sent %d bytes, Radius packet of length %d to %s",
312                   cnt, RADLEN(reply->replybuf), addr2string(client->addr));
313         else
314             while ((error = ERR_get_error()))
315                 debug(DBG_ERR, "tlsserverwr: SSL: %s", ERR_error_string(error, NULL));
316         freerq(reply);
317     }
318 }
319
320 void tlsserverrd(struct client *client) {
321     struct request *rq;
322     uint8_t *buf;
323     pthread_t tlsserverwrth;
324     
325     debug(DBG_DBG, "tlsserverrd: starting for %s", addr2string(client->addr));
326     
327     if (pthread_create(&tlsserverwrth, NULL, tlsserverwr, (void *)client)) {
328         debug(DBG_ERR, "tlsserverrd: pthread_create failed");
329         return;
330     }
331
332     for (;;) {
333         buf = radtlsget(client->ssl, 0);
334         if (!buf) {
335             debug(DBG_ERR, "tlsserverrd: connection from %s lost", addr2string(client->addr));
336             break;
337         }
338         debug(DBG_DBG, "tlsserverrd: got Radius message from %s", addr2string(client->addr));
339         rq = newrequest();
340         if (!rq) {
341             free(buf);
342             continue;
343         }
344         rq->buf = buf;
345         rq->from = client;
346         if (!radsrv(rq)) {
347             debug(DBG_ERR, "tlsserverrd: message authentication/validation failed, closing connection from %s", addr2string(client->addr));
348             break;
349         }
350     }
351     
352     /* stop writer by setting ssl to NULL and give signal in case waiting for data */
353     client->ssl = NULL;
354     pthread_mutex_lock(&client->replyq->mutex);
355     pthread_cond_signal(&client->replyq->cond);
356     pthread_mutex_unlock(&client->replyq->mutex);
357     debug(DBG_DBG, "tlsserverrd: waiting for writer to end");
358     pthread_join(tlsserverwrth, NULL);
359     debug(DBG_DBG, "tlsserverrd: reader for %s exiting", addr2string(client->addr));
360 }
361
362 void *tlsservernew(void *arg) {
363     int s;
364     struct sockaddr_storage from;
365     socklen_t fromlen = sizeof(from);
366     struct clsrvconf *conf;
367     struct list_node *cur = NULL;
368     SSL *ssl = NULL;
369     X509 *cert = NULL;
370     SSL_CTX *ctx = NULL;
371     unsigned long error;
372     struct client *client;
373
374     s = *(int *)arg;
375     if (getpeername(s, (struct sockaddr *)&from, &fromlen)) {
376         debug(DBG_DBG, "tlsservernew: getpeername failed, exiting");
377         goto exit;
378     }
379     debug(DBG_WARN, "tlsservernew: incoming TLS connection from %s", addr2string((struct sockaddr *)&from));
380
381     conf = find_clconf(handle, (struct sockaddr *)&from, &cur);
382     if (conf) {
383         ctx = tlsgetctx(handle, conf->tlsconf);
384         if (!ctx)
385             goto exit;
386         ssl = SSL_new(ctx);
387         if (!ssl)
388             goto exit;
389         SSL_set_fd(ssl, s);
390
391         if (SSL_accept(ssl) <= 0) {
392             while ((error = ERR_get_error()))
393                 debug(DBG_ERR, "tlsservernew: SSL: %s", ERR_error_string(error, NULL));
394             debug(DBG_ERR, "tlsservernew: SSL_accept failed");
395             goto exit;
396         }
397         cert = verifytlscert(ssl);
398         if (!cert)
399             goto exit;
400     }
401     
402     while (conf) {
403         if (verifyconfcert(cert, conf)) {
404             X509_free(cert);
405             client = addclient(conf, 1);
406             if (client) {
407                 client->ssl = ssl;
408                 client->addr = addr_copy((struct sockaddr *)&from);
409                 tlsserverrd(client);
410                 removeclient(client);
411             } else
412                 debug(DBG_WARN, "tlsservernew: failed to create new client instance");
413             goto exit;
414         }
415         conf = find_clconf(handle, (struct sockaddr *)&from, &cur);
416     }
417     debug(DBG_WARN, "tlsservernew: ignoring request, no matching TLS client");
418     if (cert)
419         X509_free(cert);
420
421  exit:
422     if (ssl) {
423         SSL_shutdown(ssl);
424         SSL_free(ssl);
425     }
426     ERR_remove_state(0);
427     shutdown(s, SHUT_RDWR);
428     close(s);
429     pthread_exit(NULL);
430 }
431
432 void *tlslistener(void *arg) {
433     pthread_t tlsserverth;
434     int s, *sp = (int *)arg;
435     struct sockaddr_storage from;
436     socklen_t fromlen = sizeof(from);
437
438     listen(*sp, 0);
439
440     for (;;) {
441         s = accept(*sp, (struct sockaddr *)&from, &fromlen);
442         if (s < 0) {
443             debug(DBG_WARN, "accept failed");
444             continue;
445         }
446         if (pthread_create(&tlsserverth, NULL, tlsservernew, (void *)&s)) {
447             debug(DBG_ERR, "tlslistener: pthread_create failed");
448             shutdown(s, SHUT_RDWR);
449             close(s);
450             continue;
451         }
452         pthread_detach(tlsserverth);
453     }
454     free(sp);
455     return NULL;
456 }