Alphabetise radsec.sym.
[libradsec.git] / lib / conn.c
1 /* Copyright 2010-2013 NORDUnet A/S. All rights reserved.
2    See LICENSE for licensing information. */
3
4 #if defined HAVE_CONFIG_H
5 #include <config.h>
6 #endif
7
8 #include <string.h>
9 #include <stdlib.h>
10 #include <errno.h>
11 #include <assert.h>
12 #include <event2/event.h>
13 #include <event2/bufferevent.h>
14 #include <radsec/radsec.h>
15 #include <radsec/radsec-impl.h>
16 #include "err.h"
17 #include "debug.h"
18 #include "conn.h"
19 #include "event.h"
20 #include "message.h"
21 #include "tcp.h"
22
23 int
24 conn_user_dispatch_p (const struct rs_connection *conn)
25 {
26   assert (conn);
27
28   return (conn->callbacks.connected_cb ||
29           conn->callbacks.disconnected_cb ||
30           conn->callbacks.received_cb ||
31           conn->callbacks.sent_cb);
32 }
33
34 int
35 conn_activate_timeout (struct rs_connection *conn)
36 {
37   const struct rs_conn_base *connbase;
38   assert (conn);
39   connbase = TO_BASE_CONN (conn);
40   assert (connbase->ctx);
41   assert (connbase->ctx->evb);
42   assert (conn->tev);
43   if (connbase->timeout.tv_sec || connbase->timeout.tv_usec)
44     {
45       rs_debug (("%s: activating timer: %d.%d\n", __func__,
46                  connbase->timeout.tv_sec, connbase->timeout.tv_usec));
47       if (evtimer_add (conn->tev, &connbase->timeout))
48         return rs_err_conn_push (conn, RSE_EVENT, "evtimer_add: %d", errno);
49     }
50   return RSE_OK;
51 }
52
53 int
54 conn_type_tls_p (const struct rs_connection *conn)
55 {
56   return TO_BASE_CONN(conn)->transport == RS_CONN_TYPE_TLS
57     || TO_BASE_CONN(conn)->transport == RS_CONN_TYPE_DTLS;
58 }
59
60 int
61 baseconn_type_datagram_p (const struct rs_conn_base *connbase)
62 {
63   return connbase->transport == RS_CONN_TYPE_UDP
64     || connbase->transport == RS_CONN_TYPE_DTLS;
65 }
66
67 int
68 baseconn_type_stream_p (const struct rs_conn_base *connbase)
69 {
70   return connbase->transport == RS_CONN_TYPE_TCP
71     || connbase->transport == RS_CONN_TYPE_TLS;
72 }
73
74 int
75 conn_cred_psk (const struct rs_connection *conn)
76 {
77   assert (conn);
78   return conn->active_peer != NULL
79     && conn->active_peer->transport_cred
80     && conn->active_peer->transport_cred->type == RS_CRED_TLS_PSK;
81 }
82
83 void
84 conn_init (struct rs_context *ctx, /* FIXME: rename connbase_init? */
85            struct rs_conn_base *connbase,
86            enum rs_conn_subtype type)
87 {
88   switch (type)
89     {
90     case RS_CONN_OBJTYPE_BASE:
91       connbase->magic = RS_CONN_MAGIC_BASE;
92       break;
93     case RS_CONN_OBJTYPE_GENERIC:
94       connbase->magic = RS_CONN_MAGIC_GENERIC;
95       break;
96     case RS_CONN_OBJTYPE_LISTENER:
97       connbase->magic = RS_CONN_MAGIC_LISTENER;
98       break;
99     default:
100       assert ("invalid connection subtype" == NULL);
101     }
102
103   connbase->ctx = ctx;
104   connbase->fd = -1;
105 }
106
107 int
108 conn_configure (struct rs_context *ctx, /* FIXME: rename conbbase_configure? */
109                 struct rs_conn_base *connbase,
110                 const char *config)
111 {
112   if (config)
113     {
114       struct rs_realm *r = rs_conf_find_realm (ctx, config);
115       if (r)
116         {
117           connbase->realm = r;
118           connbase->peers = r->peers; /* FIXME: Copy instead?  */
119 #if 0
120           for (p = connbase->peers; p != NULL; p = p->next)
121             p->connbase = connbase;
122 #endif
123           connbase->timeout.tv_sec = r->timeout;
124           connbase->tryagain = r->retries;
125         }
126     }
127 #if 0  /* incoming connections don't have a realm (a config object), update: they do, but "somebody else" is setting this up <-- FIXME */
128   if (connbase->realm == NULL)
129     {
130       struct rs_realm *r = rs_calloc (ctx, 1, sizeof (struct rs_realm));
131       if (r == NULL)
132         return rs_err_ctx_push_fl (ctx, RSE_NOMEM, __FILE__, __LINE__, NULL);
133       r->next = ctx->realms;
134       ctx->realms = connbase->realm = r;
135     }
136 #else
137   if (connbase->realm)
138     connbase->transport = connbase->realm->type;
139 #endif
140
141   return RSE_OK;
142 }
143
144 int
145 conn_add_read_event (struct rs_connection *conn, void *user_data)
146 {
147   struct rs_conn_base *connbase = TO_BASE_CONN(conn);
148   int err;
149
150   assert(connbase);
151
152   if (connbase->bev)            /* TCP (including TLS).  */
153     {
154       bufferevent_setwatermark (connbase->bev, EV_READ, RS_HEADER_LEN, 0);
155       bufferevent_setcb (connbase->bev, tcp_read_cb, NULL, tcp_event_cb,
156                          user_data);
157       bufferevent_enable (connbase->bev, EV_READ);
158     }
159   else                          /* UDP.  */
160     {
161       /* Put fresh message in user_data for the callback and enable the
162          read event.  */
163       event_assign (connbase->rev, connbase->ctx->evb,
164                     event_get_fd (connbase->rev), EV_READ,
165                     event_get_callback (connbase->rev),
166                     user_data);
167       err = event_add (connbase->rev, NULL);
168       if (err < 0)
169         return rs_err_connbase_push_fl (connbase, RSE_EVENT, __FILE__, __LINE__,
170                                         "event_add: %s",
171                                         evutil_gai_strerror (err));
172
173       /* Activate retransmission timer.  */
174       conn_activate_timeout (conn);
175     }
176
177   return RSE_OK;
178 }
179
180 /** Return !=0 if \a conn is an originating connection, i.e. if its
181     peer is a server. */
182 int
183 conn_originating_p (const struct rs_connection *conn)
184 {
185   return conn->active_peer->type == RS_PEER_TYPE_SERVER;
186 }
187
188 int
189 baseconn_close (struct rs_conn_base *connbase)
190 {
191   int err = 0;
192   assert (connbase);
193
194   rs_debug (("%s: closing fd %d\n", __func__, connbase->fd));
195
196   err = evutil_closesocket (connbase->fd);
197   if (err)
198     err = rs_err_connbase_push (connbase, RSE_EVENT,
199                                 "evutil_closesocket: %d (%s)",
200                                 errno, strerror (errno));
201   connbase->fd = -1;
202   return err;
203 }
204
205 /* Public functions. */
206 int
207 rs_conn_create (struct rs_context *ctx,
208                 struct rs_connection **conn,
209                 const char *config)
210 {
211   int err = RSE_OK;
212   struct rs_connection *c = NULL;
213   assert (ctx);
214
215   c = rs_calloc (ctx, 1, sizeof (struct rs_connection));
216   if (c == NULL)
217     return rs_err_ctx_push_fl (ctx, RSE_NOMEM, __FILE__, __LINE__, NULL);
218   conn_init (ctx, &c->base_, RS_CONN_OBJTYPE_GENERIC);
219   err = conn_configure (ctx, &c->base_, config);
220   if (err)
221     goto errout;
222
223   if (conn)
224     *conn = c;
225   return RSE_OK;
226
227  errout:
228   if (c)
229     rs_free (ctx, c);
230   return err;
231 }
232
233 void
234 rs_conn_set_type (struct rs_connection *conn, rs_conn_type_t type)
235 {
236   assert (conn);
237   assert (conn->base_.realm);
238   conn->base_.realm->type = type;
239 }
240
241 int
242 rs_conn_add_listener (struct rs_connection *conn,
243                       rs_conn_type_t type,
244                       const char *hostname,
245                       int port)
246 {
247   return rs_err_conn_push_fl (conn, RSE_NOSYS, __FILE__, __LINE__, NULL);
248 }
249
250
251 int
252 rs_conn_disconnect (struct rs_connection *conn)
253 {
254   int err = 0;
255
256   assert (conn);
257
258   if (conn->state == RS_CONN_STATE_CONNECTED)
259     event_on_disconnect (conn);
260
261   if (TO_BASE_CONN (conn)->bev)
262     {
263       bufferevent_free (TO_BASE_CONN (conn)->bev);
264       TO_BASE_CONN (conn)->bev = NULL;
265     }
266   if (TO_BASE_CONN (conn)->rev)
267     {
268       event_free (TO_BASE_CONN (conn)->rev);
269       TO_BASE_CONN (conn)->rev = NULL;
270     }
271   if (TO_BASE_CONN (conn)->wev)
272     {
273       event_free (TO_BASE_CONN (conn)->wev);
274       TO_BASE_CONN (conn)->wev = NULL;
275     }
276
277   err = evutil_closesocket (TO_BASE_CONN (conn)->fd);
278   TO_BASE_CONN (conn)->fd = -1;
279   return err;
280 }
281
282 int
283 rs_conn_destroy (struct rs_connection *conn)
284 {
285   int err = 0;
286
287   assert (conn);
288
289   /* NOTE: conn->realm is owned by context.  */
290   /* NOTE: conn->peers is owned by context.  */
291
292   if (conn->state == RS_CONN_STATE_CONNECTED)
293     err = rs_conn_disconnect (conn);
294
295 #if defined (RS_ENABLE_TLS)
296   if (conn->tls_ssl) /* FIXME: Free SSL strucxt in rs_conn_disconnect?  */
297     SSL_free (conn->tls_ssl);
298   if (conn->tls_ctx)
299     SSL_CTX_free (conn->tls_ctx);
300 #endif
301
302   if (conn->tev)
303     event_free (conn->tev);
304   if (conn->base_.bev)
305     bufferevent_free (conn->base_.bev);
306   if (conn->base_.rev)
307     event_free (conn->base_.rev);
308   if (conn->base_.wev)
309     event_free (conn->base_.wev);
310
311   rs_free (conn->base_.ctx, conn);
312
313   return err;
314 }
315
316 int
317 rs_conn_set_eventbase (struct rs_connection *conn, struct event_base *eb)
318 {
319   return rs_err_conn_push_fl (conn, RSE_NOSYS, __FILE__, __LINE__, NULL);
320 }
321
322 void
323 rs_conn_set_callbacks (struct rs_connection *conn,
324                        struct rs_conn_callbacks *cb,
325                        void *user_data)
326 {
327   assert (conn);
328   TO_BASE_CONN(conn)->user_data = user_data;
329   memcpy (&conn->callbacks, cb, sizeof (conn->callbacks));
330 }
331
332 void
333 rs_conn_del_callbacks (struct rs_connection *conn)
334 {
335   assert (conn);
336   memset (&conn->callbacks, 0, sizeof (conn->callbacks));
337 }
338
339 struct rs_conn_callbacks *
340 rs_conn_get_callbacks (struct rs_connection *conn)
341 {
342   assert (conn);
343   return &conn->callbacks;
344 }
345
346 int
347 rs_conn_select_peer (struct rs_connection *conn, const char *name)
348 {
349   return rs_err_conn_push_fl (conn, RSE_NOSYS, __FILE__, __LINE__, NULL);
350 }
351
352 int
353 rs_conn_get_current_peer (struct rs_connection *conn,
354                           const char *name,
355                           size_t buflen)
356 {
357   return rs_err_conn_push_fl (conn, RSE_NOSYS, __FILE__, __LINE__, NULL);
358 }
359
360 int
361 rs_conn_dispatch (struct rs_connection *conn)
362 {
363   assert (conn);
364   return event_base_loop (conn->base_.ctx->evb, EVLOOP_ONCE);
365 }
366
367 #if 0
368 struct event_base
369 *rs_conn_get_evb (const struct rs_connection *conn)
370 {
371   assert (conn);
372   return conn->evb;
373 }
374 #endif
375
376 int rs_conn_get_fd (struct rs_connection *conn)
377 {
378   assert (conn);
379   return conn->base_.fd;
380 }
381
382 static void
383 _rcb (struct rs_message *message, void *user_data)
384 {
385   struct rs_message *msg = (struct rs_message *) user_data;
386   assert (msg);
387   assert (msg->conn);
388
389   msg->flags |= RS_MESSAGE_RECEIVED;
390   if (msg->conn->base_.bev)     /* TCP -- disable bufferevent. */
391     bufferevent_disable (msg->conn->base_.bev, EV_WRITE|EV_READ);
392   else                          /* UDP -- remove read event. */
393     event_del (msg->conn->base_.rev);
394 }
395
396 int
397 rs_conn_receive_message (struct rs_connection *conn,
398                          struct rs_message *req_msg,
399                          struct rs_message **msg_out)
400 {
401   int err = 0;
402   struct rs_message *msg = NULL;
403
404   assert (conn);
405   assert (conn->base_.realm);
406   assert (!conn_user_dispatch_p (conn)); /* Blocking mode only.  */
407
408   if (rs_message_create (conn, &msg))
409     return -1;
410
411   assert (conn->base_.ctx->evb);
412   assert (conn->base_.fd >= 0);
413
414   conn->callbacks.received_cb = _rcb;
415   conn->base_.user_data = msg;
416   msg->flags &= ~RS_MESSAGE_RECEIVED;
417
418   err = conn_add_read_event (conn, msg);
419   if (err)
420     return err;
421
422   rs_debug (("%s: entering event loop\n", __func__));
423   err = event_base_dispatch (conn->base_.ctx->evb);
424   conn->callbacks.received_cb = NULL;
425   if (err < 0)
426     return rs_err_conn_push_fl (conn, RSE_EVENT, __FILE__, __LINE__,
427                                 "event_base_dispatch: %s",
428                                 evutil_gai_strerror (err));
429   rs_debug (("%s: event loop done\n", __func__));
430
431   if ((msg->flags & RS_MESSAGE_RECEIVED) == 0 /* No message. */
432       || (req_msg
433           && message_verify_response (conn, msg, req_msg) != RSE_OK))
434     {
435       if (rs_err_conn_peek_code (conn) == RSE_OK)
436         {
437           /* No message and no error on the stack _should_ mean that the
438              server hung up on us. */
439           rs_err_conn_push (conn, RSE_DISCO, "no response");
440         }
441       return rs_err_conn_peek_code (conn);
442     }
443
444   if (msg_out)
445     *msg_out = msg;
446   return RSE_OK;
447 }
448
449 void
450 rs_conn_set_timeout(struct rs_connection *conn, struct timeval *tv)
451 {
452   assert (conn);
453   assert (tv);
454   conn->base_.timeout = *tv;
455 }
456
457 struct rs_peer *
458 connbase_get_peers (const struct rs_conn_base *connbase)
459 {
460   assert (connbase);
461   return connbase->peers;
462 }