cleaning up code
[libradsec.git] / tls.c
1 /*
2  * Copyright (C) 2006-2008 Stig Venaas <venaas@uninett.no>
3  *
4  * Permission to use, copy, modify, and distribute this software for any
5  * purpose with or without fee is hereby granted, provided that the above
6  * copyright notice and this permission notice appear in all copies.
7  */
8
9 #include <signal.h>
10 #include <sys/socket.h>
11 #include <netinet/in.h>
12 #include <netdb.h>
13 #include <string.h>
14 #include <unistd.h>
15 #include <limits.h>
16 #ifdef SYS_SOLARIS9
17 #include <fcntl.h>
18 #endif
19 #include <sys/time.h>
20 #include <sys/types.h>
21 #include <sys/select.h>
22 #include <ctype.h>
23 #include <sys/wait.h>
24 #include <arpa/inet.h>
25 #include <regex.h>
26 #include <pthread.h>
27 #include <openssl/ssl.h>
28 #include <openssl/err.h>
29 #include "debug.h"
30 #include "list.h"
31 #include "util.h"
32 #include "radsecproxy.h"
33 #include "tls.h"
34
35 static struct addrinfo *srcres = NULL;
36
37 void tlssetsrcres(char *source) {
38     if (!srcres)
39         srcres = resolve_hostport_addrinfo(RAD_TLS, source);
40 }
41
42 int tlsconnect(struct server *server, struct timeval *when, int timeout, char *text) {
43     struct timeval now;
44     time_t elapsed;
45     X509 *cert;
46     SSL_CTX *ctx = NULL;
47     unsigned long error;
48     
49     debug(DBG_DBG, "tlsconnect: called from %s", text);
50     pthread_mutex_lock(&server->lock);
51     if (when && memcmp(&server->lastconnecttry, when, sizeof(struct timeval))) {
52         /* already reconnected, nothing to do */
53         debug(DBG_DBG, "tlsconnect(%s): seems already reconnected", text);
54         pthread_mutex_unlock(&server->lock);
55         return 1;
56     }
57
58     for (;;) {
59         gettimeofday(&now, NULL);
60         elapsed = now.tv_sec - server->lastconnecttry.tv_sec;
61         if (timeout && server->lastconnecttry.tv_sec && elapsed > timeout) {
62             debug(DBG_DBG, "tlsconnect: timeout");
63             if (server->sock >= 0)
64                 close(server->sock);
65             SSL_free(server->ssl);
66             server->ssl = NULL;
67             pthread_mutex_unlock(&server->lock);
68             return 0;
69         }
70         if (server->connectionok) {
71             server->connectionok = 0;
72             sleep(2);
73         } else if (elapsed < 1)
74             sleep(2);
75         else if (elapsed < 60) {
76             debug(DBG_INFO, "tlsconnect: sleeping %lds", elapsed);
77             sleep(elapsed);
78         } else if (elapsed < 100000) {
79             debug(DBG_INFO, "tlsconnect: sleeping %ds", 60);
80             sleep(60);
81         } else
82             server->lastconnecttry.tv_sec = now.tv_sec;  /* no sleep at startup */
83         debug(DBG_WARN, "tlsconnect: trying to open TLS connection to %s port %s", server->conf->host, server->conf->port);
84         if (server->sock >= 0)
85             close(server->sock);
86         if ((server->sock = connecttcp(server->conf->addrinfo, srcres)) < 0) {
87             debug(DBG_ERR, "tlsconnect: connecttcp failed");
88             continue;
89         }
90         
91         SSL_free(server->ssl);
92         server->ssl = NULL;
93         ctx = tlsgetctx(RAD_TLS, server->conf->tlsconf);
94         if (!ctx)
95             continue;
96         server->ssl = SSL_new(ctx);
97         if (!server->ssl)
98             continue;
99
100         SSL_set_fd(server->ssl, server->sock);
101         if (SSL_connect(server->ssl) <= 0) {
102             while ((error = ERR_get_error()))
103                 debug(DBG_ERR, "tlsconnect: TLS: %s", ERR_error_string(error, NULL));
104             continue;
105         }
106         cert = verifytlscert(server->ssl);
107         if (!cert)
108             continue;
109         if (verifyconfcert(cert, server->conf)) {
110             X509_free(cert);
111             break;
112         }
113         X509_free(cert);
114     }
115     debug(DBG_WARN, "tlsconnect: TLS connection to %s port %s up", server->conf->host, server->conf->port);
116     server->connectionok = 1;
117     gettimeofday(&server->lastconnecttry, NULL);
118     pthread_mutex_unlock(&server->lock);
119     return 1;
120 }
121
122 /* timeout in seconds, 0 means no timeout (blocking), returns when num bytes have been read, or timeout */
123 /* returns 0 on timeout, -1 on error and num if ok */
124 int sslreadtimeout(SSL *ssl, unsigned char *buf, int num, int timeout) {
125     int s, ndesc, cnt, len;
126     fd_set readfds, writefds;
127     struct timeval timer;
128     
129     s = SSL_get_fd(ssl);
130     if (s < 0)
131         return -1;
132     /* make socket non-blocking? */
133     for (len = 0; len < num; len += cnt) {
134         FD_ZERO(&readfds);
135         FD_SET(s, &readfds);
136         writefds = readfds;
137         if (timeout) {
138             timer.tv_sec = timeout;
139             timer.tv_usec = 0;
140         }
141         ndesc = select(s + 1, &readfds, &writefds, NULL, timeout ? &timer : NULL);
142         if (ndesc < 1)
143             return ndesc;
144
145         cnt = SSL_read(ssl, buf + len, num - len);
146         if (cnt <= 0)
147             switch (SSL_get_error(ssl, cnt)) {
148             case SSL_ERROR_WANT_READ:
149             case SSL_ERROR_WANT_WRITE:
150                 cnt = 0;
151                 continue;
152             case SSL_ERROR_ZERO_RETURN:
153                 /* remote end sent close_notify, send one back */
154                 SSL_shutdown(ssl);
155                 return -1;
156             default:
157                 return -1;
158             }
159     }
160     return num;
161 }
162
163 /* timeout in seconds, 0 means no timeout (blocking) */
164 unsigned char *radtlsget(SSL *ssl, int timeout) {
165     int cnt, len;
166     unsigned char buf[4], *rad;
167
168     for (;;) {
169         cnt = sslreadtimeout(ssl, buf, 4, timeout);
170         if (cnt < 1) {
171             debug(DBG_DBG, cnt ? "radtlsget: connection lost" : "radtlsget: timeout");
172             return NULL;
173         }
174
175         len = RADLEN(buf);
176         rad = malloc(len);
177         if (!rad) {
178             debug(DBG_ERR, "radtlsget: malloc failed");
179             continue;
180         }
181         memcpy(rad, buf, 4);
182         
183         cnt = sslreadtimeout(ssl, rad + 4, len - 4, timeout);
184         if (cnt < 1) {
185             debug(DBG_DBG, cnt ? "radtlsget: connection lost" : "radtlsget: timeout");
186             free(rad);
187             return NULL;
188         }
189         
190         if (len >= 20)
191             break;
192         
193         free(rad);
194         debug(DBG_WARN, "radtlsget: packet smaller than minimum radius size");
195     }
196     
197     debug(DBG_DBG, "radtlsget: got %d bytes", len);
198     return rad;
199 }
200
201 int clientradputtls(struct server *server, unsigned char *rad) {
202     int cnt;
203     size_t len;
204     unsigned long error;
205     struct clsrvconf *conf = server->conf;
206
207     if (!server->connectionok)
208         return 0;
209     len = RADLEN(rad);
210     if ((cnt = SSL_write(server->ssl, rad, len)) <= 0) {
211         while ((error = ERR_get_error()))
212             debug(DBG_ERR, "clientradputtls: TLS: %s", ERR_error_string(error, NULL));
213         return 0;
214     }
215
216     debug(DBG_DBG, "clientradputtls: Sent %d bytes, Radius packet of length %d to TLS peer %s", cnt, len, conf->host);
217     return 1;
218 }
219
220 void *tlsclientrd(void *arg) {
221     struct server *server = (struct server *)arg;
222     unsigned char *buf;
223     struct timeval now, lastconnecttry;
224     
225     for (;;) {
226         /* yes, lastconnecttry is really necessary */
227         lastconnecttry = server->lastconnecttry;
228         buf = radtlsget(server->ssl, server->dynamiclookuparg ? IDLE_TIMEOUT : 0);
229         if (!buf) {
230             if (server->dynamiclookuparg)
231                 break;
232             tlsconnect(server, &lastconnecttry, 0, "tlsclientrd");
233             continue;
234         }
235
236         replyh(server, buf);
237
238         if (server->dynamiclookuparg) {
239             gettimeofday(&now, NULL);
240             if (now.tv_sec - server->lastreply.tv_sec > IDLE_TIMEOUT) {
241                 debug(DBG_INFO, "tlsclientrd: idle timeout for %s", server->conf->name);
242                 break;
243             }
244         }
245     }
246     ERR_remove_state(0);
247     server->clientrdgone = 1;
248     return NULL;
249 }
250
251 void *tlsserverwr(void *arg) {
252     int cnt;
253     unsigned long error;
254     struct client *client = (struct client *)arg;
255     struct queue *replyq;
256     struct request *reply;
257     
258     debug(DBG_DBG, "tlsserverwr: starting for %s", addr2string(client->addr));
259     replyq = client->replyq;
260     for (;;) {
261         pthread_mutex_lock(&replyq->mutex);
262         while (!list_first(replyq->entries)) {
263             if (client->ssl) {      
264                 debug(DBG_DBG, "tlsserverwr: waiting for signal");
265                 pthread_cond_wait(&replyq->cond, &replyq->mutex);
266                 debug(DBG_DBG, "tlsserverwr: got signal");
267             }
268             if (!client->ssl) {
269                 /* ssl might have changed while waiting */
270                 pthread_mutex_unlock(&replyq->mutex);
271                 debug(DBG_DBG, "tlsserverwr: exiting as requested");
272                 ERR_remove_state(0);
273                 pthread_exit(NULL);
274             }
275         }
276         reply = (struct request *)list_shift(replyq->entries);
277         pthread_mutex_unlock(&replyq->mutex);
278         cnt = SSL_write(client->ssl, reply->replybuf, RADLEN(reply->replybuf));
279         if (cnt > 0)
280             debug(DBG_DBG, "tlsserverwr: sent %d bytes, Radius packet of length %d to %s",
281                   cnt, RADLEN(reply->replybuf), addr2string(client->addr));
282         else
283             while ((error = ERR_get_error()))
284                 debug(DBG_ERR, "tlsserverwr: SSL: %s", ERR_error_string(error, NULL));
285         freerq(reply);
286     }
287 }
288
289 void tlsserverrd(struct client *client) {
290     struct request *rq;
291     uint8_t *buf;
292     pthread_t tlsserverwrth;
293     
294     debug(DBG_DBG, "tlsserverrd: starting for %s", addr2string(client->addr));
295     
296     if (pthread_create(&tlsserverwrth, NULL, tlsserverwr, (void *)client)) {
297         debug(DBG_ERR, "tlsserverrd: pthread_create failed");
298         return;
299     }
300
301     for (;;) {
302         buf = radtlsget(client->ssl, 0);
303         if (!buf) {
304             debug(DBG_ERR, "tlsserverrd: connection from %s lost", addr2string(client->addr));
305             break;
306         }
307         debug(DBG_DBG, "tlsserverrd: got Radius message from %s", addr2string(client->addr));
308         rq = newrequest();
309         if (!rq) {
310             free(buf);
311             continue;
312         }
313         rq->buf = buf;
314         rq->from = client;
315         if (!radsrv(rq)) {
316             debug(DBG_ERR, "tlsserverrd: message authentication/validation failed, closing connection from %s", addr2string(client->addr));
317             break;
318         }
319     }
320     
321     /* stop writer by setting ssl to NULL and give signal in case waiting for data */
322     client->ssl = NULL;
323     pthread_mutex_lock(&client->replyq->mutex);
324     pthread_cond_signal(&client->replyq->cond);
325     pthread_mutex_unlock(&client->replyq->mutex);
326     debug(DBG_DBG, "tlsserverrd: waiting for writer to end");
327     pthread_join(tlsserverwrth, NULL);
328     debug(DBG_DBG, "tlsserverrd: reader for %s exiting", addr2string(client->addr));
329 }
330
331 void *tlsservernew(void *arg) {
332     int s;
333     struct sockaddr_storage from;
334     socklen_t fromlen = sizeof(from);
335     struct clsrvconf *conf;
336     struct list_node *cur = NULL;
337     SSL *ssl = NULL;
338     X509 *cert = NULL;
339     SSL_CTX *ctx = NULL;
340     unsigned long error;
341     struct client *client;
342
343     s = *(int *)arg;
344     if (getpeername(s, (struct sockaddr *)&from, &fromlen)) {
345         debug(DBG_DBG, "tlsservernew: getpeername failed, exiting");
346         goto exit;
347     }
348     debug(DBG_WARN, "tlsservernew: incoming TLS connection from %s", addr2string((struct sockaddr *)&from));
349
350     conf = find_clconf(RAD_TLS, (struct sockaddr *)&from, &cur);
351     if (conf) {
352         ctx = tlsgetctx(RAD_TLS, conf->tlsconf);
353         if (!ctx)
354             goto exit;
355         ssl = SSL_new(ctx);
356         if (!ssl)
357             goto exit;
358         SSL_set_fd(ssl, s);
359
360         if (SSL_accept(ssl) <= 0) {
361             while ((error = ERR_get_error()))
362                 debug(DBG_ERR, "tlsservernew: SSL: %s", ERR_error_string(error, NULL));
363             debug(DBG_ERR, "tlsservernew: SSL_accept failed");
364             goto exit;
365         }
366         cert = verifytlscert(ssl);
367         if (!cert)
368             goto exit;
369     }
370     
371     while (conf) {
372         if (verifyconfcert(cert, conf)) {
373             X509_free(cert);
374             client = addclient(conf, 1);
375             if (client) {
376                 client->ssl = ssl;
377                 client->addr = addr_copy((struct sockaddr *)&from);
378                 tlsserverrd(client);
379                 removeclient(client);
380             } else
381                 debug(DBG_WARN, "tlsservernew: failed to create new client instance");
382             goto exit;
383         }
384         conf = find_clconf(RAD_TLS, (struct sockaddr *)&from, &cur);
385     }
386     debug(DBG_WARN, "tlsservernew: ignoring request, no matching TLS client");
387     if (cert)
388         X509_free(cert);
389
390  exit:
391     if (ssl) {
392         SSL_shutdown(ssl);
393         SSL_free(ssl);
394     }
395     ERR_remove_state(0);
396     shutdown(s, SHUT_RDWR);
397     close(s);
398     pthread_exit(NULL);
399 }
400
401 void *tlslistener(void *arg) {
402     pthread_t tlsserverth;
403     int s, *sp = (int *)arg;
404     struct sockaddr_storage from;
405     socklen_t fromlen = sizeof(from);
406
407     listen(*sp, 0);
408
409     for (;;) {
410         s = accept(*sp, (struct sockaddr *)&from, &fromlen);
411         if (s < 0) {
412             debug(DBG_WARN, "accept failed");
413             continue;
414         }
415         if (pthread_create(&tlsserverth, NULL, tlsservernew, (void *)&s)) {
416             debug(DBG_ERR, "tlslistener: pthread_create failed");
417             shutdown(s, SHUT_RDWR);
418             close(s);
419             continue;
420         }
421         pthread_detach(tlsserverth);
422     }
423     free(sp);
424     return NULL;
425 }