54cdd56e99b28852118718b9827679093d4ea42a
[trust_router.git] / gsscon / gsscon_common.c
1 /*
2  * Copyright (c) 2012, JANET(UK)
3  * All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  *
9  * 1. Redistributions of source code must retain the above copyright
10  *    notice, this list of conditions and the following disclaimer.
11  *
12  * 2. Redistributions in binary form must reproduce the above copyright
13  *    notice, this list of conditions and the following disclaimer in the
14  *    documentation and/or other materials provided with the distribution.
15  *
16  * 3. Neither the name of JANET(UK) nor the names of its contributors
17  *    may be used to endorse or promote products derived from this software
18  *    without specific prior written permission.
19  *
20  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
23  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
24  * COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
25  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
28  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
29  * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
30  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
31  * OF THE POSSIBILITY OF SUCH DAMAGE.
32  *
33  * This code was adapted from the MIT Kerberos Consortium's
34  * GSS example code, which was distributed under the following
35  * license:
36  *
37  * Copyright 2004-2006 Massachusetts Institute of Technology.
38  * All Rights Reserved.
39  *
40  * WITHIN THAT CONSTRAINT, permission to use, copy, modify, and
41  * distribute this software and its documentation for any purpose and
42  * without fee is hereby granted, provided that the above copyright
43  * notice appear in all copies and that both that copyright notice and
44  * this permission notice appear in supporting documentation, and that
45  * the name of M.I.T. not be used in advertising or publicity pertaining
46  * to distribution of the software without specific, written prior
47  * permission.  Furthermore if you modify this software you must label
48  * your software as modified software and not distribute it in such a
49  * fashion that it might be confused with the original M.I.T. software.
50  * M.I.T. makes no representations about the suitability of
51  * this software for any purpose.  It is provided "as is" without express
52  * or implied warranty.
53  */
54
55 #include <gsscon.h>
56 #include <fcntl.h>
57 #include <poll.h>
58
59 /* --------------------------------------------------------------------------- */
60 /* Display the contents of the buffer in hex and ascii                         */
61
62 static void PrintBuffer (const char *inBuffer, 
63                          size_t      inLength)
64 {
65     int i;  
66     
67     for (i = 0; i < inLength; i += 16) {
68         int l;
69         for (l = i; l < (i + 16); l++) {
70             if (l >= inLength) {
71                 printf ("  ");
72             } else {
73                 u_int8_t *byte = (u_int8_t *) inBuffer + l;
74                 printf ("%2.2x", *byte);
75             }
76             if ((l % 4) == 3) { printf (" "); }
77         }
78         printf ("   ");
79         for (l = i; l < (i + 16) && l < inLength; l++) {
80             printf ("%c", ((inBuffer[l] > 0x1f) && 
81                            (inBuffer[l] < 0x7f)) ? inBuffer[l] : '.');            
82         }
83         printf ("\n");
84     }
85     printf ("\n");
86 }
87
88 /* --------------------------------------------------------------------------- */
89 /* Standard network read loop, accounting for EINTR, EOF and incomplete reads  */
90
91 #define READBUFFER_TIMEOUT 60
92 static int ReadBuffer (int     inSocket, 
93                        size_t  inBufferLength, 
94                        char   *ioBuffer)
95 {
96     int err = 0;
97     ssize_t bytesRead = 0;
98     
99     if (!ioBuffer) { err = EINVAL; }
100
101     /* Read in non-blocking mode */
102     if (!err) {
103         err = fcntl(inSocket, F_SETFL, O_NONBLOCK);
104     }
105
106     if (!err) {
107         char *ptr = ioBuffer;
108         do {
109             ssize_t count;
110             struct pollfd fds = {inSocket, POLLIN, 0}; /* poll for data ready on the socket */
111             int poll_rc = 0;
112
113             poll_rc = poll(&fds, 1, READBUFFER_TIMEOUT);
114             if (poll_rc == 0) {
115                 /* timed out */
116                 err = ETIMEDOUT;
117                 continue;
118             } else if (poll_rc < 0) {
119                 /* try again if we were interrupted, otherwise exit */
120                 if (errno != EINTR) {
121                     err = errno;
122                 }
123                 continue;
124             }
125
126             /* Data should be ready to read */
127             count = read (inSocket, ptr, inBufferLength - bytesRead);
128             if (count < 0) {
129                 /* Try again on EINTR (if we get EAGAIN or EWOULDBLOCK, something is wrong because
130                  * we just polled the fd) */
131                 if (errno != EINTR) { err = errno; }
132             } else if (count == 0) {
133                 err = ECONNRESET; /* EOF and we expected data */
134             } else {
135                 ptr += count;
136                 bytesRead += count;
137             }
138         } while (!err && (bytesRead < inBufferLength));
139     }
140
141     if (err) { gsscon_print_error (err, "ReadBuffer failed"); }
142
143     return err;
144 }
145
146 /* --------------------------------------------------------------------------- */
147 /* Standard network write loop, accounting for EINTR and incomplete writes     */
148
149 static int WriteBuffer (int         inSocket, 
150                         const char *inBuffer, 
151                         size_t      inBufferLength)
152 {
153     int err = 0;
154     ssize_t bytesWritten = 0;
155     
156     if (!inBuffer) { err = EINVAL; }
157     
158     if (!err) {
159         const char *ptr = inBuffer;
160         do {
161             ssize_t count;
162
163             count = write (inSocket, ptr, inBufferLength - bytesWritten);
164
165             if (count < 0) {
166                 /* Try again on EINTR */
167                 if (errno != EINTR) { err = errno; }
168             } else {
169                 ptr += count;
170                 bytesWritten += count;
171             }
172         } while (!err && (bytesWritten < inBufferLength));
173     } 
174     
175     if (err) { gsscon_print_error (err, "WriteBuffer failed"); }
176
177     return err;
178 }
179
180 /* --------------------------------------------------------------------------- */
181 /* Read a GSS token (length + data) off the network                            */
182
183 int gsscon_read_token (int      inSocket, 
184                char   **outTokenValue, 
185                size_t  *outTokenLength)
186 {
187     int err = 0;
188     char *token = NULL;
189     u_int32_t tokenLength = 0;
190     
191     if (!outTokenValue ) { err = EINVAL; }
192     if (!outTokenLength) { err = EINVAL; }
193     
194     if (!err) {
195         err = ReadBuffer (inSocket, 4, (char *) &tokenLength);
196     }
197     
198     if (!err) {
199         tokenLength = ntohl (tokenLength);
200         token = malloc (tokenLength);
201         if (token==NULL) {
202           err=EIO;
203         } else {
204           memset (token, 0, tokenLength); 
205         
206           err = ReadBuffer (inSocket, tokenLength, token);
207         }
208     }
209     
210     if (!err) {
211         *outTokenLength = tokenLength;
212         *outTokenValue = token;        
213         token = NULL; /* only free on error */
214     } else { 
215         gsscon_print_error (err, "ReadToken failed"); 
216     }
217
218     if (token) { free (token); }
219     
220     return err;
221 }
222
223 /* --------------------------------------------------------------------------- */
224 /* Write a GSS token (length + data) onto the network                          */
225
226 int gsscon_write_token (int         inSocket, 
227                 const char *inTokenValue, 
228                 size_t      inTokenLength)
229 {
230     int err = 0;
231     u_int32_t tokenLength = htonl (inTokenLength);
232
233     if (!inTokenValue) { err = EINVAL; }
234     
235     if (!err) {
236         err = WriteBuffer (inSocket, (char *) &tokenLength, 4);
237     }
238         
239     if (!err) { 
240         err = WriteBuffer (inSocket, inTokenValue, inTokenLength);
241     }
242     
243     if (!err) {
244     //    printf ("Wrote token:\n");
245     //    PrintBuffer (inTokenValue, inTokenLength);
246
247     } else { 
248         gsscon_print_error (err, "gsscon_write_token() failed");
249     }
250    
251     return err;
252 }
253
254 /* --------------------------------------------------------------------------- */
255 /* Read an encrypted GSS token (length + encrypted data) off the network       */
256
257
258 int gsscon_read_encrypted_token (int                  inSocket, 
259                         const gss_ctx_id_t   inContext, 
260                         char               **outTokenValue, 
261                         size_t              *outTokenLength)
262 {
263     int err = 0;
264     char *token = NULL;
265     size_t tokenLength = 0;
266     OM_uint32 majorStatus;
267     OM_uint32 minorStatus = 0;
268     gss_buffer_desc outputBuffer = { 0 , NULL};
269     char *unencryptedToken = NULL;
270     
271     if (!inContext     ) { err = EINVAL; }
272     if (!outTokenValue ) { err = EINVAL; }
273     if (!outTokenLength) { err = EINVAL; }
274     
275     if (!err) {
276         err = gsscon_read_token (inSocket, &token, &tokenLength);
277     }
278     
279     if (!err) {
280         gss_buffer_desc inputBuffer = { tokenLength, token};
281         int encrypted = 0; /* did mechanism encrypt/integrity protect? */
282
283         majorStatus = gss_unwrap (&minorStatus, 
284                                   inContext, 
285                                   &inputBuffer, 
286                                   &outputBuffer, 
287                                   &encrypted, 
288                                   NULL /* qop_state */);
289         if (majorStatus != GSS_S_COMPLETE) { 
290             gsscon_print_gss_errors("gss_unwrap", majorStatus, minorStatus);
291             err = minorStatus ? minorStatus : majorStatus; 
292         } else if (!encrypted) {
293             fprintf (stderr, "WARNING!  Mechanism not using encryption!");
294             err = EINVAL; /* You may not want to fail here. */
295         }
296     }
297     
298     if (!err) {
299         unencryptedToken = malloc (outputBuffer.length);
300         if (unencryptedToken == NULL) { err = ENOMEM; }
301     }
302     
303     if (!err) {
304         memcpy (unencryptedToken, outputBuffer.value, outputBuffer.length);
305         
306         // printf ("Unencrypted token:\n");
307         // PrintBuffer (unencryptedToken, outputBuffer.length);
308         
309         *outTokenLength = outputBuffer.length;
310         *outTokenValue = unencryptedToken;
311         unencryptedToken = NULL; /* only free on error */
312         
313     } else { 
314         gsscon_print_error (err, "ReadToken failed"); 
315     }
316     
317     if (token             ) { free (token); }
318     if (outputBuffer.value) { gss_release_buffer (&minorStatus, &outputBuffer); }
319     if (unencryptedToken  ) { free (unencryptedToken); }
320     
321     return err;
322 }
323
324 /* --------------------------------------------------------------------------- */
325 /* Write an encrypted GSS token (length + encrypted data) onto the network     */
326
327 int gsscon_write_encrypted_token (int                 inSocket, 
328                          const gss_ctx_id_t  inContext, 
329                          const char         *inToken, 
330                          size_t              inTokenLength)
331 {
332     int err = 0;
333     OM_uint32 majorStatus;
334     OM_uint32 minorStatus = 0;
335     gss_buffer_desc outputBuffer = { 0, NULL };
336     
337     if (!inContext) { err = EINVAL; }
338     if (!inToken  ) { err = EINVAL; }
339     
340     if (!err) {
341         gss_buffer_desc inputBuffer = { inTokenLength, (char *) inToken };
342         int encrypt = 1;   /* do encryption and integrity protection */
343         int encrypted = 0; /* did mechanism encrypt/integrity protect? */
344         
345         majorStatus = gss_wrap (&minorStatus, 
346                                 inContext, 
347                                 encrypt, 
348                                 GSS_C_QOP_DEFAULT,
349                                 &inputBuffer, 
350                                 &encrypted, 
351                                 &outputBuffer);
352         if (majorStatus != GSS_S_COMPLETE) { 
353             gsscon_print_gss_errors ("gss_wrap", majorStatus, minorStatus);
354             err = minorStatus ? minorStatus : majorStatus; 
355         } else if (!encrypted) {
356             fprintf (stderr, "WARNING!  Mechanism does not support encryption!");
357             err = EINVAL; /* You may not want to fail here. */
358         }
359     }
360     
361     if (!err) {
362       //  printf ("Unencrypted token:\n");
363       //  PrintBuffer (inToken, inTokenLength);
364         err = gsscon_write_token (inSocket, outputBuffer.value, outputBuffer.length);
365     }
366     
367     if (!err) {
368     } else { 
369         gsscon_print_error (err, "gsscon_write_token failed");
370     }
371     
372     if (outputBuffer.value) { gss_release_buffer (&minorStatus, &outputBuffer); }
373     
374     return err;
375 }
376
377 /* --------------------------------------------------------------------------- */
378 /* Print BSD error                                                             */
379
380 void gsscon_print_error (int         inError, 
381                  const char *inString)
382 {
383     fprintf (stderr, "%s: %s (err = %d)\n", 
384              inString, error_message (inError), inError);
385 }
386
387 /* --------------------------------------------------------------------------- */
388 /* PrintGSSAPI errors                                                         */
389
390 void gsscon_print_gss_errors (const char *inRoutineName, 
391                      OM_uint32   inMajorStatus, 
392                      OM_uint32   inMinorStatus)
393 {
394     OM_uint32 minorStatus;
395     OM_uint32 majorStatus;      
396     gss_buffer_desc errorBuffer;
397
398     OM_uint32 messageContext = 0; /* first message */
399     int count = 1;
400     
401     fprintf (stderr, "Error returned by %s:\n", inRoutineName);
402     
403     do {
404         majorStatus = gss_display_status (&minorStatus, 
405                                           inMajorStatus, 
406                                           GSS_C_GSS_CODE, 
407                                           GSS_C_NULL_OID, 
408                                           &messageContext, 
409                                           &errorBuffer);
410         if (majorStatus == GSS_S_COMPLETE) {
411             fprintf (stderr,"      major error <%d> %s\n", 
412                      count, (char *) errorBuffer.value);
413             gss_release_buffer (&minorStatus, &errorBuffer);
414         }
415         ++count;
416     } while (messageContext != 0);
417     
418     count = 1;
419     messageContext = 0;
420     do {
421         majorStatus = gss_display_status (&minorStatus, 
422                                           inMinorStatus, 
423                                           GSS_C_MECH_CODE, 
424                                           GSS_C_NULL_OID, 
425                                           &messageContext, 
426                                           &errorBuffer);
427         fprintf (stderr,"      minor error <%d> %s\n", 
428                  count, (char *) errorBuffer.value);
429         ++count;
430     } while (messageContext != 0);
431 }
432