Add a redundant safety check to insert
[shibboleth/cpp-sp.git] / shibsp / impl / StorageServiceSessionCache.cpp
index 20c02d0..0c5ffe8 100644 (file)
@@ -75,7 +75,16 @@ using namespace xmltooling;
 using namespace boost;
 using namespace std;
 
-namespace shibsp {
+namespace {
+
+    // Allows the cache to bind sessions to multiple client address
+    // families based on whatever this function returns.
+    static const char* getAddressFamily(const char* addr) {
+        if (strchr(addr, ':'))
+            return "6";
+        else
+            return "4";
+    }
 
     class StoredSession;
     class SSCache : public SessionCacheEx
@@ -133,7 +142,9 @@ namespace shibsp {
             const set<string>* indexes,
             time_t expires,
             vector<string>& sessions
-            );
+            ) {
+            return _logout(app, issuer, nameid, indexes, expires, sessions, 0);
+        }
         bool matches(
             const Application& app,
             const HTTPRequest& request,
@@ -203,7 +214,16 @@ namespace shibsp {
     private:
 #ifndef SHIBSP_LITE
         // maintain back-mappings of NameID/SessionIndex -> session key
-        void insert(const char* key, time_t expires, const char* name, const char* index);
+        void insert(const char* key, time_t expires, const char* name, const char* index, short attempts=0);
+        vector<string>::size_type _logout(
+            const Application& app,
+            const EntityDescriptor* issuer,
+            const saml2::NameID& nameid,
+            const set<string>* indexes,
+            time_t expires,
+            vector<string>& sessions,
+            short attempts
+            );
         bool stronglyMatches(const XMLCh* idp, const XMLCh* sp, const saml2::NameID& n1, const saml2::NameID& n2) const;
         LogoutEvent* newLogoutEvent(const Application& app) const;
 
@@ -231,6 +251,15 @@ namespace shibsp {
     {
     public:
         StoredSession(SSCache* cache, DDF& obj) : m_obj(obj), m_cache(cache), m_expires(0), m_lastAccess(time(nullptr)) {
+            // Check for old address format.
+            if (m_obj["client_addr"].isstring()) {
+                const char* saddr = m_obj["client_addr"].string();
+                DDF addrobj = m_obj["client_addr"].structure();
+                if (saddr && *saddr) {
+                    addrobj.addmember(getAddressFamily(saddr)).string(saddr);
+                }
+            }
+
             auto_ptr_XMLCh exp(m_obj["expires"].string());
             if (exp.get()) {
                 DateTime iso(exp.get());
@@ -278,8 +307,21 @@ namespace shibsp {
             return m_obj["application_id"].string();
         }
         const char* getClientAddress() const {
-            return m_obj["client_addr"].string();
+            return m_obj["client_addr"].first().string();
+        }
+
+        const char* getClientAddress(const char* family) const {
+            if (family)
+                return m_obj["client_addr"][family].string();
+            return nullptr;
         }
+        void setClientAddress(const char* client_addr) {
+            DDF obj = m_obj["client_addr"];
+            if (!obj.isstruct())
+                obj = m_obj.addmember("client_addr").structure();
+            obj.addmember(getAddressFamily(client_addr)).string(client_addr);
+        }
+
         const char* getEntityID() const {
             return m_obj["entity_id"].string();
         }
@@ -419,16 +461,23 @@ void StoredSession::validate(const Application& app, const char* client_addr, ti
 
     // Address check?
     if (client_addr) {
-        if (!XMLString::equals(getClientAddress(),client_addr)) {
-            m_cache->m_log.warn("client address mismatch, client (%s), session (%s)", client_addr, getClientAddress());
-            throw RetryableProfileException(
-                "Your IP address ($1) does not match the address recorded at the time the session was established.",
-                params(1,client_addr)
-                );
+        const char* saddr = getClientAddress(getAddressFamily(client_addr));
+        if (saddr && *saddr) {
+            if (!XMLString::equals(saddr, client_addr)) {
+                m_cache->m_log.warn("client address mismatch, client (%s), session (%s)", client_addr, saddr);
+                throw RetryableProfileException(
+                    "Your IP address ($1) does not match the address recorded at the time the session was established.",
+                    params(1, client_addr)
+                    );
+            }
+            client_addr = nullptr;  // clear out parameter as signal that session need not be updated below
+        }
+        else {
+            m_cache->m_log.info("session (%s) not yet bound to client address type, binding it to (%s)", getID(), client_addr);
         }
     }
 
-    if (!timeout)
+    if (!timeout && !client_addr)
         return;
 
     if (!SPConfig::getConfig().isEnabled(SPConfig::OutOfProcess)) {
@@ -438,13 +487,15 @@ void StoredSession::validate(const Application& app, const char* client_addr, ti
         in.addmember("key").string(getID());
         in.addmember("version").integer(m_obj["version"].integer());
         in.addmember("application_id").string(app.getId());
-        if (*timeout) {
+        if (client_addr)    // signals we need to bind an additional address to the session
+            in.addmember("client_addr").string(client_addr);
+        if (timeout && *timeout) {
             // On 64-bit Windows, time_t doesn't fit in a long, so I'm using ISO timestamps.
 #ifndef HAVE_GMTIME_R
-            struct tm* ptime=gmtime(timeout);
+            struct tm* ptime = gmtime(timeout);
 #else
             struct tm res;
-            struct tm* ptime=gmtime_r(timeout,&res);
+            struct tm* ptime = gmtime_r(timeout,&res);
 #endif
             char timebuf[32];
             strftime(timebuf,32,"%Y-%m-%dT%H:%M:%SZ",ptime);
@@ -461,6 +512,7 @@ void StoredSession::validate(const Application& app, const char* client_addr, ti
 
         if (out.isstruct()) {
             // We got an updated record back.
+            m_cache->m_log.debug("session updated, reconstituting it");
             m_ids.clear();
             for_each(m_attributes.begin(), m_attributes.end(), xmltooling::cleanup<Attribute>());
             m_attributes.clear();
@@ -474,7 +526,7 @@ void StoredSession::validate(const Application& app, const char* client_addr, ti
         if (!m_cache->m_storage)
             throw ConfigurationException("Session touch requires a StorageService.");
 
-        // Do a versioned read.
+        // Versioned read, since we already have the data in hand if it's current.
         string record;
         time_t lastAccess;
         int curver = m_obj["version"].integer();
@@ -484,20 +536,22 @@ void StoredSession::validate(const Application& app, const char* client_addr, ti
             throw RetryableProfileException("Your session has expired, and you must re-authenticate.");
         }
 
-        // Adjust for expiration to recover last access time and check timeout.
-        unsigned long cacheTimeout = m_cache->getCacheTimeout(app);
-        lastAccess -= cacheTimeout;
-        if (*timeout > 0 && now - lastAccess >= *timeout) {
-            m_cache->m_log.info("session timed out (ID: %s)", getID());
-            throw RetryableProfileException("Your session has expired, and you must re-authenticate.");
-        }
+        if (timeout) {
+            // Adjust for expiration to recover last access time and check timeout.
+            unsigned long cacheTimeout = m_cache->getCacheTimeout(app);
+            lastAccess -= cacheTimeout;
+            if (*timeout > 0 && now - lastAccess >= *timeout) {
+                m_cache->m_log.info("session timed out (ID: %s)", getID());
+                throw RetryableProfileException("Your session has expired, and you must re-authenticate.");
+            }
 
-        // Update storage expiration, if possible.
-        try {
-            m_cache->m_storage->updateContext(getID(), now + cacheTimeout);
-        }
-        catch (std::exception& ex) {
-            m_cache->m_log.error("failed to update session expiration: %s", ex.what());
+            // Update storage expiration, if possible.
+            try {
+                m_cache->m_storage->updateContext(getID(), now + cacheTimeout);
+            }
+            catch (std::exception& ex) {
+                m_cache->m_log.error("failed to update session expiration: %s", ex.what());
+            }
         }
 
         if (ver > curver) {
@@ -512,6 +566,82 @@ void StoredSession::validate(const Application& app, const char* client_addr, ti
             m_obj.destroy();
             m_obj = newobj;
         }
+
+        // We may need to write back a new address into the session using a versioned update loop.
+        if (client_addr) {
+            short attempts = 0;
+            do {
+                const char* saddr = getClientAddress(getAddressFamily(client_addr));
+                if (saddr) {
+                    // Something snuck in and bound the session to this address type, so it better match what we have.
+                    if (!XMLString::equals(saddr, client_addr)) {
+                        m_cache->m_log.warn("client address mismatch, client (%s), session (%s)", client_addr, saddr);
+                        throw RetryableProfileException(
+                            "Your IP address ($1) does not match the address recorded at the time the session was established.",
+                            params(1, client_addr)
+                            );
+                    }
+                    break;  // No need to update.
+                }
+                else {
+                    // Bind it into the session.
+                    setClientAddress(client_addr);
+                }
+
+                // Tentatively increment the version.
+                m_obj["version"].integer(m_obj["version"].integer() + 1);
+
+                ostringstream str;
+                str << m_obj;
+                record = str.str();
+
+                try {
+                    ver = m_cache->m_storage->updateText(getID(), "session", record.c_str(), 0, m_obj["version"].integer() - 1);
+                }
+                catch (std::exception&) {
+                    m_obj["version"].integer(m_obj["version"].integer() - 1);
+                    throw;
+                }
+
+                if (ver <= 0) {
+                    m_obj["version"].integer(m_obj["version"].integer() - 1);
+                }
+
+                if (!ver) {
+                    // Fatal problem with update.
+                    m_cache->m_log.error("updateText failed on StorageService for session (%s)", getID());
+                    throw IOException("Unable to update stored session.");
+                }
+                else if (ver < 0) {
+                    // Out of sync.
+                    if (++attempts > 10) {
+                        m_cache->m_log.error("failed to bind client address, update attempts exceeded limit");
+                        throw IOException("Unable to update stored session, exceeded retry limit.");
+                    }
+                    m_cache->m_log.warn("storage service indicates the record is out of sync, updating with a fresh copy...");
+                    ver = m_cache->m_storage->readText(getID(), "session", &record, nullptr);
+                    if (!ver) {
+                        m_cache->m_log.error("readText failed on StorageService for session (%s)", getID());
+                        throw IOException("Unable to read back stored session.");
+                    }
+
+                    // Reset object.
+                    DDF newobj;
+                    istringstream in(record);
+                    in >> newobj;
+
+                    m_ids.clear();
+                    for_each(m_attributes.begin(), m_attributes.end(), xmltooling::cleanup<Attribute>());
+                    m_attributes.clear();
+                    m_attributeIndex.clear();
+                    newobj["version"].integer(ver);
+                    m_obj.destroy();
+                    m_obj = newobj;
+
+                    ver = -1;
+                }
+            } while (ver < 0); // negative indicates a sync issue so we retry
+        }
 #else
         throw ConfigurationException("Session touch requires a StorageService.");
 #endif
@@ -534,6 +664,7 @@ void StoredSession::addAttributes(const vector<Attribute*>& attributes)
     m_cache->m_log.debug("adding attributes to session (%s)", getID());
 
     int ver;
+    short attempts = 0;
     do {
         DDF attr;
         DDF attrs = m_obj["attributes"];
@@ -576,6 +707,10 @@ void StoredSession::addAttributes(const vector<Attribute*>& attributes)
         }
         else if (ver < 0) {
             // Out of sync.
+            if (++attempts > 10) {
+                m_cache->m_log.error("failed to update stored session, update attempts exceeded limit");
+                throw IOException("Unable to update stored session, exceeded retry limit.");
+            }
             m_cache->m_log.warn("storage service indicates the record is out of sync, updating with a fresh copy...");
             ver = m_cache->m_storage->readText(getID(), "session", &record, nullptr);
             if (!ver) {
@@ -661,12 +796,13 @@ void StoredSession::addAssertion(Assertion* assertion)
         throw IOException("Attempted to insert duplicate assertion ID into session.");
 
     int ver;
+    short attempts = 0;
     do {
         DDF token = DDF(nullptr).string(id.get());
         m_obj["assertions"].add(token);
 
         // Tentatively increment the version.
-        m_obj["version"].integer(m_obj["version"].integer()+1);
+        m_obj["version"].integer(m_obj["version"].integer() + 1);
 
         ostringstream str;
         str << m_obj;
@@ -677,7 +813,7 @@ void StoredSession::addAssertion(Assertion* assertion)
         }
         catch (std::exception&) {
             token.destroy();
-            m_obj["version"].integer(m_obj["version"].integer()-1);
+            m_obj["version"].integer(m_obj["version"].integer() - 1);
             m_cache->m_storage->deleteText(getID(), id.get());
             throw;
         }
@@ -694,6 +830,10 @@ void StoredSession::addAssertion(Assertion* assertion)
         }
         else if (ver < 0) {
             // Out of sync.
+            if (++attempts > 10) {
+                m_cache->m_log.error("failed to update stored session, update attempts exceeded limit");
+                throw IOException("Unable to update stored session, exceeded retry limit.");
+            }
             m_cache->m_log.warn("storage service indicates the record is out of sync, updating with a fresh copy...");
             ver = m_cache->m_storage->readText(getID(), "session", &record, nullptr);
             if (!ver) {
@@ -786,7 +926,10 @@ SSCache::SSCache(const DOMElement* e)
     static const XMLCh _StorageService[] =      UNICODE_LITERAL_14(S,t,o,r,a,g,e,S,e,r,v,i,c,e);
     static const XMLCh _StorageServiceLite[] =  UNICODE_LITERAL_18(S,t,o,r,a,g,e,S,e,r,v,i,c,e,L,i,t,e);
 
-    m_cacheTimeout = XMLHelper::getAttrInt(e, 0, cacheTimeout);
+    if (e && e->hasAttributeNS(nullptr, cacheTimeout)) {
+        m_log.warn("cacheTimeout property is deprecated in favor of cacheAllowance (see documentation)");
+        m_cacheTimeout = XMLHelper::getAttrInt(e, 0, cacheTimeout);
+    }
     m_cacheAllowance = XMLHelper::getAttrInt(e, 0, cacheAllowance);
     if (inproc)
         m_inprocTimeout = XMLHelper::getAttrInt(e, 900, inprocTimeout);
@@ -803,7 +946,7 @@ SSCache::SSCache(const DOMElement* e)
             if (m_storage)
                 m_log.info("bound to StorageService (%s)", ssid.c_str());
             else
-                m_log.warn("specified StorageService (%s) not found", ssid.c_str());
+                throw ConfigurationException("SessionCache unable to locate StorageService ($1), check configuration.", params(1, ssid.c_str()));
         }
         if (!m_storage) {
             m_storage = conf.getServiceProvider()->getStorageService(nullptr);
@@ -819,7 +962,7 @@ SSCache::SSCache(const DOMElement* e)
             if (m_storage_lite)
                 m_log.info("bound to 'lite' StorageService (%s)", ssid.c_str());
             else
-                m_log.warn("specified 'lite' StorageService (%s) not found", ssid.c_str());
+                throw ConfigurationException("SessionCache unable to locate 'lite' StorageService ($1), check configuration.", params(1, ssid.c_str()));
         }
         if (!m_storage_lite) {
             m_log.info("StorageService for 'lite' use not set, using standard StorageService");
@@ -888,13 +1031,24 @@ SSCache::~SSCache()
 
 void SSCache::test()
 {
-    auto_ptr_char temp(SAMLConfig::getConfig().generateIdentifier());
+    XMLCh* wide = SAMLConfig::getConfig().generateIdentifier();
+    auto_ptr_char temp(wide);
+    XMLString::release(&wide);
     m_storage->createString("SessionCacheTest", temp.get(), "Test", time(nullptr) + 60);
     m_storage->deleteString("SessionCacheTest", temp.get());
 }
 
-void SSCache::insert(const char* key, time_t expires, const char* name, const char* index)
+void SSCache::insert(const char* key, time_t expires, const char* name, const char* index, short attempts)
 {
+    if (attempts > 10) {
+        throw IOException("Exceeded retry limit.");
+    }
+
+    if (!name || !*name) {
+        m_log.warn("NameID value was empty or null, ignoring request to store for logout");
+        return;
+    }
+
     string dup;
     unsigned int storageLimit = m_storage_lite->getCapabilities().getKeySize();
     if (strlen(name) > storageLimit) {
@@ -936,12 +1090,12 @@ void SSCache::insert(const char* key, time_t expires, const char* name, const ch
         ver = m_storage_lite->updateText("NameID", name, out.str().c_str(), max(expires, recordexp), ver);
         if (ver <= 0) {
             // Out of sync, or went missing, so retry.
-            return insert(key, expires, name, index);
+            return insert(key, expires, name, index, attempts + 1);
         }
     }
     else if (!m_storage_lite->createText("NameID", name, out.str().c_str(), expires)) {
         // Hit a dup, so just retry, hopefully hitting the other branch.
-        return insert(key, expires, name, index);
+        return insert(key, expires, name, index, attempts + 1);
     }
 }
 
@@ -1025,7 +1179,12 @@ void SSCache::insert(
     strftime(timebuf,32,"%Y-%m-%dT%H:%M:%SZ",ptime);
     obj.addmember("expires").string(timebuf);
 
-    obj.addmember("client_addr").string(httpRequest.getRemoteAddr().c_str());
+    string caddr(httpRequest.getRemoteAddr());
+    if (!caddr.empty()) {
+        DDF addrobj = obj.addmember("client_addr").structure();
+        addrobj.addmember(getAddressFamily(caddr.c_str())).string(caddr.c_str());
+    }
+
     if (issuer)
         obj.addmember("entity_id").string(entity_id.get());
     if (protocol) {
@@ -1080,7 +1239,8 @@ void SSCache::insert(
         throw FatalProfileException("Attempted to create a session with a duplicate key.");
 
     // Store the reverse mapping for logout.
-    if (nameid && m_reverseIndex && (m_excludedNames.size() == 0 || m_excludedNames.count(nameid->getName()) == 0)) {
+    if (name.get() && *name.get() && m_reverseIndex
+            && (m_excludedNames.size() == 0 || m_excludedNames.count(nameid->getName()) == 0)) {
         try {
             insert(key.get(), expires, name.get(), index.get());
         }
@@ -1160,13 +1320,14 @@ bool SSCache::matches(
     return false;
 }
 
-vector<string>::size_type SSCache::logout(
+vector<string>::size_type SSCache::_logout(
     const Application& app,
     const saml2md::EntityDescriptor* issuer,
     const saml2::NameID& nameid,
     const set<string>* indexes,
     time_t expires,
-    vector<string>& sessionsKilled
+    vector<string>& sessionsKilled,
+    short attempts
     )
 {
 #ifdef _DEBUG
@@ -1175,6 +1336,8 @@ vector<string>::size_type SSCache::logout(
 
     if (!m_storage)
         throw ConfigurationException("SessionCache logout requires a StorageService.");
+    else if (attempts > 10)
+        throw IOException("Exceeded retry limit.");
 
     auto_ptr_char entityID(issuer ? issuer->getEntityID() : nullptr);
     auto_ptr_char name(nameid.getName());
@@ -1230,12 +1393,12 @@ vector<string>::size_type SSCache::logout(
             ver = m_storage_lite->updateText("Logout", name.get(), lout.str().c_str(), max(expires, oldexp), ver);
             if (ver <= 0) {
                 // Out of sync, or went missing, so retry.
-                return logout(app, issuer, nameid, indexes, expires, sessionsKilled);
+                return _logout(app, issuer, nameid, indexes, expires, sessionsKilled, attempts + 1);
             }
         }
         else if (!m_storage_lite->createText("Logout", name.get(), lout.str().c_str(), expires)) {
             // Hit a dup, so just retry, hopefully hitting the other branch.
-            return logout(app, issuer, nameid, indexes, expires, sessionsKilled);
+            return _logout(app, issuer, nameid, indexes, expires, sessionsKilled, attempts + 1);
         }
 
         obj.destroy();
@@ -1774,6 +1937,7 @@ void SSCache::receive(DDF& in, ostream& out)
         string record;
         time_t lastAccess;
         if (!m_storage->readText(key, "session", &record, &lastAccess)) {
+            m_log.debug("session not found in cache (%s)", key);
             DDF ret(nullptr);
             DDFJanitor jan(ret);
             out << ret;
@@ -1821,14 +1985,15 @@ void SSCache::receive(DDF& in, ostream& out)
         const char* key=in["key"].string();
         if (!key)
             throw ListenerException("Required parameters missing for session check.");
+        const char* client_addr = in["client_addr"].string();
 
-        // Do a versioned read.
+        // Do a read. May be unversioned if we need to bind a new client address.
         string record;
         time_t lastAccess;
         int curver = in["version"].integer();
-        int ver = m_storage->readText(key, "session", &record, &lastAccess, curver);
+        int ver = m_storage->readText(key, "session", &record, &lastAccess, client_addr ? 0 : curver);
         if (ver == 0) {
-            m_log.warn("unsuccessful versioned read of session (ID: %s), caches out of sync?", key);
+            m_log.warn("unsuccessful read of session (ID: %s), caches out of sync?", key);
             throw RetryableProfileException("Your session has expired, and you must re-authenticate.");
         }
 
@@ -1859,6 +2024,65 @@ void SSCache::receive(DDF& in, ostream& out)
             m_log.error("failed to update session expiration: %s", ex.what());
         }
 
+        // We may need to write back a new address into the session using a versioned update loop.
+        if (client_addr) {
+            short attempts = 0;
+            m_log.info("binding session (%s) to new client address (%s)", key, client_addr);
+            do {
+                // We have to reconstitute the session object ourselves.
+                DDF sessionobj;
+                DDFJanitor sessionjan(sessionobj);
+                istringstream src(record);
+                src >> sessionobj;
+                ver = sessionobj["version"].integer();
+                const char* saddr = sessionobj["client_addr"][getAddressFamily(client_addr)].string();
+                if (saddr) {
+                    // Something snuck in and bound the session to this address type, so it better match what we have.
+                    if (!XMLString::equals(saddr, client_addr)) {
+                        m_log.warn("client address mismatch, client (%s), session (%s)", client_addr, saddr);
+                        throw RetryableProfileException(
+                            "Your IP address ($1) does not match the address recorded at the time the session was established.",
+                            params(1, client_addr)
+                            );
+                    }
+                    break;  // No need to update.
+                }
+                else {
+                    // Bind it into the session.
+                    sessionobj["client_addr"].addmember(getAddressFamily(client_addr)).string(client_addr);
+                }
+
+                // Tentatively increment the version.
+                sessionobj["version"].integer(sessionobj["version"].integer() + 1);
+
+                ostringstream str;
+                str << sessionobj;
+                record = str.str();
+
+                ver = m_storage->updateText(key, "session", record.c_str(), 0, ver);
+                if (!ver) {
+                    // Fatal problem with update.
+                    m_log.error("updateText failed on StorageService for session (%s)", key);
+                    throw IOException("Unable to update stored session.");
+                }
+                if (ver < 0) {
+                    // Out of sync.
+                    if (++attempts > 10) {
+                        m_log.error("failed to bind client address, update attempts exceeded limit");
+                        throw IOException("Unable to update stored session, exceeded retry limit.");
+                    }
+                    m_log.warn("storage service indicates the record is out of sync, updating with a fresh copy...");
+                    sessionobj["version"].integer(sessionobj["version"].integer() - 1);
+                    ver = m_storage->readText(key, "session", &record);
+                    if (!ver) {
+                        m_log.error("readText failed on StorageService for session (%s)", key);
+                        throw IOException("Unable to read back stored session.");
+                    }
+                    ver = -1;
+                }
+            } while (ver < 0); // negative indicates a sync issue so we retry
+        }
+
         if (ver > curver) {
             // Send the record back.
             out << record;