better logging, fixed bug with crash when removing client
[libradsec.git] / tls.c
1 /*
2  * Copyright (C) 2006-2008 Stig Venaas <venaas@uninett.no>
3  *
4  * Permission to use, copy, modify, and distribute this software for any
5  * purpose with or without fee is hereby granted, provided that the above
6  * copyright notice and this permission notice appear in all copies.
7  */
8
9 #include <signal.h>
10 #include <sys/socket.h>
11 #include <netinet/in.h>
12 #include <netdb.h>
13 #include <string.h>
14 #include <unistd.h>
15 #include <limits.h>
16 #ifdef SYS_SOLARIS9
17 #include <fcntl.h>
18 #endif
19 #include <sys/time.h>
20 #include <sys/types.h>
21 #include <sys/select.h>
22 #include <ctype.h>
23 #include <sys/wait.h>
24 #include <arpa/inet.h>
25 #include <regex.h>
26 #include <pthread.h>
27 #include <openssl/ssl.h>
28 #include <openssl/err.h>
29 #include "debug.h"
30 #include "list.h"
31 #include "util.h"
32 #include "radsecproxy.h"
33 #include "tls.h"
34
35 int tlsconnect(struct server *server, struct timeval *when, int timeout, char *text) {
36     struct timeval now;
37     time_t elapsed;
38     X509 *cert;
39     unsigned long error;
40     
41     debug(DBG_DBG, "tlsconnect: called from %s", text);
42     pthread_mutex_lock(&server->lock);
43     if (when && memcmp(&server->lastconnecttry, when, sizeof(struct timeval))) {
44         /* already reconnected, nothing to do */
45         debug(DBG_DBG, "tlsconnect(%s): seems already reconnected", text);
46         pthread_mutex_unlock(&server->lock);
47         return 1;
48     }
49
50     for (;;) {
51         gettimeofday(&now, NULL);
52         elapsed = now.tv_sec - server->lastconnecttry.tv_sec;
53         if (timeout && server->lastconnecttry.tv_sec && elapsed > timeout) {
54             debug(DBG_DBG, "tlsconnect: timeout");
55             if (server->sock >= 0)
56                 close(server->sock);
57             SSL_free(server->ssl);
58             server->ssl = NULL;
59             pthread_mutex_unlock(&server->lock);
60             return 0;
61         }
62         if (server->connectionok) {
63             server->connectionok = 0;
64             sleep(2);
65         } else if (elapsed < 1)
66             sleep(2);
67         else if (elapsed < 60) {
68             debug(DBG_INFO, "tlsconnect: sleeping %lds", elapsed);
69             sleep(elapsed);
70         } else if (elapsed < 100000) {
71             debug(DBG_INFO, "tlsconnect: sleeping %ds", 60);
72             sleep(60);
73         } else
74             server->lastconnecttry.tv_sec = now.tv_sec;  /* no sleep at startup */
75         debug(DBG_WARN, "tlsconnect: trying to open TLS connection to %s port %s", server->conf->host, server->conf->port);
76         if (server->sock >= 0)
77             close(server->sock);
78         if ((server->sock = connecttcp(server->conf->addrinfo, getsrcprotores(RAD_TLS))) < 0) {
79             debug(DBG_ERR, "tlsconnect: connecttcp failed");
80             continue;
81         }
82         
83         SSL_free(server->ssl);
84         server->ssl = SSL_new(server->conf->ssl_ctx);
85         SSL_set_fd(server->ssl, server->sock);
86         if (SSL_connect(server->ssl) <= 0) {
87             while ((error = ERR_get_error()))
88                 debug(DBG_ERR, "tlsconnect: TLS: %s", ERR_error_string(error, NULL));
89             continue;
90         }
91         cert = verifytlscert(server->ssl);
92         if (!cert)
93             continue;
94         if (verifyconfcert(cert, server->conf)) {
95             X509_free(cert);
96             break;
97         }
98         X509_free(cert);
99     }
100     debug(DBG_WARN, "tlsconnect: TLS connection to %s port %s up", server->conf->host, server->conf->port);
101     gettimeofday(&server->lastconnecttry, NULL);
102     pthread_mutex_unlock(&server->lock);
103     return 1;
104 }
105
106 /* timeout in seconds, 0 means no timeout (blocking), returns when num bytes have been read, or timeout */
107 /* returns 0 on timeout, -1 on error and num if ok */
108 int sslreadtimeout(SSL *ssl, unsigned char *buf, int num, int timeout) {
109     int s, ndesc, cnt, len;
110     fd_set readfds, writefds;
111     struct timeval timer;
112     
113     s = SSL_get_fd(ssl);
114     if (s < 0)
115         return -1;
116     /* make socket non-blocking? */
117     for (len = 0; len < num; len += cnt) {
118         FD_ZERO(&readfds);
119         FD_SET(s, &readfds);
120         writefds = readfds;
121         if (timeout) {
122             timer.tv_sec = timeout;
123             timer.tv_usec = 0;
124         }
125         ndesc = select(s + 1, &readfds, &writefds, NULL, timeout ? &timer : NULL);
126         if (ndesc < 1)
127             return ndesc;
128
129         cnt = SSL_read(ssl, buf + len, num - len);
130         if (cnt <= 0)
131             switch (SSL_get_error(ssl, cnt)) {
132             case SSL_ERROR_WANT_READ:
133             case SSL_ERROR_WANT_WRITE:
134                 cnt = 0;
135                 continue;
136             case SSL_ERROR_ZERO_RETURN:
137                 /* remote end sent close_notify, send one back */
138                 SSL_shutdown(ssl);
139                 return -1;
140             default:
141                 return -1;
142             }
143     }
144     return num;
145 }
146
147 /* timeout in seconds, 0 means no timeout (blocking) */
148 unsigned char *radtlsget(SSL *ssl, int timeout) {
149     int cnt, len;
150     unsigned char buf[4], *rad;
151
152     for (;;) {
153         cnt = sslreadtimeout(ssl, buf, 4, timeout);
154         if (cnt < 1) {
155             debug(DBG_DBG, cnt ? "radtlsget: connection lost" : "radtlsget: timeout");
156             return NULL;
157         }
158
159         len = RADLEN(buf);
160         rad = malloc(len);
161         if (!rad) {
162             debug(DBG_ERR, "radtlsget: malloc failed");
163             continue;
164         }
165         memcpy(rad, buf, 4);
166         
167         cnt = sslreadtimeout(ssl, rad + 4, len - 4, timeout);
168         if (cnt < 1) {
169             debug(DBG_DBG, cnt ? "radtlsget: connection lost" : "radtlsget: timeout");
170             free(rad);
171             return NULL;
172         }
173         
174         if (len >= 20)
175             break;
176         
177         free(rad);
178         debug(DBG_WARN, "radtlsget: packet smaller than minimum radius size");
179     }
180     
181     debug(DBG_DBG, "radtlsget: got %d bytes", len);
182     return rad;
183 }
184
185 int clientradputtls(struct server *server, unsigned char *rad) {
186     int cnt;
187     size_t len;
188     unsigned long error;
189     struct timeval lastconnecttry;
190     struct clsrvconf *conf = server->conf;
191     
192     len = RADLEN(rad);
193     lastconnecttry = server->lastconnecttry;
194     while ((cnt = SSL_write(server->ssl, rad, len)) <= 0) {
195         while ((error = ERR_get_error()))
196             debug(DBG_ERR, "clientradputtls: TLS: %s", ERR_error_string(error, NULL));
197         if (server->dynamiclookuparg)
198             return 0;
199         tlsconnect(server, &lastconnecttry, 0, "clientradputtls");
200         lastconnecttry = server->lastconnecttry;
201     }
202
203     server->connectionok = 1;
204     debug(DBG_DBG, "clientradputtls: Sent %d bytes, Radius packet of length %d to TLS peer %s", cnt, len, conf->host);
205     return 1;
206 }
207
208 void *tlsclientrd(void *arg) {
209     struct server *server = (struct server *)arg;
210     unsigned char *buf;
211     struct timeval now, lastconnecttry;
212     
213     for (;;) {
214         /* yes, lastconnecttry is really necessary */
215         lastconnecttry = server->lastconnecttry;
216         buf = radtlsget(server->ssl, server->dynamiclookuparg ? IDLE_TIMEOUT : 0);
217         if (!buf) {
218             if (server->dynamiclookuparg)
219                 break;
220             tlsconnect(server, &lastconnecttry, 0, "tlsclientrd");
221             continue;
222         }
223
224         replyh(server, buf);
225
226         if (server->dynamiclookuparg) {
227             gettimeofday(&now, NULL);
228             if (now.tv_sec - server->lastreply.tv_sec > IDLE_TIMEOUT) {
229                 debug(DBG_INFO, "tlsclientrd: idle timeout for %s", server->conf->name);
230                 break;
231             }
232         }
233     }
234     ERR_remove_state(0);
235     server->clientrdgone = 1;
236     return NULL;
237 }
238
239 void *tlsserverwr(void *arg) {
240     int cnt;
241     unsigned long error;
242     struct client *client = (struct client *)arg;
243     struct queue *replyq;
244     struct request *reply;
245     
246     debug(DBG_DBG, "tlsserverwr: starting for %s", addr2string(client->addr));
247     replyq = client->replyq;
248     for (;;) {
249         pthread_mutex_lock(&replyq->mutex);
250         while (!list_first(replyq->entries)) {
251             if (client->ssl) {      
252                 debug(DBG_DBG, "tlsserverwr: waiting for signal");
253                 pthread_cond_wait(&replyq->cond, &replyq->mutex);
254                 debug(DBG_DBG, "tlsserverwr: got signal");
255             }
256             if (!client->ssl) {
257                 /* ssl might have changed while waiting */
258                 pthread_mutex_unlock(&replyq->mutex);
259                 debug(DBG_DBG, "tlsserverwr: exiting as requested");
260                 ERR_remove_state(0);
261                 pthread_exit(NULL);
262             }
263         }
264         reply = (struct request *)list_shift(replyq->entries);
265         pthread_mutex_unlock(&replyq->mutex);
266         cnt = SSL_write(client->ssl, reply->replybuf, RADLEN(reply->replybuf));
267         if (cnt > 0)
268             debug(DBG_DBG, "tlsserverwr: sent %d bytes, Radius packet of length %d to %s",
269                   cnt, RADLEN(reply->replybuf), addr2string(client->addr));
270         else
271             while ((error = ERR_get_error()))
272                 debug(DBG_ERR, "tlsserverwr: SSL: %s", ERR_error_string(error, NULL));
273         freerq(reply);
274     }
275 }
276
277 void tlsserverrd(struct client *client) {
278     struct request *rq;
279     uint8_t *buf;
280     pthread_t tlsserverwrth;
281     
282     debug(DBG_DBG, "tlsserverrd: starting for %s", addr2string(client->addr));
283     
284     if (pthread_create(&tlsserverwrth, NULL, tlsserverwr, (void *)client)) {
285         debug(DBG_ERR, "tlsserverrd: pthread_create failed");
286         return;
287     }
288
289     for (;;) {
290         buf = radtlsget(client->ssl, 0);
291         if (!buf) {
292             debug(DBG_ERR, "tlsserverrd: connection from %s lost", addr2string(client->addr));
293             break;
294         }
295         debug(DBG_DBG, "tlsserverrd: got Radius message from %s", addr2string(client->addr));
296         rq = newrequest();
297         if (!rq) {
298             free(buf);
299             continue;
300         }
301         rq->buf = buf;
302         rq->from = client;
303         if (!radsrv(rq)) {
304             debug(DBG_ERR, "tlsserverrd: message authentication/validation failed, closing connection from %s", addr2string(client->addr));
305             break;
306         }
307     }
308     
309     /* stop writer by setting ssl to NULL and give signal in case waiting for data */
310     client->ssl = NULL;
311     pthread_mutex_lock(&client->replyq->mutex);
312     pthread_cond_signal(&client->replyq->cond);
313     pthread_mutex_unlock(&client->replyq->mutex);
314     debug(DBG_DBG, "tlsserverrd: waiting for writer to end");
315     pthread_join(tlsserverwrth, NULL);
316     removeclientrqs(client);
317     debug(DBG_DBG, "tlsserverrd: reader for %s exiting", addr2string(client->addr));
318 }
319
320 void *tlsservernew(void *arg) {
321     int s;
322     struct sockaddr_storage from;
323     size_t fromlen = sizeof(from);
324     struct clsrvconf *conf;
325     struct list_node *cur = NULL;
326     SSL *ssl = NULL;
327     X509 *cert = NULL;
328     unsigned long error;
329     struct client *client;
330
331     s = *(int *)arg;
332     if (getpeername(s, (struct sockaddr *)&from, &fromlen)) {
333         debug(DBG_DBG, "tlsservernew: getpeername failed, exiting");
334         goto exit;
335     }
336     debug(DBG_WARN, "tlsservernew: incoming TLS connection from %s", addr2string((struct sockaddr *)&from));
337
338     conf = find_clconf(RAD_TLS, (struct sockaddr *)&from, &cur);
339     if (conf) {
340         ssl = SSL_new(conf->ssl_ctx);
341         SSL_set_fd(ssl, s);
342
343         if (SSL_accept(ssl) <= 0) {
344             while ((error = ERR_get_error()))
345                 debug(DBG_ERR, "tlsservernew: SSL: %s", ERR_error_string(error, NULL));
346             debug(DBG_ERR, "tlsservernew: SSL_accept failed");
347             goto exit;
348         }
349         cert = verifytlscert(ssl);
350         if (!cert)
351             goto exit;
352     }
353     
354     while (conf) {
355         if (verifyconfcert(cert, conf)) {
356             X509_free(cert);
357             client = addclient(conf, 1);
358             if (client) {
359                 client->ssl = ssl;
360                 client->addr = addr_copy((struct sockaddr *)&from);
361                 tlsserverrd(client);
362                 removeclient(client);
363             } else
364                 debug(DBG_WARN, "tlsservernew: failed to create new client instance");
365             goto exit;
366         }
367         conf = find_clconf(RAD_TLS, (struct sockaddr *)&from, &cur);
368     }
369     debug(DBG_WARN, "tlsservernew: ignoring request, no matching TLS client");
370     if (cert)
371         X509_free(cert);
372
373  exit:
374     if (ssl) {
375         SSL_shutdown(ssl);
376         SSL_free(ssl);
377     }
378     ERR_remove_state(0);
379     shutdown(s, SHUT_RDWR);
380     close(s);
381     pthread_exit(NULL);
382 }
383
384 void *tlslistener(void *arg) {
385     pthread_t tlsserverth;
386     int s, *sp = (int *)arg;
387     struct sockaddr_storage from;
388     size_t fromlen = sizeof(from);
389
390     listen(*sp, 0);
391
392     for (;;) {
393         s = accept(*sp, (struct sockaddr *)&from, &fromlen);
394         if (s < 0) {
395             debug(DBG_WARN, "accept failed");
396             continue;
397         }
398         if (pthread_create(&tlsserverth, NULL, tlsservernew, (void *)&s)) {
399             debug(DBG_ERR, "tlslistener: pthread_create failed");
400             shutdown(s, SHUT_RDWR);
401             close(s);
402             continue;
403         }
404         pthread_detach(tlsserverth);
405     }
406     free(sp);
407     return NULL;
408 }