added some todo comments
[radsecproxy.git] / radsecproxy.c
1 /*
2  * Copyright (C) 2006 Stig Venaas <venaas@uninett.no>
3  *
4  * Permission to use, copy, modify, and distribute this software for any
5  * purpose with or without fee is hereby granted, provided that the above
6  * copyright notice and this permission notice appear in all copies.
7  */
8
9 /* BUGS:
10  * peers can not yet be specified with literal IPv6 addresses due to port syntax
11  */
12
13 /* TODO:
14  * make our server ignore client retrans and do its own instead?
15  * accounting
16  * radius keep alives (server status)
17  * tls certificate validation, see below urls
18  * clean tls shutdown, see http://www.linuxjournal.com/article/4822
19  *     and http://www.linuxjournal.com/article/5487
20  *     SSL_shutdown() and shutdown()
21  *     If shutdown() we may not need REUSEADDR
22  * when tls client goes away, ensure that all related threads and state
23  *          are removed
24  * setsockopt(keepalive...), check if openssl has some keepalive feature
25 */
26
27 /* For UDP there is one server instance consisting of udpserverrd and udpserverth
28  *              rd is responsible for init and launching wr
29  * For TLS there is a server instance that launches tlsserverrd for each TLS peer
30  *          each tlsserverrd launches tlsserverwr
31  * For each UDP/TLS peer there is clientrd and clientwr, clientwr is responsible
32  *          for init and launching rd
33  *
34  * serverrd will receive a request, processes it and puts it in the requestq of
35  *          the appropriate clientwr
36  * clientwr monitors its requestq and sends requests
37  * clientrd looks for responses, processes them and puts them in the replyq of
38  *          the peer the request came from
39  * serverwr monitors its reply and sends replies
40  *
41  * In addition to the main thread, we have:
42  * If UDP peers are configured, there will be 2 + 2 * #peers UDP threads
43  * If TLS peers are configured, there will initially be 2 * #peers TLS threads
44  * For each TLS peer connecting to us there will be 2 more TLS threads
45  *       This is only for connected peers
46  * Example: With 3 UDP peer and 30 TLS peers, there will be a max of
47  *          1 + (2 + 2 * 3) + (2 * 30) + (2 * 30) = 129 threads
48 */
49
50 #include <netdb.h>
51 #include <unistd.h>
52 #include <sys/time.h>
53 #include <pthread.h>
54 #include <openssl/ssl.h>
55 #include <openssl/rand.h>
56 #include <openssl/err.h>
57 #include <openssl/md5.h>
58 #include <openssl/hmac.h>
59 #include "radsecproxy.h"
60
61 static struct client clients[MAX_PEERS];
62 static struct server servers[MAX_PEERS];
63
64 static int client_count = 0;
65 static int server_count = 0;
66
67 static struct replyq udp_server_replyq;
68 static int udp_server_sock = -1;
69 static char *udp_server_port = DEFAULT_UDP_PORT;
70 static pthread_mutex_t *ssl_locks;
71 static long *ssl_lock_count;
72 static SSL_CTX *ssl_ctx_cl;
73 extern int optind;
74 extern char *optarg;
75
76 /* callbacks for making OpenSSL thread safe */
77 unsigned long ssl_thread_id() {
78         return (unsigned long)pthread_self();
79 };
80
81 void ssl_locking_callback(int mode, int type, const char *file, int line) {
82     if (mode & CRYPTO_LOCK) {
83         pthread_mutex_lock(&ssl_locks[type]);
84         ssl_lock_count[type]++;
85     } else
86         pthread_mutex_unlock(&ssl_locks[type]);
87 }
88
89 void ssl_locks_setup() {
90     int i;
91
92     ssl_locks = malloc(CRYPTO_num_locks() * sizeof(pthread_mutex_t));
93     ssl_lock_count = OPENSSL_malloc(CRYPTO_num_locks() * sizeof(long));
94     for (i = 0; i < CRYPTO_num_locks(); i++) {
95         ssl_lock_count[i] = 0;
96         pthread_mutex_init(&ssl_locks[i], NULL);
97     }
98
99     CRYPTO_set_id_callback(ssl_thread_id);
100     CRYPTO_set_locking_callback(ssl_locking_callback);
101 }
102
103 void printauth(char *s, unsigned char *t) {
104     int i;
105     printf("%s:", s);
106     for (i = 0; i < 16; i++)
107             printf("%02x ", t[i]);
108     printf("\n");
109 }
110
111 int resolvepeer(struct peer *peer) {
112     struct addrinfo hints, *addrinfo;
113     
114     memset(&hints, 0, sizeof(hints));
115     hints.ai_socktype = (peer->type == 'T' ? SOCK_STREAM : SOCK_DGRAM);
116     hints.ai_family = AF_UNSPEC;
117     if (getaddrinfo(peer->host, peer->port, &hints, &addrinfo)) {
118         err("resolvepeer: can't resolve %s port %s", peer->host, peer->port);
119         return 0;
120     }
121
122     if (peer->addrinfo)
123         freeaddrinfo(peer->addrinfo);
124     peer->addrinfo = addrinfo;
125     return 1;
126 }         
127
128 int connecttoserver(struct addrinfo *addrinfo) {
129     int s;
130     struct addrinfo *res;
131     
132     for (res = addrinfo; res; res = res->ai_next) {
133         s = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
134         if (s < 0) {
135             err("connecttoserver: socket failed");
136             continue;
137         }
138         if (connect(s, res->ai_addr, res->ai_addrlen) == 0)
139             break;
140         err("connecttoserver: connect failed");
141         close(s);
142         s = -1;
143     }
144     return s;
145 }         
146
147 /* returns the client with matching address, or NULL */
148 /* if client argument is not NULL, we only check that one client */
149 struct client *find_client(char type, struct sockaddr *addr, struct client *client) {
150     struct sockaddr_in6 *sa6;
151     struct in_addr *a4 = NULL;
152     struct client *c;
153     int i;
154     struct addrinfo *res;
155
156     if (addr->sa_family == AF_INET6) {
157         sa6 = (struct sockaddr_in6 *)addr;
158         if (IN6_IS_ADDR_V4MAPPED(&sa6->sin6_addr))
159             a4 = (struct in_addr *)&sa6->sin6_addr.s6_addr[12];
160     } else
161         a4 = &((struct sockaddr_in *)addr)->sin_addr;
162
163     c = (client ? client : clients);
164     for (i = 0; i < client_count; i++) {
165         if (c->peer.type == type)
166             for (res = c->peer.addrinfo; res; res = res->ai_next)
167                 if ((a4 && res->ai_family == AF_INET &&
168                      !memcmp(a4, &((struct sockaddr_in *)res->ai_addr)->sin_addr, 4)) ||
169                     (res->ai_family == AF_INET6 &&
170                      !memcmp(&sa6->sin6_addr, &((struct sockaddr_in6 *)res->ai_addr)->sin6_addr, 16)))
171                     return c;
172         if (client)
173             break;
174         c++;
175     }
176     return NULL;
177 }
178
179 /* returns the server with matching address, or NULL */
180 /* if server argument is not NULL, we only check that one server */
181 struct server *find_server(char type, struct sockaddr *addr, struct server *server) {
182     struct sockaddr_in6 *sa6;
183     struct in_addr *a4 = NULL;
184     struct server *s;
185     int i;
186     struct addrinfo *res;
187
188     if (addr->sa_family == AF_INET6) {
189         sa6 = (struct sockaddr_in6 *)addr;
190         if (IN6_IS_ADDR_V4MAPPED(&sa6->sin6_addr))
191             a4 = (struct in_addr *)&sa6->sin6_addr.s6_addr[12];
192     } else
193         a4 = &((struct sockaddr_in *)addr)->sin_addr;
194
195     s = (server ? server : servers);
196     for (i = 0; i < server_count; i++) {
197         if (s->peer.type == type)
198             for (res = s->peer.addrinfo; res; res = res->ai_next)
199                 if ((a4 && res->ai_family == AF_INET &&
200                      !memcmp(a4, &((struct sockaddr_in *)res->ai_addr)->sin_addr, 4)) ||
201                     (res->ai_family == AF_INET6 &&
202                      !memcmp(&sa6->sin6_addr, &((struct sockaddr_in6 *)res->ai_addr)->sin6_addr, 16)))
203                     return s;
204         if (server)
205             break;
206         s++;
207     }
208     return NULL;
209 }
210
211 /* exactly one of client and server must be non-NULL */
212 /* if *peer == NULL we return who we received from, else require it to be from peer */
213 /* return from in sa if not NULL */
214 unsigned char *radudpget(int s, struct client **client, struct server **server, struct sockaddr_storage *sa) {
215     int cnt, len;
216     void *f;
217     unsigned char buf[65536], *rad;
218     struct sockaddr_storage from;
219     socklen_t fromlen = sizeof(from);
220
221     for (;;) {
222         cnt = recvfrom(s, buf, sizeof(buf), 0, (struct sockaddr *)&from, &fromlen);
223         if (cnt == -1) {
224             err("radudpget: recv failed");
225             continue;
226         }
227         printf("radudpget: got %d bytes from %s\n", cnt, addr2string((struct sockaddr *)&from, fromlen));
228
229         if (cnt < 20) {
230             printf("radudpget: packet too small\n");
231             continue;
232         }
233     
234         len = RADLEN(buf);
235
236         if (cnt < len) {
237             printf("radudpget: packet smaller than length field in radius header\n");
238             continue;
239         }
240         if (cnt > len)
241             printf("radudpget: packet was padded with %d bytes\n", cnt - len);
242
243         f = (client
244              ? (void *)find_client('U', (struct sockaddr *)&from, *client)
245              : (void *)find_server('U', (struct sockaddr *)&from, *server));
246         if (!f) {
247             printf("radudpget: got packet from wrong or unknown UDP peer, ignoring\n");
248             continue;
249         }
250
251         rad = malloc(len);
252         if (rad)
253             break;
254         err("radudpget: malloc failed");
255     }
256     memcpy(rad, buf, len);
257     if (client)
258         *client = (struct client *)f; /* only need this if *client == NULL, but if not NULL *client == f here */
259     else
260         *server = (struct server *)f; /* only need this if *server == NULL, but if not NULL *server == f here */
261     if (sa)
262         *sa = from;
263     return rad;
264 }
265
266 void tlsconnect(struct server *server, struct timeval *when, char *text) {
267     struct timeval now;
268     time_t elapsed;
269     unsigned long error;
270
271     printf("tlsconnect called from %s\n", text);
272     pthread_mutex_lock(&server->lock);
273     if (when && memcmp(&server->lastconnecttry, when, sizeof(struct timeval))) {
274         /* already reconnected, nothing to do */
275         printf("tlsconnect(%s): seems already reconnected\n", text);
276         pthread_mutex_unlock(&server->lock);
277         return;
278     }
279
280     printf("tlsconnect %s\n", text);
281
282     for (;;) {
283         gettimeofday(&now, NULL);
284         elapsed = now.tv_sec - server->lastconnecttry.tv_sec;
285         if (server->connectionok) {
286             server->connectionok = 0;
287             sleep(10);
288         } else if (elapsed < 5)
289             sleep(10);
290         else if (elapsed < 600)
291             sleep(elapsed * 2);
292         else if (elapsed < 10000)
293                 sleep(900);
294         else
295             server->lastconnecttry.tv_sec = now.tv_sec;  // no sleep at startup
296         printf("tlsconnect: trying to open TLS connection to %s port %s\n", server->peer.host, server->peer.port);
297         if (server->sock >= 0)
298             close(server->sock);
299         if ((server->sock = connecttoserver(server->peer.addrinfo)) < 0)
300             continue;
301         SSL_free(server->peer.ssl);
302         server->peer.ssl = SSL_new(ssl_ctx_cl);
303         SSL_set_fd(server->peer.ssl, server->sock);
304         if (SSL_connect(server->peer.ssl) > 0)
305             break;
306         while ((error = ERR_get_error()))
307             err("tlsconnect: TLS: %s", ERR_error_string(error, NULL));
308     }
309     printf("tlsconnect: TLS connection to %s port %s up\n", server->peer.host, server->peer.port);
310     gettimeofday(&server->lastconnecttry, NULL);
311     pthread_mutex_unlock(&server->lock);
312 }
313
314 unsigned char *radtlsget(SSL *ssl) {
315     int cnt, total, len;
316     unsigned char buf[4], *rad;
317
318     for (;;) {
319         for (total = 0; total < 4; total += cnt) {
320             cnt = SSL_read(ssl, buf + total, 4 - total);
321             if (cnt <= 0) {
322                 printf("radtlsget: connection lost\n");
323                 return NULL;
324             }
325         }
326
327         len = RADLEN(buf);
328         rad = malloc(len);
329         if (!rad) {
330             err("radtlsget: malloc failed");
331             continue;
332         }
333         memcpy(rad, buf, 4);
334
335         for (; total < len; total += cnt) {
336             cnt = SSL_read(ssl, rad + total, len - total);
337             if (cnt <= 0) {
338                 printf("radtlsget: connection lost\n");
339                 free(rad);
340                 return NULL;
341             }
342         }
343     
344         if (total >= 20)
345             break;
346         
347         free(rad);
348         printf("radtlsget: packet smaller than minimum radius size\n");
349     }
350     
351     printf("radtlsget: got %d bytes\n", total);
352     return rad;
353 }
354
355 int clientradput(struct server *server, unsigned char *rad) {
356     int cnt;
357     size_t len;
358     unsigned long error;
359     struct timeval lastconnecttry;
360     
361     len = RADLEN(rad);
362     if (server->peer.type == 'U') {
363         if (send(server->sock, rad, len, 0) >= 0) {
364             printf("clienradput: sent UDP of length %d to %s port %s\n", len, server->peer.host, server->peer.port);
365             return 1;
366         }
367         err("clientradput: send failed");
368         return 0;
369     }
370
371     lastconnecttry = server->lastconnecttry;
372     while ((cnt = SSL_write(server->peer.ssl, rad, len)) <= 0) {
373         while ((error = ERR_get_error()))
374             err("clientwr: TLS: %s", ERR_error_string(error, NULL));
375         tlsconnect(server, &lastconnecttry, "clientradput");
376         lastconnecttry = server->lastconnecttry;
377     }
378
379     server->connectionok = 1;
380     printf("clientradput: Sent %d bytes, Radius packet of length %d to TLS peer %s\n",
381            cnt, len, server->peer.host);
382     return 1;
383 }
384
385 int radsign(unsigned char *rad, unsigned char *sec) {
386     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
387     static unsigned char first = 1;
388     static EVP_MD_CTX mdctx;
389     unsigned int md_len;
390     int result;
391     
392     pthread_mutex_lock(&lock);
393     if (first) {
394         EVP_MD_CTX_init(&mdctx);
395         first = 0;
396     }
397
398     result = (EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) &&
399         EVP_DigestUpdate(&mdctx, rad, RADLEN(rad)) &&
400         EVP_DigestUpdate(&mdctx, sec, strlen(sec)) &&
401         EVP_DigestFinal_ex(&mdctx, rad + 4, &md_len) &&
402         md_len == 16);
403     pthread_mutex_unlock(&lock);
404     return result;
405 }
406
407 int validauth(unsigned char *rad, unsigned char *reqauth, unsigned char *sec) {
408     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
409     static unsigned char first = 1;
410     static EVP_MD_CTX mdctx;
411     unsigned char hash[EVP_MAX_MD_SIZE];
412     unsigned int len;
413     int result;
414     
415     pthread_mutex_lock(&lock);
416     if (first) {
417         EVP_MD_CTX_init(&mdctx);
418         first = 0;
419     }
420
421     len = RADLEN(rad);
422     
423     result = (EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) &&
424               EVP_DigestUpdate(&mdctx, rad, 4) &&
425               EVP_DigestUpdate(&mdctx, reqauth, 16) &&
426               (len <= 20 || EVP_DigestUpdate(&mdctx, rad + 20, len - 20)) &&
427               EVP_DigestUpdate(&mdctx, sec, strlen(sec)) &&
428               EVP_DigestFinal_ex(&mdctx, hash, &len) &&
429               len == 16 &&
430               !memcmp(hash, rad + 4, 16));
431     pthread_mutex_unlock(&lock);
432     return result;
433 }
434               
435 int checkmessageauth(char *rad, uint8_t *authattr, char *secret) {
436     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
437     static unsigned char first = 1;
438     static HMAC_CTX hmacctx;
439     unsigned int md_len;
440     uint8_t auth[16], hash[EVP_MAX_MD_SIZE];
441     
442     pthread_mutex_lock(&lock);
443     if (first) {
444         HMAC_CTX_init(&hmacctx);
445         first = 0;
446     }
447
448     memcpy(auth, authattr, 16);
449     memset(authattr, 0, 16);
450     md_len = 0;
451     HMAC_Init_ex(&hmacctx, secret, strlen(secret), EVP_md5(), NULL);
452     HMAC_Update(&hmacctx, rad, RADLEN(rad));
453     HMAC_Final(&hmacctx, hash, &md_len);
454     memcpy(authattr, auth, 16);
455     if (md_len != 16) {
456         printf("message auth computation failed\n");
457         pthread_mutex_unlock(&lock);
458         return 0;
459     }
460
461     if (memcmp(auth, hash, 16)) {
462         printf("message authenticator, wrong value\n");
463         pthread_mutex_unlock(&lock);
464         return 0;
465     }   
466         
467     pthread_mutex_unlock(&lock);
468     return 1;
469 }
470
471 int createmessageauth(char *rad, char *authattrval, char *secret) {
472     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
473     static unsigned char first = 1;
474     static HMAC_CTX hmacctx;
475     unsigned int md_len;
476
477     if (!authattrval)
478         return 1;
479     
480     pthread_mutex_lock(&lock);
481     if (first) {
482         HMAC_CTX_init(&hmacctx);
483         first = 0;
484     }
485
486     memset(authattrval, 0, 16);
487     md_len = 0;
488     HMAC_Init_ex(&hmacctx, secret, strlen(secret), EVP_md5(), NULL);
489     HMAC_Update(&hmacctx, rad, RADLEN(rad));
490     HMAC_Final(&hmacctx, authattrval, &md_len);
491     if (md_len != 16) {
492         printf("message auth computation failed\n");
493         pthread_mutex_unlock(&lock);
494         return 0;
495     }
496
497     pthread_mutex_unlock(&lock);
498     return 1;
499 }
500
501 void sendrq(struct server *to, struct client *from, struct request *rq) {
502     int i;
503     
504     pthread_mutex_lock(&to->newrq_mutex);
505     /* might simplify if only try nextid, might be ok */
506     for (i = to->nextid; i < MAX_REQUESTS; i++)
507         if (!to->requests[i].buf)
508             break;
509     if (i == MAX_REQUESTS) {
510         for (i = 0; i < to->nextid; i++)
511             if (!to->requests[i].buf)
512                 break;
513         if (i == to->nextid) {
514             printf("No room in queue, dropping request\n");
515             pthread_mutex_unlock(&to->newrq_mutex);
516             return;
517         }
518     }
519     
520     to->nextid = i + 1;
521     rq->buf[1] = (char)i;
522     printf("sendrq: inserting packet with id %d in queue for %s\n", i, to->peer.host);
523     
524     if (!createmessageauth(rq->buf, rq->messageauthattrval, to->peer.secret))
525         return;
526
527     gettimeofday(&rq->expiry, NULL);
528     rq->expiry.tv_sec += 30;
529     to->requests[i] = *rq;
530
531     if (!to->newrq) {
532         to->newrq = 1;
533         printf("signalling client writer\n");
534         pthread_cond_signal(&to->newrq_cond);
535     }
536     pthread_mutex_unlock(&to->newrq_mutex);
537 }
538
539 void sendreply(struct client *to, struct server *from, char *buf, struct sockaddr_storage *tosa) {
540     struct replyq *replyq = to->replyq;
541     
542     pthread_mutex_lock(&replyq->count_mutex);
543     if (replyq->count == replyq->size) {
544         printf("No room in queue, dropping request\n");
545         pthread_mutex_unlock(&replyq->count_mutex);
546         return;
547     }
548
549     replyq->replies[replyq->count].buf = buf;
550     if (tosa)
551         replyq->replies[replyq->count].tosa = *tosa;
552     replyq->count++;
553
554     if (replyq->count == 1) {
555         printf("signalling client writer\n");
556         pthread_cond_signal(&replyq->count_cond);
557     }
558     pthread_mutex_unlock(&replyq->count_mutex);
559 }
560
561 int pwdencrypt(uint8_t *in, uint8_t len, uint8_t *shared, uint8_t sharedlen, uint8_t *auth) {
562     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
563     static unsigned char first = 1;
564     static EVP_MD_CTX mdctx;
565     unsigned char hash[EVP_MAX_MD_SIZE], *input;
566     unsigned int md_len;
567     uint8_t i, offset = 0, out[128];
568     
569     pthread_mutex_lock(&lock);
570     if (first) {
571         EVP_MD_CTX_init(&mdctx);
572         first = 0;
573     }
574
575     input = auth;
576     for (;;) {
577         if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
578             !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
579             !EVP_DigestUpdate(&mdctx, input, 16) ||
580             !EVP_DigestFinal_ex(&mdctx, hash, &md_len) ||
581             md_len != 16) {
582             pthread_mutex_unlock(&lock);
583             return 0;
584         }
585         for (i = 0; i < 16; i++)
586             out[offset + i] = hash[i] ^ in[offset + i];
587         input = out + offset - 16;
588         offset += 16;
589         if (offset == len)
590             break;
591     }
592     memcpy(in, out, len);
593     pthread_mutex_unlock(&lock);
594     return 1;
595 }
596
597 int pwddecrypt(uint8_t *in, uint8_t len, uint8_t *shared, uint8_t sharedlen, uint8_t *auth) {
598     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
599     static unsigned char first = 1;
600     static EVP_MD_CTX mdctx;
601     unsigned char hash[EVP_MAX_MD_SIZE], *input;
602     unsigned int md_len;
603     uint8_t i, offset = 0, out[128];
604     
605     pthread_mutex_lock(&lock);
606     if (first) {
607         EVP_MD_CTX_init(&mdctx);
608         first = 0;
609     }
610
611     input = auth;
612     for (;;) {
613         if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
614             !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
615             !EVP_DigestUpdate(&mdctx, input, 16) ||
616             !EVP_DigestFinal_ex(&mdctx, hash, &md_len) ||
617             md_len != 16) {
618             pthread_mutex_unlock(&lock);
619             return 0;
620         }
621         for (i = 0; i < 16; i++)
622             out[offset + i] = hash[i] ^ in[offset + i];
623         input = in + offset;
624         offset += 16;
625         if (offset == len)
626             break;
627     }
628     memcpy(in, out, len);
629     pthread_mutex_unlock(&lock);
630     return 1;
631 }
632
633 int msmppencrypt(uint8_t *text, uint8_t len, uint8_t *shared, uint8_t sharedlen, uint8_t *auth, uint8_t *salt) {
634     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
635     static unsigned char first = 1;
636     static EVP_MD_CTX mdctx;
637     unsigned char hash[EVP_MAX_MD_SIZE];
638     unsigned int md_len;
639     uint8_t i, offset;
640     
641     pthread_mutex_lock(&lock);
642     if (first) {
643         EVP_MD_CTX_init(&mdctx);
644         first = 0;
645     }
646
647 #if 0    
648     printf("msppencrypt auth in: ");
649     for (i = 0; i < 16; i++)
650         printf("%02x ", auth[i]);
651     printf("\n");
652     
653     printf("msppencrypt salt in: ");
654     for (i = 0; i < 2; i++)
655         printf("%02x ", salt[i]);
656     printf("\n");
657     
658     printf("msppencrypt in: ");
659     for (i = 0; i < len; i++)
660         printf("%02x ", text[i]);
661     printf("\n");
662 #endif
663     
664     if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
665         !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
666         !EVP_DigestUpdate(&mdctx, auth, 16) ||
667         !EVP_DigestUpdate(&mdctx, salt, 2) ||
668         !EVP_DigestFinal_ex(&mdctx, hash, &md_len)) {
669         pthread_mutex_unlock(&lock);
670         return 0;
671     }
672
673 #if 0    
674     printf("msppencrypt hash: ");
675     for (i = 0; i < 16; i++)
676         printf("%02x ", hash[i]);
677     printf("\n");
678 #endif
679     
680     for (i = 0; i < 16; i++)
681         text[i] ^= hash[i];
682     
683     for (offset = 16; offset < len; offset += 16) {
684 #if 0   
685         printf("text + offset - 16 c(%d): ", offset / 16);
686         for (i = 0; i < 16; i++)
687             printf("%02x ", (text + offset - 16)[i]);
688         printf("\n");
689 #endif
690         if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
691             !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
692             !EVP_DigestUpdate(&mdctx, text + offset - 16, 16) ||
693             !EVP_DigestFinal_ex(&mdctx, hash, &md_len) ||
694             md_len != 16) {
695             pthread_mutex_unlock(&lock);
696             return 0;
697         }
698 #if 0   
699         printf("msppencrypt hash: ");
700         for (i = 0; i < 16; i++)
701             printf("%02x ", hash[i]);
702         printf("\n");
703 #endif    
704         
705         for (i = 0; i < 16; i++)
706             text[offset + i] ^= hash[i];
707     }
708     
709 #if 0
710     printf("msppencrypt out: ");
711     for (i = 0; i < len; i++)
712         printf("%02x ", text[i]);
713     printf("\n");
714 #endif
715
716     pthread_mutex_unlock(&lock);
717     return 1;
718 }
719
720 int msmppdecrypt(uint8_t *text, uint8_t len, uint8_t *shared, uint8_t sharedlen, uint8_t *auth, uint8_t *salt) {
721     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
722     static unsigned char first = 1;
723     static EVP_MD_CTX mdctx;
724     unsigned char hash[EVP_MAX_MD_SIZE];
725     unsigned int md_len;
726     uint8_t i, offset;
727     char plain[255];
728     
729     pthread_mutex_lock(&lock);
730     if (first) {
731         EVP_MD_CTX_init(&mdctx);
732         first = 0;
733     }
734
735 #if 0    
736     printf("msppdecrypt auth in: ");
737     for (i = 0; i < 16; i++)
738         printf("%02x ", auth[i]);
739     printf("\n");
740     
741     printf("msppedecrypt salt in: ");
742     for (i = 0; i < 2; i++)
743         printf("%02x ", salt[i]);
744     printf("\n");
745     
746     printf("msppedecrypt in: ");
747     for (i = 0; i < len; i++)
748         printf("%02x ", text[i]);
749     printf("\n");
750 #endif
751     
752     if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
753         !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
754         !EVP_DigestUpdate(&mdctx, auth, 16) ||
755         !EVP_DigestUpdate(&mdctx, salt, 2) ||
756         !EVP_DigestFinal_ex(&mdctx, hash, &md_len)) {
757         pthread_mutex_unlock(&lock);
758         return 0;
759     }
760
761 #if 0    
762     printf("msppedecrypt hash: ");
763     for (i = 0; i < 16; i++)
764         printf("%02x ", hash[i]);
765     printf("\n");
766 #endif
767     
768     for (i = 0; i < 16; i++)
769         plain[i] = text[i] ^ hash[i];
770     
771     for (offset = 16; offset < len; offset += 16) {
772 #if 0   
773         printf("text + offset - 16 c(%d): ", offset / 16);
774         for (i = 0; i < 16; i++)
775             printf("%02x ", (text + offset - 16)[i]);
776         printf("\n");
777 #endif
778         if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
779             !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
780             !EVP_DigestUpdate(&mdctx, text + offset - 16, 16) ||
781             !EVP_DigestFinal_ex(&mdctx, hash, &md_len) ||
782             md_len != 16) {
783             pthread_mutex_unlock(&lock);
784             return 0;
785         }
786 #if 0   
787     printf("msppedecrypt hash: ");
788     for (i = 0; i < 16; i++)
789         printf("%02x ", hash[i]);
790     printf("\n");
791 #endif    
792
793     for (i = 0; i < 16; i++)
794         plain[offset + i] = text[offset + i] ^ hash[i];
795     }
796
797     memcpy(text, plain, len);
798 #if 0
799     printf("msppedecrypt out: ");
800     for (i = 0; i < len; i++)
801         printf("%02x ", text[i]);
802     printf("\n");
803 #endif
804
805     pthread_mutex_unlock(&lock);
806     return 1;
807 }
808
809 struct server *id2server(char *id, uint8_t len) {
810     int i;
811     char **realm, *idrealm;
812
813     idrealm = strchr(id, '@');
814     if (idrealm) {
815         idrealm++;
816         len -= idrealm - id;
817     } else {
818         idrealm = "-";
819         len = 1;
820     }
821     for (i = 0; i < server_count; i++) {
822         for (realm = servers[i].realms; *realm; realm++) {
823             if ((strlen(*realm) == 1 && **realm == '*') ||
824                 (strlen(*realm) == len && !memcmp(idrealm, *realm, len))) {
825                 printf("found matching realm: %s, host %s\n", *realm, servers[i].peer.host);
826                 return servers + i;
827             }
828         }
829     }
830     return NULL;
831 }
832
833 struct server *radsrv(struct request *rq, char *buf, struct client *from) {
834     uint8_t code, id, *auth, *attr, attrvallen;
835     uint8_t *usernameattr = NULL, *userpwdattr = NULL, *tunnelpwdattr = NULL, *messageauthattr = NULL;
836     int i;
837     uint16_t len;
838     int left;
839     struct server *to;
840     unsigned char newauth[16];
841     
842     code = *(uint8_t *)buf;
843     id = *(uint8_t *)(buf + 1);
844     len = RADLEN(buf);
845     auth = (uint8_t *)(buf + 4);
846
847     printf("radsrv: code %d, id %d, length %d\n", code, id, len);
848     
849     if (code != RAD_Access_Request) {
850         printf("radsrv: server currently accepts only access-requests, ignoring\n");
851         return NULL;
852     }
853
854     left = len - 20;
855     attr = buf + 20;
856     
857     while (left > 1) {
858         left -= attr[RAD_Attr_Length];
859         if (left < 0) {
860             printf("radsrv: attribute length exceeds packet length, ignoring packet\n");
861             return NULL;
862         }
863         switch (attr[RAD_Attr_Type]) {
864         case RAD_Attr_User_Name:
865             usernameattr = attr;
866             break;
867         case RAD_Attr_User_Password:
868             userpwdattr = attr;
869             break;
870         case RAD_Attr_Tunnel_Password:
871             tunnelpwdattr = attr;
872             break;
873         case RAD_Attr_Message_Authenticator:
874             messageauthattr = attr;
875             break;
876         }
877         attr += attr[RAD_Attr_Length];
878     }
879     if (left)
880         printf("radsrv: malformed packet? remaining byte after last attribute\n");
881
882     if (usernameattr) {
883         printf("radsrv: Username: ");
884         for (i = 0; i < usernameattr[RAD_Attr_Length] - 2; i++)
885             printf("%c", usernameattr[RAD_Attr_Value + i]);
886         printf("\n");
887     }
888
889     to = id2server(&usernameattr[RAD_Attr_Value], usernameattr[RAD_Attr_Length] - 2);
890     if (!to) {
891         printf("radsrv: ignoring request, don't know where to send it\n");
892         return NULL;
893     }
894     
895     if (messageauthattr && (messageauthattr[RAD_Attr_Length] != 18 ||
896                             !checkmessageauth(buf, &messageauthattr[RAD_Attr_Value], from->peer.secret))) {
897         printf("radsrv: message authentication failed\n");
898         return NULL;
899     }
900
901     if (!RAND_bytes(newauth, 16)) {
902         printf("radsrv: failed to generate random auth\n");
903         return NULL;
904     }
905
906     printauth("auth", auth);
907     printauth("newauth", newauth);
908     
909     if (userpwdattr) {
910         printf("radsrv: found userpwdattr of length %d\n", userpwdattr[RAD_Attr_Length]);
911         attrvallen = userpwdattr[RAD_Attr_Length] - 2;
912         if (attrvallen < 16 || attrvallen > 128 || attrvallen % 16) {
913             printf("radsrv: invalid user password length\n");
914             return NULL;
915         }
916         
917         if (!pwddecrypt(&userpwdattr[RAD_Attr_Value], attrvallen, from->peer.secret, strlen(from->peer.secret), auth)) {
918             printf("radsrv: cannot decrypt password\n");
919             return NULL;
920         }
921         printf("radsrv: password: ");
922         for (i = 0; i < attrvallen; i++)
923             printf("%02x ", userpwdattr[RAD_Attr_Value + i]);
924         printf("\n");
925         if (!pwdencrypt(&userpwdattr[RAD_Attr_Value], attrvallen, to->peer.secret, strlen(to->peer.secret), newauth)) {
926             printf("radsrv: cannot encrypt password\n");
927             return NULL;
928         }
929     }
930
931     if (tunnelpwdattr) {
932         printf("radsrv: found tunnelpwdattr of length %d\n", tunnelpwdattr[RAD_Attr_Length]);
933         attrvallen = tunnelpwdattr[RAD_Attr_Length] - 2;
934         if (attrvallen < 16 || attrvallen > 128 || attrvallen % 16) {
935             printf("radsrv: invalid user password length\n");
936             return NULL;
937         }
938         
939         if (!pwddecrypt(&tunnelpwdattr[RAD_Attr_Value], attrvallen, from->peer.secret, strlen(from->peer.secret), auth)) {
940             printf("radsrv: cannot decrypt password\n");
941             return NULL;
942         }
943         printf("radsrv: password: ");
944         for (i = 0; i < attrvallen; i++)
945             printf("%02x ", tunnelpwdattr[RAD_Attr_Value + i]);
946         printf("\n");
947         if (!pwdencrypt(&tunnelpwdattr[RAD_Attr_Value], attrvallen, to->peer.secret, strlen(to->peer.secret), newauth)) {
948             printf("radsrv: cannot encrypt password\n");
949             return NULL;
950         }
951     }
952
953     rq->buf = buf;
954     rq->from = from;
955     rq->origid = id;
956     rq->messageauthattrval = (messageauthattr ? &messageauthattr[RAD_Attr_Value] : NULL);
957     memcpy(rq->origauth, auth, 16);
958     memcpy(auth, newauth, 16);
959     printauth("rq->origauth", rq->origauth);
960     printauth("auth", auth);
961     return to;
962 }
963
964 void *clientrd(void *arg) {
965     struct server *server = (struct server *)arg;
966     struct client *from;
967     int i, left, subleft;
968     unsigned char *buf, *messageauthattr, *subattr, *attr;
969     struct sockaddr_storage fromsa;
970     struct timeval lastconnecttry;
971     char tmp[255];
972     
973     for (;;) {
974     getnext:
975         lastconnecttry = server->lastconnecttry;
976         buf = (server->peer.type == 'U' ? radudpget(server->sock, NULL, &server, NULL) : radtlsget(server->peer.ssl));
977         if (!buf && server->peer.type == 'T') {
978             tlsconnect(server, &lastconnecttry, "clientrd");
979             continue;
980         }
981     
982         server->connectionok = 1;
983
984         if (*buf != RAD_Access_Accept && *buf != RAD_Access_Reject && *buf != RAD_Access_Challenge) {
985             printf("clientrd: discarding, only accept access accept, access reject and access challenge messages\n");
986             continue;
987         }
988         
989         i = buf[1]; /* i is the id */
990
991         pthread_mutex_lock(&server->newrq_mutex);
992         if (!server->requests[i].buf || !server->requests[i].tries) {
993             pthread_mutex_unlock(&server->newrq_mutex);
994             printf("clientrd: no matching request sent with this id, ignoring\n");
995             continue;
996         }
997
998         if (server->requests[i].received) {
999             pthread_mutex_unlock(&server->newrq_mutex);
1000             printf("clientrd: already received, ignoring\n");
1001             continue;
1002         }
1003         
1004         if (!validauth(buf, server->requests[i].buf + 4, server->peer.secret)) {
1005             pthread_mutex_unlock(&server->newrq_mutex);
1006             printf("clientrd: invalid auth, ignoring\n");
1007             continue;
1008         }
1009         
1010         from = server->requests[i].from;
1011
1012
1013         /* messageauthattr present? */
1014         messageauthattr = NULL;
1015         left = RADLEN(buf) - 20;
1016         attr = buf + 20;
1017         while (left > 1) {
1018             left -= attr[RAD_Attr_Length];
1019             if (left < 0) {
1020                 printf("clientrd: attribute length exceeds packet length, ignoring packet\n");
1021                 goto getnext;
1022             }
1023             if (attr[RAD_Attr_Type] == RAD_Attr_Message_Authenticator) {
1024                 if (attr[RAD_Attr_Length] != 18) {
1025                     printf("clientrd: illegal message auth attribute length, ignoring packet\n");
1026                     goto getnext;
1027                 }
1028                 memcpy(tmp, buf + 4, 16);
1029                 memcpy(buf + 4, server->requests[i].buf + 4, 16);
1030                 if (!checkmessageauth(buf, &attr[RAD_Attr_Value], server->peer.secret)) {
1031                     printf("clientrd: message authentication failed\n");
1032                     goto getnext;
1033                 }
1034                 memcpy(buf + 4, tmp, 16);
1035                 printf("clientrd: message auth ok\n");
1036                 messageauthattr = attr;
1037                 break;
1038             }
1039             attr += attr[RAD_Attr_Length];
1040         }
1041
1042         /* handle MS MPPE */
1043         left = RADLEN(buf) - 20;
1044         attr = buf + 20;
1045         while (left > 1) {
1046             left -= attr[RAD_Attr_Length];
1047             if (left < 0) {
1048                 printf("clientrd: attribute length exceeds packet length, ignoring packet\n");
1049                 goto getnext;
1050             }
1051             if (attr[RAD_Attr_Type] == RAD_Attr_Vendor_Specific &&
1052                 ((uint16_t *)attr)[1] == 0 && ntohs(((uint16_t *)attr)[2]) == 311) { // 311 == MS
1053                 subleft = attr[RAD_Attr_Length] - 6;
1054                 subattr = attr + 6;
1055                 while (subleft > 1) {
1056                     subleft -= subattr[RAD_Attr_Length];
1057                     if (subleft < 0)
1058                         break;
1059                     if (subattr[RAD_Attr_Type] != RAD_VS_ATTR_MS_MPPE_Send_Key &&
1060                         subattr[RAD_Attr_Type] != RAD_VS_ATTR_MS_MPPE_Recv_Key)
1061                         continue;
1062                     printf("clientrd: Got MS MPPE\n");
1063                     if (subattr[RAD_Attr_Length] < 20)
1064                         continue;
1065
1066                     if (!msmppdecrypt(subattr + 4, subattr[RAD_Attr_Length] - 4,
1067                             server->peer.secret, strlen(server->peer.secret), server->requests[i].buf + 4, subattr + 2)) {
1068                         printf("clientrd: failed to decrypt msppe key\n");
1069                         continue;
1070                     }
1071
1072                     if (!msmppencrypt(subattr + 4, subattr[RAD_Attr_Length] - 4,
1073                             from->peer.secret, strlen(from->peer.secret), server->requests[i].origauth, subattr + 2)) {
1074                         printf("clientrd: failed to encrypt msppe key\n");
1075                         continue;
1076                     }
1077                 }
1078                 if (subleft < 0) {
1079                     printf("clientrd: bad vendor specific attr or subattr length, ignoring packet\n");
1080                     goto getnext;
1081                 }
1082             }
1083             attr += attr[RAD_Attr_Length];
1084         }
1085
1086         /* once we set received = 1, requests[i] may be reused */
1087         buf[1] = (char)server->requests[i].origid;
1088         memcpy(buf + 4, server->requests[i].origauth, 16);
1089         printauth("origauth/buf+4", buf + 4);
1090         if (messageauthattr) {
1091             if (!createmessageauth(buf, &messageauthattr[RAD_Attr_Value], from->peer.secret))
1092                 continue;
1093             printf("clientrd: computed messageauthattr\n");
1094         }
1095
1096         if (from->peer.type == 'U')
1097             fromsa = server->requests[i].fromsa;
1098         server->requests[i].received = 1;
1099         pthread_mutex_unlock(&server->newrq_mutex);
1100
1101         if (!radsign(buf, from->peer.secret)) {
1102             printf("clientrd: failed to sign message\n");
1103             continue;
1104         }
1105         printauth("signedorigauth/buf+4", buf + 4);             
1106         printf("clientrd: giving packet back to where it came from\n");
1107         sendreply(from, server, buf, from->peer.type == 'U' ? &fromsa : NULL);
1108     }
1109 }
1110
1111 void *clientwr(void *arg) {
1112     struct server *server = (struct server *)arg;
1113     struct request *rq;
1114     pthread_t clientrdth;
1115     int i;
1116     struct timeval now;
1117     
1118     if (server->peer.type == 'U') {
1119         if ((server->sock = connecttoserver(server->peer.addrinfo)) < 0) {
1120             printf("clientwr: connecttoserver failed\n");
1121             exit(1);
1122         }
1123     } else
1124         tlsconnect(server, NULL, "new client");
1125     
1126     if (pthread_create(&clientrdth, NULL, clientrd, (void *)server))
1127         errx("clientwr: pthread_create failed");
1128
1129     for (;;) {
1130         pthread_mutex_lock(&server->newrq_mutex);
1131         while (!server->newrq) {
1132             printf("clientwr: waiting for signal\n");
1133             pthread_cond_wait(&server->newrq_cond, &server->newrq_mutex);
1134             printf("clientwr: got signal\n");
1135         }
1136         server->newrq = 0;
1137         pthread_mutex_unlock(&server->newrq_mutex);
1138                
1139         for (i = 0; i < MAX_REQUESTS; i++) {
1140             pthread_mutex_lock(&server->newrq_mutex);
1141             while (!server->requests[i].buf && i < MAX_REQUESTS)
1142                 i++;
1143             if (i == MAX_REQUESTS) {
1144                 pthread_mutex_unlock(&server->newrq_mutex);
1145                 break;
1146             }
1147
1148             gettimeofday(&now, NULL);
1149             rq = server->requests + i;
1150
1151             if (rq->received) {
1152                 printf("clientwr: removing received packet from queue\n");
1153                 free(rq->buf);
1154                 /* setting this to NULL means that it can be reused */
1155                 rq->buf = NULL;
1156                 pthread_mutex_unlock(&server->newrq_mutex);
1157                 continue;
1158             }
1159             if (now.tv_sec > rq->expiry.tv_sec) {
1160                 printf("clientwr: removing expired packet from queue\n");
1161                 free(rq->buf);
1162                 /* setting this to NULL means that it can be reused */
1163                 rq->buf = NULL;
1164                 pthread_mutex_unlock(&server->newrq_mutex);
1165                 continue;
1166             }
1167
1168             if (rq->tries)
1169                 continue; // not re-sending (yet)
1170             
1171             rq->tries++;
1172             pthread_mutex_unlock(&server->newrq_mutex);
1173             
1174             clientradput(server, server->requests[i].buf);
1175         }
1176     }
1177     /* should do more work to maintain TLS connections, keepalives etc */
1178 }
1179
1180 void *udpserverwr(void *arg) {
1181     struct replyq *replyq = &udp_server_replyq;
1182     struct reply *reply = replyq->replies;
1183     
1184     pthread_mutex_lock(&replyq->count_mutex);
1185     for (;;) {
1186         while (!replyq->count) {
1187             printf("udp server writer, waiting for signal\n");
1188             pthread_cond_wait(&replyq->count_cond, &replyq->count_mutex);
1189             printf("udp server writer, got signal\n");
1190         }
1191         pthread_mutex_unlock(&replyq->count_mutex);
1192         
1193         if (sendto(udp_server_sock, reply->buf, RADLEN(reply->buf), 0,
1194                    (struct sockaddr *)&reply->tosa, SOCKADDR_SIZE(reply->tosa)) < 0)
1195             err("sendudp: send failed");
1196         free(reply->buf);
1197         
1198         pthread_mutex_lock(&replyq->count_mutex);
1199         replyq->count--;
1200         memmove(replyq->replies, replyq->replies + 1,
1201                 replyq->count * sizeof(struct reply));
1202     }
1203 }
1204
1205 void *udpserverrd(void *arg) {
1206     struct request rq;
1207     unsigned char *buf;
1208     struct server *to;
1209     struct client *fr;
1210     pthread_t udpserverwrth;
1211     
1212     if ((udp_server_sock = bindport(SOCK_DGRAM, udp_server_port)) < 0) {
1213         printf("udpserverrd: socket/bind failed\n");
1214         exit(1);
1215     }
1216     printf("udpserverrd: listening on UDP port %s\n", udp_server_port);
1217
1218     if (pthread_create(&udpserverwrth, NULL, udpserverwr, NULL))
1219         errx("pthread_create failed");
1220     
1221     for (;;) {
1222         fr = NULL;
1223         memset(&rq, 0, sizeof(struct request));
1224         buf = radudpget(udp_server_sock, &fr, NULL, &rq.fromsa);
1225         to = radsrv(&rq, buf, fr);
1226         if (!to) {
1227             printf("udpserverrd: ignoring request, no place to send it\n");
1228             continue;
1229         }
1230         sendrq(to, fr, &rq);
1231     }
1232 }
1233
1234 void *tlsserverwr(void *arg) {
1235     int cnt;
1236     unsigned long error;
1237     struct client *client = (struct client *)arg;
1238     struct replyq *replyq;
1239     
1240     pthread_mutex_lock(&client->replycount_mutex);
1241     for (;;) {
1242         replyq = client->replyq;
1243         while (!replyq->count) {
1244             printf("tls server writer, waiting for signal\n");
1245             pthread_cond_wait(&replyq->count_cond, &replyq->count_mutex);
1246             printf("tls server writer, got signal\n");
1247         }
1248         pthread_mutex_unlock(&replyq->count_mutex);
1249         cnt = SSL_write(client->peer.ssl, replyq->replies->buf, RADLEN(replyq->replies->buf));
1250         if (cnt > 0)
1251             printf("tlsserverwr: Sent %d bytes, Radius packet of length %d\n",
1252                    cnt, RADLEN(replyq->replies->buf));
1253         else
1254             while ((error = ERR_get_error()))
1255                 err("tlsserverwr: SSL: %s", ERR_error_string(error, NULL));
1256         free(replyq->replies->buf);
1257
1258         pthread_mutex_lock(&replyq->count_mutex);
1259         replyq->count--;
1260         memmove(replyq->replies, replyq->replies + 1, replyq->count * sizeof(struct reply));
1261     }
1262 }
1263
1264 void *tlsserverrd(void *arg) {
1265     struct request rq;
1266     char unsigned *buf;
1267     unsigned long error;
1268     struct server *to;
1269     int s;
1270     struct client *client = (struct client *)arg;
1271     pthread_t tlsserverwrth;
1272
1273     printf("tlsserverrd starting\n");
1274     if (SSL_accept(client->peer.ssl) <= 0) {
1275         while ((error = ERR_get_error()))
1276             err("tlsserverrd: SSL: %s", ERR_error_string(error, NULL));
1277         errx("accept failed, child exiting");
1278     }
1279
1280     if (pthread_create(&tlsserverwrth, NULL, tlsserverwr, (void *)client))
1281         errx("pthread_create failed");
1282     
1283     for (;;) {
1284         buf = radtlsget(client->peer.ssl);
1285         if (!buf) {
1286             printf("tlsserverrd: connection lost\n");
1287             s = SSL_get_fd(client->peer.ssl);
1288             SSL_free(client->peer.ssl);
1289             client->peer.ssl = NULL;
1290             if (s >= 0)
1291                 close(s);
1292             pthread_exit(NULL);
1293         }
1294         printf("tlsserverrd: got Radius message from %s\n", client->peer.host);
1295         memset(&rq, 0, sizeof(struct request));
1296         to = radsrv(&rq, buf, client);
1297         if (!to) {
1298             printf("ignoring request, no place to send it\n");
1299             continue;
1300         }
1301         sendrq(to, client, &rq);
1302     }
1303 }
1304
1305 int tlslistener(SSL_CTX *ssl_ctx) {
1306     pthread_t tlsserverth;
1307     int s, snew;
1308     struct sockaddr_storage from;
1309     size_t fromlen = sizeof(from);
1310     struct client *client;
1311
1312     if ((s = bindport(SOCK_STREAM, DEFAULT_TLS_PORT)) < 0) {
1313         printf("tlslistener: socket/bind failed\n");
1314         exit(1);
1315     }
1316     
1317     listen(s, 0);
1318     printf("listening for incoming TLS on port %s\n", DEFAULT_TLS_PORT);
1319
1320     for (;;) {
1321         snew = accept(s, (struct sockaddr *)&from, &fromlen);
1322         if (snew < 0)
1323             errx("accept failed");
1324         printf("incoming TLS connection from %s\n", addr2string((struct sockaddr *)&from, fromlen));
1325
1326         client = find_client('T', (struct sockaddr *)&from, NULL);
1327         if (!client) {
1328             printf("ignoring request, not a known TLS client\n");
1329             close(snew);
1330             continue;
1331         }
1332
1333         if (client->peer.ssl) {
1334             printf("Ignoring incoming connection, already have one from this client\n");
1335             close(snew);
1336             continue;
1337         }
1338         client->peer.ssl = SSL_new(ssl_ctx);
1339         SSL_set_fd(client->peer.ssl, snew);
1340         if (pthread_create(&tlsserverth, NULL, tlsserverrd, (void *)client))
1341             errx("pthread_create failed");
1342     }
1343     return 0;
1344 }
1345
1346 char *parsehostport(char *s, struct peer *peer) {
1347     char *p, *field;
1348     int ipv6 = 0;
1349
1350     p = s;
1351     // allow literal addresses and port, e.g. [2001:db8::1]:1812
1352     if (*p == '[') {
1353         p++;
1354         field = p;
1355         for (; *p && *p != ']' && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1356         if (*p != ']') {
1357             printf("no ] matching initial [\n");
1358             exit(1);
1359         }
1360         ipv6 = 1;
1361     } else {
1362         field = p;
1363         for (; *p && *p != ':' && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1364     }
1365     if (field == p) {
1366         printf("missing host/address\n");
1367         exit(1);
1368     }
1369     peer->host = malloc(p - field + 1);
1370     if (!peer->host)
1371         errx("malloc failed");
1372     memcpy(peer->host, field, p - field);
1373     peer->host[p - field] = '\0';
1374     if (ipv6) {
1375         p++;
1376         if (*p && *p != ':' && *p != ' ' && *p != '\t' && *p != '\n') {
1377             printf("unexpected character after ]\n");
1378             exit(1);
1379         }
1380     }
1381     if (*p == ':') {
1382             /* port number or service name is specified */;
1383             field = p++;
1384             for (; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1385             if (field == p) {
1386                 printf("syntax error, : but no following port\n");
1387                 exit(1);
1388             }
1389             peer->port = malloc(p - field + 1);
1390             if (!peer->port)
1391                 errx("malloc failed");
1392             memcpy(peer->port, field, p - field);
1393             peer->port[p - field] = '\0';
1394     } else
1395         peer->port = NULL;
1396     return p;
1397 }
1398
1399 // * is default, else longest match ... ";" used for separator
1400 char *parserealmlist(char *s, struct server *server) {
1401     char *p;
1402     int i, n, l;
1403
1404     for (p = s, n = 1; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++)
1405         if (*p == ';')
1406             n++;
1407     l = p - s;
1408     if (!l) {
1409         server->realms = NULL;
1410         return p;
1411     }
1412     server->realmdata = malloc(l + 1);
1413     if (!server->realmdata)
1414         errx("malloc failed");
1415     memcpy(server->realmdata, s, l);
1416     server->realmdata[l] = '\0';
1417     server->realms = malloc((1+n) * sizeof(char *));
1418     if (!server->realms)
1419         errx("malloc failed");
1420     server->realms[0] = server->realmdata;
1421     for (n = 1, i = 0; i < l; i++)
1422         if (server->realmdata[i] == ';') {
1423             server->realmdata[i] = '\0';
1424             server->realms[n++] = server->realmdata + i + 1;
1425         }       
1426     server->realms[n] = NULL;
1427     return p;
1428 }
1429
1430 /* exactly one argument must be non-NULL */
1431 void getconfig(const char *serverfile, const char *clientfile) {
1432     FILE *f;
1433     char line[1024];
1434     char *p, *field, **r;
1435     struct client *client;
1436     struct server *server;
1437     struct peer *peer;
1438     int *count;
1439     
1440     if (serverfile) {
1441         printf("opening file %s for reading\n", serverfile);
1442         f = fopen(serverfile, "r");
1443         if (!f)
1444             errx("getconfig failed to open %s for reading", serverfile);
1445         count = &server_count;
1446     } else {
1447         printf("opening file %s for reading\n", clientfile);
1448         f = fopen(clientfile, "r");
1449         if (!f)
1450             errx("getconfig failed to open %s for reading", clientfile);
1451         udp_server_replyq.replies = malloc(4 * MAX_REQUESTS * sizeof(struct reply));
1452         if (!udp_server_replyq.replies)
1453             errx("malloc failed");
1454         udp_server_replyq.size = 4 * MAX_REQUESTS;
1455         udp_server_replyq.count = 0;
1456         pthread_mutex_init(&udp_server_replyq.count_mutex, NULL);
1457         pthread_cond_init(&udp_server_replyq.count_cond, NULL);
1458         count = &client_count;
1459     }    
1460     
1461     *count = 0;
1462     while (fgets(line, 1024, f) && *count < MAX_PEERS) {
1463         if (serverfile) {
1464             server = &servers[*count];
1465             memset(server, 0, sizeof(struct server));
1466             peer = &server->peer;
1467         } else {
1468             client = &clients[*count];
1469             memset(client, 0, sizeof(struct client));
1470             peer = &client->peer;
1471         }
1472         for (p = line; *p == ' ' || *p == '\t'; p++);
1473         if (*p == '#' || *p == '\n')
1474             continue;
1475         if (*p != 'U' && *p != 'T') {
1476             printf("server type must be U or T, got %c\n", *p);
1477             exit(1);
1478         }
1479         peer->type = *p;
1480         for (p++; *p == ' ' || *p == '\t'; p++);
1481         p = parsehostport(p, peer);
1482         if (!peer->port)
1483             peer->port = (peer->type == 'U' ? DEFAULT_UDP_PORT : DEFAULT_TLS_PORT);
1484         for (; *p == ' ' || *p == '\t'; p++);
1485         if (serverfile) {
1486             p = parserealmlist(p, server);
1487             if (!server->realms) {
1488                 printf("realm list must be specified\n");
1489                 exit(1);
1490             }
1491             for (; *p == ' ' || *p == '\t'; p++);
1492         }
1493         field = p;
1494         for (; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1495         if (field == p) {
1496             /* no secret set and end of line, line is complete if TLS */
1497             if (peer->type == 'U') {
1498                 printf("secret must be specified for UDP\n");
1499                 exit(1);
1500             }
1501             peer->secret = DEFAULT_TLS_SECRET;
1502         } else {
1503             peer->secret = malloc(p - field + 1);
1504             if (!peer->secret)
1505                 errx("malloc failed");
1506             memcpy(peer->secret, field, p - field);
1507             peer->secret[p - field] = '\0';
1508             /* check that rest of line only white space */
1509             for (; *p == ' ' || *p == '\t'; p++);
1510             if (*p && *p != '\n') {
1511                 printf("max 4 fields per line, found a 5th\n");
1512                 exit(1);
1513             }
1514         }
1515
1516         if ((serverfile && !resolvepeer(&server->peer)) ||
1517             (clientfile && !resolvepeer(&client->peer))) {
1518             printf("failed to resolve host %s port %s, exiting\n", peer->host, peer->port);
1519             exit(1);
1520         }
1521
1522         if (serverfile) {
1523             pthread_mutex_init(&server->lock, NULL);
1524             server->sock = -1;
1525             server->requests = malloc(MAX_REQUESTS * sizeof(struct request));
1526             if (!server->requests)
1527                 errx("malloc failed");
1528             memset(server->requests, 0, MAX_REQUESTS * sizeof(struct request));
1529             server->newrq = 0;
1530             pthread_mutex_init(&server->newrq_mutex, NULL);
1531             pthread_cond_init(&server->newrq_cond, NULL);
1532         } else {
1533             if (peer->type == 'U')
1534                 client->replyq = &udp_server_replyq;
1535             else {
1536                 client->replyq = malloc(sizeof(struct replyq));
1537                 if (!client->replyq)
1538                     errx("malloc failed");
1539                 client->replyq->replies = malloc(MAX_REQUESTS * sizeof(struct reply));
1540                 if (!client->replyq->replies)
1541                     errx("malloc failed");
1542                 client->replyq->size = MAX_REQUESTS;
1543                 client->replyq->count = 0;
1544                 pthread_mutex_init(&client->replyq->count_mutex, NULL);
1545                 pthread_cond_init(&client->replyq->count_cond, NULL);
1546             }
1547         }
1548         printf("got type %c, host %s, port %s, secret %s\n", peer->type, peer->host, peer->port, peer->secret);
1549         if (serverfile) {
1550             printf("    with realms:");
1551             for (r = server->realms; *r; r++)
1552                 printf(" %s", *r);
1553             printf("\n");
1554         }
1555         (*count)++;
1556     }
1557     fclose(f);
1558 }
1559
1560 void parseargs(int argc, char **argv) {
1561     int c;
1562
1563     while ((c = getopt(argc, argv, "p:")) != -1) {
1564         switch (c) {
1565         case 'p':
1566             udp_server_port = optarg;
1567             break;
1568         default:
1569             goto usage;
1570         }
1571     }
1572
1573     return;
1574
1575  usage:
1576     printf("radsecproxy [ -p UDP-port ]\n");
1577     exit(1);
1578 }
1579                
1580 int main(int argc, char **argv) {
1581     SSL_CTX *ssl_ctx_srv;
1582     unsigned long error;
1583     pthread_t udpserverth;
1584     pthread_attr_t joinable;
1585     int i;
1586     
1587     parseargs(argc, argv);
1588     getconfig("servers.conf", NULL);
1589     getconfig(NULL, "clients.conf");
1590     
1591     ssl_locks_setup();
1592
1593     pthread_attr_init(&joinable);
1594     pthread_attr_setdetachstate(&joinable, PTHREAD_CREATE_JOINABLE);
1595    
1596     /* listen on UDP if at least one UDP client */
1597     
1598     for (i = 0; i < client_count; i++)
1599         if (clients[i].peer.type == 'U') {
1600             if (pthread_create(&udpserverth, &joinable, udpserverrd, NULL))
1601                 errx("pthread_create failed");
1602             break;
1603         }
1604     
1605     /* SSL setup */
1606     SSL_load_error_strings();
1607     SSL_library_init();
1608
1609     while (!RAND_status()) {
1610         time_t t = time(NULL);
1611         pid_t pid = getpid();
1612         RAND_seed((unsigned char *)&t, sizeof(time_t));
1613         RAND_seed((unsigned char *)&pid, sizeof(pid));
1614     }
1615     
1616     /* initialise client part and start clients */
1617     ssl_ctx_cl = SSL_CTX_new(TLSv1_client_method());
1618     if (!ssl_ctx_cl)
1619         errx("no ssl ctx");
1620     
1621     for (i = 0; i < server_count; i++) {
1622         if (pthread_create(&servers[i].clientth, NULL, clientwr, (void *)&servers[i]))
1623             errx("pthread_create failed");
1624     }
1625
1626     for (i = 0; i < client_count; i++)
1627         if (clients[i].peer.type == 'T')
1628             break;
1629
1630     if (i == client_count) {
1631         printf("No TLS clients defined, not starting TLS listener\n");
1632         /* just hang around doing nothing, anything to do here? */
1633         for (;;)
1634             sleep(1000);
1635     }
1636     
1637     /* setting up server/daemon part */
1638     ssl_ctx_srv = SSL_CTX_new(TLSv1_server_method());
1639     if (!ssl_ctx_srv)
1640         errx("no ssl ctx");
1641     if (!SSL_CTX_use_certificate_file(ssl_ctx_srv, "/tmp/server.pem", SSL_FILETYPE_PEM)) {
1642         while ((error = ERR_get_error()))
1643             err("SSL: %s", ERR_error_string(error, NULL));
1644         errx("Failed to load certificate");
1645     }
1646     if (!SSL_CTX_use_PrivateKey_file(ssl_ctx_srv, "/tmp/server.key", SSL_FILETYPE_PEM)) {
1647         while ((error = ERR_get_error()))
1648             err("SSL: %s", ERR_error_string(error, NULL));
1649         errx("Failed to load private key");
1650     }
1651
1652     return tlslistener(ssl_ctx_srv);
1653 }