Merge pull request #87 from armitasp/master
[freeradius.git] / src / lib / tcp.c
1 /*
2  * tcp.c        TCP-specific functions.
3  *
4  * Version:     $Id$
5  *
6  *   This program is free software; you can redistribute it and/or modify
7  *   it under the terms of the GNU General Public License as published by
8  *   the Free Software Foundation; either version 2 of the License, or
9  *   (at your option) any later version.
10  *
11  *   This program is distributed in the hope that it will be useful,
12  *   but WITHOUT ANY WARRANTY; without even the implied warranty of
13  *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  *   GNU General Public License for more details.
15  *
16  *   You should have received a copy of the GNU General Public License
17  *   along with this program; if not, write to the Free Software
18  *   Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
19  *
20  * Copyright (C) 2009 Dante http://dante.net
21  */
22
23 #include        <freeradius-devel/ident.h>
24 RCSID("$Id$")
25
26 #include        <freeradius-devel/libradius.h>
27 #include        <freeradius-devel/tcp.h>
28
29 #ifdef WITH_TCP
30
31 /* FIXME: into common RADIUS header? */
32 #define MAX_PACKET_LEN 4096
33
34 /*
35  *      Open a socket on the given IP and port.
36  */
37 int fr_tcp_socket(fr_ipaddr_t *ipaddr, int port)
38 {
39         int sockfd;
40         int on = 1;
41         struct sockaddr_storage salocal;
42         socklen_t       salen;
43
44         if ((port < 0) || (port > 65535)) {
45                 fr_strerror_printf("Port %d is out of allowed bounds", port);
46                 return -1;
47         }
48
49         sockfd = socket(ipaddr->af, SOCK_STREAM, 0);
50         if (sockfd < 0) {
51                 return sockfd;
52         }
53
54         if (fr_nonblock(sockfd) < 0) {
55                 close(sockfd);
56                 return -1;
57         }
58
59         if (!fr_ipaddr2sockaddr(ipaddr, port, &salocal, &salen)) {
60                 close(sockfd);
61                 return -1;
62         }
63
64 #ifdef HAVE_STRUCT_SOCKADDR_IN6
65         if (ipaddr->af == AF_INET6) {
66                 /*
67                  *      Listening on '::' does NOT get you IPv4 to
68                  *      IPv6 mapping.  You've got to listen on an IPv4
69                  *      address, too.  This makes the rest of the server
70                  *      design a little simpler.
71                  */
72 #ifdef IPV6_V6ONLY
73
74                 if (IN6_IS_ADDR_UNSPECIFIED(&ipaddr->ipaddr.ip6addr)) {
75                         setsockopt(sockfd, IPPROTO_IPV6, IPV6_V6ONLY,
76                                    (char *)&on, sizeof(on));
77                 }
78 #endif /* IPV6_V6ONLY */
79         }
80 #endif /* HAVE_STRUCT_SOCKADDR_IN6 */
81
82         if (setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)) < 0) {
83                 fr_strerror_printf("Failed in setsockopt(): %s", strerror(errno));
84                 close(sockfd);
85                 return -1;
86         }
87  
88         if (bind(sockfd, (struct sockaddr *) &salocal, salen) < 0) {
89                 fr_strerror_printf("Failed in bind(): %s", strerror(errno));
90                 close(sockfd);
91                 return -1;
92         }
93
94         if (listen(sockfd, 8) < 0) {
95                 fr_strerror_printf("Failed in listen(): %s", strerror(errno));
96                 close(sockfd);
97                 return -1;
98         }
99
100         return sockfd;
101 }
102
103
104 /*
105  *      Open a socket TO the given IP and port.
106  */
107 int fr_tcp_client_socket(fr_ipaddr_t *src_ipaddr,
108                          fr_ipaddr_t *dst_ipaddr, int dst_port)
109 {
110         int sockfd;
111         struct sockaddr_storage salocal;
112         socklen_t       salen;
113
114         if ((dst_port < 0) || (dst_port > 65535)) {
115                 fr_strerror_printf("Port %d is out of allowed bounds",
116                                    dst_port);
117                 return -1;
118         }
119
120         if (!dst_ipaddr) return -1;
121
122         sockfd = socket(dst_ipaddr->af, SOCK_STREAM, 0);
123         if (sockfd < 0) {
124                 return sockfd;
125         }
126
127 #if 0
128 #ifdef O_NONBLOCK
129         {
130                 int flags;
131                 
132                 if ((flags = fcntl(sockfd, F_GETFL, NULL)) < 0)  {
133                         fr_strerror_printf("Failure getting socket flags: %s",
134                                    strerror(errno));
135                         close(sockfd);
136                         return -1;
137                 }
138                 
139                 flags |= O_NONBLOCK;
140                 if( fcntl(sockfd, F_SETFL, flags) < 0) {
141                         fr_strerror_printf("Failure setting socket flags: %s",
142                                    strerror(errno));
143                         close(sockfd);
144                         return -1;
145                 }
146         }
147 #endif
148 #endif
149         /*
150          *      Allow the caller to bind us to a specific source IP.
151          */
152         if (src_ipaddr && (src_ipaddr->af != AF_UNSPEC)) {
153                 if (!fr_ipaddr2sockaddr(src_ipaddr, 0, &salocal, &salen)) {
154                         close(sockfd);
155                         return -1;
156                 }
157                 
158                 if (bind(sockfd, (struct sockaddr *) &salocal, salen) < 0) {
159                         fr_strerror_printf("Failure binding to IP: %s",
160                                            strerror(errno));
161                         close(sockfd);
162                         return -1;
163                 }
164         }
165
166         if (!fr_ipaddr2sockaddr(dst_ipaddr, dst_port, &salocal, &salen)) {
167                         close(sockfd);
168                 return -1;
169         }
170
171         /*
172          *      FIXME: If EINPROGRESS, then tell the caller that
173          *      somehow.  The caller can then call connect() when the
174          *      socket is ready...
175          */
176         if (connect(sockfd, (struct sockaddr *) &salocal, salen) < 0) {
177                 fr_strerror_printf("Failed in connect(): %s", strerror(errno));
178                 close(sockfd);
179                 return -1;
180         }
181
182         return sockfd;
183 }
184
185
186 RADIUS_PACKET *fr_tcp_recv(int sockfd, int flags)
187 {
188         RADIUS_PACKET *packet = rad_alloc(0);
189
190         if (!packet) return NULL;
191
192         packet->sockfd = sockfd;
193
194         if (fr_tcp_read_packet(packet, flags) != 1) {
195                 rad_free(&packet);
196                 return NULL;
197         }
198
199         return packet;
200 }
201
202
203 /*
204  *      Receives a packet, assuming that the RADIUS_PACKET structure
205  *      has been filled out already.
206  *
207  *      This ASSUMES that the packet is allocated && fields
208  *      initialized.
209  *
210  *      This ASSUMES that the socket is marked as O_NONBLOCK, which
211  *      the function above does set, if your system supports it.
212  *
213  *      Calling this function MAY change sockfd,
214  *      if src_ipaddr.af == AF_UNSPEC.
215  */
216 int fr_tcp_read_packet(RADIUS_PACKET *packet, int flags)
217 {
218         ssize_t len;
219
220         /*
221          *      No data allocated.  Read the 4-byte header into
222          *      a temporary buffer.
223          */
224         if (!packet->data) {
225                 int packet_len;
226
227                 len = recv(packet->sockfd, packet->vector + packet->data_len,
228                            4 - packet->data_len, 0);
229                 if (len == 0) return -2; /* clean close */
230
231 #ifdef ECONNRESET
232                 if ((len < 0) && (errno == ECONNRESET)) { /* forced */
233                         return -2;
234                 }
235 #endif
236
237                 if (len < 0) {
238                         fr_strerror_printf("Error receiving packet: %s",
239                                    strerror(errno));
240                         return -1;
241                 }
242
243                 packet->data_len += len;
244                 if (packet->data_len < 4) { /* want more data */
245                         return 0;
246                 }
247
248                 packet_len = (packet->vector[2] << 8) | packet->vector[3];
249
250                 if (packet_len < AUTH_HDR_LEN) {
251                         fr_strerror_printf("Discarding packet: Smaller than RFC minimum of 20 bytes.");
252                         return -1;
253                 }
254
255                 /*
256                  *      If the packet is too big, then the socket is bad.
257                  */
258                 if (packet_len > MAX_PACKET_LEN) {
259                         fr_strerror_printf("Discarding packet: Larger than RFC limitation of 4096 bytes.");
260                         return -1;
261                 }
262                 
263                 packet->data = malloc(packet_len);
264                 if (!packet->data) {
265                         fr_strerror_printf("Out of memory");
266                         return -1;
267                 }
268
269                 packet->data_len = packet_len;
270                 packet->partial = 4;
271                 memcpy(packet->data, packet->vector, 4);
272         }
273
274         /*
275          *      Try to read more data.
276          */
277         len = recv(packet->sockfd, packet->data + packet->partial,
278                    packet->data_len - packet->partial, 0);
279         if (len == 0) return -2; /* clean close */
280
281 #ifdef ECONNRESET
282         if ((len < 0) && (errno == ECONNRESET)) { /* forced */
283                 return -2;
284         }
285 #endif
286
287         if (len < 0) {
288                 fr_strerror_printf("Error receiving packet: %s", strerror(errno));
289                 return -1;
290         }
291
292         packet->partial += len;
293
294         if (packet->partial < packet->data_len) {
295                 return 0;
296         }
297
298         /*
299          *      See if it's a well-formed RADIUS packet.
300          */
301         if (!rad_packet_ok(packet, flags)) {
302                 return -1;
303         }
304
305         /*
306          *      Explicitly set the VP list to empty.
307          */
308         packet->vps = NULL;
309
310         if (fr_debug_flag) {
311                 char ip_buf[128], buffer[256];
312
313                 if (packet->src_ipaddr.af != AF_UNSPEC) {
314                         inet_ntop(packet->src_ipaddr.af,
315                                   &packet->src_ipaddr.ipaddr,
316                                   ip_buf, sizeof(ip_buf));
317                         snprintf(buffer, sizeof(buffer), "host %s port %d",
318                                  ip_buf, packet->src_port);
319                 } else {
320                         snprintf(buffer, sizeof(buffer), "socket %d",
321                                  packet->sockfd);
322                 }
323
324
325                 if ((packet->code > 0) && (packet->code < FR_MAX_PACKET_CODE)) {
326                         DEBUG("rad_recv: %s packet from %s",
327                               fr_packet_codes[packet->code], buffer);
328                 } else {
329                         DEBUG("rad_recv: Packet from %s code=%d",
330                               buffer, packet->code);
331                 }
332                 DEBUG(", id=%d, length=%zd\n", packet->id, packet->data_len);
333         }
334
335         return 1;               /* done reading the packet */
336 }
337
338 RADIUS_PACKET *fr_tcp_accept(int sockfd)
339 {
340         int newfd;
341         socklen_t salen;
342         RADIUS_PACKET *packet;
343         struct sockaddr_storage src;
344         
345         salen = sizeof(src);
346
347         newfd = accept(sockfd, (struct sockaddr *) &src, &salen);
348         if (newfd < 0) {
349                 /*
350                  *      Non-blocking sockets must handle this.
351                  */
352 #ifdef EWOULDBLOCK
353                 if (errno == EWOULDBLOCK) {
354                         packet = rad_alloc(0);
355                         if (!packet) return NULL;
356
357                         packet->sockfd = sockfd;
358                         packet->src_ipaddr.af = AF_UNSPEC;
359                         return packet;
360                 }
361 #endif
362
363                 return NULL;
364         }
365                 
366         packet = rad_alloc(0);
367         if (!packet) return NULL;
368
369         if (src.ss_family == AF_INET) {
370                 struct sockaddr_in      s4;
371
372                 memcpy(&s4, &src, sizeof(s4));
373                 packet->src_ipaddr.af = AF_INET;
374                 packet->src_ipaddr.ipaddr.ip4addr = s4.sin_addr;
375                 packet->src_port = ntohs(s4.sin_port);
376
377 #ifdef HAVE_STRUCT_SOCKADDR_IN6
378         } else if (src.ss_family == AF_INET6) {
379                 struct sockaddr_in6     s6;
380
381                 memcpy(&s6, &src, sizeof(s6));
382                 packet->src_ipaddr.af = AF_INET6;
383                 packet->src_ipaddr.ipaddr.ip6addr = s6.sin6_addr;
384                 packet->src_port = ntohs(s6.sin6_port);
385
386 #endif
387         } else {
388                 rad_free(&packet);
389                 return NULL;
390         }
391
392         packet->sockfd = newfd;
393
394         /*
395          *      Note: Caller has to set dst_ipaddr && dst_port.
396          */
397         return packet;
398 }
399
400
401 /*
402  *      Writes a packet, assuming it's already been encoded.
403  *
404  *      It returns the number of bytes written, which MAY be less than
405  *      the packet size (data_len).  It is the caller's responsibility
406  *      to check the return code, and to schedule writes again.
407  */
408 ssize_t fr_tcp_write_packet(RADIUS_PACKET *packet)
409 {
410         ssize_t rcode;
411
412         if (!packet || !packet->data) return 0;
413
414         if (packet->partial >= packet->data_len) return packet->data_len;
415
416         rcode = write(packet->sockfd, packet->data + packet->partial,
417                       packet->data_len - packet->partial);
418         if (rcode < 0) return packet->partial; /* ignore most errors */
419
420         packet->partial += rcode;
421
422         return packet->partial;
423 }
424 #endif /* WITH_TCP */