added tcp and udp listen directives
[radsecproxy.git] / radsecproxy.c
index 0efb1a8..3956da0 100644 (file)
@@ -61,6 +61,8 @@ static int server_udp_count = 0;
 static int server_tls_count = 0;
 static int server_count = 0;
 
+static struct peer *tcp_server_listen;
+static struct peer *udp_server_listen;
 static struct replyq udp_server_replyq;
 static int udp_server_sock = -1;
 static pthread_mutex_t *ssl_locks;
@@ -82,6 +84,14 @@ void ssl_locking_callback(int mode, int type, const char *file, int line) {
        pthread_mutex_unlock(&ssl_locks[type]);
 }
 
+static int pem_passwd_cb(char *buf, int size, int rwflag, void *userdata) {
+    int pwdlen = strlen(userdata);
+    if (rwflag != 0 || pwdlen > size) /* not for decryption or too large */
+       return 0;
+    memcpy(buf, userdata, pwdlen);
+    return pwdlen;
+}
+
 static int verify_cb(int ok, X509_STORE_CTX *ctx) {
   char buf[256];
   X509 *err_cert;
@@ -156,6 +166,10 @@ SSL_CTX *ssl_init() {
     }
 
     ctx = SSL_CTX_new(TLSv1_method());
+    if (options.tlscertificatekeypassword) {
+       SSL_CTX_set_default_passwd_cb_userdata(ctx, options.tlscertificatekeypassword);
+       SSL_CTX_set_default_passwd_cb(ctx, pem_passwd_cb);
+    }
     if (SSL_CTX_use_certificate_chain_file(ctx, options.tlscertificatefile) &&
        SSL_CTX_use_PrivateKey_file(ctx, options.tlscertificatekeyfile, SSL_FILETYPE_PEM) &&
        SSL_CTX_check_private_key(ctx) &&
@@ -178,12 +192,13 @@ void printauth(char *s, unsigned char *t) {
     printf("\n");
 }
 
-int resolvepeer(struct peer *peer) {
+int resolvepeer(struct peer *peer, int ai_flags) {
     struct addrinfo hints, *addrinfo;
     
     memset(&hints, 0, sizeof(hints));
     hints.ai_socktype = (peer->type == 'T' ? SOCK_STREAM : SOCK_DGRAM);
     hints.ai_family = AF_UNSPEC;
+    hints.ai_flags = ai_flags;
     if (getaddrinfo(peer->host, peer->port, &hints, &addrinfo)) {
        err("resolvepeer: can't resolve %s port %s", peer->host, peer->port);
        return 0;
@@ -214,6 +229,25 @@ int connecttoserver(struct addrinfo *addrinfo) {
     return s;
 }        
 
+int bindtoaddr(struct addrinfo *addrinfo) {
+    int s, on = 1;
+    struct addrinfo *res;
+    
+    for (res = addrinfo; res; res = res->ai_next) {
+        s = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
+        if (s < 0) {
+            err("bindtoaddr: socket failed");
+            continue;
+        }
+       setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on));
+       if (!bind(s, res->ai_addr, res->ai_addrlen))
+           return s;
+       err("bindtoaddr: bind failed");
+        close(s);
+    }
+    return -1;
+}        
+
 /* returns the client with matching address, or NULL */
 /* if client argument is not NULL, we only check that one client */
 struct client *find_client(char type, struct sockaddr *addr, struct client *client) {
@@ -1278,13 +1312,14 @@ void *clientwr(void *arg) {
                printf("clientwr: waiting for new request\n");
                pthread_cond_wait(&server->newrq_cond, &server->newrq_mutex);
            }
-           if (server->newrq) {
-               printf("clientwr: got new request\n");
-               server->newrq = 0;
-           }
        }
+       if (server->newrq) {
+           printf("clientwr: got new request\n");
+           server->newrq = 0;
+       } else
+           printf("clientwr: request timer expired, processing request queue\n");
        pthread_mutex_unlock(&server->newrq_mutex);
-       
+
        for (i = 0; i < MAX_REQUESTS; i++) {
            pthread_mutex_lock(&server->newrq_mutex);
            while (!server->requests[i].buf && i < MAX_REQUESTS)
@@ -1365,12 +1400,12 @@ void *udpserverrd(void *arg) {
     struct server *to;
     struct client *fr;
     pthread_t udpserverwrth;
-    
-    if ((udp_server_sock = bindport(SOCK_DGRAM, options.udpserverport)) < 0) {
+
+    if ((udp_server_sock = bindtoaddr(udp_server_listen->addrinfo)) < 0) {
         printf("udpserverrd: socket/bind failed\n");
        exit(1);
     }
-    printf("udpserverrd: listening on UDP port %s\n", options.udpserverport);
+    printf("udpserverrd: listening for UDP on host %s port %s\n", udp_server_listen->host, udp_server_listen->port);
 
     if (pthread_create(&udpserverwrth, NULL, udpserverwr, NULL))
        errx("pthread_create failed");
@@ -1487,18 +1522,20 @@ int tlslistener() {
     size_t fromlen = sizeof(from);
     struct client *client;
 
-    if ((s = bindport(SOCK_STREAM, DEFAULT_TLS_PORT)) < 0) {
+    if ((s = bindtoaddr(tcp_server_listen->addrinfo)) < 0) {
         printf("tlslistener: socket/bind failed\n");
        exit(1);
     }
     
     listen(s, 0);
-    printf("listening for incoming TLS on port %s\n", DEFAULT_TLS_PORT);
+    printf("listening for incoming TCP on address %s port %s\n", tcp_server_listen->host, tcp_server_listen->port);
 
     for (;;) {
        snew = accept(s, (struct sockaddr *)&from, &fromlen);
-       if (snew < 0)
-           errx("accept failed");
+       if (snew < 0) {
+           err("accept failed");
+           continue;
+       }
        printf("incoming TLS connection from %s\n", addr2string((struct sockaddr *)&from, fromlen));
 
        client = find_client('T', (struct sockaddr *)&from, NULL);
@@ -1557,7 +1594,7 @@ char *parsehostport(char *s, struct peer *peer) {
     }
     if (*p == ':') {
            /* port number or service name is specified */;
-           field = p++;
+           field = ++p;
            for (; *p && *p != ' ' && *p != '\t' && *p != '\n'; p++);
            if (field == p) {
                printf("syntax error, : but no following port\n");
@@ -1717,8 +1754,8 @@ void getconfig(const char *serverfile, const char *clientfile) {
            }
        }
 
-       if ((serverfile && !resolvepeer(&server->peer)) ||
-           (clientfile && !resolvepeer(&client->peer))) {
+       if ((serverfile && !resolvepeer(&server->peer, 0)) ||
+           (clientfile && !resolvepeer(&client->peer, 0))) {
            printf("failed to resolve host %s port %s, exiting\n", peer->host, peer->port);
            exit(1);
        }
@@ -1760,6 +1797,33 @@ void getconfig(const char *serverfile, const char *clientfile) {
     fclose(f);
 }
 
+struct peer *server_create(char type) {
+    struct peer *server;
+    char *conf;
+
+    server = malloc(sizeof(struct peer));
+    if (!server)
+       errx("malloc failed");
+    memset(server, 0, sizeof(struct peer));
+    server->type = type;
+    conf = (type == 'T' ? options.listentcp : options.listenudp);
+    if (conf) {
+       parsehostport(conf, server);
+       if (!strcmp(server->host, "*")) {
+           free(server->host);
+           server->host = NULL;
+       }
+    } else if (type == 'T')
+       server->port = stringcopy(DEFAULT_TLS_PORT, 0);
+    else
+       server->port = stringcopy(options.udpserverport ? options.udpserverport : DEFAULT_UDP_PORT, 0);
+    if (!resolvepeer(server, AI_PASSIVE)) {
+       printf("failed to resolve host %s port %s, exiting\n", server->host, server->port);
+       exit(1);
+    }
+    return server;
+}
+               
 void getmainconfig(const char *configfile) {
     FILE *f;
     char line[1024];
@@ -1805,18 +1869,26 @@ void getmainconfig(const char *configfile) {
            options.tlscertificatekeyfile = stringcopy(val, 0);
            continue;
        }
+       if (!strcasecmp(opt, "TLSCertificateKeyPassword")) {
+           options.tlscertificatekeypassword = stringcopy(val, 0);
+           continue;
+       }
        if (!strcasecmp(opt, "UDPServerPort")) {
            options.udpserverport = stringcopy(val, 0);
            continue;
        }
-
+       if (!strcasecmp(opt, "ListenUDP")) {
+           options.listenudp = stringcopy(val, 0);
+           continue;
+       }
+       if (!strcasecmp(opt, "ListenTCP")) {
+           options.listentcp = stringcopy(val, 0);
+           continue;
+       }
        printf("error in %s, unknown option %s\n", configfile, opt);
        exit(1);
     }
     fclose(f);
-
-    if (!options.udpserverport)
-       options.udpserverport = stringcopy(DEFAULT_UDP_PORT, 0);
 }
 
 #if 0
@@ -1854,10 +1926,12 @@ int main(int argc, char **argv) {
     /*    pthread_attr_init(&joinable); */
     /*    pthread_attr_setdetachstate(&joinable, PTHREAD_CREATE_JOINABLE); */
    
-    if (client_udp_count)
+    if (client_udp_count) {
+       udp_server_listen = server_create('U');
        if (pthread_create(&udpserverth, NULL /*&joinable*/, udpserverrd, NULL))
            errx("pthread_create failed");
-
+    }
+    
     if (client_tls_count || server_tls_count)
        ssl_ctx = ssl_init();
     
@@ -1865,9 +1939,11 @@ int main(int argc, char **argv) {
        if (pthread_create(&servers[i].clientth, NULL, clientwr, (void *)&servers[i]))
            errx("pthread_create failed");
 
-    if (client_tls_count)
+    if (client_tls_count) {
+       tcp_server_listen = server_create('T');
        return tlslistener();
-
+    }
+    
     /* just hang around doing nothing, anything to do here? */
     for (;;)
        sleep(1000);