Update copyright and licensing information.
[radsecproxy.git] / tcp.c
1 /* Copyright (c) 2006-2009, Stig Venaas, UNINETT AS.
2  * Copyright (c) 2010, UNINETT AS, NORDUnet A/S.
3  * Copyright (c) 2010-2012, NORDUnet A/S. */
4 /* See LICENSE for licensing information. */
5
6 #include <signal.h>
7 #include <sys/socket.h>
8 #include <netinet/in.h>
9 #include <netdb.h>
10 #include <string.h>
11 #include <unistd.h>
12 #include <limits.h>
13 #ifdef SYS_SOLARIS9
14 #include <fcntl.h>
15 #endif
16 #include <sys/time.h>
17 #include <sys/types.h>
18 #include <sys/select.h>
19 #include <ctype.h>
20 #include <sys/wait.h>
21 #include <arpa/inet.h>
22 #include <regex.h>
23 #include <pthread.h>
24 #include "radsecproxy.h"
25 #include "hostport.h"
26
27 #ifdef RADPROT_TCP
28 #include "debug.h"
29 #include "util.h"
30 static void setprotoopts(struct commonprotoopts *opts);
31 static char **getlistenerargs();
32 void *tcplistener(void *arg);
33 int tcpconnect(struct server *server, struct timeval *when, int timeout, char * text);
34 void *tcpclientrd(void *arg);
35 int clientradputtcp(struct server *server, unsigned char *rad);
36 void tcpsetsrcres();
37
38 static const struct protodefs protodefs = {
39     "tcp",
40     NULL, /* secretdefault */
41     SOCK_STREAM, /* socktype */
42     "1812", /* portdefault */
43     0, /* retrycountdefault */
44     0, /* retrycountmax */
45     REQUEST_RETRY_INTERVAL * REQUEST_RETRY_COUNT, /* retryintervaldefault */
46     60, /* retryintervalmax */
47     DUPLICATE_INTERVAL, /* duplicateintervaldefault */
48     setprotoopts, /* setprotoopts */
49     getlistenerargs, /* getlistenerargs */
50     tcplistener, /* listener */
51     tcpconnect, /* connecter */
52     tcpclientrd, /* clientconnreader */
53     clientradputtcp, /* clientradput */
54     NULL, /* addclient */
55     NULL, /* addserverextra */
56     tcpsetsrcres, /* setsrcres */
57     NULL /* initextra */
58 };
59
60 static struct addrinfo *srcres = NULL;
61 static uint8_t handle;
62 static struct commonprotoopts *protoopts = NULL;
63 const struct protodefs *tcpinit(uint8_t h) {
64     handle = h;
65     return &protodefs;
66 }
67
68 static void setprotoopts(struct commonprotoopts *opts) {
69     protoopts = opts;
70 }
71
72 static char **getlistenerargs() {
73     return protoopts ? protoopts->listenargs : NULL;
74 }
75
76 void tcpsetsrcres() {
77     if (!srcres)
78         srcres =
79             resolvepassiveaddrinfo(protoopts ? protoopts->sourcearg : NULL,
80                                    AF_UNSPEC, NULL, protodefs.socktype);
81 }
82
83 int tcpconnect(struct server *server, struct timeval *when, int timeout, char *text) {
84     struct timeval now;
85     time_t elapsed;
86
87     debug(DBG_DBG, "tcpconnect: called from %s", text);
88     pthread_mutex_lock(&server->lock);
89     if (when && memcmp(&server->lastconnecttry, when, sizeof(struct timeval))) {
90         /* already reconnected, nothing to do */
91         debug(DBG_DBG, "tcpconnect(%s): seems already reconnected", text);
92         pthread_mutex_unlock(&server->lock);
93         return 1;
94     }
95
96     for (;;) {
97         gettimeofday(&now, NULL);
98         elapsed = now.tv_sec - server->lastconnecttry.tv_sec;
99         if (timeout && server->lastconnecttry.tv_sec && elapsed > timeout) {
100             debug(DBG_DBG, "tcpconnect: timeout");
101             if (server->sock >= 0)
102                 close(server->sock);
103             pthread_mutex_unlock(&server->lock);
104             return 0;
105         }
106         if (server->connectionok) {
107             server->connectionok = 0;
108             sleep(2);
109         } else if (elapsed < 1)
110             sleep(2);
111         else if (elapsed < 60) {
112             debug(DBG_INFO, "tcpconnect: sleeping %lds", elapsed);
113             sleep(elapsed);
114         } else if (elapsed < 100000) {
115             debug(DBG_INFO, "tcpconnect: sleeping %ds", 60);
116             sleep(60);
117         } else
118             server->lastconnecttry.tv_sec = now.tv_sec;  /* no sleep at startup */
119
120         if (server->sock >= 0)
121             close(server->sock);
122         if ((server->sock = connecttcphostlist(server->conf->hostports, srcres)) >= 0)
123             break;
124     }
125     server->connectionok = 1;
126     gettimeofday(&server->lastconnecttry, NULL);
127     pthread_mutex_unlock(&server->lock);
128     return 1;
129 }
130
131 /* timeout in seconds, 0 means no timeout (blocking), returns when num bytes have been read, or timeout */
132 /* returns 0 on timeout, -1 on error and num if ok */
133 int tcpreadtimeout(int s, unsigned char *buf, int num, int timeout) {
134     int ndesc, cnt, len;
135     fd_set readfds, writefds;
136     struct timeval timer;
137
138     if (s < 0)
139         return -1;
140     /* make socket non-blocking? */
141     for (len = 0; len < num; len += cnt) {
142         FD_ZERO(&readfds);
143         FD_SET(s, &readfds);
144         writefds = readfds;
145         if (timeout) {
146             timer.tv_sec = timeout;
147             timer.tv_usec = 0;
148         }
149         ndesc = select(s + 1, &readfds, &writefds, NULL, timeout ? &timer : NULL);
150         if (ndesc < 1)
151             return ndesc;
152
153         cnt = read(s, buf + len, num - len);
154         if (cnt <= 0)
155             return -1;
156     }
157     return num;
158 }
159
160 /* timeout in seconds, 0 means no timeout (blocking) */
161 unsigned char *radtcpget(int s, int timeout) {
162     int cnt, len;
163     unsigned char buf[4], *rad;
164
165     for (;;) {
166         cnt = tcpreadtimeout(s, buf, 4, timeout);
167         if (cnt < 1) {
168             debug(DBG_DBG, cnt ? "radtcpget: connection lost" : "radtcpget: timeout");
169             return NULL;
170         }
171
172         len = RADLEN(buf);
173         rad = malloc(len);
174         if (!rad) {
175             debug(DBG_ERR, "radtcpget: malloc failed");
176             continue;
177         }
178         memcpy(rad, buf, 4);
179
180         cnt = tcpreadtimeout(s, rad + 4, len - 4, timeout);
181         if (cnt < 1) {
182             debug(DBG_DBG, cnt ? "radtcpget: connection lost" : "radtcpget: timeout");
183             free(rad);
184             return NULL;
185         }
186
187         if (len >= 20)
188             break;
189
190         free(rad);
191         debug(DBG_WARN, "radtcpget: packet smaller than minimum radius size");
192     }
193
194     debug(DBG_DBG, "radtcpget: got %d bytes", len);
195     return rad;
196 }
197
198 int clientradputtcp(struct server *server, unsigned char *rad) {
199     int cnt;
200     size_t len;
201     struct clsrvconf *conf = server->conf;
202
203     if (!server->connectionok)
204         return 0;
205     len = RADLEN(rad);
206     if ((cnt = write(server->sock, rad, len)) <= 0) {
207         debug(DBG_ERR, "clientradputtcp: write error");
208         return 0;
209     }
210     debug(DBG_DBG, "clientradputtcp: Sent %d bytes, Radius packet of length %d to TCP peer %s", cnt, len, conf->name);
211     return 1;
212 }
213
214 void *tcpclientrd(void *arg) {
215     struct server *server = (struct server *)arg;
216     unsigned char *buf;
217     struct timeval lastconnecttry;
218
219     for (;;) {
220         /* yes, lastconnecttry is really necessary */
221         lastconnecttry = server->lastconnecttry;
222         buf = radtcpget(server->sock, 0);
223         if (!buf) {
224             tcpconnect(server, &lastconnecttry, 0, "tcpclientrd");
225             continue;
226         }
227
228         replyh(server, buf);
229     }
230     server->clientrdgone = 1;
231     return NULL;
232 }
233
234 void *tcpserverwr(void *arg) {
235     int cnt;
236     struct client *client = (struct client *)arg;
237     struct gqueue *replyq;
238     struct request *reply;
239
240     debug(DBG_DBG, "tcpserverwr: starting for %s", addr2string(client->addr));
241     replyq = client->replyq;
242     for (;;) {
243         pthread_mutex_lock(&replyq->mutex);
244         while (!list_first(replyq->entries)) {
245             if (client->sock >= 0) {
246                 debug(DBG_DBG, "tcpserverwr: waiting for signal");
247                 pthread_cond_wait(&replyq->cond, &replyq->mutex);
248                 debug(DBG_DBG, "tcpserverwr: got signal");
249             }
250             if (client->sock < 0) {
251                 /* s might have changed while waiting */
252                 pthread_mutex_unlock(&replyq->mutex);
253                 debug(DBG_DBG, "tcpserverwr: exiting as requested");
254                 pthread_exit(NULL);
255             }
256         }
257         reply = (struct request *)list_shift(replyq->entries);
258         pthread_mutex_unlock(&replyq->mutex);
259         cnt = write(client->sock, reply->replybuf, RADLEN(reply->replybuf));
260         if (cnt > 0)
261             debug(DBG_DBG, "tcpserverwr: sent %d bytes, Radius packet of length %d to %s",
262                   cnt, RADLEN(reply->replybuf), addr2string(client->addr));
263         else
264             debug(DBG_ERR, "tcpserverwr: write error for %s", addr2string(client->addr));
265         freerq(reply);
266     }
267 }
268
269 void tcpserverrd(struct client *client) {
270     struct request *rq;
271     uint8_t *buf;
272     pthread_t tcpserverwrth;
273
274     debug(DBG_DBG, "tcpserverrd: starting for %s", addr2string(client->addr));
275
276     if (pthread_create(&tcpserverwrth, NULL, tcpserverwr, (void *)client)) {
277         debug(DBG_ERR, "tcpserverrd: pthread_create failed");
278         return;
279     }
280
281     for (;;) {
282         buf = radtcpget(client->sock, 0);
283         if (!buf) {
284             debug(DBG_ERR, "tcpserverrd: connection from %s lost", addr2string(client->addr));
285             break;
286         }
287         debug(DBG_DBG, "tcpserverrd: got Radius message from %s", addr2string(client->addr));
288         rq = newrequest();
289         if (!rq) {
290             free(buf);
291             continue;
292         }
293         rq->buf = buf;
294         rq->from = client;
295         if (!radsrv(rq)) {
296             debug(DBG_ERR, "tcpserverrd: message authentication/validation failed, closing connection from %s", addr2string(client->addr));
297             break;
298         }
299     }
300
301     /* stop writer by setting s to -1 and give signal in case waiting for data */
302     client->sock = -1;
303     pthread_mutex_lock(&client->replyq->mutex);
304     pthread_cond_signal(&client->replyq->cond);
305     pthread_mutex_unlock(&client->replyq->mutex);
306     debug(DBG_DBG, "tcpserverrd: waiting for writer to end");
307     pthread_join(tcpserverwrth, NULL);
308     debug(DBG_DBG, "tcpserverrd: reader for %s exiting", addr2string(client->addr));
309 }
310 void *tcpservernew(void *arg) {
311     int s;
312     struct sockaddr_storage from;
313     socklen_t fromlen = sizeof(from);
314     struct clsrvconf *conf;
315     struct client *client;
316
317     s = *(int *)arg;
318     if (getpeername(s, (struct sockaddr *)&from, &fromlen)) {
319         debug(DBG_DBG, "tcpservernew: getpeername failed, exiting");
320         goto exit;
321     }
322     debug(DBG_WARN, "tcpservernew: incoming TCP connection from %s", addr2string((struct sockaddr *)&from));
323
324     conf = find_clconf(handle, (struct sockaddr *)&from, NULL);
325     if (conf) {
326         client = addclient(conf, 1);
327         if (client) {
328             client->sock = s;
329             client->addr = addr_copy((struct sockaddr *)&from);
330             tcpserverrd(client);
331             removeclient(client);
332         } else
333             debug(DBG_WARN, "tcpservernew: failed to create new client instance");
334     } else
335         debug(DBG_WARN, "tcpservernew: ignoring request, no matching TCP client");
336
337 exit:
338     shutdown(s, SHUT_RDWR);
339     close(s);
340     pthread_exit(NULL);
341 }
342
343 void *tcplistener(void *arg) {
344     pthread_t tcpserverth;
345     int s, *sp = (int *)arg;
346     struct sockaddr_storage from;
347     socklen_t fromlen = sizeof(from);
348
349     listen(*sp, 0);
350
351     for (;;) {
352         s = accept(*sp, (struct sockaddr *)&from, &fromlen);
353         if (s < 0) {
354             debug(DBG_WARN, "accept failed");
355             continue;
356         }
357         if (pthread_create(&tcpserverth, NULL, tcpservernew, (void *)&s)) {
358             debug(DBG_ERR, "tcplistener: pthread_create failed");
359             shutdown(s, SHUT_RDWR);
360             close(s);
361             continue;
362         }
363         pthread_detach(tcpserverth);
364     }
365     free(sp);
366     return NULL;
367 }
368 #else
369 const struct protodefs *tcpinit(uint8_t h) {
370     return NULL;
371 }
372 #endif
373
374 /* Local Variables: */
375 /* c-file-style: "stroustrup" */
376 /* End: */