Support multiple AAA servers. Compiles but untested.
[trust_router.git] / tid / example / tids_main.c
1 /*
2  * Copyright (c) 2012, 2015, JANET(UK)
3  * All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  *
9  * 1. Redistributions of source code must retain the above copyright
10  *    notice, this list of conditions and the following disclaimer.
11  *
12  * 2. Redistributions in binary form must reproduce the above copyright
13  *    notice, this list of conditions and the following disclaimer in the
14  *    documentation and/or other materials provided with the distribution.
15  *
16  * 3. Neither the name of JANET(UK) nor the names of its contributors
17  *    may be used to endorse or promote products derived from this software
18  *    without specific prior written permission.
19  *
20  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
23  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
24  * COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
25  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
28  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
29  * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
30  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
31  * OF THE POSSIBILITY OF SUCH DAMAGE.
32  *
33  */
34
35 #include <stdio.h>
36 #include <string.h>
37 #include <stdlib.h>
38 #include <talloc.h>
39 #include <sqlite3.h>
40 #include <argp.h>
41 #include <poll.h>
42
43 #include <tr_debug.h>
44 #include <tr_util.h>
45 #include <tid_internal.h>
46 #include <trust_router/tr_constraint.h>
47 #include <trust_router/tr_dh.h>
48 #include <openssl/rand.h>
49
50 static sqlite3 *db = NULL;
51 static sqlite3_stmt *insert_stmt = NULL;
52 static sqlite3_stmt *authorization_insert = NULL;
53
54 static int  create_key_id(char *out_id, size_t len)
55 {
56   unsigned char rand_buf[32];
57   size_t bin_len;
58   if (len <8)
59     return -1;
60   strncpy(out_id, "key-", len);
61   len -= 4;
62   out_id += 4;
63   if (sizeof(rand_buf)*2+1 < len)
64     len = sizeof(rand_buf)*2 + 1;
65   bin_len = (len-1)/2;
66   if (-1 == RAND_pseudo_bytes(rand_buf, bin_len))
67       return -1;
68   tr_bin_to_hex(rand_buf, bin_len, out_id, len);
69   out_id[bin_len*2] = '\0';
70   return 0;
71 }
72
73 static int sqlify_wc(
74                      TID_REQ *req,
75                      const char **wc,
76                      size_t len,
77                      char **error)
78 {
79   size_t lc;
80   *error = NULL;
81   for (lc = 0; lc < len; lc++) {
82     if (strchr(wc[lc], '%')) {
83       *error = talloc_asprintf( req, "Constraint match `%s' is not appropriate for SQL",
84                                   wc[lc]);
85       return -1;
86     }
87     if ('*' ==wc[lc][0]) {
88       char *s;
89       s = talloc_strdup(req, wc[lc]);
90       s[0] = '%';
91       wc[lc] = s;
92     }
93   }
94   return 0;
95 }
96
97         
98
99 static int handle_authorizations(TID_REQ *req, const unsigned char *dh_hash,
100                                  size_t hash_len)
101 {
102   TR_CONSTRAINT_SET *intersected = NULL;
103   const char **domain_wc, **realm_wc;
104   size_t domain_len, realm_len;
105   size_t domain_index, realm_index;
106   char *error;
107   int sqlite3_result;
108
109   if (!req->cons) {
110     tr_debug("Request has no constraints, so no authorizations.");
111     return 0;
112   }
113   intersected = tr_constraint_set_intersect(req, req->cons);
114   if (!intersected)
115     return -1;
116   if (0 != tr_constraint_set_get_match_strings(req,
117                                                intersected, "domain",
118                                                &domain_wc, &domain_len))
119     return -1;
120   if (0 != tr_constraint_set_get_match_strings(req,
121                                                intersected, "realm",
122                                                &realm_wc, &realm_len))
123     return -1;
124   tr_debug(" %u domain constraint matches and %u realm constraint matches",
125            (unsigned) domain_len, (unsigned) realm_len);
126   if (0 != sqlify_wc(req, domain_wc, domain_len, &error)) {
127     tr_debug("Processing domain constraints: %s", error);
128     return -1;
129   }else if (0 != sqlify_wc(req, realm_wc, realm_len, &error)) {
130     tr_debug("Processing realm constraints: %s", error);
131     return -1;
132   }
133   if (!authorization_insert) {
134     tr_debug( " No database, no authorizations inserted");
135     return 0;
136   }
137   for (domain_index = 0; domain_index < domain_len; domain_index++)
138     for (realm_index = 0; realm_index < realm_len; realm_index++) {
139       TR_NAME *community = req->orig_coi;
140       if (!community)
141         community = req->comm;
142       sqlite3_bind_blob(authorization_insert, 1, dh_hash, hash_len, SQLITE_TRANSIENT);
143       sqlite3_bind_text(authorization_insert, 2, community->buf, community->len, SQLITE_TRANSIENT);
144       sqlite3_bind_text(authorization_insert, 3, realm_wc[realm_index], -1, SQLITE_TRANSIENT);
145       sqlite3_bind_text(authorization_insert, 4, domain_wc[domain_index], -1, SQLITE_TRANSIENT);
146       sqlite3_bind_text(authorization_insert, 5, req->comm->buf, req->comm->len, SQLITE_TRANSIENT);
147       sqlite3_result = sqlite3_step(authorization_insert);
148       if (SQLITE_DONE != sqlite3_result)
149         tr_crit("sqlite3: failed to write to database");
150       sqlite3_reset(authorization_insert);
151     }
152   return 0;
153 }
154
155
156 static int tids_req_handler (TIDS_INSTANCE *tids,
157                       TID_REQ *req, 
158                       TID_RESP *resp,
159                       void *cookie)
160 {
161   unsigned char *s_keybuf = NULL;
162   int s_keylen = 0;
163   char key_id[12];
164   unsigned char *pub_digest;
165   size_t pub_digest_len;
166   
167
168   tr_debug("tids_req_handler: Request received! target_realm = %s, community = %s", req->realm->buf, req->comm->buf);
169   if (tids)
170     tids->req_count++;
171
172   if (!(resp) || !resp) {
173     tr_debug("tids_req_handler: No response structure.");
174     return -1;
175   }
176
177
178   /* Allocate a new server block */
179   tid_srvr_blk_add(resp->servers, tid_srvr_blk_new(resp));
180   if (NULL==resp->servers) {
181     tr_crit("tids_req_handler(): unable to allocate server block.");
182     return -1;
183   }
184
185   /* TBD -- Set up the server IP Address */
186
187   if (!(req) || !(req->tidc_dh)) {
188     tr_debug("tids_req_handler(): No client DH info.");
189     return -1;
190   }
191
192   if ((!req->tidc_dh->p) || (!req->tidc_dh->g)) {
193     tr_debug("tids_req_handler: NULL dh values.");
194     return -1;
195   }
196
197   /* Generate the server DH block based on the client DH block */
198   // fprintf(stderr, "Generating the server DH block.\n");
199   // fprintf(stderr, "...from client DH block, dh_g = %s, dh_p = %s.\n", BN_bn2hex(req->tidc_dh->g), BN_bn2hex(req->tidc_dh->p));
200
201   if (NULL == (resp->servers->aaa_server_dh = tr_create_matching_dh(NULL, 0, req->tidc_dh))) {
202     tr_debug("tids_req_handler: Can't create server DH params.");
203     return -1;
204   }
205
206   resp->servers->aaa_server_addr=talloc_strdup(resp->servers, tids->ipaddr);
207
208   /* Set the key name */
209   if (-1 == create_key_id(key_id, sizeof(key_id)))
210     return -1;
211   resp->servers->key_name = tr_new_name(key_id);
212
213   /* Generate the server key */
214   // fprintf(stderr, "Generating the server key.\n");
215
216   if (0 > (s_keylen = tr_compute_dh_key(&s_keybuf, 
217                                         req->tidc_dh->pub_key, 
218                                         resp->servers->aaa_server_dh))) {
219     tr_debug("tids_req_handler: Key computation failed.");
220     return -1;
221   }
222   if (0 != tr_dh_pub_hash(req,
223                           &pub_digest, &pub_digest_len)) {
224     tr_debug("tids_req_handler: Unable to digest client public key");
225     return -1;
226   }
227   if (0 != handle_authorizations(req, pub_digest, pub_digest_len))
228     return -1;
229   tid_srvr_blk_set_path(resp->servers, req->path);
230
231   if (req->expiration_interval < 1)
232     req->expiration_interval = 1;
233   g_get_current_time(&resp->servers->key_expiration);
234   resp->servers->key_expiration.tv_sec += req->expiration_interval * 60 /*in minutes*/;
235
236   if (NULL != insert_stmt) {
237     int sqlite3_result;
238     gchar *expiration_str = g_time_val_to_iso8601(&resp->servers->key_expiration);
239         sqlite3_bind_text(insert_stmt, 1, key_id, -1, SQLITE_TRANSIENT);
240     sqlite3_bind_blob(insert_stmt, 2, s_keybuf, s_keylen, SQLITE_TRANSIENT);
241     sqlite3_bind_blob(insert_stmt, 3, pub_digest, pub_digest_len, SQLITE_TRANSIENT);
242         sqlite3_bind_text(insert_stmt, 4, expiration_str, -1, SQLITE_TRANSIENT);
243     sqlite3_result = sqlite3_step(insert_stmt);
244     if (SQLITE_DONE != sqlite3_result)
245       tr_crit("sqlite3: failed to write to database");
246     sqlite3_reset(insert_stmt);
247   }
248   
249   /* Print out the key. */
250   // fprintf(stderr, "tids_req_handler(): Server Key Generated (len = %d):\n", s_keylen);
251   // for (i = 0; i < s_keylen; i++) {
252   // fprintf(stderr, "%x", s_keybuf[i]); 
253   // }
254   // fprintf(stderr, "\n");
255
256   return s_keylen;
257 }
258
259 static int auth_handler(gss_name_t gss_name, TR_NAME *client,
260                         void *expected_client)
261 {
262   TR_NAME *expected_client_trname = (TR_NAME*) expected_client;
263   int result=tr_name_cmp(client, expected_client_trname);
264   if (result != 0) {
265     tr_notice("Auth denied for incorrect gss-name ('%.*s' requested, expected '%.*s').",
266               client->len, client->buf,
267               expected_client_trname->len, expected_client_trname->buf);
268   }
269   return result;
270 }
271
272 /* command-line option setup */
273
274 /* argp global parameters */
275 const char *argp_program_bug_address=PACKAGE_BUGREPORT; /* bug reporting address */
276
277 /* doc strings */
278 static const char doc[]=PACKAGE_NAME " - TID Server";
279 static const char arg_doc[]="<ip-address> <gss-name> <hostname> <database-name>"; /* string describing arguments, if any */
280
281 /* define the options here. Fields are:
282  * { long-name, short-name, variable name, options, help description } */
283 static const struct argp_option cmdline_options[] = {
284   { NULL }
285 };
286
287 /* structure for communicating with option parser */
288 struct cmdline_args {
289   char *ip_address;
290   char *gss_name;
291   char *hostname;
292   char *database_name;
293 };
294
295 /* parser for individual options - fills in a struct cmdline_args */
296 static error_t parse_option(int key, char *arg, struct argp_state *state)
297 {
298   /* get a shorthand to the command line argument structure, part of state */
299   struct cmdline_args *arguments=state->input;
300
301   switch (key) {
302   case ARGP_KEY_ARG: /* handle argument (not option) */
303     switch (state->arg_num) {
304     case 0:
305       arguments->ip_address=arg;
306       break;
307
308     case 1:
309       arguments->gss_name=arg;
310       break;
311
312     case 2:
313       arguments->hostname=arg;
314       break;
315
316     case 3:
317       arguments->database_name=arg;
318       break;
319
320     default:
321       /* too many arguments */
322       argp_usage(state);
323     }
324     break;
325
326   case ARGP_KEY_END: /* no more arguments */
327     if (state->arg_num < 4) {
328       /* not enough arguments encountered */
329       argp_usage(state);
330     }
331     break;
332
333   default:
334     return ARGP_ERR_UNKNOWN;
335   }
336
337   return 0; /* success */
338 }
339
340 /* assemble the argp parser */
341 static struct argp argp = {cmdline_options, parse_option, arg_doc, doc};
342
343 int main (int argc, 
344           char *argv[]) 
345 {
346   TIDS_INSTANCE *tids;
347   TR_NAME *gssname = NULL;
348   struct cmdline_args opts={NULL};
349 #define MAX_SOCKETS 10
350   int tids_socket[MAX_SOCKETS];
351   size_t n_sockets;
352   struct pollfd poll_fds[MAX_SOCKETS];
353   size_t ii=0;
354
355   /* parse the command line*/
356   argp_parse(&argp, argc, argv, 0, 0, &opts);
357
358   talloc_set_log_stderr();
359
360   /* Use standalone logging */
361   tr_log_open();
362
363   /* set logging levels */
364   tr_log_threshold(LOG_CRIT);
365   tr_console_threshold(LOG_DEBUG);
366
367   gssname = tr_new_name(opts.gss_name);
368   if (SQLITE_OK != sqlite3_open(opts.database_name, &db)) {
369     tr_crit("Error opening database %s", opts.database_name);
370     exit(1);
371   }
372   sqlite3_busy_timeout( db, 1000);
373   sqlite3_prepare_v2(db, "insert into psk_keys_tab (keyid, key, client_dh_pub, key_expiration) values(?, ?, ?, ?)",
374                      -1, &insert_stmt, NULL);
375   sqlite3_prepare_v2(db, "insert into authorizations (client_dh_pub, coi, acceptor_realm, hostname, apc) values(?, ?, ?, ?, ?)",
376                      -1, &authorization_insert, NULL);
377
378   /* Create a TID server instance */
379   if (NULL == (tids = tids_create(NULL))) {
380     tr_crit("Unable to create TIDS instance, exiting.");
381     return 1;
382   }
383
384   tids->ipaddr = opts.ip_address;
385
386   /* get listener for tids port */
387   n_sockets = tids_get_listener(tids, &tids_req_handler, auth_handler, opts.hostname, TID_PORT, gssname,
388                                 tids_socket, MAX_SOCKETS);
389
390   for (ii=0; ii<n_sockets; ii++) {
391     poll_fds[ii].fd=tids_socket[ii];
392     poll_fds[ii].events=POLLIN; /* poll on ready for reading */
393     poll_fds[ii].revents=0;
394   }
395
396   /* main event loop */
397   while (1) {
398     /* wait up to 100 ms for an event, then handle any idle work */
399     if(poll(poll_fds, n_sockets, 100) > 0) {
400       for (ii=0; ii<n_sockets; ii++) {
401         if (poll_fds[ii].revents & POLLIN) {
402           if (0 != tids_accept(tids, tids_socket[ii])) {
403             tr_err("Error handling tids request.");
404           }
405         }
406       }
407     }
408     /* idle loop stuff here */
409   }
410
411   /* Clean-up the TID server instance */
412   tids_destroy(tids);
413
414   return 1;
415 }
416