some code improvemetns, more efficiently removing outstanding requests when removing...
[radsecproxy.git] / radsecproxy.c
index 80f6a47..280cb41 100644 (file)
 */
 
 /* Bugs:
- * Multiple outgoing connections if not enough IDs? (multiple servers per conf?)
- * Also useful for TCP accounting which is not yet supported?
- * We are not removing client requests from dynamic servers, see removeclientrqs()
- * Reserve ID 0 for statusserver requests?
+ * May segfault when dtls connections go down? More testing needed
  * Need to remove UDP clients when no activity for a while...
  * Remove expired stuff from clients request list?
+ * We are not removing client requests from dynamic servers, see removeclientrqs()
+ * Multiple outgoing connections if not enough IDs? (multiple servers per conf?)
+ * Useful for TCP accounting? Now we require separate server config for alt port
  */
 
 #include <signal.h>
@@ -390,52 +390,6 @@ void freeclsrvres(struct clsrvconf *res) {
     free(res);
 }
 
-int bindtoaddr(struct addrinfo *addrinfo, int family, int reuse, int v6only) {
-    int s, on = 1;
-    struct addrinfo *res;
-
-    for (res = addrinfo; res; res = res->ai_next) {
-       if (family != AF_UNSPEC && family != res->ai_family)
-           continue;
-       s = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
-       if (s < 0) {
-           debug(DBG_WARN, "bindtoaddr: socket failed");
-           continue;
-       }
-       if (reuse)
-           setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on));
-#ifdef IPV6_V6ONLY
-       if (v6only)
-           setsockopt(s, IPPROTO_IPV6, IPV6_V6ONLY, &on, sizeof(on));
-#endif
-       if (!bind(s, res->ai_addr, res->ai_addrlen))
-           return s;
-       debug(DBG_WARN, "bindtoaddr: bind failed");
-       close(s);
-    }
-    return -1;
-}
-       
-int connecttcp(struct addrinfo *addrinfo, struct addrinfo *src) {
-    int s;
-    struct addrinfo *res;
-
-    s = -1;
-    for (res = addrinfo; res; res = res->ai_next) {
-       s = bindtoaddr(src, res->ai_family, 1, 1);
-        if (s < 0) {
-            debug(DBG_WARN, "connecttoserver: socket failed");
-            continue;
-        }
-        if (connect(s, res->ai_addr, res->ai_addrlen) == 0)
-            break;
-        debug(DBG_WARN, "connecttoserver: connect failed");
-        close(s);
-        s = -1;
-    }
-    return s;
-}        
-
 /* returns 1 if the len first bits are equal, else 0 */
 int prefixmatch(void *a1, void *a2, uint8_t len) {
     static uint8_t mask[] = { 0, 0x80, 0xc0, 0xe0, 0xf0, 0xf8, 0xfc, 0xfe };
@@ -590,6 +544,24 @@ struct client *addclient(struct clsrvconf *conf, uint8_t lock) {
     return new;
 }
 
+void removeclientrqs(struct client *client) {
+    struct request *rq;
+    struct rqout *rqout;
+    int i;
+
+    for (i = 0; i < MAX_REQUESTS; i++) {
+       rq = client->rqs[i];
+       if (!rq)
+           continue;
+       rqout = rq->to->requests + rq->newid;
+       pthread_mutex_lock(rqout->lock);
+       if (rqout->rq == rq) /* still pointing to our request */
+           freerqoutdata(rqout);
+       pthread_mutex_unlock(rqout->lock);                              
+       freerq(rq);
+    }
+}
+
 void removeclient(struct client *client) {
     struct clsrvconf *conf;
     
@@ -598,6 +570,7 @@ void removeclient(struct client *client) {
     conf = client->conf;
     pthread_mutex_lock(conf->lock);
     if (conf->clients) {
+       removeclientrqs(client);
        removequeue(client->replyq);
        list_removedata(conf->clients, client);
        free(client->addr);
@@ -606,26 +579,6 @@ void removeclient(struct client *client) {
     pthread_mutex_unlock(conf->lock);
 }
 
-void removeclientrqs(struct client *client) {
-    struct list_node *entry;
-    struct server *server;
-    struct rqout *rqout;
-    int i;
-    
-    for (entry = list_first(srvconfs); entry; entry = list_next(entry)) {
-       server = ((struct clsrvconf *)entry->data)->servers;
-       if (!server)
-           continue;
-       for (i = 0; i < MAX_REQUESTS; i++) {
-           rqout = server->requests + i;
-           pthread_mutex_lock(rqout->lock);
-           if (rqout->rq && rqout->rq->from == client)
-               freerqoutdata(rqout);
-           pthread_mutex_unlock(rqout->lock);
-       }
-    }
-}
-
 void freeserver(struct server *server, uint8_t destroymutex) {
     struct rqout *rqout, *end;
 
@@ -955,21 +908,25 @@ void freerqoutdata(struct rqout *rqout) {
     memset(&rqout->expiry, 0, sizeof(struct timeval));
 }
 
-void sendrq(struct server *to, struct request *rq) {
-    int i;
+void sendrq(struct request *rq) {
+    int i, start;
+    struct server *to = rq->to;
     
+    start = to->conf->statusserver ? 1 : 0;
     pthread_mutex_lock(&to->newrq_mutex);
-    /* might simplify if only try nextid, might be ok */
-    for (i = to->nextid; i < MAX_REQUESTS; i++) {
-       if (!to->requests[i].rq) {
-           pthread_mutex_lock(to->requests[i].lock);
-           if (!to->requests[i].rq)
-               break;
-           pthread_mutex_unlock(to->requests[i].lock);
+    if (start && rq->msg->code == RAD_Status_Server) {
+       pthread_mutex_lock(to->requests[0].lock);
+       if (to->requests[0].rq) {
+           pthread_mutex_unlock(to->requests[0].lock);
+           debug(DBG_WARN, "sendrq: status server already in queue, dropping request");
+           goto errexit;
        }
-    }
-    if (i == MAX_REQUESTS) {
-       for (i = 0; i < to->nextid; i++) {
+       i = 0;
+    } else {
+       if (!to->nextid)
+           to->nextid = start;
+       /* might simplify if only try nextid, might be ok */
+       for (i = to->nextid; i < MAX_REQUESTS; i++) {
            if (!to->requests[i].rq) {
                pthread_mutex_lock(to->requests[i].lock);
                if (!to->requests[i].rq)
@@ -977,28 +934,35 @@ void sendrq(struct server *to, struct request *rq) {
                pthread_mutex_unlock(to->requests[i].lock);
            }
        }
-       if (i == to->nextid) {
-           debug(DBG_WARN, "sendrq: no room in queue, dropping request");
-           rmclientrq(rq, rq->msg->id);
-           freerq(rq);
-           goto exit;
+       if (i == MAX_REQUESTS) {
+           for (i = start; i < to->nextid; i++) {
+               if (!to->requests[i].rq) {
+                   pthread_mutex_lock(to->requests[i].lock);
+                   if (!to->requests[i].rq)
+                       break;
+                   pthread_mutex_unlock(to->requests[i].lock);
+               }
+           }
+           if (i == to->nextid) {
+               debug(DBG_WARN, "sendrq: no room in queue, dropping request");
+               goto errexit;
+           }
        }
     }
-
+    rq->newid = (uint8_t)i;
     rq->msg->id = (uint8_t)i;
     rq->buf = radmsg2buf(rq->msg, (uint8_t *)to->conf->secret);
     if (!rq->buf) {
        pthread_mutex_unlock(to->requests[i].lock);
        debug(DBG_ERR, "sendrq: radmsg2buf failed");
-       rmclientrq(rq, rq->msg->id);
-       freerq(rq);
-       goto exit;
+       goto errexit;
     }
     
     debug(DBG_DBG, "sendrq: inserting packet with id %d in queue for %s", i, to->conf->host);
     to->requests[i].rq = rq;
     pthread_mutex_unlock(to->requests[i].lock);
-    to->nextid = i + 1;
+    if (i >= start) /* i is not reserved for statusserver */
+       to->nextid = i + 1;
 
     if (!to->newrq) {
        to->newrq = 1;
@@ -1006,7 +970,13 @@ void sendrq(struct server *to, struct request *rq) {
        pthread_cond_signal(&to->newrq_cond);
     }
 
- exit:
+    pthread_mutex_unlock(&to->newrq_mutex);
+    return;
+
+ errexit:
+    if (rq->from)
+       rmclientrq(rq, rq->msg->id);
+    freerq(rq);
     pthread_mutex_unlock(&to->newrq_mutex);
 }
 
@@ -1869,7 +1839,8 @@ int radsrv(struct request *rq) {
        goto rmclrqexit;
     
     free(userascii);
-    sendrq(to, rq);
+    rq->to = to;
+    sendrq(rq);
     return 1;
     
  rmclrqexit:
@@ -2047,7 +2018,8 @@ void *clientwr(void *arg) {
     struct server *server = (struct server *)arg;
     struct rqout *rqout = NULL;
     pthread_t clientrdth;
-    int i, secs, dynconffail = 0;
+    int i, dynconffail = 0;
+    time_t secs;
     uint8_t rnd;
     struct timeval now, laststatsrv;
     struct timespec timeout;
@@ -2093,6 +2065,8 @@ void *clientwr(void *arg) {
            rnd /= 32;
            if (conf->statusserver) {
                secs = server->lastrcv.tv_sec > laststatsrv.tv_sec ? server->lastrcv.tv_sec : laststatsrv.tv_sec;
+               if (now.tv_sec - secs > STATUS_SERVER_PERIOD)
+                   secs = now.tv_sec;
                if (!timeout.tv_sec || timeout.tv_sec > secs + STATUS_SERVER_PERIOD + rnd)
                    timeout.tv_sec = secs + STATUS_SERVER_PERIOD + rnd;
            } else {
@@ -2168,15 +2142,16 @@ void *clientwr(void *arg) {
            conf->pdef->clientradput(server, rqout->rq->buf);
            pthread_mutex_unlock(rqout->lock);
        }
-       if (conf->statusserver) {
+       if (conf->statusserver && server->connectionok) {
            secs = server->lastrcv.tv_sec > laststatsrv.tv_sec ? server->lastrcv.tv_sec : laststatsrv.tv_sec;
            gettimeofday(&now, NULL);
            if (now.tv_sec - secs > STATUS_SERVER_PERIOD) {
                laststatsrv = now;
                statsrvrq = createstatsrvrq();
                if (statsrvrq) {
+                   statsrvrq->to = server;
                    debug(DBG_DBG, "clientwr: sending status server to %s", conf->host);
-                   sendrq(server, statsrvrq);
+                   sendrq(statsrvrq);
                }
            }
        }
@@ -2374,24 +2349,50 @@ SSL_CTX *tlscreatectx(uint8_t type, struct tls *conf) {
     return ctx;
 }
 
-SSL_CTX *tlsgetctx(uint8_t type, char *alt1, char *alt2) {
+struct tls *tlsgettls(char *alt1, char *alt2) {
     struct tls *t;
 
     t = hash_read(tlsconfs, alt1, strlen(alt1));
-    if (!t) {
+    if (!t)
        t = hash_read(tlsconfs, alt2, strlen(alt2));
-       if (!t)
-           return NULL;
-    }
+    return t;
+}
 
+SSL_CTX *tlsgetctx(uint8_t type, struct tls *t) {
+    struct timeval now;
+    
+    if (!t)
+       return NULL;
+    gettimeofday(&now, NULL);
+    
     switch (type) {
     case RAD_TLS:
-       if (!t->tlsctx)
+       if (t->tlsexpiry && t->tlsctx) {
+           if (t->tlsexpiry < now.tv_sec) {
+               t->tlsexpiry = now.tv_sec + t->cacheexpiry;
+               SSL_CTX_free(t->tlsctx);
+               return t->tlsctx = tlscreatectx(RAD_TLS, t);
+           }
+       }
+       if (!t->tlsctx) {
            t->tlsctx = tlscreatectx(RAD_TLS, t);
+           if (t->cacheexpiry)
+               t->tlsexpiry = now.tv_sec + t->cacheexpiry;
+       }
        return t->tlsctx;
     case RAD_DTLS:
-       if (!t->dtlsctx)
+       if (t->dtlsexpiry && t->dtlsctx) {
+           if (t->dtlsexpiry < now.tv_sec) {
+               t->dtlsexpiry = now.tv_sec + t->cacheexpiry;
+               SSL_CTX_free(t->dtlsctx);
+               return t->dtlsctx = tlscreatectx(RAD_DTLS, t);
+           }
+       }
+       if (!t->dtlsctx) {
            t->dtlsctx = tlscreatectx(RAD_DTLS, t);
+           if (t->cacheexpiry)
+               t->dtlsexpiry = now.tv_sec + t->cacheexpiry;
+       }
        return t->dtlsctx;
     }
     return NULL;
@@ -3052,8 +3053,8 @@ int confclient_cb(struct gconffile **cf, void *arg, char *block, char *opt, char
     free(conftype);
     
     if (conf->type == RAD_TLS || conf->type == RAD_DTLS) {
-       conf->ssl_ctx = conf->tls ? tlsgetctx(conf->type, conf->tls, NULL) : tlsgetctx(conf->type, "defaultclient", "default");
-       if (!conf->ssl_ctx)
+       conf->tlsconf = conf->tls ? tlsgettls(conf->tls, NULL) : tlsgettls("defaultclient", "default");
+       if (!conf->tlsconf)
            debugx(1, DBG_ERR, "error in block %s, no tls context defined", block);
        if (conf->matchcertattr && !addmatchcertattr(conf))
            debugx(1, DBG_ERR, "error in block %s, invalid MatchCertificateAttributeValue", block);
@@ -3103,8 +3104,8 @@ int confclient_cb(struct gconffile **cf, void *arg, char *block, char *opt, char
 
 int compileserverconfig(struct clsrvconf *conf, const char *block) {
     if (conf->type == RAD_TLS || conf->type == RAD_DTLS) {
-       conf->ssl_ctx = conf->tls ? tlsgetctx(conf->type, conf->tls, NULL) : tlsgetctx(conf->type, "defaultserver", "default");
-       if (!conf->ssl_ctx) {
+       conf->tlsconf = conf->tls ? tlsgettls(conf->tls, NULL) : tlsgettls("defaultserver", "default");
+       if (!conf->tlsconf) {
            debug(DBG_ERR, "error in block %s, no tls context defined", block);
            return 0;
        }
@@ -3292,6 +3293,7 @@ int confrealm_cb(struct gconffile **cf, void *arg, char *block, char *opt, char
 
 int conftls_cb(struct gconffile **cf, void *arg, char *block, char *opt, char *val) {
     struct tls *conf;
+    long int expiry = LONG_MIN;
     
     debug(DBG_DBG, "conftls_cb called for %s", block);
     
@@ -3308,6 +3310,7 @@ int conftls_cb(struct gconffile **cf, void *arg, char *block, char *opt, char *v
                     "CertificateFile", CONF_STR, &conf->certfile,
                     "CertificateKeyFile", CONF_STR, &conf->certkeyfile,
                     "CertificateKeyPassword", CONF_STR, &conf->certkeypwd,
+                    "CacheExpiry", CONF_LINT, &expiry,
                     "CRLCheck", CONF_BLN, &conf->crlcheck,
                     NULL
                          )) {
@@ -3322,6 +3325,13 @@ int conftls_cb(struct gconffile **cf, void *arg, char *block, char *opt, char *v
        debug(DBG_ERR, "conftls_cb: CA Certificate file or path need to be specified in block %s", val);
        goto errexit;
     }
+    if (expiry != LONG_MIN) {
+       if (expiry < 0) {
+           debug(DBG_ERR, "error in block %s, value of option CacheExpiry is %ld, may not be negative", val, expiry);
+           goto errexit;
+       }
+       conf->cacheexpiry = expiry;
+    }    
 
     conf->name = stringcopy(val, 0);
     if (!conf->name) {