added tcp client support
authorvenaas <venaas>
Wed, 23 Jul 2008 09:14:16 +0000 (09:14 +0000)
committervenaas <venaas@e88ac4ed-0b26-0410-9574-a7f39faa03bf>
Wed, 23 Jul 2008 09:14:16 +0000 (09:14 +0000)
git-svn-id: https://svn.testnett.uninett.no/radsecproxy/trunk@322 e88ac4ed-0b26-0410-9574-a7f39faa03bf

radsecproxy.c
radsecproxy.h

index 6caa657..f31e107 100644 (file)
  *          1 + (2 + 2 * 3) + (2 * 30) + (2 * 30) = 129 threads
 */
 
+/* Bugs:
+ * We are not removing client requests from dynamic servers, see removeclientrqs()
+ */
+
 #include <signal.h>
 #include <sys/socket.h>
 #include <netinet/in.h>
@@ -63,8 +67,7 @@
 static struct options options;
 struct list *clconfs, *srvconfs, *realms, *tlsconfs, *rewriteconfs;
 
-static struct addrinfo *srcudpres = NULL;
-static struct addrinfo *srctcpres = NULL;
+static struct addrinfo *srcprotores[3] = { NULL, NULL, NULL };
 
 static struct replyq *udp_server_replyq = NULL;
 static int udp_client4_sock = -1;
@@ -84,7 +87,15 @@ void freerqdata(struct request *rq);
 void *udpserverrd(void *arg);
 void *tlslistener(void *arg);
 void *tcplistener(void *arg);
-
+int tlsconnect(struct server *server, struct timeval *when, int timeout, char *text);
+int tcpconnect(struct server *server, struct timeval *when, int timeout, char *text);
+void *udpclientrd(void *arg);
+void *tlsclientrd(void *arg);
+void *tcpclientrd(void *arg);
+int clientradputudp(struct server *server, unsigned char *rad);
+int clientradputtls(struct server *server, unsigned char *rad);
+int clientradputtcp(struct server *server, unsigned char *rad);
+    
 static const struct protodefs protodefs[] = {
     {   "udp", /* UDP, assuming RAD_UDP defined as 0 */
        NULL, /* secretdefault */
@@ -94,7 +105,11 @@ static const struct protodefs protodefs[] = {
        10, /* retrycountmax */
        REQUEST_RETRY_INTERVAL, /* retryintervaldefault */
        60, /* retryintervalmax */
-       udpserverrd
+       udpserverrd, /* listener */
+       &options.sourceudp, /* srcaddrport */
+       NULL, /* connecter */
+       udpclientrd, /* clientreader */
+       clientradputudp /* clientradput */
     },
     {   "tls", /* TLS, assuming RAD_TLS defined as 1 */
        "mysecret", /* secretdefault */
@@ -104,7 +119,11 @@ static const struct protodefs protodefs[] = {
        0, /* retrycountmax */
        REQUEST_RETRY_INTERVAL * REQUEST_RETRY_COUNT, /* retryintervaldefault */
        60, /* retryintervalmax */
-       tlslistener
+       tlslistener, /* listener */
+       &options.sourcetls, /* srcaddrport */
+       tlsconnect, /* connecter */
+       tlsclientrd, /* clientreader */
+       clientradputtls /* clientradput */
     },
     {   "tcp", /* TCP, assuming RAD_TCP defined as 2 */
        NULL, /* secretdefault */
@@ -114,7 +133,11 @@ static const struct protodefs protodefs[] = {
        0, /* retrycountmax */
        REQUEST_RETRY_INTERVAL * REQUEST_RETRY_COUNT, /* retryintervaldefault */
        60, /* retryintervalmax */
-       tcplistener
+       tcplistener, /* listener */
+       &options.sourcetcp, /* srcaddrport */
+       tcpconnect, /* connecter */
+       tcpclientrd, /* clientreader */
+       clientradputtcp /* clientradput */
     },
     {   NULL
     }
@@ -363,13 +386,13 @@ void freeclsrvres(struct clsrvconf *res) {
     free(res);
 }
 
-int connecttcp(struct addrinfo *addrinfo) {
+int connecttcp(struct addrinfo *addrinfo, struct addrinfo *src) {
     int s;
     struct addrinfo *res;
 
     s = -1;
     for (res = addrinfo; res; res = res->ai_next) {
-       s = bindtoaddr(srctcpres, res->ai_family, 1, 1);
+       s = bindtoaddr(src, res->ai_family, 1, 1);
         if (s < 0) {
             debug(DBG_WARN, "connecttoserver: socket failed");
             continue;
@@ -604,17 +627,18 @@ int addserver(struct clsrvconf *conf) {
     memset(conf->servers, 0, sizeof(struct server));
     conf->servers->conf = conf;
 
+    if (!srcprotores[conf->type]) {
+       res = resolve_hostport(conf->type, *conf->pdef->srcaddrport, NULL);
+       srcprotores[conf->type] = res->addrinfo;
+       res->addrinfo = NULL;
+       freeclsrvres(res);
+    }
+
     if (conf->type == RAD_UDP) {
-       if (!srcudpres) {
-           res = resolve_hostport(RAD_UDP, options.sourceudp, NULL);
-           srcudpres = res->addrinfo;
-           res->addrinfo = NULL;
-           freeclsrvres(res);
-       }
        switch (conf->addrinfo->ai_family) {
        case AF_INET:
            if (udp_client4_sock < 0) {
-               udp_client4_sock = bindtoaddr(srcudpres, AF_INET, 0, 1);
+               udp_client4_sock = bindtoaddr(srcprotores[RAD_UDP], AF_INET, 0, 1);
                if (udp_client4_sock < 0)
                    debugx(1, DBG_ERR, "addserver: failed to create client socket for server %s", conf->host);
            }
@@ -622,7 +646,7 @@ int addserver(struct clsrvconf *conf) {
            break;
        case AF_INET6:
            if (udp_client6_sock < 0) {
-               udp_client6_sock = bindtoaddr(srcudpres, AF_INET6, 0, 1);
+               udp_client6_sock = bindtoaddr(srcprotores[RAD_UDP], AF_INET6, 0, 1);
                if (udp_client6_sock < 0)
                    debugx(1, DBG_ERR, "addserver: failed to create client socket for server %s", conf->host);
            }
@@ -632,15 +656,8 @@ int addserver(struct clsrvconf *conf) {
            debugx(1, DBG_ERR, "addserver: unsupported address family");
        }
        
-    } else {
-       if (!srctcpres) {
-           res = resolve_hostport(RAD_TLS, options.sourcetcp, NULL);
-           srctcpres = res->addrinfo;
-           res->addrinfo = NULL;
-           freeclsrvres(res);
-       }
+    } else
        conf->servers->sock = -1;
-    }
     
     conf->servers->requests = calloc(MAX_REQUESTS, sizeof(struct request));
     if (!conf->servers->requests) {
@@ -943,7 +960,7 @@ int tlsconnect(struct server *server, struct timeval *when, int timeout, char *t
     time_t elapsed;
     X509 *cert;
     
-    debug(DBG_DBG, "tlsconnect called from %s", text);
+    debug(DBG_DBG, "tlsconnect: called from %s", text);
     pthread_mutex_lock(&server->lock);
     if (when && memcmp(&server->lastconnecttry, when, sizeof(struct timeval))) {
        /* already reconnected, nothing to do */
@@ -952,8 +969,6 @@ int tlsconnect(struct server *server, struct timeval *when, int timeout, char *t
        return 1;
     }
 
-    debug(DBG_DBG, "tlsconnect %s", text);
-
     for (;;) {
        gettimeofday(&now, NULL);
        elapsed = now.tv_sec - server->lastconnecttry.tv_sec;
@@ -982,7 +997,7 @@ int tlsconnect(struct server *server, struct timeval *when, int timeout, char *t
        debug(DBG_WARN, "tlsconnect: trying to open TLS connection to %s port %s", server->conf->host, server->conf->port);
        if (server->sock >= 0)
            close(server->sock);
-       if ((server->sock = connecttcp(server->conf->addrinfo)) < 0) {
+       if ((server->sock = connecttcp(server->conf->addrinfo, srcprotores[RAD_TLS])) < 0) {
            debug(DBG_ERR, "tlsconnect: connecttcp failed");
            continue;
        }
@@ -1007,6 +1022,55 @@ int tlsconnect(struct server *server, struct timeval *when, int timeout, char *t
     return 1;
 }
 
+int tcpconnect(struct server *server, struct timeval *when, int timeout, char *text) {
+    struct timeval now;
+    time_t elapsed;
+    
+    debug(DBG_DBG, "tcpconnect: called from %s", text);
+    pthread_mutex_lock(&server->lock);
+    if (when && memcmp(&server->lastconnecttry, when, sizeof(struct timeval))) {
+       /* already reconnected, nothing to do */
+       debug(DBG_DBG, "tcpconnect(%s): seems already reconnected", text);
+       pthread_mutex_unlock(&server->lock);
+       return 1;
+    }
+
+    for (;;) {
+       gettimeofday(&now, NULL);
+       elapsed = now.tv_sec - server->lastconnecttry.tv_sec;
+       if (timeout && server->lastconnecttry.tv_sec && elapsed > timeout) {
+           debug(DBG_DBG, "tcpconnect: timeout");
+           if (server->sock >= 0)
+               close(server->sock);
+           pthread_mutex_unlock(&server->lock);
+           return 0;
+       }
+       if (server->connectionok) {
+           server->connectionok = 0;
+           sleep(2);
+       } else if (elapsed < 1)
+           sleep(2);
+       else if (elapsed < 60) {
+           debug(DBG_INFO, "tcpconnect: sleeping %lds", elapsed);
+           sleep(elapsed);
+       } else if (elapsed < 100000) {
+           debug(DBG_INFO, "tcpconnect: sleeping %ds", 60);
+           sleep(60);
+       } else
+           server->lastconnecttry.tv_sec = now.tv_sec;  /* no sleep at startup */
+       debug(DBG_WARN, "tcpconnect: trying to open TCP connection to %s port %s", server->conf->host, server->conf->port);
+       if (server->sock >= 0)
+           close(server->sock);
+       if ((server->sock = connecttcp(server->conf->addrinfo, srcprotores[RAD_TCP])) >= 0)
+           break;
+       debug(DBG_ERR, "tcpconnect: connecttcp failed");
+    }
+    debug(DBG_WARN, "tcpconnect: TCP connection to %s port %s up", server->conf->host, server->conf->port);
+    gettimeofday(&server->lastconnecttry, NULL);
+    pthread_mutex_unlock(&server->lock);
+    return 1;
+}
+
 /* timeout in seconds, 0 means no timeout (blocking), returns when num bytes have been read, or timeout */
 /* returns 0 on timeout, -1 on error and num if ok */
 int sslreadtimeout(SSL *ssl, unsigned char *buf, int num, int timeout) {
@@ -1214,14 +1278,23 @@ int clientradputtls(struct server *server, unsigned char *rad) {
     return 1;
 }
 
-int clientradput(struct server *server, unsigned char *rad) {
-    switch (server->conf->type) {
-    case RAD_UDP:
-       return clientradputudp(server, rad);
-    case RAD_TLS:
-       return clientradputtls(server, rad);
+int clientradputtcp(struct server *server, unsigned char *rad) {
+    int cnt;
+    size_t len;
+    struct timeval lastconnecttry;
+    struct clsrvconf *conf = server->conf;
+    
+    len = RADLEN(rad);
+    lastconnecttry = server->lastconnecttry;
+    while ((cnt = write(server->sock, rad, len)) <= 0) {
+       debug(DBG_ERR, "clientradputtcp: write error");
+       tcpconnect(server, &lastconnecttry, 0, "clientradputtcp");
+       lastconnecttry = server->lastconnecttry;
     }
-    return 0;
+
+    server->connectionok = 1;
+    debug(DBG_DBG, "clientradputtcp: Sent %d bytes, Radius packet of length %d to TCP peer %s", cnt, len, conf->host);
+    return 1;
 }
 
 int radsign(unsigned char *rad, unsigned char *sec) {
@@ -2454,7 +2527,7 @@ void *tlsclientrd(void *arg) {
        if (!buf) {
            if (server->dynamiclookuparg)
                break;
-           tlsconnect(server, &lastconnecttry, 0, "clientrd");
+           tlsconnect(server, &lastconnecttry, 0, "tlsclientrd");
            continue;
        }
 
@@ -2463,7 +2536,7 @@ void *tlsclientrd(void *arg) {
        if (server->dynamiclookuparg) {
            gettimeofday(&now, NULL);
            if (now.tv_sec - server->lastreply.tv_sec > IDLE_TIMEOUT) {
-               debug(DBG_INFO, "clientrd: idle timeout for %s", server->conf->name);
+               debug(DBG_INFO, "tlsclientrd: idle timeout for %s", server->conf->name);
                break;
            }
        }
@@ -2472,11 +2545,32 @@ void *tlsclientrd(void *arg) {
     return NULL;
 }
 
+void *tcpclientrd(void *arg) {
+    struct server *server = (struct server *)arg;
+    unsigned char *buf;
+    struct timeval lastconnecttry;
+    
+    for (;;) {
+       /* yes, lastconnecttry is really necessary */
+       lastconnecttry = server->lastconnecttry;
+       buf = radtcpget(server->sock, 0);
+       if (!buf) {
+           tcpconnect(server, &lastconnecttry, 0, "tcpclientrd");
+           continue;
+       }
+
+       if (!replyh(server, buf))
+           free(buf);
+    }
+    server->clientrdgone = 1;
+    return NULL;
+}
+
 /* code for removing state not finished */
 void *clientwr(void *arg) {
     struct server *server = (struct server *)arg;
     struct request *rq;
-    pthread_t tlsclientrdth;
+    pthread_t clientrdth;
     int i, dynconffail = 0;
     uint8_t rnd;
     struct timeval now, lastsend;
@@ -2508,19 +2602,18 @@ void *clientwr(void *arg) {
        statsrvbuf[21] = 18;
        gettimeofday(&lastsend, NULL);
     }
-    
-    if (conf->type == RAD_UDP) {
-       server->connectionok = 1;
-    } else {
-       if (!tlsconnect(server, NULL, server->dynamiclookuparg ? 6 : 0, "clientwr"))
+
+    if (conf->pdef->connecter) {
+       if (!conf->pdef->connecter(server, NULL, server->dynamiclookuparg ? 6 : 0, "clientwr"))
            goto errexit;
        server->connectionok = 1;
-       if (pthread_create(&tlsclientrdth, NULL, tlsclientrd, (void *)server)) {
+       if (pthread_create(&clientrdth, NULL, conf->pdef->clientreader, (void *)server)) {
            debug(DBG_ERR, "clientwr: pthread_create failed");
            goto errexit;
        }
-    }
-
+    } else
+       server->connectionok = 1;
+    
     for (;;) {
        pthread_mutex_lock(&server->newrq_mutex);
        if (!server->newrq) {
@@ -2554,7 +2647,7 @@ void *clientwr(void *arg) {
 
        for (i = 0; i < MAX_REQUESTS; i++) {
            if (server->clientrdgone) {
-               pthread_join(tlsclientrdth, NULL);
+               pthread_join(clientrdth, NULL);
                goto errexit;
            }
            pthread_mutex_lock(&server->newrq_mutex);
@@ -2611,7 +2704,7 @@ void *clientwr(void *arg) {
            if (!timeout.tv_sec || rq->expiry.tv_sec < timeout.tv_sec)
                timeout.tv_sec = rq->expiry.tv_sec;
            rq->tries++;
-           clientradput(server, server->requests[i].buf);
+           conf->pdef->clientradput(server, server->requests[i].buf);
            gettimeofday(&lastsend, NULL);
        }
        if (conf->statusserver) {
@@ -3991,6 +4084,7 @@ void getmainconfig(const char *configfile) {
                          "ListenAccountingUDP", CONF_MSTR, &options.listenaccudp,
                          "SourceUDP", CONF_STR, &options.sourceudp,
                          "SourceTCP", CONF_STR, &options.sourcetcp,
+                         "SourceTLS", CONF_STR, &options.sourcetls,
                          "LogLevel", CONF_LINT, &loglevel,
                          "LogDestination", CONF_STR, &options.logdestination,
                          "LoopPrevention", CONF_BLN, &options.loopprevention,
@@ -4143,16 +4237,16 @@ int main(int argc, char **argv) {
                           (void *)(srvconf->servers)))
            debugx(1, DBG_ERR, "pthread_create failed");
     }
-    /* srcudpres no longer needed, while srctcpres is needed later */
-    if (srcudpres) {
-       freeaddrinfo(srcudpres);
-       srcudpres = NULL;
+    /* srcprotores for UDP no longer needed */
+    if (srcprotores[RAD_UDP]) {
+       freeaddrinfo(srcprotores[RAD_UDP]);
+       srcprotores[RAD_UDP] = NULL;
     }
     if (udp_client4_sock >= 0)
-       if (pthread_create(&udpclient4rdth, NULL, udpclientrd, (void *)&udp_client4_sock))
+       if (pthread_create(&udpclient4rdth, NULL, protodefs[RAD_UDP].clientreader, (void *)&udp_client4_sock))
            debugx(1, DBG_ERR, "pthread_create failed");
     if (udp_client6_sock >= 0)
-       if (pthread_create(&udpclient6rdth, NULL, udpclientrd, (void *)&udp_client6_sock))
+       if (pthread_create(&udpclient6rdth, NULL, protodefs[RAD_UDP].clientreader, (void *)&udp_client6_sock))
            debugx(1, DBG_ERR, "pthread_create failed");
     
     if (find_conf_type(RAD_TCP, clconfs, NULL))
index 171c3fd..2736733 100644 (file)
@@ -48,6 +48,7 @@ struct options {
     char **listenaccudp;
     char *sourceudp;
     char *sourcetcp;
+    char *sourcetls;
     char *logdestination;
     uint8_t loglevel;
     uint8_t loopprevention;
@@ -175,6 +176,10 @@ struct protodefs {
     uint8_t retryintervaldefault;
     uint8_t retryintervalmax;
     void *(*listener)(void*);
+    char **srcaddrport;
+    int (*connecter)(struct server *, struct timeval *, int, char *);
+    void *(*clientreader)(void*);
+    int (*clientradput)(struct server *, unsigned char *);
 };
 
 #define RADLEN(x) ntohs(((uint16_t *)(x))[1])