Track num_servers correctly
[trust_router.git] / common / tr_msg.c
1 /*
2  * Copyright (c) 2012-2014 , JANET(UK)
3  * All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  *
9  * 1. Redistributions of source code must retain the above copyright
10  *    notice, this list of conditions and the following disclaimer.
11  *
12  * 2. Redistributions in binary form must reproduce the above copyright
13  *    notice, this list of conditions and the following disclaimer in the
14  *    documentation and/or other materials provided with the distribution.
15  *
16  * 3. Neither the name of JANET(UK) nor the names of its contributors
17  *    may be used to endorse or promote products derived from this software
18  *    without specific prior written permission.
19  *
20  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
23  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
24  * COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
25  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
28  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
29  * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
30  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
31  * OF THE POSSIBILITY OF SUCH DAMAGE.
32  *
33  */
34 #include <sys/socket.h>
35 #include <netinet/in.h>
36 #include <arpa/inet.h>
37 #include <string.h>
38 #include <openssl/dh.h>
39 #include <jansson.h>
40 #include <assert.h>
41 #include <talloc.h>
42
43
44 #include <tr_msg.h>
45 #include <trust_router/tr_name.h>
46 #include <tid_internal.h>
47 #include <trust_router/tr_constraint.h>
48 #include <tr_debug.h>
49
50 enum msg_type tr_msg_get_msg_type(TR_MSG *msg) 
51 {
52   return msg->msg_type;
53 }
54
55 void tr_msg_set_msg_type(TR_MSG *msg, enum msg_type type)
56 {
57   msg->msg_type = type;
58 }
59
60 TID_REQ *tr_msg_get_req(TR_MSG *msg)
61 {
62   return msg->tid_req;
63 }
64
65 void tr_msg_set_req(TR_MSG *msg, TID_REQ *req)
66 {
67   msg->tid_req = req;
68 }
69
70 TID_RESP *tr_msg_get_resp(TR_MSG *msg)
71 {
72   return msg->tid_resp;
73 }
74
75 void tr_msg_set_resp(TR_MSG *msg, TID_RESP *resp)
76 {
77   msg->tid_resp = resp;
78 }
79
80 static json_t *tr_msg_encode_dh(DH *dh)
81 {
82   json_t *jdh = NULL;
83   json_t *jbn = NULL;
84
85   if ((!dh) || (!dh->p) || (!dh->g) || (!dh->pub_key))
86     return NULL;
87
88   jdh = json_object();
89
90   jbn = json_string(BN_bn2hex(dh->p));
91   json_object_set_new(jdh, "dh_p", jbn);
92
93   jbn = json_string(BN_bn2hex(dh->g));
94   json_object_set_new(jdh, "dh_g", jbn);
95
96   jbn = json_string(BN_bn2hex(dh->pub_key));
97   json_object_set_new(jdh, "dh_pub_key", jbn);
98
99   return jdh;
100 }
101
102 static DH *tr_msg_decode_dh(json_t *jdh)
103 {
104   DH *dh = NULL;
105   json_t *jp = NULL;
106   json_t *jg = NULL;
107   json_t *jpub_key = NULL;
108
109   if (!(dh = malloc(sizeof(DH)))) {
110     fprintf (stderr, "tr_msg_decode_dh(): Error allocating DH structure.\n");
111     return NULL;
112   }
113  
114   memset(dh, 0, sizeof(DH));
115
116   /* store required fields from dh object */
117   if ((NULL == (jp = json_object_get(jdh, "dh_p"))) ||
118       (NULL == (jg = json_object_get(jdh, "dh_g"))) ||
119       (NULL == (jpub_key = json_object_get(jdh, "dh_pub_key")))) {
120     fprintf (stderr, "tr_msg_decode_dh(): Error parsing dh_info.\n");
121     free(dh);
122     return NULL;
123   }
124
125   BN_hex2bn(&(dh->p), json_string_value(jp));
126   BN_hex2bn(&(dh->g), json_string_value(jg));
127   BN_hex2bn(&(dh->pub_key), json_string_value(jpub_key));
128
129   return dh;
130 }
131
132 static json_t * tr_msg_encode_tidreq(TID_REQ *req)
133 {
134   json_t *jreq = NULL;
135   json_t *jstr = NULL;
136
137   if ((!req) || (!req->rp_realm) || (!req->realm) || !(req->comm))
138     return NULL;
139
140   assert(jreq = json_object());
141
142   jstr = json_string(req->rp_realm->buf);
143   json_object_set_new(jreq, "rp_realm", jstr);
144
145   jstr = json_string(req->realm->buf);
146   json_object_set_new(jreq, "target_realm", jstr);
147
148   jstr = json_string(req->comm->buf);
149   json_object_set_new(jreq, "community", jstr);
150   
151   if (req->orig_coi) {
152     jstr = json_string(req->orig_coi->buf);
153     json_object_set_new(jreq, "orig_coi", jstr);
154   }
155
156   json_object_set_new(jreq, "dh_info", tr_msg_encode_dh(req->tidc_dh));
157
158   if (req->cons)
159     json_object_set(jreq, "constraints", (json_t *) req->cons);
160
161   return jreq;
162 }
163
164 static TID_REQ *tr_msg_decode_tidreq(json_t *jreq)
165 {
166   TID_REQ *treq = NULL;
167   json_t *jrp_realm = NULL;
168   json_t *jrealm = NULL;
169   json_t *jcomm = NULL;
170   json_t *jorig_coi = NULL;
171   json_t *jdh = NULL;
172
173   if (!(treq =tid_req_new())) {
174     fprintf (stderr, "tr_msg_decode_tidreq(): Error allocating TID_REQ structure.\n");
175     return NULL;
176   }
177  
178   /* store required fields from request */
179   if ((NULL == (jrp_realm = json_object_get(jreq, "rp_realm"))) ||
180       (NULL == (jrealm = json_object_get(jreq, "target_realm"))) ||
181       (NULL == (jcomm = json_object_get(jreq, "community")))) {
182     fprintf (stderr, "tr_msg_decode(): Error parsing required fields.\n");
183     tid_req_free(treq);
184     return NULL;
185   }
186
187   treq->rp_realm = tr_new_name((char *)json_string_value(jrp_realm));
188   treq->realm = tr_new_name((char *)json_string_value(jrealm));
189   treq->comm = tr_new_name((char *)json_string_value(jcomm));
190
191   /* Get DH Info from the request */
192   if (NULL == (jdh = json_object_get(jreq, "dh_info"))) {
193     fprintf (stderr, "tr_msg_decode(): Error parsing dh_info.\n");
194     tid_req_free(treq);
195     return NULL;
196   }
197   treq->tidc_dh = tr_msg_decode_dh(jdh);
198
199   /* store optional "orig_coi" field */
200   if (NULL != (jorig_coi = json_object_get(jreq, "orig_coi"))) {
201     treq->orig_coi = tr_new_name((char *)json_string_value(jorig_coi));
202   }
203
204   treq->cons = (TR_CONSTRAINT_SET *) json_object_get(jreq, "constraints");
205   if (treq->cons) {
206     if (!tr_constraint_set_validate(treq->cons)) {
207       tr_debug("Constraint set validation failed\n");
208     tid_req_free(treq);
209     return NULL;
210     }
211     json_incref((json_t *) treq->cons);
212     tid_req_cleanup_json(treq, (json_t *) treq->cons);
213   }
214   return treq;
215 }
216
217 static json_t *tr_msg_encode_one_server(TID_SRVR_BLK *srvr)
218 {
219   json_t *jsrvr = NULL;
220   json_t *jstr = NULL;
221
222   fprintf(stderr, "Encoding one server.\n");
223
224   jsrvr = json_object();
225
226   /* Server IP Address -- TBD handle IPv6 */
227   jstr = json_string(inet_ntoa(srvr->aaa_server_addr));
228   json_object_set_new(jsrvr, "server_addr", jstr);
229
230   /* Server DH Block */
231   jstr = json_string(srvr->key_name->buf);
232   json_object_set_new(jsrvr, "key_name", jstr);
233   json_object_set_new(jsrvr, "server_dh", tr_msg_encode_dh(srvr->aaa_server_dh));
234   
235   //  fprintf(stderr,"tr_msg_encode_one_server(): jsrvr contains:\n");
236   //  fprintf(stderr,"%s\n", json_dumps(jsrvr, 0));
237   return jsrvr;
238 }
239
240 static int tr_msg_decode_one_server(json_t *jsrvr, TID_SRVR_BLK *srvr) 
241 {
242   json_t *jsrvr_addr = NULL;
243   json_t *jsrvr_kn = NULL;
244   json_t *jsrvr_dh = NULL;
245
246   if (jsrvr == NULL)
247     return -1;
248
249
250   if ((NULL == (jsrvr_addr = json_object_get(jsrvr, "server_addr"))) ||
251       (NULL == (jsrvr_kn = json_object_get(jsrvr, "key_name"))) ||
252       (NULL == (jsrvr_dh = json_object_get(jsrvr, "server_dh")))) {
253     tr_debug("tr_msg_decode_one_server(): Error parsing required fields.\n");
254     return -1;
255   }
256   
257   /* TBD -- handle IPv6 Addresses */
258   inet_aton(json_string_value(jsrvr_addr), &(srvr->aaa_server_addr));
259   srvr->key_name = tr_new_name((char *)json_string_value(jsrvr_kn));
260   srvr->aaa_server_dh = tr_msg_decode_dh(jsrvr_dh);
261   return 0;
262 }
263
264 static json_t *tr_msg_encode_servers(TID_RESP *resp)
265 {
266   json_t *jservers = NULL;
267   json_t *jsrvr = NULL;
268   TID_SRVR_BLK *srvr = NULL;
269   size_t index;
270
271   jservers = json_array();
272
273   tid_resp_servers_foreach(resp, srvr, index) {
274     if ((NULL == (jsrvr = tr_msg_encode_one_server(srvr))) ||
275         (-1 == json_array_append_new(jservers, jsrvr))) {
276       return NULL;
277     }
278   }
279
280   //  fprintf(stderr,"tr_msg_encode_servers(): servers contains:\n");
281   //  fprintf(stderr,"%s\n", json_dumps(jservers, 0));
282   return jservers;
283 }
284
285 static TID_SRVR_BLK *tr_msg_decode_servers(void * ctx, json_t *jservers, size_t *out_len)
286 {
287   TID_SRVR_BLK *servers = NULL;
288   json_t *jsrvr;
289   size_t i, num_servers;
290
291   num_servers = json_array_size(jservers);
292   fprintf(stderr, "tr_msg_decode_servers(): Number of servers = %u.\n", (unsigned) num_servers);
293   
294   if (0 == num_servers) {
295     fprintf(stderr, "tr_msg_decode_servers(): Server array is empty.\n"); 
296     return NULL;
297   }
298     servers = talloc_zero_array(ctx, TID_SRVR_BLK, num_servers);
299
300   for (i = 0; i < num_servers; i++) {
301     jsrvr = json_array_get(jservers, i);
302     if (0 != tr_msg_decode_one_server(jsrvr, &servers[i])) {
303       talloc_free(servers);
304       return NULL;
305     }
306
307
308   }
309   *out_len = num_servers;
310   return servers;
311 }
312
313 static json_t * tr_msg_encode_tidresp(TID_RESP *resp)
314 {
315   json_t *jresp = NULL;
316   json_t *jstr = NULL;
317   json_t *jservers = NULL;
318
319   if ((!resp) || (!resp->rp_realm) || (!resp->realm) || !(resp->comm))
320     return NULL;
321
322   jresp = json_object();
323
324   if (TID_ERROR == resp->result) {
325     jstr = json_string("error");
326     json_object_set_new(jresp, "result", jstr);
327     if (resp->err_msg) {
328       jstr = json_string(resp->err_msg->buf);
329       json_object_set_new(jresp, "err_msg", jstr);
330     }
331   }
332   else {
333     jstr = json_string("success");
334     json_object_set_new(jresp, "result", jstr);
335   }
336
337   jstr = json_string(resp->rp_realm->buf);
338   json_object_set_new(jresp, "rp_realm", jstr);
339
340   jstr = json_string(resp->realm->buf);
341   json_object_set_new(jresp, "target_realm", jstr);
342
343   jstr = json_string(resp->comm->buf);
344   json_object_set_new(jresp, "comm", jstr);
345
346   if (resp->orig_coi) {
347     jstr = json_string(resp->orig_coi->buf);
348     json_object_set_new(jresp, "orig_coi", jstr);
349   }
350
351   if (NULL == resp->servers) {
352     fprintf(stderr, "tr_msg_encode_tidresp(): No servers to encode.\n");
353     return jresp;
354   }
355   jservers = tr_msg_encode_servers(resp);
356   json_object_set_new(jresp, "servers", jservers);
357   
358   return jresp;
359 }
360
361 static TID_RESP *tr_msg_decode_tidresp(json_t *jresp)
362 {
363   TID_RESP *tresp = NULL;
364   json_t *jresult = NULL;
365   json_t *jrp_realm = NULL;
366   json_t *jrealm = NULL;
367   json_t *jcomm = NULL;
368   json_t *jorig_coi = NULL;
369   json_t *jservers = NULL;
370   json_t *jerr_msg = NULL;
371
372   if (!(tresp = talloc_zero(NULL, TID_RESP))) {
373     fprintf (stderr, "tr_msg_decode_tidresp(): Error allocating TID_RESP structure.\n");
374     return NULL;
375   }
376  
377
378   /* store required fields from response */
379   if ((NULL == (jresult = json_object_get(jresp, "result"))) ||
380       (!json_is_string(jresult)) ||
381       (NULL == (jrp_realm = json_object_get(jresp, "rp_realm"))) ||
382       (!json_is_string(jrp_realm)) ||
383       (NULL == (jrealm = json_object_get(jresp, "target_realm"))) ||
384       (!json_is_string(jrealm)) ||
385       (NULL == (jcomm = json_object_get(jresp, "comm"))) ||
386       (!json_is_string(jcomm))) {
387     fprintf (stderr, "tr_msg_decode_tidresp(): Error parsing response.\n");
388     talloc_free(tresp);
389     return NULL;
390   }
391
392   if (0 == (strcmp(json_string_value(jresult), "success"))) {
393     fprintf(stderr, "tr_msg_decode_tidresp(): Success! result = %s.\n", json_string_value(jresult));
394     if ((NULL != (jservers = json_object_get(jresp, "servers"))) ||
395         (!json_is_array(jservers))) {
396       tresp->servers = tr_msg_decode_servers(tresp, jservers, &tresp->num_servers); 
397     } 
398     else {
399       talloc_free(tresp);
400       return NULL;
401     }
402     tresp->result = TID_SUCCESS;
403   }
404   else {
405     tresp->result = TID_ERROR;
406     fprintf(stderr, "tr_msg_decode_tidresp(): Error! result = %s.\n", json_string_value(jresult));
407     if ((NULL != (jerr_msg = json_object_get(jresp, "err_msg"))) ||
408         (!json_is_string(jerr_msg))) {
409       tresp->err_msg = tr_new_name((char *)json_string_value(jerr_msg));
410     }
411   }
412
413   tresp->rp_realm = tr_new_name((char *)json_string_value(jrp_realm));
414   tresp->realm = tr_new_name((char *)json_string_value(jrealm));
415   tresp->comm = tr_new_name((char *)json_string_value(jcomm));
416
417   /* store optional "orig_coi" field */
418   if ((NULL != (jorig_coi = json_object_get(jresp, "orig_coi"))) &&
419       (!json_is_object(jorig_coi))) {
420     tresp->orig_coi = tr_new_name((char *)json_string_value(jorig_coi));
421   }
422      
423   return tresp;
424 }
425
426 char *tr_msg_encode(TR_MSG *msg) 
427 {
428   json_t *jmsg;
429   json_t *jmsg_type;
430
431   /* TBD -- add error handling */
432   jmsg = json_object();
433
434   switch (msg->msg_type) 
435     {
436     case TID_REQUEST:
437       jmsg_type = json_string("tid_request");
438       json_object_set_new(jmsg, "msg_type", jmsg_type);
439       json_object_set_new(jmsg, "msg_body", tr_msg_encode_tidreq(msg->tid_req));
440       break;
441
442     case TID_RESPONSE:
443       jmsg_type = json_string("tid_response");
444       json_object_set_new(jmsg, "msg_type", jmsg_type);
445       json_object_set_new(jmsg, "msg_body", tr_msg_encode_tidresp(msg->tid_resp));
446       break;
447
448       /* TBD -- Add TR message types */
449
450     default:
451       json_decref(jmsg);
452       return NULL;
453     }
454   
455   return(json_dumps(jmsg, 0));
456 }
457
458 TR_MSG *tr_msg_decode(char *jbuf, size_t buflen)
459 {
460   TR_MSG *msg;
461   json_t *jmsg = NULL;
462   json_error_t rc;
463   json_t *jtype;
464   json_t *jbody;
465   const char *mtype = NULL;
466
467   if (NULL == (jmsg = json_loadb(jbuf, buflen, JSON_DISABLE_EOF_CHECK, &rc))) {
468     fprintf (stderr, "tr_msg_decode(): error loading object\n");
469     return NULL;
470   }
471
472   if (!(msg = malloc(sizeof(TR_MSG)))) {
473     fprintf (stderr, "tr_msg_decode(): Error allocating TR_MSG structure.\n");
474     json_decref(jmsg);
475     return NULL;
476   }
477  
478   memset(msg, 0, sizeof(TR_MSG));
479
480   if ((NULL == (jtype = json_object_get(jmsg, "msg_type"))) ||
481       (NULL == (jbody = json_object_get(jmsg, "msg_body")))) {
482     fprintf (stderr, "tr_msg_decode(): Error parsing message header.\n");
483     json_decref(jmsg);
484     tr_msg_free_decoded(msg);
485     return NULL;
486   }
487
488   mtype = json_string_value(jtype);
489
490   if (0 == strcmp(mtype, "tid_request")) {
491     msg->msg_type = TID_REQUEST;
492     msg->tid_req = tr_msg_decode_tidreq(jbody);
493   }
494   else if (0 == strcmp(mtype, "tid_response")) {
495     msg->msg_type = TID_RESPONSE;
496     msg->tid_resp = tr_msg_decode_tidresp(jbody);
497   }
498   else {
499     msg->msg_type = TR_UNKNOWN;
500     msg->tid_req = NULL;
501   }
502   return msg;
503 }
504
505 void tr_msg_free_encoded(char *jmsg)
506 {
507   if (jmsg)
508     free (jmsg);
509 }
510
511 void tr_msg_free_decoded(TR_MSG *msg)
512 {
513   if (msg)
514     free (msg);
515 }
516
517