Alphabetise radsec.sym.
[libradsec.git] / lib / tcp.c
1 /* Copyright 2011-2013 NORDUnet A/S. All rights reserved.
2    See LICENSE for licensing information. */
3
4 #if defined HAVE_CONFIG_H
5 #include <config.h>
6 #endif
7
8 #include <assert.h>
9 #include <event2/event.h>
10 #include <event2/bufferevent.h>
11 #if defined (RS_ENABLE_TLS)
12 #include <event2/bufferevent_ssl.h>
13 #include <openssl/err.h>
14 #endif
15 #include <radius/client.h>
16 #include <radsec/radsec.h>
17 #include <radsec/radsec-impl.h>
18 #include "tcp.h"
19 #include "message.h"
20 #include "conn.h"
21 #include "debug.h"
22 #include "event.h"
23
24 #if defined (DEBUG)
25 #include <event2/buffer.h>
26 #endif
27
28 /** Read one RADIUS message header. Return !0 on error. */
29 static int
30 _read_header (struct rs_message *msg)
31 {
32   size_t n = 0;
33
34   n = bufferevent_read (TO_BASE_CONN(msg->conn)->bev, msg->hdr, RS_HEADER_LEN);
35   if (n == RS_HEADER_LEN)
36     {
37       msg->flags |= RS_MESSAGE_HEADER_READ;
38       msg->rpkt->length = (msg->hdr[2] << 8) + msg->hdr[3];
39       if (msg->rpkt->length < 20 || msg->rpkt->length > RS_MAX_PACKET_LEN)
40         {
41           rs_debug (("%s: invalid packet length: %d\n", __func__,
42                      msg->rpkt->length));
43           rs_conn_disconnect (msg->conn);
44           return  rs_err_conn_push (msg->conn, RSE_INVALID_MSG,
45                                     "invalid message length: %d",
46                                     msg->rpkt->length);
47         }
48       memcpy (msg->rpkt->data, msg->hdr, RS_HEADER_LEN);
49       bufferevent_setwatermark (TO_BASE_CONN(msg->conn)->bev, EV_READ,
50                                 msg->rpkt->length - RS_HEADER_LEN, 0);
51       rs_debug (("%s: message header read, total msg len=%d\n",
52                  __func__, msg->rpkt->length));
53     }
54   else if (n < 0)
55     rs_debug (("%s: buffer frozen while reading header\n", __func__));
56   else      /* Error: libevent gave us less than the low watermark. */
57     {
58       rs_debug (("%s: got: %d octets reading header\n", __func__, n));
59       rs_conn_disconnect (msg->conn);
60       return rs_err_conn_push (msg->conn, RSE_INTERNAL,
61                                "got %d octets reading header", n);
62     }
63
64   return RSE_OK;
65 }
66
67 /** Read a message, check that it's valid RADIUS and hand it off to
68     registered user callback.
69
70     The message is read from the bufferevent associated with \a msg and
71     the data is stored in \a msg->rpkt.
72
73     Return 0 on success and !0 on failure. */
74 static int
75 _read_message (struct rs_message *msg)
76 {
77   size_t n = 0;
78   int err;
79
80   rs_debug (("%s: trying to read %d octets of message data\n", __func__,
81              msg->rpkt->length - RS_HEADER_LEN));
82
83   n = bufferevent_read (msg->conn->base_.bev,
84                         msg->rpkt->data + RS_HEADER_LEN,
85                         msg->rpkt->length - RS_HEADER_LEN);
86
87   rs_debug (("%s: read %ld octets of message data\n", __func__, n));
88
89   if (n == msg->rpkt->length - RS_HEADER_LEN)
90     {
91       bufferevent_disable (msg->conn->base_.bev, EV_READ);
92       rs_debug (("%s: complete message read\n", __func__));
93       msg->flags &= ~RS_MESSAGE_HEADER_READ;
94       memset (msg->hdr, 0, sizeof(*msg->hdr));
95
96       /* Checks done by nr_packet_ok:
97          - lenghts (FIXME: checks really ok for tcp?)
98          - invalid code field
99          - attribute lengths >= 2
100          - attribute sizes adding up correctly  */
101       err = nr_packet_ok (msg->rpkt);
102       if (err)
103         {
104           rs_debug (("%s: %d: invalid packet\n", __func__, -err));
105           rs_conn_disconnect (msg->conn);
106           return rs_err_conn_push (msg->conn, -err, "invalid message");
107         }
108
109 #if defined (DEBUG)
110       /* Find out what happens if there's data left in the buffer.  */
111       {
112         size_t rest = 0;
113         rest =
114           evbuffer_get_length (bufferevent_get_input (msg->conn->base_.bev));
115         if (rest)
116           rs_debug (("%s: returning with %d octets left in buffer\n", __func__,
117                      rest));
118       }
119 #endif
120
121       /* Hand over message to user.  This changes ownership of msg.
122          Don't touch it afterwards -- it might have been freed.  */
123       if (msg->conn->callbacks.received_cb)
124         msg->conn->callbacks.received_cb (msg, msg->conn->base_.user_data);
125     }
126   else if (n < 0)               /* Buffer frozen.  */
127     rs_debug (("%s: buffer frozen when reading message\n", __func__));
128   else                          /* Short message.  */
129     rs_debug (("%s: waiting for another %d octets\n", __func__,
130                msg->rpkt->length - RS_HEADER_LEN - n));
131
132   return 0;
133 }
134
135 /* The read callback for TCP.
136
137    Read exactly one RADIUS message from \a bev and store it in the
138    struct rs_message passed in \a user_data.
139
140    Inform upper layer about successful reception of received RADIUS
141    message by invoking conn->callbacks.recevied_cb(), if not NULL. */
142 void
143 tcp_read_cb (struct bufferevent *bev, void *user_data)
144 {
145   struct rs_message *msg = (struct rs_message *) user_data;
146
147   assert (msg);
148   assert (msg->conn);
149   assert (msg->rpkt);
150
151   msg->rpkt->sockfd = msg->conn->base_.fd;
152   msg->rpkt->vps = NULL; /* FIXME: can this be done when initializing msg? */
153
154   /* Read a message header if not already read, return if that
155      fails. Read a message and have it dispatched to the user
156      registered callback.
157
158      Room for improvement: Peek inside buffer (evbuffer_copyout()) to
159      avoid the extra copying. */
160   if ((msg->flags & RS_MESSAGE_HEADER_READ) == 0)
161     if (_read_header (msg))
162       return;                   /* Invalid header. */
163   _read_message (msg);
164 }
165
166 void
167 tcp_event_cb (struct bufferevent *bev, short events, void *user_data)
168 {
169   struct rs_message *msg = (struct rs_message *) user_data;
170   struct rs_connection *conn = NULL;
171   int sockerr = 0;
172 #if defined (RS_ENABLE_TLS)
173   unsigned long tlserr = 0;
174 #endif
175 #if defined (DEBUG)
176   struct rs_peer *p = NULL;
177 #endif
178
179   assert (msg);
180   assert (msg->conn);
181   conn = msg->conn;
182 #if defined (DEBUG)
183   assert (conn->active_peer);
184   p = conn->active_peer;
185 #endif
186
187   if (events & BEV_EVENT_CONNECTED)
188     {
189       int err = -1;
190
191       if (conn_originating_p (conn)) /* We're a client. */
192         {
193           assert (conn->tev);
194           if (conn->tev)
195             evtimer_del (conn->tev); /* Cancel connect timer.  */
196           err = event_on_connect_orig (conn, msg);
197         }
198       else                      /* We're a server. */
199         {
200           assert (conn->tev == NULL);
201           err = event_on_connect_term (conn, msg);
202         }
203       if (err)
204         {
205           event_on_disconnect (conn);
206           event_loopbreak (conn);
207         }
208     }
209   else if (events & BEV_EVENT_EOF)
210     {
211       event_on_disconnect (conn);
212     }
213   else if (events & BEV_EVENT_TIMEOUT)
214     {
215       rs_debug (("%s: %p times out on %s\n", __func__, p,
216                  (events & BEV_EVENT_READING) ? "read" : "write"));
217       rs_err_conn_push_fl (conn, RSE_TIMEOUT_IO, __FILE__, __LINE__, NULL);
218     }
219   else if (events & BEV_EVENT_ERROR)
220     {
221       sockerr = evutil_socket_geterror (conn->active_peer->fd);
222       if (sockerr == 0) /* FIXME: True that errno == 0 means closed? */
223         {
224           event_on_disconnect (conn);
225           rs_err_conn_push_fl (conn, RSE_DISCO, __FILE__, __LINE__, NULL);
226         }
227       else
228         {
229           rs_debug (("%s: %d: %d (%s)\n", __func__, conn->base_.fd, sockerr,
230                      evutil_socket_error_to_string (sockerr)));
231           rs_err_conn_push_fl (conn, RSE_SOCKERR, __FILE__, __LINE__,
232                                "%d: %d (%s)", conn->base_.fd, sockerr,
233                                evutil_socket_error_to_string (sockerr));
234         }
235 #if defined (RS_ENABLE_TLS)
236       if (conn->tls_ssl)        /* FIXME: correct check?  */
237         {
238           for (tlserr = bufferevent_get_openssl_error (conn->base_.bev);
239                tlserr;
240                tlserr = bufferevent_get_openssl_error (conn->base_.bev))
241             {
242               rs_debug (("%s: openssl error: %s\n", __func__,
243                          ERR_error_string (tlserr, NULL)));
244               rs_err_conn_push_fl (conn, RSE_SSLERR, __FILE__, __LINE__,
245                                    ERR_error_string (tlserr, NULL));
246             }
247         }
248 #endif  /* RS_ENABLE_TLS */
249       event_loopbreak (conn);
250     }
251
252 #if defined (DEBUG)
253   if (events & BEV_EVENT_ERROR && events != BEV_EVENT_ERROR)
254     rs_debug (("%s: BEV_EVENT_ERROR and more: 0x%x\n", __func__, events));
255 #endif
256 }
257
258 void
259 tcp_write_cb (struct bufferevent *bev, void *ctx)
260 {
261   struct rs_message *msg = (struct rs_message *) ctx;
262
263   assert (msg);
264   assert (msg->conn);
265
266   if (msg->conn->callbacks.sent_cb)
267     msg->conn->callbacks.sent_cb (msg->conn->base_.user_data);
268 }
269
270 int
271 tcp_init_connect_timer (struct rs_connection *conn)
272 {
273   assert (conn);
274   assert (conn->base_.ctx);
275
276   if (conn->tev)
277     event_free (conn->tev);
278   conn->tev = evtimer_new (conn->base_.ctx->evb, event_conn_timeout_cb, conn);
279   if (!conn->tev)
280     return rs_err_conn_push_fl (conn, RSE_EVENT, __FILE__, __LINE__,
281                                 "evtimer_new");
282
283   return RSE_OK;
284 }