Move recv() functions to talloc.
[freeradius.git] / src / lib / tcp.c
1 /*
2  * tcp.c        TCP-specific functions.
3  *
4  * Version:     $Id$
5  *
6  *   This program is free software; you can redistribute it and/or modify
7  *   it under the terms of the GNU General Public License as published by
8  *   the Free Software Foundation; either version 2 of the License, or
9  *   (at your option) any later version.
10  *
11  *   This program is distributed in the hope that it will be useful,
12  *   but WITHOUT ANY WARRANTY; without even the implied warranty of
13  *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  *   GNU General Public License for more details.
15  *
16  *   You should have received a copy of the GNU General Public License
17  *   along with this program; if not, write to the Free Software
18  *   Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
19  *
20  * Copyright (C) 2009 Dante http://dante.net
21  */
22
23 #include        <freeradius-devel/ident.h>
24 RCSID("$Id$")
25
26 #include        <freeradius-devel/libradius.h>
27 #include        <freeradius-devel/tcp.h>
28
29 #ifdef WITH_TCP
30
31 /* FIXME: into common RADIUS header? */
32 #define MAX_PACKET_LEN 4096
33
34 /*
35  *      Open a socket on the given IP and port.
36  */
37 int fr_tcp_socket(fr_ipaddr_t *ipaddr, int port)
38 {
39         int sockfd;
40         int on = 1;
41         struct sockaddr_storage salocal;
42         socklen_t       salen;
43
44         if ((port < 0) || (port > 65535)) {
45                 fr_strerror_printf("Port %d is out of allowed bounds", port);
46                 return -1;
47         }
48
49         sockfd = socket(ipaddr->af, SOCK_STREAM, 0);
50         if (sockfd < 0) {
51                 return sockfd;
52         }
53
54         if (fr_nonblock(sockfd) < 0) {
55                 close(sockfd);
56                 return -1;
57         }
58
59         if (!fr_ipaddr2sockaddr(ipaddr, port, &salocal, &salen)) {
60                 close(sockfd);
61                 return -1;
62         }
63
64 #ifdef HAVE_STRUCT_SOCKADDR_IN6
65         if (ipaddr->af == AF_INET6) {
66                 /*
67                  *      Listening on '::' does NOT get you IPv4 to
68                  *      IPv6 mapping.  You've got to listen on an IPv4
69                  *      address, too.  This makes the rest of the server
70                  *      design a little simpler.
71                  */
72 #ifdef IPV6_V6ONLY
73
74                 if (IN6_IS_ADDR_UNSPECIFIED(&ipaddr->ipaddr.ip6addr)) {
75                         if (setsockopt(sockfd, IPPROTO_IPV6, IPV6_V6ONLY,
76                                        (char *)&on, sizeof(on)) < 0) {
77                                 fr_strerror_printf("Failed in setsockopt(): %s",
78                                                    strerror(errno));
79                                 close(sockfd);
80                                 return -1;
81                         }
82                 }
83 #endif /* IPV6_V6ONLY */
84         }
85 #endif /* HAVE_STRUCT_SOCKADDR_IN6 */
86
87         if (setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)) < 0) {
88                 fr_strerror_printf("Failed in setsockopt(): %s", strerror(errno));
89                 close(sockfd);
90                 return -1;
91         }
92
93         if (bind(sockfd, (struct sockaddr *) &salocal, salen) < 0) {
94                 fr_strerror_printf("Failed in bind(): %s", strerror(errno));
95                 close(sockfd);
96                 return -1;
97         }
98
99         if (listen(sockfd, 8) < 0) {
100                 fr_strerror_printf("Failed in listen(): %s", strerror(errno));
101                 close(sockfd);
102                 return -1;
103         }
104
105         return sockfd;
106 }
107
108
109 /*
110  *      Open a socket TO the given IP and port.
111  */
112 int fr_tcp_client_socket(fr_ipaddr_t *src_ipaddr,
113                          fr_ipaddr_t *dst_ipaddr, int dst_port)
114 {
115         int sockfd;
116         struct sockaddr_storage salocal;
117         socklen_t       salen;
118
119         if ((dst_port < 0) || (dst_port > 65535)) {
120                 fr_strerror_printf("Port %d is out of allowed bounds",
121                                    dst_port);
122                 return -1;
123         }
124
125         if (!dst_ipaddr) return -1;
126
127         sockfd = socket(dst_ipaddr->af, SOCK_STREAM, 0);
128         if (sockfd < 0) {
129                 return sockfd;
130         }
131
132 #if 0
133 #ifdef O_NONBLOCK
134         {
135                 int flags;
136                 
137                 if ((flags = fcntl(sockfd, F_GETFL, NULL)) < 0)  {
138                         fr_strerror_printf("Failure getting socket flags: %s",
139                                    strerror(errno));
140                         close(sockfd);
141                         return -1;
142                 }
143                 
144                 flags |= O_NONBLOCK;
145                 if( fcntl(sockfd, F_SETFL, flags) < 0) {
146                         fr_strerror_printf("Failure setting socket flags: %s",
147                                    strerror(errno));
148                         close(sockfd);
149                         return -1;
150                 }
151         }
152 #endif
153 #endif
154         /*
155          *      Allow the caller to bind us to a specific source IP.
156          */
157         if (src_ipaddr && (src_ipaddr->af != AF_UNSPEC)) {
158                 if (!fr_ipaddr2sockaddr(src_ipaddr, 0, &salocal, &salen)) {
159                         close(sockfd);
160                         return -1;
161                 }
162                 
163                 if (bind(sockfd, (struct sockaddr *) &salocal, salen) < 0) {
164                         fr_strerror_printf("Failure binding to IP: %s",
165                                            strerror(errno));
166                         close(sockfd);
167                         return -1;
168                 }
169         }
170
171         if (!fr_ipaddr2sockaddr(dst_ipaddr, dst_port, &salocal, &salen)) {
172                         close(sockfd);
173                 return -1;
174         }
175
176         /*
177          *      FIXME: If EINPROGRESS, then tell the caller that
178          *      somehow.  The caller can then call connect() when the
179          *      socket is ready...
180          */
181         if (connect(sockfd, (struct sockaddr *) &salocal, salen) < 0) {
182                 fr_strerror_printf("Failed in connect(): %s", strerror(errno));
183                 close(sockfd);
184                 return -1;
185         }
186
187         return sockfd;
188 }
189
190
191 RADIUS_PACKET *fr_tcp_recv(int sockfd, int flags)
192 {
193         RADIUS_PACKET *packet = rad_alloc(NULL, 0);
194
195         if (!packet) return NULL;
196
197         packet->sockfd = sockfd;
198
199         if (fr_tcp_read_packet(packet, flags) != 1) {
200                 rad_free(&packet);
201                 return NULL;
202         }
203
204         return packet;
205 }
206
207
208 /*
209  *      Receives a packet, assuming that the RADIUS_PACKET structure
210  *      has been filled out already.
211  *
212  *      This ASSUMES that the packet is allocated && fields
213  *      initialized.
214  *
215  *      This ASSUMES that the socket is marked as O_NONBLOCK, which
216  *      the function above does set, if your system supports it.
217  *
218  *      Calling this function MAY change sockfd,
219  *      if src_ipaddr.af == AF_UNSPEC.
220  */
221 int fr_tcp_read_packet(RADIUS_PACKET *packet, int flags)
222 {
223         ssize_t len;
224
225         /*
226          *      No data allocated.  Read the 4-byte header into
227          *      a temporary buffer.
228          */
229         if (!packet->data) {
230                 int packet_len;
231
232                 len = recv(packet->sockfd, packet->vector + packet->data_len,
233                            4 - packet->data_len, 0);
234                 if (len == 0) return -2; /* clean close */
235
236 #ifdef ECONNRESET
237                 if ((len < 0) && (errno == ECONNRESET)) { /* forced */
238                         return -2;
239                 }
240 #endif
241
242                 if (len < 0) {
243                         fr_strerror_printf("Error receiving packet: %s",
244                                    strerror(errno));
245                         return -1;
246                 }
247
248                 packet->data_len += len;
249                 if (packet->data_len < 4) { /* want more data */
250                         return 0;
251                 }
252
253                 packet_len = (packet->vector[2] << 8) | packet->vector[3];
254
255                 if (packet_len < AUTH_HDR_LEN) {
256                         fr_strerror_printf("Discarding packet: Smaller than RFC minimum of 20 bytes.");
257                         return -1;
258                 }
259
260                 /*
261                  *      If the packet is too big, then the socket is bad.
262                  */
263                 if (packet_len > MAX_PACKET_LEN) {
264                         fr_strerror_printf("Discarding packet: Larger than RFC limitation of 4096 bytes.");
265                         return -1;
266                 }
267                 
268                 packet->data = talloc_array(packet, uint8_t, packet_len);
269                 if (!packet->data) {
270                         fr_strerror_printf("Out of memory");
271                         return -1;
272                 }
273
274                 packet->data_len = packet_len;
275                 packet->partial = 4;
276                 memcpy(packet->data, packet->vector, 4);
277         }
278
279         /*
280          *      Try to read more data.
281          */
282         len = recv(packet->sockfd, packet->data + packet->partial,
283                    packet->data_len - packet->partial, 0);
284         if (len == 0) return -2; /* clean close */
285
286 #ifdef ECONNRESET
287         if ((len < 0) && (errno == ECONNRESET)) { /* forced */
288                 return -2;
289         }
290 #endif
291
292         if (len < 0) {
293                 fr_strerror_printf("Error receiving packet: %s", strerror(errno));
294                 return -1;
295         }
296
297         packet->partial += len;
298
299         if (packet->partial < packet->data_len) {
300                 return 0;
301         }
302
303         /*
304          *      See if it's a well-formed RADIUS packet.
305          */
306         if (!rad_packet_ok(packet, flags)) {
307                 return -1;
308         }
309
310         /*
311          *      Explicitly set the VP list to empty.
312          */
313         packet->vps = NULL;
314
315         if (fr_debug_flag) {
316                 char ip_buf[128], buffer[256];
317
318                 if (packet->src_ipaddr.af != AF_UNSPEC) {
319                         inet_ntop(packet->src_ipaddr.af,
320                                   &packet->src_ipaddr.ipaddr,
321                                   ip_buf, sizeof(ip_buf));
322                         snprintf(buffer, sizeof(buffer), "host %s port %d",
323                                  ip_buf, packet->src_port);
324                 } else {
325                         snprintf(buffer, sizeof(buffer), "socket %d",
326                                  packet->sockfd);
327                 }
328
329
330                 if ((packet->code > 0) && (packet->code < FR_MAX_PACKET_CODE)) {
331                         DEBUG("rad_recv: %s packet from %s",
332                               fr_packet_codes[packet->code], buffer);
333                 } else {
334                         DEBUG("rad_recv: Packet from %s code=%d",
335                               buffer, packet->code);
336                 }
337                 DEBUG(", id=%d, length=%zu\n", packet->id, packet->data_len);
338         }
339
340         return 1;               /* done reading the packet */
341 }
342
343 RADIUS_PACKET *fr_tcp_accept(int sockfd)
344 {
345         int newfd;
346         socklen_t salen;
347         RADIUS_PACKET *packet;
348         struct sockaddr_storage src;
349         
350         salen = sizeof(src);
351
352         newfd = accept(sockfd, (struct sockaddr *) &src, &salen);
353         if (newfd < 0) {
354                 /*
355                  *      Non-blocking sockets must handle this.
356                  */
357 #ifdef EWOULDBLOCK
358                 if (errno == EWOULDBLOCK) {
359                         packet = rad_alloc(NULL, 0);
360                         if (!packet) return NULL;
361
362                         packet->sockfd = sockfd;
363                         packet->src_ipaddr.af = AF_UNSPEC;
364                         return packet;
365                 }
366 #endif
367
368                 return NULL;
369         }
370                 
371         packet = rad_alloc(NULL, 0);
372         if (!packet) {
373                 close(newfd);
374                 return NULL;
375         }
376
377         if (src.ss_family == AF_INET) {
378                 struct sockaddr_in      s4;
379
380                 memcpy(&s4, &src, sizeof(s4));
381                 packet->src_ipaddr.af = AF_INET;
382                 packet->src_ipaddr.ipaddr.ip4addr = s4.sin_addr;
383                 packet->src_port = ntohs(s4.sin_port);
384
385 #ifdef HAVE_STRUCT_SOCKADDR_IN6
386         } else if (src.ss_family == AF_INET6) {
387                 struct sockaddr_in6     s6;
388
389                 memcpy(&s6, &src, sizeof(s6));
390                 packet->src_ipaddr.af = AF_INET6;
391                 packet->src_ipaddr.ipaddr.ip6addr = s6.sin6_addr;
392                 packet->src_port = ntohs(s6.sin6_port);
393
394 #endif
395         } else {
396                 rad_free(&packet);
397                 close(newfd);
398                 return NULL;
399         }
400
401         packet->sockfd = newfd;
402
403         /*
404          *      Note: Caller has to set dst_ipaddr && dst_port.
405          */
406         return packet;
407 }
408
409
410 /*
411  *      Writes a packet, assuming it's already been encoded.
412  *
413  *      It returns the number of bytes written, which MAY be less than
414  *      the packet size (data_len).  It is the caller's responsibility
415  *      to check the return code, and to schedule writes again.
416  */
417 ssize_t fr_tcp_write_packet(RADIUS_PACKET *packet)
418 {
419         ssize_t rcode;
420
421         if (!packet || !packet->data) return 0;
422
423         if (packet->partial >= packet->data_len) return packet->data_len;
424
425         rcode = write(packet->sockfd, packet->data + packet->partial,
426                       packet->data_len - packet->partial);
427         if (rcode < 0) return packet->partial; /* ignore most errors */
428
429         packet->partial += rcode;
430
431         return packet->partial;
432 }
433 #endif /* WITH_TCP */