WIP commit moving towards working server support.
[radsecproxy.git] / lib / tcp.c
1 /* Copyright 2011,2013 NORDUnet A/S. All rights reserved.
2    See LICENSE for licensing information.  */
3
4 #if defined HAVE_CONFIG_H
5 #include <config.h>
6 #endif
7
8 #include <assert.h>
9 #include <event2/event.h>
10 #include <event2/bufferevent.h>
11 #if defined (RS_ENABLE_TLS)
12 #include <event2/bufferevent_ssl.h>
13 #include <openssl/err.h>
14 #endif
15 #include <radius/client.h>
16 #include <radsec/radsec.h>
17 #include <radsec/radsec-impl.h>
18 #include "tcp.h"
19 #include "message.h"
20 #include "conn.h"
21 #include "debug.h"
22 #include "event.h"
23
24 #if defined (DEBUG)
25 #include <event2/buffer.h>
26 #endif
27
28 /** Read one RADIUS message header. Return !0 on error. */
29 static int
30 _read_header (struct rs_message *msg)
31 {
32   size_t n = 0;
33
34   n = bufferevent_read (TO_BASE_CONN(msg->conn)->bev, msg->hdr, RS_HEADER_LEN);
35   if (n == RS_HEADER_LEN)
36     {
37       msg->flags |= RS_MESSAGE_HEADER_READ;
38       msg->rpkt->length = (msg->hdr[2] << 8) + msg->hdr[3];
39       if (msg->rpkt->length < 20 || msg->rpkt->length > RS_MAX_PACKET_LEN)
40         return  rs_err_conn_push (msg->conn, RSE_INVALID_MSG,
41                                   "invalid message length: %d",
42                                   msg->rpkt->length);
43       memcpy (msg->rpkt->data, msg->hdr, RS_HEADER_LEN);
44       bufferevent_setwatermark (TO_BASE_CONN(msg->conn)->bev, EV_READ,
45                                 msg->rpkt->length - RS_HEADER_LEN, 0);
46       rs_debug (("%s: message header read, total msg len=%d\n",
47                  __func__, msg->rpkt->length));
48     }
49   else if (n < 0)
50     rs_debug (("%s: buffer frozen while reading header\n", __func__));
51   else      /* Error: libevent gave us less than the low watermark. */
52     return rs_err_conn_push_fl (msg->conn, RSE_INTERNAL, __FILE__, __LINE__,
53                                 "got %d octets reading header", n);
54   return RSE_OK;
55 }
56
57 /** Read a message, check that it's valid RADIUS and hand it off to
58     registered user callback.
59
60     The message is read from the bufferevent associated with \a msg and
61     the data is stored in \a msg->rpkt.
62
63     Return 0 on success and !0 on failure. */
64 static int
65 _read_message (struct rs_message *msg)
66 {
67   size_t n = 0;
68   int err;
69
70   rs_debug (("%s: trying to read %d octets of message data\n", __func__,
71              msg->rpkt->length - RS_HEADER_LEN));
72
73   n = bufferevent_read (msg->conn->base_.bev,
74                         msg->rpkt->data + RS_HEADER_LEN,
75                         msg->rpkt->length - RS_HEADER_LEN);
76
77   rs_debug (("%s: read %ld octets of message data\n", __func__, n));
78
79   if (n == msg->rpkt->length - RS_HEADER_LEN)
80     {
81       bufferevent_disable (msg->conn->base_.bev, EV_READ);
82       rs_debug (("%s: complete message read\n", __func__));
83       msg->flags &= ~RS_MESSAGE_HEADER_READ;
84       memset (msg->hdr, 0, sizeof(*msg->hdr));
85
86       /* Checks done by nr_packet_ok:
87          - lenghts (FIXME: checks really ok for tcp?)
88          - invalid code field
89          - attribute lengths >= 2
90          - attribute sizes adding up correctly  */
91       err = nr_packet_ok (msg->rpkt);
92       if (err)
93         return rs_err_conn_push_fl (msg->conn, err, __FILE__, __LINE__,
94                                     "invalid message");
95
96 #if defined (DEBUG)
97       /* Find out what happens if there's data left in the buffer.  */
98       {
99         size_t rest = 0;
100         rest =
101           evbuffer_get_length (bufferevent_get_input (msg->conn->base_.bev));
102         if (rest)
103           rs_debug (("%s: returning with %d octets left in buffer\n", __func__,
104                      rest));
105       }
106 #endif
107
108       /* Hand over message to user.  This changes ownership of msg.
109          Don't touch it afterwards -- it might have been freed.  */
110       if (msg->conn->callbacks.received_cb)
111         msg->conn->callbacks.received_cb (msg, msg->conn->base_.user_data);
112     }
113   else if (n < 0)               /* Buffer frozen.  */
114     rs_debug (("%s: buffer frozen when reading message\n", __func__));
115   else                          /* Short message.  */
116     rs_debug (("%s: waiting for another %d octets\n", __func__,
117                msg->rpkt->length - RS_HEADER_LEN - n));
118
119   return 0;
120 }
121
122 /* The read callback for TCP.
123
124    Read exactly one RADIUS message from \a bev and store it in the
125    struct rs_message passed in \a user_data.
126
127    Inform upper layer about successful reception of received RADIUS
128    message by invoking conn->callbacks.recevied_cb(), if not NULL. */
129 void
130 tcp_read_cb (struct bufferevent *bev, void *user_data)
131 {
132   struct rs_message *msg = (struct rs_message *) user_data;
133
134   assert (msg);
135   assert (msg->conn);
136   assert (msg->rpkt);
137
138   msg->rpkt->sockfd = msg->conn->base_.fd;
139   msg->rpkt->vps = NULL; /* FIXME: can this be done when initializing msg? */
140
141   /* Read a message header if not already read, return if that
142      fails. Read a message and have it dispatched to the user
143      registered callback.
144
145      Room for improvement: Peek inside buffer (evbuffer_copyout()) to
146      avoid the extra copying. */
147   if ((msg->flags & RS_MESSAGE_HEADER_READ) == 0)
148     if (_read_header (msg))
149       return;                   /* Invalid header. */
150   if (_read_message (msg))
151     return;                     /* Invalid message. */
152 }
153
154 void
155 tcp_event_cb (struct bufferevent *bev, short events, void *user_data)
156 {
157   struct rs_message *msg = (struct rs_message *) user_data;
158   struct rs_connection *conn = NULL;
159   int sockerr = 0;
160 #if defined (RS_ENABLE_TLS)
161   unsigned long tlserr = 0;
162 #endif
163 #if defined (DEBUG)
164   struct rs_peer *p = NULL;
165 #endif
166
167   assert (msg);
168   assert (msg->conn);
169   conn = msg->conn;
170 #if defined (DEBUG)
171   assert (conn->active_peer);
172   p = conn->active_peer;
173 #endif
174
175   if (events & BEV_EVENT_CONNECTED)
176     {
177       int err = -1;
178
179       if (conn_originating_p (conn)) /* We're a client. */
180         {
181           assert (conn->tev);
182           if (conn->tev)
183             evtimer_del (conn->tev); /* Cancel connect timer.  */
184           err = event_on_connect_orig (conn, msg);
185         }
186       else                      /* We're a server. */
187         {
188           assert (conn->tev == NULL);
189           err = event_on_connect_term (conn, msg);
190         }
191       if (err)
192         {
193           event_on_disconnect (conn);
194           event_loopbreak (conn);
195         }
196     }
197   else if (events & BEV_EVENT_EOF)
198     {
199       event_on_disconnect (conn);
200     }
201   else if (events & BEV_EVENT_TIMEOUT)
202     {
203       rs_debug (("%s: %p times out on %s\n", __func__, p,
204                  (events & BEV_EVENT_READING) ? "read" : "write"));
205       rs_err_conn_push_fl (conn, RSE_TIMEOUT_IO, __FILE__, __LINE__, NULL);
206     }
207   else if (events & BEV_EVENT_ERROR)
208     {
209       sockerr = evutil_socket_geterror (conn->active_peer->fd);
210       if (sockerr == 0) /* FIXME: True that errno == 0 means closed? */
211         {
212           event_on_disconnect (conn);
213           rs_err_conn_push_fl (conn, RSE_DISCO, __FILE__, __LINE__, NULL);
214         }
215       else
216         {
217           rs_debug (("%s: %d: %d (%s)\n", __func__, conn->base_.fd, sockerr,
218                      evutil_socket_error_to_string (sockerr)));
219           rs_err_conn_push_fl (conn, RSE_SOCKERR, __FILE__, __LINE__,
220                                "%d: %d (%s)", conn->base_.fd, sockerr,
221                                evutil_socket_error_to_string (sockerr));
222         }
223 #if defined (RS_ENABLE_TLS)
224       if (conn->tls_ssl)        /* FIXME: correct check?  */
225         {
226           for (tlserr = bufferevent_get_openssl_error (conn->base_.bev);
227                tlserr;
228                tlserr = bufferevent_get_openssl_error (conn->base_.bev))
229             {
230               rs_debug (("%s: openssl error: %s\n", __func__,
231                          ERR_error_string (tlserr, NULL)));
232               rs_err_conn_push_fl (conn, RSE_SSLERR, __FILE__, __LINE__,
233                                    ERR_error_string (tlserr, NULL));
234             }
235         }
236 #endif  /* RS_ENABLE_TLS */
237       event_loopbreak (conn);
238     }
239
240 #if defined (DEBUG)
241   if (events & BEV_EVENT_ERROR && events != BEV_EVENT_ERROR)
242     rs_debug (("%s: BEV_EVENT_ERROR and more: 0x%x\n", __func__, events));
243 #endif
244 }
245
246 void
247 tcp_write_cb (struct bufferevent *bev, void *ctx)
248 {
249   struct rs_message *msg = (struct rs_message *) ctx;
250
251   assert (msg);
252   assert (msg->conn);
253
254   if (msg->conn->callbacks.sent_cb)
255     msg->conn->callbacks.sent_cb (msg->conn->base_.user_data);
256 }
257
258 int
259 tcp_init_connect_timer (struct rs_connection *conn)
260 {
261   assert (conn);
262   assert (conn->base_.ctx);
263
264   if (conn->tev)
265     event_free (conn->tev);
266   conn->tev = evtimer_new (conn->base_.ctx->evb, event_conn_timeout_cb, conn);
267   if (!conn->tev)
268     return rs_err_conn_push_fl (conn, RSE_EVENT, __FILE__, __LINE__,
269                                 "evtimer_new");
270
271   return RSE_OK;
272 }