bcb430141532e5a2de99393643ad8c8232695807
[radsecproxy.git] / radsecproxy.c
1 /*
2  * Copyright (C) 2006 Stig Venaas <venaas@uninett.no>
3  *
4  * Permission to use, copy, modify, and distribute this software for any
5  * purpose with or without fee is hereby granted, provided that the above
6  * copyright notice and this permission notice appear in all copies.
7  */
8
9 /* BUGS:
10  * peers can not yet be specified with literal IPv6 addresses due to port syntax
11  */
12
13 /* TODO:
14  * make our server ignore client retrans and do its own instead?
15  * accounting
16  * radius keep alives (server status)
17  * tls certificate validation, see below urls
18  * clean tls shutdown, see http://www.linuxjournal.com/article/4822
19  *     and http://www.linuxjournal.com/article/5487
20  *     SSL_shutdown() and shutdown()
21  *     If shutdown() we may not need REUSEADDR
22  * when tls client goes away, ensure that all related threads and state
23  *          are removed
24  * setsockopt(keepalive...), check if openssl has some keepalive feature
25 */
26
27 /* For UDP there is one server instance consisting of udpserverrd and udpserverth
28  *              rd is responsible for init and launching wr
29  * For TLS there is a server instance that launches tlsserverrd for each TLS peer
30  *          each tlsserverrd launches tlsserverwr
31  * For each UDP/TLS peer there is clientrd and clientwr, clientwr is responsible
32  *          for init and launching rd
33  *
34  * serverrd will receive a request, processes it and puts it in the requestq of
35  *          the appropriate clientwr
36  * clientwr monitors its requestq and sends requests
37  * clientrd looks for responses, processes them and puts them in the replyq of
38  *          the peer the request came from
39  * serverwr monitors its reply and sends replies
40  *
41  * In addition to the main thread, we have:
42  * If UDP peers are configured, there will be 2 + 2 * #peers UDP threads
43  * If TLS peers are configured, there will initially be 2 * #peers TLS threads
44  * For each TLS peer connecting to us there will be 2 more TLS threads
45  *       This is only for connected peers
46  * Example: With 3 UDP peer and 30 TLS peers, there will be a max of
47  *          1 + (2 + 2 * 3) + (2 * 30) + (2 * 30) = 129 threads
48 */
49
50 #include <netdb.h>
51 #include <unistd.h>
52 #include <sys/time.h>
53 #include <pthread.h>
54 #include <openssl/ssl.h>
55 #include <openssl/rand.h>
56 #include <openssl/err.h>
57 #include <openssl/md5.h>
58 #include <openssl/hmac.h>
59 #include "radsecproxy.h"
60
61 static struct client clients[MAX_PEERS];
62 static struct server servers[MAX_PEERS];
63
64 static int client_count = 0;
65 static int server_count = 0;
66
67 static struct replyq udp_server_replyq;
68 static int udp_server_sock = -1;
69 static char *udp_server_port = DEFAULT_UDP_PORT;
70 static pthread_mutex_t *ssl_locks;
71 static long *ssl_lock_count;
72 static SSL_CTX *ssl_ctx_cl;
73 extern int optind;
74 extern char *optarg;
75
76 /* callbacks for making OpenSSL thread safe */
77 unsigned long ssl_thread_id() {
78         return (unsigned long)pthread_self();
79 };
80
81 void ssl_locking_callback(int mode, int type, const char *file, int line) {
82     if (mode & CRYPTO_LOCK) {
83         pthread_mutex_lock(&ssl_locks[type]);
84         ssl_lock_count[type]++;
85     } else
86         pthread_mutex_unlock(&ssl_locks[type]);
87 }
88
89 void ssl_locks_setup() {
90     int i;
91
92     ssl_locks = malloc(CRYPTO_num_locks() * sizeof(pthread_mutex_t));
93     ssl_lock_count = OPENSSL_malloc(CRYPTO_num_locks() * sizeof(long));
94     for (i = 0; i < CRYPTO_num_locks(); i++) {
95         ssl_lock_count[i] = 0;
96         pthread_mutex_init(&ssl_locks[i], NULL);
97     }
98
99     CRYPTO_set_id_callback(ssl_thread_id);
100     CRYPTO_set_locking_callback(ssl_locking_callback);
101 }
102
103 void printauth(char *s, unsigned char *t) {
104     int i;
105     printf("%s:", s);
106     for (i = 0; i < 16; i++)
107             printf("%02x ", t[i]);
108     printf("\n");
109 }
110
111 int resolvepeer(struct peer *peer) {
112     struct addrinfo hints, *addrinfo;
113     
114     memset(&hints, 0, sizeof(hints));
115     hints.ai_socktype = (peer->type == 'T' ? SOCK_STREAM : SOCK_DGRAM);
116     hints.ai_family = AF_UNSPEC;
117     if (getaddrinfo(peer->host, peer->port, &hints, &addrinfo)) {
118         err("resolvepeer: can't resolve %s port %s", peer->host, peer->port);
119         return 0;
120     }
121
122     if (peer->addrinfo)
123         freeaddrinfo(peer->addrinfo);
124     peer->addrinfo = addrinfo;
125     return 1;
126 }         
127
128 int connecttoserver(struct addrinfo *addrinfo) {
129     int s;
130     struct addrinfo *res;
131     
132     for (res = addrinfo; res; res = res->ai_next) {
133         s = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
134         if (s < 0) {
135             err("connecttoserver: socket failed");
136             continue;
137         }
138         if (connect(s, res->ai_addr, res->ai_addrlen) == 0)
139             break;
140         err("connecttoserver: connect failed");
141         close(s);
142         s = -1;
143     }
144     return s;
145 }         
146
147 /* returns the client with matching address, or NULL */
148 /* if client argument is not NULL, we only check that one client */
149 struct client *find_client(char type, struct sockaddr *addr, struct client *client) {
150     struct sockaddr_in6 *sa6;
151     struct in_addr *a4 = NULL;
152     struct client *c;
153     int i;
154     struct addrinfo *res;
155
156     if (addr->sa_family == AF_INET6) {
157         sa6 = (struct sockaddr_in6 *)addr;
158         if (IN6_IS_ADDR_V4MAPPED(&sa6->sin6_addr))
159             a4 = (struct in_addr *)&sa6->sin6_addr.s6_addr[12];
160     } else
161         a4 = &((struct sockaddr_in *)addr)->sin_addr;
162
163     c = (client ? client : clients);
164     for (i = 0; i < client_count; i++) {
165         if (c->peer.type == type)
166             for (res = c->peer.addrinfo; res; res = res->ai_next)
167                 if ((a4 && res->ai_family == AF_INET &&
168                      !memcmp(a4, &((struct sockaddr_in *)res->ai_addr)->sin_addr, 4)) ||
169                     (res->ai_family == AF_INET6 &&
170                      !memcmp(&sa6->sin6_addr, &((struct sockaddr_in6 *)res->ai_addr)->sin6_addr, 16)))
171                     return c;
172         if (client)
173             break;
174         c++;
175     }
176     return NULL;
177 }
178
179 /* returns the server with matching address, or NULL */
180 /* if server argument is not NULL, we only check that one server */
181 struct server *find_server(char type, struct sockaddr *addr, struct server *server) {
182     struct sockaddr_in6 *sa6;
183     struct in_addr *a4 = NULL;
184     struct server *s;
185     int i;
186     struct addrinfo *res;
187
188     if (addr->sa_family == AF_INET6) {
189         sa6 = (struct sockaddr_in6 *)addr;
190         if (IN6_IS_ADDR_V4MAPPED(&sa6->sin6_addr))
191             a4 = (struct in_addr *)&sa6->sin6_addr.s6_addr[12];
192     } else
193         a4 = &((struct sockaddr_in *)addr)->sin_addr;
194
195     s = (server ? server : servers);
196     for (i = 0; i < server_count; i++) {
197         if (s->peer.type == type)
198             for (res = s->peer.addrinfo; res; res = res->ai_next)
199                 if ((a4 && res->ai_family == AF_INET &&
200                      !memcmp(a4, &((struct sockaddr_in *)res->ai_addr)->sin_addr, 4)) ||
201                     (res->ai_family == AF_INET6 &&
202                      !memcmp(&sa6->sin6_addr, &((struct sockaddr_in6 *)res->ai_addr)->sin6_addr, 16)))
203                     return s;
204         if (server)
205             break;
206         s++;
207     }
208     return NULL;
209 }
210
211 /* exactly one of client and server must be non-NULL */
212 /* if *peer == NULL we return who we received from, else require it to be from peer */
213 /* return from in sa if not NULL */
214 unsigned char *radudpget(int s, struct client **client, struct server **server, struct sockaddr_storage *sa) {
215     int cnt, len;
216     void *f;
217     unsigned char buf[65536], *rad;
218     struct sockaddr_storage from;
219     socklen_t fromlen = sizeof(from);
220
221     for (;;) {
222         cnt = recvfrom(s, buf, sizeof(buf), 0, (struct sockaddr *)&from, &fromlen);
223         if (cnt == -1) {
224             err("radudpget: recv failed");
225             continue;
226         }
227         printf("radudpget: got %d bytes from %s\n", cnt, addr2string((struct sockaddr *)&from, fromlen));
228
229         if (cnt < 20) {
230             printf("radudpget: packet too small\n");
231             continue;
232         }
233     
234         len = RADLEN(buf);
235
236         if (cnt < len) {
237             printf("radudpget: packet smaller than length field in radius header\n");
238             continue;
239         }
240         if (cnt > len)
241             printf("radudpget: packet was padded with %d bytes\n", cnt - len);
242
243         f = (client
244              ? (void *)find_client('U', (struct sockaddr *)&from, *client)
245              : (void *)find_server('U', (struct sockaddr *)&from, *server));
246         if (!f) {
247             printf("radudpget: got packet from wrong or unknown UDP peer, ignoring\n");
248             continue;
249         }
250
251         rad = malloc(len);
252         if (rad)
253             break;
254         err("radudpget: malloc failed");
255     }
256     memcpy(rad, buf, len);
257     if (client)
258         *client = (struct client *)f; /* only need this if *client == NULL, but if not NULL *client == f here */
259     else
260         *server = (struct server *)f; /* only need this if *server == NULL, but if not NULL *server == f here */
261     if (sa)
262         *sa = from;
263     return rad;
264 }
265
266 void tlsconnect(struct server *server, struct timeval *when, char *text) {
267     struct timeval now;
268     time_t elapsed;
269     unsigned long error;
270
271     printf("tlsconnect called from %s\n", text);
272     pthread_mutex_lock(&server->lock);
273     if (when && memcmp(&server->lastconnecttry, when, sizeof(struct timeval))) {
274         /* already reconnected, nothing to do */
275         printf("tlsconnect(%s): seems already reconnected\n", text);
276         pthread_mutex_unlock(&server->lock);
277         return;
278     }
279
280     printf("tlsconnect %s\n", text);
281
282     for (;;) {
283         gettimeofday(&now, NULL);
284         elapsed = now.tv_sec - server->lastconnecttry.tv_sec;
285         if (server->connectionok) {
286             server->connectionok = 0;
287             sleep(10);
288         } else if (elapsed < 5)
289             sleep(10);
290         else if (elapsed < 600)
291             sleep(elapsed * 2);
292         else if (elapsed < 10000)
293                 sleep(900);
294         else
295             server->lastconnecttry.tv_sec = now.tv_sec;  // no sleep at startup
296         printf("tlsconnect: trying to open TLS connection to %s port %s\n", server->peer.host, server->peer.port);
297         if (server->sock >= 0)
298             close(server->sock);
299         if ((server->sock = connecttoserver(server->peer.addrinfo)) < 0)
300             continue;
301         SSL_free(server->peer.ssl);
302         server->peer.ssl = SSL_new(ssl_ctx_cl);
303         SSL_set_fd(server->peer.ssl, server->sock);
304         if (SSL_connect(server->peer.ssl) > 0)
305             break;
306         while ((error = ERR_get_error()))
307             err("tlsconnect: TLS: %s", ERR_error_string(error, NULL));
308     }
309     printf("tlsconnect: TLS connection to %s port %s up\n", server->peer.host, server->peer.port);
310     gettimeofday(&server->lastconnecttry, NULL);
311     pthread_mutex_unlock(&server->lock);
312 }
313
314 unsigned char *radtlsget(SSL *ssl) {
315     int cnt, total, len;
316     unsigned char buf[4], *rad;
317
318     for (;;) {
319         for (total = 0; total < 4; total += cnt) {
320             cnt = SSL_read(ssl, buf + total, 4 - total);
321             if (cnt <= 0) {
322                 printf("radtlsget: connection lost\n");
323                 if (SSL_get_error(ssl, cnt) == SSL_ERROR_ZERO_RETURN) {
324                     //remote end sent close_notify, send one back
325                     SSL_shutdown(ssl);
326                 }
327                 return NULL;
328             }
329         }
330
331         len = RADLEN(buf);
332         rad = malloc(len);
333         if (!rad) {
334             err("radtlsget: malloc failed");
335             continue;
336         }
337         memcpy(rad, buf, 4);
338
339         for (; total < len; total += cnt) {
340             cnt = SSL_read(ssl, rad + total, len - total);
341             if (cnt <= 0) {
342                 printf("radtlsget: connection lost\n");
343                 if (SSL_get_error(ssl, cnt) == SSL_ERROR_ZERO_RETURN) {
344                     //remote end sent close_notify, send one back
345                     SSL_shutdown(ssl);
346                 }
347                 free(rad);
348                 return NULL;
349             }
350         }
351     
352         if (total >= 20)
353             break;
354         
355         free(rad);
356         printf("radtlsget: packet smaller than minimum radius size\n");
357     }
358     
359     printf("radtlsget: got %d bytes\n", total);
360     return rad;
361 }
362
363 int clientradput(struct server *server, unsigned char *rad) {
364     int cnt;
365     size_t len;
366     unsigned long error;
367     struct timeval lastconnecttry;
368     
369     len = RADLEN(rad);
370     if (server->peer.type == 'U') {
371         if (send(server->sock, rad, len, 0) >= 0) {
372             printf("clienradput: sent UDP of length %d to %s port %s\n", len, server->peer.host, server->peer.port);
373             return 1;
374         }
375         err("clientradput: send failed");
376         return 0;
377     }
378
379     lastconnecttry = server->lastconnecttry;
380     while ((cnt = SSL_write(server->peer.ssl, rad, len)) <= 0) {
381         while ((error = ERR_get_error()))
382             err("clientwr: TLS: %s", ERR_error_string(error, NULL));
383         tlsconnect(server, &lastconnecttry, "clientradput");
384         lastconnecttry = server->lastconnecttry;
385     }
386
387     server->connectionok = 1;
388     printf("clientradput: Sent %d bytes, Radius packet of length %d to TLS peer %s\n",
389            cnt, len, server->peer.host);
390     return 1;
391 }
392
393 int radsign(unsigned char *rad, unsigned char *sec) {
394     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
395     static unsigned char first = 1;
396     static EVP_MD_CTX mdctx;
397     unsigned int md_len;
398     int result;
399     
400     pthread_mutex_lock(&lock);
401     if (first) {
402         EVP_MD_CTX_init(&mdctx);
403         first = 0;
404     }
405
406     result = (EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) &&
407         EVP_DigestUpdate(&mdctx, rad, RADLEN(rad)) &&
408         EVP_DigestUpdate(&mdctx, sec, strlen(sec)) &&
409         EVP_DigestFinal_ex(&mdctx, rad + 4, &md_len) &&
410         md_len == 16);
411     pthread_mutex_unlock(&lock);
412     return result;
413 }
414
415 int validauth(unsigned char *rad, unsigned char *reqauth, unsigned char *sec) {
416     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
417     static unsigned char first = 1;
418     static EVP_MD_CTX mdctx;
419     unsigned char hash[EVP_MAX_MD_SIZE];
420     unsigned int len;
421     int result;
422     
423     pthread_mutex_lock(&lock);
424     if (first) {
425         EVP_MD_CTX_init(&mdctx);
426         first = 0;
427     }
428
429     len = RADLEN(rad);
430     
431     result = (EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) &&
432               EVP_DigestUpdate(&mdctx, rad, 4) &&
433               EVP_DigestUpdate(&mdctx, reqauth, 16) &&
434               (len <= 20 || EVP_DigestUpdate(&mdctx, rad + 20, len - 20)) &&
435               EVP_DigestUpdate(&mdctx, sec, strlen(sec)) &&
436               EVP_DigestFinal_ex(&mdctx, hash, &len) &&
437               len == 16 &&
438               !memcmp(hash, rad + 4, 16));
439     pthread_mutex_unlock(&lock);
440     return result;
441 }
442               
443 int checkmessageauth(char *rad, uint8_t *authattr, char *secret) {
444     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
445     static unsigned char first = 1;
446     static HMAC_CTX hmacctx;
447     unsigned int md_len;
448     uint8_t auth[16], hash[EVP_MAX_MD_SIZE];
449     
450     pthread_mutex_lock(&lock);
451     if (first) {
452         HMAC_CTX_init(&hmacctx);
453         first = 0;
454     }
455
456     memcpy(auth, authattr, 16);
457     memset(authattr, 0, 16);
458     md_len = 0;
459     HMAC_Init_ex(&hmacctx, secret, strlen(secret), EVP_md5(), NULL);
460     HMAC_Update(&hmacctx, rad, RADLEN(rad));
461     HMAC_Final(&hmacctx, hash, &md_len);
462     memcpy(authattr, auth, 16);
463     if (md_len != 16) {
464         printf("message auth computation failed\n");
465         pthread_mutex_unlock(&lock);
466         return 0;
467     }
468
469     if (memcmp(auth, hash, 16)) {
470         printf("message authenticator, wrong value\n");
471         pthread_mutex_unlock(&lock);
472         return 0;
473     }   
474         
475     pthread_mutex_unlock(&lock);
476     return 1;
477 }
478
479 int createmessageauth(char *rad, char *authattrval, char *secret) {
480     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
481     static unsigned char first = 1;
482     static HMAC_CTX hmacctx;
483     unsigned int md_len;
484
485     if (!authattrval)
486         return 1;
487     
488     pthread_mutex_lock(&lock);
489     if (first) {
490         HMAC_CTX_init(&hmacctx);
491         first = 0;
492     }
493
494     memset(authattrval, 0, 16);
495     md_len = 0;
496     HMAC_Init_ex(&hmacctx, secret, strlen(secret), EVP_md5(), NULL);
497     HMAC_Update(&hmacctx, rad, RADLEN(rad));
498     HMAC_Final(&hmacctx, authattrval, &md_len);
499     if (md_len != 16) {
500         printf("message auth computation failed\n");
501         pthread_mutex_unlock(&lock);
502         return 0;
503     }
504
505     pthread_mutex_unlock(&lock);
506     return 1;
507 }
508
509 void sendrq(struct server *to, struct client *from, struct request *rq) {
510     int i;
511     
512     pthread_mutex_lock(&to->newrq_mutex);
513     /* might simplify if only try nextid, might be ok */
514     for (i = to->nextid; i < MAX_REQUESTS; i++)
515         if (!to->requests[i].buf)
516             break;
517     if (i == MAX_REQUESTS) {
518         for (i = 0; i < to->nextid; i++)
519             if (!to->requests[i].buf)
520                 break;
521         if (i == to->nextid) {
522             printf("No room in queue, dropping request\n");
523             pthread_mutex_unlock(&to->newrq_mutex);
524             return;
525         }
526     }
527     
528     to->nextid = i + 1;
529     rq->buf[1] = (char)i;
530     printf("sendrq: inserting packet with id %d in queue for %s\n", i, to->peer.host);
531     
532     if (!createmessageauth(rq->buf, rq->messageauthattrval, to->peer.secret))
533         return;
534
535     gettimeofday(&rq->expiry, NULL);
536     rq->expiry.tv_sec += 30;
537     to->requests[i] = *rq;
538
539     if (!to->newrq) {
540         to->newrq = 1;
541         printf("signalling client writer\n");
542         pthread_cond_signal(&to->newrq_cond);
543     }
544     pthread_mutex_unlock(&to->newrq_mutex);
545 }
546
547 void sendreply(struct client *to, struct server *from, char *buf, struct sockaddr_storage *tosa) {
548     struct replyq *replyq = to->replyq;
549     
550     pthread_mutex_lock(&replyq->count_mutex);
551     if (replyq->count == replyq->size) {
552         printf("No room in queue, dropping request\n");
553         pthread_mutex_unlock(&replyq->count_mutex);
554         return;
555     }
556
557     replyq->replies[replyq->count].buf = buf;
558     if (tosa)
559         replyq->replies[replyq->count].tosa = *tosa;
560     replyq->count++;
561
562     if (replyq->count == 1) {
563         printf("signalling client writer\n");
564         pthread_cond_signal(&replyq->count_cond);
565     }
566     pthread_mutex_unlock(&replyq->count_mutex);
567 }
568
569 int pwdencrypt(uint8_t *in, uint8_t len, uint8_t *shared, uint8_t sharedlen, uint8_t *auth) {
570     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
571     static unsigned char first = 1;
572     static EVP_MD_CTX mdctx;
573     unsigned char hash[EVP_MAX_MD_SIZE], *input;
574     unsigned int md_len;
575     uint8_t i, offset = 0, out[128];
576     
577     pthread_mutex_lock(&lock);
578     if (first) {
579         EVP_MD_CTX_init(&mdctx);
580         first = 0;
581     }
582
583     input = auth;
584     for (;;) {
585         if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
586             !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
587             !EVP_DigestUpdate(&mdctx, input, 16) ||
588             !EVP_DigestFinal_ex(&mdctx, hash, &md_len) ||
589             md_len != 16) {
590             pthread_mutex_unlock(&lock);
591             return 0;
592         }
593         for (i = 0; i < 16; i++)
594             out[offset + i] = hash[i] ^ in[offset + i];
595         input = out + offset - 16;
596         offset += 16;
597         if (offset == len)
598             break;
599     }
600     memcpy(in, out, len);
601     pthread_mutex_unlock(&lock);
602     return 1;
603 }
604
605 int pwddecrypt(uint8_t *in, uint8_t len, uint8_t *shared, uint8_t sharedlen, uint8_t *auth) {
606     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
607     static unsigned char first = 1;
608     static EVP_MD_CTX mdctx;
609     unsigned char hash[EVP_MAX_MD_SIZE], *input;
610     unsigned int md_len;
611     uint8_t i, offset = 0, out[128];
612     
613     pthread_mutex_lock(&lock);
614     if (first) {
615         EVP_MD_CTX_init(&mdctx);
616         first = 0;
617     }
618
619     input = auth;
620     for (;;) {
621         if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
622             !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
623             !EVP_DigestUpdate(&mdctx, input, 16) ||
624             !EVP_DigestFinal_ex(&mdctx, hash, &md_len) ||
625             md_len != 16) {
626             pthread_mutex_unlock(&lock);
627             return 0;
628         }
629         for (i = 0; i < 16; i++)
630             out[offset + i] = hash[i] ^ in[offset + i];
631         input = in + offset;
632         offset += 16;
633         if (offset == len)
634             break;
635     }
636     memcpy(in, out, len);
637     pthread_mutex_unlock(&lock);
638     return 1;
639 }
640
641 int msmppencrypt(uint8_t *text, uint8_t len, uint8_t *shared, uint8_t sharedlen, uint8_t *auth, uint8_t *salt) {
642     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
643     static unsigned char first = 1;
644     static EVP_MD_CTX mdctx;
645     unsigned char hash[EVP_MAX_MD_SIZE];
646     unsigned int md_len;
647     uint8_t i, offset;
648     
649     pthread_mutex_lock(&lock);
650     if (first) {
651         EVP_MD_CTX_init(&mdctx);
652         first = 0;
653     }
654
655 #if 0    
656     printf("msppencrypt auth in: ");
657     for (i = 0; i < 16; i++)
658         printf("%02x ", auth[i]);
659     printf("\n");
660     
661     printf("msppencrypt salt in: ");
662     for (i = 0; i < 2; i++)
663         printf("%02x ", salt[i]);
664     printf("\n");
665     
666     printf("msppencrypt in: ");
667     for (i = 0; i < len; i++)
668         printf("%02x ", text[i]);
669     printf("\n");
670 #endif
671     
672     if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
673         !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
674         !EVP_DigestUpdate(&mdctx, auth, 16) ||
675         !EVP_DigestUpdate(&mdctx, salt, 2) ||
676         !EVP_DigestFinal_ex(&mdctx, hash, &md_len)) {
677         pthread_mutex_unlock(&lock);
678         return 0;
679     }
680
681 #if 0    
682     printf("msppencrypt hash: ");
683     for (i = 0; i < 16; i++)
684         printf("%02x ", hash[i]);
685     printf("\n");
686 #endif
687     
688     for (i = 0; i < 16; i++)
689         text[i] ^= hash[i];
690     
691     for (offset = 16; offset < len; offset += 16) {
692 #if 0   
693         printf("text + offset - 16 c(%d): ", offset / 16);
694         for (i = 0; i < 16; i++)
695             printf("%02x ", (text + offset - 16)[i]);
696         printf("\n");
697 #endif
698         if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
699             !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
700             !EVP_DigestUpdate(&mdctx, text + offset - 16, 16) ||
701             !EVP_DigestFinal_ex(&mdctx, hash, &md_len) ||
702             md_len != 16) {
703             pthread_mutex_unlock(&lock);
704             return 0;
705         }
706 #if 0   
707         printf("msppencrypt hash: ");
708         for (i = 0; i < 16; i++)
709             printf("%02x ", hash[i]);
710         printf("\n");
711 #endif    
712         
713         for (i = 0; i < 16; i++)
714             text[offset + i] ^= hash[i];
715     }
716     
717 #if 0
718     printf("msppencrypt out: ");
719     for (i = 0; i < len; i++)
720         printf("%02x ", text[i]);
721     printf("\n");
722 #endif
723
724     pthread_mutex_unlock(&lock);
725     return 1;
726 }
727
728 int msmppdecrypt(uint8_t *text, uint8_t len, uint8_t *shared, uint8_t sharedlen, uint8_t *auth, uint8_t *salt) {
729     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
730     static unsigned char first = 1;
731     static EVP_MD_CTX mdctx;
732     unsigned char hash[EVP_MAX_MD_SIZE];
733     unsigned int md_len;
734     uint8_t i, offset;
735     char plain[255];
736     
737     pthread_mutex_lock(&lock);
738     if (first) {
739         EVP_MD_CTX_init(&mdctx);
740         first = 0;
741     }
742
743 #if 0    
744     printf("msppdecrypt auth in: ");
745     for (i = 0; i < 16; i++)
746         printf("%02x ", auth[i]);
747     printf("\n");
748     
749     printf("msppedecrypt salt in: ");
750     for (i = 0; i < 2; i++)
751         printf("%02x ", salt[i]);
752     printf("\n");
753     
754     printf("msppedecrypt in: ");
755     for (i = 0; i < len; i++)
756         printf("%02x ", text[i]);
757     printf("\n");
758 #endif
759     
760     if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
761         !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
762         !EVP_DigestUpdate(&mdctx, auth, 16) ||
763         !EVP_DigestUpdate(&mdctx, salt, 2) ||
764         !EVP_DigestFinal_ex(&mdctx, hash, &md_len)) {
765         pthread_mutex_unlock(&lock);
766         return 0;
767     }
768
769 #if 0    
770     printf("msppedecrypt hash: ");
771     for (i = 0; i < 16; i++)
772         printf("%02x ", hash[i]);
773     printf("\n");
774 #endif
775     
776     for (i = 0; i < 16; i++)
777         plain[i] = text[i] ^ hash[i];
778     
779     for (offset = 16; offset < len; offset += 16) {
780 #if 0   
781         printf("text + offset - 16 c(%d): ", offset / 16);
782         for (i = 0; i < 16; i++)
783             printf("%02x ", (text + offset - 16)[i]);
784         printf("\n");
785 #endif
786         if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
787             !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
788             !EVP_DigestUpdate(&mdctx, text + offset - 16, 16) ||
789             !EVP_DigestFinal_ex(&mdctx, hash, &md_len) ||
790             md_len != 16) {
791             pthread_mutex_unlock(&lock);
792             return 0;
793         }
794 #if 0   
795     printf("msppedecrypt hash: ");
796     for (i = 0; i < 16; i++)
797         printf("%02x ", hash[i]);
798     printf("\n");
799 #endif    
800
801     for (i = 0; i < 16; i++)
802         plain[offset + i] = text[offset + i] ^ hash[i];
803     }
804
805     memcpy(text, plain, len);
806 #if 0
807     printf("msppedecrypt out: ");
808     for (i = 0; i < len; i++)
809         printf("%02x ", text[i]);
810     printf("\n");
811 #endif
812
813     pthread_mutex_unlock(&lock);
814     return 1;
815 }
816
817 struct server *id2server(char *id, uint8_t len) {
818     int i;
819     char **realm, *idrealm;
820
821     idrealm = strchr(id, '@');
822     if (idrealm) {
823         idrealm++;
824         len -= idrealm - id;
825     } else {
826         idrealm = "-";
827         len = 1;
828     }
829     for (i = 0; i < server_count; i++) {
830         for (realm = servers[i].realms; *realm; realm++) {
831             if ((strlen(*realm) == 1 && **realm == '*') ||
832                 (strlen(*realm) == len && !memcmp(idrealm, *realm, len))) {
833                 printf("found matching realm: %s, host %s\n", *realm, servers[i].peer.host);
834                 return servers + i;
835             }
836         }
837     }
838     return NULL;
839 }
840
841 struct server *radsrv(struct request *rq, char *buf, struct client *from) {
842     uint8_t code, id, *auth, *attr, attrvallen;
843     uint8_t *usernameattr = NULL, *userpwdattr = NULL, *tunnelpwdattr = NULL, *messageauthattr = NULL;
844     int i;
845     uint16_t len;
846     int left;
847     struct server *to;
848     unsigned char newauth[16];
849     
850     code = *(uint8_t *)buf;
851     id = *(uint8_t *)(buf + 1);
852     len = RADLEN(buf);
853     auth = (uint8_t *)(buf + 4);
854
855     printf("radsrv: code %d, id %d, length %d\n", code, id, len);
856     
857     if (code != RAD_Access_Request) {
858         printf("radsrv: server currently accepts only access-requests, ignoring\n");
859         return NULL;
860     }
861
862     left = len - 20;
863     attr = buf + 20;
864     
865     while (left > 1) {
866         left -= attr[RAD_Attr_Length];
867         if (left < 0) {
868             printf("radsrv: attribute length exceeds packet length, ignoring packet\n");
869             return NULL;
870         }
871         switch (attr[RAD_Attr_Type]) {
872         case RAD_Attr_User_Name:
873             usernameattr = attr;
874             break;
875         case RAD_Attr_User_Password:
876             userpwdattr = attr;
877             break;
878         case RAD_Attr_Tunnel_Password:
879             tunnelpwdattr = attr;
880             break;
881         case RAD_Attr_Message_Authenticator:
882             messageauthattr = attr;
883             break;
884         }
885         attr += attr[RAD_Attr_Length];
886     }
887     if (left)
888         printf("radsrv: malformed packet? remaining byte after last attribute\n");
889
890     if (usernameattr) {
891         printf("radsrv: Username: ");
892         for (i = 0; i < usernameattr[RAD_Attr_Length] - 2; i++)
893             printf("%c", usernameattr[RAD_Attr_Value + i]);
894         printf("\n");
895     }
896
897     to = id2server(&usernameattr[RAD_Attr_Value], usernameattr[RAD_Attr_Length] - 2);
898     if (!to) {
899         printf("radsrv: ignoring request, don't know where to send it\n");
900         return NULL;
901     }
902     
903     if (messageauthattr && (messageauthattr[RAD_Attr_Length] != 18 ||
904                             !checkmessageauth(buf, &messageauthattr[RAD_Attr_Value], from->peer.secret))) {
905         printf("radsrv: message authentication failed\n");
906         return NULL;
907     }
908
909     if (!RAND_bytes(newauth, 16)) {
910         printf("radsrv: failed to generate random auth\n");
911         return NULL;
912     }
913
914     printauth("auth", auth);
915     printauth("newauth", newauth);
916     
917     if (userpwdattr) {
918         printf("radsrv: found userpwdattr of length %d\n", userpwdattr[RAD_Attr_Length]);
919         attrvallen = userpwdattr[RAD_Attr_Length] - 2;
920         if (attrvallen < 16 || attrvallen > 128 || attrvallen % 16) {
921             printf("radsrv: invalid user password length\n");
922             return NULL;
923         }
924         
925         if (!pwddecrypt(&userpwdattr[RAD_Attr_Value], attrvallen, from->peer.secret, strlen(from->peer.secret), auth)) {
926             printf("radsrv: cannot decrypt password\n");
927             return NULL;
928         }
929         printf("radsrv: password: ");
930         for (i = 0; i < attrvallen; i++)
931             printf("%02x ", userpwdattr[RAD_Attr_Value + i]);
932         printf("\n");
933         if (!pwdencrypt(&userpwdattr[RAD_Attr_Value], attrvallen, to->peer.secret, strlen(to->peer.secret), newauth)) {
934             printf("radsrv: cannot encrypt password\n");
935             return NULL;
936         }
937     }
938
939     if (tunnelpwdattr) {
940         printf("radsrv: found tunnelpwdattr of length %d\n", tunnelpwdattr[RAD_Attr_Length]);
941         attrvallen = tunnelpwdattr[RAD_Attr_Length] - 2;
942         if (attrvallen < 16 || attrvallen > 128 || attrvallen % 16) {
943             printf("radsrv: invalid user password length\n");
944             return NULL;
945         }
946         
947         if (!pwddecrypt(&tunnelpwdattr[RAD_Attr_Value], attrvallen, from->peer.secret, strlen(from->peer.secret), auth)) {
948             printf("radsrv: cannot decrypt password\n");
949             return NULL;
950         }
951         printf("radsrv: password: ");
952         for (i = 0; i < attrvallen; i++)
953             printf("%02x ", tunnelpwdattr[RAD_Attr_Value + i]);
954         printf("\n");
955         if (!pwdencrypt(&tunnelpwdattr[RAD_Attr_Value], attrvallen, to->peer.secret, strlen(to->peer.secret), newauth)) {
956             printf("radsrv: cannot encrypt password\n");
957             return NULL;
958         }
959     }
960
961     rq->buf = buf;
962     rq->from = from;
963     rq->origid = id;
964     rq->messageauthattrval = (messageauthattr ? &messageauthattr[RAD_Attr_Value] : NULL);
965     memcpy(rq->origauth, auth, 16);
966     memcpy(auth, newauth, 16);
967     printauth("rq->origauth", rq->origauth);
968     printauth("auth", auth);
969     return to;
970 }
971
972 void *clientrd(void *arg) {
973     struct server *server = (struct server *)arg;
974     struct client *from;
975     int i, left, subleft;
976     unsigned char *buf, *messageauthattr, *subattr, *attr;
977     struct sockaddr_storage fromsa;
978     struct timeval lastconnecttry;
979     char tmp[255];
980     
981     for (;;) {
982     getnext:
983         lastconnecttry = server->lastconnecttry;
984         buf = (server->peer.type == 'U' ? radudpget(server->sock, NULL, &server, NULL) : radtlsget(server->peer.ssl));
985         if (!buf && server->peer.type == 'T') {
986             tlsconnect(server, &lastconnecttry, "clientrd");
987             continue;
988         }
989     
990         server->connectionok = 1;
991
992         if (*buf != RAD_Access_Accept && *buf != RAD_Access_Reject && *buf != RAD_Access_Challenge) {
993             printf("clientrd: discarding, only accept access accept, access reject and access challenge messages\n");
994             continue;
995         }
996         
997         i = buf[1]; /* i is the id */
998
999         pthread_mutex_lock(&server->newrq_mutex);
1000         if (!server->requests[i].buf || !server->requests[i].tries) {
1001             pthread_mutex_unlock(&server->newrq_mutex);
1002             printf("clientrd: no matching request sent with this id, ignoring\n");
1003             continue;
1004         }
1005
1006         if (server->requests[i].received) {
1007             pthread_mutex_unlock(&server->newrq_mutex);
1008             printf("clientrd: already received, ignoring\n");
1009             continue;
1010         }
1011         
1012         if (!validauth(buf, server->requests[i].buf + 4, server->peer.secret)) {
1013             pthread_mutex_unlock(&server->newrq_mutex);
1014             printf("clientrd: invalid auth, ignoring\n");
1015             continue;
1016         }
1017         
1018         from = server->requests[i].from;
1019
1020
1021         /* messageauthattr present? */
1022         messageauthattr = NULL;
1023         left = RADLEN(buf) - 20;
1024         attr = buf + 20;
1025         while (left > 1) {
1026             left -= attr[RAD_Attr_Length];
1027             if (left < 0) {
1028                 printf("clientrd: attribute length exceeds packet length, ignoring packet\n");
1029                 goto getnext;
1030             }
1031             if (attr[RAD_Attr_Type] == RAD_Attr_Message_Authenticator) {
1032                 if (attr[RAD_Attr_Length] != 18) {
1033                     printf("clientrd: illegal message auth attribute length, ignoring packet\n");
1034                     goto getnext;
1035                 }
1036                 memcpy(tmp, buf + 4, 16);
1037                 memcpy(buf + 4, server->requests[i].buf + 4, 16);
1038                 if (!checkmessageauth(buf, &attr[RAD_Attr_Value], server->peer.secret)) {
1039                     printf("clientrd: message authentication failed\n");
1040                     goto getnext;
1041                 }
1042                 memcpy(buf + 4, tmp, 16);
1043                 printf("clientrd: message auth ok\n");
1044                 messageauthattr = attr;
1045                 break;
1046             }
1047             attr += attr[RAD_Attr_Length];
1048         }
1049
1050         /* handle MS MPPE */
1051         left = RADLEN(buf) - 20;
1052         attr = buf + 20;
1053         while (left > 1) {
1054             left -= attr[RAD_Attr_Length];
1055             if (left < 0) {
1056                 printf("clientrd: attribute length exceeds packet length, ignoring packet\n");
1057                 goto getnext;
1058             }
1059             if (attr[RAD_Attr_Type] == RAD_Attr_Vendor_Specific &&
1060                 ((uint16_t *)attr)[1] == 0 && ntohs(((uint16_t *)attr)[2]) == 311) { // 311 == MS
1061                 subleft = attr[RAD_Attr_Length] - 6;
1062                 subattr = attr + 6;
1063                 while (subleft > 1) {
1064                     subleft -= subattr[RAD_Attr_Length];
1065                     if (subleft < 0)
1066                         break;
1067                     if (subattr[RAD_Attr_Type] != RAD_VS_ATTR_MS_MPPE_Send_Key &&
1068                         subattr[RAD_Attr_Type] != RAD_VS_ATTR_MS_MPPE_Recv_Key)
1069                         continue;
1070                     printf("clientrd: Got MS MPPE\n");
1071                     if (subattr[RAD_Attr_Length] < 20)
1072                         continue;
1073
1074                     if (!msmppdecrypt(subattr + 4, subattr[RAD_Attr_Length] - 4,
1075                             server->peer.secret, strlen(server->peer.secret), server->requests[i].buf + 4, subattr + 2)) {
1076                         printf("clientrd: failed to decrypt msppe key\n");
1077                         continue;
1078                     }
1079
1080                     if (!msmppencrypt(subattr + 4, subattr[RAD_Attr_Length] - 4,
1081                             from->peer.secret, strlen(from->peer.secret), server->requests[i].origauth, subattr + 2)) {
1082                         printf("clientrd: failed to encrypt msppe key\n");
1083                         continue;
1084                     }
1085                 }
1086                 if (subleft < 0) {
1087                     printf("clientrd: bad vendor specific attr or subattr length, ignoring packet\n");
1088                     goto getnext;
1089                 }
1090             }
1091             attr += attr[RAD_Attr_Length];
1092         }
1093
1094         /* once we set received = 1, requests[i] may be reused */
1095         buf[1] = (char)server->requests[i].origid;
1096         memcpy(buf + 4, server->requests[i].origauth, 16);
1097         printauth("origauth/buf+4", buf + 4);
1098         if (messageauthattr) {
1099             if (!createmessageauth(buf, &messageauthattr[RAD_Attr_Value], from->peer.secret))
1100                 continue;
1101             printf("clientrd: computed messageauthattr\n");
1102         }
1103
1104         if (from->peer.type == 'U')
1105             fromsa = server->requests[i].fromsa;
1106         server->requests[i].received = 1;
1107         pthread_mutex_unlock(&server->newrq_mutex);
1108
1109         if (!radsign(buf, from->peer.secret)) {
1110             printf("clientrd: failed to sign message\n");
1111             continue;
1112         }
1113         printauth("signedorigauth/buf+4", buf + 4);             
1114         printf("clientrd: giving packet back to where it came from\n");
1115         sendreply(from, server, buf, from->peer.type == 'U' ? &fromsa : NULL);
1116     }
1117 }
1118
1119 void *clientwr(void *arg) {
1120     struct server *server = (struct server *)arg;
1121     struct request *rq;
1122     pthread_t clientrdth;
1123     int i;
1124     struct timeval now;
1125     
1126     if (server->peer.type == 'U') {
1127         if ((server->sock = connecttoserver(server->peer.addrinfo)) < 0) {
1128             printf("clientwr: connecttoserver failed\n");
1129             exit(1);
1130         }
1131     } else
1132         tlsconnect(server, NULL, "new client");
1133     
1134     if (pthread_create(&clientrdth, NULL, clientrd, (void *)server))
1135         errx("clientwr: pthread_create failed");
1136
1137     for (;;) {
1138         pthread_mutex_lock(&server->newrq_mutex);
1139         while (!server->newrq) {
1140             printf("clientwr: waiting for signal\n");
1141             pthread_cond_wait(&server->newrq_cond, &server->newrq_mutex);
1142             printf("clientwr: got signal\n");
1143         }
1144         server->newrq = 0;
1145         pthread_mutex_unlock(&server->newrq_mutex);
1146                
1147         for (i = 0; i < MAX_REQUESTS; i++) {
1148             pthread_mutex_lock(&server->newrq_mutex);
1149             while (!server->requests[i].buf && i < MAX_REQUESTS)
1150                 i++;
1151             if (i == MAX_REQUESTS) {
1152                 pthread_mutex_unlock(&server->newrq_mutex);
1153                 break;
1154             }
1155
1156             gettimeofday(&now, NULL);
1157             rq = server->requests + i;
1158
1159             if (rq->received) {
1160                 printf("clientwr: removing received packet from queue\n");
1161                 free(rq->buf);
1162                 /* setting this to NULL means that it can be reused */
1163                 rq->buf = NULL;
1164                 pthread_mutex_unlock(&server->newrq_mutex);
1165                 continue;
1166             }
1167             if (now.tv_sec > rq->expiry.tv_sec) {
1168                 printf("clientwr: removing expired packet from queue\n");
1169                 free(rq->buf);
1170                 /* setting this to NULL means that it can be reused */
1171                 rq->buf = NULL;
1172                 pthread_mutex_unlock(&server->newrq_mutex);
1173                 continue;
1174             }
1175
1176             if (rq->tries)
1177                 continue; // not re-sending (yet)
1178             
1179             rq->tries++;
1180             pthread_mutex_unlock(&server->newrq_mutex);
1181             
1182             clientradput(server, server->requests[i].buf);
1183         }
1184     }
1185     /* should do more work to maintain TLS connections, keepalives etc */
1186 }
1187
1188 void *udpserverwr(void *arg) {
1189     struct replyq *replyq = &udp_server_replyq;
1190     struct reply *reply = replyq->replies;
1191     
1192     pthread_mutex_lock(&replyq->count_mutex);
1193     for (;;) {
1194         while (!replyq->count) {
1195             printf("udp server writer, waiting for signal\n");
1196             pthread_cond_wait(&replyq->count_cond, &replyq->count_mutex);
1197             printf("udp server writer, got signal\n");
1198         }
1199         pthread_mutex_unlock(&replyq->count_mutex);
1200         
1201         if (sendto(udp_server_sock, reply->buf, RADLEN(reply->buf), 0,
1202                    (struct sockaddr *)&reply->tosa, SOCKADDR_SIZE(reply->tosa)) < 0)
1203             err("sendudp: send failed");
1204         free(reply->buf);
1205         
1206         pthread_mutex_lock(&replyq->count_mutex);
1207         replyq->count--;
1208         memmove(replyq->replies, replyq->replies + 1,
1209                 replyq->count * sizeof(struct reply));
1210     }
1211 }
1212
1213 void *udpserverrd(void *arg) {
1214     struct request rq;
1215     unsigned char *buf;
1216     struct server *to;
1217     struct client *fr;
1218     pthread_t udpserverwrth;
1219     
1220     if ((udp_server_sock = bindport(SOCK_DGRAM, udp_server_port)) < 0) {
1221         printf("udpserverrd: socket/bind failed\n");
1222         exit(1);
1223     }
1224     printf("udpserverrd: listening on UDP port %s\n", udp_server_port);
1225
1226     if (pthread_create(&udpserverwrth, NULL, udpserverwr, NULL))
1227         errx("pthread_create failed");
1228     
1229     for (;;) {
1230         fr = NULL;
1231         memset(&rq, 0, sizeof(struct request));
1232         buf = radudpget(udp_server_sock, &fr, NULL, &rq.fromsa);
1233         to = radsrv(&rq, buf, fr);
1234         if (!to) {
1235             printf("udpserverrd: ignoring request, no place to send it\n");
1236             continue;
1237         }
1238         sendrq(to, fr, &rq);
1239     }
1240 }
1241
1242 void *tlsserverwr(void *arg) {
1243     int cnt;
1244     unsigned long error;
1245     struct client *client = (struct client *)arg;
1246     struct replyq *replyq;
1247     
1248     printf("tlsserverwr starting for %s\n", client->peer.host);
1249     replyq = client->replyq;
1250     pthread_mutex_lock(&replyq->count_mutex);
1251     for (;;) {
1252         while (!replyq->count) {
1253             if (client->peer.ssl) {         
1254                 printf("tls server writer, waiting for signal\n");
1255                 pthread_cond_wait(&replyq->count_cond, &replyq->count_mutex);
1256                 printf("tls server writer, got signal\n");
1257             }
1258             if (!client->peer.ssl) {
1259                 //ssl might have changed while waiting
1260                 pthread_mutex_unlock(&replyq->count_mutex);
1261                 printf("tlsserverwr: exiting as requested\n");
1262                 pthread_exit(NULL);
1263             }
1264         }
1265         pthread_mutex_unlock(&replyq->count_mutex);
1266         cnt = SSL_write(client->peer.ssl, replyq->replies->buf, RADLEN(replyq->replies->buf));
1267         if (cnt > 0)
1268             printf("tlsserverwr: Sent %d bytes, Radius packet of length %d\n",
1269                    cnt, RADLEN(replyq->replies->buf));
1270         else
1271             while ((error = ERR_get_error()))
1272                 err("tlsserverwr: SSL: %s", ERR_error_string(error, NULL));
1273         free(replyq->replies->buf);
1274
1275         pthread_mutex_lock(&replyq->count_mutex);
1276         replyq->count--;
1277         memmove(replyq->replies, replyq->replies + 1, replyq->count * sizeof(struct reply));
1278     }
1279 }
1280
1281 void *tlsserverrd(void *arg) {
1282     struct request rq;
1283     char unsigned *buf;
1284     unsigned long error;
1285     struct server *to;
1286     int s;
1287     struct client *client = (struct client *)arg;
1288     pthread_t tlsserverwrth;
1289     SSL *ssl;
1290     
1291     printf("tlsserverrd starting for %s\n", client->peer.host);
1292     ssl = client->peer.ssl;
1293
1294     if (SSL_accept(ssl) <= 0) {
1295         while ((error = ERR_get_error()))
1296             err("tlsserverrd: SSL: %s", ERR_error_string(error, NULL));
1297         errx("accept failed, child exiting");
1298     }
1299
1300     if (pthread_create(&tlsserverwrth, NULL, tlsserverwr, (void *)client))
1301         errx("pthread_create failed");
1302     
1303     for (;;) {
1304         buf = radtlsget(client->peer.ssl);
1305         if (!buf)
1306             break;
1307         printf("tlsserverrd: got Radius message from %s\n", client->peer.host);
1308         memset(&rq, 0, sizeof(struct request));
1309         to = radsrv(&rq, buf, client);
1310         if (!to) {
1311             printf("ignoring request, no place to send it\n");
1312             continue;
1313         }
1314         sendrq(to, client, &rq);
1315     }
1316     printf("tlsserverrd: connection lost\n");
1317     // stop writer by setting peer.ssl to NULL and give signal in case waiting for data
1318     client->peer.ssl = NULL;
1319     pthread_mutex_lock(&client->replyq->count_mutex);
1320     pthread_cond_signal(&client->replyq->count_cond);
1321     pthread_mutex_unlock(&client->replyq->count_mutex);
1322     printf("tlsserverrd: waiting for writer to end\n");
1323     pthread_join(tlsserverwrth, NULL);
1324     s = SSL_get_fd(ssl);
1325     SSL_free(ssl);
1326     close(s);
1327     printf("tlsserverrd thread for %s exiting\n", client->peer.host);
1328     pthread_exit(NULL);
1329 }
1330
1331 int tlslistener(SSL_CTX *ssl_ctx) {
1332     pthread_t tlsserverth;
1333     int s, snew;
1334     struct sockaddr_storage from;
1335     size_t fromlen = sizeof(from);
1336     struct client *client;
1337
1338     if ((s = bindport(SOCK_STREAM, DEFAULT_TLS_PORT)) < 0) {
1339         printf("tlslistener: socket/bind failed\n");
1340         exit(1);
1341     }
1342     
1343     listen(s, 0);
1344     printf("listening for incoming TLS on port %s\n", DEFAULT_TLS_PORT);
1345
1346     for (;;) {
1347         snew = accept(s, (struct sockaddr *)&from, &fromlen);
1348         if (snew < 0)
1349             errx("accept failed");
1350         printf("incoming TLS connection from %s\n", addr2string((struct sockaddr *)&from, fromlen));
1351
1352         client = find_client('T', (struct sockaddr *)&from, NULL);
1353         if (!client) {
1354             printf("ignoring request, not a known TLS client\n");
1355             shutdown(snew, SHUT_RDWR);
1356             close(snew);
1357             continue;
1358         }
1359
1360         if (client->peer.ssl) {
1361             printf("Ignoring incoming connection, already have one from this client\n");
1362             shutdown(snew, SHUT_RDWR);
1363             close(snew);
1364             continue;
1365         }
1366         client->peer.ssl = SSL_new(ssl_ctx);
1367         SSL_set_fd(client->peer.ssl, snew);
1368         if (pthread_create(&tlsserverth, NULL, tlsserverrd, (void *)client))
1369             errx("pthread_create failed");
1370         pthread_detach(tlsserverth);
1371     }
1372     return 0;
1373 }
1374
1375 char *parsehostport(char *s, struct peer *peer) {
1376     char *p, *field;
1377     int ipv6 = 0;
1378
1379     p = s;
1380     // allow literal addresses and port, e.g. [2001:db8::1]:1812
1381     if (*p == '[') {
1382         p++;
1383         field = p;
1384         for (; *p && *p != ']' && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1385         if (*p != ']') {
1386             printf("no ] matching initial [\n");
1387             exit(1);
1388         }
1389         ipv6 = 1;
1390     } else {
1391         field = p;
1392         for (; *p && *p != ':' && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1393     }
1394     if (field == p) {
1395         printf("missing host/address\n");
1396         exit(1);
1397     }
1398     peer->host = malloc(p - field + 1);
1399     if (!peer->host)
1400         errx("malloc failed");
1401     memcpy(peer->host, field, p - field);
1402     peer->host[p - field] = '\0';
1403     if (ipv6) {
1404         p++;
1405         if (*p && *p != ':' && *p != ' ' && *p != '\t' && *p != '\n') {
1406             printf("unexpected character after ]\n");
1407             exit(1);
1408         }
1409     }
1410     if (*p == ':') {
1411             /* port number or service name is specified */;
1412             field = p++;
1413             for (; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1414             if (field == p) {
1415                 printf("syntax error, : but no following port\n");
1416                 exit(1);
1417             }
1418             peer->port = malloc(p - field + 1);
1419             if (!peer->port)
1420                 errx("malloc failed");
1421             memcpy(peer->port, field, p - field);
1422             peer->port[p - field] = '\0';
1423     } else
1424         peer->port = NULL;
1425     return p;
1426 }
1427
1428 // * is default, else longest match ... ";" used for separator
1429 char *parserealmlist(char *s, struct server *server) {
1430     char *p;
1431     int i, n, l;
1432
1433     for (p = s, n = 1; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++)
1434         if (*p == ';')
1435             n++;
1436     l = p - s;
1437     if (!l) {
1438         server->realms = NULL;
1439         return p;
1440     }
1441     server->realmdata = malloc(l + 1);
1442     if (!server->realmdata)
1443         errx("malloc failed");
1444     memcpy(server->realmdata, s, l);
1445     server->realmdata[l] = '\0';
1446     server->realms = malloc((1+n) * sizeof(char *));
1447     if (!server->realms)
1448         errx("malloc failed");
1449     server->realms[0] = server->realmdata;
1450     for (n = 1, i = 0; i < l; i++)
1451         if (server->realmdata[i] == ';') {
1452             server->realmdata[i] = '\0';
1453             server->realms[n++] = server->realmdata + i + 1;
1454         }       
1455     server->realms[n] = NULL;
1456     return p;
1457 }
1458
1459 /* exactly one argument must be non-NULL */
1460 void getconfig(const char *serverfile, const char *clientfile) {
1461     FILE *f;
1462     char line[1024];
1463     char *p, *field, **r;
1464     struct client *client;
1465     struct server *server;
1466     struct peer *peer;
1467     int *count;
1468     
1469     if (serverfile) {
1470         printf("opening file %s for reading\n", serverfile);
1471         f = fopen(serverfile, "r");
1472         if (!f)
1473             errx("getconfig failed to open %s for reading", serverfile);
1474         count = &server_count;
1475     } else {
1476         printf("opening file %s for reading\n", clientfile);
1477         f = fopen(clientfile, "r");
1478         if (!f)
1479             errx("getconfig failed to open %s for reading", clientfile);
1480         udp_server_replyq.replies = malloc(4 * MAX_REQUESTS * sizeof(struct reply));
1481         if (!udp_server_replyq.replies)
1482             errx("malloc failed");
1483         udp_server_replyq.size = 4 * MAX_REQUESTS;
1484         udp_server_replyq.count = 0;
1485         pthread_mutex_init(&udp_server_replyq.count_mutex, NULL);
1486         pthread_cond_init(&udp_server_replyq.count_cond, NULL);
1487         count = &client_count;
1488     }    
1489     
1490     *count = 0;
1491     while (fgets(line, 1024, f) && *count < MAX_PEERS) {
1492         if (serverfile) {
1493             server = &servers[*count];
1494             memset(server, 0, sizeof(struct server));
1495             peer = &server->peer;
1496         } else {
1497             client = &clients[*count];
1498             memset(client, 0, sizeof(struct client));
1499             peer = &client->peer;
1500         }
1501         for (p = line; *p == ' ' || *p == '\t'; p++);
1502         if (*p == '#' || *p == '\n')
1503             continue;
1504         if (*p != 'U' && *p != 'T') {
1505             printf("server type must be U or T, got %c\n", *p);
1506             exit(1);
1507         }
1508         peer->type = *p;
1509         for (p++; *p == ' ' || *p == '\t'; p++);
1510         p = parsehostport(p, peer);
1511         if (!peer->port)
1512             peer->port = (peer->type == 'U' ? DEFAULT_UDP_PORT : DEFAULT_TLS_PORT);
1513         for (; *p == ' ' || *p == '\t'; p++);
1514         if (serverfile) {
1515             p = parserealmlist(p, server);
1516             if (!server->realms) {
1517                 printf("realm list must be specified\n");
1518                 exit(1);
1519             }
1520             for (; *p == ' ' || *p == '\t'; p++);
1521         }
1522         field = p;
1523         for (; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1524         if (field == p) {
1525             /* no secret set and end of line, line is complete if TLS */
1526             if (peer->type == 'U') {
1527                 printf("secret must be specified for UDP\n");
1528                 exit(1);
1529             }
1530             peer->secret = DEFAULT_TLS_SECRET;
1531         } else {
1532             peer->secret = malloc(p - field + 1);
1533             if (!peer->secret)
1534                 errx("malloc failed");
1535             memcpy(peer->secret, field, p - field);
1536             peer->secret[p - field] = '\0';
1537             /* check that rest of line only white space */
1538             for (; *p == ' ' || *p == '\t'; p++);
1539             if (*p && *p != '\n') {
1540                 printf("max 4 fields per line, found a 5th\n");
1541                 exit(1);
1542             }
1543         }
1544
1545         if ((serverfile && !resolvepeer(&server->peer)) ||
1546             (clientfile && !resolvepeer(&client->peer))) {
1547             printf("failed to resolve host %s port %s, exiting\n", peer->host, peer->port);
1548             exit(1);
1549         }
1550
1551         if (serverfile) {
1552             pthread_mutex_init(&server->lock, NULL);
1553             server->sock = -1;
1554             server->requests = malloc(MAX_REQUESTS * sizeof(struct request));
1555             if (!server->requests)
1556                 errx("malloc failed");
1557             memset(server->requests, 0, MAX_REQUESTS * sizeof(struct request));
1558             server->newrq = 0;
1559             pthread_mutex_init(&server->newrq_mutex, NULL);
1560             pthread_cond_init(&server->newrq_cond, NULL);
1561         } else {
1562             if (peer->type == 'U')
1563                 client->replyq = &udp_server_replyq;
1564             else {
1565                 client->replyq = malloc(sizeof(struct replyq));
1566                 if (!client->replyq)
1567                     errx("malloc failed");
1568                 client->replyq->replies = malloc(MAX_REQUESTS * sizeof(struct reply));
1569                 if (!client->replyq->replies)
1570                     errx("malloc failed");
1571                 client->replyq->size = MAX_REQUESTS;
1572                 client->replyq->count = 0;
1573                 pthread_mutex_init(&client->replyq->count_mutex, NULL);
1574                 pthread_cond_init(&client->replyq->count_cond, NULL);
1575             }
1576         }
1577         printf("got type %c, host %s, port %s, secret %s\n", peer->type, peer->host, peer->port, peer->secret);
1578         if (serverfile) {
1579             printf("    with realms:");
1580             for (r = server->realms; *r; r++)
1581                 printf(" %s", *r);
1582             printf("\n");
1583         }
1584         (*count)++;
1585     }
1586     fclose(f);
1587 }
1588
1589 void parseargs(int argc, char **argv) {
1590     int c;
1591
1592     while ((c = getopt(argc, argv, "p:")) != -1) {
1593         switch (c) {
1594         case 'p':
1595             udp_server_port = optarg;
1596             break;
1597         default:
1598             goto usage;
1599         }
1600     }
1601
1602     return;
1603
1604  usage:
1605     printf("radsecproxy [ -p UDP-port ]\n");
1606     exit(1);
1607 }
1608                
1609 int main(int argc, char **argv) {
1610     SSL_CTX *ssl_ctx_srv;
1611     unsigned long error;
1612     pthread_t udpserverth;
1613     //    pthread_attr_t joinable;
1614     int i;
1615     
1616     parseargs(argc, argv);
1617     getconfig("servers.conf", NULL);
1618     getconfig(NULL, "clients.conf");
1619     
1620     ssl_locks_setup();
1621
1622     //    pthread_attr_init(&joinable);
1623     //    pthread_attr_setdetachstate(&joinable, PTHREAD_CREATE_JOINABLE);
1624    
1625     /* listen on UDP if at least one UDP client */
1626     
1627     for (i = 0; i < client_count; i++)
1628         if (clients[i].peer.type == 'U') {
1629             if (pthread_create(&udpserverth, NULL /*&joinable*/, udpserverrd, NULL))
1630                 errx("pthread_create failed");
1631             break;
1632         }
1633     
1634     /* SSL setup */
1635     SSL_load_error_strings();
1636     SSL_library_init();
1637
1638     while (!RAND_status()) {
1639         time_t t = time(NULL);
1640         pid_t pid = getpid();
1641         RAND_seed((unsigned char *)&t, sizeof(time_t));
1642         RAND_seed((unsigned char *)&pid, sizeof(pid));
1643     }
1644     
1645     /* initialise client part and start clients */
1646     ssl_ctx_cl = SSL_CTX_new(TLSv1_client_method());
1647     if (!ssl_ctx_cl)
1648         errx("no ssl ctx");
1649     
1650     for (i = 0; i < server_count; i++) {
1651         if (pthread_create(&servers[i].clientth, NULL, clientwr, (void *)&servers[i]))
1652             errx("pthread_create failed");
1653     }
1654
1655     for (i = 0; i < client_count; i++)
1656         if (clients[i].peer.type == 'T')
1657             break;
1658
1659     if (i == client_count) {
1660         printf("No TLS clients defined, not starting TLS listener\n");
1661         /* just hang around doing nothing, anything to do here? */
1662         for (;;)
1663             sleep(1000);
1664     }
1665     
1666     /* setting up server/daemon part */
1667     ssl_ctx_srv = SSL_CTX_new(TLSv1_server_method());
1668     if (!ssl_ctx_srv)
1669         errx("no ssl ctx");
1670     if (!SSL_CTX_use_certificate_file(ssl_ctx_srv, "/tmp/server.pem", SSL_FILETYPE_PEM)) {
1671         while ((error = ERR_get_error()))
1672             err("SSL: %s", ERR_error_string(error, NULL));
1673         errx("Failed to load certificate");
1674     }
1675     if (!SSL_CTX_use_PrivateKey_file(ssl_ctx_srv, "/tmp/server.key", SSL_FILETYPE_PEM)) {
1676         while ((error = ERR_get_error()))
1677             err("SSL: %s", ERR_error_string(error, NULL));
1678         errx("Failed to load private key");
1679     }
1680
1681     return tlslistener(ssl_ctx_srv);
1682 }