handling mschapv2 and bug fixing
[libradsec.git] / radsecproxy.c
1 /*
2  * Copyright (C) 2006 Stig Venaas <venaas@uninett.no>
3  *
4  * Permission to use, copy, modify, and distribute this software for any
5  * purpose with or without fee is hereby granted, provided that the above
6  * copyright notice and this permission notice appear in all copies.
7  */
8
9 /* BUGS:
10  * peers can not yet be specified with literal IPv6 addresses due to port syntax
11  */
12
13 /* TODO:
14  * make our server ignore client retrans and do its own instead?
15  * tls keep alives (server status)
16  * tls certificate validation
17 */
18
19 /* For UDP there is one server instance consisting of udpserverrd and udpserverth
20  *              rd is responsible for init and launching wr
21  * For TLS there is a server instance that launches tlsserverrd for each TLS peer
22  *          each tlsserverrd launches tlsserverwr
23  * For each UDP/TLS peer there is clientrd and clientwr, clientwr is responsible
24  *          for init and launching rd
25  *
26  * serverrd will receive a request, processes it and puts it in the requestq of
27  *          the appropriate clientwr
28  * clientwr monitors its requestq and sends requests
29  * clientrd looks for responses, processes them and puts them in the replyq of
30  *          the peer the request came from
31  * serverwr monitors its reply and sends replies
32  *
33  * In addition to the main thread, we have:
34  * If UDP peers are configured, there will be 2 + 2 * #peers UDP threads
35  * If TLS peers are configured, there will initially be 2 * #peers TLS threads
36  * For each TLS peer connecting to us there will be 2 more TLS threads
37  *       This is only for connected peers
38  * Example: With 3 UDP peer and 30 TLS peers, there will be a max of
39  *          1 + (2 + 2 * 3) + (2 * 30) + (2 * 30) = 129 threads
40 */
41
42 #include <netdb.h>
43 #include <unistd.h>
44 #include <sys/time.h>
45 #include <pthread.h>
46 #include <openssl/ssl.h>
47 #include <openssl/rand.h>
48 #include <openssl/err.h>
49 #include <openssl/md5.h>
50 #include <openssl/hmac.h>
51 #include "radsecproxy.h"
52
53 static struct client clients[MAX_PEERS];
54 static struct server servers[MAX_PEERS];
55
56 static int client_count = 0;
57 static int server_count = 0;
58
59 static struct replyq udp_server_replyq;
60 static int udp_server_sock = -1;
61 static char *udp_server_port = DEFAULT_UDP_PORT;
62 static pthread_mutex_t *ssl_locks;
63 static long *ssl_lock_count;
64 static SSL_CTX *ssl_ctx_cl;
65 extern int optind;
66 extern char *optarg;
67
68 /* callbacks for making OpenSSL thread safe */
69 unsigned long ssl_thread_id() {
70         return (unsigned long)pthread_self();
71 };
72
73 void ssl_locking_callback(int mode, int type, const char *file, int line) {
74     if (mode & CRYPTO_LOCK) {
75         pthread_mutex_lock(&ssl_locks[type]);
76         ssl_lock_count[type]++;
77     } else
78         pthread_mutex_unlock(&ssl_locks[type]);
79 }
80
81 void ssl_locks_setup() {
82     int i;
83
84     ssl_locks = malloc(CRYPTO_num_locks() * sizeof(pthread_mutex_t));
85     ssl_lock_count = OPENSSL_malloc(CRYPTO_num_locks() * sizeof(long));
86     for (i = 0; i < CRYPTO_num_locks(); i++) {
87         ssl_lock_count[i] = 0;
88         pthread_mutex_init(&ssl_locks[i], NULL);
89     }
90
91     CRYPTO_set_id_callback(ssl_thread_id);
92     CRYPTO_set_locking_callback(ssl_locking_callback);
93 }
94
95 void printauth(char *s, unsigned char *t) {
96     int i;
97     printf("%s:", s);
98     for (i = 0; i < 16; i++)
99             printf("%02x ", t[i]);
100     printf("\n");
101 }
102
103 int resolvepeer(struct peer *peer) {
104     struct addrinfo hints, *addrinfo;
105     
106     memset(&hints, 0, sizeof(hints));
107     hints.ai_socktype = (peer->type == 'T' ? SOCK_STREAM : SOCK_DGRAM);
108     hints.ai_family = AF_UNSPEC;
109     if (getaddrinfo(peer->host, peer->port, &hints, &addrinfo)) {
110         err("resolvepeer: can't resolve %s port %s", peer->host, peer->port);
111         return 0;
112     }
113
114     if (peer->addrinfo)
115         freeaddrinfo(peer->addrinfo);
116     peer->addrinfo = addrinfo;
117     return 1;
118 }         
119
120 int connecttoserver(struct addrinfo *addrinfo) {
121     int s;
122     struct addrinfo *res;
123     
124     for (res = addrinfo; res; res = res->ai_next) {
125         s = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
126         if (s < 0) {
127             err("connecttoserver: socket failed");
128             continue;
129         }
130         if (connect(s, res->ai_addr, res->ai_addrlen) == 0)
131             break;
132         err("connecttoserver: connect failed");
133         close(s);
134         s = -1;
135     }
136     return s;
137 }         
138
139 /* returns the client with matching address, or NULL */
140 /* if client argument is not NULL, we only check that one client */
141 struct client *find_client(char type, struct sockaddr *addr, struct client *client) {
142     struct sockaddr_in6 *sa6;
143     struct in_addr *a4 = NULL;
144     struct client *c;
145     int i;
146     struct addrinfo *res;
147
148     if (addr->sa_family == AF_INET6) {
149         sa6 = (struct sockaddr_in6 *)addr;
150         if (IN6_IS_ADDR_V4MAPPED(&sa6->sin6_addr))
151             a4 = (struct in_addr *)&sa6->sin6_addr.s6_addr[12];
152     } else
153         a4 = &((struct sockaddr_in *)addr)->sin_addr;
154
155     c = (client ? client : clients);
156     for (i = 0; i < client_count; i++) {
157         if (c->peer.type == type)
158             for (res = c->peer.addrinfo; res; res = res->ai_next)
159                 if ((a4 && res->ai_family == AF_INET &&
160                      !memcmp(a4, &((struct sockaddr_in *)res->ai_addr)->sin_addr, 4)) ||
161                     (res->ai_family == AF_INET6 &&
162                      !memcmp(&sa6->sin6_addr, &((struct sockaddr_in6 *)res->ai_addr)->sin6_addr, 16)))
163                     return c;
164         if (client)
165             break;
166         c++;
167     }
168     return NULL;
169 }
170
171 /* returns the server with matching address, or NULL */
172 /* if server argument is not NULL, we only check that one server */
173 struct server *find_server(char type, struct sockaddr *addr, struct server *server) {
174     struct sockaddr_in6 *sa6;
175     struct in_addr *a4 = NULL;
176     struct server *s;
177     int i;
178     struct addrinfo *res;
179
180     if (addr->sa_family == AF_INET6) {
181         sa6 = (struct sockaddr_in6 *)addr;
182         if (IN6_IS_ADDR_V4MAPPED(&sa6->sin6_addr))
183             a4 = (struct in_addr *)&sa6->sin6_addr.s6_addr[12];
184     } else
185         a4 = &((struct sockaddr_in *)addr)->sin_addr;
186
187     s = (server ? server : servers);
188     for (i = 0; i < server_count; i++) {
189         if (s->peer.type == type)
190             for (res = s->peer.addrinfo; res; res = res->ai_next)
191                 if ((a4 && res->ai_family == AF_INET &&
192                      !memcmp(a4, &((struct sockaddr_in *)res->ai_addr)->sin_addr, 4)) ||
193                     (res->ai_family == AF_INET6 &&
194                      !memcmp(&sa6->sin6_addr, &((struct sockaddr_in6 *)res->ai_addr)->sin6_addr, 16)))
195                     return s;
196         if (server)
197             break;
198         s++;
199     }
200     return NULL;
201 }
202
203 /* exactly one of client and server must be non-NULL */
204 /* if *peer == NULL we return who we received from, else require it to be from peer */
205 /* return from in sa if not NULL */
206 unsigned char *radudpget(int s, struct client **client, struct server **server, struct sockaddr_storage *sa) {
207     int cnt, len;
208     void *f;
209     unsigned char buf[65536], *rad;
210     struct sockaddr_storage from;
211     socklen_t fromlen = sizeof(from);
212
213     for (;;) {
214         cnt = recvfrom(s, buf, sizeof(buf), 0, (struct sockaddr *)&from, &fromlen);
215         if (cnt == -1) {
216             err("radudpget: recv failed");
217             continue;
218         }
219         printf("radudpget: got %d bytes from %s\n", cnt, addr2string((struct sockaddr *)&from, fromlen));
220
221         if (cnt < 20) {
222             printf("radudpget: packet too small\n");
223             continue;
224         }
225     
226         len = RADLEN(buf);
227
228         if (cnt < len) {
229             printf("radudpget: packet smaller than length field in radius header\n");
230             continue;
231         }
232         if (cnt > len)
233             printf("radudpget: packet was padded with %d bytes\n", cnt - len);
234
235         f = (client
236              ? (void *)find_client('U', (struct sockaddr *)&from, *client)
237              : (void *)find_server('U', (struct sockaddr *)&from, *server));
238         if (!f) {
239             printf("radudpget: got packet from wrong or unknown UDP peer, ignoring\n");
240             continue;
241         }
242
243         rad = malloc(len);
244         if (rad)
245             break;
246         err("radudpget: malloc failed");
247     }
248     memcpy(rad, buf, len);
249     if (client)
250         *client = (struct client *)f; /* only need this if *client == NULL, but if not NULL *client == f here */
251     else
252         *server = (struct server *)f; /* only need this if *server == NULL, but if not NULL *server == f here */
253     if (sa)
254         *sa = from;
255     return rad;
256 }
257
258 void tlsconnect(struct server *server, struct timeval *when, char *text) {
259     struct timeval now;
260     time_t elapsed;
261     unsigned long error;
262
263     printf("tlsconnect called from %s\n", text);
264     pthread_mutex_lock(&server->lock);
265     if (when && memcmp(&server->lastconnecttry, when, sizeof(struct timeval))) {
266         /* already reconnected, nothing to do */
267         printf("tlsconnect(%s): seems already reconnected\n", text);
268         pthread_mutex_unlock(&server->lock);
269         return;
270     }
271
272     printf("tlsconnect %s\n", text);
273
274     for (;;) {
275         gettimeofday(&now, NULL);
276         elapsed = now.tv_sec - server->lastconnecttry.tv_sec;
277         if (server->connectionok) {
278             server->connectionok = 0;
279             sleep(10);
280         } else if (elapsed < 5)
281             sleep(10);
282         else if (elapsed < 600)
283             sleep(elapsed * 2);
284         else if (elapsed < 10000) /* no sleep at startup */
285                 sleep(900);
286         printf("tlsconnect: trying to open TLS connection to %s port %s\n", server->peer.host, server->peer.port);
287         if (server->sock >= 0)
288             close(server->sock);
289         if ((server->sock = connecttoserver(server->peer.addrinfo)) < 0)
290             continue;
291         SSL_free(server->peer.ssl);
292         server->peer.ssl = SSL_new(ssl_ctx_cl);
293         SSL_set_fd(server->peer.ssl, server->sock);
294         if (SSL_connect(server->peer.ssl) > 0)
295             break;
296         while ((error = ERR_get_error()))
297             err("tlsconnect: TLS: %s", ERR_error_string(error, NULL));
298     }
299     printf("tlsconnect: TLS connection to %s port %s up\n", server->peer.host, server->peer.port);
300     gettimeofday(&server->lastconnecttry, NULL);
301     pthread_mutex_unlock(&server->lock);
302 }
303
304 unsigned char *radtlsget(SSL *ssl) {
305     int cnt, total, len;
306     unsigned char buf[4], *rad;
307
308     for (;;) {
309         for (total = 0; total < 4; total += cnt) {
310             cnt = SSL_read(ssl, buf + total, 4 - total);
311             if (cnt <= 0) {
312                 printf("radtlsget: connection lost\n");
313                 return NULL;
314             }
315         }
316
317         len = RADLEN(buf);
318         rad = malloc(len);
319         if (!rad) {
320             err("radtlsget: malloc failed");
321             continue;
322         }
323         memcpy(rad, buf, 4);
324
325         for (; total < len; total += cnt) {
326             cnt = SSL_read(ssl, rad + total, len - total);
327             if (cnt <= 0) {
328                 printf("radtlsget: connection lost\n");
329                 free(rad);
330                 return NULL;
331             }
332         }
333     
334         if (total >= 20)
335             break;
336         
337         free(rad);
338         printf("radtlsget: packet smaller than minimum radius size\n");
339     }
340     
341     printf("radtlsget: got %d bytes\n", total);
342     return rad;
343 }
344
345 int clientradput(struct server *server, unsigned char *rad) {
346     int cnt;
347     size_t len;
348     unsigned long error;
349     struct timeval lastconnecttry;
350     
351     len = RADLEN(rad);
352     if (server->peer.type == 'U') {
353         if (send(server->sock, rad, len, 0) >= 0) {
354             printf("clienradput: sent UDP of length %d to %s port %s\n", len, server->peer.host, server->peer.port);
355             return 1;
356         }
357         err("clientradput: send failed");
358         return 0;
359     }
360
361     lastconnecttry = server->lastconnecttry;
362     while ((cnt = SSL_write(server->peer.ssl, rad, len)) <= 0) {
363         while ((error = ERR_get_error()))
364             err("clientwr: TLS: %s", ERR_error_string(error, NULL));
365         tlsconnect(server, &lastconnecttry, "clientradput");
366         lastconnecttry = server->lastconnecttry;
367     }
368
369     server->connectionok = 1;
370     printf("clientradput: Sent %d bytes, Radius packet of length %d to TLS peer %s\n",
371            cnt, len, server->peer.host);
372     return 1;
373 }
374
375 int radsign(unsigned char *rad, unsigned char *sec) {
376     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
377     static unsigned char first = 1;
378     static EVP_MD_CTX mdctx;
379     unsigned int md_len;
380     int result;
381     
382     pthread_mutex_lock(&lock);
383     if (first) {
384         EVP_MD_CTX_init(&mdctx);
385         first = 0;
386     }
387
388     result = (EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) &&
389         EVP_DigestUpdate(&mdctx, rad, RADLEN(rad)) &&
390         EVP_DigestUpdate(&mdctx, sec, strlen(sec)) &&
391         EVP_DigestFinal_ex(&mdctx, rad + 4, &md_len) &&
392         md_len == 16);
393     pthread_mutex_unlock(&lock);
394     return result;
395 }
396
397 int validauth(unsigned char *rad, unsigned char *reqauth, unsigned char *sec) {
398     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
399     static unsigned char first = 1;
400     static EVP_MD_CTX mdctx;
401     unsigned char hash[EVP_MAX_MD_SIZE];
402     unsigned int len;
403     int result;
404     
405     pthread_mutex_lock(&lock);
406     if (first) {
407         EVP_MD_CTX_init(&mdctx);
408         first = 0;
409     }
410
411     len = RADLEN(rad);
412     
413     result = (EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) &&
414               EVP_DigestUpdate(&mdctx, rad, 4) &&
415               EVP_DigestUpdate(&mdctx, reqauth, 16) &&
416               (len <= 20 || EVP_DigestUpdate(&mdctx, rad + 20, len - 20)) &&
417               EVP_DigestUpdate(&mdctx, sec, strlen(sec)) &&
418               EVP_DigestFinal_ex(&mdctx, hash, &len) &&
419               len == 16 &&
420               !memcmp(hash, rad + 4, 16));
421     pthread_mutex_unlock(&lock);
422     return result;
423 }
424               
425 int checkmessageauth(char *rad, uint8_t *authattr, char *secret) {
426     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
427     static unsigned char first = 1;
428     static HMAC_CTX hmacctx;
429     unsigned int md_len;
430     uint8_t auth[16], hash[EVP_MAX_MD_SIZE];
431     
432     pthread_mutex_lock(&lock);
433     if (first) {
434         HMAC_CTX_init(&hmacctx);
435         first = 0;
436     }
437
438     memcpy(auth, authattr, 16);
439     memset(authattr, 0, 16);
440     md_len = 0;
441     HMAC_Init_ex(&hmacctx, secret, strlen(secret), EVP_md5(), NULL);
442     HMAC_Update(&hmacctx, rad, RADLEN(rad));
443     HMAC_Final(&hmacctx, hash, &md_len);
444     memcpy(authattr, auth, 16);
445     if (md_len != 16) {
446         printf("message auth computation failed\n");
447         pthread_mutex_unlock(&lock);
448         return 0;
449     }
450
451     if (memcmp(auth, hash, 16)) {
452         printf("message authenticator, wrong value\n");
453         pthread_mutex_unlock(&lock);
454         return 0;
455     }   
456         
457     pthread_mutex_unlock(&lock);
458     return 1;
459 }
460
461 int createmessageauth(char *rad, char *authattrval, char *secret) {
462     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
463     static unsigned char first = 1;
464     static HMAC_CTX hmacctx;
465     unsigned int md_len;
466
467     if (!authattrval)
468         return 1;
469     
470     pthread_mutex_lock(&lock);
471     if (first) {
472         HMAC_CTX_init(&hmacctx);
473         first = 0;
474     }
475
476     memset(authattrval, 0, 16);
477     md_len = 0;
478     HMAC_Init_ex(&hmacctx, secret, strlen(secret), EVP_md5(), NULL);
479     HMAC_Update(&hmacctx, rad, RADLEN(rad));
480     HMAC_Final(&hmacctx, authattrval, &md_len);
481     if (md_len != 16) {
482         printf("message auth computation failed\n");
483         pthread_mutex_unlock(&lock);
484         return 0;
485     }
486
487     pthread_mutex_unlock(&lock);
488     return 1;
489 }
490
491 void sendrq(struct server *to, struct client *from, struct request *rq) {
492     int i;
493     
494     pthread_mutex_lock(&to->newrq_mutex);
495
496     /* should search from where inserted last */
497     for (i = 0; i < MAX_REQUESTS; i++)
498         if (!to->requests[i].buf)
499             break;
500     if (i == MAX_REQUESTS) {
501         printf("No room in queue, dropping request\n");
502         pthread_mutex_unlock(&to->newrq_mutex);
503         return;
504     }
505     rq->buf[1] = (char)i;
506     
507     if (!createmessageauth(rq->buf, rq->messageauthattrval, to->peer.secret))
508         return;
509
510     gettimeofday(&rq->expiry, NULL);
511     rq->expiry.tv_sec += 30;
512     to->requests[i] = *rq;
513
514     if (!to->newrq) {
515         to->newrq = 1;
516         printf("signalling client writer\n");
517         pthread_cond_signal(&to->newrq_cond);
518     }
519     pthread_mutex_unlock(&to->newrq_mutex);
520 }
521
522 void sendreply(struct client *to, struct server *from, char *buf, struct sockaddr_storage *tosa) {
523     struct replyq *replyq = to->replyq;
524     
525     pthread_mutex_lock(&replyq->count_mutex);
526     if (replyq->count == replyq->size) {
527         printf("No room in queue, dropping request\n");
528         pthread_mutex_unlock(&replyq->count_mutex);
529         return;
530     }
531
532     replyq->replies[replyq->count].buf = buf;
533     if (tosa)
534         replyq->replies[replyq->count].tosa = *tosa;
535     replyq->count++;
536
537     if (replyq->count == 1) {
538         printf("signalling client writer\n");
539         pthread_cond_signal(&replyq->count_cond);
540     }
541     pthread_mutex_unlock(&replyq->count_mutex);
542 }
543
544 int pwdencrypt(uint8_t *out, uint8_t *in, uint8_t len, uint8_t *shared, uint8_t sharedlen,
545                 uint8_t *auth) {
546     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
547     static unsigned char first = 1;
548     static EVP_MD_CTX mdctx;
549     unsigned char hash[EVP_MAX_MD_SIZE], *input;
550     unsigned int md_len;
551     uint8_t i, offset = 0;
552     
553     pthread_mutex_lock(&lock);
554     if (first) {
555         EVP_MD_CTX_init(&mdctx);
556         first = 0;
557     }
558
559     input = auth;
560     for (;;) {
561         if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
562             !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
563             !EVP_DigestUpdate(&mdctx, input, 16) ||
564             !EVP_DigestFinal_ex(&mdctx, hash, &md_len) ||
565             md_len != 16) {
566             pthread_mutex_unlock(&lock);
567             return 0;
568         }
569         for (i = 0; i < 16; i++)
570             out[offset + i] = hash[i] ^ in[offset + i];
571         input = out + offset - 16;
572         offset += 16;
573         if (offset == len)
574             break;
575     }
576     pthread_mutex_unlock(&lock);
577     return 1;
578 }
579
580 int pwddecrypt(uint8_t *out, uint8_t *in, uint8_t len, uint8_t *shared, uint8_t sharedlen,
581                 uint8_t *auth) {
582     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
583     static unsigned char first = 1;
584     static EVP_MD_CTX mdctx;
585     unsigned char hash[EVP_MAX_MD_SIZE], *input;
586     unsigned int md_len;
587     uint8_t i, offset = 0;
588     
589     pthread_mutex_lock(&lock);
590     if (first) {
591         EVP_MD_CTX_init(&mdctx);
592         first = 0;
593     }
594
595     input = auth;
596     for (;;) {
597         if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
598             !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
599             !EVP_DigestUpdate(&mdctx, input, 16) ||
600             !EVP_DigestFinal_ex(&mdctx, hash, &md_len) ||
601             md_len != 16) {
602             pthread_mutex_unlock(&lock);
603             return 0;
604         }
605         for (i = 0; i < 16; i++)
606             out[offset + i] = hash[i] ^ in[offset + i];
607         input = in + offset;
608         offset += 16;
609         if (offset == len)
610             break;
611     }
612     pthread_mutex_unlock(&lock);
613     return 1;
614 }
615
616 int msmppencrypt(uint8_t *text, uint8_t len, uint8_t *shared, uint8_t sharedlen, uint8_t *auth, uint8_t *salt) {
617     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
618     static unsigned char first = 1;
619     static EVP_MD_CTX mdctx;
620     unsigned char hash[EVP_MAX_MD_SIZE];
621     unsigned int md_len;
622     uint8_t i, offset;
623     
624     pthread_mutex_lock(&lock);
625     if (first) {
626         EVP_MD_CTX_init(&mdctx);
627         first = 0;
628     }
629
630 #if 0    
631     printf("msppencrypt auth in: ");
632     for (i = 0; i < 16; i++)
633         printf("%02x ", auth[i]);
634     printf("\n");
635     
636     printf("msppencrypt salt in: ");
637     for (i = 0; i < 2; i++)
638         printf("%02x ", salt[i]);
639     printf("\n");
640     
641     printf("msppencrypt in: ");
642     for (i = 0; i < len; i++)
643         printf("%02x ", text[i]);
644     printf("\n");
645 #endif
646     
647     if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
648         !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
649         !EVP_DigestUpdate(&mdctx, auth, 16) ||
650         !EVP_DigestUpdate(&mdctx, salt, 2) ||
651         !EVP_DigestFinal_ex(&mdctx, hash, &md_len)) {
652         pthread_mutex_unlock(&lock);
653         return 0;
654     }
655
656 #if 0    
657     printf("msppencrypt hash: ");
658     for (i = 0; i < 16; i++)
659         printf("%02x ", hash[i]);
660     printf("\n");
661 #endif
662     
663     for (i = 0; i < 16; i++)
664         text[i] ^= hash[i];
665     
666     for (offset = 16; offset < len; offset += 16) {
667         printf("text + offset - 16 c(%d): ", offset / 16);
668         for (i = 0; i < 16; i++)
669             printf("%02x ", (text + offset - 16)[i]);
670         printf("\n");
671
672         if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
673             !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
674             !EVP_DigestUpdate(&mdctx, text + offset - 16, 16) ||
675             !EVP_DigestFinal_ex(&mdctx, hash, &md_len) ||
676             md_len != 16) {
677             pthread_mutex_unlock(&lock);
678             return 0;
679         }
680 #if 0   
681         printf("msppencrypt hash: ");
682         for (i = 0; i < 16; i++)
683             printf("%02x ", hash[i]);
684         printf("\n");
685 #endif    
686         
687         for (i = 0; i < 16; i++)
688             text[offset + i] ^= hash[i];
689     }
690     
691 #if 0
692     printf("msppencrypt out: ");
693     for (i = 0; i < len; i++)
694         printf("%02x ", text[i]);
695     printf("\n");
696 #endif
697
698     pthread_mutex_unlock(&lock);
699     return 1;
700 }
701
702 int msmppdecrypt(uint8_t *text, uint8_t len, uint8_t *shared, uint8_t sharedlen, uint8_t *auth, uint8_t *salt) {
703     static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER;
704     static unsigned char first = 1;
705     static EVP_MD_CTX mdctx;
706     unsigned char hash[EVP_MAX_MD_SIZE];
707     unsigned int md_len;
708     uint8_t i, offset;
709     char plain[255];
710     
711     pthread_mutex_lock(&lock);
712     if (first) {
713         EVP_MD_CTX_init(&mdctx);
714         first = 0;
715     }
716
717 #if 0    
718     printf("msppdecrypt auth in: ");
719     for (i = 0; i < 16; i++)
720         printf("%02x ", auth[i]);
721     printf("\n");
722     
723     printf("msppedecrypt salt in: ");
724     for (i = 0; i < 2; i++)
725         printf("%02x ", salt[i]);
726     printf("\n");
727     
728     printf("msppedecrypt in: ");
729     for (i = 0; i < len; i++)
730         printf("%02x ", text[i]);
731     printf("\n");
732 #endif
733     
734     if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
735         !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
736         !EVP_DigestUpdate(&mdctx, auth, 16) ||
737         !EVP_DigestUpdate(&mdctx, salt, 2) ||
738         !EVP_DigestFinal_ex(&mdctx, hash, &md_len)) {
739         pthread_mutex_unlock(&lock);
740         return 0;
741     }
742
743 #if 0    
744     printf("msppedecrypt hash: ");
745     for (i = 0; i < 16; i++)
746         printf("%02x ", hash[i]);
747     printf("\n");
748 #endif
749     
750     for (i = 0; i < 16; i++)
751         plain[i] = text[i] ^ hash[i];
752     
753     for (offset = 16; offset < len; offset += 16) {
754 #if 0   
755         printf("text + offset - 16 c(%d): ", offset / 16);
756         for (i = 0; i < 16; i++)
757             printf("%02x ", (text + offset - 16)[i]);
758         printf("\n");
759 #endif
760         if (!EVP_DigestInit_ex(&mdctx, EVP_md5(), NULL) ||
761             !EVP_DigestUpdate(&mdctx, shared, sharedlen) ||
762             !EVP_DigestUpdate(&mdctx, text + offset - 16, 16) ||
763             !EVP_DigestFinal_ex(&mdctx, hash, &md_len) ||
764             md_len != 16) {
765             pthread_mutex_unlock(&lock);
766             return 0;
767         }
768 #if 0   
769     printf("msppedecrypt hash: ");
770     for (i = 0; i < 16; i++)
771         printf("%02x ", hash[i]);
772     printf("\n");
773 #endif    
774
775     for (i = 0; i < 16; i++)
776         plain[offset + i] = text[offset + i] ^ hash[i];
777     }
778
779     memcpy(text, plain, len);
780 #if 0
781     printf("msppedecrypt out: ");
782     for (i = 0; i < len; i++)
783         printf("%02x ", text[i]);
784     printf("\n");
785 #endif
786
787     pthread_mutex_unlock(&lock);
788     return 1;
789 }
790
791 struct server *id2server(char *id, uint8_t len) {
792     int i;
793     char **realm, *idrealm;
794
795     idrealm = strchr(id, '@');
796     if (idrealm) {
797         idrealm++;
798         len -= idrealm - id;
799     } else {
800         idrealm = "-";
801         len = 1;
802     }
803     for (i = 0; i < server_count; i++) {
804         for (realm = servers[i].realms; *realm; realm++) {
805             if ((strlen(*realm) == 1 && **realm == '*') ||
806                 (strlen(*realm) == len && !memcmp(idrealm, *realm, len))) {
807                 printf("found matching realm: %s, host %s\n", *realm, servers[i].peer.host);
808                 return servers + i;
809             }
810         }
811     }
812     return NULL;
813 }
814
815 struct server *radsrv(struct request *rq, char *buf, struct client *from) {
816     uint8_t code, id, *auth, *attr, pwd[128], attrvallen;
817     uint8_t *usernameattr = NULL, *userpwdattr = NULL, *tunnelpwdattr = NULL, *messageauthattr = NULL;
818     int i;
819     uint16_t len;
820     int left;
821     struct server *to;
822     unsigned char newauth[16];
823     
824     code = *(uint8_t *)buf;
825     id = *(uint8_t *)(buf + 1);
826     len = RADLEN(buf);
827     auth = (uint8_t *)(buf + 4);
828
829     printf("radsrv: code %d, id %d, length %d\n", code, id, len);
830     
831     if (code != RAD_Access_Request) {
832         printf("radsrv: server currently accepts only access-requests, ignoring\n");
833         return NULL;
834     }
835
836     left = len - 20;
837     attr = buf + 20;
838     
839     while (left > 1) {
840         left -= attr[RAD_Attr_Length];
841         if (left < 0) {
842             printf("radsrv: attribute length exceeds packet length, ignoring packet\n");
843             return NULL;
844         }
845         switch (attr[RAD_Attr_Type]) {
846         case RAD_Attr_User_Name:
847             usernameattr = attr;
848             break;
849         case RAD_Attr_User_Password:
850             userpwdattr = attr;
851             break;
852         case RAD_Attr_Tunnel_Password:
853             tunnelpwdattr = attr;
854             break;
855         case RAD_Attr_Message_Authenticator:
856             messageauthattr = attr;
857             break;
858         }
859         attr += attr[RAD_Attr_Length];
860     }
861     if (left)
862         printf("radsrv: malformed packet? remaining byte after last attribute\n");
863
864     if (usernameattr) {
865         printf("radsrv: Username: ");
866         for (i = 0; i < usernameattr[RAD_Attr_Length] - 2; i++)
867             printf("%c", usernameattr[RAD_Attr_Value + i]);
868         printf("\n");
869     }
870
871     to = id2server(&usernameattr[RAD_Attr_Value], usernameattr[RAD_Attr_Length] - 2);
872     if (!to) {
873         printf("radsrv: ignoring request, don't know where to send it\n");
874         return NULL;
875     }
876     
877     if (messageauthattr && (messageauthattr[RAD_Attr_Length] != 18 ||
878                             !checkmessageauth(buf, &messageauthattr[RAD_Attr_Value], from->peer.secret))) {
879         printf("radsrv: message authentication failed\n");
880         return NULL;
881     }
882
883     if (!RAND_bytes(newauth, 16)) {
884         printf("radsrv: failed to generate random auth\n");
885         return NULL;
886     }
887
888     printauth("auth", auth);
889     printauth("newauth", newauth);
890     
891     if (userpwdattr) {
892         printf("radsrv: found userpwdattr of length %d\n", userpwdattr[RAD_Attr_Length]);
893         attrvallen = userpwdattr[RAD_Attr_Length] - 2;
894         if (attrvallen < 16 || attrvallen > 128 || attrvallen % 16) {
895             printf("radsrv: invalid user password length\n");
896             return NULL;
897         }
898         
899         if (!pwddecrypt(pwd, &userpwdattr[RAD_Attr_Value], attrvallen, from->peer.secret, strlen(from->peer.secret), auth)) {
900             printf("radsrv: cannot decrypt password\n");
901             return NULL;
902         }
903         printf("radsrv: password: ");
904         for (i = 0; i < attrvallen; i++)
905             printf("%02x ", pwd[i]);
906         printf("\n");
907         if (!pwdencrypt(&userpwdattr[RAD_Attr_Value], pwd, attrvallen, to->peer.secret, strlen(to->peer.secret), newauth)) {
908             printf("radsrv: cannot encrypt password\n");
909             return NULL;
910         }
911     }
912
913     if (tunnelpwdattr) {
914         printf("radsrv: found tunnelpwdattr of length %d\n", tunnelpwdattr[RAD_Attr_Length]);
915         attrvallen = tunnelpwdattr[RAD_Attr_Length] - 2;
916         if (attrvallen < 16 || attrvallen > 128 || attrvallen % 16) {
917             printf("radsrv: invalid user password length\n");
918             return NULL;
919         }
920         
921         if (!pwddecrypt(pwd, &tunnelpwdattr[RAD_Attr_Value], attrvallen, from->peer.secret, strlen(from->peer.secret), auth)) {
922             printf("radsrv: cannot decrypt password\n");
923             return NULL;
924         }
925         printf("radsrv: password: ");
926         for (i = 0; i < attrvallen; i++)
927             printf("%02x ", pwd[i]);
928         printf("\n");
929         if (!pwdencrypt(&tunnelpwdattr[RAD_Attr_Value], pwd, attrvallen, to->peer.secret, strlen(to->peer.secret), newauth)) {
930             printf("radsrv: cannot encrypt password\n");
931             return NULL;
932         }
933     }
934
935     rq->buf = buf;
936     rq->from = from;
937     rq->origid = id;
938     rq->messageauthattrval = (messageauthattr ? &messageauthattr[RAD_Attr_Value] : NULL);
939     memcpy(rq->origauth, auth, 16);
940     memcpy(auth, newauth, 16);
941     printauth("rq->origauth", rq->origauth);
942     printauth("auth", auth);
943     return to;
944 }
945
946 void *clientrd(void *arg) {
947     struct server *server = (struct server *)arg;
948     struct client *from;
949     int i, left, subleft;
950     unsigned char *buf, *messageauthattr, *subattr, *attr;
951     struct sockaddr_storage fromsa;
952     struct timeval lastconnecttry;
953     char tmp[255];
954     
955     for (;;) {
956     getnext:
957         lastconnecttry = server->lastconnecttry;
958         buf = (server->peer.type == 'U' ? radudpget(server->sock, NULL, &server, NULL) : radtlsget(server->peer.ssl));
959         if (!buf && server->peer.type == 'T') {
960             tlsconnect(server, &lastconnecttry, "clientrd");
961             continue;
962         }
963     
964         server->connectionok = 1;
965
966         if (*buf != RAD_Access_Accept && *buf != RAD_Access_Reject && *buf != RAD_Access_Challenge) {
967             printf("clientrd: discarding, only accept access accept, access reject and access challenge messages\n");
968             continue;
969         }
970         
971         i = buf[1]; /* i is the id */
972
973         pthread_mutex_lock(&server->newrq_mutex);
974         if (!server->requests[i].buf || !server->requests[i].tries) {
975             pthread_mutex_unlock(&server->newrq_mutex);
976             printf("clientrd: no matching request sent with this id, ignoring\n");
977             continue;
978         }
979
980         if (server->requests[i].received) {
981             pthread_mutex_unlock(&server->newrq_mutex);
982             printf("clientrd: already received, ignoring\n");
983             continue;
984         }
985         
986         if (!validauth(buf, server->requests[i].buf + 4, server->peer.secret)) {
987             pthread_mutex_unlock(&server->newrq_mutex);
988             printf("clientrd: invalid auth, ignoring\n");
989             continue;
990         }
991         
992         from = server->requests[i].from;
993
994
995         /* messageauthattr present? */
996         messageauthattr = NULL;
997         left = RADLEN(buf) - 20;
998         attr = buf + 20;
999         while (left > 1) {
1000             left -= attr[RAD_Attr_Length];
1001             if (left < 0) {
1002                 printf("clientrd: attribute length exceeds packet length, ignoring packet\n");
1003                 goto getnext;
1004             }
1005             if (attr[RAD_Attr_Type] == RAD_Attr_Message_Authenticator) {
1006                 if (attr[RAD_Attr_Length] != 18) {
1007                     printf("clientrd: illegal message auth attribute length, ignoring packet\n");
1008                     goto getnext;
1009                 }
1010                 memcpy(tmp, buf + 4, 16);
1011                 memcpy(buf + 4, server->requests[i].buf + 4, 16);
1012                 if (!checkmessageauth(buf, &attr[RAD_Attr_Value], server->peer.secret)) {
1013                     printf("clientrd: message authentication failed\n");
1014                     goto getnext;
1015                 }
1016                 memcpy(buf + 4, tmp, 16);
1017                 printf("clientrd: message auth ok\n");
1018                 messageauthattr = attr;
1019                 break;
1020             }
1021             attr += attr[RAD_Attr_Length];
1022         }
1023
1024         /* handle MS MPPE */
1025         left = RADLEN(buf) - 20;
1026         attr = buf + 20;
1027         while (left > 1) {
1028             left -= attr[RAD_Attr_Length];
1029             if (left < 0) {
1030                 printf("clientrd: attribute length exceeds packet length, ignoring packet\n");
1031                 goto getnext;
1032             }
1033             if (attr[RAD_Attr_Type] == RAD_Attr_Vendor_Specific &&
1034                 ((uint16_t *)attr)[1] == 0 && ntohs(((uint16_t *)attr)[2]) == 311) { // 311 == MS
1035                 subleft = attr[RAD_Attr_Length] - 6;
1036                 subattr = attr + 6;
1037                 while (subleft > 1) {
1038                     subleft -= subattr[RAD_Attr_Length];
1039                     if (subleft < 0)
1040                         break;
1041                     if (subattr[RAD_Attr_Type] != RAD_VS_ATTR_MS_MPPE_Send_Key &&
1042                         subattr[RAD_Attr_Type] != RAD_VS_ATTR_MS_MPPE_Recv_Key)
1043                         continue;
1044                     printf("clientrd: Got MS MPPE\n");
1045                     if (subattr[RAD_Attr_Length] < 20)
1046                         continue;
1047
1048                     if (!msmppdecrypt(subattr + 4, subattr[RAD_Attr_Length] - 4,
1049                             server->peer.secret, strlen(server->peer.secret), server->requests[i].buf + 4, subattr + 2)) {
1050                         printf("clientrd: failed to decrypt msppe key\n");
1051                         continue;
1052                     }
1053
1054                     if (!msmppencrypt(subattr + 4, subattr[RAD_Attr_Length] - 4,
1055                             from->peer.secret, strlen(from->peer.secret), server->requests[i].origauth, subattr + 2)) {
1056                         printf("clientrd: failed to encrypt msppe key\n");
1057                         continue;
1058                     }
1059                 }
1060                 if (subleft < 0) {
1061                     printf("clientrd: bad vendor specific attr or subattr length, ignoring packet\n");
1062                     goto getnext;
1063                 }
1064             }
1065             attr += attr[RAD_Attr_Length];
1066         }
1067
1068         /* once we set received = 1, requests[i] may be reused */
1069         buf[1] = (char)server->requests[i].origid;
1070         memcpy(buf + 4, server->requests[i].origauth, 16);
1071         printauth("origauth/buf+4", buf + 4);
1072         if (messageauthattr) {
1073             if (!createmessageauth(buf, &messageauthattr[RAD_Attr_Value], from->peer.secret))
1074                 continue;
1075             printf("clientrd: computed messageauthattr\n");
1076         }
1077
1078         if (from->peer.type == 'U')
1079             fromsa = server->requests[i].fromsa;
1080         server->requests[i].received = 1;
1081         pthread_mutex_unlock(&server->newrq_mutex);
1082
1083         if (!radsign(buf, from->peer.secret)) {
1084             printf("clientrd: failed to sign message\n");
1085             continue;
1086         }
1087         printauth("signedorigauth/buf+4", buf + 4);             
1088         printf("clientrd: giving packet back to where it came from\n");
1089         sendreply(from, server, buf, from->peer.type == 'U' ? &fromsa : NULL);
1090     }
1091 }
1092
1093 void *clientwr(void *arg) {
1094     struct server *server = (struct server *)arg;
1095     struct request *rq;
1096     pthread_t clientrdth;
1097     int i;
1098     struct timeval now;
1099     
1100     if (server->peer.type == 'U') {
1101         if ((server->sock = connecttoserver(server->peer.addrinfo)) < 0) {
1102             printf("clientwr: connecttoserver failed\n");
1103             exit(1);
1104         }
1105     } else
1106         tlsconnect(server, NULL, "new client");
1107     
1108     if (pthread_create(&clientrdth, NULL, clientrd, (void *)server))
1109         errx("clientwr: pthread_create failed");
1110
1111     for (;;) {
1112         pthread_mutex_lock(&server->newrq_mutex);
1113         while (!server->newrq) {
1114             printf("clientwr: waiting for signal\n");
1115             pthread_cond_wait(&server->newrq_cond, &server->newrq_mutex);
1116             printf("clientwr: got signal\n");
1117         }
1118         server->newrq = 0;
1119         pthread_mutex_unlock(&server->newrq_mutex);
1120                
1121         for (i = 0; i < MAX_REQUESTS; i++) {
1122             pthread_mutex_lock(&server->newrq_mutex);
1123             while (!server->requests[i].buf && i < MAX_REQUESTS)
1124                 i++;
1125             if (i == MAX_REQUESTS) {
1126                 pthread_mutex_unlock(&server->newrq_mutex);
1127                 break;
1128             }
1129
1130             gettimeofday(&now, NULL);
1131             rq = server->requests + i;
1132
1133             if (rq->received) {
1134                 printf("clientwr: removing received packet from queue\n");
1135                 free(rq->buf);
1136                 /* setting this to NULL means that it can be reused */
1137                 rq->buf = NULL;
1138                 pthread_mutex_unlock(&server->newrq_mutex);
1139                 continue;
1140             }
1141             if (now.tv_sec > rq->expiry.tv_sec) {
1142                 printf("clientwr: removing expired packet from queue\n");
1143                 free(rq->buf);
1144                 /* setting this to NULL means that it can be reused */
1145                 rq->buf = NULL;
1146                 pthread_mutex_unlock(&server->newrq_mutex);
1147                 continue;
1148             }
1149
1150             if (rq->tries)
1151                 continue; // not re-sending (yet)
1152             
1153             rq->tries++;
1154             pthread_mutex_unlock(&server->newrq_mutex);
1155             
1156             clientradput(server, server->requests[i].buf);
1157         }
1158     }
1159     /* should do more work to maintain TLS connections, keepalives etc */
1160 }
1161
1162 void *udpserverwr(void *arg) {
1163     struct replyq *replyq = &udp_server_replyq;
1164     struct reply *reply = replyq->replies;
1165     
1166     pthread_mutex_lock(&replyq->count_mutex);
1167     for (;;) {
1168         while (!replyq->count) {
1169             printf("udp server writer, waiting for signal\n");
1170             pthread_cond_wait(&replyq->count_cond, &replyq->count_mutex);
1171             printf("udp server writer, got signal\n");
1172         }
1173         pthread_mutex_unlock(&replyq->count_mutex);
1174         
1175         if (sendto(udp_server_sock, reply->buf, RADLEN(reply->buf), 0,
1176                    (struct sockaddr *)&reply->tosa, SOCKADDR_SIZE(reply->tosa)) < 0)
1177             err("sendudp: send failed");
1178         free(reply->buf);
1179         
1180         pthread_mutex_lock(&replyq->count_mutex);
1181         replyq->count--;
1182         memmove(replyq->replies, replyq->replies + 1,
1183                 replyq->count * sizeof(struct reply));
1184     }
1185 }
1186
1187 void *udpserverrd(void *arg) {
1188     struct request rq;
1189     unsigned char *buf;
1190     struct server *to;
1191     struct client *fr;
1192     pthread_t udpserverwrth;
1193     
1194     if ((udp_server_sock = bindport(SOCK_DGRAM, udp_server_port)) < 0) {
1195         printf("udpserverrd: socket/bind failed\n");
1196         exit(1);
1197     }
1198     printf("udpserverrd: listening on UDP port %s\n", udp_server_port);
1199
1200     if (pthread_create(&udpserverwrth, NULL, udpserverwr, NULL))
1201         errx("pthread_create failed");
1202     
1203     for (;;) {
1204         fr = NULL;
1205         memset(&rq, 0, sizeof(struct request));
1206         buf = radudpget(udp_server_sock, &fr, NULL, &rq.fromsa);
1207         to = radsrv(&rq, buf, fr);
1208         if (!to) {
1209             printf("udpserverrd: ignoring request, no place to send it\n");
1210             continue;
1211         }
1212         sendrq(to, fr, &rq);
1213     }
1214 }
1215
1216 void *tlsserverwr(void *arg) {
1217     int cnt;
1218     unsigned long error;
1219     struct client *client = (struct client *)arg;
1220     struct replyq *replyq;
1221     
1222     pthread_mutex_lock(&client->replycount_mutex);
1223     for (;;) {
1224         replyq = client->replyq;
1225         while (!replyq->count) {
1226             printf("tls server writer, waiting for signal\n");
1227             pthread_cond_wait(&replyq->count_cond, &replyq->count_mutex);
1228             printf("tls server writer, got signal\n");
1229         }
1230         pthread_mutex_unlock(&replyq->count_mutex);
1231         cnt = SSL_write(client->peer.ssl, replyq->replies->buf, RADLEN(replyq->replies->buf));
1232         if (cnt > 0)
1233             printf("tlsserverwr: Sent %d bytes, Radius packet of length %d\n",
1234                    cnt, RADLEN(replyq->replies->buf));
1235         else
1236             while ((error = ERR_get_error()))
1237                 err("tlsserverwr: SSL: %s", ERR_error_string(error, NULL));
1238         free(replyq->replies->buf);
1239
1240         pthread_mutex_lock(&replyq->count_mutex);
1241         replyq->count--;
1242         memmove(replyq->replies, replyq->replies + 1, replyq->count * sizeof(struct reply));
1243     }
1244 }
1245
1246 void *tlsserverrd(void *arg) {
1247     struct request rq;
1248     char unsigned *buf;
1249     unsigned long error;
1250     struct server *to;
1251     int s;
1252     struct client *client = (struct client *)arg;
1253     pthread_t tlsserverwrth;
1254
1255     printf("tlsserverrd starting\n");
1256     if (SSL_accept(client->peer.ssl) <= 0) {
1257         while ((error = ERR_get_error()))
1258             err("tlsserverrd: SSL: %s", ERR_error_string(error, NULL));
1259         errx("accept failed, child exiting");
1260     }
1261
1262     if (pthread_create(&tlsserverwrth, NULL, tlsserverwr, (void *)client))
1263         errx("pthread_create failed");
1264     
1265     for (;;) {
1266         buf = radtlsget(client->peer.ssl);
1267         if (!buf) {
1268             printf("tlsserverrd: connection lost\n");
1269             s = SSL_get_fd(client->peer.ssl);
1270             SSL_free(client->peer.ssl);
1271             client->peer.ssl = NULL;
1272             if (s >= 0)
1273                 close(s);
1274             pthread_exit(NULL);
1275         }
1276         printf("tlsserverrd: got Radius message from %s\n", client->peer.host);
1277         memset(&rq, 0, sizeof(struct request));
1278         to = radsrv(&rq, buf, client);
1279         if (!to) {
1280             printf("ignoring request, no place to send it\n");
1281             continue;
1282         }
1283         sendrq(to, client, &rq);
1284     }
1285 }
1286
1287 int tlslistener(SSL_CTX *ssl_ctx) {
1288     pthread_t tlsserverth;
1289     int s, snew;
1290     struct sockaddr_storage from;
1291     size_t fromlen = sizeof(from);
1292     struct client *client;
1293
1294     if ((s = bindport(SOCK_STREAM, DEFAULT_TLS_PORT)) < 0) {
1295         printf("tlslistener: socket/bind failed\n");
1296         exit(1);
1297     }
1298     
1299     listen(s, 0);
1300     printf("listening for incoming TLS on port %s\n", DEFAULT_TLS_PORT);
1301
1302     for (;;) {
1303         snew = accept(s, (struct sockaddr *)&from, &fromlen);
1304         if (snew < 0)
1305             errx("accept failed");
1306         printf("incoming TLS connection from %s\n", addr2string((struct sockaddr *)&from, fromlen));
1307
1308         client = find_client('T', (struct sockaddr *)&from, NULL);
1309         if (!client) {
1310             printf("ignoring request, not a known TLS client\n");
1311             close(snew);
1312             continue;
1313         }
1314
1315         if (client->peer.ssl) {
1316             printf("Ignoring incoming connection, already have one from this client\n");
1317             close(snew);
1318             continue;
1319         }
1320         client->peer.ssl = SSL_new(ssl_ctx);
1321         SSL_set_fd(client->peer.ssl, snew);
1322         if (pthread_create(&tlsserverth, NULL, tlsserverrd, (void *)client))
1323             errx("pthread_create failed");
1324     }
1325     return 0;
1326 }
1327
1328 char *parsehostport(char *s, struct peer *peer) {
1329     char *p, *field;
1330     int ipv6 = 0;
1331
1332     p = s;
1333     // allow literal addresses and port, e.g. [2001:db8::1]:1812
1334     if (*p == '[') {
1335         p++;
1336         field = p;
1337         for (; *p && *p != ']' && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1338         if (*p != ']') {
1339             printf("no ] matching initial [\n");
1340             exit(1);
1341         }
1342         ipv6 = 1;
1343     } else {
1344         field = p;
1345         for (; *p && *p != ':' && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1346     }
1347     if (field == p) {
1348         printf("missing host/address\n");
1349         exit(1);
1350     }
1351     peer->host = malloc(p - field + 1);
1352     if (!peer->host)
1353         errx("malloc failed");
1354     memcpy(peer->host, field, p - field);
1355     peer->host[p - field] = '\0';
1356     if (ipv6) {
1357         p++;
1358         if (*p && *p != ':' && *p != ' ' && *p != '\t' && *p != '\n') {
1359             printf("unexpected character after ]\n");
1360             exit(1);
1361         }
1362     }
1363     if (*p == ':') {
1364             /* port number or service name is specified */;
1365             field = p++;
1366             for (; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1367             if (field == p) {
1368                 printf("syntax error, : but no following port\n");
1369                 exit(1);
1370             }
1371             peer->port = malloc(p - field + 1);
1372             if (!peer->port)
1373                 errx("malloc failed");
1374             memcpy(peer->port, field, p - field);
1375             peer->port[p - field] = '\0';
1376     } else
1377         peer->port = NULL;
1378     return p;
1379 }
1380
1381 // * is default, else longest match ... ";" used for separator
1382 char *parserealmlist(char *s, struct server *server) {
1383     char *p;
1384     int i, n, l;
1385
1386     for (p = s, n = 1; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++)
1387         if (*p == ';')
1388             n++;
1389     l = p - s;
1390     if (!l) {
1391         server->realms = NULL;
1392         return p;
1393     }
1394     server->realmdata = malloc(l + 1);
1395     if (!server->realmdata)
1396         errx("malloc failed");
1397     memcpy(server->realmdata, s, l);
1398     server->realmdata[l] = '\0';
1399     server->realms = malloc((1+n) * sizeof(char *));
1400     if (!server->realms)
1401         errx("malloc failed");
1402     server->realms[0] = server->realmdata;
1403     for (n = 1, i = 0; i < l; i++)
1404         if (server->realmdata[i] == ';') {
1405             server->realmdata[i] = '\0';
1406             server->realms[n++] = server->realmdata + i + 1;
1407         }       
1408     server->realms[n] = NULL;
1409     return p;
1410 }
1411
1412 /* exactly one argument must be non-NULL */
1413 void getconfig(const char *serverfile, const char *clientfile) {
1414     FILE *f;
1415     char line[1024];
1416     char *p, *field, **r;
1417     struct client *client;
1418     struct server *server;
1419     struct peer *peer;
1420     int *count;
1421     
1422     if (serverfile) {
1423         printf("opening file %s for reading\n", serverfile);
1424         f = fopen(serverfile, "r");
1425         if (!f)
1426             errx("getconfig failed to open %s for reading", serverfile);
1427         count = &server_count;
1428     } else {
1429         printf("opening file %s for reading\n", clientfile);
1430         f = fopen(clientfile, "r");
1431         if (!f)
1432             errx("getconfig failed to open %s for reading", clientfile);
1433         udp_server_replyq.replies = malloc(4 * MAX_REQUESTS * sizeof(struct reply));
1434         if (!udp_server_replyq.replies)
1435             errx("malloc failed");
1436         udp_server_replyq.size = 4 * MAX_REQUESTS;
1437         udp_server_replyq.count = 0;
1438         pthread_mutex_init(&udp_server_replyq.count_mutex, NULL);
1439         pthread_cond_init(&udp_server_replyq.count_cond, NULL);
1440         count = &client_count;
1441     }    
1442     
1443     *count = 0;
1444     while (fgets(line, 1024, f) && *count < MAX_PEERS) {
1445         if (serverfile) {
1446             server = &servers[*count];
1447             memset(server, 0, sizeof(struct server));
1448             peer = &server->peer;
1449         } else {
1450             client = &clients[*count];
1451             memset(client, 0, sizeof(struct client));
1452             peer = &client->peer;
1453         }
1454         for (p = line; *p == ' ' || *p == '\t'; p++);
1455         if (*p == '#' || *p == '\n')
1456             continue;
1457         if (*p != 'U' && *p != 'T') {
1458             printf("server type must be U or T, got %c\n", *p);
1459             exit(1);
1460         }
1461         peer->type = *p;
1462         for (p++; *p == ' ' || *p == '\t'; p++);
1463         p = parsehostport(p, peer);
1464         if (!peer->port)
1465             peer->port = (peer->type == 'U' ? DEFAULT_UDP_PORT : DEFAULT_TLS_PORT);
1466         for (; *p == ' ' || *p == '\t'; p++);
1467         if (serverfile) {
1468             p = parserealmlist(p, server);
1469             if (!server->realms) {
1470                 printf("realm list must be specified\n");
1471                 exit(1);
1472             }
1473             for (; *p == ' ' || *p == '\t'; p++);
1474         }
1475         field = p;
1476         for (; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++);
1477         if (field == p) {
1478             /* no secret set and end of line, line is complete if TLS */
1479             if (peer->type == 'U') {
1480                 printf("secret must be specified for UDP\n");
1481                 exit(1);
1482             }
1483             peer->secret = DEFAULT_TLS_SECRET;
1484         } else {
1485             peer->secret = malloc(p - field + 1);
1486             if (!peer->secret)
1487                 errx("malloc failed");
1488             memcpy(peer->secret, field, p - field);
1489             peer->secret[p - field] = '\0';
1490             /* check that rest of line only white space */
1491             for (; *p == ' ' || *p == '\t'; p++);
1492             if (*p && *p != '\n') {
1493                 printf("max 4 fields per line, found a 5th\n");
1494                 exit(1);
1495             }
1496         }
1497
1498         if ((serverfile && !resolvepeer(&server->peer)) ||
1499             (clientfile && !resolvepeer(&client->peer))) {
1500             printf("failed to resolve host %s port %s, exiting\n", peer->host, peer->port);
1501             exit(1);
1502         }
1503
1504         if (serverfile) {
1505             pthread_mutex_init(&server->lock, NULL);
1506             server->sock = -1;
1507             server->requests = malloc(MAX_REQUESTS * sizeof(struct request));
1508             if (!server->requests)
1509                 errx("malloc failed");
1510             memset(server->requests, 0, MAX_REQUESTS * sizeof(struct request));
1511             server->newrq = 0;
1512             pthread_mutex_init(&server->newrq_mutex, NULL);
1513             pthread_cond_init(&server->newrq_cond, NULL);
1514         } else {
1515             if (peer->type == 'U')
1516                 client->replyq = &udp_server_replyq;
1517             else {
1518                 client->replyq = malloc(sizeof(struct replyq));
1519                 if (!client->replyq)
1520                     errx("malloc failed");
1521                 client->replyq->replies = malloc(MAX_REQUESTS * sizeof(struct reply));
1522                 if (!client->replyq->replies)
1523                     errx("malloc failed");
1524                 client->replyq->size = MAX_REQUESTS;
1525                 client->replyq->count = 0;
1526                 pthread_mutex_init(&client->replyq->count_mutex, NULL);
1527                 pthread_cond_init(&client->replyq->count_cond, NULL);
1528             }
1529         }
1530         printf("got type %c, host %s, port %s, secret %s\n", peer->type, peer->host, peer->port, peer->secret);
1531         if (serverfile) {
1532             printf("    with realms:");
1533             for (r = server->realms; *r; r++)
1534                 printf(" %s", *r);
1535             printf("\n");
1536         }
1537         (*count)++;
1538     }
1539     fclose(f);
1540 }
1541
1542 void parseargs(int argc, char **argv) {
1543     int c;
1544
1545     while ((c = getopt(argc, argv, "p:")) != -1) {
1546         switch (c) {
1547         case 'p':
1548             udp_server_port = optarg;
1549             break;
1550         default:
1551             goto usage;
1552         }
1553     }
1554
1555     return;
1556
1557  usage:
1558     printf("radsecproxy [ -p UDP-port ]\n");
1559     exit(1);
1560 }
1561                
1562 int main(int argc, char **argv) {
1563     SSL_CTX *ssl_ctx_srv;
1564     unsigned long error;
1565     pthread_t udpserverth;
1566     pthread_attr_t joinable;
1567     int i;
1568     
1569     parseargs(argc, argv);
1570     getconfig("servers.conf", NULL);
1571     getconfig(NULL, "clients.conf");
1572     
1573     ssl_locks_setup();
1574
1575     pthread_attr_init(&joinable);
1576     pthread_attr_setdetachstate(&joinable, PTHREAD_CREATE_JOINABLE);
1577    
1578     /* listen on UDP if at least one UDP client */
1579     
1580     for (i = 0; i < client_count; i++)
1581         if (clients[i].peer.type == 'U') {
1582             if (pthread_create(&udpserverth, &joinable, udpserverrd, NULL))
1583                 errx("pthread_create failed");
1584             break;
1585         }
1586     
1587     /* SSL setup */
1588     SSL_load_error_strings();
1589     SSL_library_init();
1590
1591     while (!RAND_status()) {
1592         time_t t = time(NULL);
1593         pid_t pid = getpid();
1594         RAND_seed((unsigned char *)&t, sizeof(time_t));
1595         RAND_seed((unsigned char *)&pid, sizeof(pid));
1596     }
1597     
1598     /* initialise client part and start clients */
1599     ssl_ctx_cl = SSL_CTX_new(TLSv1_client_method());
1600     if (!ssl_ctx_cl)
1601         errx("no ssl ctx");
1602     
1603     for (i = 0; i < server_count; i++) {
1604         if (pthread_create(&servers[i].clientth, NULL, clientwr, (void *)&servers[i]))
1605             errx("pthread_create failed");
1606     }
1607
1608     for (i = 0; i < client_count; i++)
1609         if (clients[i].peer.type == 'T')
1610             break;
1611
1612     if (i == client_count) {
1613         printf("No TLS clients defined, not starting TLS listener\n");
1614         /* just hang around doing nothing, anything to do here? */
1615         for (;;)
1616             sleep(1000);
1617     }
1618     
1619     /* setting up server/daemon part */
1620     ssl_ctx_srv = SSL_CTX_new(TLSv1_server_method());
1621     if (!ssl_ctx_srv)
1622         errx("no ssl ctx");
1623     if (!SSL_CTX_use_certificate_file(ssl_ctx_srv, "/tmp/server.pem", SSL_FILETYPE_PEM)) {
1624         while ((error = ERR_get_error()))
1625             err("SSL: %s", ERR_error_string(error, NULL));
1626         errx("Failed to load certificate");
1627     }
1628     if (!SSL_CTX_use_PrivateKey_file(ssl_ctx_srv, "/tmp/server.key", SSL_FILETYPE_PEM)) {
1629         while ((error = ERR_get_error()))
1630             err("SSL: %s", ERR_error_string(error, NULL));
1631         errx("Failed to load private key");
1632     }
1633
1634     return tlslistener(ssl_ctx_srv);
1635 }