Merge up from branch.
[shibboleth/cpp-sp.git] / shib-target / RPCListener.cpp
1 /*
2  *  Copyright 2001-2005 Internet2
3  * 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /*
18  * RPCListener.cpp -- Handles marshalling and connection mgmt for ONC-remoted IListeners
19  *
20  * Scott Cantor
21  * 5/1/05
22  *
23  */
24
25 #include "RPCListener.h"
26
27 // Deal with inadequate Sun RPC libraries
28
29 #if !HAVE_DECL_SVCFD_CREATE
30   extern "C" SVCXPRT* svcfd_create(int, u_int, u_int);
31 #endif
32
33 #ifndef HAVE_WORKING_SVC_DESTROY
34 struct tcp_conn {  /* kept in xprt->xp_p1 */
35     enum xprt_stat strm_stat;
36     u_long x_id;
37     XDR xdrs;
38     char verf_body[MAX_AUTH_BYTES];
39 };
40 #endif
41
42 extern "C" void shibrpc_prog_3(struct svc_req* rqstp, register SVCXPRT* transp);
43
44 #include <errno.h>
45 #include <sstream>
46
47 #ifdef HAVE_UNISTD_H
48 # include <unistd.h>
49 #endif
50
51 using namespace std;
52 using namespace log4cpp;
53 using namespace saml;
54 using namespace shibboleth;
55 using namespace shibtarget;
56
57 namespace shibtarget {
58     // Wraps the actual RPC connection
59     class RPCHandle
60     {
61     public:
62         RPCHandle(Category& log);
63         ~RPCHandle();
64
65         CLIENT* connect(const RPCListener* listener);         // connects and returns the CLIENT handle
66         void disconnect(const RPCListener* listener=NULL);    // disconnects, should not return disconnected handles to pool!
67
68     private:
69         Category& m_log;
70         CLIENT* m_clnt;
71         RPCListener::ShibSocket m_sock;
72     };
73   
74     // Manages the pool of connections
75     class RPCHandlePool
76     {
77     public:
78         RPCHandlePool(Category& log, const RPCListener* listener)
79             : m_log(log), m_listener(listener), m_lock(shibboleth::Mutex::create()) {}
80         ~RPCHandlePool();
81         RPCHandle* get();
82         void put(RPCHandle*);
83   
84     private:
85         const RPCListener* m_listener;
86         Category& m_log;
87         auto_ptr<Mutex> m_lock;
88         stack<RPCHandle*> m_pool;
89     };
90   
91     // Cleans up after use
92     class RPC
93     {
94     public:
95         RPC(RPCHandlePool& pool);
96         ~RPC() {delete m_handle;}
97         RPCHandle* operator->() {return m_handle;}
98         void pool() {if (m_handle) m_pool.put(m_handle); m_handle=NULL;}
99     
100     private:
101         RPCHandle* m_handle;
102         RPCHandlePool& m_pool;
103     };
104     
105     // Worker threads in server
106     class ServerThread {
107     public:
108         ServerThread(RPCListener::ShibSocket& s, RPCListener* listener);
109         ~ServerThread();
110         void run();
111
112     private:
113         bool svc_create();
114         RPCListener::ShibSocket m_sock;
115         Thread* m_child;
116         RPCListener* m_listener;
117     };
118 }
119
120
121 RPCListener::RPCListener(const DOMElement* e) : log(&Category::getInstance(SHIBT_LOGCAT".Listener")),
122     m_shutdown(NULL), m_child_lock(NULL), m_child_wait(NULL), m_rpcpool(NULL), m_socket((ShibSocket)0)
123 {
124     // Are we a client?
125     if (ShibTargetConfig::getConfig().isEnabled(ShibTargetConfig::InProcess)) {
126         m_rpcpool=new RPCHandlePool(*log,this);
127     }
128     // Are we a server?
129     if (ShibTargetConfig::getConfig().isEnabled(ShibTargetConfig::OutOfProcess)) {
130         m_child_lock = Mutex::create();
131         m_child_wait = CondWait::create();
132     }
133 }
134
135 RPCListener::~RPCListener()
136 {
137     delete m_rpcpool;
138     delete m_child_wait;
139     delete m_child_lock;
140 }
141
142 bool RPCListener::run(bool* shutdown)
143 {
144 #ifdef _DEBUG
145     saml::NDC ndc("run");
146 #endif
147
148     // Save flag to monitor for shutdown request.
149     m_shutdown=shutdown;
150
151     if (!create(m_socket)) {
152         log->crit("failed to create socket");
153         return false;
154     }
155     if (!bind(m_socket,true)) {
156         this->close(m_socket);
157         log->crit("failed to bind to socket.");
158         return false;
159     }
160
161     while (!*m_shutdown) {
162         fd_set readfds;
163         FD_ZERO(&readfds);
164         FD_SET(m_socket, &readfds);
165         struct timeval tv = { 0, 0 };
166         tv.tv_sec = 5;
167     
168         switch (select(m_socket + 1, &readfds, 0, 0, &tv)) {
169 #ifdef WIN32
170             case SOCKET_ERROR:
171 #else
172             case -1:
173 #endif
174                 if (errno == EINTR) continue;
175                 log_error();
176                 log->error("select() on main listener socket failed");
177                 return false;
178         
179             case 0:
180                 continue;
181         
182             default:
183             {
184                 // Accept the connection.
185                 RPCListener::ShibSocket newsock;
186                 if (!accept(m_socket, newsock))
187                     log->crit("failed to accept incoming socket connection");
188
189                 // We throw away the result because the children manage themselves...
190                 try {
191                     new ServerThread(newsock,this);
192                 }
193                 catch (...) {
194                     log->crit("error starting new server thread to service incoming request");
195                 }
196             }
197         }
198     }
199     log->info("listener service shutting down");
200
201     // Wait for all children to exit.
202     m_child_lock->lock();
203     while (!m_children.empty())
204         m_child_wait->wait(m_child_lock);
205     m_child_lock->unlock();
206
207     this->close(m_socket);
208     m_socket=(ShibSocket)0;
209     return true;
210 }
211
212 DDF RPCListener::send(const DDF& in)
213 {
214 #ifdef _DEBUG
215     saml::NDC ndc("send");
216 #endif
217
218     // Serialize data for transmission.
219     ostringstream os;
220     os << in;
221     shibrpc_args_3 arg;
222     string ostr(os.str());
223     arg.xml = const_cast<char*>(ostr.c_str());
224
225     log->debug("sending message: %s", in.name());
226
227     shibrpc_ret_3 ret;
228     memset(&ret, 0, sizeof(ret));
229
230     // Loop on the RPC in case we lost contact the first time through
231     int retry = 1;
232     CLIENT* clnt;
233     RPC rpc(*m_rpcpool);
234     do {
235         clnt = rpc->connect(this);
236         clnt_stat status = shibrpc_call_3(&arg, &ret, clnt);
237         if (status != RPC_SUCCESS) {
238             // FAILED.  Release, disconnect, and retry
239             log->error("RPC Failure: (CLIENT: %p) (%d): %s", clnt, status, clnt_spcreateerror("shibrpc_call_3"));
240             rpc->disconnect(this);
241             if (retry)
242                 retry--;
243             else
244                 throw ListenerException("Failure sending remoted message ($1).",params(1,in.name()));
245         }
246         else {
247             // SUCCESS.  Pool and continue
248             retry = -1;
249         }
250     } while (retry>=0);
251
252     log->debug("call completed, unmarshalling response message");
253
254     // Deserialize data.
255     DDF out;
256     try {
257         istringstream is(ret.xml);
258         is >> out;
259         clnt_freeres(clnt, (xdrproc_t)xdr_shibrpc_ret_3, (caddr_t)&ret);
260         rpc.pool();
261     }
262     catch (...) {
263         log->error("caught exception while unmarshalling response message");
264         clnt_freeres(clnt, (xdrproc_t)xdr_shibrpc_ret_3, (caddr_t)&ret);
265         rpc.pool();
266         throw;
267     }
268     
269     // Check for exception to unmarshall and throw, otherwise return.
270     if (out.isstring() && out.name() && !strcmp(out.name(),"exception")) {
271         // Reconstitute exception object.
272         DDFJanitor jout(out);
273         SAMLException* except=NULL;
274         try { 
275             istringstream es(out.string());
276             except=SAMLException::getInstance(es);
277         }
278         catch (SAMLException& e) {
279             log->error("caught SAML Exception while building the SAMLException: %s", e.what());
280             log->error("XML was: %s", out.string());
281             throw ListenerException("Remote call failed with an unparsable exception.");
282         }
283 #ifndef _DEBUG
284         catch (...) {
285             log->error("caught unknown exception building SAMLException");
286             log->error("XML was: %s", out.string());
287             throw;
288         }
289 #endif
290         auto_ptr<SAMLException> wrapper(except);
291         wrapper->raise();
292     }
293
294     return out;
295 }
296
297 bool RPCListener::log_error() const
298 {
299 #ifdef WIN32
300     int rc=WSAGetLastError();
301 #else
302     int rc=errno;
303 #endif
304 #ifdef HAVE_STRERROR_R
305     char buf[256];
306     memset(buf,0,sizeof(buf));
307     strerror_r(rc,buf,sizeof(buf));
308     log->error("socket call resulted in error (%d): %s",rc,isprint(*buf) ? buf : "no message");
309 #else
310     const char* buf=strerror(rc);
311     log->error("socket call resulted in error (%d): %s",rc,isprint(*buf) ? buf : "no message");
312 #endif
313     return false;
314 }
315
316 RPCHandle::RPCHandle(Category& log) : m_clnt(NULL), m_sock((RPCListener::ShibSocket)0), m_log(log)
317 {
318     m_log.debug("new RPCHandle created: %p", this);
319 }
320
321 RPCHandle::~RPCHandle()
322 {
323     m_log.debug("destroying RPC Handle: %p", this);
324     disconnect();
325 }
326
327 void RPCHandle::disconnect(const RPCListener* listener)
328 {
329     if (m_clnt) {
330         clnt_destroy(m_clnt);
331         m_clnt=NULL;
332         if (listener) {
333             listener->close(m_sock);
334             m_sock=(RPCListener::ShibSocket)0;
335         }
336         else {
337 #ifdef WIN32
338             ::closesocket(m_sock);
339 #else
340             ::close(m_sock);
341 #endif
342             m_sock=(RPCListener::ShibSocket)0;
343         }
344     }
345 }
346
347 CLIENT* RPCHandle::connect(const RPCListener* listener)
348 {
349 #ifdef _DEBUG
350     saml::NDC ndc("connect");
351 #endif
352     if (m_clnt) {
353         m_log.debug("returning existing connection: %p -> %p", this, m_clnt);
354         return m_clnt;
355     }
356
357     m_log.debug("trying to connect to socket");
358
359     RPCListener::ShibSocket sock;
360     if (!listener->create(sock)) {
361         m_log.error("cannot create socket");
362         throw ListenerException("Cannot create socket");
363     }
364
365     bool connected = false;
366     int num_tries = 3;
367
368     for (int i = num_tries-1; i >= 0; i--) {
369         if (listener->connect(sock)) {
370             connected = true;
371             break;
372         }
373     
374         m_log.warn("cannot connect %p to socket...%s", this, (i > 0 ? "retrying" : ""));
375
376         if (i) {
377 #ifdef WIN32
378             Sleep(2000*(num_tries-i));
379 #else
380             sleep(2*(num_tries-i));
381 #endif
382         }
383     }
384
385     if (!connected) {
386         m_log.crit("socket server unavailable, failing");
387         listener->close(sock);
388         throw ListenerException("Cannot connect to listener process, a site adminstrator should be notified.");
389     }
390
391     CLIENT* clnt = (CLIENT*)listener->getClientHandle(sock, SHIBRPC_PROG, SHIBRPC_VERS_3);
392     if (!clnt) {
393         const char* rpcerror = clnt_spcreateerror("RPCHandle::connect");
394         m_log.crit("RPC failed for %p: %s", this, rpcerror);
395         listener->close(sock);
396         throw ListenerException(rpcerror);
397     }
398
399     // Set the RPC timeout to a fairly high value...
400     struct timeval tv;
401     tv.tv_sec = 300;    /* change timeout to 5 minutes */
402     tv.tv_usec = 0;     /* this should always be set  */
403     clnt_control(clnt, CLSET_TIMEOUT, (char*)&tv);
404
405     m_clnt = clnt;
406     m_sock = sock;
407
408     m_log.debug("success: %p -> %p", this, m_clnt);
409     return m_clnt;
410 }
411
412 RPCHandlePool::~RPCHandlePool()
413 {
414     while (!m_pool.empty()) {
415         delete m_pool.top();
416         m_pool.pop();
417     }
418 }
419
420 RPCHandle* RPCHandlePool::get()
421 {
422     m_lock->lock();
423     if (m_pool.empty()) {
424         m_lock->unlock();
425         return new RPCHandle(m_log);
426     }
427     RPCHandle* ret=m_pool.top();
428     m_pool.pop();
429     m_lock->unlock();
430     return ret;
431 }
432
433 void RPCHandlePool::put(RPCHandle* handle)
434 {
435     m_lock->lock();
436     m_pool.push(handle);
437     m_lock->unlock();
438 }
439
440 RPC::RPC(RPCHandlePool& pool) : m_pool(pool)
441 {
442     m_handle=m_pool.get();
443 }
444
445 // actual function run in listener on server threads
446 void* server_thread_fn(void* arg)
447 {
448     ServerThread* child = (ServerThread*)arg;
449
450     // First, let's block all signals
451     Thread::mask_all_signals();
452
453     // Run the child until it exits.
454     child->run();
455
456     // Now we can clean up and exit the thread.
457     delete child;
458     return NULL;
459 }
460
461 ServerThread::ServerThread(RPCListener::ShibSocket& s, RPCListener* listener)
462     : m_sock(s), m_child(NULL), m_listener(listener)
463 {
464     // Create the child thread
465     m_child = Thread::create(server_thread_fn, (void*)this);
466     m_child->detach();
467 }
468
469 ServerThread::~ServerThread()
470 {
471     // Then lock the children map, remove this socket/thread, signal waiters, and return
472     m_listener->m_child_lock->lock();
473     m_listener->m_children.erase(m_sock);
474     m_listener->m_child_lock->unlock();
475     m_listener->m_child_wait->signal();
476   
477     delete m_child;
478 }
479
480 void ServerThread::run()
481 {
482     // Before starting up, make sure we fully "own" this socket.
483     m_listener->m_child_lock->lock();
484     while (m_listener->m_children.find(m_sock)!=m_listener->m_children.end())
485         m_listener->m_child_wait->wait(m_listener->m_child_lock);
486     m_listener->m_children[m_sock] = m_child;
487     m_listener->m_child_lock->unlock();
488     
489     if (!svc_create())
490         return;
491
492     fd_set readfds;
493     struct timeval tv = { 0, 0 };
494
495     while(!*(m_listener->m_shutdown) && FD_ISSET(m_sock, &svc_fdset)) {
496         FD_ZERO(&readfds);
497         FD_SET(m_sock, &readfds);
498         tv.tv_sec = 1;
499
500         switch (select(m_sock+1, &readfds, 0, 0, &tv)) {
501 #ifdef WIN32
502         case SOCKET_ERROR:
503 #else
504         case -1:
505 #endif
506             if (errno == EINTR) continue;
507             m_listener->log_error();
508             m_listener->log->error("select() on incoming request socket (%u) returned error", m_sock);
509             return;
510
511         case 0:
512             break;
513
514         default:
515             svc_getreqset(&readfds);
516         }
517     }
518 }
519
520 bool ServerThread::svc_create()
521 {
522     /* Wrap an RPC Service around the new connection socket. */
523     SVCXPRT* transp = svcfd_create(m_sock, 0, 0);
524     if (!transp) {
525 #ifdef _DEBUG
526         NDC ndc("svc_create");
527 #endif
528         m_listener->log->error("failed to wrap RPC service around socket");
529         return false;
530     }
531
532     /* Register the SHIBRPC RPC Program */
533     if (!svc_register (transp, SHIBRPC_PROG, SHIBRPC_VERS_3, shibrpc_prog_3, 0)) {
534 #ifdef HAVE_WORKING_SVC_DESTROY
535         svc_destroy(transp);
536 #else
537         /* we have to inline svc_destroy because we can't pass in the xprt variable */
538         struct tcp_conn *cd = (struct tcp_conn *)transp->xp_p1;
539         xprt_unregister(transp);
540         close(transp->xp_sock);
541         if (transp->xp_port != 0) {
542             /* a rendezvouser socket */
543             transp->xp_port = 0;
544         }
545         else {
546             /* an actual connection socket */
547             XDR_DESTROY(&(cd->xdrs));
548         }
549         mem_free((caddr_t)cd, sizeof(struct tcp_conn));
550         mem_free((caddr_t)transp, sizeof(SVCXPRT));
551 #endif
552 #ifdef _DEBUG
553         NDC ndc("svc_create");
554 #endif
555         m_listener->log->error("failed to register RPC program");
556         return false;
557     }
558
559     return true;
560 }
561
562 static string get_threadid()
563 {
564   static u_long counter = 0;
565   ostringstream buf;
566   buf << "[" << counter++ << "]";
567   return buf.str();
568 }
569
570 extern "C" bool_t shibrpc_call_3_svc(
571     shibrpc_args_3 *argp,
572     shibrpc_ret_3 *result,
573     struct svc_req *rqstp
574     )
575 {
576     string ctx=get_threadid();
577     saml::NDC ndc(ctx);
578     Category& log = Category::getInstance("shibd.Listener");
579
580     if (!argp || !result) {
581         log.error("RPC Argument Error");
582         return FALSE;
583     }
584
585     memset(result, 0, sizeof (*result));
586
587     DDF out;
588     DDFJanitor jout(out);
589
590     try {
591         // Lock the configuration.
592         IConfig* conf=ShibTargetConfig::getConfig().getINI();
593         Locker locker(conf);
594
595         // Get listener interface.
596         IListener* listener=conf->getListener();
597         if (!listener)
598             throw ListenerException("No listener implementation found to process incoming message.");
599         
600         // Unmarshal the message.
601         DDF in;
602         DDFJanitor jin(in);
603         istringstream is(argp->xml);
604         is >> in;
605
606         // Dispatch the message.
607         out=listener->receive(in);
608     }
609     catch (SAMLException &e) {
610         log.error("error processing incoming message: %s", e.what());
611         ostringstream os;
612         os << e;
613         out=DDF("exception").string(os.str().c_str());
614     }
615 #ifndef _DEBUG
616     catch (...) {
617         log.error("unexpected error processing incoming message");
618         ListenerException ex("An unexpected error occurred while processing an incoming message.");
619         ostringstream os;
620         os << ex;
621         out=DDF("exception").string(os.str().c_str());
622     }
623 #endif
624     
625     // Return whatever's available.
626     ostringstream xmlout;
627     xmlout << out;
628     result->xml=strdup(xmlout.str().c_str());
629     return TRUE;
630 }
631
632 extern "C" int
633 shibrpc_prog_3_freeresult (SVCXPRT *transp, xdrproc_t xdr_result, caddr_t result)
634 {
635         xdr_free (xdr_result, result);
636
637         /*
638          * Insert additional freeing code here, if needed
639          */
640
641         return 1;
642 }