renamed some stuff, added client state for received rqs etc
authorvenaas <venaas>
Tue, 16 Sep 2008 09:29:11 +0000 (09:29 +0000)
committervenaas <venaas@e88ac4ed-0b26-0410-9574-a7f39faa03bf>
Tue, 16 Sep 2008 09:29:11 +0000 (09:29 +0000)
git-svn-id: https://svn.testnett.uninett.no/radsecproxy/trunk@379 e88ac4ed-0b26-0410-9574-a7f39faa03bf

dtls.c
radsecproxy.c
radsecproxy.h
tcp.c
tls.c
udp.c

diff --git a/dtls.c b/dtls.c
index a80c8fd..f1a0e1e 100644 (file)
--- a/dtls.c
+++ b/dtls.c
@@ -248,7 +248,8 @@ void *dtlsserverwr(void *arg) {
 }
 
 void dtlsserverrd(struct client *client) {
-    struct request rq;
+    struct request *rq;
+    uint8_t *buf;
     pthread_t dtlsserverwrth;
     
     debug(DBG_DBG, "dtlsserverrd: starting for %s", client->conf->host);
@@ -259,18 +260,25 @@ void dtlsserverrd(struct client *client) {
     }
 
     for (;;) {
-       memset(&rq, 0, sizeof(struct request));
-       rq.buf = raddtlsget(client->ssl, client->rbios, IDLE_TIMEOUT);
-       if (!rq.buf) {
+       buf = raddtlsget(client->ssl, client->rbios, IDLE_TIMEOUT);
+       if (!buf) {
            debug(DBG_ERR, "dtlsserverrd: connection from %s lost", client->conf->host);
            break;
        }
        debug(DBG_DBG, "dtlsserverrd: got Radius message from %s", client->conf->host);
-       rq.from = client;
-       if (!radsrv(&rq)) {
+       rq = newrequest();
+       if (!rq) {
+           free(buf);
+           continue;
+       }
+       rq->buf = buf;
+       rq->from = client;
+       if (!radsrv(rq)) {
+           freerq(rq);
            debug(DBG_ERR, "dtlsserverrd: message authentication/validation failed, closing connection from %s", client->conf->host);
            break;
        }
+       freerq(rq);
     }
     
     /* stop writer by setting ssl to NULL and give signal in case waiting for data */
@@ -308,7 +316,7 @@ void *dtlsservernew(void *arg) {
     while (conf) {
        if (verifyconfcert(cert, conf)) {
            X509_free(cert);
-           client = addclient(conf);
+           client = addclient(conf, 1);
            if (client) {
                client->sock = params->sock;
                client->rbios = params->sesscache->rbios;
index cdf9fbb..6565c75 100644 (file)
@@ -32,6 +32,8 @@
 /* Bugs:
  * TCP accounting not yet supported
  * We are not removing client requests from dynamic servers, see removeclientrqs()
+ * Need to remove UDP clients when no activity for a while...
+ * Remove expired stuff from clients request list?
  */
 
 #include <signal.h>
@@ -88,7 +90,8 @@ int dynamicconfig(struct server *server);
 int confserver_cb(struct gconffile **cf, void *arg, char *block, char *opt, char *val);
 void freerealm(struct realm *realm);
 void freeclsrvconf(struct clsrvconf *conf);
-void freerqdata(struct request *rq);
+void freerq(struct request *rq);
+void freerqoutdata(struct rqout *rqout);
 
 static const struct protodefs protodefs[] = {
     {   "udp", /* UDP, assuming RAD_UDP defined as 0 */
@@ -99,6 +102,7 @@ static const struct protodefs protodefs[] = {
        10, /* retrycountmax */
        REQUEST_RETRY_INTERVAL, /* retryintervaldefault */
        60, /* retryintervalmax */
+       DUPLICATE_INTERVAL, /* duplicateintervaldefault */
        udpserverrd, /* listener */
        &options.sourceudp, /* srcaddrport */
        NULL, /* connecter */
@@ -116,6 +120,7 @@ static const struct protodefs protodefs[] = {
        0, /* retrycountmax */
        REQUEST_RETRY_INTERVAL * REQUEST_RETRY_COUNT, /* retryintervaldefault */
        60, /* retryintervalmax */
+       DUPLICATE_INTERVAL, /* duplicateintervaldefault */
        tlslistener, /* listener */
        &options.sourcetls, /* srcaddrport */
        tlsconnect, /* connecter */
@@ -133,6 +138,7 @@ static const struct protodefs protodefs[] = {
        0, /* retrycountmax */
        REQUEST_RETRY_INTERVAL * REQUEST_RETRY_COUNT, /* retryintervaldefault */
        60, /* retryintervalmax */
+       DUPLICATE_INTERVAL, /* duplicateintervaldefault */
        tcplistener, /* listener */
        &options.sourcetcp, /* srcaddrport */
        tcpconnect, /* connecter */
@@ -150,6 +156,7 @@ static const struct protodefs protodefs[] = {
        10, /* retrycountmax */
        REQUEST_RETRY_INTERVAL, /* retryintervaldefault */
        60, /* retryintervalmax */
+       DUPLICATE_INTERVAL, /* duplicateintervaldefault */
        udpdtlsserverrd, /* listener */
        &options.sourcedtls, /* srcaddrport */
        dtlsconnect, /* connecter */
@@ -551,44 +558,60 @@ void freebios(struct queue *q) {
     removequeue(q);
 }
 
-struct client *addclient(struct clsrvconf *conf) {
+struct client *addclient(struct clsrvconf *conf, uint8_t lock) {
     struct client *new = malloc(sizeof(struct client));
     
     if (!new) {
        debug(DBG_ERR, "malloc failed");
        return NULL;
     }
+
+    if (lock)
+       pthread_mutex_lock(conf->lock);
     if (!conf->clients) {
        conf->clients = list_create();
        if (!conf->clients) {
+           if (lock)
+               pthread_mutex_unlock(conf->lock);
            debug(DBG_ERR, "malloc failed");
            return NULL;
        }
     }
     
     memset(new, 0, sizeof(struct client));
+    pthread_mutex_init(&new->lock, NULL);
     new->conf = conf;
     if (conf->pdef->addclient)
        conf->pdef->addclient(new);
     else
        new->replyq = newqueue();
     list_push(conf->clients, new);
+    if (lock)
+       pthread_mutex_unlock(conf->lock);
     return new;
 }
 
 void removeclient(struct client *client) {
-    if (!client || !client->conf->clients)
+    if (!client)
        return;
-    removequeue(client->replyq);
-    list_removedata(client->conf->clients, client);
-    free(client->addr);
-    free(client);
+
+    pthread_mutex_lock(client->conf->lock);
+    if (client->conf->clients) {
+       pthread_mutex_lock(&client->lock);
+       removequeue(client->replyq);
+       list_removedata(client->conf->clients, client);
+       pthread_mutex_unlock(&client->lock);
+       pthread_mutex_destroy(&client->lock);
+       free(client->addr);
+       free(client);
+    }
+    pthread_mutex_unlock(client->conf->lock);
 }
 
 void removeclientrqs(struct client *client) {
     struct list_node *entry;
     struct server *server;
-    struct request *rq;
+    struct rqout *rqout;
     int i;
     
     for (entry = list_first(srvconfs); entry; entry = list_next(entry)) {
@@ -597,24 +620,26 @@ void removeclientrqs(struct client *client) {
            continue;
        pthread_mutex_lock(&server->newrq_mutex);
        for (i = 0; i < MAX_REQUESTS; i++) {
-           rq = server->requests + i;
-           if (rq->from == client)
-               rq->from = NULL;
+           rqout = server->requests + i;
+           if (rqout->rq && rqout->rq->from == client) {
+               freerq(rqout->rq);
+               rqout->rq = NULL;
+           }
        }
        pthread_mutex_unlock(&server->newrq_mutex);
     }
 }
 
 void freeserver(struct server *server, uint8_t destroymutex) {
-    struct request *rq, *end;
+    struct rqout *rqout, *end;
 
     if (!server)
        return;
 
     if (server->requests) {
-       rq = server->requests;
-       for (end = rq + MAX_REQUESTS; rq < end; rq++)
-           freerqdata(rq);
+       rqout = server->requests;
+       for (end = rqout + MAX_REQUESTS; rqout < end; rqout++)
+           freerqoutdata(rqout);
        free(server->requests);
     }
     if (server->rbios)
@@ -661,7 +686,7 @@ int addserver(struct clsrvconf *conf) {
     if (conf->pdef->addserverextra)
        conf->pdef->addserverextra(conf);
     
-    conf->servers->requests = calloc(MAX_REQUESTS, sizeof(struct request));
+    conf->servers->requests = calloc(MAX_REQUESTS, sizeof(struct rqout));
     if (!conf->servers->requests) {
        debug(DBG_ERR, "malloc failed");
        goto errexit;
@@ -890,15 +915,36 @@ unsigned char *attrget(unsigned char *attrs, int length, uint8_t type) {
 }
 
 void freerqdata(struct request *rq) {
-    if (rq->origusername)
-       free(rq->origusername);
-    if (rq->msg)
-       radmsg_free(rq->msg);
+    if (!rq)
+       return;
     if (rq->buf)
        free(rq->buf);
 }
 
-void sendrq(struct server *to, struct request *rq) {
+void freerq(struct request *rq) {
+    if (!rq)
+       return;
+    debug(DBG_DBG, "freerq: called with refcount %d", rq->refcount);
+    if (--rq->refcount)
+       return;
+    freerqdata(rq);
+    free(rq);
+}
+
+void freerqoutdata(struct rqout *rqout) {
+    if (!rqout)
+       return;
+    if (rqout->origusername)
+       free(rqout->origusername);
+    if (rqout->msg)
+       radmsg_free(rqout->msg);
+    if (rqout->buf)
+       free(rqout->buf);
+    if (rqout->rq)
+       freerq(rqout->rq);
+}
+
+void sendrq(struct server *to, struct rqout *rqout) {
     int i;
 
     pthread_mutex_lock(&to->newrq_mutex);
@@ -912,21 +958,21 @@ void sendrq(struct server *to, struct request *rq) {
                break;
        if (i == to->nextid) {
            debug(DBG_WARN, "sendrq: no room in queue, dropping request");
-           freerqdata(rq);
+           freerqoutdata(rqout);
            goto exit;
        }
     }
 
-    rq->msg->id = (uint8_t)i;
-    rq->buf = radmsg2buf(rq->msg, (uint8_t *)to->conf->secret);
-    if (!rq->buf) {
+    rqout->msg->id = (uint8_t)i;
+    rqout->buf = radmsg2buf(rqout->msg, (uint8_t *)to->conf->secret);
+    if (!rqout->buf) {
        debug(DBG_ERR, "sendrq: radmsg2buf failed");
-       freerqdata(rq);
+       freerqoutdata(rqout);
        goto exit;
     }
     
     debug(DBG_DBG, "sendrq: inserting packet with id %d in queue for %s", i, to->conf->host);
-    to->requests[i] = *rq;
+    to->requests[i] = *rqout;
     to->nextid = i + 1;
 
     if (!to->newrq) {
@@ -1254,15 +1300,15 @@ void removeserversubrealms(struct list *realmlist, struct clsrvconf *srv) {
 }
                        
 int rqinqueue(struct server *to, struct client *from, uint8_t id, uint8_t code) {
-    struct request *rq = to->requests, *end;
+    struct rqout *rqout = to->requests, *end;
     
     pthread_mutex_lock(&to->newrq_mutex);
-    for (end = rq + MAX_REQUESTS; rq < end; rq++)
-       if (rq->buf && !rq->received && rq->origid == id && rq->from == from && *rq->buf == code)
+    for (end = rqout + MAX_REQUESTS; rqout < end; rqout++)
+       if (rqout->buf && !rqout->received && rqout->origid == id && rqout->rq && rqout->rq->from == from && *rqout->buf == code)
            break;
     pthread_mutex_unlock(&to->newrq_mutex);
     
-    return rq < end;
+    return rqout < end;
 }
 
 int attrvalidate(unsigned char *attrs, int length) {
@@ -1317,7 +1363,7 @@ int msmpprecrypt(uint8_t *msmpp, uint8_t len, char *oldsecret, char *newsecret,
     return 1;
 }
 
-int msmppe(unsigned char *attrs, int length, uint8_t type, char *attrtxt, struct request *rq,
+int msmppe(unsigned char *attrs, int length, uint8_t type, char *attrtxt, struct rqout *rq,
           char *oldsecret, char *newsecret) {
     unsigned char *attr;
     
@@ -1508,14 +1554,14 @@ int dorewrite(struct radmsg *msg, struct rewrite *rewrite) {
     return 1;
 }
 
-int rewriteusername(struct request *rq, struct tlv *attr) {
+int rewriteusername(struct rqout *rqout, struct tlv *attr) {
     char *orig = (char *)tlv2str(attr);
-    if (!dorewritemodattr(attr, rq->from->conf->rewriteusername)) {
+    if (!dorewritemodattr(attr, rqout->rq->from->conf->rewriteusername)) {
        free(orig);
        return 0;
     }
     if (strlen(orig) != attr->l || memcmp(orig, attr->v, attr->l))
-       rq->origusername = (char *)orig;
+       rqout->origusername = (char *)orig;
     else
        free(orig);
     return 1;
@@ -1585,33 +1631,33 @@ void acclog(struct radmsg *msg, char *host) {
     }
 }
 
-void respondaccounting(struct request *rq) {
+void respondaccounting(struct rqout *rqout) {
     struct radmsg *msg;
 
-    msg = radmsg_init(RAD_Accounting_Response, rq->msg->id, rq->msg->auth);
+    msg = radmsg_init(RAD_Accounting_Response, rqout->msg->id, rqout->msg->auth);
     if (msg) {
-       debug(DBG_DBG, "respondaccounting: responding to %s", rq->from->conf->host);
-       sendreply(rq->from, msg, &rq->fromsa, rq->fromudpsock);
+       debug(DBG_DBG, "respondaccounting: responding to %s", rqout->rq->from->conf->host);
+       sendreply(rqout->rq->from, msg, &rqout->rq->fromsa, rqout->rq->fromudpsock);
     } else     
        debug(DBG_ERR, "respondaccounting: malloc failed");
 }
 
-void respondstatusserver(struct request *rq) {
+void respondstatusserver(struct rqout *rqout) {
     struct radmsg *msg;
 
-    msg = radmsg_init(RAD_Access_Accept, rq->msg->id, rq->msg->auth);
+    msg = radmsg_init(RAD_Access_Accept, rqout->msg->id, rqout->msg->auth);
     if (msg) {
-       debug(DBG_DBG, "respondstatusserver: responding to %s", rq->from->conf->host);
-       sendreply(rq->from, msg, &rq->fromsa, rq->fromudpsock);
+       debug(DBG_DBG, "respondstatusserver: responding to %s", rqout->rq->from->conf->host);
+       sendreply(rqout->rq->from, msg, &rqout->rq->fromsa, rqout->rq->fromudpsock);
     } else
        debug(DBG_ERR, "respondstatusserver: malloc failed");
 }
 
-void respondreject(struct request *rq, char *message) {
+void respondreject(struct rqout *rqout, char *message) {
     struct radmsg *msg;
     struct tlv *attr;
 
-    msg = radmsg_init(RAD_Access_Reject, rq->msg->id, rq->msg->auth);
+    msg = radmsg_init(RAD_Access_Reject, rqout->msg->id, rqout->msg->auth);
     if (!msg) {
        debug(DBG_ERR, "respondreject: malloc failed");
        return;
@@ -1626,8 +1672,8 @@ void respondreject(struct request *rq, char *message) {
        }
     }
 
-    debug(DBG_DBG, "respondreject: responding to %s", rq->from->conf->host);
-    sendreply(rq->from, msg, &rq->fromsa, rq->fromudpsock);
+    debug(DBG_DBG, "respondreject: responding to %s", rqout->rq->from->conf->host);
+    sendreply(rqout->rq->from, msg, &rqout->rq->fromsa, rqout->rq->fromudpsock);
 }
 
 struct clsrvconf *choosesrvconf(struct list *srvconfs) {
@@ -1677,50 +1723,110 @@ struct server *findserver(struct realm **realm, struct tlv *username, uint8_t ac
     return srvconf->servers;
 }
 
+
+struct request *newrequest() {
+    struct request *rq;
+
+    rq = malloc(sizeof(struct request));
+    if (!rq) {
+       debug(DBG_ERR, "newrequest: malloc failed");
+       return NULL;
+    }
+    memset(rq, 0, sizeof(struct request));
+    rq->refcount = 1;
+    gettimeofday(&rq->created, NULL);
+    return rq;
+}
+
+int addclientrq(struct request *rq, uint8_t id) {
+    struct request *r;
+    struct timeval now;
+
+    pthread_mutex_lock(&rq->from->lock);
+    gettimeofday(&now, NULL);
+    r = rq->from->rqs[id];
+    if (r) {
+       if (r->refcount > 1 && now.tv_sec - r->created.tv_sec < r->from->conf->dupinterval) {
+           pthread_mutex_unlock(&rq->from->lock);
+           return 0;
+       }
+       freerq(r);
+    }
+    rq->refcount++;
+    rq->from->rqs[id] = rq;
+    pthread_mutex_unlock(&rq->from->lock);
+    return 1;
+}
+
+void rmclientrq(struct request *rq, uint8_t id) {
+    struct request *r;
+
+    pthread_mutex_lock(&rq->from->lock);
+    r = rq->from->rqs[id];
+    if (r) {
+       freerq(r);
+       rq->from->rqs[id] = NULL;
+    }
+    pthread_mutex_unlock(&rq->from->lock);
+}
+
 /* returns 0 if validation/authentication fails, else 1 */
 int radsrv(struct request *rq) {
     struct radmsg *msg = NULL;
+    struct rqout *rqout, rqdata;
     struct tlv *attr;
     uint8_t *userascii = NULL;
     unsigned char newauth[16];
     struct realm *realm = NULL;
     struct server *to = NULL;
-
-    msg = buf2radmsg(rq->buf, (uint8_t *)rq->from->conf->secret, NULL);
+    struct client *from = rq->from;
+    
+    msg = buf2radmsg(rq->buf, (uint8_t *)from->conf->secret, NULL);
     if (!msg) {
        debug(DBG_WARN, "radsrv: message validation failed, ignoring packet");
-       freerqdata(rq);
        return 0;
     }
-    rq->msg = msg;
+
+    rqout = &rqdata;
+    memset(rqout, 0, sizeof(struct rqout));
+    rqout->msg = msg;
     debug(DBG_DBG, "radsrv: code %d, id %d", msg->code, msg->id);
     
     if (msg->code != RAD_Access_Request && msg->code != RAD_Status_Server && msg->code != RAD_Accounting_Request) {
        debug(DBG_INFO, "radsrv: server currently accepts only access-requests, accounting-requests and status-server, ignoring");
-       goto exit;
+       freerqoutdata(rqout);
+       return 1;
     }
 
+    if (!addclientrq(rq, msg->id)) {
+       debug(DBG_INFO, "radsrv: already got request with id %d from %s, ignoring", msg->id, from->conf->host);
+       freerqoutdata(rqout);
+       return 1;
+    }
+    rqout->rq = rq;
+    rq->refcount++;
+    
     if (msg->code == RAD_Status_Server) {
-       respondstatusserver(rq);
+       respondstatusserver(rqout);
        goto exit;
     }
 
     /* below: code == RAD_Access_Request || code == RAD_Accounting_Request */
 
-    if (rq->from->conf->rewritein && !dorewrite(msg, rq->from->conf->rewritein))
+    if (from->conf->rewritein && !dorewrite(msg, from->conf->rewritein))
        goto exit;
 
     attr = radmsg_gettype(msg, RAD_Attr_User_Name);
     if (!attr) {
        if (msg->code == RAD_Accounting_Request) {
-           acclog(msg, rq->from->conf->host);
-           respondaccounting(rq);
+           acclog(msg, from->conf->host);
+           respondaccounting(rqout);
        } else
            debug(DBG_WARN, "radsrv: ignoring access request, no username attribute");
        goto exit;
     }
     
-    if (rq->from->conf->rewriteusername && !rewriteusername(rq, attr)) {
+    if (from->conf->rewriteusername && !rewriteusername(rqout, attr)) {
        debug(DBG_WARN, "radsrv: username malloc failed, ignoring request");
        goto exit;
     }
@@ -1737,11 +1843,11 @@ int radsrv(struct request *rq) {
     }
     if (!to) {
        if (realm->message && msg->code == RAD_Access_Request) {
-           debug(DBG_INFO, "radsrv: sending reject to %s for %s", rq->from->conf->host, userascii);
-           respondreject(rq, realm->message);
+           debug(DBG_INFO, "radsrv: sending reject to %s for %s", from->conf->host, userascii);
+           respondreject(rqout, realm->message);
        } else if (realm->accresp && msg->code == RAD_Accounting_Request) {
-           acclog(msg, rq->from->conf->host);
-           respondaccounting(rq);
+           acclog(msg, from->conf->host);
+           respondaccounting(rqout);
        }
        goto exit;
     }
@@ -1749,17 +1855,20 @@ int radsrv(struct request *rq) {
     free(rq->buf);
     rq->buf = NULL;
     
-    if (options.loopprevention && !strcmp(rq->from->conf->name, to->conf->name)) {
+    if (options.loopprevention && !strcmp(from->conf->name, to->conf->name)) {
        debug(DBG_INFO, "radsrv: Loop prevented, not forwarding request from client %s to server %s, discarding",
-             rq->from->conf->name, to->conf->name);
+             from->conf->name, to->conf->name);
        goto exit;
     }
 
-    if (rqinqueue(to, rq->from, msg->id, msg->code)) {
+#if 0
+    skip this now that we have rqrcv... per client?
+    if (rqinqueue(to, from, msg->id, msg->code)) {
        debug(DBG_INFO, "radsrv: already got %s from host %s with id %d, ignoring",
-             radmsgtype2string(msg->code), rq->from->conf->host, msg->id);
+             radmsgtype2string(msg->code), from->conf->host, msg->id);
        goto exit;
     }
+#endif
     
     if (msg->code != RAD_Accounting_Request) {
        if (!RAND_bytes(newauth, 16)) {
@@ -1767,7 +1876,7 @@ int radsrv(struct request *rq) {
            goto exit;
        }
     }
-
+    
 #ifdef DEBUG
     printfchars(NULL, "auth", "%02x ", auth, 16);
 #endif
@@ -1775,37 +1884,38 @@ int radsrv(struct request *rq) {
     attr = radmsg_gettype(msg, RAD_Attr_User_Password);
     if (attr) {
        debug(DBG_DBG, "radsrv: found userpwdattr with value length %d", attr->l);
-       if (!pwdrecrypt(attr->v, attr->l, rq->from->conf->secret, to->conf->secret, msg->auth, newauth))
+       if (!pwdrecrypt(attr->v, attr->l, from->conf->secret, to->conf->secret, msg->auth, newauth))
            goto exit;
     }
 
     attr = radmsg_gettype(msg, RAD_Attr_Tunnel_Password);
     if (attr) {
        debug(DBG_DBG, "radsrv: found tunnelpwdattr with value length %d", attr->l);
-       if (!pwdrecrypt(attr->v, attr->l, rq->from->conf->secret, to->conf->secret, msg->auth, newauth))
+       if (!pwdrecrypt(attr->v, attr->l, from->conf->secret, to->conf->secret, msg->auth, newauth))
            goto exit;
     }
 
-    rq->origid = msg->id;
-    memcpy(rq->origauth, msg->auth, 16);
+    rqout->origid = msg->id;
+    memcpy(rqout->origauth, msg->auth, 16);
     memcpy(msg->auth, newauth, 16);    
 
     if (to->conf->rewriteout && !dorewrite(msg, to->conf->rewriteout))
        goto exit;
     
     free(userascii);
-    sendrq(to, rq);
+    sendrq(to, rqout);
     return 1;
     
  exit:
+    rmclientrq(rq, msg->id);
     free(userascii);
-    freerqdata(rq);
+    freerqoutdata(rqout);
     return 1;
 }
 
 void replyh(struct server *server, unsigned char *buf) {
     struct client *from;
-    struct request *rq;
+    struct rqout *rqout;
     int sublen;
     unsigned char *subattrs;
     struct sockaddr_storage fromsa;
@@ -1817,8 +1927,8 @@ void replyh(struct server *server, unsigned char *buf) {
     server->connectionok = 1;
     server->lostrqs = 0;
 
-    rq = server->requests + buf[1];
-    msg = buf2radmsg(buf, (uint8_t *)server->conf->secret, rq->msg->auth);
+    rqout = server->requests + buf[1];
+    msg = buf2radmsg(buf, (uint8_t *)server->conf->secret, rqout->msg->auth);
     free(buf);
     buf = NULL;
     if (!msg) {
@@ -1834,33 +1944,33 @@ void replyh(struct server *server, unsigned char *buf) {
     debug(DBG_DBG, "got %s message with id %d", radmsgtype2string(msg->code), msg->id);
 
     pthread_mutex_lock(&server->newrq_mutex);
-    if (!rq->buf || !rq->tries) {
+    if (!rqout->buf || !rqout->tries) {
        debug(DBG_INFO, "replyh: no matching request sent with this id, ignoring reply");
        goto errunlock;
     }
 
-    if (rq->received) {
+    if (rqout->received) {
        debug(DBG_INFO, "replyh: already received, ignoring reply");
        goto errunlock;
     }
        
     gettimeofday(&server->lastrcv, NULL);
     
-    if (rq->msg->code == RAD_Status_Server) {
-       rq->received = 1;
+    if (rqout->msg->code == RAD_Status_Server) {
+       rqout->received = 1;
        debug(DBG_DBG, "replyh: got status server response from %s", server->conf->host);
        goto errunlock;
     }
 
     gettimeofday(&server->lastreply, NULL);
     
-    from = rq->from;
-    if (!from) {
+    if (!rqout->rq) {
        debug(DBG_INFO, "replyh: client gone, ignoring reply");
        goto errunlock;
     }
+    from = rqout->rq->from;
        
-    if (server->conf->rewritein && !dorewrite(msg, rq->from->conf->rewritein)) {
+    if (server->conf->rewritein && !dorewrite(msg, from->conf->rewritein)) {
        debug(DBG_WARN, "replyh: rewritein failed");
        goto errunlock;
     }
@@ -1879,9 +1989,9 @@ void replyh(struct server *server, unsigned char *buf) {
        subattrs = attr->v + 4;  
        if (!attrvalidate(subattrs, sublen) ||
            !msmppe(subattrs, sublen, RAD_VS_ATTR_MS_MPPE_Send_Key, "MS MPPE Send Key",
-                   rq, server->conf->secret, from->conf->secret) ||
+                   rqout, server->conf->secret, from->conf->secret) ||
            !msmppe(subattrs, sublen, RAD_VS_ATTR_MS_MPPE_Recv_Key, "MS MPPE Recv Key",
-                   rq, server->conf->secret, from->conf->secret))
+                   rqout, server->conf->secret, from->conf->secret))
            break;
     }
     if (node) {
@@ -1890,9 +2000,9 @@ void replyh(struct server *server, unsigned char *buf) {
     }
 
     if (msg->code == RAD_Access_Accept || msg->code == RAD_Access_Reject || msg->code == RAD_Accounting_Response) {
-       username = radattr2ascii(radmsg_gettype(rq->msg, RAD_Attr_User_Name));
+       username = radattr2ascii(radmsg_gettype(rqout->msg, RAD_Attr_User_Name));
        if (username) {
-           stationid = radattr2ascii(radmsg_gettype(rq->msg, RAD_Attr_Calling_Station_Id));
+           stationid = radattr2ascii(radmsg_gettype(rqout->msg, RAD_Attr_Calling_Station_Id));
            if (stationid) {
                debug(DBG_INFO, "%s for user %s stationid %s from %s",
                      radmsgtype2string(msg->code), username, stationid, server->conf->host);
@@ -1903,19 +2013,19 @@ void replyh(struct server *server, unsigned char *buf) {
        }
     }
 
-    msg->id = (char)rq->origid;
-    memcpy(msg->auth, rq->origauth, 16);
+    msg->id = (char)rqout->origid;
+    memcpy(msg->auth, rqout->origauth, 16);
 
 #ifdef DEBUG   
     printfchars(NULL, "origauth/buf+4", "%02x ", buf + 4, 16);
 #endif
 
-    if (rq->origusername && (attr = radmsg_gettype(msg, RAD_Attr_User_Name))) {
-       if (!resizeattr(attr, strlen(rq->origusername))) {
+    if (rqout->origusername && (attr = radmsg_gettype(msg, RAD_Attr_User_Name))) {
+       if (!resizeattr(attr, strlen(rqout->origusername))) {
            debug(DBG_WARN, "replyh: malloc failed, ignoring reply");
            goto errunlock;
        }
-       memcpy(attr->v, rq->origusername, strlen(rq->origusername));
+       memcpy(attr->v, rqout->origusername, strlen(rqout->origusername));
     }
 
     if (from->conf->rewriteout && !dorewrite(msg, from->conf->rewriteout)) {
@@ -1923,12 +2033,12 @@ void replyh(struct server *server, unsigned char *buf) {
        goto errunlock;
     }
 
-    fromsa = rq->fromsa; /* only needed for UDP */
+    fromsa = rqout->rq->fromsa; /* only needed for UDP */
     /* once we set received = 1, rq may be reused */
-    rq->received = 1;
+    rqout->received = 1;
 
     debug(DBG_INFO, "replyh: passing reply to client %s", from->conf->name);
-    sendreply(from, msg, &fromsa, rq->fromudpsock);
+    sendreply(from, msg, &fromsa, rqout->rq->fromudpsock);
     pthread_mutex_unlock(&server->newrq_mutex);
     return;
 
@@ -1962,13 +2072,13 @@ struct radmsg *createstatsrvmsg() {
 /* code for removing state not finished */
 void *clientwr(void *arg) {
     struct server *server = (struct server *)arg;
-    struct request *rq;
+    struct rqout *rqout;
     pthread_t clientrdth;
     int i, secs, dynconffail = 0;
     uint8_t rnd;
     struct timeval now, laststatsrv;
     struct timespec timeout;
-    struct request statsrvrq;
+    struct rqout statsrvrq;
     struct clsrvconf *conf;
     
     conf = server->conf;
@@ -2045,32 +2155,32 @@ void *clientwr(void *arg) {
                pthread_mutex_unlock(&server->newrq_mutex);
                break;
            }
-           rq = server->requests + i;
+           rqout = server->requests + i;
 
-            if (rq->received) {
+            if (rqout->received) {
                debug(DBG_DBG, "clientwr: packet %d in queue is marked as received", i);
-               if (rq->buf) {
+               if (rqout->buf) {
                    debug(DBG_DBG, "clientwr: freeing received packet %d from queue", i);
-                   freerqdata(rq);
+                   freerqoutdata(rqout);
                    /* setting this to NULL means that it can be reused */
-                   rq->buf = NULL;
+                   rqout->buf = NULL;
                }
                 pthread_mutex_unlock(&server->newrq_mutex);
                 continue;
             }
            
            gettimeofday(&now, NULL);
-            if (now.tv_sec < rq->expiry.tv_sec) {
-               if (!timeout.tv_sec || rq->expiry.tv_sec < timeout.tv_sec)
-                   timeout.tv_sec = rq->expiry.tv_sec;
+            if (now.tv_sec < rqout->expiry.tv_sec) {
+               if (!timeout.tv_sec || rqout->expiry.tv_sec < timeout.tv_sec)
+                   timeout.tv_sec = rqout->expiry.tv_sec;
                pthread_mutex_unlock(&server->newrq_mutex);
                continue;
            }
 
-           if (rq->tries == (*rq->buf == RAD_Status_Server ? 1 : conf->retrycount + 1)) {
+           if (rqout->tries == (*rqout->buf == RAD_Status_Server ? 1 : conf->retrycount + 1)) {
                debug(DBG_DBG, "clientwr: removing expired packet from queue");
                if (conf->statusserver) {
-                   if (*rq->buf == RAD_Status_Server) {
+                   if (*rqout->buf == RAD_Status_Server) {
                        debug(DBG_WARN, "clientwr: no status server response, %s dead?", conf->host);
                        if (server->lostrqs < 255)
                            server->lostrqs++;
@@ -2080,18 +2190,18 @@ void *clientwr(void *arg) {
                    if (server->lostrqs < 255)
                        server->lostrqs++;
                }
-               freerqdata(rq);
+               freerqoutdata(rqout);
                /* setting this to NULL means that it can be reused */
-               rq->buf = NULL;
+               rqout->buf = NULL;
                pthread_mutex_unlock(&server->newrq_mutex);
                continue;
            }
             pthread_mutex_unlock(&server->newrq_mutex);
 
-           rq->expiry.tv_sec = now.tv_sec + conf->retryinterval;
-           if (!timeout.tv_sec || rq->expiry.tv_sec < timeout.tv_sec)
-               timeout.tv_sec = rq->expiry.tv_sec;
-           rq->tries++;
+           rqout->expiry.tv_sec = now.tv_sec + conf->retryinterval;
+           if (!timeout.tv_sec || rqout->expiry.tv_sec < timeout.tv_sec)
+               timeout.tv_sec = rqout->expiry.tv_sec;
+           rqout->tries++;
            conf->pdef->clientradput(server, server->requests[i].buf);
        }
        if (conf->statusserver) {
@@ -2099,7 +2209,7 @@ void *clientwr(void *arg) {
            gettimeofday(&now, NULL);
            if (now.tv_sec - secs > STATUS_SERVER_PERIOD) {
                laststatsrv = now;
-               memset(&statsrvrq, 0, sizeof(struct request));
+               memset(&statsrvrq, 0, sizeof(struct rqout));
                statsrvrq.msg = createstatsrvmsg();
                if (statsrvrq.msg) {
                    debug(DBG_DBG, "clientwr: sending status server to %s", conf->host);
@@ -2885,6 +2995,10 @@ void freeclsrvconf(struct clsrvconf *conf) {
     free(conf->rewriteout);
     if (conf->addrinfo)
        freeaddrinfo(conf->addrinfo);
+    if (conf->lock) {
+       pthread_mutex_destroy(conf->lock);
+       free(conf->lock);
+    }
     /* not touching ssl_ctx, clients and servers */
     free(conf);
 }
@@ -2934,11 +3048,12 @@ int mergesrvconf(struct clsrvconf *dst, struct clsrvconf *src) {
 int confclient_cb(struct gconffile **cf, void *arg, char *block, char *opt, char *val) {
     struct clsrvconf *conf;
     char *conftype = NULL, *rewriteinalias = NULL;
+    long int dupinterval = LONG_MIN;
     
     debug(DBG_DBG, "confclient_cb called for %s", block);
 
     conf = malloc(sizeof(struct clsrvconf));
-    if (!conf || !list_push(clconfs, conf))
+    if (!conf)
        debugx(1, DBG_ERR, "malloc failed");
     memset(conf, 0, sizeof(struct clsrvconf));
     conf->certnamecheck = 1;
@@ -2950,6 +3065,7 @@ int confclient_cb(struct gconffile **cf, void *arg, char *block, char *opt, char
                     "tls", CONF_STR, &conf->tls,
                     "matchcertificateattribute", CONF_STR, &conf->matchcertattr,
                     "CertificateNameCheck", CONF_BLN, &conf->certnamecheck,
+                    "DuplicateInterval", CONF_LINT, &dupinterval,
                     "rewrite", CONF_STR, &rewriteinalias,
                     "rewriteIn", CONF_STR, &conf->confrewritein,
                     "rewriteOut", CONF_STR, &conf->confrewriteout,
@@ -2979,7 +3095,14 @@ int confclient_cb(struct gconffile **cf, void *arg, char *block, char *opt, char
        if (conf->matchcertattr && !addmatchcertattr(conf))
            debugx(1, DBG_ERR, "error in block %s, invalid MatchCertificateAttributeValue", block);
     }
-
+    
+    if (dupinterval != LONG_MIN) {
+       if (dupinterval < 0 || dupinterval > 255)
+           debugx(1, DBG_ERR, "error in block %s, value of option DuplicateInterval is %d, must be 0-255", block, dupinterval);
+       conf->dupinterval = (uint8_t)dupinterval;
+    } else
+       conf->dupinterval = conf->pdef->duplicateintervaldefault;
+    
     if (!conf->confrewritein)
        conf->confrewritein = rewriteinalias;
     else
@@ -3004,6 +3127,14 @@ int confclient_cb(struct gconffile **cf, void *arg, char *block, char *opt, char
        if (!conf->secret)
            debugx(1, DBG_ERR, "malloc failed");
     }
+
+    conf->lock = malloc(sizeof(pthread_mutex_t));
+    if (!conf->lock)
+       debugx(1, DBG_ERR, "malloc failed");
+
+    pthread_mutex_init(conf->lock, NULL);
+    if (!list_push(clconfs, conf))
+       debugx(1, DBG_ERR, "malloc failed");
     return 1;
 }
 
index d37967f..00a185d 100644 (file)
@@ -17,6 +17,7 @@
 #define MAX_REQUESTS 256
 #define REQUEST_RETRY_INTERVAL 5
 #define REQUEST_RETRY_COUNT 2
+#define DUPLICATE_INTERVAL REQUEST_RETRY_INTERVAL * REQUEST_RETRY_COUNT
 #define MAX_CERT_DEPTH 5
 #define STATUS_SERVER_PERIOD 25
 #define IDLE_TIMEOUT 300
@@ -41,19 +42,26 @@ struct options {
     uint8_t loopprevention;
 };
 
-/* requests that our client will send */
 struct request {
+    struct timeval created;
+    uint8_t refcount;
+    uint8_t *buf;
+    struct client *from;
+    struct sockaddr_storage fromsa; /* used by udpservwr */
+    int fromudpsock; /* used by udpservwr */
+};
+
+/* requests that our client will send */
+struct rqout {
     unsigned char *buf;
     struct radmsg *msg;
     uint8_t tries;
     uint8_t received;
     struct timeval expiry;
-    struct client *from;
     char *origusername;
     uint8_t origid; /* used by servwr */
     char origauth[16]; /* used by servwr */
-    struct sockaddr_storage fromsa; /* used by udpservwr */
-    int fromudpsock; /* used by udpservwr */
+    struct request *rq;
 };
 
 /* replies that a server will send */
@@ -88,12 +96,14 @@ struct clsrvconf {
     uint8_t statusserver;
     uint8_t retryinterval;
     uint8_t retrycount;
+    uint8_t dupinterval;
     uint8_t certnamecheck;
     SSL_CTX *ssl_ctx;
     struct rewrite *rewritein;
     struct rewrite *rewriteout;
     struct addrinfo *addrinfo;
     uint8_t prefixlen;
+    pthread_mutex_t *lock; /* only used for updating clients so far */
     struct list *clients;
     struct server *servers;
 };
@@ -102,6 +112,8 @@ struct client {
     struct clsrvconf *conf;
     int sock; /* for tcp/dtls */
     SSL *ssl;
+    pthread_mutex_t lock; /* used for updating rqs */
+    struct request *rqs[MAX_REQUESTS];
     struct queue *replyq;
     struct queue *rbios; /* for dtls */
     struct sockaddr *addr; /* for udp */
@@ -121,7 +133,7 @@ struct server {
     char *dynamiclookuparg;
     int nextid;
     struct timeval lastrcv;
-    struct request *requests;
+    struct rqout *requests;
     uint8_t newrq;
     pthread_mutex_t newrq_mutex;
     pthread_cond_t newrq_cond;
@@ -173,6 +185,7 @@ struct protodefs {
     uint8_t retrycountmax;
     uint8_t retryintervaldefault;
     uint8_t retryintervalmax;
+    uint8_t duplicateintervaldefault;
     void *(*listener)(void*);
     char **srcaddrport;
     int (*connecter)(struct server *, struct timeval *, int, char *);
@@ -198,12 +211,14 @@ struct addrinfo *getsrcprotores(uint8_t type);
 struct clsrvconf *find_clconf(uint8_t type, struct sockaddr *addr, struct list_node **cur);
 struct clsrvconf *find_srvconf(uint8_t type, struct sockaddr *addr, struct list_node **cur);
 struct clsrvconf *find_clconf_type(uint8_t type, struct list_node **cur);
-struct client *addclient(struct clsrvconf *conf);
+struct client *addclient(struct clsrvconf *conf, uint8_t lock);
 void removeclient(struct client *client);
 void removeclientrqs(struct client *client);
 struct queue *newqueue();
 void removequeue(struct queue *q);
 void freebios(struct queue *q);
+struct request *newrequest();
+void freerq(struct request *rq);
 int radsrv(struct request *rq);
 X509 *verifytlscert(SSL *ssl);
 int verifyconfcert(X509 *cert, struct clsrvconf *conf);
diff --git a/tcp.c b/tcp.c
index a470120..b3e41df 100644 (file)
--- a/tcp.c
+++ b/tcp.c
@@ -223,7 +223,8 @@ void *tcpserverwr(void *arg) {
 }
 
 void tcpserverrd(struct client *client) {
-    struct request rq;
+    struct request *rq;
+    uint8_t *buf;
     pthread_t tcpserverwrth;
     
     debug(DBG_DBG, "tcpserverrd: starting for %s", client->conf->host);
@@ -234,18 +235,25 @@ void tcpserverrd(struct client *client) {
     }
 
     for (;;) {
-       memset(&rq, 0, sizeof(struct request));
-       rq.buf = radtcpget(client->sock, 0);
-       if (!rq.buf) {
+       buf = radtcpget(client->sock, 0);
+       if (!buf) {
            debug(DBG_ERR, "tcpserverrd: connection from %s lost", client->conf->host);
            break;
        }
        debug(DBG_DBG, "tcpserverrd: got Radius message from %s", client->conf->host);
-       rq.from = client;
-       if (!radsrv(&rq)) {
+       rq = newrequest();
+       if (!rq) {
+           free(buf);
+           continue;
+       }
+       rq->buf = buf;
+       rq->from = client;
+       if (!radsrv(rq)) {
+           freerq(rq);
            debug(DBG_ERR, "tcpserverrd: message authentication/validation failed, closing connection from %s", client->conf->host);
            break;
        }
+       freerq(rq);
     }
 
     /* stop writer by setting s to -1 and give signal in case waiting for data */
@@ -275,7 +283,7 @@ void *tcpservernew(void *arg) {
 
     conf = find_clconf(RAD_TCP, (struct sockaddr *)&from, NULL);
     if (conf) {
-       client = addclient(conf);
+       client = addclient(conf, 1);
        if (client) {
            client->sock = s;
            tcpserverrd(client);
diff --git a/tls.c b/tls.c
index 4a9641f..6836c28 100644 (file)
--- a/tls.c
+++ b/tls.c
@@ -276,7 +276,8 @@ void *tlsserverwr(void *arg) {
 }
 
 void tlsserverrd(struct client *client) {
-    struct request rq;
+    struct request *rq;
+    uint8_t *buf;
     pthread_t tlsserverwrth;
     
     debug(DBG_DBG, "tlsserverrd: starting for %s", client->conf->host);
@@ -287,18 +288,25 @@ void tlsserverrd(struct client *client) {
     }
 
     for (;;) {
-       memset(&rq, 0, sizeof(struct request));
-       rq.buf = radtlsget(client->ssl, 0);
-       if (!rq.buf) {
+       buf = radtlsget(client->ssl, 0);
+       if (!buf) {
            debug(DBG_ERR, "tlsserverrd: connection from %s lost", client->conf->host);
            break;
        }
        debug(DBG_DBG, "tlsserverrd: got Radius message from %s", client->conf->host);
-       rq.from = client;
-       if (!radsrv(&rq)) {
+       rq = newrequest();
+       if (!rq) {
+           free(buf);
+           continue;
+       }
+       rq->buf = buf;
+       rq->from = client;
+       if (!radsrv(rq)) {
+           freerq(rq);
            debug(DBG_ERR, "tlsserverrd: message authentication/validation failed, closing connection from %s", client->conf->host);
            break;
        }
+       freerq(rq);
     }
     
     /* stop writer by setting ssl to NULL and give signal in case waiting for data */
@@ -349,7 +357,7 @@ void *tlsservernew(void *arg) {
     while (conf) {
        if (verifyconfcert(cert, conf)) {
            X509_free(cert);
-           client = addclient(conf);
+           client = addclient(conf, 1);
            if (client) {
                client->ssl = ssl;
                tlsserverrd(client);
diff --git a/udp.c b/udp.c
index 4d27425..3079a68 100644 (file)
--- a/udp.c
+++ b/udp.c
@@ -102,22 +102,28 @@ unsigned char *radudpget(int s, struct client **client, struct server **server,
            debug(DBG_DBG, "radudpget: packet was padded with %d bytes", cnt - len);
 
        if (client) {
+           pthread_mutex_lock(p->lock);
            for (node = list_first(p->clients); node; node = list_next(node))
                if (addr_equal((struct sockaddr *)&from, ((struct client *)node->data)->addr))
                    break;
            if (node) {
                *client = (struct client *)node->data;
+               pthread_mutex_unlock(p->lock);
                break;
            }
            fromcopy = addr_copy((struct sockaddr *)&from);
-           if (!fromcopy)
+           if (!fromcopy) {
+               pthread_mutex_unlock(p->lock);
                continue;
-           *client = addclient(p);
+           }
+           *client = addclient(p, 0);
            if (!*client) {
                free(fromcopy);
+               pthread_mutex_unlock(p->lock);
                continue;
            }
            (*client)->addr = fromcopy;
+           pthread_mutex_unlock(p->lock);
        } else if (server)
            *server = p->servers;
        break;
@@ -178,14 +184,19 @@ void *udpclientrd(void *arg) {
 }
 
 void *udpserverrd(void *arg) {
-    struct request rq;
+    struct request *rq;
     int *sp = (int *)arg;
     
     for (;;) {
-       memset(&rq, 0, sizeof(struct request));
-       rq.buf = radudpget(*sp, &rq.from, NULL, &rq.fromsa);
-       rq.fromudpsock = *sp;
-       radsrv(&rq);
+       rq = newrequest();
+       if (!rq) {
+           sleep(5); /* malloc failed */
+           continue;
+       }
+       rq->buf = radudpget(*sp, &rq->from, NULL, &rq->fromsa);
+       rq->fromudpsock = *sp;
+       radsrv(rq);
+       freerq(rq);
     }
     free(sp);
 }