Enable tls psk
[libradsec.git] / packet.c
1 /* Copyright 2010-2013 NORDUnet A/S. All rights reserved.
2    See LICENSE for licensing information. */
3
4 #if defined HAVE_CONFIG_H
5 #include <config.h>
6 #endif
7
8 #include <stdlib.h>
9 #include <assert.h>
10 #include <radius/client.h>
11 #include <event2/bufferevent.h>
12 #include <radsec/radsec.h>
13 #include <radsec/radsec-impl.h>
14 #include "conn.h"
15 #include "debug.h"
16 #include "packet.h"
17
18 #if defined (DEBUG)
19 #include <netdb.h>
20 #include <sys/socket.h>
21 #include <event2/buffer.h>
22 #endif
23
24 int
25 packet_verify_response (struct rs_connection *conn,
26                         struct rs_packet *response,
27                         struct rs_packet *request)
28 {
29   int err;
30
31   assert (conn);
32   assert (conn->active_peer);
33   assert (conn->active_peer->secret);
34   assert (response);
35   assert (response->rpkt);
36   assert (request);
37   assert (request->rpkt);
38
39   response->rpkt->secret = conn->active_peer->secret;
40   response->rpkt->sizeof_secret = strlen (conn->active_peer->secret);
41
42   /* Verify header and message authenticator.  */
43   err = nr_packet_verify (response->rpkt, request->rpkt);
44   if (err)
45     {
46       if (conn->is_connected)
47         rs_conn_disconnect(conn);
48       return rs_err_conn_push_fl (conn, -err, __FILE__, __LINE__,
49                                   "nr_packet_verify");
50     }
51
52   /* Decode and decrypt.  */
53   err = nr_packet_decode (response->rpkt, request->rpkt);
54   if (err)
55     {
56       if (conn->is_connected)
57         rs_conn_disconnect(conn);
58       return rs_err_conn_push_fl (conn, -err, __FILE__, __LINE__,
59                                   "nr_packet_decode");
60     }
61
62   return RSE_OK;
63 }
64
65
66 /* Badly named function for preparing a RADIUS message and queue it.
67    FIXME: Rename.  */
68 int
69 packet_do_send (struct rs_packet *pkt)
70 {
71   int err;
72
73   assert (pkt);
74   assert (pkt->conn);
75   assert (pkt->conn->active_peer);
76   assert (pkt->conn->active_peer->secret);
77   assert (pkt->rpkt);
78
79   pkt->rpkt->secret = pkt->conn->active_peer->secret;
80   pkt->rpkt->sizeof_secret = strlen (pkt->rpkt->secret);
81
82   /* Encode message.  */
83   err = nr_packet_encode (pkt->rpkt, NULL);
84   if (err < 0)
85     return rs_err_conn_push_fl (pkt->conn, -err, __FILE__, __LINE__,
86                                 "nr_packet_encode");
87   /* Sign message.  */
88   err = nr_packet_sign (pkt->rpkt, NULL);
89   if (err < 0)
90     return rs_err_conn_push_fl (pkt->conn, -err, __FILE__, __LINE__,
91                                 "nr_packet_sign");
92 #if defined (DEBUG)
93   {
94     char host[80], serv[80];
95
96     getnameinfo (pkt->conn->active_peer->addr_cache->ai_addr,
97                  pkt->conn->active_peer->addr_cache->ai_addrlen,
98                  host, sizeof(host), serv, sizeof(serv),
99                  0 /* NI_NUMERICHOST|NI_NUMERICSERV*/);
100     rs_debug (("%s: about to send this to %s:%s:\n", __func__, host, serv));
101     rs_dump_packet (pkt);
102   }
103 #endif
104
105   /* Put message in output buffer.  */
106   if (pkt->conn->bev)           /* TCP.  */
107     {
108       int err = bufferevent_write (pkt->conn->bev, pkt->rpkt->data,
109                                    pkt->rpkt->length);
110       if (err < 0)
111         return rs_err_conn_push_fl (pkt->conn, RSE_EVENT, __FILE__, __LINE__,
112                                     "bufferevent_write: %s",
113                                     evutil_gai_strerror (err));
114     }
115   else                          /* UDP.  */
116     {
117       struct rs_packet **pp = &pkt->conn->out_queue;
118
119       while (*pp && (*pp)->next)
120         *pp = (*pp)->next;
121       *pp = pkt;
122     }
123
124   return RSE_OK;
125 }
126
127 /* Public functions.  */
128 int
129 rs_packet_create (struct rs_connection *conn, struct rs_packet **pkt_out)
130 {
131   struct rs_packet *p;
132   RADIUS_PACKET *rpkt;
133   int err;
134
135   *pkt_out = NULL;
136
137   rpkt = rs_malloc (conn->ctx, sizeof(*rpkt) + RS_MAX_PACKET_LEN);
138   if (rpkt == NULL)
139     return rs_err_conn_push (conn, RSE_NOMEM, __func__);
140
141   err = nr_packet_init (rpkt, NULL, NULL,
142                         PW_ACCESS_REQUEST,
143                         rpkt + 1, RS_MAX_PACKET_LEN);
144   if (err < 0)
145     return rs_err_conn_push (conn, -err, __func__);
146
147   p = (struct rs_packet *) rs_calloc (conn->ctx, 1, sizeof (*p));
148   if (p == NULL)
149     {
150       rs_free (conn->ctx, rpkt);
151       return rs_err_conn_push (conn, RSE_NOMEM, __func__);
152     }
153   p->conn = conn;
154   p->rpkt = rpkt;
155
156   *pkt_out = p;
157   return RSE_OK;
158 }
159
160 int
161 rs_packet_create_authn_request (struct rs_connection *conn,
162                                 struct rs_packet **pkt_out,
163                                 const char *user_name, const char *user_pw)
164 {
165   struct rs_packet *pkt;
166   int err;
167
168   if (rs_packet_create (conn, pkt_out))
169     return -1;
170
171   pkt = *pkt_out;
172   pkt->rpkt->code = PW_ACCESS_REQUEST;
173
174   if (user_name)
175     {
176       err = rs_packet_add_avp (pkt, PW_USER_NAME, 0, user_name,
177                                strlen (user_name));
178       if (err)
179         return err;
180     }
181
182   if (user_pw)
183     {
184       err = rs_packet_add_avp (pkt, PW_USER_PASSWORD, 0, user_pw,
185                                strlen (user_pw));
186       if (err)
187         return err;
188     }
189
190   return RSE_OK;
191 }
192
193 void
194 rs_packet_destroy (struct rs_packet *pkt)
195 {
196   assert (pkt);
197   assert (pkt->conn);
198   assert (pkt->conn->ctx);
199
200   rs_avp_free (&pkt->rpkt->vps);
201   rs_free (pkt->conn->ctx, pkt->rpkt);
202   rs_free (pkt->conn->ctx, pkt);
203 }
204
205 int
206 rs_packet_add_avp (struct rs_packet *pkt,
207                    unsigned int attr, unsigned int vendor,
208                    const void *data, size_t data_len)
209
210 {
211   const DICT_ATTR *da;
212   VALUE_PAIR *vp;
213   int err;
214
215   assert (pkt);
216   assert (pkt->conn);
217   assert (pkt->conn->ctx);
218
219   da = nr_dict_attr_byvalue (attr, vendor);
220   if (da == NULL)
221     return rs_err_conn_push (pkt->conn, RSE_ATTR_TYPE_UNKNOWN,
222                              "nr_dict_attr_byvalue");
223   vp = rs_malloc (pkt->conn->ctx, sizeof(*vp));
224   if (vp == NULL)
225     return rs_err_conn_push (pkt->conn, RSE_NOMEM, NULL);
226   if (nr_vp_init (vp, da) == NULL)
227     {
228       nr_vp_free (&vp);
229       return rs_err_conn_push (pkt->conn, RSE_INTERNAL, NULL);
230     }
231   err = nr_vp_set_data (vp, data, data_len);
232   if (err < 0)
233     {
234       nr_vp_free (&vp);
235       return rs_err_conn_push (pkt->conn, -err, "nr_vp_set_data");
236     }
237   nr_vps_append (&pkt->rpkt->vps, vp);
238
239   return RSE_OK;
240 }
241
242 /* TODO: Rename rs_packet_append_avp, indicating that encoding is
243    being done. */
244 int
245 rs_packet_append_avp (struct rs_packet *pkt,
246                       unsigned int attr, unsigned int vendor,
247                       const void *data, size_t data_len)
248 {
249   const DICT_ATTR *da;
250   int err;
251
252   assert (pkt);
253
254   da = nr_dict_attr_byvalue (attr, vendor);
255   if (da == NULL)
256     return rs_err_conn_push (pkt->conn, RSE_ATTR_TYPE_UNKNOWN, __func__);
257
258   err = nr_packet_attr_append (pkt->rpkt, NULL, da, data, data_len);
259   if (err < 0)
260     return rs_err_conn_push (pkt->conn, -err, __func__);
261
262   return RSE_OK;
263 }
264
265 void
266 rs_packet_avps (struct rs_packet *pkt, rs_avp ***vps)
267 {
268   assert (pkt);
269   *vps = &pkt->rpkt->vps;
270 }
271
272 unsigned int
273 rs_packet_code (struct rs_packet *pkt)
274 {
275   assert (pkt);
276   return pkt->rpkt->code;
277 }
278
279 rs_const_avp *
280 rs_packet_find_avp (struct rs_packet *pkt, unsigned int attr, unsigned int vendor)
281 {
282   assert (pkt);
283   return rs_avp_find_const (pkt->rpkt->vps, attr, vendor);
284 }
285
286 int
287 rs_packet_set_id (struct rs_packet *pkt, int id)
288 {
289   int old = pkt->rpkt->id;
290
291   pkt->rpkt->id = id;
292
293   return old;
294 }