simplified ssl_init
[radsecproxy.git] / radsecproxy.c
1 /*
2  * Copyright (C) 2006 Stig Venaas <venaas@uninett.no>
3  *
4  * Permission to use, copy, modify, and distribute this software for any
5  * purpose with or without fee is hereby granted, provided that the above
6  * copyright notice and this permission notice appear in all copies.
7  */
8
9 /* TODO:
10  * accounting
11  * radius keep alives (server status)
12  * tls certificate validation, see below urls
13  * clean tls shutdown, see http://www.linuxjournal.com/article/4822
14  *     and http://www.linuxjournal.com/article/5487
15  *     SSL_shutdown() and shutdown()
16  *     If shutdown() we may not need REUSEADDR
17  * when tls client goes away, ensure that all related threads and state
18  *          are removed
19  * setsockopt(keepalive...), check if openssl has some keepalive feature
20 */
21
22 /* For UDP there is one server instance consisting of udpserverrd and udpserverth
23  *              rd is responsible for init and launching wr
24  * For TLS there is a server instance that launches tlsserverrd for each TLS peer
25  *          each tlsserverrd launches tlsserverwr
26  * For each UDP/TLS peer there is clientrd and clientwr, clientwr is responsible
27  *          for init and launching rd
28  *
29  * serverrd will receive a request, processes it and puts it in the requestq of
30  *          the appropriate clientwr
31  * clientwr monitors its requestq and sends requests
32  * clientrd looks for responses, processes them and puts them in the replyq of
33  *          the peer the request came from
34  * serverwr monitors its reply and sends replies
35  *
36  * In addition to the main thread, we have:
37  * If UDP peers are configured, there will be 2 + 2 * #peers UDP threads
38  * If TLS peers are configured, there will initially be 2 * #peers TLS threads
39  * For each TLS peer connecting to us there will be 2 more TLS threads
40  *       This is only for connected peers
41  * Example: With 3 UDP peer and 30 TLS peers, there will be a max of
42  *          1 + (2 + 2 * 3) + (2 * 30) + (2 * 30) = 129 threads
43 */
44
45 #include <netdb.h>
46 #include <unistd.h>
47 #include <sys/time.h>
48 #include <pthread.h>
49 #include <openssl/ssl.h>
50 #include <openssl/rand.h>
51 #include <openssl/err.h>
52 #include <openssl/md5.h>
53 #include <openssl/hmac.h>
54 #include "radsecproxy.h"
55
56 static struct options options;
57 static struct client *clients;
58 static struct server *servers;
59
60 static int client_udp_count = 0;
61 static int client_tls_count = 0;
62 static int client_count = 0;
63 static int server_udp_count = 0;
64 static int server_tls_count = 0;
65 static int server_count = 0;
66
67 static struct replyq udp_server_replyq;
68 static int udp_server_sock = -1;
69 static pthread_mutex_t *ssl_locks;
70 static long *ssl_lock_count;
71 static SSL_CTX *ssl_ctx = NULL;
72 extern int optind;
73 extern char *optarg;
74
75 /* callbacks for making OpenSSL thread safe */
76 unsigned long ssl_thread_id() {
77         return (unsigned long)pthread_self();
78 };
79
80 void ssl_locking_callback(int mode, int type, const char *file, int line) {
81     if (mode & CRYPTO_LOCK) {
82         pthread_mutex_lock(&ssl_locks[type]);
83         ssl_lock_count[type]++;
84     } else
85         pthread_mutex_unlock(&ssl_locks[type]);
86 }
87
88 static int verify_cb(int ok, X509_STORE_CTX *ctx) {
89   char buf[256];
90   X509 *err_cert;
91   int err, depth;
92
93   err_cert = X509_STORE_CTX_get_current_cert(ctx);
94   err = X509_STORE_CTX_get_error(ctx);
95   depth = X509_STORE_CTX_get_error_depth(ctx);
96
97   if (depth > MAX_CERT_DEPTH) {
98       ok = 0;
99       err = X509_V_ERR_CERT_CHAIN_TOO_LONG;
100       X509_STORE_CTX_set_error(ctx, err);
101   }
102
103   if (!ok) {
104       X509_NAME_oneline(X509_get_subject_name(err_cert), buf, 256);
105       printf("verify error: num=%d:%s:depth=%d:%s\n", err, X509_verify_cert_error_string(err), depth, buf);
106
107       switch (err) {
108       case X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT:
109           X509_NAME_oneline(X509_get_issuer_name(ctx->current_cert), buf, 256);
110           printf("issuer=%s\n", buf);
111           break;
112       case X509_V_ERR_CERT_NOT_YET_VALID:
113       case X509_V_ERR_ERROR_IN_CERT_NOT_BEFORE_FIELD:
114           printf("Certificate not yet valid\n");
115           break;
116       case X509_V_ERR_CERT_HAS_EXPIRED:
117           printf("Certificate has expired\n");
118           break;
119       case X509_V_ERR_ERROR_IN_CERT_NOT_AFTER_FIELD:
120           printf("Certificate no longer valid (after notAfter)\n");
121           break;
122       }
123   }
124   //  printf("certificate verify returns %d\n", ok);
125   return ok;
126 }
127
128 SSL_CTX *ssl_init() {
129     SSL_CTX *ctx;
130     int i;
131     unsigned long error;
132     
133     if (!options.tlscertificatefile || !options.tlscertificatekeyfile) {
134         printf("TLSCertificateFile and TLSCertificateKeyFile must be specified for TLS\n");
135         exit(1);
136     }
137     if (!options.tlscacertificatefile && !options.tlscacertificatepath) {
138         printf("CA Certificate file/path need to be configured\n");
139         exit(1);
140     }
141
142     ssl_locks = malloc(CRYPTO_num_locks() * sizeof(pthread_mutex_t));
143     ssl_lock_count = OPENSSL_malloc(CRYPTO_num_locks() * sizeof(long));
144     for (i = 0; i < CRYPTO_num_locks(); i++) {
145         ssl_lock_count[i] = 0;
146         pthread_mutex_init(&ssl_locks[i], NULL);
147     }
148     CRYPTO_set_id_callback(ssl_thread_id);
149     CRYPTO_set_locking_callback(ssl_locking_callback);
150
151     SSL_load_error_strings();
152     SSL_library_init();
153
154     while (!RAND_status()) {
155         time_t t = time(NULL);
156         pid_t pid = getpid();
157         RAND_seed((unsigned char *)&t, sizeof(time_t));
158         RAND_seed((unsigned char *)&pid, sizeof(pid));
159     }
160
161     ctx = SSL_CTX_new(TLSv1_method());
162     if (SSL_CTX_use_certificate_chain_file(ctx, options.tlscertificatefile) &&
163         SSL_CTX_use_PrivateKey_file(ctx, options.tlscertificatekeyfile, SSL_FILETYPE_PEM) &&
164         SSL_CTX_check_private_key(ctx) &&
165         SSL_CTX_load_verify_locations(ctx, options.tlscacertificatefile, options.tlscacertificatepath)) {
166         SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, verify_cb);
167         SSL_CTX_set_verify_depth(ctx, MAX_CERT_DEPTH + 1);
168         return ctx;
169     }
170
171     while ((error = ERR_get_error()))
172         err("SSL: %s", ERR_error_string(error, NULL));
173     exit(1);
174 }    
175
176 void printauth(char *s, unsigned char *t) {
177     int i;
178     printf("%s:", s);
179     for (i = 0; i < 16; i++)
180             printf("%02x ", t[i]);
181     printf("\n");
182 }
183
184 int resolvepeer(struct peer *peer) {
185     struct addrinfo hints, *addrinfo;
186     
187     memset(&hints, 0, sizeof(hints));
188     hints.ai_socktype = (peer->type == 'T' ? SOCK_STREAM : SOCK_DGRAM);
189     hints.ai_family = AF_UNSPEC;
190     if (getaddrinfo(peer->host, peer->port, &hints, &addrinfo)) {
191         err("resolvepeer: can't resolve %s port %s", peer->host, peer->port);
192         return 0;
193     }
194
195     if (peer->addrinfo)
196         freeaddrinfo(peer->addrinfo);
197     peer->addrinfo = addrinfo;
198     return 1;
199 }         
200
201 int connecttoserver(struct addrinfo *addrinfo) {
202     int s;
203     struct addrinfo *res;
204     
205     for (res = addrinfo; res; res = res->ai_next) {
206         s = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
207         if (s < 0) {
208             err("connecttoserver: socket failed");
209             continue;
210         }
211         if (connect(s, res->ai_addr, res->ai_addrlen) == 0)
212             break;
213         err("connecttoserver: connect failed");
214         close(s);
215         s = -1;
216     }
217     return s;
218 }         
219
220 /* returns the client with matching address, or NULL */
221 /* if client argument is not NULL, we only check that one client */
222 struct client *find_client(char type, struct sockaddr *addr, struct client *client) {
223     struct sockaddr_in6 *sa6;
224     struct in_addr *a4 = NULL;
225     struct client *c;
226     int i;
227     struct addrinfo *res;
228
229     if (addr->sa_family == AF_INET6) {
230         sa6 = (struct sockaddr_in6 *)addr;
231         if (IN6_IS_ADDR_V4MAPPED(&sa6->sin6_addr))
232             a4 = (struct in_addr *)&sa6->sin6_addr.s6_addr[12];
233     } else
234         a4 = &((struct sockaddr_in *)addr)->sin_addr;
235
236     c = (client ? client : clients);
237     for (i = 0; i < client_count; i++) {
238         if (c->peer.type == type)
239             for (res = c->peer.addrinfo; res; res = res->ai_next)
240                 if ((a4 && res->ai_family == AF_INET &&
241                      !memcmp(a4, &((struct sockaddr_in *)res->ai_addr)->sin_addr, 4)) ||
242                     (res->ai_family == AF_INET6 &&
243                      !memcmp(&sa6->sin6_addr, &((struct sockaddr_in6 *)res->ai_addr)->sin6_addr, 16)))
244                     return c;
245         if (client)
246             break;
247         c++;
248     }
249     return NULL;
250 }
251
252 /* returns the server with matching address, or NULL */
253 /* if server argument is not NULL, we only check that one server */
254 struct server *find_server(char type, struct sockaddr *addr, struct server *server) {
255     struct sockaddr_in6 *sa6;
256     struct in_addr *a4 = NULL;
257     struct server *s;
258     int i;
259     struct addrinfo *res;
260
261     if (addr->sa_family == AF_INET6) {
262         sa6 = (struct sockaddr_in6 *)addr;
263         if (IN6_IS_ADDR_V4MAPPED(&sa6->sin6_addr))
264             a4 = (struct in_addr *)&sa6->sin6_addr.s6_addr[12];
265     } else
266         a4 = &((struct sockaddr_in *)addr)->sin_addr;
267
268     s = (server ? server : servers);
269     for (i = 0; i < server_count; i++) {
270         if (s->peer.type == type)
271             for (res = s->peer.addrinfo; res; res = res->ai_next)
272                 if ((a4 && res->ai_family == AF_INET &&
273                      !memcmp(a4, &((struct sockaddr_in *)res->ai_addr)->sin_addr, 4)) ||
274                     (res->ai_family == AF_INET6 &&
275                      !memcmp(&sa6->sin6_addr, &((struct sockaddr_in6 *)res->ai_addr)->sin6_addr, 16)))
276                     return s;
277         if (server)
278             break;
279         s++;
280     }
281     return NULL;
282 }
283
284 /* exactly one of client and server must be non-NULL */
285 /* if *peer == NULL we return who we received from, else require it to be from peer */
286 /* return from in sa if not NULL */
287 unsigned char *radudpget(int s, struct client **client, struct server **server, struct sockaddr_storage *sa) {
288     int cnt, len;
289     void *f;
290     unsigned char buf[65536], *rad;
291     struct sockaddr_storage from;
292     socklen_t fromlen = sizeof(from);
293
294     for (;;) {
295         cnt = recvfrom(s, buf, sizeof(buf), 0, (struct sockaddr *)&from, &fromlen);
296         if (cnt == -1) {
297             err("radudpget: recv failed");
298             continue;
299         }
300         printf("radudpget: got %d bytes from %s\n", cnt, addr2string((struct sockaddr *)&from, fromlen));
301
302         if (cnt < 20) {
303             printf("radudpget: packet too small\n");
304             continue;
305         }
306     
307         len = RADLEN(buf);
308
309         if (cnt < len) {
310             printf("radudpget: packet smaller than length field in radius header\n");
311             continue;
312         }
313         if (cnt > len)
314             printf("radudpget: packet was padded with %d bytes\n", cnt - len);
315
316         f = (client
317              ? (void *)find_client('U', (struct sockaddr *)&from, *client)
318              : (void *)find_server('U', (struct sockaddr *)&from, *server));
319         if (!f) {
320             printf("radudpget: got packet from wrong or unknown UDP peer, ignoring\n");
321             continue;
322         }
323
324         rad = malloc(len);
325         if (rad)
326             break;
327         err("radudpget: malloc failed");
328     }
329     memcpy(rad, buf, len);
330     if (client)
331         *client = (struct client *)f; /* only need this if *client == NULL, but if not NULL *client == f here */
332     else
333         *server = (struct server *)f; /* only need this if *server == NULL, but if not NULL *server == f here */
334     if (sa)
335         *sa = from;
336     return rad;
337 }
338
339 int tlsverifycert(struct peer *peer) {
340     int i, l, loc;
341     X509 *cert;
342     X509_NAME *nm;
343     X509_NAME_ENTRY *e;
344     unsigned char *v;
345     unsigned long error;
346
347 #if 1
348     if (SSL_get_verify_result(peer->ssl) != X509_V_OK) {
349         printf("tlsverifycert: basic validation failed\n");
350         while ((error = ERR_get_error()))
351             err("clientwr: TLS: %s", ERR_error_string(error, NULL));
352         return 0;
353     }
354 #endif    
355     cert = SSL_get_peer_certificate(peer->ssl);
356     if (!cert) {
357         printf("tlsverifycert: failed to obtain certificate\n");
358         return 0;
359     }
360     nm = X509_get_subject_name(cert);
361     loc = -1;
362     for (;;) {
363         loc = X509_NAME_get_index_by_NID(nm, NID_commonName, loc);
364         if (loc == -1)
365             break;
366         e = X509_NAME_get_entry(nm, loc);
367         l = ASN1_STRING_to_UTF8(&v, X509_NAME_ENTRY_get_data(e));
368         if (l < 0)
369             continue;
370         printf("cn: ");
371         for (i = 0; i < l; i++)
372             printf("%c", v[i]);
373         printf("\n");
374         if (l == strlen(peer->host) && !strncasecmp(peer->host, v, l)) {
375             printf("tlsverifycert: Found cn matching host %s, All OK\n", peer->host);
376             return 1;
377         }
378         printf("tlsverifycert: cn not matching host %s\n", peer->host);
379     }
380     X509_free(cert);
381     return 0;
382 }
383
384 void tlsconnect(struct server *server, struct timeval *when, char *text) {
385     struct timeval now;
386     time_t elapsed;
387
388     printf("tlsconnect called from %s\n", text);
389     pthread_mutex_lock(&server->lock);
390     if (when && memcmp(&server->lastconnecttry, when, sizeof(struct timeval))) {
391         /* already reconnected, nothing to do */
392         printf("tlsconnect(%s): seems already reconnected\n", text);
393         pthread_mutex_unlock(&server->lock);
394         return;
395     }
396
397     printf("tlsconnect %s\n", text);
398
399     for (;;) {
400         gettimeofday(&now, NULL);
401         elapsed = now.tv_sec - server->lastconnecttry.tv_sec;
402         if (server->connectionok) {
403             server->connectionok = 0;
404             sleep(10);
405         } else if (elapsed < 5)
406             sleep(10);
407         else if (elapsed < 600)
408             sleep(elapsed * 2);
409         else if (elapsed < 10000)
410                 sleep(900);
411         else
412             server->lastconnecttry.tv_sec = now.tv_sec;  // no sleep at startup
413         printf("tlsconnect: trying to open TLS connection to %s port %s\n", server->peer.host, server->peer.port);
414         if (server->sock >= 0)
415             close(server->sock);
416         if ((server->sock = connecttoserver(server->peer.addrinfo)) < 0)
417             continue;
418         SSL_free(server->peer.ssl);
419         server->peer.ssl = SSL_new(ssl_ctx);
420         SSL_set_fd(server->peer.ssl, server->sock);
421         if (SSL_connect(server->peer.ssl) > 0 && tlsverifycert(&server->peer))
422             break;
423     }
424     printf("tlsconnect: TLS connection to %s port %s up\n", server->peer.host, server->peer.port);
425     gettimeofday(&server->lastconnecttry, NULL);
426     pthread_mutex_unlock(&server->lock);
427 }
428
429 unsigned char *radtlsget(SSL *ssl) {
430     int cnt, total, len;
431     unsigned char buf[4], *rad;
432
433     for (;;) {
434         for (total = 0; total < 4; total += cnt) {
435             cnt = SSL_read(ssl, buf + total, 4 - total);
436             if (cnt <= 0) {
437                 printf("radtlsget: connection lost\n");
438                 if (SSL_get_error(ssl, cnt) == SSL_ERROR_ZERO_RETURN) {
439                     //remote end sent close_notify, send one back
440                     SSL_shutdown(ssl);
441                 }
442                 return NULL;
443             }
444         }
445
446         len = RADLEN(buf);
447         rad = malloc(len);
448         if (!rad) {
449             err("radtlsget: malloc failed");
450             continue;
451         }
452         memcpy(rad, buf, 4);
453
454         for (; total < len; total += cnt) {
455             cnt = SSL_read(ssl, rad + total, len - total);
456             if (cnt <= 0) {
457                 printf("radtlsget: connection lost\n");
458                 if (SSL_get_error(ssl, cnt) == SSL_ERROR_ZERO_RETURN) {
459                     //remote end sent close_notify, send one back
460                     SSL_shutdown(ssl);
461                 }
462                 free(rad);
463                 return NULL;
464             }
465         }
466     
467         if (total >= 20)
468             break;
469         
470         free(rad);
471         printf("radtlsget: packet smaller than minimum radius size\n");
472     }
473     
474     printf("radtlsget: got %d bytes\n", total);
475     return rad;
476 }
477
478 int clientradput(struct server *server, unsigned char *rad) {
479     int cnt;
480     size_t len;
481     unsigned long error;
482     struct timeval lastconnecttry;
483     
484     len = RADLEN(rad);
485     if (server->peer.type == 'U') {
486         if (send(server->sock, rad, len, 0) >= 0) {
487             printf("clienradput: sent UDP of length %d to %s port %s\n", len, server->peer.host, server->peer.port);
488             return 1;
489         }
490         err("clientradput: send failed");
491         return 0;
492     }
493
494     lastconnecttry = server->lastconnecttry;
495     while ((cnt = SSL_write(server->peer.ssl, rad, len)) <= 0) {
496         while ((error = ERR_get_error()))
497             err("clientwr: TLS: %s", ERR_error_string(error, NULL));
498         tlsconnect(server, &lastconnecttry, "clientradput");
499         lastconnecttry = server->lastconnecttry;
500     }
501
502     server->connectionok = 1;
503     printf("clientradput: Sent %d bytes, Radius packet of length %d to TLS peer %s\n",
504            cnt, len, server->peer.host);
505     return 1;
506 }
507
508 int radsign(unsigned char *rad, unsigned char *sec) {
509     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
510     static unsigned char first = 1;
511     static EVP_MD_CTX mdctx;
512     unsigned int md_len;
513     int result;
514     
515     pthread_mutex_lock(&lock);
516     if (first) {
517         EVP_MD_CTX_init(&mdctx);
518         first = 0;
519     }
520
521     result = (EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) &&
522         EVP_DigestUpdate(&mdctx, rad, RADLEN(rad)) &&
523         EVP_DigestUpdate(&mdctx, sec, strlen(sec)) &&
524         EVP_DigestFinal_ex(&mdctx, rad + 4, &md_len) &&
525         md_len == 16);
526     pthread_mutex_unlock(&lock);
527     return result;
528 }
529
530 int validauth(unsigned char *rad, unsigned char *reqauth, unsigned char *sec) {
531     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
532     static unsigned char first = 1;
533     static EVP_MD_CTX mdctx;
534     unsigned char hash[EVP_MAX_MD_SIZE];
535     unsigned int len;
536     int result;
537     
538     pthread_mutex_lock(&lock);
539     if (first) {
540         EVP_MD_CTX_init(&mdctx);
541         first = 0;
542     }
543
544     len = RADLEN(rad);
545     
546     result = (EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) &&
547               EVP_DigestUpdate(&mdctx, rad, 4) &&
548               EVP_DigestUpdate(&mdctx, reqauth, 16) &&
549               (len <= 20 || EVP_DigestUpdate(&mdctx, rad + 20, len - 20)) &&
550               EVP_DigestUpdate(&mdctx, sec, strlen(sec)) &&
551               EVP_DigestFinal_ex(&mdctx, hash, &len) &&
552               len == 16 &&
553               !memcmp(hash, rad + 4, 16));
554     pthread_mutex_unlock(&lock);
555     return result;
556 }
557               
558 int checkmessageauth(char *rad, uint8_t *authattr, char *secret) {
559     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
560     static unsigned char first = 1;
561     static HMAC_CTX hmacctx;
562     unsigned int md_len;
563     uint8_t auth[16], hash[EVP_MAX_MD_SIZE];
564     
565     pthread_mutex_lock(&lock);
566     if (first) {
567         HMAC_CTX_init(&hmacctx);
568         first = 0;
569     }
570
571     memcpy(auth, authattr, 16);
572     memset(authattr, 0, 16);
573     md_len = 0;
574     HMAC_Init_ex(&hmacctx, secret, strlen(secret), EVP_md5(), NULL);
575     HMAC_Update(&hmacctx, rad, RADLEN(rad));
576     HMAC_Final(&hmacctx, hash, &md_len);
577     memcpy(authattr, auth, 16);
578     if (md_len != 16) {
579         printf("message auth computation failed\n");
580         pthread_mutex_unlock(&lock);
581         return 0;
582     }
583
584     if (memcmp(auth, hash, 16)) {
585         printf("message authenticator, wrong value\n");
586         pthread_mutex_unlock(&lock);
587         return 0;
588     }   
589         
590     pthread_mutex_unlock(&lock);
591     return 1;
592 }
593
594 int createmessageauth(char *rad, char *authattrval, char *secret) {
595     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
596     static unsigned char first = 1;
597     static HMAC_CTX hmacctx;
598     unsigned int md_len;
599
600     if (!authattrval)
601         return 1;
602     
603     pthread_mutex_lock(&lock);
604     if (first) {
605         HMAC_CTX_init(&hmacctx);
606         first = 0;
607     }
608
609     memset(authattrval, 0, 16);
610     md_len = 0;
611     HMAC_Init_ex(&hmacctx, secret, strlen(secret), EVP_md5(), NULL);
612     HMAC_Update(&hmacctx, rad, RADLEN(rad));
613     HMAC_Final(&hmacctx, authattrval, &md_len);
614     if (md_len != 16) {
615         printf("message auth computation failed\n");
616         pthread_mutex_unlock(&lock);
617         return 0;
618     }
619
620     pthread_mutex_unlock(&lock);
621     return 1;
622 }
623
624 void sendrq(struct server *to, struct client *from, struct request *rq) {
625     int i;
626     
627     pthread_mutex_lock(&to->newrq_mutex);
628     /* might simplify if only try nextid, might be ok */
629     for (i = to->nextid; i < MAX_REQUESTS; i++)
630         if (!to->requests[i].buf)
631             break;
632     if (i == MAX_REQUESTS) {
633         for (i = 0; i < to->nextid; i++)
634             if (!to->requests[i].buf)
635                 break;
636         if (i == to->nextid) {
637             printf("No room in queue, dropping request\n");
638             pthread_mutex_unlock(&to->newrq_mutex);
639             return;
640         }
641     }
642     
643     to->nextid = i + 1;
644     rq->buf[1] = (char)i;
645     printf("sendrq: inserting packet with id %d in queue for %s\n", i, to->peer.host);
646     
647     if (!createmessageauth(rq->buf, rq->messageauthattrval, to->peer.secret))
648         return;
649
650     to->requests[i] = *rq;
651
652     if (!to->newrq) {
653         to->newrq = 1;
654         printf("signalling client writer\n");
655         pthread_cond_signal(&to->newrq_cond);
656     }
657     pthread_mutex_unlock(&to->newrq_mutex);
658 }
659
660 void sendreply(struct client *to, struct server *from, char *buf, struct sockaddr_storage *tosa) {
661     struct replyq *replyq = to->replyq;
662     
663     pthread_mutex_lock(&replyq->count_mutex);
664     if (replyq->count == replyq->size) {
665         printf("No room in queue, dropping request\n");
666         pthread_mutex_unlock(&replyq->count_mutex);
667         return;
668     }
669
670     replyq->replies[replyq->count].buf = buf;
671     if (tosa)
672         replyq->replies[replyq->count].tosa = *tosa;
673     replyq->count++;
674
675     if (replyq->count == 1) {
676         printf("signalling client writer\n");
677         pthread_cond_signal(&replyq->count_cond);
678     }
679     pthread_mutex_unlock(&replyq->count_mutex);
680 }
681
682 int pwdencrypt(uint8_t *in, uint8_t len, uint8_t *shared, uint8_t sharedlen, uint8_t *auth) {
683     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
684     static unsigned char first = 1;
685     static EVP_MD_CTX mdctx;
686     unsigned char hash[EVP_MAX_MD_SIZE], *input;
687     unsigned int md_len;
688     uint8_t i, offset = 0, out[128];
689     
690     pthread_mutex_lock(&lock);
691     if (first) {
692         EVP_MD_CTX_init(&mdctx);
693         first = 0;
694     }
695
696     input = auth;
697     for (;;) {
698         if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
699             !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
700             !EVP_DigestUpdate(&mdctx, input, 16) ||
701             !EVP_DigestFinal_ex(&mdctx, hash, &md_len) ||
702             md_len != 16) {
703             pthread_mutex_unlock(&lock);
704             return 0;
705         }
706         for (i = 0; i < 16; i++)
707             out[offset + i] = hash[i] ^ in[offset + i];
708         input = out + offset - 16;
709         offset += 16;
710         if (offset == len)
711             break;
712     }
713     memcpy(in, out, len);
714     pthread_mutex_unlock(&lock);
715     return 1;
716 }
717
718 int pwddecrypt(uint8_t *in, uint8_t len, uint8_t *shared, uint8_t sharedlen, uint8_t *auth) {
719     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
720     static unsigned char first = 1;
721     static EVP_MD_CTX mdctx;
722     unsigned char hash[EVP_MAX_MD_SIZE], *input;
723     unsigned int md_len;
724     uint8_t i, offset = 0, out[128];
725     
726     pthread_mutex_lock(&lock);
727     if (first) {
728         EVP_MD_CTX_init(&mdctx);
729         first = 0;
730     }
731
732     input = auth;
733     for (;;) {
734         if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
735             !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
736             !EVP_DigestUpdate(&mdctx, input, 16) ||
737             !EVP_DigestFinal_ex(&mdctx, hash, &md_len) ||
738             md_len != 16) {
739             pthread_mutex_unlock(&lock);
740             return 0;
741         }
742         for (i = 0; i < 16; i++)
743             out[offset + i] = hash[i] ^ in[offset + i];
744         input = in + offset;
745         offset += 16;
746         if (offset == len)
747             break;
748     }
749     memcpy(in, out, len);
750     pthread_mutex_unlock(&lock);
751     return 1;
752 }
753
754 int msmppencrypt(uint8_t *text, uint8_t len, uint8_t *shared, uint8_t sharedlen, uint8_t *auth, uint8_t *salt) {
755     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
756     static unsigned char first = 1;
757     static EVP_MD_CTX mdctx;
758     unsigned char hash[EVP_MAX_MD_SIZE];
759     unsigned int md_len;
760     uint8_t i, offset;
761     
762     pthread_mutex_lock(&lock);
763     if (first) {
764         EVP_MD_CTX_init(&mdctx);
765         first = 0;
766     }
767
768 #if 0    
769     printf("msppencrypt auth in: ");
770     for (i = 0; i < 16; i++)
771         printf("%02x ", auth[i]);
772     printf("\n");
773     
774     printf("msppencrypt salt in: ");
775     for (i = 0; i < 2; i++)
776         printf("%02x ", salt[i]);
777     printf("\n");
778     
779     printf("msppencrypt in: ");
780     for (i = 0; i < len; i++)
781         printf("%02x ", text[i]);
782     printf("\n");
783 #endif
784     
785     if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
786         !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
787         !EVP_DigestUpdate(&mdctx, auth, 16) ||
788         !EVP_DigestUpdate(&mdctx, salt, 2) ||
789         !EVP_DigestFinal_ex(&mdctx, hash, &md_len)) {
790         pthread_mutex_unlock(&lock);
791         return 0;
792     }
793
794 #if 0    
795     printf("msppencrypt hash: ");
796     for (i = 0; i < 16; i++)
797         printf("%02x ", hash[i]);
798     printf("\n");
799 #endif
800     
801     for (i = 0; i < 16; i++)
802         text[i] ^= hash[i];
803     
804     for (offset = 16; offset < len; offset += 16) {
805 #if 0   
806         printf("text + offset - 16 c(%d): ", offset / 16);
807         for (i = 0; i < 16; i++)
808             printf("%02x ", (text + offset - 16)[i]);
809         printf("\n");
810 #endif
811         if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
812             !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
813             !EVP_DigestUpdate(&mdctx, text + offset - 16, 16) ||
814             !EVP_DigestFinal_ex(&mdctx, hash, &md_len) ||
815             md_len != 16) {
816             pthread_mutex_unlock(&lock);
817             return 0;
818         }
819 #if 0   
820         printf("msppencrypt hash: ");
821         for (i = 0; i < 16; i++)
822             printf("%02x ", hash[i]);
823         printf("\n");
824 #endif    
825         
826         for (i = 0; i < 16; i++)
827             text[offset + i] ^= hash[i];
828     }
829     
830 #if 0
831     printf("msppencrypt out: ");
832     for (i = 0; i < len; i++)
833         printf("%02x ", text[i]);
834     printf("\n");
835 #endif
836
837     pthread_mutex_unlock(&lock);
838     return 1;
839 }
840
841 int msmppdecrypt(uint8_t *text, uint8_t len, uint8_t *shared, uint8_t sharedlen, uint8_t *auth, uint8_t *salt) {
842     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
843     static unsigned char first = 1;
844     static EVP_MD_CTX mdctx;
845     unsigned char hash[EVP_MAX_MD_SIZE];
846     unsigned int md_len;
847     uint8_t i, offset;
848     char plain[255];
849     
850     pthread_mutex_lock(&lock);
851     if (first) {
852         EVP_MD_CTX_init(&mdctx);
853         first = 0;
854     }
855
856 #if 0    
857     printf("msppdecrypt auth in: ");
858     for (i = 0; i < 16; i++)
859         printf("%02x ", auth[i]);
860     printf("\n");
861     
862     printf("msppedecrypt salt in: ");
863     for (i = 0; i < 2; i++)
864         printf("%02x ", salt[i]);
865     printf("\n");
866     
867     printf("msppedecrypt in: ");
868     for (i = 0; i < len; i++)
869         printf("%02x ", text[i]);
870     printf("\n");
871 #endif
872     
873     if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
874         !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
875         !EVP_DigestUpdate(&mdctx, auth, 16) ||
876         !EVP_DigestUpdate(&mdctx, salt, 2) ||
877         !EVP_DigestFinal_ex(&mdctx, hash, &md_len)) {
878         pthread_mutex_unlock(&lock);
879         return 0;
880     }
881
882 #if 0    
883     printf("msppedecrypt hash: ");
884     for (i = 0; i < 16; i++)
885         printf("%02x ", hash[i]);
886     printf("\n");
887 #endif
888     
889     for (i = 0; i < 16; i++)
890         plain[i] = text[i] ^ hash[i];
891     
892     for (offset = 16; offset < len; offset += 16) {
893 #if 0   
894         printf("text + offset - 16 c(%d): ", offset / 16);
895         for (i = 0; i < 16; i++)
896             printf("%02x ", (text + offset - 16)[i]);
897         printf("\n");
898 #endif
899         if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
900             !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
901             !EVP_DigestUpdate(&mdctx, text + offset - 16, 16) ||
902             !EVP_DigestFinal_ex(&mdctx, hash, &md_len) ||
903             md_len != 16) {
904             pthread_mutex_unlock(&lock);
905             return 0;
906         }
907 #if 0   
908     printf("msppedecrypt hash: ");
909     for (i = 0; i < 16; i++)
910         printf("%02x ", hash[i]);
911     printf("\n");
912 #endif    
913
914     for (i = 0; i < 16; i++)
915         plain[offset + i] = text[offset + i] ^ hash[i];
916     }
917
918     memcpy(text, plain, len);
919 #if 0
920     printf("msppedecrypt out: ");
921     for (i = 0; i < len; i++)
922         printf("%02x ", text[i]);
923     printf("\n");
924 #endif
925
926     pthread_mutex_unlock(&lock);
927     return 1;
928 }
929
930 struct server *id2server(char *id, uint8_t len) {
931     int i;
932     char **realm, *idrealm;
933
934     idrealm = strchr(id, '@');
935     if (idrealm) {
936         idrealm++;
937         len -= idrealm - id;
938     } else {
939         idrealm = "-";
940         len = 1;
941     }
942     for (i = 0; i < server_count; i++) {
943         for (realm = servers[i].realms; *realm; realm++) {
944             if ((strlen(*realm) == 1 && **realm == '*') ||
945                 (strlen(*realm) == len && !memcmp(idrealm, *realm, len))) {
946                 printf("found matching realm: %s, host %s\n", *realm, servers[i].peer.host);
947                 return servers + i;
948             }
949         }
950     }
951     return NULL;
952 }
953
954 int rqinqueue(struct server *to, struct client *from, uint8_t id) {
955     int i;
956     
957     pthread_mutex_lock(&to->newrq_mutex);
958     for (i = 0; i < MAX_REQUESTS; i++)
959         if (to->requests[i].buf && to->requests[i].origid == id && to->requests[i].from == from)
960             break;
961     pthread_mutex_unlock(&to->newrq_mutex);
962     
963     return i < MAX_REQUESTS;
964 }
965         
966 struct server *radsrv(struct request *rq, char *buf, struct client *from) {
967     uint8_t code, id, *auth, *attr, attrvallen;
968     uint8_t *usernameattr = NULL, *userpwdattr = NULL, *tunnelpwdattr = NULL, *messageauthattr = NULL;
969     int i;
970     uint16_t len;
971     int left;
972     struct server *to;
973     unsigned char newauth[16];
974     
975     code = *(uint8_t *)buf;
976     id = *(uint8_t *)(buf + 1);
977     len = RADLEN(buf);
978     auth = (uint8_t *)(buf + 4);
979
980     printf("radsrv: code %d, id %d, length %d\n", code, id, len);
981     
982     if (code != RAD_Access_Request) {
983         printf("radsrv: server currently accepts only access-requests, ignoring\n");
984         return NULL;
985     }
986
987     left = len - 20;
988     attr = buf + 20;
989     
990     while (left > 1) {
991         left -= attr[RAD_Attr_Length];
992         if (left < 0) {
993             printf("radsrv: attribute length exceeds packet length, ignoring packet\n");
994             return NULL;
995         }
996         switch (attr[RAD_Attr_Type]) {
997         case RAD_Attr_User_Name:
998             usernameattr = attr;
999             break;
1000         case RAD_Attr_User_Password:
1001             userpwdattr = attr;
1002             break;
1003         case RAD_Attr_Tunnel_Password:
1004             tunnelpwdattr = attr;
1005             break;
1006         case RAD_Attr_Message_Authenticator:
1007             messageauthattr = attr;
1008             break;
1009         }
1010         attr += attr[RAD_Attr_Length];
1011     }
1012     if (left)
1013         printf("radsrv: malformed packet? remaining byte after last attribute\n");
1014
1015     if (usernameattr) {
1016         printf("radsrv: Username: ");
1017         for (i = 0; i < usernameattr[RAD_Attr_Length] - 2; i++)
1018             printf("%c", usernameattr[RAD_Attr_Value + i]);
1019         printf("\n");
1020     }
1021
1022     to = id2server(&usernameattr[RAD_Attr_Value], usernameattr[RAD_Attr_Length] - 2);
1023     if (!to) {
1024         printf("radsrv: ignoring request, don't know where to send it\n");
1025         return NULL;
1026     }
1027
1028     if (rqinqueue(to, from, id)) {
1029         printf("radsrv: ignoring request from host %s with id %d, already got one\n", from->peer.host, id);
1030         return NULL;
1031     }
1032     
1033     if (messageauthattr && (messageauthattr[RAD_Attr_Length] != 18 ||
1034                             !checkmessageauth(buf, &messageauthattr[RAD_Attr_Value], from->peer.secret))) {
1035         printf("radsrv: message authentication failed\n");
1036         return NULL;
1037     }
1038
1039     if (!RAND_bytes(newauth, 16)) {
1040         printf("radsrv: failed to generate random auth\n");
1041         return NULL;
1042     }
1043
1044     printauth("auth", auth);
1045     printauth("newauth", newauth);
1046     
1047     if (userpwdattr) {
1048         printf("radsrv: found userpwdattr of length %d\n", userpwdattr[RAD_Attr_Length]);
1049         attrvallen = userpwdattr[RAD_Attr_Length] - 2;
1050         if (attrvallen < 16 || attrvallen > 128 || attrvallen % 16) {
1051             printf("radsrv: invalid user password length\n");
1052             return NULL;
1053         }
1054         
1055         if (!pwddecrypt(&userpwdattr[RAD_Attr_Value], attrvallen, from->peer.secret, strlen(from->peer.secret), auth)) {
1056             printf("radsrv: cannot decrypt password\n");
1057             return NULL;
1058         }
1059         printf("radsrv: password: ");
1060         for (i = 0; i < attrvallen; i++)
1061             printf("%02x ", userpwdattr[RAD_Attr_Value + i]);
1062         printf("\n");
1063         if (!pwdencrypt(&userpwdattr[RAD_Attr_Value], attrvallen, to->peer.secret, strlen(to->peer.secret), newauth)) {
1064             printf("radsrv: cannot encrypt password\n");
1065             return NULL;
1066         }
1067     }
1068
1069     if (tunnelpwdattr) {
1070         printf("radsrv: found tunnelpwdattr of length %d\n", tunnelpwdattr[RAD_Attr_Length]);
1071         attrvallen = tunnelpwdattr[RAD_Attr_Length] - 2;
1072         if (attrvallen < 16 || attrvallen > 128 || attrvallen % 16) {
1073             printf("radsrv: invalid user password length\n");
1074             return NULL;
1075         }
1076         
1077         if (!pwddecrypt(&tunnelpwdattr[RAD_Attr_Value], attrvallen, from->peer.secret, strlen(from->peer.secret), auth)) {
1078             printf("radsrv: cannot decrypt password\n");
1079             return NULL;
1080         }
1081         printf("radsrv: password: ");
1082         for (i = 0; i < attrvallen; i++)
1083             printf("%02x ", tunnelpwdattr[RAD_Attr_Value + i]);
1084         printf("\n");
1085         if (!pwdencrypt(&tunnelpwdattr[RAD_Attr_Value], attrvallen, to->peer.secret, strlen(to->peer.secret), newauth)) {
1086             printf("radsrv: cannot encrypt password\n");
1087             return NULL;
1088         }
1089     }
1090
1091     rq->buf = buf;
1092     rq->from = from;
1093     rq->origid = id;
1094     rq->messageauthattrval = (messageauthattr ? &messageauthattr[RAD_Attr_Value] : NULL);
1095     memcpy(rq->origauth, auth, 16);
1096     memcpy(auth, newauth, 16);
1097     printauth("rq->origauth", rq->origauth);
1098     printauth("auth", auth);
1099     return to;
1100 }
1101
1102 void *clientrd(void *arg) {
1103     struct server *server = (struct server *)arg;
1104     struct client *from;
1105     int i, left, subleft;
1106     unsigned char *buf, *messageauthattr, *subattr, *attr;
1107     struct sockaddr_storage fromsa;
1108     struct timeval lastconnecttry;
1109     char tmp[255];
1110     
1111     for (;;) {
1112     getnext:
1113         lastconnecttry = server->lastconnecttry;
1114         buf = (server->peer.type == 'U' ? radudpget(server->sock, NULL, &server, NULL) : radtlsget(server->peer.ssl));
1115         if (!buf && server->peer.type == 'T') {
1116             tlsconnect(server, &lastconnecttry, "clientrd");
1117             continue;
1118         }
1119     
1120         server->connectionok = 1;
1121
1122         if (*buf != RAD_Access_Accept && *buf != RAD_Access_Reject && *buf != RAD_Access_Challenge) {
1123             printf("clientrd: discarding, only accept access accept, access reject and access challenge messages\n");
1124             continue;
1125         }
1126         
1127         i = buf[1]; /* i is the id */
1128
1129         pthread_mutex_lock(&server->newrq_mutex);
1130         if (!server->requests[i].buf || !server->requests[i].tries) {
1131             pthread_mutex_unlock(&server->newrq_mutex);
1132             printf("clientrd: no matching request sent with this id, ignoring\n");
1133             continue;
1134         }
1135
1136         if (server->requests[i].received) {
1137             pthread_mutex_unlock(&server->newrq_mutex);
1138             printf("clientrd: already received, ignoring\n");
1139             continue;
1140         }
1141         
1142         if (!validauth(buf, server->requests[i].buf + 4, server->peer.secret)) {
1143             pthread_mutex_unlock(&server->newrq_mutex);
1144             printf("clientrd: invalid auth, ignoring\n");
1145             continue;
1146         }
1147         
1148         from = server->requests[i].from;
1149
1150
1151         /* messageauthattr present? */
1152         messageauthattr = NULL;
1153         left = RADLEN(buf) - 20;
1154         attr = buf + 20;
1155         while (left > 1) {
1156             left -= attr[RAD_Attr_Length];
1157             if (left < 0) {
1158                 printf("clientrd: attribute length exceeds packet length, ignoring packet\n");
1159                 goto getnext;
1160             }
1161             if (attr[RAD_Attr_Type] == RAD_Attr_Message_Authenticator) {
1162                 if (attr[RAD_Attr_Length] != 18) {
1163                     printf("clientrd: illegal message auth attribute length, ignoring packet\n");
1164                     goto getnext;
1165                 }
1166                 memcpy(tmp, buf + 4, 16);
1167                 memcpy(buf + 4, server->requests[i].buf + 4, 16);
1168                 if (!checkmessageauth(buf, &attr[RAD_Attr_Value], server->peer.secret)) {
1169                     printf("clientrd: message authentication failed\n");
1170                     goto getnext;
1171                 }
1172                 memcpy(buf + 4, tmp, 16);
1173                 printf("clientrd: message auth ok\n");
1174                 messageauthattr = attr;
1175                 break;
1176             }
1177             attr += attr[RAD_Attr_Length];
1178         }
1179
1180         /* handle MS MPPE */
1181         left = RADLEN(buf) - 20;
1182         attr = buf + 20;
1183         while (left > 1) {
1184             left -= attr[RAD_Attr_Length];
1185             if (left < 0) {
1186                 printf("clientrd: attribute length exceeds packet length, ignoring packet\n");
1187                 goto getnext;
1188             }
1189             if (attr[RAD_Attr_Type] == RAD_Attr_Vendor_Specific &&
1190                 ((uint16_t *)attr)[1] == 0 && ntohs(((uint16_t *)attr)[2]) == 311) { // 311 == MS
1191                 subleft = attr[RAD_Attr_Length] - 6;
1192                 subattr = attr + 6;
1193                 while (subleft > 1) {
1194                     subleft -= subattr[RAD_Attr_Length];
1195                     if (subleft < 0)
1196                         break;
1197                     if (subattr[RAD_Attr_Type] != RAD_VS_ATTR_MS_MPPE_Send_Key &&
1198                         subattr[RAD_Attr_Type] != RAD_VS_ATTR_MS_MPPE_Recv_Key)
1199                         continue;
1200                     printf("clientrd: Got MS MPPE\n");
1201                     if (subattr[RAD_Attr_Length] < 20)
1202                         continue;
1203
1204                     if (!msmppdecrypt(subattr + 4, subattr[RAD_Attr_Length] - 4,
1205                             server->peer.secret, strlen(server->peer.secret), server->requests[i].buf + 4, subattr + 2)) {
1206                         printf("clientrd: failed to decrypt msppe key\n");
1207                         continue;
1208                     }
1209
1210                     if (!msmppencrypt(subattr + 4, subattr[RAD_Attr_Length] - 4,
1211                             from->peer.secret, strlen(from->peer.secret), server->requests[i].origauth, subattr + 2)) {
1212                         printf("clientrd: failed to encrypt msppe key\n");
1213                         continue;
1214                     }
1215                 }
1216                 if (subleft < 0) {
1217                     printf("clientrd: bad vendor specific attr or subattr length, ignoring packet\n");
1218                     goto getnext;
1219                 }
1220             }
1221             attr += attr[RAD_Attr_Length];
1222         }
1223
1224         /* once we set received = 1, requests[i] may be reused */
1225         buf[1] = (char)server->requests[i].origid;
1226         memcpy(buf + 4, server->requests[i].origauth, 16);
1227         printauth("origauth/buf+4", buf + 4);
1228         if (messageauthattr) {
1229             if (!createmessageauth(buf, &messageauthattr[RAD_Attr_Value], from->peer.secret))
1230                 continue;
1231             printf("clientrd: computed messageauthattr\n");
1232         }
1233
1234         if (from->peer.type == 'U')
1235             fromsa = server->requests[i].fromsa;
1236         server->requests[i].received = 1;
1237         pthread_mutex_unlock(&server->newrq_mutex);
1238
1239         if (!radsign(buf, from->peer.secret)) {
1240             printf("clientrd: failed to sign message\n");
1241             continue;
1242         }
1243         printauth("signedorigauth/buf+4", buf + 4);             
1244         printf("clientrd: giving packet back to where it came from\n");
1245         sendreply(from, server, buf, from->peer.type == 'U' ? &fromsa : NULL);
1246     }
1247 }
1248
1249 void *clientwr(void *arg) {
1250     struct server *server = (struct server *)arg;
1251     struct request *rq;
1252     pthread_t clientrdth;
1253     int i;
1254     struct timeval now;
1255     struct timespec timeout;
1256
1257     memset(&timeout, 0, sizeof(struct timespec));
1258     
1259     if (server->peer.type == 'U') {
1260         if ((server->sock = connecttoserver(server->peer.addrinfo)) < 0) {
1261             printf("clientwr: connecttoserver failed\n");
1262             exit(1);
1263         }
1264     } else
1265         tlsconnect(server, NULL, "new client");
1266     
1267     if (pthread_create(&clientrdth, NULL, clientrd, (void *)server))
1268         errx("clientwr: pthread_create failed");
1269
1270     for (;;) {
1271         pthread_mutex_lock(&server->newrq_mutex);
1272         if (!server->newrq) {
1273             if (timeout.tv_nsec) {
1274                 printf("clientwr: waiting up to %ld secs for new request\n", timeout.tv_nsec);
1275                 pthread_cond_timedwait(&server->newrq_cond, &server->newrq_mutex, &timeout);
1276                 timeout.tv_nsec = 0;
1277             } else {
1278                 printf("clientwr: waiting for new request\n");
1279                 pthread_cond_wait(&server->newrq_cond, &server->newrq_mutex);
1280             }
1281             if (server->newrq) {
1282                 printf("clientwr: got new request\n");
1283                 server->newrq = 0;
1284             }
1285         }
1286         pthread_mutex_unlock(&server->newrq_mutex);
1287         
1288         for (i = 0; i < MAX_REQUESTS; i++) {
1289             pthread_mutex_lock(&server->newrq_mutex);
1290             while (!server->requests[i].buf && i < MAX_REQUESTS)
1291                 i++;
1292             if (i == MAX_REQUESTS) {
1293                 pthread_mutex_unlock(&server->newrq_mutex);
1294                 break;
1295             }
1296             rq = server->requests + i;
1297
1298             if (rq->received) {
1299                 printf("clientwr: removing received packet from queue\n");
1300                 free(rq->buf);
1301                 /* setting this to NULL means that it can be reused */
1302                 rq->buf = NULL;
1303                 pthread_mutex_unlock(&server->newrq_mutex);
1304                 continue;
1305             }
1306             
1307             gettimeofday(&now, NULL);
1308             if (now.tv_sec <= rq->expiry.tv_sec) {
1309                 if (!timeout.tv_sec || rq->expiry.tv_sec < timeout.tv_sec)
1310                     timeout.tv_sec = rq->expiry.tv_sec;
1311                 pthread_mutex_unlock(&server->newrq_mutex);
1312                 continue;
1313             }
1314
1315             if (rq->tries == (server->peer.type == 'T' ? 1 : REQUEST_RETRIES)) {
1316                 printf("clientwr: removing expired packet from queue\n");
1317                 free(rq->buf);
1318                 /* setting this to NULL means that it can be reused */
1319                 rq->buf = NULL;
1320                 pthread_mutex_unlock(&server->newrq_mutex);
1321                 continue;
1322             }
1323             pthread_mutex_unlock(&server->newrq_mutex);
1324
1325             rq->expiry.tv_sec = now.tv_sec +
1326                 (server->peer.type == 'T' ? REQUEST_EXPIRY : REQUEST_EXPIRY / REQUEST_RETRIES);
1327             if (!timeout.tv_sec || rq->expiry.tv_sec < timeout.tv_sec)
1328                 timeout.tv_sec = rq->expiry.tv_sec;
1329             rq->tries++;
1330             clientradput(server, server->requests[i].buf);
1331             usleep(200000);
1332         }
1333     }
1334     /* should do more work to maintain TLS connections, keepalives etc */
1335 }
1336
1337 void *udpserverwr(void *arg) {
1338     struct replyq *replyq = &udp_server_replyq;
1339     struct reply *reply = replyq->replies;
1340     
1341     pthread_mutex_lock(&replyq->count_mutex);
1342     for (;;) {
1343         while (!replyq->count) {
1344             printf("udp server writer, waiting for signal\n");
1345             pthread_cond_wait(&replyq->count_cond, &replyq->count_mutex);
1346             printf("udp server writer, got signal\n");
1347         }
1348         pthread_mutex_unlock(&replyq->count_mutex);
1349         
1350         if (sendto(udp_server_sock, reply->buf, RADLEN(reply->buf), 0,
1351                    (struct sockaddr *)&reply->tosa, SOCKADDR_SIZE(reply->tosa)) < 0)
1352             err("sendudp: send failed");
1353         free(reply->buf);
1354         
1355         pthread_mutex_lock(&replyq->count_mutex);
1356         replyq->count--;
1357         memmove(replyq->replies, replyq->replies + 1,
1358                 replyq->count * sizeof(struct reply));
1359     }
1360 }
1361
1362 void *udpserverrd(void *arg) {
1363     struct request rq;
1364     unsigned char *buf;
1365     struct server *to;
1366     struct client *fr;
1367     pthread_t udpserverwrth;
1368     
1369     if ((udp_server_sock = bindport(SOCK_DGRAM, options.udpserverport)) < 0) {
1370         printf("udpserverrd: socket/bind failed\n");
1371         exit(1);
1372     }
1373     printf("udpserverrd: listening on UDP port %s\n", options.udpserverport);
1374
1375     if (pthread_create(&udpserverwrth, NULL, udpserverwr, NULL))
1376         errx("pthread_create failed");
1377     
1378     for (;;) {
1379         fr = NULL;
1380         memset(&rq, 0, sizeof(struct request));
1381         buf = radudpget(udp_server_sock, &fr, NULL, &rq.fromsa);
1382         to = radsrv(&rq, buf, fr);
1383         if (!to) {
1384             printf("udpserverrd: ignoring request, no place to send it\n");
1385             continue;
1386         }
1387         sendrq(to, fr, &rq);
1388     }
1389 }
1390
1391 void *tlsserverwr(void *arg) {
1392     int cnt;
1393     unsigned long error;
1394     struct client *client = (struct client *)arg;
1395     struct replyq *replyq;
1396     
1397     printf("tlsserverwr starting for %s\n", client->peer.host);
1398     replyq = client->replyq;
1399     pthread_mutex_lock(&replyq->count_mutex);
1400     for (;;) {
1401         while (!replyq->count) {
1402             if (client->peer.ssl) {         
1403                 printf("tls server writer, waiting for signal\n");
1404                 pthread_cond_wait(&replyq->count_cond, &replyq->count_mutex);
1405                 printf("tls server writer, got signal\n");
1406             }
1407             if (!client->peer.ssl) {
1408                 //ssl might have changed while waiting
1409                 pthread_mutex_unlock(&replyq->count_mutex);
1410                 printf("tlsserverwr: exiting as requested\n");
1411                 pthread_exit(NULL);
1412             }
1413         }
1414         pthread_mutex_unlock(&replyq->count_mutex);
1415         cnt = SSL_write(client->peer.ssl, replyq->replies->buf, RADLEN(replyq->replies->buf));
1416         if (cnt > 0)
1417             printf("tlsserverwr: Sent %d bytes, Radius packet of length %d\n",
1418                    cnt, RADLEN(replyq->replies->buf));
1419         else
1420             while ((error = ERR_get_error()))
1421                 err("tlsserverwr: SSL: %s", ERR_error_string(error, NULL));
1422         free(replyq->replies->buf);
1423
1424         pthread_mutex_lock(&replyq->count_mutex);
1425         replyq->count--;
1426         memmove(replyq->replies, replyq->replies + 1, replyq->count * sizeof(struct reply));
1427     }
1428 }
1429
1430 void *tlsserverrd(void *arg) {
1431     struct request rq;
1432     char unsigned *buf;
1433     unsigned long error;
1434     struct server *to;
1435     int s;
1436     struct client *client = (struct client *)arg;
1437     pthread_t tlsserverwrth;
1438     SSL *ssl;
1439     
1440     printf("tlsserverrd starting for %s\n", client->peer.host);
1441     ssl = client->peer.ssl;
1442
1443     if (SSL_accept(ssl) <= 0) {
1444         while ((error = ERR_get_error()))
1445             err("tlsserverrd: SSL: %s", ERR_error_string(error, NULL));
1446         errx("accept failed, child exiting");
1447     }
1448
1449     if (tlsverifycert(&client->peer)) {
1450         if (pthread_create(&tlsserverwrth, NULL, tlsserverwr, (void *)client))
1451             errx("pthread_create failed");
1452     
1453         for (;;) {
1454             buf = radtlsget(client->peer.ssl);
1455             if (!buf)
1456                 break;
1457             printf("tlsserverrd: got Radius message from %s\n", client->peer.host);
1458             memset(&rq, 0, sizeof(struct request));
1459             to = radsrv(&rq, buf, client);
1460             if (!to) {
1461                 printf("ignoring request, no place to send it\n");
1462                 continue;
1463             }
1464             sendrq(to, client, &rq);
1465         }
1466         printf("tlsserverrd: connection lost\n");
1467         // stop writer by setting peer.ssl to NULL and give signal in case waiting for data
1468         client->peer.ssl = NULL;
1469         pthread_mutex_lock(&client->replyq->count_mutex);
1470         pthread_cond_signal(&client->replyq->count_cond);
1471         pthread_mutex_unlock(&client->replyq->count_mutex);
1472         printf("tlsserverrd: waiting for writer to end\n");
1473         pthread_join(tlsserverwrth, NULL);
1474     }
1475     s = SSL_get_fd(ssl);
1476     SSL_free(ssl);
1477     close(s);
1478     printf("tlsserverrd thread for %s exiting\n", client->peer.host);
1479     client->peer.ssl = NULL;
1480     pthread_exit(NULL);
1481 }
1482
1483 int tlslistener() {
1484     pthread_t tlsserverth;
1485     int s, snew;
1486     struct sockaddr_storage from;
1487     size_t fromlen = sizeof(from);
1488     struct client *client;
1489
1490     if ((s = bindport(SOCK_STREAM, DEFAULT_TLS_PORT)) < 0) {
1491         printf("tlslistener: socket/bind failed\n");
1492         exit(1);
1493     }
1494     
1495     listen(s, 0);
1496     printf("listening for incoming TLS on port %s\n", DEFAULT_TLS_PORT);
1497
1498     for (;;) {
1499         snew = accept(s, (struct sockaddr *)&from, &fromlen);
1500         if (snew < 0)
1501             errx("accept failed");
1502         printf("incoming TLS connection from %s\n", addr2string((struct sockaddr *)&from, fromlen));
1503
1504         client = find_client('T', (struct sockaddr *)&from, NULL);
1505         if (!client) {
1506             printf("ignoring request, not a known TLS client\n");
1507             shutdown(snew, SHUT_RDWR);
1508             close(snew);
1509             continue;
1510         }
1511
1512         if (client->peer.ssl) {
1513             printf("Ignoring incoming connection, already have one from this client\n");
1514             shutdown(snew, SHUT_RDWR);
1515             close(snew);
1516             continue;
1517         }
1518         client->peer.ssl = SSL_new(ssl_ctx);
1519         SSL_set_fd(client->peer.ssl, snew);
1520         if (pthread_create(&tlsserverth, NULL, tlsserverrd, (void *)client))
1521             errx("pthread_create failed");
1522         pthread_detach(tlsserverth);
1523     }
1524     return 0;
1525 }
1526
1527 char *parsehostport(char *s, struct peer *peer) {
1528     char *p, *field;
1529     int ipv6 = 0;
1530
1531     p = s;
1532     // allow literal addresses and port, e.g. [2001:db8::1]:1812
1533     if (*p == '[') {
1534         p++;
1535         field = p;
1536         for (; *p && *p != ']' && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1537         if (*p != ']') {
1538             printf("no ] matching initial [\n");
1539             exit(1);
1540         }
1541         ipv6 = 1;
1542     } else {
1543         field = p;
1544         for (; *p && *p != ':' && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1545     }
1546     if (field == p) {
1547         printf("missing host/address\n");
1548         exit(1);
1549     }
1550     peer->host = stringcopy(field, p - field);
1551     if (ipv6) {
1552         p++;
1553         if (*p && *p != ':' && *p != ' ' && *p != '\t' && *p != '\n') {
1554             printf("unexpected character after ]\n");
1555             exit(1);
1556         }
1557     }
1558     if (*p == ':') {
1559             /* port number or service name is specified */;
1560             field = p++;
1561             for (; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1562             if (field == p) {
1563                 printf("syntax error, : but no following port\n");
1564                 exit(1);
1565             }
1566             peer->port = stringcopy(field, p - field);
1567     } else
1568         peer->port = stringcopy(peer->type == 'U' ? DEFAULT_UDP_PORT : DEFAULT_TLS_PORT, 0);
1569     return p;
1570 }
1571
1572 // * is default, else longest match ... ";" used for separator
1573 char *parserealmlist(char *s, struct server *server) {
1574     char *p;
1575     int i, n, l;
1576
1577     for (p = s, n = 1; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++)
1578         if (*p == ';')
1579             n++;
1580     l = p - s;
1581     if (!l) {
1582         printf("realm list must be specified\n");
1583         exit(1);
1584     }
1585     server->realmdata = stringcopy(s, l);
1586     server->realms = malloc((1+n) * sizeof(char *));
1587     if (!server->realms)
1588         errx("malloc failed");
1589     server->realms[0] = server->realmdata;
1590     for (n = 1, i = 0; i < l; i++)
1591         if (server->realmdata[i] == ';') {
1592             server->realmdata[i] = '\0';
1593             server->realms[n++] = server->realmdata + i + 1;
1594         }       
1595     server->realms[n] = NULL;
1596     return p;
1597 }
1598
1599 /* exactly one argument must be non-NULL */
1600 void getconfig(const char *serverfile, const char *clientfile) {
1601     FILE *f;
1602     char line[1024];
1603     char *p, *field, **r;
1604     const char *file;
1605     struct client *client;
1606     struct server *server;
1607     struct peer *peer;
1608     int i, count, *ucount, *tcount;
1609  
1610     file = serverfile ? serverfile : clientfile;
1611     f = fopen(file, "r");
1612     if (!f)
1613         errx("getconfig failed to open %s for reading", file);
1614     printf("opening file %s for reading\n", file);
1615     if (serverfile) {
1616         ucount = &server_udp_count;
1617         tcount = &server_tls_count;
1618     } else {
1619         ucount = &client_udp_count;
1620         tcount = &client_tls_count;
1621     }
1622     while (fgets(line, 1024, f)) {
1623         for (p = line; *p == ' ' || *p == '\t'; p++);
1624         switch (*p) {
1625         case '#':
1626         case '\n':
1627             break;
1628         case 'T':
1629             (*tcount)++;
1630             break;
1631         case 'U':
1632             (*ucount)++;
1633             break;
1634         default:
1635             printf("type must be U or T, got %c\n", *p);
1636             exit(1);
1637         }
1638     }
1639
1640     if (serverfile) {
1641         count = server_count = server_udp_count + server_tls_count;
1642         servers = calloc(count, sizeof(struct server));
1643         if (!servers)
1644             errx("malloc failed");
1645     } else {
1646         count = client_count = client_udp_count + client_tls_count;
1647         clients = calloc(count, sizeof(struct client));
1648         if (!clients)
1649             errx("malloc failed");
1650     }
1651     
1652     if (client_udp_count) {
1653         udp_server_replyq.replies = malloc(client_udp_count * MAX_REQUESTS * sizeof(struct reply));
1654         if (!udp_server_replyq.replies)
1655             errx("malloc failed");
1656         udp_server_replyq.size = client_udp_count * MAX_REQUESTS;
1657         udp_server_replyq.count = 0;
1658         pthread_mutex_init(&udp_server_replyq.count_mutex, NULL);
1659         pthread_cond_init(&udp_server_replyq.count_cond, NULL);
1660     }    
1661     
1662     rewind(f);
1663     for (i = 0; i < count && fgets(line, 1024, f);) {
1664         if (serverfile) {
1665             server = &servers[i];
1666             peer = &server->peer;
1667         } else {
1668             client = &clients[i];
1669             peer = &client->peer;
1670         }
1671         for (p = line; *p == ' ' || *p == '\t'; p++);
1672         if (*p == '#' || *p == '\n')
1673             continue;
1674         peer->type = *p;        // we already know it must be U or T
1675         for (p++; *p == ' ' || *p == '\t'; p++);
1676         p = parsehostport(p, peer);
1677         for (; *p == ' ' || *p == '\t'; p++);
1678         if (serverfile) {
1679             p = parserealmlist(p, server);
1680             for (; *p == ' ' || *p == '\t'; p++);
1681         }
1682         field = p;
1683         for (; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1684         if (field == p) {
1685             /* no secret set and end of line, line is complete if TLS */
1686             if (peer->type == 'U') {
1687                 printf("secret must be specified for UDP\n");
1688                 exit(1);
1689             }
1690             peer->secret = stringcopy(DEFAULT_TLS_SECRET, 0);
1691         } else {
1692             peer->secret = stringcopy(field, p - field);
1693             /* check that rest of line only white space */
1694             for (; *p == ' ' || *p == '\t'; p++);
1695             if (*p && *p != '\n') {
1696                 printf("max 4 fields per line, found a 5th\n");
1697                 exit(1);
1698             }
1699         }
1700
1701         if ((serverfile && !resolvepeer(&server->peer)) ||
1702             (clientfile && !resolvepeer(&client->peer))) {
1703             printf("failed to resolve host %s port %s, exiting\n", peer->host, peer->port);
1704             exit(1);
1705         }
1706
1707         if (serverfile) {
1708             pthread_mutex_init(&server->lock, NULL);
1709             server->sock = -1;
1710             server->requests = calloc(MAX_REQUESTS, sizeof(struct request));
1711             if (!server->requests)
1712                 errx("malloc failed");
1713             server->newrq = 0;
1714             pthread_mutex_init(&server->newrq_mutex, NULL);
1715             pthread_cond_init(&server->newrq_cond, NULL);
1716         } else {
1717             if (peer->type == 'U')
1718                 client->replyq = &udp_server_replyq;
1719             else {
1720                 client->replyq = malloc(sizeof(struct replyq));
1721                 if (!client->replyq)
1722                     errx("malloc failed");
1723                 client->replyq->replies = calloc(MAX_REQUESTS, sizeof(struct reply));
1724                 if (!client->replyq->replies)
1725                     errx("malloc failed");
1726                 client->replyq->size = MAX_REQUESTS;
1727                 client->replyq->count = 0;
1728                 pthread_mutex_init(&client->replyq->count_mutex, NULL);
1729                 pthread_cond_init(&client->replyq->count_cond, NULL);
1730             }
1731         }
1732         printf("got type %c, host %s, port %s, secret %s\n", peer->type, peer->host, peer->port, peer->secret);
1733         if (serverfile) {
1734             printf("    with realms:");
1735             for (r = server->realms; *r; r++)
1736                 printf(" %s", *r);
1737             printf("\n");
1738         }
1739         i++;
1740     }
1741     fclose(f);
1742 }
1743
1744 void getmainconfig(const char *configfile) {
1745     FILE *f;
1746     char line[1024];
1747     char *p, *opt, *endopt, *val, *endval;
1748     
1749     printf("opening file %s for reading\n", configfile);
1750     f = fopen(configfile, "r");
1751     if (!f)
1752         errx("getmainconfig failed to open %s for reading", configfile);
1753
1754     memset(&options, 0, sizeof(options));
1755
1756     while (fgets(line, 1024, f)) {
1757         for (p = line; *p == ' ' || *p == '\t'; p++);
1758         if (!*p || *p == '#' || *p == '\n')
1759             continue;
1760         opt = p++;
1761         for (; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1762         endopt = p - 1;
1763         for (; *p == ' ' || *p == '\t'; p++);
1764         if (!*p || *p == '\n') {
1765             endopt[1] = '\0';
1766             printf("error in %s, option %s has no value\n", configfile, opt);
1767             exit(1);
1768         }
1769         val = p;
1770         for (; *p && *p != '\n'; p++)
1771             if (*p != ' ' && *p != '\t')
1772                 endval = p;
1773         endopt[1] = '\0';
1774         endval[1] = '\0';
1775         printf("getmainconfig: %s = %s\n", opt, val);
1776         
1777         if (!strcasecmp(opt, "TLSCACertificateFile")) {
1778             options.tlscacertificatefile = stringcopy(val, 0);
1779             continue;
1780         }
1781         if (!strcasecmp(opt, "TLSCACertificatePath")) {
1782             options.tlscacertificatepath = stringcopy(val, 0);
1783             continue;
1784         }
1785         if (!strcasecmp(opt, "TLSCertificateFile")) {
1786             options.tlscertificatefile = stringcopy(val, 0);
1787             continue;
1788         }
1789         if (!strcasecmp(opt, "TLSCertificateKeyFile")) {
1790             options.tlscertificatekeyfile = stringcopy(val, 0);
1791             continue;
1792         }
1793         if (!strcasecmp(opt, "UDPServerPort")) {
1794             options.udpserverport = stringcopy(val, 0);
1795             continue;
1796         }
1797
1798         printf("error in %s, unknown option %s\n", configfile, opt);
1799         exit(1);
1800     }
1801     fclose(f);
1802
1803     if (!options.udpserverport)
1804         options.udpserverport = stringcopy(DEFAULT_UDP_PORT, 0);
1805 }
1806
1807 #if 0
1808 void parseargs(int argc, char **argv) {
1809     int c;
1810
1811     while ((c = getopt(argc, argv, "p:")) != -1) {
1812         switch (c) {
1813         case 'p':
1814             udp_server_port = optarg;
1815             break;
1816         default:
1817             goto usage;
1818         }
1819     }
1820
1821     return;
1822
1823  usage:
1824     printf("radsecproxy [ -p UDP-port ]\n");
1825     exit(1);
1826 }
1827 #endif
1828
1829 int main(int argc, char **argv) {
1830     pthread_t udpserverth;
1831     //    pthread_attr_t joinable;
1832     int i;
1833     
1834     //    parseargs(argc, argv);
1835     getmainconfig("radsecproxy.conf");
1836     getconfig("servers.conf", NULL);
1837     getconfig(NULL, "clients.conf");
1838
1839     //    pthread_attr_init(&joinable);
1840     //    pthread_attr_setdetachstate(&joinable, PTHREAD_CREATE_JOINABLE);
1841    
1842     if (client_udp_count)
1843         if (pthread_create(&udpserverth, NULL /*&joinable*/, udpserverrd, NULL))
1844             errx("pthread_create failed");
1845
1846     if (client_tls_count || server_tls_count)
1847         ssl_ctx = ssl_init();
1848     
1849     for (i = 0; i < server_count; i++)
1850         if (pthread_create(&servers[i].clientth, NULL, clientwr, (void *)&servers[i]))
1851             errx("pthread_create failed");
1852
1853     if (client_tls_count)
1854         return tlslistener();
1855
1856     /* just hang around doing nothing, anything to do here? */
1857     for (;;)
1858         sleep(1000);
1859 }