Replaced RPC remoting with plain sockets and length-prefixed XML.
[shibboleth/sp.git] / shib-target / SocketListener.cpp
similarity index 52%
rename from shib-target/RPCListener.cpp
rename to shib-target/SocketListener.cpp
index f70ab86..72a7337 100644 (file)
  * limitations under the License.
  */
 
-/*
- * RPCListener.cpp -- Handles marshalling and connection mgmt for ONC-remoted IListeners
- *
- * Scott Cantor
- * 5/1/05
- *
+/**
+ * SocketListener.cpp
+ * 
+ * Berkeley Socket-based Listener implementation
  */
 
-#include "RPCListener.h"
-
-// Deal with inadequate Sun RPC libraries
-
-#if !HAVE_DECL_SVCFD_CREATE
-  extern "C" SVCXPRT* svcfd_create(int, u_int, u_int);
-#endif
-
-#ifndef HAVE_WORKING_SVC_DESTROY
-struct tcp_conn {  /* kept in xprt->xp_p1 */
-    enum xprt_stat strm_stat;
-    u_long x_id;
-    XDR xdrs;
-    char verf_body[MAX_AUTH_BYTES];
-};
-#endif
-
-extern "C" void shibrpc_prog_3(struct svc_req* rqstp, register SVCXPRT* transp);
+#include "SocketListener.h"
 
 #include <errno.h>
 #include <sstream>
@@ -55,75 +36,125 @@ using namespace shibboleth;
 using namespace shibtarget;
 
 namespace shibtarget {
-    // Wraps the actual RPC connection
-    class RPCHandle
-    {
-    public:
-        RPCHandle(Category& log);
-        ~RPCHandle();
-
-        CLIENT* connect(const RPCListener* listener);         // connects and returns the CLIENT handle
-        void disconnect(const RPCListener* listener=NULL);    // disconnects, should not return disconnected handles to pool!
-
-    private:
-        Category& m_log;
-        CLIENT* m_clnt;
-        RPCListener::ShibSocket m_sock;
-    };
   
     // Manages the pool of connections
-    class RPCHandlePool
+    class SocketPool
     {
     public:
-        RPCHandlePool(Category& log, const RPCListener* listener)
+        SocketPool(Category& log, const SocketListener* listener)
             : m_log(log), m_listener(listener), m_lock(shibboleth::Mutex::create()) {}
-        ~RPCHandlePool();
-        RPCHandle* get();
-        void put(RPCHandle*);
+        ~SocketPool();
+        SocketListener::ShibSocket get();
+        void put(SocketListener::ShibSocket s);
   
     private:
-        const RPCListener* m_listener;
+        SocketListener::ShibSocket connect();
+        
+        const SocketListener* m_listener;
         Category& m_log;
         auto_ptr<Mutex> m_lock;
-        stack<RPCHandle*> m_pool;
+        stack<SocketListener::ShibSocket> m_pool;
     };
   
-    // Cleans up after use
-    class RPC
-    {
-    public:
-        RPC(RPCHandlePool& pool);
-        ~RPC() {delete m_handle;}
-        RPCHandle* operator->() {return m_handle;}
-        void pool() {if (m_handle) m_pool.put(m_handle); m_handle=NULL;}
-    
-    private:
-        RPCHandle* m_handle;
-        RPCHandlePool& m_pool;
-    };
-    
     // Worker threads in server
     class ServerThread {
     public:
-        ServerThread(RPCListener::ShibSocket& s, RPCListener* listener);
+        ServerThread(SocketListener::ShibSocket& s, SocketListener* listener, unsigned long id);
         ~ServerThread();
         void run();
+        bool job();
 
     private:
-        bool svc_create();
-        RPCListener::ShibSocket m_sock;
+        SocketListener::ShibSocket m_sock;
         Thread* m_child;
-        RPCListener* m_listener;
+        SocketListener* m_listener;
+        string m_id;
+        char m_buf[16384];
     };
 }
 
+SocketListener::ShibSocket SocketPool::connect()
+{
+#ifdef _DEBUG
+    saml::NDC ndc("connect");
+#endif
+
+    m_log.debug("trying to connect to listener");
+
+    SocketListener::ShibSocket sock;
+    if (!m_listener->create(sock)) {
+        m_log.error("cannot create socket");
+        throw ListenerException("Cannot create socket");
+    }
+
+    bool connected = false;
+    int num_tries = 3;
+
+    for (int i = num_tries-1; i >= 0; i--) {
+        if (m_listener->connect(sock)) {
+            connected = true;
+            break;
+        }
+    
+        m_log.warn("cannot connect socket (%u)...%s", sock, (i > 0 ? "retrying" : ""));
+
+        if (i) {
+#ifdef WIN32
+            Sleep(2000*(num_tries-i));
+#else
+            sleep(2*(num_tries-i));
+#endif
+        }
+    }
+
+    if (!connected) {
+        m_log.crit("socket server unavailable, failing");
+        m_listener->close(sock);
+        throw ListenerException("Cannot connect to listener process, a site adminstrator should be notified.");
+    }
+
+    m_log.debug("socket (%u) connected successfully", sock);
+    return sock;
+}
+
+SocketPool::~SocketPool()
+{
+    while (!m_pool.empty()) {
+#ifdef WIN32
+        closesocket(m_pool.top());
+#else
+        ::close(m_pool.top());
+#endif
+        m_pool.pop();
+    }
+}
+
+SocketListener::ShibSocket SocketPool::get()
+{
+    m_lock->lock();
+    if (m_pool.empty()) {
+        m_lock->unlock();
+        return connect();
+    }
+    SocketListener::ShibSocket ret=m_pool.top();
+    m_pool.pop();
+    m_lock->unlock();
+    return ret;
+}
+
+void SocketPool::put(SocketListener::ShibSocket s)
+{
+    m_lock->lock();
+    m_pool.push(s);
+    m_lock->unlock();
+}
 
-RPCListener::RPCListener(const DOMElement* e) : log(&Category::getInstance(SHIBT_LOGCAT".Listener")),
-    m_shutdown(NULL), m_child_lock(NULL), m_child_wait(NULL), m_rpcpool(NULL), m_socket((ShibSocket)0)
+SocketListener::SocketListener(const DOMElement* e) : log(&Category::getInstance(SHIBT_LOGCAT".Listener")),
+    m_shutdown(NULL), m_child_lock(NULL), m_child_wait(NULL), m_socketpool(NULL), m_socket((ShibSocket)0)
 {
     // Are we a client?
     if (ShibTargetConfig::getConfig().isEnabled(ShibTargetConfig::InProcess)) {
-        m_rpcpool=new RPCHandlePool(*log,this);
+        m_socketpool=new SocketPool(*log,this);
     }
     // Are we a server?
     if (ShibTargetConfig::getConfig().isEnabled(ShibTargetConfig::OutOfProcess)) {
@@ -132,14 +163,14 @@ RPCListener::RPCListener(const DOMElement* e) : log(&Category::getInstance(SHIBT
     }
 }
 
-RPCListener::~RPCListener()
+SocketListener::~SocketListener()
 {
-    delete m_rpcpool;
+    delete m_socketpool;
     delete m_child_wait;
     delete m_child_lock;
 }
 
-bool RPCListener::run(bool* shutdown)
+bool SocketListener::run(bool* shutdown)
 {
 #ifdef _DEBUG
     saml::NDC ndc("run");
@@ -147,6 +178,7 @@ bool RPCListener::run(bool* shutdown)
 
     // Save flag to monitor for shutdown request.
     m_shutdown=shutdown;
+    unsigned long count = 0;
 
     if (!create(m_socket)) {
         log->crit("failed to create socket");
@@ -182,13 +214,13 @@ bool RPCListener::run(bool* shutdown)
             default:
             {
                 // Accept the connection.
-                RPCListener::ShibSocket newsock;
+                SocketListener::ShibSocket newsock;
                 if (!accept(m_socket, newsock))
                     log->crit("failed to accept incoming socket connection");
 
                 // We throw away the result because the children manage themselves...
                 try {
-                    new ServerThread(newsock,this);
+                    new ServerThread(newsock,this,++count);
                 }
                 catch (...) {
                     log->crit("error starting new server thread to service incoming request");
@@ -209,63 +241,76 @@ bool RPCListener::run(bool* shutdown)
     return true;
 }
 
-DDF RPCListener::send(const DDF& in)
+DDF SocketListener::send(const DDF& in)
 {
 #ifdef _DEBUG
     saml::NDC ndc("send");
 #endif
 
+    log->debug("sending message: %s", in.name());
+
     // Serialize data for transmission.
     ostringstream os;
     os << in;
-    shibrpc_args_3 arg;
     string ostr(os.str());
-    arg.xml = const_cast<char*>(ostr.c_str());
-
-    log->debug("sending message: %s", in.name());
-
-    shibrpc_ret_3 ret;
-    memset(&ret, 0, sizeof(ret));
 
     // Loop on the RPC in case we lost contact the first time through
+#ifdef WIN32
+    u_long len;
+#else
+    uint32_t len;
+#endif
     int retry = 1;
-    CLIENT* clnt;
-    RPC rpc(*m_rpcpool);
-    do {
-        clnt = rpc->connect(this);
-        clnt_stat status = shibrpc_call_3(&arg, &ret, clnt);
-        if (status != RPC_SUCCESS) {
-            // FAILED.  Release, disconnect, and retry
-            log->error("RPC Failure: (CLIENT: %p) (%d): %s", clnt, status, clnt_spcreateerror("shibrpc_call_3"));
-            rpc->disconnect(this);
+    SocketListener::ShibSocket sock;
+    while (retry >= 0) {
+        sock = m_socketpool->get();
+        
+        int outlen = ostr.length();
+        len = htonl(outlen);
+        if (send(sock,(char*)&len,sizeof(len)) != sizeof(len) || send(sock,ostr.c_str(),outlen) != outlen) {
+            log_error();
+            this->close(sock);
             if (retry)
                 retry--;
             else
-                throw ListenerException("Failure sending remoted message ($1).",params(1,in.name()));
+                throw ListenerException("Failure sending remoted message ($1).", params(1,in.name()));
         }
         else {
-            // SUCCESS.  Pool and continue
+            // SUCCESS.
             retry = -1;
         }
-    } while (retry>=0);
+    }
 
-    log->debug("call completed, unmarshalling response message");
+    log->debug("send completed, reading response message");
 
-    // Deserialize data.
-    DDF out;
-    try {
-        istringstream is(ret.xml);
-        is >> out;
-        clnt_freeres(clnt, (xdrproc_t)xdr_shibrpc_ret_3, (caddr_t)&ret);
-        rpc.pool();
+    // Read the message.
+    if (recv(sock,(char*)&len,sizeof(len)) != sizeof(len)) {
+        log->error("error reading size of output message");
+        this->close(sock);
+        throw ListenerException("Failure receiving response to remoted message ($1).", params(1,in.name()));
     }
-    catch (...) {
-        log->error("caught exception while unmarshalling response message");
-        clnt_freeres(clnt, (xdrproc_t)xdr_shibrpc_ret_3, (caddr_t)&ret);
-        rpc.pool();
-        throw;
+    len = ntohl(len);
+    
+    char buf[16384];
+    int size_read;
+    stringstream is;
+    while (len && (size_read = recv(sock, buf, sizeof(buf))) > 0) {
+        is.write(buf, size_read);
+        len -= size_read;
     }
     
+    if (len) {
+        log->error("error reading output message from socket");
+        this->close(sock);
+        throw ListenerException("Failure receiving response to remoted message ($1).", params(1,in.name()));
+    }
+    
+    m_socketpool->put(sock);
+
+    // Unmarshall data.
+    DDF out;
+    is >> out;
+    
     // Check for exception to unmarshall and throw, otherwise return.
     if (out.isstring() && out.name() && !strcmp(out.name(),"exception")) {
         // Reconstitute exception object.
@@ -280,13 +325,7 @@ DDF RPCListener::send(const DDF& in)
             log->error("XML was: %s", out.string());
             throw ListenerException("Remote call failed with an unparsable exception.");
         }
-#ifndef _DEBUG
-        catch (...) {
-            log->error("caught unknown exception building SAMLException");
-            log->error("XML was: %s", out.string());
-            throw;
-        }
-#endif
+
         auto_ptr<SAMLException> wrapper(except);
         wrapper->raise();
     }
@@ -294,7 +333,7 @@ DDF RPCListener::send(const DDF& in)
     return out;
 }
 
-bool RPCListener::log_error() const
+bool SocketListener::log_error() const
 {
 #ifdef WIN32
     int rc=WSAGetLastError();
@@ -313,135 +352,6 @@ bool RPCListener::log_error() const
     return false;
 }
 
-RPCHandle::RPCHandle(Category& log) : m_clnt(NULL), m_sock((RPCListener::ShibSocket)0), m_log(log)
-{
-    m_log.debug("new RPCHandle created: %p", this);
-}
-
-RPCHandle::~RPCHandle()
-{
-    m_log.debug("destroying RPC Handle: %p", this);
-    disconnect();
-}
-
-void RPCHandle::disconnect(const RPCListener* listener)
-{
-    if (m_clnt) {
-        clnt_destroy(m_clnt);
-        m_clnt=NULL;
-        if (listener) {
-            listener->close(m_sock);
-            m_sock=(RPCListener::ShibSocket)0;
-        }
-        else {
-#ifdef WIN32
-            ::closesocket(m_sock);
-#else
-            ::close(m_sock);
-#endif
-            m_sock=(RPCListener::ShibSocket)0;
-        }
-    }
-}
-
-CLIENT* RPCHandle::connect(const RPCListener* listener)
-{
-#ifdef _DEBUG
-    saml::NDC ndc("connect");
-#endif
-    if (m_clnt) {
-        m_log.debug("returning existing connection: %p -> %p", this, m_clnt);
-        return m_clnt;
-    }
-
-    m_log.debug("trying to connect to socket");
-
-    RPCListener::ShibSocket sock;
-    if (!listener->create(sock)) {
-        m_log.error("cannot create socket");
-        throw ListenerException("Cannot create socket");
-    }
-
-    bool connected = false;
-    int num_tries = 3;
-
-    for (int i = num_tries-1; i >= 0; i--) {
-        if (listener->connect(sock)) {
-            connected = true;
-            break;
-        }
-    
-        m_log.warn("cannot connect %p to socket...%s", this, (i > 0 ? "retrying" : ""));
-
-        if (i) {
-#ifdef WIN32
-            Sleep(2000*(num_tries-i));
-#else
-            sleep(2*(num_tries-i));
-#endif
-        }
-    }
-
-    if (!connected) {
-        m_log.crit("socket server unavailable, failing");
-        listener->close(sock);
-        throw ListenerException("Cannot connect to listener process, a site adminstrator should be notified.");
-    }
-
-    CLIENT* clnt = (CLIENT*)listener->getClientHandle(sock, SHIBRPC_PROG, SHIBRPC_VERS_3);
-    if (!clnt) {
-        const char* rpcerror = clnt_spcreateerror("RPCHandle::connect");
-        m_log.crit("RPC failed for %p: %s", this, rpcerror);
-        listener->close(sock);
-        throw ListenerException(rpcerror);
-    }
-
-    // Set the RPC timeout to a fairly high value...
-    struct timeval tv;
-    tv.tv_sec = 300;    /* change timeout to 5 minutes */
-    tv.tv_usec = 0;     /* this should always be set  */
-    clnt_control(clnt, CLSET_TIMEOUT, (char*)&tv);
-
-    m_clnt = clnt;
-    m_sock = sock;
-
-    m_log.debug("success: %p -> %p", this, m_clnt);
-    return m_clnt;
-}
-
-RPCHandlePool::~RPCHandlePool()
-{
-    while (!m_pool.empty()) {
-        delete m_pool.top();
-        m_pool.pop();
-    }
-}
-
-RPCHandle* RPCHandlePool::get()
-{
-    m_lock->lock();
-    if (m_pool.empty()) {
-        m_lock->unlock();
-        return new RPCHandle(m_log);
-    }
-    RPCHandle* ret=m_pool.top();
-    m_pool.pop();
-    m_lock->unlock();
-    return ret;
-}
-
-void RPCHandlePool::put(RPCHandle* handle)
-{
-    m_lock->lock();
-    m_pool.push(handle);
-    m_lock->unlock();
-}
-
-RPC::RPC(RPCHandlePool& pool) : m_pool(pool)
-{
-    m_handle=m_pool.get();
-}
-
 // actual function run in listener on server threads
 void* server_thread_fn(void* arg)
 {
@@ -458,9 +368,14 @@ void* server_thread_fn(void* arg)
     return NULL;
 }
 
-ServerThread::ServerThread(RPCListener::ShibSocket& s, RPCListener* listener)
+ServerThread::ServerThread(SocketListener::ShibSocket& s, SocketListener* listener, unsigned long id)
     : m_sock(s), m_child(NULL), m_listener(listener)
 {
+
+    ostringstream buf;
+    buf << "[" << id << "]";
+    m_id = buf.str();
+
     // Create the child thread
     m_child = Thread::create(server_thread_fn, (void*)this);
     m_child->detach();
@@ -479,6 +394,8 @@ ServerThread::~ServerThread()
 
 void ServerThread::run()
 {
+    saml::NDC ndc(m_id);
+
     // Before starting up, make sure we fully "own" this socket.
     m_listener->m_child_lock->lock();
     while (m_listener->m_children.find(m_sock)!=m_listener->m_children.end())
@@ -486,13 +403,10 @@ void ServerThread::run()
     m_listener->m_children[m_sock] = m_child;
     m_listener->m_child_lock->unlock();
     
-    if (!svc_create())
-        return;
-
     fd_set readfds;
     struct timeval tv = { 0, 0 };
 
-    while(!*(m_listener->m_shutdown) && FD_ISSET(m_sock, &svc_fdset)) {
+    while(!*(m_listener->m_shutdown)) {
         FD_ZERO(&readfds);
         FD_SET(m_sock, &readfds);
         tv.tv_sec = 1;
@@ -512,99 +426,59 @@ void ServerThread::run()
             break;
 
         default:
-            svc_getreqset(&readfds);
-        }
-    }
-}
-
-bool ServerThread::svc_create()
-{
-    /* Wrap an RPC Service around the new connection socket. */
-    SVCXPRT* transp = svcfd_create(m_sock, 0, 0);
-    if (!transp) {
-#ifdef _DEBUG
-        NDC ndc("svc_create");
-#endif
-        m_listener->log->error("failed to wrap RPC service around socket");
-        return false;
-    }
-
-    /* Register the SHIBRPC RPC Program */
-    if (!svc_register (transp, SHIBRPC_PROG, SHIBRPC_VERS_3, shibrpc_prog_3, 0)) {
-#ifdef HAVE_WORKING_SVC_DESTROY
-        svc_destroy(transp);
-#else
-        /* we have to inline svc_destroy because we can't pass in the xprt variable */
-        struct tcp_conn *cd = (struct tcp_conn *)transp->xp_p1;
-        xprt_unregister(transp);
-        close(transp->xp_sock);
-        if (transp->xp_port != 0) {
-            /* a rendezvouser socket */
-            transp->xp_port = 0;
-        }
-        else {
-            /* an actual connection socket */
-            XDR_DESTROY(&(cd->xdrs));
+            if (!job()) {
+                m_listener->log_error();
+                m_listener->log->error("I/O failure processing request on socket (%u)", m_sock);
+                m_listener->close(m_sock);
+                return;
+            }
         }
-        mem_free((caddr_t)cd, sizeof(struct tcp_conn));
-        mem_free((caddr_t)transp, sizeof(SVCXPRT));
-#endif
-#ifdef _DEBUG
-        NDC ndc("svc_create");
-#endif
-        m_listener->log->error("failed to register RPC program");
-        return false;
     }
-
-    return true;
 }
 
-static string get_threadid()
+bool ServerThread::job()
 {
-  static u_long counter = 0;
-  ostringstream buf;
-  buf << "[" << counter++ << "]";
-  return buf.str();
-}
-
-extern "C" bool_t shibrpc_call_3_svc(
-    shibrpc_args_3 *argp,
-    shibrpc_ret_3 *result,
-    struct svc_req *rqstp
-    )
-{
-    string ctx=get_threadid();
-    saml::NDC ndc(ctx);
     Category& log = Category::getInstance("shibd.Listener");
 
-    if (!argp || !result) {
-        log.error("RPC Argument Error");
-        return FALSE;
-    }
-
-    memset(result, 0, sizeof (*result));
-
     DDF out;
     DDFJanitor jout(out);
+#ifdef WIN32
+    u_long len;
+#else
+    uint32_t len;
+#endif
 
     try {
         // Lock the configuration.
         IConfig* conf=ShibTargetConfig::getConfig().getINI();
         Locker locker(conf);
 
-        // Get listener interface.
-        IListener* listener=conf->getListener();
-        if (!listener)
-            throw ListenerException("No listener implementation found to process incoming message.");
+        // Read the message.
+        if (m_listener->recv(m_sock,(char*)&len,sizeof(len)) != sizeof(len)) {
+            log.error("error reading size of input message");
+            return false;
+        }
+        len = ntohl(len);
         
-        // Unmarshal the message.
+        int size_read;
+        stringstream is;
+        while (len && (size_read = m_listener->recv(m_sock, m_buf, sizeof(m_buf))) > 0) {
+            is.write(m_buf, size_read);
+            len -= size_read;
+        }
+        
+        if (len) {
+            log.error("error reading input message from socket");
+            return false;
+        }
+        
+        // Unmarshall the message.
         DDF in;
         DDFJanitor jin(in);
-        istringstream is(argp->xml);
         is >> in;
 
         // Dispatch the message.
-        out=listener->receive(in);
+        out=m_listener->receive(in);
     }
     catch (SAMLException &e) {
         log.error("error processing incoming message: %s", e.what());
@@ -625,18 +499,17 @@ extern "C" bool_t shibrpc_call_3_svc(
     // Return whatever's available.
     ostringstream xmlout;
     xmlout << out;
-    result->xml=strdup(xmlout.str().c_str());
-    return TRUE;
-}
-
-extern "C" int
-shibrpc_prog_3_freeresult (SVCXPRT *transp, xdrproc_t xdr_result, caddr_t result)
-{
-       xdr_free (xdr_result, result);
-
-       /*
-        * Insert additional freeing code here, if needed
-        */
-
-       return 1;
+    string response(xmlout.str());
+    int outlen = response.length();
+    len = htonl(outlen);
+    if (m_listener->send(m_sock,(char*)&len,sizeof(len)) != sizeof(len)) {
+        log.error("error sending output message size");
+        return false;
+    }
+    if (m_listener->send(m_sock,response.c_str(),outlen) != outlen) {
+        log.error("error sending output message");
+        return false;
+    }
+    
+    return true;
 }