added tcp server support, improved some log messages, fixed bug when trying to free...
authorvenaas <venaas>
Wed, 23 Jul 2008 07:09:07 +0000 (07:09 +0000)
committervenaas <venaas@e88ac4ed-0b26-0410-9574-a7f39faa03bf>
Wed, 23 Jul 2008 07:09:07 +0000 (07:09 +0000)
git-svn-id: https://svn.testnett.uninett.no/radsecproxy/trunk@321 e88ac4ed-0b26-0410-9574-a7f39faa03bf

radsecproxy.c
radsecproxy.h

index 7114345..6caa657 100644 (file)
@@ -81,6 +81,9 @@ int confserver_cb(struct gconffile **cf, void *arg, char *block, char *opt, char
 void freerealm(struct realm *realm);
 void freeclsrvconf(struct clsrvconf *conf);
 void freerqdata(struct request *rq);
+void *udpserverrd(void *arg);
+void *tlslistener(void *arg);
+void *tcplistener(void *arg);
 
 static const struct protodefs protodefs[] = {
     {   "udp", /* UDP, assuming RAD_UDP defined as 0 */
@@ -90,7 +93,8 @@ static const struct protodefs protodefs[] = {
        REQUEST_RETRY_COUNT, /* retrycountdefault */
        10, /* retrycountmax */
        REQUEST_RETRY_INTERVAL, /* retryintervaldefault */
-       60 /* retryintervalmax */
+       60, /* retryintervalmax */
+       udpserverrd
     },
     {   "tls", /* TLS, assuming RAD_TLS defined as 1 */
        "mysecret", /* secretdefault */
@@ -99,7 +103,8 @@ static const struct protodefs protodefs[] = {
        0, /* retrycountdefault */
        0, /* retrycountmax */
        REQUEST_RETRY_INTERVAL * REQUEST_RETRY_COUNT, /* retryintervaldefault */
-       60 /* retryintervalmax */
+       60, /* retryintervalmax */
+       tlslistener
     },
     {   "tcp", /* TCP, assuming RAD_TCP defined as 2 */
        NULL, /* secretdefault */
@@ -108,7 +113,8 @@ static const struct protodefs protodefs[] = {
        0, /* retrycountdefault */
        0, /* retrycountmax */
        REQUEST_RETRY_INTERVAL * REQUEST_RETRY_COUNT, /* retryintervaldefault */
-       60 /* retryintervalmax */
+       60, /* retryintervalmax */
+       tcplistener
     },
     {   NULL
     }
@@ -550,6 +556,8 @@ void removeclientrqs(struct client *client) {
     
     for (entry = list_first(srvconfs); entry; entry = list_next(entry)) {
        server = ((struct clsrvconf *)entry->data)->servers;
+       if (!server)
+           continue;
        pthread_mutex_lock(&server->newrq_mutex);
        for (i = 0; i < MAX_REQUESTS; i++) {
            rq = server->requests + i;
@@ -1078,6 +1086,73 @@ unsigned char *radtlsget(SSL *ssl, int timeout) {
     return rad;
 }
 
+/* timeout in seconds, 0 means no timeout (blocking), returns when num bytes have been read, or timeout */
+/* returns 0 on timeout, -1 on error and num if ok */
+int tcpreadtimeout(int s, unsigned char *buf, int num, int timeout) {
+    int ndesc, cnt, len;
+    fd_set readfds, writefds;
+    struct timeval timer;
+    
+    if (s < 0)
+       return -1;
+    /* make socket non-blocking? */
+    for (len = 0; len < num; len += cnt) {
+       FD_ZERO(&readfds);
+       FD_SET(s, &readfds);
+       writefds = readfds;
+       if (timeout) {
+           timer.tv_sec = timeout;
+           timer.tv_usec = 0;
+       }
+       ndesc = select(s + 1, &readfds, &writefds, NULL, timeout ? &timer : NULL);
+       if (ndesc < 1)
+           return ndesc;
+
+       cnt = read(s, buf + len, num - len);
+       if (cnt <= 0)
+           return -1;
+    }
+    return num;
+}
+
+/* timeout in seconds, 0 means no timeout (blocking) */
+unsigned char *radtcpget(int s, int timeout) {
+    int cnt, len;
+    unsigned char buf[4], *rad;
+
+    for (;;) {
+       cnt = tcpreadtimeout(s, buf, 4, timeout);
+       if (cnt < 1) {
+           debug(DBG_DBG, cnt ? "radtcpget: connection lost" : "radtcpget: timeout");
+           return NULL;
+       }
+
+       len = RADLEN(buf);
+       rad = malloc(len);
+       if (!rad) {
+           debug(DBG_ERR, "radtcpget: malloc failed");
+           continue;
+       }
+       memcpy(rad, buf, 4);
+       
+       cnt = tcpreadtimeout(s, rad + 4, len - 4, timeout);
+       if (cnt < 1) {
+           debug(DBG_DBG, cnt ? "radtcpget: connection lost" : "radtcpget: timeout");
+           free(rad);
+           return NULL;
+       }
+       
+       if (len >= 20)
+           break;
+       
+       free(rad);
+       debug(DBG_WARN, "radtcpget: packet smaller than minimum radius size");
+    }
+    
+    debug(DBG_DBG, "radtcpget: got %d bytes", len);
+    return rad;
+}
+
 int clientradputudp(struct server *server, unsigned char *rad) {
     size_t len;
     struct sockaddr_storage sa;
@@ -2616,15 +2691,15 @@ void *tlsserverwr(void *arg) {
     struct replyq *replyq;
     struct reply *reply;
     
-    debug(DBG_DBG, "tlsserverwr starting for %s", client->conf->host);
+    debug(DBG_DBG, "tlsserverwr: starting for %s", client->conf->host);
     replyq = client->replyq;
     for (;;) {
        pthread_mutex_lock(&replyq->mutex);
        while (!list_first(replyq->replies)) {
            if (client->ssl) {      
-               debug(DBG_DBG, "tls server writer, waiting for signal");
+               debug(DBG_DBG, "tlsserverwr: waiting for signal");
                pthread_cond_wait(&replyq->cond, &replyq->mutex);
-               debug(DBG_DBG, "tls server writer, got signal");
+               debug(DBG_DBG, "tlsserverwr: got signal");
            }
            if (!client->ssl) {
                /* ssl might have changed while waiting */
@@ -2637,7 +2712,7 @@ void *tlsserverwr(void *arg) {
        pthread_mutex_unlock(&replyq->mutex);
        cnt = SSL_write(client->ssl, reply->buf, RADLEN(reply->buf));
        if (cnt > 0)
-           debug(DBG_DBG, "tlsserverwr: Sent %d bytes, Radius packet of length %d",
+           debug(DBG_DBG, "tlsserverwr: sent %d bytes, Radius packet of length %d",
                  cnt, RADLEN(reply->buf));
        else
            while ((error = ERR_get_error()))
@@ -2651,7 +2726,7 @@ void tlsserverrd(struct client *client) {
     struct request rq;
     pthread_t tlsserverwrth;
     
-    debug(DBG_DBG, "tlsserverrd starting for %s", client->conf->host);
+    debug(DBG_DBG, "tlsserverrd: starting for %s", client->conf->host);
     
     if (pthread_create(&tlsserverwrth, NULL, tlsserverwr, (void *)client)) {
        debug(DBG_ERR, "tlsserverrd: pthread_create failed");
@@ -2677,7 +2752,7 @@ void tlsserverrd(struct client *client) {
     debug(DBG_DBG, "tlsserverrd: waiting for writer to end");
     pthread_join(tlsserverwrth, NULL);
     removeclientrqs(client);
-    debug(DBG_DBG, "tlsserverrd for %s exiting", client->conf->host);
+    debug(DBG_DBG, "tlsserverrd: reader for %s exiting", client->conf->host);
 }
 
 void *tlsservernew(void *arg) {
@@ -2693,10 +2768,10 @@ void *tlsservernew(void *arg) {
 
     s = *(int *)arg;
     if (getpeername(s, (struct sockaddr *)&from, &fromlen)) {
-       debug(DBG_DBG, "tlsserverrd: getpeername failed, exiting");
+       debug(DBG_DBG, "tlsservernew: getpeername failed, exiting");
        goto exit;
     }
-    debug(DBG_WARN, "incoming TLS connection from %s", addr2string((struct sockaddr *)&from, fromlen));
+    debug(DBG_WARN, "tlsservernew: incoming TLS connection from %s", addr2string((struct sockaddr *)&from, fromlen));
 
     conf = find_conf(RAD_TLS, (struct sockaddr *)&from, clconfs, &cur);
     if (conf) {
@@ -2705,8 +2780,8 @@ void *tlsservernew(void *arg) {
 
        if (SSL_accept(ssl) <= 0) {
            while ((error = ERR_get_error()))
-               debug(DBG_ERR, "tlsserverrd: SSL: %s", ERR_error_string(error, NULL));
-           debug(DBG_ERR, "SSL_accept failed");
+               debug(DBG_ERR, "tlsservernew: SSL: %s", ERR_error_string(error, NULL));
+           debug(DBG_ERR, "tlsservernew: SSL_accept failed");
            goto exit;
        }
        cert = verifytlscert(ssl);
@@ -2723,12 +2798,12 @@ void *tlsservernew(void *arg) {
                tlsserverrd(client);
                removeclient(client);
            } else
-               debug(DBG_WARN, "Failed to create new client instance");
+               debug(DBG_WARN, "tlsservernew: failed to create new client instance");
            goto exit;
        }
        conf = find_conf(RAD_TLS, (struct sockaddr *)&from, clconfs, &cur);
     }
-    debug(DBG_WARN, "ignoring request, no matching TLS client");
+    debug(DBG_WARN, "tlsservernew: ignoring request, no matching TLS client");
     if (cert)
        X509_free(cert);
 
@@ -2765,6 +2840,133 @@ void *tlslistener(void *arg) {
     return NULL;
 }
 
+void *tcpserverwr(void *arg) {
+    int cnt;
+    struct client *client = (struct client *)arg;
+    struct replyq *replyq;
+    struct reply *reply;
+    
+    debug(DBG_DBG, "tcpserverwr: starting for %s", client->conf->host);
+    replyq = client->replyq;
+    for (;;) {
+       pthread_mutex_lock(&replyq->mutex);
+       while (!list_first(replyq->replies)) {
+           if (client->s >= 0) {           
+               debug(DBG_DBG, "tcpserverwr: waiting for signal");
+               pthread_cond_wait(&replyq->cond, &replyq->mutex);
+               debug(DBG_DBG, "tcpserverwr: got signal");
+           }
+           if (client->s < 0) {
+               /* s might have changed while waiting */
+               pthread_mutex_unlock(&replyq->mutex);
+               debug(DBG_DBG, "tcpserverwr: exiting as requested");
+               pthread_exit(NULL);
+           }
+       }
+       reply = (struct reply *)list_shift(replyq->replies);
+       pthread_mutex_unlock(&replyq->mutex);
+       cnt = write(client->s, reply->buf, RADLEN(reply->buf));
+       if (cnt > 0)
+           debug(DBG_DBG, "tcpserverwr: sent %d bytes, Radius packet of length %d",
+                 cnt, RADLEN(reply->buf));
+       else
+           debug(DBG_ERR, "tcpserverwr: write error for %s", client->conf->host);
+       free(reply->buf);
+       free(reply);
+    }
+}
+
+void tcpserverrd(struct client *client) {
+    struct request rq;
+    pthread_t tcpserverwrth;
+    
+    debug(DBG_DBG, "tcpserverrd: starting for %s", client->conf->host);
+    
+    if (pthread_create(&tcpserverwrth, NULL, tcpserverwr, (void *)client)) {
+       debug(DBG_ERR, "tcpserverrd: pthread_create failed");
+       return;
+    }
+
+    for (;;) {
+       memset(&rq, 0, sizeof(struct request));
+       rq.buf = radtcpget(client->s, 0);
+       if (!rq.buf)
+           break;
+       debug(DBG_DBG, "tcpserverrd: got Radius message from %s", client->conf->host);
+       rq.from = client;
+       radsrv(&rq);
+    }
+    
+    debug(DBG_ERR, "tcpserverrd: connection lost");
+    /* stop writer by setting s to -1 and give signal in case waiting for data */
+    client->s = -1;
+    pthread_mutex_lock(&client->replyq->mutex);
+    pthread_cond_signal(&client->replyq->cond);
+    pthread_mutex_unlock(&client->replyq->mutex);
+    debug(DBG_DBG, "tcpserverrd: waiting for writer to end");
+    pthread_join(tcpserverwrth, NULL);
+    removeclientrqs(client);
+    debug(DBG_DBG, "tcpserverrd: reader for %s exiting", client->conf->host);
+}
+
+void *tcpservernew(void *arg) {
+    int s;
+    struct sockaddr_storage from;
+    size_t fromlen = sizeof(from);
+    struct clsrvconf *conf;
+    struct client *client;
+
+    s = *(int *)arg;
+    if (getpeername(s, (struct sockaddr *)&from, &fromlen)) {
+       debug(DBG_DBG, "tcpservernew: getpeername failed, exiting");
+       goto exit;
+    }
+    debug(DBG_WARN, "tcpservernew: incoming TCP connection from %s", addr2string((struct sockaddr *)&from, fromlen));
+
+    conf = find_conf(RAD_TCP, (struct sockaddr *)&from, clconfs, NULL);
+    if (conf) {
+       client = addclient(conf);
+       if (client) {
+           client->s = s;
+           tcpserverrd(client);
+           removeclient(client);
+       } else
+           debug(DBG_WARN, "tcpservernew: failed to create new client instance");
+    } else
+       debug(DBG_WARN, "tcpservernew: ignoring request, no matching TCP client");
+
+ exit:
+    shutdown(s, SHUT_RDWR);
+    close(s);
+    pthread_exit(NULL);
+}
+
+void *tcplistener(void *arg) {
+    pthread_t tcpserverth;
+    int s;
+    struct sockaddr_storage from;
+    size_t fromlen = sizeof(from);
+    struct listenerarg *larg = (struct listenerarg *)arg;
+
+    listen(larg->s, 0);
+
+    for (;;) {
+       s = accept(larg->s, (struct sockaddr *)&from, &fromlen);
+       if (s < 0) {
+           debug(DBG_WARN, "accept failed");
+           continue;
+       }
+       if (pthread_create(&tcpserverth, NULL, tcpservernew, (void *)&s)) {
+           debug(DBG_ERR, "tcplistener: pthread_create failed");
+           shutdown(s, SHUT_RDWR);
+           close(s);
+           continue;
+       }
+       pthread_detach(tcpserverth);
+    }
+    return NULL;
+}
+
 void createlistener(uint8_t type, char *arg, uint8_t acconly) {
     pthread_t th;
     struct clsrvconf *listenres;
@@ -2799,15 +3001,15 @@ void createlistener(uint8_t type, char *arg, uint8_t acconly) {
             debugx(1, DBG_ERR, "malloc failed");
         larg->s = s;
         larg->acconly = acconly;
-       if (pthread_create(&th, NULL, type == RAD_UDP ? udpserverrd : tlslistener, (void *)larg))
+       if (pthread_create(&th, NULL, protodefs[type].listener, (void *)larg))
             debugx(1, DBG_ERR, "pthread_create failed");
        pthread_detach(th);
     }
     if (!larg)
        debugx(1, DBG_ERR, "createlistener: socket/bind failed");
     
-    debug(DBG_WARN, "createlistener: listening for %s%s on %s:%s",
-         type == RAD_UDP ? "UDP" : "TCP", acconly ? " accounting" : "",
+    debug(DBG_WARN, "createlistener: listening for %s%s on %s:%s", protodefs[type].name,
+         acconly ? " accounting" : "",
          listenres->host ? listenres->host : "*", listenres->port);
     freeclsrvres(listenres);
 }
@@ -3953,6 +4155,9 @@ int main(int argc, char **argv) {
        if (pthread_create(&udpclient6rdth, NULL, udpclientrd, (void *)&udp_client6_sock))
            debugx(1, DBG_ERR, "pthread_create failed");
     
+    if (find_conf_type(RAD_TCP, clconfs, NULL))
+       createlisteners(RAD_TCP, options.listentcp, 0);
+    
     if (find_conf_type(RAD_TLS, clconfs, NULL))
        createlisteners(RAD_TLS, options.listentls, 0);
     
index 2048354..171c3fd 100644 (file)
@@ -115,6 +115,7 @@ struct clsrvconf {
 
 struct client {
     struct clsrvconf *conf;
+    int s; /* for tcp */
     SSL *ssl;
     struct replyq *replyq;
 };
@@ -173,6 +174,7 @@ struct protodefs {
     uint8_t retrycountmax;
     uint8_t retryintervaldefault;
     uint8_t retryintervalmax;
+    void *(*listener)(void*);
 };
 
 #define RADLEN(x) ntohs(((uint16_t *)(x))[1])