cleaning up code
[radsecproxy.git] / udp.c
1 /*
2  * Copyright (C) 2006-2008 Stig Venaas <venaas@uninett.no>
3  *
4  * Permission to use, copy, modify, and distribute this software for any
5  * purpose with or without fee is hereby granted, provided that the above
6  * copyright notice and this permission notice appear in all copies.
7  */
8
9 #include <signal.h>
10 #include <sys/socket.h>
11 #include <netinet/in.h>
12 #include <netdb.h>
13 #include <string.h>
14 #include <unistd.h>
15 #include <limits.h>
16 #ifdef SYS_SOLARIS9
17 #include <fcntl.h>
18 #endif
19 #include <sys/time.h>
20 #include <sys/types.h>
21 #include <sys/select.h>
22 #include <ctype.h>
23 #include <sys/wait.h>
24 #include <arpa/inet.h>
25 #include <regex.h>
26 #include <pthread.h>
27 #include <openssl/ssl.h>
28 #include "debug.h"
29 #include "list.h"
30 #include "util.h"
31 #include "radsecproxy.h"
32
33 static void setprotoopts(struct commonprotoopts *opts);
34 static char **getlistenerargs();
35 void *udpserverrd(void *arg);
36 int clientradputudp(struct server *server, unsigned char *rad);
37 void addclientudp(struct client *client);
38 void addserverextraudp(struct clsrvconf *conf);
39 void udpsetsrcres();
40 void initextraudp();
41
42 static const struct protodefs protodefs = {
43     "udp",
44     NULL, /* secretdefault */
45     SOCK_DGRAM, /* socktype */
46     "1812", /* portdefault */
47     REQUEST_RETRY_COUNT, /* retrycountdefault */
48     10, /* retrycountmax */
49     REQUEST_RETRY_INTERVAL, /* retryintervaldefault */
50     60, /* retryintervalmax */
51     DUPLICATE_INTERVAL, /* duplicateintervaldefault */
52     setprotoopts, /* setprotoopts */
53     getlistenerargs, /* getlistenerargs */
54     udpserverrd, /* listener */
55     NULL, /* connecter */
56     NULL, /* clientconnreader */
57     clientradputudp, /* clientradput */
58     addclientudp, /* addclient */
59     addserverextraudp, /* addserverextra */
60     udpsetsrcres, /* setsrcres */
61     initextraudp /* initextra */
62 };
63
64 static int client4_sock = -1;
65 static int client6_sock = -1;
66 static struct queue *server_replyq = NULL;
67
68 static struct addrinfo *srcres = NULL;
69 static uint8_t handle;
70 static struct commonprotoopts *protoopts = NULL;
71
72 const struct protodefs *udpinit(uint8_t h) {
73     handle = h;
74     return &protodefs;
75 }
76
77 static void setprotoopts(struct commonprotoopts *opts) {
78     protoopts = opts;
79 }
80
81 static char **getlistenerargs() {
82     return protoopts ? protoopts->listenargs : NULL;
83 }
84
85 void udpsetsrcres() {
86     if (!srcres)
87         srcres = resolve_hostport_addrinfo(handle, protoopts ? protoopts->sourcearg : NULL);
88 }
89
90 void removeudpclientfromreplyq(struct client *c) {
91     struct list_node *n;
92     struct request *r;
93     
94     /* lock the common queue and remove replies for this client */
95     pthread_mutex_lock(&c->replyq->mutex);
96     for (n = list_first(c->replyq->entries); n; n = list_next(n)) {
97         r = (struct request *)n->data;
98         if (r->from == c)
99             r->from = NULL;
100     }
101     pthread_mutex_unlock(&c->replyq->mutex);
102 }       
103
104 /* exactly one of client and server must be non-NULL */
105 /* return who we received from in *client or *server */
106 /* return from in sa if not NULL */
107 unsigned char *radudpget(int s, struct client **client, struct server **server, uint16_t *port) {
108     int cnt, len;
109     unsigned char buf[4], *rad = NULL;
110     struct sockaddr_storage from;
111     struct sockaddr *fromcopy;
112     socklen_t fromlen = sizeof(from);
113     struct clsrvconf *p;
114     struct list_node *node;
115     fd_set readfds;
116     struct client *c = NULL;
117     struct timeval now;
118     
119     for (;;) {
120         if (rad) {
121             free(rad);
122             rad = NULL;
123         }
124         FD_ZERO(&readfds);
125         FD_SET(s, &readfds);
126         if (select(s + 1, &readfds, NULL, NULL, NULL) < 1)
127             continue;
128         cnt = recvfrom(s, buf, 4, MSG_PEEK | MSG_TRUNC, (struct sockaddr *)&from, &fromlen);
129         if (cnt == -1) {
130             debug(DBG_WARN, "radudpget: recv failed");
131             continue;
132         }
133         if (cnt < 20) {
134             debug(DBG_WARN, "radudpget: length too small");
135             recv(s, buf, 4, 0);
136             continue;
137         }
138         
139         p = client
140             ? find_clconf(handle, (struct sockaddr *)&from, NULL)
141             : find_srvconf(handle, (struct sockaddr *)&from, NULL);
142         if (!p) {
143             debug(DBG_WARN, "radudpget: got packet from wrong or unknown UDP peer %s, ignoring", addr2string((struct sockaddr *)&from));
144             recv(s, buf, 4, 0);
145             continue;
146         }
147         
148         len = RADLEN(buf);
149         if (len < 20) {
150             debug(DBG_WARN, "radudpget: length too small");
151             recv(s, buf, 4, 0);
152             continue;
153         }
154             
155         rad = malloc(len);
156         if (!rad) {
157             debug(DBG_ERR, "radudpget: malloc failed");
158             recv(s, buf, 4, 0);
159             continue;
160         }
161         
162         cnt = recv(s, rad, len, MSG_TRUNC);
163         debug(DBG_DBG, "radudpget: got %d bytes from %s", cnt, addr2string((struct sockaddr *)&from));
164
165         if (cnt < len) {
166             debug(DBG_WARN, "radudpget: packet smaller than length field in radius header");
167             continue;
168         }
169         if (cnt > len)
170             debug(DBG_DBG, "radudpget: packet was padded with %d bytes", cnt - len);
171
172         if (client) {
173             *client = NULL;
174             pthread_mutex_lock(p->lock);
175             for (node = list_first(p->clients); node;) {
176                 c = (struct client *)node->data;
177                 node = list_next(node);
178                 if (s != c->sock)
179                     continue;
180                 gettimeofday(&now, NULL);
181                 if (!*client && addr_equal((struct sockaddr *)&from, c->addr)) {
182                     c->expiry = now.tv_sec + 60;
183                     *client = c;
184                 }
185                 if (c->expiry >= now.tv_sec)
186                     continue;
187                 
188                 debug(DBG_DBG, "radudpget: removing expired client (%s)", addr2string(c->addr));
189                 removeudpclientfromreplyq(c);
190                 c->replyq = NULL; /* stop removeclient() from removing common udp replyq */
191                 removelockedclient(c);
192                 break;
193             }
194             if (!*client) {
195                 fromcopy = addr_copy((struct sockaddr *)&from);
196                 if (!fromcopy) {
197                     pthread_mutex_unlock(p->lock);
198                     continue;
199                 }
200                 c = addclient(p, 0);
201                 if (!c) {
202                     free(fromcopy);
203                     pthread_mutex_unlock(p->lock);
204                     continue;
205                 }
206                 c->sock = s;
207                 c->addr = fromcopy;
208                 gettimeofday(&now, NULL);
209                 c->expiry = now.tv_sec + 60;
210                 *client = c;
211             }
212             pthread_mutex_unlock(p->lock);
213         } else if (server)
214             *server = p->servers;
215         break;
216     }
217     if (port)
218         *port = port_get((struct sockaddr *)&from);
219     return rad;
220 }
221
222 int clientradputudp(struct server *server, unsigned char *rad) {
223     size_t len;
224     struct clsrvconf *conf = server->conf;
225     
226     len = RADLEN(rad);
227     if (sendto(server->sock, rad, len, 0, conf->addrinfo->ai_addr, conf->addrinfo->ai_addrlen) >= 0) {
228         debug(DBG_DBG, "clienradputudp: sent UDP of length %d to %s port %d", len, conf->host, port_get(conf->addrinfo->ai_addr));
229         return 1;
230     }
231
232     debug(DBG_WARN, "clientradputudp: send failed");
233     return 0;
234 }
235
236 void *udpclientrd(void *arg) {
237     struct server *server;
238     unsigned char *buf;
239     int *s = (int *)arg;
240     
241     for (;;) {
242         server = NULL;
243         buf = radudpget(*s, NULL, &server, NULL);
244         replyh(server, buf);
245     }
246 }
247
248 void *udpserverrd(void *arg) {
249     struct request *rq;
250     int *sp = (int *)arg;
251     
252     for (;;) {
253         rq = newrequest();
254         if (!rq) {
255             sleep(5); /* malloc failed */
256             continue;
257         }
258         rq->buf = radudpget(*sp, &rq->from, NULL, &rq->udpport);
259         rq->udpsock = *sp;
260         radsrv(rq);
261     }
262     free(sp);
263 }
264
265 void *udpserverwr(void *arg) {
266     struct queue *replyq = (struct queue *)arg;
267     struct request *reply;
268     struct sockaddr_storage to;
269     
270     for (;;) {
271         pthread_mutex_lock(&replyq->mutex);
272         while (!(reply = (struct request *)list_shift(replyq->entries))) {
273             debug(DBG_DBG, "udp server writer, waiting for signal");
274             pthread_cond_wait(&replyq->cond, &replyq->mutex);
275             debug(DBG_DBG, "udp server writer, got signal");
276         }
277         /* do this with lock, udpserverrd may set from = NULL if from expires */
278         if (reply->from)
279             memcpy(&to, reply->from->addr, SOCKADDRP_SIZE(reply->from->addr));
280         pthread_mutex_unlock(&replyq->mutex);
281         if (reply->from) {
282             port_set((struct sockaddr *)&to, reply->udpport);
283             if (sendto(reply->udpsock, reply->replybuf, RADLEN(reply->replybuf), 0, (struct sockaddr *)&to, SOCKADDR_SIZE(to)) < 0)
284                 debug(DBG_WARN, "udpserverwr: send failed");
285         }
286         debug(DBG_DBG, "udpserverwr: refcount %d", reply->refcount);
287         freerq(reply);
288     }
289 }
290
291 void addclientudp(struct client *client) {
292     client->replyq = server_replyq;
293 }
294
295 void addserverextraudp(struct clsrvconf *conf) {
296     switch (conf->addrinfo->ai_family) {
297     case AF_INET:
298         if (client4_sock < 0) {
299             client4_sock = bindtoaddr(srcres, AF_INET, 0, 1);
300             if (client4_sock < 0)
301                 debugx(1, DBG_ERR, "addserver: failed to create client socket for server %s", conf->host);
302         }
303         conf->servers->sock = client4_sock;
304         break;
305     case AF_INET6:
306         if (client6_sock < 0) {
307             client6_sock = bindtoaddr(srcres, AF_INET6, 0, 1);
308             if (client6_sock < 0)
309                 debugx(1, DBG_ERR, "addserver: failed to create client socket for server %s", conf->host);
310         }
311         conf->servers->sock = client6_sock;
312         break;
313     default:
314         debugx(1, DBG_ERR, "addserver: unsupported address family");
315     }
316 }
317
318 void initextraudp() {
319     pthread_t cl4th, cl6th, srvth;
320
321     if (srcres) {
322         freeaddrinfo(srcres);
323         srcres = NULL;
324     }
325     
326     if (client4_sock >= 0)
327         if (pthread_create(&cl4th, NULL, udpclientrd, (void *)&client4_sock))
328             debugx(1, DBG_ERR, "pthread_create failed");
329     if (client6_sock >= 0)
330         if (pthread_create(&cl6th, NULL, udpclientrd, (void *)&client6_sock))
331             debugx(1, DBG_ERR, "pthread_create failed");
332
333     if (find_clconf_type(handle, NULL)) {
334         server_replyq = newqueue();
335         if (pthread_create(&srvth, NULL, udpserverwr, (void *)server_replyq))
336             debugx(1, DBG_ERR, "pthread_create failed");
337     }
338 }