https://issues.shibboleth.net/jira/browse/SSPCPP-326
[shibboleth/cpp-sp.git] / shibsp / remoting / impl / SocketListener.cpp
1 /**
2  * Licensed to the University Corporation for Advanced Internet
3  * Development, Inc. (UCAID) under one or more contributor license
4  * agreements. See the NOTICE file distributed with this work for
5  * additional information regarding copyright ownership.
6  *
7  * UCAID licenses this file to you under the Apache License,
8  * Version 2.0 (the "License"); you may not use this file except
9  * in compliance with the License. You may obtain a copy of the
10  * License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing,
15  * software distributed under the License is distributed on an
16  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
17  * either express or implied. See the License for the specific
18  * language governing permissions and limitations under the License.
19  */
20
21 /**
22  * SocketListener.cpp
23  *
24  * Berkeley Socket-based ListenerService implementation.
25  */
26
27 #include "internal.h"
28 #include "exceptions.h"
29 #include "ServiceProvider.h"
30 #include "SPConfig.h"
31 #include "remoting/impl/SocketListener.h"
32
33 #include <errno.h>
34 #include <stack>
35 #include <sstream>
36 #include <xercesc/util/XMLUniDefs.hpp>
37
38 #include <xmltooling/util/NDC.h>
39 #include <xmltooling/util/XMLHelper.h>
40
41 #ifndef WIN32
42 # include <netinet/in.h>
43 #endif
44
45 #ifdef HAVE_UNISTD_H
46 # include <unistd.h>
47 #endif
48
49 using namespace shibsp;
50 using namespace xmltooling;
51 using namespace std;
52 using xercesc::DOMElement;
53
54 namespace shibsp {
55
56     // Manages the pool of connections
57     class SocketPool
58     {
59     public:
60         SocketPool(Category& log, const SocketListener* listener)
61             : m_log(log), m_listener(listener), m_lock(Mutex::create()) {}
62         ~SocketPool();
63         SocketListener::ShibSocket get();
64         void put(SocketListener::ShibSocket s);
65
66     private:
67         SocketListener::ShibSocket connect();
68
69         Category& m_log;
70         const SocketListener* m_listener;
71         auto_ptr<Mutex> m_lock;
72         stack<SocketListener::ShibSocket> m_pool;
73     };
74
75     // Worker threads in server
76     class ServerThread {
77     public:
78         ServerThread(SocketListener::ShibSocket& s, SocketListener* listener, unsigned long id);
79         ~ServerThread();
80         void run();
81         int job();  // Return -1 on error, 1 for closed, 0 for success
82
83     private:
84         SocketListener::ShibSocket m_sock;
85         Thread* m_child;
86         SocketListener* m_listener;
87         string m_id;
88         char m_buf[16384];
89     };
90 }
91
92 SocketListener::ShibSocket SocketPool::connect()
93 {
94 #ifdef _DEBUG
95     NDC ndc("connect");
96 #endif
97
98     m_log.debug("trying to connect to listener");
99
100     SocketListener::ShibSocket sock;
101     if (!m_listener->create(sock)) {
102         m_log.error("cannot create socket");
103         throw ListenerException("Cannot create socket");
104     }
105
106     bool connected = false;
107     int num_tries = 3;
108
109     for (int i = num_tries-1; i >= 0; i--) {
110         if (m_listener->connect(sock)) {
111             connected = true;
112             break;
113         }
114
115         m_log.warn("cannot connect socket (%u)...%s", sock, (i > 0 ? "retrying" : ""));
116
117         if (i) {
118 #ifdef WIN32
119             Sleep(2000*(num_tries-i));
120 #else
121             sleep(2*(num_tries-i));
122 #endif
123         }
124     }
125
126     if (!connected) {
127         m_log.crit("socket server unavailable, failing");
128         m_listener->close(sock);
129         throw ListenerException("Cannot connect to shibd process, a site adminstrator should be notified.");
130     }
131
132     m_log.debug("socket (%u) connected successfully", sock);
133     return sock;
134 }
135
136 SocketPool::~SocketPool()
137 {
138     while (!m_pool.empty()) {
139 #ifdef WIN32
140         closesocket(m_pool.top());
141 #else
142         ::close(m_pool.top());
143 #endif
144         m_pool.pop();
145     }
146 }
147
148 SocketListener::ShibSocket SocketPool::get()
149 {
150     m_lock->lock();
151     if (m_pool.empty()) {
152         m_lock->unlock();
153         return connect();
154     }
155     SocketListener::ShibSocket ret=m_pool.top();
156     m_pool.pop();
157     m_lock->unlock();
158     return ret;
159 }
160
161 void SocketPool::put(SocketListener::ShibSocket s)
162 {
163     m_lock->lock();
164     m_pool.push(s);
165     m_lock->unlock();
166 }
167
168 SocketListener::SocketListener(const DOMElement* e)
169     : m_catchAll(false), log(&Category::getInstance(SHIBSP_LOGCAT".Listener")), m_socketpool(nullptr),
170         m_shutdown(nullptr), m_child_lock(nullptr), m_child_wait(nullptr), m_stackSize(0), m_socket((ShibSocket)0)
171 {
172     // Are we a client?
173     if (SPConfig::getConfig().isEnabled(SPConfig::InProcess)) {
174         m_socketpool=new SocketPool(*log,this);
175     }
176     // Are we a server?
177     if (SPConfig::getConfig().isEnabled(SPConfig::OutOfProcess)) {
178         m_child_lock = Mutex::create();
179         m_child_wait = CondWait::create();
180
181         static const XMLCh stackSize[] = UNICODE_LITERAL_9(s,t,a,c,k,S,i,z,e);
182         m_stackSize = XMLHelper::getAttrInt(e, 0, stackSize) * 1024;
183     }
184 }
185
186 SocketListener::~SocketListener()
187 {
188     delete m_socketpool;
189     delete m_child_wait;
190     delete m_child_lock;
191 }
192
193 bool SocketListener::init(bool force)
194 {
195 #ifdef _DEBUG
196     NDC ndc("init");
197 #endif
198     log->info("listener service starting");
199
200     ServiceProvider* sp = SPConfig::getConfig().getServiceProvider();
201     sp->lock();
202     const PropertySet* props = sp->getPropertySet("OutOfProcess");
203     if (props) {
204         pair<bool,bool> flag = props->getBool("catchAll");
205         m_catchAll = flag.first && flag.second;
206     }
207     sp->unlock();
208
209     if (!create(m_socket)) {
210         log->crit("failed to create socket");
211         return false;
212     }
213     if (!bind(m_socket, force)) {
214         this->close(m_socket);
215         log->crit("failed to bind to socket.");
216         return false;
217     }
218
219     return true;
220 }
221
222 bool SocketListener::run(bool* shutdown)
223 {
224 #ifdef _DEBUG
225     NDC ndc("run");
226 #endif
227     // Save flag to monitor for shutdown request.
228     m_shutdown=shutdown;
229     unsigned long count = 0;
230
231     while (!*m_shutdown) {
232         fd_set readfds;
233         FD_ZERO(&readfds);
234         FD_SET(m_socket, &readfds);
235         struct timeval tv = { 0, 0 };
236         tv.tv_sec = 5;
237
238         switch (select(m_socket + 1, &readfds, 0, 0, &tv)) {
239 #ifdef WIN32
240             case SOCKET_ERROR:
241 #else
242             case -1:
243 #endif
244                 if (errno == EINTR) continue;
245                 log_error();
246                 log->error("select() on main listener socket failed");
247                 *m_shutdown = true;
248                 break;
249
250             case 0:
251                 continue;
252
253             default:
254             {
255                 // Accept the connection.
256                 SocketListener::ShibSocket newsock;
257                 if (!accept(m_socket, newsock)) {
258                     log->crit("failed to accept incoming socket connection");
259                     continue;
260                 }
261
262                 // We throw away the result because the children manage themselves...
263                 try {
264                     new ServerThread(newsock,this,++count);
265                 }
266                 catch (exception& ex) {
267                     log->crit("exception starting new server thread to service incoming request: %s", ex.what());
268                 }
269                 catch (...) {
270                     log->crit("unknown error starting new server thread to service incoming request");
271                     if (!m_catchAll)
272                         *m_shutdown = true;
273                 }
274             }
275         }
276     }
277     log->info("listener service shutting down");
278
279     // Wait for all children to exit.
280     m_child_lock->lock();
281     while (!m_children.empty())
282         m_child_wait->wait(m_child_lock);
283     m_child_lock->unlock();
284
285     return true;
286 }
287
288 void SocketListener::term()
289 {
290     this->close(m_socket);
291     m_socket=(ShibSocket)0;
292 }
293
294 DDF SocketListener::send(const DDF& in)
295 {
296 #ifdef _DEBUG
297     NDC ndc("send");
298 #endif
299
300     log->debug("sending message (%s)", in.name() ? in.name() : "unnamed");
301
302     // Serialize data for transmission.
303     ostringstream os;
304     os << in;
305     string ostr(os.str());
306
307     // Loop on the RPC in case we lost contact the first time through
308 #ifdef WIN32
309     u_long len;
310 #else
311     uint32_t len;
312 #endif
313     int retry = 1;
314     SocketListener::ShibSocket sock;
315     while (retry >= 0) {
316         sock = m_socketpool->get();
317
318         int outlen = ostr.length();
319         len = htonl(outlen);
320         if (send(sock,(char*)&len,sizeof(len)) != sizeof(len) || send(sock,ostr.c_str(),outlen) != outlen) {
321             log_error();
322             this->close(sock);
323             if (retry)
324                 retry--;
325             else
326                 throw ListenerException("Failure sending remoted message ($1).", params(1,in.name()));
327         }
328         else {
329             // SUCCESS.
330             retry = -1;
331         }
332     }
333
334     log->debug("send completed, reading response message");
335
336     // Read the message.
337     while (recv(sock,(char*)&len,sizeof(len)) != sizeof(len)) {
338         if (errno == EINTR) continue;   // Apparently this happens when a signal interrupts the blocking call.
339         log->error("error reading size of output message");
340         this->close(sock);
341         throw ListenerException("Failure receiving response to remoted message ($1).", params(1,in.name()));
342     }
343     len = ntohl(len);
344
345     char buf[16384];
346     int size_read;
347     stringstream is;
348     while (len) {
349         size_read = recv(sock, buf, sizeof(buf));
350         if (size_read > 0) {
351             is.write(buf, size_read);
352             len -= size_read;
353         }
354         else if (errno != EINTR) {
355                 break;
356         }
357     }
358
359     if (len) {
360         log->error("error reading output message from socket");
361         this->close(sock);
362         throw ListenerException("Failure receiving response to remoted message ($1).", params(1,in.name()));
363     }
364
365     m_socketpool->put(sock);
366
367     // Unmarshall data.
368     DDF out;
369     is >> out;
370
371     // Check for exception to unmarshall and throw, otherwise return.
372     if (out.isstring() && out.name() && !strcmp(out.name(),"exception")) {
373         // Reconstitute exception object.
374         DDFJanitor jout(out);
375         XMLToolingException* except=nullptr;
376         try {
377             except=XMLToolingException::fromString(out.string());
378             log->error("remoted message returned an error: %s", except->what());
379         }
380         catch (XMLToolingException& e) {
381             log->error("caught XMLToolingException while building the XMLToolingException: %s", e.what());
382             log->error("XML was: %s", out.string());
383             throw ListenerException("Remote call failed with an unparsable exception.");
384         }
385
386         auto_ptr<XMLToolingException> wrapper(except);
387         wrapper->raise();
388     }
389
390     return out;
391 }
392
393 bool SocketListener::log_error(const char* fn) const
394 {
395     if (!fn)
396         fn = "unknown";
397 #ifdef WIN32
398     int rc=WSAGetLastError();
399 #else
400     int rc=errno;
401 #endif
402 #ifdef HAVE_STRERROR_R
403     char buf[256];
404     memset(buf,0,sizeof(buf));
405     strerror_r(rc,buf,sizeof(buf));
406     log->error("socket call (%s) resulted in error (%d): %s",fn, rc, isprint(*buf) ? buf : "no message");
407 #else
408     const char* buf=strerror(rc);
409     log->error("socket call (%s) resulted in error (%d): %s", fn, rc, isprint(*buf) ? buf : "no message");
410 #endif
411     return false;
412 }
413
414 // actual function run in listener on server threads
415 void* server_thread_fn(void* arg)
416 {
417     ServerThread* child = (ServerThread*)arg;
418
419 #ifndef WIN32
420     // First, let's block all signals
421     Thread::mask_all_signals();
422 #endif
423
424     // Run the child until it exits.
425     child->run();
426
427     // Now we can clean up and exit the thread.
428     delete child;
429     return nullptr;
430 }
431
432 ServerThread::ServerThread(SocketListener::ShibSocket& s, SocketListener* listener, unsigned long id)
433     : m_sock(s), m_child(nullptr), m_listener(listener)
434 {
435
436     ostringstream buf;
437     buf << "[" << id << "]";
438     m_id = buf.str();
439
440     // Create the child thread
441     m_child = Thread::create(server_thread_fn, (void*)this, m_listener->m_stackSize);
442     m_child->detach();
443 }
444
445 ServerThread::~ServerThread()
446 {
447     // Then lock the children map, remove this socket/thread, signal waiters, and return
448     m_listener->m_child_lock->lock();
449     m_listener->m_children.erase(m_sock);
450     m_listener->m_child_lock->unlock();
451     m_listener->m_child_wait->signal();
452
453     delete m_child;
454 }
455
456 void ServerThread::run()
457 {
458     NDC ndc(m_id);
459
460     // Before starting up, make sure we fully "own" this socket.
461     m_listener->m_child_lock->lock();
462     while (m_listener->m_children.find(m_sock)!=m_listener->m_children.end())
463         m_listener->m_child_wait->wait(m_listener->m_child_lock);
464     m_listener->m_children[m_sock] = m_child;
465     m_listener->m_child_lock->unlock();
466
467     int result;
468     fd_set readfds;
469     struct timeval tv = { 0, 0 };
470
471     while(!*(m_listener->m_shutdown)) {
472         FD_ZERO(&readfds);
473         FD_SET(m_sock, &readfds);
474         tv.tv_sec = 1;
475
476         switch (select(m_sock+1, &readfds, 0, 0, &tv)) {
477 #ifdef WIN32
478         case SOCKET_ERROR:
479 #else
480         case -1:
481 #endif
482             if (errno == EINTR) continue;
483             m_listener->log_error();
484             m_listener->log->error("select() on incoming request socket (%u) returned error", m_sock);
485             return;
486
487         case 0:
488             break;
489
490         default:
491             result = job();
492             if (result) {
493                 if (result < 0) {
494                     m_listener->log_error();
495                     m_listener->log->error("I/O failure processing request on socket (%u)", m_sock);
496                 }
497                 m_listener->close(m_sock);
498                 return;
499             }
500         }
501     }
502 }
503
504 int ServerThread::job()
505 {
506     Category& log = Category::getInstance(SHIBSP_LOGCAT".Listener");
507
508     bool incomingError = true;  // set false once incoming message is received
509     ostringstream sink;
510 #ifdef WIN32
511     u_long len;
512 #else
513     uint32_t len;
514 #endif
515
516     try {
517         // Read the message.
518         int readlength = m_listener->recv(m_sock,(char*)&len,sizeof(len));
519         if (readlength == 0) {
520             log.info("detected socket closure, shutting down worker thread");
521             return 1;
522         }
523         else if (readlength != sizeof(len)) {
524             log.error("error reading size of input message");
525             return -1;
526         }
527         len = ntohl(len);
528
529         int size_read;
530         stringstream is;
531         while (len && (size_read = m_listener->recv(m_sock, m_buf, sizeof(m_buf))) > 0) {
532             is.write(m_buf, size_read);
533             len -= size_read;
534         }
535
536         if (len) {
537             log.error("error reading input message from socket");
538             return -1;
539         }
540
541         // Unmarshall the message.
542         DDF in;
543         DDFJanitor jin(in);
544         is >> in;
545
546         log.debug("dispatching message (%s)", in.name() ? in.name() : "unnamed");
547
548         incomingError = false;
549
550         // Dispatch the message.
551         m_listener->receive(in, sink);
552     }
553     catch (XMLToolingException& e) {
554         if (incomingError)
555             log.error("error processing incoming message: %s", e.what());
556         DDF out=DDF("exception").string(e.toString().c_str());
557         DDFJanitor jout(out);
558         sink << out;
559     }
560     catch (exception& e) {
561         if (incomingError)
562             log.error("error processing incoming message: %s", e.what());
563         ListenerException ex(e.what());
564         DDF out=DDF("exception").string(ex.toString().c_str());
565         DDFJanitor jout(out);
566         sink << out;
567     }
568     catch (...) {
569         if (incomingError)
570             log.error("unexpected error processing incoming message");
571         if (!m_listener->m_catchAll)
572             throw;
573         ListenerException ex("An unexpected error occurred while processing an incoming message.");
574         DDF out=DDF("exception").string(ex.toString().c_str());
575         DDFJanitor jout(out);
576         sink << out;
577     }
578
579     // Return whatever's available.
580     string response(sink.str());
581     int outlen = response.length();
582     len = htonl(outlen);
583     if (m_listener->send(m_sock,(char*)&len,sizeof(len)) != sizeof(len)) {
584         log.error("error sending output message size");
585         return -1;
586     }
587     if (m_listener->send(m_sock,response.c_str(),outlen) != outlen) {
588         log.error("error sending output message");
589         return -1;
590     }
591
592     return 0;
593 }