Apply BSD fix to connect call
[shibboleth/cpp-sp.git] / shibsp / remoting / impl / TCPListener.cpp
index 0c860f5..5f07611 100644 (file)
@@ -1,34 +1,49 @@
-/*
- *  Copyright 2001-2007 Internet2
+/**
+ * Licensed to the University Corporation for Advanced Internet
+ * Development, Inc. (UCAID) under one or more contributor license
+ * agreements. See the NOTICE file distributed with this work for
+ * additional information regarding copyright ownership.
  *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
+ * UCAID licenses this file to you under the Apache License,
+ * Version 2.0 (the "License"); you may not use this file except
+ * in compliance with the License. You may obtain a copy of the
+ * License at
  *
- *     http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
  *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
+ * either express or implied. See the License for the specific
+ * language governing permissions and limitations under the License.
  */
 
 /**
  * TCPListener.cpp
  *
- * TCP-based SocketListener implementation
+ * TCP-based SocketListener implementation.
  */
 
 #include "internal.h"
+#include "exceptions.h"
 #include "remoting/impl/SocketListener.h"
+#include "util/IPRange.h"
 
+#include <boost/bind.hpp>
+#include <boost/algorithm/string.hpp>
 #include <xercesc/util/XMLUniDefs.hpp>
 #include <xmltooling/unicode.h>
+#include <xmltooling/util/XMLHelper.h>
+
+#ifdef WIN32
+# include <winsock2.h>
+# include <ws2tcpip.h>
+#endif
 
 #ifdef HAVE_UNISTD_H
 # include <sys/socket.h>
 # include <sys/un.h>
+# include <netdb.h>
 # include <unistd.h>
 # include <arpa/inet.h>
 # include <netinet/in.h>
 using namespace shibsp;
 using namespace xmltooling;
 using namespace xercesc;
+using namespace boost;
 using namespace std;
 
 namespace shibsp {
-    static const XMLCh address[] = UNICODE_LITERAL_7(a,d,d,r,e,s,s);
-    static const XMLCh port[] = UNICODE_LITERAL_4(p,o,r,t);
-    static const XMLCh acl[] = UNICODE_LITERAL_3(a,c,l);
-
     class TCPListener : virtual public SocketListener
     {
     public:
@@ -71,111 +83,167 @@ namespace shibsp {
         }
 
     private:
-        void setup_tcp_sockaddr(struct sockaddr_in* addr) const;
+        bool setup_tcp_sockaddr();
 
         string m_address;
         unsigned short m_port;
-        vector<string> m_acl;
+        vector<IPRange> m_acl;
+#ifdef HAVE_STRUCT_SOCKADDR_STORAGE
+        struct sockaddr_storage m_sockaddr;
+#else
+        struct sockaddr_in m_sockaddr;
+#endif
     };
 
     ListenerService* SHIBSP_DLLLOCAL TCPListenerServiceFactory(const DOMElement* const & e)
     {
         return new TCPListener(e);
     }
+
+    static const XMLCh address[] = UNICODE_LITERAL_7(a,d,d,r,e,s,s);
+    static const XMLCh port[] = UNICODE_LITERAL_4(p,o,r,t);
+    static const XMLCh acl[] = UNICODE_LITERAL_3(a,c,l);
 };
 
-TCPListener::TCPListener(const DOMElement* e) : SocketListener(e), m_address("127.0.0.1"), m_port(12345)
+TCPListener::TCPListener(const DOMElement* e)
+    : SocketListener(e),
+      m_address(XMLHelper::getAttrString(e, getenv("SHIBSP_LISTENER_ADDRESS"), address)),
+      m_port(XMLHelper::getAttrInt(e, 0, port))
 {
-    // We're stateless, but we need to load the configuration.
-    const XMLCh* tag=e->getAttributeNS(NULL,address);
-    if (tag && *tag) {
-        auto_ptr_char a(tag);
-        m_address=a.get();
-    }
+    if (m_address.empty())
+        m_address = "127.0.0.1";
 
-    tag=e->getAttributeNS(NULL,port);
-    if (tag && *tag) {
-        m_port=XMLString::parseInt(tag);
-        if (m_port==0)
-            m_port=12345;
+    if (m_port == 0) {
+        const char* p = getenv("SHIBSP_LISTENER_PORT");
+        if (p && *p)
+            m_port = atoi(p);
+        if (m_port == 0)
+            m_port = 1600;
     }
 
-    tag=e->getAttributeNS(NULL,acl);
-    if (tag && *tag) {
-        auto_ptr_char temp(tag);
-        string sockacl=temp.get();
-        if (sockacl.length()) {
-            int j = 0;
-            for (unsigned int i=0;  i < sockacl.length();  i++) {
-                if (sockacl.at(i)==' ') {
-                    m_acl.push_back(sockacl.substr(j, i-j));
-                    j = i+1;
-                }
-            }
-            m_acl.push_back(sockacl.substr(j, sockacl.length()-j));
+    vector<string> rawacls;
+    string aclbuf = XMLHelper::getAttrString(e, "127.0.0.1", acl);
+    boost::split(rawacls, aclbuf, boost::is_space(), algorithm::token_compress_on);
+    for (vector<string>::const_iterator i = rawacls.begin();  i < rawacls.end();  ++i) {
+        try {
+            m_acl.push_back(IPRange::parseCIDRBlock(i->c_str()));
+        }
+        catch (std::exception& ex) {
+            log->error("invalid CIDR block (%s): %s", i->c_str(), ex.what());
         }
     }
-    else
-        m_acl.push_back("127.0.0.1");
+
+    if (m_acl.empty()) {
+        log->warn("invalid CIDR range(s) in acl property, allowing 127.0.0.1 as a fall back");
+        m_acl.push_back(IPRange::parseCIDRBlock("127.0.0.1"));
+    }
+
+    if (!setup_tcp_sockaddr()) {
+        throw ConfigurationException("Unable to use configured socket address property.");
+    }
 }
 
-void TCPListener::setup_tcp_sockaddr(struct sockaddr_in* addr) const
+bool TCPListener::setup_tcp_sockaddr()
 {
-    // Split on host:port boundary. Default to port only.
-    memset(addr,0,sizeof(struct sockaddr_in));
-    addr->sin_family=AF_INET;
-    addr->sin_port=htons(m_port);
-    addr->sin_addr.s_addr=inet_addr(m_address.c_str());
+    struct addrinfo* ret = nullptr;
+    struct addrinfo hints;
+
+    memset(&hints, 0, sizeof(hints));
+    hints.ai_flags = AI_NUMERICHOST;
+    hints.ai_family = AF_UNSPEC;
+
+    if (getaddrinfo(m_address.c_str(), nullptr, &hints, &ret) != 0) {
+        log->error("unable to parse server address (%s)", m_address.c_str());
+        return false;
+    }
+
+    if (ret->ai_family == AF_INET) {
+        memcpy(&m_sockaddr, ret->ai_addr, ret->ai_addrlen);
+        freeaddrinfo(ret);
+        ((struct sockaddr_in*)&m_sockaddr)->sin_port=htons(m_port);
+        return true;
+    }
+#if defined(AF_INET6) && defined(HAVE_STRUCT_SOCKADDR_STORAGE)
+    else if (ret->ai_family == AF_INET6) {
+        memcpy(&m_sockaddr, ret->ai_addr, ret->ai_addrlen);
+        freeaddrinfo(ret);
+        ((struct sockaddr_in6*)&m_sockaddr)->sin6_port=htons(m_port);
+        return true;
+    }
+#endif
+
+    log->error("unknown address type (%d)", ret->ai_family);
+    freeaddrinfo(ret);
+    return false;
 }
 
 bool TCPListener::create(ShibSocket& s) const
 {
-    s=socket(AF_INET,SOCK_STREAM,0);
+#ifdef HAVE_STRUCT_SOCKADDR_STORAGE
+    s = socket(m_sockaddr.ss_family, SOCK_STREAM, 0);
+#else
+    s = socket(m_sockaddr.sin_family, SOCK_STREAM, 0);
+#endif
 #ifdef WIN32
-    if(s==INVALID_SOCKET)
+    if(s == INVALID_SOCKET)
 #else
     if (s < 0)
 #endif
-        return log_error();
+        return log_error("socket");
     return true;
 }
 
 bool TCPListener::bind(ShibSocket& s, bool force) const
 {
-    struct sockaddr_in addr;
-    setup_tcp_sockaddr(&addr);
-
     // XXX: Do we care about the return value from setsockopt?
     int opt = 1;
     ::setsockopt(s, SOL_SOCKET, SO_REUSEADDR, (const char*)&opt, sizeof(opt));
 
 #ifdef WIN32
-    if (SOCKET_ERROR==::bind(s,(struct sockaddr *)&addr,sizeof(addr)) || SOCKET_ERROR==::listen(s,3)) {
-        log_error();
+    if (SOCKET_ERROR==::bind(s, (const struct sockaddr*)&m_sockaddr, sizeof(m_sockaddr)) || SOCKET_ERROR==::listen(s, 3)) {
+        log_error("bind");
         close(s);
         return false;
     }
 #else
-    if (::bind(s, (struct sockaddr *)&addr, sizeof (addr)) < 0) {
-        log_error();
+    // Newer BSDs require the struct length be passed based on the socket address.
+    // Others have no field for that and take the whole struct size like Windows does.
+# ifdef HAVE_STRUCT_SOCKADDR_SA_LEN
+#  ifdef HAVE_STRUCT_SOCKADDR_STORAGE
+    if (::bind(s, (const struct sockaddr*)&m_sockaddr, m_sockaddr.ss_len) < 0) {
+#  else
+    if (::bind(s, (const struct sockaddr*)&m_sockaddr, m_sockaddr.sin_len) < 0) {
+#  endif
+# else
+    if (::bind(s, (const struct sockaddr*)&m_sockaddr, sizeof(m_sockaddr)) < 0) {
+# endif
+        log_error("bind");
         close(s);
         return false;
     }
-    ::listen(s,3);
+    ::listen(s, 3);
 #endif
     return true;
 }
 
 bool TCPListener::connect(ShibSocket& s) const
 {
-    struct sockaddr_in addr;
-    setup_tcp_sockaddr(&addr);
 #ifdef WIN32
-    if(SOCKET_ERROR==::connect(s,(struct sockaddr *)&addr,sizeof(addr)))
-        return log_error();
+    if(SOCKET_ERROR==::connect(s, (const struct sockaddr*)&m_sockaddr, sizeof(m_sockaddr)))
+        return log_error("connect");
 #else
-    if (::connect(s, (struct sockaddr*)&addr, sizeof (addr)) < 0)
-        return log_error();
+    // Newer BSDs require the struct length be passed based on the socket address.
+    // Others have no field for that and take the whole struct size like Windows does.
+# ifdef HAVE_STRUCT_SOCKADDR_SA_LEN
+#  ifdef HAVE_STRUCT_SOCKADDR_STORAGE
+    if (::connect(s, (const struct sockaddr*)&m_sockaddr, m_sockaddr.ss_len) < 0)
+#  else
+    if (::connect(s, (const struct sockaddr*)&m_sockaddr, m_sockaddr.sin_len) < 0)
+#  endif
+# else
+    if (::connect(s, (const struct sockaddr*)&m_sockaddr, sizeof(m_sockaddr)) < 0)
+# endif
+        return log_error("connect");
 #endif
     return true;
 }
@@ -192,25 +260,30 @@ bool TCPListener::close(ShibSocket& s) const
 
 bool TCPListener::accept(ShibSocket& listener, ShibSocket& s) const
 {
+#ifdef HAVE_STRUCT_SOCKADDR_STORAGE
+    struct sockaddr_storage addr;
+#else
     struct sockaddr_in addr;
+#endif
+    memset(&addr, 0, sizeof(addr));
 
 #ifdef WIN32
     int size=sizeof(addr);
-    s=::accept(listener,(struct sockaddr*)&addr,&size);
+    s=::accept(listener, (struct sockaddr*)&addr, &size);
     if(s==INVALID_SOCKET)
 #else
     socklen_t size=sizeof(addr);
-    s=::accept(listener,(struct sockaddr*)&addr,&size);
+    s=::accept(listener, (struct sockaddr*)&addr, &size);
     if (s < 0)
 #endif
-        return log_error();
-    char* client=inet_ntoa(addr.sin_addr);
-    for (vector<string>::const_iterator i=m_acl.begin(); i!=m_acl.end(); i++) {
-        if (*i==client)
-            return true;
+        return log_error("accept");
+
+    static bool (IPRange::* contains)(const struct sockaddr*) const = &IPRange::contains;
+    if (find_if(m_acl.begin(), m_acl.end(), boost::bind(contains, _1, (const struct sockaddr*)&addr)) == m_acl.end()) {
+        close(s);
+        s = -1;
+        log->error("accept() rejected client with invalid address");
+        return false;
     }
-    close(s);
-    s=-1;
-    log->error("accept() rejected client at %s\n",client);
-    return false;
+    return true;
 }