https://issues.shibboleth.net/jira/browse/SSPCPP-515
[shibboleth/sp.git] / shibsp / impl / StorageServiceSessionCache.cpp
index f1af3f6..9e5d48d 100644 (file)
@@ -75,7 +75,16 @@ using namespace xmltooling;
 using namespace boost;
 using namespace std;
 
-namespace shibsp {
+namespace {
+
+    // Allows the cache to bind sessions to multiple client address
+    // families based on whatever this function returns.
+    static const char* getAddressFamily(const char* addr) {
+        if (strchr(addr, ':'))
+            return "6";
+        else
+            return "4";
+    }
 
     class StoredSession;
     class SSCache : public SessionCacheEx
@@ -133,7 +142,9 @@ namespace shibsp {
             const set<string>* indexes,
             time_t expires,
             vector<string>& sessions
-            );
+            ) {
+            return _logout(app, issuer, nameid, indexes, expires, sessions, 0);
+        }
         bool matches(
             const Application& app,
             const HTTPRequest& request,
@@ -203,7 +214,16 @@ namespace shibsp {
     private:
 #ifndef SHIBSP_LITE
         // maintain back-mappings of NameID/SessionIndex -> session key
-        void insert(const char* key, time_t expires, const char* name, const char* index);
+        void insert(const char* key, time_t expires, const char* name, const char* index, short attempts=0);
+        vector<string>::size_type _logout(
+            const Application& app,
+            const EntityDescriptor* issuer,
+            const saml2::NameID& nameid,
+            const set<string>* indexes,
+            time_t expires,
+            vector<string>& sessions,
+            short attempts
+            );
         bool stronglyMatches(const XMLCh* idp, const XMLCh* sp, const saml2::NameID& n1, const saml2::NameID& n2) const;
         LogoutEvent* newLogoutEvent(const Application& app) const;
 
@@ -236,10 +256,7 @@ namespace shibsp {
                 const char* saddr = m_obj["client_addr"].string();
                 DDF addrobj = m_obj["client_addr"].structure();
                 if (saddr && *saddr) {
-                    if (strchr(saddr, ':'))
-                        addrobj.addmember("6").string(saddr);
-                    else
-                        addrobj.addmember("4").string(saddr);
+                    addrobj.addmember(getAddressFamily(saddr)).string(saddr);
                 }
             }
 
@@ -293,20 +310,16 @@ namespace shibsp {
             return m_obj["client_addr"].first().string();
         }
 
-        const char* getClientAddressV4() const {
-            return m_obj["client_addr"]["4"].string();
-        }
-        const char* getClientAddressV6() const {
-            return m_obj["client_addr"]["6"].string();
+        const char* getClientAddress(const char* family) const {
+            if (family)
+                return m_obj["client_addr"][family].string();
+            return nullptr;
         }
         void setClientAddress(const char* client_addr) {
             DDF obj = m_obj["client_addr"];
             if (!obj.isstruct())
                 obj = m_obj.addmember("client_addr").structure();
-            if (strchr(client_addr, ':'))
-                obj.addmember("6").string(client_addr);
-            else
-                obj.addmember("4").string(client_addr);
+            obj.addmember(getAddressFamily(client_addr)).string(client_addr);
         }
 
         const char* getEntityID() const {
@@ -448,11 +461,7 @@ void StoredSession::validate(const Application& app, const char* client_addr, ti
 
     // Address check?
     if (client_addr) {
-        const char* saddr = nullptr;
-        if (strchr(client_addr, ':'))
-            saddr = getClientAddressV6();
-        else
-            saddr = getClientAddressV4();
+        const char* saddr = getClientAddress(getAddressFamily(client_addr));
         if (saddr && *saddr) {
             if (!XMLString::equals(saddr, client_addr)) {
                 m_cache->m_log.warn("client address mismatch, client (%s), session (%s)", client_addr, saddr);
@@ -560,14 +569,11 @@ void StoredSession::validate(const Application& app, const char* client_addr, ti
 
         // We may need to write back a new address into the session using a versioned update loop.
         if (client_addr) {
+            short attempts = 0;
             do {
-                const char* saddr = nullptr;
-                if (strchr(client_addr, ':'))
-                    saddr = getClientAddressV6();
-                else
-                    saddr = getClientAddressV4();
-                // Something snuck in and bound the session to this address type, so it better match what we have.
+                const char* saddr = getClientAddress(getAddressFamily(client_addr));
                 if (saddr) {
+                    // Something snuck in and bound the session to this address type, so it better match what we have.
                     if (!XMLString::equals(saddr, client_addr)) {
                         m_cache->m_log.warn("client address mismatch, client (%s), session (%s)", client_addr, saddr);
                         throw RetryableProfileException(
@@ -608,6 +614,10 @@ void StoredSession::validate(const Application& app, const char* client_addr, ti
                 }
                 else if (ver < 0) {
                     // Out of sync.
+                    if (++attempts > 10) {
+                        m_cache->m_log.error("failed to bind client address, update attempts exceeded limit");
+                        throw IOException("Unable to update stored session, exceeded retry limit.");
+                    }
                     m_cache->m_log.warn("storage service indicates the record is out of sync, updating with a fresh copy...");
                     ver = m_cache->m_storage->readText(getID(), "session", &record, nullptr);
                     if (!ver) {
@@ -654,6 +664,7 @@ void StoredSession::addAttributes(const vector<Attribute*>& attributes)
     m_cache->m_log.debug("adding attributes to session (%s)", getID());
 
     int ver;
+    short attempts = 0;
     do {
         DDF attr;
         DDF attrs = m_obj["attributes"];
@@ -696,6 +707,10 @@ void StoredSession::addAttributes(const vector<Attribute*>& attributes)
         }
         else if (ver < 0) {
             // Out of sync.
+            if (++attempts > 10) {
+                m_cache->m_log.error("failed to update stored session, update attempts exceeded limit");
+                throw IOException("Unable to update stored session, exceeded retry limit.");
+            }
             m_cache->m_log.warn("storage service indicates the record is out of sync, updating with a fresh copy...");
             ver = m_cache->m_storage->readText(getID(), "session", &record, nullptr);
             if (!ver) {
@@ -781,6 +796,7 @@ void StoredSession::addAssertion(Assertion* assertion)
         throw IOException("Attempted to insert duplicate assertion ID into session.");
 
     int ver;
+    short attempts = 0;
     do {
         DDF token = DDF(nullptr).string(id.get());
         m_obj["assertions"].add(token);
@@ -814,6 +830,10 @@ void StoredSession::addAssertion(Assertion* assertion)
         }
         else if (ver < 0) {
             // Out of sync.
+            if (++attempts > 10) {
+                m_cache->m_log.error("failed to update stored session, update attempts exceeded limit");
+                throw IOException("Unable to update stored session, exceeded retry limit.");
+            }
             m_cache->m_log.warn("storage service indicates the record is out of sync, updating with a fresh copy...");
             ver = m_cache->m_storage->readText(getID(), "session", &record, nullptr);
             if (!ver) {
@@ -926,7 +946,7 @@ SSCache::SSCache(const DOMElement* e)
             if (m_storage)
                 m_log.info("bound to StorageService (%s)", ssid.c_str());
             else
-                m_log.warn("specified StorageService (%s) not found", ssid.c_str());
+                throw ConfigurationException("SessionCache unable to locate StorageService ($1), check configuration.", params(1, ssid.c_str()));
         }
         if (!m_storage) {
             m_storage = conf.getServiceProvider()->getStorageService(nullptr);
@@ -942,7 +962,7 @@ SSCache::SSCache(const DOMElement* e)
             if (m_storage_lite)
                 m_log.info("bound to 'lite' StorageService (%s)", ssid.c_str());
             else
-                m_log.warn("specified 'lite' StorageService (%s) not found", ssid.c_str());
+                throw ConfigurationException("SessionCache unable to locate 'lite' StorageService ($1), check configuration.", params(1, ssid.c_str()));
         }
         if (!m_storage_lite) {
             m_log.info("StorageService for 'lite' use not set, using standard StorageService");
@@ -1018,8 +1038,11 @@ void SSCache::test()
     m_storage->deleteString("SessionCacheTest", temp.get());
 }
 
-void SSCache::insert(const char* key, time_t expires, const char* name, const char* index)
+void SSCache::insert(const char* key, time_t expires, const char* name, const char* index, short attempts)
 {
+    if (attempts > 10)
+        throw IOException("Exceeded retry limit.");
+
     string dup;
     unsigned int storageLimit = m_storage_lite->getCapabilities().getKeySize();
     if (strlen(name) > storageLimit) {
@@ -1061,12 +1084,12 @@ void SSCache::insert(const char* key, time_t expires, const char* name, const ch
         ver = m_storage_lite->updateText("NameID", name, out.str().c_str(), max(expires, recordexp), ver);
         if (ver <= 0) {
             // Out of sync, or went missing, so retry.
-            return insert(key, expires, name, index);
+            return insert(key, expires, name, index, attempts + 1);
         }
     }
     else if (!m_storage_lite->createText("NameID", name, out.str().c_str(), expires)) {
         // Hit a dup, so just retry, hopefully hitting the other branch.
-        return insert(key, expires, name, index);
+        return insert(key, expires, name, index, attempts + 1);
     }
 }
 
@@ -1153,10 +1176,7 @@ void SSCache::insert(
     string caddr(httpRequest.getRemoteAddr());
     if (!caddr.empty()) {
         DDF addrobj = obj.addmember("client_addr").structure();
-        if (caddr.find(':') != string::npos)
-            addrobj.addmember("6").string(caddr.c_str());
-        else
-            addrobj.addmember("4").string(caddr.c_str());
+        addrobj.addmember(getAddressFamily(caddr.c_str())).string(caddr.c_str());
     }
 
     if (issuer)
@@ -1293,13 +1313,14 @@ bool SSCache::matches(
     return false;
 }
 
-vector<string>::size_type SSCache::logout(
+vector<string>::size_type SSCache::_logout(
     const Application& app,
     const saml2md::EntityDescriptor* issuer,
     const saml2::NameID& nameid,
     const set<string>* indexes,
     time_t expires,
-    vector<string>& sessionsKilled
+    vector<string>& sessionsKilled,
+    short attempts
     )
 {
 #ifdef _DEBUG
@@ -1308,6 +1329,8 @@ vector<string>::size_type SSCache::logout(
 
     if (!m_storage)
         throw ConfigurationException("SessionCache logout requires a StorageService.");
+    else if (attempts > 10)
+        throw IOException("Exceeded retry limit.");
 
     auto_ptr_char entityID(issuer ? issuer->getEntityID() : nullptr);
     auto_ptr_char name(nameid.getName());
@@ -1363,12 +1386,12 @@ vector<string>::size_type SSCache::logout(
             ver = m_storage_lite->updateText("Logout", name.get(), lout.str().c_str(), max(expires, oldexp), ver);
             if (ver <= 0) {
                 // Out of sync, or went missing, so retry.
-                return logout(app, issuer, nameid, indexes, expires, sessionsKilled);
+                return _logout(app, issuer, nameid, indexes, expires, sessionsKilled, attempts + 1);
             }
         }
         else if (!m_storage_lite->createText("Logout", name.get(), lout.str().c_str(), expires)) {
             // Hit a dup, so just retry, hopefully hitting the other branch.
-            return logout(app, issuer, nameid, indexes, expires, sessionsKilled);
+            return _logout(app, issuer, nameid, indexes, expires, sessionsKilled, attempts + 1);
         }
 
         obj.destroy();
@@ -1996,6 +2019,7 @@ void SSCache::receive(DDF& in, ostream& out)
 
         // We may need to write back a new address into the session using a versioned update loop.
         if (client_addr) {
+            short attempts = 0;
             m_log.info("binding session (%s) to new client address (%s)", key, client_addr);
             do {
                 // We have to reconstitute the session object ourselves.
@@ -2004,11 +2028,7 @@ void SSCache::receive(DDF& in, ostream& out)
                 istringstream src(record);
                 src >> sessionobj;
                 ver = sessionobj["version"].integer();
-                const char* saddr = nullptr;
-                if (strchr(client_addr, ':'))
-                    saddr = sessionobj["client_addr"]["6"].string();
-                else
-                    saddr = sessionobj["client_addr"]["4"].string();
+                const char* saddr = sessionobj["client_addr"][getAddressFamily(client_addr)].string();
                 if (saddr) {
                     // Something snuck in and bound the session to this address type, so it better match what we have.
                     if (!XMLString::equals(saddr, client_addr)) {
@@ -2022,10 +2042,7 @@ void SSCache::receive(DDF& in, ostream& out)
                 }
                 else {
                     // Bind it into the session.
-                    if (strchr(client_addr, ':'))
-                        sessionobj["client_addr"].addmember("6").string(client_addr);
-                    else
-                        sessionobj["client_addr"].addmember("4").string(client_addr);
+                    sessionobj["client_addr"].addmember(getAddressFamily(client_addr)).string(client_addr);
                 }
 
                 // Tentatively increment the version.
@@ -2043,6 +2060,10 @@ void SSCache::receive(DDF& in, ostream& out)
                 }
                 if (ver < 0) {
                     // Out of sync.
+                    if (++attempts > 10) {
+                        m_log.error("failed to bind client address, update attempts exceeded limit");
+                        throw IOException("Unable to update stored session, exceeded retry limit.");
+                    }
                     m_log.warn("storage service indicates the record is out of sync, updating with a fresh copy...");
                     sessionobj["version"].integer(sessionobj["version"].integer() - 1);
                     ver = m_storage->readText(key, "session", &record);