Handle failing rs_context_create().
[libradsec.git] / tcp.c
1 /*
2  * Copyright (C) 2008-2009 Stig Venaas <venaas@uninett.no>
3  *
4  * Permission to use, copy, modify, and distribute this software for any
5  * purpose with or without fee is hereby granted, provided that the above
6  * copyright notice and this permission notice appear in all copies.
7  */
8
9 #include <signal.h>
10 #include <sys/socket.h>
11 #include <netinet/in.h>
12 #include <netdb.h>
13 #include <string.h>
14 #include <unistd.h>
15 #include <limits.h>
16 #ifdef SYS_SOLARIS9
17 #include <fcntl.h>
18 #endif
19 #include <sys/time.h>
20 #include <sys/types.h>
21 #include <sys/select.h>
22 #include <ctype.h>
23 #include <sys/wait.h>
24 #include <arpa/inet.h>
25 #include <regex.h>
26 #include <pthread.h>
27 #include "list.h"
28 #include "hostport.h"
29 #include "radsecproxy.h"
30
31 #ifdef RADPROT_TCP
32 #include "debug.h"
33 #include "util.h"
34 static void setprotoopts(struct commonprotoopts *opts);
35 static char **getlistenerargs();
36 void *tcplistener(void *arg);
37 int tcpconnect(struct server *server, struct timeval *when, int timeout, char * text);
38 void *tcpclientrd(void *arg);
39 int clientradputtcp(struct server *server, unsigned char *rad);
40 void tcpsetsrcres();
41
42 static const struct protodefs protodefs = {
43     "tcp",
44     NULL, /* secretdefault */
45     SOCK_STREAM, /* socktype */
46     "1812", /* portdefault */
47     0, /* retrycountdefault */
48     0, /* retrycountmax */
49     REQUEST_RETRY_INTERVAL * REQUEST_RETRY_COUNT, /* retryintervaldefault */
50     60, /* retryintervalmax */
51     DUPLICATE_INTERVAL, /* duplicateintervaldefault */
52     setprotoopts, /* setprotoopts */
53     getlistenerargs, /* getlistenerargs */
54     tcplistener, /* listener */
55     tcpconnect, /* connecter */
56     tcpclientrd, /* clientconnreader */
57     clientradputtcp, /* clientradput */
58     NULL, /* addclient */
59     NULL, /* addserverextra */
60     tcpsetsrcres, /* setsrcres */
61     NULL /* initextra */
62 };
63
64 static struct addrinfo *srcres = NULL;
65 static uint8_t handle;
66 static struct commonprotoopts *protoopts = NULL;
67 const struct protodefs *tcpinit(uint8_t h) {
68     handle = h;
69     return &protodefs;
70 }
71
72 static void setprotoopts(struct commonprotoopts *opts) {
73     protoopts = opts;
74 }
75
76 static char **getlistenerargs() {
77     return protoopts ? protoopts->listenargs : NULL;
78 }
79
80 void tcpsetsrcres() {
81     if (!srcres)
82         srcres = resolvepassiveaddrinfo(protoopts ? protoopts->sourcearg : NULL, NULL, protodefs.socktype);
83 }
84
85 int tcpconnect(struct server *server, struct timeval *when, int timeout, char *text) {
86     struct timeval now;
87     time_t elapsed;
88
89     debug(DBG_DBG, "tcpconnect: called from %s", text);
90     pthread_mutex_lock(&server->lock);
91     if (when && memcmp(&server->lastconnecttry, when, sizeof(struct timeval))) {
92         /* already reconnected, nothing to do */
93         debug(DBG_DBG, "tcpconnect(%s): seems already reconnected", text);
94         pthread_mutex_unlock(&server->lock);
95         return 1;
96     }
97
98     for (;;) {
99         gettimeofday(&now, NULL);
100         elapsed = now.tv_sec - server->lastconnecttry.tv_sec;
101         if (timeout && server->lastconnecttry.tv_sec && elapsed > timeout) {
102             debug(DBG_DBG, "tcpconnect: timeout");
103             if (server->sock >= 0)
104                 close(server->sock);
105             pthread_mutex_unlock(&server->lock);
106             return 0;
107         }
108         if (server->connectionok) {
109             server->connectionok = 0;
110             sleep(2);
111         } else if (elapsed < 1)
112             sleep(2);
113         else if (elapsed < 60) {
114             debug(DBG_INFO, "tcpconnect: sleeping %lds", elapsed);
115             sleep(elapsed);
116         } else if (elapsed < 100000) {
117             debug(DBG_INFO, "tcpconnect: sleeping %ds", 60);
118             sleep(60);
119         } else
120             server->lastconnecttry.tv_sec = now.tv_sec;  /* no sleep at startup */
121
122         if (server->sock >= 0)
123             close(server->sock);
124         if ((server->sock = connecttcphostlist(server->conf->hostports, srcres)) >= 0)
125             break;
126     }
127     server->connectionok = 1;
128     gettimeofday(&server->lastconnecttry, NULL);
129     pthread_mutex_unlock(&server->lock);
130     return 1;
131 }
132
133 /* timeout in seconds, 0 means no timeout (blocking), returns when num bytes have been read, or timeout */
134 /* returns 0 on timeout, -1 on error and num if ok */
135 int tcpreadtimeout(int s, unsigned char *buf, int num, int timeout) {
136     int ndesc, cnt, len;
137     fd_set readfds, writefds;
138     struct timeval timer;
139
140     if (s < 0)
141         return -1;
142     /* make socket non-blocking? */
143     for (len = 0; len < num; len += cnt) {
144         FD_ZERO(&readfds);
145         FD_SET(s, &readfds);
146         writefds = readfds;
147         if (timeout) {
148             timer.tv_sec = timeout;
149             timer.tv_usec = 0;
150         }
151         ndesc = select(s + 1, &readfds, &writefds, NULL, timeout ? &timer : NULL);
152         if (ndesc < 1)
153             return ndesc;
154
155         cnt = read(s, buf + len, num - len);
156         if (cnt <= 0)
157             return -1;
158     }
159     return num;
160 }
161
162 /* timeout in seconds, 0 means no timeout (blocking) */
163 unsigned char *radtcpget(int s, int timeout) {
164     int cnt, len;
165     unsigned char buf[4], *rad;
166
167     for (;;) {
168         cnt = tcpreadtimeout(s, buf, 4, timeout);
169         if (cnt < 1) {
170             debug(DBG_DBG, cnt ? "radtcpget: connection lost" : "radtcpget: timeout");
171             return NULL;
172         }
173
174         len = RADLEN(buf);
175         rad = malloc(len);
176         if (!rad) {
177             debug(DBG_ERR, "radtcpget: malloc failed");
178             continue;
179         }
180         memcpy(rad, buf, 4);
181
182         cnt = tcpreadtimeout(s, rad + 4, len - 4, timeout);
183         if (cnt < 1) {
184             debug(DBG_DBG, cnt ? "radtcpget: connection lost" : "radtcpget: timeout");
185             free(rad);
186             return NULL;
187         }
188
189         if (len >= 20)
190             break;
191
192         free(rad);
193         debug(DBG_WARN, "radtcpget: packet smaller than minimum radius size");
194     }
195
196     debug(DBG_DBG, "radtcpget: got %d bytes", len);
197     return rad;
198 }
199
200 int clientradputtcp(struct server *server, unsigned char *rad) {
201     int cnt;
202     size_t len;
203     struct clsrvconf *conf = server->conf;
204
205     if (!server->connectionok)
206         return 0;
207     len = RADLEN(rad);
208     if ((cnt = write(server->sock, rad, len)) <= 0) {
209         debug(DBG_ERR, "clientradputtcp: write error");
210         return 0;
211     }
212     debug(DBG_DBG, "clientradputtcp: Sent %d bytes, Radius packet of length %d to TCP peer %s", cnt, len, conf->name);
213     return 1;
214 }
215
216 void *tcpclientrd(void *arg) {
217     struct server *server = (struct server *)arg;
218     unsigned char *buf;
219     struct timeval lastconnecttry;
220
221     for (;;) {
222         /* yes, lastconnecttry is really necessary */
223         lastconnecttry = server->lastconnecttry;
224         buf = radtcpget(server->sock, 0);
225         if (!buf) {
226             tcpconnect(server, &lastconnecttry, 0, "tcpclientrd");
227             continue;
228         }
229
230         replyh(server, buf);
231     }
232     server->clientrdgone = 1;
233     return NULL;
234 }
235
236 void *tcpserverwr(void *arg) {
237     int cnt;
238     struct client *client = (struct client *)arg;
239     struct gqueue *replyq;
240     struct request *reply;
241
242     debug(DBG_DBG, "tcpserverwr: starting for %s", addr2string(client->addr));
243     replyq = client->replyq;
244     for (;;) {
245         pthread_mutex_lock(&replyq->mutex);
246         while (!list_first(replyq->entries)) {
247             if (client->sock >= 0) {
248                 debug(DBG_DBG, "tcpserverwr: waiting for signal");
249                 pthread_cond_wait(&replyq->cond, &replyq->mutex);
250                 debug(DBG_DBG, "tcpserverwr: got signal");
251             }
252             if (client->sock < 0) {
253                 /* s might have changed while waiting */
254                 pthread_mutex_unlock(&replyq->mutex);
255                 debug(DBG_DBG, "tcpserverwr: exiting as requested");
256                 pthread_exit(NULL);
257             }
258         }
259         reply = (struct request *)list_shift(replyq->entries);
260         pthread_mutex_unlock(&replyq->mutex);
261         cnt = write(client->sock, reply->replybuf, RADLEN(reply->replybuf));
262         if (cnt > 0)
263             debug(DBG_DBG, "tcpserverwr: sent %d bytes, Radius packet of length %d to %s",
264                   cnt, RADLEN(reply->replybuf), addr2string(client->addr));
265         else
266             debug(DBG_ERR, "tcpserverwr: write error for %s", addr2string(client->addr));
267         freerq(reply);
268     }
269 }
270
271 void tcpserverrd(struct client *client) {
272     struct request *rq;
273     uint8_t *buf;
274     pthread_t tcpserverwrth;
275
276     debug(DBG_DBG, "tcpserverrd: starting for %s", addr2string(client->addr));
277
278     if (pthread_create(&tcpserverwrth, NULL, tcpserverwr, (void *)client)) {
279         debug(DBG_ERR, "tcpserverrd: pthread_create failed");
280         return;
281     }
282
283     for (;;) {
284         buf = radtcpget(client->sock, 0);
285         if (!buf) {
286             debug(DBG_ERR, "tcpserverrd: connection from %s lost", addr2string(client->addr));
287             break;
288         }
289         debug(DBG_DBG, "tcpserverrd: got Radius message from %s", addr2string(client->addr));
290         rq = newrequest();
291         if (!rq) {
292             free(buf);
293             continue;
294         }
295         rq->buf = buf;
296         rq->from = client;
297         if (!radsrv(rq)) {
298             debug(DBG_ERR, "tcpserverrd: message authentication/validation failed, closing connection from %s", addr2string(client->addr));
299             break;
300         }
301     }
302
303     /* stop writer by setting s to -1 and give signal in case waiting for data */
304     client->sock = -1;
305     pthread_mutex_lock(&client->replyq->mutex);
306     pthread_cond_signal(&client->replyq->cond);
307     pthread_mutex_unlock(&client->replyq->mutex);
308     debug(DBG_DBG, "tcpserverrd: waiting for writer to end");
309     pthread_join(tcpserverwrth, NULL);
310     debug(DBG_DBG, "tcpserverrd: reader for %s exiting", addr2string(client->addr));
311 }
312 void *tcpservernew(void *arg) {
313     int s;
314     struct sockaddr_storage from;
315     socklen_t fromlen = sizeof(from);
316     struct clsrvconf *conf;
317     struct client *client;
318
319     s = *(int *)arg;
320     if (getpeername(s, (struct sockaddr *)&from, &fromlen)) {
321         debug(DBG_DBG, "tcpservernew: getpeername failed, exiting");
322         goto exit;
323     }
324     debug(DBG_WARN, "tcpservernew: incoming TCP connection from %s", addr2string((struct sockaddr *)&from));
325
326     conf = find_clconf(handle, (struct sockaddr *)&from, NULL);
327     if (conf) {
328         client = addclient(conf, 1);
329         if (client) {
330             client->sock = s;
331             client->addr = addr_copy((struct sockaddr *)&from);
332             tcpserverrd(client);
333             removeclient(client);
334         } else
335             debug(DBG_WARN, "tcpservernew: failed to create new client instance");
336     } else
337         debug(DBG_WARN, "tcpservernew: ignoring request, no matching TCP client");
338
339 exit:
340     shutdown(s, SHUT_RDWR);
341     close(s);
342     pthread_exit(NULL);
343 }
344
345 void *tcplistener(void *arg) {
346     pthread_t tcpserverth;
347     int s, *sp = (int *)arg;
348     struct sockaddr_storage from;
349     socklen_t fromlen = sizeof(from);
350
351     listen(*sp, 0);
352
353     for (;;) {
354         s = accept(*sp, (struct sockaddr *)&from, &fromlen);
355         if (s < 0) {
356             debug(DBG_WARN, "accept failed");
357             continue;
358         }
359         if (pthread_create(&tcpserverth, NULL, tcpservernew, (void *)&s)) {
360             debug(DBG_ERR, "tcplistener: pthread_create failed");
361             shutdown(s, SHUT_RDWR);
362             close(s);
363             continue;
364         }
365         pthread_detach(tcpserverth);
366     }
367     free(sp);
368     return NULL;
369 }
370 #else
371 const struct protodefs *tcpinit(uint8_t h) {
372     return NULL;
373 }
374 #endif
375
376 /* Local Variables: */
377 /* c-file-style: "stroustrup" */
378 /* End: */