Pulled fix from branch_1_1
[freeradius.git] / src / lib / udpfromto.c
1 /*
2  *   This library is free software; you can redistribute it and/or
3  *   modify it under the terms of the GNU Lesser General Public
4  *   License as published by the Free Software Foundation; either
5  *   version 2.1 of the License, or (at your option) any later version.
6  *
7  *   This library is distributed in the hope that it will be useful,
8  *   but WITHOUT ANY WARRANTY; without even the implied warranty of
9  *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
10  *   Lesser General Public License for more details.
11  *
12  *   You should have received a copy of the GNU Lesser General Public
13  *   License along with this library; if not, write to the Free Software
14  *   Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
15  *
16  *  Helper functions to get/set addresses of UDP packets
17  *  based on recvfromto by Miquel van Smoorenburg
18  *
19  * recvfromto   Like recvfrom, but also stores the destination
20  *              IP address. Useful on multihomed hosts.
21  *
22  *              Should work on Linux and BSD.
23  *
24  *              Copyright (C) 2002 Miquel van Smoorenburg.
25  *
26  *              This program is free software; you can redistribute it and/or
27  *              modify it under the terms of the GNU Lesser General Public
28  *              License as published by the Free Software Foundation; either
29  *              version 2 of the License, or (at your option) any later version.
30  *      Copyright (C) 2007 Alan DeKok <aland@deployingradius.com>
31  *
32  * sendfromto   added 18/08/2003, Jan Berkel <jan@sitadelle.com>
33  *              Works on Linux and FreeBSD (5.x)
34  *
35  * Version: $Id$
36  */
37
38 #include <freeradius-devel/ident.h>
39 RCSID("$Id$")
40
41 #include <freeradius-devel/udpfromto.h>
42
43 #ifdef WITH_UDPFROMTO
44
45 #ifdef HAVE_SYS_UIO_H
46 #include <sys/uio.h>
47 #endif
48
49 #include <fcntl.h>
50
51 int udpfromto_init(int s)
52 {
53         int proto, flag, opt = 1;
54         struct sockaddr_storage si;
55         socklen_t si_len = sizeof(si);
56
57         errno = ENOSYS;
58
59         proto = -1;
60
61         if (getsockname(s, (struct sockaddr *) &si, &si_len) < 0) {
62                 return -1;
63         }
64
65         if (si.ss_family == AF_INET) {
66 #ifdef HAVE_IP_PKTINFO
67                 /*
68                  *      Linux
69                  */
70                 proto = SOL_IP;
71                 flag = IP_PKTINFO;
72 #endif
73                 
74 #ifdef HAVE_IP_RECVDSTADDR
75                 /*
76                  *      Set the IP_RECVDSTADDR option (BSD).  Note:
77                  *      IP_RECVDSTADDR == IP_SENDSRCADDR
78                  */
79                 proto = IPPROTO_IP;
80                 flag = IP_RECVDSTADDR;
81 #endif
82
83 #ifdef HAVE_AF_INET6
84         } else if (si.ss_family == AF_INET6) {
85 #ifdef HAVE_IN6_PKTINFO
86                 /*
87                  *      This should actually be standard IPv6
88                  */
89                 proto = IPPROTO_IPV6;
90                 flag = IPV6_PKTINFO;
91 #endif
92 #endif
93         } else {
94                 /*
95                  *      Unknown AF.
96                  */
97                 return -1;
98         }
99                 
100         /*
101          *      Unsupported.  Don't worry about it.
102          */
103         if (proto < 0) return 0;
104
105         return setsockopt(s, proto, flag, &opt, sizeof(opt));
106 }
107
108 int recvfromto(int s, void *buf, size_t len, int flags,
109                struct sockaddr *from, socklen_t *fromlen,
110                struct sockaddr *to, socklen_t *tolen)
111 {
112         struct msghdr msgh;
113         struct cmsghdr *cmsg;
114         struct iovec iov;
115         char cbuf[256];
116         int err;
117         struct sockaddr_storage si;
118         socklen_t si_len = sizeof(si);
119
120 #if !defined(HAVE_IP_PKTINFO) && !defined(HAVE_IP_RECVDSTADDR) && !defined (HAVE_IN6_PKTINFO)
121         /*
122          *      If the recvmsg() flags aren't defined, fall back to
123          *      using recvfrom().
124          */
125         to = NULL;
126 #endif
127
128         /*
129          *      Catch the case where the caller passes invalid arguments.
130          */
131         if (!to || !tolen) return recvfrom(s, buf, len, flags, from, fromlen);
132
133
134         /*
135          *      recvmsg doesn't provide sin_port so we have to
136          *      retrieve it using getsockname().
137          */
138         if (getsockname(s, (struct sockaddr *)&si, &si_len) < 0) {
139                 return -1;
140         }
141
142         /*
143          *      Initialize the 'to' address.  It may be INADDR_ANY here,
144          *      with a more specific address given by recvmsg(), below.
145          */
146         if (si.ss_family == AF_INET) {
147                 struct sockaddr_in *dst = (struct sockaddr_in *) to;
148                 struct sockaddr_in *src = (struct sockaddr_in *) &si;
149                 
150                 if (*tolen < sizeof(*dst)) {
151                         errno = EINVAL;
152                         return -1;
153                 }
154                 *tolen = sizeof(*dst);
155                 *dst = *src;
156
157 #if !defined(HAVE_IP_PKTINFO) && !defined(HAVE_IP_RECVDSTADDR)
158                 /*
159                  *      recvmsg() flags aren't defined.  Use recvfrom()
160                  */
161                 return recvfrom(s, buf, len, flags, from, fromlen);
162 #endif
163         }
164
165 #ifdef AF_INET6
166         else if (si.ss_family == AF_INET6) {
167                 struct sockaddr_in6 *dst = (struct sockaddr_in6 *) to;
168                 struct sockaddr_in6 *src = (struct sockaddr_in6 *) &si;
169                 
170                 if (*tolen < sizeof(*dst)) {
171                         errno = EINVAL;
172                         return -1;
173                 }
174                 *tolen = sizeof(*dst);
175                 *dst = *src;
176
177 #if !defined(HAVE_IN6_PKTINFO)
178                 /*
179                  *      recvmsg() flags aren't defined.  Use recvfrom()
180                  */
181                 return recvfrom(s, buf, len, flags, from, fromlen);
182 #endif
183         }
184 #endif
185         /*
186          *      Unknown address family.
187          */             
188         else {
189                 errno = EINVAL;
190                 return -1;
191         }
192
193         /* Set up iov and msgh structures. */
194         memset(&msgh, 0, sizeof(struct msghdr));
195         iov.iov_base = buf;
196         iov.iov_len  = len;
197         msgh.msg_control = cbuf;
198         msgh.msg_controllen = sizeof(cbuf);
199         msgh.msg_name = from;
200         msgh.msg_namelen = fromlen ? *fromlen : 0;
201         msgh.msg_iov  = &iov;
202         msgh.msg_iovlen = 1;
203         msgh.msg_flags = 0;
204
205         /* Receive one packet. */
206         if ((err = recvmsg(s, &msgh, flags)) < 0) {
207                 return err;
208         }
209
210         if (fromlen) *fromlen = msgh.msg_namelen;
211
212         /* Process auxiliary received data in msgh */
213         for (cmsg = CMSG_FIRSTHDR(&msgh);
214              cmsg != NULL;
215              cmsg = CMSG_NXTHDR(&msgh,cmsg)) {
216
217 #ifdef HAVE_IP_PKTINFO
218                 if ((cmsg->cmsg_level == SOL_IP) &&
219                     (cmsg->cmsg_type == IP_PKTINFO)) {
220                         struct in_pktinfo *i =
221                                 (struct in_pktinfo *) CMSG_DATA(cmsg);
222                         ((struct sockaddr_in *)to)->sin_addr = i->ipi_addr;
223                         *tolen = sizeof(struct sockaddr_in);
224                         break;
225                 }
226 #endif
227
228 #ifdef HAVE_IP_RECVDSTADDR
229                 if ((cmsg->cmsg_level == IPPROTO_IP)
230                     (cmsg->cmsg_type == IP_RECVDSTADDR)) {
231                         struct in_addr *i = (struct in_addr *) CMSG_DATA(cmsg);
232                         ((struct sockaddr_in *)to)->sin_addr = *i;
233                         *tolen = sizeof(struct sockaddr_in);
234                         break;
235                 }
236 #endif
237
238 #ifdef HAVE_IN6_PKTINFO
239                 if ((cmsg->cmsg_level == IPPROTO_IPV6) &&
240                     (cmsg->cmsg_type == IPV6_PKTINFO)) {
241                         struct in6_pktinfo *i =
242                                 (struct in6_pktinfo *) CMSG_DATA(cmsg);
243                         ((struct sockaddr_in6 *)to)->sin6_addr = i->ipi6_addr;
244                         *tolen = sizeof(struct sockaddr_in6);
245                         break;
246                 }
247 #endif
248         }
249
250         return err;
251 }
252
253 int sendfromto(int s, void *buf, size_t len, int flags,
254                struct sockaddr *from, socklen_t fromlen,
255                struct sockaddr *to, socklen_t tolen)
256 {
257         struct msghdr msgh;
258         struct cmsghdr *cmsg;
259         struct iovec iov;
260         char cbuf[256];
261
262 #if !defined(HAVE_IP_PKTINFO) && !defined(HAVE_IP_SENDSRCADDR) && !defined(HAVE_IN6_PKTINFO)
263         /*
264          *      If the sendmsg() flags aren't defined, fall back to
265          *      using sendto().
266          */
267         from = NULL;
268 #endif
269
270         /*
271          *      Catch the case where the caller passes invalid arguments.
272          */
273         if (!from || (fromlen == 0) || (from->sa_family == AF_UNSPEC)) {
274                 return sendto(s, buf, len, flags, to, tolen);
275         }
276
277         /* Set up iov and msgh structures. */
278         memset(&msgh, 0, sizeof(struct msghdr));
279         iov.iov_base = buf;
280         iov.iov_len = len;
281         msgh.msg_iov = &iov;
282         msgh.msg_iovlen = 1;
283         msgh.msg_name = to;
284         msgh.msg_namelen = tolen;
285
286         if (from->sa_family == AF_INET) {
287                 struct sockaddr_in *s4 = (struct sockaddr_in *) from;
288
289 #ifdef HAVE_IP_PKTINFO
290                 struct in_pktinfo *pkt;
291
292                 msgh.msg_control = cbuf;
293                 msgh.msg_controllen = CMSG_SPACE(sizeof(*pkt));
294
295                 cmsg = CMSG_FIRSTHDR(&msgh);
296                 cmsg->cmsg_level = SOL_IP;
297                 cmsg->cmsg_type = IP_PKTINFO;
298                 cmsg->cmsg_len = CMSG_LEN(sizeof(*pkt));
299
300                 pkt = (struct in_pktinfo *) CMSG_DATA(cmsg);
301                 memset(pkt, 0, sizeof(*pkt));
302                 pkt->ipi_spec_dst = s4->sin_addr;
303 #endif
304
305 #ifdef HAVE_IP_SENDSRCADDR
306                 struct in_addr *in;
307
308                 msgh.msg_control = cbuf;
309                 msgh.msg_controllen = CMSG_SPACE(sizeof(*in));
310
311                 cmsg = CMSG_FIRSTHDR(&msgh);
312                 cmsg->cmsg_level = IPPROTO_IP;
313                 cmsg->cmsg_type = IP_SENDSRCADDR;
314                 cmsg->cmsg_len = CMSG_LEN(sizeof(*in));
315
316                 in = (struct in_addr *) CMSG_DATA(cmsg);
317                 *in = s4->sin_addr;
318 #endif
319         }
320
321 #ifdef AF_INET6
322         else if (from->sa_family == AF_INET6) {
323 #ifdef HAVE_IN6_PKTINFO
324                 struct sockaddr_in6 *s6 = (struct sockaddr_in6 *) from;
325
326                 struct in6_pktinfo *pkt;
327
328                 msgh.msg_control = cbuf;
329                 msgh.msg_controllen = CMSG_SPACE(sizeof(*pkt));
330
331                 cmsg = CMSG_FIRSTHDR(&msgh);
332                 cmsg->cmsg_level = IPPROTO_IPV6;
333                 cmsg->cmsg_type = IPV6_PKTINFO;
334                 cmsg->cmsg_len = CMSG_LEN(sizeof(*pkt));
335
336                 pkt = (struct in6_pktinfo *) CMSG_DATA(cmsg);
337                 memset(pkt, 0, sizeof(*pkt));
338                 pkt->ipi6_addr = s6->sin6_addr;
339 #endif
340         }
341 #endif
342
343         /*
344          *      Unknown address family.
345          */             
346         else {
347                 errno = EINVAL;
348                 return -1;
349         }
350
351         return sendmsg(s, &msgh, flags);
352 }
353
354
355 #ifdef TESTING
356 /*
357  *      Small test program to test recvfromto/sendfromto
358  *
359  *      use a virtual IP address as first argument to test
360  *
361  *      reply packet should originate from virtual IP and not
362  *      from the default interface the alias is bound to
363  */
364
365 #include <stdio.h>
366 #include <stdlib.h>
367 #include <arpa/inet.h>
368 #include <sys/types.h>
369 #include <sys/wait.h>
370
371 #define DEF_PORT 20000          /* default port to listen on */
372 #define DESTIP "127.0.0.1"      /* send packet to localhost per default */
373 #define TESTSTRING "foo"        /* what to send */
374 #define TESTLEN 4                       /* 4 bytes */
375
376 int main(int argc, char **argv)
377 {
378         struct sockaddr_in from, to, in;
379         char buf[TESTLEN];
380         char *destip = DESTIP;
381         int port = DEF_PORT;
382         int n, server_socket, client_socket, fl, tl, pid;
383
384         if (argc > 1) destip = argv[1];
385         if (argc > 2) port = atoi(argv[2]);
386
387         in.sin_family = AF_INET;
388         in.sin_addr.s_addr = INADDR_ANY;
389         in.sin_port = htons(port);
390         fl = tl = sizeof(struct sockaddr_in);
391         memset(&from, 0, sizeof(from));
392         memset(&to,   0, sizeof(to));
393
394         switch(pid = fork()) {
395                 case -1:
396                         perror("fork");
397                         return 0;
398                 case 0:
399                         /* child */
400                         usleep(100000);
401                         goto client;
402         }
403
404         /* parent: server */
405         server_socket = socket(PF_INET, SOCK_DGRAM, 0);
406         if (udpfromto_init(server_socket) != 0) {
407                 perror("udpfromto_init\n");
408                 waitpid(pid, NULL, WNOHANG);
409                 return 0;
410         }
411
412         if (bind(server_socket, (struct sockaddr *)&in, sizeof(in)) < 0) {
413                 perror("server: bind");
414                 waitpid(pid, NULL, WNOHANG);
415                 return 0;
416         }
417
418         printf("server: waiting for packets on INADDR_ANY:%d\n", port);
419         if ((n = recvfromto(server_socket, buf, sizeof(buf), 0,
420             (struct sockaddr *)&from, &fl,
421             (struct sockaddr *)&to, &tl)) < 0) {
422                 perror("server: recvfromto");
423                 waitpid(pid, NULL, WNOHANG);
424                 return 0;
425         }
426
427         printf("server: received a packet of %d bytes [%s] ", n, buf);
428         printf("(src ip:port %s:%d ",
429                 inet_ntoa(from.sin_addr), ntohs(from.sin_port));
430         printf(" dst ip:port %s:%d)\n",
431                 inet_ntoa(to.sin_addr), ntohs(to.sin_port));
432
433         printf("server: replying from address packet was received on to source address\n");
434
435         if ((n = sendfromto(server_socket, buf, n, 0,
436                 (struct sockaddr *)&to, tl,
437                 (struct sockaddr *)&from, fl)) < 0) {
438                 perror("server: sendfromto");
439         }
440
441         waitpid(pid, NULL, 0);
442         return 0;
443
444 client:
445         close(server_socket);
446         client_socket = socket(PF_INET, SOCK_DGRAM, 0);
447         if (udpfromto_init(client_socket) != 0) {
448                 perror("udpfromto_init");
449                 _exit(0);
450         }
451         /* bind client on different port */
452         in.sin_port = htons(port+1);
453         if (bind(client_socket, (struct sockaddr *)&in, sizeof(in)) < 0) {
454                 perror("client: bind");
455                 _exit(0);
456         }
457
458         in.sin_port = htons(port);
459         in.sin_addr.s_addr = inet_addr(destip);
460
461         printf("client: sending packet to %s:%d\n", destip, port);
462         if (sendto(client_socket, TESTSTRING, TESTLEN, 0,
463                         (struct sockaddr *)&in, sizeof(in)) < 0) {
464                 perror("client: sendto");
465                 _exit(0);
466         }
467
468         printf("client: waiting for reply from server on INADDR_ANY:%d\n", port+1);
469
470         if ((n = recvfromto(client_socket, buf, sizeof(buf), 0,
471             (struct sockaddr *)&from, &fl,
472             (struct sockaddr *)&to, &tl)) < 0) {
473                 perror("client: recvfromto");
474                 _exit(0);
475         }
476
477         printf("client: received a packet of %d bytes [%s] ", n, buf);
478         printf("(src ip:port %s:%d",
479                 inet_ntoa(from.sin_addr), ntohs(from.sin_port));
480         printf(" dst ip:port %s:%d)\n",
481                 inet_ntoa(to.sin_addr), ntohs(to.sin_port));
482
483         _exit(0);
484 }
485
486 #endif /* TESTING */
487 #endif /* WITH_UDPFROMTO */