Convert logging to log4shib via compile time switch.
[shibboleth/cpp-sp.git] / shibsp / remoting / impl / SocketListener.cpp
1 /*
2  *  Copyright 2001-2007 Internet2
3  * 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /**
18  * SocketListener.cpp
19  * 
20  * Berkeley Socket-based ListenerService implementation
21  */
22
23 #include "internal.h"
24 #include "exceptions.h"
25 #include "remoting/impl/SocketListener.h"
26
27 #include <errno.h>
28 #include <stack>
29 #include <sstream>
30 #include <shibsp/SPConfig.h>
31 #include <xmltooling/util/NDC.h>
32
33 #ifndef WIN32
34 # include <netinet/in.h>
35 #endif
36
37 using namespace shibsp;
38 using namespace xmltooling;
39 using namespace std;
40 using xercesc::DOMElement;
41
42 namespace shibsp {
43   
44     // Manages the pool of connections
45     class SocketPool
46     {
47     public:
48         SocketPool(Category& log, const SocketListener* listener)
49             : m_log(log), m_listener(listener), m_lock(Mutex::create()) {}
50         ~SocketPool();
51         SocketListener::ShibSocket get();
52         void put(SocketListener::ShibSocket s);
53   
54     private:
55         SocketListener::ShibSocket connect();
56        
57         Category& m_log; 
58         const SocketListener* m_listener;
59         auto_ptr<Mutex> m_lock;
60         stack<SocketListener::ShibSocket> m_pool;
61     };
62   
63     // Worker threads in server
64     class ServerThread {
65     public:
66         ServerThread(SocketListener::ShibSocket& s, SocketListener* listener, unsigned long id);
67         ~ServerThread();
68         void run();
69         bool job();
70
71     private:
72         SocketListener::ShibSocket m_sock;
73         Thread* m_child;
74         SocketListener* m_listener;
75         string m_id;
76         char m_buf[16384];
77     };
78 }
79
80 SocketListener::ShibSocket SocketPool::connect()
81 {
82 #ifdef _DEBUG
83     NDC ndc("connect");
84 #endif
85
86     m_log.debug("trying to connect to listener");
87
88     SocketListener::ShibSocket sock;
89     if (!m_listener->create(sock)) {
90         m_log.error("cannot create socket");
91         throw ListenerException("Cannot create socket");
92     }
93
94     bool connected = false;
95     int num_tries = 3;
96
97     for (int i = num_tries-1; i >= 0; i--) {
98         if (m_listener->connect(sock)) {
99             connected = true;
100             break;
101         }
102     
103         m_log.warn("cannot connect socket (%u)...%s", sock, (i > 0 ? "retrying" : ""));
104
105         if (i) {
106 #ifdef WIN32
107             Sleep(2000*(num_tries-i));
108 #else
109             sleep(2*(num_tries-i));
110 #endif
111         }
112     }
113
114     if (!connected) {
115         m_log.crit("socket server unavailable, failing");
116         m_listener->close(sock);
117         throw ListenerException("Cannot connect to shibd process, a site adminstrator should be notified.");
118     }
119
120     m_log.debug("socket (%u) connected successfully", sock);
121     return sock;
122 }
123
124 SocketPool::~SocketPool()
125 {
126     while (!m_pool.empty()) {
127 #ifdef WIN32
128         closesocket(m_pool.top());
129 #else
130         ::close(m_pool.top());
131 #endif
132         m_pool.pop();
133     }
134 }
135
136 SocketListener::ShibSocket SocketPool::get()
137 {
138     m_lock->lock();
139     if (m_pool.empty()) {
140         m_lock->unlock();
141         return connect();
142     }
143     SocketListener::ShibSocket ret=m_pool.top();
144     m_pool.pop();
145     m_lock->unlock();
146     return ret;
147 }
148
149 void SocketPool::put(SocketListener::ShibSocket s)
150 {
151     m_lock->lock();
152     m_pool.push(s);
153     m_lock->unlock();
154 }
155
156 SocketListener::SocketListener(const DOMElement* e) : log(&Category::getInstance(SHIBSP_LOGCAT".Listener")),
157     m_socketpool(NULL), m_shutdown(NULL), m_child_lock(NULL), m_child_wait(NULL), m_socket((ShibSocket)0)
158 {
159     // Are we a client?
160     if (SPConfig::getConfig().isEnabled(SPConfig::InProcess)) {
161         m_socketpool=new SocketPool(*log,this);
162     }
163     // Are we a server?
164     if (SPConfig::getConfig().isEnabled(SPConfig::OutOfProcess)) {
165         m_child_lock = Mutex::create();
166         m_child_wait = CondWait::create();
167     }
168 }
169
170 SocketListener::~SocketListener()
171 {
172     delete m_socketpool;
173     delete m_child_wait;
174     delete m_child_lock;
175 }
176
177 bool SocketListener::run(bool* shutdown)
178 {
179 #ifdef _DEBUG
180     NDC ndc("run");
181 #endif
182     log->info("listener service starting");
183
184     // Save flag to monitor for shutdown request.
185     m_shutdown=shutdown;
186     unsigned long count = 0;
187
188     if (!create(m_socket)) {
189         log->crit("failed to create socket");
190         return false;
191     }
192     if (!bind(m_socket,true)) {
193         this->close(m_socket);
194         log->crit("failed to bind to socket.");
195         return false;
196     }
197
198     while (!*m_shutdown) {
199         fd_set readfds;
200         FD_ZERO(&readfds);
201         FD_SET(m_socket, &readfds);
202         struct timeval tv = { 0, 0 };
203         tv.tv_sec = 5;
204     
205         switch (select(m_socket + 1, &readfds, 0, 0, &tv)) {
206 #ifdef WIN32
207             case SOCKET_ERROR:
208 #else
209             case -1:
210 #endif
211                 if (errno == EINTR) continue;
212                 log_error();
213                 log->error("select() on main listener socket failed");
214                 return false;
215         
216             case 0:
217                 continue;
218         
219             default:
220             {
221                 // Accept the connection.
222                 SocketListener::ShibSocket newsock;
223                 if (!accept(m_socket, newsock))
224                     log->crit("failed to accept incoming socket connection");
225
226                 // We throw away the result because the children manage themselves...
227                 try {
228                     new ServerThread(newsock,this,++count);
229                 }
230                 catch (...) {
231                     log->crit("error starting new server thread to service incoming request");
232                 }
233             }
234         }
235     }
236     log->info("listener service shutting down");
237
238     // Wait for all children to exit.
239     m_child_lock->lock();
240     while (!m_children.empty())
241         m_child_wait->wait(m_child_lock);
242     m_child_lock->unlock();
243
244     this->close(m_socket);
245     m_socket=(ShibSocket)0;
246     return true;
247 }
248
249 DDF SocketListener::send(const DDF& in)
250 {
251 #ifdef _DEBUG
252     NDC ndc("send");
253 #endif
254
255     log->debug("sending message (%s)", in.name() ? in.name() : "unnamed");
256
257     // Serialize data for transmission.
258     ostringstream os;
259     os << in;
260     string ostr(os.str());
261
262     // Loop on the RPC in case we lost contact the first time through
263 #ifdef WIN32
264     u_long len;
265 #else
266     uint32_t len;
267 #endif
268     int retry = 1;
269     SocketListener::ShibSocket sock;
270     while (retry >= 0) {
271         sock = m_socketpool->get();
272         
273         int outlen = ostr.length();
274         len = htonl(outlen);
275         if (send(sock,(char*)&len,sizeof(len)) != sizeof(len) || send(sock,ostr.c_str(),outlen) != outlen) {
276             log_error();
277             this->close(sock);
278             if (retry)
279                 retry--;
280             else
281                 throw ListenerException("Failure sending remoted message ($1).", params(1,in.name()));
282         }
283         else {
284             // SUCCESS.
285             retry = -1;
286         }
287     }
288
289     log->debug("send completed, reading response message");
290
291     // Read the message.
292     if (recv(sock,(char*)&len,sizeof(len)) != sizeof(len)) {
293         log->error("error reading size of output message");
294         this->close(sock);
295         throw ListenerException("Failure receiving response to remoted message ($1).", params(1,in.name()));
296     }
297     len = ntohl(len);
298     
299     char buf[16384];
300     int size_read;
301     stringstream is;
302     while (len && (size_read = recv(sock, buf, sizeof(buf))) > 0) {
303         is.write(buf, size_read);
304         len -= size_read;
305     }
306     
307     if (len) {
308         log->error("error reading output message from socket");
309         this->close(sock);
310         throw ListenerException("Failure receiving response to remoted message ($1).", params(1,in.name()));
311     }
312     
313     m_socketpool->put(sock);
314
315     // Unmarshall data.
316     DDF out;
317     is >> out;
318     
319     // Check for exception to unmarshall and throw, otherwise return.
320     if (out.isstring() && out.name() && !strcmp(out.name(),"exception")) {
321         // Reconstitute exception object.
322         DDFJanitor jout(out);
323         XMLToolingException* except=NULL;
324         try { 
325             except=XMLToolingException::fromString(out.string());
326             log->error("remoted message returned an error: %s", except->what());
327         }
328         catch (XMLToolingException& e) {
329             log->error("caught XMLToolingException while building the XMLToolingException: %s", e.what());
330             log->error("XML was: %s", out.string());
331             throw ListenerException("Remote call failed with an unparsable exception.");
332         }
333
334         auto_ptr<XMLToolingException> wrapper(except);
335         wrapper->raise();
336     }
337
338     return out;
339 }
340
341 bool SocketListener::log_error() const
342 {
343 #ifdef WIN32
344     int rc=WSAGetLastError();
345 #else
346     int rc=errno;
347 #endif
348 #ifdef HAVE_STRERROR_R
349     char buf[256];
350     memset(buf,0,sizeof(buf));
351     strerror_r(rc,buf,sizeof(buf));
352     log->error("socket call resulted in error (%d): %s",rc,isprint(*buf) ? buf : "no message");
353 #else
354     const char* buf=strerror(rc);
355     log->error("socket call resulted in error (%d): %s",rc,isprint(*buf) ? buf : "no message");
356 #endif
357     return false;
358 }
359
360 // actual function run in listener on server threads
361 void* server_thread_fn(void* arg)
362 {
363     ServerThread* child = (ServerThread*)arg;
364
365 #ifndef WIN32
366     // First, let's block all signals
367     Thread::mask_all_signals();
368 #endif
369
370     // Run the child until it exits.
371     child->run();
372
373     // Now we can clean up and exit the thread.
374     delete child;
375     return NULL;
376 }
377
378 ServerThread::ServerThread(SocketListener::ShibSocket& s, SocketListener* listener, unsigned long id)
379     : m_sock(s), m_child(NULL), m_listener(listener)
380 {
381
382     ostringstream buf;
383     buf << "[" << id << "]";
384     m_id = buf.str();
385
386     // Create the child thread
387     m_child = Thread::create(server_thread_fn, (void*)this);
388     m_child->detach();
389 }
390
391 ServerThread::~ServerThread()
392 {
393     // Then lock the children map, remove this socket/thread, signal waiters, and return
394     m_listener->m_child_lock->lock();
395     m_listener->m_children.erase(m_sock);
396     m_listener->m_child_lock->unlock();
397     m_listener->m_child_wait->signal();
398   
399     delete m_child;
400 }
401
402 void ServerThread::run()
403 {
404     NDC ndc(m_id);
405
406     // Before starting up, make sure we fully "own" this socket.
407     m_listener->m_child_lock->lock();
408     while (m_listener->m_children.find(m_sock)!=m_listener->m_children.end())
409         m_listener->m_child_wait->wait(m_listener->m_child_lock);
410     m_listener->m_children[m_sock] = m_child;
411     m_listener->m_child_lock->unlock();
412     
413     fd_set readfds;
414     struct timeval tv = { 0, 0 };
415
416     while(!*(m_listener->m_shutdown)) {
417         FD_ZERO(&readfds);
418         FD_SET(m_sock, &readfds);
419         tv.tv_sec = 1;
420
421         switch (select(m_sock+1, &readfds, 0, 0, &tv)) {
422 #ifdef WIN32
423         case SOCKET_ERROR:
424 #else
425         case -1:
426 #endif
427             if (errno == EINTR) continue;
428             m_listener->log_error();
429             m_listener->log->error("select() on incoming request socket (%u) returned error", m_sock);
430             return;
431
432         case 0:
433             break;
434
435         default:
436             if (!job()) {
437                 m_listener->log_error();
438                 m_listener->log->error("I/O failure processing request on socket (%u)", m_sock);
439                 m_listener->close(m_sock);
440                 return;
441             }
442         }
443     }
444 }
445
446 bool ServerThread::job()
447 {
448     Category& log = Category::getInstance("shibd.Listener");
449
450     bool incomingError = true;  // set false once incoming message is received
451     ostringstream sink;
452 #ifdef WIN32
453     u_long len;
454 #else
455     uint32_t len;
456 #endif
457
458     try {
459         // Read the message.
460         if (m_listener->recv(m_sock,(char*)&len,sizeof(len)) != sizeof(len)) {
461             log.error("error reading size of input message");
462             return false;
463         }
464         len = ntohl(len);
465         
466         int size_read;
467         stringstream is;
468         while (len && (size_read = m_listener->recv(m_sock, m_buf, sizeof(m_buf))) > 0) {
469             is.write(m_buf, size_read);
470             len -= size_read;
471         }
472         
473         if (len) {
474             log.error("error reading input message from socket");
475             return false;
476         }
477         
478         // Unmarshall the message.
479         DDF in;
480         DDFJanitor jin(in);
481         is >> in;
482
483         log.debug("dispatching message (%s)", in.name() ? in.name() : "unnamed");
484
485         incomingError = false;
486
487         // Dispatch the message.
488         m_listener->receive(in, sink);
489     }
490     catch (XMLToolingException& e) {
491         if (incomingError)
492             log.error("error processing incoming message: %s", e.what());
493         DDF out=DDF("exception").string(e.toString().c_str());
494         DDFJanitor jout(out);
495         sink << out;
496     }
497     catch (exception& e) {
498         if (incomingError)
499             log.error("error processing incoming message: %s", e.what());
500         ListenerException ex(e.what());
501         DDF out=DDF("exception").string(ex.toString().c_str());
502         DDFJanitor jout(out);
503         sink << out;
504     }
505 #ifndef _DEBUG
506     catch (...) {
507         if (incomingError)
508             log.error("unexpected error processing incoming message");
509         ListenerException ex("An unexpected error occurred while processing an incoming message.");
510         DDF out=DDF("exception").string(ex.toString().c_str());
511         DDFJanitor jout(out);
512         sink << out;
513     }
514 #endif
515     
516     // Return whatever's available.
517     string response(sink.str());
518     int outlen = response.length();
519     len = htonl(outlen);
520     if (m_listener->send(m_sock,(char*)&len,sizeof(len)) != sizeof(len)) {
521         log.error("error sending output message size");
522         return false;
523     }
524     if (m_listener->send(m_sock,response.c_str(),outlen) != outlen) {
525         log.error("error sending output message");
526         return false;
527     }
528     
529     return true;
530 }