5f07611f6d5b4cb8f31974e249011aeba33add10
[shibboleth/cpp-sp.git] / shibsp / remoting / impl / TCPListener.cpp
1 /**
2  * Licensed to the University Corporation for Advanced Internet
3  * Development, Inc. (UCAID) under one or more contributor license
4  * agreements. See the NOTICE file distributed with this work for
5  * additional information regarding copyright ownership.
6  *
7  * UCAID licenses this file to you under the Apache License,
8  * Version 2.0 (the "License"); you may not use this file except
9  * in compliance with the License. You may obtain a copy of the
10  * License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing,
15  * software distributed under the License is distributed on an
16  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
17  * either express or implied. See the License for the specific
18  * language governing permissions and limitations under the License.
19  */
20
21 /**
22  * TCPListener.cpp
23  *
24  * TCP-based SocketListener implementation.
25  */
26
27 #include "internal.h"
28 #include "exceptions.h"
29 #include "remoting/impl/SocketListener.h"
30 #include "util/IPRange.h"
31
32 #include <boost/bind.hpp>
33 #include <boost/algorithm/string.hpp>
34 #include <xercesc/util/XMLUniDefs.hpp>
35 #include <xmltooling/unicode.h>
36 #include <xmltooling/util/XMLHelper.h>
37
38 #ifdef WIN32
39 # include <winsock2.h>
40 # include <ws2tcpip.h>
41 #endif
42
43 #ifdef HAVE_UNISTD_H
44 # include <sys/socket.h>
45 # include <sys/un.h>
46 # include <netdb.h>
47 # include <unistd.h>
48 # include <arpa/inet.h>
49 # include <netinet/in.h>
50 #endif
51
52 #include <sys/types.h>
53 #include <sys/stat.h>           /* for chmod() */
54 #include <stdio.h>
55 #include <stdlib.h>
56 #include <errno.h>
57
58 using namespace shibsp;
59 using namespace xmltooling;
60 using namespace xercesc;
61 using namespace boost;
62 using namespace std;
63
64 namespace shibsp {
65     class TCPListener : virtual public SocketListener
66     {
67     public:
68         TCPListener(const DOMElement* e);
69         ~TCPListener() {}
70
71         bool create(ShibSocket& s) const;
72         bool bind(ShibSocket& s, bool force=false) const;
73         bool connect(ShibSocket& s) const;
74         bool close(ShibSocket& s) const;
75         bool accept(ShibSocket& listener, ShibSocket& s) const;
76
77         int send(ShibSocket& s, const char* buf, int len) const {
78             return ::send(s, buf, len, 0);
79         }
80
81         int recv(ShibSocket& s, char* buf, int buflen) const {
82             return ::recv(s, buf, buflen, 0);
83         }
84
85     private:
86         bool setup_tcp_sockaddr();
87
88         string m_address;
89         unsigned short m_port;
90         vector<IPRange> m_acl;
91 #ifdef HAVE_STRUCT_SOCKADDR_STORAGE
92         struct sockaddr_storage m_sockaddr;
93 #else
94         struct sockaddr_in m_sockaddr;
95 #endif
96     };
97
98     ListenerService* SHIBSP_DLLLOCAL TCPListenerServiceFactory(const DOMElement* const & e)
99     {
100         return new TCPListener(e);
101     }
102
103     static const XMLCh address[] = UNICODE_LITERAL_7(a,d,d,r,e,s,s);
104     static const XMLCh port[] = UNICODE_LITERAL_4(p,o,r,t);
105     static const XMLCh acl[] = UNICODE_LITERAL_3(a,c,l);
106 };
107
108 TCPListener::TCPListener(const DOMElement* e)
109     : SocketListener(e),
110       m_address(XMLHelper::getAttrString(e, getenv("SHIBSP_LISTENER_ADDRESS"), address)),
111       m_port(XMLHelper::getAttrInt(e, 0, port))
112 {
113     if (m_address.empty())
114         m_address = "127.0.0.1";
115
116     if (m_port == 0) {
117         const char* p = getenv("SHIBSP_LISTENER_PORT");
118         if (p && *p)
119             m_port = atoi(p);
120         if (m_port == 0)
121             m_port = 1600;
122     }
123
124     vector<string> rawacls;
125     string aclbuf = XMLHelper::getAttrString(e, "127.0.0.1", acl);
126     boost::split(rawacls, aclbuf, boost::is_space(), algorithm::token_compress_on);
127     for (vector<string>::const_iterator i = rawacls.begin();  i < rawacls.end();  ++i) {
128         try {
129             m_acl.push_back(IPRange::parseCIDRBlock(i->c_str()));
130         }
131         catch (std::exception& ex) {
132             log->error("invalid CIDR block (%s): %s", i->c_str(), ex.what());
133         }
134     }
135
136     if (m_acl.empty()) {
137         log->warn("invalid CIDR range(s) in acl property, allowing 127.0.0.1 as a fall back");
138         m_acl.push_back(IPRange::parseCIDRBlock("127.0.0.1"));
139     }
140
141     if (!setup_tcp_sockaddr()) {
142         throw ConfigurationException("Unable to use configured socket address property.");
143     }
144 }
145
146 bool TCPListener::setup_tcp_sockaddr()
147 {
148     struct addrinfo* ret = nullptr;
149     struct addrinfo hints;
150
151     memset(&hints, 0, sizeof(hints));
152     hints.ai_flags = AI_NUMERICHOST;
153     hints.ai_family = AF_UNSPEC;
154
155     if (getaddrinfo(m_address.c_str(), nullptr, &hints, &ret) != 0) {
156         log->error("unable to parse server address (%s)", m_address.c_str());
157         return false;
158     }
159
160     if (ret->ai_family == AF_INET) {
161         memcpy(&m_sockaddr, ret->ai_addr, ret->ai_addrlen);
162         freeaddrinfo(ret);
163         ((struct sockaddr_in*)&m_sockaddr)->sin_port=htons(m_port);
164         return true;
165     }
166 #if defined(AF_INET6) && defined(HAVE_STRUCT_SOCKADDR_STORAGE)
167     else if (ret->ai_family == AF_INET6) {
168         memcpy(&m_sockaddr, ret->ai_addr, ret->ai_addrlen);
169         freeaddrinfo(ret);
170         ((struct sockaddr_in6*)&m_sockaddr)->sin6_port=htons(m_port);
171         return true;
172     }
173 #endif
174
175     log->error("unknown address type (%d)", ret->ai_family);
176     freeaddrinfo(ret);
177     return false;
178 }
179
180 bool TCPListener::create(ShibSocket& s) const
181 {
182 #ifdef HAVE_STRUCT_SOCKADDR_STORAGE
183     s = socket(m_sockaddr.ss_family, SOCK_STREAM, 0);
184 #else
185     s = socket(m_sockaddr.sin_family, SOCK_STREAM, 0);
186 #endif
187 #ifdef WIN32
188     if(s == INVALID_SOCKET)
189 #else
190     if (s < 0)
191 #endif
192         return log_error("socket");
193     return true;
194 }
195
196 bool TCPListener::bind(ShibSocket& s, bool force) const
197 {
198     // XXX: Do we care about the return value from setsockopt?
199     int opt = 1;
200     ::setsockopt(s, SOL_SOCKET, SO_REUSEADDR, (const char*)&opt, sizeof(opt));
201
202 #ifdef WIN32
203     if (SOCKET_ERROR==::bind(s, (const struct sockaddr*)&m_sockaddr, sizeof(m_sockaddr)) || SOCKET_ERROR==::listen(s, 3)) {
204         log_error("bind");
205         close(s);
206         return false;
207     }
208 #else
209     // Newer BSDs require the struct length be passed based on the socket address.
210     // Others have no field for that and take the whole struct size like Windows does.
211 # ifdef HAVE_STRUCT_SOCKADDR_SA_LEN
212 #  ifdef HAVE_STRUCT_SOCKADDR_STORAGE
213     if (::bind(s, (const struct sockaddr*)&m_sockaddr, m_sockaddr.ss_len) < 0) {
214 #  else
215     if (::bind(s, (const struct sockaddr*)&m_sockaddr, m_sockaddr.sin_len) < 0) {
216 #  endif
217 # else
218     if (::bind(s, (const struct sockaddr*)&m_sockaddr, sizeof(m_sockaddr)) < 0) {
219 # endif
220         log_error("bind");
221         close(s);
222         return false;
223     }
224     ::listen(s, 3);
225 #endif
226     return true;
227 }
228
229 bool TCPListener::connect(ShibSocket& s) const
230 {
231 #ifdef WIN32
232     if(SOCKET_ERROR==::connect(s, (const struct sockaddr*)&m_sockaddr, sizeof(m_sockaddr)))
233         return log_error("connect");
234 #else
235     // Newer BSDs require the struct length be passed based on the socket address.
236     // Others have no field for that and take the whole struct size like Windows does.
237 # ifdef HAVE_STRUCT_SOCKADDR_SA_LEN
238 #  ifdef HAVE_STRUCT_SOCKADDR_STORAGE
239     if (::connect(s, (const struct sockaddr*)&m_sockaddr, m_sockaddr.ss_len) < 0)
240 #  else
241     if (::connect(s, (const struct sockaddr*)&m_sockaddr, m_sockaddr.sin_len) < 0)
242 #  endif
243 # else
244     if (::connect(s, (const struct sockaddr*)&m_sockaddr, sizeof(m_sockaddr)) < 0)
245 # endif
246         return log_error("connect");
247 #endif
248     return true;
249 }
250
251 bool TCPListener::close(ShibSocket& s) const
252 {
253 #ifdef WIN32
254     closesocket(s);
255 #else
256     ::close(s);
257 #endif
258     return true;
259 }
260
261 bool TCPListener::accept(ShibSocket& listener, ShibSocket& s) const
262 {
263 #ifdef HAVE_STRUCT_SOCKADDR_STORAGE
264     struct sockaddr_storage addr;
265 #else
266     struct sockaddr_in addr;
267 #endif
268     memset(&addr, 0, sizeof(addr));
269
270 #ifdef WIN32
271     int size=sizeof(addr);
272     s=::accept(listener, (struct sockaddr*)&addr, &size);
273     if(s==INVALID_SOCKET)
274 #else
275     socklen_t size=sizeof(addr);
276     s=::accept(listener, (struct sockaddr*)&addr, &size);
277     if (s < 0)
278 #endif
279         return log_error("accept");
280
281     static bool (IPRange::* contains)(const struct sockaddr*) const = &IPRange::contains;
282     if (find_if(m_acl.begin(), m_acl.end(), boost::bind(contains, _1, (const struct sockaddr*)&addr)) == m_acl.end()) {
283         close(s);
284         s = -1;
285         log->error("accept() rejected client with invalid address");
286         return false;
287     }
288     return true;
289 }