a6b8bc2266a15c4c4104d29355a70d7bd2046c16
[radsecproxy.git] / udp.c
1 /* Copyright (c) 2006-2009, Stig Venaas, UNINETT AS.
2  * Copyright (c) 2010, UNINETT AS, NORDUnet A/S.
3  * Copyright (c) 2010-2012, NORDUnet A/S. */
4 /* See LICENSE for licensing information. */
5
6 #include <signal.h>
7 #include <sys/socket.h>
8 #include <netinet/in.h>
9 #include <netdb.h>
10 #include <string.h>
11 #include <unistd.h>
12 #include <limits.h>
13 #ifdef SYS_SOLARIS9
14 #include <fcntl.h>
15 #endif
16 #include <sys/time.h>
17 #include <sys/types.h>
18 #include <sys/select.h>
19 #include <ctype.h>
20 #include <sys/wait.h>
21 #include <arpa/inet.h>
22 #include <regex.h>
23 #include <pthread.h>
24 #include <assert.h>
25 #include "radsecproxy.h"
26 #include "hostport.h"
27
28 #ifdef RADPROT_UDP
29 #include "debug.h"
30 #include "util.h"
31
32 static void setprotoopts(struct commonprotoopts *opts);
33 static char **getlistenerargs();
34 void *udpserverrd(void *arg);
35 int clientradputudp(struct server *server, unsigned char *rad);
36 void addclientudp(struct client *client);
37 void addserverextraudp(struct clsrvconf *conf);
38 void udpsetsrcres();
39 void initextraudp();
40
41 static const struct protodefs protodefs = {
42     "udp",
43     NULL, /* secretdefault */
44     SOCK_DGRAM, /* socktype */
45     "1812", /* portdefault */
46     REQUEST_RETRY_COUNT, /* retrycountdefault */
47     10, /* retrycountmax */
48     REQUEST_RETRY_INTERVAL, /* retryintervaldefault */
49     60, /* retryintervalmax */
50     DUPLICATE_INTERVAL, /* duplicateintervaldefault */
51     setprotoopts, /* setprotoopts */
52     getlistenerargs, /* getlistenerargs */
53     udpserverrd, /* listener */
54     NULL, /* connecter */
55     NULL, /* clientconnreader */
56     clientradputudp, /* clientradput */
57     addclientudp, /* addclient */
58     addserverextraudp, /* addserverextra */
59     udpsetsrcres, /* setsrcres */
60     initextraudp /* initextra */
61 };
62
63 static int client4_sock = -1;
64 static int client6_sock = -1;
65 static struct gqueue *server_replyq = NULL;
66
67 static struct addrinfo *srcres = NULL;
68 static uint8_t handle;
69 static struct commonprotoopts *protoopts = NULL;
70
71 const struct protodefs *udpinit(uint8_t h) {
72     handle = h;
73     return &protodefs;
74 }
75
76 static void setprotoopts(struct commonprotoopts *opts) {
77     protoopts = opts;
78 }
79
80 static char **getlistenerargs() {
81     return protoopts ? protoopts->listenargs : NULL;
82 }
83
84 void udpsetsrcres() {
85     if (!srcres)
86         srcres =
87             resolvepassiveaddrinfo(protoopts ? protoopts->sourcearg : NULL,
88                                    AF_UNSPEC, NULL, protodefs.socktype);
89 }
90
91 void removeudpclientfromreplyq(struct client *c) {
92     struct list_node *n;
93     struct request *r;
94
95     /* lock the common queue and remove replies for this client */
96     pthread_mutex_lock(&c->replyq->mutex);
97     for (n = list_first(c->replyq->entries); n; n = list_next(n)) {
98         r = (struct request *)n->data;
99         if (r->from == c)
100             r->from = NULL;
101     }
102     pthread_mutex_unlock(&c->replyq->mutex);
103 }
104
105 static int addr_equal(struct sockaddr *a, struct sockaddr *b) {
106     switch (a->sa_family) {
107     case AF_INET:
108         return !memcmp(&((struct sockaddr_in*)a)->sin_addr,
109                        &((struct sockaddr_in*)b)->sin_addr,
110                        sizeof(struct in_addr));
111     case AF_INET6:
112         return IN6_ARE_ADDR_EQUAL(&((struct sockaddr_in6*)a)->sin6_addr,
113                                   &((struct sockaddr_in6*)b)->sin6_addr);
114     default:
115         /* Must not reach */
116         return 0;
117     }
118 }
119
120 uint16_t port_get(struct sockaddr *sa) {
121     switch (sa->sa_family) {
122     case AF_INET:
123         return ntohs(((struct sockaddr_in *)sa)->sin_port);
124     case AF_INET6:
125         return ntohs(((struct sockaddr_in6 *)sa)->sin6_port);
126     }
127     return 0;
128 }
129
130 /* exactly one of client and server must be non-NULL */
131 /* return who we received from in *client or *server */
132 /* return from in sa if not NULL */
133 unsigned char *radudpget(int s, struct client **client, struct server **server, uint16_t *port) {
134     int cnt, len;
135     unsigned char buf[4], *rad = NULL;
136     struct sockaddr_storage from;
137     struct sockaddr *fromcopy;
138     socklen_t fromlen = sizeof(from);
139     struct clsrvconf *p;
140     struct list_node *node;
141     fd_set readfds;
142     struct client *c = NULL;
143     struct timeval now;
144
145     for (;;) {
146         if (rad) {
147             free(rad);
148             rad = NULL;
149         }
150         FD_ZERO(&readfds);
151         FD_SET(s, &readfds);
152         if (select(s + 1, &readfds, NULL, NULL, NULL) < 1)
153             continue;
154         cnt = recvfrom(s, buf, 4, MSG_PEEK | MSG_TRUNC, (struct sockaddr *)&from, &fromlen);
155         if (cnt == -1) {
156             debug(DBG_WARN, "radudpget: recv failed");
157             continue;
158         }
159
160         p = client
161             ? find_clconf(handle, (struct sockaddr *)&from, NULL)
162             : find_srvconf(handle, (struct sockaddr *)&from, NULL);
163         if (!p) {
164             debug(DBG_WARN, "radudpget: got packet from wrong or unknown UDP peer %s, ignoring", addr2string((struct sockaddr *)&from));
165             recv(s, buf, 4, 0);
166             continue;
167         }
168
169         len = RADLEN(buf);
170         if (len < 20) {
171             debug(DBG_WARN, "radudpget: length too small");
172             recv(s, buf, 4, 0);
173             continue;
174         }
175
176         rad = malloc(len);
177         if (!rad) {
178             debug(DBG_ERR, "radudpget: malloc failed");
179             recv(s, buf, 4, 0);
180             continue;
181         }
182
183         cnt = recv(s, rad, len, MSG_TRUNC);
184         debug(DBG_DBG, "radudpget: got %d bytes from %s", cnt, addr2string((struct sockaddr *)&from));
185
186         if (cnt < len) {
187             debug(DBG_WARN, "radudpget: packet smaller than length field in radius header");
188             continue;
189         }
190         if (cnt > len)
191             debug(DBG_DBG, "radudpget: packet was padded with %d bytes", cnt - len);
192
193         if (client) {
194             *client = NULL;
195             pthread_mutex_lock(p->lock);
196             for (node = list_first(p->clients); node;) {
197                 c = (struct client *)node->data;
198                 node = list_next(node);
199                 if (s != c->sock)
200                     continue;
201                 gettimeofday(&now, NULL);
202                 if (!*client && addr_equal((struct sockaddr *)&from, c->addr)) {
203                     c->expiry = now.tv_sec + 60;
204                     *client = c;
205                 }
206                 if (c->expiry >= now.tv_sec)
207                     continue;
208
209                 debug(DBG_DBG, "radudpget: removing expired client (%s)", addr2string(c->addr));
210                 removeudpclientfromreplyq(c);
211                 c->replyq = NULL; /* stop removeclient() from removing common udp replyq */
212                 removelockedclient(c);
213                 break;
214             }
215             if (!*client) {
216                 fromcopy = addr_copy((struct sockaddr *)&from);
217                 if (!fromcopy) {
218                     pthread_mutex_unlock(p->lock);
219                     continue;
220                 }
221                 c = addclient(p, 0);
222                 if (!c) {
223                     free(fromcopy);
224                     pthread_mutex_unlock(p->lock);
225                     continue;
226                 }
227                 c->sock = s;
228                 c->addr = fromcopy;
229                 gettimeofday(&now, NULL);
230                 c->expiry = now.tv_sec + 60;
231                 *client = c;
232             }
233             pthread_mutex_unlock(p->lock);
234         } else if (server)
235             *server = p->servers;
236         break;
237     }
238     if (port)
239         *port = port_get((struct sockaddr *)&from);
240     return rad;
241 }
242
243 int clientradputudp(struct server *server, unsigned char *rad) {
244     size_t len;
245     struct clsrvconf *conf = server->conf;
246     struct addrinfo *ai;
247
248     len = RADLEN(rad);
249     ai = ((struct hostportres *)list_first(conf->hostports)->data)->addrinfo;
250     if (sendto(server->sock, rad, len, 0, ai->ai_addr, ai->ai_addrlen) >= 0) {
251         debug(DBG_DBG, "clienradputudp: sent UDP of length %d to %s port %d", len, addr2string(ai->ai_addr), port_get(ai->ai_addr));
252         return 1;
253     }
254
255     debug(DBG_WARN, "clientradputudp: send failed");
256     return 0;
257 }
258
259 void *udpclientrd(void *arg) {
260     struct server *server;
261     unsigned char *buf;
262     int *s = (int *)arg;
263
264     for (;;) {
265         server = NULL;
266         buf = radudpget(*s, NULL, &server, NULL);
267         replyh(server, buf);
268     }
269 }
270
271 void *udpserverrd(void *arg) {
272     struct request *rq;
273     int *sp = (int *)arg;
274
275     for (;;) {
276         rq = newrequest();
277         if (!rq) {
278             sleep(5); /* malloc failed */
279             continue;
280         }
281         rq->buf = radudpget(*sp, &rq->from, NULL, &rq->udpport);
282         rq->udpsock = *sp;
283         radsrv(rq);
284     }
285     free(sp);
286     return NULL;
287 }
288
289 void *udpserverwr(void *arg) {
290     struct gqueue *replyq = (struct gqueue *)arg;
291     struct request *reply;
292     struct sockaddr_storage to;
293
294     for (;;) {
295         pthread_mutex_lock(&replyq->mutex);
296         while (!(reply = (struct request *)list_shift(replyq->entries))) {
297             debug(DBG_DBG, "udp server writer, waiting for signal");
298             pthread_cond_wait(&replyq->cond, &replyq->mutex);
299             debug(DBG_DBG, "udp server writer, got signal");
300         }
301         /* do this with lock, udpserverrd may set from = NULL if from expires */
302         if (reply->from)
303             memcpy(&to, reply->from->addr, SOCKADDRP_SIZE(reply->from->addr));
304         pthread_mutex_unlock(&replyq->mutex);
305         if (reply->from) {
306             port_set((struct sockaddr *)&to, reply->udpport);
307             if (sendto(reply->udpsock, reply->replybuf, RADLEN(reply->replybuf), 0, (struct sockaddr *)&to, SOCKADDR_SIZE(to)) < 0)
308                 debug(DBG_WARN, "udpserverwr: send failed");
309         }
310         debug(DBG_DBG, "udpserverwr: refcount %d", reply->refcount);
311         freerq(reply);
312     }
313 }
314
315 void addclientudp(struct client *client) {
316     client->replyq = server_replyq;
317 }
318
319 void addserverextraudp(struct clsrvconf *conf) {
320     assert(list_first(conf->hostports) != NULL);
321     switch (((struct hostportres *)list_first(conf->hostports)->data)->addrinfo->ai_family) {
322     case AF_INET:
323         if (client4_sock < 0) {
324             client4_sock = bindtoaddr(srcres, AF_INET, 0, 1);
325             if (client4_sock < 0)
326                 debugx(1, DBG_ERR, "addserver: failed to create client socket for server %s", conf->name);
327         }
328         conf->servers->sock = client4_sock;
329         break;
330     case AF_INET6:
331         if (client6_sock < 0) {
332             client6_sock = bindtoaddr(srcres, AF_INET6, 0, 1);
333             if (client6_sock < 0)
334                 debugx(1, DBG_ERR, "addserver: failed to create client socket for server %s", conf->name);
335         }
336         conf->servers->sock = client6_sock;
337         break;
338     default:
339         debugx(1, DBG_ERR, "addserver: unsupported address family");
340     }
341 }
342
343 void initextraudp() {
344     pthread_t cl4th, cl6th, srvth;
345
346     if (srcres) {
347         freeaddrinfo(srcres);
348         srcres = NULL;
349     }
350
351     if (client4_sock >= 0)
352         if (pthread_create(&cl4th, NULL, udpclientrd, (void *)&client4_sock))
353             debugx(1, DBG_ERR, "pthread_create failed");
354     if (client6_sock >= 0)
355         if (pthread_create(&cl6th, NULL, udpclientrd, (void *)&client6_sock))
356             debugx(1, DBG_ERR, "pthread_create failed");
357
358     if (find_clconf_type(handle, NULL)) {
359         server_replyq = newqueue();
360         if (pthread_create(&srvth, NULL, udpserverwr, (void *)server_replyq))
361             debugx(1, DBG_ERR, "pthread_create failed");
362     }
363 }
364 #else
365 const struct protodefs *udpinit(uint8_t h) {
366     return NULL;
367 }
368 #endif
369
370 /* Local Variables: */
371 /* c-file-style: "stroustrup" */
372 /* End: */