code for cleaning up when tls client goes away
[libradsec.git] / radsecproxy.c
1 /*
2  * Copyright (C) 2006 Stig Venaas <venaas@uninett.no>
3  *
4  * Permission to use, copy, modify, and distribute this software for any
5  * purpose with or without fee is hereby granted, provided that the above
6  * copyright notice and this permission notice appear in all copies.
7  */
8
9 /* BUGS:
10  * peers can not yet be specified with literal IPv6 addresses due to port syntax
11  */
12
13 /* TODO:
14  * make our server ignore client retrans and do its own instead?
15  * accounting
16  * radius keep alives (server status)
17  * tls certificate validation, see below urls
18  * clean tls shutdown, see http://www.linuxjournal.com/article/4822
19  *     and http://www.linuxjournal.com/article/5487
20  *     SSL_shutdown() and shutdown()
21  *     If shutdown() we may not need REUSEADDR
22  * when tls client goes away, ensure that all related threads and state
23  *          are removed
24  * setsockopt(keepalive...), check if openssl has some keepalive feature
25 */
26
27 /* For UDP there is one server instance consisting of udpserverrd and udpserverth
28  *              rd is responsible for init and launching wr
29  * For TLS there is a server instance that launches tlsserverrd for each TLS peer
30  *          each tlsserverrd launches tlsserverwr
31  * For each UDP/TLS peer there is clientrd and clientwr, clientwr is responsible
32  *          for init and launching rd
33  *
34  * serverrd will receive a request, processes it and puts it in the requestq of
35  *          the appropriate clientwr
36  * clientwr monitors its requestq and sends requests
37  * clientrd looks for responses, processes them and puts them in the replyq of
38  *          the peer the request came from
39  * serverwr monitors its reply and sends replies
40  *
41  * In addition to the main thread, we have:
42  * If UDP peers are configured, there will be 2 + 2 * #peers UDP threads
43  * If TLS peers are configured, there will initially be 2 * #peers TLS threads
44  * For each TLS peer connecting to us there will be 2 more TLS threads
45  *       This is only for connected peers
46  * Example: With 3 UDP peer and 30 TLS peers, there will be a max of
47  *          1 + (2 + 2 * 3) + (2 * 30) + (2 * 30) = 129 threads
48 */
49
50 #include <netdb.h>
51 #include <unistd.h>
52 #include <sys/time.h>
53 #include <pthread.h>
54 #include <openssl/ssl.h>
55 #include <openssl/rand.h>
56 #include <openssl/err.h>
57 #include <openssl/md5.h>
58 #include <openssl/hmac.h>
59 #include "radsecproxy.h"
60
61 static struct client clients[MAX_PEERS];
62 static struct server servers[MAX_PEERS];
63
64 static int client_count = 0;
65 static int server_count = 0;
66
67 static struct replyq udp_server_replyq;
68 static int udp_server_sock = -1;
69 static char *udp_server_port = DEFAULT_UDP_PORT;
70 static pthread_mutex_t *ssl_locks;
71 static long *ssl_lock_count;
72 static SSL_CTX *ssl_ctx_cl;
73 extern int optind;
74 extern char *optarg;
75
76 /* callbacks for making OpenSSL thread safe */
77 unsigned long ssl_thread_id() {
78         return (unsigned long)pthread_self();
79 };
80
81 void ssl_locking_callback(int mode, int type, const char *file, int line) {
82     if (mode & CRYPTO_LOCK) {
83         pthread_mutex_lock(&ssl_locks[type]);
84         ssl_lock_count[type]++;
85     } else
86         pthread_mutex_unlock(&ssl_locks[type]);
87 }
88
89 void ssl_locks_setup() {
90     int i;
91
92     ssl_locks = malloc(CRYPTO_num_locks() * sizeof(pthread_mutex_t));
93     ssl_lock_count = OPENSSL_malloc(CRYPTO_num_locks() * sizeof(long));
94     for (i = 0; i < CRYPTO_num_locks(); i++) {
95         ssl_lock_count[i] = 0;
96         pthread_mutex_init(&ssl_locks[i], NULL);
97     }
98
99     CRYPTO_set_id_callback(ssl_thread_id);
100     CRYPTO_set_locking_callback(ssl_locking_callback);
101 }
102
103 void printauth(char *s, unsigned char *t) {
104     int i;
105     printf("%s:", s);
106     for (i = 0; i < 16; i++)
107             printf("%02x ", t[i]);
108     printf("\n");
109 }
110
111 int resolvepeer(struct peer *peer) {
112     struct addrinfo hints, *addrinfo;
113     
114     memset(&hints, 0, sizeof(hints));
115     hints.ai_socktype = (peer->type == 'T' ? SOCK_STREAM : SOCK_DGRAM);
116     hints.ai_family = AF_UNSPEC;
117     if (getaddrinfo(peer->host, peer->port, &hints, &addrinfo)) {
118         err("resolvepeer: can't resolve %s port %s", peer->host, peer->port);
119         return 0;
120     }
121
122     if (peer->addrinfo)
123         freeaddrinfo(peer->addrinfo);
124     peer->addrinfo = addrinfo;
125     return 1;
126 }         
127
128 int connecttoserver(struct addrinfo *addrinfo) {
129     int s;
130     struct addrinfo *res;
131     
132     for (res = addrinfo; res; res = res->ai_next) {
133         s = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
134         if (s < 0) {
135             err("connecttoserver: socket failed");
136             continue;
137         }
138         if (connect(s, res->ai_addr, res->ai_addrlen) == 0)
139             break;
140         err("connecttoserver: connect failed");
141         close(s);
142         s = -1;
143     }
144     return s;
145 }         
146
147 /* returns the client with matching address, or NULL */
148 /* if client argument is not NULL, we only check that one client */
149 struct client *find_client(char type, struct sockaddr *addr, struct client *client) {
150     struct sockaddr_in6 *sa6;
151     struct in_addr *a4 = NULL;
152     struct client *c;
153     int i;
154     struct addrinfo *res;
155
156     if (addr->sa_family == AF_INET6) {
157         sa6 = (struct sockaddr_in6 *)addr;
158         if (IN6_IS_ADDR_V4MAPPED(&sa6->sin6_addr))
159             a4 = (struct in_addr *)&sa6->sin6_addr.s6_addr[12];
160     } else
161         a4 = &((struct sockaddr_in *)addr)->sin_addr;
162
163     c = (client ? client : clients);
164     for (i = 0; i < client_count; i++) {
165         if (c->peer.type == type)
166             for (res = c->peer.addrinfo; res; res = res->ai_next)
167                 if ((a4 && res->ai_family == AF_INET &&
168                      !memcmp(a4, &((struct sockaddr_in *)res->ai_addr)->sin_addr, 4)) ||
169                     (res->ai_family == AF_INET6 &&
170                      !memcmp(&sa6->sin6_addr, &((struct sockaddr_in6 *)res->ai_addr)->sin6_addr, 16)))
171                     return c;
172         if (client)
173             break;
174         c++;
175     }
176     return NULL;
177 }
178
179 /* returns the server with matching address, or NULL */
180 /* if server argument is not NULL, we only check that one server */
181 struct server *find_server(char type, struct sockaddr *addr, struct server *server) {
182     struct sockaddr_in6 *sa6;
183     struct in_addr *a4 = NULL;
184     struct server *s;
185     int i;
186     struct addrinfo *res;
187
188     if (addr->sa_family == AF_INET6) {
189         sa6 = (struct sockaddr_in6 *)addr;
190         if (IN6_IS_ADDR_V4MAPPED(&sa6->sin6_addr))
191             a4 = (struct in_addr *)&sa6->sin6_addr.s6_addr[12];
192     } else
193         a4 = &((struct sockaddr_in *)addr)->sin_addr;
194
195     s = (server ? server : servers);
196     for (i = 0; i < server_count; i++) {
197         if (s->peer.type == type)
198             for (res = s->peer.addrinfo; res; res = res->ai_next)
199                 if ((a4 && res->ai_family == AF_INET &&
200                      !memcmp(a4, &((struct sockaddr_in *)res->ai_addr)->sin_addr, 4)) ||
201                     (res->ai_family == AF_INET6 &&
202                      !memcmp(&sa6->sin6_addr, &((struct sockaddr_in6 *)res->ai_addr)->sin6_addr, 16)))
203                     return s;
204         if (server)
205             break;
206         s++;
207     }
208     return NULL;
209 }
210
211 /* exactly one of client and server must be non-NULL */
212 /* if *peer == NULL we return who we received from, else require it to be from peer */
213 /* return from in sa if not NULL */
214 unsigned char *radudpget(int s, struct client **client, struct server **server, struct sockaddr_storage *sa) {
215     int cnt, len;
216     void *f;
217     unsigned char buf[65536], *rad;
218     struct sockaddr_storage from;
219     socklen_t fromlen = sizeof(from);
220
221     for (;;) {
222         cnt = recvfrom(s, buf, sizeof(buf), 0, (struct sockaddr *)&from, &fromlen);
223         if (cnt == -1) {
224             err("radudpget: recv failed");
225             continue;
226         }
227         printf("radudpget: got %d bytes from %s\n", cnt, addr2string((struct sockaddr *)&from, fromlen));
228
229         if (cnt < 20) {
230             printf("radudpget: packet too small\n");
231             continue;
232         }
233     
234         len = RADLEN(buf);
235
236         if (cnt < len) {
237             printf("radudpget: packet smaller than length field in radius header\n");
238             continue;
239         }
240         if (cnt > len)
241             printf("radudpget: packet was padded with %d bytes\n", cnt - len);
242
243         f = (client
244              ? (void *)find_client('U', (struct sockaddr *)&from, *client)
245              : (void *)find_server('U', (struct sockaddr *)&from, *server));
246         if (!f) {
247             printf("radudpget: got packet from wrong or unknown UDP peer, ignoring\n");
248             continue;
249         }
250
251         rad = malloc(len);
252         if (rad)
253             break;
254         err("radudpget: malloc failed");
255     }
256     memcpy(rad, buf, len);
257     if (client)
258         *client = (struct client *)f; /* only need this if *client == NULL, but if not NULL *client == f here */
259     else
260         *server = (struct server *)f; /* only need this if *server == NULL, but if not NULL *server == f here */
261     if (sa)
262         *sa = from;
263     return rad;
264 }
265
266 void tlsconnect(struct server *server, struct timeval *when, char *text) {
267     struct timeval now;
268     time_t elapsed;
269     unsigned long error;
270
271     printf("tlsconnect called from %s\n", text);
272     pthread_mutex_lock(&server->lock);
273     if (when && memcmp(&server->lastconnecttry, when, sizeof(struct timeval))) {
274         /* already reconnected, nothing to do */
275         printf("tlsconnect(%s): seems already reconnected\n", text);
276         pthread_mutex_unlock(&server->lock);
277         return;
278     }
279
280     printf("tlsconnect %s\n", text);
281
282     for (;;) {
283         gettimeofday(&now, NULL);
284         elapsed = now.tv_sec - server->lastconnecttry.tv_sec;
285         if (server->connectionok) {
286             server->connectionok = 0;
287             sleep(10);
288         } else if (elapsed < 5)
289             sleep(10);
290         else if (elapsed < 600)
291             sleep(elapsed * 2);
292         else if (elapsed < 10000)
293                 sleep(900);
294         else
295             server->lastconnecttry.tv_sec = now.tv_sec;  // no sleep at startup
296         printf("tlsconnect: trying to open TLS connection to %s port %s\n", server->peer.host, server->peer.port);
297         if (server->sock >= 0)
298             close(server->sock);
299         if ((server->sock = connecttoserver(server->peer.addrinfo)) < 0)
300             continue;
301         SSL_free(server->peer.ssl);
302         server->peer.ssl = SSL_new(ssl_ctx_cl);
303         SSL_set_fd(server->peer.ssl, server->sock);
304         if (SSL_connect(server->peer.ssl) > 0)
305             break;
306         while ((error = ERR_get_error()))
307             err("tlsconnect: TLS: %s", ERR_error_string(error, NULL));
308     }
309     printf("tlsconnect: TLS connection to %s port %s up\n", server->peer.host, server->peer.port);
310     gettimeofday(&server->lastconnecttry, NULL);
311     pthread_mutex_unlock(&server->lock);
312 }
313
314 unsigned char *radtlsget(SSL *ssl) {
315     int cnt, total, len;
316     unsigned char buf[4], *rad;
317
318     for (;;) {
319         for (total = 0; total < 4; total += cnt) {
320             cnt = SSL_read(ssl, buf + total, 4 - total);
321             if (cnt <= 0) {
322                 printf("radtlsget: connection lost\n");
323                 if (SSL_get_error(ssl, cnt) == SSL_ERROR_ZERO_RETURN) {
324                     //remote end sent close_notify, send one back
325                     SSL_shutdown(ssl);
326                 }
327                 return NULL;
328             }
329         }
330
331         len = RADLEN(buf);
332         rad = malloc(len);
333         if (!rad) {
334             err("radtlsget: malloc failed");
335             continue;
336         }
337         memcpy(rad, buf, 4);
338
339         for (; total < len; total += cnt) {
340             cnt = SSL_read(ssl, rad + total, len - total);
341             if (cnt <= 0) {
342                 printf("radtlsget: connection lost\n");
343                 if (SSL_get_error(ssl, cnt) == SSL_ERROR_ZERO_RETURN) {
344                     //remote end sent close_notify, send one back
345                     SSL_shutdown(ssl);
346                 }
347                 free(rad);
348                 return NULL;
349             }
350         }
351     
352         if (total >= 20)
353             break;
354         
355         free(rad);
356         printf("radtlsget: packet smaller than minimum radius size\n");
357     }
358     
359     printf("radtlsget: got %d bytes\n", total);
360     return rad;
361 }
362
363 int clientradput(struct server *server, unsigned char *rad) {
364     int cnt;
365     size_t len;
366     unsigned long error;
367     struct timeval lastconnecttry;
368     
369     len = RADLEN(rad);
370     if (server->peer.type == 'U') {
371         if (send(server->sock, rad, len, 0) >= 0) {
372             printf("clienradput: sent UDP of length %d to %s port %s\n", len, server->peer.host, server->peer.port);
373             return 1;
374         }
375         err("clientradput: send failed");
376         return 0;
377     }
378
379     lastconnecttry = server->lastconnecttry;
380     while ((cnt = SSL_write(server->peer.ssl, rad, len)) <= 0) {
381         while ((error = ERR_get_error()))
382             err("clientwr: TLS: %s", ERR_error_string(error, NULL));
383         tlsconnect(server, &lastconnecttry, "clientradput");
384         lastconnecttry = server->lastconnecttry;
385     }
386
387     server->connectionok = 1;
388     printf("clientradput: Sent %d bytes, Radius packet of length %d to TLS peer %s\n",
389            cnt, len, server->peer.host);
390     return 1;
391 }
392
393 int radsign(unsigned char *rad, unsigned char *sec) {
394     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
395     static unsigned char first = 1;
396     static EVP_MD_CTX mdctx;
397     unsigned int md_len;
398     int result;
399     
400     pthread_mutex_lock(&lock);
401     if (first) {
402         EVP_MD_CTX_init(&mdctx);
403         first = 0;
404     }
405
406     result = (EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) &&
407         EVP_DigestUpdate(&mdctx, rad, RADLEN(rad)) &&
408         EVP_DigestUpdate(&mdctx, sec, strlen(sec)) &&
409         EVP_DigestFinal_ex(&mdctx, rad + 4, &md_len) &&
410         md_len == 16);
411     pthread_mutex_unlock(&lock);
412     return result;
413 }
414
415 int validauth(unsigned char *rad, unsigned char *reqauth, unsigned char *sec) {
416     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
417     static unsigned char first = 1;
418     static EVP_MD_CTX mdctx;
419     unsigned char hash[EVP_MAX_MD_SIZE];
420     unsigned int len;
421     int result;
422     
423     pthread_mutex_lock(&lock);
424     if (first) {
425         EVP_MD_CTX_init(&mdctx);
426         first = 0;
427     }
428
429     len = RADLEN(rad);
430     
431     result = (EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) &&
432               EVP_DigestUpdate(&mdctx, rad, 4) &&
433               EVP_DigestUpdate(&mdctx, reqauth, 16) &&
434               (len <= 20 || EVP_DigestUpdate(&mdctx, rad + 20, len - 20)) &&
435               EVP_DigestUpdate(&mdctx, sec, strlen(sec)) &&
436               EVP_DigestFinal_ex(&mdctx, hash, &len) &&
437               len == 16 &&
438               !memcmp(hash, rad + 4, 16));
439     pthread_mutex_unlock(&lock);
440     return result;
441 }
442               
443 int checkmessageauth(char *rad, uint8_t *authattr, char *secret) {
444     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
445     static unsigned char first = 1;
446     static HMAC_CTX hmacctx;
447     unsigned int md_len;
448     uint8_t auth[16], hash[EVP_MAX_MD_SIZE];
449     
450     pthread_mutex_lock(&lock);
451     if (first) {
452         HMAC_CTX_init(&hmacctx);
453         first = 0;
454     }
455
456     memcpy(auth, authattr, 16);
457     memset(authattr, 0, 16);
458     md_len = 0;
459     HMAC_Init_ex(&hmacctx, secret, strlen(secret), EVP_md5(), NULL);
460     HMAC_Update(&hmacctx, rad, RADLEN(rad));
461     HMAC_Final(&hmacctx, hash, &md_len);
462     memcpy(authattr, auth, 16);
463     if (md_len != 16) {
464         printf("message auth computation failed\n");
465         pthread_mutex_unlock(&lock);
466         return 0;
467     }
468
469     if (memcmp(auth, hash, 16)) {
470         printf("message authenticator, wrong value\n");
471         pthread_mutex_unlock(&lock);
472         return 0;
473     }   
474         
475     pthread_mutex_unlock(&lock);
476     return 1;
477 }
478
479 int createmessageauth(char *rad, char *authattrval, char *secret) {
480     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
481     static unsigned char first = 1;
482     static HMAC_CTX hmacctx;
483     unsigned int md_len;
484
485     if (!authattrval)
486         return 1;
487     
488     pthread_mutex_lock(&lock);
489     if (first) {
490         HMAC_CTX_init(&hmacctx);
491         first = 0;
492     }
493
494     memset(authattrval, 0, 16);
495     md_len = 0;
496     HMAC_Init_ex(&hmacctx, secret, strlen(secret), EVP_md5(), NULL);
497     HMAC_Update(&hmacctx, rad, RADLEN(rad));
498     HMAC_Final(&hmacctx, authattrval, &md_len);
499     if (md_len != 16) {
500         printf("message auth computation failed\n");
501         pthread_mutex_unlock(&lock);
502         return 0;
503     }
504
505     pthread_mutex_unlock(&lock);
506     return 1;
507 }
508
509 void sendrq(struct server *to, struct client *from, struct request *rq) {
510     int i;
511     
512     pthread_mutex_lock(&to->newrq_mutex);
513     /* might simplify if only try nextid, might be ok */
514     for (i = to->nextid; i < MAX_REQUESTS; i++)
515         if (!to->requests[i].buf)
516             break;
517     if (i == MAX_REQUESTS) {
518         for (i = 0; i < to->nextid; i++)
519             if (!to->requests[i].buf)
520                 break;
521         if (i == to->nextid) {
522             printf("No room in queue, dropping request\n");
523             pthread_mutex_unlock(&to->newrq_mutex);
524             return;
525         }
526     }
527     
528     to->nextid = i + 1;
529     rq->buf[1] = (char)i;
530     printf("sendrq: inserting packet with id %d in queue for %s\n", i, to->peer.host);
531     
532     if (!createmessageauth(rq->buf, rq->messageauthattrval, to->peer.secret))
533         return;
534
535     gettimeofday(&rq->expiry, NULL);
536     rq->expiry.tv_sec += 30;
537     to->requests[i] = *rq;
538
539     if (!to->newrq) {
540         to->newrq = 1;
541         printf("signalling client writer\n");
542         pthread_cond_signal(&to->newrq_cond);
543     }
544     pthread_mutex_unlock(&to->newrq_mutex);
545 }
546
547 void sendreply(struct client *to, struct server *from, char *buf, struct sockaddr_storage *tosa) {
548     struct replyq *replyq = to->replyq;
549     
550     pthread_mutex_lock(&replyq->count_mutex);
551     if (replyq->count == replyq->size) {
552         printf("No room in queue, dropping request\n");
553         pthread_mutex_unlock(&replyq->count_mutex);
554         return;
555     }
556
557     replyq->replies[replyq->count].buf = buf;
558     if (tosa)
559         replyq->replies[replyq->count].tosa = *tosa;
560     replyq->count++;
561
562     if (replyq->count == 1) {
563         printf("signalling client writer\n");
564         pthread_cond_signal(&replyq->count_cond);
565     }
566     pthread_mutex_unlock(&replyq->count_mutex);
567 }
568
569 int pwdencrypt(uint8_t *in, uint8_t len, uint8_t *shared, uint8_t sharedlen, uint8_t *auth) {
570     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
571     static unsigned char first = 1;
572     static EVP_MD_CTX mdctx;
573     unsigned char hash[EVP_MAX_MD_SIZE], *input;
574     unsigned int md_len;
575     uint8_t i, offset = 0, out[128];
576     
577     pthread_mutex_lock(&lock);
578     if (first) {
579         EVP_MD_CTX_init(&mdctx);
580         first = 0;
581     }
582
583     input = auth;
584     for (;;) {
585         if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
586             !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
587             !EVP_DigestUpdate(&mdctx, input, 16) ||
588             !EVP_DigestFinal_ex(&mdctx, hash, &md_len) ||
589             md_len != 16) {
590             pthread_mutex_unlock(&lock);
591             return 0;
592         }
593         for (i = 0; i < 16; i++)
594             out[offset + i] = hash[i] ^ in[offset + i];
595         input = out + offset - 16;
596         offset += 16;
597         if (offset == len)
598             break;
599     }
600     memcpy(in, out, len);
601     pthread_mutex_unlock(&lock);
602     return 1;
603 }
604
605 int pwddecrypt(uint8_t *in, uint8_t len, uint8_t *shared, uint8_t sharedlen, uint8_t *auth) {
606     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
607     static unsigned char first = 1;
608     static EVP_MD_CTX mdctx;
609     unsigned char hash[EVP_MAX_MD_SIZE], *input;
610     unsigned int md_len;
611     uint8_t i, offset = 0, out[128];
612     
613     pthread_mutex_lock(&lock);
614     if (first) {
615         EVP_MD_CTX_init(&mdctx);
616         first = 0;
617     }
618
619     input = auth;
620     for (;;) {
621         if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
622             !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
623             !EVP_DigestUpdate(&mdctx, input, 16) ||
624             !EVP_DigestFinal_ex(&mdctx, hash, &md_len) ||
625             md_len != 16) {
626             pthread_mutex_unlock(&lock);
627             return 0;
628         }
629         for (i = 0; i < 16; i++)
630             out[offset + i] = hash[i] ^ in[offset + i];
631         input = in + offset;
632         offset += 16;
633         if (offset == len)
634             break;
635     }
636     memcpy(in, out, len);
637     pthread_mutex_unlock(&lock);
638     return 1;
639 }
640
641 int msmppencrypt(uint8_t *text, uint8_t len, uint8_t *shared, uint8_t sharedlen, uint8_t *auth, uint8_t *salt) {
642     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
643     static unsigned char first = 1;
644     static EVP_MD_CTX mdctx;
645     unsigned char hash[EVP_MAX_MD_SIZE];
646     unsigned int md_len;
647     uint8_t i, offset;
648     
649     pthread_mutex_lock(&lock);
650     if (first) {
651         EVP_MD_CTX_init(&mdctx);
652         first = 0;
653     }
654
655 #if 0    
656     printf("msppencrypt auth in: ");
657     for (i = 0; i < 16; i++)
658         printf("%02x ", auth[i]);
659     printf("\n");
660     
661     printf("msppencrypt salt in: ");
662     for (i = 0; i < 2; i++)
663         printf("%02x ", salt[i]);
664     printf("\n");
665     
666     printf("msppencrypt in: ");
667     for (i = 0; i < len; i++)
668         printf("%02x ", text[i]);
669     printf("\n");
670 #endif
671     
672     if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
673         !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
674         !EVP_DigestUpdate(&mdctx, auth, 16) ||
675         !EVP_DigestUpdate(&mdctx, salt, 2) ||
676         !EVP_DigestFinal_ex(&mdctx, hash, &md_len)) {
677         pthread_mutex_unlock(&lock);
678         return 0;
679     }
680
681 #if 0    
682     printf("msppencrypt hash: ");
683     for (i = 0; i < 16; i++)
684         printf("%02x ", hash[i]);
685     printf("\n");
686 #endif
687     
688     for (i = 0; i < 16; i++)
689         text[i] ^= hash[i];
690     
691     for (offset = 16; offset < len; offset += 16) {
692 #if 0   
693         printf("text + offset - 16 c(%d): ", offset / 16);
694         for (i = 0; i < 16; i++)
695             printf("%02x ", (text + offset - 16)[i]);
696         printf("\n");
697 #endif
698         if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
699             !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
700             !EVP_DigestUpdate(&mdctx, text + offset - 16, 16) ||
701             !EVP_DigestFinal_ex(&mdctx, hash, &md_len) ||
702             md_len != 16) {
703             pthread_mutex_unlock(&lock);
704             return 0;
705         }
706 #if 0   
707         printf("msppencrypt hash: ");
708         for (i = 0; i < 16; i++)
709             printf("%02x ", hash[i]);
710         printf("\n");
711 #endif    
712         
713         for (i = 0; i < 16; i++)
714             text[offset + i] ^= hash[i];
715     }
716     
717 #if 0
718     printf("msppencrypt out: ");
719     for (i = 0; i < len; i++)
720         printf("%02x ", text[i]);
721     printf("\n");
722 #endif
723
724     pthread_mutex_unlock(&lock);
725     return 1;
726 }
727
728 int msmppdecrypt(uint8_t *text, uint8_t len, uint8_t *shared, uint8_t sharedlen, uint8_t *auth, uint8_t *salt) {
729     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
730     static unsigned char first = 1;
731     static EVP_MD_CTX mdctx;
732     unsigned char hash[EVP_MAX_MD_SIZE];
733     unsigned int md_len;
734     uint8_t i, offset;
735     char plain[255];
736     
737     pthread_mutex_lock(&lock);
738     if (first) {
739         EVP_MD_CTX_init(&mdctx);
740         first = 0;
741     }
742
743 #if 0    
744     printf("msppdecrypt auth in: ");
745     for (i = 0; i < 16; i++)
746         printf("%02x ", auth[i]);
747     printf("\n");
748     
749     printf("msppedecrypt salt in: ");
750     for (i = 0; i < 2; i++)
751         printf("%02x ", salt[i]);
752     printf("\n");
753     
754     printf("msppedecrypt in: ");
755     for (i = 0; i < len; i++)
756         printf("%02x ", text[i]);
757     printf("\n");
758 #endif
759     
760     if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
761         !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
762         !EVP_DigestUpdate(&mdctx, auth, 16) ||
763         !EVP_DigestUpdate(&mdctx, salt, 2) ||
764         !EVP_DigestFinal_ex(&mdctx, hash, &md_len)) {
765         pthread_mutex_unlock(&lock);
766         return 0;
767     }
768
769 #if 0    
770     printf("msppedecrypt hash: ");
771     for (i = 0; i < 16; i++)
772         printf("%02x ", hash[i]);
773     printf("\n");
774 #endif
775     
776     for (i = 0; i < 16; i++)
777         plain[i] = text[i] ^ hash[i];
778     
779     for (offset = 16; offset < len; offset += 16) {
780 #if 0   
781         printf("text + offset - 16 c(%d): ", offset / 16);
782         for (i = 0; i < 16; i++)
783             printf("%02x ", (text + offset - 16)[i]);
784         printf("\n");
785 #endif
786         if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
787             !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
788             !EVP_DigestUpdate(&mdctx, text + offset - 16, 16) ||
789             !EVP_DigestFinal_ex(&mdctx, hash, &md_len) ||
790             md_len != 16) {
791             pthread_mutex_unlock(&lock);
792             return 0;
793         }
794 #if 0   
795     printf("msppedecrypt hash: ");
796     for (i = 0; i < 16; i++)
797         printf("%02x ", hash[i]);
798     printf("\n");
799 #endif    
800
801     for (i = 0; i < 16; i++)
802         plain[offset + i] = text[offset + i] ^ hash[i];
803     }
804
805     memcpy(text, plain, len);
806 #if 0
807     printf("msppedecrypt out: ");
808     for (i = 0; i < len; i++)
809         printf("%02x ", text[i]);
810     printf("\n");
811 #endif
812
813     pthread_mutex_unlock(&lock);
814     return 1;
815 }
816
817 struct server *id2server(char *id, uint8_t len) {
818     int i;
819     char **realm, *idrealm;
820
821     idrealm = strchr(id, '@');
822     if (idrealm) {
823         idrealm++;
824         len -= idrealm - id;
825     } else {
826         idrealm = "-";
827         len = 1;
828     }
829     for (i = 0; i < server_count; i++) {
830         for (realm = servers[i].realms; *realm; realm++) {
831             if ((strlen(*realm) == 1 && **realm == '*') ||
832                 (strlen(*realm) == len && !memcmp(idrealm, *realm, len))) {
833                 printf("found matching realm: %s, host %s\n", *realm, servers[i].peer.host);
834                 return servers + i;
835             }
836         }
837     }
838     return NULL;
839 }
840
841 struct server *radsrv(struct request *rq, char *buf, struct client *from) {
842     uint8_t code, id, *auth, *attr, attrvallen;
843     uint8_t *usernameattr = NULL, *userpwdattr = NULL, *tunnelpwdattr = NULL, *messageauthattr = NULL;
844     int i;
845     uint16_t len;
846     int left;
847     struct server *to;
848     unsigned char newauth[16];
849     
850     code = *(uint8_t *)buf;
851     id = *(uint8_t *)(buf + 1);
852     len = RADLEN(buf);
853     auth = (uint8_t *)(buf + 4);
854
855     printf("radsrv: code %d, id %d, length %d\n", code, id, len);
856     
857     if (code != RAD_Access_Request) {
858         printf("radsrv: server currently accepts only access-requests, ignoring\n");
859         return NULL;
860     }
861
862     left = len - 20;
863     attr = buf + 20;
864     
865     while (left > 1) {
866         left -= attr[RAD_Attr_Length];
867         if (left < 0) {
868             printf("radsrv: attribute length exceeds packet length, ignoring packet\n");
869             return NULL;
870         }
871         switch (attr[RAD_Attr_Type]) {
872         case RAD_Attr_User_Name:
873             usernameattr = attr;
874             break;
875         case RAD_Attr_User_Password:
876             userpwdattr = attr;
877             break;
878         case RAD_Attr_Tunnel_Password:
879             tunnelpwdattr = attr;
880             break;
881         case RAD_Attr_Message_Authenticator:
882             messageauthattr = attr;
883             break;
884         }
885         attr += attr[RAD_Attr_Length];
886     }
887     if (left)
888         printf("radsrv: malformed packet? remaining byte after last attribute\n");
889
890     if (usernameattr) {
891         printf("radsrv: Username: ");
892         for (i = 0; i < usernameattr[RAD_Attr_Length] - 2; i++)
893             printf("%c", usernameattr[RAD_Attr_Value + i]);
894         printf("\n");
895     }
896
897     to = id2server(&usernameattr[RAD_Attr_Value], usernameattr[RAD_Attr_Length] - 2);
898     if (!to) {
899         printf("radsrv: ignoring request, don't know where to send it\n");
900         return NULL;
901     }
902     
903     if (messageauthattr && (messageauthattr[RAD_Attr_Length] != 18 ||
904                             !checkmessageauth(buf, &messageauthattr[RAD_Attr_Value], from->peer.secret))) {
905         printf("radsrv: message authentication failed\n");
906         return NULL;
907     }
908
909     if (!RAND_bytes(newauth, 16)) {
910         printf("radsrv: failed to generate random auth\n");
911         return NULL;
912     }
913
914     printauth("auth", auth);
915     printauth("newauth", newauth);
916     
917     if (userpwdattr) {
918         printf("radsrv: found userpwdattr of length %d\n", userpwdattr[RAD_Attr_Length]);
919         attrvallen = userpwdattr[RAD_Attr_Length] - 2;
920         if (attrvallen < 16 || attrvallen > 128 || attrvallen % 16) {
921             printf("radsrv: invalid user password length\n");
922             return NULL;
923         }
924         
925         if (!pwddecrypt(&userpwdattr[RAD_Attr_Value], attrvallen, from->peer.secret, strlen(from->peer.secret), auth)) {
926             printf("radsrv: cannot decrypt password\n");
927             return NULL;
928         }
929         printf("radsrv: password: ");
930         for (i = 0; i < attrvallen; i++)
931             printf("%02x ", userpwdattr[RAD_Attr_Value + i]);
932         printf("\n");
933         if (!pwdencrypt(&userpwdattr[RAD_Attr_Value], attrvallen, to->peer.secret, strlen(to->peer.secret), newauth)) {
934             printf("radsrv: cannot encrypt password\n");
935             return NULL;
936         }
937     }
938
939     if (tunnelpwdattr) {
940         printf("radsrv: found tunnelpwdattr of length %d\n", tunnelpwdattr[RAD_Attr_Length]);
941         attrvallen = tunnelpwdattr[RAD_Attr_Length] - 2;
942         if (attrvallen < 16 || attrvallen > 128 || attrvallen % 16) {
943             printf("radsrv: invalid user password length\n");
944             return NULL;
945         }
946         
947         if (!pwddecrypt(&tunnelpwdattr[RAD_Attr_Value], attrvallen, from->peer.secret, strlen(from->peer.secret), auth)) {
948             printf("radsrv: cannot decrypt password\n");
949             return NULL;
950         }
951         printf("radsrv: password: ");
952         for (i = 0; i < attrvallen; i++)
953             printf("%02x ", tunnelpwdattr[RAD_Attr_Value + i]);
954         printf("\n");
955         if (!pwdencrypt(&tunnelpwdattr[RAD_Attr_Value], attrvallen, to->peer.secret, strlen(to->peer.secret), newauth)) {
956             printf("radsrv: cannot encrypt password\n");
957             return NULL;
958         }
959     }
960
961     rq->buf = buf;
962     rq->from = from;
963     rq->origid = id;
964     rq->messageauthattrval = (messageauthattr ? &messageauthattr[RAD_Attr_Value] : NULL);
965     memcpy(rq->origauth, auth, 16);
966     memcpy(auth, newauth, 16);
967     printauth("rq->origauth", rq->origauth);
968     printauth("auth", auth);
969     return to;
970 }
971
972 void *clientrd(void *arg) {
973     struct server *server = (struct server *)arg;
974     struct client *from;
975     int i, left, subleft;
976     unsigned char *buf, *messageauthattr, *subattr, *attr;
977     struct sockaddr_storage fromsa;
978     struct timeval lastconnecttry;
979     char tmp[255];
980     
981     for (;;) {
982     getnext:
983         lastconnecttry = server->lastconnecttry;
984         buf = (server->peer.type == 'U' ? radudpget(server->sock, NULL, &server, NULL) : radtlsget(server->peer.ssl));
985         if (!buf && server->peer.type == 'T') {
986             tlsconnect(server, &lastconnecttry, "clientrd");
987             continue;
988         }
989     
990         server->connectionok = 1;
991
992         if (*buf != RAD_Access_Accept && *buf != RAD_Access_Reject && *buf != RAD_Access_Challenge) {
993             printf("clientrd: discarding, only accept access accept, access reject and access challenge messages\n");
994             continue;
995         }
996         
997         i = buf[1]; /* i is the id */
998
999         pthread_mutex_lock(&server->newrq_mutex);
1000         if (!server->requests[i].buf || !server->requests[i].tries) {
1001             pthread_mutex_unlock(&server->newrq_mutex);
1002             printf("clientrd: no matching request sent with this id, ignoring\n");
1003             continue;
1004         }
1005
1006         if (server->requests[i].received) {
1007             pthread_mutex_unlock(&server->newrq_mutex);
1008             printf("clientrd: already received, ignoring\n");
1009             continue;
1010         }
1011         
1012         if (!validauth(buf, server->requests[i].buf + 4, server->peer.secret)) {
1013             pthread_mutex_unlock(&server->newrq_mutex);
1014             printf("clientrd: invalid auth, ignoring\n");
1015             continue;
1016         }
1017         
1018         from = server->requests[i].from;
1019
1020
1021         /* messageauthattr present? */
1022         messageauthattr = NULL;
1023         left = RADLEN(buf) - 20;
1024         attr = buf + 20;
1025         while (left > 1) {
1026             left -= attr[RAD_Attr_Length];
1027             if (left < 0) {
1028                 printf("clientrd: attribute length exceeds packet length, ignoring packet\n");
1029                 goto getnext;
1030             }
1031             if (attr[RAD_Attr_Type] == RAD_Attr_Message_Authenticator) {
1032                 if (attr[RAD_Attr_Length] != 18) {
1033                     printf("clientrd: illegal message auth attribute length, ignoring packet\n");
1034                     goto getnext;
1035                 }
1036                 memcpy(tmp, buf + 4, 16);
1037                 memcpy(buf + 4, server->requests[i].buf + 4, 16);
1038                 if (!checkmessageauth(buf, &attr[RAD_Attr_Value], server->peer.secret)) {
1039                     printf("clientrd: message authentication failed\n");
1040                     goto getnext;
1041                 }
1042                 memcpy(buf + 4, tmp, 16);
1043                 printf("clientrd: message auth ok\n");
1044                 messageauthattr = attr;
1045                 break;
1046             }
1047             attr += attr[RAD_Attr_Length];
1048         }
1049
1050         /* handle MS MPPE */
1051         left = RADLEN(buf) - 20;
1052         attr = buf + 20;
1053         while (left > 1) {
1054             left -= attr[RAD_Attr_Length];
1055             if (left < 0) {
1056                 printf("clientrd: attribute length exceeds packet length, ignoring packet\n");
1057                 goto getnext;
1058             }
1059             if (attr[RAD_Attr_Type] == RAD_Attr_Vendor_Specific &&
1060                 ((uint16_t *)attr)[1] == 0 && ntohs(((uint16_t *)attr)[2]) == 311) { // 311 == MS
1061                 subleft = attr[RAD_Attr_Length] - 6;
1062                 subattr = attr + 6;
1063                 while (subleft > 1) {
1064                     subleft -= subattr[RAD_Attr_Length];
1065                     if (subleft < 0)
1066                         break;
1067                     if (subattr[RAD_Attr_Type] != RAD_VS_ATTR_MS_MPPE_Send_Key &&
1068                         subattr[RAD_Attr_Type] != RAD_VS_ATTR_MS_MPPE_Recv_Key)
1069                         continue;
1070                     printf("clientrd: Got MS MPPE\n");
1071                     if (subattr[RAD_Attr_Length] < 20)
1072                         continue;
1073
1074                     if (!msmppdecrypt(subattr + 4, subattr[RAD_Attr_Length] - 4,
1075                             server->peer.secret, strlen(server->peer.secret), server->requests[i].buf + 4, subattr + 2)) {
1076                         printf("clientrd: failed to decrypt msppe key\n");
1077                         continue;
1078                     }
1079
1080                     if (!msmppencrypt(subattr + 4, subattr[RAD_Attr_Length] - 4,
1081                             from->peer.secret, strlen(from->peer.secret), server->requests[i].origauth, subattr + 2)) {
1082                         printf("clientrd: failed to encrypt msppe key\n");
1083                         continue;
1084                     }
1085                 }
1086                 if (subleft < 0) {
1087                     printf("clientrd: bad vendor specific attr or subattr length, ignoring packet\n");
1088                     goto getnext;
1089                 }
1090             }
1091             attr += attr[RAD_Attr_Length];
1092         }
1093
1094         /* once we set received = 1, requests[i] may be reused */
1095         buf[1] = (char)server->requests[i].origid;
1096         memcpy(buf + 4, server->requests[i].origauth, 16);
1097         printauth("origauth/buf+4", buf + 4);
1098         if (messageauthattr) {
1099             if (!createmessageauth(buf, &messageauthattr[RAD_Attr_Value], from->peer.secret))
1100                 continue;
1101             printf("clientrd: computed messageauthattr\n");
1102         }
1103
1104         if (from->peer.type == 'U')
1105             fromsa = server->requests[i].fromsa;
1106         server->requests[i].received = 1;
1107         pthread_mutex_unlock(&server->newrq_mutex);
1108
1109         if (!radsign(buf, from->peer.secret)) {
1110             printf("clientrd: failed to sign message\n");
1111             continue;
1112         }
1113         printauth("signedorigauth/buf+4", buf + 4);             
1114         printf("clientrd: giving packet back to where it came from\n");
1115         sendreply(from, server, buf, from->peer.type == 'U' ? &fromsa : NULL);
1116     }
1117 }
1118
1119 void *clientwr(void *arg) {
1120     struct server *server = (struct server *)arg;
1121     struct request *rq;
1122     pthread_t clientrdth;
1123     int i;
1124     struct timeval now;
1125     
1126     if (server->peer.type == 'U') {
1127         if ((server->sock = connecttoserver(server->peer.addrinfo)) < 0) {
1128             printf("clientwr: connecttoserver failed\n");
1129             exit(1);
1130         }
1131     } else
1132         tlsconnect(server, NULL, "new client");
1133     
1134     if (pthread_create(&clientrdth, NULL, clientrd, (void *)server))
1135         errx("clientwr: pthread_create failed");
1136
1137     for (;;) {
1138         pthread_mutex_lock(&server->newrq_mutex);
1139         while (!server->newrq) {
1140             printf("clientwr: waiting for signal\n");
1141             pthread_cond_wait(&server->newrq_cond, &server->newrq_mutex);
1142             printf("clientwr: got signal\n");
1143         }
1144         server->newrq = 0;
1145         pthread_mutex_unlock(&server->newrq_mutex);
1146                
1147         for (i = 0; i < MAX_REQUESTS; i++) {
1148             pthread_mutex_lock(&server->newrq_mutex);
1149             while (!server->requests[i].buf && i < MAX_REQUESTS)
1150                 i++;
1151             if (i == MAX_REQUESTS) {
1152                 pthread_mutex_unlock(&server->newrq_mutex);
1153                 break;
1154             }
1155
1156             gettimeofday(&now, NULL);
1157             rq = server->requests + i;
1158
1159             if (rq->received) {
1160                 printf("clientwr: removing received packet from queue\n");
1161                 free(rq->buf);
1162                 /* setting this to NULL means that it can be reused */
1163                 rq->buf = NULL;
1164                 pthread_mutex_unlock(&server->newrq_mutex);
1165                 continue;
1166             }
1167             if (now.tv_sec > rq->expiry.tv_sec) {
1168                 printf("clientwr: removing expired packet from queue\n");
1169                 free(rq->buf);
1170                 /* setting this to NULL means that it can be reused */
1171                 rq->buf = NULL;
1172                 pthread_mutex_unlock(&server->newrq_mutex);
1173                 continue;
1174             }
1175
1176             if (rq->tries)
1177                 continue; // not re-sending (yet)
1178             
1179             rq->tries++;
1180             pthread_mutex_unlock(&server->newrq_mutex);
1181             
1182             clientradput(server, server->requests[i].buf);
1183         }
1184     }
1185     /* should do more work to maintain TLS connections, keepalives etc */
1186 }
1187
1188 void *udpserverwr(void *arg) {
1189     struct replyq *replyq = &udp_server_replyq;
1190     struct reply *reply = replyq->replies;
1191     
1192     pthread_mutex_lock(&replyq->count_mutex);
1193     for (;;) {
1194         while (!replyq->count) {
1195             printf("udp server writer, waiting for signal\n");
1196             pthread_cond_wait(&replyq->count_cond, &replyq->count_mutex);
1197             printf("udp server writer, got signal\n");
1198         }
1199         pthread_mutex_unlock(&replyq->count_mutex);
1200         
1201         if (sendto(udp_server_sock, reply->buf, RADLEN(reply->buf), 0,
1202                    (struct sockaddr *)&reply->tosa, SOCKADDR_SIZE(reply->tosa)) < 0)
1203             err("sendudp: send failed");
1204         free(reply->buf);
1205         
1206         pthread_mutex_lock(&replyq->count_mutex);
1207         replyq->count--;
1208         memmove(replyq->replies, replyq->replies + 1,
1209                 replyq->count * sizeof(struct reply));
1210     }
1211 }
1212
1213 void *udpserverrd(void *arg) {
1214     struct request rq;
1215     unsigned char *buf;
1216     struct server *to;
1217     struct client *fr;
1218     pthread_t udpserverwrth;
1219     
1220     if ((udp_server_sock = bindport(SOCK_DGRAM, udp_server_port)) < 0) {
1221         printf("udpserverrd: socket/bind failed\n");
1222         exit(1);
1223     }
1224     printf("udpserverrd: listening on UDP port %s\n", udp_server_port);
1225
1226     if (pthread_create(&udpserverwrth, NULL, udpserverwr, NULL))
1227         errx("pthread_create failed");
1228     
1229     for (;;) {
1230         fr = NULL;
1231         memset(&rq, 0, sizeof(struct request));
1232         buf = radudpget(udp_server_sock, &fr, NULL, &rq.fromsa);
1233         to = radsrv(&rq, buf, fr);
1234         if (!to) {
1235             printf("udpserverrd: ignoring request, no place to send it\n");
1236             continue;
1237         }
1238         sendrq(to, fr, &rq);
1239     }
1240 }
1241
1242 void *tlsserverwr(void *arg) {
1243     int cnt;
1244     unsigned long error;
1245     struct client *client = (struct client *)arg;
1246     struct replyq *replyq;
1247     
1248     printf("tlsserverwr starting for %s\n", client->peer.host);
1249     replyq = client->replyq;
1250     pthread_mutex_lock(&replyq->count_mutex);
1251     for (;;) {
1252         while (!replyq->count) {
1253             if (client->peer.ssl) {         
1254                 printf("tls server writer, waiting for signal\n");
1255                 pthread_cond_wait(&replyq->count_cond, &replyq->count_mutex);
1256                 printf("tls server writer, got signal\n");
1257             }
1258             if (!client->peer.ssl) {
1259                 //ssl might have changed while waiting
1260                 pthread_mutex_unlock(&replyq->count_mutex);
1261                 printf("tlsserverwr: exiting as requested\n");
1262                 pthread_exit(NULL);
1263             }
1264         }
1265         pthread_mutex_unlock(&replyq->count_mutex);
1266         cnt = SSL_write(client->peer.ssl, replyq->replies->buf, RADLEN(replyq->replies->buf));
1267         if (cnt > 0)
1268             printf("tlsserverwr: Sent %d bytes, Radius packet of length %d\n",
1269                    cnt, RADLEN(replyq->replies->buf));
1270         else
1271             while ((error = ERR_get_error()))
1272                 err("tlsserverwr: SSL: %s", ERR_error_string(error, NULL));
1273         free(replyq->replies->buf);
1274
1275         pthread_mutex_lock(&replyq->count_mutex);
1276         replyq->count--;
1277         memmove(replyq->replies, replyq->replies + 1, replyq->count * sizeof(struct reply));
1278     }
1279 }
1280
1281 void *tlsserverrd(void *arg) {
1282     struct request rq;
1283     char unsigned *buf;
1284     unsigned long error;
1285     struct server *to;
1286     int s;
1287     struct client *client = (struct client *)arg;
1288     pthread_t tlsserverwrth;
1289     SSL *ssl;
1290     
1291     printf("tlsserverrd starting for %s\n", client->peer.host);
1292     ssl = client->peer.ssl;
1293
1294     if (SSL_accept(ssl) <= 0) {
1295         while ((error = ERR_get_error()))
1296             err("tlsserverrd: SSL: %s", ERR_error_string(error, NULL));
1297         errx("accept failed, child exiting");
1298     }
1299
1300     if (pthread_create(&tlsserverwrth, NULL, tlsserverwr, (void *)client))
1301         errx("pthread_create failed");
1302     
1303     for (;;) {
1304         buf = radtlsget(client->peer.ssl);
1305         if (!buf)
1306             break;
1307         printf("tlsserverrd: got Radius message from %s\n", client->peer.host);
1308         memset(&rq, 0, sizeof(struct request));
1309         to = radsrv(&rq, buf, client);
1310         if (!to) {
1311             printf("ignoring request, no place to send it\n");
1312             continue;
1313         }
1314         sendrq(to, client, &rq);
1315     }
1316     printf("tlsserverrd: connection lost\n");
1317     // stop writer by setting peer.ssl to NULL and give signal in case waiting for data
1318     client->peer.ssl = NULL;
1319     pthread_mutex_lock(&client->replyq->count_mutex);
1320     pthread_cond_signal(&client->replyq->count_cond);
1321     pthread_mutex_unlock(&client->replyq->count_mutex);
1322     printf("tlsserverrd: waiting for writer to end\n");
1323     pthread_join(tlsserverwrth, NULL);
1324     s = SSL_get_fd(ssl);
1325     SSL_free(ssl);
1326     close(s);
1327     printf("tlsserverrd thread for %s exiting\n", client->peer.host);
1328     pthread_exit(NULL);
1329 }
1330
1331 int tlslistener(SSL_CTX *ssl_ctx) {
1332     pthread_t tlsserverth;
1333     int s, snew;
1334     struct sockaddr_storage from;
1335     size_t fromlen = sizeof(from);
1336     struct client *client;
1337
1338     if ((s = bindport(SOCK_STREAM, DEFAULT_TLS_PORT)) < 0) {
1339         printf("tlslistener: socket/bind failed\n");
1340         exit(1);
1341     }
1342     
1343     listen(s, 0);
1344     printf("listening for incoming TLS on port %s\n", DEFAULT_TLS_PORT);
1345
1346     for (;;) {
1347         snew = accept(s, (struct sockaddr *)&from, &fromlen);
1348         if (snew < 0)
1349             errx("accept failed");
1350         printf("incoming TLS connection from %s\n", addr2string((struct sockaddr *)&from, fromlen));
1351
1352         client = find_client('T', (struct sockaddr *)&from, NULL);
1353         if (!client) {
1354             printf("ignoring request, not a known TLS client\n");
1355             close(snew);
1356             continue;
1357         }
1358
1359         if (client->peer.ssl) {
1360             printf("Ignoring incoming connection, already have one from this client\n");
1361             close(snew);
1362             continue;
1363         }
1364         client->peer.ssl = SSL_new(ssl_ctx);
1365         SSL_set_fd(client->peer.ssl, snew);
1366         if (pthread_create(&tlsserverth, NULL, tlsserverrd, (void *)client))
1367             errx("pthread_create failed");
1368         pthread_detach(tlsserverth);
1369     }
1370     return 0;
1371 }
1372
1373 char *parsehostport(char *s, struct peer *peer) {
1374     char *p, *field;
1375     int ipv6 = 0;
1376
1377     p = s;
1378     // allow literal addresses and port, e.g. [2001:db8::1]:1812
1379     if (*p == '[') {
1380         p++;
1381         field = p;
1382         for (; *p && *p != ']' && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1383         if (*p != ']') {
1384             printf("no ] matching initial [\n");
1385             exit(1);
1386         }
1387         ipv6 = 1;
1388     } else {
1389         field = p;
1390         for (; *p && *p != ':' && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1391     }
1392     if (field == p) {
1393         printf("missing host/address\n");
1394         exit(1);
1395     }
1396     peer->host = malloc(p - field + 1);
1397     if (!peer->host)
1398         errx("malloc failed");
1399     memcpy(peer->host, field, p - field);
1400     peer->host[p - field] = '\0';
1401     if (ipv6) {
1402         p++;
1403         if (*p && *p != ':' && *p != ' ' && *p != '\t' && *p != '\n') {
1404             printf("unexpected character after ]\n");
1405             exit(1);
1406         }
1407     }
1408     if (*p == ':') {
1409             /* port number or service name is specified */;
1410             field = p++;
1411             for (; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1412             if (field == p) {
1413                 printf("syntax error, : but no following port\n");
1414                 exit(1);
1415             }
1416             peer->port = malloc(p - field + 1);
1417             if (!peer->port)
1418                 errx("malloc failed");
1419             memcpy(peer->port, field, p - field);
1420             peer->port[p - field] = '\0';
1421     } else
1422         peer->port = NULL;
1423     return p;
1424 }
1425
1426 // * is default, else longest match ... ";" used for separator
1427 char *parserealmlist(char *s, struct server *server) {
1428     char *p;
1429     int i, n, l;
1430
1431     for (p = s, n = 1; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++)
1432         if (*p == ';')
1433             n++;
1434     l = p - s;
1435     if (!l) {
1436         server->realms = NULL;
1437         return p;
1438     }
1439     server->realmdata = malloc(l + 1);
1440     if (!server->realmdata)
1441         errx("malloc failed");
1442     memcpy(server->realmdata, s, l);
1443     server->realmdata[l] = '\0';
1444     server->realms = malloc((1+n) * sizeof(char *));
1445     if (!server->realms)
1446         errx("malloc failed");
1447     server->realms[0] = server->realmdata;
1448     for (n = 1, i = 0; i < l; i++)
1449         if (server->realmdata[i] == ';') {
1450             server->realmdata[i] = '\0';
1451             server->realms[n++] = server->realmdata + i + 1;
1452         }       
1453     server->realms[n] = NULL;
1454     return p;
1455 }
1456
1457 /* exactly one argument must be non-NULL */
1458 void getconfig(const char *serverfile, const char *clientfile) {
1459     FILE *f;
1460     char line[1024];
1461     char *p, *field, **r;
1462     struct client *client;
1463     struct server *server;
1464     struct peer *peer;
1465     int *count;
1466     
1467     if (serverfile) {
1468         printf("opening file %s for reading\n", serverfile);
1469         f = fopen(serverfile, "r");
1470         if (!f)
1471             errx("getconfig failed to open %s for reading", serverfile);
1472         count = &server_count;
1473     } else {
1474         printf("opening file %s for reading\n", clientfile);
1475         f = fopen(clientfile, "r");
1476         if (!f)
1477             errx("getconfig failed to open %s for reading", clientfile);
1478         udp_server_replyq.replies = malloc(4 * MAX_REQUESTS * sizeof(struct reply));
1479         if (!udp_server_replyq.replies)
1480             errx("malloc failed");
1481         udp_server_replyq.size = 4 * MAX_REQUESTS;
1482         udp_server_replyq.count = 0;
1483         pthread_mutex_init(&udp_server_replyq.count_mutex, NULL);
1484         pthread_cond_init(&udp_server_replyq.count_cond, NULL);
1485         count = &client_count;
1486     }    
1487     
1488     *count = 0;
1489     while (fgets(line, 1024, f) && *count < MAX_PEERS) {
1490         if (serverfile) {
1491             server = &servers[*count];
1492             memset(server, 0, sizeof(struct server));
1493             peer = &server->peer;
1494         } else {
1495             client = &clients[*count];
1496             memset(client, 0, sizeof(struct client));
1497             peer = &client->peer;
1498         }
1499         for (p = line; *p == ' ' || *p == '\t'; p++);
1500         if (*p == '#' || *p == '\n')
1501             continue;
1502         if (*p != 'U' && *p != 'T') {
1503             printf("server type must be U or T, got %c\n", *p);
1504             exit(1);
1505         }
1506         peer->type = *p;
1507         for (p++; *p == ' ' || *p == '\t'; p++);
1508         p = parsehostport(p, peer);
1509         if (!peer->port)
1510             peer->port = (peer->type == 'U' ? DEFAULT_UDP_PORT : DEFAULT_TLS_PORT);
1511         for (; *p == ' ' || *p == '\t'; p++);
1512         if (serverfile) {
1513             p = parserealmlist(p, server);
1514             if (!server->realms) {
1515                 printf("realm list must be specified\n");
1516                 exit(1);
1517             }
1518             for (; *p == ' ' || *p == '\t'; p++);
1519         }
1520         field = p;
1521         for (; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1522         if (field == p) {
1523             /* no secret set and end of line, line is complete if TLS */
1524             if (peer->type == 'U') {
1525                 printf("secret must be specified for UDP\n");
1526                 exit(1);
1527             }
1528             peer->secret = DEFAULT_TLS_SECRET;
1529         } else {
1530             peer->secret = malloc(p - field + 1);
1531             if (!peer->secret)
1532                 errx("malloc failed");
1533             memcpy(peer->secret, field, p - field);
1534             peer->secret[p - field] = '\0';
1535             /* check that rest of line only white space */
1536             for (; *p == ' ' || *p == '\t'; p++);
1537             if (*p && *p != '\n') {
1538                 printf("max 4 fields per line, found a 5th\n");
1539                 exit(1);
1540             }
1541         }
1542
1543         if ((serverfile && !resolvepeer(&server->peer)) ||
1544             (clientfile && !resolvepeer(&client->peer))) {
1545             printf("failed to resolve host %s port %s, exiting\n", peer->host, peer->port);
1546             exit(1);
1547         }
1548
1549         if (serverfile) {
1550             pthread_mutex_init(&server->lock, NULL);
1551             server->sock = -1;
1552             server->requests = malloc(MAX_REQUESTS * sizeof(struct request));
1553             if (!server->requests)
1554                 errx("malloc failed");
1555             memset(server->requests, 0, MAX_REQUESTS * sizeof(struct request));
1556             server->newrq = 0;
1557             pthread_mutex_init(&server->newrq_mutex, NULL);
1558             pthread_cond_init(&server->newrq_cond, NULL);
1559         } else {
1560             if (peer->type == 'U')
1561                 client->replyq = &udp_server_replyq;
1562             else {
1563                 client->replyq = malloc(sizeof(struct replyq));
1564                 if (!client->replyq)
1565                     errx("malloc failed");
1566                 client->replyq->replies = malloc(MAX_REQUESTS * sizeof(struct reply));
1567                 if (!client->replyq->replies)
1568                     errx("malloc failed");
1569                 client->replyq->size = MAX_REQUESTS;
1570                 client->replyq->count = 0;
1571                 pthread_mutex_init(&client->replyq->count_mutex, NULL);
1572                 pthread_cond_init(&client->replyq->count_cond, NULL);
1573             }
1574         }
1575         printf("got type %c, host %s, port %s, secret %s\n", peer->type, peer->host, peer->port, peer->secret);
1576         if (serverfile) {
1577             printf("    with realms:");
1578             for (r = server->realms; *r; r++)
1579                 printf(" %s", *r);
1580             printf("\n");
1581         }
1582         (*count)++;
1583     }
1584     fclose(f);
1585 }
1586
1587 void parseargs(int argc, char **argv) {
1588     int c;
1589
1590     while ((c = getopt(argc, argv, "p:")) != -1) {
1591         switch (c) {
1592         case 'p':
1593             udp_server_port = optarg;
1594             break;
1595         default:
1596             goto usage;
1597         }
1598     }
1599
1600     return;
1601
1602  usage:
1603     printf("radsecproxy [ -p UDP-port ]\n");
1604     exit(1);
1605 }
1606                
1607 int main(int argc, char **argv) {
1608     SSL_CTX *ssl_ctx_srv;
1609     unsigned long error;
1610     pthread_t udpserverth;
1611     //    pthread_attr_t joinable;
1612     int i;
1613     
1614     parseargs(argc, argv);
1615     getconfig("servers.conf", NULL);
1616     getconfig(NULL, "clients.conf");
1617     
1618     ssl_locks_setup();
1619
1620     //    pthread_attr_init(&joinable);
1621     //    pthread_attr_setdetachstate(&joinable, PTHREAD_CREATE_JOINABLE);
1622    
1623     /* listen on UDP if at least one UDP client */
1624     
1625     for (i = 0; i < client_count; i++)
1626         if (clients[i].peer.type == 'U') {
1627             if (pthread_create(&udpserverth, NULL /*&joinable*/, udpserverrd, NULL))
1628                 errx("pthread_create failed");
1629             break;
1630         }
1631     
1632     /* SSL setup */
1633     SSL_load_error_strings();
1634     SSL_library_init();
1635
1636     while (!RAND_status()) {
1637         time_t t = time(NULL);
1638         pid_t pid = getpid();
1639         RAND_seed((unsigned char *)&t, sizeof(time_t));
1640         RAND_seed((unsigned char *)&pid, sizeof(pid));
1641     }
1642     
1643     /* initialise client part and start clients */
1644     ssl_ctx_cl = SSL_CTX_new(TLSv1_client_method());
1645     if (!ssl_ctx_cl)
1646         errx("no ssl ctx");
1647     
1648     for (i = 0; i < server_count; i++) {
1649         if (pthread_create(&servers[i].clientth, NULL, clientwr, (void *)&servers[i]))
1650             errx("pthread_create failed");
1651     }
1652
1653     for (i = 0; i < client_count; i++)
1654         if (clients[i].peer.type == 'T')
1655             break;
1656
1657     if (i == client_count) {
1658         printf("No TLS clients defined, not starting TLS listener\n");
1659         /* just hang around doing nothing, anything to do here? */
1660         for (;;)
1661             sleep(1000);
1662     }
1663     
1664     /* setting up server/daemon part */
1665     ssl_ctx_srv = SSL_CTX_new(TLSv1_server_method());
1666     if (!ssl_ctx_srv)
1667         errx("no ssl ctx");
1668     if (!SSL_CTX_use_certificate_file(ssl_ctx_srv, "/tmp/server.pem", SSL_FILETYPE_PEM)) {
1669         while ((error = ERR_get_error()))
1670             err("SSL: %s", ERR_error_string(error, NULL));
1671         errx("Failed to load certificate");
1672     }
1673     if (!SSL_CTX_use_PrivateKey_file(ssl_ctx_srv, "/tmp/server.key", SSL_FILETYPE_PEM)) {
1674         while ((error = ERR_get_error()))
1675             err("SSL: %s", ERR_error_string(error, NULL));
1676         errx("Failed to load private key");
1677     }
1678
1679     return tlslistener(ssl_ctx_srv);
1680 }