Fix memory leak in tids.c.
[trust_router.git] / tid / tids.c
1 /*
2  * Copyright (c) 2012, 2015, JANET(UK)
3  * All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  *
9  * 1. Redistributions of source code must retain the above copyright
10  *    notice, this list of conditions and the following disclaimer.
11  *
12  * 2. Redistributions in binary form must reproduce the above copyright
13  *    notice, this list of conditions and the following disclaimer in the
14  *    documentation and/or other materials provided with the distribution.
15  *
16  * 3. Neither the name of JANET(UK) nor the names of its contributors
17  *    may be used to endorse or promote products derived from this software
18  *    without specific prior written permission.
19  *
20  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
23  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
24  * COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
25  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
28  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
29  * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
30  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
31  * OF THE POSSIBILITY OF SUCH DAMAGE.
32  *
33  */
34
35 #include <assert.h>
36 #include <stdlib.h>
37 #include <unistd.h>
38 #include <fcntl.h>
39 #include <string.h>
40 #include <stdio.h>
41 #include <errno.h>
42 #include <sys/socket.h>
43 #include <sys/wait.h>
44 #include <netinet/in.h>
45 #include <jansson.h>
46 #include <talloc.h>
47 #include <tid_internal.h>
48 #include <gsscon.h>
49 #include <tr_debug.h>
50 #include <tr_msg.h>
51
52 static TID_RESP *tids_create_response (TIDS_INSTANCE *tids, TID_REQ *req) 
53 {
54   TID_RESP *resp=NULL;
55   int success=0;
56
57   if ((NULL == (resp = talloc_zero(req, TID_RESP)))) {
58     tr_crit("tids_create_response: Error allocating response structure.");
59     return NULL;
60   }
61   
62   resp->result = TID_SUCCESS; /* presume success */
63   if ((NULL == (resp->rp_realm = tr_dup_name(req->rp_realm))) ||
64       (NULL == (resp->realm = tr_dup_name(req->realm))) ||
65       (NULL == (resp->comm = tr_dup_name(req->comm)))) {
66     tr_crit("tids_create_response: Error allocating fields in response.");
67     goto cleanup;
68   }
69   if (req->orig_coi) {
70     if (NULL == (resp->orig_coi = tr_dup_name(req->orig_coi))) {
71       tr_crit("tids_create_response: Error allocating fields in response.");
72       goto cleanup;
73     }
74   }
75
76   success=1;
77
78 cleanup:
79   if ((!success) && (resp!=NULL)) {
80     if (resp->rp_realm!=NULL)
81       tr_free_name(resp->rp_realm);
82     if (resp->realm!=NULL)
83       tr_free_name(resp->realm);
84     if (resp->comm!=NULL)
85       tr_free_name(resp->comm);
86     if (resp->orig_coi!=NULL)
87       tr_free_name(resp->orig_coi);
88     talloc_free(resp);
89     resp=NULL;
90   }
91   return resp;
92 }
93
94 static void tids_destroy_response(TIDS_INSTANCE *tids, TID_RESP *resp) 
95 {
96   if (resp) {
97     if (resp->err_msg)
98       tr_free_name(resp->err_msg);
99     if (resp->rp_realm)
100       tr_free_name(resp->rp_realm);
101     if (resp->realm)
102       tr_free_name(resp->realm);
103     if (resp->comm)
104       tr_free_name(resp->comm);
105     if (resp->orig_coi)
106       tr_free_name(resp->orig_coi);
107     talloc_free(resp);
108   }
109 }
110
111 static int tids_listen (TIDS_INSTANCE *tids, int port) 
112 {
113     int rc = 0;
114     int conn = -1;
115     int optval = 1;
116
117     union {
118       struct sockaddr_storage storage;
119       struct sockaddr_in in4;
120     } addr;
121
122     struct sockaddr_in *saddr = (struct sockaddr_in *) &addr.in4;
123
124     saddr->sin_port = htons (port);
125     saddr->sin_family = AF_INET;
126     saddr->sin_addr.s_addr = INADDR_ANY;
127
128     if (0 > (conn = socket (AF_INET, SOCK_STREAM, 0)))
129       return conn;
130
131     setsockopt(conn, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval));
132
133     if (0 > (rc = bind (conn, (struct sockaddr *) saddr, sizeof(struct sockaddr_in))))
134       return rc;
135
136     if (0 > (rc = listen(conn, 512)))
137       return rc;
138
139     tr_debug("tids_listen: TID Server listening on port %d", port);
140     return conn;
141 }
142
143 /* returns EACCES if authorization is denied */
144 static int tids_auth_cb(gss_name_t clientName, gss_buffer_t displayName,
145                         void *data)
146 {
147   struct tids_instance *inst = (struct tids_instance *) data;
148   TR_NAME name ={(char *) displayName->value,
149                  displayName->length};
150   int result=0;
151
152   if (0!=inst->auth_handler(clientName, &name, inst->cookie)) {
153     tr_debug("tids_auth_cb: client '%.*s' denied authorization.", name.len, name.buf);
154     result=EACCES; /* denied */
155   }
156
157   return result;
158 }
159
160 /* returns 0 on authorization success, 1 on failure, or -1 in case of error */
161 static int tids_auth_connection (TIDS_INSTANCE *inst,
162                                  int conn,
163                                  gss_ctx_id_t *gssctx)
164 {
165   int rc = 0;
166   int auth, autherr = 0;
167   gss_buffer_desc nameBuffer = {0, NULL};
168   char *name = 0;
169   int nameLen = 0;
170
171   nameLen = asprintf(&name, "trustidentity@%s", inst->hostname);
172   nameBuffer.length = nameLen;
173   nameBuffer.value = name;
174
175   if (rc = gsscon_passive_authenticate(conn, nameBuffer, gssctx, tids_auth_cb, inst)) {
176     tr_debug("tids_auth_connection: Error from gsscon_passive_authenticate(), rc = %d.", rc);
177     return -1;
178   }
179
180   if (rc = gsscon_authorize(*gssctx, &auth, &autherr)) {
181     tr_debug("tids_auth_connection: Error from gsscon_authorize, rc = %d, autherr = %d.", 
182             rc, autherr);
183     return -1;
184   }
185
186   if (auth)
187     tr_debug("tids_auth_connection: Connection authenticated, conn = %d.", conn);
188   else
189     tr_debug("tids_auth_connection: Authentication failed, conn %d.", conn);
190
191   return !auth;
192 }
193
194 static int tids_read_request (TIDS_INSTANCE *tids, int conn, gss_ctx_id_t *gssctx, TR_MSG **mreq)
195 {
196   int err;
197   char *buf;
198   size_t buflen = 0;
199
200   if (err = gsscon_read_encrypted_token(conn, *gssctx, &buf, &buflen)) {
201     if (buf)
202       free(buf);
203     return -1;
204   }
205
206   tr_debug("tids_read_request():Request Received, %u bytes.", (unsigned) buflen);
207
208   /* Parse request */
209   if (NULL == ((*mreq) = tr_msg_decode(buf, buflen))) {
210     tr_debug("tids_read_request():Error decoding request.");
211     free (buf);
212     return -1;
213   }
214
215   /* If this isn't a TID Request, just drop it. */
216   if (TID_REQUEST != (*mreq)->msg_type) {
217     tr_debug("tids_read_request(): Not a TID Request, dropped.");
218     return -1;
219   }
220
221   free (buf);
222   return buflen;
223 }
224
225 static int tids_handle_request (TIDS_INSTANCE *tids, TR_MSG *mreq, TID_RESP *resp) 
226 {
227   int rc;
228
229   /* Check that this is a valid TID Request.  If not, send an error return. */
230   if ((!tr_msg_get_req(mreq)) ||
231       (!tr_msg_get_req(mreq)->rp_realm) ||
232       (!tr_msg_get_req(mreq)->realm) ||
233       (!tr_msg_get_req(mreq)->comm)) {
234     tr_notice("tids_handle_request(): Not a valid TID Request.");
235     resp->result = TID_ERROR;
236     resp->err_msg = tr_new_name("Bad request format");
237     return -1;
238   }
239
240   tid_req_add_path(tr_msg_get_req(mreq), tids->hostname, tids->tids_port);
241   
242   /* Call the caller's request handler */
243   /* TBD -- Handle different error returns/msgs */
244   if (0 > (rc = (*tids->req_handler)(tids, tr_msg_get_req(mreq), resp, tids->cookie))) {
245     /* set-up an error response */
246     resp->result = TID_ERROR;
247     if (!resp->err_msg) /* Use msg set by handler, if any */
248       resp->err_msg = tr_new_name("Internal processing error");
249   }
250   else {
251     /* set-up a success response */
252     resp->result = TID_SUCCESS;
253     resp->err_msg = NULL;       /* No error msg on successful return */
254   }
255     
256   return rc;
257 }
258
259 int tids_send_err_response (TIDS_INSTANCE *tids, TID_REQ *req, const char *err_msg) {
260   TID_RESP *resp = NULL;
261   int rc = 0;
262
263   /* If we already sent a response, don't send another no matter what. */
264   if (req->resp_sent)
265     return 0;
266
267   if (NULL == (resp = tids_create_response(tids, req))) {
268     tr_crit("tids_send_err_response: Can't create response.");
269     return -1;
270   }
271
272   
273   /* mark this as an error response, and include the error message */
274   resp->result = TID_ERROR;
275   resp->err_msg = tr_new_name((char *)err_msg);
276   resp->error_path = req->path;
277
278   rc = tids_send_response(tids, req, resp);
279   
280   tids_destroy_response(tids, resp);
281   return rc;
282 }
283
284 int tids_send_response (TIDS_INSTANCE *tids, TID_REQ *req, TID_RESP *resp)
285 {
286   int err;
287   TR_MSG mresp;
288   char *resp_buf;
289
290   if ((!tids) || (!req) || (!resp))
291     tr_debug("tids_send_response: Invalid parameters.");
292
293   /* Never send a second response if we already sent one. */
294   if (req->resp_sent)
295     return 0;
296
297   mresp.msg_type = TID_RESPONSE;
298   tr_msg_set_resp(&mresp, resp);
299
300   if (NULL == (resp_buf = tr_msg_encode(&mresp))) {
301
302     fprintf(stderr, "tids_send_response: Error encoding json response.\n");
303     tr_audit_req(req);
304
305     return -1;
306   }
307
308   tr_debug("tids_send_response: Encoded response: %s", resp_buf);
309
310   /* If external logging is enabled, fire off a message */
311   /* TODO Can be moved to end once segfault in gsscon_write_encrypted_token fixed */
312   tr_audit_resp(resp);
313
314   /* Send the response over the connection */
315   if (err = gsscon_write_encrypted_token (req->conn, req->gssctx, resp_buf, 
316                                           strlen(resp_buf) + 1)) {
317     tr_notice("tids_send_response: Error sending response over connection.");
318
319     tr_audit_req(req);
320
321     return -1;
322   }
323
324   /* indicate that a response has been sent for this request */
325   req->resp_sent = 1;
326
327   free(resp_buf);
328
329   return 0;
330 }
331
332 static void tids_handle_connection (TIDS_INSTANCE *tids, int conn)
333 {
334   TR_MSG *mreq = NULL;
335   TID_RESP *resp = NULL;
336   int rc = 0;
337   gss_ctx_id_t gssctx = GSS_C_NO_CONTEXT;
338
339   if (tids_auth_connection(tids, conn, &gssctx)) {
340     tr_notice("tids_handle_connection: Error authorizing TID Server connection.");
341     close(conn);
342     return;
343   }
344
345   tr_debug("tids_handle_connection: Connection authorized!");
346
347   while (1) {   /* continue until an error breaks us out */
348
349     if (0 > (rc = tids_read_request(tids, conn, &gssctx, &mreq))) {
350       tr_debug("tids_handle_connection: Error from tids_read_request(), rc = %d.", rc);
351       return;
352     } else if (0 == rc) {
353       continue;
354     }
355
356     /* Put connection information into the request structure */
357     tr_msg_get_req(mreq)->conn = conn;
358     tr_msg_get_req(mreq)->gssctx = gssctx;
359
360     /* Allocate a response structure and populate common fields */
361     if (NULL == (resp = tids_create_response (tids, tr_msg_get_req(mreq)))) {
362       tr_crit("tids_handle_connection: Error creating response structure.");
363       /* try to send an error */
364       tids_send_err_response(tids, tr_msg_get_req(mreq), "Error creating response.");
365       tr_msg_free_decoded(mreq);
366       return;
367     }
368
369     if (0 > (rc = tids_handle_request(tids, mreq, resp))) {
370       tr_debug("tids_handle_connection: Error from tids_handle_request(), rc = %d.", rc);
371       /* Fall through, to send the response, either way */
372     }
373
374     if (0 > (rc = tids_send_response(tids, tr_msg_get_req(mreq), resp))) {
375       tr_debug("tids_handle_connection: Error from tids_send_response(), rc = %d.", rc);
376       /* if we didn't already send a response, try to send a generic error. */
377       if (!tr_msg_get_req(mreq)->resp_sent)
378         tids_send_err_response(tids, tr_msg_get_req(mreq), "Error sending response.");
379       /* Fall through to free the response, either way. */
380     }
381     
382     tids_destroy_response(tids, resp);
383     tr_msg_free_decoded(mreq);
384     return;
385   } 
386 }
387
388 TIDS_INSTANCE *tids_create (TALLOC_CTX *mem_ctx)
389 {
390   return talloc_zero(mem_ctx, TIDS_INSTANCE);
391 }
392
393 /* Get a listener for tids requests, returns its socket fd. Accept
394  * connections with tids_accept() */
395 int tids_get_listener(TIDS_INSTANCE *tids, 
396                       TIDS_REQ_FUNC *req_handler,
397                       TIDS_AUTH_FUNC *auth_handler,
398                       const char *hostname,
399                       unsigned int port,
400                       void *cookie)
401 {
402   int listen = -1;
403
404   tids->tids_port = port;
405   if (0 > (listen = tids_listen(tids, port))) {
406     char errbuf[256];
407     if (0 == strerror_r(errno, errbuf, 256)) {
408       tr_debug("tids_get_listener: Error opening port %d: %s.", port, errbuf);
409     } else {
410       tr_debug("tids_get_listener: Unknown error openining port %d.", port);
411     }
412   } 
413
414   if (listen > 0) {
415     /* opening port succeeded */
416     tr_debug("tids_get_listener: Opened port %d.", port);
417     
418     /* make this socket non-blocking */
419     if (0 != fcntl(listen, F_SETFL, O_NONBLOCK)) {
420       tr_debug("tids_get_listener: Error setting O_NONBLOCK.");
421       close(listen);
422       listen=-1;
423     }
424   }
425
426   if (listen > 0) {
427     /* store the caller's request handler & cookie */
428     tids->req_handler = req_handler;
429     tids->auth_handler = auth_handler;
430     tids->hostname = hostname;
431     tids->cookie = cookie;
432   }
433
434   return listen;
435 }
436
437 /* Accept and process a connection on a port opened with tids_get_listener() */
438 int tids_accept(TIDS_INSTANCE *tids, int listen)
439 {
440   int conn=-1;
441   int pid=-1;
442
443   if (0 > (conn = accept(listen, NULL, NULL))) {
444     perror("Error from TIDS Server accept()");
445     return 1;
446   }
447
448   if (0 > (pid = fork())) {
449     perror("Error on fork()");
450     return 1;
451   }
452
453   if (pid == 0) {
454     close(listen);
455     tids_handle_connection(tids, conn);
456     close(conn);
457     exit(0); /* exit to kill forked child process */
458   } else {
459     close(conn);
460   }
461
462   /* clean up any processes that have completed  (TBD: move to main loop?) */
463   while (waitpid(-1, 0, WNOHANG) > 0);
464
465   return 0;
466 }
467
468 /* Process tids requests forever. Should not return except on error. */
469 int tids_start (TIDS_INSTANCE *tids, 
470                 TIDS_REQ_FUNC *req_handler,
471                 TIDS_AUTH_FUNC *auth_handler,
472                 const char *hostname,
473                 unsigned int port,
474                 void *cookie)
475 {
476   int listen = -1;
477   int conn = -1;
478   pid_t pid;
479
480   tids->tids_port = port;
481   if (0 > (listen = tids_listen(tids, port)))
482     perror ("Error from tids_listen()");
483
484   /* store the caller's request handler & cookie */
485   tids->req_handler = req_handler;
486   tids->auth_handler = auth_handler;
487   tids->hostname = hostname;
488   tids->cookie = cookie;
489
490   tr_info("Trust Path Query Server starting on host %s:%d.", hostname, port);
491
492   while(1) {    /* accept incoming conns until we are stopped */
493
494     if (0 > (conn = accept(listen, NULL, NULL))) {
495       perror("Error from TIDS Server accept()");
496       return 1;
497     }
498
499     if (0 > (pid = fork())) {
500       perror("Error on fork()");
501       return 1;
502     }
503
504     if (pid == 0) {
505       close(listen);
506       tids_handle_connection(tids, conn);
507       close(conn);
508       exit(0); /* exit to kill forked child process */
509     } else {
510       close(conn);
511     }
512
513     /* clean up any processes that have completed */
514     while (waitpid(-1, 0, WNOHANG) > 0);
515   }
516
517   return 1;     /* should never get here, loops "forever" */
518 }
519
520 void tids_destroy (TIDS_INSTANCE *tids)
521 {
522   /* clean up logfiles */
523   tr_log_close();
524
525   if (tids)
526     free(tids);
527 }