fixed minor bugs
[radsecproxy.git] / radsecproxy.c
1 /*
2  * Copyright (C) 2006 Stig Venaas <venaas@uninett.no>
3  *
4  * Permission to use, copy, modify, and distribute this software for any
5  * purpose with or without fee is hereby granted, provided that the above
6  * copyright notice and this permission notice appear in all copies.
7  */
8
9 /* BUGS:
10  * peers can not yet be specified with literal IPv6 addresses due to port syntax
11  */
12
13 /* TODO:
14  * Among other things:
15  * timer based client retrans or maybe no retrans and just a timer...
16  * make our server ignore client retrans?
17  * tls keep alives
18  * routing based on id....
19  * need to also encrypt Tunnel-Password and Message-Authenticator attrs
20  * tls certificate validation
21 */
22
23 /* For UDP there is one server instance consisting of udpserverrd and udpserverth
24  *              rd is responsible for init and launching wr
25  * For TLS there is a server instance that launches tlsserverrd for each TLS peer
26  *          each tlsserverrd launches tlsserverwr
27  * For each UDP/TLS peer there is clientrd and clientwr, clientwr is responsible
28  *          for init and launching rd
29  *
30  * serverrd will receive a request, processes it and puts it in the requestq of
31  *          the appropriate clientwr
32  * clientwr monitors its requestq and sends requests
33  * clientrd looks for responses, processes them and puts them in the replyq of
34  *          the peer the request came from
35  * serverwr monitors its reply and sends replies
36  *
37  * In addition to the main thread, we have:
38  * If UDP peers are configured, there will be 2 + 2 * #peers UDP threads
39  * If TLS peers are configured, there will initially be 2 * #peers TLS threads
40  * For each TLS peer connecting to us there will be 2 more TLS threads
41  *       This is only for connected peers
42  * Example: With 3 UDP peer and 30 TLS peers, there will be a max of
43  *          1 + (2 + 2 * 3) + (2 * 30) + (2 * 30) = 129 threads
44 */
45
46 #include <netdb.h>
47 #include <unistd.h>
48 #include <sys/time.h>
49 #include <pthread.h>
50 #include <openssl/ssl.h>
51 #include <openssl/rand.h>
52 #include <openssl/err.h>
53 #include <openssl/md5.h>
54 #include "radsecproxy.h"
55
56 static struct client clients[MAX_PEERS];
57 static struct server servers[MAX_PEERS];
58
59 static int client_count = 0;
60 static int server_count = 0;
61
62 static struct replyq udp_server_replyq;
63 static int udp_server_sock = -1;
64 static char *udp_server_port = DEFAULT_UDP_PORT;
65 static pthread_mutex_t *ssl_locks;
66 static long *ssl_lock_count;
67 static SSL_CTX *ssl_ctx_cl;
68 extern int optind;
69 extern char *optarg;
70
71 /* callbacks for making OpenSSL thread safe */
72 unsigned long ssl_thread_id() {
73         return (unsigned long)pthread_self();
74 };
75
76 void ssl_locking_callback(int mode, int type, const char *file, int line) {
77     if (mode & CRYPTO_LOCK) {
78         pthread_mutex_lock(&ssl_locks[type]);
79         ssl_lock_count[type]++;
80     } else
81         pthread_mutex_unlock(&ssl_locks[type]);
82 }
83
84 void ssl_locks_setup() {
85     int i;
86
87     ssl_locks = malloc(CRYPTO_num_locks() * sizeof(pthread_mutex_t));
88     ssl_lock_count = OPENSSL_malloc(CRYPTO_num_locks() * sizeof(long));
89     for (i = 0; i < CRYPTO_num_locks(); i++) {
90         ssl_lock_count[i] = 0;
91         pthread_mutex_init(&ssl_locks[i], NULL);
92     }
93
94     CRYPTO_set_id_callback(ssl_thread_id);
95     CRYPTO_set_locking_callback(ssl_locking_callback);
96 }
97
98 int resolvepeer(struct peer *peer) {
99     struct addrinfo hints, *addrinfo;
100     
101     memset(&hints, 0, sizeof(hints));
102     hints.ai_socktype = (peer->type == 'T' ? SOCK_STREAM : SOCK_DGRAM);
103     hints.ai_family = AF_UNSPEC;
104     if (getaddrinfo(peer->host, peer->port, &hints, &addrinfo)) {
105         err("resolvepeer: can't resolve %s port %s", peer->host, peer->port);
106         return 0;
107     }
108
109     if (peer->addrinfo)
110         freeaddrinfo(peer->addrinfo);
111     peer->addrinfo = addrinfo;
112     return 1;
113 }         
114
115 int connecttoserver(struct addrinfo *addrinfo) {
116     int s;
117     struct addrinfo *res;
118     
119     for (res = addrinfo; res; res = res->ai_next) {
120         s = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
121         if (s < 0) {
122             err("connecttoserver: socket failed");
123             continue;
124         }
125         if (connect(s, res->ai_addr, res->ai_addrlen) == 0)
126             break;
127         err("connecttoserver: connect failed");
128         close(s);
129         s = -1;
130     }
131     return s;
132 }         
133
134 /* returns the client with matching address, or NULL */
135 /* if client argument is not NULL, we only check that one client */
136 struct client *find_client(char type, struct sockaddr *addr, struct client *client) {
137     struct sockaddr_in6 *sa6;
138     struct in_addr *a4 = NULL;
139     struct client *c;
140     int i;
141     struct addrinfo *res;
142
143     if (addr->sa_family == AF_INET6) {
144         sa6 = (struct sockaddr_in6 *)addr;
145         if (IN6_IS_ADDR_V4MAPPED(&sa6->sin6_addr))
146             a4 = (struct in_addr *)&sa6->sin6_addr.s6_addr[12];
147     } else
148         a4 = &((struct sockaddr_in *)addr)->sin_addr;
149
150     c = (client ? client : clients);
151     for (i = 0; i < client_count; i++) {
152         if (c->peer.type == type)
153             for (res = c->peer.addrinfo; res; res = res->ai_next)
154                 if ((a4 && res->ai_family == AF_INET &&
155                      !memcmp(a4, &((struct sockaddr_in *)res->ai_addr)->sin_addr, 4)) ||
156                     (res->ai_family == AF_INET6 &&
157                      !memcmp(&sa6->sin6_addr, &((struct sockaddr_in6 *)res->ai_addr)->sin6_addr, 16)))
158                     return c;
159         if (client)
160             break;
161         c++;
162     }
163     return NULL;
164 }
165
166 /* returns the server with matching address, or NULL */
167 /* if server argument is not NULL, we only check that one server */
168 struct server *find_server(char type, struct sockaddr *addr, struct server *server) {
169     struct sockaddr_in6 *sa6;
170     struct in_addr *a4 = NULL;
171     struct server *s;
172     int i;
173     struct addrinfo *res;
174
175     if (addr->sa_family == AF_INET6) {
176         sa6 = (struct sockaddr_in6 *)addr;
177         if (IN6_IS_ADDR_V4MAPPED(&sa6->sin6_addr))
178             a4 = (struct in_addr *)&sa6->sin6_addr.s6_addr[12];
179     } else
180         a4 = &((struct sockaddr_in *)addr)->sin_addr;
181
182     s = (server ? server : servers);
183     for (i = 0; i < server_count; i++) {
184         if (s->peer.type == type)
185             for (res = s->peer.addrinfo; res; res = res->ai_next)
186                 if ((a4 && res->ai_family == AF_INET &&
187                      !memcmp(a4, &((struct sockaddr_in *)res->ai_addr)->sin_addr, 4)) ||
188                     (res->ai_family == AF_INET6 &&
189                      !memcmp(&sa6->sin6_addr, &((struct sockaddr_in6 *)res->ai_addr)->sin6_addr, 16)))
190                     return s;
191         if (server)
192             break;
193         s++;
194     }
195     return NULL;
196 }
197
198 /* exactly one of client and server must be non-NULL */
199 /* if *peer == NULL we return who we received from, else require it to be from peer */
200 /* return from in sa if not NULL */
201 unsigned char *radudpget(int s, struct client **client, struct server **server, struct sockaddr_storage *sa) {
202     int cnt, len;
203     void *f;
204     unsigned char buf[65536], *rad;
205     struct sockaddr_storage from;
206     socklen_t fromlen = sizeof(from);
207
208     for (;;) {
209         cnt = recvfrom(s, buf, sizeof(buf), 0, (struct sockaddr *)&from, &fromlen);
210         if (cnt == -1) {
211             err("radudpget: recv failed");
212             continue;
213         }
214         printf("radudpget: got %d bytes from %s\n", cnt, addr2string((struct sockaddr *)&from, fromlen));
215
216         if (cnt < 20) {
217             printf("radudpget: packet too small\n");
218             continue;
219         }
220     
221         len = RADLEN(buf);
222
223         if (cnt < len) {
224             printf("radudpget: packet smaller than length field in radius header\n");
225             continue;
226         }
227         if (cnt > len)
228             printf("radudpget: packet was padded with %d bytes\n", cnt - len);
229
230         f = (client
231              ? (void *)find_client('U', (struct sockaddr *)&from, *client)
232              : (void *)find_server('U', (struct sockaddr *)&from, *server));
233         if (!f) {
234             printf("radudpget: got packet from wrong or unknown UDP peer, ignoring\n");
235             continue;
236         }
237
238         rad = malloc(len);
239         if (rad)
240             break;
241         err("radudpget: malloc failed");
242     }
243     memcpy(rad, buf, len);
244     if (client)
245         *client = (struct client *)f; /* only need this if *client == NULL, but if not NULL *client == f here */
246     else
247         *server = (struct server *)f; /* only need this if *server == NULL, but if not NULL *server == f here */
248     if (sa)
249         *sa = from;
250     return rad;
251 }
252
253 void tlsconnect(struct server *server, struct timeval *when, char *text) {
254     struct timeval now;
255     time_t elapsed;
256     unsigned long error;
257
258     printf("tlsconnect called from %s\n", text);
259     pthread_mutex_lock(&server->lock);
260     if (when && memcmp(&server->lastconnecttry, when, sizeof(struct timeval))) {
261         /* already reconnected, nothing to do */
262         printf("tlsconnect(%s): seems already reconnected\n", text);
263         pthread_mutex_unlock(&server->lock);
264         return;
265     }
266
267     printf("tlsconnect %s\n", text);
268
269     for (;;) {
270         gettimeofday(&now, NULL);
271         elapsed = now.tv_sec - server->lastconnecttry.tv_sec;
272         if (server->connectionok) {
273             server->connectionok = 0;
274             sleep(10);
275         } else if (elapsed < 5)
276             sleep(10);
277         else if (elapsed < 600)
278             sleep(elapsed * 2);
279         else if (elapsed < 10000) /* no sleep at startup */
280                 sleep(900);
281         printf("tlsconnect: trying to open TLS connection to %s port %s\n", server->peer.host, server->peer.port);
282         if (server->sock >= 0)
283             close(server->sock);
284         if ((server->sock = connecttoserver(server->peer.addrinfo)) < 0)
285             continue;
286         SSL_free(server->peer.ssl);
287         server->peer.ssl = SSL_new(ssl_ctx_cl);
288         SSL_set_fd(server->peer.ssl, server->sock);
289         if (SSL_connect(server->peer.ssl) > 0)
290             break;
291         while ((error = ERR_get_error()))
292             err("tlsconnect: TLS: %s", ERR_error_string(error, NULL));
293     }
294     printf("tlsconnect: TLS connection to %s port %s up\n", server->peer.host, server->peer.port);
295     gettimeofday(&server->lastconnecttry, NULL);
296     pthread_mutex_unlock(&server->lock);
297 }
298
299 unsigned char *radtlsget(SSL *ssl) {
300     int cnt, total, len;
301     unsigned char buf[4], *rad;
302
303     for (;;) {
304         for (total = 0; total < 4; total += cnt) {
305             cnt = SSL_read(ssl, buf + total, 4 - total);
306             if (cnt <= 0) {
307                 printf("radtlsget: connection lost\n");
308                 return NULL;
309             }
310         }
311
312         len = RADLEN(buf);
313         rad = malloc(len);
314         if (!rad) {
315             err("radtlsget: malloc failed");
316             continue;
317         }
318         memcpy(rad, buf, 4);
319
320         for (; total < len; total += cnt) {
321             cnt = SSL_read(ssl, rad + total, len - total);
322             if (cnt <= 0) {
323                 printf("radtlsget: connection lost\n");
324                 free(rad);
325                 return NULL;
326             }
327         }
328     
329         if (total >= 20)
330             break;
331         
332         free(rad);
333         printf("radtlsget: packet smaller than minimum radius size\n");
334     }
335     
336     printf("radtlsget: got %d bytes\n", total);
337     return rad;
338 }
339
340 int clientradput(struct server *server, unsigned char *rad) {
341     int cnt;
342     size_t len;
343     unsigned long error;
344     struct timeval lastconnecttry;
345     
346     len = RADLEN(rad);
347     if (server->peer.type == 'U') {
348         if (send(server->sock, rad, len, 0) >= 0) {
349             printf("clienradput: sent UDP of length %d to %s port %s\n", len, server->peer.host, server->peer.port);
350             return 1;
351         }
352         err("clientradput: send failed");
353         return 0;
354     }
355
356     lastconnecttry = server->lastconnecttry;
357     while ((cnt = SSL_write(server->peer.ssl, rad, len)) <= 0) {
358         while ((error = ERR_get_error()))
359             err("clientwr: TLS: %s", ERR_error_string(error, NULL));
360         tlsconnect(server, &lastconnecttry, "clientradput");
361         lastconnecttry = server->lastconnecttry;
362     }
363
364     server->connectionok = 1;
365     printf("clientradput: Sent %d bytes, Radius packet of length %d to TLS peer %s\n",
366            cnt, len, server->peer.host);
367     return 1;
368 }
369
370 int radsign(unsigned char *rad, unsigned char *sec) {
371     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
372     static unsigned char first = 1;
373     static EVP_MD_CTX mdctx;
374     unsigned int md_len;
375     int result;
376     
377     pthread_mutex_lock(&lock);
378     if (first) {
379         EVP_MD_CTX_init(&mdctx);
380         first = 0;
381     }
382
383     result = (EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) &&
384         EVP_DigestUpdate(&mdctx, rad, RADLEN(rad)) &&
385         EVP_DigestUpdate(&mdctx, sec, strlen(sec)) &&
386         EVP_DigestFinal_ex(&mdctx, rad + 4, &md_len) &&
387         md_len == 16);
388     pthread_mutex_unlock(&lock);
389     return result;
390 }
391
392 int validauth(unsigned char *rad, unsigned char *reqauth, unsigned char *sec) {
393     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
394     static unsigned char first = 1;
395     static EVP_MD_CTX mdctx;
396     unsigned char hash[EVP_MAX_MD_SIZE];
397     unsigned int len;
398     int result;
399     
400     pthread_mutex_lock(&lock);
401     if (first) {
402         EVP_MD_CTX_init(&mdctx);
403         first = 0;
404     }
405
406     len = RADLEN(rad);
407     
408     result = (EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) &&
409               EVP_DigestUpdate(&mdctx, rad, 4) &&
410               EVP_DigestUpdate(&mdctx, reqauth, 16) &&
411               (len <= 20 || EVP_DigestUpdate(&mdctx, rad + 20, len - 20)) &&
412               EVP_DigestUpdate(&mdctx, sec, strlen(sec)) &&
413               EVP_DigestFinal_ex(&mdctx, hash, &len) &&
414               len == 16 &&
415               !memcmp(hash, rad + 4, 16));
416     pthread_mutex_unlock(&lock);
417     return result;
418 }
419               
420 void sendrq(struct server *to, struct client *from, struct request *rq) {
421     int i;
422
423     pthread_mutex_lock(&to->newrq_mutex);
424     for (i = 0; i < MAX_REQUESTS; i++)
425         if (!to->requests[i].buf)
426             break;
427     if (i == MAX_REQUESTS) {
428         printf("No room in queue, dropping request\n");
429         pthread_mutex_unlock(&to->newrq_mutex);
430         return;
431     }
432     
433     rq->buf[1] = (char)i;
434     to->requests[i] = *rq;
435
436     if (!to->newrq) {
437         to->newrq = 1;
438         printf("signalling client writer\n");
439         pthread_cond_signal(&to->newrq_cond);
440     }
441     pthread_mutex_unlock(&to->newrq_mutex);
442 }
443
444 void sendreply(struct client *to, struct server *from, char *buf, struct sockaddr_storage *tosa) {
445     struct replyq *replyq = to->replyq;
446     
447     pthread_mutex_lock(&replyq->count_mutex);
448     if (replyq->count == replyq->size) {
449         printf("No room in queue, dropping request\n");
450         pthread_mutex_unlock(&replyq->count_mutex);
451         return;
452     }
453
454     replyq->replies[replyq->count].buf = buf;
455     if (tosa)
456         replyq->replies[replyq->count].tosa = *tosa;
457     replyq->count++;
458
459     if (replyq->count == 1) {
460         printf("signalling client writer\n");
461         pthread_cond_signal(&replyq->count_cond);
462     }
463     pthread_mutex_unlock(&replyq->count_mutex);
464 }
465
466 int pwdcrypt(uint8_t *plain, uint8_t *enc, uint8_t enclen, uint8_t *shared, uint8_t sharedlen,
467                 uint8_t *auth) {
468     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
469     static unsigned char first = 1;
470     static EVP_MD_CTX mdctx;
471     unsigned char hash[EVP_MAX_MD_SIZE], *input;
472     unsigned int md_len;
473     uint8_t i, offset = 0;
474     
475     pthread_mutex_lock(&lock);
476     if (first) {
477         EVP_MD_CTX_init(&mdctx);
478         first = 0;
479     }
480
481     input = auth;
482     for (;;) {
483         if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
484             !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
485             !EVP_DigestUpdate(&mdctx, input, 16) ||
486             !EVP_DigestFinal_ex(&mdctx, hash, &md_len) ||
487             md_len != 16) {
488             pthread_mutex_unlock(&lock);
489             return 0;
490         }
491         for (i = 0; i < 16; i++)
492             plain[offset + i] = hash[i] ^ enc[offset + i];
493         offset += 16;
494         if (offset == enclen)
495             break;
496         input = enc + offset - 16;
497     }
498     pthread_mutex_unlock(&lock);
499     return 1;
500 }
501
502 struct server *id2server(char *id, uint8_t len) {
503     int i;
504     char **realm, *idrealm;
505
506     idrealm = strchr(id, '@');
507     if (idrealm) {
508         idrealm++;
509         len -= idrealm - id;
510     } else {
511         idrealm = "-";
512         len = 1;
513     }
514     for (i = 0; i < server_count; i++) {
515         for (realm = servers[i].realms; *realm; realm++) {
516             if ((strlen(*realm) == 1 && **realm == '*') ||
517                 (strlen(*realm) == len && !memcmp(idrealm, *realm, len))) {
518                 printf("found matching realm: %s, host %s\n", *realm, servers[i].peer.host);
519                 return servers + i;
520             }
521         }
522     }
523     return NULL;
524 }
525
526 struct server *radsrv(struct request *rq, char *buf, struct client *from) {
527     uint8_t code, id, *auth, *attr, *usernameattr = NULL, *userpwdattr = NULL, pwd[128], pwdlen;
528     int i;
529     uint16_t len;
530     int left;
531     struct server *to;
532     unsigned char newauth[16];
533     
534     code = *(uint8_t *)buf;
535     id = *(uint8_t *)(buf + 1);
536     len = RADLEN(buf);
537     auth = (uint8_t *)(buf + 4);
538
539     printf("radsrv: code %d, id %d, length %d\n", code, id, len);
540     
541     if (code != RAD_Access_Request) {
542         printf("radsrv: server currently accepts only access-requests, ignoring\n");
543         return NULL;
544     }
545
546     left = len - 20;
547     attr = buf + 20;
548     
549     while (left > 1) {
550         left -= attr[RAD_Attr_Length];
551         if (left < 0) {
552             printf("radsrv: attribute length exceeds packet length, ignoring packet\n");
553             return NULL;
554         }
555         switch (attr[RAD_Attr_Type]) {
556         case RAD_Attr_User_Name:
557             usernameattr = attr;
558             break;
559         case RAD_Attr_User_Password:
560             userpwdattr = attr;
561             break;
562         }
563         attr += attr[RAD_Attr_Length];
564     }
565     if (left)
566         printf("radsrv: malformed packet? remaining byte after last attribute\n");
567
568     if (usernameattr) {
569         printf("radsrv: Username: ");
570         for (i = 0; i < usernameattr[RAD_Attr_Length] - 2; i++)
571             printf("%c", usernameattr[RAD_Attr_Value + i]);
572         printf("\n");
573     }
574
575     to = id2server(&usernameattr[RAD_Attr_Value], usernameattr[RAD_Attr_Length] - 2);
576     if (!to) {
577         printf("radsrv: ignoring request, don't know where to send it\n");
578         return NULL;
579     }
580
581     if (!RAND_bytes(newauth, 16)) {
582         printf("radsrv: failed to generate random auth\n");
583         return NULL;
584     }
585
586     if (userpwdattr) {
587         printf("radsrv: found userpwdattr of length %d\n", userpwdattr[RAD_Attr_Length]);
588         pwdlen = userpwdattr[RAD_Attr_Length] - 2;
589         if (pwdlen < 16 || pwdlen > 128 || pwdlen % 16) {
590             printf("radsrv: invalid user password length\n");
591             return NULL;
592         }
593         
594         if (!pwdcrypt(pwd, &userpwdattr[RAD_Attr_Value], pwdlen, from->peer.secret, strlen(from->peer.secret), auth)) {
595             printf("radsrv: cannot decrypt password\n");
596             return NULL;
597         }
598         printf("radsrv: password: ");
599         for (i = 0; i < pwdlen; i++)
600             printf("%02x ", pwd[i]);
601         printf("\n");
602         if (!pwdcrypt(&userpwdattr[RAD_Attr_Value], pwd, pwdlen, to->peer.secret, strlen(to->peer.secret), newauth)) {
603             printf("radsrv: cannot encrypt password\n");
604             return NULL;
605         }
606     }
607
608     rq->buf = buf;
609     rq->from = from;
610     rq->origid = id;
611     memcpy(rq->origauth, auth, 16);
612     memcpy(rq->buf + 4, newauth, 16);
613     return to;
614 }
615
616 void *clientrd(void *arg) {
617     struct server *server = (struct server *)arg;
618     struct client *from;
619     int i;
620     unsigned char *buf;
621     struct sockaddr_storage fromsa;
622     struct timeval lastconnecttry;
623     
624     for (;;) {
625         lastconnecttry = server->lastconnecttry;
626         buf = (server->peer.type == 'U' ? radudpget(server->sock, NULL, &server, NULL) : radtlsget(server->peer.ssl));
627         if (!buf && server->peer.type == 'T') {
628             tlsconnect(server, &lastconnecttry, "clientrd");
629             continue;
630         }
631     
632         server->connectionok = 1;
633         
634         i = buf[1]; /* i is the id */
635
636         pthread_mutex_lock(&server->newrq_mutex);
637         if (!server->requests[i].buf || !server->requests[i].tries) {
638             pthread_mutex_unlock(&server->newrq_mutex);
639             printf("clientrd: no matching request sent with this id, ignoring\n");
640             continue;
641         }
642         
643         if (server->requests[i].received) {
644             pthread_mutex_unlock(&server->newrq_mutex);
645             printf("clientrd: already received, ignoring\n");
646             continue;
647         }
648
649         if (!validauth(buf, server->requests[i].buf + 4, server->peer.secret)) {
650             pthread_mutex_unlock(&server->newrq_mutex);
651             printf("clientrd: invalid auth, ignoring\n");
652             continue;
653         }
654
655         /* once we set received = 1, requests[i] may be reused */
656         buf[1] = (char)server->requests[i].origid;
657         memcpy(buf + 4, server->requests[i].origauth, 16);
658         from = server->requests[i].from;
659         if (from->peer.type == 'U')
660             fromsa = server->requests[i].fromsa;
661         server->requests[i].received = 1;
662         pthread_mutex_unlock(&server->newrq_mutex);
663
664         if (!radsign(buf, from->peer.secret)) {
665             printf("clientrd: failed to sign message\n");
666             continue;
667         }
668         
669         printf("clientrd: giving packet back to where it came from\n");
670         sendreply(from, server, buf, from->peer.type == 'U' ? &fromsa : NULL);
671     }
672 }
673
674 void *clientwr(void *arg) {
675     struct server *server = (struct server *)arg;
676     pthread_t clientrdth;
677     int i;
678
679     if (server->peer.type == 'U') {
680         if ((server->sock = connecttoserver(server->peer.addrinfo)) < 0) {
681             printf("clientwr: connecttoserver failed\n");
682             exit(1);
683         }
684     } else
685         tlsconnect(server, NULL, "new client");
686     
687     if (pthread_create(&clientrdth, NULL, clientrd, (void *)server))
688         errx("clientwr: pthread_create failed");
689
690     for (;;) {
691         pthread_mutex_lock(&server->newrq_mutex);
692         while (!server->newrq) {
693             printf("clientwr: waiting for signal\n");
694             pthread_cond_wait(&server->newrq_cond, &server->newrq_mutex);
695             printf("clientwr: got signal\n");
696         }
697         server->newrq = 0;
698         pthread_mutex_unlock(&server->newrq_mutex);
699                
700         for (i = 0; i < MAX_REQUESTS; i++) {
701             pthread_mutex_lock(&server->newrq_mutex);
702             while (!server->requests[i].buf && i < MAX_REQUESTS)
703                 i++;
704             if (i == MAX_REQUESTS) {
705                 pthread_mutex_unlock(&server->newrq_mutex);
706                 break;
707             }
708
709             /* already received or too many tries */
710             if (server->requests[i].received || server->requests[i].tries > 2) {
711                 free(server->requests[i].buf);
712                 /* setting this to NULL means that it can be reused */
713                 server->requests[i].buf = NULL;
714                 pthread_mutex_unlock(&server->newrq_mutex);
715                 continue;
716             }
717             pthread_mutex_unlock(&server->newrq_mutex);
718             
719             server->requests[i].tries++;
720             clientradput(server, server->requests[i].buf);
721         }
722     }
723     /* should do more work to maintain TLS connections, keepalives etc */
724 }
725
726 void *udpserverwr(void *arg) {
727     struct replyq *replyq = &udp_server_replyq;
728     struct reply *reply = replyq->replies;
729     
730     pthread_mutex_lock(&replyq->count_mutex);
731     for (;;) {
732         while (!replyq->count) {
733             printf("udp server writer, waiting for signal\n");
734             pthread_cond_wait(&replyq->count_cond, &replyq->count_mutex);
735             printf("udp server writer, got signal\n");
736         }
737         pthread_mutex_unlock(&replyq->count_mutex);
738         
739         if (sendto(udp_server_sock, reply->buf, RADLEN(reply->buf), 0,
740                    (struct sockaddr *)&reply->tosa, SOCKADDR_SIZE(reply->tosa)) < 0)
741             err("sendudp: send failed");
742         free(reply->buf);
743         
744         pthread_mutex_lock(&replyq->count_mutex);
745         replyq->count--;
746         memmove(replyq->replies, replyq->replies + 1,
747                 replyq->count * sizeof(struct reply));
748     }
749 }
750
751 void *udpserverrd(void *arg) {
752     struct request rq;
753     unsigned char *buf;
754     struct server *to;
755     struct client *fr;
756     pthread_t udpserverwrth;
757     
758     if ((udp_server_sock = bindport(SOCK_DGRAM, udp_server_port)) < 0) {
759         printf("udpserverrd: socket/bind failed\n");
760         exit(1);
761     }
762     printf("udpserverrd: listening on UDP port %s\n", udp_server_port);
763
764     if (pthread_create(&udpserverwrth, NULL, udpserverwr, NULL))
765         errx("pthread_create failed");
766     
767     for (;;) {
768         fr = NULL;
769         memset(&rq, 0, sizeof(struct request));
770         buf = radudpget(udp_server_sock, &fr, NULL, &rq.fromsa);
771         to = radsrv(&rq, buf, fr);
772         if (!to) {
773             printf("udpserverrd: ignoring request, no place to send it\n");
774             continue;
775         }
776         sendrq(to, fr, &rq);
777     }
778 }
779
780 void *tlsserverwr(void *arg) {
781     int cnt;
782     unsigned long error;
783     struct client *client = (struct client *)arg;
784     struct replyq *replyq;
785     
786     pthread_mutex_lock(&client->replycount_mutex);
787     for (;;) {
788         replyq = client->replyq;
789         while (!replyq->count) {
790             printf("tls server writer, waiting for signal\n");
791             pthread_cond_wait(&replyq->count_cond, &replyq->count_mutex);
792             printf("tls server writer, got signal\n");
793         }
794         pthread_mutex_unlock(&replyq->count_mutex);
795         cnt = SSL_write(client->peer.ssl, replyq->replies->buf, RADLEN(replyq->replies->buf));
796         if (cnt > 0)
797             printf("tlsserverwr: Sent %d bytes, Radius packet of length %d\n",
798                    cnt, RADLEN(replyq->replies->buf));
799         else
800             while ((error = ERR_get_error()))
801                 err("tlsserverwr: SSL: %s", ERR_error_string(error, NULL));
802         free(replyq->replies->buf);
803
804         pthread_mutex_lock(&replyq->count_mutex);
805         replyq->count--;
806         memmove(replyq->replies, replyq->replies + 1, replyq->count * sizeof(struct reply));
807     }
808 }
809
810 void *tlsserverrd(void *arg) {
811     struct request rq;
812     char unsigned *buf;
813     unsigned long error;
814     struct server *to;
815     int s;
816     struct client *client = (struct client *)arg;
817     pthread_t tlsserverwrth;
818
819     printf("tlsserverrd starting\n");
820     if (SSL_accept(client->peer.ssl) <= 0) {
821         while ((error = ERR_get_error()))
822             err("tlsserverrd: SSL: %s", ERR_error_string(error, NULL));
823         errx("accept failed, child exiting");
824     }
825
826     if (pthread_create(&tlsserverwrth, NULL, tlsserverwr, (void *)client))
827         errx("pthread_create failed");
828     
829     for (;;) {
830         buf = radtlsget(client->peer.ssl);
831         if (!buf) {
832             printf("tlsserverrd: connection lost\n");
833             s = SSL_get_fd(client->peer.ssl);
834             SSL_free(client->peer.ssl);
835             client->peer.ssl = NULL;
836             if (s >= 0)
837                 close(s);
838             pthread_exit(NULL);
839         }
840         printf("tlsserverrd: got Radius message from %s\n", client->peer.host);
841         memset(&rq, 0, sizeof(struct request));
842         to = radsrv(&rq, buf, client);
843         if (!to) {
844             printf("ignoring request, no place to send it\n");
845             continue;
846         }
847         sendrq(to, client, &rq);
848     }
849 }
850
851 int tlslistener(SSL_CTX *ssl_ctx) {
852     pthread_t tlsserverth;
853     int s, snew;
854     struct sockaddr_storage from;
855     size_t fromlen = sizeof(from);
856     struct client *client;
857
858     if ((s = bindport(SOCK_STREAM, DEFAULT_TLS_PORT)) < 0) {
859         printf("tlslistener: socket/bind failed\n");
860         exit(1);
861     }
862     
863     listen(s, 0);
864     printf("listening for incoming TLS on port %s\n", DEFAULT_TLS_PORT);
865
866     for (;;) {
867         snew = accept(s, (struct sockaddr *)&from, &fromlen);
868         if (snew < 0)
869             errx("accept failed");
870         printf("incoming TLS connection from %s\n", addr2string((struct sockaddr *)&from, fromlen));
871
872         client = find_client('T', (struct sockaddr *)&from, NULL);
873         if (!client) {
874             printf("ignoring request, not a known TLS client\n");
875             close(snew);
876             continue;
877         }
878
879         if (client->peer.ssl) {
880             printf("Ignoring incoming connection, already have one from this client\n");
881             close(snew);
882             continue;
883         }
884         client->peer.ssl = SSL_new(ssl_ctx);
885         SSL_set_fd(client->peer.ssl, snew);
886         if (pthread_create(&tlsserverth, NULL, tlsserverrd, (void *)client))
887             errx("pthread_create failed");
888     }
889     return 0;
890 }
891
892 char *parsehostport(char *s, struct peer *peer) {
893     char *p, *field;
894     int ipv6 = 0;
895
896     p = s;
897     // allow literal addresses and port, e.g. [2001:db8::1]:1812
898     if (*p == '[') {
899         p++;
900         field = p;
901         for (; *p && *p != ']' && *p != ' ' && *p != '\t' && *p != '\n'; p++);
902         if (*p != ']') {
903             printf("no ] matching initial [\n");
904             exit(1);
905         }
906         ipv6 = 1;
907     } else {
908         field = p;
909         for (; *p && *p != ':' && *p != ' ' && *p != '\t' && *p != '\n'; p++);
910     }
911     if (field == p) {
912         printf("missing host/address\n");
913         exit(1);
914     }
915     peer->host = malloc(p - field + 1);
916     if (!peer->host)
917         errx("malloc failed");
918     memcpy(peer->host, field, p - field);
919     peer->host[p - field] = '\0';
920     if (ipv6) {
921         p++;
922         if (*p && *p != ':' && *p != ' ' && *p != '\t' && *p != '\n') {
923             printf("unexpected character after ]\n");
924             exit(1);
925         }
926     }
927     if (*p == ':') {
928             /* port number or service name is specified */;
929             field = p++;
930             for (; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++);
931             if (field == p) {
932                 printf("syntax error, : but no following port\n");
933                 exit(1);
934             }
935             peer->port = malloc(p - field + 1);
936             if (!peer->port)
937                 errx("malloc failed");
938             memcpy(peer->port, field, p - field);
939             peer->port[p - field] = '\0';
940     } else
941         peer->port = NULL;
942     return p;
943 }
944
945 // * is default, else longest match ... ";" used for separator
946 char *parserealmlist(char *s, struct server *server) {
947     char *p;
948     int i, n, l;
949
950     for (p = s, n = 1; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++)
951         if (*p == ';')
952             n++;
953     l = p - s;
954     if (!l) {
955         server->realms = NULL;
956         return p;
957     }
958     server->realmdata = malloc(l + 1);
959     if (!server->realmdata)
960         errx("malloc failed");
961     memcpy(server->realmdata, s, l);
962     server->realmdata[l] = '\0';
963     server->realms = malloc((1+n) * sizeof(char *));
964     if (!server->realms)
965         errx("malloc failed");
966     server->realms[0] = server->realmdata;
967     for (n = 1, i = 0; i < l; i++)
968         if (server->realmdata[i] == ';') {
969             server->realmdata[i] = '\0';
970             server->realms[n++] = server->realmdata + i + 1;
971         }       
972     server->realms[n] = NULL;
973     return p;
974 }
975
976 /* exactly one argument must be non-NULL */
977 void getconfig(const char *serverfile, const char *clientfile) {
978     FILE *f;
979     char line[1024];
980     char *p, *field, **r;
981     struct client *client;
982     struct server *server;
983     struct peer *peer;
984     int *count;
985     
986     if (serverfile) {
987         printf("opening file %s for reading\n", serverfile);
988         f = fopen(serverfile, "r");
989         if (!f)
990             errx("getconfig failed to open %s for reading", serverfile);
991         count = &server_count;
992     } else {
993         printf("opening file %s for reading\n", clientfile);
994         f = fopen(clientfile, "r");
995         if (!f)
996             errx("getconfig failed to open %s for reading", clientfile);
997         udp_server_replyq.replies = malloc(4 * MAX_REQUESTS * sizeof(struct reply));
998         if (!udp_server_replyq.replies)
999             errx("malloc failed");
1000         udp_server_replyq.size = 4 * MAX_REQUESTS;
1001         udp_server_replyq.count = 0;
1002         pthread_mutex_init(&udp_server_replyq.count_mutex, NULL);
1003         pthread_cond_init(&udp_server_replyq.count_cond, NULL);
1004         count = &client_count;
1005     }    
1006     
1007     *count = 0;
1008     while (fgets(line, 1024, f) && *count < MAX_PEERS) {
1009         if (serverfile) {
1010             server = &servers[*count];
1011             memset(server, 0, sizeof(struct server));
1012             peer = &server->peer;
1013         } else {
1014             client = &clients[*count];
1015             memset(client, 0, sizeof(struct client));
1016             peer = &client->peer;
1017         }
1018         for (p = line; *p == ' ' || *p == '\t'; p++);
1019         if (*p == '#' || *p == '\n')
1020             continue;
1021         if (*p != 'U' && *p != 'T') {
1022             printf("server type must be U or T, got %c\n", *p);
1023             exit(1);
1024         }
1025         peer->type = *p;
1026         for (p++; *p == ' ' || *p == '\t'; p++);
1027         p = parsehostport(p, peer);
1028         if (!peer->port)
1029             peer->port = (peer->type == 'U' ? DEFAULT_UDP_PORT : DEFAULT_TLS_PORT);
1030         for (; *p == ' ' || *p == '\t'; p++);
1031         if (serverfile) {
1032             p = parserealmlist(p, server);
1033             if (!server->realms) {
1034                 printf("realm list must be specified\n");
1035                 exit(1);
1036             }
1037             for (; *p == ' ' || *p == '\t'; p++);
1038         }
1039         field = p;
1040         for (; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1041         if (field == p) {
1042             /* no secret set and end of line, line is complete if TLS */
1043             if (peer->type == 'U') {
1044                 printf("secret must be specified for UDP\n");
1045                 exit(1);
1046             }
1047             peer->secret = DEFAULT_TLS_SECRET;
1048         } else {
1049             peer->secret = malloc(p - field + 1);
1050             if (!peer->secret)
1051                 errx("malloc failed");
1052             memcpy(peer->secret, field, p - field);
1053             peer->secret[p - field] = '\0';
1054             /* check that rest of line only white space */
1055             for (; *p == ' ' || *p == '\t'; p++);
1056             if (*p && *p != '\n') {
1057                 printf("max 4 fields per line, found a 5th\n");
1058                 exit(1);
1059             }
1060         }
1061
1062         if ((serverfile && !resolvepeer(&server->peer)) ||
1063             (clientfile && !resolvepeer(&client->peer))) {
1064             printf("failed to resolve host %s port %s, exiting\n", peer->host, peer->port);
1065             exit(1);
1066         }
1067
1068         if (serverfile) {
1069             pthread_mutex_init(&server->lock, NULL);
1070             server->sock = -1;
1071             server->requests = malloc(MAX_REQUESTS * sizeof(struct request));
1072             if (!server->requests)
1073                 errx("malloc failed");
1074             memset(server->requests, 0, MAX_REQUESTS * sizeof(struct request));
1075             server->newrq = 0;
1076             pthread_mutex_init(&server->newrq_mutex, NULL);
1077             pthread_cond_init(&server->newrq_cond, NULL);
1078         } else {
1079             if (peer->type == 'U')
1080                 client->replyq = &udp_server_replyq;
1081             else {
1082                 client->replyq = malloc(sizeof(struct replyq));
1083                 if (!client->replyq)
1084                     errx("malloc failed");
1085                 client->replyq->replies = malloc(MAX_REQUESTS * sizeof(struct reply));
1086                 if (!client->replyq->replies)
1087                     errx("malloc failed");
1088                 client->replyq->size = MAX_REQUESTS;
1089                 client->replyq->count = 0;
1090                 pthread_mutex_init(&client->replyq->count_mutex, NULL);
1091                 pthread_cond_init(&client->replyq->count_cond, NULL);
1092             }
1093         }
1094         printf("got type %c, host %s, port %s, secret %s\n", peer->type, peer->host, peer->port, peer->secret);
1095         if (serverfile) {
1096             printf("    with realms:");
1097             for (r = server->realms; *r; r++)
1098                 printf(" %s", *r);
1099             printf("\n");
1100         }
1101         (*count)++;
1102     }
1103     fclose(f);
1104 }
1105
1106 void parseargs(int argc, char **argv) {
1107     int c;
1108
1109     while ((c = getopt(argc, argv, "p:")) != -1) {
1110         switch (c) {
1111         case 'p':
1112             udp_server_port = optarg;
1113             break;
1114         default:
1115             goto usage;
1116         }
1117     }
1118
1119     return;
1120
1121  usage:
1122     printf("radsecproxy [ -p UDP-port ]\n");
1123     exit(1);
1124 }
1125                
1126 int main(int argc, char **argv) {
1127     SSL_CTX *ssl_ctx_srv;
1128     unsigned long error;
1129     pthread_t udpserverth;
1130     pthread_attr_t joinable;
1131     int i;
1132     
1133     parseargs(argc, argv);
1134     getconfig("servers.conf", NULL);
1135     getconfig(NULL, "clients.conf");
1136     
1137     ssl_locks_setup();
1138
1139     pthread_attr_init(&joinable);
1140     pthread_attr_setdetachstate(&joinable, PTHREAD_CREATE_JOINABLE);
1141    
1142     /* listen on UDP if at least one UDP client */
1143     
1144     for (i = 0; i < client_count; i++)
1145         if (clients[i].peer.type == 'U') {
1146             if (pthread_create(&udpserverth, &joinable, udpserverrd, NULL))
1147                 errx("pthread_create failed");
1148             break;
1149         }
1150     
1151     /* SSL setup */
1152     SSL_load_error_strings();
1153     SSL_library_init();
1154
1155     while (!RAND_status()) {
1156         time_t t = time(NULL);
1157         pid_t pid = getpid();
1158         RAND_seed((unsigned char *)&t, sizeof(time_t));
1159         RAND_seed((unsigned char *)&pid, sizeof(pid));
1160     }
1161     
1162     /* initialise client part and start clients */
1163     ssl_ctx_cl = SSL_CTX_new(TLSv1_client_method());
1164     if (!ssl_ctx_cl)
1165         errx("no ssl ctx");
1166     
1167     for (i = 0; i < server_count; i++) {
1168         if (pthread_create(&servers[i].clientth, NULL, clientwr, (void *)&servers[i]))
1169             errx("pthread_create failed");
1170     }
1171
1172     for (i = 0; i < client_count; i++)
1173         if (clients[i].peer.type == 'T')
1174             break;
1175
1176     if (i == client_count) {
1177         printf("No TLS clients defined, not starting TLS listener\n");
1178         /* just hang around doing nothing, anything to do here? */
1179         for (;;)
1180             sleep(1000);
1181     }
1182     
1183     /* setting up server/daemon part */
1184     ssl_ctx_srv = SSL_CTX_new(TLSv1_server_method());
1185     if (!ssl_ctx_srv)
1186         errx("no ssl ctx");
1187     if (!SSL_CTX_use_certificate_file(ssl_ctx_srv, "/tmp/server.pem", SSL_FILETYPE_PEM)) {
1188         while ((error = ERR_get_error()))
1189             err("SSL: %s", ERR_error_string(error, NULL));
1190         errx("Failed to load certificate");
1191     }
1192     if (!SSL_CTX_use_PrivateKey_file(ssl_ctx_srv, "/tmp/server.key", SSL_FILETYPE_PEM)) {
1193         while ((error = ERR_get_error()))
1194             err("SSL: %s", ERR_error_string(error, NULL));
1195         errx("Failed to load private key");
1196     }
1197
1198     return tlslistener(ssl_ctx_srv);
1199 }