Match function prototypes with definitions.
[libradsec.git] / tls.c
1 /*
2  * Copyright (C) 2006-2009 Stig Venaas <venaas@uninett.no>
3  *
4  * Permission to use, copy, modify, and distribute this software for any
5  * purpose with or without fee is hereby granted, provided that the above
6  * copyright notice and this permission notice appear in all copies.
7  */
8
9 #include <signal.h>
10 #include <sys/socket.h>
11 #include <netinet/in.h>
12 #include <netdb.h>
13 #include <string.h>
14 #include <unistd.h>
15 #include <limits.h>
16 #ifdef SYS_SOLARIS9
17 #include <fcntl.h>
18 #endif
19 #include <sys/time.h>
20 #include <sys/types.h>
21 #include <sys/select.h>
22 #include <ctype.h>
23 #include <sys/wait.h>
24 #include <arpa/inet.h>
25 #include <regex.h>
26 #include <pthread.h>
27 #include <openssl/ssl.h>
28 #include <openssl/err.h>
29 #include "radsecproxy.h"
30 #include "hostport.h"
31
32 #ifdef RADPROT_TLS
33 #include "debug.h"
34 #include "util.h"
35
36 static void setprotoopts(struct commonprotoopts *opts);
37 static char **getlistenerargs();
38 void *tlslistener(void *arg);
39 int tlsconnect(struct server *server, struct timeval *when, int timeout, char *text);
40 void *tlsclientrd(void *arg);
41 int clientradputtls(struct server *server, unsigned char *rad);
42 void tlssetsrcres();
43
44 static const struct protodefs protodefs = {
45     "tls",
46     "mysecret", /* secretdefault */
47     SOCK_STREAM, /* socktype */
48     "2083", /* portdefault */
49     0, /* retrycountdefault */
50     0, /* retrycountmax */
51     REQUEST_RETRY_INTERVAL * REQUEST_RETRY_COUNT, /* retryintervaldefault */
52     60, /* retryintervalmax */
53     DUPLICATE_INTERVAL, /* duplicateintervaldefault */
54     setprotoopts, /* setprotoopts */
55     getlistenerargs, /* getlistenerargs */
56     tlslistener, /* listener */
57     tlsconnect, /* connecter */
58     tlsclientrd, /* clientconnreader */
59     clientradputtls, /* clientradput */
60     NULL, /* addclient */
61     NULL, /* addserverextra */
62     tlssetsrcres, /* setsrcres */
63     NULL /* initextra */
64 };
65
66 static struct addrinfo *srcres = NULL;
67 static uint8_t handle;
68 static struct commonprotoopts *protoopts = NULL;
69
70 const struct protodefs *tlsinit(uint8_t h) {
71     handle = h;
72     return &protodefs;
73 }
74
75 static void setprotoopts(struct commonprotoopts *opts) {
76     protoopts = opts;
77 }
78
79 static char **getlistenerargs() {
80     return protoopts ? protoopts->listenargs : NULL;
81 }
82
83 void tlssetsrcres() {
84     if (!srcres)
85         srcres = resolvepassiveaddrinfo(protoopts ? protoopts->sourcearg : NULL, NULL, protodefs.socktype);
86 }
87
88 int tlsconnect(struct server *server, struct timeval *when, int timeout, char *text) {
89     struct timeval now;
90     time_t elapsed;
91     X509 *cert;
92     SSL_CTX *ctx = NULL;
93     unsigned long error;
94
95     debug(DBG_DBG, "tlsconnect: called from %s", text);
96     pthread_mutex_lock(&server->lock);
97     if (when && memcmp(&server->lastconnecttry, when, sizeof(struct timeval))) {
98         /* already reconnected, nothing to do */
99         debug(DBG_DBG, "tlsconnect(%s): seems already reconnected", text);
100         pthread_mutex_unlock(&server->lock);
101         return 1;
102     }
103
104     for (;;) {
105         gettimeofday(&now, NULL);
106         elapsed = now.tv_sec - server->lastconnecttry.tv_sec;
107         if (timeout && server->lastconnecttry.tv_sec && elapsed > timeout) {
108             debug(DBG_DBG, "tlsconnect: timeout");
109             if (server->sock >= 0)
110                 close(server->sock);
111             SSL_free(server->ssl);
112             server->ssl = NULL;
113             pthread_mutex_unlock(&server->lock);
114             return 0;
115         }
116         if (server->connectionok) {
117             server->connectionok = 0;
118             sleep(2);
119         } else if (elapsed < 1)
120             sleep(2);
121         else if (elapsed < 60) {
122             debug(DBG_INFO, "tlsconnect: sleeping %lds", elapsed);
123             sleep(elapsed);
124         } else if (elapsed < 100000) {
125             debug(DBG_INFO, "tlsconnect: sleeping %ds", 60);
126             sleep(60);
127         } else
128             server->lastconnecttry.tv_sec = now.tv_sec;  /* no sleep at startup */
129
130         if (server->sock >= 0)
131             close(server->sock);
132         if ((server->sock = connecttcphostlist(server->conf->hostports, srcres)) < 0)
133             continue;
134
135         SSL_free(server->ssl);
136         server->ssl = NULL;
137         ctx = tlsgetctx(handle, server->conf->tlsconf);
138         if (!ctx)
139             continue;
140         server->ssl = SSL_new(ctx);
141         if (!server->ssl)
142             continue;
143
144         SSL_set_fd(server->ssl, server->sock);
145         if (SSL_connect(server->ssl) <= 0) {
146             while ((error = ERR_get_error()))
147                 debug(DBG_ERR, "tlsconnect: TLS: %s", ERR_error_string(error, NULL));
148             continue;
149         }
150         cert = verifytlscert(server->ssl);
151         if (!cert)
152             continue;
153         if (verifyconfcert(cert, server->conf)) {
154             X509_free(cert);
155             break;
156         }
157         X509_free(cert);
158     }
159     debug(DBG_WARN, "tlsconnect: TLS connection to %s up", server->conf->name);
160     server->connectionok = 1;
161     gettimeofday(&server->lastconnecttry, NULL);
162     pthread_mutex_unlock(&server->lock);
163     return 1;
164 }
165
166 /* timeout in seconds, 0 means no timeout (blocking), returns when num bytes have been read, or timeout */
167 /* returns 0 on timeout, -1 on error and num if ok */
168 int sslreadtimeout(SSL *ssl, unsigned char *buf, int num, int timeout) {
169     int s, ndesc, cnt, len;
170     fd_set readfds, writefds;
171     struct timeval timer;
172
173     s = SSL_get_fd(ssl);
174     if (s < 0)
175         return -1;
176     /* make socket non-blocking? */
177     for (len = 0; len < num; len += cnt) {
178         FD_ZERO(&readfds);
179         FD_SET(s, &readfds);
180         writefds = readfds;
181         if (timeout) {
182             timer.tv_sec = timeout;
183             timer.tv_usec = 0;
184         }
185         ndesc = select(s + 1, &readfds, &writefds, NULL, timeout ? &timer : NULL);
186         if (ndesc < 1)
187             return ndesc;
188
189         cnt = SSL_read(ssl, buf + len, num - len);
190         if (cnt <= 0)
191             switch (SSL_get_error(ssl, cnt)) {
192             case SSL_ERROR_WANT_READ:
193             case SSL_ERROR_WANT_WRITE:
194                 cnt = 0;
195                 continue;
196             case SSL_ERROR_ZERO_RETURN:
197                 /* remote end sent close_notify, send one back */
198                 SSL_shutdown(ssl);
199                 return -1;
200             default:
201                 return -1;
202             }
203     }
204     return num;
205 }
206
207 /* timeout in seconds, 0 means no timeout (blocking) */
208 unsigned char *radtlsget(SSL *ssl, int timeout) {
209     int cnt, len;
210     unsigned char buf[4], *rad;
211
212     for (;;) {
213         cnt = sslreadtimeout(ssl, buf, 4, timeout);
214         if (cnt < 1) {
215             debug(DBG_DBG, cnt ? "radtlsget: connection lost" : "radtlsget: timeout");
216             return NULL;
217         }
218
219         len = RADLEN(buf);
220         rad = malloc(len);
221         if (!rad) {
222             debug(DBG_ERR, "radtlsget: malloc failed");
223             continue;
224         }
225         memcpy(rad, buf, 4);
226
227         cnt = sslreadtimeout(ssl, rad + 4, len - 4, timeout);
228         if (cnt < 1) {
229             debug(DBG_DBG, cnt ? "radtlsget: connection lost" : "radtlsget: timeout");
230             free(rad);
231             return NULL;
232         }
233
234         if (len >= 20)
235             break;
236
237         free(rad);
238         debug(DBG_WARN, "radtlsget: packet smaller than minimum radius size");
239     }
240
241     debug(DBG_DBG, "radtlsget: got %d bytes", len);
242     return rad;
243 }
244
245 int clientradputtls(struct server *server, unsigned char *rad) {
246     int cnt;
247     size_t len;
248     unsigned long error;
249     struct clsrvconf *conf = server->conf;
250
251     if (!server->connectionok)
252         return 0;
253     len = RADLEN(rad);
254     if ((cnt = SSL_write(server->ssl, rad, len)) <= 0) {
255         while ((error = ERR_get_error()))
256             debug(DBG_ERR, "clientradputtls: TLS: %s", ERR_error_string(error, NULL));
257         return 0;
258     }
259
260     debug(DBG_DBG, "clientradputtls: Sent %d bytes, Radius packet of length %d to TLS peer %s", cnt, len, conf->name);
261     return 1;
262 }
263
264 void *tlsclientrd(void *arg) {
265     struct server *server = (struct server *)arg;
266     unsigned char *buf;
267     struct timeval now, lastconnecttry;
268
269     for (;;) {
270         /* yes, lastconnecttry is really necessary */
271         lastconnecttry = server->lastconnecttry;
272         buf = radtlsget(server->ssl, server->dynamiclookuparg ? IDLE_TIMEOUT : 0);
273         if (!buf) {
274             if (server->dynamiclookuparg)
275                 break;
276             tlsconnect(server, &lastconnecttry, 0, "tlsclientrd");
277             continue;
278         }
279
280         replyh(server, buf);
281
282         if (server->dynamiclookuparg) {
283             gettimeofday(&now, NULL);
284             if (now.tv_sec - server->lastreply.tv_sec > IDLE_TIMEOUT) {
285                 debug(DBG_INFO, "tlsclientrd: idle timeout for %s", server->conf->name);
286                 break;
287             }
288         }
289     }
290     ERR_remove_state(0);
291     server->clientrdgone = 1;
292     return NULL;
293 }
294
295 void *tlsserverwr(void *arg) {
296     int cnt;
297     unsigned long error;
298     struct client *client = (struct client *)arg;
299     struct gqueue *replyq;
300     struct request *reply;
301
302     debug(DBG_DBG, "tlsserverwr: starting for %s", addr2string(client->addr));
303     replyq = client->replyq;
304     for (;;) {
305         pthread_mutex_lock(&replyq->mutex);
306         while (!list_first(replyq->entries)) {
307             if (client->ssl) {
308                 debug(DBG_DBG, "tlsserverwr: waiting for signal");
309                 pthread_cond_wait(&replyq->cond, &replyq->mutex);
310                 debug(DBG_DBG, "tlsserverwr: got signal");
311             }
312             if (!client->ssl) {
313                 /* ssl might have changed while waiting */
314                 pthread_mutex_unlock(&replyq->mutex);
315                 debug(DBG_DBG, "tlsserverwr: exiting as requested");
316                 ERR_remove_state(0);
317                 pthread_exit(NULL);
318             }
319         }
320         reply = (struct request *)list_shift(replyq->entries);
321         pthread_mutex_unlock(&replyq->mutex);
322         cnt = SSL_write(client->ssl, reply->replybuf, RADLEN(reply->replybuf));
323         if (cnt > 0)
324             debug(DBG_DBG, "tlsserverwr: sent %d bytes, Radius packet of length %d to %s",
325                   cnt, RADLEN(reply->replybuf), addr2string(client->addr));
326         else
327             while ((error = ERR_get_error()))
328                 debug(DBG_ERR, "tlsserverwr: SSL: %s", ERR_error_string(error, NULL));
329         freerq(reply);
330     }
331 }
332
333 void tlsserverrd(struct client *client) {
334     struct request *rq;
335     uint8_t *buf;
336     pthread_t tlsserverwrth;
337
338     debug(DBG_DBG, "tlsserverrd: starting for %s", addr2string(client->addr));
339
340     if (pthread_create(&tlsserverwrth, NULL, tlsserverwr, (void *)client)) {
341         debug(DBG_ERR, "tlsserverrd: pthread_create failed");
342         return;
343     }
344
345     for (;;) {
346         buf = radtlsget(client->ssl, 0);
347         if (!buf) {
348             debug(DBG_ERR, "tlsserverrd: connection from %s lost", addr2string(client->addr));
349             break;
350         }
351         debug(DBG_DBG, "tlsserverrd: got Radius message from %s", addr2string(client->addr));
352         rq = newrequest();
353         if (!rq) {
354             free(buf);
355             continue;
356         }
357         rq->buf = buf;
358         rq->from = client;
359         if (!radsrv(rq)) {
360             debug(DBG_ERR, "tlsserverrd: message authentication/validation failed, closing connection from %s", addr2string(client->addr));
361             break;
362         }
363     }
364
365     /* stop writer by setting ssl to NULL and give signal in case waiting for data */
366     client->ssl = NULL;
367     pthread_mutex_lock(&client->replyq->mutex);
368     pthread_cond_signal(&client->replyq->cond);
369     pthread_mutex_unlock(&client->replyq->mutex);
370     debug(DBG_DBG, "tlsserverrd: waiting for writer to end");
371     pthread_join(tlsserverwrth, NULL);
372     debug(DBG_DBG, "tlsserverrd: reader for %s exiting", addr2string(client->addr));
373 }
374
375 void *tlsservernew(void *arg) {
376     int s;
377     struct sockaddr_storage from;
378     socklen_t fromlen = sizeof(from);
379     struct clsrvconf *conf;
380     struct list_node *cur = NULL;
381     SSL *ssl = NULL;
382     X509 *cert = NULL;
383     SSL_CTX *ctx = NULL;
384     unsigned long error;
385     struct client *client;
386
387     s = *(int *)arg;
388     if (getpeername(s, (struct sockaddr *)&from, &fromlen)) {
389         debug(DBG_DBG, "tlsservernew: getpeername failed, exiting");
390         goto exit;
391     }
392     debug(DBG_WARN, "tlsservernew: incoming TLS connection from %s", addr2string((struct sockaddr *)&from));
393
394     conf = find_clconf(handle, (struct sockaddr *)&from, &cur);
395     if (conf) {
396         ctx = tlsgetctx(handle, conf->tlsconf);
397         if (!ctx)
398             goto exit;
399         ssl = SSL_new(ctx);
400         if (!ssl)
401             goto exit;
402         SSL_set_fd(ssl, s);
403
404         if (SSL_accept(ssl) <= 0) {
405             while ((error = ERR_get_error()))
406                 debug(DBG_ERR, "tlsservernew: SSL: %s", ERR_error_string(error, NULL));
407             debug(DBG_ERR, "tlsservernew: SSL_accept failed");
408             goto exit;
409         }
410         cert = verifytlscert(ssl);
411         if (!cert)
412             goto exit;
413     }
414
415     while (conf) {
416         if (verifyconfcert(cert, conf)) {
417             X509_free(cert);
418             client = addclient(conf, 1);
419             if (client) {
420                 client->ssl = ssl;
421                 client->addr = addr_copy((struct sockaddr *)&from);
422                 tlsserverrd(client);
423                 removeclient(client);
424             } else
425                 debug(DBG_WARN, "tlsservernew: failed to create new client instance");
426             goto exit;
427         }
428         conf = find_clconf(handle, (struct sockaddr *)&from, &cur);
429     }
430     debug(DBG_WARN, "tlsservernew: ignoring request, no matching TLS client");
431     if (cert)
432         X509_free(cert);
433
434 exit:
435     if (ssl) {
436         SSL_shutdown(ssl);
437         SSL_free(ssl);
438     }
439     ERR_remove_state(0);
440     shutdown(s, SHUT_RDWR);
441     close(s);
442     pthread_exit(NULL);
443 }
444
445 void *tlslistener(void *arg) {
446     pthread_t tlsserverth;
447     int s, *sp = (int *)arg;
448     struct sockaddr_storage from;
449     socklen_t fromlen = sizeof(from);
450
451     listen(*sp, 0);
452
453     for (;;) {
454         s = accept(*sp, (struct sockaddr *)&from, &fromlen);
455         if (s < 0) {
456             debug(DBG_WARN, "accept failed");
457             continue;
458         }
459         if (pthread_create(&tlsserverth, NULL, tlsservernew, (void *)&s)) {
460             debug(DBG_ERR, "tlslistener: pthread_create failed");
461             shutdown(s, SHUT_RDWR);
462             close(s);
463             continue;
464         }
465         pthread_detach(tlsserverth);
466     }
467     free(sp);
468     return NULL;
469 }
470 #else
471 const struct protodefs *tlsinit(uint8_t h) {
472     return NULL;
473 }
474 #endif
475
476 /* Local Variables: */
477 /* c-file-style: "stroustrup" */
478 /* End: */