Enable tls psk
[libradsec.git] / tcp.c
1 /* Copyright 2011-2013 NORDUnet A/S. All rights reserved.
2    See LICENSE for licensing information. */
3
4 #if defined HAVE_CONFIG_H
5 #include <config.h>
6 #endif
7
8 #include <assert.h>
9 #include <event2/event.h>
10 #include <event2/bufferevent.h>
11 #if defined (RS_ENABLE_TLS)
12 #include <event2/bufferevent_ssl.h>
13 #include <openssl/err.h>
14 #endif
15 #include <radius/client.h>
16 #include <radsec/radsec.h>
17 #include <radsec/radsec-impl.h>
18 #include "tcp.h"
19 #include "packet.h"
20 #include "conn.h"
21 #include "debug.h"
22 #include "event.h"
23
24 #if defined (DEBUG)
25 #include <event2/buffer.h>
26 #endif
27
28 /** Read one RADIUS packet header. Return !0 on error. */
29 static int
30 _read_header (struct rs_packet *pkt)
31 {
32   size_t n = 0;
33
34   n = bufferevent_read (pkt->conn->bev, pkt->hdr, RS_HEADER_LEN);
35   if (n == RS_HEADER_LEN)
36     {
37       pkt->flags |= RS_PACKET_HEADER_READ;
38       pkt->rpkt->length = (pkt->hdr[2] << 8) + pkt->hdr[3];
39       if (pkt->rpkt->length < 20 || pkt->rpkt->length > RS_MAX_PACKET_LEN)
40         {
41           rs_debug (("%s: invalid packet length: %d\n",
42                      __func__, pkt->rpkt->length));
43           rs_conn_disconnect (pkt->conn);
44           return rs_err_conn_push (pkt->conn, RSE_INVALID_PKT,
45                                    "invalid packet length: %d",
46                                    pkt->rpkt->length);
47         }
48       memcpy (pkt->rpkt->data, pkt->hdr, RS_HEADER_LEN);
49       bufferevent_setwatermark (pkt->conn->bev, EV_READ,
50                                 pkt->rpkt->length - RS_HEADER_LEN, 0);
51       rs_debug (("%s: packet header read, total pkt len=%d\n",
52                  __func__, pkt->rpkt->length));
53     }
54   else if (n < 0)
55     {
56       rs_debug (("%s: buffer frozen while reading header\n", __func__));
57     }
58   else      /* Error: libevent gave us less than the low watermark. */
59     {
60       rs_debug (("%s: got: %d octets reading header\n", __func__, n));
61       rs_conn_disconnect (pkt->conn);
62       return rs_err_conn_push_fl (pkt->conn, RSE_INTERNAL, __FILE__, __LINE__,
63                                   "got %d octets reading header", n);
64     }
65
66   return 0;
67 }
68
69 /** Read a message, check that it's valid RADIUS and hand it off to
70     registered user callback.
71
72     The packet is read from the bufferevent associated with \a pkt and
73     the data is stored in \a pkt->rpkt.
74
75     Return 0 on success and !0 on failure. */
76 static int
77 _read_packet (struct rs_packet *pkt)
78 {
79   size_t n = 0;
80   int err;
81
82   rs_debug (("%s: trying to read %d octets of packet data\n", __func__,
83              pkt->rpkt->length - RS_HEADER_LEN));
84
85   n = bufferevent_read (pkt->conn->bev,
86                         pkt->rpkt->data + RS_HEADER_LEN,
87                         pkt->rpkt->length - RS_HEADER_LEN);
88
89   rs_debug (("%s: read %ld octets of packet data\n", __func__, n));
90
91   if (n == pkt->rpkt->length - RS_HEADER_LEN)
92     {
93       bufferevent_disable (pkt->conn->bev, EV_READ);
94       rs_debug (("%s: complete packet read\n", __func__));
95       pkt->flags &= ~RS_PACKET_HEADER_READ;
96       memset (pkt->hdr, 0, sizeof(*pkt->hdr));
97
98       /* Checks done by rad_packet_ok:
99          - lenghts (FIXME: checks really ok for tcp?)
100          - invalid code field
101          - attribute lengths >= 2
102          - attribute sizes adding up correctly  */
103       err = nr_packet_ok (pkt->rpkt);
104       if (err != RSE_OK)
105         {
106           rs_debug (("%s: %d: invalid packet\n", __func__, -err));
107           rs_conn_disconnect (pkt->conn);
108           return rs_err_conn_push_fl (pkt->conn, -err, __FILE__, __LINE__,
109                                       "invalid packet");
110         }
111
112 #if defined (DEBUG)
113       /* Find out what happens if there's data left in the buffer.  */
114       {
115         size_t rest = 0;
116         rest = evbuffer_get_length (bufferevent_get_input (pkt->conn->bev));
117         if (rest)
118           rs_debug (("%s: returning with %d octets left in buffer\n", __func__,
119                      rest));
120       }
121 #endif
122
123       /* Hand over message to user.  This changes ownership of pkt.
124          Don't touch it afterwards -- it might have been freed.  */
125       if (pkt->conn->callbacks.received_cb)
126         pkt->conn->callbacks.received_cb (pkt, pkt->conn->user_data);
127     }
128   else if (n < 0)               /* Buffer frozen.  */
129     rs_debug (("%s: buffer frozen when reading packet\n", __func__));
130   else                          /* Short packet.  */
131     rs_debug (("%s: waiting for another %d octets\n", __func__,
132                pkt->rpkt->length - RS_HEADER_LEN - n));
133
134   return 0;
135 }
136
137 /* The read callback for TCP.
138
139    Read exactly one RADIUS message from BEV and store it in struct
140    rs_packet passed in USER_DATA.
141
142    Inform upper layer about successful reception of received RADIUS
143    message by invoking conn->callbacks.recevied_cb(), if !NULL.  */
144 void
145 tcp_read_cb (struct bufferevent *bev, void *user_data)
146 {
147   struct rs_packet *pkt = (struct rs_packet *) user_data;
148
149   assert (pkt);
150   assert (pkt->conn);
151   assert (pkt->rpkt);
152
153   pkt->rpkt->sockfd = pkt->conn->fd;
154   pkt->rpkt->vps = NULL;        /* FIXME: can this be done when initializing pkt? */
155
156   /* Read a message header if not already read, return if that
157      fails. Read a message and have it dispatched to the user
158      registered callback.
159
160      Room for improvement: Peek inside buffer (evbuffer_copyout()) to
161      avoid the extra copying. */
162   if ((pkt->flags & RS_PACKET_HEADER_READ) == 0)
163     if (_read_header (pkt))
164       return;                   /* Error.  */
165   _read_packet (pkt);
166 }
167
168 void
169 tcp_event_cb (struct bufferevent *bev, short events, void *user_data)
170 {
171   struct rs_packet *pkt = (struct rs_packet *) user_data;
172   struct rs_connection *conn = NULL;
173   int sockerr = 0;
174 #if defined (RS_ENABLE_TLS)
175   unsigned long tlserr = 0;
176 #endif
177 #if defined (DEBUG)
178   struct rs_peer *p = NULL;
179 #endif
180
181   assert (pkt);
182   assert (pkt->conn);
183   conn = pkt->conn;
184 #if defined (DEBUG)
185   assert (pkt->conn->active_peer);
186   p = conn->active_peer;
187 #endif
188
189   conn->is_connecting = 0;
190   if (events & BEV_EVENT_CONNECTED)
191     {
192       if (conn->tev)
193         evtimer_del (conn->tev); /* Cancel connect timer.  */
194       if (event_on_connect (conn, pkt))
195         {
196           event_on_disconnect (conn);
197           event_loopbreak (conn);
198         }
199     }
200   else if (events & BEV_EVENT_EOF)
201     {
202       event_on_disconnect (conn);
203     }
204   else if (events & BEV_EVENT_TIMEOUT)
205     {
206       rs_debug (("%s: %p times out on %s\n", __func__, p,
207                  (events & BEV_EVENT_READING) ? "read" : "write"));
208       rs_err_conn_push_fl (conn, RSE_TIMEOUT_IO, __FILE__, __LINE__, NULL);
209     }
210   else if (events & BEV_EVENT_ERROR)
211     {
212       sockerr = evutil_socket_geterror (conn->active_peer->fd);
213       if (sockerr == 0) /* FIXME: True that errno == 0 means closed? */
214         {
215           event_on_disconnect (conn);
216           rs_err_conn_push_fl (conn, RSE_DISCO, __FILE__, __LINE__, NULL);
217         }
218       else
219         {
220           rs_debug (("%s: %d: %d (%s)\n", __func__, conn->fd, sockerr,
221                      evutil_socket_error_to_string (sockerr)));
222           rs_err_conn_push_fl (conn, RSE_SOCKERR, __FILE__, __LINE__,
223                                "%d: %d (%s)", conn->fd, sockerr,
224                                evutil_socket_error_to_string (sockerr));
225         }
226 #if defined (RS_ENABLE_TLS)
227       if (conn->tls_ssl)        /* FIXME: correct check?  */
228         {
229           for (tlserr = bufferevent_get_openssl_error (conn->bev);
230                tlserr;
231                tlserr = bufferevent_get_openssl_error (conn->bev))
232             {
233               rs_debug (("%s: openssl error: %s\n", __func__,
234                          ERR_error_string (tlserr, NULL)));
235               rs_err_conn_push_fl (conn, RSE_SSLERR, __FILE__, __LINE__,
236                                    ERR_error_string (tlserr, NULL));
237             }
238         }
239 #endif  /* RS_ENABLE_TLS */
240       event_loopbreak (conn);
241     }
242
243 #if defined (DEBUG)
244   if (events & BEV_EVENT_ERROR && events != BEV_EVENT_ERROR)
245     rs_debug (("%s: BEV_EVENT_ERROR and more: 0x%x\n", __func__, events));
246 #endif
247 }
248
249 void
250 tcp_write_cb (struct bufferevent *bev, void *ctx)
251 {
252   struct rs_packet *pkt = (struct rs_packet *) ctx;
253
254   assert (pkt);
255   assert (pkt->conn);
256
257   if (pkt->conn->callbacks.sent_cb)
258     pkt->conn->callbacks.sent_cb (pkt->conn->user_data);
259 }
260
261 int
262 tcp_init_connect_timer (struct rs_connection *conn)
263 {
264   assert (conn);
265
266   if (conn->tev)
267     event_free (conn->tev);
268   conn->tev = evtimer_new (conn->evb, event_conn_timeout_cb, conn);
269   if (!conn->tev)
270     return rs_err_conn_push_fl (conn, RSE_EVENT, __FILE__, __LINE__,
271                                 "evtimer_new");
272
273   return RSE_OK;
274 }