now using same ctx for client and server
[radsecproxy.git] / radsecproxy.c
1 /*
2  * Copyright (C) 2006 Stig Venaas <venaas@uninett.no>
3  *
4  * Permission to use, copy, modify, and distribute this software for any
5  * purpose with or without fee is hereby granted, provided that the above
6  * copyright notice and this permission notice appear in all copies.
7  */
8
9 /* TODO:
10  * accounting
11  * radius keep alives (server status)
12  * tls certificate validation, see below urls
13  * clean tls shutdown, see http://www.linuxjournal.com/article/4822
14  *     and http://www.linuxjournal.com/article/5487
15  *     SSL_shutdown() and shutdown()
16  *     If shutdown() we may not need REUSEADDR
17  * when tls client goes away, ensure that all related threads and state
18  *          are removed
19  * setsockopt(keepalive...), check if openssl has some keepalive feature
20 */
21
22 /* For UDP there is one server instance consisting of udpserverrd and udpserverth
23  *              rd is responsible for init and launching wr
24  * For TLS there is a server instance that launches tlsserverrd for each TLS peer
25  *          each tlsserverrd launches tlsserverwr
26  * For each UDP/TLS peer there is clientrd and clientwr, clientwr is responsible
27  *          for init and launching rd
28  *
29  * serverrd will receive a request, processes it and puts it in the requestq of
30  *          the appropriate clientwr
31  * clientwr monitors its requestq and sends requests
32  * clientrd looks for responses, processes them and puts them in the replyq of
33  *          the peer the request came from
34  * serverwr monitors its reply and sends replies
35  *
36  * In addition to the main thread, we have:
37  * If UDP peers are configured, there will be 2 + 2 * #peers UDP threads
38  * If TLS peers are configured, there will initially be 2 * #peers TLS threads
39  * For each TLS peer connecting to us there will be 2 more TLS threads
40  *       This is only for connected peers
41  * Example: With 3 UDP peer and 30 TLS peers, there will be a max of
42  *          1 + (2 + 2 * 3) + (2 * 30) + (2 * 30) = 129 threads
43 */
44
45 #include <netdb.h>
46 #include <unistd.h>
47 #include <sys/time.h>
48 #include <pthread.h>
49 #include <openssl/ssl.h>
50 #include <openssl/rand.h>
51 #include <openssl/err.h>
52 #include <openssl/md5.h>
53 #include <openssl/hmac.h>
54 #include "radsecproxy.h"
55
56 static struct options options;
57 static struct client *clients;
58 static struct server *servers;
59
60 static int client_udp_count = 0;
61 static int client_tls_count = 0;
62 static int client_count = 0;
63 static int server_udp_count = 0;
64 static int server_tls_count = 0;
65 static int server_count = 0;
66
67 static struct replyq udp_server_replyq;
68 static int udp_server_sock = -1;
69 static pthread_mutex_t *ssl_locks;
70 static long *ssl_lock_count;
71 static SSL_CTX *ssl_ctx = NULL;
72 extern int optind;
73 extern char *optarg;
74
75 /* callbacks for making OpenSSL thread safe */
76 unsigned long ssl_thread_id() {
77         return (unsigned long)pthread_self();
78 };
79
80 void ssl_locking_callback(int mode, int type, const char *file, int line) {
81     if (mode & CRYPTO_LOCK) {
82         pthread_mutex_lock(&ssl_locks[type]);
83         ssl_lock_count[type]++;
84     } else
85         pthread_mutex_unlock(&ssl_locks[type]);
86 }
87
88 static int verify_cb(int ok, X509_STORE_CTX *ctx) {
89   char buf[256];
90   X509 *err_cert;
91   int err, depth;
92
93   err_cert = X509_STORE_CTX_get_current_cert(ctx);
94   err = X509_STORE_CTX_get_error(ctx);
95   depth = X509_STORE_CTX_get_error_depth(ctx);
96
97   if (depth > MAX_CERT_DEPTH) {
98       ok = 0;
99       err = X509_V_ERR_CERT_CHAIN_TOO_LONG;
100       X509_STORE_CTX_set_error(ctx, err);
101   }
102
103   if (!ok) {
104       X509_NAME_oneline(X509_get_subject_name(err_cert), buf, 256);
105       printf("verify error: num=%d:%s:depth=%d:%s\n", err, X509_verify_cert_error_string(err), depth, buf);
106
107       switch (err) {
108       case X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT:
109           X509_NAME_oneline(X509_get_issuer_name(ctx->current_cert), buf, 256);
110           printf("issuer=%s\n", buf);
111           break;
112       case X509_V_ERR_CERT_NOT_YET_VALID:
113       case X509_V_ERR_ERROR_IN_CERT_NOT_BEFORE_FIELD:
114           printf("Certificate not yet valid\n");
115           break;
116       case X509_V_ERR_CERT_HAS_EXPIRED:
117           printf("Certificate has expired\n");
118           break;
119       case X509_V_ERR_ERROR_IN_CERT_NOT_AFTER_FIELD:
120           printf("Certificate no longer valid (after notAfter)\n");
121           break;
122       }
123   }
124   //  printf("certificate verify returns %d\n", ok);
125   return ok;
126 }
127
128 SSL_CTX *ssl_init() {
129     SSL_CTX *ctx;
130     int i;
131     unsigned long error;
132     
133     if (!options.tlscertificatefile || !options.tlscertificatekeyfile) {
134         printf("TLSCertificateFile and TLSCertificateKeyFile must be specified for TLS\n");
135         exit(1);
136     }
137     if (!options.tlscacertificatefile && !options.tlscacertificatepath) {
138         printf("CA Certificate file/path need to be configured\n");
139         exit(1);
140     }
141
142     ssl_locks = malloc(CRYPTO_num_locks() * sizeof(pthread_mutex_t));
143     ssl_lock_count = OPENSSL_malloc(CRYPTO_num_locks() * sizeof(long));
144     for (i = 0; i < CRYPTO_num_locks(); i++) {
145         ssl_lock_count[i] = 0;
146         pthread_mutex_init(&ssl_locks[i], NULL);
147     }
148     CRYPTO_set_id_callback(ssl_thread_id);
149     CRYPTO_set_locking_callback(ssl_locking_callback);
150
151     SSL_load_error_strings();
152     SSL_library_init();
153
154     while (!RAND_status()) {
155         time_t t = time(NULL);
156         pid_t pid = getpid();
157         RAND_seed((unsigned char *)&t, sizeof(time_t));
158         RAND_seed((unsigned char *)&pid, sizeof(pid));
159     }
160
161     ctx = SSL_CTX_new(TLSv1_method());
162     if (!SSL_CTX_use_certificate_chain_file(ctx, options.tlscertificatefile) ||
163         !SSL_CTX_use_PrivateKey_file(ctx, options.tlscertificatekeyfile, SSL_FILETYPE_PEM) ||
164         !SSL_CTX_check_private_key(ctx))
165         goto errexit;
166     if (!SSL_CTX_load_verify_locations(ctx, options.tlscacertificatefile, options.tlscacertificatepath))
167         goto errexit;
168     SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, verify_cb);
169     SSL_CTX_set_verify_depth(ctx, MAX_CERT_DEPTH + 1);
170     return ctx;
171     
172  errexit:
173     while ((error = ERR_get_error()))
174         err("SSL: %s", ERR_error_string(error, NULL));
175     exit(1);
176 }    
177
178 void printauth(char *s, unsigned char *t) {
179     int i;
180     printf("%s:", s);
181     for (i = 0; i < 16; i++)
182             printf("%02x ", t[i]);
183     printf("\n");
184 }
185
186 int resolvepeer(struct peer *peer) {
187     struct addrinfo hints, *addrinfo;
188     
189     memset(&hints, 0, sizeof(hints));
190     hints.ai_socktype = (peer->type == 'T' ? SOCK_STREAM : SOCK_DGRAM);
191     hints.ai_family = AF_UNSPEC;
192     if (getaddrinfo(peer->host, peer->port, &hints, &addrinfo)) {
193         err("resolvepeer: can't resolve %s port %s", peer->host, peer->port);
194         return 0;
195     }
196
197     if (peer->addrinfo)
198         freeaddrinfo(peer->addrinfo);
199     peer->addrinfo = addrinfo;
200     return 1;
201 }         
202
203 int connecttoserver(struct addrinfo *addrinfo) {
204     int s;
205     struct addrinfo *res;
206     
207     for (res = addrinfo; res; res = res->ai_next) {
208         s = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
209         if (s < 0) {
210             err("connecttoserver: socket failed");
211             continue;
212         }
213         if (connect(s, res->ai_addr, res->ai_addrlen) == 0)
214             break;
215         err("connecttoserver: connect failed");
216         close(s);
217         s = -1;
218     }
219     return s;
220 }         
221
222 /* returns the client with matching address, or NULL */
223 /* if client argument is not NULL, we only check that one client */
224 struct client *find_client(char type, struct sockaddr *addr, struct client *client) {
225     struct sockaddr_in6 *sa6;
226     struct in_addr *a4 = NULL;
227     struct client *c;
228     int i;
229     struct addrinfo *res;
230
231     if (addr->sa_family == AF_INET6) {
232         sa6 = (struct sockaddr_in6 *)addr;
233         if (IN6_IS_ADDR_V4MAPPED(&sa6->sin6_addr))
234             a4 = (struct in_addr *)&sa6->sin6_addr.s6_addr[12];
235     } else
236         a4 = &((struct sockaddr_in *)addr)->sin_addr;
237
238     c = (client ? client : clients);
239     for (i = 0; i < client_count; i++) {
240         if (c->peer.type == type)
241             for (res = c->peer.addrinfo; res; res = res->ai_next)
242                 if ((a4 && res->ai_family == AF_INET &&
243                      !memcmp(a4, &((struct sockaddr_in *)res->ai_addr)->sin_addr, 4)) ||
244                     (res->ai_family == AF_INET6 &&
245                      !memcmp(&sa6->sin6_addr, &((struct sockaddr_in6 *)res->ai_addr)->sin6_addr, 16)))
246                     return c;
247         if (client)
248             break;
249         c++;
250     }
251     return NULL;
252 }
253
254 /* returns the server with matching address, or NULL */
255 /* if server argument is not NULL, we only check that one server */
256 struct server *find_server(char type, struct sockaddr *addr, struct server *server) {
257     struct sockaddr_in6 *sa6;
258     struct in_addr *a4 = NULL;
259     struct server *s;
260     int i;
261     struct addrinfo *res;
262
263     if (addr->sa_family == AF_INET6) {
264         sa6 = (struct sockaddr_in6 *)addr;
265         if (IN6_IS_ADDR_V4MAPPED(&sa6->sin6_addr))
266             a4 = (struct in_addr *)&sa6->sin6_addr.s6_addr[12];
267     } else
268         a4 = &((struct sockaddr_in *)addr)->sin_addr;
269
270     s = (server ? server : servers);
271     for (i = 0; i < server_count; i++) {
272         if (s->peer.type == type)
273             for (res = s->peer.addrinfo; res; res = res->ai_next)
274                 if ((a4 && res->ai_family == AF_INET &&
275                      !memcmp(a4, &((struct sockaddr_in *)res->ai_addr)->sin_addr, 4)) ||
276                     (res->ai_family == AF_INET6 &&
277                      !memcmp(&sa6->sin6_addr, &((struct sockaddr_in6 *)res->ai_addr)->sin6_addr, 16)))
278                     return s;
279         if (server)
280             break;
281         s++;
282     }
283     return NULL;
284 }
285
286 /* exactly one of client and server must be non-NULL */
287 /* if *peer == NULL we return who we received from, else require it to be from peer */
288 /* return from in sa if not NULL */
289 unsigned char *radudpget(int s, struct client **client, struct server **server, struct sockaddr_storage *sa) {
290     int cnt, len;
291     void *f;
292     unsigned char buf[65536], *rad;
293     struct sockaddr_storage from;
294     socklen_t fromlen = sizeof(from);
295
296     for (;;) {
297         cnt = recvfrom(s, buf, sizeof(buf), 0, (struct sockaddr *)&from, &fromlen);
298         if (cnt == -1) {
299             err("radudpget: recv failed");
300             continue;
301         }
302         printf("radudpget: got %d bytes from %s\n", cnt, addr2string((struct sockaddr *)&from, fromlen));
303
304         if (cnt < 20) {
305             printf("radudpget: packet too small\n");
306             continue;
307         }
308     
309         len = RADLEN(buf);
310
311         if (cnt < len) {
312             printf("radudpget: packet smaller than length field in radius header\n");
313             continue;
314         }
315         if (cnt > len)
316             printf("radudpget: packet was padded with %d bytes\n", cnt - len);
317
318         f = (client
319              ? (void *)find_client('U', (struct sockaddr *)&from, *client)
320              : (void *)find_server('U', (struct sockaddr *)&from, *server));
321         if (!f) {
322             printf("radudpget: got packet from wrong or unknown UDP peer, ignoring\n");
323             continue;
324         }
325
326         rad = malloc(len);
327         if (rad)
328             break;
329         err("radudpget: malloc failed");
330     }
331     memcpy(rad, buf, len);
332     if (client)
333         *client = (struct client *)f; /* only need this if *client == NULL, but if not NULL *client == f here */
334     else
335         *server = (struct server *)f; /* only need this if *server == NULL, but if not NULL *server == f here */
336     if (sa)
337         *sa = from;
338     return rad;
339 }
340
341 int tlsverifycert(struct peer *peer) {
342     int i, l, loc;
343     X509 *cert;
344     X509_NAME *nm;
345     X509_NAME_ENTRY *e;
346     unsigned char *v;
347     unsigned long error;
348
349 #if 1
350     if (SSL_get_verify_result(peer->ssl) != X509_V_OK) {
351         printf("tlsverifycert: basic validation failed\n");
352         while ((error = ERR_get_error()))
353             err("clientwr: TLS: %s", ERR_error_string(error, NULL));
354         return 0;
355     }
356 #endif    
357     cert = SSL_get_peer_certificate(peer->ssl);
358     if (!cert) {
359         printf("tlsverifycert: failed to obtain certificate\n");
360         return 0;
361     }
362     nm = X509_get_subject_name(cert);
363     loc = -1;
364     for (;;) {
365         loc = X509_NAME_get_index_by_NID(nm, NID_commonName, loc);
366         if (loc == -1)
367             break;
368         e = X509_NAME_get_entry(nm, loc);
369         l = ASN1_STRING_to_UTF8(&v, X509_NAME_ENTRY_get_data(e));
370         if (l < 0)
371             continue;
372         printf("cn: ");
373         for (i = 0; i < l; i++)
374             printf("%c", v[i]);
375         printf("\n");
376         if (l == strlen(peer->host) && !strncasecmp(peer->host, v, l)) {
377             printf("tlsverifycert: Found cn matching host %s, All OK\n", peer->host);
378             return 1;
379         }
380         printf("tlsverifycert: cn not matching host %s\n", peer->host);
381     }
382     X509_free(cert);
383     return 0;
384 }
385
386 void tlsconnect(struct server *server, struct timeval *when, char *text) {
387     struct timeval now;
388     time_t elapsed;
389
390     printf("tlsconnect called from %s\n", text);
391     pthread_mutex_lock(&server->lock);
392     if (when && memcmp(&server->lastconnecttry, when, sizeof(struct timeval))) {
393         /* already reconnected, nothing to do */
394         printf("tlsconnect(%s): seems already reconnected\n", text);
395         pthread_mutex_unlock(&server->lock);
396         return;
397     }
398
399     printf("tlsconnect %s\n", text);
400
401     for (;;) {
402         gettimeofday(&now, NULL);
403         elapsed = now.tv_sec - server->lastconnecttry.tv_sec;
404         if (server->connectionok) {
405             server->connectionok = 0;
406             sleep(10);
407         } else if (elapsed < 5)
408             sleep(10);
409         else if (elapsed < 600)
410             sleep(elapsed * 2);
411         else if (elapsed < 10000)
412                 sleep(900);
413         else
414             server->lastconnecttry.tv_sec = now.tv_sec;  // no sleep at startup
415         printf("tlsconnect: trying to open TLS connection to %s port %s\n", server->peer.host, server->peer.port);
416         if (server->sock >= 0)
417             close(server->sock);
418         if ((server->sock = connecttoserver(server->peer.addrinfo)) < 0)
419             continue;
420         SSL_free(server->peer.ssl);
421         server->peer.ssl = SSL_new(ssl_ctx);
422         SSL_set_fd(server->peer.ssl, server->sock);
423         if (SSL_connect(server->peer.ssl) > 0 && tlsverifycert(&server->peer))
424             break;
425     }
426     printf("tlsconnect: TLS connection to %s port %s up\n", server->peer.host, server->peer.port);
427     gettimeofday(&server->lastconnecttry, NULL);
428     pthread_mutex_unlock(&server->lock);
429 }
430
431 unsigned char *radtlsget(SSL *ssl) {
432     int cnt, total, len;
433     unsigned char buf[4], *rad;
434
435     for (;;) {
436         for (total = 0; total < 4; total += cnt) {
437             cnt = SSL_read(ssl, buf + total, 4 - total);
438             if (cnt <= 0) {
439                 printf("radtlsget: connection lost\n");
440                 if (SSL_get_error(ssl, cnt) == SSL_ERROR_ZERO_RETURN) {
441                     //remote end sent close_notify, send one back
442                     SSL_shutdown(ssl);
443                 }
444                 return NULL;
445             }
446         }
447
448         len = RADLEN(buf);
449         rad = malloc(len);
450         if (!rad) {
451             err("radtlsget: malloc failed");
452             continue;
453         }
454         memcpy(rad, buf, 4);
455
456         for (; total < len; total += cnt) {
457             cnt = SSL_read(ssl, rad + total, len - total);
458             if (cnt <= 0) {
459                 printf("radtlsget: connection lost\n");
460                 if (SSL_get_error(ssl, cnt) == SSL_ERROR_ZERO_RETURN) {
461                     //remote end sent close_notify, send one back
462                     SSL_shutdown(ssl);
463                 }
464                 free(rad);
465                 return NULL;
466             }
467         }
468     
469         if (total >= 20)
470             break;
471         
472         free(rad);
473         printf("radtlsget: packet smaller than minimum radius size\n");
474     }
475     
476     printf("radtlsget: got %d bytes\n", total);
477     return rad;
478 }
479
480 int clientradput(struct server *server, unsigned char *rad) {
481     int cnt;
482     size_t len;
483     unsigned long error;
484     struct timeval lastconnecttry;
485     
486     len = RADLEN(rad);
487     if (server->peer.type == 'U') {
488         if (send(server->sock, rad, len, 0) >= 0) {
489             printf("clienradput: sent UDP of length %d to %s port %s\n", len, server->peer.host, server->peer.port);
490             return 1;
491         }
492         err("clientradput: send failed");
493         return 0;
494     }
495
496     lastconnecttry = server->lastconnecttry;
497     while ((cnt = SSL_write(server->peer.ssl, rad, len)) <= 0) {
498         while ((error = ERR_get_error()))
499             err("clientwr: TLS: %s", ERR_error_string(error, NULL));
500         tlsconnect(server, &lastconnecttry, "clientradput");
501         lastconnecttry = server->lastconnecttry;
502     }
503
504     server->connectionok = 1;
505     printf("clientradput: Sent %d bytes, Radius packet of length %d to TLS peer %s\n",
506            cnt, len, server->peer.host);
507     return 1;
508 }
509
510 int radsign(unsigned char *rad, unsigned char *sec) {
511     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
512     static unsigned char first = 1;
513     static EVP_MD_CTX mdctx;
514     unsigned int md_len;
515     int result;
516     
517     pthread_mutex_lock(&lock);
518     if (first) {
519         EVP_MD_CTX_init(&mdctx);
520         first = 0;
521     }
522
523     result = (EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) &&
524         EVP_DigestUpdate(&mdctx, rad, RADLEN(rad)) &&
525         EVP_DigestUpdate(&mdctx, sec, strlen(sec)) &&
526         EVP_DigestFinal_ex(&mdctx, rad + 4, &md_len) &&
527         md_len == 16);
528     pthread_mutex_unlock(&lock);
529     return result;
530 }
531
532 int validauth(unsigned char *rad, unsigned char *reqauth, unsigned char *sec) {
533     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
534     static unsigned char first = 1;
535     static EVP_MD_CTX mdctx;
536     unsigned char hash[EVP_MAX_MD_SIZE];
537     unsigned int len;
538     int result;
539     
540     pthread_mutex_lock(&lock);
541     if (first) {
542         EVP_MD_CTX_init(&mdctx);
543         first = 0;
544     }
545
546     len = RADLEN(rad);
547     
548     result = (EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) &&
549               EVP_DigestUpdate(&mdctx, rad, 4) &&
550               EVP_DigestUpdate(&mdctx, reqauth, 16) &&
551               (len <= 20 || EVP_DigestUpdate(&mdctx, rad + 20, len - 20)) &&
552               EVP_DigestUpdate(&mdctx, sec, strlen(sec)) &&
553               EVP_DigestFinal_ex(&mdctx, hash, &len) &&
554               len == 16 &&
555               !memcmp(hash, rad + 4, 16));
556     pthread_mutex_unlock(&lock);
557     return result;
558 }
559               
560 int checkmessageauth(char *rad, uint8_t *authattr, char *secret) {
561     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
562     static unsigned char first = 1;
563     static HMAC_CTX hmacctx;
564     unsigned int md_len;
565     uint8_t auth[16], hash[EVP_MAX_MD_SIZE];
566     
567     pthread_mutex_lock(&lock);
568     if (first) {
569         HMAC_CTX_init(&hmacctx);
570         first = 0;
571     }
572
573     memcpy(auth, authattr, 16);
574     memset(authattr, 0, 16);
575     md_len = 0;
576     HMAC_Init_ex(&hmacctx, secret, strlen(secret), EVP_md5(), NULL);
577     HMAC_Update(&hmacctx, rad, RADLEN(rad));
578     HMAC_Final(&hmacctx, hash, &md_len);
579     memcpy(authattr, auth, 16);
580     if (md_len != 16) {
581         printf("message auth computation failed\n");
582         pthread_mutex_unlock(&lock);
583         return 0;
584     }
585
586     if (memcmp(auth, hash, 16)) {
587         printf("message authenticator, wrong value\n");
588         pthread_mutex_unlock(&lock);
589         return 0;
590     }   
591         
592     pthread_mutex_unlock(&lock);
593     return 1;
594 }
595
596 int createmessageauth(char *rad, char *authattrval, char *secret) {
597     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
598     static unsigned char first = 1;
599     static HMAC_CTX hmacctx;
600     unsigned int md_len;
601
602     if (!authattrval)
603         return 1;
604     
605     pthread_mutex_lock(&lock);
606     if (first) {
607         HMAC_CTX_init(&hmacctx);
608         first = 0;
609     }
610
611     memset(authattrval, 0, 16);
612     md_len = 0;
613     HMAC_Init_ex(&hmacctx, secret, strlen(secret), EVP_md5(), NULL);
614     HMAC_Update(&hmacctx, rad, RADLEN(rad));
615     HMAC_Final(&hmacctx, authattrval, &md_len);
616     if (md_len != 16) {
617         printf("message auth computation failed\n");
618         pthread_mutex_unlock(&lock);
619         return 0;
620     }
621
622     pthread_mutex_unlock(&lock);
623     return 1;
624 }
625
626 void sendrq(struct server *to, struct client *from, struct request *rq) {
627     int i;
628     
629     pthread_mutex_lock(&to->newrq_mutex);
630     /* might simplify if only try nextid, might be ok */
631     for (i = to->nextid; i < MAX_REQUESTS; i++)
632         if (!to->requests[i].buf)
633             break;
634     if (i == MAX_REQUESTS) {
635         for (i = 0; i < to->nextid; i++)
636             if (!to->requests[i].buf)
637                 break;
638         if (i == to->nextid) {
639             printf("No room in queue, dropping request\n");
640             pthread_mutex_unlock(&to->newrq_mutex);
641             return;
642         }
643     }
644     
645     to->nextid = i + 1;
646     rq->buf[1] = (char)i;
647     printf("sendrq: inserting packet with id %d in queue for %s\n", i, to->peer.host);
648     
649     if (!createmessageauth(rq->buf, rq->messageauthattrval, to->peer.secret))
650         return;
651
652     to->requests[i] = *rq;
653
654     if (!to->newrq) {
655         to->newrq = 1;
656         printf("signalling client writer\n");
657         pthread_cond_signal(&to->newrq_cond);
658     }
659     pthread_mutex_unlock(&to->newrq_mutex);
660 }
661
662 void sendreply(struct client *to, struct server *from, char *buf, struct sockaddr_storage *tosa) {
663     struct replyq *replyq = to->replyq;
664     
665     pthread_mutex_lock(&replyq->count_mutex);
666     if (replyq->count == replyq->size) {
667         printf("No room in queue, dropping request\n");
668         pthread_mutex_unlock(&replyq->count_mutex);
669         return;
670     }
671
672     replyq->replies[replyq->count].buf = buf;
673     if (tosa)
674         replyq->replies[replyq->count].tosa = *tosa;
675     replyq->count++;
676
677     if (replyq->count == 1) {
678         printf("signalling client writer\n");
679         pthread_cond_signal(&replyq->count_cond);
680     }
681     pthread_mutex_unlock(&replyq->count_mutex);
682 }
683
684 int pwdencrypt(uint8_t *in, uint8_t len, uint8_t *shared, uint8_t sharedlen, uint8_t *auth) {
685     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
686     static unsigned char first = 1;
687     static EVP_MD_CTX mdctx;
688     unsigned char hash[EVP_MAX_MD_SIZE], *input;
689     unsigned int md_len;
690     uint8_t i, offset = 0, out[128];
691     
692     pthread_mutex_lock(&lock);
693     if (first) {
694         EVP_MD_CTX_init(&mdctx);
695         first = 0;
696     }
697
698     input = auth;
699     for (;;) {
700         if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
701             !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
702             !EVP_DigestUpdate(&mdctx, input, 16) ||
703             !EVP_DigestFinal_ex(&mdctx, hash, &md_len) ||
704             md_len != 16) {
705             pthread_mutex_unlock(&lock);
706             return 0;
707         }
708         for (i = 0; i < 16; i++)
709             out[offset + i] = hash[i] ^ in[offset + i];
710         input = out + offset - 16;
711         offset += 16;
712         if (offset == len)
713             break;
714     }
715     memcpy(in, out, len);
716     pthread_mutex_unlock(&lock);
717     return 1;
718 }
719
720 int pwddecrypt(uint8_t *in, uint8_t len, uint8_t *shared, uint8_t sharedlen, uint8_t *auth) {
721     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
722     static unsigned char first = 1;
723     static EVP_MD_CTX mdctx;
724     unsigned char hash[EVP_MAX_MD_SIZE], *input;
725     unsigned int md_len;
726     uint8_t i, offset = 0, out[128];
727     
728     pthread_mutex_lock(&lock);
729     if (first) {
730         EVP_MD_CTX_init(&mdctx);
731         first = 0;
732     }
733
734     input = auth;
735     for (;;) {
736         if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
737             !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
738             !EVP_DigestUpdate(&mdctx, input, 16) ||
739             !EVP_DigestFinal_ex(&mdctx, hash, &md_len) ||
740             md_len != 16) {
741             pthread_mutex_unlock(&lock);
742             return 0;
743         }
744         for (i = 0; i < 16; i++)
745             out[offset + i] = hash[i] ^ in[offset + i];
746         input = in + offset;
747         offset += 16;
748         if (offset == len)
749             break;
750     }
751     memcpy(in, out, len);
752     pthread_mutex_unlock(&lock);
753     return 1;
754 }
755
756 int msmppencrypt(uint8_t *text, uint8_t len, uint8_t *shared, uint8_t sharedlen, uint8_t *auth, uint8_t *salt) {
757     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
758     static unsigned char first = 1;
759     static EVP_MD_CTX mdctx;
760     unsigned char hash[EVP_MAX_MD_SIZE];
761     unsigned int md_len;
762     uint8_t i, offset;
763     
764     pthread_mutex_lock(&lock);
765     if (first) {
766         EVP_MD_CTX_init(&mdctx);
767         first = 0;
768     }
769
770 #if 0    
771     printf("msppencrypt auth in: ");
772     for (i = 0; i < 16; i++)
773         printf("%02x ", auth[i]);
774     printf("\n");
775     
776     printf("msppencrypt salt in: ");
777     for (i = 0; i < 2; i++)
778         printf("%02x ", salt[i]);
779     printf("\n");
780     
781     printf("msppencrypt in: ");
782     for (i = 0; i < len; i++)
783         printf("%02x ", text[i]);
784     printf("\n");
785 #endif
786     
787     if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
788         !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
789         !EVP_DigestUpdate(&mdctx, auth, 16) ||
790         !EVP_DigestUpdate(&mdctx, salt, 2) ||
791         !EVP_DigestFinal_ex(&mdctx, hash, &md_len)) {
792         pthread_mutex_unlock(&lock);
793         return 0;
794     }
795
796 #if 0    
797     printf("msppencrypt hash: ");
798     for (i = 0; i < 16; i++)
799         printf("%02x ", hash[i]);
800     printf("\n");
801 #endif
802     
803     for (i = 0; i < 16; i++)
804         text[i] ^= hash[i];
805     
806     for (offset = 16; offset < len; offset += 16) {
807 #if 0   
808         printf("text + offset - 16 c(%d): ", offset / 16);
809         for (i = 0; i < 16; i++)
810             printf("%02x ", (text + offset - 16)[i]);
811         printf("\n");
812 #endif
813         if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
814             !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
815             !EVP_DigestUpdate(&mdctx, text + offset - 16, 16) ||
816             !EVP_DigestFinal_ex(&mdctx, hash, &md_len) ||
817             md_len != 16) {
818             pthread_mutex_unlock(&lock);
819             return 0;
820         }
821 #if 0   
822         printf("msppencrypt hash: ");
823         for (i = 0; i < 16; i++)
824             printf("%02x ", hash[i]);
825         printf("\n");
826 #endif    
827         
828         for (i = 0; i < 16; i++)
829             text[offset + i] ^= hash[i];
830     }
831     
832 #if 0
833     printf("msppencrypt out: ");
834     for (i = 0; i < len; i++)
835         printf("%02x ", text[i]);
836     printf("\n");
837 #endif
838
839     pthread_mutex_unlock(&lock);
840     return 1;
841 }
842
843 int msmppdecrypt(uint8_t *text, uint8_t len, uint8_t *shared, uint8_t sharedlen, uint8_t *auth, uint8_t *salt) {
844     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
845     static unsigned char first = 1;
846     static EVP_MD_CTX mdctx;
847     unsigned char hash[EVP_MAX_MD_SIZE];
848     unsigned int md_len;
849     uint8_t i, offset;
850     char plain[255];
851     
852     pthread_mutex_lock(&lock);
853     if (first) {
854         EVP_MD_CTX_init(&mdctx);
855         first = 0;
856     }
857
858 #if 0    
859     printf("msppdecrypt auth in: ");
860     for (i = 0; i < 16; i++)
861         printf("%02x ", auth[i]);
862     printf("\n");
863     
864     printf("msppedecrypt salt in: ");
865     for (i = 0; i < 2; i++)
866         printf("%02x ", salt[i]);
867     printf("\n");
868     
869     printf("msppedecrypt in: ");
870     for (i = 0; i < len; i++)
871         printf("%02x ", text[i]);
872     printf("\n");
873 #endif
874     
875     if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
876         !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
877         !EVP_DigestUpdate(&mdctx, auth, 16) ||
878         !EVP_DigestUpdate(&mdctx, salt, 2) ||
879         !EVP_DigestFinal_ex(&mdctx, hash, &md_len)) {
880         pthread_mutex_unlock(&lock);
881         return 0;
882     }
883
884 #if 0    
885     printf("msppedecrypt hash: ");
886     for (i = 0; i < 16; i++)
887         printf("%02x ", hash[i]);
888     printf("\n");
889 #endif
890     
891     for (i = 0; i < 16; i++)
892         plain[i] = text[i] ^ hash[i];
893     
894     for (offset = 16; offset < len; offset += 16) {
895 #if 0   
896         printf("text + offset - 16 c(%d): ", offset / 16);
897         for (i = 0; i < 16; i++)
898             printf("%02x ", (text + offset - 16)[i]);
899         printf("\n");
900 #endif
901         if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
902             !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
903             !EVP_DigestUpdate(&mdctx, text + offset - 16, 16) ||
904             !EVP_DigestFinal_ex(&mdctx, hash, &md_len) ||
905             md_len != 16) {
906             pthread_mutex_unlock(&lock);
907             return 0;
908         }
909 #if 0   
910     printf("msppedecrypt hash: ");
911     for (i = 0; i < 16; i++)
912         printf("%02x ", hash[i]);
913     printf("\n");
914 #endif    
915
916     for (i = 0; i < 16; i++)
917         plain[offset + i] = text[offset + i] ^ hash[i];
918     }
919
920     memcpy(text, plain, len);
921 #if 0
922     printf("msppedecrypt out: ");
923     for (i = 0; i < len; i++)
924         printf("%02x ", text[i]);
925     printf("\n");
926 #endif
927
928     pthread_mutex_unlock(&lock);
929     return 1;
930 }
931
932 struct server *id2server(char *id, uint8_t len) {
933     int i;
934     char **realm, *idrealm;
935
936     idrealm = strchr(id, '@');
937     if (idrealm) {
938         idrealm++;
939         len -= idrealm - id;
940     } else {
941         idrealm = "-";
942         len = 1;
943     }
944     for (i = 0; i < server_count; i++) {
945         for (realm = servers[i].realms; *realm; realm++) {
946             if ((strlen(*realm) == 1 && **realm == '*') ||
947                 (strlen(*realm) == len && !memcmp(idrealm, *realm, len))) {
948                 printf("found matching realm: %s, host %s\n", *realm, servers[i].peer.host);
949                 return servers + i;
950             }
951         }
952     }
953     return NULL;
954 }
955
956 int rqinqueue(struct server *to, struct client *from, uint8_t id) {
957     int i;
958     
959     pthread_mutex_lock(&to->newrq_mutex);
960     for (i = 0; i < MAX_REQUESTS; i++)
961         if (to->requests[i].buf && to->requests[i].origid == id && to->requests[i].from == from)
962             break;
963     pthread_mutex_unlock(&to->newrq_mutex);
964     
965     return i < MAX_REQUESTS;
966 }
967         
968 struct server *radsrv(struct request *rq, char *buf, struct client *from) {
969     uint8_t code, id, *auth, *attr, attrvallen;
970     uint8_t *usernameattr = NULL, *userpwdattr = NULL, *tunnelpwdattr = NULL, *messageauthattr = NULL;
971     int i;
972     uint16_t len;
973     int left;
974     struct server *to;
975     unsigned char newauth[16];
976     
977     code = *(uint8_t *)buf;
978     id = *(uint8_t *)(buf + 1);
979     len = RADLEN(buf);
980     auth = (uint8_t *)(buf + 4);
981
982     printf("radsrv: code %d, id %d, length %d\n", code, id, len);
983     
984     if (code != RAD_Access_Request) {
985         printf("radsrv: server currently accepts only access-requests, ignoring\n");
986         return NULL;
987     }
988
989     left = len - 20;
990     attr = buf + 20;
991     
992     while (left > 1) {
993         left -= attr[RAD_Attr_Length];
994         if (left < 0) {
995             printf("radsrv: attribute length exceeds packet length, ignoring packet\n");
996             return NULL;
997         }
998         switch (attr[RAD_Attr_Type]) {
999         case RAD_Attr_User_Name:
1000             usernameattr = attr;
1001             break;
1002         case RAD_Attr_User_Password:
1003             userpwdattr = attr;
1004             break;
1005         case RAD_Attr_Tunnel_Password:
1006             tunnelpwdattr = attr;
1007             break;
1008         case RAD_Attr_Message_Authenticator:
1009             messageauthattr = attr;
1010             break;
1011         }
1012         attr += attr[RAD_Attr_Length];
1013     }
1014     if (left)
1015         printf("radsrv: malformed packet? remaining byte after last attribute\n");
1016
1017     if (usernameattr) {
1018         printf("radsrv: Username: ");
1019         for (i = 0; i < usernameattr[RAD_Attr_Length] - 2; i++)
1020             printf("%c", usernameattr[RAD_Attr_Value + i]);
1021         printf("\n");
1022     }
1023
1024     to = id2server(&usernameattr[RAD_Attr_Value], usernameattr[RAD_Attr_Length] - 2);
1025     if (!to) {
1026         printf("radsrv: ignoring request, don't know where to send it\n");
1027         return NULL;
1028     }
1029
1030     if (rqinqueue(to, from, id)) {
1031         printf("radsrv: ignoring request from host %s with id %d, already got one\n", from->peer.host, id);
1032         return NULL;
1033     }
1034     
1035     if (messageauthattr && (messageauthattr[RAD_Attr_Length] != 18 ||
1036                             !checkmessageauth(buf, &messageauthattr[RAD_Attr_Value], from->peer.secret))) {
1037         printf("radsrv: message authentication failed\n");
1038         return NULL;
1039     }
1040
1041     if (!RAND_bytes(newauth, 16)) {
1042         printf("radsrv: failed to generate random auth\n");
1043         return NULL;
1044     }
1045
1046     printauth("auth", auth);
1047     printauth("newauth", newauth);
1048     
1049     if (userpwdattr) {
1050         printf("radsrv: found userpwdattr of length %d\n", userpwdattr[RAD_Attr_Length]);
1051         attrvallen = userpwdattr[RAD_Attr_Length] - 2;
1052         if (attrvallen < 16 || attrvallen > 128 || attrvallen % 16) {
1053             printf("radsrv: invalid user password length\n");
1054             return NULL;
1055         }
1056         
1057         if (!pwddecrypt(&userpwdattr[RAD_Attr_Value], attrvallen, from->peer.secret, strlen(from->peer.secret), auth)) {
1058             printf("radsrv: cannot decrypt password\n");
1059             return NULL;
1060         }
1061         printf("radsrv: password: ");
1062         for (i = 0; i < attrvallen; i++)
1063             printf("%02x ", userpwdattr[RAD_Attr_Value + i]);
1064         printf("\n");
1065         if (!pwdencrypt(&userpwdattr[RAD_Attr_Value], attrvallen, to->peer.secret, strlen(to->peer.secret), newauth)) {
1066             printf("radsrv: cannot encrypt password\n");
1067             return NULL;
1068         }
1069     }
1070
1071     if (tunnelpwdattr) {
1072         printf("radsrv: found tunnelpwdattr of length %d\n", tunnelpwdattr[RAD_Attr_Length]);
1073         attrvallen = tunnelpwdattr[RAD_Attr_Length] - 2;
1074         if (attrvallen < 16 || attrvallen > 128 || attrvallen % 16) {
1075             printf("radsrv: invalid user password length\n");
1076             return NULL;
1077         }
1078         
1079         if (!pwddecrypt(&tunnelpwdattr[RAD_Attr_Value], attrvallen, from->peer.secret, strlen(from->peer.secret), auth)) {
1080             printf("radsrv: cannot decrypt password\n");
1081             return NULL;
1082         }
1083         printf("radsrv: password: ");
1084         for (i = 0; i < attrvallen; i++)
1085             printf("%02x ", tunnelpwdattr[RAD_Attr_Value + i]);
1086         printf("\n");
1087         if (!pwdencrypt(&tunnelpwdattr[RAD_Attr_Value], attrvallen, to->peer.secret, strlen(to->peer.secret), newauth)) {
1088             printf("radsrv: cannot encrypt password\n");
1089             return NULL;
1090         }
1091     }
1092
1093     rq->buf = buf;
1094     rq->from = from;
1095     rq->origid = id;
1096     rq->messageauthattrval = (messageauthattr ? &messageauthattr[RAD_Attr_Value] : NULL);
1097     memcpy(rq->origauth, auth, 16);
1098     memcpy(auth, newauth, 16);
1099     printauth("rq->origauth", rq->origauth);
1100     printauth("auth", auth);
1101     return to;
1102 }
1103
1104 void *clientrd(void *arg) {
1105     struct server *server = (struct server *)arg;
1106     struct client *from;
1107     int i, left, subleft;
1108     unsigned char *buf, *messageauthattr, *subattr, *attr;
1109     struct sockaddr_storage fromsa;
1110     struct timeval lastconnecttry;
1111     char tmp[255];
1112     
1113     for (;;) {
1114     getnext:
1115         lastconnecttry = server->lastconnecttry;
1116         buf = (server->peer.type == 'U' ? radudpget(server->sock, NULL, &server, NULL) : radtlsget(server->peer.ssl));
1117         if (!buf && server->peer.type == 'T') {
1118             tlsconnect(server, &lastconnecttry, "clientrd");
1119             continue;
1120         }
1121     
1122         server->connectionok = 1;
1123
1124         if (*buf != RAD_Access_Accept && *buf != RAD_Access_Reject && *buf != RAD_Access_Challenge) {
1125             printf("clientrd: discarding, only accept access accept, access reject and access challenge messages\n");
1126             continue;
1127         }
1128         
1129         i = buf[1]; /* i is the id */
1130
1131         pthread_mutex_lock(&server->newrq_mutex);
1132         if (!server->requests[i].buf || !server->requests[i].tries) {
1133             pthread_mutex_unlock(&server->newrq_mutex);
1134             printf("clientrd: no matching request sent with this id, ignoring\n");
1135             continue;
1136         }
1137
1138         if (server->requests[i].received) {
1139             pthread_mutex_unlock(&server->newrq_mutex);
1140             printf("clientrd: already received, ignoring\n");
1141             continue;
1142         }
1143         
1144         if (!validauth(buf, server->requests[i].buf + 4, server->peer.secret)) {
1145             pthread_mutex_unlock(&server->newrq_mutex);
1146             printf("clientrd: invalid auth, ignoring\n");
1147             continue;
1148         }
1149         
1150         from = server->requests[i].from;
1151
1152
1153         /* messageauthattr present? */
1154         messageauthattr = NULL;
1155         left = RADLEN(buf) - 20;
1156         attr = buf + 20;
1157         while (left > 1) {
1158             left -= attr[RAD_Attr_Length];
1159             if (left < 0) {
1160                 printf("clientrd: attribute length exceeds packet length, ignoring packet\n");
1161                 goto getnext;
1162             }
1163             if (attr[RAD_Attr_Type] == RAD_Attr_Message_Authenticator) {
1164                 if (attr[RAD_Attr_Length] != 18) {
1165                     printf("clientrd: illegal message auth attribute length, ignoring packet\n");
1166                     goto getnext;
1167                 }
1168                 memcpy(tmp, buf + 4, 16);
1169                 memcpy(buf + 4, server->requests[i].buf + 4, 16);
1170                 if (!checkmessageauth(buf, &attr[RAD_Attr_Value], server->peer.secret)) {
1171                     printf("clientrd: message authentication failed\n");
1172                     goto getnext;
1173                 }
1174                 memcpy(buf + 4, tmp, 16);
1175                 printf("clientrd: message auth ok\n");
1176                 messageauthattr = attr;
1177                 break;
1178             }
1179             attr += attr[RAD_Attr_Length];
1180         }
1181
1182         /* handle MS MPPE */
1183         left = RADLEN(buf) - 20;
1184         attr = buf + 20;
1185         while (left > 1) {
1186             left -= attr[RAD_Attr_Length];
1187             if (left < 0) {
1188                 printf("clientrd: attribute length exceeds packet length, ignoring packet\n");
1189                 goto getnext;
1190             }
1191             if (attr[RAD_Attr_Type] == RAD_Attr_Vendor_Specific &&
1192                 ((uint16_t *)attr)[1] == 0 && ntohs(((uint16_t *)attr)[2]) == 311) { // 311 == MS
1193                 subleft = attr[RAD_Attr_Length] - 6;
1194                 subattr = attr + 6;
1195                 while (subleft > 1) {
1196                     subleft -= subattr[RAD_Attr_Length];
1197                     if (subleft < 0)
1198                         break;
1199                     if (subattr[RAD_Attr_Type] != RAD_VS_ATTR_MS_MPPE_Send_Key &&
1200                         subattr[RAD_Attr_Type] != RAD_VS_ATTR_MS_MPPE_Recv_Key)
1201                         continue;
1202                     printf("clientrd: Got MS MPPE\n");
1203                     if (subattr[RAD_Attr_Length] < 20)
1204                         continue;
1205
1206                     if (!msmppdecrypt(subattr + 4, subattr[RAD_Attr_Length] - 4,
1207                             server->peer.secret, strlen(server->peer.secret), server->requests[i].buf + 4, subattr + 2)) {
1208                         printf("clientrd: failed to decrypt msppe key\n");
1209                         continue;
1210                     }
1211
1212                     if (!msmppencrypt(subattr + 4, subattr[RAD_Attr_Length] - 4,
1213                             from->peer.secret, strlen(from->peer.secret), server->requests[i].origauth, subattr + 2)) {
1214                         printf("clientrd: failed to encrypt msppe key\n");
1215                         continue;
1216                     }
1217                 }
1218                 if (subleft < 0) {
1219                     printf("clientrd: bad vendor specific attr or subattr length, ignoring packet\n");
1220                     goto getnext;
1221                 }
1222             }
1223             attr += attr[RAD_Attr_Length];
1224         }
1225
1226         /* once we set received = 1, requests[i] may be reused */
1227         buf[1] = (char)server->requests[i].origid;
1228         memcpy(buf + 4, server->requests[i].origauth, 16);
1229         printauth("origauth/buf+4", buf + 4);
1230         if (messageauthattr) {
1231             if (!createmessageauth(buf, &messageauthattr[RAD_Attr_Value], from->peer.secret))
1232                 continue;
1233             printf("clientrd: computed messageauthattr\n");
1234         }
1235
1236         if (from->peer.type == 'U')
1237             fromsa = server->requests[i].fromsa;
1238         server->requests[i].received = 1;
1239         pthread_mutex_unlock(&server->newrq_mutex);
1240
1241         if (!radsign(buf, from->peer.secret)) {
1242             printf("clientrd: failed to sign message\n");
1243             continue;
1244         }
1245         printauth("signedorigauth/buf+4", buf + 4);             
1246         printf("clientrd: giving packet back to where it came from\n");
1247         sendreply(from, server, buf, from->peer.type == 'U' ? &fromsa : NULL);
1248     }
1249 }
1250
1251 void *clientwr(void *arg) {
1252     struct server *server = (struct server *)arg;
1253     struct request *rq;
1254     pthread_t clientrdth;
1255     int i;
1256     struct timeval now;
1257     struct timespec timeout;
1258
1259     memset(&timeout, 0, sizeof(struct timespec));
1260     
1261     if (server->peer.type == 'U') {
1262         if ((server->sock = connecttoserver(server->peer.addrinfo)) < 0) {
1263             printf("clientwr: connecttoserver failed\n");
1264             exit(1);
1265         }
1266     } else
1267         tlsconnect(server, NULL, "new client");
1268     
1269     if (pthread_create(&clientrdth, NULL, clientrd, (void *)server))
1270         errx("clientwr: pthread_create failed");
1271
1272     for (;;) {
1273         pthread_mutex_lock(&server->newrq_mutex);
1274         if (!server->newrq) {
1275             if (timeout.tv_nsec) {
1276                 printf("clientwr: waiting up to %ld secs for new request\n", timeout.tv_nsec);
1277                 pthread_cond_timedwait(&server->newrq_cond, &server->newrq_mutex, &timeout);
1278                 timeout.tv_nsec = 0;
1279             } else {
1280                 printf("clientwr: waiting for new request\n");
1281                 pthread_cond_wait(&server->newrq_cond, &server->newrq_mutex);
1282             }
1283             if (server->newrq) {
1284                 printf("clientwr: got new request\n");
1285                 server->newrq = 0;
1286             }
1287         }
1288         pthread_mutex_unlock(&server->newrq_mutex);
1289         
1290         for (i = 0; i < MAX_REQUESTS; i++) {
1291             pthread_mutex_lock(&server->newrq_mutex);
1292             while (!server->requests[i].buf && i < MAX_REQUESTS)
1293                 i++;
1294             if (i == MAX_REQUESTS) {
1295                 pthread_mutex_unlock(&server->newrq_mutex);
1296                 break;
1297             }
1298             rq = server->requests + i;
1299
1300             if (rq->received) {
1301                 printf("clientwr: removing received packet from queue\n");
1302                 free(rq->buf);
1303                 /* setting this to NULL means that it can be reused */
1304                 rq->buf = NULL;
1305                 pthread_mutex_unlock(&server->newrq_mutex);
1306                 continue;
1307             }
1308             
1309             gettimeofday(&now, NULL);
1310             if (now.tv_sec <= rq->expiry.tv_sec) {
1311                 if (!timeout.tv_sec || rq->expiry.tv_sec < timeout.tv_sec)
1312                     timeout.tv_sec = rq->expiry.tv_sec;
1313                 pthread_mutex_unlock(&server->newrq_mutex);
1314                 continue;
1315             }
1316
1317             if (rq->tries == (server->peer.type == 'T' ? 1 : REQUEST_RETRIES)) {
1318                 printf("clientwr: removing expired packet from queue\n");
1319                 free(rq->buf);
1320                 /* setting this to NULL means that it can be reused */
1321                 rq->buf = NULL;
1322                 pthread_mutex_unlock(&server->newrq_mutex);
1323                 continue;
1324             }
1325             pthread_mutex_unlock(&server->newrq_mutex);
1326
1327             rq->expiry.tv_sec = now.tv_sec +
1328                 (server->peer.type == 'T' ? REQUEST_EXPIRY : REQUEST_EXPIRY / REQUEST_RETRIES);
1329             if (!timeout.tv_sec || rq->expiry.tv_sec < timeout.tv_sec)
1330                 timeout.tv_sec = rq->expiry.tv_sec;
1331             rq->tries++;
1332             clientradput(server, server->requests[i].buf);
1333             usleep(200000);
1334         }
1335     }
1336     /* should do more work to maintain TLS connections, keepalives etc */
1337 }
1338
1339 void *udpserverwr(void *arg) {
1340     struct replyq *replyq = &udp_server_replyq;
1341     struct reply *reply = replyq->replies;
1342     
1343     pthread_mutex_lock(&replyq->count_mutex);
1344     for (;;) {
1345         while (!replyq->count) {
1346             printf("udp server writer, waiting for signal\n");
1347             pthread_cond_wait(&replyq->count_cond, &replyq->count_mutex);
1348             printf("udp server writer, got signal\n");
1349         }
1350         pthread_mutex_unlock(&replyq->count_mutex);
1351         
1352         if (sendto(udp_server_sock, reply->buf, RADLEN(reply->buf), 0,
1353                    (struct sockaddr *)&reply->tosa, SOCKADDR_SIZE(reply->tosa)) < 0)
1354             err("sendudp: send failed");
1355         free(reply->buf);
1356         
1357         pthread_mutex_lock(&replyq->count_mutex);
1358         replyq->count--;
1359         memmove(replyq->replies, replyq->replies + 1,
1360                 replyq->count * sizeof(struct reply));
1361     }
1362 }
1363
1364 void *udpserverrd(void *arg) {
1365     struct request rq;
1366     unsigned char *buf;
1367     struct server *to;
1368     struct client *fr;
1369     pthread_t udpserverwrth;
1370     
1371     if ((udp_server_sock = bindport(SOCK_DGRAM, options.udpserverport)) < 0) {
1372         printf("udpserverrd: socket/bind failed\n");
1373         exit(1);
1374     }
1375     printf("udpserverrd: listening on UDP port %s\n", options.udpserverport);
1376
1377     if (pthread_create(&udpserverwrth, NULL, udpserverwr, NULL))
1378         errx("pthread_create failed");
1379     
1380     for (;;) {
1381         fr = NULL;
1382         memset(&rq, 0, sizeof(struct request));
1383         buf = radudpget(udp_server_sock, &fr, NULL, &rq.fromsa);
1384         to = radsrv(&rq, buf, fr);
1385         if (!to) {
1386             printf("udpserverrd: ignoring request, no place to send it\n");
1387             continue;
1388         }
1389         sendrq(to, fr, &rq);
1390     }
1391 }
1392
1393 void *tlsserverwr(void *arg) {
1394     int cnt;
1395     unsigned long error;
1396     struct client *client = (struct client *)arg;
1397     struct replyq *replyq;
1398     
1399     printf("tlsserverwr starting for %s\n", client->peer.host);
1400     replyq = client->replyq;
1401     pthread_mutex_lock(&replyq->count_mutex);
1402     for (;;) {
1403         while (!replyq->count) {
1404             if (client->peer.ssl) {         
1405                 printf("tls server writer, waiting for signal\n");
1406                 pthread_cond_wait(&replyq->count_cond, &replyq->count_mutex);
1407                 printf("tls server writer, got signal\n");
1408             }
1409             if (!client->peer.ssl) {
1410                 //ssl might have changed while waiting
1411                 pthread_mutex_unlock(&replyq->count_mutex);
1412                 printf("tlsserverwr: exiting as requested\n");
1413                 pthread_exit(NULL);
1414             }
1415         }
1416         pthread_mutex_unlock(&replyq->count_mutex);
1417         cnt = SSL_write(client->peer.ssl, replyq->replies->buf, RADLEN(replyq->replies->buf));
1418         if (cnt > 0)
1419             printf("tlsserverwr: Sent %d bytes, Radius packet of length %d\n",
1420                    cnt, RADLEN(replyq->replies->buf));
1421         else
1422             while ((error = ERR_get_error()))
1423                 err("tlsserverwr: SSL: %s", ERR_error_string(error, NULL));
1424         free(replyq->replies->buf);
1425
1426         pthread_mutex_lock(&replyq->count_mutex);
1427         replyq->count--;
1428         memmove(replyq->replies, replyq->replies + 1, replyq->count * sizeof(struct reply));
1429     }
1430 }
1431
1432 void *tlsserverrd(void *arg) {
1433     struct request rq;
1434     char unsigned *buf;
1435     unsigned long error;
1436     struct server *to;
1437     int s;
1438     struct client *client = (struct client *)arg;
1439     pthread_t tlsserverwrth;
1440     SSL *ssl;
1441     
1442     printf("tlsserverrd starting for %s\n", client->peer.host);
1443     ssl = client->peer.ssl;
1444
1445     if (SSL_accept(ssl) <= 0) {
1446         while ((error = ERR_get_error()))
1447             err("tlsserverrd: SSL: %s", ERR_error_string(error, NULL));
1448         errx("accept failed, child exiting");
1449     }
1450
1451     if (tlsverifycert(&client->peer)) {
1452         if (pthread_create(&tlsserverwrth, NULL, tlsserverwr, (void *)client))
1453             errx("pthread_create failed");
1454     
1455         for (;;) {
1456             buf = radtlsget(client->peer.ssl);
1457             if (!buf)
1458                 break;
1459             printf("tlsserverrd: got Radius message from %s\n", client->peer.host);
1460             memset(&rq, 0, sizeof(struct request));
1461             to = radsrv(&rq, buf, client);
1462             if (!to) {
1463                 printf("ignoring request, no place to send it\n");
1464                 continue;
1465             }
1466             sendrq(to, client, &rq);
1467         }
1468         printf("tlsserverrd: connection lost\n");
1469         // stop writer by setting peer.ssl to NULL and give signal in case waiting for data
1470         client->peer.ssl = NULL;
1471         pthread_mutex_lock(&client->replyq->count_mutex);
1472         pthread_cond_signal(&client->replyq->count_cond);
1473         pthread_mutex_unlock(&client->replyq->count_mutex);
1474         printf("tlsserverrd: waiting for writer to end\n");
1475         pthread_join(tlsserverwrth, NULL);
1476     }
1477     s = SSL_get_fd(ssl);
1478     SSL_free(ssl);
1479     close(s);
1480     printf("tlsserverrd thread for %s exiting\n", client->peer.host);
1481     client->peer.ssl = NULL;
1482     pthread_exit(NULL);
1483 }
1484
1485 int tlslistener() {
1486     pthread_t tlsserverth;
1487     int s, snew;
1488     struct sockaddr_storage from;
1489     size_t fromlen = sizeof(from);
1490     struct client *client;
1491
1492     if ((s = bindport(SOCK_STREAM, DEFAULT_TLS_PORT)) < 0) {
1493         printf("tlslistener: socket/bind failed\n");
1494         exit(1);
1495     }
1496     
1497     listen(s, 0);
1498     printf("listening for incoming TLS on port %s\n", DEFAULT_TLS_PORT);
1499
1500     for (;;) {
1501         snew = accept(s, (struct sockaddr *)&from, &fromlen);
1502         if (snew < 0)
1503             errx("accept failed");
1504         printf("incoming TLS connection from %s\n", addr2string((struct sockaddr *)&from, fromlen));
1505
1506         client = find_client('T', (struct sockaddr *)&from, NULL);
1507         if (!client) {
1508             printf("ignoring request, not a known TLS client\n");
1509             shutdown(snew, SHUT_RDWR);
1510             close(snew);
1511             continue;
1512         }
1513
1514         if (client->peer.ssl) {
1515             printf("Ignoring incoming connection, already have one from this client\n");
1516             shutdown(snew, SHUT_RDWR);
1517             close(snew);
1518             continue;
1519         }
1520         client->peer.ssl = SSL_new(ssl_ctx);
1521         SSL_set_fd(client->peer.ssl, snew);
1522         if (pthread_create(&tlsserverth, NULL, tlsserverrd, (void *)client))
1523             errx("pthread_create failed");
1524         pthread_detach(tlsserverth);
1525     }
1526     return 0;
1527 }
1528
1529 char *parsehostport(char *s, struct peer *peer) {
1530     char *p, *field;
1531     int ipv6 = 0;
1532
1533     p = s;
1534     // allow literal addresses and port, e.g. [2001:db8::1]:1812
1535     if (*p == '[') {
1536         p++;
1537         field = p;
1538         for (; *p && *p != ']' && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1539         if (*p != ']') {
1540             printf("no ] matching initial [\n");
1541             exit(1);
1542         }
1543         ipv6 = 1;
1544     } else {
1545         field = p;
1546         for (; *p && *p != ':' && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1547     }
1548     if (field == p) {
1549         printf("missing host/address\n");
1550         exit(1);
1551     }
1552     peer->host = stringcopy(field, p - field);
1553     if (ipv6) {
1554         p++;
1555         if (*p && *p != ':' && *p != ' ' && *p != '\t' && *p != '\n') {
1556             printf("unexpected character after ]\n");
1557             exit(1);
1558         }
1559     }
1560     if (*p == ':') {
1561             /* port number or service name is specified */;
1562             field = p++;
1563             for (; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1564             if (field == p) {
1565                 printf("syntax error, : but no following port\n");
1566                 exit(1);
1567             }
1568             peer->port = stringcopy(field, p - field);
1569     } else
1570         peer->port = stringcopy(peer->type == 'U' ? DEFAULT_UDP_PORT : DEFAULT_TLS_PORT, 0);
1571     return p;
1572 }
1573
1574 // * is default, else longest match ... ";" used for separator
1575 char *parserealmlist(char *s, struct server *server) {
1576     char *p;
1577     int i, n, l;
1578
1579     for (p = s, n = 1; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++)
1580         if (*p == ';')
1581             n++;
1582     l = p - s;
1583     if (!l) {
1584         printf("realm list must be specified\n");
1585         exit(1);
1586     }
1587     server->realmdata = stringcopy(s, l);
1588     server->realms = malloc((1+n) * sizeof(char *));
1589     if (!server->realms)
1590         errx("malloc failed");
1591     server->realms[0] = server->realmdata;
1592     for (n = 1, i = 0; i < l; i++)
1593         if (server->realmdata[i] == ';') {
1594             server->realmdata[i] = '\0';
1595             server->realms[n++] = server->realmdata + i + 1;
1596         }       
1597     server->realms[n] = NULL;
1598     return p;
1599 }
1600
1601 /* exactly one argument must be non-NULL */
1602 void getconfig(const char *serverfile, const char *clientfile) {
1603     FILE *f;
1604     char line[1024];
1605     char *p, *field, **r;
1606     const char *file;
1607     struct client *client;
1608     struct server *server;
1609     struct peer *peer;
1610     int i, count, *ucount, *tcount;
1611  
1612     file = serverfile ? serverfile : clientfile;
1613     f = fopen(file, "r");
1614     if (!f)
1615         errx("getconfig failed to open %s for reading", file);
1616     printf("opening file %s for reading\n", file);
1617     if (serverfile) {
1618         ucount = &server_udp_count;
1619         tcount = &server_tls_count;
1620     } else {
1621         ucount = &client_udp_count;
1622         tcount = &client_tls_count;
1623     }
1624     while (fgets(line, 1024, f)) {
1625         for (p = line; *p == ' ' || *p == '\t'; p++);
1626         switch (*p) {
1627         case '#':
1628         case '\n':
1629             break;
1630         case 'T':
1631             (*tcount)++;
1632             break;
1633         case 'U':
1634             (*ucount)++;
1635             break;
1636         default:
1637             printf("type must be U or T, got %c\n", *p);
1638             exit(1);
1639         }
1640     }
1641
1642     if (serverfile) {
1643         count = server_count = server_udp_count + server_tls_count;
1644         servers = calloc(count, sizeof(struct server));
1645         if (!servers)
1646             errx("malloc failed");
1647     } else {
1648         count = client_count = client_udp_count + client_tls_count;
1649         clients = calloc(count, sizeof(struct client));
1650         if (!clients)
1651             errx("malloc failed");
1652     }
1653     
1654     if (client_udp_count) {
1655         udp_server_replyq.replies = malloc(client_udp_count * MAX_REQUESTS * sizeof(struct reply));
1656         if (!udp_server_replyq.replies)
1657             errx("malloc failed");
1658         udp_server_replyq.size = client_udp_count * MAX_REQUESTS;
1659         udp_server_replyq.count = 0;
1660         pthread_mutex_init(&udp_server_replyq.count_mutex, NULL);
1661         pthread_cond_init(&udp_server_replyq.count_cond, NULL);
1662     }    
1663     
1664     rewind(f);
1665     for (i = 0; i < count && fgets(line, 1024, f);) {
1666         if (serverfile) {
1667             server = &servers[i];
1668             peer = &server->peer;
1669         } else {
1670             client = &clients[i];
1671             peer = &client->peer;
1672         }
1673         for (p = line; *p == ' ' || *p == '\t'; p++);
1674         if (*p == '#' || *p == '\n')
1675             continue;
1676         peer->type = *p;        // we already know it must be U or T
1677         for (p++; *p == ' ' || *p == '\t'; p++);
1678         p = parsehostport(p, peer);
1679         for (; *p == ' ' || *p == '\t'; p++);
1680         if (serverfile) {
1681             p = parserealmlist(p, server);
1682             for (; *p == ' ' || *p == '\t'; p++);
1683         }
1684         field = p;
1685         for (; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1686         if (field == p) {
1687             /* no secret set and end of line, line is complete if TLS */
1688             if (peer->type == 'U') {
1689                 printf("secret must be specified for UDP\n");
1690                 exit(1);
1691             }
1692             peer->secret = stringcopy(DEFAULT_TLS_SECRET, 0);
1693         } else {
1694             peer->secret = stringcopy(field, p - field);
1695             /* check that rest of line only white space */
1696             for (; *p == ' ' || *p == '\t'; p++);
1697             if (*p && *p != '\n') {
1698                 printf("max 4 fields per line, found a 5th\n");
1699                 exit(1);
1700             }
1701         }
1702
1703         if ((serverfile && !resolvepeer(&server->peer)) ||
1704             (clientfile && !resolvepeer(&client->peer))) {
1705             printf("failed to resolve host %s port %s, exiting\n", peer->host, peer->port);
1706             exit(1);
1707         }
1708
1709         if (serverfile) {
1710             pthread_mutex_init(&server->lock, NULL);
1711             server->sock = -1;
1712             server->requests = calloc(MAX_REQUESTS, sizeof(struct request));
1713             if (!server->requests)
1714                 errx("malloc failed");
1715             server->newrq = 0;
1716             pthread_mutex_init(&server->newrq_mutex, NULL);
1717             pthread_cond_init(&server->newrq_cond, NULL);
1718         } else {
1719             if (peer->type == 'U')
1720                 client->replyq = &udp_server_replyq;
1721             else {
1722                 client->replyq = malloc(sizeof(struct replyq));
1723                 if (!client->replyq)
1724                     errx("malloc failed");
1725                 client->replyq->replies = calloc(MAX_REQUESTS, sizeof(struct reply));
1726                 if (!client->replyq->replies)
1727                     errx("malloc failed");
1728                 client->replyq->size = MAX_REQUESTS;
1729                 client->replyq->count = 0;
1730                 pthread_mutex_init(&client->replyq->count_mutex, NULL);
1731                 pthread_cond_init(&client->replyq->count_cond, NULL);
1732             }
1733         }
1734         printf("got type %c, host %s, port %s, secret %s\n", peer->type, peer->host, peer->port, peer->secret);
1735         if (serverfile) {
1736             printf("    with realms:");
1737             for (r = server->realms; *r; r++)
1738                 printf(" %s", *r);
1739             printf("\n");
1740         }
1741         i++;
1742     }
1743     fclose(f);
1744 }
1745
1746 void getmainconfig(const char *configfile) {
1747     FILE *f;
1748     char line[1024];
1749     char *p, *opt, *endopt, *val, *endval;
1750     
1751     printf("opening file %s for reading\n", configfile);
1752     f = fopen(configfile, "r");
1753     if (!f)
1754         errx("getmainconfig failed to open %s for reading", configfile);
1755
1756     memset(&options, 0, sizeof(options));
1757
1758     while (fgets(line, 1024, f)) {
1759         for (p = line; *p == ' ' || *p == '\t'; p++);
1760         if (!*p || *p == '#' || *p == '\n')
1761             continue;
1762         opt = p++;
1763         for (; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1764         endopt = p - 1;
1765         for (; *p == ' ' || *p == '\t'; p++);
1766         if (!*p || *p == '\n') {
1767             endopt[1] = '\0';
1768             printf("error in %s, option %s has no value\n", configfile, opt);
1769             exit(1);
1770         }
1771         val = p;
1772         for (; *p && *p != '\n'; p++)
1773             if (*p != ' ' && *p != '\t')
1774                 endval = p;
1775         endopt[1] = '\0';
1776         endval[1] = '\0';
1777         printf("getmainconfig: %s = %s\n", opt, val);
1778         
1779         if (!strcasecmp(opt, "TLSCACertificateFile")) {
1780             options.tlscacertificatefile = stringcopy(val, 0);
1781             continue;
1782         }
1783         if (!strcasecmp(opt, "TLSCACertificatePath")) {
1784             options.tlscacertificatepath = stringcopy(val, 0);
1785             continue;
1786         }
1787         if (!strcasecmp(opt, "TLSCertificateFile")) {
1788             options.tlscertificatefile = stringcopy(val, 0);
1789             continue;
1790         }
1791         if (!strcasecmp(opt, "TLSCertificateKeyFile")) {
1792             options.tlscertificatekeyfile = stringcopy(val, 0);
1793             continue;
1794         }
1795         if (!strcasecmp(opt, "UDPServerPort")) {
1796             options.udpserverport = stringcopy(val, 0);
1797             continue;
1798         }
1799
1800         printf("error in %s, unknown option %s\n", configfile, opt);
1801         exit(1);
1802     }
1803     fclose(f);
1804
1805     if (!options.udpserverport)
1806         options.udpserverport = stringcopy(DEFAULT_UDP_PORT, 0);
1807 }
1808
1809 #if 0
1810 void parseargs(int argc, char **argv) {
1811     int c;
1812
1813     while ((c = getopt(argc, argv, "p:")) != -1) {
1814         switch (c) {
1815         case 'p':
1816             udp_server_port = optarg;
1817             break;
1818         default:
1819             goto usage;
1820         }
1821     }
1822
1823     return;
1824
1825  usage:
1826     printf("radsecproxy [ -p UDP-port ]\n");
1827     exit(1);
1828 }
1829 #endif
1830
1831 int main(int argc, char **argv) {
1832     pthread_t udpserverth;
1833     //    pthread_attr_t joinable;
1834     int i;
1835     
1836     //    parseargs(argc, argv);
1837     getmainconfig("radsecproxy.conf");
1838     getconfig("servers.conf", NULL);
1839     getconfig(NULL, "clients.conf");
1840
1841     //    pthread_attr_init(&joinable);
1842     //    pthread_attr_setdetachstate(&joinable, PTHREAD_CREATE_JOINABLE);
1843    
1844     if (client_udp_count)
1845         if (pthread_create(&udpserverth, NULL /*&joinable*/, udpserverrd, NULL))
1846             errx("pthread_create failed");
1847
1848     if (client_tls_count || server_tls_count)
1849         ssl_ctx = ssl_init();
1850     
1851     for (i = 0; i < server_count; i++)
1852         if (pthread_create(&servers[i].clientth, NULL, clientwr, (void *)&servers[i]))
1853             errx("pthread_create failed");
1854
1855     if (client_tls_count)
1856         return tlslistener();
1857
1858     /* just hang around doing nothing, anything to do here? */
1859     for (;;)
1860         sleep(1000);
1861 }