Add a timeout to ReadBuffer() method
[trust_router.git] / gsscon / gsscon_common.c
index d067979..54cdd56 100755 (executable)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2011, JANET(UK)
+ * Copyright (c) 2012, JANET(UK)
  * All rights reserved.
  *
  * Redistribution and use in source and binary forms, with or without
@@ -53,6 +53,8 @@
  */
 
 #include <gsscon.h>
+#include <fcntl.h>
+#include <poll.h>
 
 /* --------------------------------------------------------------------------- */
 /* Display the contents of the buffer in hex and ascii                         */
@@ -86,6 +88,7 @@ static void PrintBuffer (const char *inBuffer,
 /* --------------------------------------------------------------------------- */
 /* Standard network read loop, accounting for EINTR, EOF and incomplete reads  */
 
+#define READBUFFER_TIMEOUT 60
 static int ReadBuffer (int     inSocket, 
                        size_t  inBufferLength, 
                        char   *ioBuffer)
@@ -94,13 +97,37 @@ static int ReadBuffer (int     inSocket,
     ssize_t bytesRead = 0;
     
     if (!ioBuffer) { err = EINVAL; }
-    
+
+    /* Read in non-blocking mode */
+    if (!err) {
+        err = fcntl(inSocket, F_SETFL, O_NONBLOCK);
+    }
+
     if (!err) {
         char *ptr = ioBuffer;
         do {
-            ssize_t count = read (inSocket, ptr, inBufferLength - bytesRead);
+            ssize_t count;
+            struct pollfd fds = {inSocket, POLLIN, 0}; /* poll for data ready on the socket */
+            int poll_rc = 0;
+
+            poll_rc = poll(&fds, 1, READBUFFER_TIMEOUT);
+            if (poll_rc == 0) {
+                /* timed out */
+                err = ETIMEDOUT;
+                continue;
+            } else if (poll_rc < 0) {
+                /* try again if we were interrupted, otherwise exit */
+                if (errno != EINTR) {
+                    err = errno;
+                }
+                continue;
+            }
+
+            /* Data should be ready to read */
+            count = read (inSocket, ptr, inBufferLength - bytesRead);
             if (count < 0) {
-                /* Try again on EINTR */
+                /* Try again on EINTR (if we get EAGAIN or EWOULDBLOCK, something is wrong because
+                 * we just polled the fd) */
                 if (errno != EINTR) { err = errno; }
             } else if (count == 0) {
                 err = ECONNRESET; /* EOF and we expected data */
@@ -109,8 +136,8 @@ static int ReadBuffer (int     inSocket,
                 bytesRead += count;
             }
         } while (!err && (bytesRead < inBufferLength));
-    } 
-    
+    }
+
     if (err) { gsscon_print_error (err, "ReadBuffer failed"); }
 
     return err;
@@ -131,7 +158,10 @@ static int WriteBuffer (int         inSocket,
     if (!err) {
         const char *ptr = inBuffer;
         do {
-            ssize_t count = write (inSocket, ptr, inBufferLength - bytesWritten);
+            ssize_t count;
+
+            count = write (inSocket, ptr, inBufferLength - bytesWritten);
+
             if (count < 0) {
                 /* Try again on EINTR */
                 if (errno != EINTR) { err = errno; }
@@ -142,7 +172,7 @@ static int WriteBuffer (int         inSocket,
         } while (!err && (bytesWritten < inBufferLength));
     } 
     
-    if (err) { gsscon_print_error (err, "WritBuffer failed"); }
+    if (err) { gsscon_print_error (err, "WriteBuffer failed"); }
 
     return err;
 }
@@ -168,15 +198,16 @@ int gsscon_read_token (int      inSocket,
     if (!err) {
        tokenLength = ntohl (tokenLength);
        token = malloc (tokenLength);
-       memset (token, 0, tokenLength); 
+        if (token==NULL) {
+          err=EIO;
+        } else {
+          memset (token, 0, tokenLength); 
         
-       err = ReadBuffer (inSocket, tokenLength, token);
+          err = ReadBuffer (inSocket, tokenLength, token);
+        }
     }
     
     if (!err) {
-        printf ("Read token:\n");
-        PrintBuffer (token, tokenLength);
-        
        *outTokenLength = tokenLength;
         *outTokenValue = token;        
         token = NULL; /* only free on error */
@@ -210,8 +241,9 @@ int gsscon_write_token (int         inSocket,
     }
     
     if (!err) {
-        printf ("Wrote token:\n");
-        PrintBuffer (inTokenValue, inTokenLength);
+    //    printf ("Wrote token:\n");
+    //    PrintBuffer (inTokenValue, inTokenLength);
+
     } else { 
         gsscon_print_error (err, "gsscon_write_token() failed");
     }
@@ -271,8 +303,8 @@ int gsscon_read_encrypted_token (int                  inSocket,
     if (!err) {
         memcpy (unencryptedToken, outputBuffer.value, outputBuffer.length);
         
-        printf ("Unencrypted token:\n");
-        PrintBuffer (unencryptedToken, outputBuffer.length);
+       // printf ("Unencrypted token:\n");
+        // PrintBuffer (unencryptedToken, outputBuffer.length);
         
        *outTokenLength = outputBuffer.length;
         *outTokenValue = unencryptedToken;
@@ -327,8 +359,8 @@ int gsscon_write_encrypted_token (int                 inSocket,
     }
     
     if (!err) {
-        printf ("Unencrypted token:\n");
-        PrintBuffer (inToken, inTokenLength);
+      //  printf ("Unencrypted token:\n");
+      //  PrintBuffer (inToken, inTokenLength);
        err = gsscon_write_token (inSocket, outputBuffer.value, outputBuffer.length);
     }