Replaced RPC remoting with plain sockets and length-prefixed XML.
[shibboleth/cpp-sp.git] / shib-target / SocketListener.cpp
1 /*
2  *  Copyright 2001-2005 Internet2
3  * 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /**
18  * SocketListener.cpp
19  * 
20  * Berkeley Socket-based Listener implementation
21  */
22
23 #include "SocketListener.h"
24
25 #include <errno.h>
26 #include <sstream>
27
28 #ifdef HAVE_UNISTD_H
29 # include <unistd.h>
30 #endif
31
32 using namespace std;
33 using namespace log4cpp;
34 using namespace saml;
35 using namespace shibboleth;
36 using namespace shibtarget;
37
38 namespace shibtarget {
39   
40     // Manages the pool of connections
41     class SocketPool
42     {
43     public:
44         SocketPool(Category& log, const SocketListener* listener)
45             : m_log(log), m_listener(listener), m_lock(shibboleth::Mutex::create()) {}
46         ~SocketPool();
47         SocketListener::ShibSocket get();
48         void put(SocketListener::ShibSocket s);
49   
50     private:
51         SocketListener::ShibSocket connect();
52         
53         const SocketListener* m_listener;
54         Category& m_log;
55         auto_ptr<Mutex> m_lock;
56         stack<SocketListener::ShibSocket> m_pool;
57     };
58   
59     // Worker threads in server
60     class ServerThread {
61     public:
62         ServerThread(SocketListener::ShibSocket& s, SocketListener* listener, unsigned long id);
63         ~ServerThread();
64         void run();
65         bool job();
66
67     private:
68         SocketListener::ShibSocket m_sock;
69         Thread* m_child;
70         SocketListener* m_listener;
71         string m_id;
72         char m_buf[16384];
73     };
74 }
75
76 SocketListener::ShibSocket SocketPool::connect()
77 {
78 #ifdef _DEBUG
79     saml::NDC ndc("connect");
80 #endif
81
82     m_log.debug("trying to connect to listener");
83
84     SocketListener::ShibSocket sock;
85     if (!m_listener->create(sock)) {
86         m_log.error("cannot create socket");
87         throw ListenerException("Cannot create socket");
88     }
89
90     bool connected = false;
91     int num_tries = 3;
92
93     for (int i = num_tries-1; i >= 0; i--) {
94         if (m_listener->connect(sock)) {
95             connected = true;
96             break;
97         }
98     
99         m_log.warn("cannot connect socket (%u)...%s", sock, (i > 0 ? "retrying" : ""));
100
101         if (i) {
102 #ifdef WIN32
103             Sleep(2000*(num_tries-i));
104 #else
105             sleep(2*(num_tries-i));
106 #endif
107         }
108     }
109
110     if (!connected) {
111         m_log.crit("socket server unavailable, failing");
112         m_listener->close(sock);
113         throw ListenerException("Cannot connect to listener process, a site adminstrator should be notified.");
114     }
115
116     m_log.debug("socket (%u) connected successfully", sock);
117     return sock;
118 }
119
120 SocketPool::~SocketPool()
121 {
122     while (!m_pool.empty()) {
123 #ifdef WIN32
124         closesocket(m_pool.top());
125 #else
126         ::close(m_pool.top());
127 #endif
128         m_pool.pop();
129     }
130 }
131
132 SocketListener::ShibSocket SocketPool::get()
133 {
134     m_lock->lock();
135     if (m_pool.empty()) {
136         m_lock->unlock();
137         return connect();
138     }
139     SocketListener::ShibSocket ret=m_pool.top();
140     m_pool.pop();
141     m_lock->unlock();
142     return ret;
143 }
144
145 void SocketPool::put(SocketListener::ShibSocket s)
146 {
147     m_lock->lock();
148     m_pool.push(s);
149     m_lock->unlock();
150 }
151
152 SocketListener::SocketListener(const DOMElement* e) : log(&Category::getInstance(SHIBT_LOGCAT".Listener")),
153     m_shutdown(NULL), m_child_lock(NULL), m_child_wait(NULL), m_socketpool(NULL), m_socket((ShibSocket)0)
154 {
155     // Are we a client?
156     if (ShibTargetConfig::getConfig().isEnabled(ShibTargetConfig::InProcess)) {
157         m_socketpool=new SocketPool(*log,this);
158     }
159     // Are we a server?
160     if (ShibTargetConfig::getConfig().isEnabled(ShibTargetConfig::OutOfProcess)) {
161         m_child_lock = Mutex::create();
162         m_child_wait = CondWait::create();
163     }
164 }
165
166 SocketListener::~SocketListener()
167 {
168     delete m_socketpool;
169     delete m_child_wait;
170     delete m_child_lock;
171 }
172
173 bool SocketListener::run(bool* shutdown)
174 {
175 #ifdef _DEBUG
176     saml::NDC ndc("run");
177 #endif
178
179     // Save flag to monitor for shutdown request.
180     m_shutdown=shutdown;
181     unsigned long count = 0;
182
183     if (!create(m_socket)) {
184         log->crit("failed to create socket");
185         return false;
186     }
187     if (!bind(m_socket,true)) {
188         this->close(m_socket);
189         log->crit("failed to bind to socket.");
190         return false;
191     }
192
193     while (!*m_shutdown) {
194         fd_set readfds;
195         FD_ZERO(&readfds);
196         FD_SET(m_socket, &readfds);
197         struct timeval tv = { 0, 0 };
198         tv.tv_sec = 5;
199     
200         switch (select(m_socket + 1, &readfds, 0, 0, &tv)) {
201 #ifdef WIN32
202             case SOCKET_ERROR:
203 #else
204             case -1:
205 #endif
206                 if (errno == EINTR) continue;
207                 log_error();
208                 log->error("select() on main listener socket failed");
209                 return false;
210         
211             case 0:
212                 continue;
213         
214             default:
215             {
216                 // Accept the connection.
217                 SocketListener::ShibSocket newsock;
218                 if (!accept(m_socket, newsock))
219                     log->crit("failed to accept incoming socket connection");
220
221                 // We throw away the result because the children manage themselves...
222                 try {
223                     new ServerThread(newsock,this,++count);
224                 }
225                 catch (...) {
226                     log->crit("error starting new server thread to service incoming request");
227                 }
228             }
229         }
230     }
231     log->info("listener service shutting down");
232
233     // Wait for all children to exit.
234     m_child_lock->lock();
235     while (!m_children.empty())
236         m_child_wait->wait(m_child_lock);
237     m_child_lock->unlock();
238
239     this->close(m_socket);
240     m_socket=(ShibSocket)0;
241     return true;
242 }
243
244 DDF SocketListener::send(const DDF& in)
245 {
246 #ifdef _DEBUG
247     saml::NDC ndc("send");
248 #endif
249
250     log->debug("sending message: %s", in.name());
251
252     // Serialize data for transmission.
253     ostringstream os;
254     os << in;
255     string ostr(os.str());
256
257     // Loop on the RPC in case we lost contact the first time through
258 #ifdef WIN32
259     u_long len;
260 #else
261     uint32_t len;
262 #endif
263     int retry = 1;
264     SocketListener::ShibSocket sock;
265     while (retry >= 0) {
266         sock = m_socketpool->get();
267         
268         int outlen = ostr.length();
269         len = htonl(outlen);
270         if (send(sock,(char*)&len,sizeof(len)) != sizeof(len) || send(sock,ostr.c_str(),outlen) != outlen) {
271             log_error();
272             this->close(sock);
273             if (retry)
274                 retry--;
275             else
276                 throw ListenerException("Failure sending remoted message ($1).", params(1,in.name()));
277         }
278         else {
279             // SUCCESS.
280             retry = -1;
281         }
282     }
283
284     log->debug("send completed, reading response message");
285
286     // Read the message.
287     if (recv(sock,(char*)&len,sizeof(len)) != sizeof(len)) {
288         log->error("error reading size of output message");
289         this->close(sock);
290         throw ListenerException("Failure receiving response to remoted message ($1).", params(1,in.name()));
291     }
292     len = ntohl(len);
293     
294     char buf[16384];
295     int size_read;
296     stringstream is;
297     while (len && (size_read = recv(sock, buf, sizeof(buf))) > 0) {
298         is.write(buf, size_read);
299         len -= size_read;
300     }
301     
302     if (len) {
303         log->error("error reading output message from socket");
304         this->close(sock);
305         throw ListenerException("Failure receiving response to remoted message ($1).", params(1,in.name()));
306     }
307     
308     m_socketpool->put(sock);
309
310     // Unmarshall data.
311     DDF out;
312     is >> out;
313     
314     // Check for exception to unmarshall and throw, otherwise return.
315     if (out.isstring() && out.name() && !strcmp(out.name(),"exception")) {
316         // Reconstitute exception object.
317         DDFJanitor jout(out);
318         SAMLException* except=NULL;
319         try { 
320             istringstream es(out.string());
321             except=SAMLException::getInstance(es);
322         }
323         catch (SAMLException& e) {
324             log->error("caught SAML Exception while building the SAMLException: %s", e.what());
325             log->error("XML was: %s", out.string());
326             throw ListenerException("Remote call failed with an unparsable exception.");
327         }
328
329         auto_ptr<SAMLException> wrapper(except);
330         wrapper->raise();
331     }
332
333     return out;
334 }
335
336 bool SocketListener::log_error() const
337 {
338 #ifdef WIN32
339     int rc=WSAGetLastError();
340 #else
341     int rc=errno;
342 #endif
343 #ifdef HAVE_STRERROR_R
344     char buf[256];
345     memset(buf,0,sizeof(buf));
346     strerror_r(rc,buf,sizeof(buf));
347     log->error("socket call resulted in error (%d): %s",rc,isprint(*buf) ? buf : "no message");
348 #else
349     const char* buf=strerror(rc);
350     log->error("socket call resulted in error (%d): %s",rc,isprint(*buf) ? buf : "no message");
351 #endif
352     return false;
353 }
354
355 // actual function run in listener on server threads
356 void* server_thread_fn(void* arg)
357 {
358     ServerThread* child = (ServerThread*)arg;
359
360     // First, let's block all signals
361     Thread::mask_all_signals();
362
363     // Run the child until it exits.
364     child->run();
365
366     // Now we can clean up and exit the thread.
367     delete child;
368     return NULL;
369 }
370
371 ServerThread::ServerThread(SocketListener::ShibSocket& s, SocketListener* listener, unsigned long id)
372     : m_sock(s), m_child(NULL), m_listener(listener)
373 {
374
375     ostringstream buf;
376     buf << "[" << id << "]";
377     m_id = buf.str();
378
379     // Create the child thread
380     m_child = Thread::create(server_thread_fn, (void*)this);
381     m_child->detach();
382 }
383
384 ServerThread::~ServerThread()
385 {
386     // Then lock the children map, remove this socket/thread, signal waiters, and return
387     m_listener->m_child_lock->lock();
388     m_listener->m_children.erase(m_sock);
389     m_listener->m_child_lock->unlock();
390     m_listener->m_child_wait->signal();
391   
392     delete m_child;
393 }
394
395 void ServerThread::run()
396 {
397     saml::NDC ndc(m_id);
398
399     // Before starting up, make sure we fully "own" this socket.
400     m_listener->m_child_lock->lock();
401     while (m_listener->m_children.find(m_sock)!=m_listener->m_children.end())
402         m_listener->m_child_wait->wait(m_listener->m_child_lock);
403     m_listener->m_children[m_sock] = m_child;
404     m_listener->m_child_lock->unlock();
405     
406     fd_set readfds;
407     struct timeval tv = { 0, 0 };
408
409     while(!*(m_listener->m_shutdown)) {
410         FD_ZERO(&readfds);
411         FD_SET(m_sock, &readfds);
412         tv.tv_sec = 1;
413
414         switch (select(m_sock+1, &readfds, 0, 0, &tv)) {
415 #ifdef WIN32
416         case SOCKET_ERROR:
417 #else
418         case -1:
419 #endif
420             if (errno == EINTR) continue;
421             m_listener->log_error();
422             m_listener->log->error("select() on incoming request socket (%u) returned error", m_sock);
423             return;
424
425         case 0:
426             break;
427
428         default:
429             if (!job()) {
430                 m_listener->log_error();
431                 m_listener->log->error("I/O failure processing request on socket (%u)", m_sock);
432                 m_listener->close(m_sock);
433                 return;
434             }
435         }
436     }
437 }
438
439 bool ServerThread::job()
440 {
441     Category& log = Category::getInstance("shibd.Listener");
442
443     DDF out;
444     DDFJanitor jout(out);
445 #ifdef WIN32
446     u_long len;
447 #else
448     uint32_t len;
449 #endif
450
451     try {
452         // Lock the configuration.
453         IConfig* conf=ShibTargetConfig::getConfig().getINI();
454         Locker locker(conf);
455
456         // Read the message.
457         if (m_listener->recv(m_sock,(char*)&len,sizeof(len)) != sizeof(len)) {
458             log.error("error reading size of input message");
459             return false;
460         }
461         len = ntohl(len);
462         
463         int size_read;
464         stringstream is;
465         while (len && (size_read = m_listener->recv(m_sock, m_buf, sizeof(m_buf))) > 0) {
466             is.write(m_buf, size_read);
467             len -= size_read;
468         }
469         
470         if (len) {
471             log.error("error reading input message from socket");
472             return false;
473         }
474         
475         // Unmarshall the message.
476         DDF in;
477         DDFJanitor jin(in);
478         is >> in;
479
480         // Dispatch the message.
481         out=m_listener->receive(in);
482     }
483     catch (SAMLException &e) {
484         log.error("error processing incoming message: %s", e.what());
485         ostringstream os;
486         os << e;
487         out=DDF("exception").string(os.str().c_str());
488     }
489 #ifndef _DEBUG
490     catch (...) {
491         log.error("unexpected error processing incoming message");
492         ListenerException ex("An unexpected error occurred while processing an incoming message.");
493         ostringstream os;
494         os << ex;
495         out=DDF("exception").string(os.str().c_str());
496     }
497 #endif
498     
499     // Return whatever's available.
500     ostringstream xmlout;
501     xmlout << out;
502     string response(xmlout.str());
503     int outlen = response.length();
504     len = htonl(outlen);
505     if (m_listener->send(m_sock,(char*)&len,sizeof(len)) != sizeof(len)) {
506         log.error("error sending output message size");
507         return false;
508     }
509     if (m_listener->send(m_sock,response.c_str(),outlen) != outlen) {
510         log.error("error sending output message");
511         return false;
512     }
513     
514     return true;
515 }