Tagging 2.4.2 release.
[shibboleth/sp.git] / shibsp / handler / impl / AbstractHandler.cpp
index da9fc01..4944e0c 100644 (file)
@@ -1,5 +1,5 @@
 /*
- *  Copyright 2001-2010 Internet2
+ *  Copyright 2001-2011 Internet2
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -74,6 +74,7 @@ namespace shibsp {
     SHIBSP_DLLLOCAL PluginManager< Handler,string,pair<const DOMElement*,const char*> >::Factory SAML2LogoutFactory;
     SHIBSP_DLLLOCAL PluginManager< Handler,string,pair<const DOMElement*,const char*> >::Factory SAML2NameIDMgmtFactory;
     SHIBSP_DLLLOCAL PluginManager< Handler,string,pair<const DOMElement*,const char*> >::Factory AssertionLookupFactory;
+    SHIBSP_DLLLOCAL PluginManager< Handler,string,pair<const DOMElement*,const char*> >::Factory DiscoveryFeedFactory;
     SHIBSP_DLLLOCAL PluginManager< Handler,string,pair<const DOMElement*,const char*> >::Factory MetadataGeneratorFactory;
     SHIBSP_DLLLOCAL PluginManager< Handler,string,pair<const DOMElement*,const char*> >::Factory StatusHandlerFactory;
     SHIBSP_DLLLOCAL PluginManager< Handler,string,pair<const DOMElement*,const char*> >::Factory SessionHandlerFactory;
@@ -111,7 +112,69 @@ namespace shibsp {
         }
     }
 
+    void SHIBSP_DLLLOCAL limitRelayState(
+        Category& log, const Application& application, const HTTPRequest& httpRequest, const char* relayState
+        ) {
+        const PropertySet* sessionProps = application.getPropertySet("Sessions");
+        if (sessionProps) {
+            pair<bool,const char*> relayStateLimit = sessionProps->getString("relayStateLimit");
+            if (relayStateLimit.first && strcmp(relayStateLimit.second, "none")) {
+                vector<string> whitelist;
+                if (!strcmp(relayStateLimit.second, "exact")) {
+                    // Scheme and hostname have to match.
+                    if (!strcmp(httpRequest.getScheme(), "https") && httpRequest.getPort() == 443) {
+                        whitelist.push_back(string("https://") + httpRequest.getHostname() + '/');
+                    }
+                    else if (!strcmp(httpRequest.getScheme(), "http") && httpRequest.getPort() == 80) {
+                        whitelist.push_back(string("http://") + httpRequest.getHostname() + '/');
+                    }
+                    ostringstream portstr;
+                    portstr << httpRequest.getPort();
+                    whitelist.push_back(string(httpRequest.getScheme()) + "://" + httpRequest.getHostname() + ':' + portstr.str() + '/');
+                }
+                else if (!strcmp(relayStateLimit.second, "host")) {
+                    // Allow any scheme or port.
+                    whitelist.push_back(string("https://") + httpRequest.getHostname() + '/');
+                    whitelist.push_back(string("http://") + httpRequest.getHostname() + '/');
+                    whitelist.push_back(string("https://") + httpRequest.getHostname() + ':');
+                    whitelist.push_back(string("http://") + httpRequest.getHostname() + ':');
+                }
+                else if (!strcmp(relayStateLimit.second, "whitelist")) {
+                    // Literal set of comparisons to use.
+                    pair<bool,const char*> whitelistval = sessionProps->getString("relayStateWhitelist");
+                    if (whitelistval.first) {
+#ifdef HAVE_STRTOK_R
+                        char* pos=nullptr;
+                        const char* token = strtok_r(const_cast<char*>(whitelistval.second), " ", &pos);
+#else
+                        const char* token = strtok(const_cast<char*>(whitelistval.second), " ");
+#endif
+                        while (token) {
+                            whitelist.push_back(token);
+#ifdef HAVE_STRTOK_R
+                            token = strtok_r(nullptr, " ", &pos);
+#else
+                            token = strtok(nullptr, " ");
+#endif
+                        }
+                    }
+                }
+                else {
+                    log.warn("unrecognized relayStateLimit policy (%s), blocked redirect to (%s)", relayStateLimit.second, relayState);
+                    throw opensaml::SecurityPolicyException("Unrecognized relayStateLimit setting.");
+                }
 
+                for (vector<string>::const_iterator w = whitelist.begin(); w != whitelist.end(); ++w) {
+                    if (XMLString::startsWithI(relayState, w->c_str())) {
+                        return;
+                    }
+                }
+
+                log.warn("relayStateLimit policy (%s), blocked redirect to (%s)", relayStateLimit.second, relayState);
+                throw opensaml::SecurityPolicyException("Blocked unacceptable redirect location.");
+            }
+        }
+    }
 };
 
 void SHIBSP_API shibsp::registerHandlers()
@@ -131,6 +194,7 @@ void SHIBSP_API shibsp::registerHandlers()
     conf.ArtifactResolutionServiceManager.registerFactory(SAML20_BINDING_SOAP, SAML2ArtifactResolutionFactory);
 
     conf.HandlerManager.registerFactory(SAML20_BINDING_URI, AssertionLookupFactory);
+    conf.HandlerManager.registerFactory(DISCOVERY_FEED_HANDLER, DiscoveryFeedFactory);
     conf.HandlerManager.registerFactory(METADATA_GENERATOR_HANDLER, MetadataGeneratorFactory);
     conf.HandlerManager.registerFactory(STATUS_HANDLER, StatusHandlerFactory);
     conf.HandlerManager.registerFactory(SESSION_HANDLER, SessionHandlerFactory);
@@ -158,6 +222,14 @@ Handler::~Handler()
 {
 }
 
+#ifndef SHIBSP_LITE
+
+void Handler::generateMetadata(SPSSODescriptor& role, const char* handlerURL) const
+{
+}
+
+#endif
+
 const XMLCh* Handler::getProtocolFamily() const
 {
     return nullptr;
@@ -215,7 +287,7 @@ void Handler::preserveRelayState(const Application& application, HTTPResponse& r
                     StorageService* storage = application.getServiceProvider().getStorageService(mech.second);
                     if (storage) {
                         string rsKey;
-                        generateRandomHex(rsKey,5);
+                        generateRandomHex(rsKey,32);
                         if (!storage->createString("RelayState", rsKey.c_str(), relayState.c_str(), time(nullptr) + 600))
                             throw IOException("Attempted to insert duplicate storage key.");
                         relayState = string(mech.second-3) + ':' + rsKey;
@@ -556,7 +628,7 @@ void AbstractHandler::preservePostData(
             if (storage) {
                 // Use a random key
                 string rsKey;
-                SAMLConfig::getConfig().generateRandomBytes(rsKey,20);
+                SAMLConfig::getConfig().generateRandomBytes(rsKey,32);
                 rsKey = SAMLArtifact::toHex(rsKey);
                 ostringstream out;
                 out << postData;