removed tls test code
[radsecproxy.git] / radsecproxy.c
1 /*
2  * Copyright (C) 2006 Stig Venaas <venaas@uninett.no>
3  *
4  * Permission to use, copy, modify, and distribute this software for any
5  * purpose with or without fee is hereby granted, provided that the above
6  * copyright notice and this permission notice appear in all copies.
7  */
8
9 /* TODO:
10  * accounting
11  * radius keep alives (server status)
12  * tls certificate validation, see below urls
13  * clean tls shutdown, see http://www.linuxjournal.com/article/4822
14  *     and http://www.linuxjournal.com/article/5487
15  *     SSL_shutdown() and shutdown()
16  *     If shutdown() we may not need REUSEADDR
17  * when tls client goes away, ensure that all related threads and state
18  *          are removed
19  * setsockopt(keepalive...), check if openssl has some keepalive feature
20 */
21
22 /* For UDP there is one server instance consisting of udpserverrd and udpserverth
23  *              rd is responsible for init and launching wr
24  * For TLS there is a server instance that launches tlsserverrd for each TLS peer
25  *          each tlsserverrd launches tlsserverwr
26  * For each UDP/TLS peer there is clientrd and clientwr, clientwr is responsible
27  *          for init and launching rd
28  *
29  * serverrd will receive a request, processes it and puts it in the requestq of
30  *          the appropriate clientwr
31  * clientwr monitors its requestq and sends requests
32  * clientrd looks for responses, processes them and puts them in the replyq of
33  *          the peer the request came from
34  * serverwr monitors its reply and sends replies
35  *
36  * In addition to the main thread, we have:
37  * If UDP peers are configured, there will be 2 + 2 * #peers UDP threads
38  * If TLS peers are configured, there will initially be 2 * #peers TLS threads
39  * For each TLS peer connecting to us there will be 2 more TLS threads
40  *       This is only for connected peers
41  * Example: With 3 UDP peer and 30 TLS peers, there will be a max of
42  *          1 + (2 + 2 * 3) + (2 * 30) + (2 * 30) = 129 threads
43 */
44
45 #include <netdb.h>
46 #include <unistd.h>
47 #include <sys/time.h>
48 #include <pthread.h>
49 #include <openssl/ssl.h>
50 #include <openssl/rand.h>
51 #include <openssl/err.h>
52 #include <openssl/md5.h>
53 #include <openssl/hmac.h>
54 #include "radsecproxy.h"
55
56 static struct options options;
57 static struct client *clients;
58 static struct server *servers;
59
60 static int client_udp_count = 0;
61 static int client_tls_count = 0;
62 static int client_count = 0;
63 static int server_udp_count = 0;
64 static int server_tls_count = 0;
65 static int server_count = 0;
66
67 static struct replyq udp_server_replyq;
68 static int udp_server_sock = -1;
69 static pthread_mutex_t *ssl_locks;
70 static long *ssl_lock_count;
71 static SSL_CTX *ssl_ctx_cl = NULL;
72 static SSL_CTX *ssl_ctx_srv = NULL;
73 extern int optind;
74 extern char *optarg;
75
76 /* callbacks for making OpenSSL thread safe */
77 unsigned long ssl_thread_id() {
78         return (unsigned long)pthread_self();
79 };
80
81 void ssl_locking_callback(int mode, int type, const char *file, int line) {
82     if (mode & CRYPTO_LOCK) {
83         pthread_mutex_lock(&ssl_locks[type]);
84         ssl_lock_count[type]++;
85     } else
86         pthread_mutex_unlock(&ssl_locks[type]);
87 }
88
89 static int verify_cb(int ok, X509_STORE_CTX *ctx) {
90   char buf[256];
91   X509 *err_cert;
92   int err, depth;
93
94   err_cert = X509_STORE_CTX_get_current_cert(ctx);
95   err = X509_STORE_CTX_get_error(ctx);
96   depth = X509_STORE_CTX_get_error_depth(ctx);
97
98   if (depth > MAX_CERT_DEPTH) {
99       ok = 0;
100       err = X509_V_ERR_CERT_CHAIN_TOO_LONG;
101       X509_STORE_CTX_set_error(ctx, err);
102   }
103
104   if (!ok) {
105       X509_NAME_oneline(X509_get_subject_name(err_cert), buf, 256);
106       printf("verify error: num=%d:%s:depth=%d:%s\n", err, X509_verify_cert_error_string(err), depth, buf);
107
108       switch (err) {
109       case X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT:
110           X509_NAME_oneline(X509_get_issuer_name(ctx->current_cert), buf, 256);
111           printf("issuer=%s\n", buf);
112           break;
113       case X509_V_ERR_CERT_NOT_YET_VALID:
114       case X509_V_ERR_ERROR_IN_CERT_NOT_BEFORE_FIELD:
115           printf("Certificate not yet valid\n");
116           break;
117       case X509_V_ERR_CERT_HAS_EXPIRED:
118           printf("Certificate has expired\n");
119           break;
120       case X509_V_ERR_ERROR_IN_CERT_NOT_AFTER_FIELD:
121           printf("Certificate no longer valid (after notAfter)\n");
122           break;
123       }
124   }
125   //  printf("certificate verify returns %d\n", ok);
126   return ok;
127 }
128
129 void ssl_init(SSL_CTX **ctx_srv, SSL_CTX **ctx_cl) {
130     int i;
131     unsigned long error;
132     
133     if (!options.tlscertificatefile || !options.tlscertificatekeyfile) {
134         printf("TLSCertificateFile and TLSCertificateKeyFile must be specified for TLS\n");
135         exit(1);
136     }
137     if (!options.tlscacertificatefile && !options.tlscacertificatepath) {
138         printf("CA Certificate file/path need to be configured\n");
139         exit(1);
140     }
141
142     ssl_locks = malloc(CRYPTO_num_locks() * sizeof(pthread_mutex_t));
143     ssl_lock_count = OPENSSL_malloc(CRYPTO_num_locks() * sizeof(long));
144     for (i = 0; i < CRYPTO_num_locks(); i++) {
145         ssl_lock_count[i] = 0;
146         pthread_mutex_init(&ssl_locks[i], NULL);
147     }
148     CRYPTO_set_id_callback(ssl_thread_id);
149     CRYPTO_set_locking_callback(ssl_locking_callback);
150
151     SSL_load_error_strings();
152     SSL_library_init();
153
154     while (!RAND_status()) {
155         time_t t = time(NULL);
156         pid_t pid = getpid();
157         RAND_seed((unsigned char *)&t, sizeof(time_t));
158         RAND_seed((unsigned char *)&pid, sizeof(pid));
159     }
160
161     if (ctx_srv) {
162         *ctx_srv = SSL_CTX_new(TLSv1_server_method());
163         if (!SSL_CTX_use_certificate_chain_file(*ctx_srv, options.tlscertificatefile) ||
164             !SSL_CTX_use_PrivateKey_file(*ctx_srv, options.tlscertificatekeyfile, SSL_FILETYPE_PEM) ||
165             !SSL_CTX_check_private_key(*ctx_srv))
166             goto errexit;
167         if (!SSL_CTX_load_verify_locations(*ctx_srv, options.tlscacertificatefile, options.tlscacertificatepath))
168             goto errexit;
169         SSL_CTX_set_verify(*ctx_srv, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, verify_cb);
170         SSL_CTX_set_verify_depth(*ctx_srv, MAX_CERT_DEPTH + 1);
171     }
172     if (ctx_cl) {
173         *ctx_cl = SSL_CTX_new(TLSv1_client_method());
174         if (!SSL_CTX_use_certificate_chain_file(*ctx_cl, options.tlscertificatefile) ||
175             !SSL_CTX_use_PrivateKey_file(*ctx_cl, options.tlscertificatekeyfile, SSL_FILETYPE_PEM) ||
176             !SSL_CTX_check_private_key(*ctx_cl))
177             goto errexit;
178         if (!SSL_CTX_load_verify_locations(*ctx_cl, options.tlscacertificatefile, options.tlscacertificatepath))
179             goto errexit;
180         SSL_CTX_set_verify(*ctx_cl, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, verify_cb);
181         SSL_CTX_set_verify_depth(*ctx_cl, MAX_CERT_DEPTH + 1);
182     }
183     return;
184     
185  errexit:
186     while ((error = ERR_get_error()))
187         err("SSL: %s", ERR_error_string(error, NULL));
188     exit(1);
189 }    
190
191 void printauth(char *s, unsigned char *t) {
192     int i;
193     printf("%s:", s);
194     for (i = 0; i < 16; i++)
195             printf("%02x ", t[i]);
196     printf("\n");
197 }
198
199 int resolvepeer(struct peer *peer) {
200     struct addrinfo hints, *addrinfo;
201     
202     memset(&hints, 0, sizeof(hints));
203     hints.ai_socktype = (peer->type == 'T' ? SOCK_STREAM : SOCK_DGRAM);
204     hints.ai_family = AF_UNSPEC;
205     if (getaddrinfo(peer->host, peer->port, &hints, &addrinfo)) {
206         err("resolvepeer: can't resolve %s port %s", peer->host, peer->port);
207         return 0;
208     }
209
210     if (peer->addrinfo)
211         freeaddrinfo(peer->addrinfo);
212     peer->addrinfo = addrinfo;
213     return 1;
214 }         
215
216 int connecttoserver(struct addrinfo *addrinfo) {
217     int s;
218     struct addrinfo *res;
219     
220     for (res = addrinfo; res; res = res->ai_next) {
221         s = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
222         if (s < 0) {
223             err("connecttoserver: socket failed");
224             continue;
225         }
226         if (connect(s, res->ai_addr, res->ai_addrlen) == 0)
227             break;
228         err("connecttoserver: connect failed");
229         close(s);
230         s = -1;
231     }
232     return s;
233 }         
234
235 /* returns the client with matching address, or NULL */
236 /* if client argument is not NULL, we only check that one client */
237 struct client *find_client(char type, struct sockaddr *addr, struct client *client) {
238     struct sockaddr_in6 *sa6;
239     struct in_addr *a4 = NULL;
240     struct client *c;
241     int i;
242     struct addrinfo *res;
243
244     if (addr->sa_family == AF_INET6) {
245         sa6 = (struct sockaddr_in6 *)addr;
246         if (IN6_IS_ADDR_V4MAPPED(&sa6->sin6_addr))
247             a4 = (struct in_addr *)&sa6->sin6_addr.s6_addr[12];
248     } else
249         a4 = &((struct sockaddr_in *)addr)->sin_addr;
250
251     c = (client ? client : clients);
252     for (i = 0; i < client_count; i++) {
253         if (c->peer.type == type)
254             for (res = c->peer.addrinfo; res; res = res->ai_next)
255                 if ((a4 && res->ai_family == AF_INET &&
256                      !memcmp(a4, &((struct sockaddr_in *)res->ai_addr)->sin_addr, 4)) ||
257                     (res->ai_family == AF_INET6 &&
258                      !memcmp(&sa6->sin6_addr, &((struct sockaddr_in6 *)res->ai_addr)->sin6_addr, 16)))
259                     return c;
260         if (client)
261             break;
262         c++;
263     }
264     return NULL;
265 }
266
267 /* returns the server with matching address, or NULL */
268 /* if server argument is not NULL, we only check that one server */
269 struct server *find_server(char type, struct sockaddr *addr, struct server *server) {
270     struct sockaddr_in6 *sa6;
271     struct in_addr *a4 = NULL;
272     struct server *s;
273     int i;
274     struct addrinfo *res;
275
276     if (addr->sa_family == AF_INET6) {
277         sa6 = (struct sockaddr_in6 *)addr;
278         if (IN6_IS_ADDR_V4MAPPED(&sa6->sin6_addr))
279             a4 = (struct in_addr *)&sa6->sin6_addr.s6_addr[12];
280     } else
281         a4 = &((struct sockaddr_in *)addr)->sin_addr;
282
283     s = (server ? server : servers);
284     for (i = 0; i < server_count; i++) {
285         if (s->peer.type == type)
286             for (res = s->peer.addrinfo; res; res = res->ai_next)
287                 if ((a4 && res->ai_family == AF_INET &&
288                      !memcmp(a4, &((struct sockaddr_in *)res->ai_addr)->sin_addr, 4)) ||
289                     (res->ai_family == AF_INET6 &&
290                      !memcmp(&sa6->sin6_addr, &((struct sockaddr_in6 *)res->ai_addr)->sin6_addr, 16)))
291                     return s;
292         if (server)
293             break;
294         s++;
295     }
296     return NULL;
297 }
298
299 /* exactly one of client and server must be non-NULL */
300 /* if *peer == NULL we return who we received from, else require it to be from peer */
301 /* return from in sa if not NULL */
302 unsigned char *radudpget(int s, struct client **client, struct server **server, struct sockaddr_storage *sa) {
303     int cnt, len;
304     void *f;
305     unsigned char buf[65536], *rad;
306     struct sockaddr_storage from;
307     socklen_t fromlen = sizeof(from);
308
309     for (;;) {
310         cnt = recvfrom(s, buf, sizeof(buf), 0, (struct sockaddr *)&from, &fromlen);
311         if (cnt == -1) {
312             err("radudpget: recv failed");
313             continue;
314         }
315         printf("radudpget: got %d bytes from %s\n", cnt, addr2string((struct sockaddr *)&from, fromlen));
316
317         if (cnt < 20) {
318             printf("radudpget: packet too small\n");
319             continue;
320         }
321     
322         len = RADLEN(buf);
323
324         if (cnt < len) {
325             printf("radudpget: packet smaller than length field in radius header\n");
326             continue;
327         }
328         if (cnt > len)
329             printf("radudpget: packet was padded with %d bytes\n", cnt - len);
330
331         f = (client
332              ? (void *)find_client('U', (struct sockaddr *)&from, *client)
333              : (void *)find_server('U', (struct sockaddr *)&from, *server));
334         if (!f) {
335             printf("radudpget: got packet from wrong or unknown UDP peer, ignoring\n");
336             continue;
337         }
338
339         rad = malloc(len);
340         if (rad)
341             break;
342         err("radudpget: malloc failed");
343     }
344     memcpy(rad, buf, len);
345     if (client)
346         *client = (struct client *)f; /* only need this if *client == NULL, but if not NULL *client == f here */
347     else
348         *server = (struct server *)f; /* only need this if *server == NULL, but if not NULL *server == f here */
349     if (sa)
350         *sa = from;
351     return rad;
352 }
353
354 int tlsverifycert(struct peer *peer) {
355     int i, l, loc;
356     X509 *cert;
357     X509_NAME *nm;
358     X509_NAME_ENTRY *e;
359     unsigned char *v;
360     unsigned long error;
361
362 #if 1
363     if (SSL_get_verify_result(peer->ssl) != X509_V_OK) {
364         printf("tlsverifycert: basic validation failed\n");
365         while ((error = ERR_get_error()))
366             err("clientwr: TLS: %s", ERR_error_string(error, NULL));
367         return 0;
368     }
369 #endif    
370     cert = SSL_get_peer_certificate(peer->ssl);
371     if (!cert) {
372         printf("tlsverifycert: failed to obtain certificate\n");
373         return 0;
374     }
375     nm = X509_get_subject_name(cert);
376     loc = -1;
377     for (;;) {
378         loc = X509_NAME_get_index_by_NID(nm, NID_commonName, loc);
379         if (loc == -1)
380             break;
381         e = X509_NAME_get_entry(nm, loc);
382         l = ASN1_STRING_to_UTF8(&v, X509_NAME_ENTRY_get_data(e));
383         if (l < 0)
384             continue;
385         printf("cn: ");
386         for (i = 0; i < l; i++)
387             printf("%c", v[i]);
388         printf("\n");
389         if (l == strlen(peer->host) && !strncasecmp(peer->host, v, l)) {
390             printf("tlsverifycert: Found cn matching host %s, All OK\n", peer->host);
391             return 1;
392         }
393         printf("tlsverifycert: cn not matching host %s\n", peer->host);
394     }
395     X509_free(cert);
396     return 0;
397 }
398
399 void tlsconnect(struct server *server, struct timeval *when, char *text) {
400     struct timeval now;
401     time_t elapsed;
402
403     printf("tlsconnect called from %s\n", text);
404     pthread_mutex_lock(&server->lock);
405     if (when && memcmp(&server->lastconnecttry, when, sizeof(struct timeval))) {
406         /* already reconnected, nothing to do */
407         printf("tlsconnect(%s): seems already reconnected\n", text);
408         pthread_mutex_unlock(&server->lock);
409         return;
410     }
411
412     printf("tlsconnect %s\n", text);
413
414     for (;;) {
415         gettimeofday(&now, NULL);
416         elapsed = now.tv_sec - server->lastconnecttry.tv_sec;
417         if (server->connectionok) {
418             server->connectionok = 0;
419             sleep(10);
420         } else if (elapsed < 5)
421             sleep(10);
422         else if (elapsed < 600)
423             sleep(elapsed * 2);
424         else if (elapsed < 10000)
425                 sleep(900);
426         else
427             server->lastconnecttry.tv_sec = now.tv_sec;  // no sleep at startup
428         printf("tlsconnect: trying to open TLS connection to %s port %s\n", server->peer.host, server->peer.port);
429         if (server->sock >= 0)
430             close(server->sock);
431         if ((server->sock = connecttoserver(server->peer.addrinfo)) < 0)
432             continue;
433         SSL_free(server->peer.ssl);
434         server->peer.ssl = SSL_new(ssl_ctx_cl);
435         SSL_set_fd(server->peer.ssl, server->sock);
436         if (SSL_connect(server->peer.ssl) > 0 && tlsverifycert(&server->peer))
437             break;
438     }
439     printf("tlsconnect: TLS connection to %s port %s up\n", server->peer.host, server->peer.port);
440     gettimeofday(&server->lastconnecttry, NULL);
441     pthread_mutex_unlock(&server->lock);
442 }
443
444 unsigned char *radtlsget(SSL *ssl) {
445     int cnt, total, len;
446     unsigned char buf[4], *rad;
447
448     for (;;) {
449         for (total = 0; total < 4; total += cnt) {
450             cnt = SSL_read(ssl, buf + total, 4 - total);
451             if (cnt <= 0) {
452                 printf("radtlsget: connection lost\n");
453                 if (SSL_get_error(ssl, cnt) == SSL_ERROR_ZERO_RETURN) {
454                     //remote end sent close_notify, send one back
455                     SSL_shutdown(ssl);
456                 }
457                 return NULL;
458             }
459         }
460
461         len = RADLEN(buf);
462         rad = malloc(len);
463         if (!rad) {
464             err("radtlsget: malloc failed");
465             continue;
466         }
467         memcpy(rad, buf, 4);
468
469         for (; total < len; total += cnt) {
470             cnt = SSL_read(ssl, rad + total, len - total);
471             if (cnt <= 0) {
472                 printf("radtlsget: connection lost\n");
473                 if (SSL_get_error(ssl, cnt) == SSL_ERROR_ZERO_RETURN) {
474                     //remote end sent close_notify, send one back
475                     SSL_shutdown(ssl);
476                 }
477                 free(rad);
478                 return NULL;
479             }
480         }
481     
482         if (total >= 20)
483             break;
484         
485         free(rad);
486         printf("radtlsget: packet smaller than minimum radius size\n");
487     }
488     
489     printf("radtlsget: got %d bytes\n", total);
490     return rad;
491 }
492
493 int clientradput(struct server *server, unsigned char *rad) {
494     int cnt;
495     size_t len;
496     unsigned long error;
497     struct timeval lastconnecttry;
498     
499     len = RADLEN(rad);
500     if (server->peer.type == 'U') {
501         if (send(server->sock, rad, len, 0) >= 0) {
502             printf("clienradput: sent UDP of length %d to %s port %s\n", len, server->peer.host, server->peer.port);
503             return 1;
504         }
505         err("clientradput: send failed");
506         return 0;
507     }
508
509     lastconnecttry = server->lastconnecttry;
510     while ((cnt = SSL_write(server->peer.ssl, rad, len)) <= 0) {
511         while ((error = ERR_get_error()))
512             err("clientwr: TLS: %s", ERR_error_string(error, NULL));
513         tlsconnect(server, &lastconnecttry, "clientradput");
514         lastconnecttry = server->lastconnecttry;
515     }
516
517     server->connectionok = 1;
518     printf("clientradput: Sent %d bytes, Radius packet of length %d to TLS peer %s\n",
519            cnt, len, server->peer.host);
520     return 1;
521 }
522
523 int radsign(unsigned char *rad, unsigned char *sec) {
524     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
525     static unsigned char first = 1;
526     static EVP_MD_CTX mdctx;
527     unsigned int md_len;
528     int result;
529     
530     pthread_mutex_lock(&lock);
531     if (first) {
532         EVP_MD_CTX_init(&mdctx);
533         first = 0;
534     }
535
536     result = (EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) &&
537         EVP_DigestUpdate(&mdctx, rad, RADLEN(rad)) &&
538         EVP_DigestUpdate(&mdctx, sec, strlen(sec)) &&
539         EVP_DigestFinal_ex(&mdctx, rad + 4, &md_len) &&
540         md_len == 16);
541     pthread_mutex_unlock(&lock);
542     return result;
543 }
544
545 int validauth(unsigned char *rad, unsigned char *reqauth, unsigned char *sec) {
546     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
547     static unsigned char first = 1;
548     static EVP_MD_CTX mdctx;
549     unsigned char hash[EVP_MAX_MD_SIZE];
550     unsigned int len;
551     int result;
552     
553     pthread_mutex_lock(&lock);
554     if (first) {
555         EVP_MD_CTX_init(&mdctx);
556         first = 0;
557     }
558
559     len = RADLEN(rad);
560     
561     result = (EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) &&
562               EVP_DigestUpdate(&mdctx, rad, 4) &&
563               EVP_DigestUpdate(&mdctx, reqauth, 16) &&
564               (len <= 20 || EVP_DigestUpdate(&mdctx, rad + 20, len - 20)) &&
565               EVP_DigestUpdate(&mdctx, sec, strlen(sec)) &&
566               EVP_DigestFinal_ex(&mdctx, hash, &len) &&
567               len == 16 &&
568               !memcmp(hash, rad + 4, 16));
569     pthread_mutex_unlock(&lock);
570     return result;
571 }
572               
573 int checkmessageauth(char *rad, uint8_t *authattr, char *secret) {
574     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
575     static unsigned char first = 1;
576     static HMAC_CTX hmacctx;
577     unsigned int md_len;
578     uint8_t auth[16], hash[EVP_MAX_MD_SIZE];
579     
580     pthread_mutex_lock(&lock);
581     if (first) {
582         HMAC_CTX_init(&hmacctx);
583         first = 0;
584     }
585
586     memcpy(auth, authattr, 16);
587     memset(authattr, 0, 16);
588     md_len = 0;
589     HMAC_Init_ex(&hmacctx, secret, strlen(secret), EVP_md5(), NULL);
590     HMAC_Update(&hmacctx, rad, RADLEN(rad));
591     HMAC_Final(&hmacctx, hash, &md_len);
592     memcpy(authattr, auth, 16);
593     if (md_len != 16) {
594         printf("message auth computation failed\n");
595         pthread_mutex_unlock(&lock);
596         return 0;
597     }
598
599     if (memcmp(auth, hash, 16)) {
600         printf("message authenticator, wrong value\n");
601         pthread_mutex_unlock(&lock);
602         return 0;
603     }   
604         
605     pthread_mutex_unlock(&lock);
606     return 1;
607 }
608
609 int createmessageauth(char *rad, char *authattrval, char *secret) {
610     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
611     static unsigned char first = 1;
612     static HMAC_CTX hmacctx;
613     unsigned int md_len;
614
615     if (!authattrval)
616         return 1;
617     
618     pthread_mutex_lock(&lock);
619     if (first) {
620         HMAC_CTX_init(&hmacctx);
621         first = 0;
622     }
623
624     memset(authattrval, 0, 16);
625     md_len = 0;
626     HMAC_Init_ex(&hmacctx, secret, strlen(secret), EVP_md5(), NULL);
627     HMAC_Update(&hmacctx, rad, RADLEN(rad));
628     HMAC_Final(&hmacctx, authattrval, &md_len);
629     if (md_len != 16) {
630         printf("message auth computation failed\n");
631         pthread_mutex_unlock(&lock);
632         return 0;
633     }
634
635     pthread_mutex_unlock(&lock);
636     return 1;
637 }
638
639 void sendrq(struct server *to, struct client *from, struct request *rq) {
640     int i;
641     
642     pthread_mutex_lock(&to->newrq_mutex);
643     /* might simplify if only try nextid, might be ok */
644     for (i = to->nextid; i < MAX_REQUESTS; i++)
645         if (!to->requests[i].buf)
646             break;
647     if (i == MAX_REQUESTS) {
648         for (i = 0; i < to->nextid; i++)
649             if (!to->requests[i].buf)
650                 break;
651         if (i == to->nextid) {
652             printf("No room in queue, dropping request\n");
653             pthread_mutex_unlock(&to->newrq_mutex);
654             return;
655         }
656     }
657     
658     to->nextid = i + 1;
659     rq->buf[1] = (char)i;
660     printf("sendrq: inserting packet with id %d in queue for %s\n", i, to->peer.host);
661     
662     if (!createmessageauth(rq->buf, rq->messageauthattrval, to->peer.secret))
663         return;
664
665     to->requests[i] = *rq;
666
667     if (!to->newrq) {
668         to->newrq = 1;
669         printf("signalling client writer\n");
670         pthread_cond_signal(&to->newrq_cond);
671     }
672     pthread_mutex_unlock(&to->newrq_mutex);
673 }
674
675 void sendreply(struct client *to, struct server *from, char *buf, struct sockaddr_storage *tosa) {
676     struct replyq *replyq = to->replyq;
677     
678     pthread_mutex_lock(&replyq->count_mutex);
679     if (replyq->count == replyq->size) {
680         printf("No room in queue, dropping request\n");
681         pthread_mutex_unlock(&replyq->count_mutex);
682         return;
683     }
684
685     replyq->replies[replyq->count].buf = buf;
686     if (tosa)
687         replyq->replies[replyq->count].tosa = *tosa;
688     replyq->count++;
689
690     if (replyq->count == 1) {
691         printf("signalling client writer\n");
692         pthread_cond_signal(&replyq->count_cond);
693     }
694     pthread_mutex_unlock(&replyq->count_mutex);
695 }
696
697 int pwdencrypt(uint8_t *in, uint8_t len, uint8_t *shared, uint8_t sharedlen, uint8_t *auth) {
698     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
699     static unsigned char first = 1;
700     static EVP_MD_CTX mdctx;
701     unsigned char hash[EVP_MAX_MD_SIZE], *input;
702     unsigned int md_len;
703     uint8_t i, offset = 0, out[128];
704     
705     pthread_mutex_lock(&lock);
706     if (first) {
707         EVP_MD_CTX_init(&mdctx);
708         first = 0;
709     }
710
711     input = auth;
712     for (;;) {
713         if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
714             !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
715             !EVP_DigestUpdate(&mdctx, input, 16) ||
716             !EVP_DigestFinal_ex(&mdctx, hash, &md_len) ||
717             md_len != 16) {
718             pthread_mutex_unlock(&lock);
719             return 0;
720         }
721         for (i = 0; i < 16; i++)
722             out[offset + i] = hash[i] ^ in[offset + i];
723         input = out + offset - 16;
724         offset += 16;
725         if (offset == len)
726             break;
727     }
728     memcpy(in, out, len);
729     pthread_mutex_unlock(&lock);
730     return 1;
731 }
732
733 int pwddecrypt(uint8_t *in, uint8_t len, uint8_t *shared, uint8_t sharedlen, uint8_t *auth) {
734     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
735     static unsigned char first = 1;
736     static EVP_MD_CTX mdctx;
737     unsigned char hash[EVP_MAX_MD_SIZE], *input;
738     unsigned int md_len;
739     uint8_t i, offset = 0, out[128];
740     
741     pthread_mutex_lock(&lock);
742     if (first) {
743         EVP_MD_CTX_init(&mdctx);
744         first = 0;
745     }
746
747     input = auth;
748     for (;;) {
749         if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
750             !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
751             !EVP_DigestUpdate(&mdctx, input, 16) ||
752             !EVP_DigestFinal_ex(&mdctx, hash, &md_len) ||
753             md_len != 16) {
754             pthread_mutex_unlock(&lock);
755             return 0;
756         }
757         for (i = 0; i < 16; i++)
758             out[offset + i] = hash[i] ^ in[offset + i];
759         input = in + offset;
760         offset += 16;
761         if (offset == len)
762             break;
763     }
764     memcpy(in, out, len);
765     pthread_mutex_unlock(&lock);
766     return 1;
767 }
768
769 int msmppencrypt(uint8_t *text, uint8_t len, uint8_t *shared, uint8_t sharedlen, uint8_t *auth, uint8_t *salt) {
770     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
771     static unsigned char first = 1;
772     static EVP_MD_CTX mdctx;
773     unsigned char hash[EVP_MAX_MD_SIZE];
774     unsigned int md_len;
775     uint8_t i, offset;
776     
777     pthread_mutex_lock(&lock);
778     if (first) {
779         EVP_MD_CTX_init(&mdctx);
780         first = 0;
781     }
782
783 #if 0    
784     printf("msppencrypt auth in: ");
785     for (i = 0; i < 16; i++)
786         printf("%02x ", auth[i]);
787     printf("\n");
788     
789     printf("msppencrypt salt in: ");
790     for (i = 0; i < 2; i++)
791         printf("%02x ", salt[i]);
792     printf("\n");
793     
794     printf("msppencrypt in: ");
795     for (i = 0; i < len; i++)
796         printf("%02x ", text[i]);
797     printf("\n");
798 #endif
799     
800     if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
801         !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
802         !EVP_DigestUpdate(&mdctx, auth, 16) ||
803         !EVP_DigestUpdate(&mdctx, salt, 2) ||
804         !EVP_DigestFinal_ex(&mdctx, hash, &md_len)) {
805         pthread_mutex_unlock(&lock);
806         return 0;
807     }
808
809 #if 0    
810     printf("msppencrypt hash: ");
811     for (i = 0; i < 16; i++)
812         printf("%02x ", hash[i]);
813     printf("\n");
814 #endif
815     
816     for (i = 0; i < 16; i++)
817         text[i] ^= hash[i];
818     
819     for (offset = 16; offset < len; offset += 16) {
820 #if 0   
821         printf("text + offset - 16 c(%d): ", offset / 16);
822         for (i = 0; i < 16; i++)
823             printf("%02x ", (text + offset - 16)[i]);
824         printf("\n");
825 #endif
826         if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
827             !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
828             !EVP_DigestUpdate(&mdctx, text + offset - 16, 16) ||
829             !EVP_DigestFinal_ex(&mdctx, hash, &md_len) ||
830             md_len != 16) {
831             pthread_mutex_unlock(&lock);
832             return 0;
833         }
834 #if 0   
835         printf("msppencrypt hash: ");
836         for (i = 0; i < 16; i++)
837             printf("%02x ", hash[i]);
838         printf("\n");
839 #endif    
840         
841         for (i = 0; i < 16; i++)
842             text[offset + i] ^= hash[i];
843     }
844     
845 #if 0
846     printf("msppencrypt out: ");
847     for (i = 0; i < len; i++)
848         printf("%02x ", text[i]);
849     printf("\n");
850 #endif
851
852     pthread_mutex_unlock(&lock);
853     return 1;
854 }
855
856 int msmppdecrypt(uint8_t *text, uint8_t len, uint8_t *shared, uint8_t sharedlen, uint8_t *auth, uint8_t *salt) {
857     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
858     static unsigned char first = 1;
859     static EVP_MD_CTX mdctx;
860     unsigned char hash[EVP_MAX_MD_SIZE];
861     unsigned int md_len;
862     uint8_t i, offset;
863     char plain[255];
864     
865     pthread_mutex_lock(&lock);
866     if (first) {
867         EVP_MD_CTX_init(&mdctx);
868         first = 0;
869     }
870
871 #if 0    
872     printf("msppdecrypt auth in: ");
873     for (i = 0; i < 16; i++)
874         printf("%02x ", auth[i]);
875     printf("\n");
876     
877     printf("msppedecrypt salt in: ");
878     for (i = 0; i < 2; i++)
879         printf("%02x ", salt[i]);
880     printf("\n");
881     
882     printf("msppedecrypt in: ");
883     for (i = 0; i < len; i++)
884         printf("%02x ", text[i]);
885     printf("\n");
886 #endif
887     
888     if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
889         !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
890         !EVP_DigestUpdate(&mdctx, auth, 16) ||
891         !EVP_DigestUpdate(&mdctx, salt, 2) ||
892         !EVP_DigestFinal_ex(&mdctx, hash, &md_len)) {
893         pthread_mutex_unlock(&lock);
894         return 0;
895     }
896
897 #if 0    
898     printf("msppedecrypt hash: ");
899     for (i = 0; i < 16; i++)
900         printf("%02x ", hash[i]);
901     printf("\n");
902 #endif
903     
904     for (i = 0; i < 16; i++)
905         plain[i] = text[i] ^ hash[i];
906     
907     for (offset = 16; offset < len; offset += 16) {
908 #if 0   
909         printf("text + offset - 16 c(%d): ", offset / 16);
910         for (i = 0; i < 16; i++)
911             printf("%02x ", (text + offset - 16)[i]);
912         printf("\n");
913 #endif
914         if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
915             !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
916             !EVP_DigestUpdate(&mdctx, text + offset - 16, 16) ||
917             !EVP_DigestFinal_ex(&mdctx, hash, &md_len) ||
918             md_len != 16) {
919             pthread_mutex_unlock(&lock);
920             return 0;
921         }
922 #if 0   
923     printf("msppedecrypt hash: ");
924     for (i = 0; i < 16; i++)
925         printf("%02x ", hash[i]);
926     printf("\n");
927 #endif    
928
929     for (i = 0; i < 16; i++)
930         plain[offset + i] = text[offset + i] ^ hash[i];
931     }
932
933     memcpy(text, plain, len);
934 #if 0
935     printf("msppedecrypt out: ");
936     for (i = 0; i < len; i++)
937         printf("%02x ", text[i]);
938     printf("\n");
939 #endif
940
941     pthread_mutex_unlock(&lock);
942     return 1;
943 }
944
945 struct server *id2server(char *id, uint8_t len) {
946     int i;
947     char **realm, *idrealm;
948
949     idrealm = strchr(id, '@');
950     if (idrealm) {
951         idrealm++;
952         len -= idrealm - id;
953     } else {
954         idrealm = "-";
955         len = 1;
956     }
957     for (i = 0; i < server_count; i++) {
958         for (realm = servers[i].realms; *realm; realm++) {
959             if ((strlen(*realm) == 1 && **realm == '*') ||
960                 (strlen(*realm) == len && !memcmp(idrealm, *realm, len))) {
961                 printf("found matching realm: %s, host %s\n", *realm, servers[i].peer.host);
962                 return servers + i;
963             }
964         }
965     }
966     return NULL;
967 }
968
969 int rqinqueue(struct server *to, struct client *from, uint8_t id) {
970     int i;
971     
972     pthread_mutex_lock(&to->newrq_mutex);
973     for (i = 0; i < MAX_REQUESTS; i++)
974         if (to->requests[i].buf && to->requests[i].origid == id && to->requests[i].from == from)
975             break;
976     pthread_mutex_unlock(&to->newrq_mutex);
977     
978     return i < MAX_REQUESTS;
979 }
980         
981 struct server *radsrv(struct request *rq, char *buf, struct client *from) {
982     uint8_t code, id, *auth, *attr, attrvallen;
983     uint8_t *usernameattr = NULL, *userpwdattr = NULL, *tunnelpwdattr = NULL, *messageauthattr = NULL;
984     int i;
985     uint16_t len;
986     int left;
987     struct server *to;
988     unsigned char newauth[16];
989     
990     code = *(uint8_t *)buf;
991     id = *(uint8_t *)(buf + 1);
992     len = RADLEN(buf);
993     auth = (uint8_t *)(buf + 4);
994
995     printf("radsrv: code %d, id %d, length %d\n", code, id, len);
996     
997     if (code != RAD_Access_Request) {
998         printf("radsrv: server currently accepts only access-requests, ignoring\n");
999         return NULL;
1000     }
1001
1002     left = len - 20;
1003     attr = buf + 20;
1004     
1005     while (left > 1) {
1006         left -= attr[RAD_Attr_Length];
1007         if (left < 0) {
1008             printf("radsrv: attribute length exceeds packet length, ignoring packet\n");
1009             return NULL;
1010         }
1011         switch (attr[RAD_Attr_Type]) {
1012         case RAD_Attr_User_Name:
1013             usernameattr = attr;
1014             break;
1015         case RAD_Attr_User_Password:
1016             userpwdattr = attr;
1017             break;
1018         case RAD_Attr_Tunnel_Password:
1019             tunnelpwdattr = attr;
1020             break;
1021         case RAD_Attr_Message_Authenticator:
1022             messageauthattr = attr;
1023             break;
1024         }
1025         attr += attr[RAD_Attr_Length];
1026     }
1027     if (left)
1028         printf("radsrv: malformed packet? remaining byte after last attribute\n");
1029
1030     if (usernameattr) {
1031         printf("radsrv: Username: ");
1032         for (i = 0; i < usernameattr[RAD_Attr_Length] - 2; i++)
1033             printf("%c", usernameattr[RAD_Attr_Value + i]);
1034         printf("\n");
1035     }
1036
1037     to = id2server(&usernameattr[RAD_Attr_Value], usernameattr[RAD_Attr_Length] - 2);
1038     if (!to) {
1039         printf("radsrv: ignoring request, don't know where to send it\n");
1040         return NULL;
1041     }
1042
1043     if (rqinqueue(to, from, id)) {
1044         printf("radsrv: ignoring request from host %s with id %d, already got one\n", from->peer.host, id);
1045         return NULL;
1046     }
1047     
1048     if (messageauthattr && (messageauthattr[RAD_Attr_Length] != 18 ||
1049                             !checkmessageauth(buf, &messageauthattr[RAD_Attr_Value], from->peer.secret))) {
1050         printf("radsrv: message authentication failed\n");
1051         return NULL;
1052     }
1053
1054     if (!RAND_bytes(newauth, 16)) {
1055         printf("radsrv: failed to generate random auth\n");
1056         return NULL;
1057     }
1058
1059     printauth("auth", auth);
1060     printauth("newauth", newauth);
1061     
1062     if (userpwdattr) {
1063         printf("radsrv: found userpwdattr of length %d\n", userpwdattr[RAD_Attr_Length]);
1064         attrvallen = userpwdattr[RAD_Attr_Length] - 2;
1065         if (attrvallen < 16 || attrvallen > 128 || attrvallen % 16) {
1066             printf("radsrv: invalid user password length\n");
1067             return NULL;
1068         }
1069         
1070         if (!pwddecrypt(&userpwdattr[RAD_Attr_Value], attrvallen, from->peer.secret, strlen(from->peer.secret), auth)) {
1071             printf("radsrv: cannot decrypt password\n");
1072             return NULL;
1073         }
1074         printf("radsrv: password: ");
1075         for (i = 0; i < attrvallen; i++)
1076             printf("%02x ", userpwdattr[RAD_Attr_Value + i]);
1077         printf("\n");
1078         if (!pwdencrypt(&userpwdattr[RAD_Attr_Value], attrvallen, to->peer.secret, strlen(to->peer.secret), newauth)) {
1079             printf("radsrv: cannot encrypt password\n");
1080             return NULL;
1081         }
1082     }
1083
1084     if (tunnelpwdattr) {
1085         printf("radsrv: found tunnelpwdattr of length %d\n", tunnelpwdattr[RAD_Attr_Length]);
1086         attrvallen = tunnelpwdattr[RAD_Attr_Length] - 2;
1087         if (attrvallen < 16 || attrvallen > 128 || attrvallen % 16) {
1088             printf("radsrv: invalid user password length\n");
1089             return NULL;
1090         }
1091         
1092         if (!pwddecrypt(&tunnelpwdattr[RAD_Attr_Value], attrvallen, from->peer.secret, strlen(from->peer.secret), auth)) {
1093             printf("radsrv: cannot decrypt password\n");
1094             return NULL;
1095         }
1096         printf("radsrv: password: ");
1097         for (i = 0; i < attrvallen; i++)
1098             printf("%02x ", tunnelpwdattr[RAD_Attr_Value + i]);
1099         printf("\n");
1100         if (!pwdencrypt(&tunnelpwdattr[RAD_Attr_Value], attrvallen, to->peer.secret, strlen(to->peer.secret), newauth)) {
1101             printf("radsrv: cannot encrypt password\n");
1102             return NULL;
1103         }
1104     }
1105
1106     rq->buf = buf;
1107     rq->from = from;
1108     rq->origid = id;
1109     rq->messageauthattrval = (messageauthattr ? &messageauthattr[RAD_Attr_Value] : NULL);
1110     memcpy(rq->origauth, auth, 16);
1111     memcpy(auth, newauth, 16);
1112     printauth("rq->origauth", rq->origauth);
1113     printauth("auth", auth);
1114     return to;
1115 }
1116
1117 void *clientrd(void *arg) {
1118     struct server *server = (struct server *)arg;
1119     struct client *from;
1120     int i, left, subleft;
1121     unsigned char *buf, *messageauthattr, *subattr, *attr;
1122     struct sockaddr_storage fromsa;
1123     struct timeval lastconnecttry;
1124     char tmp[255];
1125     
1126     for (;;) {
1127     getnext:
1128         lastconnecttry = server->lastconnecttry;
1129         buf = (server->peer.type == 'U' ? radudpget(server->sock, NULL, &server, NULL) : radtlsget(server->peer.ssl));
1130         if (!buf && server->peer.type == 'T') {
1131             tlsconnect(server, &lastconnecttry, "clientrd");
1132             continue;
1133         }
1134     
1135         server->connectionok = 1;
1136
1137         if (*buf != RAD_Access_Accept && *buf != RAD_Access_Reject && *buf != RAD_Access_Challenge) {
1138             printf("clientrd: discarding, only accept access accept, access reject and access challenge messages\n");
1139             continue;
1140         }
1141         
1142         i = buf[1]; /* i is the id */
1143
1144         pthread_mutex_lock(&server->newrq_mutex);
1145         if (!server->requests[i].buf || !server->requests[i].tries) {
1146             pthread_mutex_unlock(&server->newrq_mutex);
1147             printf("clientrd: no matching request sent with this id, ignoring\n");
1148             continue;
1149         }
1150
1151         if (server->requests[i].received) {
1152             pthread_mutex_unlock(&server->newrq_mutex);
1153             printf("clientrd: already received, ignoring\n");
1154             continue;
1155         }
1156         
1157         if (!validauth(buf, server->requests[i].buf + 4, server->peer.secret)) {
1158             pthread_mutex_unlock(&server->newrq_mutex);
1159             printf("clientrd: invalid auth, ignoring\n");
1160             continue;
1161         }
1162         
1163         from = server->requests[i].from;
1164
1165
1166         /* messageauthattr present? */
1167         messageauthattr = NULL;
1168         left = RADLEN(buf) - 20;
1169         attr = buf + 20;
1170         while (left > 1) {
1171             left -= attr[RAD_Attr_Length];
1172             if (left < 0) {
1173                 printf("clientrd: attribute length exceeds packet length, ignoring packet\n");
1174                 goto getnext;
1175             }
1176             if (attr[RAD_Attr_Type] == RAD_Attr_Message_Authenticator) {
1177                 if (attr[RAD_Attr_Length] != 18) {
1178                     printf("clientrd: illegal message auth attribute length, ignoring packet\n");
1179                     goto getnext;
1180                 }
1181                 memcpy(tmp, buf + 4, 16);
1182                 memcpy(buf + 4, server->requests[i].buf + 4, 16);
1183                 if (!checkmessageauth(buf, &attr[RAD_Attr_Value], server->peer.secret)) {
1184                     printf("clientrd: message authentication failed\n");
1185                     goto getnext;
1186                 }
1187                 memcpy(buf + 4, tmp, 16);
1188                 printf("clientrd: message auth ok\n");
1189                 messageauthattr = attr;
1190                 break;
1191             }
1192             attr += attr[RAD_Attr_Length];
1193         }
1194
1195         /* handle MS MPPE */
1196         left = RADLEN(buf) - 20;
1197         attr = buf + 20;
1198         while (left > 1) {
1199             left -= attr[RAD_Attr_Length];
1200             if (left < 0) {
1201                 printf("clientrd: attribute length exceeds packet length, ignoring packet\n");
1202                 goto getnext;
1203             }
1204             if (attr[RAD_Attr_Type] == RAD_Attr_Vendor_Specific &&
1205                 ((uint16_t *)attr)[1] == 0 && ntohs(((uint16_t *)attr)[2]) == 311) { // 311 == MS
1206                 subleft = attr[RAD_Attr_Length] - 6;
1207                 subattr = attr + 6;
1208                 while (subleft > 1) {
1209                     subleft -= subattr[RAD_Attr_Length];
1210                     if (subleft < 0)
1211                         break;
1212                     if (subattr[RAD_Attr_Type] != RAD_VS_ATTR_MS_MPPE_Send_Key &&
1213                         subattr[RAD_Attr_Type] != RAD_VS_ATTR_MS_MPPE_Recv_Key)
1214                         continue;
1215                     printf("clientrd: Got MS MPPE\n");
1216                     if (subattr[RAD_Attr_Length] < 20)
1217                         continue;
1218
1219                     if (!msmppdecrypt(subattr + 4, subattr[RAD_Attr_Length] - 4,
1220                             server->peer.secret, strlen(server->peer.secret), server->requests[i].buf + 4, subattr + 2)) {
1221                         printf("clientrd: failed to decrypt msppe key\n");
1222                         continue;
1223                     }
1224
1225                     if (!msmppencrypt(subattr + 4, subattr[RAD_Attr_Length] - 4,
1226                             from->peer.secret, strlen(from->peer.secret), server->requests[i].origauth, subattr + 2)) {
1227                         printf("clientrd: failed to encrypt msppe key\n");
1228                         continue;
1229                     }
1230                 }
1231                 if (subleft < 0) {
1232                     printf("clientrd: bad vendor specific attr or subattr length, ignoring packet\n");
1233                     goto getnext;
1234                 }
1235             }
1236             attr += attr[RAD_Attr_Length];
1237         }
1238
1239         /* once we set received = 1, requests[i] may be reused */
1240         buf[1] = (char)server->requests[i].origid;
1241         memcpy(buf + 4, server->requests[i].origauth, 16);
1242         printauth("origauth/buf+4", buf + 4);
1243         if (messageauthattr) {
1244             if (!createmessageauth(buf, &messageauthattr[RAD_Attr_Value], from->peer.secret))
1245                 continue;
1246             printf("clientrd: computed messageauthattr\n");
1247         }
1248
1249         if (from->peer.type == 'U')
1250             fromsa = server->requests[i].fromsa;
1251         server->requests[i].received = 1;
1252         pthread_mutex_unlock(&server->newrq_mutex);
1253
1254         if (!radsign(buf, from->peer.secret)) {
1255             printf("clientrd: failed to sign message\n");
1256             continue;
1257         }
1258         printauth("signedorigauth/buf+4", buf + 4);             
1259         printf("clientrd: giving packet back to where it came from\n");
1260         sendreply(from, server, buf, from->peer.type == 'U' ? &fromsa : NULL);
1261     }
1262 }
1263
1264 void *clientwr(void *arg) {
1265     struct server *server = (struct server *)arg;
1266     struct request *rq;
1267     pthread_t clientrdth;
1268     int i;
1269     struct timeval now;
1270     struct timespec timeout;
1271
1272     memset(&timeout, 0, sizeof(struct timespec));
1273     
1274     if (server->peer.type == 'U') {
1275         if ((server->sock = connecttoserver(server->peer.addrinfo)) < 0) {
1276             printf("clientwr: connecttoserver failed\n");
1277             exit(1);
1278         }
1279     } else
1280         tlsconnect(server, NULL, "new client");
1281     
1282     if (pthread_create(&clientrdth, NULL, clientrd, (void *)server))
1283         errx("clientwr: pthread_create failed");
1284
1285     for (;;) {
1286         pthread_mutex_lock(&server->newrq_mutex);
1287         if (!server->newrq) {
1288             if (timeout.tv_nsec) {
1289                 printf("clientwr: waiting up to %ld secs for new request\n", timeout.tv_nsec);
1290                 pthread_cond_timedwait(&server->newrq_cond, &server->newrq_mutex, &timeout);
1291                 timeout.tv_nsec = 0;
1292             } else {
1293                 printf("clientwr: waiting for new request\n");
1294                 pthread_cond_wait(&server->newrq_cond, &server->newrq_mutex);
1295             }
1296             if (server->newrq) {
1297                 printf("clientwr: got new request\n");
1298                 server->newrq = 0;
1299             }
1300         }
1301         pthread_mutex_unlock(&server->newrq_mutex);
1302         
1303         for (i = 0; i < MAX_REQUESTS; i++) {
1304             pthread_mutex_lock(&server->newrq_mutex);
1305             while (!server->requests[i].buf && i < MAX_REQUESTS)
1306                 i++;
1307             if (i == MAX_REQUESTS) {
1308                 pthread_mutex_unlock(&server->newrq_mutex);
1309                 break;
1310             }
1311             rq = server->requests + i;
1312
1313             if (rq->received) {
1314                 printf("clientwr: removing received packet from queue\n");
1315                 free(rq->buf);
1316                 /* setting this to NULL means that it can be reused */
1317                 rq->buf = NULL;
1318                 pthread_mutex_unlock(&server->newrq_mutex);
1319                 continue;
1320             }
1321             
1322             gettimeofday(&now, NULL);
1323             if (now.tv_sec <= rq->expiry.tv_sec) {
1324                 if (!timeout.tv_sec || rq->expiry.tv_sec < timeout.tv_sec)
1325                     timeout.tv_sec = rq->expiry.tv_sec;
1326                 pthread_mutex_unlock(&server->newrq_mutex);
1327                 continue;
1328             }
1329
1330             if (rq->tries == (server->peer.type == 'T' ? 1 : REQUEST_RETRIES)) {
1331                 printf("clientwr: removing expired packet from queue\n");
1332                 free(rq->buf);
1333                 /* setting this to NULL means that it can be reused */
1334                 rq->buf = NULL;
1335                 pthread_mutex_unlock(&server->newrq_mutex);
1336                 continue;
1337             }
1338             pthread_mutex_unlock(&server->newrq_mutex);
1339
1340             rq->expiry.tv_sec = now.tv_sec +
1341                 (server->peer.type == 'T' ? REQUEST_EXPIRY : REQUEST_EXPIRY / REQUEST_RETRIES);
1342             if (!timeout.tv_sec || rq->expiry.tv_sec < timeout.tv_sec)
1343                 timeout.tv_sec = rq->expiry.tv_sec;
1344             rq->tries++;
1345             clientradput(server, server->requests[i].buf);
1346             usleep(200000);
1347         }
1348     }
1349     /* should do more work to maintain TLS connections, keepalives etc */
1350 }
1351
1352 void *udpserverwr(void *arg) {
1353     struct replyq *replyq = &udp_server_replyq;
1354     struct reply *reply = replyq->replies;
1355     
1356     pthread_mutex_lock(&replyq->count_mutex);
1357     for (;;) {
1358         while (!replyq->count) {
1359             printf("udp server writer, waiting for signal\n");
1360             pthread_cond_wait(&replyq->count_cond, &replyq->count_mutex);
1361             printf("udp server writer, got signal\n");
1362         }
1363         pthread_mutex_unlock(&replyq->count_mutex);
1364         
1365         if (sendto(udp_server_sock, reply->buf, RADLEN(reply->buf), 0,
1366                    (struct sockaddr *)&reply->tosa, SOCKADDR_SIZE(reply->tosa)) < 0)
1367             err("sendudp: send failed");
1368         free(reply->buf);
1369         
1370         pthread_mutex_lock(&replyq->count_mutex);
1371         replyq->count--;
1372         memmove(replyq->replies, replyq->replies + 1,
1373                 replyq->count * sizeof(struct reply));
1374     }
1375 }
1376
1377 void *udpserverrd(void *arg) {
1378     struct request rq;
1379     unsigned char *buf;
1380     struct server *to;
1381     struct client *fr;
1382     pthread_t udpserverwrth;
1383     
1384     if ((udp_server_sock = bindport(SOCK_DGRAM, options.udpserverport)) < 0) {
1385         printf("udpserverrd: socket/bind failed\n");
1386         exit(1);
1387     }
1388     printf("udpserverrd: listening on UDP port %s\n", options.udpserverport);
1389
1390     if (pthread_create(&udpserverwrth, NULL, udpserverwr, NULL))
1391         errx("pthread_create failed");
1392     
1393     for (;;) {
1394         fr = NULL;
1395         memset(&rq, 0, sizeof(struct request));
1396         buf = radudpget(udp_server_sock, &fr, NULL, &rq.fromsa);
1397         to = radsrv(&rq, buf, fr);
1398         if (!to) {
1399             printf("udpserverrd: ignoring request, no place to send it\n");
1400             continue;
1401         }
1402         sendrq(to, fr, &rq);
1403     }
1404 }
1405
1406 void *tlsserverwr(void *arg) {
1407     int cnt;
1408     unsigned long error;
1409     struct client *client = (struct client *)arg;
1410     struct replyq *replyq;
1411     
1412     printf("tlsserverwr starting for %s\n", client->peer.host);
1413     replyq = client->replyq;
1414     pthread_mutex_lock(&replyq->count_mutex);
1415     for (;;) {
1416         while (!replyq->count) {
1417             if (client->peer.ssl) {         
1418                 printf("tls server writer, waiting for signal\n");
1419                 pthread_cond_wait(&replyq->count_cond, &replyq->count_mutex);
1420                 printf("tls server writer, got signal\n");
1421             }
1422             if (!client->peer.ssl) {
1423                 //ssl might have changed while waiting
1424                 pthread_mutex_unlock(&replyq->count_mutex);
1425                 printf("tlsserverwr: exiting as requested\n");
1426                 pthread_exit(NULL);
1427             }
1428         }
1429         pthread_mutex_unlock(&replyq->count_mutex);
1430         cnt = SSL_write(client->peer.ssl, replyq->replies->buf, RADLEN(replyq->replies->buf));
1431         if (cnt > 0)
1432             printf("tlsserverwr: Sent %d bytes, Radius packet of length %d\n",
1433                    cnt, RADLEN(replyq->replies->buf));
1434         else
1435             while ((error = ERR_get_error()))
1436                 err("tlsserverwr: SSL: %s", ERR_error_string(error, NULL));
1437         free(replyq->replies->buf);
1438
1439         pthread_mutex_lock(&replyq->count_mutex);
1440         replyq->count--;
1441         memmove(replyq->replies, replyq->replies + 1, replyq->count * sizeof(struct reply));
1442     }
1443 }
1444
1445 void *tlsserverrd(void *arg) {
1446     struct request rq;
1447     char unsigned *buf;
1448     unsigned long error;
1449     struct server *to;
1450     int s;
1451     struct client *client = (struct client *)arg;
1452     pthread_t tlsserverwrth;
1453     SSL *ssl;
1454     
1455     printf("tlsserverrd starting for %s\n", client->peer.host);
1456     ssl = client->peer.ssl;
1457
1458     if (SSL_accept(ssl) <= 0) {
1459         while ((error = ERR_get_error()))
1460             err("tlsserverrd: SSL: %s", ERR_error_string(error, NULL));
1461         errx("accept failed, child exiting");
1462     }
1463
1464     if (tlsverifycert(&client->peer)) {
1465         if (pthread_create(&tlsserverwrth, NULL, tlsserverwr, (void *)client))
1466             errx("pthread_create failed");
1467     
1468         for (;;) {
1469             buf = radtlsget(client->peer.ssl);
1470             if (!buf)
1471                 break;
1472             printf("tlsserverrd: got Radius message from %s\n", client->peer.host);
1473             memset(&rq, 0, sizeof(struct request));
1474             to = radsrv(&rq, buf, client);
1475             if (!to) {
1476                 printf("ignoring request, no place to send it\n");
1477                 continue;
1478             }
1479             sendrq(to, client, &rq);
1480         }
1481         printf("tlsserverrd: connection lost\n");
1482         // stop writer by setting peer.ssl to NULL and give signal in case waiting for data
1483         client->peer.ssl = NULL;
1484         pthread_mutex_lock(&client->replyq->count_mutex);
1485         pthread_cond_signal(&client->replyq->count_cond);
1486         pthread_mutex_unlock(&client->replyq->count_mutex);
1487         printf("tlsserverrd: waiting for writer to end\n");
1488         pthread_join(tlsserverwrth, NULL);
1489     }
1490     s = SSL_get_fd(ssl);
1491     SSL_free(ssl);
1492     close(s);
1493     printf("tlsserverrd thread for %s exiting\n", client->peer.host);
1494     pthread_exit(NULL);
1495 }
1496
1497 int tlslistener() {
1498     pthread_t tlsserverth;
1499     int s, snew;
1500     struct sockaddr_storage from;
1501     size_t fromlen = sizeof(from);
1502     struct client *client;
1503
1504     if ((s = bindport(SOCK_STREAM, DEFAULT_TLS_PORT)) < 0) {
1505         printf("tlslistener: socket/bind failed\n");
1506         exit(1);
1507     }
1508     
1509     listen(s, 0);
1510     printf("listening for incoming TLS on port %s\n", DEFAULT_TLS_PORT);
1511
1512     for (;;) {
1513         snew = accept(s, (struct sockaddr *)&from, &fromlen);
1514         if (snew < 0)
1515             errx("accept failed");
1516         printf("incoming TLS connection from %s\n", addr2string((struct sockaddr *)&from, fromlen));
1517
1518         client = find_client('T', (struct sockaddr *)&from, NULL);
1519         if (!client) {
1520             printf("ignoring request, not a known TLS client\n");
1521             shutdown(snew, SHUT_RDWR);
1522             close(snew);
1523             continue;
1524         }
1525
1526         if (client->peer.ssl) {
1527             printf("Ignoring incoming connection, already have one from this client\n");
1528             shutdown(snew, SHUT_RDWR);
1529             close(snew);
1530             continue;
1531         }
1532         client->peer.ssl = SSL_new(ssl_ctx_srv);
1533         SSL_set_fd(client->peer.ssl, snew);
1534         if (pthread_create(&tlsserverth, NULL, tlsserverrd, (void *)client))
1535             errx("pthread_create failed");
1536         pthread_detach(tlsserverth);
1537     }
1538     return 0;
1539 }
1540
1541 char *parsehostport(char *s, struct peer *peer) {
1542     char *p, *field;
1543     int ipv6 = 0;
1544
1545     p = s;
1546     // allow literal addresses and port, e.g. [2001:db8::1]:1812
1547     if (*p == '[') {
1548         p++;
1549         field = p;
1550         for (; *p && *p != ']' && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1551         if (*p != ']') {
1552             printf("no ] matching initial [\n");
1553             exit(1);
1554         }
1555         ipv6 = 1;
1556     } else {
1557         field = p;
1558         for (; *p && *p != ':' && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1559     }
1560     if (field == p) {
1561         printf("missing host/address\n");
1562         exit(1);
1563     }
1564     peer->host = stringcopy(field, p - field);
1565     if (ipv6) {
1566         p++;
1567         if (*p && *p != ':' && *p != ' ' && *p != '\t' && *p != '\n') {
1568             printf("unexpected character after ]\n");
1569             exit(1);
1570         }
1571     }
1572     if (*p == ':') {
1573             /* port number or service name is specified */;
1574             field = p++;
1575             for (; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1576             if (field == p) {
1577                 printf("syntax error, : but no following port\n");
1578                 exit(1);
1579             }
1580             peer->port = stringcopy(field, p - field);
1581     } else
1582         peer->port = stringcopy(peer->type == 'U' ? DEFAULT_UDP_PORT : DEFAULT_TLS_PORT, 0);
1583     return p;
1584 }
1585
1586 // * is default, else longest match ... ";" used for separator
1587 char *parserealmlist(char *s, struct server *server) {
1588     char *p;
1589     int i, n, l;
1590
1591     for (p = s, n = 1; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++)
1592         if (*p == ';')
1593             n++;
1594     l = p - s;
1595     if (!l) {
1596         printf("realm list must be specified\n");
1597         exit(1);
1598     }
1599     server->realmdata = stringcopy(s, l);
1600     server->realms = malloc((1+n) * sizeof(char *));
1601     if (!server->realms)
1602         errx("malloc failed");
1603     server->realms[0] = server->realmdata;
1604     for (n = 1, i = 0; i < l; i++)
1605         if (server->realmdata[i] == ';') {
1606             server->realmdata[i] = '\0';
1607             server->realms[n++] = server->realmdata + i + 1;
1608         }       
1609     server->realms[n] = NULL;
1610     return p;
1611 }
1612
1613 /* exactly one argument must be non-NULL */
1614 void getconfig(const char *serverfile, const char *clientfile) {
1615     FILE *f;
1616     char line[1024];
1617     char *p, *field, **r;
1618     const char *file;
1619     struct client *client;
1620     struct server *server;
1621     struct peer *peer;
1622     int i, count, *ucount, *tcount;
1623  
1624     file = serverfile ? serverfile : clientfile;
1625     f = fopen(file, "r");
1626     if (!f)
1627         errx("getconfig failed to open %s for reading", file);
1628     printf("opening file %s for reading\n", file);
1629     if (serverfile) {
1630         ucount = &server_udp_count;
1631         tcount = &server_tls_count;
1632     } else {
1633         ucount = &client_udp_count;
1634         tcount = &client_tls_count;
1635     }
1636     while (fgets(line, 1024, f)) {
1637         for (p = line; *p == ' ' || *p == '\t'; p++);
1638         switch (*p) {
1639         case '#':
1640         case '\n':
1641             break;
1642         case 'T':
1643             (*tcount)++;
1644             break;
1645         case 'U':
1646             (*ucount)++;
1647             break;
1648         default:
1649             printf("type must be U or T, got %c\n", *p);
1650             exit(1);
1651         }
1652     }
1653
1654     if (serverfile) {
1655         count = server_count = server_udp_count + server_tls_count;
1656         servers = calloc(count, sizeof(struct server));
1657         if (!servers)
1658             errx("malloc failed");
1659     } else {
1660         count = client_count = client_udp_count + client_tls_count;
1661         clients = calloc(count, sizeof(struct client));
1662         if (!clients)
1663             errx("malloc failed");
1664     }
1665     
1666     if (client_udp_count) {
1667         udp_server_replyq.replies = malloc(client_udp_count * MAX_REQUESTS * sizeof(struct reply));
1668         if (!udp_server_replyq.replies)
1669             errx("malloc failed");
1670         udp_server_replyq.size = client_udp_count * MAX_REQUESTS;
1671         udp_server_replyq.count = 0;
1672         pthread_mutex_init(&udp_server_replyq.count_mutex, NULL);
1673         pthread_cond_init(&udp_server_replyq.count_cond, NULL);
1674     }    
1675     
1676     rewind(f);
1677     for (i = 0; i < count && fgets(line, 1024, f);) {
1678         if (serverfile) {
1679             server = &servers[i];
1680             peer = &server->peer;
1681         } else {
1682             client = &clients[i];
1683             peer = &client->peer;
1684         }
1685         for (p = line; *p == ' ' || *p == '\t'; p++);
1686         if (*p == '#' || *p == '\n')
1687             continue;
1688         peer->type = *p;        // we already know it must be U or T
1689         for (p++; *p == ' ' || *p == '\t'; p++);
1690         p = parsehostport(p, peer);
1691         for (; *p == ' ' || *p == '\t'; p++);
1692         if (serverfile) {
1693             p = parserealmlist(p, server);
1694             for (; *p == ' ' || *p == '\t'; p++);
1695         }
1696         field = p;
1697         for (; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1698         if (field == p) {
1699             /* no secret set and end of line, line is complete if TLS */
1700             if (peer->type == 'U') {
1701                 printf("secret must be specified for UDP\n");
1702                 exit(1);
1703             }
1704             peer->secret = stringcopy(DEFAULT_TLS_SECRET, 0);
1705         } else {
1706             peer->secret = stringcopy(field, p - field);
1707             /* check that rest of line only white space */
1708             for (; *p == ' ' || *p == '\t'; p++);
1709             if (*p && *p != '\n') {
1710                 printf("max 4 fields per line, found a 5th\n");
1711                 exit(1);
1712             }
1713         }
1714
1715         if ((serverfile && !resolvepeer(&server->peer)) ||
1716             (clientfile && !resolvepeer(&client->peer))) {
1717             printf("failed to resolve host %s port %s, exiting\n", peer->host, peer->port);
1718             exit(1);
1719         }
1720
1721         if (serverfile) {
1722             pthread_mutex_init(&server->lock, NULL);
1723             server->sock = -1;
1724             server->requests = calloc(MAX_REQUESTS, sizeof(struct request));
1725             if (!server->requests)
1726                 errx("malloc failed");
1727             server->newrq = 0;
1728             pthread_mutex_init(&server->newrq_mutex, NULL);
1729             pthread_cond_init(&server->newrq_cond, NULL);
1730         } else {
1731             if (peer->type == 'U')
1732                 client->replyq = &udp_server_replyq;
1733             else {
1734                 client->replyq = malloc(sizeof(struct replyq));
1735                 if (!client->replyq)
1736                     errx("malloc failed");
1737                 client->replyq->replies = calloc(MAX_REQUESTS, sizeof(struct reply));
1738                 if (!client->replyq->replies)
1739                     errx("malloc failed");
1740                 client->replyq->size = MAX_REQUESTS;
1741                 client->replyq->count = 0;
1742                 pthread_mutex_init(&client->replyq->count_mutex, NULL);
1743                 pthread_cond_init(&client->replyq->count_cond, NULL);
1744             }
1745         }
1746         printf("got type %c, host %s, port %s, secret %s\n", peer->type, peer->host, peer->port, peer->secret);
1747         if (serverfile) {
1748             printf("    with realms:");
1749             for (r = server->realms; *r; r++)
1750                 printf(" %s", *r);
1751             printf("\n");
1752         }
1753         i++;
1754     }
1755     fclose(f);
1756 }
1757
1758 void getmainconfig(const char *configfile) {
1759     FILE *f;
1760     char line[1024];
1761     char *p, *opt, *endopt, *val, *endval;
1762     
1763     printf("opening file %s for reading\n", configfile);
1764     f = fopen(configfile, "r");
1765     if (!f)
1766         errx("getmainconfig failed to open %s for reading", configfile);
1767
1768     memset(&options, 0, sizeof(options));
1769
1770     while (fgets(line, 1024, f)) {
1771         for (p = line; *p == ' ' || *p == '\t'; p++);
1772         if (!*p || *p == '#' || *p == '\n')
1773             continue;
1774         opt = p++;
1775         for (; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1776         endopt = p - 1;
1777         for (; *p == ' ' || *p == '\t'; p++);
1778         if (!*p || *p == '\n') {
1779             endopt[1] = '\0';
1780             printf("error in %s, option %s has no value\n", configfile, opt);
1781             exit(1);
1782         }
1783         val = p;
1784         for (; *p && *p != '\n'; p++)
1785             if (*p != ' ' && *p != '\t')
1786                 endval = p;
1787         endopt[1] = '\0';
1788         endval[1] = '\0';
1789         printf("getmainconfig: %s = %s\n", opt, val);
1790         
1791         if (!strcasecmp(opt, "TLSCACertificateFile")) {
1792             options.tlscacertificatefile = stringcopy(val, 0);
1793             continue;
1794         }
1795         if (!strcasecmp(opt, "TLSCACertificatePath")) {
1796             options.tlscacertificatepath = stringcopy(val, 0);
1797             continue;
1798         }
1799         if (!strcasecmp(opt, "TLSCertificateFile")) {
1800             options.tlscertificatefile = stringcopy(val, 0);
1801             continue;
1802         }
1803         if (!strcasecmp(opt, "TLSCertificateKeyFile")) {
1804             options.tlscertificatekeyfile = stringcopy(val, 0);
1805             continue;
1806         }
1807         if (!strcasecmp(opt, "UDPServerPort")) {
1808             options.udpserverport = stringcopy(val, 0);
1809             continue;
1810         }
1811
1812         printf("error in %s, unknown option %s\n", configfile, opt);
1813         exit(1);
1814     }
1815     fclose(f);
1816
1817     if (!options.udpserverport)
1818         options.udpserverport = stringcopy(DEFAULT_UDP_PORT, 0);
1819 }
1820
1821 #if 0
1822 void parseargs(int argc, char **argv) {
1823     int c;
1824
1825     while ((c = getopt(argc, argv, "p:")) != -1) {
1826         switch (c) {
1827         case 'p':
1828             udp_server_port = optarg;
1829             break;
1830         default:
1831             goto usage;
1832         }
1833     }
1834
1835     return;
1836
1837  usage:
1838     printf("radsecproxy [ -p UDP-port ]\n");
1839     exit(1);
1840 }
1841 #endif
1842
1843 int main(int argc, char **argv) {
1844     pthread_t udpserverth;
1845     //    pthread_attr_t joinable;
1846     int i;
1847     
1848     //    parseargs(argc, argv);
1849     getmainconfig("radsecproxy.conf");
1850     getconfig("servers.conf", NULL);
1851     getconfig(NULL, "clients.conf");
1852
1853     //    pthread_attr_init(&joinable);
1854     //    pthread_attr_setdetachstate(&joinable, PTHREAD_CREATE_JOINABLE);
1855    
1856     if (client_udp_count)
1857         if (pthread_create(&udpserverth, NULL /*&joinable*/, udpserverrd, NULL))
1858             errx("pthread_create failed");
1859
1860     ssl_init(client_tls_count ? &ssl_ctx_srv : NULL, server_tls_count ? &ssl_ctx_cl : NULL);
1861     
1862     for (i = 0; i < server_count; i++)
1863         if (pthread_create(&servers[i].clientth, NULL, clientwr, (void *)&servers[i]))
1864             errx("pthread_create failed");
1865
1866     if (client_tls_count)
1867         return tlslistener();
1868
1869     /* just hang around doing nothing, anything to do here? */
1870     for (;;)
1871         sleep(1000);
1872 }