minor changes
[radsecproxy.git] / radsecproxy.c
1 /*
2  * Copyright (C) 2006 Stig Venaas <venaas@uninett.no>
3  *
4  * Permission to use, copy, modify, and distribute this software for any
5  * purpose with or without fee is hereby granted, provided that the above
6  * copyright notice and this permission notice appear in all copies.
7  */
8
9 /* TODO:
10  * accounting
11  * radius keep alives (server status)
12  * setsockopt(keepalive...), check if openssl has some keepalive feature
13 */
14
15 /* For UDP there is one server instance consisting of udpserverrd and udpserverth
16  *              rd is responsible for init and launching wr
17  * For TLS there is a server instance that launches tlsserverrd for each TLS peer
18  *          each tlsserverrd launches tlsserverwr
19  * For each UDP/TLS peer there is clientrd and clientwr, clientwr is responsible
20  *          for init and launching rd
21  *
22  * serverrd will receive a request, processes it and puts it in the requestq of
23  *          the appropriate clientwr
24  * clientwr monitors its requestq and sends requests
25  * clientrd looks for responses, processes them and puts them in the replyq of
26  *          the peer the request came from
27  * serverwr monitors its reply and sends replies
28  *
29  * In addition to the main thread, we have:
30  * If UDP peers are configured, there will be 2 + 2 * #peers UDP threads
31  * If TLS peers are configured, there will initially be 2 * #peers TLS threads
32  * For each TLS peer connecting to us there will be 2 more TLS threads
33  *       This is only for connected peers
34  * Example: With 3 UDP peer and 30 TLS peers, there will be a max of
35  *          1 + (2 + 2 * 3) + (2 * 30) + (2 * 30) = 129 threads
36 */
37
38 #include <netdb.h>
39 #include <unistd.h>
40 #include <sys/time.h>
41 #include <pthread.h>
42 #include <openssl/ssl.h>
43 #include <openssl/rand.h>
44 #include <openssl/err.h>
45 #include <openssl/md5.h>
46 #include <openssl/hmac.h>
47 #include "radsecproxy.h"
48
49 static struct options options;
50 static struct client *clients;
51 static struct server *servers;
52
53 static int client_udp_count = 0;
54 static int client_tls_count = 0;
55 static int client_count = 0;
56 static int server_udp_count = 0;
57 static int server_tls_count = 0;
58 static int server_count = 0;
59
60 static struct replyq udp_server_replyq;
61 static int udp_server_sock = -1;
62 static pthread_mutex_t *ssl_locks;
63 static long *ssl_lock_count;
64 static SSL_CTX *ssl_ctx = NULL;
65 extern int optind;
66 extern char *optarg;
67
68 /* callbacks for making OpenSSL thread safe */
69 unsigned long ssl_thread_id() {
70         return (unsigned long)pthread_self();
71 };
72
73 void ssl_locking_callback(int mode, int type, const char *file, int line) {
74     if (mode & CRYPTO_LOCK) {
75         pthread_mutex_lock(&ssl_locks[type]);
76         ssl_lock_count[type]++;
77     } else
78         pthread_mutex_unlock(&ssl_locks[type]);
79 }
80
81 static int verify_cb(int ok, X509_STORE_CTX *ctx) {
82   char buf[256];
83   X509 *err_cert;
84   int err, depth;
85
86   err_cert = X509_STORE_CTX_get_current_cert(ctx);
87   err = X509_STORE_CTX_get_error(ctx);
88   depth = X509_STORE_CTX_get_error_depth(ctx);
89
90   if (depth > MAX_CERT_DEPTH) {
91       ok = 0;
92       err = X509_V_ERR_CERT_CHAIN_TOO_LONG;
93       X509_STORE_CTX_set_error(ctx, err);
94   }
95
96   if (!ok) {
97       X509_NAME_oneline(X509_get_subject_name(err_cert), buf, 256);
98       printf("verify error: num=%d:%s:depth=%d:%s\n", err, X509_verify_cert_error_string(err), depth, buf);
99
100       switch (err) {
101       case X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT:
102           X509_NAME_oneline(X509_get_issuer_name(ctx->current_cert), buf, 256);
103           printf("issuer=%s\n", buf);
104           break;
105       case X509_V_ERR_CERT_NOT_YET_VALID:
106       case X509_V_ERR_ERROR_IN_CERT_NOT_BEFORE_FIELD:
107           printf("Certificate not yet valid\n");
108           break;
109       case X509_V_ERR_CERT_HAS_EXPIRED:
110           printf("Certificate has expired\n");
111           break;
112       case X509_V_ERR_ERROR_IN_CERT_NOT_AFTER_FIELD:
113           printf("Certificate no longer valid (after notAfter)\n");
114           break;
115       }
116   }
117   //  printf("certificate verify returns %d\n", ok);
118   return ok;
119 }
120
121 SSL_CTX *ssl_init() {
122     SSL_CTX *ctx;
123     int i;
124     unsigned long error;
125     
126     if (!options.tlscertificatefile || !options.tlscertificatekeyfile) {
127         printf("TLSCertificateFile and TLSCertificateKeyFile must be specified for TLS\n");
128         exit(1);
129     }
130     if (!options.tlscacertificatefile && !options.tlscacertificatepath) {
131         printf("CA Certificate file/path need to be configured\n");
132         exit(1);
133     }
134
135     ssl_locks = malloc(CRYPTO_num_locks() * sizeof(pthread_mutex_t));
136     ssl_lock_count = OPENSSL_malloc(CRYPTO_num_locks() * sizeof(long));
137     for (i = 0; i < CRYPTO_num_locks(); i++) {
138         ssl_lock_count[i] = 0;
139         pthread_mutex_init(&ssl_locks[i], NULL);
140     }
141     CRYPTO_set_id_callback(ssl_thread_id);
142     CRYPTO_set_locking_callback(ssl_locking_callback);
143
144     SSL_load_error_strings();
145     SSL_library_init();
146
147     while (!RAND_status()) {
148         time_t t = time(NULL);
149         pid_t pid = getpid();
150         RAND_seed((unsigned char *)&t, sizeof(time_t));
151         RAND_seed((unsigned char *)&pid, sizeof(pid));
152     }
153
154     ctx = SSL_CTX_new(TLSv1_method());
155     if (SSL_CTX_use_certificate_chain_file(ctx, options.tlscertificatefile) &&
156         SSL_CTX_use_PrivateKey_file(ctx, options.tlscertificatekeyfile, SSL_FILETYPE_PEM) &&
157         SSL_CTX_check_private_key(ctx) &&
158         SSL_CTX_load_verify_locations(ctx, options.tlscacertificatefile, options.tlscacertificatepath)) {
159         SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, verify_cb);
160         SSL_CTX_set_verify_depth(ctx, MAX_CERT_DEPTH + 1);
161         return ctx;
162     }
163
164     while ((error = ERR_get_error()))
165         err("SSL: %s", ERR_error_string(error, NULL));
166     exit(1);
167 }    
168
169 void printauth(char *s, unsigned char *t) {
170     int i;
171     printf("%s:", s);
172     for (i = 0; i < 16; i++)
173             printf("%02x ", t[i]);
174     printf("\n");
175 }
176
177 int resolvepeer(struct peer *peer) {
178     struct addrinfo hints, *addrinfo;
179     
180     memset(&hints, 0, sizeof(hints));
181     hints.ai_socktype = (peer->type == 'T' ? SOCK_STREAM : SOCK_DGRAM);
182     hints.ai_family = AF_UNSPEC;
183     if (getaddrinfo(peer->host, peer->port, &hints, &addrinfo)) {
184         err("resolvepeer: can't resolve %s port %s", peer->host, peer->port);
185         return 0;
186     }
187
188     if (peer->addrinfo)
189         freeaddrinfo(peer->addrinfo);
190     peer->addrinfo = addrinfo;
191     return 1;
192 }         
193
194 int connecttoserver(struct addrinfo *addrinfo) {
195     int s;
196     struct addrinfo *res;
197     
198     for (res = addrinfo; res; res = res->ai_next) {
199         s = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
200         if (s < 0) {
201             err("connecttoserver: socket failed");
202             continue;
203         }
204         if (connect(s, res->ai_addr, res->ai_addrlen) == 0)
205             break;
206         err("connecttoserver: connect failed");
207         close(s);
208         s = -1;
209     }
210     return s;
211 }         
212
213 /* returns the client with matching address, or NULL */
214 /* if client argument is not NULL, we only check that one client */
215 struct client *find_client(char type, struct sockaddr *addr, struct client *client) {
216     struct sockaddr_in6 *sa6;
217     struct in_addr *a4 = NULL;
218     struct client *c;
219     int i;
220     struct addrinfo *res;
221
222     if (addr->sa_family == AF_INET6) {
223         sa6 = (struct sockaddr_in6 *)addr;
224         if (IN6_IS_ADDR_V4MAPPED(&sa6->sin6_addr))
225             a4 = (struct in_addr *)&sa6->sin6_addr.s6_addr[12];
226     } else
227         a4 = &((struct sockaddr_in *)addr)->sin_addr;
228
229     c = (client ? client : clients);
230     for (i = 0; i < client_count; i++) {
231         if (c->peer.type == type)
232             for (res = c->peer.addrinfo; res; res = res->ai_next)
233                 if ((a4 && res->ai_family == AF_INET &&
234                      !memcmp(a4, &((struct sockaddr_in *)res->ai_addr)->sin_addr, 4)) ||
235                     (res->ai_family == AF_INET6 &&
236                      !memcmp(&sa6->sin6_addr, &((struct sockaddr_in6 *)res->ai_addr)->sin6_addr, 16)))
237                     return c;
238         if (client)
239             break;
240         c++;
241     }
242     return NULL;
243 }
244
245 /* returns the server with matching address, or NULL */
246 /* if server argument is not NULL, we only check that one server */
247 struct server *find_server(char type, struct sockaddr *addr, struct server *server) {
248     struct sockaddr_in6 *sa6;
249     struct in_addr *a4 = NULL;
250     struct server *s;
251     int i;
252     struct addrinfo *res;
253
254     if (addr->sa_family == AF_INET6) {
255         sa6 = (struct sockaddr_in6 *)addr;
256         if (IN6_IS_ADDR_V4MAPPED(&sa6->sin6_addr))
257             a4 = (struct in_addr *)&sa6->sin6_addr.s6_addr[12];
258     } else
259         a4 = &((struct sockaddr_in *)addr)->sin_addr;
260
261     s = (server ? server : servers);
262     for (i = 0; i < server_count; i++) {
263         if (s->peer.type == type)
264             for (res = s->peer.addrinfo; res; res = res->ai_next)
265                 if ((a4 && res->ai_family == AF_INET &&
266                      !memcmp(a4, &((struct sockaddr_in *)res->ai_addr)->sin_addr, 4)) ||
267                     (res->ai_family == AF_INET6 &&
268                      !memcmp(&sa6->sin6_addr, &((struct sockaddr_in6 *)res->ai_addr)->sin6_addr, 16)))
269                     return s;
270         if (server)
271             break;
272         s++;
273     }
274     return NULL;
275 }
276
277 /* exactly one of client and server must be non-NULL */
278 /* if *peer == NULL we return who we received from, else require it to be from peer */
279 /* return from in sa if not NULL */
280 unsigned char *radudpget(int s, struct client **client, struct server **server, struct sockaddr_storage *sa) {
281     int cnt, len;
282     void *f;
283     unsigned char buf[65536], *rad;
284     struct sockaddr_storage from;
285     socklen_t fromlen = sizeof(from);
286
287     for (;;) {
288         cnt = recvfrom(s, buf, sizeof(buf), 0, (struct sockaddr *)&from, &fromlen);
289         if (cnt == -1) {
290             err("radudpget: recv failed");
291             continue;
292         }
293         printf("radudpget: got %d bytes from %s\n", cnt, addr2string((struct sockaddr *)&from, fromlen));
294
295         if (cnt < 20) {
296             printf("radudpget: packet too small\n");
297             continue;
298         }
299     
300         len = RADLEN(buf);
301
302         if (cnt < len) {
303             printf("radudpget: packet smaller than length field in radius header\n");
304             continue;
305         }
306         if (cnt > len)
307             printf("radudpget: packet was padded with %d bytes\n", cnt - len);
308
309         f = (client
310              ? (void *)find_client('U', (struct sockaddr *)&from, *client)
311              : (void *)find_server('U', (struct sockaddr *)&from, *server));
312         if (!f) {
313             printf("radudpget: got packet from wrong or unknown UDP peer, ignoring\n");
314             continue;
315         }
316
317         rad = malloc(len);
318         if (rad)
319             break;
320         err("radudpget: malloc failed");
321     }
322     memcpy(rad, buf, len);
323     if (client)
324         *client = (struct client *)f; /* only need this if *client == NULL, but if not NULL *client == f here */
325     else
326         *server = (struct server *)f; /* only need this if *server == NULL, but if not NULL *server == f here */
327     if (sa)
328         *sa = from;
329     return rad;
330 }
331
332 int tlsverifycert(struct peer *peer) {
333     int i, l, loc;
334     X509 *cert;
335     X509_NAME *nm;
336     X509_NAME_ENTRY *e;
337     unsigned char *v;
338     unsigned long error;
339
340 #if 1
341     if (SSL_get_verify_result(peer->ssl) != X509_V_OK) {
342         printf("tlsverifycert: basic validation failed\n");
343         while ((error = ERR_get_error()))
344             err("clientwr: TLS: %s", ERR_error_string(error, NULL));
345         return 0;
346     }
347 #endif    
348     cert = SSL_get_peer_certificate(peer->ssl);
349     if (!cert) {
350         printf("tlsverifycert: failed to obtain certificate\n");
351         return 0;
352     }
353     nm = X509_get_subject_name(cert);
354     loc = -1;
355     for (;;) {
356         loc = X509_NAME_get_index_by_NID(nm, NID_commonName, loc);
357         if (loc == -1)
358             break;
359         e = X509_NAME_get_entry(nm, loc);
360         l = ASN1_STRING_to_UTF8(&v, X509_NAME_ENTRY_get_data(e));
361         if (l < 0)
362             continue;
363         printf("cn: ");
364         for (i = 0; i < l; i++)
365             printf("%c", v[i]);
366         printf("\n");
367         if (l == strlen(peer->host) && !strncasecmp(peer->host, v, l)) {
368             printf("tlsverifycert: Found cn matching host %s, All OK\n", peer->host);
369             return 1;
370         }
371         printf("tlsverifycert: cn not matching host %s\n", peer->host);
372     }
373     X509_free(cert);
374     return 0;
375 }
376
377 void tlsconnect(struct server *server, struct timeval *when, char *text) {
378     struct timeval now;
379     time_t elapsed;
380
381     printf("tlsconnect called from %s\n", text);
382     pthread_mutex_lock(&server->lock);
383     if (when && memcmp(&server->lastconnecttry, when, sizeof(struct timeval))) {
384         /* already reconnected, nothing to do */
385         printf("tlsconnect(%s): seems already reconnected\n", text);
386         pthread_mutex_unlock(&server->lock);
387         return;
388     }
389
390     printf("tlsconnect %s\n", text);
391
392     for (;;) {
393         gettimeofday(&now, NULL);
394         elapsed = now.tv_sec - server->lastconnecttry.tv_sec;
395         if (server->connectionok) {
396             server->connectionok = 0;
397             sleep(10);
398         } else if (elapsed < 5)
399             sleep(10);
400         else if (elapsed < 600) {
401             printf("tlsconnect: sleeping %lds\n", elapsed);
402             sleep(elapsed);
403         } else if (elapsed < 1000) {
404             printf("tlsconnect: sleeping %ds\n", 900);
405             sleep(900);
406         } else
407             server->lastconnecttry.tv_sec = now.tv_sec;  // no sleep at startup
408         printf("tlsconnect: trying to open TLS connection to %s port %s\n", server->peer.host, server->peer.port);
409         if (server->sock >= 0)
410             close(server->sock);
411         if ((server->sock = connecttoserver(server->peer.addrinfo)) < 0)
412             continue;
413         SSL_free(server->peer.ssl);
414         server->peer.ssl = SSL_new(ssl_ctx);
415         SSL_set_fd(server->peer.ssl, server->sock);
416         if (SSL_connect(server->peer.ssl) > 0 && tlsverifycert(&server->peer))
417             break;
418     }
419     printf("tlsconnect: TLS connection to %s port %s up\n", server->peer.host, server->peer.port);
420     gettimeofday(&server->lastconnecttry, NULL);
421     pthread_mutex_unlock(&server->lock);
422 }
423
424 unsigned char *radtlsget(SSL *ssl) {
425     int cnt, total, len;
426     unsigned char buf[4], *rad;
427
428     for (;;) {
429         for (total = 0; total < 4; total += cnt) {
430             cnt = SSL_read(ssl, buf + total, 4 - total);
431             if (cnt <= 0) {
432                 printf("radtlsget: connection lost\n");
433                 if (SSL_get_error(ssl, cnt) == SSL_ERROR_ZERO_RETURN) {
434                     //remote end sent close_notify, send one back
435                     SSL_shutdown(ssl);
436                 }
437                 return NULL;
438             }
439         }
440
441         len = RADLEN(buf);
442         rad = malloc(len);
443         if (!rad) {
444             err("radtlsget: malloc failed");
445             continue;
446         }
447         memcpy(rad, buf, 4);
448
449         for (; total < len; total += cnt) {
450             cnt = SSL_read(ssl, rad + total, len - total);
451             if (cnt <= 0) {
452                 printf("radtlsget: connection lost\n");
453                 if (SSL_get_error(ssl, cnt) == SSL_ERROR_ZERO_RETURN) {
454                     //remote end sent close_notify, send one back
455                     SSL_shutdown(ssl);
456                 }
457                 free(rad);
458                 return NULL;
459             }
460         }
461     
462         if (total >= 20)
463             break;
464         
465         free(rad);
466         printf("radtlsget: packet smaller than minimum radius size\n");
467     }
468     
469     printf("radtlsget: got %d bytes\n", total);
470     return rad;
471 }
472
473 int clientradput(struct server *server, unsigned char *rad) {
474     int cnt;
475     size_t len;
476     unsigned long error;
477     struct timeval lastconnecttry;
478     
479     len = RADLEN(rad);
480     if (server->peer.type == 'U') {
481         if (send(server->sock, rad, len, 0) >= 0) {
482             printf("clienradput: sent UDP of length %d to %s port %s\n", len, server->peer.host, server->peer.port);
483             return 1;
484         }
485         err("clientradput: send failed");
486         return 0;
487     }
488
489     lastconnecttry = server->lastconnecttry;
490     while ((cnt = SSL_write(server->peer.ssl, rad, len)) <= 0) {
491         while ((error = ERR_get_error()))
492             err("clientwr: TLS: %s", ERR_error_string(error, NULL));
493         tlsconnect(server, &lastconnecttry, "clientradput");
494         lastconnecttry = server->lastconnecttry;
495     }
496
497     server->connectionok = 1;
498     printf("clientradput: Sent %d bytes, Radius packet of length %d to TLS peer %s\n",
499            cnt, len, server->peer.host);
500     return 1;
501 }
502
503 int radsign(unsigned char *rad, unsigned char *sec) {
504     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
505     static unsigned char first = 1;
506     static EVP_MD_CTX mdctx;
507     unsigned int md_len;
508     int result;
509     
510     pthread_mutex_lock(&lock);
511     if (first) {
512         EVP_MD_CTX_init(&mdctx);
513         first = 0;
514     }
515
516     result = (EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) &&
517         EVP_DigestUpdate(&mdctx, rad, RADLEN(rad)) &&
518         EVP_DigestUpdate(&mdctx, sec, strlen(sec)) &&
519         EVP_DigestFinal_ex(&mdctx, rad + 4, &md_len) &&
520         md_len == 16);
521     pthread_mutex_unlock(&lock);
522     return result;
523 }
524
525 int validauth(unsigned char *rad, unsigned char *reqauth, unsigned char *sec) {
526     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
527     static unsigned char first = 1;
528     static EVP_MD_CTX mdctx;
529     unsigned char hash[EVP_MAX_MD_SIZE];
530     unsigned int len;
531     int result;
532     
533     pthread_mutex_lock(&lock);
534     if (first) {
535         EVP_MD_CTX_init(&mdctx);
536         first = 0;
537     }
538
539     len = RADLEN(rad);
540     
541     result = (EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) &&
542               EVP_DigestUpdate(&mdctx, rad, 4) &&
543               EVP_DigestUpdate(&mdctx, reqauth, 16) &&
544               (len <= 20 || EVP_DigestUpdate(&mdctx, rad + 20, len - 20)) &&
545               EVP_DigestUpdate(&mdctx, sec, strlen(sec)) &&
546               EVP_DigestFinal_ex(&mdctx, hash, &len) &&
547               len == 16 &&
548               !memcmp(hash, rad + 4, 16));
549     pthread_mutex_unlock(&lock);
550     return result;
551 }
552               
553 int checkmessageauth(char *rad, uint8_t *authattr, char *secret) {
554     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
555     static unsigned char first = 1;
556     static HMAC_CTX hmacctx;
557     unsigned int md_len;
558     uint8_t auth[16], hash[EVP_MAX_MD_SIZE];
559     
560     pthread_mutex_lock(&lock);
561     if (first) {
562         HMAC_CTX_init(&hmacctx);
563         first = 0;
564     }
565
566     memcpy(auth, authattr, 16);
567     memset(authattr, 0, 16);
568     md_len = 0;
569     HMAC_Init_ex(&hmacctx, secret, strlen(secret), EVP_md5(), NULL);
570     HMAC_Update(&hmacctx, rad, RADLEN(rad));
571     HMAC_Final(&hmacctx, hash, &md_len);
572     memcpy(authattr, auth, 16);
573     if (md_len != 16) {
574         printf("message auth computation failed\n");
575         pthread_mutex_unlock(&lock);
576         return 0;
577     }
578
579     if (memcmp(auth, hash, 16)) {
580         printf("message authenticator, wrong value\n");
581         pthread_mutex_unlock(&lock);
582         return 0;
583     }   
584         
585     pthread_mutex_unlock(&lock);
586     return 1;
587 }
588
589 int createmessageauth(char *rad, char *authattrval, char *secret) {
590     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
591     static unsigned char first = 1;
592     static HMAC_CTX hmacctx;
593     unsigned int md_len;
594
595     if (!authattrval)
596         return 1;
597     
598     pthread_mutex_lock(&lock);
599     if (first) {
600         HMAC_CTX_init(&hmacctx);
601         first = 0;
602     }
603
604     memset(authattrval, 0, 16);
605     md_len = 0;
606     HMAC_Init_ex(&hmacctx, secret, strlen(secret), EVP_md5(), NULL);
607     HMAC_Update(&hmacctx, rad, RADLEN(rad));
608     HMAC_Final(&hmacctx, authattrval, &md_len);
609     if (md_len != 16) {
610         printf("message auth computation failed\n");
611         pthread_mutex_unlock(&lock);
612         return 0;
613     }
614
615     pthread_mutex_unlock(&lock);
616     return 1;
617 }
618
619 void sendrq(struct server *to, struct client *from, struct request *rq) {
620     int i;
621     
622     pthread_mutex_lock(&to->newrq_mutex);
623     /* might simplify if only try nextid, might be ok */
624     for (i = to->nextid; i < MAX_REQUESTS; i++)
625         if (!to->requests[i].buf)
626             break;
627     if (i == MAX_REQUESTS) {
628         for (i = 0; i < to->nextid; i++)
629             if (!to->requests[i].buf)
630                 break;
631         if (i == to->nextid) {
632             printf("No room in queue, dropping request\n");
633             pthread_mutex_unlock(&to->newrq_mutex);
634             return;
635         }
636     }
637     
638     to->nextid = i + 1;
639     rq->buf[1] = (char)i;
640     printf("sendrq: inserting packet with id %d in queue for %s\n", i, to->peer.host);
641     
642     if (!createmessageauth(rq->buf, rq->messageauthattrval, to->peer.secret))
643         return;
644
645     to->requests[i] = *rq;
646
647     if (!to->newrq) {
648         to->newrq = 1;
649         printf("signalling client writer\n");
650         pthread_cond_signal(&to->newrq_cond);
651     }
652     pthread_mutex_unlock(&to->newrq_mutex);
653 }
654
655 void sendreply(struct client *to, struct server *from, char *buf, struct sockaddr_storage *tosa) {
656     struct replyq *replyq = to->replyq;
657     
658     pthread_mutex_lock(&replyq->count_mutex);
659     if (replyq->count == replyq->size) {
660         printf("No room in queue, dropping request\n");
661         pthread_mutex_unlock(&replyq->count_mutex);
662         return;
663     }
664
665     replyq->replies[replyq->count].buf = buf;
666     if (tosa)
667         replyq->replies[replyq->count].tosa = *tosa;
668     replyq->count++;
669
670     if (replyq->count == 1) {
671         printf("signalling client writer\n");
672         pthread_cond_signal(&replyq->count_cond);
673     }
674     pthread_mutex_unlock(&replyq->count_mutex);
675 }
676
677 int pwdencrypt(uint8_t *in, uint8_t len, uint8_t *shared, uint8_t sharedlen, uint8_t *auth) {
678     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
679     static unsigned char first = 1;
680     static EVP_MD_CTX mdctx;
681     unsigned char hash[EVP_MAX_MD_SIZE], *input;
682     unsigned int md_len;
683     uint8_t i, offset = 0, out[128];
684     
685     pthread_mutex_lock(&lock);
686     if (first) {
687         EVP_MD_CTX_init(&mdctx);
688         first = 0;
689     }
690
691     input = auth;
692     for (;;) {
693         if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
694             !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
695             !EVP_DigestUpdate(&mdctx, input, 16) ||
696             !EVP_DigestFinal_ex(&mdctx, hash, &md_len) ||
697             md_len != 16) {
698             pthread_mutex_unlock(&lock);
699             return 0;
700         }
701         for (i = 0; i < 16; i++)
702             out[offset + i] = hash[i] ^ in[offset + i];
703         input = out + offset - 16;
704         offset += 16;
705         if (offset == len)
706             break;
707     }
708     memcpy(in, out, len);
709     pthread_mutex_unlock(&lock);
710     return 1;
711 }
712
713 int pwddecrypt(uint8_t *in, uint8_t len, uint8_t *shared, uint8_t sharedlen, uint8_t *auth) {
714     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
715     static unsigned char first = 1;
716     static EVP_MD_CTX mdctx;
717     unsigned char hash[EVP_MAX_MD_SIZE], *input;
718     unsigned int md_len;
719     uint8_t i, offset = 0, out[128];
720     
721     pthread_mutex_lock(&lock);
722     if (first) {
723         EVP_MD_CTX_init(&mdctx);
724         first = 0;
725     }
726
727     input = auth;
728     for (;;) {
729         if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
730             !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
731             !EVP_DigestUpdate(&mdctx, input, 16) ||
732             !EVP_DigestFinal_ex(&mdctx, hash, &md_len) ||
733             md_len != 16) {
734             pthread_mutex_unlock(&lock);
735             return 0;
736         }
737         for (i = 0; i < 16; i++)
738             out[offset + i] = hash[i] ^ in[offset + i];
739         input = in + offset;
740         offset += 16;
741         if (offset == len)
742             break;
743     }
744     memcpy(in, out, len);
745     pthread_mutex_unlock(&lock);
746     return 1;
747 }
748
749 int msmppencrypt(uint8_t *text, uint8_t len, uint8_t *shared, uint8_t sharedlen, uint8_t *auth, uint8_t *salt) {
750     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
751     static unsigned char first = 1;
752     static EVP_MD_CTX mdctx;
753     unsigned char hash[EVP_MAX_MD_SIZE];
754     unsigned int md_len;
755     uint8_t i, offset;
756     
757     pthread_mutex_lock(&lock);
758     if (first) {
759         EVP_MD_CTX_init(&mdctx);
760         first = 0;
761     }
762
763 #if 0    
764     printf("msppencrypt auth in: ");
765     for (i = 0; i < 16; i++)
766         printf("%02x ", auth[i]);
767     printf("\n");
768     
769     printf("msppencrypt salt in: ");
770     for (i = 0; i < 2; i++)
771         printf("%02x ", salt[i]);
772     printf("\n");
773     
774     printf("msppencrypt in: ");
775     for (i = 0; i < len; i++)
776         printf("%02x ", text[i]);
777     printf("\n");
778 #endif
779     
780     if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
781         !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
782         !EVP_DigestUpdate(&mdctx, auth, 16) ||
783         !EVP_DigestUpdate(&mdctx, salt, 2) ||
784         !EVP_DigestFinal_ex(&mdctx, hash, &md_len)) {
785         pthread_mutex_unlock(&lock);
786         return 0;
787     }
788
789 #if 0    
790     printf("msppencrypt hash: ");
791     for (i = 0; i < 16; i++)
792         printf("%02x ", hash[i]);
793     printf("\n");
794 #endif
795     
796     for (i = 0; i < 16; i++)
797         text[i] ^= hash[i];
798     
799     for (offset = 16; offset < len; offset += 16) {
800 #if 0   
801         printf("text + offset - 16 c(%d): ", offset / 16);
802         for (i = 0; i < 16; i++)
803             printf("%02x ", (text + offset - 16)[i]);
804         printf("\n");
805 #endif
806         if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
807             !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
808             !EVP_DigestUpdate(&mdctx, text + offset - 16, 16) ||
809             !EVP_DigestFinal_ex(&mdctx, hash, &md_len) ||
810             md_len != 16) {
811             pthread_mutex_unlock(&lock);
812             return 0;
813         }
814 #if 0   
815         printf("msppencrypt hash: ");
816         for (i = 0; i < 16; i++)
817             printf("%02x ", hash[i]);
818         printf("\n");
819 #endif    
820         
821         for (i = 0; i < 16; i++)
822             text[offset + i] ^= hash[i];
823     }
824     
825 #if 0
826     printf("msppencrypt out: ");
827     for (i = 0; i < len; i++)
828         printf("%02x ", text[i]);
829     printf("\n");
830 #endif
831
832     pthread_mutex_unlock(&lock);
833     return 1;
834 }
835
836 int msmppdecrypt(uint8_t *text, uint8_t len, uint8_t *shared, uint8_t sharedlen, uint8_t *auth, uint8_t *salt) {
837     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
838     static unsigned char first = 1;
839     static EVP_MD_CTX mdctx;
840     unsigned char hash[EVP_MAX_MD_SIZE];
841     unsigned int md_len;
842     uint8_t i, offset;
843     char plain[255];
844     
845     pthread_mutex_lock(&lock);
846     if (first) {
847         EVP_MD_CTX_init(&mdctx);
848         first = 0;
849     }
850
851 #if 0    
852     printf("msppdecrypt auth in: ");
853     for (i = 0; i < 16; i++)
854         printf("%02x ", auth[i]);
855     printf("\n");
856     
857     printf("msppedecrypt salt in: ");
858     for (i = 0; i < 2; i++)
859         printf("%02x ", salt[i]);
860     printf("\n");
861     
862     printf("msppedecrypt in: ");
863     for (i = 0; i < len; i++)
864         printf("%02x ", text[i]);
865     printf("\n");
866 #endif
867     
868     if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
869         !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
870         !EVP_DigestUpdate(&mdctx, auth, 16) ||
871         !EVP_DigestUpdate(&mdctx, salt, 2) ||
872         !EVP_DigestFinal_ex(&mdctx, hash, &md_len)) {
873         pthread_mutex_unlock(&lock);
874         return 0;
875     }
876
877 #if 0    
878     printf("msppedecrypt hash: ");
879     for (i = 0; i < 16; i++)
880         printf("%02x ", hash[i]);
881     printf("\n");
882 #endif
883     
884     for (i = 0; i < 16; i++)
885         plain[i] = text[i] ^ hash[i];
886     
887     for (offset = 16; offset < len; offset += 16) {
888 #if 0   
889         printf("text + offset - 16 c(%d): ", offset / 16);
890         for (i = 0; i < 16; i++)
891             printf("%02x ", (text + offset - 16)[i]);
892         printf("\n");
893 #endif
894         if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
895             !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
896             !EVP_DigestUpdate(&mdctx, text + offset - 16, 16) ||
897             !EVP_DigestFinal_ex(&mdctx, hash, &md_len) ||
898             md_len != 16) {
899             pthread_mutex_unlock(&lock);
900             return 0;
901         }
902 #if 0   
903     printf("msppedecrypt hash: ");
904     for (i = 0; i < 16; i++)
905         printf("%02x ", hash[i]);
906     printf("\n");
907 #endif    
908
909     for (i = 0; i < 16; i++)
910         plain[offset + i] = text[offset + i] ^ hash[i];
911     }
912
913     memcpy(text, plain, len);
914 #if 0
915     printf("msppedecrypt out: ");
916     for (i = 0; i < len; i++)
917         printf("%02x ", text[i]);
918     printf("\n");
919 #endif
920
921     pthread_mutex_unlock(&lock);
922     return 1;
923 }
924
925 struct server *id2server(char *id, uint8_t len) {
926     int i;
927     char **realm, *idrealm;
928
929     idrealm = strchr(id, '@');
930     if (idrealm) {
931         idrealm++;
932         len -= idrealm - id;
933     } else {
934         idrealm = "-";
935         len = 1;
936     }
937     for (i = 0; i < server_count; i++) {
938         for (realm = servers[i].realms; *realm; realm++) {
939             if ((strlen(*realm) == 1 && **realm == '*') ||
940                 (strlen(*realm) == len && !memcmp(idrealm, *realm, len))) {
941                 printf("found matching realm: %s, host %s\n", *realm, servers[i].peer.host);
942                 return servers + i;
943             }
944         }
945     }
946     return NULL;
947 }
948
949 int rqinqueue(struct server *to, struct client *from, uint8_t id) {
950     int i;
951     
952     pthread_mutex_lock(&to->newrq_mutex);
953     for (i = 0; i < MAX_REQUESTS; i++)
954         if (to->requests[i].buf && to->requests[i].origid == id && to->requests[i].from == from)
955             break;
956     pthread_mutex_unlock(&to->newrq_mutex);
957     
958     return i < MAX_REQUESTS;
959 }
960         
961 struct server *radsrv(struct request *rq, char *buf, struct client *from) {
962     uint8_t code, id, *auth, *attr, attrvallen;
963     uint8_t *usernameattr = NULL, *userpwdattr = NULL, *tunnelpwdattr = NULL, *messageauthattr = NULL;
964     int i;
965     uint16_t len;
966     int left;
967     struct server *to;
968     unsigned char newauth[16];
969     
970     code = *(uint8_t *)buf;
971     id = *(uint8_t *)(buf + 1);
972     len = RADLEN(buf);
973     auth = (uint8_t *)(buf + 4);
974
975     printf("radsrv: code %d, id %d, length %d\n", code, id, len);
976     
977     if (code != RAD_Access_Request) {
978         printf("radsrv: server currently accepts only access-requests, ignoring\n");
979         return NULL;
980     }
981
982     left = len - 20;
983     attr = buf + 20;
984     
985     while (left > 1) {
986         left -= attr[RAD_Attr_Length];
987         if (left < 0) {
988             printf("radsrv: attribute length exceeds packet length, ignoring packet\n");
989             return NULL;
990         }
991         switch (attr[RAD_Attr_Type]) {
992         case RAD_Attr_User_Name:
993             usernameattr = attr;
994             break;
995         case RAD_Attr_User_Password:
996             userpwdattr = attr;
997             break;
998         case RAD_Attr_Tunnel_Password:
999             tunnelpwdattr = attr;
1000             break;
1001         case RAD_Attr_Message_Authenticator:
1002             messageauthattr = attr;
1003             break;
1004         }
1005         attr += attr[RAD_Attr_Length];
1006     }
1007     if (left)
1008         printf("radsrv: malformed packet? remaining byte after last attribute\n");
1009
1010     if (usernameattr) {
1011         printf("radsrv: Username: ");
1012         for (i = 0; i < usernameattr[RAD_Attr_Length] - 2; i++)
1013             printf("%c", usernameattr[RAD_Attr_Value + i]);
1014         printf("\n");
1015     }
1016
1017     to = id2server(&usernameattr[RAD_Attr_Value], usernameattr[RAD_Attr_Length] - 2);
1018     if (!to) {
1019         printf("radsrv: ignoring request, don't know where to send it\n");
1020         return NULL;
1021     }
1022
1023     if (rqinqueue(to, from, id)) {
1024         printf("radsrv: ignoring request from host %s with id %d, already got one\n", from->peer.host, id);
1025         return NULL;
1026     }
1027     
1028     if (messageauthattr && (messageauthattr[RAD_Attr_Length] != 18 ||
1029                             !checkmessageauth(buf, &messageauthattr[RAD_Attr_Value], from->peer.secret))) {
1030         printf("radsrv: message authentication failed\n");
1031         return NULL;
1032     }
1033
1034     if (!RAND_bytes(newauth, 16)) {
1035         printf("radsrv: failed to generate random auth\n");
1036         return NULL;
1037     }
1038
1039     printauth("auth", auth);
1040     printauth("newauth", newauth);
1041     
1042     if (userpwdattr) {
1043         printf("radsrv: found userpwdattr of length %d\n", userpwdattr[RAD_Attr_Length]);
1044         attrvallen = userpwdattr[RAD_Attr_Length] - 2;
1045         if (attrvallen < 16 || attrvallen > 128 || attrvallen % 16) {
1046             printf("radsrv: invalid user password length\n");
1047             return NULL;
1048         }
1049         
1050         if (!pwddecrypt(&userpwdattr[RAD_Attr_Value], attrvallen, from->peer.secret, strlen(from->peer.secret), auth)) {
1051             printf("radsrv: cannot decrypt password\n");
1052             return NULL;
1053         }
1054         printf("radsrv: password: ");
1055         for (i = 0; i < attrvallen; i++)
1056             printf("%02x ", userpwdattr[RAD_Attr_Value + i]);
1057         printf("\n");
1058         if (!pwdencrypt(&userpwdattr[RAD_Attr_Value], attrvallen, to->peer.secret, strlen(to->peer.secret), newauth)) {
1059             printf("radsrv: cannot encrypt password\n");
1060             return NULL;
1061         }
1062     }
1063
1064     if (tunnelpwdattr) {
1065         printf("radsrv: found tunnelpwdattr of length %d\n", tunnelpwdattr[RAD_Attr_Length]);
1066         attrvallen = tunnelpwdattr[RAD_Attr_Length] - 2;
1067         if (attrvallen < 16 || attrvallen > 128 || attrvallen % 16) {
1068             printf("radsrv: invalid user password length\n");
1069             return NULL;
1070         }
1071         
1072         if (!pwddecrypt(&tunnelpwdattr[RAD_Attr_Value], attrvallen, from->peer.secret, strlen(from->peer.secret), auth)) {
1073             printf("radsrv: cannot decrypt password\n");
1074             return NULL;
1075         }
1076         printf("radsrv: password: ");
1077         for (i = 0; i < attrvallen; i++)
1078             printf("%02x ", tunnelpwdattr[RAD_Attr_Value + i]);
1079         printf("\n");
1080         if (!pwdencrypt(&tunnelpwdattr[RAD_Attr_Value], attrvallen, to->peer.secret, strlen(to->peer.secret), newauth)) {
1081             printf("radsrv: cannot encrypt password\n");
1082             return NULL;
1083         }
1084     }
1085
1086     rq->buf = buf;
1087     rq->from = from;
1088     rq->origid = id;
1089     rq->messageauthattrval = (messageauthattr ? &messageauthattr[RAD_Attr_Value] : NULL);
1090     memcpy(rq->origauth, auth, 16);
1091     memcpy(auth, newauth, 16);
1092     printauth("rq->origauth", rq->origauth);
1093     printauth("auth", auth);
1094     return to;
1095 }
1096
1097 void *clientrd(void *arg) {
1098     struct server *server = (struct server *)arg;
1099     struct client *from;
1100     int i, left, subleft;
1101     unsigned char *buf, *messageauthattr, *subattr, *attr;
1102     struct sockaddr_storage fromsa;
1103     struct timeval lastconnecttry;
1104     char tmp[255];
1105     
1106     for (;;) {
1107     getnext:
1108         lastconnecttry = server->lastconnecttry;
1109         buf = (server->peer.type == 'U' ? radudpget(server->sock, NULL, &server, NULL) : radtlsget(server->peer.ssl));
1110         if (!buf && server->peer.type == 'T') {
1111             tlsconnect(server, &lastconnecttry, "clientrd");
1112             continue;
1113         }
1114     
1115         server->connectionok = 1;
1116
1117         if (*buf != RAD_Access_Accept && *buf != RAD_Access_Reject && *buf != RAD_Access_Challenge) {
1118             printf("clientrd: discarding, only accept access accept, access reject and access challenge messages\n");
1119             continue;
1120         }
1121
1122         printf("got message type: %d, id: %d\n", buf[0], buf[1]);
1123         
1124         i = buf[1]; /* i is the id */
1125
1126         pthread_mutex_lock(&server->newrq_mutex);
1127         if (!server->requests[i].buf || !server->requests[i].tries) {
1128             pthread_mutex_unlock(&server->newrq_mutex);
1129             printf("clientrd: no matching request sent with this id, ignoring\n");
1130             continue;
1131         }
1132
1133         if (server->requests[i].received) {
1134             pthread_mutex_unlock(&server->newrq_mutex);
1135             printf("clientrd: already received, ignoring\n");
1136             continue;
1137         }
1138         
1139         if (!validauth(buf, server->requests[i].buf + 4, server->peer.secret)) {
1140             pthread_mutex_unlock(&server->newrq_mutex);
1141             printf("clientrd: invalid auth, ignoring\n");
1142             continue;
1143         }
1144         
1145         from = server->requests[i].from;
1146
1147         /* messageauthattr present? */
1148         messageauthattr = NULL;
1149         left = RADLEN(buf) - 20;
1150         attr = buf + 20;
1151         while (left > 1) {
1152             left -= attr[RAD_Attr_Length];
1153             if (left < 0) {
1154                 printf("clientrd: attribute length exceeds packet length, ignoring packet\n");
1155                 goto getnext;
1156             }
1157             if (attr[RAD_Attr_Type] == RAD_Attr_Message_Authenticator) {
1158                 if (attr[RAD_Attr_Length] != 18) {
1159                     printf("clientrd: illegal message auth attribute length, ignoring packet\n");
1160                     goto getnext;
1161                 }
1162                 memcpy(tmp, buf + 4, 16);
1163                 memcpy(buf + 4, server->requests[i].buf + 4, 16);
1164                 if (!checkmessageauth(buf, &attr[RAD_Attr_Value], server->peer.secret)) {
1165                     printf("clientrd: message authentication failed\n");
1166                     goto getnext;
1167                 }
1168                 memcpy(buf + 4, tmp, 16);
1169                 printf("clientrd: message auth ok\n");
1170                 messageauthattr = attr;
1171                 break;
1172             }
1173             attr += attr[RAD_Attr_Length];
1174         }
1175
1176         /* handle MS MPPE */
1177         left = RADLEN(buf) - 20;
1178         attr = buf + 20;
1179         while (left > 1) {
1180             left -= attr[RAD_Attr_Length];
1181             if (left < 0) {
1182                 printf("clientrd: attribute length exceeds packet length, ignoring packet\n");
1183                 goto getnext;
1184             }
1185             if (attr[RAD_Attr_Type] == RAD_Attr_Vendor_Specific &&
1186                 ((uint16_t *)attr)[1] == 0 && ntohs(((uint16_t *)attr)[2]) == 311) { // 311 == MS
1187                 subleft = attr[RAD_Attr_Length] - 6;
1188                 subattr = attr + 6;
1189                 while (subleft > 1) {
1190                     subleft -= subattr[RAD_Attr_Length];
1191                     if (subleft < 0)
1192                         break;
1193                     if (subattr[RAD_Attr_Type] != RAD_VS_ATTR_MS_MPPE_Send_Key &&
1194                         subattr[RAD_Attr_Type] != RAD_VS_ATTR_MS_MPPE_Recv_Key)
1195                         continue;
1196                     printf("clientrd: Got MS MPPE\n");
1197                     if (subattr[RAD_Attr_Length] < 20)
1198                         continue;
1199
1200                     if (!msmppdecrypt(subattr + 4, subattr[RAD_Attr_Length] - 4,
1201                             server->peer.secret, strlen(server->peer.secret), server->requests[i].buf + 4, subattr + 2)) {
1202                         printf("clientrd: failed to decrypt msppe key\n");
1203                         continue;
1204                     }
1205
1206                     if (!msmppencrypt(subattr + 4, subattr[RAD_Attr_Length] - 4,
1207                             from->peer.secret, strlen(from->peer.secret), server->requests[i].origauth, subattr + 2)) {
1208                         printf("clientrd: failed to encrypt msppe key\n");
1209                         continue;
1210                     }
1211                 }
1212                 if (subleft < 0) {
1213                     printf("clientrd: bad vendor specific attr or subattr length, ignoring packet\n");
1214                     goto getnext;
1215                 }
1216             }
1217             attr += attr[RAD_Attr_Length];
1218         }
1219
1220         /* once we set received = 1, requests[i] may be reused */
1221         buf[1] = (char)server->requests[i].origid;
1222         memcpy(buf + 4, server->requests[i].origauth, 16);
1223         printauth("origauth/buf+4", buf + 4);
1224         if (messageauthattr) {
1225             if (!createmessageauth(buf, &messageauthattr[RAD_Attr_Value], from->peer.secret))
1226                 continue;
1227             printf("clientrd: computed messageauthattr\n");
1228         }
1229
1230         if (from->peer.type == 'U')
1231             fromsa = server->requests[i].fromsa;
1232         server->requests[i].received = 1;
1233         pthread_mutex_unlock(&server->newrq_mutex);
1234
1235         if (!radsign(buf, from->peer.secret)) {
1236             printf("clientrd: failed to sign message\n");
1237             continue;
1238         }
1239         printauth("signedorigauth/buf+4", buf + 4);             
1240         printf("clientrd: giving packet back to where it came from\n");
1241         sendreply(from, server, buf, from->peer.type == 'U' ? &fromsa : NULL);
1242     }
1243 }
1244
1245 void *clientwr(void *arg) {
1246     struct server *server = (struct server *)arg;
1247     struct request *rq;
1248     pthread_t clientrdth;
1249     int i;
1250     struct timeval now;
1251     struct timespec timeout;
1252
1253     memset(&timeout, 0, sizeof(struct timespec));
1254     
1255     if (server->peer.type == 'U') {
1256         if ((server->sock = connecttoserver(server->peer.addrinfo)) < 0) {
1257             printf("clientwr: connecttoserver failed\n");
1258             exit(1);
1259         }
1260     } else
1261         tlsconnect(server, NULL, "new client");
1262     
1263     if (pthread_create(&clientrdth, NULL, clientrd, (void *)server))
1264         errx("clientwr: pthread_create failed");
1265
1266     for (;;) {
1267         pthread_mutex_lock(&server->newrq_mutex);
1268         if (!server->newrq) {
1269             if (timeout.tv_nsec) {
1270                 printf("clientwr: waiting up to %ld secs for new request\n", timeout.tv_nsec);
1271                 pthread_cond_timedwait(&server->newrq_cond, &server->newrq_mutex, &timeout);
1272                 timeout.tv_nsec = 0;
1273             } else {
1274                 printf("clientwr: waiting for new request\n");
1275                 pthread_cond_wait(&server->newrq_cond, &server->newrq_mutex);
1276             }
1277             if (server->newrq) {
1278                 printf("clientwr: got new request\n");
1279                 server->newrq = 0;
1280             }
1281         }
1282         pthread_mutex_unlock(&server->newrq_mutex);
1283         
1284         for (i = 0; i < MAX_REQUESTS; i++) {
1285             pthread_mutex_lock(&server->newrq_mutex);
1286             while (!server->requests[i].buf && i < MAX_REQUESTS)
1287                 i++;
1288             if (i == MAX_REQUESTS) {
1289                 pthread_mutex_unlock(&server->newrq_mutex);
1290                 break;
1291             }
1292             rq = server->requests + i;
1293
1294             if (rq->received) {
1295                 printf("clientwr: removing received packet from queue\n");
1296                 free(rq->buf);
1297                 /* setting this to NULL means that it can be reused */
1298                 rq->buf = NULL;
1299                 pthread_mutex_unlock(&server->newrq_mutex);
1300                 continue;
1301             }
1302             
1303             gettimeofday(&now, NULL);
1304             if (now.tv_sec <= rq->expiry.tv_sec) {
1305                 if (!timeout.tv_sec || rq->expiry.tv_sec < timeout.tv_sec)
1306                     timeout.tv_sec = rq->expiry.tv_sec;
1307                 pthread_mutex_unlock(&server->newrq_mutex);
1308                 continue;
1309             }
1310
1311             if (rq->tries == (server->peer.type == 'T' ? 1 : REQUEST_RETRIES)) {
1312                 printf("clientwr: removing expired packet from queue\n");
1313                 free(rq->buf);
1314                 /* setting this to NULL means that it can be reused */
1315                 rq->buf = NULL;
1316                 pthread_mutex_unlock(&server->newrq_mutex);
1317                 continue;
1318             }
1319             pthread_mutex_unlock(&server->newrq_mutex);
1320
1321             rq->expiry.tv_sec = now.tv_sec +
1322                 (server->peer.type == 'T' ? REQUEST_EXPIRY : REQUEST_EXPIRY / REQUEST_RETRIES);
1323             if (!timeout.tv_sec || rq->expiry.tv_sec < timeout.tv_sec)
1324                 timeout.tv_sec = rq->expiry.tv_sec;
1325             rq->tries++;
1326             clientradput(server, server->requests[i].buf);
1327             usleep(200000);
1328         }
1329     }
1330     /* should do more work to maintain TLS connections, keepalives etc */
1331 }
1332
1333 void *udpserverwr(void *arg) {
1334     struct replyq *replyq = &udp_server_replyq;
1335     struct reply *reply = replyq->replies;
1336     
1337     pthread_mutex_lock(&replyq->count_mutex);
1338     for (;;) {
1339         while (!replyq->count) {
1340             printf("udp server writer, waiting for signal\n");
1341             pthread_cond_wait(&replyq->count_cond, &replyq->count_mutex);
1342             printf("udp server writer, got signal\n");
1343         }
1344         pthread_mutex_unlock(&replyq->count_mutex);
1345         
1346         if (sendto(udp_server_sock, reply->buf, RADLEN(reply->buf), 0,
1347                    (struct sockaddr *)&reply->tosa, SOCKADDR_SIZE(reply->tosa)) < 0)
1348             err("sendudp: send failed");
1349         free(reply->buf);
1350         
1351         pthread_mutex_lock(&replyq->count_mutex);
1352         replyq->count--;
1353         memmove(replyq->replies, replyq->replies + 1,
1354                 replyq->count * sizeof(struct reply));
1355     }
1356 }
1357
1358 void *udpserverrd(void *arg) {
1359     struct request rq;
1360     unsigned char *buf;
1361     struct server *to;
1362     struct client *fr;
1363     pthread_t udpserverwrth;
1364     
1365     if ((udp_server_sock = bindport(SOCK_DGRAM, options.udpserverport)) < 0) {
1366         printf("udpserverrd: socket/bind failed\n");
1367         exit(1);
1368     }
1369     printf("udpserverrd: listening on UDP port %s\n", options.udpserverport);
1370
1371     if (pthread_create(&udpserverwrth, NULL, udpserverwr, NULL))
1372         errx("pthread_create failed");
1373     
1374     for (;;) {
1375         fr = NULL;
1376         memset(&rq, 0, sizeof(struct request));
1377         buf = radudpget(udp_server_sock, &fr, NULL, &rq.fromsa);
1378         to = radsrv(&rq, buf, fr);
1379         if (!to) {
1380             printf("udpserverrd: ignoring request, no place to send it\n");
1381             continue;
1382         }
1383         sendrq(to, fr, &rq);
1384     }
1385 }
1386
1387 void *tlsserverwr(void *arg) {
1388     int cnt;
1389     unsigned long error;
1390     struct client *client = (struct client *)arg;
1391     struct replyq *replyq;
1392     
1393     printf("tlsserverwr starting for %s\n", client->peer.host);
1394     replyq = client->replyq;
1395     pthread_mutex_lock(&replyq->count_mutex);
1396     for (;;) {
1397         while (!replyq->count) {
1398             if (client->peer.ssl) {         
1399                 printf("tls server writer, waiting for signal\n");
1400                 pthread_cond_wait(&replyq->count_cond, &replyq->count_mutex);
1401                 printf("tls server writer, got signal\n");
1402             }
1403             if (!client->peer.ssl) {
1404                 //ssl might have changed while waiting
1405                 pthread_mutex_unlock(&replyq->count_mutex);
1406                 printf("tlsserverwr: exiting as requested\n");
1407                 pthread_exit(NULL);
1408             }
1409         }
1410         pthread_mutex_unlock(&replyq->count_mutex);
1411         cnt = SSL_write(client->peer.ssl, replyq->replies->buf, RADLEN(replyq->replies->buf));
1412         if (cnt > 0)
1413             printf("tlsserverwr: Sent %d bytes, Radius packet of length %d\n",
1414                    cnt, RADLEN(replyq->replies->buf));
1415         else
1416             while ((error = ERR_get_error()))
1417                 err("tlsserverwr: SSL: %s", ERR_error_string(error, NULL));
1418         free(replyq->replies->buf);
1419
1420         pthread_mutex_lock(&replyq->count_mutex);
1421         replyq->count--;
1422         memmove(replyq->replies, replyq->replies + 1, replyq->count * sizeof(struct reply));
1423     }
1424 }
1425
1426 void *tlsserverrd(void *arg) {
1427     struct request rq;
1428     char unsigned *buf;
1429     unsigned long error;
1430     struct server *to;
1431     int s;
1432     struct client *client = (struct client *)arg;
1433     pthread_t tlsserverwrth;
1434     SSL *ssl;
1435     
1436     printf("tlsserverrd starting for %s\n", client->peer.host);
1437     ssl = client->peer.ssl;
1438
1439     if (SSL_accept(ssl) <= 0) {
1440         while ((error = ERR_get_error()))
1441             err("tlsserverrd: SSL: %s", ERR_error_string(error, NULL));
1442         errx("accept failed, child exiting");
1443     }
1444
1445     if (tlsverifycert(&client->peer)) {
1446         if (pthread_create(&tlsserverwrth, NULL, tlsserverwr, (void *)client))
1447             errx("pthread_create failed");
1448     
1449         for (;;) {
1450             buf = radtlsget(client->peer.ssl);
1451             if (!buf)
1452                 break;
1453             printf("tlsserverrd: got Radius message from %s\n", client->peer.host);
1454             memset(&rq, 0, sizeof(struct request));
1455             to = radsrv(&rq, buf, client);
1456             if (!to) {
1457                 printf("ignoring request, no place to send it\n");
1458                 continue;
1459             }
1460             sendrq(to, client, &rq);
1461         }
1462         printf("tlsserverrd: connection lost\n");
1463         // stop writer by setting peer.ssl to NULL and give signal in case waiting for data
1464         client->peer.ssl = NULL;
1465         pthread_mutex_lock(&client->replyq->count_mutex);
1466         pthread_cond_signal(&client->replyq->count_cond);
1467         pthread_mutex_unlock(&client->replyq->count_mutex);
1468         printf("tlsserverrd: waiting for writer to end\n");
1469         pthread_join(tlsserverwrth, NULL);
1470     }
1471     s = SSL_get_fd(ssl);
1472     SSL_free(ssl);
1473     close(s);
1474     printf("tlsserverrd thread for %s exiting\n", client->peer.host);
1475     client->peer.ssl = NULL;
1476     pthread_exit(NULL);
1477 }
1478
1479 int tlslistener() {
1480     pthread_t tlsserverth;
1481     int s, snew;
1482     struct sockaddr_storage from;
1483     size_t fromlen = sizeof(from);
1484     struct client *client;
1485
1486     if ((s = bindport(SOCK_STREAM, DEFAULT_TLS_PORT)) < 0) {
1487         printf("tlslistener: socket/bind failed\n");
1488         exit(1);
1489     }
1490     
1491     listen(s, 0);
1492     printf("listening for incoming TLS on port %s\n", DEFAULT_TLS_PORT);
1493
1494     for (;;) {
1495         snew = accept(s, (struct sockaddr *)&from, &fromlen);
1496         if (snew < 0)
1497             errx("accept failed");
1498         printf("incoming TLS connection from %s\n", addr2string((struct sockaddr *)&from, fromlen));
1499
1500         client = find_client('T', (struct sockaddr *)&from, NULL);
1501         if (!client) {
1502             printf("ignoring request, not a known TLS client\n");
1503             shutdown(snew, SHUT_RDWR);
1504             close(snew);
1505             continue;
1506         }
1507
1508         if (client->peer.ssl) {
1509             printf("Ignoring incoming connection, already have one from this client\n");
1510             shutdown(snew, SHUT_RDWR);
1511             close(snew);
1512             continue;
1513         }
1514         client->peer.ssl = SSL_new(ssl_ctx);
1515         SSL_set_fd(client->peer.ssl, snew);
1516         if (pthread_create(&tlsserverth, NULL, tlsserverrd, (void *)client))
1517             errx("pthread_create failed");
1518         pthread_detach(tlsserverth);
1519     }
1520     return 0;
1521 }
1522
1523 char *parsehostport(char *s, struct peer *peer) {
1524     char *p, *field;
1525     int ipv6 = 0;
1526
1527     p = s;
1528     // allow literal addresses and port, e.g. [2001:db8::1]:1812
1529     if (*p == '[') {
1530         p++;
1531         field = p;
1532         for (; *p && *p != ']' && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1533         if (*p != ']') {
1534             printf("no ] matching initial [\n");
1535             exit(1);
1536         }
1537         ipv6 = 1;
1538     } else {
1539         field = p;
1540         for (; *p && *p != ':' && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1541     }
1542     if (field == p) {
1543         printf("missing host/address\n");
1544         exit(1);
1545     }
1546     peer->host = stringcopy(field, p - field);
1547     if (ipv6) {
1548         p++;
1549         if (*p && *p != ':' && *p != ' ' && *p != '\t' && *p != '\n') {
1550             printf("unexpected character after ]\n");
1551             exit(1);
1552         }
1553     }
1554     if (*p == ':') {
1555             /* port number or service name is specified */;
1556             field = p++;
1557             for (; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1558             if (field == p) {
1559                 printf("syntax error, : but no following port\n");
1560                 exit(1);
1561             }
1562             peer->port = stringcopy(field, p - field);
1563     } else
1564         peer->port = stringcopy(peer->type == 'U' ? DEFAULT_UDP_PORT : DEFAULT_TLS_PORT, 0);
1565     return p;
1566 }
1567
1568 // * is default, else longest match ... ";" used for separator
1569 char *parserealmlist(char *s, struct server *server) {
1570     char *p;
1571     int i, n, l;
1572
1573     for (p = s, n = 1; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++)
1574         if (*p == ';')
1575             n++;
1576     l = p - s;
1577     if (!l) {
1578         printf("realm list must be specified\n");
1579         exit(1);
1580     }
1581     server->realmdata = stringcopy(s, l);
1582     server->realms = malloc((1+n) * sizeof(char *));
1583     if (!server->realms)
1584         errx("malloc failed");
1585     server->realms[0] = server->realmdata;
1586     for (n = 1, i = 0; i < l; i++)
1587         if (server->realmdata[i] == ';') {
1588             server->realmdata[i] = '\0';
1589             server->realms[n++] = server->realmdata + i + 1;
1590         }       
1591     server->realms[n] = NULL;
1592     return p;
1593 }
1594
1595 /* exactly one argument must be non-NULL */
1596 void getconfig(const char *serverfile, const char *clientfile) {
1597     FILE *f;
1598     char line[1024];
1599     char *p, *field, **r;
1600     const char *file;
1601     struct client *client;
1602     struct server *server;
1603     struct peer *peer;
1604     int i, count, *ucount, *tcount;
1605  
1606     file = serverfile ? serverfile : clientfile;
1607     f = fopen(file, "r");
1608     if (!f)
1609         errx("getconfig failed to open %s for reading", file);
1610     printf("opening file %s for reading\n", file);
1611     if (serverfile) {
1612         ucount = &server_udp_count;
1613         tcount = &server_tls_count;
1614     } else {
1615         ucount = &client_udp_count;
1616         tcount = &client_tls_count;
1617     }
1618     while (fgets(line, 1024, f)) {
1619         for (p = line; *p == ' ' || *p == '\t'; p++);
1620         switch (*p) {
1621         case '#':
1622         case '\n':
1623             break;
1624         case 'T':
1625             (*tcount)++;
1626             break;
1627         case 'U':
1628             (*ucount)++;
1629             break;
1630         default:
1631             printf("type must be U or T, got %c\n", *p);
1632             exit(1);
1633         }
1634     }
1635
1636     if (serverfile) {
1637         count = server_count = server_udp_count + server_tls_count;
1638         servers = calloc(count, sizeof(struct server));
1639         if (!servers)
1640             errx("malloc failed");
1641     } else {
1642         count = client_count = client_udp_count + client_tls_count;
1643         clients = calloc(count, sizeof(struct client));
1644         if (!clients)
1645             errx("malloc failed");
1646     }
1647     
1648     if (client_udp_count) {
1649         udp_server_replyq.replies = malloc(client_udp_count * MAX_REQUESTS * sizeof(struct reply));
1650         if (!udp_server_replyq.replies)
1651             errx("malloc failed");
1652         udp_server_replyq.size = client_udp_count * MAX_REQUESTS;
1653         udp_server_replyq.count = 0;
1654         pthread_mutex_init(&udp_server_replyq.count_mutex, NULL);
1655         pthread_cond_init(&udp_server_replyq.count_cond, NULL);
1656     }    
1657     
1658     rewind(f);
1659     for (i = 0; i < count && fgets(line, 1024, f);) {
1660         if (serverfile) {
1661             server = &servers[i];
1662             peer = &server->peer;
1663         } else {
1664             client = &clients[i];
1665             peer = &client->peer;
1666         }
1667         for (p = line; *p == ' ' || *p == '\t'; p++);
1668         if (*p == '#' || *p == '\n')
1669             continue;
1670         peer->type = *p;        // we already know it must be U or T
1671         for (p++; *p == ' ' || *p == '\t'; p++);
1672         p = parsehostport(p, peer);
1673         for (; *p == ' ' || *p == '\t'; p++);
1674         if (serverfile) {
1675             p = parserealmlist(p, server);
1676             for (; *p == ' ' || *p == '\t'; p++);
1677         }
1678         field = p;
1679         for (; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1680         if (field == p) {
1681             /* no secret set and end of line, line is complete if TLS */
1682             if (peer->type == 'U') {
1683                 printf("secret must be specified for UDP\n");
1684                 exit(1);
1685             }
1686             peer->secret = stringcopy(DEFAULT_TLS_SECRET, 0);
1687         } else {
1688             peer->secret = stringcopy(field, p - field);
1689             /* check that rest of line only white space */
1690             for (; *p == ' ' || *p == '\t'; p++);
1691             if (*p && *p != '\n') {
1692                 printf("max 4 fields per line, found a 5th\n");
1693                 exit(1);
1694             }
1695         }
1696
1697         if ((serverfile && !resolvepeer(&server->peer)) ||
1698             (clientfile && !resolvepeer(&client->peer))) {
1699             printf("failed to resolve host %s port %s, exiting\n", peer->host, peer->port);
1700             exit(1);
1701         }
1702
1703         if (serverfile) {
1704             pthread_mutex_init(&server->lock, NULL);
1705             server->sock = -1;
1706             server->requests = calloc(MAX_REQUESTS, sizeof(struct request));
1707             if (!server->requests)
1708                 errx("malloc failed");
1709             server->newrq = 0;
1710             pthread_mutex_init(&server->newrq_mutex, NULL);
1711             pthread_cond_init(&server->newrq_cond, NULL);
1712         } else {
1713             if (peer->type == 'U')
1714                 client->replyq = &udp_server_replyq;
1715             else {
1716                 client->replyq = malloc(sizeof(struct replyq));
1717                 if (!client->replyq)
1718                     errx("malloc failed");
1719                 client->replyq->replies = calloc(MAX_REQUESTS, sizeof(struct reply));
1720                 if (!client->replyq->replies)
1721                     errx("malloc failed");
1722                 client->replyq->size = MAX_REQUESTS;
1723                 client->replyq->count = 0;
1724                 pthread_mutex_init(&client->replyq->count_mutex, NULL);
1725                 pthread_cond_init(&client->replyq->count_cond, NULL);
1726             }
1727         }
1728         printf("got type %c, host %s, port %s, secret %s\n", peer->type, peer->host, peer->port, peer->secret);
1729         if (serverfile) {
1730             printf("    with realms:");
1731             for (r = server->realms; *r; r++)
1732                 printf(" %s", *r);
1733             printf("\n");
1734         }
1735         i++;
1736     }
1737     fclose(f);
1738 }
1739
1740 void getmainconfig(const char *configfile) {
1741     FILE *f;
1742     char line[1024];
1743     char *p, *opt, *endopt, *val, *endval;
1744     
1745     printf("opening file %s for reading\n", configfile);
1746     f = fopen(configfile, "r");
1747     if (!f)
1748         errx("getmainconfig failed to open %s for reading", configfile);
1749
1750     memset(&options, 0, sizeof(options));
1751
1752     while (fgets(line, 1024, f)) {
1753         for (p = line; *p == ' ' || *p == '\t'; p++);
1754         if (!*p || *p == '#' || *p == '\n')
1755             continue;
1756         opt = p++;
1757         for (; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1758         endopt = p - 1;
1759         for (; *p == ' ' || *p == '\t'; p++);
1760         if (!*p || *p == '\n') {
1761             endopt[1] = '\0';
1762             printf("error in %s, option %s has no value\n", configfile, opt);
1763             exit(1);
1764         }
1765         val = p;
1766         for (; *p && *p != '\n'; p++)
1767             if (*p != ' ' && *p != '\t')
1768                 endval = p;
1769         endopt[1] = '\0';
1770         endval[1] = '\0';
1771         printf("getmainconfig: %s = %s\n", opt, val);
1772         
1773         if (!strcasecmp(opt, "TLSCACertificateFile")) {
1774             options.tlscacertificatefile = stringcopy(val, 0);
1775             continue;
1776         }
1777         if (!strcasecmp(opt, "TLSCACertificatePath")) {
1778             options.tlscacertificatepath = stringcopy(val, 0);
1779             continue;
1780         }
1781         if (!strcasecmp(opt, "TLSCertificateFile")) {
1782             options.tlscertificatefile = stringcopy(val, 0);
1783             continue;
1784         }
1785         if (!strcasecmp(opt, "TLSCertificateKeyFile")) {
1786             options.tlscertificatekeyfile = stringcopy(val, 0);
1787             continue;
1788         }
1789         if (!strcasecmp(opt, "UDPServerPort")) {
1790             options.udpserverport = stringcopy(val, 0);
1791             continue;
1792         }
1793
1794         printf("error in %s, unknown option %s\n", configfile, opt);
1795         exit(1);
1796     }
1797     fclose(f);
1798
1799     if (!options.udpserverport)
1800         options.udpserverport = stringcopy(DEFAULT_UDP_PORT, 0);
1801 }
1802
1803 #if 0
1804 void parseargs(int argc, char **argv) {
1805     int c;
1806
1807     while ((c = getopt(argc, argv, "p:")) != -1) {
1808         switch (c) {
1809         case 'p':
1810             udp_server_port = optarg;
1811             break;
1812         default:
1813             goto usage;
1814         }
1815     }
1816
1817     return;
1818
1819  usage:
1820     printf("radsecproxy [ -p UDP-port ]\n");
1821     exit(1);
1822 }
1823 #endif
1824
1825 int main(int argc, char **argv) {
1826     pthread_t udpserverth;
1827     //    pthread_attr_t joinable;
1828     int i;
1829     
1830     //    parseargs(argc, argv);
1831     getmainconfig("radsecproxy.conf");
1832     getconfig("servers.conf", NULL);
1833     getconfig(NULL, "clients.conf");
1834
1835     //    pthread_attr_init(&joinable);
1836     //    pthread_attr_setdetachstate(&joinable, PTHREAD_CREATE_JOINABLE);
1837    
1838     if (client_udp_count)
1839         if (pthread_create(&udpserverth, NULL /*&joinable*/, udpserverrd, NULL))
1840             errx("pthread_create failed");
1841
1842     if (client_tls_count || server_tls_count)
1843         ssl_ctx = ssl_init();
1844     
1845     for (i = 0; i < server_count; i++)
1846         if (pthread_create(&servers[i].clientth, NULL, clientwr, (void *)&servers[i]))
1847             errx("pthread_create failed");
1848
1849     if (client_tls_count)
1850         return tlslistener();
1851
1852     /* just hang around doing nothing, anything to do here? */
1853     for (;;)
1854         sleep(1000);
1855 }