50aa95bb6541ef2bd32b70844f0e8449b8d960d8
[shibboleth/sp.git] / shibsp / remoting / impl / SocketListener.cpp
1 /*
2  *  Copyright 2001-2010 Internet2
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /**
18  * SocketListener.cpp
19  *
20  * Berkeley Socket-based ListenerService implementation.
21  */
22
23 #include "internal.h"
24 #include "exceptions.h"
25 #include "ServiceProvider.h"
26 #include "remoting/impl/SocketListener.h"
27
28 #include <errno.h>
29 #include <stack>
30 #include <sstream>
31 #include <shibsp/SPConfig.h>
32 #include <xmltooling/util/NDC.h>
33
34 #ifndef WIN32
35 # include <netinet/in.h>
36 #endif
37
38 #ifdef HAVE_UNISTD_H
39 # include <unistd.h>
40 #endif
41
42 using namespace shibsp;
43 using namespace xmltooling;
44 using namespace std;
45 using xercesc::DOMElement;
46
47 namespace shibsp {
48
49     // Manages the pool of connections
50     class SocketPool
51     {
52     public:
53         SocketPool(Category& log, const SocketListener* listener)
54             : m_log(log), m_listener(listener), m_lock(Mutex::create()) {}
55         ~SocketPool();
56         SocketListener::ShibSocket get();
57         void put(SocketListener::ShibSocket s);
58
59     private:
60         SocketListener::ShibSocket connect();
61
62         Category& m_log;
63         const SocketListener* m_listener;
64         auto_ptr<Mutex> m_lock;
65         stack<SocketListener::ShibSocket> m_pool;
66     };
67
68     // Worker threads in server
69     class ServerThread {
70     public:
71         ServerThread(SocketListener::ShibSocket& s, SocketListener* listener, unsigned long id);
72         ~ServerThread();
73         void run();
74         int job();  // Return -1 on error, 1 for closed, 0 for success
75
76     private:
77         SocketListener::ShibSocket m_sock;
78         Thread* m_child;
79         SocketListener* m_listener;
80         string m_id;
81         char m_buf[16384];
82     };
83 }
84
85 SocketListener::ShibSocket SocketPool::connect()
86 {
87 #ifdef _DEBUG
88     NDC ndc("connect");
89 #endif
90
91     m_log.debug("trying to connect to listener");
92
93     SocketListener::ShibSocket sock;
94     if (!m_listener->create(sock)) {
95         m_log.error("cannot create socket");
96         throw ListenerException("Cannot create socket");
97     }
98
99     bool connected = false;
100     int num_tries = 3;
101
102     for (int i = num_tries-1; i >= 0; i--) {
103         if (m_listener->connect(sock)) {
104             connected = true;
105             break;
106         }
107
108         m_log.warn("cannot connect socket (%u)...%s", sock, (i > 0 ? "retrying" : ""));
109
110         if (i) {
111 #ifdef WIN32
112             Sleep(2000*(num_tries-i));
113 #else
114             sleep(2*(num_tries-i));
115 #endif
116         }
117     }
118
119     if (!connected) {
120         m_log.crit("socket server unavailable, failing");
121         m_listener->close(sock);
122         throw ListenerException("Cannot connect to shibd process, a site adminstrator should be notified.");
123     }
124
125     m_log.debug("socket (%u) connected successfully", sock);
126     return sock;
127 }
128
129 SocketPool::~SocketPool()
130 {
131     while (!m_pool.empty()) {
132 #ifdef WIN32
133         closesocket(m_pool.top());
134 #else
135         ::close(m_pool.top());
136 #endif
137         m_pool.pop();
138     }
139 }
140
141 SocketListener::ShibSocket SocketPool::get()
142 {
143     m_lock->lock();
144     if (m_pool.empty()) {
145         m_lock->unlock();
146         return connect();
147     }
148     SocketListener::ShibSocket ret=m_pool.top();
149     m_pool.pop();
150     m_lock->unlock();
151     return ret;
152 }
153
154 void SocketPool::put(SocketListener::ShibSocket s)
155 {
156     m_lock->lock();
157     m_pool.push(s);
158     m_lock->unlock();
159 }
160
161 SocketListener::SocketListener(const DOMElement* e)
162     : m_catchAll(false), log(&Category::getInstance(SHIBSP_LOGCAT".Listener")), m_socketpool(NULL),
163         m_shutdown(NULL), m_child_lock(NULL), m_child_wait(NULL), m_stackSize(0), m_socket((ShibSocket)0)
164 {
165     // Are we a client?
166     if (SPConfig::getConfig().isEnabled(SPConfig::InProcess)) {
167         m_socketpool=new SocketPool(*log,this);
168     }
169     // Are we a server?
170     if (SPConfig::getConfig().isEnabled(SPConfig::OutOfProcess)) {
171         m_child_lock = Mutex::create();
172         m_child_wait = CondWait::create();
173
174         static const XMLCh stackSize[] = UNICODE_LITERAL_9(s,t,a,c,k,S,i,z,e);
175         const XMLCh* attr = e ? e->getAttributeNS(NULL, stackSize) : NULL;
176         if (attr && *attr)
177             m_stackSize = XMLString::parseInt(attr) * 1024;
178     }
179 }
180
181 SocketListener::~SocketListener()
182 {
183     delete m_socketpool;
184     delete m_child_wait;
185     delete m_child_lock;
186 }
187
188 bool SocketListener::init(bool force)
189 {
190 #ifdef _DEBUG
191     NDC ndc("init");
192 #endif
193     log->info("listener service starting");
194
195     ServiceProvider* sp = SPConfig::getConfig().getServiceProvider();
196     sp->lock();
197     const PropertySet* props = sp->getPropertySet("OutOfProcess");
198     if (props) {
199         pair<bool,bool> flag = props->getBool("catchAll");
200         m_catchAll = flag.first && flag.second;
201     }
202     sp->unlock();
203
204     if (!create(m_socket)) {
205         log->crit("failed to create socket");
206         return false;
207     }
208     if (!bind(m_socket, force)) {
209         this->close(m_socket);
210         log->crit("failed to bind to socket.");
211         return false;
212     }
213
214     return true;
215 }
216
217 bool SocketListener::run(bool* shutdown)
218 {
219 #ifdef _DEBUG
220     NDC ndc("run");
221 #endif
222     // Save flag to monitor for shutdown request.
223     m_shutdown=shutdown;
224     unsigned long count = 0;
225
226     while (!*m_shutdown) {
227         fd_set readfds;
228         FD_ZERO(&readfds);
229         FD_SET(m_socket, &readfds);
230         struct timeval tv = { 0, 0 };
231         tv.tv_sec = 5;
232
233         switch (select(m_socket + 1, &readfds, 0, 0, &tv)) {
234 #ifdef WIN32
235             case SOCKET_ERROR:
236 #else
237             case -1:
238 #endif
239                 if (errno == EINTR) continue;
240                 log_error();
241                 log->error("select() on main listener socket failed");
242                 *m_shutdown = true;
243                 break;
244
245             case 0:
246                 continue;
247
248             default:
249             {
250                 // Accept the connection.
251                 SocketListener::ShibSocket newsock;
252                 if (!accept(m_socket, newsock)) {
253                     log->crit("failed to accept incoming socket connection");
254                     continue;
255                 }
256
257                 // We throw away the result because the children manage themselves...
258                 try {
259                     new ServerThread(newsock,this,++count);
260                 }
261                 catch (exception& ex) {
262                     log->crit("exception starting new server thread to service incoming request: %s", ex.what());
263                 }
264                 catch (...) {
265                     log->crit("unknown error starting new server thread to service incoming request");
266                     if (!m_catchAll)
267                         *m_shutdown = true;
268                 }
269             }
270         }
271     }
272     log->info("listener service shutting down");
273
274     // Wait for all children to exit.
275     m_child_lock->lock();
276     while (!m_children.empty())
277         m_child_wait->wait(m_child_lock);
278     m_child_lock->unlock();
279
280     return true;
281 }
282
283 void SocketListener::term()
284 {
285     this->close(m_socket);
286     m_socket=(ShibSocket)0;
287 }
288
289 DDF SocketListener::send(const DDF& in)
290 {
291 #ifdef _DEBUG
292     NDC ndc("send");
293 #endif
294
295     log->debug("sending message (%s)", in.name() ? in.name() : "unnamed");
296
297     // Serialize data for transmission.
298     ostringstream os;
299     os << in;
300     string ostr(os.str());
301
302     // Loop on the RPC in case we lost contact the first time through
303 #ifdef WIN32
304     u_long len;
305 #else
306     uint32_t len;
307 #endif
308     int retry = 1;
309     SocketListener::ShibSocket sock;
310     while (retry >= 0) {
311         sock = m_socketpool->get();
312
313         int outlen = ostr.length();
314         len = htonl(outlen);
315         if (send(sock,(char*)&len,sizeof(len)) != sizeof(len) || send(sock,ostr.c_str(),outlen) != outlen) {
316             log_error();
317             this->close(sock);
318             if (retry)
319                 retry--;
320             else
321                 throw ListenerException("Failure sending remoted message ($1).", params(1,in.name()));
322         }
323         else {
324             // SUCCESS.
325             retry = -1;
326         }
327     }
328
329     log->debug("send completed, reading response message");
330
331     // Read the message.
332     while (recv(sock,(char*)&len,sizeof(len)) != sizeof(len)) {
333         if (errno == EINTR) continue;   // Apparently this happens when a signal interrupts the blocking call.
334         log->error("error reading size of output message");
335         this->close(sock);
336         throw ListenerException("Failure receiving response to remoted message ($1).", params(1,in.name()));
337     }
338     len = ntohl(len);
339
340     char buf[16384];
341     int size_read;
342     stringstream is;
343     while (len) {
344         size_read = recv(sock, buf, sizeof(buf));
345         if (size_read > 0) {
346             is.write(buf, size_read);
347             len -= size_read;
348         }
349         else if (errno != EINTR) {
350                 break;
351         }
352     }
353
354     if (len) {
355         log->error("error reading output message from socket");
356         this->close(sock);
357         throw ListenerException("Failure receiving response to remoted message ($1).", params(1,in.name()));
358     }
359
360     m_socketpool->put(sock);
361
362     // Unmarshall data.
363     DDF out;
364     is >> out;
365
366     // Check for exception to unmarshall and throw, otherwise return.
367     if (out.isstring() && out.name() && !strcmp(out.name(),"exception")) {
368         // Reconstitute exception object.
369         DDFJanitor jout(out);
370         XMLToolingException* except=NULL;
371         try {
372             except=XMLToolingException::fromString(out.string());
373             log->error("remoted message returned an error: %s", except->what());
374         }
375         catch (XMLToolingException& e) {
376             log->error("caught XMLToolingException while building the XMLToolingException: %s", e.what());
377             log->error("XML was: %s", out.string());
378             throw ListenerException("Remote call failed with an unparsable exception.");
379         }
380
381         auto_ptr<XMLToolingException> wrapper(except);
382         wrapper->raise();
383     }
384
385     return out;
386 }
387
388 bool SocketListener::log_error() const
389 {
390 #ifdef WIN32
391     int rc=WSAGetLastError();
392 #else
393     int rc=errno;
394 #endif
395 #ifdef HAVE_STRERROR_R
396     char buf[256];
397     memset(buf,0,sizeof(buf));
398     strerror_r(rc,buf,sizeof(buf));
399     log->error("socket call resulted in error (%d): %s",rc,isprint(*buf) ? buf : "no message");
400 #else
401     const char* buf=strerror(rc);
402     log->error("socket call resulted in error (%d): %s",rc,isprint(*buf) ? buf : "no message");
403 #endif
404     return false;
405 }
406
407 // actual function run in listener on server threads
408 void* server_thread_fn(void* arg)
409 {
410     ServerThread* child = (ServerThread*)arg;
411
412 #ifndef WIN32
413     // First, let's block all signals
414     Thread::mask_all_signals();
415 #endif
416
417     // Run the child until it exits.
418     child->run();
419
420     // Now we can clean up and exit the thread.
421     delete child;
422     return NULL;
423 }
424
425 ServerThread::ServerThread(SocketListener::ShibSocket& s, SocketListener* listener, unsigned long id)
426     : m_sock(s), m_child(NULL), m_listener(listener)
427 {
428
429     ostringstream buf;
430     buf << "[" << id << "]";
431     m_id = buf.str();
432
433     // Create the child thread
434     m_child = Thread::create(server_thread_fn, (void*)this, m_listener->m_stackSize);
435     m_child->detach();
436 }
437
438 ServerThread::~ServerThread()
439 {
440     // Then lock the children map, remove this socket/thread, signal waiters, and return
441     m_listener->m_child_lock->lock();
442     m_listener->m_children.erase(m_sock);
443     m_listener->m_child_lock->unlock();
444     m_listener->m_child_wait->signal();
445
446     delete m_child;
447 }
448
449 void ServerThread::run()
450 {
451     NDC ndc(m_id);
452
453     // Before starting up, make sure we fully "own" this socket.
454     m_listener->m_child_lock->lock();
455     while (m_listener->m_children.find(m_sock)!=m_listener->m_children.end())
456         m_listener->m_child_wait->wait(m_listener->m_child_lock);
457     m_listener->m_children[m_sock] = m_child;
458     m_listener->m_child_lock->unlock();
459
460     int result;
461     fd_set readfds;
462     struct timeval tv = { 0, 0 };
463
464     while(!*(m_listener->m_shutdown)) {
465         FD_ZERO(&readfds);
466         FD_SET(m_sock, &readfds);
467         tv.tv_sec = 1;
468
469         switch (select(m_sock+1, &readfds, 0, 0, &tv)) {
470 #ifdef WIN32
471         case SOCKET_ERROR:
472 #else
473         case -1:
474 #endif
475             if (errno == EINTR) continue;
476             m_listener->log_error();
477             m_listener->log->error("select() on incoming request socket (%u) returned error", m_sock);
478             return;
479
480         case 0:
481             break;
482
483         default:
484             result = job();
485             if (result) {
486                 if (result < 0) {
487                     m_listener->log_error();
488                     m_listener->log->error("I/O failure processing request on socket (%u)", m_sock);
489                 }
490                 m_listener->close(m_sock);
491                 return;
492             }
493         }
494     }
495 }
496
497 int ServerThread::job()
498 {
499     Category& log = Category::getInstance(SHIBSP_LOGCAT".Listener");
500
501     bool incomingError = true;  // set false once incoming message is received
502     ostringstream sink;
503 #ifdef WIN32
504     u_long len;
505 #else
506     uint32_t len;
507 #endif
508
509     try {
510         // Read the message.
511         int readlength = m_listener->recv(m_sock,(char*)&len,sizeof(len));
512         if (readlength == 0) {
513             log.info("detected socket closure, shutting down worker thread");
514             return 1;
515         }
516         else if (readlength != sizeof(len)) {
517             log.error("error reading size of input message");
518             return -1;
519         }
520         len = ntohl(len);
521
522         int size_read;
523         stringstream is;
524         while (len && (size_read = m_listener->recv(m_sock, m_buf, sizeof(m_buf))) > 0) {
525             is.write(m_buf, size_read);
526             len -= size_read;
527         }
528
529         if (len) {
530             log.error("error reading input message from socket");
531             return -1;
532         }
533
534         // Unmarshall the message.
535         DDF in;
536         DDFJanitor jin(in);
537         is >> in;
538
539         log.debug("dispatching message (%s)", in.name() ? in.name() : "unnamed");
540
541         incomingError = false;
542
543         // Dispatch the message.
544         m_listener->receive(in, sink);
545     }
546     catch (XMLToolingException& e) {
547         if (incomingError)
548             log.error("error processing incoming message: %s", e.what());
549         DDF out=DDF("exception").string(e.toString().c_str());
550         DDFJanitor jout(out);
551         sink << out;
552     }
553     catch (exception& e) {
554         if (incomingError)
555             log.error("error processing incoming message: %s", e.what());
556         ListenerException ex(e.what());
557         DDF out=DDF("exception").string(ex.toString().c_str());
558         DDFJanitor jout(out);
559         sink << out;
560     }
561     catch (...) {
562         if (incomingError)
563             log.error("unexpected error processing incoming message");
564         if (!m_listener->m_catchAll)
565             throw;
566         ListenerException ex("An unexpected error occurred while processing an incoming message.");
567         DDF out=DDF("exception").string(ex.toString().c_str());
568         DDFJanitor jout(out);
569         sink << out;
570     }
571
572     // Return whatever's available.
573     string response(sink.str());
574     int outlen = response.length();
575     len = htonl(outlen);
576     if (m_listener->send(m_sock,(char*)&len,sizeof(len)) != sizeof(len)) {
577         log.error("error sending output message size");
578         return -1;
579     }
580     if (m_listener->send(m_sock,response.c_str(),outlen) != outlen) {
581         log.error("error sending output message");
582         return -1;
583     }
584
585     return 0;
586 }