Imported Upstream version 2.4+dfsg
[shibboleth/sp.git] / shibsp / remoting / impl / SocketListener.cpp
1 /*
2  *  Copyright 2001-2010 Internet2
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /**
18  * SocketListener.cpp
19  *
20  * Berkeley Socket-based ListenerService implementation.
21  */
22
23 #include "internal.h"
24 #include "exceptions.h"
25 #include "ServiceProvider.h"
26 #include "SPConfig.h"
27 #include "remoting/impl/SocketListener.h"
28
29 #include <errno.h>
30 #include <stack>
31 #include <sstream>
32 #include <xercesc/util/XMLUniDefs.hpp>
33
34 #include <xmltooling/util/NDC.h>
35 #include <xmltooling/util/XMLHelper.h>
36
37 #ifndef WIN32
38 # include <netinet/in.h>
39 #endif
40
41 #ifdef HAVE_UNISTD_H
42 # include <unistd.h>
43 #endif
44
45 using namespace shibsp;
46 using namespace xmltooling;
47 using namespace std;
48 using xercesc::DOMElement;
49
50 namespace shibsp {
51
52     // Manages the pool of connections
53     class SocketPool
54     {
55     public:
56         SocketPool(Category& log, const SocketListener* listener)
57             : m_log(log), m_listener(listener), m_lock(Mutex::create()) {}
58         ~SocketPool();
59         SocketListener::ShibSocket get();
60         void put(SocketListener::ShibSocket s);
61
62     private:
63         SocketListener::ShibSocket connect();
64
65         Category& m_log;
66         const SocketListener* m_listener;
67         auto_ptr<Mutex> m_lock;
68         stack<SocketListener::ShibSocket> m_pool;
69     };
70
71     // Worker threads in server
72     class ServerThread {
73     public:
74         ServerThread(SocketListener::ShibSocket& s, SocketListener* listener, unsigned long id);
75         ~ServerThread();
76         void run();
77         int job();  // Return -1 on error, 1 for closed, 0 for success
78
79     private:
80         SocketListener::ShibSocket m_sock;
81         Thread* m_child;
82         SocketListener* m_listener;
83         string m_id;
84         char m_buf[16384];
85     };
86 }
87
88 SocketListener::ShibSocket SocketPool::connect()
89 {
90 #ifdef _DEBUG
91     NDC ndc("connect");
92 #endif
93
94     m_log.debug("trying to connect to listener");
95
96     SocketListener::ShibSocket sock;
97     if (!m_listener->create(sock)) {
98         m_log.error("cannot create socket");
99         throw ListenerException("Cannot create socket");
100     }
101
102     bool connected = false;
103     int num_tries = 3;
104
105     for (int i = num_tries-1; i >= 0; i--) {
106         if (m_listener->connect(sock)) {
107             connected = true;
108             break;
109         }
110
111         m_log.warn("cannot connect socket (%u)...%s", sock, (i > 0 ? "retrying" : ""));
112
113         if (i) {
114 #ifdef WIN32
115             Sleep(2000*(num_tries-i));
116 #else
117             sleep(2*(num_tries-i));
118 #endif
119         }
120     }
121
122     if (!connected) {
123         m_log.crit("socket server unavailable, failing");
124         m_listener->close(sock);
125         throw ListenerException("Cannot connect to shibd process, a site adminstrator should be notified.");
126     }
127
128     m_log.debug("socket (%u) connected successfully", sock);
129     return sock;
130 }
131
132 SocketPool::~SocketPool()
133 {
134     while (!m_pool.empty()) {
135 #ifdef WIN32
136         closesocket(m_pool.top());
137 #else
138         ::close(m_pool.top());
139 #endif
140         m_pool.pop();
141     }
142 }
143
144 SocketListener::ShibSocket SocketPool::get()
145 {
146     m_lock->lock();
147     if (m_pool.empty()) {
148         m_lock->unlock();
149         return connect();
150     }
151     SocketListener::ShibSocket ret=m_pool.top();
152     m_pool.pop();
153     m_lock->unlock();
154     return ret;
155 }
156
157 void SocketPool::put(SocketListener::ShibSocket s)
158 {
159     m_lock->lock();
160     m_pool.push(s);
161     m_lock->unlock();
162 }
163
164 SocketListener::SocketListener(const DOMElement* e)
165     : m_catchAll(false), log(&Category::getInstance(SHIBSP_LOGCAT".Listener")), m_socketpool(nullptr),
166         m_shutdown(nullptr), m_child_lock(nullptr), m_child_wait(nullptr), m_stackSize(0), m_socket((ShibSocket)0)
167 {
168     // Are we a client?
169     if (SPConfig::getConfig().isEnabled(SPConfig::InProcess)) {
170         m_socketpool=new SocketPool(*log,this);
171     }
172     // Are we a server?
173     if (SPConfig::getConfig().isEnabled(SPConfig::OutOfProcess)) {
174         m_child_lock = Mutex::create();
175         m_child_wait = CondWait::create();
176
177         static const XMLCh stackSize[] = UNICODE_LITERAL_9(s,t,a,c,k,S,i,z,e);
178         m_stackSize = XMLHelper::getAttrInt(e, 0, stackSize) * 1024;
179     }
180 }
181
182 SocketListener::~SocketListener()
183 {
184     delete m_socketpool;
185     delete m_child_wait;
186     delete m_child_lock;
187 }
188
189 bool SocketListener::init(bool force)
190 {
191 #ifdef _DEBUG
192     NDC ndc("init");
193 #endif
194     log->info("listener service starting");
195
196     ServiceProvider* sp = SPConfig::getConfig().getServiceProvider();
197     sp->lock();
198     const PropertySet* props = sp->getPropertySet("OutOfProcess");
199     if (props) {
200         pair<bool,bool> flag = props->getBool("catchAll");
201         m_catchAll = flag.first && flag.second;
202     }
203     sp->unlock();
204
205     if (!create(m_socket)) {
206         log->crit("failed to create socket");
207         return false;
208     }
209     if (!bind(m_socket, force)) {
210         this->close(m_socket);
211         log->crit("failed to bind to socket.");
212         return false;
213     }
214
215     return true;
216 }
217
218 bool SocketListener::run(bool* shutdown)
219 {
220 #ifdef _DEBUG
221     NDC ndc("run");
222 #endif
223     // Save flag to monitor for shutdown request.
224     m_shutdown=shutdown;
225     unsigned long count = 0;
226
227     while (!*m_shutdown) {
228         fd_set readfds;
229         FD_ZERO(&readfds);
230         FD_SET(m_socket, &readfds);
231         struct timeval tv = { 0, 0 };
232         tv.tv_sec = 5;
233
234         switch (select(m_socket + 1, &readfds, 0, 0, &tv)) {
235 #ifdef WIN32
236             case SOCKET_ERROR:
237 #else
238             case -1:
239 #endif
240                 if (errno == EINTR) continue;
241                 log_error();
242                 log->error("select() on main listener socket failed");
243                 *m_shutdown = true;
244                 break;
245
246             case 0:
247                 continue;
248
249             default:
250             {
251                 // Accept the connection.
252                 SocketListener::ShibSocket newsock;
253                 if (!accept(m_socket, newsock)) {
254                     log->crit("failed to accept incoming socket connection");
255                     continue;
256                 }
257
258                 // We throw away the result because the children manage themselves...
259                 try {
260                     new ServerThread(newsock,this,++count);
261                 }
262                 catch (exception& ex) {
263                     log->crit("exception starting new server thread to service incoming request: %s", ex.what());
264                 }
265                 catch (...) {
266                     log->crit("unknown error starting new server thread to service incoming request");
267                     if (!m_catchAll)
268                         *m_shutdown = true;
269                 }
270             }
271         }
272     }
273     log->info("listener service shutting down");
274
275     // Wait for all children to exit.
276     m_child_lock->lock();
277     while (!m_children.empty())
278         m_child_wait->wait(m_child_lock);
279     m_child_lock->unlock();
280
281     return true;
282 }
283
284 void SocketListener::term()
285 {
286     this->close(m_socket);
287     m_socket=(ShibSocket)0;
288 }
289
290 DDF SocketListener::send(const DDF& in)
291 {
292 #ifdef _DEBUG
293     NDC ndc("send");
294 #endif
295
296     log->debug("sending message (%s)", in.name() ? in.name() : "unnamed");
297
298     // Serialize data for transmission.
299     ostringstream os;
300     os << in;
301     string ostr(os.str());
302
303     // Loop on the RPC in case we lost contact the first time through
304 #ifdef WIN32
305     u_long len;
306 #else
307     uint32_t len;
308 #endif
309     int retry = 1;
310     SocketListener::ShibSocket sock;
311     while (retry >= 0) {
312         sock = m_socketpool->get();
313
314         int outlen = ostr.length();
315         len = htonl(outlen);
316         if (send(sock,(char*)&len,sizeof(len)) != sizeof(len) || send(sock,ostr.c_str(),outlen) != outlen) {
317             log_error();
318             this->close(sock);
319             if (retry)
320                 retry--;
321             else
322                 throw ListenerException("Failure sending remoted message ($1).", params(1,in.name()));
323         }
324         else {
325             // SUCCESS.
326             retry = -1;
327         }
328     }
329
330     log->debug("send completed, reading response message");
331
332     // Read the message.
333     while (recv(sock,(char*)&len,sizeof(len)) != sizeof(len)) {
334         if (errno == EINTR) continue;   // Apparently this happens when a signal interrupts the blocking call.
335         log->error("error reading size of output message");
336         this->close(sock);
337         throw ListenerException("Failure receiving response to remoted message ($1).", params(1,in.name()));
338     }
339     len = ntohl(len);
340
341     char buf[16384];
342     int size_read;
343     stringstream is;
344     while (len) {
345         size_read = recv(sock, buf, sizeof(buf));
346         if (size_read > 0) {
347             is.write(buf, size_read);
348             len -= size_read;
349         }
350         else if (errno != EINTR) {
351                 break;
352         }
353     }
354
355     if (len) {
356         log->error("error reading output message from socket");
357         this->close(sock);
358         throw ListenerException("Failure receiving response to remoted message ($1).", params(1,in.name()));
359     }
360
361     m_socketpool->put(sock);
362
363     // Unmarshall data.
364     DDF out;
365     is >> out;
366
367     // Check for exception to unmarshall and throw, otherwise return.
368     if (out.isstring() && out.name() && !strcmp(out.name(),"exception")) {
369         // Reconstitute exception object.
370         DDFJanitor jout(out);
371         XMLToolingException* except=nullptr;
372         try {
373             except=XMLToolingException::fromString(out.string());
374             log->error("remoted message returned an error: %s", except->what());
375         }
376         catch (XMLToolingException& e) {
377             log->error("caught XMLToolingException while building the XMLToolingException: %s", e.what());
378             log->error("XML was: %s", out.string());
379             throw ListenerException("Remote call failed with an unparsable exception.");
380         }
381
382         auto_ptr<XMLToolingException> wrapper(except);
383         wrapper->raise();
384     }
385
386     return out;
387 }
388
389 bool SocketListener::log_error() const
390 {
391 #ifdef WIN32
392     int rc=WSAGetLastError();
393 #else
394     int rc=errno;
395 #endif
396 #ifdef HAVE_STRERROR_R
397     char buf[256];
398     memset(buf,0,sizeof(buf));
399     strerror_r(rc,buf,sizeof(buf));
400     log->error("socket call resulted in error (%d): %s",rc,isprint(*buf) ? buf : "no message");
401 #else
402     const char* buf=strerror(rc);
403     log->error("socket call resulted in error (%d): %s",rc,isprint(*buf) ? buf : "no message");
404 #endif
405     return false;
406 }
407
408 // actual function run in listener on server threads
409 void* server_thread_fn(void* arg)
410 {
411     ServerThread* child = (ServerThread*)arg;
412
413 #ifndef WIN32
414     // First, let's block all signals
415     Thread::mask_all_signals();
416 #endif
417
418     // Run the child until it exits.
419     child->run();
420
421     // Now we can clean up and exit the thread.
422     delete child;
423     return nullptr;
424 }
425
426 ServerThread::ServerThread(SocketListener::ShibSocket& s, SocketListener* listener, unsigned long id)
427     : m_sock(s), m_child(nullptr), m_listener(listener)
428 {
429
430     ostringstream buf;
431     buf << "[" << id << "]";
432     m_id = buf.str();
433
434     // Create the child thread
435     m_child = Thread::create(server_thread_fn, (void*)this, m_listener->m_stackSize);
436     m_child->detach();
437 }
438
439 ServerThread::~ServerThread()
440 {
441     // Then lock the children map, remove this socket/thread, signal waiters, and return
442     m_listener->m_child_lock->lock();
443     m_listener->m_children.erase(m_sock);
444     m_listener->m_child_lock->unlock();
445     m_listener->m_child_wait->signal();
446
447     delete m_child;
448 }
449
450 void ServerThread::run()
451 {
452     NDC ndc(m_id);
453
454     // Before starting up, make sure we fully "own" this socket.
455     m_listener->m_child_lock->lock();
456     while (m_listener->m_children.find(m_sock)!=m_listener->m_children.end())
457         m_listener->m_child_wait->wait(m_listener->m_child_lock);
458     m_listener->m_children[m_sock] = m_child;
459     m_listener->m_child_lock->unlock();
460
461     int result;
462     fd_set readfds;
463     struct timeval tv = { 0, 0 };
464
465     while(!*(m_listener->m_shutdown)) {
466         FD_ZERO(&readfds);
467         FD_SET(m_sock, &readfds);
468         tv.tv_sec = 1;
469
470         switch (select(m_sock+1, &readfds, 0, 0, &tv)) {
471 #ifdef WIN32
472         case SOCKET_ERROR:
473 #else
474         case -1:
475 #endif
476             if (errno == EINTR) continue;
477             m_listener->log_error();
478             m_listener->log->error("select() on incoming request socket (%u) returned error", m_sock);
479             return;
480
481         case 0:
482             break;
483
484         default:
485             result = job();
486             if (result) {
487                 if (result < 0) {
488                     m_listener->log_error();
489                     m_listener->log->error("I/O failure processing request on socket (%u)", m_sock);
490                 }
491                 m_listener->close(m_sock);
492                 return;
493             }
494         }
495     }
496 }
497
498 int ServerThread::job()
499 {
500     Category& log = Category::getInstance(SHIBSP_LOGCAT".Listener");
501
502     bool incomingError = true;  // set false once incoming message is received
503     ostringstream sink;
504 #ifdef WIN32
505     u_long len;
506 #else
507     uint32_t len;
508 #endif
509
510     try {
511         // Read the message.
512         int readlength = m_listener->recv(m_sock,(char*)&len,sizeof(len));
513         if (readlength == 0) {
514             log.info("detected socket closure, shutting down worker thread");
515             return 1;
516         }
517         else if (readlength != sizeof(len)) {
518             log.error("error reading size of input message");
519             return -1;
520         }
521         len = ntohl(len);
522
523         int size_read;
524         stringstream is;
525         while (len && (size_read = m_listener->recv(m_sock, m_buf, sizeof(m_buf))) > 0) {
526             is.write(m_buf, size_read);
527             len -= size_read;
528         }
529
530         if (len) {
531             log.error("error reading input message from socket");
532             return -1;
533         }
534
535         // Unmarshall the message.
536         DDF in;
537         DDFJanitor jin(in);
538         is >> in;
539
540         log.debug("dispatching message (%s)", in.name() ? in.name() : "unnamed");
541
542         incomingError = false;
543
544         // Dispatch the message.
545         m_listener->receive(in, sink);
546     }
547     catch (XMLToolingException& e) {
548         if (incomingError)
549             log.error("error processing incoming message: %s", e.what());
550         DDF out=DDF("exception").string(e.toString().c_str());
551         DDFJanitor jout(out);
552         sink << out;
553     }
554     catch (exception& e) {
555         if (incomingError)
556             log.error("error processing incoming message: %s", e.what());
557         ListenerException ex(e.what());
558         DDF out=DDF("exception").string(ex.toString().c_str());
559         DDFJanitor jout(out);
560         sink << out;
561     }
562     catch (...) {
563         if (incomingError)
564             log.error("unexpected error processing incoming message");
565         if (!m_listener->m_catchAll)
566             throw;
567         ListenerException ex("An unexpected error occurred while processing an incoming message.");
568         DDF out=DDF("exception").string(ex.toString().c_str());
569         DDFJanitor jout(out);
570         sink << out;
571     }
572
573     // Return whatever's available.
574     string response(sink.str());
575     int outlen = response.length();
576     len = htonl(outlen);
577     if (m_listener->send(m_sock,(char*)&len,sizeof(len)) != sizeof(len)) {
578         log.error("error sending output message size");
579         return -1;
580     }
581     if (m_listener->send(m_sock,response.c_str(),outlen) != outlen) {
582         log.error("error sending output message");
583         return -1;
584     }
585
586     return 0;
587 }