Detect socket closure.
[shibboleth/cpp-sp.git] / shibsp / remoting / impl / SocketListener.cpp
1 /*
2  *  Copyright 2001-2007 Internet2
3  * 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /**
18  * SocketListener.cpp
19  * 
20  * Berkeley Socket-based ListenerService implementation
21  */
22
23 #include "internal.h"
24 #include "exceptions.h"
25 #include "remoting/impl/SocketListener.h"
26
27 #include <errno.h>
28 #include <stack>
29 #include <sstream>
30 #include <shibsp/SPConfig.h>
31 #include <xmltooling/util/NDC.h>
32
33 #ifndef WIN32
34 # include <netinet/in.h>
35 #endif
36
37 using namespace shibsp;
38 using namespace xmltooling;
39 using namespace std;
40 using xercesc::DOMElement;
41
42 namespace shibsp {
43   
44     // Manages the pool of connections
45     class SocketPool
46     {
47     public:
48         SocketPool(Category& log, const SocketListener* listener)
49             : m_log(log), m_listener(listener), m_lock(Mutex::create()) {}
50         ~SocketPool();
51         SocketListener::ShibSocket get();
52         void put(SocketListener::ShibSocket s);
53   
54     private:
55         SocketListener::ShibSocket connect();
56        
57         Category& m_log; 
58         const SocketListener* m_listener;
59         auto_ptr<Mutex> m_lock;
60         stack<SocketListener::ShibSocket> m_pool;
61     };
62   
63     // Worker threads in server
64     class ServerThread {
65     public:
66         ServerThread(SocketListener::ShibSocket& s, SocketListener* listener, unsigned long id);
67         ~ServerThread();
68         void run();
69         int job();  // Return -1 on error, 1 for closed, 0 for success
70
71     private:
72         SocketListener::ShibSocket m_sock;
73         Thread* m_child;
74         SocketListener* m_listener;
75         string m_id;
76         char m_buf[16384];
77     };
78 }
79
80 SocketListener::ShibSocket SocketPool::connect()
81 {
82 #ifdef _DEBUG
83     NDC ndc("connect");
84 #endif
85
86     m_log.debug("trying to connect to listener");
87
88     SocketListener::ShibSocket sock;
89     if (!m_listener->create(sock)) {
90         m_log.error("cannot create socket");
91         throw ListenerException("Cannot create socket");
92     }
93
94     bool connected = false;
95     int num_tries = 3;
96
97     for (int i = num_tries-1; i >= 0; i--) {
98         if (m_listener->connect(sock)) {
99             connected = true;
100             break;
101         }
102     
103         m_log.warn("cannot connect socket (%u)...%s", sock, (i > 0 ? "retrying" : ""));
104
105         if (i) {
106 #ifdef WIN32
107             Sleep(2000*(num_tries-i));
108 #else
109             sleep(2*(num_tries-i));
110 #endif
111         }
112     }
113
114     if (!connected) {
115         m_log.crit("socket server unavailable, failing");
116         m_listener->close(sock);
117         throw ListenerException("Cannot connect to shibd process, a site adminstrator should be notified.");
118     }
119
120     m_log.debug("socket (%u) connected successfully", sock);
121     return sock;
122 }
123
124 SocketPool::~SocketPool()
125 {
126     while (!m_pool.empty()) {
127 #ifdef WIN32
128         closesocket(m_pool.top());
129 #else
130         ::close(m_pool.top());
131 #endif
132         m_pool.pop();
133     }
134 }
135
136 SocketListener::ShibSocket SocketPool::get()
137 {
138     m_lock->lock();
139     if (m_pool.empty()) {
140         m_lock->unlock();
141         return connect();
142     }
143     SocketListener::ShibSocket ret=m_pool.top();
144     m_pool.pop();
145     m_lock->unlock();
146     return ret;
147 }
148
149 void SocketPool::put(SocketListener::ShibSocket s)
150 {
151     m_lock->lock();
152     m_pool.push(s);
153     m_lock->unlock();
154 }
155
156 SocketListener::SocketListener(const DOMElement* e) : log(&Category::getInstance(SHIBSP_LOGCAT".Listener")),
157     m_socketpool(NULL), m_shutdown(NULL), m_child_lock(NULL), m_child_wait(NULL), m_socket((ShibSocket)0)
158 {
159     // Are we a client?
160     if (SPConfig::getConfig().isEnabled(SPConfig::InProcess)) {
161         m_socketpool=new SocketPool(*log,this);
162     }
163     // Are we a server?
164     if (SPConfig::getConfig().isEnabled(SPConfig::OutOfProcess)) {
165         m_child_lock = Mutex::create();
166         m_child_wait = CondWait::create();
167     }
168 }
169
170 SocketListener::~SocketListener()
171 {
172     delete m_socketpool;
173     delete m_child_wait;
174     delete m_child_lock;
175 }
176
177 bool SocketListener::run(bool* shutdown)
178 {
179 #ifdef _DEBUG
180     NDC ndc("run");
181 #endif
182     log->info("listener service starting");
183
184     // Save flag to monitor for shutdown request.
185     m_shutdown=shutdown;
186     unsigned long count = 0;
187
188     if (!create(m_socket)) {
189         log->crit("failed to create socket");
190         return false;
191     }
192     if (!bind(m_socket,true)) {
193         this->close(m_socket);
194         log->crit("failed to bind to socket.");
195         return false;
196     }
197
198     while (!*m_shutdown) {
199         fd_set readfds;
200         FD_ZERO(&readfds);
201         FD_SET(m_socket, &readfds);
202         struct timeval tv = { 0, 0 };
203         tv.tv_sec = 5;
204     
205         switch (select(m_socket + 1, &readfds, 0, 0, &tv)) {
206 #ifdef WIN32
207             case SOCKET_ERROR:
208 #else
209             case -1:
210 #endif
211                 if (errno == EINTR) continue;
212                 log_error();
213                 log->error("select() on main listener socket failed");
214                 return false;
215         
216             case 0:
217                 continue;
218         
219             default:
220             {
221                 // Accept the connection.
222                 SocketListener::ShibSocket newsock;
223                 if (!accept(m_socket, newsock))
224                     log->crit("failed to accept incoming socket connection");
225
226                 // We throw away the result because the children manage themselves...
227                 try {
228                     new ServerThread(newsock,this,++count);
229                 }
230                 catch (...) {
231                     log->crit("error starting new server thread to service incoming request");
232                 }
233             }
234         }
235     }
236     log->info("listener service shutting down");
237
238     // Wait for all children to exit.
239     m_child_lock->lock();
240     while (!m_children.empty())
241         m_child_wait->wait(m_child_lock);
242     m_child_lock->unlock();
243
244     this->close(m_socket);
245     m_socket=(ShibSocket)0;
246     return true;
247 }
248
249 DDF SocketListener::send(const DDF& in)
250 {
251 #ifdef _DEBUG
252     NDC ndc("send");
253 #endif
254
255     log->debug("sending message (%s)", in.name() ? in.name() : "unnamed");
256
257     // Serialize data for transmission.
258     ostringstream os;
259     os << in;
260     string ostr(os.str());
261
262     // Loop on the RPC in case we lost contact the first time through
263 #ifdef WIN32
264     u_long len;
265 #else
266     uint32_t len;
267 #endif
268     int retry = 1;
269     SocketListener::ShibSocket sock;
270     while (retry >= 0) {
271         sock = m_socketpool->get();
272         
273         int outlen = ostr.length();
274         len = htonl(outlen);
275         if (send(sock,(char*)&len,sizeof(len)) != sizeof(len) || send(sock,ostr.c_str(),outlen) != outlen) {
276             log_error();
277             this->close(sock);
278             if (retry)
279                 retry--;
280             else
281                 throw ListenerException("Failure sending remoted message ($1).", params(1,in.name()));
282         }
283         else {
284             // SUCCESS.
285             retry = -1;
286         }
287     }
288
289     log->debug("send completed, reading response message");
290
291     // Read the message.
292     if (recv(sock,(char*)&len,sizeof(len)) != sizeof(len)) {
293         log->error("error reading size of output message");
294         this->close(sock);
295         throw ListenerException("Failure receiving response to remoted message ($1).", params(1,in.name()));
296     }
297     len = ntohl(len);
298     
299     char buf[16384];
300     int size_read;
301     stringstream is;
302     while (len && (size_read = recv(sock, buf, sizeof(buf))) > 0) {
303         is.write(buf, size_read);
304         len -= size_read;
305     }
306     
307     if (len) {
308         log->error("error reading output message from socket");
309         this->close(sock);
310         throw ListenerException("Failure receiving response to remoted message ($1).", params(1,in.name()));
311     }
312     
313     m_socketpool->put(sock);
314
315     // Unmarshall data.
316     DDF out;
317     is >> out;
318     
319     // Check for exception to unmarshall and throw, otherwise return.
320     if (out.isstring() && out.name() && !strcmp(out.name(),"exception")) {
321         // Reconstitute exception object.
322         DDFJanitor jout(out);
323         XMLToolingException* except=NULL;
324         try { 
325             except=XMLToolingException::fromString(out.string());
326             log->error("remoted message returned an error: %s", except->what());
327         }
328         catch (XMLToolingException& e) {
329             log->error("caught XMLToolingException while building the XMLToolingException: %s", e.what());
330             log->error("XML was: %s", out.string());
331             throw ListenerException("Remote call failed with an unparsable exception.");
332         }
333
334         auto_ptr<XMLToolingException> wrapper(except);
335         wrapper->raise();
336     }
337
338     return out;
339 }
340
341 bool SocketListener::log_error() const
342 {
343 #ifdef WIN32
344     int rc=WSAGetLastError();
345 #else
346     int rc=errno;
347 #endif
348 #ifdef HAVE_STRERROR_R
349     char buf[256];
350     memset(buf,0,sizeof(buf));
351     strerror_r(rc,buf,sizeof(buf));
352     log->error("socket call resulted in error (%d): %s",rc,isprint(*buf) ? buf : "no message");
353 #else
354     const char* buf=strerror(rc);
355     log->error("socket call resulted in error (%d): %s",rc,isprint(*buf) ? buf : "no message");
356 #endif
357     return false;
358 }
359
360 // actual function run in listener on server threads
361 void* server_thread_fn(void* arg)
362 {
363     ServerThread* child = (ServerThread*)arg;
364
365 #ifndef WIN32
366     // First, let's block all signals
367     Thread::mask_all_signals();
368 #endif
369
370     // Run the child until it exits.
371     child->run();
372
373     // Now we can clean up and exit the thread.
374     delete child;
375     return NULL;
376 }
377
378 ServerThread::ServerThread(SocketListener::ShibSocket& s, SocketListener* listener, unsigned long id)
379     : m_sock(s), m_child(NULL), m_listener(listener)
380 {
381
382     ostringstream buf;
383     buf << "[" << id << "]";
384     m_id = buf.str();
385
386     // Create the child thread
387     m_child = Thread::create(server_thread_fn, (void*)this);
388     m_child->detach();
389 }
390
391 ServerThread::~ServerThread()
392 {
393     // Then lock the children map, remove this socket/thread, signal waiters, and return
394     m_listener->m_child_lock->lock();
395     m_listener->m_children.erase(m_sock);
396     m_listener->m_child_lock->unlock();
397     m_listener->m_child_wait->signal();
398   
399     delete m_child;
400 }
401
402 void ServerThread::run()
403 {
404     NDC ndc(m_id);
405
406     // Before starting up, make sure we fully "own" this socket.
407     m_listener->m_child_lock->lock();
408     while (m_listener->m_children.find(m_sock)!=m_listener->m_children.end())
409         m_listener->m_child_wait->wait(m_listener->m_child_lock);
410     m_listener->m_children[m_sock] = m_child;
411     m_listener->m_child_lock->unlock();
412     
413     int result;
414     fd_set readfds;
415     struct timeval tv = { 0, 0 };
416
417     while(!*(m_listener->m_shutdown)) {
418         FD_ZERO(&readfds);
419         FD_SET(m_sock, &readfds);
420         tv.tv_sec = 1;
421
422         switch (select(m_sock+1, &readfds, 0, 0, &tv)) {
423 #ifdef WIN32
424         case SOCKET_ERROR:
425 #else
426         case -1:
427 #endif
428             if (errno == EINTR) continue;
429             m_listener->log_error();
430             m_listener->log->error("select() on incoming request socket (%u) returned error", m_sock);
431             return;
432
433         case 0:
434             break;
435
436         default:
437             result = job();
438             if (result) {
439                 if (result < 0) {
440                     m_listener->log_error();
441                     m_listener->log->error("I/O failure processing request on socket (%u)", m_sock);
442                 }
443                 m_listener->close(m_sock);
444                 return;
445             }
446         }
447     }
448 }
449
450 int ServerThread::job()
451 {
452     Category& log = Category::getInstance("shibd.Listener");
453
454     bool incomingError = true;  // set false once incoming message is received
455     ostringstream sink;
456 #ifdef WIN32
457     u_long len;
458 #else
459     uint32_t len;
460 #endif
461
462     try {
463         // Read the message.
464         int readlength = m_listener->recv(m_sock,(char*)&len,sizeof(len));
465         if (readlength == 0) {
466             log.info("detected socket closure, shutting down worker thread");
467             return 1;
468         }
469         else if (readlength != sizeof(len)) {
470             log.error("error reading size of input message");
471             return -1;
472         }
473         len = ntohl(len);
474         
475         int size_read;
476         stringstream is;
477         while (len && (size_read = m_listener->recv(m_sock, m_buf, sizeof(m_buf))) > 0) {
478             is.write(m_buf, size_read);
479             len -= size_read;
480         }
481         
482         if (len) {
483             log.error("error reading input message from socket");
484             return -1;
485         }
486         
487         // Unmarshall the message.
488         DDF in;
489         DDFJanitor jin(in);
490         is >> in;
491
492         log.debug("dispatching message (%s)", in.name() ? in.name() : "unnamed");
493
494         incomingError = false;
495
496         // Dispatch the message.
497         m_listener->receive(in, sink);
498     }
499     catch (XMLToolingException& e) {
500         if (incomingError)
501             log.error("error processing incoming message: %s", e.what());
502         DDF out=DDF("exception").string(e.toString().c_str());
503         DDFJanitor jout(out);
504         sink << out;
505     }
506     catch (exception& e) {
507         if (incomingError)
508             log.error("error processing incoming message: %s", e.what());
509         ListenerException ex(e.what());
510         DDF out=DDF("exception").string(ex.toString().c_str());
511         DDFJanitor jout(out);
512         sink << out;
513     }
514 #ifndef _DEBUG
515     catch (...) {
516         if (incomingError)
517             log.error("unexpected error processing incoming message");
518         ListenerException ex("An unexpected error occurred while processing an incoming message.");
519         DDF out=DDF("exception").string(ex.toString().c_str());
520         DDFJanitor jout(out);
521         sink << out;
522     }
523 #endif
524     
525     // Return whatever's available.
526     string response(sink.str());
527     int outlen = response.length();
528     len = htonl(outlen);
529     if (m_listener->send(m_sock,(char*)&len,sizeof(len)) != sizeof(len)) {
530         log.error("error sending output message size");
531         return -1;
532     }
533     if (m_listener->send(m_sock,response.c_str(),outlen) != outlen) {
534         log.error("error sending output message");
535         return -1;
536     }
537     
538     return 0;
539 }