Revamping for listeners.
[radsecproxy.git] / lib / conn.c
1 /* Copyright 2010,2011,2013 NORDUnet A/S. All rights reserved.
2    See LICENSE for licensing information. */
3
4 #if defined HAVE_CONFIG_H
5 #include <config.h>
6 #endif
7
8 #include <string.h>
9 #include <stdlib.h>
10 #include <errno.h>
11 #include <assert.h>
12 #include <event2/event.h>
13 #include <event2/bufferevent.h>
14 #include <radsec/radsec.h>
15 #include <radsec/radsec-impl.h>
16 #include "debug.h"
17 #include "conn.h"
18 #include "event.h"
19 #include "message.h"
20 #include "tcp.h"
21
22 int
23 conn_close (struct rs_connection **connp)
24 {
25   int r = 0;
26   assert (connp);
27   assert (*connp);
28   if ((*connp)->state == RS_CONN_STATE_CONNECTED)
29     r = rs_conn_disconnect (*connp);
30   if (r == RSE_OK)
31     *connp = NULL;
32   return r;
33 }
34
35 int
36 conn_user_dispatch_p (const struct rs_connection *conn)
37 {
38   assert (conn);
39
40   return (conn->callbacks.connected_cb ||
41           conn->callbacks.disconnected_cb ||
42           conn->callbacks.received_cb ||
43           conn->callbacks.sent_cb);
44 }
45
46
47 int
48 conn_activate_timeout (struct rs_connection *conn)
49 {
50   assert (conn);
51   assert (conn->tev);
52   assert (conn->base_.ctx->evb);
53   if (conn->base_.timeout.tv_sec || conn->base_.timeout.tv_usec)
54     {
55       rs_debug (("%s: activating timer: %d.%d\n", __func__,
56                  conn->base_.timeout.tv_sec, conn->base_.timeout.tv_usec));
57       if (evtimer_add (conn->tev, &conn->base_.timeout))
58         return rs_err_conn_push_fl (conn, RSE_EVENT, __FILE__, __LINE__,
59                                     "evtimer_add: %d", errno);
60     }
61   return RSE_OK;
62 }
63
64 int
65 conn_type_tls (const struct rs_connection *conn)
66 {
67   assert (conn->base_.active_peer);
68   return conn->base_.realm->type == RS_CONN_TYPE_TLS
69     || conn->base_.realm->type == RS_CONN_TYPE_DTLS;
70 }
71
72 int
73 conn_cred_psk (const struct rs_connection *conn)
74 {
75   assert (conn->base_.active_peer);
76   return conn->base_.active_peer->transport_cred &&
77     conn->base_.active_peer->transport_cred->type == RS_CRED_TLS_PSK;
78 }
79
80 void
81 conn_init (struct rs_context *ctx,
82            struct rs_conn_base *connbase,
83            enum rs_conn_subtype type)
84 {
85   switch (type)
86     {
87     case RS_CONN_OBJTYPE_BASE:
88       connbase->magic = RS_CONN_MAGIC_BASE;
89       break;
90     case RS_CONN_OBJTYPE_GENERIC:
91       connbase->magic = RS_CONN_MAGIC_GENERIC;
92       break;
93     case RS_CONN_OBJTYPE_LISTENER:
94       connbase->magic = RS_CONN_MAGIC_LISTENER;
95       break;
96     default:
97       assert ("invalid connection subtype" == NULL);
98     }
99
100   connbase->ctx = ctx;
101   connbase->fd = -1;
102 }
103
104 int
105 conn_configure (struct rs_context *ctx,
106                 struct rs_conn_base *connbase,
107                 const char *config)
108 {
109   if (config)
110     {
111       struct rs_realm *r = rs_conf_find_realm (ctx, config);
112       if (r)
113         {
114           connbase->realm = r;
115           connbase->peers = r->peers; /* FIXME: Copy instead?  */
116 #if 0
117           for (p = connbase->peers; p != NULL; p = p->next)
118             p->connbase = connbase;
119 #endif
120           connbase->timeout.tv_sec = r->timeout;
121           connbase->tryagain = r->retries;
122         }
123     }
124   if (connbase->realm == NULL)
125     {
126       connbase->realm = rs_calloc (ctx, 1, sizeof (struct rs_realm));
127       if (connbase->realm == NULL)
128         return rs_err_ctx_push_fl (ctx, RSE_NOMEM, __FILE__, __LINE__, NULL);
129     }
130   return RSE_OK;
131 }
132
133 /* Public functions. */
134 int
135 rs_conn_create (struct rs_context *ctx,
136                 struct rs_connection **conn,
137                 const char *config)
138 {
139   int err = RSE_OK;
140   struct rs_connection *c = NULL;
141   assert (ctx);
142
143   c = rs_calloc (ctx, 1, sizeof (struct rs_connection));
144   if (c == NULL)
145     return rs_err_ctx_push_fl (ctx, RSE_NOMEM, __FILE__, __LINE__, NULL);
146   conn_init (ctx, &c->base_, RS_CONN_OBJTYPE_GENERIC);
147   err = conn_configure (ctx, &c->base_, config);
148   if (err)
149     goto errout;
150
151   if (conn)
152     *conn = c;
153   return RSE_OK;
154
155  errout:
156   if (c)
157     rs_free (ctx, c);
158   return err;
159 }
160
161 void
162 rs_conn_set_type (struct rs_connection *conn, rs_conn_type_t type)
163 {
164   assert (conn);
165   assert (conn->base_.realm);
166   conn->base_.realm->type = type;
167 }
168
169 int
170 rs_conn_add_listener (struct rs_connection *conn,
171                       rs_conn_type_t type,
172                       const char *hostname,
173                       int port)
174 {
175   return rs_err_conn_push_fl (conn, RSE_NOSYS, __FILE__, __LINE__, NULL);
176 }
177
178
179 int
180 rs_conn_disconnect (struct rs_connection *conn)
181 {
182   int err = 0;
183
184   assert (conn);
185
186   err = evutil_closesocket (conn->base_.fd);
187   conn->base_.fd = -1;
188   return err;
189 }
190
191 int
192 rs_conn_destroy (struct rs_connection *conn)
193 {
194   int err = 0;
195
196   assert (conn);
197
198   /* NOTE: conn->realm is owned by context.  */
199   /* NOTE: conn->peers is owned by context.  */
200
201   if (conn->state == RS_CONN_STATE_CONNECTED)
202     err = rs_conn_disconnect (conn);
203
204 #if defined (RS_ENABLE_TLS)
205   if (conn->tls_ssl) /* FIXME: Free SSL strucxt in rs_conn_disconnect?  */
206     SSL_free (conn->tls_ssl);
207   if (conn->tls_ctx)
208     SSL_CTX_free (conn->tls_ctx);
209 #endif
210
211   if (conn->tev)
212     event_free (conn->tev);
213   if (conn->base_.bev)
214     bufferevent_free (conn->base_.bev);
215   if (conn->base_.rev)
216     event_free (conn->base_.rev);
217   if (conn->base_.wev)
218     event_free (conn->base_.wev);
219
220   rs_free (conn->base_.ctx, conn);
221
222   return err;
223 }
224
225 int
226 rs_conn_set_eventbase (struct rs_connection *conn, struct event_base *eb)
227 {
228   return rs_err_conn_push_fl (conn, RSE_NOSYS, __FILE__, __LINE__, NULL);
229 }
230
231 void
232 rs_conn_set_callbacks (struct rs_connection *conn, struct rs_conn_callbacks *cb)
233 {
234   assert (conn);
235   memcpy (&conn->callbacks, cb, sizeof (conn->callbacks));
236 }
237
238 void
239 rs_conn_del_callbacks (struct rs_connection *conn)
240 {
241   assert (conn);
242   memset (&conn->callbacks, 0, sizeof (conn->callbacks));
243 }
244
245 struct rs_conn_callbacks *
246 rs_conn_get_callbacks(struct rs_connection *conn)
247 {
248   assert (conn);
249   return &conn->callbacks;
250 }
251
252 int
253 rs_conn_select_peer (struct rs_connection *conn, const char *name)
254 {
255   return rs_err_conn_push_fl (conn, RSE_NOSYS, __FILE__, __LINE__, NULL);
256 }
257
258 int
259 rs_conn_get_current_peer (struct rs_connection *conn,
260                           const char *name,
261                           size_t buflen)
262 {
263   return rs_err_conn_push_fl (conn, RSE_NOSYS, __FILE__, __LINE__, NULL);
264 }
265
266 int
267 rs_conn_dispatch (struct rs_connection *conn)
268 {
269   assert (conn);
270   return event_base_loop (conn->base_.ctx->evb, EVLOOP_ONCE);
271 }
272
273 #if 0
274 struct event_base
275 *rs_conn_get_evb (const struct rs_connection *conn)
276 {
277   assert (conn);
278   return conn->evb;
279 }
280 #endif
281
282 int rs_conn_get_fd (struct rs_connection *conn)
283 {
284   assert (conn);
285   assert (conn->base_.active_peer);
286   return conn->base_.fd;
287 }
288
289 static void
290 _rcb (struct rs_message *message, void *user_data)
291 {
292   struct rs_message *msg = (struct rs_message *) user_data;
293   assert (msg);
294   assert (msg->conn);
295
296   msg->flags |= RS_MESSAGE_RECEIVED;
297   if (msg->conn->base_.bev)
298     bufferevent_disable (msg->conn->base_.bev, EV_WRITE|EV_READ);
299   else
300     event_del (msg->conn->base_.rev);
301 }
302
303 int
304 rs_conn_receive_message (struct rs_connection *conn,
305                          struct rs_message *req_msg,
306                          struct rs_message **msg_out)
307 {
308   int err = 0;
309   struct rs_message *msg = NULL;
310
311   assert (conn);
312   assert (conn->base_.realm);
313   assert (!conn_user_dispatch_p (conn)); /* Blocking mode only.  */
314
315   if (rs_message_create (conn, &msg))
316     return -1;
317
318   assert (conn->base_.ctx->evb);
319   assert (conn->base_.fd >= 0);
320
321   conn->callbacks.received_cb = _rcb;
322   conn->base_.user_data = msg;
323   msg->flags &= ~RS_MESSAGE_RECEIVED;
324
325   if (conn->base_.bev)          /* TCP.  */
326     {
327       bufferevent_setwatermark (conn->base_.bev, EV_READ, RS_HEADER_LEN, 0);
328       bufferevent_setcb (conn->base_.bev, tcp_read_cb, NULL, tcp_event_cb, msg);
329       bufferevent_enable (conn->base_.bev, EV_READ);
330     }
331   else                          /* UDP.  */
332     {
333       /* Put fresh message in user_data for the callback and enable the
334          read event.  */
335       event_assign (conn->base_.rev, conn->base_.ctx->evb,
336                     event_get_fd (conn->base_.rev), EV_READ,
337                     event_get_callback (conn->base_.rev), msg);
338       err = event_add (conn->base_.rev, NULL);
339       if (err < 0)
340         return rs_err_conn_push_fl (msg->conn, RSE_EVENT, __FILE__, __LINE__,
341                                     "event_add: %s",
342                                     evutil_gai_strerror (err));
343
344       /* Activate retransmission timer.  */
345       conn_activate_timeout (msg->conn);
346     }
347
348   rs_debug (("%s: entering event loop\n", __func__));
349   err = event_base_dispatch (conn->base_.ctx->evb);
350   conn->callbacks.received_cb = NULL;
351   if (err < 0)
352     return rs_err_conn_push_fl (msg->conn, RSE_EVENT, __FILE__, __LINE__,
353                                 "event_base_dispatch: %s",
354                                 evutil_gai_strerror (err));
355   rs_debug (("%s: event loop done\n", __func__));
356
357   if ((msg->flags & RS_MESSAGE_RECEIVED) == 0
358       || (req_msg
359           && message_verify_response (msg->conn, msg, req_msg) != RSE_OK))
360     {
361       if (rs_err_conn_peek_code (msg->conn) == RSE_OK)
362         /* No message and no error on the stack _should_ mean that the
363            server hung up on us.  */
364         rs_err_conn_push (msg->conn, RSE_DISCO, "no response");
365       return rs_err_conn_peek_code (conn);
366     }
367
368   if (msg_out)
369     *msg_out = msg;
370   return RSE_OK;
371 }
372
373 void
374 rs_conn_set_timeout(struct rs_connection *conn, struct timeval *tv)
375 {
376   assert (conn);
377   assert (tv);
378   conn->base_.timeout = *tv;
379 }