ff49126d99e22f225c552dcd01783a5ddbb2c985
[shibboleth/cpp-sp.git] / shibsp / remoting / impl / TCPListener.cpp
1 /**
2  * Licensed to the University Corporation for Advanced Internet
3  * Development, Inc. (UCAID) under one or more contributor license
4  * agreements. See the NOTICE file distributed with this work for
5  * additional information regarding copyright ownership.
6  *
7  * UCAID licenses this file to you under the Apache License,
8  * Version 2.0 (the "License"); you may not use this file except
9  * in compliance with the License. You may obtain a copy of the
10  * License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing,
15  * software distributed under the License is distributed on an
16  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
17  * either express or implied. See the License for the specific
18  * language governing permissions and limitations under the License.
19  */
20
21 /**
22  * TCPListener.cpp
23  *
24  * TCP-based SocketListener implementation.
25  */
26
27 #include "internal.h"
28 #include "exceptions.h"
29 #include "remoting/impl/SocketListener.h"
30 #include "util/IPRange.h"
31
32 #include <xercesc/util/XMLUniDefs.hpp>
33 #include <xmltooling/unicode.h>
34 #include <xmltooling/util/XMLHelper.h>
35
36 #ifdef WIN32
37 # include <winsock2.h>
38 # include <ws2tcpip.h>
39 #endif
40
41 #ifdef HAVE_UNISTD_H
42 # include <sys/socket.h>
43 # include <sys/un.h>
44 # include <netdb.h>
45 # include <unistd.h>
46 # include <arpa/inet.h>
47 # include <netinet/in.h>
48 #endif
49
50 #include <sys/types.h>
51 #include <sys/stat.h>           /* for chmod() */
52 #include <stdio.h>
53 #include <stdlib.h>
54 #include <errno.h>
55
56 using namespace shibsp;
57 using namespace xmltooling;
58 using namespace xercesc;
59 using namespace std;
60
61 namespace shibsp {
62     class TCPListener : virtual public SocketListener
63     {
64     public:
65         TCPListener(const DOMElement* e);
66         ~TCPListener() {}
67
68         bool create(ShibSocket& s) const;
69         bool bind(ShibSocket& s, bool force=false) const;
70         bool connect(ShibSocket& s) const;
71         bool close(ShibSocket& s) const;
72         bool accept(ShibSocket& listener, ShibSocket& s) const;
73
74         int send(ShibSocket& s, const char* buf, int len) const {
75             return ::send(s, buf, len, 0);
76         }
77
78         int recv(ShibSocket& s, char* buf, int buflen) const {
79             return ::recv(s, buf, buflen, 0);
80         }
81
82     private:
83         bool setup_tcp_sockaddr();
84
85         string m_address;
86         unsigned short m_port;
87         vector<IPRange> m_acl;
88 #ifdef HAVE_STRUCT_SOCKADDR_STORAGE
89         struct sockaddr_storage m_sockaddr;
90 #else
91         struct sockaddr_in m_sockaddr;
92 #endif
93     };
94
95     ListenerService* SHIBSP_DLLLOCAL TCPListenerServiceFactory(const DOMElement* const & e)
96     {
97         return new TCPListener(e);
98     }
99
100     static const XMLCh address[] = UNICODE_LITERAL_7(a,d,d,r,e,s,s);
101     static const XMLCh port[] = UNICODE_LITERAL_4(p,o,r,t);
102     static const XMLCh acl[] = UNICODE_LITERAL_3(a,c,l);
103 };
104
105 TCPListener::TCPListener(const DOMElement* e)
106     : SocketListener(e),
107       m_address(XMLHelper::getAttrString(e, getenv("SHIBSP_LISTENER_ADDRESS"), address)),
108       m_port(XMLHelper::getAttrInt(e, 0, port))
109 {
110     if (m_address.empty())
111         m_address = "127.0.0.1";
112
113     if (m_port == 0) {
114         const char* p = getenv("SHIBSP_LISTENER_PORT");
115         if (p && *p)
116             m_port = atoi(p);
117         if (m_port == 0)
118             m_port = 1600;
119     }
120
121     int j = 0;
122     string aclbuf = XMLHelper::getAttrString(e, "127.0.0.1", acl);
123     for (unsigned int i = 0;  i < aclbuf.length();  ++i) {
124         if (aclbuf.at(i) == ' ') {
125             try {
126                 m_acl.push_back(IPRange::parseCIDRBlock(aclbuf.substr(j, i-j).c_str()));
127             }
128             catch (exception& ex) {
129                 log->error("invalid CIDR block (%s): %s", aclbuf.substr(j, i-j).c_str(), ex.what());
130             }
131             j = i + 1;
132         }
133     }
134     try {
135         m_acl.push_back(IPRange::parseCIDRBlock(aclbuf.substr(j, aclbuf.length()-j).c_str()));
136     }
137     catch (exception& ex) {
138         log->error("invalid CIDR block (%s): %s", aclbuf.substr(j, aclbuf.length()-j).c_str(), ex.what());
139     }
140
141     if (m_acl.empty()) {
142         log->warn("invalid CIDR range(s) in acl property, allowing 127.0.0.1 as a fall back");
143         m_acl.push_back(IPRange::parseCIDRBlock("127.0.0.1"));
144     }
145
146     if (!setup_tcp_sockaddr()) {
147         throw ConfigurationException("Unable to use configured socket address property.");
148     }
149 }
150
151 bool TCPListener::setup_tcp_sockaddr()
152 {
153     struct addrinfo* ret = nullptr;
154     struct addrinfo hints;
155
156     memset(&hints, 0, sizeof(hints));
157     hints.ai_flags = AI_NUMERICHOST;
158     hints.ai_family = AF_UNSPEC;
159
160     if (getaddrinfo(m_address.c_str(), nullptr, &hints, &ret) != 0) {
161         log->error("unable to parse server address (%s)", m_address.c_str());
162         return false;
163     }
164
165     if (ret->ai_family == AF_INET) {
166         memcpy(&m_sockaddr, ret->ai_addr, ret->ai_addrlen);
167         freeaddrinfo(ret);
168         ((struct sockaddr_in*)&m_sockaddr)->sin_port=htons(m_port);
169         return true;
170     }
171 #if defined(AF_INET6) && defined(HAVE_STRUCT_SOCKADDR_STORAGE)
172     else if (ret->ai_family == AF_INET6) {
173         memcpy(&m_sockaddr, ret->ai_addr, ret->ai_addrlen);
174         freeaddrinfo(ret);
175         ((struct sockaddr_in6*)&m_sockaddr)->sin6_port=htons(m_port);
176         return true;
177     }
178 #endif
179
180     log->error("unknown address type (%d)", ret->ai_family);
181     freeaddrinfo(ret);
182     return false;
183 }
184
185 bool TCPListener::create(ShibSocket& s) const
186 {
187 #ifdef HAVE_STRUCT_SOCKADDR_STORAGE
188     s = socket(m_sockaddr.ss_family, SOCK_STREAM, 0);
189 #else
190     s = socket(m_sockaddr.sin_family, SOCK_STREAM, 0);
191 #endif
192 #ifdef WIN32
193     if(s == INVALID_SOCKET)
194 #else
195     if (s < 0)
196 #endif
197         return log_error("socket");
198     return true;
199 }
200
201 bool TCPListener::bind(ShibSocket& s, bool force) const
202 {
203     // XXX: Do we care about the return value from setsockopt?
204     int opt = 1;
205     ::setsockopt(s, SOL_SOCKET, SO_REUSEADDR, (const char*)&opt, sizeof(opt));
206
207 #ifdef WIN32
208     if (SOCKET_ERROR==::bind(s, (const struct sockaddr*)&m_sockaddr, m_sockaddr.ss_len) || SOCKET_ERROR==::listen(s, 3)) {
209         log_error("bind");
210         close(s);
211         return false;
212     }
213 #else
214 # ifdef HAVE_STRUCT_SOCKADDR_STORAGE
215     if (::bind(s, (const struct sockaddr*)&m_sockaddr, m_sockaddr.ss_len) < 0)
216 # else
217     if (::bind(s, (const struct sockaddr*)&m_sockaddr, m_sockaddr.sin_len) < 0)
218 # endif
219     {
220         log_error("bind");
221         close(s);
222         return false;
223     }
224     ::listen(s, 3);
225 #endif
226     return true;
227 }
228
229 bool TCPListener::connect(ShibSocket& s) const
230 {
231 #ifdef WIN32
232     if(SOCKET_ERROR==::connect(s, (const struct sockaddr*)&m_sockaddr, m_sockaddr.ss_len))
233         return log_error("connect");
234 #else
235 # ifdef HAVE_STRUCT_SOCKADDR_STORAGE
236     if (::connect(s, (const struct sockaddr*)&m_sockaddr, m_sockaddr.ss_len) < 0)
237 # else
238     if (::connect(s, (const struct sockaddr*)&m_sockaddr, m_sockaddr.sin_len) < 0)
239 # endif
240         return log_error("connect");
241 #endif
242     return true;
243 }
244
245 bool TCPListener::close(ShibSocket& s) const
246 {
247 #ifdef WIN32
248     closesocket(s);
249 #else
250     ::close(s);
251 #endif
252     return true;
253 }
254
255 bool TCPListener::accept(ShibSocket& listener, ShibSocket& s) const
256 {
257 #ifdef HAVE_STRUCT_SOCKADDR_STORAGE
258     struct sockaddr_storage addr;
259 #else
260     struct sockaddr_in addr;
261 #endif
262     memset(&addr, 0, sizeof(addr));
263
264 #ifdef WIN32
265     int size=sizeof(addr);
266     s=::accept(listener, (struct sockaddr*)&addr, &size);
267     if(s==INVALID_SOCKET)
268 #else
269     socklen_t size=sizeof(addr);
270     s=::accept(listener, (struct sockaddr*)&addr, &size);
271     if (s < 0)
272 #endif
273         return log_error("accept");
274     bool found = false;
275     for (vector<IPRange>::const_iterator acl = m_acl.begin(); !found && acl != m_acl.end(); ++acl) {
276         found = acl->contains((const struct sockaddr*)&addr);
277     }
278     if (!found) {
279         close(s);
280         s = -1;
281         log->error("accept() rejected client with invalid address");
282         return false;
283     }
284     return true;
285 }