Large reorg of shibsp lib, new SPRequest API, ported modules, shifted code out of...
[shibboleth/cpp-sp.git] / shibsp / remoting / impl / SocketListener.cpp
1 /*
2  *  Copyright 2001-2005 Internet2
3  * 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /**
18  * SocketListener.cpp
19  * 
20  * Berkeley Socket-based ListenerService implementation
21  */
22
23 #include "internal.h"
24 #include "exceptions.h"
25 #include "remoting/impl/SocketListener.h"
26
27 #include <errno.h>
28 #include <stack>
29 #include <sstream>
30 #include <shibsp/SPConfig.h>
31 #include <xmltooling/util/NDC.h>
32
33 #ifdef HAVE_UNISTD_H
34 # include <unistd.h>
35 #endif
36
37 using namespace shibsp;
38 using namespace xmltooling;
39 using namespace log4cpp;
40 using namespace std;
41 using xercesc::DOMElement;
42
43 namespace shibsp {
44   
45     // Manages the pool of connections
46     class SocketPool
47     {
48     public:
49         SocketPool(Category& log, const SocketListener* listener)
50             : m_log(log), m_listener(listener), m_lock(Mutex::create()) {}
51         ~SocketPool();
52         SocketListener::ShibSocket get();
53         void put(SocketListener::ShibSocket s);
54   
55     private:
56         SocketListener::ShibSocket connect();
57         
58         const SocketListener* m_listener;
59         Category& m_log;
60         auto_ptr<Mutex> m_lock;
61         stack<SocketListener::ShibSocket> m_pool;
62     };
63   
64     // Worker threads in server
65     class ServerThread {
66     public:
67         ServerThread(SocketListener::ShibSocket& s, SocketListener* listener, unsigned long id);
68         ~ServerThread();
69         void run();
70         bool job();
71
72     private:
73         SocketListener::ShibSocket m_sock;
74         Thread* m_child;
75         SocketListener* m_listener;
76         string m_id;
77         char m_buf[16384];
78     };
79 }
80
81 SocketListener::ShibSocket SocketPool::connect()
82 {
83 #ifdef _DEBUG
84     NDC ndc("connect");
85 #endif
86
87     m_log.debug("trying to connect to listener");
88
89     SocketListener::ShibSocket sock;
90     if (!m_listener->create(sock)) {
91         m_log.error("cannot create socket");
92         throw ListenerException("Cannot create socket");
93     }
94
95     bool connected = false;
96     int num_tries = 3;
97
98     for (int i = num_tries-1; i >= 0; i--) {
99         if (m_listener->connect(sock)) {
100             connected = true;
101             break;
102         }
103     
104         m_log.warn("cannot connect socket (%u)...%s", sock, (i > 0 ? "retrying" : ""));
105
106         if (i) {
107 #ifdef WIN32
108             Sleep(2000*(num_tries-i));
109 #else
110             sleep(2*(num_tries-i));
111 #endif
112         }
113     }
114
115     if (!connected) {
116         m_log.crit("socket server unavailable, failing");
117         m_listener->close(sock);
118         throw ListenerException("Cannot connect to shibd process, a site adminstrator should be notified.");
119     }
120
121     m_log.debug("socket (%u) connected successfully", sock);
122     return sock;
123 }
124
125 SocketPool::~SocketPool()
126 {
127     while (!m_pool.empty()) {
128 #ifdef WIN32
129         closesocket(m_pool.top());
130 #else
131         ::close(m_pool.top());
132 #endif
133         m_pool.pop();
134     }
135 }
136
137 SocketListener::ShibSocket SocketPool::get()
138 {
139     m_lock->lock();
140     if (m_pool.empty()) {
141         m_lock->unlock();
142         return connect();
143     }
144     SocketListener::ShibSocket ret=m_pool.top();
145     m_pool.pop();
146     m_lock->unlock();
147     return ret;
148 }
149
150 void SocketPool::put(SocketListener::ShibSocket s)
151 {
152     m_lock->lock();
153     m_pool.push(s);
154     m_lock->unlock();
155 }
156
157 SocketListener::SocketListener(const DOMElement* e) : log(&Category::getInstance(SHIBSP_LOGCAT".Listener")),
158     m_shutdown(NULL), m_child_lock(NULL), m_child_wait(NULL), m_socketpool(NULL), m_socket((ShibSocket)0)
159 {
160     // Are we a client?
161     if (SPConfig::getConfig().isEnabled(SPConfig::InProcess)) {
162         m_socketpool=new SocketPool(*log,this);
163     }
164     // Are we a server?
165     if (SPConfig::getConfig().isEnabled(SPConfig::OutOfProcess)) {
166         m_child_lock = Mutex::create();
167         m_child_wait = CondWait::create();
168     }
169 }
170
171 SocketListener::~SocketListener()
172 {
173     delete m_socketpool;
174     delete m_child_wait;
175     delete m_child_lock;
176 }
177
178 bool SocketListener::run(bool* shutdown)
179 {
180 #ifdef _DEBUG
181     NDC ndc("run");
182 #endif
183
184     // Save flag to monitor for shutdown request.
185     m_shutdown=shutdown;
186     unsigned long count = 0;
187
188     if (!create(m_socket)) {
189         log->crit("failed to create socket");
190         return false;
191     }
192     if (!bind(m_socket,true)) {
193         this->close(m_socket);
194         log->crit("failed to bind to socket.");
195         return false;
196     }
197
198     while (!*m_shutdown) {
199         fd_set readfds;
200         FD_ZERO(&readfds);
201         FD_SET(m_socket, &readfds);
202         struct timeval tv = { 0, 0 };
203         tv.tv_sec = 5;
204     
205         switch (select(m_socket + 1, &readfds, 0, 0, &tv)) {
206 #ifdef WIN32
207             case SOCKET_ERROR:
208 #else
209             case -1:
210 #endif
211                 if (errno == EINTR) continue;
212                 log_error();
213                 log->error("select() on main listener socket failed");
214                 return false;
215         
216             case 0:
217                 continue;
218         
219             default:
220             {
221                 // Accept the connection.
222                 SocketListener::ShibSocket newsock;
223                 if (!accept(m_socket, newsock))
224                     log->crit("failed to accept incoming socket connection");
225
226                 // We throw away the result because the children manage themselves...
227                 try {
228                     new ServerThread(newsock,this,++count);
229                 }
230                 catch (...) {
231                     log->crit("error starting new server thread to service incoming request");
232                 }
233             }
234         }
235     }
236     log->info("listener service shutting down");
237
238     // Wait for all children to exit.
239     m_child_lock->lock();
240     while (!m_children.empty())
241         m_child_wait->wait(m_child_lock);
242     m_child_lock->unlock();
243
244     this->close(m_socket);
245     m_socket=(ShibSocket)0;
246     return true;
247 }
248
249 DDF SocketListener::send(const DDF& in)
250 {
251 #ifdef _DEBUG
252     NDC ndc("send");
253 #endif
254
255     log->debug("sending message: %s", in.name());
256
257     // Serialize data for transmission.
258     ostringstream os;
259     os << in;
260     string ostr(os.str());
261
262     // Loop on the RPC in case we lost contact the first time through
263 #ifdef WIN32
264     u_long len;
265 #else
266     uint32_t len;
267 #endif
268     int retry = 1;
269     SocketListener::ShibSocket sock;
270     while (retry >= 0) {
271         sock = m_socketpool->get();
272         
273         int outlen = ostr.length();
274         len = htonl(outlen);
275         if (send(sock,(char*)&len,sizeof(len)) != sizeof(len) || send(sock,ostr.c_str(),outlen) != outlen) {
276             log_error();
277             this->close(sock);
278             if (retry)
279                 retry--;
280             else
281                 throw ListenerException("Failure sending remoted message ($1).", params(1,in.name()));
282         }
283         else {
284             // SUCCESS.
285             retry = -1;
286         }
287     }
288
289     log->debug("send completed, reading response message");
290
291     // Read the message.
292     if (recv(sock,(char*)&len,sizeof(len)) != sizeof(len)) {
293         log->error("error reading size of output message");
294         this->close(sock);
295         throw ListenerException("Failure receiving response to remoted message ($1).", params(1,in.name()));
296     }
297     len = ntohl(len);
298     
299     char buf[16384];
300     int size_read;
301     stringstream is;
302     while (len && (size_read = recv(sock, buf, sizeof(buf))) > 0) {
303         is.write(buf, size_read);
304         len -= size_read;
305     }
306     
307     if (len) {
308         log->error("error reading output message from socket");
309         this->close(sock);
310         throw ListenerException("Failure receiving response to remoted message ($1).", params(1,in.name()));
311     }
312     
313     m_socketpool->put(sock);
314
315     // Unmarshall data.
316     DDF out;
317     is >> out;
318     
319     // Check for exception to unmarshall and throw, otherwise return.
320     if (out.isstring() && out.name() && !strcmp(out.name(),"exception")) {
321         // Reconstitute exception object.
322         DDFJanitor jout(out);
323         XMLToolingException* except=NULL;
324         try { 
325             except=XMLToolingException::fromString(out.string());
326         }
327         catch (XMLToolingException& e) {
328             log->error("caught XMLToolingException while building the XMLToolingException: %s", e.what());
329             log->error("XML was: %s", out.string());
330             throw ListenerException("Remote call failed with an unparsable exception.");
331         }
332
333         auto_ptr<XMLToolingException> wrapper(except);
334         wrapper->raise();
335     }
336
337     return out;
338 }
339
340 bool SocketListener::log_error() const
341 {
342 #ifdef WIN32
343     int rc=WSAGetLastError();
344 #else
345     int rc=errno;
346 #endif
347 #ifdef HAVE_STRERROR_R
348     char buf[256];
349     memset(buf,0,sizeof(buf));
350     strerror_r(rc,buf,sizeof(buf));
351     log->error("socket call resulted in error (%d): %s",rc,isprint(*buf) ? buf : "no message");
352 #else
353     const char* buf=strerror(rc);
354     log->error("socket call resulted in error (%d): %s",rc,isprint(*buf) ? buf : "no message");
355 #endif
356     return false;
357 }
358
359 // actual function run in listener on server threads
360 void* server_thread_fn(void* arg)
361 {
362     ServerThread* child = (ServerThread*)arg;
363
364 #ifndef WIN32
365     // First, let's block all signals
366     Thread::mask_all_signals();
367 #endif
368
369     // Run the child until it exits.
370     child->run();
371
372     // Now we can clean up and exit the thread.
373     delete child;
374     return NULL;
375 }
376
377 ServerThread::ServerThread(SocketListener::ShibSocket& s, SocketListener* listener, unsigned long id)
378     : m_sock(s), m_child(NULL), m_listener(listener)
379 {
380
381     ostringstream buf;
382     buf << "[" << id << "]";
383     m_id = buf.str();
384
385     // Create the child thread
386     m_child = Thread::create(server_thread_fn, (void*)this);
387     m_child->detach();
388 }
389
390 ServerThread::~ServerThread()
391 {
392     // Then lock the children map, remove this socket/thread, signal waiters, and return
393     m_listener->m_child_lock->lock();
394     m_listener->m_children.erase(m_sock);
395     m_listener->m_child_lock->unlock();
396     m_listener->m_child_wait->signal();
397   
398     delete m_child;
399 }
400
401 void ServerThread::run()
402 {
403     NDC ndc(m_id);
404
405     // Before starting up, make sure we fully "own" this socket.
406     m_listener->m_child_lock->lock();
407     while (m_listener->m_children.find(m_sock)!=m_listener->m_children.end())
408         m_listener->m_child_wait->wait(m_listener->m_child_lock);
409     m_listener->m_children[m_sock] = m_child;
410     m_listener->m_child_lock->unlock();
411     
412     fd_set readfds;
413     struct timeval tv = { 0, 0 };
414
415     while(!*(m_listener->m_shutdown)) {
416         FD_ZERO(&readfds);
417         FD_SET(m_sock, &readfds);
418         tv.tv_sec = 1;
419
420         switch (select(m_sock+1, &readfds, 0, 0, &tv)) {
421 #ifdef WIN32
422         case SOCKET_ERROR:
423 #else
424         case -1:
425 #endif
426             if (errno == EINTR) continue;
427             m_listener->log_error();
428             m_listener->log->error("select() on incoming request socket (%u) returned error", m_sock);
429             return;
430
431         case 0:
432             break;
433
434         default:
435             if (!job()) {
436                 m_listener->log_error();
437                 m_listener->log->error("I/O failure processing request on socket (%u)", m_sock);
438                 m_listener->close(m_sock);
439                 return;
440             }
441         }
442     }
443 }
444
445 bool ServerThread::job()
446 {
447     Category& log = Category::getInstance("shibd.Listener");
448
449     DDF out;
450     DDFJanitor jout(out);
451 #ifdef WIN32
452     u_long len;
453 #else
454     uint32_t len;
455 #endif
456
457     try {
458         // Read the message.
459         if (m_listener->recv(m_sock,(char*)&len,sizeof(len)) != sizeof(len)) {
460             log.error("error reading size of input message");
461             return false;
462         }
463         len = ntohl(len);
464         
465         int size_read;
466         stringstream is;
467         while (len && (size_read = m_listener->recv(m_sock, m_buf, sizeof(m_buf))) > 0) {
468             is.write(m_buf, size_read);
469             len -= size_read;
470         }
471         
472         if (len) {
473             log.error("error reading input message from socket");
474             return false;
475         }
476         
477         // Unmarshall the message.
478         DDF in;
479         DDFJanitor jin(in);
480         is >> in;
481
482         // Dispatch the message.
483         out=m_listener->receive(in);
484     }
485     catch (XMLToolingException& e) {
486         log.error("error processing incoming message: %s", e.what());
487         out=DDF("exception").string(e.toString().c_str());
488     }
489     catch (exception& e) {
490         log.error("error processing incoming message: %s", e.what());
491         ListenerException ex(e.what());
492         out=DDF("exception").string(ex.toString().c_str());
493     }
494 #ifndef _DEBUG
495     catch (...) {
496         log.error("unexpected error processing incoming message");
497         ListenerException ex("An unexpected error occurred while processing an incoming message.");
498         out=DDF("exception").string(ex.toString().c_str());
499     }
500 #endif
501     
502     // Return whatever's available.
503     ostringstream xmlout;
504     xmlout << out;
505     string response(xmlout.str());
506     int outlen = response.length();
507     len = htonl(outlen);
508     if (m_listener->send(m_sock,(char*)&len,sizeof(len)) != sizeof(len)) {
509         log.error("error sending output message size");
510         return false;
511     }
512     if (m_listener->send(m_sock,response.c_str(),outlen) != outlen) {
513         log.error("error sending output message");
514         return false;
515     }
516     
517     return true;
518 }