cleaning up code
[libradsec.git] / tcp.c
1 /*
2  * Copyright (C) 2008 Stig Venaas <venaas@uninett.no>
3  *
4  * Permission to use, copy, modify, and distribute this software for any
5  * purpose with or without fee is hereby granted, provided that the above
6  * copyright notice and this permission notice appear in all copies.
7  */
8
9 #include <signal.h>
10 #include <sys/socket.h>
11 #include <netinet/in.h>
12 #include <netdb.h>
13 #include <string.h>
14 #include <unistd.h>
15 #include <limits.h>
16 #ifdef SYS_SOLARIS9
17 #include <fcntl.h>
18 #endif
19 #include <sys/time.h>
20 #include <sys/types.h>
21 #include <sys/select.h>
22 #include <ctype.h>
23 #include <sys/wait.h>
24 #include <arpa/inet.h>
25 #include <regex.h>
26 #include <pthread.h>
27 #include <openssl/ssl.h>
28 #include "debug.h"
29 #include "list.h"
30 #include "util.h"
31 #include "radsecproxy.h"
32 #include "tcp.h"
33
34 static struct addrinfo *srcres = NULL;
35
36 void tcpsetsrcres(char *source) {
37     if (!srcres)
38         srcres = resolve_hostport_addrinfo(RAD_TCP, source);
39 }
40     
41 int tcpconnect(struct server *server, struct timeval *when, int timeout, char *text) {
42     struct timeval now;
43     time_t elapsed;
44     
45     debug(DBG_DBG, "tcpconnect: called from %s", text);
46     pthread_mutex_lock(&server->lock);
47     if (when && memcmp(&server->lastconnecttry, when, sizeof(struct timeval))) {
48         /* already reconnected, nothing to do */
49         debug(DBG_DBG, "tcpconnect(%s): seems already reconnected", text);
50         pthread_mutex_unlock(&server->lock);
51         return 1;
52     }
53
54     for (;;) {
55         gettimeofday(&now, NULL);
56         elapsed = now.tv_sec - server->lastconnecttry.tv_sec;
57         if (timeout && server->lastconnecttry.tv_sec && elapsed > timeout) {
58             debug(DBG_DBG, "tcpconnect: timeout");
59             if (server->sock >= 0)
60                 close(server->sock);
61             pthread_mutex_unlock(&server->lock);
62             return 0;
63         }
64         if (server->connectionok) {
65             server->connectionok = 0;
66             sleep(2);
67         } else if (elapsed < 1)
68             sleep(2);
69         else if (elapsed < 60) {
70             debug(DBG_INFO, "tcpconnect: sleeping %lds", elapsed);
71             sleep(elapsed);
72         } else if (elapsed < 100000) {
73             debug(DBG_INFO, "tcpconnect: sleeping %ds", 60);
74             sleep(60);
75         } else
76             server->lastconnecttry.tv_sec = now.tv_sec;  /* no sleep at startup */
77         debug(DBG_WARN, "tcpconnect: trying to open TCP connection to %s port %s", server->conf->host, server->conf->port);
78         if (server->sock >= 0)
79             close(server->sock);
80         if ((server->sock = connecttcp(server->conf->addrinfo, srcres)) >= 0)
81             break;
82         debug(DBG_ERR, "tcpconnect: connecttcp failed");
83     }
84     debug(DBG_WARN, "tcpconnect: TCP connection to %s port %s up", server->conf->host, server->conf->port);
85     server->connectionok = 1;
86     gettimeofday(&server->lastconnecttry, NULL);
87     pthread_mutex_unlock(&server->lock);
88     return 1;
89 }
90
91 /* timeout in seconds, 0 means no timeout (blocking), returns when num bytes have been read, or timeout */
92 /* returns 0 on timeout, -1 on error and num if ok */
93 int tcpreadtimeout(int s, unsigned char *buf, int num, int timeout) {
94     int ndesc, cnt, len;
95     fd_set readfds, writefds;
96     struct timeval timer;
97     
98     if (s < 0)
99         return -1;
100     /* make socket non-blocking? */
101     for (len = 0; len < num; len += cnt) {
102         FD_ZERO(&readfds);
103         FD_SET(s, &readfds);
104         writefds = readfds;
105         if (timeout) {
106             timer.tv_sec = timeout;
107             timer.tv_usec = 0;
108         }
109         ndesc = select(s + 1, &readfds, &writefds, NULL, timeout ? &timer : NULL);
110         if (ndesc < 1)
111             return ndesc;
112
113         cnt = read(s, buf + len, num - len);
114         if (cnt <= 0)
115             return -1;
116     }
117     return num;
118 }
119
120 /* timeout in seconds, 0 means no timeout (blocking) */
121 unsigned char *radtcpget(int s, int timeout) {
122     int cnt, len;
123     unsigned char buf[4], *rad;
124
125     for (;;) {
126         cnt = tcpreadtimeout(s, buf, 4, timeout);
127         if (cnt < 1) {
128             debug(DBG_DBG, cnt ? "radtcpget: connection lost" : "radtcpget: timeout");
129             return NULL;
130         }
131
132         len = RADLEN(buf);
133         rad = malloc(len);
134         if (!rad) {
135             debug(DBG_ERR, "radtcpget: malloc failed");
136             continue;
137         }
138         memcpy(rad, buf, 4);
139         
140         cnt = tcpreadtimeout(s, rad + 4, len - 4, timeout);
141         if (cnt < 1) {
142             debug(DBG_DBG, cnt ? "radtcpget: connection lost" : "radtcpget: timeout");
143             free(rad);
144             return NULL;
145         }
146         
147         if (len >= 20)
148             break;
149         
150         free(rad);
151         debug(DBG_WARN, "radtcpget: packet smaller than minimum radius size");
152     }
153     
154     debug(DBG_DBG, "radtcpget: got %d bytes", len);
155     return rad;
156 }
157
158 int clientradputtcp(struct server *server, unsigned char *rad) {
159     int cnt;
160     size_t len;
161     struct clsrvconf *conf = server->conf;
162
163     if (!server->connectionok)
164         return 0;
165     len = RADLEN(rad);
166     if ((cnt = write(server->sock, rad, len)) <= 0) {
167         debug(DBG_ERR, "clientradputtcp: write error");
168         return 0;
169     }
170     debug(DBG_DBG, "clientradputtcp: Sent %d bytes, Radius packet of length %d to TCP peer %s", cnt, len, conf->host);
171     return 1;
172 }
173
174 void *tcpclientrd(void *arg) {
175     struct server *server = (struct server *)arg;
176     unsigned char *buf;
177     struct timeval lastconnecttry;
178     
179     for (;;) {
180         /* yes, lastconnecttry is really necessary */
181         lastconnecttry = server->lastconnecttry;
182         buf = radtcpget(server->sock, 0);
183         if (!buf) {
184             tcpconnect(server, &lastconnecttry, 0, "tcpclientrd");
185             continue;
186         }
187
188         replyh(server, buf);
189     }
190     server->clientrdgone = 1;
191     return NULL;
192 }
193
194 void *tcpserverwr(void *arg) {
195     int cnt;
196     struct client *client = (struct client *)arg;
197     struct queue *replyq;
198     struct request *reply;
199     
200     debug(DBG_DBG, "tcpserverwr: starting for %s", addr2string(client->addr));
201     replyq = client->replyq;
202     for (;;) {
203         pthread_mutex_lock(&replyq->mutex);
204         while (!list_first(replyq->entries)) {
205             if (client->sock >= 0) {        
206                 debug(DBG_DBG, "tcpserverwr: waiting for signal");
207                 pthread_cond_wait(&replyq->cond, &replyq->mutex);
208                 debug(DBG_DBG, "tcpserverwr: got signal");
209             }
210             if (client->sock < 0) {
211                 /* s might have changed while waiting */
212                 pthread_mutex_unlock(&replyq->mutex);
213                 debug(DBG_DBG, "tcpserverwr: exiting as requested");
214                 pthread_exit(NULL);
215             }
216         }
217         reply = (struct request *)list_shift(replyq->entries);
218         pthread_mutex_unlock(&replyq->mutex);
219         cnt = write(client->sock, reply->replybuf, RADLEN(reply->replybuf));
220         if (cnt > 0)
221             debug(DBG_DBG, "tcpserverwr: sent %d bytes, Radius packet of length %d to %s",
222                   cnt, RADLEN(reply->replybuf), addr2string(client->addr));
223         else
224             debug(DBG_ERR, "tcpserverwr: write error for %s", addr2string(client->addr));
225         freerq(reply);
226     }
227 }
228
229 void tcpserverrd(struct client *client) {
230     struct request *rq;
231     uint8_t *buf;
232     pthread_t tcpserverwrth;
233     
234     debug(DBG_DBG, "tcpserverrd: starting for %s", addr2string(client->addr));
235     
236     if (pthread_create(&tcpserverwrth, NULL, tcpserverwr, (void *)client)) {
237         debug(DBG_ERR, "tcpserverrd: pthread_create failed");
238         return;
239     }
240
241     for (;;) {
242         buf = radtcpget(client->sock, 0);
243         if (!buf) {
244             debug(DBG_ERR, "tcpserverrd: connection from %s lost", addr2string(client->addr));
245             break;
246         }
247         debug(DBG_DBG, "tcpserverrd: got Radius message from %s", addr2string(client->addr));
248         rq = newrequest();
249         if (!rq) {
250             free(buf);
251             continue;
252         }
253         rq->buf = buf;
254         rq->from = client;
255         if (!radsrv(rq)) {
256             debug(DBG_ERR, "tcpserverrd: message authentication/validation failed, closing connection from %s", addr2string(client->addr));
257             break;
258         }
259     }
260
261     /* stop writer by setting s to -1 and give signal in case waiting for data */
262     client->sock = -1;
263     pthread_mutex_lock(&client->replyq->mutex);
264     pthread_cond_signal(&client->replyq->cond);
265     pthread_mutex_unlock(&client->replyq->mutex);
266     debug(DBG_DBG, "tcpserverrd: waiting for writer to end");
267     pthread_join(tcpserverwrth, NULL);
268     debug(DBG_DBG, "tcpserverrd: reader for %s exiting", addr2string(client->addr));
269 }
270 void *tcpservernew(void *arg) {
271     int s;
272     struct sockaddr_storage from;
273     socklen_t fromlen = sizeof(from);
274     struct clsrvconf *conf;
275     struct client *client;
276
277     s = *(int *)arg;
278     if (getpeername(s, (struct sockaddr *)&from, &fromlen)) {
279         debug(DBG_DBG, "tcpservernew: getpeername failed, exiting");
280         goto exit;
281     }
282     debug(DBG_WARN, "tcpservernew: incoming TCP connection from %s", addr2string((struct sockaddr *)&from));
283
284     conf = find_clconf(RAD_TCP, (struct sockaddr *)&from, NULL);
285     if (conf) {
286         client = addclient(conf, 1);
287         if (client) {
288             client->sock = s;
289             client->addr = addr_copy((struct sockaddr *)&from);
290             tcpserverrd(client);
291             removeclient(client);
292         } else
293             debug(DBG_WARN, "tcpservernew: failed to create new client instance");
294     } else
295         debug(DBG_WARN, "tcpservernew: ignoring request, no matching TCP client");
296
297  exit:
298     shutdown(s, SHUT_RDWR);
299     close(s);
300     pthread_exit(NULL);
301 }
302
303 void *tcplistener(void *arg) {
304     pthread_t tcpserverth;
305     int s, *sp = (int *)arg;
306     struct sockaddr_storage from;
307     socklen_t fromlen = sizeof(from);
308
309     listen(*sp, 0);
310
311     for (;;) {
312         s = accept(*sp, (struct sockaddr *)&from, &fromlen);
313         if (s < 0) {
314             debug(DBG_WARN, "accept failed");
315             continue;
316         }
317         if (pthread_create(&tcpserverth, NULL, tcpservernew, (void *)&s)) {
318             debug(DBG_ERR, "tcplistener: pthread_create failed");
319             shutdown(s, SHUT_RDWR);
320             close(s);
321             continue;
322         }
323         pthread_detach(tcpserverth);
324     }
325     free(sp);
326     return NULL;
327 }