Handle failing rs_context_create().
[libradsec.git] / udp.c
diff --git a/udp.c b/udp.c
index be61b0b..4740fd0 100644 (file)
--- a/udp.c
+++ b/udp.c
@@ -1,5 +1,5 @@
 /*
- * Copyright (C) 2006-2008 Stig Venaas <venaas@uninett.no>
+ * Copyright (C) 2006-2009 Stig Venaas <venaas@uninett.no>
  *
  * Permission to use, copy, modify, and distribute this software for any
  * purpose with or without fee is hereby granted, provided that the above
 #include <arpa/inet.h>
 #include <regex.h>
 #include <pthread.h>
-#include <openssl/ssl.h>
-#include "debug.h"
 #include "list.h"
-#include "util.h"
+#include "hostport.h"
 #include "radsecproxy.h"
-#include "tls.h"
+
+#ifdef RADPROT_UDP
+#include "debug.h"
+#include "util.h"
+
+static void setprotoopts(struct commonprotoopts *opts);
+static char **getlistenerargs();
+void *udpserverrd(void *arg);
+int clientradputudp(struct server *server, unsigned char *rad);
+void addclientudp(struct client *client);
+void addserverextraudp(struct clsrvconf *conf);
+void udpsetsrcres();
+void initextraudp();
+
+static const struct protodefs protodefs = {
+    "udp",
+    NULL, /* secretdefault */
+    SOCK_DGRAM, /* socktype */
+    "1812", /* portdefault */
+    REQUEST_RETRY_COUNT, /* retrycountdefault */
+    10, /* retrycountmax */
+    REQUEST_RETRY_INTERVAL, /* retryintervaldefault */
+    60, /* retryintervalmax */
+    DUPLICATE_INTERVAL, /* duplicateintervaldefault */
+    setprotoopts, /* setprotoopts */
+    getlistenerargs, /* getlistenerargs */
+    udpserverrd, /* listener */
+    NULL, /* connecter */
+    NULL, /* clientconnreader */
+    clientradputudp, /* clientradput */
+    addclientudp, /* addclient */
+    addserverextraudp, /* addserverextra */
+    udpsetsrcres, /* setsrcres */
+    initextraudp /* initextra */
+};
 
 static int client4_sock = -1;
 static int client6_sock = -1;
-static struct queue *server_replyq = NULL;
+static struct gqueue *server_replyq = NULL;
+
+static struct addrinfo *srcres = NULL;
+static uint8_t handle;
+static struct commonprotoopts *protoopts = NULL;
+
+const struct protodefs *udpinit(uint8_t h) {
+    handle = h;
+    return &protodefs;
+}
+
+static void setprotoopts(struct commonprotoopts *opts) {
+    protoopts = opts;
+}
+
+static char **getlistenerargs() {
+    return protoopts ? protoopts->listenargs : NULL;
+}
+
+void udpsetsrcres() {
+    if (!srcres)
+       srcres = resolvepassiveaddrinfo(protoopts ? protoopts->sourcearg : NULL, NULL, protodefs.socktype);
+}
+
+void removeudpclientfromreplyq(struct client *c) {
+    struct list_node *n;
+    struct request *r;
+
+    /* lock the common queue and remove replies for this client */
+    pthread_mutex_lock(&c->replyq->mutex);
+    for (n = list_first(c->replyq->entries); n; n = list_next(n)) {
+       r = (struct request *)n->data;
+       if (r->from == c)
+           r->from = NULL;
+    }
+    pthread_mutex_unlock(&c->replyq->mutex);
+}
+
+static int addr_equal(struct sockaddr *a, struct sockaddr *b) {
+    switch (a->sa_family) {
+    case AF_INET:
+       return !memcmp(&((struct sockaddr_in*)a)->sin_addr,
+                      &((struct sockaddr_in*)b)->sin_addr,
+                      sizeof(struct in_addr));
+    case AF_INET6:
+       return IN6_ARE_ADDR_EQUAL(&((struct sockaddr_in6*)a)->sin6_addr,
+                                 &((struct sockaddr_in6*)b)->sin6_addr);
+    default:
+       /* Must not reach */
+       return 0;
+    }
+}
+
+uint16_t port_get(struct sockaddr *sa) {
+    switch (sa->sa_family) {
+    case AF_INET:
+       return ntohs(((struct sockaddr_in *)sa)->sin_port);
+    case AF_INET6:
+       return ntohs(((struct sockaddr_in6 *)sa)->sin6_port);
+    }
+    return 0;
+}
 
 /* exactly one of client and server must be non-NULL */
 /* return who we received from in *client or *server */
@@ -47,7 +140,9 @@ unsigned char *radudpget(int s, struct client **client, struct server **server,
     struct clsrvconf *p;
     struct list_node *node;
     fd_set readfds;
-    
+    struct client *c = NULL;
+    struct timeval now;
+
     for (;;) {
        if (rad) {
            free(rad);
@@ -62,37 +157,32 @@ unsigned char *radudpget(int s, struct client **client, struct server **server,
            debug(DBG_WARN, "radudpget: recv failed");
            continue;
        }
-       if (cnt < 20) {
-           debug(DBG_WARN, "radudpget: length too small");
-           recv(s, buf, 4, 0);
-           continue;
-       }
-       
+
        p = client
-           ? find_clconf(RAD_UDP, (struct sockaddr *)&from, NULL)
-           : find_srvconf(RAD_UDP, (struct sockaddr *)&from, NULL);
+           ? find_clconf(handle, (struct sockaddr *)&from, NULL)
+           : find_srvconf(handle, (struct sockaddr *)&from, NULL);
        if (!p) {
-           debug(DBG_WARN, "radudpget: got packet from wrong or unknown UDP peer %s, ignoring", addr2string((struct sockaddr *)&from, fromlen));
+           debug(DBG_WARN, "radudpget: got packet from wrong or unknown UDP peer %s, ignoring", addr2string((struct sockaddr *)&from));
            recv(s, buf, 4, 0);
            continue;
        }
-       
+
        len = RADLEN(buf);
        if (len < 20) {
            debug(DBG_WARN, "radudpget: length too small");
            recv(s, buf, 4, 0);
            continue;
        }
-           
+
        rad = malloc(len);
        if (!rad) {
            debug(DBG_ERR, "radudpget: malloc failed");
            recv(s, buf, 4, 0);
            continue;
        }
-       
+
        cnt = recv(s, rad, len, MSG_TRUNC);
-       debug(DBG_DBG, "radudpget: got %d bytes from %s", cnt, addr2string((struct sockaddr *)&from, fromlen));
+       debug(DBG_DBG, "radudpget: got %d bytes from %s", cnt, addr2string((struct sockaddr *)&from));
 
        if (cnt < len) {
            debug(DBG_WARN, "radudpget: packet smaller than length field in radius header");
@@ -102,27 +192,45 @@ unsigned char *radudpget(int s, struct client **client, struct server **server,
            debug(DBG_DBG, "radudpget: packet was padded with %d bytes", cnt - len);
 
        if (client) {
+           *client = NULL;
            pthread_mutex_lock(p->lock);
-           for (node = list_first(p->clients); node; node = list_next(node))
-               if (addr_equal((struct sockaddr *)&from, ((struct client *)node->data)->addr))
-                   break;
-           if (node) {
-               *client = (struct client *)node->data;
-               pthread_mutex_unlock(p->lock);
+           for (node = list_first(p->clients); node;) {
+               c = (struct client *)node->data;
+               node = list_next(node);
+               if (s != c->sock)
+                   continue;
+               gettimeofday(&now, NULL);
+               if (!*client && addr_equal((struct sockaddr *)&from, c->addr)) {
+                   c->expiry = now.tv_sec + 60;
+                   *client = c;
+               }
+               if (c->expiry >= now.tv_sec)
+                   continue;
+
+               debug(DBG_DBG, "radudpget: removing expired client (%s)", addr2string(c->addr));
+               removeudpclientfromreplyq(c);
+               c->replyq = NULL; /* stop removeclient() from removing common udp replyq */
+               removelockedclient(c);
                break;
            }
-           fromcopy = addr_copy((struct sockaddr *)&from);
-           if (!fromcopy) {
-               pthread_mutex_unlock(p->lock);
-               continue;
-           }
-           *client = addclient(p, 0);
            if (!*client) {
-               free(fromcopy);
-               pthread_mutex_unlock(p->lock);
-               continue;
+               fromcopy = addr_copy((struct sockaddr *)&from);
+               if (!fromcopy) {
+                   pthread_mutex_unlock(p->lock);
+                   continue;
+               }
+               c = addclient(p, 0);
+               if (!c) {
+                   free(fromcopy);
+                   pthread_mutex_unlock(p->lock);
+                   continue;
+               }
+               c->sock = s;
+               c->addr = fromcopy;
+               gettimeofday(&now, NULL);
+               c->expiry = now.tv_sec + 60;
+               *client = c;
            }
-           (*client)->addr = fromcopy;
            pthread_mutex_unlock(p->lock);
        } else if (server)
            *server = p->servers;
@@ -135,23 +243,13 @@ unsigned char *radudpget(int s, struct client **client, struct server **server,
 
 int clientradputudp(struct server *server, unsigned char *rad) {
     size_t len;
-    struct sockaddr_storage sa;
-    struct sockaddr *sap;
     struct clsrvconf *conf = server->conf;
-    uint16_t port;
-    
+    struct addrinfo *ai;
+
     len = RADLEN(rad);
-    port = port_get(conf->addrinfo->ai_addr);
-    
-    if (*rad == RAD_Accounting_Request) {
-       sap = (struct sockaddr *)&sa;
-       memcpy(sap, conf->addrinfo->ai_addr, conf->addrinfo->ai_addrlen);
-       port_set(sap, ++port);
-    } else
-       sap = conf->addrinfo->ai_addr;
-
-    if (sendto(server->sock, rad, len, 0, sap, conf->addrinfo->ai_addrlen) >= 0) {
-       debug(DBG_DBG, "clienradputudp: sent UDP of length %d to %s port %d", len, conf->host, port);
+    ai = ((struct hostportres *)list_first(conf->hostports)->data)->addrinfo;
+    if (sendto(server->sock, rad, len, 0, ai->ai_addr, ai->ai_addrlen) >= 0) {
+       debug(DBG_DBG, "clienradputudp: sent UDP of length %d to %s port %d", len, addr2string(ai->ai_addr), port_get(ai->ai_addr));
        return 1;
     }
 
@@ -163,7 +261,7 @@ void *udpclientrd(void *arg) {
     struct server *server;
     unsigned char *buf;
     int *s = (int *)arg;
-    
+
     for (;;) {
        server = NULL;
        buf = radudpget(*s, NULL, &server, NULL);
@@ -174,7 +272,7 @@ void *udpclientrd(void *arg) {
 void *udpserverrd(void *arg) {
     struct request *rq;
     int *sp = (int *)arg;
-    
+
     for (;;) {
        rq = newrequest();
        if (!rq) {
@@ -186,13 +284,14 @@ void *udpserverrd(void *arg) {
        radsrv(rq);
     }
     free(sp);
+    return NULL;
 }
 
 void *udpserverwr(void *arg) {
-    struct queue *replyq = (struct queue *)arg;
+    struct gqueue *replyq = (struct gqueue *)arg;
     struct request *reply;
     struct sockaddr_storage to;
-    
+
     for (;;) {
        pthread_mutex_lock(&replyq->mutex);
        while (!(reply = (struct request *)list_shift(replyq->entries))) {
@@ -200,12 +299,16 @@ void *udpserverwr(void *arg) {
            pthread_cond_wait(&replyq->cond, &replyq->mutex);
            debug(DBG_DBG, "udp server writer, got signal");
        }
+       /* do this with lock, udpserverrd may set from = NULL if from expires */
+       if (reply->from)
+           memcpy(&to, reply->from->addr, SOCKADDRP_SIZE(reply->from->addr));
        pthread_mutex_unlock(&replyq->mutex);
-
-       memcpy(&to, reply->from->addr, SOCKADDRP_SIZE(reply->from->addr));
-       port_set((struct sockaddr *)&to, reply->udpport);
-       if (sendto(reply->udpsock, reply->replybuf, RADLEN(reply->replybuf), 0, (struct sockaddr *)&to, SOCKADDR_SIZE(to)) < 0)
-           debug(DBG_WARN, "udpserverwr: send failed");
+       if (reply->from) {
+           port_set((struct sockaddr *)&to, reply->udpport);
+           if (sendto(reply->udpsock, reply->replybuf, RADLEN(reply->replybuf), 0, (struct sockaddr *)&to, SOCKADDR_SIZE(to)) < 0)
+               debug(DBG_WARN, "udpserverwr: send failed");
+       }
+       debug(DBG_DBG, "udpserverwr: refcount %d", reply->refcount);
        freerq(reply);
     }
 }
@@ -215,20 +318,20 @@ void addclientudp(struct client *client) {
 }
 
 void addserverextraudp(struct clsrvconf *conf) {
-    switch (conf->addrinfo->ai_family) {
+    switch (((struct hostportres *)list_first(conf->hostports)->data)->addrinfo->ai_family) {
     case AF_INET:
        if (client4_sock < 0) {
-           client4_sock = bindtoaddr(getsrcprotores(RAD_UDP), AF_INET, 0, 1);
+           client4_sock = bindtoaddr(srcres, AF_INET, 0, 1);
            if (client4_sock < 0)
-               debugx(1, DBG_ERR, "addserver: failed to create client socket for server %s", conf->host);
+               debugx(1, DBG_ERR, "addserver: failed to create client socket for server %s", conf->name);
        }
        conf->servers->sock = client4_sock;
        break;
     case AF_INET6:
        if (client6_sock < 0) {
-           client6_sock = bindtoaddr(getsrcprotores(RAD_UDP), AF_INET6, 0, 1);
+           client6_sock = bindtoaddr(srcres, AF_INET6, 0, 1);
            if (client6_sock < 0)
-               debugx(1, DBG_ERR, "addserver: failed to create client socket for server %s", conf->host);
+               debugx(1, DBG_ERR, "addserver: failed to create client socket for server %s", conf->name);
        }
        conf->servers->sock = client6_sock;
        break;
@@ -239,7 +342,12 @@ void addserverextraudp(struct clsrvconf *conf) {
 
 void initextraudp() {
     pthread_t cl4th, cl6th, srvth;
-    
+
+    if (srcres) {
+       freeaddrinfo(srcres);
+       srcres = NULL;
+    }
+
     if (client4_sock >= 0)
        if (pthread_create(&cl4th, NULL, udpclientrd, (void *)&client4_sock))
            debugx(1, DBG_ERR, "pthread_create failed");
@@ -247,9 +355,18 @@ void initextraudp() {
        if (pthread_create(&cl6th, NULL, udpclientrd, (void *)&client6_sock))
            debugx(1, DBG_ERR, "pthread_create failed");
 
-    if (find_clconf_type(RAD_UDP, NULL)) {
+    if (find_clconf_type(handle, NULL)) {
        server_replyq = newqueue();
        if (pthread_create(&srvth, NULL, udpserverwr, (void *)server_replyq))
            debugx(1, DBG_ERR, "pthread_create failed");
     }
 }
+#else
+const struct protodefs *udpinit(uint8_t h) {
+    return NULL;
+}
+#endif
+
+/* Local Variables: */
+/* c-file-style: "stroustrup" */
+/* End: */