allowing build with only specific transports
[libradsec.git] / tcp.c
1 /*
2  * Copyright (C) 2008 Stig Venaas <venaas@uninett.no>
3  *
4  * Permission to use, copy, modify, and distribute this software for any
5  * purpose with or without fee is hereby granted, provided that the above
6  * copyright notice and this permission notice appear in all copies.
7  */
8
9 #ifdef RADPROT_TCP
10 #include <signal.h>
11 #include <sys/socket.h>
12 #include <netinet/in.h>
13 #include <netdb.h>
14 #include <string.h>
15 #include <unistd.h>
16 #include <limits.h>
17 #ifdef SYS_SOLARIS9
18 #include <fcntl.h>
19 #endif
20 #include <sys/time.h>
21 #include <sys/types.h>
22 #include <sys/select.h>
23 #include <ctype.h>
24 #include <sys/wait.h>
25 #include <arpa/inet.h>
26 #include <regex.h>
27 #include <pthread.h>
28 #include <openssl/ssl.h>
29 #include "debug.h"
30 #include "list.h"
31 #include "util.h"
32 #include "radsecproxy.h"
33
34 static void setprotoopts(struct commonprotoopts *opts);
35 static char **getlistenerargs();
36 void *tcplistener(void *arg);
37 int tcpconnect(struct server *server, struct timeval *when, int timeout, char * text);
38 void *tcpclientrd(void *arg);
39 int clientradputtcp(struct server *server, unsigned char *rad);
40 void tcpsetsrcres();
41
42 static const struct protodefs protodefs = {
43     "tcp",
44     NULL, /* secretdefault */
45     SOCK_STREAM, /* socktype */
46     "1812", /* portdefault */
47     0, /* retrycountdefault */
48     0, /* retrycountmax */
49     REQUEST_RETRY_INTERVAL * REQUEST_RETRY_COUNT, /* retryintervaldefault */
50     60, /* retryintervalmax */
51     DUPLICATE_INTERVAL, /* duplicateintervaldefault */
52     setprotoopts, /* setprotoopts */
53     getlistenerargs, /* getlistenerargs */
54     tcplistener, /* listener */
55     tcpconnect, /* connecter */
56     tcpclientrd, /* clientconnreader */
57     clientradputtcp, /* clientradput */
58     NULL, /* addclient */
59     NULL, /* addserverextra */
60     tcpsetsrcres, /* setsrcres */
61     NULL /* initextra */
62 };
63
64 static struct addrinfo *srcres = NULL;
65 static uint8_t handle;
66 static struct commonprotoopts *protoopts = NULL;
67 const struct protodefs *tcpinit(uint8_t h) {
68     handle = h;
69     return &protodefs;
70 }
71
72 static void setprotoopts(struct commonprotoopts *opts) {
73     protoopts = opts;
74 }
75
76 static char **getlistenerargs() {
77     return protoopts ? protoopts->listenargs : NULL;
78 }
79
80 void tcpsetsrcres() {
81     if (!srcres)
82         srcres = resolve_hostport_addrinfo(handle, protoopts ? protoopts->sourcearg : NULL);
83 }
84     
85 int tcpconnect(struct server *server, struct timeval *when, int timeout, char *text) {
86     struct timeval now;
87     time_t elapsed;
88     
89     debug(DBG_DBG, "tcpconnect: called from %s", text);
90     pthread_mutex_lock(&server->lock);
91     if (when && memcmp(&server->lastconnecttry, when, sizeof(struct timeval))) {
92         /* already reconnected, nothing to do */
93         debug(DBG_DBG, "tcpconnect(%s): seems already reconnected", text);
94         pthread_mutex_unlock(&server->lock);
95         return 1;
96     }
97
98     for (;;) {
99         gettimeofday(&now, NULL);
100         elapsed = now.tv_sec - server->lastconnecttry.tv_sec;
101         if (timeout && server->lastconnecttry.tv_sec && elapsed > timeout) {
102             debug(DBG_DBG, "tcpconnect: timeout");
103             if (server->sock >= 0)
104                 close(server->sock);
105             pthread_mutex_unlock(&server->lock);
106             return 0;
107         }
108         if (server->connectionok) {
109             server->connectionok = 0;
110             sleep(2);
111         } else if (elapsed < 1)
112             sleep(2);
113         else if (elapsed < 60) {
114             debug(DBG_INFO, "tcpconnect: sleeping %lds", elapsed);
115             sleep(elapsed);
116         } else if (elapsed < 100000) {
117             debug(DBG_INFO, "tcpconnect: sleeping %ds", 60);
118             sleep(60);
119         } else
120             server->lastconnecttry.tv_sec = now.tv_sec;  /* no sleep at startup */
121         debug(DBG_WARN, "tcpconnect: trying to open TCP connection to %s port %s", server->conf->host, server->conf->port);
122         if (server->sock >= 0)
123             close(server->sock);
124         if ((server->sock = connecttcp(server->conf->addrinfo, srcres)) >= 0)
125             break;
126         debug(DBG_ERR, "tcpconnect: connecttcp failed");
127     }
128     debug(DBG_WARN, "tcpconnect: TCP connection to %s port %s up", server->conf->host, server->conf->port);
129     server->connectionok = 1;
130     gettimeofday(&server->lastconnecttry, NULL);
131     pthread_mutex_unlock(&server->lock);
132     return 1;
133 }
134
135 /* timeout in seconds, 0 means no timeout (blocking), returns when num bytes have been read, or timeout */
136 /* returns 0 on timeout, -1 on error and num if ok */
137 int tcpreadtimeout(int s, unsigned char *buf, int num, int timeout) {
138     int ndesc, cnt, len;
139     fd_set readfds, writefds;
140     struct timeval timer;
141     
142     if (s < 0)
143         return -1;
144     /* make socket non-blocking? */
145     for (len = 0; len < num; len += cnt) {
146         FD_ZERO(&readfds);
147         FD_SET(s, &readfds);
148         writefds = readfds;
149         if (timeout) {
150             timer.tv_sec = timeout;
151             timer.tv_usec = 0;
152         }
153         ndesc = select(s + 1, &readfds, &writefds, NULL, timeout ? &timer : NULL);
154         if (ndesc < 1)
155             return ndesc;
156
157         cnt = read(s, buf + len, num - len);
158         if (cnt <= 0)
159             return -1;
160     }
161     return num;
162 }
163
164 /* timeout in seconds, 0 means no timeout (blocking) */
165 unsigned char *radtcpget(int s, int timeout) {
166     int cnt, len;
167     unsigned char buf[4], *rad;
168
169     for (;;) {
170         cnt = tcpreadtimeout(s, buf, 4, timeout);
171         if (cnt < 1) {
172             debug(DBG_DBG, cnt ? "radtcpget: connection lost" : "radtcpget: timeout");
173             return NULL;
174         }
175
176         len = RADLEN(buf);
177         rad = malloc(len);
178         if (!rad) {
179             debug(DBG_ERR, "radtcpget: malloc failed");
180             continue;
181         }
182         memcpy(rad, buf, 4);
183         
184         cnt = tcpreadtimeout(s, rad + 4, len - 4, timeout);
185         if (cnt < 1) {
186             debug(DBG_DBG, cnt ? "radtcpget: connection lost" : "radtcpget: timeout");
187             free(rad);
188             return NULL;
189         }
190         
191         if (len >= 20)
192             break;
193         
194         free(rad);
195         debug(DBG_WARN, "radtcpget: packet smaller than minimum radius size");
196     }
197     
198     debug(DBG_DBG, "radtcpget: got %d bytes", len);
199     return rad;
200 }
201
202 int clientradputtcp(struct server *server, unsigned char *rad) {
203     int cnt;
204     size_t len;
205     struct clsrvconf *conf = server->conf;
206
207     if (!server->connectionok)
208         return 0;
209     len = RADLEN(rad);
210     if ((cnt = write(server->sock, rad, len)) <= 0) {
211         debug(DBG_ERR, "clientradputtcp: write error");
212         return 0;
213     }
214     debug(DBG_DBG, "clientradputtcp: Sent %d bytes, Radius packet of length %d to TCP peer %s", cnt, len, conf->host);
215     return 1;
216 }
217
218 void *tcpclientrd(void *arg) {
219     struct server *server = (struct server *)arg;
220     unsigned char *buf;
221     struct timeval lastconnecttry;
222     
223     for (;;) {
224         /* yes, lastconnecttry is really necessary */
225         lastconnecttry = server->lastconnecttry;
226         buf = radtcpget(server->sock, 0);
227         if (!buf) {
228             tcpconnect(server, &lastconnecttry, 0, "tcpclientrd");
229             continue;
230         }
231
232         replyh(server, buf);
233     }
234     server->clientrdgone = 1;
235     return NULL;
236 }
237
238 void *tcpserverwr(void *arg) {
239     int cnt;
240     struct client *client = (struct client *)arg;
241     struct queue *replyq;
242     struct request *reply;
243     
244     debug(DBG_DBG, "tcpserverwr: starting for %s", addr2string(client->addr));
245     replyq = client->replyq;
246     for (;;) {
247         pthread_mutex_lock(&replyq->mutex);
248         while (!list_first(replyq->entries)) {
249             if (client->sock >= 0) {        
250                 debug(DBG_DBG, "tcpserverwr: waiting for signal");
251                 pthread_cond_wait(&replyq->cond, &replyq->mutex);
252                 debug(DBG_DBG, "tcpserverwr: got signal");
253             }
254             if (client->sock < 0) {
255                 /* s might have changed while waiting */
256                 pthread_mutex_unlock(&replyq->mutex);
257                 debug(DBG_DBG, "tcpserverwr: exiting as requested");
258                 pthread_exit(NULL);
259             }
260         }
261         reply = (struct request *)list_shift(replyq->entries);
262         pthread_mutex_unlock(&replyq->mutex);
263         cnt = write(client->sock, reply->replybuf, RADLEN(reply->replybuf));
264         if (cnt > 0)
265             debug(DBG_DBG, "tcpserverwr: sent %d bytes, Radius packet of length %d to %s",
266                   cnt, RADLEN(reply->replybuf), addr2string(client->addr));
267         else
268             debug(DBG_ERR, "tcpserverwr: write error for %s", addr2string(client->addr));
269         freerq(reply);
270     }
271 }
272
273 void tcpserverrd(struct client *client) {
274     struct request *rq;
275     uint8_t *buf;
276     pthread_t tcpserverwrth;
277     
278     debug(DBG_DBG, "tcpserverrd: starting for %s", addr2string(client->addr));
279     
280     if (pthread_create(&tcpserverwrth, NULL, tcpserverwr, (void *)client)) {
281         debug(DBG_ERR, "tcpserverrd: pthread_create failed");
282         return;
283     }
284
285     for (;;) {
286         buf = radtcpget(client->sock, 0);
287         if (!buf) {
288             debug(DBG_ERR, "tcpserverrd: connection from %s lost", addr2string(client->addr));
289             break;
290         }
291         debug(DBG_DBG, "tcpserverrd: got Radius message from %s", addr2string(client->addr));
292         rq = newrequest();
293         if (!rq) {
294             free(buf);
295             continue;
296         }
297         rq->buf = buf;
298         rq->from = client;
299         if (!radsrv(rq)) {
300             debug(DBG_ERR, "tcpserverrd: message authentication/validation failed, closing connection from %s", addr2string(client->addr));
301             break;
302         }
303     }
304
305     /* stop writer by setting s to -1 and give signal in case waiting for data */
306     client->sock = -1;
307     pthread_mutex_lock(&client->replyq->mutex);
308     pthread_cond_signal(&client->replyq->cond);
309     pthread_mutex_unlock(&client->replyq->mutex);
310     debug(DBG_DBG, "tcpserverrd: waiting for writer to end");
311     pthread_join(tcpserverwrth, NULL);
312     debug(DBG_DBG, "tcpserverrd: reader for %s exiting", addr2string(client->addr));
313 }
314 void *tcpservernew(void *arg) {
315     int s;
316     struct sockaddr_storage from;
317     socklen_t fromlen = sizeof(from);
318     struct clsrvconf *conf;
319     struct client *client;
320
321     s = *(int *)arg;
322     if (getpeername(s, (struct sockaddr *)&from, &fromlen)) {
323         debug(DBG_DBG, "tcpservernew: getpeername failed, exiting");
324         goto exit;
325     }
326     debug(DBG_WARN, "tcpservernew: incoming TCP connection from %s", addr2string((struct sockaddr *)&from));
327
328     conf = find_clconf(handle, (struct sockaddr *)&from, NULL);
329     if (conf) {
330         client = addclient(conf, 1);
331         if (client) {
332             client->sock = s;
333             client->addr = addr_copy((struct sockaddr *)&from);
334             tcpserverrd(client);
335             removeclient(client);
336         } else
337             debug(DBG_WARN, "tcpservernew: failed to create new client instance");
338     } else
339         debug(DBG_WARN, "tcpservernew: ignoring request, no matching TCP client");
340
341  exit:
342     shutdown(s, SHUT_RDWR);
343     close(s);
344     pthread_exit(NULL);
345 }
346
347 void *tcplistener(void *arg) {
348     pthread_t tcpserverth;
349     int s, *sp = (int *)arg;
350     struct sockaddr_storage from;
351     socklen_t fromlen = sizeof(from);
352
353     listen(*sp, 0);
354
355     for (;;) {
356         s = accept(*sp, (struct sockaddr *)&from, &fromlen);
357         if (s < 0) {
358             debug(DBG_WARN, "accept failed");
359             continue;
360         }
361         if (pthread_create(&tcpserverth, NULL, tcpservernew, (void *)&s)) {
362             debug(DBG_ERR, "tcplistener: pthread_create failed");
363             shutdown(s, SHUT_RDWR);
364             close(s);
365             continue;
366         }
367         pthread_detach(tcpserverth);
368     }
369     free(sp);
370     return NULL;
371 }
372 #else
373 const struct protodefs *tcpinit(uint8_t h) {
374     return NULL;
375 }
376 #endif