https://issues.shibboleth.net/jira/browse/SSPCPP-569
[shibboleth/cpp-sp.git] / shibsp / remoting / impl / TCPListener.cpp
index 5f07611..4d1452a 100644 (file)
@@ -53,6 +53,7 @@
 #include <sys/stat.h>          /* for chmod() */
 #include <stdio.h>
 #include <stdlib.h>
+#include <fcntl.h>
 #include <errno.h>
 
 using namespace shibsp;
@@ -88,6 +89,7 @@ namespace shibsp {
         string m_address;
         unsigned short m_port;
         vector<IPRange> m_acl;
+        size_t m_sockaddrlen;
 #ifdef HAVE_STRUCT_SOCKADDR_STORAGE
         struct sockaddr_storage m_sockaddr;
 #else
@@ -157,15 +159,16 @@ bool TCPListener::setup_tcp_sockaddr()
         return false;
     }
 
+    m_sockaddrlen = ret->ai_addrlen;
     if (ret->ai_family == AF_INET) {
-        memcpy(&m_sockaddr, ret->ai_addr, ret->ai_addrlen);
+        memcpy(&m_sockaddr, ret->ai_addr, m_sockaddrlen);
         freeaddrinfo(ret);
         ((struct sockaddr_in*)&m_sockaddr)->sin_port=htons(m_port);
         return true;
     }
 #if defined(AF_INET6) && defined(HAVE_STRUCT_SOCKADDR_STORAGE)
     else if (ret->ai_family == AF_INET6) {
-        memcpy(&m_sockaddr, ret->ai_addr, ret->ai_addrlen);
+        memcpy(&m_sockaddr, ret->ai_addr, m_sockaddrlen);
         freeaddrinfo(ret);
         ((struct sockaddr_in6*)&m_sockaddr)->sin6_port=htons(m_port);
         return true;
@@ -179,10 +182,15 @@ bool TCPListener::setup_tcp_sockaddr()
 
 bool TCPListener::create(ShibSocket& s) const
 {
+    int type = SOCK_STREAM;
+#ifdef HAVE_SOCK_CLOEXEC
+    type |= SOCK_CLOEXEC;
+#endif
+
 #ifdef HAVE_STRUCT_SOCKADDR_STORAGE
-    s = socket(m_sockaddr.ss_family, SOCK_STREAM, 0);
+    s = socket(m_sockaddr.ss_family, type, 0);
 #else
-    s = socket(m_sockaddr.sin_family, SOCK_STREAM, 0);
+    s = socket(m_sockaddr.sin_family, type, 0);
 #endif
 #ifdef WIN32
     if(s == INVALID_SOCKET)
@@ -190,6 +198,15 @@ bool TCPListener::create(ShibSocket& s) const
     if (s < 0)
 #endif
         return log_error("socket");
+
+#if !defined(HAVE_SOCK_CLOEXEC) && defined(HAVE_FD_CLOEXEC)
+    int fdflags = fcntl(s, F_GETFD);
+    if (fdflags != -1) {
+        fdflags |= FD_CLOEXEC;
+        fcntl(s, F_SETFD, fdflags);
+    }
+#endif
+
     return true;
 }
 
@@ -200,14 +217,14 @@ bool TCPListener::bind(ShibSocket& s, bool force) const
     ::setsockopt(s, SOL_SOCKET, SO_REUSEADDR, (const char*)&opt, sizeof(opt));
 
 #ifdef WIN32
-    if (SOCKET_ERROR==::bind(s, (const struct sockaddr*)&m_sockaddr, sizeof(m_sockaddr)) || SOCKET_ERROR==::listen(s, 3)) {
+    if (SOCKET_ERROR==::bind(s, (const struct sockaddr*)&m_sockaddr, m_sockaddrlen) || SOCKET_ERROR==::listen(s, 3)) {
         log_error("bind");
         close(s);
         return false;
     }
 #else
-    // Newer BSDs require the struct length be passed based on the socket address.
-    // Others have no field for that and take the whole struct size like Windows does.
+    // Newer BSDs, and Solaris, require the struct length be passed based on the socket address.
+    // All but Solaris seem to have an ss_len field in the sockaddr_storage struct.
 # ifdef HAVE_STRUCT_SOCKADDR_SA_LEN
 #  ifdef HAVE_STRUCT_SOCKADDR_STORAGE
     if (::bind(s, (const struct sockaddr*)&m_sockaddr, m_sockaddr.ss_len) < 0) {
@@ -215,7 +232,7 @@ bool TCPListener::bind(ShibSocket& s, bool force) const
     if (::bind(s, (const struct sockaddr*)&m_sockaddr, m_sockaddr.sin_len) < 0) {
 #  endif
 # else
-    if (::bind(s, (const struct sockaddr*)&m_sockaddr, sizeof(m_sockaddr)) < 0) {
+    if (::bind(s, (const struct sockaddr*)&m_sockaddr, m_sockaddrlen) < 0) {
 # endif
         log_error("bind");
         close(s);
@@ -229,7 +246,7 @@ bool TCPListener::bind(ShibSocket& s, bool force) const
 bool TCPListener::connect(ShibSocket& s) const
 {
 #ifdef WIN32
-    if(SOCKET_ERROR==::connect(s, (const struct sockaddr*)&m_sockaddr, sizeof(m_sockaddr)))
+    if(SOCKET_ERROR==::connect(s, (const struct sockaddr*)&m_sockaddr, m_sockaddrlen))
         return log_error("connect");
 #else
     // Newer BSDs require the struct length be passed based on the socket address.
@@ -241,7 +258,7 @@ bool TCPListener::connect(ShibSocket& s) const
     if (::connect(s, (const struct sockaddr*)&m_sockaddr, m_sockaddr.sin_len) < 0)
 #  endif
 # else
-    if (::connect(s, (const struct sockaddr*)&m_sockaddr, sizeof(m_sockaddr)) < 0)
+    if (::connect(s, (const struct sockaddr*)&m_sockaddr, m_sockaddrlen) < 0)
 # endif
         return log_error("connect");
 #endif