tid: implement GSS name callback
[trust_router.git] / tid / example / tids_main.c
1 /*
2  * Copyright (c) 2012, JANET(UK)
3  * All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  *
9  * 1. Redistributions of source code must retain the above copyright
10  *    notice, this list of conditions and the following disclaimer.
11  *
12  * 2. Redistributions in binary form must reproduce the above copyright
13  *    notice, this list of conditions and the following disclaimer in the
14  *    documentation and/or other materials provided with the distribution.
15  *
16  * 3. Neither the name of JANET(UK) nor the names of its contributors
17  *    may be used to endorse or promote products derived from this software
18  *    without specific prior written permission.
19  *
20  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
23  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
24  * COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
25  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
28  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
29  * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
30  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
31  * OF THE POSSIBILITY OF SUCH DAMAGE.
32  *
33  */
34
35 #include <stdio.h>
36 #include <string.h>
37 #include <stdlib.h>
38 #include <sqlite3.h>
39
40 #include <trust_router/tid.h>
41 #include <trust_router/tr_dh.h>
42 #include <openssl/rand.h>
43
44 static sqlite3 *db = NULL;
45 static sqlite3_stmt *insert_stmt = NULL;
46
47 static int  create_key_id(char *out_id, size_t len)
48 {
49   unsigned char rand_buf[32];
50   size_t bin_len;
51   if (len <8)
52     return -1;
53   strncpy(out_id, "key-", len);
54   len -= 4;
55   out_id += 4;
56   if (sizeof(rand_buf)*2+1 < len)
57     len = sizeof(rand_buf)*2 + 1;
58   bin_len = (len-1)/2;
59   if (-1 == RAND_pseudo_bytes(rand_buf, bin_len))
60       return -1;
61   tr_bin_to_hex(rand_buf, bin_len, out_id, len);
62   out_id[bin_len*2] = '\0';
63   return 0;
64 }
65   
66 static int tids_req_handler (TIDS_INSTANCE * tids,
67                       TID_REQ *req, 
68                       TID_RESP **resp,
69                       void *cookie)
70 {
71   unsigned char *s_keybuf = NULL;
72   int s_keylen = 0;
73   char key_id[12];
74   
75
76   fprintf(stdout, "tids_req_handler: Request received! target_realm = %s, community = %s\n", req->realm->buf, req->comm->buf);
77   if (tids)
78     tids->req_count++;
79
80   if (!(resp) || !(*resp)) {
81     fprintf(stderr, "tids_req_handler: No response structure.\n");
82     return -1;
83   }
84
85   /* Allocate a new server block */
86   if (NULL == ((*resp)->servers = malloc(sizeof(TID_SRVR_BLK)))){
87     fprintf(stderr, "tids_req_handler(): malloc failed.\n");
88     return -1;
89   }
90   memset((*resp)->servers, 0, sizeof(TID_SRVR_BLK));
91
92   /* TBD -- Set up the server IP Address */
93
94   if (!(req) || !(req->tidc_dh)) {
95     fprintf(stderr, "tids_req_handler(): No client DH info.\n");
96     return -1;
97   }
98
99   if ((!req->tidc_dh->p) || (!req->tidc_dh->g)) {
100     fprintf(stderr, "tids_req_handler(): NULL dh values.\n");
101     return -1;
102   }
103
104   /* Generate the server DH block based on the client DH block */
105   // fprintf(stderr, "Generating the server DH block.\n");
106   // fprintf(stderr, "...from client DH block, dh_g = %s, dh_p = %s.\n", BN_bn2hex(req->tidc_dh->g), BN_bn2hex(req->tidc_dh->p));
107
108   if (NULL == ((*resp)->servers->aaa_server_dh = tr_create_matching_dh(NULL, 0, req->tidc_dh))) {
109     fprintf(stderr, "tids_req_handler(): Can't create server DH params.\n");
110     return -1;
111   }
112
113   if (0 == inet_aton(tids->ipaddr, &((*resp)->servers->aaa_server_addr))) {
114     fprintf(stderr, "tids_req_handler(): inet_aton() failed.\n");
115     return -1;
116   }
117
118   /* Set the key name */
119   if (-1 == create_key_id(key_id, sizeof(key_id)))
120     return -1;
121   (*resp)->servers->key_name = tr_new_name(key_id);
122
123   /* Generate the server key */
124   // fprintf(stderr, "Generating the server key.\n");
125
126   if (0 > (s_keylen = tr_compute_dh_key(&s_keybuf, 
127                                         req->tidc_dh->pub_key, 
128                                         (*resp)->servers->aaa_server_dh))) {
129     fprintf(stderr, "tids_req_handler(): Key computation failed.");
130     return -1;
131   }
132   if (NULL != insert_stmt) {
133     int sqlite3_result;
134     sqlite3_bind_text(insert_stmt, 1, key_id, -1, SQLITE_TRANSIENT);
135     sqlite3_bind_blob(insert_stmt, 2, s_keybuf, s_keylen, SQLITE_TRANSIENT);
136     sqlite3_result = sqlite3_step(insert_stmt);
137     if (SQLITE_DONE != sqlite3_result)
138       printf("sqlite3: failed to write to database\n");
139     sqlite3_reset(insert_stmt);
140   }
141   
142   /* Print out the key. */
143   // fprintf(stderr, "tids_req_handler(): Server Key Generated (len = %d):\n", s_keylen);
144   // for (i = 0; i < s_keylen; i++) {
145   // fprintf(stderr, "%x", s_keybuf[i]); 
146   // }
147   // fprintf(stderr, "\n");
148
149   return s_keylen;
150 }
151 static int auth_handler(gss_name_t gss_name, TR_NAME *client,
152                         void *expected_client)
153 {
154   TR_NAME *expected_client_trname = (TR_NAME*) expected_client;
155   return tr_name_cmp(client, expected_client_trname);
156 }
157
158
159 int main (int argc, 
160           const char *argv[]) 
161 {
162   TIDS_INSTANCE *tids;
163   int rc = 0;
164   char *ipaddr = NULL;
165   TR_NAME *gssname = NULL;
166
167   /* Parse command-line arguments */ 
168   if (argc > 4)
169     fprintf(stdout, "Usage: %s [<ip-address> <gss-name> [<database-name>]]\n", argv[0]);
170
171   if (argc >= 2) {
172     ipaddr = (char *)argv[1];
173   } else {
174     ipaddr = "127.0.0.1";
175   }
176   gssname = tr_new_name((char *) argv[2]);
177
178   /* TBD -- check that input is a valid IP address? */
179
180   /*If we have a database, open and prepare*/
181   if (argc == 4) {
182     if (SQLITE_OK != sqlite3_open(argv[3], &db)) {
183       fprintf(stdout, "Error opening database %s\n", argv[2]);
184       exit(1);
185     }
186     sqlite3_prepare_v2(db, "insert into psk_keys (keyid, key) values(?, ?)",
187                        -1, &insert_stmt, NULL);
188   }
189
190   /* Create a TID server instance */
191   if (NULL == (tids = tids_create())) {
192     fprintf(stdout, "Unable to create TIDS instance,exiting.\n");
193     return 1;
194   }
195
196   tids->ipaddr = ipaddr;
197
198   /* Start-up the server, won't return unless there is an error. */
199   rc = tids_start(tids, &tids_req_handler , auth_handler, gssname);
200   
201   fprintf(stdout, "Error in tids_start(), rc = %d. Exiting.\n", rc);
202
203   /* Clean-up the TID server instance */
204   tids_destroy(tids);
205
206   return 1;
207 }
208