Merge branch 'logging_changes' of https://github.com/adam-bishop/trust_router
[trust_router.git] / tid / example / tids_main.c
1 /*
2  * Copyright (c) 2012, 2015, JANET(UK)
3  * All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  *
9  * 1. Redistributions of source code must retain the above copyright
10  *    notice, this list of conditions and the following disclaimer.
11  *
12  * 2. Redistributions in binary form must reproduce the above copyright
13  *    notice, this list of conditions and the following disclaimer in the
14  *    documentation and/or other materials provided with the distribution.
15  *
16  * 3. Neither the name of JANET(UK) nor the names of its contributors
17  *    may be used to endorse or promote products derived from this software
18  *    without specific prior written permission.
19  *
20  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
23  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
24  * COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
25  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
28  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
29  * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
30  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
31  * OF THE POSSIBILITY OF SUCH DAMAGE.
32  *
33  */
34
35 #include <stdio.h>
36 #include <string.h>
37 #include <stdlib.h>
38 #include <talloc.h>
39 #include <sqlite3.h>
40
41 #include <tr_debug.h>
42 #include <tid_internal.h>
43 #include <trust_router/tr_constraint.h>
44 #include <trust_router/tr_dh.h>
45 #include <openssl/rand.h>
46
47 static sqlite3 *db = NULL;
48 static sqlite3_stmt *insert_stmt = NULL;
49 static sqlite3_stmt *authorization_insert = NULL;
50
51 static int  create_key_id(char *out_id, size_t len)
52 {
53   unsigned char rand_buf[32];
54   size_t bin_len;
55   if (len <8)
56     return -1;
57   strncpy(out_id, "key-", len);
58   len -= 4;
59   out_id += 4;
60   if (sizeof(rand_buf)*2+1 < len)
61     len = sizeof(rand_buf)*2 + 1;
62   bin_len = (len-1)/2;
63   if (-1 == RAND_pseudo_bytes(rand_buf, bin_len))
64       return -1;
65   tr_bin_to_hex(rand_buf, bin_len, out_id, len);
66   out_id[bin_len*2] = '\0';
67   return 0;
68 }
69
70 static int sqlify_wc(
71                      TID_REQ *req,
72                      const char **wc,
73                      size_t len,
74                      char **error)
75 {
76   size_t lc;
77   *error = NULL;
78   for (lc = 0; lc < len; lc++) {
79     if (strchr(wc[lc], '%')) {
80       *error = talloc_asprintf( req, "Constraint match `%s' is not appropriate for SQL",
81                                   wc[lc]);
82       return -1;
83     }
84     if ('*' ==wc[lc][0]) {
85       char *s;
86       s = talloc_strdup(req, wc[lc]);
87       s[0] = '%';
88       wc[lc] = s;
89     }
90   }
91   return 0;
92 }
93
94         
95
96 static int handle_authorizations(TID_REQ *req, const unsigned char *dh_hash,
97                                  size_t hash_len)
98 {
99   TR_CONSTRAINT_SET *intersected = NULL;
100   const char **domain_wc, **realm_wc;
101   size_t domain_len, realm_len;
102   size_t domain_index, realm_index;
103   char *error;
104   int sqlite3_result;
105
106   if (!req->cons) {
107     tr_debug("Request has no constraints, so no authorizations.");
108     return 0;
109   }
110   intersected = tr_constraint_set_intersect(req, req->cons);
111   if (!intersected)
112     return -1;
113   if (0 != tr_constraint_set_get_match_strings(req,
114                                                intersected, "domain",
115                                                &domain_wc, &domain_len))
116     return -1;
117   if (0 != tr_constraint_set_get_match_strings(req,
118                                                intersected, "realm",
119                                                &realm_wc, &realm_len))
120     return -1;
121   tr_debug(" %u domain constraint matches and %u realm constraint matches",
122            (unsigned) domain_len, (unsigned) realm_len);
123   if (0 != sqlify_wc(req, domain_wc, domain_len, &error)) {
124     tr_debug("Processing domain constraints: %s", error);
125     return -1;
126   }else if (0 != sqlify_wc(req, realm_wc, realm_len, &error)) {
127     tr_debug("Processing realm constraints: %s", error);
128     return -1;
129   }
130   if (!authorization_insert) {
131     tr_debug( " No database, no authorizations inserted");
132     return 0;
133   }
134   for (domain_index = 0; domain_index < domain_len; domain_index++)
135     for (realm_index = 0; realm_index < realm_len; realm_index++) {
136       TR_NAME *community = req->orig_coi;
137       if (!community)
138         community = req->comm;
139       sqlite3_bind_blob(authorization_insert, 1, dh_hash, hash_len, SQLITE_TRANSIENT);
140       sqlite3_bind_text(authorization_insert, 2, community->buf, community->len, SQLITE_TRANSIENT);
141       sqlite3_bind_text(authorization_insert, 3, realm_wc[realm_index], -1, SQLITE_TRANSIENT);
142       sqlite3_bind_text(authorization_insert, 4, domain_wc[domain_index], -1, SQLITE_TRANSIENT);
143       sqlite3_bind_text(authorization_insert, 5, req->comm->buf, req->comm->len, SQLITE_TRANSIENT);
144       sqlite3_result = sqlite3_step(authorization_insert);
145       if (SQLITE_DONE != sqlite3_result)
146         tr_crit("sqlite3: failed to write to database");
147       sqlite3_reset(authorization_insert);
148     }
149   return 0;
150 }
151
152
153 static int tids_req_handler (TIDS_INSTANCE *tids,
154                       TID_REQ *req, 
155                       TID_RESP *resp,
156                       void *cookie)
157 {
158   unsigned char *s_keybuf = NULL;
159   int s_keylen = 0;
160   char key_id[12];
161   unsigned char *pub_digest;
162   size_t pub_digest_len;
163   
164
165   tr_debug("tids_req_handler: Request received! target_realm = %s, community = %s", req->realm->buf, req->comm->buf);
166   if (tids)
167     tids->req_count++;
168
169   if (!(resp) || !resp) {
170     tr_debug("tids_req_handler: No response structure.");
171     return -1;
172   }
173
174
175   /* Allocate a new server block */
176   if (NULL == (resp->servers = malloc(sizeof(TID_SRVR_BLK)))){
177     tr_crit("tids_req_handler(): malloc failed.");
178     return -1;
179   }
180   memset(resp->servers, 0, sizeof(TID_SRVR_BLK));
181   resp->num_servers = 1;
182
183   /* TBD -- Set up the server IP Address */
184
185   if (!(req) || !(req->tidc_dh)) {
186     tr_debug("tids_req_handler(): No client DH info.");
187     return -1;
188   }
189
190   if ((!req->tidc_dh->p) || (!req->tidc_dh->g)) {
191     tr_debug("tids_req_handler: NULL dh values.");
192     return -1;
193   }
194
195   /* Generate the server DH block based on the client DH block */
196   // fprintf(stderr, "Generating the server DH block.\n");
197   // fprintf(stderr, "...from client DH block, dh_g = %s, dh_p = %s.\n", BN_bn2hex(req->tidc_dh->g), BN_bn2hex(req->tidc_dh->p));
198
199   if (NULL == (resp->servers->aaa_server_dh = tr_create_matching_dh(NULL, 0, req->tidc_dh))) {
200     tr_debug("tids_req_handler: Can't create server DH params.");
201     return -1;
202   }
203
204   if (0 == inet_aton(tids->ipaddr, &(resp->servers->aaa_server_addr))) {
205     tr_debug("tids_req_handler: inet_aton() failed.");
206     return -1;
207   }
208
209   /* Set the key name */
210   if (-1 == create_key_id(key_id, sizeof(key_id)))
211     return -1;
212   resp->servers->key_name = tr_new_name(key_id);
213
214   /* Generate the server key */
215   // fprintf(stderr, "Generating the server key.\n");
216
217   if (0 > (s_keylen = tr_compute_dh_key(&s_keybuf, 
218                                         req->tidc_dh->pub_key, 
219                                         resp->servers->aaa_server_dh))) {
220     tr_debug("tids_req_handler: Key computation failed.");
221     return -1;
222   }
223   if (0 != tr_dh_pub_hash(req,
224                           &pub_digest, &pub_digest_len)) {
225     tr_debug("tids_req_handler: Unable to digest client public key");
226     return -1;
227   }
228   if (0 != handle_authorizations(req, pub_digest, pub_digest_len))
229     return -1;
230   resp->servers->path = req->path;
231   if (req->expiration_interval < 1)
232     req->expiration_interval = 1;
233   g_get_current_time(&resp->servers->key_expiration);
234   resp->servers->key_expiration.tv_sec += req->expiration_interval;
235
236   if (NULL != insert_stmt) {
237     int sqlite3_result;
238     gchar *expiration_str = g_time_val_to_iso8601(&resp->servers->key_expiration);
239         sqlite3_bind_text(insert_stmt, 1, key_id, -1, SQLITE_TRANSIENT);
240     sqlite3_bind_blob(insert_stmt, 2, s_keybuf, s_keylen, SQLITE_TRANSIENT);
241     sqlite3_bind_blob(insert_stmt, 3, pub_digest, pub_digest_len, SQLITE_TRANSIENT);
242         sqlite3_bind_text(insert_stmt, 3, expiration_str, -1, SQLITE_TRANSIENT);
243     sqlite3_result = sqlite3_step(insert_stmt);
244     if (SQLITE_DONE != sqlite3_result)
245       tr_crit("sqlite3: failed to write to database");
246     sqlite3_reset(insert_stmt);
247   }
248   
249   /* Print out the key. */
250   // fprintf(stderr, "tids_req_handler(): Server Key Generated (len = %d):\n", s_keylen);
251   // for (i = 0; i < s_keylen; i++) {
252   // fprintf(stderr, "%x", s_keybuf[i]); 
253   // }
254   // fprintf(stderr, "\n");
255
256   return s_keylen;
257 }
258 static int auth_handler(gss_name_t gss_name, TR_NAME *client,
259                         void *expected_client)
260 {
261   TR_NAME *expected_client_trname = (TR_NAME*) expected_client;
262   return tr_name_cmp(client, expected_client_trname);
263 }
264
265
266 int main (int argc, 
267           const char *argv[]) 
268 {
269   TIDS_INSTANCE *tids;
270   int rc = 0;
271   char *ipaddr = NULL;
272   const char *hostname = NULL;
273   TR_NAME *gssname = NULL;
274
275   talloc_set_log_stderr();
276   /* Parse command-line arguments */ 
277   if (argc != 5) {
278     fprintf(stdout, "Usage: %s <ip-address> <gss-name> <hostname> <database-name>\n", argv[0]);
279     exit(1);
280   }
281
282   /* Use standalone logging */
283   tr_log_open();
284
285   /* set logging levels */
286   tr_log_threshold(LOG_CRIT);
287   tr_console_threshold(LOG_DEBUG);
288
289   ipaddr = (char *)argv[1];
290   gssname = tr_new_name((char *) argv[2]);
291   hostname = argv[3];
292   if (SQLITE_OK != sqlite3_open(argv[4], &db)) {
293     tr_crit("Error opening database %s", argv[4]);
294     exit(1);
295   }
296   sqlite3_busy_timeout( db, 1000);
297   sqlite3_prepare_v2(db, "insert into psk_keys (keyid, key, client_dh_pub, key_expiration) values(?, ?, ?, ?)",
298                      -1, &insert_stmt, NULL);
299   sqlite3_prepare_v2(db, "insert into authorizations (client_dh_pub, coi, acceptor_realm, hostname, apc) values(?, ?, ?, ?, ?)",
300                      -1, &authorization_insert, NULL);
301
302   /* Create a TID server instance */
303   if (NULL == (tids = tids_create())) {
304     tr_crit("Unable to create TIDS instance, exiting.");
305     return 1;
306   }
307
308   tids->ipaddr = ipaddr;
309
310   /* Start-up the server, won't return unless there is an error. */
311   rc = tids_start(tids, &tids_req_handler , auth_handler, hostname, TID_PORT, gssname);
312   
313   tr_crit("Error in tids_start(), rc = %d. Exiting.", rc);
314
315   /* Clean-up the TID server instance */
316   tids_destroy(tids);
317
318   return 1;
319 }
320