Adding testing CLI client (based off the Heimdal testing sample)
[mod_auth_kerb.git] / client / http_client.c
1 /*
2  * Copyright (c) 2003 - 2005 Kungliga Tekniska Högskolan
3  * (Royal Institute of Technology, Stockholm, Sweden). 
4  * All rights reserved. 
5  *
6  * Redistribution and use in source and binary forms, with or without 
7  * modification, are permitted provided that the following conditions 
8  * are met: 
9  *
10  * 1. Redistributions of source code must retain the above copyright 
11  *    notice, this list of conditions and the following disclaimer. 
12  *
13  * 2. Redistributions in binary form must reproduce the above copyright 
14  *    notice, this list of conditions and the following disclaimer in the 
15  *    documentation and/or other materials provided with the distribution. 
16  *
17  * 3. Neither the name of the Institute nor the names of its contributors 
18  *    may be used to endorse or promote products derived from this software 
19  *    without specific prior written permission. 
20  *
21  * THIS SOFTWARE IS PROVIDED BY THE INSTITUTE AND CONTRIBUTORS ``AS IS'' AND 
22  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 
23  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 
24  * ARE DISCLAIMED.  IN NO EVENT SHALL THE INSTITUTE OR CONTRIBUTORS BE LIABLE 
25  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 
26  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS 
27  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 
28  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 
29  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY 
30  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF 
31  * SUCH DAMAGE. 
32  */
33
34 #include "client_locl.h"
35 #include <gssapi.h>
36 #include <gssapi/gssapi_ext.h>
37 #include "gss_common.h"
38 #include "base64.h"
39
40 /*
41  * A simplistic client implementing draft-brezak-spnego-http-04.txt
42  */
43
44 static int
45 do_connect (const char *hostname, const char *port)
46 {
47     struct addrinfo *ai, *a;
48     struct addrinfo hints;
49     int error;
50     int s = -1;
51
52     memset (&hints, 0, sizeof(hints));
53     hints.ai_family = PF_UNSPEC;
54     hints.ai_socktype = SOCK_STREAM;
55     hints.ai_protocol = 0;
56
57     error = getaddrinfo (hostname, port, &hints, &ai);
58     if (error)
59         errx (1, "getaddrinfo(%s): %s", hostname, gai_strerror(error));
60
61     for (a = ai; a != NULL; a = a->ai_next) {
62         s = socket (a->ai_family, a->ai_socktype, a->ai_protocol);
63         if (s < 0)
64             continue;
65         if (connect (s, a->ai_addr, a->ai_addrlen) < 0) {
66             warn ("connect(%s)", hostname);
67             close (s);
68             continue;
69         }
70         break;
71     }
72     freeaddrinfo (ai);
73     if (a == NULL)
74         errx (1, "failed to contact %s", hostname);
75
76     return s;
77 }
78
79 static void
80 fdprintf(int s, const char *fmt, ...)
81 {
82     size_t len;
83     ssize_t ret;
84     va_list ap;
85     char *str, *buf;
86     
87     va_start(ap, fmt);
88     vasprintf(&str, fmt, ap);
89     va_end(ap);
90
91     if (str == NULL)
92         errx(1, "vasprintf");
93
94     buf = str;
95     len = strlen(buf);
96     while (len) {
97         ret = write(s, buf, len);
98         if (ret == 0)
99             err(1, "connection closed");
100         else if (ret < 0)
101             err(1, "error");
102         len -= ret;
103         buf += ret;
104     }
105     free(str);
106 }
107
108 //static int version_flag;
109 static int verbose_flag;
110 static int mutual_flag = 1;
111 static int delegate_flag;
112 static char *mech = NULL;
113 static char *port_str = "http";
114 static char *gss_service = "HTTP";
115 static char *user = NULL;
116 static char *pwd = NULL;
117
118 static struct option const long_opts[] = {
119     { "help", no_argument, 0, 'h' },
120     { "mech", required_argument, 0, 'm' },
121     { "password", required_argument, 0, 'p' },
122     { "gss-service", required_argument, 0, 's' },
123     { "user", required_argument, 0, 'u' },
124     { NULL, 0, NULL, 0 }
125 };
126
127 static const char *short_opts = "hm:p:s:u:";
128
129 static void
130 usage(int ret)
131 {
132     fprintf(stderr, "Usage: http_client [OPTION] URL\n"
133                     "-m mech, --mech=mech               gssapi mech to use\n"
134                     "-p pass, --password=pass           password to acquire credentials\n"
135                     "-s service, --gss-service=service  gssapi service to use\n"
136                     "-u user, --user=user               client's username\n");
137     exit(ret);
138 }
139
140 /*
141  *
142  */
143
144 struct http_req {
145     char *response;
146     char **headers;
147     int num_headers;
148     void *body;
149     size_t body_size;
150 };
151
152
153 static void
154 http_req_zero(struct http_req *req)
155 {
156     req->response = NULL;
157     req->headers = NULL;
158     req->num_headers = 0;
159     req->body = NULL;
160     req->body_size = 0;
161 }
162
163 static void
164 http_req_free(struct http_req *req)
165 {
166     int i;
167
168     free(req->response);
169     for (i = 0; i < req->num_headers; i++)
170         free(req->headers[i]);
171     free(req->headers);
172     free(req->body);
173     http_req_zero(req);
174 }
175
176 static const char *
177 http_find_header(struct http_req *req, const char *header)
178 {
179     int i, len = strlen(header);
180
181     for (i = 0; i < req->num_headers; i++) {
182         if (strncasecmp(header, req->headers[i], len) == 0) {
183             return req->headers[i] + len + 1;
184         }
185     }
186     return NULL;
187 }
188
189
190 static int
191 http_query(int s, const char *host, const char *page, 
192            char **headers, int num_headers, struct http_req *req)
193 {
194     enum { RESPONSE, HEADER, BODY } state;
195     ssize_t ret;
196 //    char in_buf[4096], *in_ptr = in_buf;
197     char in_buf[8000], *in_ptr = in_buf;
198     size_t in_len = 0;
199     int i;
200     size_t content_length = 0;
201
202     http_req_zero(req);
203
204     fdprintf(s, "GET %s HTTP/1.0\r\n", page);
205     for (i = 0; i < num_headers; i++)
206         fdprintf(s, "%s\r\n", headers[i]);
207     fdprintf(s, "Keep-Alive: 115\r\n");
208     fdprintf(s, "Connection: keep-alive\r\n");
209     fdprintf(s, "Host: %s\r\n\r\n", host);
210
211     state = RESPONSE;
212
213     while (1) {
214         ret = read (s, in_ptr, sizeof(in_buf) - in_len - 1);
215         if (ret == 0)
216             break;
217         else if (ret < 0)
218             err (1, "read: %lu", (unsigned long)ret);
219         
220         in_buf[ret + in_len] = '\0';
221
222         if (state == HEADER || state == RESPONSE) {
223             char *p;
224
225             in_len += ret;
226             in_ptr += ret;
227
228             while (1) {
229                 p = strstr(in_buf, "\r\n");
230
231                 if (p == NULL) {
232                     break;
233                 } else if (p == in_buf) {
234                     memmove(in_buf, in_buf + 2, sizeof(in_buf) - 2);
235                     state = BODY;
236                     in_len -= 2;
237                     in_ptr -= 2;
238                     break;
239                 } else if (state == RESPONSE) {
240                     req->response = strndup(in_buf, p - in_buf);
241                     state = HEADER;
242                 } else {
243                     req->headers = realloc(req->headers,
244                                            (req->num_headers + 1) * sizeof(req->headers[0]));
245                     req->headers[req->num_headers] = strndup(in_buf, p - in_buf);
246                     if (req->headers[req->num_headers] == NULL)
247                         errx(1, "strdup");
248                     if (strncmp(req->headers[req->num_headers], "Content-Length:", 15) == 0)
249                         content_length = atoi(req->headers[req->num_headers] + 16);
250                     req->num_headers++;
251                 }
252                 memmove(in_buf, p + 2, sizeof(in_buf) - (p - in_buf) - 2);
253                 in_len -= (p - in_buf) + 2;
254                 in_ptr -= (p - in_buf) + 2;
255             }
256         }
257
258         if (state == BODY) {
259
260             req->body = realloc(req->body, req->body_size + in_len + 1);
261
262             memcpy((char *)req->body + req->body_size, in_buf, in_len);
263             req->body_size += in_len;
264             ((char *)req->body)[req->body_size] = '\0';
265
266             if (content_length && req->body_size == content_length)
267                 break;
268
269             in_ptr = in_buf;
270             in_len = 0;
271         }
272 //      else
273 //          abort();
274     }
275
276 #if 0
277     if (verbose_flag) {
278         int i;
279         printf("response: %s\n", req->response);
280         for (i = 0; i < req->num_headers; i++)
281             printf("header[%d] %s\n", i, req->headers[i]);
282         printf("body: %.*s\n", (int)req->body_size, (char *)req->body);
283     }
284 #endif
285
286     return 0;
287 }
288
289 static int
290 do_http(const char *host, const char *page, gss_OID mech_oid, gss_cred_id_t cred)
291 {
292     struct http_req req;
293     int i, done, print_body, gssapi_done, gssapi_started;
294     char *headers[10]; /* XXX */
295     int num_headers;
296     gss_ctx_id_t context_hdl = GSS_C_NO_CONTEXT;
297     gss_name_t server = GSS_C_NO_NAME;
298     OM_uint32 flags = 0;
299     int s;
300
301     flags = 0;
302     if (delegate_flag)
303         flags |= GSS_C_DELEG_FLAG;
304     if (mutual_flag)
305         flags |= GSS_C_MUTUAL_FLAG;
306
307     done = 0;
308     num_headers = 0;
309     gssapi_done = 1;
310     gssapi_started = 0;
311
312     s = do_connect(host, port_str);
313     if (s < 0)
314         errx(1, "connection failed");
315
316     do {
317         print_body = 0;
318
319         http_query(s, host, page, headers, num_headers, &req);
320         for (i = 0 ; i < num_headers; i++) 
321             free(headers[i]);
322         num_headers = 0;
323
324         if (strstr(req.response, " 200 ") != NULL) {
325             print_body = 1;
326             done = 1;
327         } else if (strstr(req.response, " 401 ") != NULL) {
328             if (http_find_header(&req, "WWW-Authenticate:") == NULL)
329                 errx(1, "Got %s but missed `WWW-Authenticate'", req.response);
330             gssapi_done = 0;
331         }
332
333         if (!gssapi_done) {
334             const char *h = http_find_header(&req, "WWW-Authenticate:");
335             if (h == NULL)
336                 errx(1, "Got %s but missed `WWW-Authenticate'", req.response);
337
338             if (strncasecmp(h, "GSSAPI", 6) == 0) {
339                 OM_uint32 maj_stat, min_stat;
340                 gss_buffer_desc input_token, output_token;
341
342                 if (verbose_flag)
343                     printf("Negotiate found\n");
344                 
345 #if 1
346                 if (server == GSS_C_NO_NAME) {
347                     char *name;
348                     asprintf(&name, "%s@%s", gss_service, host);
349                     input_token.length = strlen(name);
350                     input_token.value = name;
351
352                     maj_stat = gss_import_name(&min_stat,
353                                                &input_token,
354                                                GSS_C_NT_HOSTBASED_SERVICE,
355                                                &server);
356                     if (GSS_ERROR(maj_stat))
357                         gss_err (1, maj_stat, min_stat, "gss_inport_name");
358                     free(name);
359                     input_token.length = 0;
360                     input_token.value = NULL;
361                 }
362 #endif
363
364 //              i = 9;
365                 i = 6;
366                 while(h[i] && isspace((unsigned char)h[i]))
367                     i++;
368                 if (h[i] != '\0') {
369                     int len = strlen(&h[i]);
370                     if (len == 0)
371                         errx(1, "invalid Negotiate token");
372                     input_token.value = malloc(len);
373                     len = base64_decode(&h[i], input_token.value);
374                     if (len < 0)
375                         errx(1, "invalid base64 Negotiate token %s", &h[i]);
376                     input_token.length = len;
377                 } else {
378                     if (gssapi_started)
379                         errx(1, "Negotiate already started");
380                     gssapi_started = 1;
381
382                     input_token.length = 0;
383                     input_token.value = NULL;
384                 }
385
386                 maj_stat =
387                     gss_init_sec_context(&min_stat,
388                                          cred,
389                                          &context_hdl,
390                                          server,
391                                          mech_oid,
392                                          flags,
393                                          0,
394                                          GSS_C_NO_CHANNEL_BINDINGS,
395                                          &input_token,
396                                          NULL,
397                                          &output_token,
398                                          NULL,
399                                          NULL);
400                 if (GSS_ERROR(maj_stat))
401                     gss_err (1, maj_stat, min_stat, "gss_init_sec_context");
402                 else if (maj_stat & GSS_S_CONTINUE_NEEDED)
403                     gssapi_done = 0;
404                 else {
405                     gss_name_t targ_name, src_name;
406                     gss_buffer_desc name_buffer;
407                     gss_OID mech_type;
408
409                     gssapi_done = 1;
410
411                     printf("\nNegotiate done: %s\n", mech);
412
413                     maj_stat = gss_inquire_context(&min_stat,
414                                                    context_hdl,
415                                                    &src_name,
416                                                    &targ_name,
417                                                    NULL,
418                                                    &mech_type,
419                                                    NULL,
420                                                    NULL,
421                                                    NULL);
422                     if (GSS_ERROR(maj_stat))
423                         gss_err (1, maj_stat, min_stat, "gss_inquire_context");
424
425                     maj_stat = gss_display_name(&min_stat,
426                                                 src_name,
427                                                 &name_buffer,
428                                                 NULL);
429                     if (GSS_ERROR(maj_stat))
430                         gss_err (1, maj_stat, min_stat, "gss_display_name");
431
432                     printf("Source: %.*s\n",
433                            (int)name_buffer.length,
434                            (char *)name_buffer.value);
435
436                     gss_release_buffer(&min_stat, &name_buffer);
437
438                     maj_stat = gss_display_name(&min_stat,
439                                                 targ_name,
440                                                 &name_buffer,
441                                                 NULL);
442                     if (GSS_ERROR(maj_stat))
443                         gss_err (1, maj_stat, min_stat, "gss_display_name");
444
445                     printf("Target: %.*s\n",
446                            (int)name_buffer.length,
447                            (char *)name_buffer.value);
448
449                     gss_release_name(&min_stat, &targ_name);
450                     gss_release_buffer(&min_stat, &name_buffer);
451                 }
452
453                 if (output_token.length) {
454                     char *neg_token;
455
456                     base64_encode(output_token.value,
457                                   output_token.length,
458                                   &neg_token);
459                     
460                     asprintf(&headers[0], "Authorization: GSSAPI %s",
461                              neg_token);
462                     num_headers = 1;
463                     free(neg_token);
464                     gss_release_buffer(&min_stat, &output_token);
465                 }
466                 if (input_token.length)
467                     free(input_token.value);
468
469             } else
470                 done = 1;
471         } else
472             done = 1;
473
474         if (verbose_flag) {
475             printf("%s\n\n", req.response);
476
477             for (i = 0; i < req.num_headers; i++)
478                 printf("%s\n", req.headers[i]);
479             printf("\n");
480         }
481         if (print_body || verbose_flag)
482             printf("%.*s\n", (int)req.body_size, (char *)req.body);
483
484         http_req_free(&req);
485     } while (!done);
486
487     close(s);
488
489     if (gssapi_done == 0)
490         errx(1, "gssapi not done but http dance done");
491
492     return 0;
493 }
494
495 int
496 main(int argc, char *argv[])
497 {
498     int c, ret;
499     gss_buffer_desc token;
500     gss_OID mech_oid = GSS_C_NO_OID;
501     OM_uint32 maj_stat, min_stat;
502     gss_name_t gss_username = GSS_C_NO_NAME;
503     gss_cred_id_t cred = GSS_C_NO_CREDENTIAL;
504     char *p, *host, *page;
505
506     while ((c = getopt_long(argc, argv, short_opts, long_opts, NULL)) != EOF) {
507         switch (c) {
508             case 'h':
509                 usage(0);
510             case 'm':
511                 mech = optarg;
512                 mech_oid = select_mech(mech);
513                 break;
514             case 'p':
515                 pwd = optarg;
516                 break;
517             case 's':
518                 gss_service = optarg;
519                 break;
520             case 'u':
521                 user = optarg;
522                 break;
523             default:
524                 usage(1);
525         }
526     }
527
528     if (optind >= argc)
529         usage(1);
530
531     p = argv[optind];
532     if (strncmp(p, "http://", 7) == 0)
533         p += 7;
534     host = p;
535     p = strchr(host, '/');
536     if (p) {
537         page = strdup(p);
538         *p = '\0';
539     } else
540         page = strdup("/");
541
542     if (user) {
543         token.value = user;
544         token.length = strlen(token.value);
545         maj_stat = gss_import_name(&min_stat, &token,
546                                    GSS_C_NT_USER_NAME,
547                                    &gss_username);
548         if (GSS_ERROR(maj_stat))
549             gss_err(1, maj_stat, min_stat, "Invalid user name %s", user);
550     }
551
552     if (pwd) {
553         gss_OID_set_desc mechs, *mechsp = GSS_C_NO_OID_SET;
554
555         token.value = pwd;
556         token.length = strlen(token.value);
557         mechs.elements = mech_oid;
558         mechs.count = 1;
559         mechsp = &mechs;
560         maj_stat = gss_acquire_cred_with_password(&min_stat,
561                         gss_username, &token, 0,
562                         mechsp, GSS_C_INITIATE,
563                         &cred, NULL, NULL);
564         if (GSS_ERROR(maj_stat))
565             gss_err(1, maj_stat, min_stat, "Failed to load initial credentials");
566     }
567
568     ret = do_http(host, page, mech_oid, cred);
569
570     if (gss_username != GSS_C_NO_NAME)
571         gss_release_name(&min_stat, &gss_username);
572
573     if (cred != GSS_C_NO_CREDENTIAL)
574         gss_release_cred(&min_stat, &cred);
575
576     free(page);
577
578     return (ret);
579 }