some code improvemetns, more efficiently removing outstanding requests when removing...
[libradsec.git] / tls.c
1 /*
2  * Copyright (C) 2006-2008 Stig Venaas <venaas@uninett.no>
3  *
4  * Permission to use, copy, modify, and distribute this software for any
5  * purpose with or without fee is hereby granted, provided that the above
6  * copyright notice and this permission notice appear in all copies.
7  */
8
9 #include <signal.h>
10 #include <sys/socket.h>
11 #include <netinet/in.h>
12 #include <netdb.h>
13 #include <string.h>
14 #include <unistd.h>
15 #include <limits.h>
16 #ifdef SYS_SOLARIS9
17 #include <fcntl.h>
18 #endif
19 #include <sys/time.h>
20 #include <sys/types.h>
21 #include <sys/select.h>
22 #include <ctype.h>
23 #include <sys/wait.h>
24 #include <arpa/inet.h>
25 #include <regex.h>
26 #include <pthread.h>
27 #include <openssl/ssl.h>
28 #include <openssl/err.h>
29 #include "debug.h"
30 #include "list.h"
31 #include "util.h"
32 #include "radsecproxy.h"
33 #include "tls.h"
34
35 int tlsconnect(struct server *server, struct timeval *when, int timeout, char *text) {
36     struct timeval now;
37     time_t elapsed;
38     X509 *cert;
39     SSL_CTX *ctx = NULL;
40     unsigned long error;
41     
42     debug(DBG_DBG, "tlsconnect: called from %s", text);
43     pthread_mutex_lock(&server->lock);
44     if (when && memcmp(&server->lastconnecttry, when, sizeof(struct timeval))) {
45         /* already reconnected, nothing to do */
46         debug(DBG_DBG, "tlsconnect(%s): seems already reconnected", text);
47         pthread_mutex_unlock(&server->lock);
48         return 1;
49     }
50
51     for (;;) {
52         gettimeofday(&now, NULL);
53         elapsed = now.tv_sec - server->lastconnecttry.tv_sec;
54         if (timeout && server->lastconnecttry.tv_sec && elapsed > timeout) {
55             debug(DBG_DBG, "tlsconnect: timeout");
56             if (server->sock >= 0)
57                 close(server->sock);
58             SSL_free(server->ssl);
59             server->ssl = NULL;
60             pthread_mutex_unlock(&server->lock);
61             return 0;
62         }
63         if (server->connectionok) {
64             server->connectionok = 0;
65             sleep(2);
66         } else if (elapsed < 1)
67             sleep(2);
68         else if (elapsed < 60) {
69             debug(DBG_INFO, "tlsconnect: sleeping %lds", elapsed);
70             sleep(elapsed);
71         } else if (elapsed < 100000) {
72             debug(DBG_INFO, "tlsconnect: sleeping %ds", 60);
73             sleep(60);
74         } else
75             server->lastconnecttry.tv_sec = now.tv_sec;  /* no sleep at startup */
76         debug(DBG_WARN, "tlsconnect: trying to open TLS connection to %s port %s", server->conf->host, server->conf->port);
77         if (server->sock >= 0)
78             close(server->sock);
79         if ((server->sock = connecttcp(server->conf->addrinfo, getsrcprotores(RAD_TLS))) < 0) {
80             debug(DBG_ERR, "tlsconnect: connecttcp failed");
81             continue;
82         }
83         
84         SSL_free(server->ssl);
85         server->ssl = NULL;
86         ctx = tlsgetctx(RAD_TLS, server->conf->tlsconf);
87         if (!ctx)
88             continue;
89         server->ssl = SSL_new(ctx);
90         if (!server->ssl)
91             continue;
92
93         SSL_set_fd(server->ssl, server->sock);
94         if (SSL_connect(server->ssl) <= 0) {
95             while ((error = ERR_get_error()))
96                 debug(DBG_ERR, "tlsconnect: TLS: %s", ERR_error_string(error, NULL));
97             continue;
98         }
99         cert = verifytlscert(server->ssl);
100         if (!cert)
101             continue;
102         if (verifyconfcert(cert, server->conf)) {
103             X509_free(cert);
104             break;
105         }
106         X509_free(cert);
107     }
108     debug(DBG_WARN, "tlsconnect: TLS connection to %s port %s up", server->conf->host, server->conf->port);
109     server->connectionok = 1;
110     gettimeofday(&server->lastconnecttry, NULL);
111     pthread_mutex_unlock(&server->lock);
112     return 1;
113 }
114
115 /* timeout in seconds, 0 means no timeout (blocking), returns when num bytes have been read, or timeout */
116 /* returns 0 on timeout, -1 on error and num if ok */
117 int sslreadtimeout(SSL *ssl, unsigned char *buf, int num, int timeout) {
118     int s, ndesc, cnt, len;
119     fd_set readfds, writefds;
120     struct timeval timer;
121     
122     s = SSL_get_fd(ssl);
123     if (s < 0)
124         return -1;
125     /* make socket non-blocking? */
126     for (len = 0; len < num; len += cnt) {
127         FD_ZERO(&readfds);
128         FD_SET(s, &readfds);
129         writefds = readfds;
130         if (timeout) {
131             timer.tv_sec = timeout;
132             timer.tv_usec = 0;
133         }
134         ndesc = select(s + 1, &readfds, &writefds, NULL, timeout ? &timer : NULL);
135         if (ndesc < 1)
136             return ndesc;
137
138         cnt = SSL_read(ssl, buf + len, num - len);
139         if (cnt <= 0)
140             switch (SSL_get_error(ssl, cnt)) {
141             case SSL_ERROR_WANT_READ:
142             case SSL_ERROR_WANT_WRITE:
143                 cnt = 0;
144                 continue;
145             case SSL_ERROR_ZERO_RETURN:
146                 /* remote end sent close_notify, send one back */
147                 SSL_shutdown(ssl);
148                 return -1;
149             default:
150                 return -1;
151             }
152     }
153     return num;
154 }
155
156 /* timeout in seconds, 0 means no timeout (blocking) */
157 unsigned char *radtlsget(SSL *ssl, int timeout) {
158     int cnt, len;
159     unsigned char buf[4], *rad;
160
161     for (;;) {
162         cnt = sslreadtimeout(ssl, buf, 4, timeout);
163         if (cnt < 1) {
164             debug(DBG_DBG, cnt ? "radtlsget: connection lost" : "radtlsget: timeout");
165             return NULL;
166         }
167
168         len = RADLEN(buf);
169         rad = malloc(len);
170         if (!rad) {
171             debug(DBG_ERR, "radtlsget: malloc failed");
172             continue;
173         }
174         memcpy(rad, buf, 4);
175         
176         cnt = sslreadtimeout(ssl, rad + 4, len - 4, timeout);
177         if (cnt < 1) {
178             debug(DBG_DBG, cnt ? "radtlsget: connection lost" : "radtlsget: timeout");
179             free(rad);
180             return NULL;
181         }
182         
183         if (len >= 20)
184             break;
185         
186         free(rad);
187         debug(DBG_WARN, "radtlsget: packet smaller than minimum radius size");
188     }
189     
190     debug(DBG_DBG, "radtlsget: got %d bytes", len);
191     return rad;
192 }
193
194 int clientradputtls(struct server *server, unsigned char *rad) {
195     int cnt;
196     size_t len;
197     unsigned long error;
198     struct clsrvconf *conf = server->conf;
199
200     if (!server->connectionok)
201         return 0;
202     len = RADLEN(rad);
203     if ((cnt = SSL_write(server->ssl, rad, len)) <= 0) {
204         while ((error = ERR_get_error()))
205             debug(DBG_ERR, "clientradputtls: TLS: %s", ERR_error_string(error, NULL));
206         return 0;
207     }
208
209     debug(DBG_DBG, "clientradputtls: Sent %d bytes, Radius packet of length %d to TLS peer %s", cnt, len, conf->host);
210     return 1;
211 }
212
213 void *tlsclientrd(void *arg) {
214     struct server *server = (struct server *)arg;
215     unsigned char *buf;
216     struct timeval now, lastconnecttry;
217     
218     for (;;) {
219         /* yes, lastconnecttry is really necessary */
220         lastconnecttry = server->lastconnecttry;
221         buf = radtlsget(server->ssl, server->dynamiclookuparg ? IDLE_TIMEOUT : 0);
222         if (!buf) {
223             if (server->dynamiclookuparg)
224                 break;
225             tlsconnect(server, &lastconnecttry, 0, "tlsclientrd");
226             continue;
227         }
228
229         replyh(server, buf);
230
231         if (server->dynamiclookuparg) {
232             gettimeofday(&now, NULL);
233             if (now.tv_sec - server->lastreply.tv_sec > IDLE_TIMEOUT) {
234                 debug(DBG_INFO, "tlsclientrd: idle timeout for %s", server->conf->name);
235                 break;
236             }
237         }
238     }
239     ERR_remove_state(0);
240     server->clientrdgone = 1;
241     return NULL;
242 }
243
244 void *tlsserverwr(void *arg) {
245     int cnt;
246     unsigned long error;
247     struct client *client = (struct client *)arg;
248     struct queue *replyq;
249     struct request *reply;
250     
251     debug(DBG_DBG, "tlsserverwr: starting for %s", addr2string(client->addr));
252     replyq = client->replyq;
253     for (;;) {
254         pthread_mutex_lock(&replyq->mutex);
255         while (!list_first(replyq->entries)) {
256             if (client->ssl) {      
257                 debug(DBG_DBG, "tlsserverwr: waiting for signal");
258                 pthread_cond_wait(&replyq->cond, &replyq->mutex);
259                 debug(DBG_DBG, "tlsserverwr: got signal");
260             }
261             if (!client->ssl) {
262                 /* ssl might have changed while waiting */
263                 pthread_mutex_unlock(&replyq->mutex);
264                 debug(DBG_DBG, "tlsserverwr: exiting as requested");
265                 ERR_remove_state(0);
266                 pthread_exit(NULL);
267             }
268         }
269         reply = (struct request *)list_shift(replyq->entries);
270         pthread_mutex_unlock(&replyq->mutex);
271         cnt = SSL_write(client->ssl, reply->replybuf, RADLEN(reply->replybuf));
272         if (cnt > 0)
273             debug(DBG_DBG, "tlsserverwr: sent %d bytes, Radius packet of length %d to %s",
274                   cnt, RADLEN(reply->replybuf), addr2string(client->addr));
275         else
276             while ((error = ERR_get_error()))
277                 debug(DBG_ERR, "tlsserverwr: SSL: %s", ERR_error_string(error, NULL));
278         freerq(reply);
279     }
280 }
281
282 void tlsserverrd(struct client *client) {
283     struct request *rq;
284     uint8_t *buf;
285     pthread_t tlsserverwrth;
286     
287     debug(DBG_DBG, "tlsserverrd: starting for %s", addr2string(client->addr));
288     
289     if (pthread_create(&tlsserverwrth, NULL, tlsserverwr, (void *)client)) {
290         debug(DBG_ERR, "tlsserverrd: pthread_create failed");
291         return;
292     }
293
294     for (;;) {
295         buf = radtlsget(client->ssl, 0);
296         if (!buf) {
297             debug(DBG_ERR, "tlsserverrd: connection from %s lost", addr2string(client->addr));
298             break;
299         }
300         debug(DBG_DBG, "tlsserverrd: got Radius message from %s", addr2string(client->addr));
301         rq = newrequest();
302         if (!rq) {
303             free(buf);
304             continue;
305         }
306         rq->buf = buf;
307         rq->from = client;
308         if (!radsrv(rq)) {
309             debug(DBG_ERR, "tlsserverrd: message authentication/validation failed, closing connection from %s", addr2string(client->addr));
310             break;
311         }
312     }
313     
314     /* stop writer by setting ssl to NULL and give signal in case waiting for data */
315     client->ssl = NULL;
316     pthread_mutex_lock(&client->replyq->mutex);
317     pthread_cond_signal(&client->replyq->cond);
318     pthread_mutex_unlock(&client->replyq->mutex);
319     debug(DBG_DBG, "tlsserverrd: waiting for writer to end");
320     pthread_join(tlsserverwrth, NULL);
321     debug(DBG_DBG, "tlsserverrd: reader for %s exiting", addr2string(client->addr));
322 }
323
324 void *tlsservernew(void *arg) {
325     int s;
326     struct sockaddr_storage from;
327     size_t fromlen = sizeof(from);
328     struct clsrvconf *conf;
329     struct list_node *cur = NULL;
330     SSL *ssl = NULL;
331     X509 *cert = NULL;
332     SSL_CTX *ctx = NULL;
333     unsigned long error;
334     struct client *client;
335
336     s = *(int *)arg;
337     if (getpeername(s, (struct sockaddr *)&from, &fromlen)) {
338         debug(DBG_DBG, "tlsservernew: getpeername failed, exiting");
339         goto exit;
340     }
341     debug(DBG_WARN, "tlsservernew: incoming TLS connection from %s", addr2string((struct sockaddr *)&from));
342
343     conf = find_clconf(RAD_TLS, (struct sockaddr *)&from, &cur);
344     if (conf) {
345         ctx = tlsgetctx(RAD_TLS, conf->tlsconf);
346         if (!ctx)
347             goto exit;
348         ssl = SSL_new(ctx);
349         if (!ssl)
350             goto exit;
351         SSL_set_fd(ssl, s);
352
353         if (SSL_accept(ssl) <= 0) {
354             while ((error = ERR_get_error()))
355                 debug(DBG_ERR, "tlsservernew: SSL: %s", ERR_error_string(error, NULL));
356             debug(DBG_ERR, "tlsservernew: SSL_accept failed");
357             goto exit;
358         }
359         cert = verifytlscert(ssl);
360         if (!cert)
361             goto exit;
362     }
363     
364     while (conf) {
365         if (verifyconfcert(cert, conf)) {
366             X509_free(cert);
367             client = addclient(conf, 1);
368             if (client) {
369                 client->ssl = ssl;
370                 client->addr = addr_copy((struct sockaddr *)&from);
371                 tlsserverrd(client);
372                 removeclient(client);
373             } else
374                 debug(DBG_WARN, "tlsservernew: failed to create new client instance");
375             goto exit;
376         }
377         conf = find_clconf(RAD_TLS, (struct sockaddr *)&from, &cur);
378     }
379     debug(DBG_WARN, "tlsservernew: ignoring request, no matching TLS client");
380     if (cert)
381         X509_free(cert);
382
383  exit:
384     if (ssl) {
385         SSL_shutdown(ssl);
386         SSL_free(ssl);
387     }
388     ERR_remove_state(0);
389     shutdown(s, SHUT_RDWR);
390     close(s);
391     pthread_exit(NULL);
392 }
393
394 void *tlslistener(void *arg) {
395     pthread_t tlsserverth;
396     int s, *sp = (int *)arg;
397     struct sockaddr_storage from;
398     size_t fromlen = sizeof(from);
399
400     listen(*sp, 0);
401
402     for (;;) {
403         s = accept(*sp, (struct sockaddr *)&from, &fromlen);
404         if (s < 0) {
405             debug(DBG_WARN, "accept failed");
406             continue;
407         }
408         if (pthread_create(&tlsserverth, NULL, tlsservernew, (void *)&s)) {
409             debug(DBG_ERR, "tlslistener: pthread_create failed");
410             shutdown(s, SHUT_RDWR);
411             close(s);
412             continue;
413         }
414         pthread_detach(tlsserverth);
415     }
416     free(sp);
417     return NULL;
418 }