git-svn-id: https://svn.testnett.uninett.no/radsecproxy/trunk@2 e88ac4ed-0b26-0410...
[libradsec.git] / radsecproxy.c
1 /*
2  * Copyright (C) 2006 Stig Venaas <venaas@uninett.no>
3  *
4  * Permission to use, copy, modify, and distribute this software for any
5  * purpose with or without fee is hereby granted, provided that the above
6  * copyright notice and this permission notice appear in all copies.
7  */
8
9 /* BUGS:
10  * peers can not yet be specified with literal IPv6 addresses due to port syntax
11  */
12
13 /* TODO:
14  * Among other things:
15  * timer based client retrans or maybe no retrans and just a timer...
16  * make our server ignore client retrans?
17  * tls keep alives
18  * routing based on id....
19  * tls certificate validation
20 */
21
22 /* For UDP there is one server instance consisting of udpserverrd and udpserverth
23  *              rd is responsible for init and launching wr
24  * For TLS there is a server instance that launches tlsserverrd for each TLS peer
25  *          each tlsserverrd launches tlsserverwr
26  * For each UDP/TLS peer there is clientrd and clientwr, clientwr is responsible
27  *          for init and launching rd
28  *
29  * serverrd will receive a request, processes it and puts it in the requestq of
30  *          the appropriate clientwr
31  * clientwr monitors its requestq and sends requests
32  * clientrd looks for responses, processes them and puts them in the replyq of
33  *          the peer the request came from
34  * serverwr monitors its reply and sends replies
35  *
36  * In addition to the main thread, we have:
37  * If UDP peers are configured, there will be 2 + 2 * #peers UDP threads
38  * If TLS peers are configured, there will initially be 2 * #peers TLS threads
39  * For each TLS peer connecting to us there will be 2 more TLS threads
40  *       This is only for connected peers
41  * Example: With 3 UDP peer and 30 TLS peers, there will be a max of
42  *          1 + (2 + 2 * 3) + (2 * 30) + (2 * 30) = 129 threads
43 */
44
45 #include <netdb.h>
46 #include <unistd.h>
47 #include <pthread.h>
48 #include <openssl/ssl.h>
49 #include <openssl/rand.h>
50 #include <openssl/err.h>
51 #include <openssl/md5.h>
52 #include "radsecproxy.h"
53
54 static struct peer peers[MAX_PEERS];
55 static int peer_count = 0;
56
57 static struct replyq udp_server_replyq;
58 static int udp_server_sock = -1;
59 static char *udp_server_port = DEFAULT_UDP_PORT;
60 static pthread_mutex_t *ssl_locks;
61 static long *ssl_lock_count;
62 static SSL_CTX *ssl_ctx_cl;
63 extern int optind;
64 extern char *optarg;
65
66 /* callbacks for making OpenSSL thread safe */
67 unsigned long ssl_thread_id() {
68         return (unsigned long)pthread_self();
69 };
70
71 void ssl_locking_callback(int mode, int type, const char *file, int line) {
72     if (mode & CRYPTO_LOCK) {
73         pthread_mutex_lock(&ssl_locks[type]);
74         ssl_lock_count[type]++;
75     } else
76         pthread_mutex_unlock(&ssl_locks[type]);
77 }
78
79 void ssl_locks_setup() {
80     int i;
81
82     ssl_locks = malloc(CRYPTO_num_locks() * sizeof(pthread_mutex_t));
83     ssl_lock_count = OPENSSL_malloc(CRYPTO_num_locks() * sizeof(long));
84     for (i = 0; i < CRYPTO_num_locks(); i++) {
85         ssl_lock_count[i] = 0;
86         pthread_mutex_init(&ssl_locks[i], NULL);
87     }
88
89     CRYPTO_set_id_callback(ssl_thread_id);
90     CRYPTO_set_locking_callback(ssl_locking_callback);
91 }
92
93 int resolvepeer(struct peer *peer) {
94     struct addrinfo hints;
95     
96     pthread_mutex_lock(&peer->lock);
97     if (peer->addrinfo) {
98         /* assume we should re-resolve */
99         freeaddrinfo(peer->addrinfo);
100         peer->addrinfo = NULL;
101     }
102     
103     memset(&hints, 0, sizeof(hints));
104     hints.ai_socktype = (peer->type == 'T' ? SOCK_STREAM : SOCK_DGRAM);
105     hints.ai_family = AF_UNSPEC;
106     if (getaddrinfo(peer->host, peer->port, &hints, &peer->addrinfo)) {
107         err("resolvepeer: can't resolve %s port %s", peer->host, peer->port);
108         peer->addrinfo = NULL; /* probably don't need this */
109         pthread_mutex_unlock(&peer->lock);
110         return 0;
111     }
112     pthread_mutex_unlock(&peer->lock);
113     return 1;
114 }         
115
116 int connecttopeer(struct peer *peer) {
117     int s;
118     struct addrinfo *res;
119     
120     if (!peer->addrinfo) {
121         resolvepeer(peer);
122         if (!peer->addrinfo) {
123             printf("connecttopeer: can't resolve %s into address to connect to\n", peer->host);
124             return -1;
125         }
126     }
127
128     for (res = peer->addrinfo; res; res = res->ai_next) {
129         s = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
130         if (s < 0) {
131             err("connecttopeer: socket failed");
132             continue;
133         }
134         if (connect(s, res->ai_addr, res->ai_addrlen) == 0)
135             break;
136         err("connecttopeer: connect failed");
137         close(s);
138         s = -1;
139     }
140     return s;
141 }         
142
143 /* returns the peer with matching address, or NULL */
144 /* if peer argument is not NULL, we only check that one peer */
145 struct peer *find_peer(char type, struct sockaddr *addr, struct peer *peer) {
146     struct sockaddr_in6 *sa6;
147     struct in_addr *a4 = NULL;
148     struct peer *p;
149     int i;
150     struct addrinfo *res;
151
152     if (addr->sa_family == AF_INET6) {
153         sa6 = (struct sockaddr_in6 *)addr;
154         if (IN6_IS_ADDR_V4MAPPED(&sa6->sin6_addr))
155             a4 = (struct in_addr *)&sa6->sin6_addr.s6_addr[12];
156     } else
157         a4 = &((struct sockaddr_in *)addr)->sin_addr;
158
159     p = (peer ? peer : peers);
160     for (i = 0; i < peer_count; i++) {
161         if (p->type == type)
162             for (res = p->addrinfo; res; res = res->ai_next)
163                 if ((a4 && res->ai_family == AF_INET &&
164                      !memcmp(a4, &((struct sockaddr_in *)res->ai_addr)->sin_addr, 4)) ||
165                     (res->ai_family == AF_INET6 &&
166                      !memcmp(&sa6->sin6_addr, &((struct sockaddr_in6 *)res->ai_addr)->sin6_addr, 16)))
167                     return p;
168         if (peer)
169             break;
170         p++;
171     }
172     return NULL;
173 }
174
175 /* if *peer == NULL we return who we received from, else require it to be from peer */
176 /* return from in sa if not NULL */
177 unsigned char *radudpget(int s, struct peer **peer, struct sockaddr_storage *sa) {
178     int cnt, len;
179     struct peer *f;
180     unsigned char buf[65536], *rad;
181     struct sockaddr_storage from;
182     socklen_t fromlen = sizeof(from);
183
184     for (;;) {
185         cnt = recvfrom(s, buf, sizeof(buf), 0, (struct sockaddr *)&from, &fromlen);
186         if (cnt == -1) {
187             err("radudpget: recv failed");
188             continue;
189         }
190         printf("radudpget: got %d bytes from %s\n", cnt, addr2string((struct sockaddr *)&from, fromlen));
191
192         if (cnt < 20) {
193             printf("radudpget: packet too small\n");
194             continue;
195         }
196     
197         len = RADLEN(buf);
198
199         if (cnt < len) {
200             printf("radudpget: packet smaller than length field in radius header\n");
201             continue;
202         }
203         if (cnt > len)
204             printf("radudpget: packet was padded with %d bytes\n", cnt - len);
205
206         f = find_peer('U', (struct sockaddr *)&from, *peer);
207         if (!f) {
208             printf("radudpget: got packet from wrong or unknown UDP peer, ignoring\n");
209             continue;
210         }
211
212         rad = malloc(len);
213         if (rad)
214             break;
215         err("radudpget: malloc failed");
216     }
217     memcpy(rad, buf, len);
218     *peer = f; /* only need this if *peer == NULL, but if not NULL *peer == f here */
219     if (sa)
220         *sa = from;
221     return rad;
222 }
223
224 void tlsconnect(struct peer *peer, int oldsock, char *text) {
225     unsigned int sleeptime, try = 0;
226     
227     pthread_mutex_lock(&peer->lock);
228     if (peer->sockcl != oldsock) {
229         /* already reconnected, nothing to do */
230         printf("reconnect: seems already reconnected\n");
231         pthread_mutex_unlock(&peer->lock);
232         return;
233     }
234
235     printf("tlsconnect %d %s\n", oldsock, text);
236     sleep(1);
237     for (;;) {
238         printf("tlsconnect: trying to open TLS connection to %s port %s\n", peer->host, peer->port);
239         if ((peer->sockcl = connecttopeer(peer)) >= 0)
240             break;
241         try++;
242         if (try < 6)
243             sleeptime = 10;
244         else if (try < 20)
245             sleeptime = 60;
246         else
247             sleeptime = 900;
248         /* should possibly re-resolve host addresses at some point */
249         printf("tlsconnect: can't connect, retry #%d in %ds\n", try, sleeptime);
250         sleep(sleeptime);
251     }
252         
253     SSL_free(peer->sslcl);
254     peer->sslcl = SSL_new(ssl_ctx_cl);
255     SSL_set_fd(peer->sslcl, peer->sockcl);
256     /* must close oldsock after get new socket so that they are always different */
257     if (oldsock >= 0)
258         close(oldsock);
259     SSL_connect(peer->sslcl);
260     printf("tlsconnect: TLS connection to %s port %s up\n", peer->host, peer->port);
261     pthread_mutex_unlock(&peer->lock);
262 }
263
264 unsigned char *radtlsget(SSL *ssl) {
265     int cnt, total, len;
266     unsigned char buf[4], *rad;
267
268     for (;;) {
269         for (total = 0; total < 4; total += cnt) {
270             cnt = SSL_read(ssl, buf + total, 4 - total);
271             if (cnt <= 0) {
272                 printf("radtlsget: connection lost\n");
273                 return NULL;
274             }
275         }
276
277         len = RADLEN(buf);
278         rad = malloc(len);
279         if (!rad) {
280             err("radtlsget: malloc failed");
281             continue;
282         }
283         memcpy(rad, buf, 4);
284
285         for (; total < len; total += cnt) {
286             cnt = SSL_read(ssl, rad + total, len - total);
287             if (cnt <= 0) {
288                 printf("radtlsget: connection lost\n");
289                 free(rad);
290                 return NULL;
291             }
292         }
293     
294         if (total >= 20)
295             break;
296         
297         free(rad);
298         printf("radtlsget: packet smaller than minimum radius size\n");
299     }
300     
301     printf("radtlsget: got %d bytes\n", total);
302     return rad;
303 }
304
305 int clientradput(struct peer *peer, unsigned char *rad) {
306     int cnt, s;
307     size_t len;
308     unsigned long error;
309
310     len = RADLEN(rad);
311     if (peer->type == 'U') {
312         if (send(peer->sockcl, rad, len, 0) >= 0) {
313             printf("clienradput: sent UDP of length %d to %s port %s\n", len, peer->host, peer->port);
314             return 1;
315         }
316         err("clientradput: send failed");
317         return 0;
318     }
319
320     s = peer->sockcl;
321     while ((cnt = SSL_write(peer->sslcl, rad, len)) <= 0) {
322         while ((error = ERR_get_error()))
323             err("clientwr: TLS: %s", ERR_error_string(error, NULL));
324         tlsconnect(peer, s, "clientradput");
325         s = peer->sockcl;
326     }
327            
328     printf("clientradput: Sent %d bytes, Radius packet of length %d to TLS peer %s\n",
329            cnt, len, peer->host);
330     return 1;
331 }
332
333 int radsign(unsigned char *rad, unsigned char *sec) {
334     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
335     static unsigned char first = 1;
336     static EVP_MD_CTX mdctx;
337     unsigned int md_len;
338     int result;
339     
340     pthread_mutex_lock(&lock);
341     if (first) {
342         EVP_MD_CTX_init(&mdctx);
343         first = 0;
344     }
345
346     result = (EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) &&
347         EVP_DigestUpdate(&mdctx, rad, RADLEN(rad)) &&
348         EVP_DigestUpdate(&mdctx, sec, strlen(sec)) &&
349         EVP_DigestFinal_ex(&mdctx, rad + 4, &md_len) &&
350         md_len == 16);
351     pthread_mutex_unlock(&lock);
352     return result;
353 }
354
355 int validauth(unsigned char *rad, unsigned char *reqauth, unsigned char *sec) {
356     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
357     static unsigned char first = 1;
358     static EVP_MD_CTX mdctx;
359     unsigned char hash[EVP_MAX_MD_SIZE];
360     unsigned int len;
361     int result;
362     
363     pthread_mutex_lock(&lock);
364     if (first) {
365         EVP_MD_CTX_init(&mdctx);
366         first = 0;
367     }
368
369     len = RADLEN(rad);
370     
371     result = (EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) &&
372               EVP_DigestUpdate(&mdctx, rad, 4) &&
373               EVP_DigestUpdate(&mdctx, reqauth, 16) &&
374               (len <= 20 || EVP_DigestUpdate(&mdctx, rad + 20, len - 20)) &&
375               EVP_DigestUpdate(&mdctx, sec, strlen(sec)) &&
376               EVP_DigestFinal_ex(&mdctx, hash, &len) &&
377               len == 16 &&
378               !memcmp(hash, rad + 4, 16));
379     pthread_mutex_unlock(&lock);
380     return result;
381 }
382               
383 void sendrq(struct peer *to, struct peer *from, struct request *rq) {
384     int i;
385
386     pthread_mutex_lock(&to->newrq_mutex);
387     for (i = 0; i < MAX_REQUESTS; i++)
388         if (!to->requests[i].buf)
389             break;
390     if (i == MAX_REQUESTS) {
391         printf("No room in queue, dropping request\n");
392         pthread_mutex_unlock(&to->newrq_mutex);
393         return;
394     }
395     
396     rq->buf[1] = (char)i;
397     to->requests[i] = *rq;
398
399     if (!to->newrq) {
400         to->newrq = 1;
401         printf("signalling client writer\n");
402         pthread_cond_signal(&to->newrq_cond);
403     }
404     pthread_mutex_unlock(&to->newrq_mutex);
405 }
406
407 void sendreply(struct peer *to, struct peer *from, char *buf, struct sockaddr_storage *tosa) {
408     struct replyq *replyq = to->replyq;
409     
410     pthread_mutex_lock(&replyq->count_mutex);
411     if (replyq->count == replyq->size) {
412         printf("No room in queue, dropping request\n");
413         pthread_mutex_unlock(&replyq->count_mutex);
414         return;
415     }
416
417     replyq->replies[replyq->count].buf = buf;
418     if (tosa)
419         replyq->replies[replyq->count].tosa = *tosa;
420     replyq->count++;
421
422     if (replyq->count == 1) {
423         printf("signalling client writer\n");
424         pthread_cond_signal(&replyq->count_cond);
425     }
426     pthread_mutex_unlock(&replyq->count_mutex);
427 }
428
429 int pwdcrypt(uint8_t *plain, uint8_t *enc, uint8_t enclen, uint8_t *shared, uint8_t sharedlen,
430                 uint8_t *auth) {
431     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
432     static unsigned char first = 1;
433     static EVP_MD_CTX mdctx;
434     unsigned char hash[EVP_MAX_MD_SIZE], *input;
435     unsigned int md_len;
436     uint8_t i, offset = 0;
437     
438     pthread_mutex_lock(&lock);
439     if (first) {
440         EVP_MD_CTX_init(&mdctx);
441         first = 0;
442     }
443
444     input = auth;
445     for (;;) {
446         if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
447             !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
448             !EVP_DigestUpdate(&mdctx, input, 16) ||
449             !EVP_DigestFinal_ex(&mdctx, hash, &md_len) ||
450             md_len != 16) {
451             pthread_mutex_unlock(&lock);
452             return 0;
453         }
454         for (i = 0; i < 16; i++)
455             plain[offset + i] = hash[i] ^ enc[offset + i];
456         offset += 16;
457         if (offset == enclen)
458             break;
459         input = enc + offset - 16;
460     }
461     pthread_mutex_unlock(&lock);
462     return 1;
463 }
464
465 struct peer *radsrv(struct request *rq, char *buf, struct peer *from) {
466     uint8_t code, id, *auth, *attr, *usernameattr = NULL, *userpwdattr = NULL, pwd[128], pwdlen;
467     int i;
468     uint16_t len;
469     int left;
470     struct peer *to;
471     unsigned char newauth[16];
472     
473     code = *(uint8_t *)buf;
474     id = *(uint8_t *)(buf + 1);
475     len = RADLEN(buf);
476     auth = (uint8_t *)(buf + 4);
477
478     printf("radsrv: code %d, id %d, length %d\n", code, id, len);
479     
480     if (code != RAD_Access_Request) {
481         printf("radsrv: server currently accepts only access-requests, ignoring\n");
482         return NULL;
483     }
484
485     left = len - 20;
486     attr = buf + 20;
487     
488     while (left > 1) {
489         left -= attr[RAD_Attr_Length];
490         if (left < 0) {
491             printf("radsrv: attribute length exceeds packet length, ignoring packet\n");
492             return NULL;
493         }
494         switch (attr[RAD_Attr_Type]) {
495         case RAD_Attr_User_Name:
496             usernameattr = attr;
497             break;
498         case RAD_Attr_User_Password:
499             userpwdattr = attr;
500             break;
501         }
502         attr += attr[RAD_Attr_Length];
503     }
504     if (left)
505         printf("radsrv: malformed packet? remaining byte after last attribute\n");
506
507     if (usernameattr) {
508         printf("radsrv: Username: ");
509         for (i = 0; i < usernameattr[RAD_Attr_Length]; i++)
510             printf("%c", usernameattr[RAD_Attr_Value + i]);
511         printf("\n");
512     }
513
514     /* find out where to send the packet, for now we send to first connected
515        TLS peer if UDP, and first UDP peer if TLS */
516     
517     i = peer_count;
518     switch (from->type) {
519     case 'U':
520         for (i = 0; i < peer_count; i++)
521             if (peers[i].type == 'T' && peers[i].sockcl >= 0)
522                 break;
523         break;
524     case 'T':
525         for (i = 0; i < peer_count; i++)
526             if (peers[i].type == 'U')
527                 break;
528         break;
529     }
530
531     if (i == peer_count) {
532         printf("radsrv: ignoring request, don't know where to send it\n");
533         return NULL;
534     }
535
536     to = &peers[i];
537     
538     if (!RAND_bytes(newauth, 16)) {
539         printf("radsrv: failed to generate random auth\n");
540         return NULL;
541     }
542
543     if (userpwdattr) {
544         printf("radsrv: found userpwdattr of length %d\n", userpwdattr[RAD_Attr_Length]);
545         pwdlen = userpwdattr[RAD_Attr_Length] - 2;
546         if (pwdlen < 16 || pwdlen > 128 || pwdlen % 16) {
547             printf("radsrv: invalid user password length\n");
548             return NULL;
549         }
550         
551         if (!pwdcrypt(pwd, &userpwdattr[RAD_Attr_Value], pwdlen, from->secret, strlen(from->secret), auth)) {
552             printf("radsrv: cannot decrypt password\n");
553             return NULL;
554         }
555         printf("radsrv: password: ");
556         for (i = 0; i < pwdlen; i++)
557             printf("%02x ", pwd[i]);
558         printf("\n");
559         if (!pwdcrypt(&userpwdattr[RAD_Attr_Value], pwd, pwdlen, to->secret, strlen(to->secret), newauth)) {
560             printf("radsrv: cannot encrypt password\n");
561             return NULL;
562         }
563     }
564
565     rq->buf = buf;
566     rq->from = from;
567     rq->origid = id;
568     memcpy(rq->origauth, auth, 16);
569     memcpy(rq->buf + 4, newauth, 16);
570     return to;
571 }
572
573 void *clientrd(void *arg) {
574     struct peer *from, *peer = (struct peer *)arg;
575     int i, s;
576     unsigned char *buf;
577     struct sockaddr_storage fromsa;
578     
579     for (;;) {
580         s = peer->sockcl;
581         buf = (peer->type == 'U' ? radudpget(s, &peer, NULL) : radtlsget(peer->sslcl));
582         if (!buf && peer->type == 'T') {
583             printf("retry in 60s\n");
584             sleep(60); /* should have exponential backoff perhaps, better do it inside radtlsget */
585             tlsconnect(peer, s, "clientrd");
586             continue;
587         }
588         
589         i = buf[1]; /* i is the id */
590
591         pthread_mutex_lock(&peer->newrq_mutex);
592         if (!peer->requests[i].buf || !peer->requests[i].tries) {
593             pthread_mutex_unlock(&peer->newrq_mutex);
594             printf("clientrd: no matching request sent with this id, ignoring\n");
595             continue;
596         }
597         
598         if (peer->requests[i].received) {
599             pthread_mutex_unlock(&peer->newrq_mutex);
600             printf("clientrd: already received, ignoring\n");
601             continue;
602         }
603
604         if (!validauth(buf, peer->requests[i].buf + 4, peer->secret)) {
605             pthread_mutex_unlock(&peer->newrq_mutex);
606             printf("clientrd: invalid auth, ignoring\n");
607             continue;
608         }
609
610         /* once we set received = 1, requests[i] may be reused */
611         buf[1] = (char)peer->requests[i].origid;
612         memcpy(buf + 4, peer->requests[i].origauth, 16);
613         from = peer->requests[i].from;
614         if (from->type == 'U')
615             fromsa = peer->requests[i].fromsa;
616         peer->requests[i].received = 1;
617         pthread_mutex_unlock(&peer->newrq_mutex);
618
619         if (!radsign(buf, from->secret)) {
620             printf("clientrd: failed to sign message\n");
621             continue;
622         }
623         
624         printf("clientrd: giving packet back to where it came from\n");
625         sendreply(from, peer, buf, from->type == 'U' ? &fromsa : NULL);
626     }
627 }
628
629 void *clientwr(void *arg) {
630     struct peer *peer = (struct peer *)arg;
631     pthread_t clientrdth;
632     int i;
633
634     if (peer->type == 'U') {
635         if ((peer->sockcl = connecttopeer(peer)) < 0) {
636             printf("clientwr: connecttopeer failed\n");
637             exit(1);
638         }
639     } else
640         tlsconnect(peer, -1, "new client");
641     
642     if (pthread_create(&clientrdth, NULL, clientrd, (void *)peer))
643         errx("clientwr: pthread_create failed");
644
645     for (;;) {
646         pthread_mutex_lock(&peer->newrq_mutex);
647         while (!peer->newrq) {
648             printf("clientwr: waiting for signal\n");
649             pthread_cond_wait(&peer->newrq_cond, &peer->newrq_mutex);
650             printf("clientwr: got signal\n");
651         }
652         peer->newrq = 0;
653         pthread_mutex_unlock(&peer->newrq_mutex);
654                
655         for (i = 0; i < MAX_REQUESTS; i++) {
656             pthread_mutex_lock(&peer->newrq_mutex);
657             while (!peer->requests[i].buf && i < MAX_REQUESTS)
658                 i++;
659             if (i == MAX_REQUESTS) {
660                 pthread_mutex_unlock(&peer->newrq_mutex);
661                 break;
662             }
663
664             /* already received or too many tries */
665             if (peer->requests[i].received || peer->requests[i].tries > 2) {
666                 free(peer->requests[i].buf);
667                 /* setting this to NULL means that it can be reused */
668                 peer->requests[i].buf = NULL;
669                 pthread_mutex_unlock(&peer->newrq_mutex);
670                 continue;
671             }
672             pthread_mutex_unlock(&peer->newrq_mutex);
673             
674             peer->requests[i].tries++;
675             clientradput(peer, peer->requests[i].buf);
676         }
677     }
678     /* should do more work to maintain TLS connections, keepalives etc */
679 }
680
681 void *udpserverwr(void *arg) {
682     struct replyq *replyq = &udp_server_replyq;
683     struct reply *reply = replyq->replies;
684     
685     pthread_mutex_lock(&replyq->count_mutex);
686     for (;;) {
687         while (!replyq->count) {
688             printf("udp server writer, waiting for signal\n");
689             pthread_cond_wait(&replyq->count_cond, &replyq->count_mutex);
690             printf("udp server writer, got signal\n");
691         }
692         pthread_mutex_unlock(&replyq->count_mutex);
693         
694         if (sendto(udp_server_sock, reply->buf, RADLEN(reply->buf), 0,
695                    (struct sockaddr *)&reply->tosa, SOCKADDR_SIZE(reply->tosa)) < 0)
696             err("sendudp: send failed");
697         free(reply->buf);
698         
699         pthread_mutex_lock(&replyq->count_mutex);
700         replyq->count--;
701         memmove(replyq->replies, replyq->replies + 1,
702                 replyq->count * sizeof(struct reply));
703     }
704 }
705
706 void *udpserverrd(void *arg) {
707     struct request rq;
708     unsigned char *buf;
709     struct peer *to, *fr;
710     pthread_t udpserverwrth;
711     
712     if ((udp_server_sock = bindport(SOCK_DGRAM, udp_server_port)) < 0) {
713         printf("udpserverrd: socket/bind failed\n");
714         exit(1);
715     }
716     printf("udpserverrd: listening on UDP port %s\n", udp_server_port);
717
718     if (pthread_create(&udpserverwrth, NULL, udpserverwr, NULL))
719         errx("pthread_create failed");
720     
721     for (;;) {
722         fr = NULL;
723         memset(&rq, 0, sizeof(struct request));
724         buf = radudpget(udp_server_sock, &fr, &rq.fromsa);
725         to = radsrv(&rq, buf, fr);
726         if (!to) {
727             printf("udpserverrd: ignoring request, no place to send it\n");
728             continue;
729         }
730         sendrq(to, fr, &rq);
731     }
732 }
733
734 void *tlsserverwr(void *arg) {
735     int cnt;
736     unsigned long error;
737     struct peer *peer = (struct peer *)arg;
738     struct replyq *replyq;
739     
740     pthread_mutex_lock(&peer->replycount_mutex);
741     for (;;) {
742         replyq = peer->replyq;
743         while (!replyq->count) {
744             printf("tls server writer, waiting for signal\n");
745             pthread_cond_wait(&replyq->count_cond, &replyq->count_mutex);
746             printf("tls server writer, got signal\n");
747         }
748         pthread_mutex_unlock(&replyq->count_mutex);
749         cnt = SSL_write(peer->sslsrv, replyq->replies->buf, RADLEN(replyq->replies->buf));
750         if (cnt > 0)
751             printf("tlsserverwr: Sent %d bytes, Radius packet of length %d\n",
752                    cnt, RADLEN(replyq->replies->buf));
753         else
754             while ((error = ERR_get_error()))
755                 err("tlsserverwr: SSL: %s", ERR_error_string(error, NULL));
756         free(replyq->replies->buf);
757
758         pthread_mutex_lock(&replyq->count_mutex);
759         replyq->count--;
760         memmove(replyq->replies, replyq->replies + 1, replyq->count * sizeof(struct reply));
761     }
762 }
763
764 void *tlsserverrd(void *arg) {
765     struct request rq;
766     char unsigned *buf;
767     unsigned long error;
768     struct peer *to;
769     int s;
770     struct peer *peer = (struct peer *)arg;
771     pthread_t tlsserverwrth;
772
773     printf("tlsserverrd starting\n");
774     if (SSL_accept(peer->sslsrv) <= 0) {
775         while ((error = ERR_get_error()))
776             err("tlsserverrd: SSL: %s", ERR_error_string(error, NULL));
777         errx("accept failed, child exiting");
778     }
779
780     if (pthread_create(&tlsserverwrth, NULL, tlsserverwr, (void *)peer))
781         errx("pthread_create failed");
782     
783     for (;;) {
784         buf = radtlsget(peer->sslsrv);
785         if (!buf) {
786             printf("tlsserverrd: connection lost\n");
787             s = SSL_get_fd(peer->sslsrv);
788             SSL_free(peer->sslsrv);
789             peer->sslsrv = NULL;
790             if (s >= 0)
791                 close(s);
792             pthread_exit(NULL);
793         }
794         printf("tlsserverrd: got Radius message from %s\n", peer->host);
795         memset(&rq, 0, sizeof(struct request));
796         to = radsrv(&rq, buf, peer);
797         if (!to) {
798             printf("ignoring request, no place to send it\n");
799             continue;
800         }
801         sendrq(to, peer, &rq);
802     }
803 }
804
805 int tlslistener(SSL_CTX *ssl_ctx) {
806     pthread_t tlsserverth;
807     int s, snew;
808     struct sockaddr_storage from;
809     size_t fromlen = sizeof(from);
810     struct peer *peer;
811
812     if ((s = bindport(SOCK_STREAM, DEFAULT_TLS_PORT)) < 0) {
813         printf("tlslistener: socket/bind failed\n");
814         exit(1);
815     }
816
817     listen(s, 0);
818     printf("listening for incoming TLS on port %s\n", DEFAULT_TLS_PORT);
819
820     for (;;) {
821         snew = accept(s, (struct sockaddr *)&from, &fromlen);
822         if (snew < 0)
823             errx("accept failed");
824         printf("incoming TLS connection from %s\n", addr2string((struct sockaddr *)&from, fromlen));
825
826         peer = find_peer('T', (struct sockaddr *)&from, NULL);
827         if (!peer) {
828             printf("ignoring request, not a known TLS peer\n");
829             close(snew);
830             continue;
831         }
832
833         if (peer->sslsrv) {
834             printf("Ignoring incoming connection, already have one from this peer\n");
835             close(snew);
836             continue;
837         }
838         peer->sslsrv = SSL_new(ssl_ctx);
839         SSL_set_fd(peer->sslsrv, snew);
840         if (pthread_create(&tlsserverth, NULL, tlsserverrd, (void *)peer))
841             errx("pthread_create failed");
842         
843         for (;;) {
844             /* currently only one server thread, so just halt here */
845             sleep(1000);
846         }
847     }
848     return 0;
849 }
850
851 char *parsehostport(char *s, char **host, char **port) {
852     char *p, *field;
853     int ipv6 = 0;
854
855     p = s;
856     // allow literal addresses and port, e.g. [2001:db8::1]:1812
857     if (*p == '[') {
858         p++;
859         field = p;
860         for (; *p && *p != ']' && *p != ' ' && *p != '\t' && *p != '\n'; p++);
861         if (*p != ']') {
862             printf("no ] matching initial [\n");
863             exit(1);
864         }
865         ipv6 = 1;
866     } else {
867         field = p;
868         for (; *p && *p != ':' && *p != ' ' && *p != '\t' && *p != '\n'; p++);
869     }
870     if (field == p) {
871         printf("missing host/address\n");
872         exit(1);
873     }
874     *host = malloc(p - field + 1);
875     if (!*host)
876         errx("malloc failed");
877     memcpy(*host, field, p - field);
878     (*host)[p - field] = '\0';
879     if (ipv6) {
880         p++;
881         if (*p && *p != ':' && *p != ' ' && *p != '\t' && *p != '\n') {
882             printf("unexpected character after ]\n");
883             exit(1);
884         }
885     }
886     if (*p == ':') {
887             /* port number or service name is specified */;
888             field = p++;
889             for (; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++);
890             if (field == p) {
891                 printf("syntax error, : but no following port\n");
892                 exit(1);
893             }
894             *port = malloc(p - field + 1);
895             if (!*port)
896                 errx("malloc failed");
897             memcpy(*port, field, p - field);
898             (*port)[p - field ] = '\0';
899     } else
900         *port = NULL;
901     return p;
902 }
903
904 void getconfig(const char *filename) {
905     FILE *f;
906     char line[1024];
907     char *p, *field;
908     struct peer *peer;
909     
910     peer_count = 0;
911     
912     udp_server_replyq.replies = malloc(4 * MAX_REQUESTS * sizeof(struct reply));
913     if (!udp_server_replyq.replies)
914         errx("malloc failed");
915     udp_server_replyq.size = 4 * MAX_REQUESTS;
916     udp_server_replyq.count = 0;
917     pthread_mutex_init(&udp_server_replyq.count_mutex, NULL);
918     pthread_cond_init(&udp_server_replyq.count_cond, NULL);
919     
920     f = fopen(filename, "r");
921     if (!f)
922         errx("getconfig failed to open %s for reading", filename);
923
924     while (fgets(line, 1024, f) && peer_count < MAX_PEERS) {
925         peer = &peers[peer_count];
926         memset(peer, 0, sizeof(struct peer));
927
928         for (p = line; *p == ' ' || *p == '\t'; p++);
929         if (*p == '#' || *p == '\n')
930             continue;
931         if (*p != 'U' && *p != 'T') {
932             printf("server type must be U or T, got %c\n", *p);
933             exit(1);
934         }
935         peer->type = *p;
936         for (p++; *p == ' ' || *p == '\t'; p++);
937         p = parsehostport(p, &peer->host, &peer->port);
938         if (!peer->port)
939             peer->port = (peer->type == 'U' ? DEFAULT_UDP_PORT : DEFAULT_TLS_PORT);
940         for (; *p == ' ' || *p == '\t'; p++);
941         field = p;
942         for (; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++);
943         if (field == p) {
944             /* no secret set and end of line, line is complete if TLS */
945             if (peer->type == 'U') {
946                 printf("secret must be specified for UDP\n");
947                 exit(1);
948             }
949             peer->secret = DEFAULT_TLS_SECRET;
950         } else {
951             peer->secret = malloc(p - field + 1);
952             if (!peer->secret)
953                 errx("malloc failed");
954             memcpy(peer->secret, field, p - field);
955             peer->secret[p - field] = '\0';
956             /* check that rest of line only white space */
957             for (; *p == ' ' || *p == '\t'; p++);
958             if (*p && *p != '\n') {
959                 printf("max 3 fields per line, found a 4th\n");
960                 exit(1);
961             }
962         }
963         peer->sockcl = -1;
964         peer->sslsrv = NULL;
965         peer->sslcl = NULL;
966         pthread_mutex_init(&peer->lock, NULL);
967         if (!resolvepeer(peer)) {
968             printf("failed to resolve host %s port %s, exiting\n", peer->host, peer->port);
969             exit(1);
970         }
971         peer->requests = malloc(MAX_REQUESTS * sizeof(struct request));
972         if (!peer->requests)
973             errx("malloc failed");
974         memset(peer->requests, 0, MAX_REQUESTS * sizeof(struct request));
975         peer->newrq = 0;
976         pthread_mutex_init(&peer->newrq_mutex, NULL);
977         pthread_cond_init(&peer->newrq_cond, NULL);
978
979         if (peer->type == 'U')
980             peer->replyq = &udp_server_replyq;
981         else {
982             peer->replyq = malloc(sizeof(struct replyq));
983             if (!peer->replyq)
984                 errx("malloc failed");
985             peer->replyq->replies = malloc(MAX_REQUESTS * sizeof(struct reply));
986             if (!peer->replyq->replies)
987                 errx("malloc failed");
988             peer->replyq->size = MAX_REQUESTS;
989             peer->replyq->count = 0;
990             pthread_mutex_init(&peer->replyq->count_mutex, NULL);
991             pthread_cond_init(&peer->replyq->count_cond, NULL);
992         }
993         printf("got type %c, host %s, port %s, secret %s\n", peers[peer_count].type,
994                peers[peer_count].host, peers[peer_count].port, peers[peer_count].secret);
995         peer_count++;
996     }
997     fclose(f);
998 }
999
1000 void parseargs(int argc, char **argv) {
1001     int c;
1002
1003     while ((c = getopt(argc, argv, "p:")) != -1) {
1004         switch (c) {
1005         case 'p':
1006             udp_server_port = optarg;
1007             break;
1008         default:
1009             goto usage;
1010         }
1011     }
1012
1013     return;
1014
1015  usage:
1016     printf("radsecproxy [ -p UDP-port ]\n");
1017     exit(1);
1018 }
1019                
1020 int main(int argc, char **argv) {
1021     SSL_CTX *ssl_ctx_srv;
1022     unsigned long error;
1023     pthread_t udpserverth;
1024     pthread_attr_t joinable;
1025     int i;
1026     
1027     parseargs(argc, argv);
1028     getconfig("radsecproxy.conf");
1029     
1030     ssl_locks_setup();
1031
1032     pthread_attr_init(&joinable);
1033     pthread_attr_setdetachstate(&joinable, PTHREAD_CREATE_JOINABLE);
1034    
1035     /* listen on UDP if at least one UDP peer */
1036     
1037     for (i = 0; i < peer_count; i++)
1038         if (peers[i].type == 'U') {
1039             if (pthread_create(&udpserverth, &joinable, udpserverrd, NULL))
1040                 errx("pthread_create failed");
1041             break;
1042         }
1043     
1044     for (i = 0; i < peer_count; i++)
1045         if (peers[i].type == 'T')
1046             break;
1047
1048     if (i == peer_count) {
1049         printf("No TLS peers defined, just doing UDP proxying\n");
1050         /* just hang around doing nothing, anything to do here? */
1051         pthread_join(udpserverth, NULL);
1052         return 0;
1053     }
1054     
1055     /* SSL setup */
1056     SSL_load_error_strings();
1057     SSL_library_init();
1058
1059     while (!RAND_status()) {
1060         time_t t = time(NULL);
1061         pid_t pid = getpid();
1062         RAND_seed((unsigned char *)&t, sizeof(time_t));
1063         RAND_seed((unsigned char *)&pid, sizeof(pid));
1064     }
1065     
1066     /* initialise client part and start clients */
1067     ssl_ctx_cl = SSL_CTX_new(TLSv1_client_method());
1068     if (!ssl_ctx_cl)
1069         errx("no ssl ctx");
1070     
1071     for (i = 0; i < peer_count; i++) {
1072         if (pthread_create(&peers[i].clientth, NULL, clientwr, (void *)&peers[i]))
1073             errx("pthread_create failed");
1074     }
1075
1076     /* setting up server/daemon part */
1077     ssl_ctx_srv = SSL_CTX_new(TLSv1_server_method());
1078     if (!ssl_ctx_srv)
1079         errx("no ssl ctx");
1080     if (!SSL_CTX_use_certificate_file(ssl_ctx_srv, "/tmp/server.pem", SSL_FILETYPE_PEM)) {
1081         while ((error = ERR_get_error()))
1082             err("SSL: %s", ERR_error_string(error, NULL));
1083         errx("Failed to load certificate");
1084     }
1085     if (!SSL_CTX_use_PrivateKey_file(ssl_ctx_srv, "/tmp/server.key", SSL_FILETYPE_PEM)) {
1086         while ((error = ERR_get_error()))
1087             err("SSL: %s", ERR_error_string(error, NULL));
1088         errx("Failed to load private key");
1089     }
1090     return tlslistener(ssl_ctx_srv);
1091 }