Add logging when catching unknown errors.
[shibboleth/cpp-sp.git] / shibsp / remoting / impl / SocketListener.cpp
1 /*
2  *  Copyright 2001-2007 Internet2
3  * 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /**
18  * SocketListener.cpp
19  * 
20  * Berkeley Socket-based ListenerService implementation
21  */
22
23 #include "internal.h"
24 #include "exceptions.h"
25 #include "ServiceProvider.h"
26 #include "remoting/impl/SocketListener.h"
27
28 #include <errno.h>
29 #include <stack>
30 #include <sstream>
31 #include <shibsp/SPConfig.h>
32 #include <xmltooling/util/NDC.h>
33
34 #ifndef WIN32
35 # include <netinet/in.h>
36 #endif
37
38 using namespace shibsp;
39 using namespace xmltooling;
40 using namespace std;
41 using xercesc::DOMElement;
42
43 namespace shibsp {
44   
45     // Manages the pool of connections
46     class SocketPool
47     {
48     public:
49         SocketPool(Category& log, const SocketListener* listener)
50             : m_log(log), m_listener(listener), m_lock(Mutex::create()) {}
51         ~SocketPool();
52         SocketListener::ShibSocket get();
53         void put(SocketListener::ShibSocket s);
54   
55     private:
56         SocketListener::ShibSocket connect();
57        
58         Category& m_log; 
59         const SocketListener* m_listener;
60         auto_ptr<Mutex> m_lock;
61         stack<SocketListener::ShibSocket> m_pool;
62     };
63   
64     // Worker threads in server
65     class ServerThread {
66     public:
67         ServerThread(SocketListener::ShibSocket& s, SocketListener* listener, unsigned long id);
68         ~ServerThread();
69         void run();
70         int job();  // Return -1 on error, 1 for closed, 0 for success
71
72     private:
73         SocketListener::ShibSocket m_sock;
74         Thread* m_child;
75         SocketListener* m_listener;
76         string m_id;
77         char m_buf[16384];
78     };
79 }
80
81 SocketListener::ShibSocket SocketPool::connect()
82 {
83 #ifdef _DEBUG
84     NDC ndc("connect");
85 #endif
86
87     m_log.debug("trying to connect to listener");
88
89     SocketListener::ShibSocket sock;
90     if (!m_listener->create(sock)) {
91         m_log.error("cannot create socket");
92         throw ListenerException("Cannot create socket");
93     }
94
95     bool connected = false;
96     int num_tries = 3;
97
98     for (int i = num_tries-1; i >= 0; i--) {
99         if (m_listener->connect(sock)) {
100             connected = true;
101             break;
102         }
103     
104         m_log.warn("cannot connect socket (%u)...%s", sock, (i > 0 ? "retrying" : ""));
105
106         if (i) {
107 #ifdef WIN32
108             Sleep(2000*(num_tries-i));
109 #else
110             sleep(2*(num_tries-i));
111 #endif
112         }
113     }
114
115     if (!connected) {
116         m_log.crit("socket server unavailable, failing");
117         m_listener->close(sock);
118         throw ListenerException("Cannot connect to shibd process, a site adminstrator should be notified.");
119     }
120
121     m_log.debug("socket (%u) connected successfully", sock);
122     return sock;
123 }
124
125 SocketPool::~SocketPool()
126 {
127     while (!m_pool.empty()) {
128 #ifdef WIN32
129         closesocket(m_pool.top());
130 #else
131         ::close(m_pool.top());
132 #endif
133         m_pool.pop();
134     }
135 }
136
137 SocketListener::ShibSocket SocketPool::get()
138 {
139     m_lock->lock();
140     if (m_pool.empty()) {
141         m_lock->unlock();
142         return connect();
143     }
144     SocketListener::ShibSocket ret=m_pool.top();
145     m_pool.pop();
146     m_lock->unlock();
147     return ret;
148 }
149
150 void SocketPool::put(SocketListener::ShibSocket s)
151 {
152     m_lock->lock();
153     m_pool.push(s);
154     m_lock->unlock();
155 }
156
157 SocketListener::SocketListener(const DOMElement* e) : m_catchAll(false), log(&Category::getInstance(SHIBSP_LOGCAT".Listener")),
158     m_socketpool(NULL), m_shutdown(NULL), m_child_lock(NULL), m_child_wait(NULL), m_socket((ShibSocket)0)
159 {
160     // Are we a client?
161     if (SPConfig::getConfig().isEnabled(SPConfig::InProcess)) {
162         m_socketpool=new SocketPool(*log,this);
163     }
164     // Are we a server?
165     if (SPConfig::getConfig().isEnabled(SPConfig::OutOfProcess)) {
166         m_child_lock = Mutex::create();
167         m_child_wait = CondWait::create();
168     }
169 }
170
171 SocketListener::~SocketListener()
172 {
173     delete m_socketpool;
174     delete m_child_wait;
175     delete m_child_lock;
176 }
177
178 bool SocketListener::run(bool* shutdown)
179 {
180 #ifdef _DEBUG
181     NDC ndc("run");
182 #endif
183     log->info("listener service starting");
184
185     ServiceProvider* sp = SPConfig::getConfig().getServiceProvider();
186     sp->lock();
187     const PropertySet* props = sp->getPropertySet("OutOfProcess");
188     if (props) {
189         pair<bool,bool> flag = props->getBool("catchAll");
190         m_catchAll = flag.first && flag.second;
191     }
192     sp->unlock();
193     
194     // Save flag to monitor for shutdown request.
195     m_shutdown=shutdown;
196     unsigned long count = 0;
197
198     if (!create(m_socket)) {
199         log->crit("failed to create socket");
200         return false;
201     }
202     if (!bind(m_socket,true)) {
203         this->close(m_socket);
204         log->crit("failed to bind to socket.");
205         return false;
206     }
207
208     while (!*m_shutdown) {
209         fd_set readfds;
210         FD_ZERO(&readfds);
211         FD_SET(m_socket, &readfds);
212         struct timeval tv = { 0, 0 };
213         tv.tv_sec = 5;
214     
215         switch (select(m_socket + 1, &readfds, 0, 0, &tv)) {
216 #ifdef WIN32
217             case SOCKET_ERROR:
218 #else
219             case -1:
220 #endif
221                 if (errno == EINTR) continue;
222                 log_error();
223                 log->error("select() on main listener socket failed");
224                 return false;
225         
226             case 0:
227                 continue;
228         
229             default:
230             {
231                 // Accept the connection.
232                 SocketListener::ShibSocket newsock;
233                 if (!accept(m_socket, newsock))
234                     log->crit("failed to accept incoming socket connection");
235
236                 // We throw away the result because the children manage themselves...
237                 try {
238                     new ServerThread(newsock,this,++count);
239                 }
240                 catch (...) {
241                     log->crit("error starting new server thread to service incoming request");
242                     if (!m_catchAll)
243                         *m_shutdown = true;
244                 }
245             }
246         }
247     }
248     log->info("listener service shutting down");
249
250     // Wait for all children to exit.
251     m_child_lock->lock();
252     while (!m_children.empty())
253         m_child_wait->wait(m_child_lock);
254     m_child_lock->unlock();
255
256     this->close(m_socket);
257     m_socket=(ShibSocket)0;
258     return true;
259 }
260
261 DDF SocketListener::send(const DDF& in)
262 {
263 #ifdef _DEBUG
264     NDC ndc("send");
265 #endif
266
267     log->debug("sending message (%s)", in.name() ? in.name() : "unnamed");
268
269     // Serialize data for transmission.
270     ostringstream os;
271     os << in;
272     string ostr(os.str());
273
274     // Loop on the RPC in case we lost contact the first time through
275 #ifdef WIN32
276     u_long len;
277 #else
278     uint32_t len;
279 #endif
280     int retry = 1;
281     SocketListener::ShibSocket sock;
282     while (retry >= 0) {
283         sock = m_socketpool->get();
284         
285         int outlen = ostr.length();
286         len = htonl(outlen);
287         if (send(sock,(char*)&len,sizeof(len)) != sizeof(len) || send(sock,ostr.c_str(),outlen) != outlen) {
288             log_error();
289             this->close(sock);
290             if (retry)
291                 retry--;
292             else
293                 throw ListenerException("Failure sending remoted message ($1).", params(1,in.name()));
294         }
295         else {
296             // SUCCESS.
297             retry = -1;
298         }
299     }
300
301     log->debug("send completed, reading response message");
302
303     // Read the message.
304     if (recv(sock,(char*)&len,sizeof(len)) != sizeof(len)) {
305         log->error("error reading size of output message");
306         this->close(sock);
307         throw ListenerException("Failure receiving response to remoted message ($1).", params(1,in.name()));
308     }
309     len = ntohl(len);
310     
311     char buf[16384];
312     int size_read;
313     stringstream is;
314     while (len && (size_read = recv(sock, buf, sizeof(buf))) > 0) {
315         is.write(buf, size_read);
316         len -= size_read;
317     }
318     
319     if (len) {
320         log->error("error reading output message from socket");
321         this->close(sock);
322         throw ListenerException("Failure receiving response to remoted message ($1).", params(1,in.name()));
323     }
324     
325     m_socketpool->put(sock);
326
327     // Unmarshall data.
328     DDF out;
329     is >> out;
330     
331     // Check for exception to unmarshall and throw, otherwise return.
332     if (out.isstring() && out.name() && !strcmp(out.name(),"exception")) {
333         // Reconstitute exception object.
334         DDFJanitor jout(out);
335         XMLToolingException* except=NULL;
336         try { 
337             except=XMLToolingException::fromString(out.string());
338             log->error("remoted message returned an error: %s", except->what());
339         }
340         catch (XMLToolingException& e) {
341             log->error("caught XMLToolingException while building the XMLToolingException: %s", e.what());
342             log->error("XML was: %s", out.string());
343             throw ListenerException("Remote call failed with an unparsable exception.");
344         }
345
346         auto_ptr<XMLToolingException> wrapper(except);
347         wrapper->raise();
348     }
349
350     return out;
351 }
352
353 bool SocketListener::log_error() const
354 {
355 #ifdef WIN32
356     int rc=WSAGetLastError();
357 #else
358     int rc=errno;
359 #endif
360 #ifdef HAVE_STRERROR_R
361     char buf[256];
362     memset(buf,0,sizeof(buf));
363     strerror_r(rc,buf,sizeof(buf));
364     log->error("socket call resulted in error (%d): %s",rc,isprint(*buf) ? buf : "no message");
365 #else
366     const char* buf=strerror(rc);
367     log->error("socket call resulted in error (%d): %s",rc,isprint(*buf) ? buf : "no message");
368 #endif
369     return false;
370 }
371
372 // actual function run in listener on server threads
373 void* server_thread_fn(void* arg)
374 {
375     ServerThread* child = (ServerThread*)arg;
376
377 #ifndef WIN32
378     // First, let's block all signals
379     Thread::mask_all_signals();
380 #endif
381
382     // Run the child until it exits.
383     child->run();
384
385     // Now we can clean up and exit the thread.
386     delete child;
387     return NULL;
388 }
389
390 ServerThread::ServerThread(SocketListener::ShibSocket& s, SocketListener* listener, unsigned long id)
391     : m_sock(s), m_child(NULL), m_listener(listener)
392 {
393
394     ostringstream buf;
395     buf << "[" << id << "]";
396     m_id = buf.str();
397
398     // Create the child thread
399     m_child = Thread::create(server_thread_fn, (void*)this);
400     m_child->detach();
401 }
402
403 ServerThread::~ServerThread()
404 {
405     // Then lock the children map, remove this socket/thread, signal waiters, and return
406     m_listener->m_child_lock->lock();
407     m_listener->m_children.erase(m_sock);
408     m_listener->m_child_lock->unlock();
409     m_listener->m_child_wait->signal();
410   
411     delete m_child;
412 }
413
414 void ServerThread::run()
415 {
416     NDC ndc(m_id);
417
418     // Before starting up, make sure we fully "own" this socket.
419     m_listener->m_child_lock->lock();
420     while (m_listener->m_children.find(m_sock)!=m_listener->m_children.end())
421         m_listener->m_child_wait->wait(m_listener->m_child_lock);
422     m_listener->m_children[m_sock] = m_child;
423     m_listener->m_child_lock->unlock();
424     
425     int result;
426     fd_set readfds;
427     struct timeval tv = { 0, 0 };
428
429     while(!*(m_listener->m_shutdown)) {
430         FD_ZERO(&readfds);
431         FD_SET(m_sock, &readfds);
432         tv.tv_sec = 1;
433
434         switch (select(m_sock+1, &readfds, 0, 0, &tv)) {
435 #ifdef WIN32
436         case SOCKET_ERROR:
437 #else
438         case -1:
439 #endif
440             if (errno == EINTR) continue;
441             m_listener->log_error();
442             m_listener->log->error("select() on incoming request socket (%u) returned error", m_sock);
443             return;
444
445         case 0:
446             break;
447
448         default:
449             result = job();
450             if (result) {
451                 if (result < 0) {
452                     m_listener->log_error();
453                     m_listener->log->error("I/O failure processing request on socket (%u)", m_sock);
454                 }
455                 m_listener->close(m_sock);
456                 return;
457             }
458         }
459     }
460 }
461
462 int ServerThread::job()
463 {
464     Category& log = Category::getInstance("shibd.Listener");
465
466     bool incomingError = true;  // set false once incoming message is received
467     ostringstream sink;
468 #ifdef WIN32
469     u_long len;
470 #else
471     uint32_t len;
472 #endif
473
474     try {
475         // Read the message.
476         int readlength = m_listener->recv(m_sock,(char*)&len,sizeof(len));
477         if (readlength == 0) {
478             log.info("detected socket closure, shutting down worker thread");
479             return 1;
480         }
481         else if (readlength != sizeof(len)) {
482             log.error("error reading size of input message");
483             return -1;
484         }
485         len = ntohl(len);
486         
487         int size_read;
488         stringstream is;
489         while (len && (size_read = m_listener->recv(m_sock, m_buf, sizeof(m_buf))) > 0) {
490             is.write(m_buf, size_read);
491             len -= size_read;
492         }
493         
494         if (len) {
495             log.error("error reading input message from socket");
496             return -1;
497         }
498         
499         // Unmarshall the message.
500         DDF in;
501         DDFJanitor jin(in);
502         is >> in;
503
504         log.debug("dispatching message (%s)", in.name() ? in.name() : "unnamed");
505
506         incomingError = false;
507
508         // Dispatch the message.
509         m_listener->receive(in, sink);
510     }
511     catch (XMLToolingException& e) {
512         if (incomingError)
513             log.error("error processing incoming message: %s", e.what());
514         DDF out=DDF("exception").string(e.toString().c_str());
515         DDFJanitor jout(out);
516         sink << out;
517     }
518     catch (exception& e) {
519         if (incomingError)
520             log.error("error processing incoming message: %s", e.what());
521         ListenerException ex(e.what());
522         DDF out=DDF("exception").string(ex.toString().c_str());
523         DDFJanitor jout(out);
524         sink << out;
525     }
526     catch (...) {
527         if (incomingError)
528             log.error("unexpected error processing incoming message");
529         if (!m_listener->m_catchAll)
530             throw;
531         ListenerException ex("An unexpected error occurred while processing an incoming message.");
532         DDF out=DDF("exception").string(ex.toString().c_str());
533         DDFJanitor jout(out);
534         sink << out;
535     }
536     
537     // Return whatever's available.
538     string response(sink.str());
539     int outlen = response.length();
540     len = htonl(outlen);
541     if (m_listener->send(m_sock,(char*)&len,sizeof(len)) != sizeof(len)) {
542         log.error("error sending output message size");
543         return -1;
544     }
545     if (m_listener->send(m_sock,response.c_str(),outlen) != outlen) {
546         log.error("error sending output message");
547         return -1;
548     }
549     
550     return 0;
551 }