Move config logic into an "XML" SP plugin, divorce shibd and modules from old libs.
[shibboleth/cpp-sp.git] / shibsp / remoting / impl / SocketListener.cpp
1 /*
2  *  Copyright 2001-2007 Internet2
3  * 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /**
18  * SocketListener.cpp
19  * 
20  * Berkeley Socket-based ListenerService implementation
21  */
22
23 #include "internal.h"
24 #include "exceptions.h"
25 #include "remoting/impl/SocketListener.h"
26
27 #include <errno.h>
28 #include <stack>
29 #include <sstream>
30 #include <shibsp/SPConfig.h>
31 #include <xmltooling/util/NDC.h>
32
33 #ifdef HAVE_UNISTD_H
34 # include <unistd.h>
35 #endif
36
37 using namespace shibsp;
38 using namespace xmltooling;
39 using namespace log4cpp;
40 using namespace std;
41 using xercesc::DOMElement;
42
43 namespace shibsp {
44   
45     // Manages the pool of connections
46     class SocketPool
47     {
48     public:
49         SocketPool(Category& log, const SocketListener* listener)
50             : m_log(log), m_listener(listener), m_lock(Mutex::create()) {}
51         ~SocketPool();
52         SocketListener::ShibSocket get();
53         void put(SocketListener::ShibSocket s);
54   
55     private:
56         SocketListener::ShibSocket connect();
57         
58         const SocketListener* m_listener;
59         Category& m_log;
60         auto_ptr<Mutex> m_lock;
61         stack<SocketListener::ShibSocket> m_pool;
62     };
63   
64     // Worker threads in server
65     class ServerThread {
66     public:
67         ServerThread(SocketListener::ShibSocket& s, SocketListener* listener, unsigned long id);
68         ~ServerThread();
69         void run();
70         bool job();
71
72     private:
73         SocketListener::ShibSocket m_sock;
74         Thread* m_child;
75         SocketListener* m_listener;
76         string m_id;
77         char m_buf[16384];
78     };
79 }
80
81 SocketListener::ShibSocket SocketPool::connect()
82 {
83 #ifdef _DEBUG
84     NDC ndc("connect");
85 #endif
86
87     m_log.debug("trying to connect to listener");
88
89     SocketListener::ShibSocket sock;
90     if (!m_listener->create(sock)) {
91         m_log.error("cannot create socket");
92         throw ListenerException("Cannot create socket");
93     }
94
95     bool connected = false;
96     int num_tries = 3;
97
98     for (int i = num_tries-1; i >= 0; i--) {
99         if (m_listener->connect(sock)) {
100             connected = true;
101             break;
102         }
103     
104         m_log.warn("cannot connect socket (%u)...%s", sock, (i > 0 ? "retrying" : ""));
105
106         if (i) {
107 #ifdef WIN32
108             Sleep(2000*(num_tries-i));
109 #else
110             sleep(2*(num_tries-i));
111 #endif
112         }
113     }
114
115     if (!connected) {
116         m_log.crit("socket server unavailable, failing");
117         m_listener->close(sock);
118         throw ListenerException("Cannot connect to shibd process, a site adminstrator should be notified.");
119     }
120
121     m_log.debug("socket (%u) connected successfully", sock);
122     return sock;
123 }
124
125 SocketPool::~SocketPool()
126 {
127     while (!m_pool.empty()) {
128 #ifdef WIN32
129         closesocket(m_pool.top());
130 #else
131         ::close(m_pool.top());
132 #endif
133         m_pool.pop();
134     }
135 }
136
137 SocketListener::ShibSocket SocketPool::get()
138 {
139     m_lock->lock();
140     if (m_pool.empty()) {
141         m_lock->unlock();
142         return connect();
143     }
144     SocketListener::ShibSocket ret=m_pool.top();
145     m_pool.pop();
146     m_lock->unlock();
147     return ret;
148 }
149
150 void SocketPool::put(SocketListener::ShibSocket s)
151 {
152     m_lock->lock();
153     m_pool.push(s);
154     m_lock->unlock();
155 }
156
157 SocketListener::SocketListener(const DOMElement* e) : log(&Category::getInstance(SHIBSP_LOGCAT".Listener")),
158     m_shutdown(NULL), m_child_lock(NULL), m_child_wait(NULL), m_socketpool(NULL), m_socket((ShibSocket)0)
159 {
160     // Are we a client?
161     if (SPConfig::getConfig().isEnabled(SPConfig::InProcess)) {
162         m_socketpool=new SocketPool(*log,this);
163     }
164     // Are we a server?
165     if (SPConfig::getConfig().isEnabled(SPConfig::OutOfProcess)) {
166         m_child_lock = Mutex::create();
167         m_child_wait = CondWait::create();
168     }
169 }
170
171 SocketListener::~SocketListener()
172 {
173     delete m_socketpool;
174     delete m_child_wait;
175     delete m_child_lock;
176 }
177
178 bool SocketListener::run(bool* shutdown)
179 {
180 #ifdef _DEBUG
181     NDC ndc("run");
182 #endif
183     log->info("listener service starting");
184
185     // Save flag to monitor for shutdown request.
186     m_shutdown=shutdown;
187     unsigned long count = 0;
188
189     if (!create(m_socket)) {
190         log->crit("failed to create socket");
191         return false;
192     }
193     if (!bind(m_socket,true)) {
194         this->close(m_socket);
195         log->crit("failed to bind to socket.");
196         return false;
197     }
198
199     while (!*m_shutdown) {
200         fd_set readfds;
201         FD_ZERO(&readfds);
202         FD_SET(m_socket, &readfds);
203         struct timeval tv = { 0, 0 };
204         tv.tv_sec = 5;
205     
206         switch (select(m_socket + 1, &readfds, 0, 0, &tv)) {
207 #ifdef WIN32
208             case SOCKET_ERROR:
209 #else
210             case -1:
211 #endif
212                 if (errno == EINTR) continue;
213                 log_error();
214                 log->error("select() on main listener socket failed");
215                 return false;
216         
217             case 0:
218                 continue;
219         
220             default:
221             {
222                 // Accept the connection.
223                 SocketListener::ShibSocket newsock;
224                 if (!accept(m_socket, newsock))
225                     log->crit("failed to accept incoming socket connection");
226
227                 // We throw away the result because the children manage themselves...
228                 try {
229                     new ServerThread(newsock,this,++count);
230                 }
231                 catch (...) {
232                     log->crit("error starting new server thread to service incoming request");
233                 }
234             }
235         }
236     }
237     log->info("listener service shutting down");
238
239     // Wait for all children to exit.
240     m_child_lock->lock();
241     while (!m_children.empty())
242         m_child_wait->wait(m_child_lock);
243     m_child_lock->unlock();
244
245     this->close(m_socket);
246     m_socket=(ShibSocket)0;
247     return true;
248 }
249
250 DDF SocketListener::send(const DDF& in)
251 {
252 #ifdef _DEBUG
253     NDC ndc("send");
254 #endif
255
256     log->debug("sending message: %s", in.name());
257
258     // Serialize data for transmission.
259     ostringstream os;
260     os << in;
261     string ostr(os.str());
262
263     // Loop on the RPC in case we lost contact the first time through
264 #ifdef WIN32
265     u_long len;
266 #else
267     uint32_t len;
268 #endif
269     int retry = 1;
270     SocketListener::ShibSocket sock;
271     while (retry >= 0) {
272         sock = m_socketpool->get();
273         
274         int outlen = ostr.length();
275         len = htonl(outlen);
276         if (send(sock,(char*)&len,sizeof(len)) != sizeof(len) || send(sock,ostr.c_str(),outlen) != outlen) {
277             log_error();
278             this->close(sock);
279             if (retry)
280                 retry--;
281             else
282                 throw ListenerException("Failure sending remoted message ($1).", params(1,in.name()));
283         }
284         else {
285             // SUCCESS.
286             retry = -1;
287         }
288     }
289
290     log->debug("send completed, reading response message");
291
292     // Read the message.
293     if (recv(sock,(char*)&len,sizeof(len)) != sizeof(len)) {
294         log->error("error reading size of output message");
295         this->close(sock);
296         throw ListenerException("Failure receiving response to remoted message ($1).", params(1,in.name()));
297     }
298     len = ntohl(len);
299     
300     char buf[16384];
301     int size_read;
302     stringstream is;
303     while (len && (size_read = recv(sock, buf, sizeof(buf))) > 0) {
304         is.write(buf, size_read);
305         len -= size_read;
306     }
307     
308     if (len) {
309         log->error("error reading output message from socket");
310         this->close(sock);
311         throw ListenerException("Failure receiving response to remoted message ($1).", params(1,in.name()));
312     }
313     
314     m_socketpool->put(sock);
315
316     // Unmarshall data.
317     DDF out;
318     is >> out;
319     
320     // Check for exception to unmarshall and throw, otherwise return.
321     if (out.isstring() && out.name() && !strcmp(out.name(),"exception")) {
322         // Reconstitute exception object.
323         DDFJanitor jout(out);
324         XMLToolingException* except=NULL;
325         try { 
326             except=XMLToolingException::fromString(out.string());
327         }
328         catch (XMLToolingException& e) {
329             log->error("caught XMLToolingException while building the XMLToolingException: %s", e.what());
330             log->error("XML was: %s", out.string());
331             throw ListenerException("Remote call failed with an unparsable exception.");
332         }
333
334         auto_ptr<XMLToolingException> wrapper(except);
335         wrapper->raise();
336     }
337
338     return out;
339 }
340
341 bool SocketListener::log_error() const
342 {
343 #ifdef WIN32
344     int rc=WSAGetLastError();
345 #else
346     int rc=errno;
347 #endif
348 #ifdef HAVE_STRERROR_R
349     char buf[256];
350     memset(buf,0,sizeof(buf));
351     strerror_r(rc,buf,sizeof(buf));
352     log->error("socket call resulted in error (%d): %s",rc,isprint(*buf) ? buf : "no message");
353 #else
354     const char* buf=strerror(rc);
355     log->error("socket call resulted in error (%d): %s",rc,isprint(*buf) ? buf : "no message");
356 #endif
357     return false;
358 }
359
360 // actual function run in listener on server threads
361 void* server_thread_fn(void* arg)
362 {
363     ServerThread* child = (ServerThread*)arg;
364
365 #ifndef WIN32
366     // First, let's block all signals
367     Thread::mask_all_signals();
368 #endif
369
370     // Run the child until it exits.
371     child->run();
372
373     // Now we can clean up and exit the thread.
374     delete child;
375     return NULL;
376 }
377
378 ServerThread::ServerThread(SocketListener::ShibSocket& s, SocketListener* listener, unsigned long id)
379     : m_sock(s), m_child(NULL), m_listener(listener)
380 {
381
382     ostringstream buf;
383     buf << "[" << id << "]";
384     m_id = buf.str();
385
386     // Create the child thread
387     m_child = Thread::create(server_thread_fn, (void*)this);
388     m_child->detach();
389 }
390
391 ServerThread::~ServerThread()
392 {
393     // Then lock the children map, remove this socket/thread, signal waiters, and return
394     m_listener->m_child_lock->lock();
395     m_listener->m_children.erase(m_sock);
396     m_listener->m_child_lock->unlock();
397     m_listener->m_child_wait->signal();
398   
399     delete m_child;
400 }
401
402 void ServerThread::run()
403 {
404     NDC ndc(m_id);
405
406     // Before starting up, make sure we fully "own" this socket.
407     m_listener->m_child_lock->lock();
408     while (m_listener->m_children.find(m_sock)!=m_listener->m_children.end())
409         m_listener->m_child_wait->wait(m_listener->m_child_lock);
410     m_listener->m_children[m_sock] = m_child;
411     m_listener->m_child_lock->unlock();
412     
413     fd_set readfds;
414     struct timeval tv = { 0, 0 };
415
416     while(!*(m_listener->m_shutdown)) {
417         FD_ZERO(&readfds);
418         FD_SET(m_sock, &readfds);
419         tv.tv_sec = 1;
420
421         switch (select(m_sock+1, &readfds, 0, 0, &tv)) {
422 #ifdef WIN32
423         case SOCKET_ERROR:
424 #else
425         case -1:
426 #endif
427             if (errno == EINTR) continue;
428             m_listener->log_error();
429             m_listener->log->error("select() on incoming request socket (%u) returned error", m_sock);
430             return;
431
432         case 0:
433             break;
434
435         default:
436             if (!job()) {
437                 m_listener->log_error();
438                 m_listener->log->error("I/O failure processing request on socket (%u)", m_sock);
439                 m_listener->close(m_sock);
440                 return;
441             }
442         }
443     }
444 }
445
446 bool ServerThread::job()
447 {
448     Category& log = Category::getInstance("shibd.Listener");
449
450     ostringstream sink;
451 #ifdef WIN32
452     u_long len;
453 #else
454     uint32_t len;
455 #endif
456
457     try {
458         // Read the message.
459         if (m_listener->recv(m_sock,(char*)&len,sizeof(len)) != sizeof(len)) {
460             log.error("error reading size of input message");
461             return false;
462         }
463         len = ntohl(len);
464         
465         int size_read;
466         stringstream is;
467         while (len && (size_read = m_listener->recv(m_sock, m_buf, sizeof(m_buf))) > 0) {
468             is.write(m_buf, size_read);
469             len -= size_read;
470         }
471         
472         if (len) {
473             log.error("error reading input message from socket");
474             return false;
475         }
476         
477         // Unmarshall the message.
478         DDF in;
479         DDFJanitor jin(in);
480         is >> in;
481
482         // Dispatch the message.
483         m_listener->receive(in, sink);
484     }
485     catch (XMLToolingException& e) {
486         log.error("error processing incoming message: %s", e.what());
487         DDF out=DDF("exception").string(e.toString().c_str());
488         DDFJanitor jout(out);
489         sink << out;
490     }
491     catch (exception& e) {
492         log.error("error processing incoming message: %s", e.what());
493         ListenerException ex(e.what());
494         DDF out=DDF("exception").string(ex.toString().c_str());
495         DDFJanitor jout(out);
496         sink << out;
497     }
498 #ifndef _DEBUG
499     catch (...) {
500         log.error("unexpected error processing incoming message");
501         ListenerException ex("An unexpected error occurred while processing an incoming message.");
502         DDF out=DDF("exception").string(ex.toString().c_str());
503         DDFJanitor jout(out);
504         sink << out;
505     }
506 #endif
507     
508     // Return whatever's available.
509     string response(sink.str());
510     int outlen = response.length();
511     len = htonl(outlen);
512     if (m_listener->send(m_sock,(char*)&len,sizeof(len)) != sizeof(len)) {
513         log.error("error sending output message size");
514         return false;
515     }
516     if (m_listener->send(m_sock,response.c_str(),outlen) != outlen) {
517         log.error("error sending output message");
518         return false;
519     }
520     
521     return true;
522 }