messageauth
[libradsec.git] / radsecproxy.c
1 /*
2  * Copyright (C) 2006 Stig Venaas <venaas@uninett.no>
3  *
4  * Permission to use, copy, modify, and distribute this software for any
5  * purpose with or without fee is hereby granted, provided that the above
6  * copyright notice and this permission notice appear in all copies.
7  */
8
9 /* BUGS:
10  * peers can not yet be specified with literal IPv6 addresses due to port syntax
11  */
12
13 /* TODO:
14  * Among other things:
15  * timer based client retrans or maybe no retrans and just a timer...
16  * make our server ignore client retrans?
17  * tls keep alives
18  * routing based on id....
19  * need to also encrypt Tunnel-Password and Message-Authenticator attrs
20  * tls certificate validation
21 */
22
23 /* For UDP there is one server instance consisting of udpserverrd and udpserverth
24  *              rd is responsible for init and launching wr
25  * For TLS there is a server instance that launches tlsserverrd for each TLS peer
26  *          each tlsserverrd launches tlsserverwr
27  * For each UDP/TLS peer there is clientrd and clientwr, clientwr is responsible
28  *          for init and launching rd
29  *
30  * serverrd will receive a request, processes it and puts it in the requestq of
31  *          the appropriate clientwr
32  * clientwr monitors its requestq and sends requests
33  * clientrd looks for responses, processes them and puts them in the replyq of
34  *          the peer the request came from
35  * serverwr monitors its reply and sends replies
36  *
37  * In addition to the main thread, we have:
38  * If UDP peers are configured, there will be 2 + 2 * #peers UDP threads
39  * If TLS peers are configured, there will initially be 2 * #peers TLS threads
40  * For each TLS peer connecting to us there will be 2 more TLS threads
41  *       This is only for connected peers
42  * Example: With 3 UDP peer and 30 TLS peers, there will be a max of
43  *          1 + (2 + 2 * 3) + (2 * 30) + (2 * 30) = 129 threads
44 */
45
46 #include <netdb.h>
47 #include <unistd.h>
48 #include <sys/time.h>
49 #include <pthread.h>
50 #include <openssl/ssl.h>
51 #include <openssl/rand.h>
52 #include <openssl/err.h>
53 #include <openssl/md5.h>
54 #include <openssl/hmac.h>
55 #include "radsecproxy.h"
56
57 static struct client clients[MAX_PEERS];
58 static struct server servers[MAX_PEERS];
59
60 static int client_count = 0;
61 static int server_count = 0;
62
63 static struct replyq udp_server_replyq;
64 static int udp_server_sock = -1;
65 static char *udp_server_port = DEFAULT_UDP_PORT;
66 static pthread_mutex_t *ssl_locks;
67 static long *ssl_lock_count;
68 static SSL_CTX *ssl_ctx_cl;
69 extern int optind;
70 extern char *optarg;
71
72 /* callbacks for making OpenSSL thread safe */
73 unsigned long ssl_thread_id() {
74         return (unsigned long)pthread_self();
75 };
76
77 void ssl_locking_callback(int mode, int type, const char *file, int line) {
78     if (mode & CRYPTO_LOCK) {
79         pthread_mutex_lock(&ssl_locks[type]);
80         ssl_lock_count[type]++;
81     } else
82         pthread_mutex_unlock(&ssl_locks[type]);
83 }
84
85 void ssl_locks_setup() {
86     int i;
87
88     ssl_locks = malloc(CRYPTO_num_locks() * sizeof(pthread_mutex_t));
89     ssl_lock_count = OPENSSL_malloc(CRYPTO_num_locks() * sizeof(long));
90     for (i = 0; i < CRYPTO_num_locks(); i++) {
91         ssl_lock_count[i] = 0;
92         pthread_mutex_init(&ssl_locks[i], NULL);
93     }
94
95     CRYPTO_set_id_callback(ssl_thread_id);
96     CRYPTO_set_locking_callback(ssl_locking_callback);
97 }
98
99 int resolvepeer(struct peer *peer) {
100     struct addrinfo hints, *addrinfo;
101     
102     memset(&hints, 0, sizeof(hints));
103     hints.ai_socktype = (peer->type == 'T' ? SOCK_STREAM : SOCK_DGRAM);
104     hints.ai_family = AF_UNSPEC;
105     if (getaddrinfo(peer->host, peer->port, &hints, &addrinfo)) {
106         err("resolvepeer: can't resolve %s port %s", peer->host, peer->port);
107         return 0;
108     }
109
110     if (peer->addrinfo)
111         freeaddrinfo(peer->addrinfo);
112     peer->addrinfo = addrinfo;
113     return 1;
114 }         
115
116 int connecttoserver(struct addrinfo *addrinfo) {
117     int s;
118     struct addrinfo *res;
119     
120     for (res = addrinfo; res; res = res->ai_next) {
121         s = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
122         if (s < 0) {
123             err("connecttoserver: socket failed");
124             continue;
125         }
126         if (connect(s, res->ai_addr, res->ai_addrlen) == 0)
127             break;
128         err("connecttoserver: connect failed");
129         close(s);
130         s = -1;
131     }
132     return s;
133 }         
134
135 /* returns the client with matching address, or NULL */
136 /* if client argument is not NULL, we only check that one client */
137 struct client *find_client(char type, struct sockaddr *addr, struct client *client) {
138     struct sockaddr_in6 *sa6;
139     struct in_addr *a4 = NULL;
140     struct client *c;
141     int i;
142     struct addrinfo *res;
143
144     if (addr->sa_family == AF_INET6) {
145         sa6 = (struct sockaddr_in6 *)addr;
146         if (IN6_IS_ADDR_V4MAPPED(&sa6->sin6_addr))
147             a4 = (struct in_addr *)&sa6->sin6_addr.s6_addr[12];
148     } else
149         a4 = &((struct sockaddr_in *)addr)->sin_addr;
150
151     c = (client ? client : clients);
152     for (i = 0; i < client_count; i++) {
153         if (c->peer.type == type)
154             for (res = c->peer.addrinfo; res; res = res->ai_next)
155                 if ((a4 && res->ai_family == AF_INET &&
156                      !memcmp(a4, &((struct sockaddr_in *)res->ai_addr)->sin_addr, 4)) ||
157                     (res->ai_family == AF_INET6 &&
158                      !memcmp(&sa6->sin6_addr, &((struct sockaddr_in6 *)res->ai_addr)->sin6_addr, 16)))
159                     return c;
160         if (client)
161             break;
162         c++;
163     }
164     return NULL;
165 }
166
167 /* returns the server with matching address, or NULL */
168 /* if server argument is not NULL, we only check that one server */
169 struct server *find_server(char type, struct sockaddr *addr, struct server *server) {
170     struct sockaddr_in6 *sa6;
171     struct in_addr *a4 = NULL;
172     struct server *s;
173     int i;
174     struct addrinfo *res;
175
176     if (addr->sa_family == AF_INET6) {
177         sa6 = (struct sockaddr_in6 *)addr;
178         if (IN6_IS_ADDR_V4MAPPED(&sa6->sin6_addr))
179             a4 = (struct in_addr *)&sa6->sin6_addr.s6_addr[12];
180     } else
181         a4 = &((struct sockaddr_in *)addr)->sin_addr;
182
183     s = (server ? server : servers);
184     for (i = 0; i < server_count; i++) {
185         if (s->peer.type == type)
186             for (res = s->peer.addrinfo; res; res = res->ai_next)
187                 if ((a4 && res->ai_family == AF_INET &&
188                      !memcmp(a4, &((struct sockaddr_in *)res->ai_addr)->sin_addr, 4)) ||
189                     (res->ai_family == AF_INET6 &&
190                      !memcmp(&sa6->sin6_addr, &((struct sockaddr_in6 *)res->ai_addr)->sin6_addr, 16)))
191                     return s;
192         if (server)
193             break;
194         s++;
195     }
196     return NULL;
197 }
198
199 /* exactly one of client and server must be non-NULL */
200 /* if *peer == NULL we return who we received from, else require it to be from peer */
201 /* return from in sa if not NULL */
202 unsigned char *radudpget(int s, struct client **client, struct server **server, struct sockaddr_storage *sa) {
203     int cnt, len;
204     void *f;
205     unsigned char buf[65536], *rad;
206     struct sockaddr_storage from;
207     socklen_t fromlen = sizeof(from);
208
209     for (;;) {
210         cnt = recvfrom(s, buf, sizeof(buf), 0, (struct sockaddr *)&from, &fromlen);
211         if (cnt == -1) {
212             err("radudpget: recv failed");
213             continue;
214         }
215         printf("radudpget: got %d bytes from %s\n", cnt, addr2string((struct sockaddr *)&from, fromlen));
216
217         if (cnt < 20) {
218             printf("radudpget: packet too small\n");
219             continue;
220         }
221     
222         len = RADLEN(buf);
223
224         if (cnt < len) {
225             printf("radudpget: packet smaller than length field in radius header\n");
226             continue;
227         }
228         if (cnt > len)
229             printf("radudpget: packet was padded with %d bytes\n", cnt - len);
230
231         f = (client
232              ? (void *)find_client('U', (struct sockaddr *)&from, *client)
233              : (void *)find_server('U', (struct sockaddr *)&from, *server));
234         if (!f) {
235             printf("radudpget: got packet from wrong or unknown UDP peer, ignoring\n");
236             continue;
237         }
238
239         rad = malloc(len);
240         if (rad)
241             break;
242         err("radudpget: malloc failed");
243     }
244     memcpy(rad, buf, len);
245     if (client)
246         *client = (struct client *)f; /* only need this if *client == NULL, but if not NULL *client == f here */
247     else
248         *server = (struct server *)f; /* only need this if *server == NULL, but if not NULL *server == f here */
249     if (sa)
250         *sa = from;
251     return rad;
252 }
253
254 void tlsconnect(struct server *server, struct timeval *when, char *text) {
255     struct timeval now;
256     time_t elapsed;
257     unsigned long error;
258
259     printf("tlsconnect called from %s\n", text);
260     pthread_mutex_lock(&server->lock);
261     if (when && memcmp(&server->lastconnecttry, when, sizeof(struct timeval))) {
262         /* already reconnected, nothing to do */
263         printf("tlsconnect(%s): seems already reconnected\n", text);
264         pthread_mutex_unlock(&server->lock);
265         return;
266     }
267
268     printf("tlsconnect %s\n", text);
269
270     for (;;) {
271         gettimeofday(&now, NULL);
272         elapsed = now.tv_sec - server->lastconnecttry.tv_sec;
273         if (server->connectionok) {
274             server->connectionok = 0;
275             sleep(10);
276         } else if (elapsed < 5)
277             sleep(10);
278         else if (elapsed < 600)
279             sleep(elapsed * 2);
280         else if (elapsed < 10000) /* no sleep at startup */
281                 sleep(900);
282         printf("tlsconnect: trying to open TLS connection to %s port %s\n", server->peer.host, server->peer.port);
283         if (server->sock >= 0)
284             close(server->sock);
285         if ((server->sock = connecttoserver(server->peer.addrinfo)) < 0)
286             continue;
287         SSL_free(server->peer.ssl);
288         server->peer.ssl = SSL_new(ssl_ctx_cl);
289         SSL_set_fd(server->peer.ssl, server->sock);
290         if (SSL_connect(server->peer.ssl) > 0)
291             break;
292         while ((error = ERR_get_error()))
293             err("tlsconnect: TLS: %s", ERR_error_string(error, NULL));
294     }
295     printf("tlsconnect: TLS connection to %s port %s up\n", server->peer.host, server->peer.port);
296     gettimeofday(&server->lastconnecttry, NULL);
297     pthread_mutex_unlock(&server->lock);
298 }
299
300 unsigned char *radtlsget(SSL *ssl) {
301     int cnt, total, len;
302     unsigned char buf[4], *rad;
303
304     for (;;) {
305         for (total = 0; total < 4; total += cnt) {
306             cnt = SSL_read(ssl, buf + total, 4 - total);
307             if (cnt <= 0) {
308                 printf("radtlsget: connection lost\n");
309                 return NULL;
310             }
311         }
312
313         len = RADLEN(buf);
314         rad = malloc(len);
315         if (!rad) {
316             err("radtlsget: malloc failed");
317             continue;
318         }
319         memcpy(rad, buf, 4);
320
321         for (; total < len; total += cnt) {
322             cnt = SSL_read(ssl, rad + total, len - total);
323             if (cnt <= 0) {
324                 printf("radtlsget: connection lost\n");
325                 free(rad);
326                 return NULL;
327             }
328         }
329     
330         if (total >= 20)
331             break;
332         
333         free(rad);
334         printf("radtlsget: packet smaller than minimum radius size\n");
335     }
336     
337     printf("radtlsget: got %d bytes\n", total);
338     return rad;
339 }
340
341 int clientradput(struct server *server, unsigned char *rad) {
342     int cnt;
343     size_t len;
344     unsigned long error;
345     struct timeval lastconnecttry;
346     
347     len = RADLEN(rad);
348     if (server->peer.type == 'U') {
349         if (send(server->sock, rad, len, 0) >= 0) {
350             printf("clienradput: sent UDP of length %d to %s port %s\n", len, server->peer.host, server->peer.port);
351             return 1;
352         }
353         err("clientradput: send failed");
354         return 0;
355     }
356
357     lastconnecttry = server->lastconnecttry;
358     while ((cnt = SSL_write(server->peer.ssl, rad, len)) <= 0) {
359         while ((error = ERR_get_error()))
360             err("clientwr: TLS: %s", ERR_error_string(error, NULL));
361         tlsconnect(server, &lastconnecttry, "clientradput");
362         lastconnecttry = server->lastconnecttry;
363     }
364
365     server->connectionok = 1;
366     printf("clientradput: Sent %d bytes, Radius packet of length %d to TLS peer %s\n",
367            cnt, len, server->peer.host);
368     return 1;
369 }
370
371 int radsign(unsigned char *rad, unsigned char *sec) {
372     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
373     static unsigned char first = 1;
374     static EVP_MD_CTX mdctx;
375     unsigned int md_len;
376     int result;
377     
378     pthread_mutex_lock(&lock);
379     if (first) {
380         EVP_MD_CTX_init(&mdctx);
381         first = 0;
382     }
383
384     result = (EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) &&
385         EVP_DigestUpdate(&mdctx, rad, RADLEN(rad)) &&
386         EVP_DigestUpdate(&mdctx, sec, strlen(sec)) &&
387         EVP_DigestFinal_ex(&mdctx, rad + 4, &md_len) &&
388         md_len == 16);
389     pthread_mutex_unlock(&lock);
390     return result;
391 }
392
393 int validauth(unsigned char *rad, unsigned char *reqauth, unsigned char *sec) {
394     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
395     static unsigned char first = 1;
396     static EVP_MD_CTX mdctx;
397     unsigned char hash[EVP_MAX_MD_SIZE];
398     unsigned int len;
399     int result;
400     
401     pthread_mutex_lock(&lock);
402     if (first) {
403         EVP_MD_CTX_init(&mdctx);
404         first = 0;
405     }
406
407     len = RADLEN(rad);
408     
409     result = (EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) &&
410               EVP_DigestUpdate(&mdctx, rad, 4) &&
411               EVP_DigestUpdate(&mdctx, reqauth, 16) &&
412               (len <= 20 || EVP_DigestUpdate(&mdctx, rad + 20, len - 20)) &&
413               EVP_DigestUpdate(&mdctx, sec, strlen(sec)) &&
414               EVP_DigestFinal_ex(&mdctx, hash, &len) &&
415               len == 16 &&
416               !memcmp(hash, rad + 4, 16));
417     pthread_mutex_unlock(&lock);
418     return result;
419 }
420               
421 void sendrq(struct server *to, struct client *from, struct request *rq) {
422     int i;
423
424     pthread_mutex_lock(&to->newrq_mutex);
425     for (i = 0; i < MAX_REQUESTS; i++)
426         if (!to->requests[i].buf)
427             break;
428     if (i == MAX_REQUESTS) {
429         printf("No room in queue, dropping request\n");
430         pthread_mutex_unlock(&to->newrq_mutex);
431         return;
432     }
433     
434     rq->buf[1] = (char)i;
435     to->requests[i] = *rq;
436
437     if (!to->newrq) {
438         to->newrq = 1;
439         printf("signalling client writer\n");
440         pthread_cond_signal(&to->newrq_cond);
441     }
442     pthread_mutex_unlock(&to->newrq_mutex);
443 }
444
445 void sendreply(struct client *to, struct server *from, char *buf, struct sockaddr_storage *tosa) {
446     struct replyq *replyq = to->replyq;
447     
448     pthread_mutex_lock(&replyq->count_mutex);
449     if (replyq->count == replyq->size) {
450         printf("No room in queue, dropping request\n");
451         pthread_mutex_unlock(&replyq->count_mutex);
452         return;
453     }
454
455     replyq->replies[replyq->count].buf = buf;
456     if (tosa)
457         replyq->replies[replyq->count].tosa = *tosa;
458     replyq->count++;
459
460     if (replyq->count == 1) {
461         printf("signalling client writer\n");
462         pthread_cond_signal(&replyq->count_cond);
463     }
464     pthread_mutex_unlock(&replyq->count_mutex);
465 }
466
467 int pwdcrypt(uint8_t *plain, uint8_t *enc, uint8_t enclen, uint8_t *shared, uint8_t sharedlen,
468                 uint8_t *auth) {
469     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
470     static unsigned char first = 1;
471     static EVP_MD_CTX mdctx;
472     unsigned char hash[EVP_MAX_MD_SIZE], *input;
473     unsigned int md_len;
474     uint8_t i, offset = 0;
475     
476     pthread_mutex_lock(&lock);
477     if (first) {
478         EVP_MD_CTX_init(&mdctx);
479         first = 0;
480     }
481
482     input = auth;
483     for (;;) {
484         if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
485             !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
486             !EVP_DigestUpdate(&mdctx, input, 16) ||
487             !EVP_DigestFinal_ex(&mdctx, hash, &md_len) ||
488             md_len != 16) {
489             pthread_mutex_unlock(&lock);
490             return 0;
491         }
492         for (i = 0; i < 16; i++)
493             plain[offset + i] = hash[i] ^ enc[offset + i];
494         offset += 16;
495         if (offset == enclen)
496             break;
497         input = enc + offset - 16;
498     }
499     pthread_mutex_unlock(&lock);
500     return 1;
501 }
502
503 struct server *id2server(char *id, uint8_t len) {
504     int i;
505     char **realm, *idrealm;
506
507     idrealm = strchr(id, '@');
508     if (idrealm) {
509         idrealm++;
510         len -= idrealm - id;
511     } else {
512         idrealm = "-";
513         len = 1;
514     }
515     for (i = 0; i < server_count; i++) {
516         for (realm = servers[i].realms; *realm; realm++) {
517             if ((strlen(*realm) == 1 && **realm == '*') ||
518                 (strlen(*realm) == len && !memcmp(idrealm, *realm, len))) {
519                 printf("found matching realm: %s, host %s\n", *realm, servers[i].peer.host);
520                 return servers + i;
521             }
522         }
523     }
524     return NULL;
525 }
526
527 int messageauth(char *rad, uint8_t *authattr, uint8_t *newauth, struct peer *from, struct peer *to) {
528     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
529     static unsigned char first = 1;
530     static HMAC_CTX hmacctx;
531     unsigned int md_len;
532     uint8_t auth[16], hash[EVP_MAX_MD_SIZE];
533     
534     pthread_mutex_lock(&lock);
535     if (first) {
536         HMAC_CTX_init(&hmacctx);
537         first = 0;
538     }
539
540     memcpy(auth, authattr, 16);
541     memset(authattr, 0, 16);
542     md_len = 0;
543     HMAC_Init_ex(&hmacctx, from->secret, strlen(from->secret), EVP_md5(), NULL);
544     HMAC_Update(&hmacctx, rad, RADLEN(rad));
545     HMAC_Final(&hmacctx, hash, &md_len);
546     if (md_len != 16) {
547         printf("message auth computation failed\n");
548         pthread_mutex_unlock(&lock);
549         return 0;
550     }
551
552     if (memcmp(auth, hash, 16)) {
553         printf("message authenticator, wrong value\n");
554         pthread_mutex_unlock(&lock);
555         return 0;
556     }   
557
558     md_len = 0;
559     HMAC_Init_ex(&hmacctx, to->secret, strlen(to->secret), EVP_md5(), NULL);
560     HMAC_Update(&hmacctx, rad, RADLEN(rad));
561     HMAC_Final(&hmacctx, authattr, &md_len);
562     if (md_len != 16) {
563         printf("message auth re-computation failed\n");
564         pthread_mutex_unlock(&lock);
565         return 0;
566     }
567         
568     pthread_mutex_unlock(&lock);
569     return 1;
570 }
571
572 struct server *radsrv(struct request *rq, char *buf, struct client *from) {
573     uint8_t code, id, *auth, *attr, pwd[128], attrvallen;
574     uint8_t *usernameattr = NULL, *userpwdattr = NULL, *tunnelpwdattr = NULL, *messageauthattr = NULL;
575     int i;
576     uint16_t len;
577     int left;
578     struct server *to;
579     unsigned char newauth[16];
580     
581     code = *(uint8_t *)buf;
582     id = *(uint8_t *)(buf + 1);
583     len = RADLEN(buf);
584     auth = (uint8_t *)(buf + 4);
585
586     printf("radsrv: code %d, id %d, length %d\n", code, id, len);
587     
588     if (code != RAD_Access_Request) {
589         printf("radsrv: server currently accepts only access-requests, ignoring\n");
590         return NULL;
591     }
592
593     left = len - 20;
594     attr = buf + 20;
595     
596     while (left > 1) {
597         left -= attr[RAD_Attr_Length];
598         if (left < 0) {
599             printf("radsrv: attribute length exceeds packet length, ignoring packet\n");
600             return NULL;
601         }
602         switch (attr[RAD_Attr_Type]) {
603         case RAD_Attr_User_Name:
604             usernameattr = attr;
605             break;
606         case RAD_Attr_User_Password:
607             userpwdattr = attr;
608             break;
609         case RAD_Attr_Tunnel_Password:
610             tunnelpwdattr = attr;
611             break;
612         case RAD_Attr_Message_Authenticator:
613             messageauthattr = attr;
614             break;
615         }
616         attr += attr[RAD_Attr_Length];
617     }
618     if (left)
619         printf("radsrv: malformed packet? remaining byte after last attribute\n");
620
621     if (usernameattr) {
622         printf("radsrv: Username: ");
623         for (i = 0; i < usernameattr[RAD_Attr_Length] - 2; i++)
624             printf("%c", usernameattr[RAD_Attr_Value + i]);
625         printf("\n");
626     }
627
628     to = id2server(&usernameattr[RAD_Attr_Value], usernameattr[RAD_Attr_Length] - 2);
629     if (!to) {
630         printf("radsrv: ignoring request, don't know where to send it\n");
631         return NULL;
632     }
633
634     if (!RAND_bytes(newauth, 16)) {
635         printf("radsrv: failed to generate random auth\n");
636         return NULL;
637     }
638
639     if (messageauthattr && (messageauthattr[RAD_Attr_Length] != 18 ||
640                             !messageauth(buf, &messageauthattr[RAD_Attr_Value], newauth, &from->peer, &to->peer))) {
641         printf("radsrv: message authentication failed\n");
642         return NULL;
643     }
644
645     if (userpwdattr) {
646         printf("radsrv: found userpwdattr of length %d\n", userpwdattr[RAD_Attr_Length]);
647         attrvallen = userpwdattr[RAD_Attr_Length] - 2;
648         if (attrvallen < 16 || attrvallen > 128 || attrvallen % 16) {
649             printf("radsrv: invalid user password length\n");
650             return NULL;
651         }
652         
653         if (!pwdcrypt(pwd, &userpwdattr[RAD_Attr_Value], attrvallen, from->peer.secret, strlen(from->peer.secret), auth)) {
654             printf("radsrv: cannot decrypt password\n");
655             return NULL;
656         }
657         printf("radsrv: password: ");
658         for (i = 0; i < attrvallen; i++)
659             printf("%02x ", pwd[i]);
660         printf("\n");
661         if (!pwdcrypt(&userpwdattr[RAD_Attr_Value], pwd, attrvallen, to->peer.secret, strlen(to->peer.secret), newauth)) {
662             printf("radsrv: cannot encrypt password\n");
663             return NULL;
664         }
665     }
666
667     if (tunnelpwdattr) {
668         printf("radsrv: found tunnelpwdattr of length %d\n", tunnelpwdattr[RAD_Attr_Length]);
669         attrvallen = tunnelpwdattr[RAD_Attr_Length] - 2;
670         if (attrvallen < 16 || attrvallen > 128 || attrvallen % 16) {
671             printf("radsrv: invalid user password length\n");
672             return NULL;
673         }
674         
675         if (!pwdcrypt(pwd, &tunnelpwdattr[RAD_Attr_Value], attrvallen, from->peer.secret, strlen(from->peer.secret), auth)) {
676             printf("radsrv: cannot decrypt password\n");
677             return NULL;
678         }
679         printf("radsrv: password: ");
680         for (i = 0; i < attrvallen; i++)
681             printf("%02x ", pwd[i]);
682         printf("\n");
683         if (!pwdcrypt(&tunnelpwdattr[RAD_Attr_Value], pwd, attrvallen, to->peer.secret, strlen(to->peer.secret), newauth)) {
684             printf("radsrv: cannot encrypt password\n");
685             return NULL;
686         }
687     }
688
689     rq->buf = buf;
690     rq->from = from;
691     rq->origid = id;
692     memcpy(rq->origauth, auth, 16);
693     memcpy(rq->buf + 4, newauth, 16);
694     return to;
695 }
696
697 void *clientrd(void *arg) {
698     struct server *server = (struct server *)arg;
699     struct client *from;
700     int i;
701     unsigned char *buf;
702     struct sockaddr_storage fromsa;
703     struct timeval lastconnecttry;
704     
705     for (;;) {
706         lastconnecttry = server->lastconnecttry;
707         buf = (server->peer.type == 'U' ? radudpget(server->sock, NULL, &server, NULL) : radtlsget(server->peer.ssl));
708         if (!buf && server->peer.type == 'T') {
709             tlsconnect(server, &lastconnecttry, "clientrd");
710             continue;
711         }
712     
713         server->connectionok = 1;
714         
715         i = buf[1]; /* i is the id */
716
717         pthread_mutex_lock(&server->newrq_mutex);
718         if (!server->requests[i].buf || !server->requests[i].tries) {
719             pthread_mutex_unlock(&server->newrq_mutex);
720             printf("clientrd: no matching request sent with this id, ignoring\n");
721             continue;
722         }
723         
724         if (server->requests[i].received) {
725             pthread_mutex_unlock(&server->newrq_mutex);
726             printf("clientrd: already received, ignoring\n");
727             continue;
728         }
729
730         if (!validauth(buf, server->requests[i].buf + 4, server->peer.secret)) {
731             pthread_mutex_unlock(&server->newrq_mutex);
732             printf("clientrd: invalid auth, ignoring\n");
733             continue;
734         }
735
736         /* once we set received = 1, requests[i] may be reused */
737         buf[1] = (char)server->requests[i].origid;
738         memcpy(buf + 4, server->requests[i].origauth, 16);
739         from = server->requests[i].from;
740         if (from->peer.type == 'U')
741             fromsa = server->requests[i].fromsa;
742         server->requests[i].received = 1;
743         pthread_mutex_unlock(&server->newrq_mutex);
744
745         if (!radsign(buf, from->peer.secret)) {
746             printf("clientrd: failed to sign message\n");
747             continue;
748         }
749         
750         printf("clientrd: giving packet back to where it came from\n");
751         sendreply(from, server, buf, from->peer.type == 'U' ? &fromsa : NULL);
752     }
753 }
754
755 void *clientwr(void *arg) {
756     struct server *server = (struct server *)arg;
757     pthread_t clientrdth;
758     int i;
759
760     if (server->peer.type == 'U') {
761         if ((server->sock = connecttoserver(server->peer.addrinfo)) < 0) {
762             printf("clientwr: connecttoserver failed\n");
763             exit(1);
764         }
765     } else
766         tlsconnect(server, NULL, "new client");
767     
768     if (pthread_create(&clientrdth, NULL, clientrd, (void *)server))
769         errx("clientwr: pthread_create failed");
770
771     for (;;) {
772         pthread_mutex_lock(&server->newrq_mutex);
773         while (!server->newrq) {
774             printf("clientwr: waiting for signal\n");
775             pthread_cond_wait(&server->newrq_cond, &server->newrq_mutex);
776             printf("clientwr: got signal\n");
777         }
778         server->newrq = 0;
779         pthread_mutex_unlock(&server->newrq_mutex);
780                
781         for (i = 0; i < MAX_REQUESTS; i++) {
782             pthread_mutex_lock(&server->newrq_mutex);
783             while (!server->requests[i].buf && i < MAX_REQUESTS)
784                 i++;
785             if (i == MAX_REQUESTS) {
786                 pthread_mutex_unlock(&server->newrq_mutex);
787                 break;
788             }
789
790             /* already received or too many tries */
791             if (server->requests[i].received || server->requests[i].tries > 2) {
792                 free(server->requests[i].buf);
793                 /* setting this to NULL means that it can be reused */
794                 server->requests[i].buf = NULL;
795                 pthread_mutex_unlock(&server->newrq_mutex);
796                 continue;
797             }
798             pthread_mutex_unlock(&server->newrq_mutex);
799             
800             server->requests[i].tries++;
801             clientradput(server, server->requests[i].buf);
802         }
803     }
804     /* should do more work to maintain TLS connections, keepalives etc */
805 }
806
807 void *udpserverwr(void *arg) {
808     struct replyq *replyq = &udp_server_replyq;
809     struct reply *reply = replyq->replies;
810     
811     pthread_mutex_lock(&replyq->count_mutex);
812     for (;;) {
813         while (!replyq->count) {
814             printf("udp server writer, waiting for signal\n");
815             pthread_cond_wait(&replyq->count_cond, &replyq->count_mutex);
816             printf("udp server writer, got signal\n");
817         }
818         pthread_mutex_unlock(&replyq->count_mutex);
819         
820         if (sendto(udp_server_sock, reply->buf, RADLEN(reply->buf), 0,
821                    (struct sockaddr *)&reply->tosa, SOCKADDR_SIZE(reply->tosa)) < 0)
822             err("sendudp: send failed");
823         free(reply->buf);
824         
825         pthread_mutex_lock(&replyq->count_mutex);
826         replyq->count--;
827         memmove(replyq->replies, replyq->replies + 1,
828                 replyq->count * sizeof(struct reply));
829     }
830 }
831
832 void *udpserverrd(void *arg) {
833     struct request rq;
834     unsigned char *buf;
835     struct server *to;
836     struct client *fr;
837     pthread_t udpserverwrth;
838     
839     if ((udp_server_sock = bindport(SOCK_DGRAM, udp_server_port)) < 0) {
840         printf("udpserverrd: socket/bind failed\n");
841         exit(1);
842     }
843     printf("udpserverrd: listening on UDP port %s\n", udp_server_port);
844
845     if (pthread_create(&udpserverwrth, NULL, udpserverwr, NULL))
846         errx("pthread_create failed");
847     
848     for (;;) {
849         fr = NULL;
850         memset(&rq, 0, sizeof(struct request));
851         buf = radudpget(udp_server_sock, &fr, NULL, &rq.fromsa);
852         to = radsrv(&rq, buf, fr);
853         if (!to) {
854             printf("udpserverrd: ignoring request, no place to send it\n");
855             continue;
856         }
857         sendrq(to, fr, &rq);
858     }
859 }
860
861 void *tlsserverwr(void *arg) {
862     int cnt;
863     unsigned long error;
864     struct client *client = (struct client *)arg;
865     struct replyq *replyq;
866     
867     pthread_mutex_lock(&client->replycount_mutex);
868     for (;;) {
869         replyq = client->replyq;
870         while (!replyq->count) {
871             printf("tls server writer, waiting for signal\n");
872             pthread_cond_wait(&replyq->count_cond, &replyq->count_mutex);
873             printf("tls server writer, got signal\n");
874         }
875         pthread_mutex_unlock(&replyq->count_mutex);
876         cnt = SSL_write(client->peer.ssl, replyq->replies->buf, RADLEN(replyq->replies->buf));
877         if (cnt > 0)
878             printf("tlsserverwr: Sent %d bytes, Radius packet of length %d\n",
879                    cnt, RADLEN(replyq->replies->buf));
880         else
881             while ((error = ERR_get_error()))
882                 err("tlsserverwr: SSL: %s", ERR_error_string(error, NULL));
883         free(replyq->replies->buf);
884
885         pthread_mutex_lock(&replyq->count_mutex);
886         replyq->count--;
887         memmove(replyq->replies, replyq->replies + 1, replyq->count * sizeof(struct reply));
888     }
889 }
890
891 void *tlsserverrd(void *arg) {
892     struct request rq;
893     char unsigned *buf;
894     unsigned long error;
895     struct server *to;
896     int s;
897     struct client *client = (struct client *)arg;
898     pthread_t tlsserverwrth;
899
900     printf("tlsserverrd starting\n");
901     if (SSL_accept(client->peer.ssl) <= 0) {
902         while ((error = ERR_get_error()))
903             err("tlsserverrd: SSL: %s", ERR_error_string(error, NULL));
904         errx("accept failed, child exiting");
905     }
906
907     if (pthread_create(&tlsserverwrth, NULL, tlsserverwr, (void *)client))
908         errx("pthread_create failed");
909     
910     for (;;) {
911         buf = radtlsget(client->peer.ssl);
912         if (!buf) {
913             printf("tlsserverrd: connection lost\n");
914             s = SSL_get_fd(client->peer.ssl);
915             SSL_free(client->peer.ssl);
916             client->peer.ssl = NULL;
917             if (s >= 0)
918                 close(s);
919             pthread_exit(NULL);
920         }
921         printf("tlsserverrd: got Radius message from %s\n", client->peer.host);
922         memset(&rq, 0, sizeof(struct request));
923         to = radsrv(&rq, buf, client);
924         if (!to) {
925             printf("ignoring request, no place to send it\n");
926             continue;
927         }
928         sendrq(to, client, &rq);
929     }
930 }
931
932 int tlslistener(SSL_CTX *ssl_ctx) {
933     pthread_t tlsserverth;
934     int s, snew;
935     struct sockaddr_storage from;
936     size_t fromlen = sizeof(from);
937     struct client *client;
938
939     if ((s = bindport(SOCK_STREAM, DEFAULT_TLS_PORT)) < 0) {
940         printf("tlslistener: socket/bind failed\n");
941         exit(1);
942     }
943     
944     listen(s, 0);
945     printf("listening for incoming TLS on port %s\n", DEFAULT_TLS_PORT);
946
947     for (;;) {
948         snew = accept(s, (struct sockaddr *)&from, &fromlen);
949         if (snew < 0)
950             errx("accept failed");
951         printf("incoming TLS connection from %s\n", addr2string((struct sockaddr *)&from, fromlen));
952
953         client = find_client('T', (struct sockaddr *)&from, NULL);
954         if (!client) {
955             printf("ignoring request, not a known TLS client\n");
956             close(snew);
957             continue;
958         }
959
960         if (client->peer.ssl) {
961             printf("Ignoring incoming connection, already have one from this client\n");
962             close(snew);
963             continue;
964         }
965         client->peer.ssl = SSL_new(ssl_ctx);
966         SSL_set_fd(client->peer.ssl, snew);
967         if (pthread_create(&tlsserverth, NULL, tlsserverrd, (void *)client))
968             errx("pthread_create failed");
969     }
970     return 0;
971 }
972
973 char *parsehostport(char *s, struct peer *peer) {
974     char *p, *field;
975     int ipv6 = 0;
976
977     p = s;
978     // allow literal addresses and port, e.g. [2001:db8::1]:1812
979     if (*p == '[') {
980         p++;
981         field = p;
982         for (; *p && *p != ']' && *p != ' ' && *p != '\t' && *p != '\n'; p++);
983         if (*p != ']') {
984             printf("no ] matching initial [\n");
985             exit(1);
986         }
987         ipv6 = 1;
988     } else {
989         field = p;
990         for (; *p && *p != ':' && *p != ' ' && *p != '\t' && *p != '\n'; p++);
991     }
992     if (field == p) {
993         printf("missing host/address\n");
994         exit(1);
995     }
996     peer->host = malloc(p - field + 1);
997     if (!peer->host)
998         errx("malloc failed");
999     memcpy(peer->host, field, p - field);
1000     peer->host[p - field] = '\0';
1001     if (ipv6) {
1002         p++;
1003         if (*p && *p != ':' && *p != ' ' && *p != '\t' && *p != '\n') {
1004             printf("unexpected character after ]\n");
1005             exit(1);
1006         }
1007     }
1008     if (*p == ':') {
1009             /* port number or service name is specified */;
1010             field = p++;
1011             for (; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1012             if (field == p) {
1013                 printf("syntax error, : but no following port\n");
1014                 exit(1);
1015             }
1016             peer->port = malloc(p - field + 1);
1017             if (!peer->port)
1018                 errx("malloc failed");
1019             memcpy(peer->port, field, p - field);
1020             peer->port[p - field] = '\0';
1021     } else
1022         peer->port = NULL;
1023     return p;
1024 }
1025
1026 // * is default, else longest match ... ";" used for separator
1027 char *parserealmlist(char *s, struct server *server) {
1028     char *p;
1029     int i, n, l;
1030
1031     for (p = s, n = 1; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++)
1032         if (*p == ';')
1033             n++;
1034     l = p - s;
1035     if (!l) {
1036         server->realms = NULL;
1037         return p;
1038     }
1039     server->realmdata = malloc(l + 1);
1040     if (!server->realmdata)
1041         errx("malloc failed");
1042     memcpy(server->realmdata, s, l);
1043     server->realmdata[l] = '\0';
1044     server->realms = malloc((1+n) * sizeof(char *));
1045     if (!server->realms)
1046         errx("malloc failed");
1047     server->realms[0] = server->realmdata;
1048     for (n = 1, i = 0; i < l; i++)
1049         if (server->realmdata[i] == ';') {
1050             server->realmdata[i] = '\0';
1051             server->realms[n++] = server->realmdata + i + 1;
1052         }       
1053     server->realms[n] = NULL;
1054     return p;
1055 }
1056
1057 /* exactly one argument must be non-NULL */
1058 void getconfig(const char *serverfile, const char *clientfile) {
1059     FILE *f;
1060     char line[1024];
1061     char *p, *field, **r;
1062     struct client *client;
1063     struct server *server;
1064     struct peer *peer;
1065     int *count;
1066     
1067     if (serverfile) {
1068         printf("opening file %s for reading\n", serverfile);
1069         f = fopen(serverfile, "r");
1070         if (!f)
1071             errx("getconfig failed to open %s for reading", serverfile);
1072         count = &server_count;
1073     } else {
1074         printf("opening file %s for reading\n", clientfile);
1075         f = fopen(clientfile, "r");
1076         if (!f)
1077             errx("getconfig failed to open %s for reading", clientfile);
1078         udp_server_replyq.replies = malloc(4 * MAX_REQUESTS * sizeof(struct reply));
1079         if (!udp_server_replyq.replies)
1080             errx("malloc failed");
1081         udp_server_replyq.size = 4 * MAX_REQUESTS;
1082         udp_server_replyq.count = 0;
1083         pthread_mutex_init(&udp_server_replyq.count_mutex, NULL);
1084         pthread_cond_init(&udp_server_replyq.count_cond, NULL);
1085         count = &client_count;
1086     }    
1087     
1088     *count = 0;
1089     while (fgets(line, 1024, f) && *count < MAX_PEERS) {
1090         if (serverfile) {
1091             server = &servers[*count];
1092             memset(server, 0, sizeof(struct server));
1093             peer = &server->peer;
1094         } else {
1095             client = &clients[*count];
1096             memset(client, 0, sizeof(struct client));
1097             peer = &client->peer;
1098         }
1099         for (p = line; *p == ' ' || *p == '\t'; p++);
1100         if (*p == '#' || *p == '\n')
1101             continue;
1102         if (*p != 'U' && *p != 'T') {
1103             printf("server type must be U or T, got %c\n", *p);
1104             exit(1);
1105         }
1106         peer->type = *p;
1107         for (p++; *p == ' ' || *p == '\t'; p++);
1108         p = parsehostport(p, peer);
1109         if (!peer->port)
1110             peer->port = (peer->type == 'U' ? DEFAULT_UDP_PORT : DEFAULT_TLS_PORT);
1111         for (; *p == ' ' || *p == '\t'; p++);
1112         if (serverfile) {
1113             p = parserealmlist(p, server);
1114             if (!server->realms) {
1115                 printf("realm list must be specified\n");
1116                 exit(1);
1117             }
1118             for (; *p == ' ' || *p == '\t'; p++);
1119         }
1120         field = p;
1121         for (; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1122         if (field == p) {
1123             /* no secret set and end of line, line is complete if TLS */
1124             if (peer->type == 'U') {
1125                 printf("secret must be specified for UDP\n");
1126                 exit(1);
1127             }
1128             peer->secret = DEFAULT_TLS_SECRET;
1129         } else {
1130             peer->secret = malloc(p - field + 1);
1131             if (!peer->secret)
1132                 errx("malloc failed");
1133             memcpy(peer->secret, field, p - field);
1134             peer->secret[p - field] = '\0';
1135             /* check that rest of line only white space */
1136             for (; *p == ' ' || *p == '\t'; p++);
1137             if (*p && *p != '\n') {
1138                 printf("max 4 fields per line, found a 5th\n");
1139                 exit(1);
1140             }
1141         }
1142
1143         if ((serverfile && !resolvepeer(&server->peer)) ||
1144             (clientfile && !resolvepeer(&client->peer))) {
1145             printf("failed to resolve host %s port %s, exiting\n", peer->host, peer->port);
1146             exit(1);
1147         }
1148
1149         if (serverfile) {
1150             pthread_mutex_init(&server->lock, NULL);
1151             server->sock = -1;
1152             server->requests = malloc(MAX_REQUESTS * sizeof(struct request));
1153             if (!server->requests)
1154                 errx("malloc failed");
1155             memset(server->requests, 0, MAX_REQUESTS * sizeof(struct request));
1156             server->newrq = 0;
1157             pthread_mutex_init(&server->newrq_mutex, NULL);
1158             pthread_cond_init(&server->newrq_cond, NULL);
1159         } else {
1160             if (peer->type == 'U')
1161                 client->replyq = &udp_server_replyq;
1162             else {
1163                 client->replyq = malloc(sizeof(struct replyq));
1164                 if (!client->replyq)
1165                     errx("malloc failed");
1166                 client->replyq->replies = malloc(MAX_REQUESTS * sizeof(struct reply));
1167                 if (!client->replyq->replies)
1168                     errx("malloc failed");
1169                 client->replyq->size = MAX_REQUESTS;
1170                 client->replyq->count = 0;
1171                 pthread_mutex_init(&client->replyq->count_mutex, NULL);
1172                 pthread_cond_init(&client->replyq->count_cond, NULL);
1173             }
1174         }
1175         printf("got type %c, host %s, port %s, secret %s\n", peer->type, peer->host, peer->port, peer->secret);
1176         if (serverfile) {
1177             printf("    with realms:");
1178             for (r = server->realms; *r; r++)
1179                 printf(" %s", *r);
1180             printf("\n");
1181         }
1182         (*count)++;
1183     }
1184     fclose(f);
1185 }
1186
1187 void parseargs(int argc, char **argv) {
1188     int c;
1189
1190     while ((c = getopt(argc, argv, "p:")) != -1) {
1191         switch (c) {
1192         case 'p':
1193             udp_server_port = optarg;
1194             break;
1195         default:
1196             goto usage;
1197         }
1198     }
1199
1200     return;
1201
1202  usage:
1203     printf("radsecproxy [ -p UDP-port ]\n");
1204     exit(1);
1205 }
1206                
1207 int main(int argc, char **argv) {
1208     SSL_CTX *ssl_ctx_srv;
1209     unsigned long error;
1210     pthread_t udpserverth;
1211     pthread_attr_t joinable;
1212     int i;
1213     
1214     parseargs(argc, argv);
1215     getconfig("servers.conf", NULL);
1216     getconfig(NULL, "clients.conf");
1217     
1218     ssl_locks_setup();
1219
1220     pthread_attr_init(&joinable);
1221     pthread_attr_setdetachstate(&joinable, PTHREAD_CREATE_JOINABLE);
1222    
1223     /* listen on UDP if at least one UDP client */
1224     
1225     for (i = 0; i < client_count; i++)
1226         if (clients[i].peer.type == 'U') {
1227             if (pthread_create(&udpserverth, &joinable, udpserverrd, NULL))
1228                 errx("pthread_create failed");
1229             break;
1230         }
1231     
1232     /* SSL setup */
1233     SSL_load_error_strings();
1234     SSL_library_init();
1235
1236     while (!RAND_status()) {
1237         time_t t = time(NULL);
1238         pid_t pid = getpid();
1239         RAND_seed((unsigned char *)&t, sizeof(time_t));
1240         RAND_seed((unsigned char *)&pid, sizeof(pid));
1241     }
1242     
1243     /* initialise client part and start clients */
1244     ssl_ctx_cl = SSL_CTX_new(TLSv1_client_method());
1245     if (!ssl_ctx_cl)
1246         errx("no ssl ctx");
1247     
1248     for (i = 0; i < server_count; i++) {
1249         if (pthread_create(&servers[i].clientth, NULL, clientwr, (void *)&servers[i]))
1250             errx("pthread_create failed");
1251     }
1252
1253     for (i = 0; i < client_count; i++)
1254         if (clients[i].peer.type == 'T')
1255             break;
1256
1257     if (i == client_count) {
1258         printf("No TLS clients defined, not starting TLS listener\n");
1259         /* just hang around doing nothing, anything to do here? */
1260         for (;;)
1261             sleep(1000);
1262     }
1263     
1264     /* setting up server/daemon part */
1265     ssl_ctx_srv = SSL_CTX_new(TLSv1_server_method());
1266     if (!ssl_ctx_srv)
1267         errx("no ssl ctx");
1268     if (!SSL_CTX_use_certificate_file(ssl_ctx_srv, "/tmp/server.pem", SSL_FILETYPE_PEM)) {
1269         while ((error = ERR_get_error()))
1270             err("SSL: %s", ERR_error_string(error, NULL));
1271         errx("Failed to load certificate");
1272     }
1273     if (!SSL_CTX_use_PrivateKey_file(ssl_ctx_srv, "/tmp/server.key", SSL_FILETYPE_PEM)) {
1274         while ((error = ERR_get_error()))
1275             err("SSL: %s", ERR_error_string(error, NULL));
1276         errx("Failed to load private key");
1277     }
1278
1279     return tlslistener(ssl_ctx_srv);
1280 }