Fix log category.
[shibboleth/sp.git] / shibsp / remoting / impl / SocketListener.cpp
1 /*
2  *  Copyright 2001-2007 Internet2
3  * 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /**
18  * SocketListener.cpp
19  * 
20  * Berkeley Socket-based ListenerService implementation
21  */
22
23 #include "internal.h"
24 #include "exceptions.h"
25 #include "ServiceProvider.h"
26 #include "remoting/impl/SocketListener.h"
27
28 #include <errno.h>
29 #include <stack>
30 #include <sstream>
31 #include <shibsp/SPConfig.h>
32 #include <xmltooling/util/NDC.h>
33
34 #ifndef WIN32
35 # include <netinet/in.h>
36 #endif
37
38 #ifdef HAVE_UNISTD_H
39 # include <unistd.h>
40 #endif
41
42 using namespace shibsp;
43 using namespace xmltooling;
44 using namespace std;
45 using xercesc::DOMElement;
46
47 namespace shibsp {
48   
49     // Manages the pool of connections
50     class SocketPool
51     {
52     public:
53         SocketPool(Category& log, const SocketListener* listener)
54             : m_log(log), m_listener(listener), m_lock(Mutex::create()) {}
55         ~SocketPool();
56         SocketListener::ShibSocket get();
57         void put(SocketListener::ShibSocket s);
58   
59     private:
60         SocketListener::ShibSocket connect();
61        
62         Category& m_log; 
63         const SocketListener* m_listener;
64         auto_ptr<Mutex> m_lock;
65         stack<SocketListener::ShibSocket> m_pool;
66     };
67   
68     // Worker threads in server
69     class ServerThread {
70     public:
71         ServerThread(SocketListener::ShibSocket& s, SocketListener* listener, unsigned long id);
72         ~ServerThread();
73         void run();
74         int job();  // Return -1 on error, 1 for closed, 0 for success
75
76     private:
77         SocketListener::ShibSocket m_sock;
78         Thread* m_child;
79         SocketListener* m_listener;
80         string m_id;
81         char m_buf[16384];
82     };
83 }
84
85 SocketListener::ShibSocket SocketPool::connect()
86 {
87 #ifdef _DEBUG
88     NDC ndc("connect");
89 #endif
90
91     m_log.debug("trying to connect to listener");
92
93     SocketListener::ShibSocket sock;
94     if (!m_listener->create(sock)) {
95         m_log.error("cannot create socket");
96         throw ListenerException("Cannot create socket");
97     }
98
99     bool connected = false;
100     int num_tries = 3;
101
102     for (int i = num_tries-1; i >= 0; i--) {
103         if (m_listener->connect(sock)) {
104             connected = true;
105             break;
106         }
107     
108         m_log.warn("cannot connect socket (%u)...%s", sock, (i > 0 ? "retrying" : ""));
109
110         if (i) {
111 #ifdef WIN32
112             Sleep(2000*(num_tries-i));
113 #else
114             sleep(2*(num_tries-i));
115 #endif
116         }
117     }
118
119     if (!connected) {
120         m_log.crit("socket server unavailable, failing");
121         m_listener->close(sock);
122         throw ListenerException("Cannot connect to shibd process, a site adminstrator should be notified.");
123     }
124
125     m_log.debug("socket (%u) connected successfully", sock);
126     return sock;
127 }
128
129 SocketPool::~SocketPool()
130 {
131     while (!m_pool.empty()) {
132 #ifdef WIN32
133         closesocket(m_pool.top());
134 #else
135         ::close(m_pool.top());
136 #endif
137         m_pool.pop();
138     }
139 }
140
141 SocketListener::ShibSocket SocketPool::get()
142 {
143     m_lock->lock();
144     if (m_pool.empty()) {
145         m_lock->unlock();
146         return connect();
147     }
148     SocketListener::ShibSocket ret=m_pool.top();
149     m_pool.pop();
150     m_lock->unlock();
151     return ret;
152 }
153
154 void SocketPool::put(SocketListener::ShibSocket s)
155 {
156     m_lock->lock();
157     m_pool.push(s);
158     m_lock->unlock();
159 }
160
161 SocketListener::SocketListener(const DOMElement* e) : m_catchAll(false), log(&Category::getInstance(SHIBSP_LOGCAT".Listener")),
162     m_socketpool(NULL), m_shutdown(NULL), m_child_lock(NULL), m_child_wait(NULL), m_socket((ShibSocket)0)
163 {
164     // Are we a client?
165     if (SPConfig::getConfig().isEnabled(SPConfig::InProcess)) {
166         m_socketpool=new SocketPool(*log,this);
167     }
168     // Are we a server?
169     if (SPConfig::getConfig().isEnabled(SPConfig::OutOfProcess)) {
170         m_child_lock = Mutex::create();
171         m_child_wait = CondWait::create();
172     }
173 }
174
175 SocketListener::~SocketListener()
176 {
177     delete m_socketpool;
178     delete m_child_wait;
179     delete m_child_lock;
180 }
181
182 bool SocketListener::run(bool* shutdown)
183 {
184 #ifdef _DEBUG
185     NDC ndc("run");
186 #endif
187     log->info("listener service starting");
188
189     ServiceProvider* sp = SPConfig::getConfig().getServiceProvider();
190     sp->lock();
191     const PropertySet* props = sp->getPropertySet("OutOfProcess");
192     if (props) {
193         pair<bool,bool> flag = props->getBool("catchAll");
194         m_catchAll = flag.first && flag.second;
195     }
196     sp->unlock();
197     
198     // Save flag to monitor for shutdown request.
199     m_shutdown=shutdown;
200     unsigned long count = 0;
201
202     if (!create(m_socket)) {
203         log->crit("failed to create socket");
204         return false;
205     }
206     if (!bind(m_socket,true)) {
207         this->close(m_socket);
208         log->crit("failed to bind to socket.");
209         return false;
210     }
211
212     while (!*m_shutdown) {
213         fd_set readfds;
214         FD_ZERO(&readfds);
215         FD_SET(m_socket, &readfds);
216         struct timeval tv = { 0, 0 };
217         tv.tv_sec = 5;
218     
219         switch (select(m_socket + 1, &readfds, 0, 0, &tv)) {
220 #ifdef WIN32
221             case SOCKET_ERROR:
222 #else
223             case -1:
224 #endif
225                 if (errno == EINTR) continue;
226                 log_error();
227                 log->error("select() on main listener socket failed");
228                 return false;
229         
230             case 0:
231                 continue;
232         
233             default:
234             {
235                 // Accept the connection.
236                 SocketListener::ShibSocket newsock;
237                 if (!accept(m_socket, newsock))
238                     log->crit("failed to accept incoming socket connection");
239
240                 // We throw away the result because the children manage themselves...
241                 try {
242                     new ServerThread(newsock,this,++count);
243                 }
244                 catch (...) {
245                     log->crit("error starting new server thread to service incoming request");
246                     if (!m_catchAll)
247                         *m_shutdown = true;
248                 }
249             }
250         }
251     }
252     log->info("listener service shutting down");
253
254     // Wait for all children to exit.
255     m_child_lock->lock();
256     while (!m_children.empty())
257         m_child_wait->wait(m_child_lock);
258     m_child_lock->unlock();
259
260     this->close(m_socket);
261     m_socket=(ShibSocket)0;
262     return true;
263 }
264
265 DDF SocketListener::send(const DDF& in)
266 {
267 #ifdef _DEBUG
268     NDC ndc("send");
269 #endif
270
271     log->debug("sending message (%s)", in.name() ? in.name() : "unnamed");
272
273     // Serialize data for transmission.
274     ostringstream os;
275     os << in;
276     string ostr(os.str());
277
278     // Loop on the RPC in case we lost contact the first time through
279 #ifdef WIN32
280     u_long len;
281 #else
282     uint32_t len;
283 #endif
284     int retry = 1;
285     SocketListener::ShibSocket sock;
286     while (retry >= 0) {
287         sock = m_socketpool->get();
288         
289         int outlen = ostr.length();
290         len = htonl(outlen);
291         if (send(sock,(char*)&len,sizeof(len)) != sizeof(len) || send(sock,ostr.c_str(),outlen) != outlen) {
292             log_error();
293             this->close(sock);
294             if (retry)
295                 retry--;
296             else
297                 throw ListenerException("Failure sending remoted message ($1).", params(1,in.name()));
298         }
299         else {
300             // SUCCESS.
301             retry = -1;
302         }
303     }
304
305     log->debug("send completed, reading response message");
306
307     // Read the message.
308     if (recv(sock,(char*)&len,sizeof(len)) != sizeof(len)) {
309         log->error("error reading size of output message");
310         this->close(sock);
311         throw ListenerException("Failure receiving response to remoted message ($1).", params(1,in.name()));
312     }
313     len = ntohl(len);
314     
315     char buf[16384];
316     int size_read;
317     stringstream is;
318     while (len && (size_read = recv(sock, buf, sizeof(buf))) > 0) {
319         is.write(buf, size_read);
320         len -= size_read;
321     }
322     
323     if (len) {
324         log->error("error reading output message from socket");
325         this->close(sock);
326         throw ListenerException("Failure receiving response to remoted message ($1).", params(1,in.name()));
327     }
328     
329     m_socketpool->put(sock);
330
331     // Unmarshall data.
332     DDF out;
333     is >> out;
334     
335     // Check for exception to unmarshall and throw, otherwise return.
336     if (out.isstring() && out.name() && !strcmp(out.name(),"exception")) {
337         // Reconstitute exception object.
338         DDFJanitor jout(out);
339         XMLToolingException* except=NULL;
340         try { 
341             except=XMLToolingException::fromString(out.string());
342             log->error("remoted message returned an error: %s", except->what());
343         }
344         catch (XMLToolingException& e) {
345             log->error("caught XMLToolingException while building the XMLToolingException: %s", e.what());
346             log->error("XML was: %s", out.string());
347             throw ListenerException("Remote call failed with an unparsable exception.");
348         }
349
350         auto_ptr<XMLToolingException> wrapper(except);
351         wrapper->raise();
352     }
353
354     return out;
355 }
356
357 bool SocketListener::log_error() const
358 {
359 #ifdef WIN32
360     int rc=WSAGetLastError();
361 #else
362     int rc=errno;
363 #endif
364 #ifdef HAVE_STRERROR_R
365     char buf[256];
366     memset(buf,0,sizeof(buf));
367     strerror_r(rc,buf,sizeof(buf));
368     log->error("socket call resulted in error (%d): %s",rc,isprint(*buf) ? buf : "no message");
369 #else
370     const char* buf=strerror(rc);
371     log->error("socket call resulted in error (%d): %s",rc,isprint(*buf) ? buf : "no message");
372 #endif
373     return false;
374 }
375
376 // actual function run in listener on server threads
377 void* server_thread_fn(void* arg)
378 {
379     ServerThread* child = (ServerThread*)arg;
380
381 #ifndef WIN32
382     // First, let's block all signals
383     Thread::mask_all_signals();
384 #endif
385
386     // Run the child until it exits.
387     child->run();
388
389     // Now we can clean up and exit the thread.
390     delete child;
391     return NULL;
392 }
393
394 ServerThread::ServerThread(SocketListener::ShibSocket& s, SocketListener* listener, unsigned long id)
395     : m_sock(s), m_child(NULL), m_listener(listener)
396 {
397
398     ostringstream buf;
399     buf << "[" << id << "]";
400     m_id = buf.str();
401
402     // Create the child thread
403     m_child = Thread::create(server_thread_fn, (void*)this);
404     m_child->detach();
405 }
406
407 ServerThread::~ServerThread()
408 {
409     // Then lock the children map, remove this socket/thread, signal waiters, and return
410     m_listener->m_child_lock->lock();
411     m_listener->m_children.erase(m_sock);
412     m_listener->m_child_lock->unlock();
413     m_listener->m_child_wait->signal();
414   
415     delete m_child;
416 }
417
418 void ServerThread::run()
419 {
420     NDC ndc(m_id);
421
422     // Before starting up, make sure we fully "own" this socket.
423     m_listener->m_child_lock->lock();
424     while (m_listener->m_children.find(m_sock)!=m_listener->m_children.end())
425         m_listener->m_child_wait->wait(m_listener->m_child_lock);
426     m_listener->m_children[m_sock] = m_child;
427     m_listener->m_child_lock->unlock();
428     
429     int result;
430     fd_set readfds;
431     struct timeval tv = { 0, 0 };
432
433     while(!*(m_listener->m_shutdown)) {
434         FD_ZERO(&readfds);
435         FD_SET(m_sock, &readfds);
436         tv.tv_sec = 1;
437
438         switch (select(m_sock+1, &readfds, 0, 0, &tv)) {
439 #ifdef WIN32
440         case SOCKET_ERROR:
441 #else
442         case -1:
443 #endif
444             if (errno == EINTR) continue;
445             m_listener->log_error();
446             m_listener->log->error("select() on incoming request socket (%u) returned error", m_sock);
447             return;
448
449         case 0:
450             break;
451
452         default:
453             result = job();
454             if (result) {
455                 if (result < 0) {
456                     m_listener->log_error();
457                     m_listener->log->error("I/O failure processing request on socket (%u)", m_sock);
458                 }
459                 m_listener->close(m_sock);
460                 return;
461             }
462         }
463     }
464 }
465
466 int ServerThread::job()
467 {
468     Category& log = Category::getInstance(SHIBSP_LOGCAT".Listener");
469
470     bool incomingError = true;  // set false once incoming message is received
471     ostringstream sink;
472 #ifdef WIN32
473     u_long len;
474 #else
475     uint32_t len;
476 #endif
477
478     try {
479         // Read the message.
480         int readlength = m_listener->recv(m_sock,(char*)&len,sizeof(len));
481         if (readlength == 0) {
482             log.info("detected socket closure, shutting down worker thread");
483             return 1;
484         }
485         else if (readlength != sizeof(len)) {
486             log.error("error reading size of input message");
487             return -1;
488         }
489         len = ntohl(len);
490         
491         int size_read;
492         stringstream is;
493         while (len && (size_read = m_listener->recv(m_sock, m_buf, sizeof(m_buf))) > 0) {
494             is.write(m_buf, size_read);
495             len -= size_read;
496         }
497         
498         if (len) {
499             log.error("error reading input message from socket");
500             return -1;
501         }
502         
503         // Unmarshall the message.
504         DDF in;
505         DDFJanitor jin(in);
506         is >> in;
507
508         log.debug("dispatching message (%s)", in.name() ? in.name() : "unnamed");
509
510         incomingError = false;
511
512         // Dispatch the message.
513         m_listener->receive(in, sink);
514     }
515     catch (XMLToolingException& e) {
516         if (incomingError)
517             log.error("error processing incoming message: %s", e.what());
518         DDF out=DDF("exception").string(e.toString().c_str());
519         DDFJanitor jout(out);
520         sink << out;
521     }
522     catch (exception& e) {
523         if (incomingError)
524             log.error("error processing incoming message: %s", e.what());
525         ListenerException ex(e.what());
526         DDF out=DDF("exception").string(ex.toString().c_str());
527         DDFJanitor jout(out);
528         sink << out;
529     }
530     catch (...) {
531         if (incomingError)
532             log.error("unexpected error processing incoming message");
533         if (!m_listener->m_catchAll)
534             throw;
535         ListenerException ex("An unexpected error occurred while processing an incoming message.");
536         DDF out=DDF("exception").string(ex.toString().c_str());
537         DDFJanitor jout(out);
538         sink << out;
539     }
540     
541     // Return whatever's available.
542     string response(sink.str());
543     int outlen = response.length();
544     len = htonl(outlen);
545     if (m_listener->send(m_sock,(char*)&len,sizeof(len)) != sizeof(len)) {
546         log.error("error sending output message size");
547         return -1;
548     }
549     if (m_listener->send(m_sock,response.c_str(),outlen) != outlen) {
550         log.error("error sending output message");
551         return -1;
552     }
553     
554     return 0;
555 }