WIP commit moving towards working server support.
[radsecproxy.git] / lib / conn.c
1 /* Copyright 2010,2011,2013 NORDUnet A/S. All rights reserved.
2    See LICENSE for licensing information. */
3
4 #if defined HAVE_CONFIG_H
5 #include <config.h>
6 #endif
7
8 #include <string.h>
9 #include <stdlib.h>
10 #include <errno.h>
11 #include <assert.h>
12 #include <event2/event.h>
13 #include <event2/bufferevent.h>
14 #include <radsec/radsec.h>
15 #include <radsec/radsec-impl.h>
16 #include "err.h"
17 #include "debug.h"
18 #include "conn.h"
19 #include "event.h"
20 #include "message.h"
21 #include "tcp.h"
22
23 int
24 conn_user_dispatch_p (const struct rs_connection *conn)
25 {
26   assert (conn);
27
28   return (conn->callbacks.connected_cb ||
29           conn->callbacks.disconnected_cb ||
30           conn->callbacks.received_cb ||
31           conn->callbacks.sent_cb);
32 }
33
34 int
35 conn_activate_timeout (struct rs_connection *conn)
36 {
37   assert (conn);
38   assert (conn->tev);
39   assert (conn->base_.ctx->evb);
40   if (conn->base_.timeout.tv_sec || conn->base_.timeout.tv_usec)
41     {
42       rs_debug (("%s: activating timer: %d.%d\n", __func__,
43                  conn->base_.timeout.tv_sec, conn->base_.timeout.tv_usec));
44       if (evtimer_add (conn->tev, &conn->base_.timeout))
45         return rs_err_conn_push_fl (conn, RSE_EVENT, __FILE__, __LINE__,
46                                     "evtimer_add: %d", errno);
47     }
48   return RSE_OK;
49 }
50
51 int
52 conn_type_tls_p (const struct rs_connection *conn)
53 {
54   return TO_BASE_CONN(conn)->transport == RS_CONN_TYPE_TLS
55     || TO_BASE_CONN(conn)->transport == RS_CONN_TYPE_DTLS;
56 }
57
58 int
59 baseconn_type_datagram_p (const struct rs_conn_base *connbase)
60 {
61   return connbase->transport == RS_CONN_TYPE_UDP
62     || connbase->transport == RS_CONN_TYPE_DTLS;
63 }
64
65 int
66 baseconn_type_stream_p (const struct rs_conn_base *connbase)
67 {
68   return connbase->transport == RS_CONN_TYPE_TCP
69     || connbase->transport == RS_CONN_TYPE_TLS;
70 }
71
72 int
73 conn_cred_psk (const struct rs_connection *conn)
74 {
75   assert (conn);
76   assert (conn->active_peer);
77   return conn->active_peer->transport_cred &&
78     conn->active_peer->transport_cred->type == RS_CRED_TLS_PSK;
79 }
80
81 void
82 conn_init (struct rs_context *ctx, /* FIXME: rename connbase_init? */
83            struct rs_conn_base *connbase,
84            enum rs_conn_subtype type)
85 {
86   switch (type)
87     {
88     case RS_CONN_OBJTYPE_BASE:
89       connbase->magic = RS_CONN_MAGIC_BASE;
90       break;
91     case RS_CONN_OBJTYPE_GENERIC:
92       connbase->magic = RS_CONN_MAGIC_GENERIC;
93       break;
94     case RS_CONN_OBJTYPE_LISTENER:
95       connbase->magic = RS_CONN_MAGIC_LISTENER;
96       break;
97     default:
98       assert ("invalid connection subtype" == NULL);
99     }
100
101   connbase->ctx = ctx;
102   connbase->fd = -1;
103 }
104
105 int
106 conn_configure (struct rs_context *ctx, /* FIXME: rename conbbase_configure? */
107                 struct rs_conn_base *connbase,
108                 const char *config)
109 {
110   if (config)
111     {
112       struct rs_realm *r = rs_conf_find_realm (ctx, config);
113       if (r)
114         {
115           connbase->realm = r;
116           connbase->peers = r->peers; /* FIXME: Copy instead?  */
117 #if 0
118           for (p = connbase->peers; p != NULL; p = p->next)
119             p->connbase = connbase;
120 #endif
121           connbase->timeout.tv_sec = r->timeout;
122           connbase->tryagain = r->retries;
123         }
124     }
125 #if 0  /* incoming connections don't have a realm (a config object), update: they do, but "somebody else" is setting this up <-- FIXME */
126   if (connbase->realm == NULL)
127     {
128       struct rs_realm *r = rs_calloc (ctx, 1, sizeof (struct rs_realm));
129       if (r == NULL)
130         return rs_err_ctx_push_fl (ctx, RSE_NOMEM, __FILE__, __LINE__, NULL);
131       r->next = ctx->realms;
132       ctx->realms = connbase->realm = r;
133     }
134 #else
135   if (connbase->realm)
136     connbase->transport = connbase->realm->type;
137 #endif
138
139   return RSE_OK;
140 }
141
142 int
143 conn_add_read_event (struct rs_connection *conn, void *user_data)
144 {
145   struct rs_conn_base *connbase = TO_BASE_CONN(conn);
146   int err;
147
148   assert(connbase);
149
150   if (connbase->bev)            /* TCP (including TLS).  */
151     {
152       bufferevent_setwatermark (connbase->bev, EV_READ, RS_HEADER_LEN, 0);
153       bufferevent_setcb (connbase->bev, tcp_read_cb, NULL, tcp_event_cb,
154                          user_data);
155       bufferevent_enable (connbase->bev, EV_READ);
156     }
157   else                          /* UDP.  */
158     {
159       /* Put fresh message in user_data for the callback and enable the
160          read event.  */
161       event_assign (connbase->rev, connbase->ctx->evb,
162                     event_get_fd (connbase->rev), EV_READ,
163                     event_get_callback (connbase->rev),
164                     user_data);
165       err = event_add (connbase->rev, NULL);
166       if (err < 0)
167         return rs_err_connbase_push_fl (connbase, RSE_EVENT, __FILE__, __LINE__,
168                                         "event_add: %s",
169                                         evutil_gai_strerror (err));
170
171       /* Activate retransmission timer.  */
172       conn_activate_timeout (conn);
173     }
174
175   return RSE_OK;
176 }
177
178 /** Return !=0 if \a conn is an originating connection, i.e. if its
179     peer is a server. */
180 int
181 conn_originating_p (const struct rs_connection *conn)
182 {
183   return conn->active_peer->type == RS_PEER_TYPE_SERVER;
184 }
185
186 int
187 baseconn_close (struct rs_conn_base *connbase)
188 {
189   int err = 0;
190   assert (connbase);
191
192   rs_debug (("%s: closing fd %d\n", __func__, connbase->fd));
193
194   err = evutil_closesocket (connbase->fd);
195   if (err)
196     err = rs_err_connbase_push (connbase, RSE_EVENT,
197                                 "evutil_closesocket: %d (%s)",
198                                 errno, strerror (errno));
199   connbase->fd = -1;
200   return err;
201 }
202
203 /* Public functions. */
204 int
205 rs_conn_create (struct rs_context *ctx,
206                 struct rs_connection **conn,
207                 const char *config)
208 {
209   int err = RSE_OK;
210   struct rs_connection *c = NULL;
211   assert (ctx);
212
213   c = rs_calloc (ctx, 1, sizeof (struct rs_connection));
214   if (c == NULL)
215     return rs_err_ctx_push_fl (ctx, RSE_NOMEM, __FILE__, __LINE__, NULL);
216   conn_init (ctx, &c->base_, RS_CONN_OBJTYPE_GENERIC);
217   err = conn_configure (ctx, &c->base_, config);
218   if (err)
219     goto errout;
220
221   if (conn)
222     *conn = c;
223   return RSE_OK;
224
225  errout:
226   if (c)
227     rs_free (ctx, c);
228   return err;
229 }
230
231 void
232 rs_conn_set_type (struct rs_connection *conn, rs_conn_type_t type)
233 {
234   assert (conn);
235   assert (conn->base_.realm);
236   conn->base_.realm->type = type;
237 }
238
239 int
240 rs_conn_add_listener (struct rs_connection *conn,
241                       rs_conn_type_t type,
242                       const char *hostname,
243                       int port)
244 {
245   return rs_err_conn_push_fl (conn, RSE_NOSYS, __FILE__, __LINE__, NULL);
246 }
247
248
249 int
250 rs_conn_disconnect (struct rs_connection *conn)
251 {
252   int err = baseconn_close (TO_BASE_CONN (conn));
253   conn->state = RS_CONN_STATE_UNDEFINED;
254   return err;
255 }
256
257 int
258 rs_conn_destroy (struct rs_connection *conn)
259 {
260   int err = 0;
261
262   assert (conn);
263
264   /* NOTE: conn->realm is owned by context.  */
265   /* NOTE: conn->peers is owned by context.  */
266
267   if (conn->state == RS_CONN_STATE_CONNECTED)
268     err = rs_conn_disconnect (conn);
269
270 #if defined (RS_ENABLE_TLS)
271   if (conn->tls_ssl) /* FIXME: Free SSL strucxt in rs_conn_disconnect?  */
272     SSL_free (conn->tls_ssl);
273   if (conn->tls_ctx)
274     SSL_CTX_free (conn->tls_ctx);
275 #endif
276
277   if (conn->tev)
278     event_free (conn->tev);
279   if (conn->base_.bev)
280     bufferevent_free (conn->base_.bev);
281   if (conn->base_.rev)
282     event_free (conn->base_.rev);
283   if (conn->base_.wev)
284     event_free (conn->base_.wev);
285
286   rs_free (conn->base_.ctx, conn);
287
288   return err;
289 }
290
291 int
292 rs_conn_set_eventbase (struct rs_connection *conn, struct event_base *eb)
293 {
294   return rs_err_conn_push_fl (conn, RSE_NOSYS, __FILE__, __LINE__, NULL);
295 }
296
297 void
298 rs_conn_set_callbacks (struct rs_connection *conn,
299                        struct rs_conn_callbacks *cb,
300                        void *user_data)
301 {
302   assert (conn);
303   TO_BASE_CONN(conn)->user_data = user_data;
304   memcpy (&conn->callbacks, cb, sizeof (conn->callbacks));
305 }
306
307 void
308 rs_conn_del_callbacks (struct rs_connection *conn)
309 {
310   assert (conn);
311   memset (&conn->callbacks, 0, sizeof (conn->callbacks));
312 }
313
314 struct rs_conn_callbacks *
315 rs_conn_get_callbacks (struct rs_connection *conn)
316 {
317   assert (conn);
318   return &conn->callbacks;
319 }
320
321 int
322 rs_conn_select_peer (struct rs_connection *conn, const char *name)
323 {
324   return rs_err_conn_push_fl (conn, RSE_NOSYS, __FILE__, __LINE__, NULL);
325 }
326
327 int
328 rs_conn_get_current_peer (struct rs_connection *conn,
329                           const char *name,
330                           size_t buflen)
331 {
332   return rs_err_conn_push_fl (conn, RSE_NOSYS, __FILE__, __LINE__, NULL);
333 }
334
335 int
336 rs_conn_dispatch (struct rs_connection *conn)
337 {
338   assert (conn);
339   return event_base_loop (conn->base_.ctx->evb, EVLOOP_ONCE);
340 }
341
342 #if 0
343 struct event_base
344 *rs_conn_get_evb (const struct rs_connection *conn)
345 {
346   assert (conn);
347   return conn->evb;
348 }
349 #endif
350
351 int rs_conn_get_fd (struct rs_connection *conn)
352 {
353   assert (conn);
354   return conn->base_.fd;
355 }
356
357 static void
358 _rcb (struct rs_message *message, void *user_data)
359 {
360   struct rs_message *msg = (struct rs_message *) user_data;
361   assert (msg);
362   assert (msg->conn);
363
364   msg->flags |= RS_MESSAGE_RECEIVED;
365   if (msg->conn->base_.bev)     /* TCP -- disable bufferevent. */
366     bufferevent_disable (msg->conn->base_.bev, EV_WRITE|EV_READ);
367   else                          /* UDP -- remove read event. */
368     event_del (msg->conn->base_.rev);
369 }
370
371 int
372 rs_conn_receive_message (struct rs_connection *conn,
373                          struct rs_message *req_msg,
374                          struct rs_message **msg_out)
375 {
376   int err = 0;
377   struct rs_message *msg = NULL;
378
379   assert (conn);
380   assert (conn->base_.realm);
381   assert (!conn_user_dispatch_p (conn)); /* Blocking mode only.  */
382
383   if (rs_message_create (conn, &msg))
384     return -1;
385
386   assert (conn->base_.ctx->evb);
387   assert (conn->base_.fd >= 0);
388
389   conn->callbacks.received_cb = _rcb;
390   conn->base_.user_data = msg;
391   msg->flags &= ~RS_MESSAGE_RECEIVED;
392
393   err = conn_add_read_event (conn, msg);
394   if (err)
395     return err;
396
397   rs_debug (("%s: entering event loop\n", __func__));
398   err = event_base_dispatch (conn->base_.ctx->evb);
399   conn->callbacks.received_cb = NULL;
400   if (err < 0)
401     return rs_err_conn_push_fl (conn, RSE_EVENT, __FILE__, __LINE__,
402                                 "event_base_dispatch: %s",
403                                 evutil_gai_strerror (err));
404   rs_debug (("%s: event loop done\n", __func__));
405
406   if ((msg->flags & RS_MESSAGE_RECEIVED) == 0 /* No message. */
407       || (req_msg
408           && message_verify_response (conn, msg, req_msg) != RSE_OK))
409     {
410       if (rs_err_conn_peek_code (conn) == RSE_OK)
411         {
412           /* No message and no error on the stack _should_ mean that the
413              server hung up on us. */
414           rs_err_conn_push (conn, RSE_DISCO, "no response");
415         }
416       return rs_err_conn_peek_code (conn);
417     }
418
419   if (msg_out)
420     *msg_out = msg;
421   return RSE_OK;
422 }
423
424 void
425 rs_conn_set_timeout(struct rs_connection *conn, struct timeval *tv)
426 {
427   assert (conn);
428   assert (tv);
429   conn->base_.timeout = *tv;
430 }
431
432 struct rs_peer *
433 connbase_get_peers (const struct rs_conn_base *connbase)
434 {
435   assert (connbase);
436   return connbase->peers;
437 }