Add two helper functions to conn.[ch].
[libradsec.git] / lib / conn.c
1 /* Copyright 2010, 2011 NORDUnet A/S. All rights reserved.
2    See LICENSE for licensing information.  */
3
4 #if defined HAVE_CONFIG_H
5 #include <config.h>
6 #endif
7
8 #include <string.h>
9 #include <stdlib.h>
10 #include <errno.h>
11 #include <assert.h>
12 #include <event2/event.h>
13 #include <event2/bufferevent.h>
14 #include <radsec/radsec.h>
15 #include <radsec/radsec-impl.h>
16 #include "debug.h"
17 #include "conn.h"
18 #include "event.h"
19 #include "packet.h"
20 #include "tcp.h"
21
22 int
23 conn_close (struct rs_connection **connp)
24 {
25   int r = 0;
26   assert (connp);
27   assert (*connp);
28   if ((*connp)->is_connected)
29     r = rs_conn_disconnect (*connp);
30   if (r == RSE_OK)
31     *connp = NULL;
32   return r;
33 }
34
35 int
36 conn_user_dispatch_p (const struct rs_connection *conn)
37 {
38   assert (conn);
39
40   return (conn->callbacks.connected_cb ||
41           conn->callbacks.disconnected_cb ||
42           conn->callbacks.received_cb ||
43           conn->callbacks.sent_cb);
44 }
45
46
47 int
48 conn_activate_timeout (struct rs_connection *conn)
49 {
50   assert (conn);
51   assert (conn->tev);
52   assert (conn->evb);
53   if (conn->timeout.tv_sec || conn->timeout.tv_usec)
54     {
55       rs_debug (("%s: activating timer: %d.%d\n", __func__,
56                  conn->timeout.tv_sec, conn->timeout.tv_usec));
57       if (evtimer_add (conn->tev, &conn->timeout))
58         return rs_err_conn_push_fl (conn, RSE_EVENT, __FILE__, __LINE__,
59                                     "evtimer_add: %d", errno);
60     }
61   return RSE_OK;
62 }
63
64 int
65 conn_type_tls (const struct rs_connection *conn)
66 {
67   return conn->realm->type == RS_CONN_TYPE_TLS
68     || conn->realm->type == RS_CONN_TYPE_DTLS;
69 }
70
71 int
72 conn_cred_psk (const struct rs_connection *conn)
73 {
74   return conn->realm->transport_cred &&
75     conn->realm->transport_cred->type == RS_CRED_TLS_PSK;
76 }
77
78
79 /* Public functions. */
80 int
81 rs_conn_create (struct rs_context *ctx,
82                 struct rs_connection **conn,
83                 const char *config)
84 {
85   struct rs_connection *c;
86
87   c = (struct rs_connection *) malloc (sizeof(struct rs_connection));
88   if (!c)
89     return rs_err_ctx_push_fl (ctx, RSE_NOMEM, __FILE__, __LINE__, NULL);
90
91   memset (c, 0, sizeof(struct rs_connection));
92   c->ctx = ctx;
93   c->fd = -1;
94   if (config)
95     {
96       struct rs_realm *r = rs_conf_find_realm (ctx, config);
97       if (r)
98         {
99           struct rs_peer *p;
100
101           c->realm = r;
102           c->peers = r->peers;  /* FIXME: Copy instead?  */
103           for (p = c->peers; p; p = p->next)
104             p->conn = c;
105           c->timeout.tv_sec = r->timeout;
106           c->tryagain = r->retries;
107         }
108       else
109         {
110           c->realm = rs_malloc (ctx, sizeof (struct rs_realm));
111           if (!c->realm)
112             return rs_err_ctx_push_fl (ctx, RSE_NOMEM, __FILE__, __LINE__,
113                                        NULL);
114           memset (c->realm, 0, sizeof (struct rs_realm));
115         }
116     }
117
118   if (conn)
119     *conn = c;
120   return RSE_OK;
121 }
122
123 void
124 rs_conn_set_type (struct rs_connection *conn, rs_conn_type_t type)
125 {
126   assert (conn);
127   assert (conn->realm);
128   conn->realm->type = type;
129 }
130
131 int
132 rs_conn_add_listener (struct rs_connection *conn,
133                       rs_conn_type_t type,
134                       const char *hostname,
135                       int port)
136 {
137   return rs_err_conn_push_fl (conn, RSE_NOSYS, __FILE__, __LINE__, NULL);
138 }
139
140
141 int
142 rs_conn_disconnect (struct rs_connection *conn)
143 {
144   int err = 0;
145
146   assert (conn);
147
148   err = evutil_closesocket (conn->fd);
149   conn->fd = -1;
150   return err;
151 }
152
153 int
154 rs_conn_destroy (struct rs_connection *conn)
155 {
156   int err = 0;
157
158   assert (conn);
159
160   /* NOTE: conn->realm is owned by context.  */
161   /* NOTE: conn->peers is owned by context.  */
162
163   if (conn->is_connected)
164     err = rs_conn_disconnect (conn);
165
166 #if defined (RS_ENABLE_TLS)
167   if (conn->tls_ssl) /* FIXME: Free SSL strucxt in rs_conn_disconnect?  */
168     SSL_free (conn->tls_ssl);
169   if (conn->tls_ctx)
170     SSL_CTX_free (conn->tls_ctx);
171 #endif
172
173   if (conn->tev)
174     event_free (conn->tev);
175   if (conn->bev)
176     bufferevent_free (conn->bev);
177   if (conn->rev)
178     event_free (conn->rev);
179   if (conn->wev)
180     event_free (conn->wev);
181   if (conn->evb)
182     event_base_free (conn->evb);
183
184   rs_free (conn->ctx, conn);
185
186   return err;
187 }
188
189 int
190 rs_conn_set_eventbase (struct rs_connection *conn, struct event_base *eb)
191 {
192   return rs_err_conn_push_fl (conn, RSE_NOSYS, __FILE__, __LINE__, NULL);
193 }
194
195 void
196 rs_conn_set_callbacks (struct rs_connection *conn, struct rs_conn_callbacks *cb)
197 {
198   assert (conn);
199   memcpy (&conn->callbacks, cb, sizeof (conn->callbacks));
200 }
201
202 void
203 rs_conn_del_callbacks (struct rs_connection *conn)
204 {
205   assert (conn);
206   memset (&conn->callbacks, 0, sizeof (conn->callbacks));
207 }
208
209 struct rs_conn_callbacks *
210 rs_conn_get_callbacks(struct rs_connection *conn)
211 {
212   assert (conn);
213   return &conn->callbacks;
214 }
215
216 int
217 rs_conn_select_peer (struct rs_connection *conn, const char *name)
218 {
219   return rs_err_conn_push_fl (conn, RSE_NOSYS, __FILE__, __LINE__, NULL);
220 }
221
222 int
223 rs_conn_get_current_peer (struct rs_connection *conn,
224                           const char *name,
225                           size_t buflen)
226 {
227   return rs_err_conn_push_fl (conn, RSE_NOSYS, __FILE__, __LINE__, NULL);
228 }
229
230 int rs_conn_fd (struct rs_connection *conn)
231 {
232   assert (conn);
233   assert (conn->active_peer);
234   return conn->fd;
235 }
236
237 static void
238 _rcb (struct rs_packet *packet, void *user_data)
239 {
240   struct rs_packet *pkt = (struct rs_packet *) user_data;
241   assert (pkt);
242   assert (pkt->conn);
243
244   pkt->flags |= RS_PACKET_RECEIVED;
245   if (pkt->conn->bev)
246     bufferevent_disable (pkt->conn->bev, EV_WRITE|EV_READ);
247   else
248     event_del (pkt->conn->rev);
249 }
250
251 int
252 rs_conn_receive_packet (struct rs_connection *conn,
253                         struct rs_packet *req_msg,
254                         struct rs_packet **pkt_out)
255 {
256   int err = 0;
257   struct rs_packet *pkt = NULL;
258
259   assert (conn);
260   assert (conn->realm);
261   assert (!conn_user_dispatch_p (conn)); /* Blocking mode only.  */
262
263   if (rs_packet_create (conn, &pkt))
264     return -1;
265
266   assert (conn->evb);
267   assert (conn->fd >= 0);
268
269   conn->callbacks.received_cb = _rcb;
270   conn->user_data = pkt;
271   pkt->flags &= ~RS_PACKET_RECEIVED;
272
273   if (conn->bev)                /* TCP.  */
274     {
275       bufferevent_setwatermark (conn->bev, EV_READ, RS_HEADER_LEN, 0);
276       bufferevent_setcb (conn->bev, tcp_read_cb, NULL, tcp_event_cb, pkt);
277       bufferevent_enable (conn->bev, EV_READ);
278     }
279   else                          /* UDP.  */
280     {
281       /* Put fresh packet in user_data for the callback and enable the
282          read event.  */
283       event_assign (conn->rev, conn->evb, event_get_fd (conn->rev),
284                     EV_READ, event_get_callback (conn->rev), pkt);
285       err = event_add (conn->rev, NULL);
286       if (err < 0)
287         return rs_err_conn_push_fl (pkt->conn, RSE_EVENT, __FILE__, __LINE__,
288                                     "event_add: %s",
289                                     evutil_gai_strerror (err));
290
291       /* Activate retransmission timer.  */
292       conn_activate_timeout (pkt->conn);
293     }
294
295   rs_debug (("%s: entering event loop\n", __func__));
296   err = event_base_dispatch (conn->evb);
297   conn->callbacks.received_cb = NULL;
298   if (err < 0)
299     return rs_err_conn_push_fl (pkt->conn, RSE_EVENT, __FILE__, __LINE__,
300                                 "event_base_dispatch: %s",
301                                 evutil_gai_strerror (err));
302   rs_debug (("%s: event loop done\n", __func__));
303
304   if ((pkt->flags & RS_PACKET_RECEIVED) == 0
305       || (req_msg
306           && packet_verify_response (pkt->conn, pkt, req_msg) != RSE_OK))
307     {
308       if (rs_err_conn_peek_code (pkt->conn) == RSE_OK)
309         /* No packet and no error on the stack _should_ mean that the
310            server hung up on us.  */
311         rs_err_conn_push (pkt->conn, RSE_DISCO, "no response");
312       return rs_err_conn_peek_code (conn);
313     }
314
315   if (pkt_out)
316     *pkt_out = pkt;
317   return RSE_OK;
318 }
319
320 void
321 rs_conn_set_timeout(struct rs_connection *conn, struct timeval *tv)
322 {
323   assert (conn);
324   assert (tv);
325   conn->timeout = *tv;
326 }