https://issues.shibboleth.net/jira/browse/SSPCPP-624
[shibboleth/cpp-sp.git] / shibsp / remoting / impl / TCPListener.cpp
1 /**
2  * Licensed to the University Corporation for Advanced Internet
3  * Development, Inc. (UCAID) under one or more contributor license
4  * agreements. See the NOTICE file distributed with this work for
5  * additional information regarding copyright ownership.
6  *
7  * UCAID licenses this file to you under the Apache License,
8  * Version 2.0 (the "License"); you may not use this file except
9  * in compliance with the License. You may obtain a copy of the
10  * License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing,
15  * software distributed under the License is distributed on an
16  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
17  * either express or implied. See the License for the specific
18  * language governing permissions and limitations under the License.
19  */
20
21 /**
22  * TCPListener.cpp
23  *
24  * TCP-based SocketListener implementation.
25  */
26
27 #include "internal.h"
28 #include "exceptions.h"
29 #include "remoting/impl/SocketListener.h"
30 #include "util/IPRange.h"
31
32 #include <boost/bind.hpp>
33 #include <boost/algorithm/string.hpp>
34 #include <xercesc/util/XMLUniDefs.hpp>
35 #include <xmltooling/unicode.h>
36 #include <xmltooling/util/XMLHelper.h>
37
38 #ifdef WIN32
39 # include <winsock2.h>
40 # include <ws2tcpip.h>
41 #endif
42
43 #ifdef HAVE_UNISTD_H
44 # include <sys/socket.h>
45 # include <sys/un.h>
46 # include <netdb.h>
47 # include <unistd.h>
48 # include <arpa/inet.h>
49 # include <netinet/in.h>
50 #endif
51
52 #include <sys/types.h>
53 #include <sys/stat.h>           /* for chmod() */
54 #include <stdio.h>
55 #include <stdlib.h>
56 #include <fcntl.h>
57 #include <errno.h>
58
59 using namespace shibsp;
60 using namespace xmltooling;
61 using namespace xercesc;
62 using namespace boost;
63 using namespace std;
64
65 namespace shibsp {
66     class TCPListener : virtual public SocketListener
67     {
68     public:
69         TCPListener(const DOMElement* e);
70         ~TCPListener() {}
71
72         bool create(ShibSocket& s) const;
73         bool bind(ShibSocket& s, bool force=false) const;
74         bool connect(ShibSocket& s) const;
75         bool close(ShibSocket& s) const;
76         bool accept(ShibSocket& listener, ShibSocket& s) const;
77
78         int send(ShibSocket& s, const char* buf, int len) const {
79             return ::send(s, buf, len, 0);
80         }
81
82         int recv(ShibSocket& s, char* buf, int buflen) const {
83             return ::recv(s, buf, buflen, 0);
84         }
85
86     private:
87         bool setup_tcp_sockaddr();
88
89         string m_address;
90         unsigned short m_port;
91         vector<IPRange> m_acl;
92         size_t m_sockaddrlen;
93 #ifdef HAVE_STRUCT_SOCKADDR_STORAGE
94         struct sockaddr_storage m_sockaddr;
95 #else
96         struct sockaddr_in m_sockaddr;
97 #endif
98     };
99
100     ListenerService* SHIBSP_DLLLOCAL TCPListenerServiceFactory(const DOMElement* const & e)
101     {
102         return new TCPListener(e);
103     }
104
105     static const XMLCh address[] = UNICODE_LITERAL_7(a,d,d,r,e,s,s);
106     static const XMLCh port[] = UNICODE_LITERAL_4(p,o,r,t);
107     static const XMLCh acl[] = UNICODE_LITERAL_3(a,c,l);
108 };
109
110 TCPListener::TCPListener(const DOMElement* e)
111     : SocketListener(e),
112       m_address(XMLHelper::getAttrString(e, getenv("SHIBSP_LISTENER_ADDRESS"), address)),
113       m_port(XMLHelper::getAttrInt(e, 0, port))
114 {
115     if (m_address.empty())
116         m_address = "127.0.0.1";
117
118     if (m_port == 0) {
119         const char* p = getenv("SHIBSP_LISTENER_PORT");
120         if (p && *p)
121             m_port = atoi(p);
122         if (m_port == 0)
123             m_port = 1600;
124     }
125
126     vector<string> rawacls;
127     string aclbuf = XMLHelper::getAttrString(e, "127.0.0.1", acl);
128     boost::trim(aclbuf);
129     boost::split(rawacls, aclbuf, boost::is_space(), algorithm::token_compress_on);
130     for (vector<string>::const_iterator i = rawacls.begin();  i < rawacls.end();  ++i) {
131         try {
132             m_acl.push_back(IPRange::parseCIDRBlock(i->c_str()));
133         }
134         catch (std::exception& ex) {
135             log->error("invalid CIDR block (%s): %s", i->c_str(), ex.what());
136         }
137     }
138
139     if (m_acl.empty()) {
140         log->warn("invalid CIDR range(s) in acl property, allowing 127.0.0.1 as a fall back");
141         m_acl.push_back(IPRange::parseCIDRBlock("127.0.0.1"));
142     }
143
144     if (!setup_tcp_sockaddr()) {
145         throw ConfigurationException("Unable to use configured socket address property.");
146     }
147 }
148
149 bool TCPListener::setup_tcp_sockaddr()
150 {
151     struct addrinfo* ret = nullptr;
152     struct addrinfo hints;
153
154     memset(&hints, 0, sizeof(hints));
155     hints.ai_flags = AI_NUMERICHOST;
156     hints.ai_family = AF_UNSPEC;
157
158     if (getaddrinfo(m_address.c_str(), nullptr, &hints, &ret) != 0) {
159         log->error("unable to parse server address (%s)", m_address.c_str());
160         return false;
161     }
162
163     m_sockaddrlen = ret->ai_addrlen;
164     if (ret->ai_family == AF_INET) {
165         memcpy(&m_sockaddr, ret->ai_addr, m_sockaddrlen);
166         freeaddrinfo(ret);
167         ((struct sockaddr_in*)&m_sockaddr)->sin_port=htons(m_port);
168         return true;
169     }
170 #if defined(AF_INET6) && defined(HAVE_STRUCT_SOCKADDR_STORAGE)
171     else if (ret->ai_family == AF_INET6) {
172         memcpy(&m_sockaddr, ret->ai_addr, m_sockaddrlen);
173         freeaddrinfo(ret);
174         ((struct sockaddr_in6*)&m_sockaddr)->sin6_port=htons(m_port);
175         return true;
176     }
177 #endif
178
179     log->error("unknown address type (%d)", ret->ai_family);
180     freeaddrinfo(ret);
181     return false;
182 }
183
184 bool TCPListener::create(ShibSocket& s) const
185 {
186     int type = SOCK_STREAM;
187 #ifdef HAVE_SOCK_CLOEXEC
188     type |= SOCK_CLOEXEC;
189 #endif
190
191 #ifdef HAVE_STRUCT_SOCKADDR_STORAGE
192     s = socket(m_sockaddr.ss_family, type, 0);
193 #else
194     s = socket(m_sockaddr.sin_family, type, 0);
195 #endif
196 #ifdef WIN32
197     if(s == INVALID_SOCKET)
198 #else
199     if (s < 0)
200 #endif
201         return log_error("socket");
202
203 #if !defined(HAVE_SOCK_CLOEXEC) && defined(HAVE_FD_CLOEXEC)
204     int fdflags = fcntl(s, F_GETFD);
205     if (fdflags != -1) {
206         fdflags |= FD_CLOEXEC;
207         fcntl(s, F_SETFD, fdflags);
208     }
209 #endif
210
211     return true;
212 }
213
214 bool TCPListener::bind(ShibSocket& s, bool force) const
215 {
216     // XXX: Do we care about the return value from setsockopt?
217     int opt = 1;
218     ::setsockopt(s, SOL_SOCKET, SO_REUSEADDR, (const char*)&opt, sizeof(opt));
219
220 #ifdef WIN32
221     if (SOCKET_ERROR==::bind(s, (const struct sockaddr*)&m_sockaddr, m_sockaddrlen) || SOCKET_ERROR==::listen(s, 3)) {
222         log_error("bind");
223         close(s);
224         return false;
225     }
226 #else
227     // Newer BSDs, and Solaris, require the struct length be passed based on the socket address.
228     // All but Solaris seem to have an ss_len field in the sockaddr_storage struct.
229 # ifdef HAVE_STRUCT_SOCKADDR_SA_LEN
230 #  ifdef HAVE_STRUCT_SOCKADDR_STORAGE
231     if (::bind(s, (const struct sockaddr*)&m_sockaddr, m_sockaddr.ss_len) < 0) {
232 #  else
233     if (::bind(s, (const struct sockaddr*)&m_sockaddr, m_sockaddr.sin_len) < 0) {
234 #  endif
235 # else
236     if (::bind(s, (const struct sockaddr*)&m_sockaddr, m_sockaddrlen) < 0) {
237 # endif
238         log_error("bind");
239         close(s);
240         return false;
241     }
242     ::listen(s, 3);
243 #endif
244     return true;
245 }
246
247 bool TCPListener::connect(ShibSocket& s) const
248 {
249 #ifdef WIN32
250     if(SOCKET_ERROR==::connect(s, (const struct sockaddr*)&m_sockaddr, m_sockaddrlen))
251         return log_error("connect");
252 #else
253     // Newer BSDs require the struct length be passed based on the socket address.
254     // Others have no field for that and take the whole struct size like Windows does.
255 # ifdef HAVE_STRUCT_SOCKADDR_SA_LEN
256 #  ifdef HAVE_STRUCT_SOCKADDR_STORAGE
257     if (::connect(s, (const struct sockaddr*)&m_sockaddr, m_sockaddr.ss_len) < 0)
258 #  else
259     if (::connect(s, (const struct sockaddr*)&m_sockaddr, m_sockaddr.sin_len) < 0)
260 #  endif
261 # else
262     if (::connect(s, (const struct sockaddr*)&m_sockaddr, m_sockaddrlen) < 0)
263 # endif
264         return log_error("connect");
265 #endif
266     return true;
267 }
268
269 bool TCPListener::close(ShibSocket& s) const
270 {
271 #ifdef WIN32
272     closesocket(s);
273 #else
274     ::close(s);
275 #endif
276     return true;
277 }
278
279 bool TCPListener::accept(ShibSocket& listener, ShibSocket& s) const
280 {
281 #ifdef HAVE_STRUCT_SOCKADDR_STORAGE
282     struct sockaddr_storage addr;
283 #else
284     struct sockaddr_in addr;
285 #endif
286     memset(&addr, 0, sizeof(addr));
287
288 #ifdef WIN32
289     int size=sizeof(addr);
290     s=::accept(listener, (struct sockaddr*)&addr, &size);
291     if(s==INVALID_SOCKET)
292 #else
293     socklen_t size=sizeof(addr);
294     s=::accept(listener, (struct sockaddr*)&addr, &size);
295     if (s < 0)
296 #endif
297         return log_error("accept");
298
299     static bool (IPRange::* contains)(const struct sockaddr*) const = &IPRange::contains;
300     if (find_if(m_acl.begin(), m_acl.end(), boost::bind(contains, _1, (const struct sockaddr*)&addr)) == m_acl.end()) {
301         close(s);
302         s = -1;
303         log->error("accept() rejected client with invalid address");
304         return false;
305     }
306     return true;
307 }