Add protocol to log.
[shibboleth/sp.git] / shibsp / impl / StorageServiceSessionCache.cpp
index 8c1d921..b687d33 100644 (file)
@@ -32,7 +32,7 @@
 #include "Application.h"
 #include "exceptions.h"
 #include "ServiceProvider.h"
-#include "SessionCache.h"
+#include "SessionCacheEx.h"
 #include "TransactionLog.h"
 #include "attribute/Attribute.h"
 #include "remoting/ListenerService.h"
@@ -59,7 +59,7 @@ using namespace std;
 namespace shibsp {
 
     class StoredSession;
-    class SSCache : public SessionCache
+    class SSCache : public SessionCacheEx
 #ifndef SHIBSP_LITE
         ,public virtual Remoted
 #endif
@@ -71,10 +71,11 @@ namespace shibsp {
 #ifndef SHIBSP_LITE
         void receive(DDF& in, ostream& out);
 
-        string insert(
-            time_t expires,
+        void insert(
             const Application& application,
-            const char* client_addr=NULL,
+            const HTTPRequest& httpRequest,
+            HTTPResponse& httpResponse,
+            time_t expires,
             const saml2md::EntityDescriptor* issuer=NULL,
             const XMLCh* protocol=NULL,
             const saml2::NameID* nameid=NULL,
@@ -86,25 +87,48 @@ namespace shibsp {
             const vector<Attribute*>* attributes=NULL
             );
         vector<string>::size_type logout(
+            const Application& application,
             const saml2md::EntityDescriptor* issuer,
             const saml2::NameID& nameid,
             const set<string>* indexes,
             time_t expires,
-            const Application& application,
             vector<string>& sessions
             );
         bool matches(
-            const char* key,
+            const Application& application,
+            const xmltooling::HTTPRequest& request,
             const saml2md::EntityDescriptor* issuer,
             const saml2::NameID& nameid,
-            const set<string>* indexes,
-            const Application& application
+            const set<string>* indexes
             );
 #endif
-        Session* find(const char* key, const Application& application, const char* client_addr=NULL, time_t* timeout=NULL);
-        void remove(const char* key, const Application& application);
+        Session* find(const Application& application, const char* key, const char* client_addr=NULL, time_t* timeout=NULL);
+        void remove(const Application& application, const char* key);
         void test();
 
+        string active(const Application& application, const xmltooling::HTTPRequest& request) {
+            pair<string,const char*> shib_cookie = application.getCookieNameProps("_shibsession_");
+            const char* session_id = request.getCookie(shib_cookie.first.c_str());
+            return (session_id ? session_id : "");
+        }
+
+        Session* find(const Application& application, const HTTPRequest& request, const char* client_addr=NULL, time_t* timeout=NULL) {
+            string id = active(application, request);
+            if (!id.empty())
+                return find(application, id.c_str(), client_addr, timeout);
+            return NULL;
+        }
+
+        void remove(const Application& application, const HTTPRequest& request, HTTPResponse* response=NULL) {
+            pair<string,const char*> shib_cookie = application.getCookieNameProps("_shibsession_");
+            const char* session_id = request.getCookie(shib_cookie.first.c_str());
+            if (session_id && *session_id) {
+                if (response)
+                    response->setCookie(shib_cookie.first.c_str(), shib_cookie.second);
+                remove(application, session_id);
+            }
+        }
+
         void cleanup();
 
         Category& m_log;
@@ -285,10 +309,10 @@ namespace shibsp {
     }
 }
 
-void SHIBSP_API shibsp::registerSessionCaches()\r
-{\r
-    SPConfig::getConfig().SessionCacheManager.registerFactory(STORAGESERVICE_SESSION_CACHE, StorageServiceCacheFactory);\r
-}\r
+void SHIBSP_API shibsp::registerSessionCaches()
+{
+    SPConfig::getConfig().SessionCacheManager.registerFactory(STORAGESERVICE_SESSION_CACHE, StorageServiceCacheFactory);
+}
 
 void StoredSession::unmarshallAttributes() const
 {
@@ -656,27 +680,27 @@ SSCache::SSCache(const DOMElement* e)
 #endif
         m_root(e), m_inprocTimeout(900), m_lock(NULL), shutdown(false), shutdown_wait(NULL), cleanup_thread(NULL)
 {
-    static const XMLCh cacheTimeout[] =     UNICODE_LITERAL_12(c,a,c,h,e,T,i,m,e,o,u,t);\r
-    static const XMLCh inprocTimeout[] =    UNICODE_LITERAL_13(i,n,p,r,o,c,T,i,m,e,o,u,t);\r
+    static const XMLCh cacheTimeout[] =     UNICODE_LITERAL_12(c,a,c,h,e,T,i,m,e,o,u,t);
+    static const XMLCh inprocTimeout[] =    UNICODE_LITERAL_13(i,n,p,r,o,c,T,i,m,e,o,u,t);
     static const XMLCh _StorageService[] =  UNICODE_LITERAL_14(S,t,o,r,a,g,e,S,e,r,v,i,c,e);
 
     SPConfig& conf = SPConfig::getConfig();
     inproc = conf.isEnabled(SPConfig::InProcess);
 
     if (e) {
-        const XMLCh* tag=e->getAttributeNS(NULL,cacheTimeout);\r
-        if (tag && *tag) {\r
-            m_cacheTimeout = XMLString::parseInt(tag);\r
-            if (!m_cacheTimeout)\r
-                m_cacheTimeout=3600;\r
-        }\r
+        const XMLCh* tag=e->getAttributeNS(NULL,cacheTimeout);
+        if (tag && *tag) {
+            m_cacheTimeout = XMLString::parseInt(tag);
+            if (!m_cacheTimeout)
+                m_cacheTimeout=3600;
+        }
         if (inproc) {
-            const XMLCh* tag=e->getAttributeNS(NULL,inprocTimeout);\r
-            if (tag && *tag) {\r
-                m_inprocTimeout = XMLString::parseInt(tag);\r
-                if (!m_inprocTimeout)\r
-                    m_inprocTimeout=900;\r
-            }\r
+            const XMLCh* tag=e->getAttributeNS(NULL,inprocTimeout);
+            if (tag && *tag) {
+                m_inprocTimeout = XMLString::parseInt(tag);
+                if (!m_inprocTimeout)
+                    m_inprocTimeout=900;
+            }
         }
     }
 
@@ -801,10 +825,11 @@ void SSCache::insert(const char* key, time_t expires, const char* name, const ch
     }
 }
 
-string SSCache::insert(
-    time_t expires,
+void SSCache::insert(
     const Application& application,
-    const char* client_addr,
+    const HTTPRequest& httpRequest,
+    HTTPResponse& httpResponse,
+    time_t expires,
     const saml2md::EntityDescriptor* issuer,
     const XMLCh* protocol,
     const saml2::NameID* nameid,
@@ -841,7 +866,7 @@ string SSCache::insert(
             istringstream pstr(pending);
             pstr >> pendobj;
             // IdP.SP.index contains logout expiration, if any.
-            DDF deadmenwalking = pendobj[issuer ? entity_id.get() : "_shibnull"][application.getString("entityID").second];
+            DDF deadmenwalking = pendobj[issuer ? entity_id.get() : "_shibnull"][application.getRelyingParty(issuer)->getString("entityID").second];
             const char* logexpstr = deadmenwalking[session_index ? index.get() : "_shibnull"].string();
             if (!logexpstr && session_index)    // we tried an exact session match, now try for NULL
                 logexpstr = deadmenwalking["_shibnull"].string();
@@ -874,8 +899,7 @@ string SSCache::insert(
     strftime(timebuf,32,"%Y-%m-%dT%H:%M:%SZ",ptime);
     obj.addmember("expires").string(timebuf);
 
-    if (client_addr)
-        obj.addmember("client_addr").string(client_addr);
+    obj.addmember("client_addr").string(httpRequest.getRemoteAddr().c_str());
     if (issuer)
         obj.addmember("entity_id").string(entity_id.get());
     if (protocol) {
@@ -953,7 +977,9 @@ string SSCache::insert(
     }
 
     const char* pid = obj["entity_id"].string();
-    m_log.info("new session created: SessionID (%s) IdP (%s) Address (%s)", key.get(), pid ? pid : "none", client_addr);
+    const char* prot = obj["protocol"].string();
+    m_log.info("new session created: ID (%s) IdP (%s) Protocol(%s) Address (%s)",
+        key.get(), pid ? pid : "none", prot ? prot : "none", httpRequest.getRemoteAddr().c_str());
 
     // Transaction Logging
     TransactionLog* xlog = application.getServiceProvider().getTransactionLog();
@@ -966,9 +992,11 @@ string SSCache::insert(
         ") for principal from (IdP: " <<
             (pid ? pid : "none") <<
         ") at (ClientAddress: " <<
-            (client_addr ? client_addr : "none") <<
+            httpRequest.getRemoteAddr() <<
         ") with (NameIdentifier: " <<
             (nameid ? name.get() : "none") <<
+        ") using (Protocol: " <<
+            (prot ? prot : "none") <<
         ")";
     
     if (attributes) {
@@ -983,40 +1011,43 @@ string SSCache::insert(
         xlog->log.info("}");
     }
 
-    return key.get();
+    pair<string,const char*> shib_cookie = application.getCookieNameProps("_shibsession_");
+    string k(key.get());
+    k += shib_cookie.second;
+    httpResponse.setCookie(shib_cookie.first.c_str(), k.c_str());
 }
 
 bool SSCache::matches(
-    const char* key,
+    const Application& application,
+    const xmltooling::HTTPRequest& request,
     const saml2md::EntityDescriptor* issuer,
     const saml2::NameID& nameid,
-    const set<string>* indexes,
-    const Application& application
+    const set<string>* indexes
     )
 {
     auto_ptr_char entityID(issuer ? issuer->getEntityID() : NULL);
     try {
-        Session* session = find(key, application);
+        Session* session = find(application, request);
         if (session) {
             Locker locker(session, false);
             if (XMLString::equals(session->getEntityID(), entityID.get()) && session->getNameID() &&
-                    stronglyMatches(issuer->getEntityID(), application.getXMLString("entityID").second, nameid, *session->getNameID())) {
+                    stronglyMatches(issuer->getEntityID(), application.getRelyingParty(issuer)->getXMLString("entityID").second, nameid, *session->getNameID())) {
                 return (!indexes || indexes->empty() || (session->getSessionIndex() ? (indexes->count(session->getSessionIndex())>0) : false));
             }
         }
     }
     catch (exception& ex) {
-        m_log.error("error while matching session (%s): %s", key, ex.what());
+        m_log.error("error while matching session: %s", ex.what());
     }
     return false;
 }
 
 vector<string>::size_type SSCache::logout(
+    const Application& application,
     const saml2md::EntityDescriptor* issuer,
     const saml2::NameID& nameid,
     const set<string>* indexes,
     time_t expires,
-    const Application& application,
     vector<string>& sessionsKilled
     )
 {
@@ -1063,7 +1094,7 @@ vector<string>::size_type SSCache::logout(
         }
 
         // Structure is keyed by the IdP and SP, with a member per session index containing the expiration.
-        DDF root = obj.addmember(issuer ? entityID.get() : "_shibnull").addmember(application.getString("entityID").second);
+        DDF root = obj.addmember(issuer ? entityID.get() : "_shibnull").addmember(application.getRelyingParty(issuer)->getString("entityID").second);
         if (indexes) {
             for (set<string>::const_iterator x = indexes->begin(); x!=indexes->end(); ++x)
                 root.addmember(x->c_str()).string(timebuf);
@@ -1080,12 +1111,12 @@ vector<string>::size_type SSCache::logout(
             ver = m_storage->updateText("Logout", name.get(), lout.str().c_str(), max(expires, oldexp), ver);
             if (ver <= 0) {
                 // Out of sync, or went missing, so retry.
-                return logout(issuer, nameid, indexes, expires, application, sessionsKilled);
+                return logout(application, issuer, nameid, indexes, expires, sessionsKilled);
             }
         }
         else if (!m_storage->createText("Logout", name.get(), lout.str().c_str(), expires)) {
             // Hit a dup, so just retry, hopefully hitting the other branch.
-            return logout(issuer, nameid, indexes, expires, application, sessionsKilled);
+            return logout(application, issuer, nameid, indexes, expires, sessionsKilled);
         }
 
         obj.destroy();
@@ -1112,7 +1143,7 @@ vector<string>::size_type SSCache::logout(
                 // Fetch the session for comparison.
                 Session* session = NULL;
                 try {
-                    session = find(key.string(), application);
+                    session = find(application, key.string());
                 }
                 catch (exception& ex) {
                     m_log.error("error locating session (%s): %s", key.string(), ex.what());
@@ -1123,7 +1154,7 @@ vector<string>::size_type SSCache::logout(
                     // Same issuer?
                     if (XMLString::equals(session->getEntityID(), entityID.get())) {
                         // Same NameID?
-                        if (stronglyMatches(issuer->getEntityID(), application.getXMLString("entityID").second, nameid, *session->getNameID())) {
+                        if (stronglyMatches(issuer->getEntityID(), application.getRelyingParty(issuer)->getXMLString("entityID").second, nameid, *session->getNameID())) {
                             sessionsKilled.push_back(key.string());
                             key.destroy();
                         }
@@ -1209,7 +1240,7 @@ bool SSCache::stronglyMatches(const XMLCh* idp, const XMLCh* sp, const saml2::Na
 
 #endif
 
-Session* SSCache::find(const char* key, const Application& application, const char* client_addr, time_t* timeout)
+Session* SSCache::find(const Application& application, const char* key, const char* client_addr, time_t* timeout)
 {
 #ifdef _DEBUG
     xmltooling::NDC ndc("find");
@@ -1297,18 +1328,15 @@ Session* SSCache::find(const char* key, const Application& application, const ch
             
             if (timeout && *timeout > 0 && now - lastAccess >= *timeout) {
                 m_log.info("session timed out (ID: %s)", key);
-                remove(key, application);
-                RetryableProfileException ex("Your session has expired, and you must re-authenticate.");
+                remove(application, key);
                 const char* eid = obj["entity_id"].string();
                 if (!eid) {
                     obj.destroy();
-                    throw ex;
+                    throw RetryableProfileException("Your session has expired, and you must re-authenticate.");
                 }
                 string eid2(eid);
                 obj.destroy();
-                MetadataProvider* m=application.getMetadataProvider();
-                Locker locker(m);
-                annotateException(&ex,m->getEntityDescriptor(MetadataProvider::Criteria(eid2.c_str(),NULL,NULL,false)).first); // throws it
+                throw RetryableProfileException("Your session has expired, and you must re-authenticate.", namedparams(1, "entityID", eid2.c_str()));
             }
             
             if (timeout) {
@@ -1359,14 +1387,14 @@ Session* SSCache::find(const char* key, const Application& application, const ch
     }
     catch (...) {
         session->unlock();
-        remove(key, application);
+        remove(application, key);
         throw;
     }
     
     return session;
 }
 
-void SSCache::remove(const char* key, const Application& application)
+void SSCache::remove(const Application& application, const char* key)
 {
 #ifdef _DEBUG
     xmltooling::NDC ndc("remove");
@@ -1554,7 +1582,7 @@ void SSCache::receive(DDF& in, ostream& out)
                     
             if (timeout > 0 && now - lastAccess >= timeout) {
                 m_log.info("session timed out (ID: %s)", key);
-                remove(key,*app);
+                remove(*app, key);
                 throw RetryableProfileException("Your session has expired, and you must re-authenticate.");
             } 
 
@@ -1630,7 +1658,7 @@ void SSCache::receive(DDF& in, ostream& out)
         if (!app)
             throw ConfigurationException("Application not found, check configuration?");
 
-        remove(key,*app);
+        remove(*app, key);
         DDF ret(NULL);
         DDFJanitor jan(ret);
         out << ret;