some code improvemetns, more efficiently removing outstanding requests when removing...
[radsecproxy.git] / tls.c
diff --git a/tls.c b/tls.c
index ecbbbec..5014b46 100644 (file)
--- a/tls.c
+++ b/tls.c
@@ -36,6 +36,7 @@ int tlsconnect(struct server *server, struct timeval *when, int timeout, char *t
     struct timeval now;
     time_t elapsed;
     X509 *cert;
+    SSL_CTX *ctx = NULL;
     unsigned long error;
     
     debug(DBG_DBG, "tlsconnect: called from %s", text);
@@ -81,7 +82,14 @@ int tlsconnect(struct server *server, struct timeval *when, int timeout, char *t
        }
        
        SSL_free(server->ssl);
-       server->ssl = SSL_new(server->conf->ssl_ctx);
+       server->ssl = NULL;
+       ctx = tlsgetctx(RAD_TLS, server->conf->tlsconf);
+       if (!ctx)
+           continue;
+       server->ssl = SSL_new(ctx);
+       if (!server->ssl)
+           continue;
+
        SSL_set_fd(server->ssl, server->sock);
        if (SSL_connect(server->ssl) <= 0) {
            while ((error = ERR_get_error()))
@@ -98,6 +106,7 @@ int tlsconnect(struct server *server, struct timeval *when, int timeout, char *t
        X509_free(cert);
     }
     debug(DBG_WARN, "tlsconnect: TLS connection to %s port %s up", server->conf->host, server->conf->port);
+    server->connectionok = 1;
     gettimeofday(&server->lastconnecttry, NULL);
     pthread_mutex_unlock(&server->lock);
     return 1;
@@ -186,21 +195,17 @@ int clientradputtls(struct server *server, unsigned char *rad) {
     int cnt;
     size_t len;
     unsigned long error;
-    struct timeval lastconnecttry;
     struct clsrvconf *conf = server->conf;
-    
+
+    if (!server->connectionok)
+       return 0;
     len = RADLEN(rad);
-    lastconnecttry = server->lastconnecttry;
-    while ((cnt = SSL_write(server->ssl, rad, len)) <= 0) {
+    if ((cnt = SSL_write(server->ssl, rad, len)) <= 0) {
        while ((error = ERR_get_error()))
            debug(DBG_ERR, "clientradputtls: TLS: %s", ERR_error_string(error, NULL));
-       if (server->dynamiclookuparg)
-           return 0;
-       tlsconnect(server, &lastconnecttry, 0, "clientradputtls");
-       lastconnecttry = server->lastconnecttry;
+       return 0;
     }
 
-    server->connectionok = 1;
     debug(DBG_DBG, "clientradputtls: Sent %d bytes, Radius packet of length %d to TLS peer %s", cnt, len, conf->host);
     return 1;
 }
@@ -221,8 +226,8 @@ void *tlsclientrd(void *arg) {
            continue;
        }
 
-       if (!replyh(server, buf))
-           free(buf);
+       replyh(server, buf);
+
        if (server->dynamiclookuparg) {
            gettimeofday(&now, NULL);
            if (now.tv_sec - server->lastreply.tv_sec > IDLE_TIMEOUT) {
@@ -241,9 +246,9 @@ void *tlsserverwr(void *arg) {
     unsigned long error;
     struct client *client = (struct client *)arg;
     struct queue *replyq;
-    struct reply *reply;
+    struct request *reply;
     
-    debug(DBG_DBG, "tlsserverwr: starting for %s", client->conf->host);
+    debug(DBG_DBG, "tlsserverwr: starting for %s", addr2string(client->addr));
     replyq = client->replyq;
     for (;;) {
        pthread_mutex_lock(&replyq->mutex);
@@ -261,25 +266,25 @@ void *tlsserverwr(void *arg) {
                pthread_exit(NULL);
            }
        }
-       reply = (struct reply *)list_shift(replyq->entries);
+       reply = (struct request *)list_shift(replyq->entries);
        pthread_mutex_unlock(&replyq->mutex);
-       cnt = SSL_write(client->ssl, reply->buf, RADLEN(reply->buf));
+       cnt = SSL_write(client->ssl, reply->replybuf, RADLEN(reply->replybuf));
        if (cnt > 0)
-           debug(DBG_DBG, "tlsserverwr: sent %d bytes, Radius packet of length %d",
-                 cnt, RADLEN(reply->buf));
+           debug(DBG_DBG, "tlsserverwr: sent %d bytes, Radius packet of length %d to %s",
+                 cnt, RADLEN(reply->replybuf), addr2string(client->addr));
        else
            while ((error = ERR_get_error()))
                debug(DBG_ERR, "tlsserverwr: SSL: %s", ERR_error_string(error, NULL));
-       free(reply->buf);
-       free(reply);
+       freerq(reply);
     }
 }
 
 void tlsserverrd(struct client *client) {
-    struct request rq;
+    struct request *rq;
+    uint8_t *buf;
     pthread_t tlsserverwrth;
     
-    debug(DBG_DBG, "tlsserverrd: starting for %s", client->conf->host);
+    debug(DBG_DBG, "tlsserverrd: starting for %s", addr2string(client->addr));
     
     if (pthread_create(&tlsserverwrth, NULL, tlsserverwr, (void *)client)) {
        debug(DBG_ERR, "tlsserverrd: pthread_create failed");
@@ -287,16 +292,21 @@ void tlsserverrd(struct client *client) {
     }
 
     for (;;) {
-       memset(&rq, 0, sizeof(struct request));
-       rq.buf = radtlsget(client->ssl, 0);
-       if (!rq.buf) {
-           debug(DBG_ERR, "tlsserverrd: connection from %s lost", client->conf->host);
+       buf = radtlsget(client->ssl, 0);
+       if (!buf) {
+           debug(DBG_ERR, "tlsserverrd: connection from %s lost", addr2string(client->addr));
            break;
        }
-       debug(DBG_DBG, "tlsserverrd: got Radius message from %s", client->conf->host);
-       rq.from = client;
-       if (!radsrv(&rq)) {
-           debug(DBG_ERR, "tlsserverrd: message authentication/validation failed, closing connection from %s", client->conf->host);
+       debug(DBG_DBG, "tlsserverrd: got Radius message from %s", addr2string(client->addr));
+       rq = newrequest();
+       if (!rq) {
+           free(buf);
+           continue;
+       }
+       rq->buf = buf;
+       rq->from = client;
+       if (!radsrv(rq)) {
+           debug(DBG_ERR, "tlsserverrd: message authentication/validation failed, closing connection from %s", addr2string(client->addr));
            break;
        }
     }
@@ -308,8 +318,7 @@ void tlsserverrd(struct client *client) {
     pthread_mutex_unlock(&client->replyq->mutex);
     debug(DBG_DBG, "tlsserverrd: waiting for writer to end");
     pthread_join(tlsserverwrth, NULL);
-    removeclientrqs(client);
-    debug(DBG_DBG, "tlsserverrd: reader for %s exiting", client->conf->host);
+    debug(DBG_DBG, "tlsserverrd: reader for %s exiting", addr2string(client->addr));
 }
 
 void *tlsservernew(void *arg) {
@@ -320,6 +329,7 @@ void *tlsservernew(void *arg) {
     struct list_node *cur = NULL;
     SSL *ssl = NULL;
     X509 *cert = NULL;
+    SSL_CTX *ctx = NULL;
     unsigned long error;
     struct client *client;
 
@@ -328,11 +338,16 @@ void *tlsservernew(void *arg) {
        debug(DBG_DBG, "tlsservernew: getpeername failed, exiting");
        goto exit;
     }
-    debug(DBG_WARN, "tlsservernew: incoming TLS connection from %s", addr2string((struct sockaddr *)&from, fromlen));
+    debug(DBG_WARN, "tlsservernew: incoming TLS connection from %s", addr2string((struct sockaddr *)&from));
 
     conf = find_clconf(RAD_TLS, (struct sockaddr *)&from, &cur);
     if (conf) {
-       ssl = SSL_new(conf->ssl_ctx);
+       ctx = tlsgetctx(RAD_TLS, conf->tlsconf);
+       if (!ctx)
+           goto exit;
+       ssl = SSL_new(ctx);
+       if (!ssl)
+           goto exit;
        SSL_set_fd(ssl, s);
 
        if (SSL_accept(ssl) <= 0) {
@@ -349,9 +364,10 @@ void *tlsservernew(void *arg) {
     while (conf) {
        if (verifyconfcert(cert, conf)) {
            X509_free(cert);
-           client = addclient(conf);
+           client = addclient(conf, 1);
            if (client) {
                client->ssl = ssl;
+               client->addr = addr_copy((struct sockaddr *)&from);
                tlsserverrd(client);
                removeclient(client);
            } else
@@ -365,7 +381,10 @@ void *tlsservernew(void *arg) {
        X509_free(cert);
 
  exit:
-    SSL_free(ssl);
+    if (ssl) {
+       SSL_shutdown(ssl);
+       SSL_free(ssl);
+    }
     ERR_remove_state(0);
     shutdown(s, SHUT_RDWR);
     close(s);