cleaning up code
[libradsec.git] / tcp.c
1 /*
2  * Copyright (C) 2008 Stig Venaas <venaas@uninett.no>
3  *
4  * Permission to use, copy, modify, and distribute this software for any
5  * purpose with or without fee is hereby granted, provided that the above
6  * copyright notice and this permission notice appear in all copies.
7  */
8
9 #include <signal.h>
10 #include <sys/socket.h>
11 #include <netinet/in.h>
12 #include <netdb.h>
13 #include <string.h>
14 #include <unistd.h>
15 #include <limits.h>
16 #ifdef SYS_SOLARIS9
17 #include <fcntl.h>
18 #endif
19 #include <sys/time.h>
20 #include <sys/types.h>
21 #include <sys/select.h>
22 #include <ctype.h>
23 #include <sys/wait.h>
24 #include <arpa/inet.h>
25 #include <regex.h>
26 #include <pthread.h>
27 #include <openssl/ssl.h>
28 #include "debug.h"
29 #include "list.h"
30 #include "util.h"
31 #include "radsecproxy.h"
32
33 void *tcplistener(void *arg);
34 int tcpconnect(struct server *server, struct timeval *when, int timeout, char * text);
35 void *tcpclientrd(void *arg);
36 int clientradputtcp(struct server *server, unsigned char *rad);
37 void tcpsetsrcres(char *source);
38
39 static const struct protodefs protodefs = {
40     "tcp",
41     NULL, /* secretdefault */
42     SOCK_STREAM, /* socktype */
43     "1812", /* portdefault */
44     0, /* retrycountdefault */
45     0, /* retrycountmax */
46     REQUEST_RETRY_INTERVAL * REQUEST_RETRY_COUNT, /* retryintervaldefault */
47     60, /* retryintervalmax */
48     DUPLICATE_INTERVAL, /* duplicateintervaldefault */
49     tcplistener, /* listener */
50     tcpconnect, /* connecter */
51     tcpclientrd, /* clientconnreader */
52     clientradputtcp, /* clientradput */
53     NULL, /* addclient */
54     NULL, /* addserverextra */
55     tcpsetsrcres, /* setsrcres */
56     NULL /* initextra */
57 };
58
59 static struct addrinfo *srcres = NULL;
60 static uint8_t handle;
61
62 const struct protodefs *tcpinit(uint8_t h) {
63     handle = h;
64     return &protodefs;
65 }
66
67 void tcpsetsrcres(char *source) {
68     if (!srcres)
69         srcres = resolve_hostport_addrinfo(handle, source);
70 }
71     
72 int tcpconnect(struct server *server, struct timeval *when, int timeout, char *text) {
73     struct timeval now;
74     time_t elapsed;
75     
76     debug(DBG_DBG, "tcpconnect: called from %s", text);
77     pthread_mutex_lock(&server->lock);
78     if (when && memcmp(&server->lastconnecttry, when, sizeof(struct timeval))) {
79         /* already reconnected, nothing to do */
80         debug(DBG_DBG, "tcpconnect(%s): seems already reconnected", text);
81         pthread_mutex_unlock(&server->lock);
82         return 1;
83     }
84
85     for (;;) {
86         gettimeofday(&now, NULL);
87         elapsed = now.tv_sec - server->lastconnecttry.tv_sec;
88         if (timeout && server->lastconnecttry.tv_sec && elapsed > timeout) {
89             debug(DBG_DBG, "tcpconnect: timeout");
90             if (server->sock >= 0)
91                 close(server->sock);
92             pthread_mutex_unlock(&server->lock);
93             return 0;
94         }
95         if (server->connectionok) {
96             server->connectionok = 0;
97             sleep(2);
98         } else if (elapsed < 1)
99             sleep(2);
100         else if (elapsed < 60) {
101             debug(DBG_INFO, "tcpconnect: sleeping %lds", elapsed);
102             sleep(elapsed);
103         } else if (elapsed < 100000) {
104             debug(DBG_INFO, "tcpconnect: sleeping %ds", 60);
105             sleep(60);
106         } else
107             server->lastconnecttry.tv_sec = now.tv_sec;  /* no sleep at startup */
108         debug(DBG_WARN, "tcpconnect: trying to open TCP connection to %s port %s", server->conf->host, server->conf->port);
109         if (server->sock >= 0)
110             close(server->sock);
111         if ((server->sock = connecttcp(server->conf->addrinfo, srcres)) >= 0)
112             break;
113         debug(DBG_ERR, "tcpconnect: connecttcp failed");
114     }
115     debug(DBG_WARN, "tcpconnect: TCP connection to %s port %s up", server->conf->host, server->conf->port);
116     server->connectionok = 1;
117     gettimeofday(&server->lastconnecttry, NULL);
118     pthread_mutex_unlock(&server->lock);
119     return 1;
120 }
121
122 /* timeout in seconds, 0 means no timeout (blocking), returns when num bytes have been read, or timeout */
123 /* returns 0 on timeout, -1 on error and num if ok */
124 int tcpreadtimeout(int s, unsigned char *buf, int num, int timeout) {
125     int ndesc, cnt, len;
126     fd_set readfds, writefds;
127     struct timeval timer;
128     
129     if (s < 0)
130         return -1;
131     /* make socket non-blocking? */
132     for (len = 0; len < num; len += cnt) {
133         FD_ZERO(&readfds);
134         FD_SET(s, &readfds);
135         writefds = readfds;
136         if (timeout) {
137             timer.tv_sec = timeout;
138             timer.tv_usec = 0;
139         }
140         ndesc = select(s + 1, &readfds, &writefds, NULL, timeout ? &timer : NULL);
141         if (ndesc < 1)
142             return ndesc;
143
144         cnt = read(s, buf + len, num - len);
145         if (cnt <= 0)
146             return -1;
147     }
148     return num;
149 }
150
151 /* timeout in seconds, 0 means no timeout (blocking) */
152 unsigned char *radtcpget(int s, int timeout) {
153     int cnt, len;
154     unsigned char buf[4], *rad;
155
156     for (;;) {
157         cnt = tcpreadtimeout(s, buf, 4, timeout);
158         if (cnt < 1) {
159             debug(DBG_DBG, cnt ? "radtcpget: connection lost" : "radtcpget: timeout");
160             return NULL;
161         }
162
163         len = RADLEN(buf);
164         rad = malloc(len);
165         if (!rad) {
166             debug(DBG_ERR, "radtcpget: malloc failed");
167             continue;
168         }
169         memcpy(rad, buf, 4);
170         
171         cnt = tcpreadtimeout(s, rad + 4, len - 4, timeout);
172         if (cnt < 1) {
173             debug(DBG_DBG, cnt ? "radtcpget: connection lost" : "radtcpget: timeout");
174             free(rad);
175             return NULL;
176         }
177         
178         if (len >= 20)
179             break;
180         
181         free(rad);
182         debug(DBG_WARN, "radtcpget: packet smaller than minimum radius size");
183     }
184     
185     debug(DBG_DBG, "radtcpget: got %d bytes", len);
186     return rad;
187 }
188
189 int clientradputtcp(struct server *server, unsigned char *rad) {
190     int cnt;
191     size_t len;
192     struct clsrvconf *conf = server->conf;
193
194     if (!server->connectionok)
195         return 0;
196     len = RADLEN(rad);
197     if ((cnt = write(server->sock, rad, len)) <= 0) {
198         debug(DBG_ERR, "clientradputtcp: write error");
199         return 0;
200     }
201     debug(DBG_DBG, "clientradputtcp: Sent %d bytes, Radius packet of length %d to TCP peer %s", cnt, len, conf->host);
202     return 1;
203 }
204
205 void *tcpclientrd(void *arg) {
206     struct server *server = (struct server *)arg;
207     unsigned char *buf;
208     struct timeval lastconnecttry;
209     
210     for (;;) {
211         /* yes, lastconnecttry is really necessary */
212         lastconnecttry = server->lastconnecttry;
213         buf = radtcpget(server->sock, 0);
214         if (!buf) {
215             tcpconnect(server, &lastconnecttry, 0, "tcpclientrd");
216             continue;
217         }
218
219         replyh(server, buf);
220     }
221     server->clientrdgone = 1;
222     return NULL;
223 }
224
225 void *tcpserverwr(void *arg) {
226     int cnt;
227     struct client *client = (struct client *)arg;
228     struct queue *replyq;
229     struct request *reply;
230     
231     debug(DBG_DBG, "tcpserverwr: starting for %s", addr2string(client->addr));
232     replyq = client->replyq;
233     for (;;) {
234         pthread_mutex_lock(&replyq->mutex);
235         while (!list_first(replyq->entries)) {
236             if (client->sock >= 0) {        
237                 debug(DBG_DBG, "tcpserverwr: waiting for signal");
238                 pthread_cond_wait(&replyq->cond, &replyq->mutex);
239                 debug(DBG_DBG, "tcpserverwr: got signal");
240             }
241             if (client->sock < 0) {
242                 /* s might have changed while waiting */
243                 pthread_mutex_unlock(&replyq->mutex);
244                 debug(DBG_DBG, "tcpserverwr: exiting as requested");
245                 pthread_exit(NULL);
246             }
247         }
248         reply = (struct request *)list_shift(replyq->entries);
249         pthread_mutex_unlock(&replyq->mutex);
250         cnt = write(client->sock, reply->replybuf, RADLEN(reply->replybuf));
251         if (cnt > 0)
252             debug(DBG_DBG, "tcpserverwr: sent %d bytes, Radius packet of length %d to %s",
253                   cnt, RADLEN(reply->replybuf), addr2string(client->addr));
254         else
255             debug(DBG_ERR, "tcpserverwr: write error for %s", addr2string(client->addr));
256         freerq(reply);
257     }
258 }
259
260 void tcpserverrd(struct client *client) {
261     struct request *rq;
262     uint8_t *buf;
263     pthread_t tcpserverwrth;
264     
265     debug(DBG_DBG, "tcpserverrd: starting for %s", addr2string(client->addr));
266     
267     if (pthread_create(&tcpserverwrth, NULL, tcpserverwr, (void *)client)) {
268         debug(DBG_ERR, "tcpserverrd: pthread_create failed");
269         return;
270     }
271
272     for (;;) {
273         buf = radtcpget(client->sock, 0);
274         if (!buf) {
275             debug(DBG_ERR, "tcpserverrd: connection from %s lost", addr2string(client->addr));
276             break;
277         }
278         debug(DBG_DBG, "tcpserverrd: got Radius message from %s", addr2string(client->addr));
279         rq = newrequest();
280         if (!rq) {
281             free(buf);
282             continue;
283         }
284         rq->buf = buf;
285         rq->from = client;
286         if (!radsrv(rq)) {
287             debug(DBG_ERR, "tcpserverrd: message authentication/validation failed, closing connection from %s", addr2string(client->addr));
288             break;
289         }
290     }
291
292     /* stop writer by setting s to -1 and give signal in case waiting for data */
293     client->sock = -1;
294     pthread_mutex_lock(&client->replyq->mutex);
295     pthread_cond_signal(&client->replyq->cond);
296     pthread_mutex_unlock(&client->replyq->mutex);
297     debug(DBG_DBG, "tcpserverrd: waiting for writer to end");
298     pthread_join(tcpserverwrth, NULL);
299     debug(DBG_DBG, "tcpserverrd: reader for %s exiting", addr2string(client->addr));
300 }
301 void *tcpservernew(void *arg) {
302     int s;
303     struct sockaddr_storage from;
304     socklen_t fromlen = sizeof(from);
305     struct clsrvconf *conf;
306     struct client *client;
307
308     s = *(int *)arg;
309     if (getpeername(s, (struct sockaddr *)&from, &fromlen)) {
310         debug(DBG_DBG, "tcpservernew: getpeername failed, exiting");
311         goto exit;
312     }
313     debug(DBG_WARN, "tcpservernew: incoming TCP connection from %s", addr2string((struct sockaddr *)&from));
314
315     conf = find_clconf(handle, (struct sockaddr *)&from, NULL);
316     if (conf) {
317         client = addclient(conf, 1);
318         if (client) {
319             client->sock = s;
320             client->addr = addr_copy((struct sockaddr *)&from);
321             tcpserverrd(client);
322             removeclient(client);
323         } else
324             debug(DBG_WARN, "tcpservernew: failed to create new client instance");
325     } else
326         debug(DBG_WARN, "tcpservernew: ignoring request, no matching TCP client");
327
328  exit:
329     shutdown(s, SHUT_RDWR);
330     close(s);
331     pthread_exit(NULL);
332 }
333
334 void *tcplistener(void *arg) {
335     pthread_t tcpserverth;
336     int s, *sp = (int *)arg;
337     struct sockaddr_storage from;
338     socklen_t fromlen = sizeof(from);
339
340     listen(*sp, 0);
341
342     for (;;) {
343         s = accept(*sp, (struct sockaddr *)&from, &fromlen);
344         if (s < 0) {
345             debug(DBG_WARN, "accept failed");
346             continue;
347         }
348         if (pthread_create(&tcpserverth, NULL, tcpservernew, (void *)&s)) {
349             debug(DBG_ERR, "tcplistener: pthread_create failed");
350             shutdown(s, SHUT_RDWR);
351             close(s);
352             continue;
353         }
354         pthread_detach(tcpserverth);
355     }
356     free(sp);
357     return NULL;
358 }