https://issues.shibboleth.net/jira/browse/SSPCPP-254
[shibboleth/cpp-sp.git] / shibsp / handler / impl / AbstractHandler.cpp
index aa86b54..07319d3 100644 (file)
@@ -67,19 +67,35 @@ using namespace xercesc;
 using namespace std;
 
 namespace shibsp {
+
     SHIBSP_DLLLOCAL PluginManager< Handler,string,pair<const DOMElement*,const char*> >::Factory SAML1ConsumerFactory;
     SHIBSP_DLLLOCAL PluginManager< Handler,string,pair<const DOMElement*,const char*> >::Factory SAML2ConsumerFactory;
     SHIBSP_DLLLOCAL PluginManager< Handler,string,pair<const DOMElement*,const char*> >::Factory SAML2ArtifactResolutionFactory;
-    SHIBSP_DLLLOCAL PluginManager< Handler,string,pair<const DOMElement*,const char*> >::Factory ChainingLogoutInitiatorFactory;
-    SHIBSP_DLLLOCAL PluginManager< Handler,string,pair<const DOMElement*,const char*> >::Factory LocalLogoutInitiatorFactory;
-    SHIBSP_DLLLOCAL PluginManager< Handler,string,pair<const DOMElement*,const char*> >::Factory SAML2LogoutInitiatorFactory;
     SHIBSP_DLLLOCAL PluginManager< Handler,string,pair<const DOMElement*,const char*> >::Factory SAML2LogoutFactory;
     SHIBSP_DLLLOCAL PluginManager< Handler,string,pair<const DOMElement*,const char*> >::Factory SAML2NameIDMgmtFactory;
     SHIBSP_DLLLOCAL PluginManager< Handler,string,pair<const DOMElement*,const char*> >::Factory AssertionLookupFactory;
+    SHIBSP_DLLLOCAL PluginManager< Handler,string,pair<const DOMElement*,const char*> >::Factory DiscoveryFeedFactory;
     SHIBSP_DLLLOCAL PluginManager< Handler,string,pair<const DOMElement*,const char*> >::Factory MetadataGeneratorFactory;
     SHIBSP_DLLLOCAL PluginManager< Handler,string,pair<const DOMElement*,const char*> >::Factory StatusHandlerFactory;
     SHIBSP_DLLLOCAL PluginManager< Handler,string,pair<const DOMElement*,const char*> >::Factory SessionHandlerFactory;
 
+    void SHIBSP_DLLLOCAL absolutize(const HTTPRequest& request, string& url) {
+        if (url.empty())
+            url = '/';
+        if (url[0] == '/') {
+            // Compute a URL to the root of the site.
+            int port = request.getPort();
+            const char* scheme = request.getScheme();
+            string root = string(scheme) + "://" + request.getHostname();
+            if ((!strcmp(scheme,"http") && port!=80) || (!strcmp(scheme,"https") && port!=443)) {
+                ostringstream portstr;
+                portstr << port;
+                root += ":" + portstr.str();
+            }
+            url = root + url;
+        }
+    }
+
     void SHIBSP_DLLLOCAL generateRandomHex(std::string& buf, unsigned int len) {
         static char DIGITS[] = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'};
         int r;
@@ -95,35 +111,40 @@ namespace shibsp {
             buf += (DIGITS[0x0F & b2]);
         }
     }
+
+
 };
 
 void SHIBSP_API shibsp::registerHandlers()
 {
     SPConfig& conf=SPConfig::getConfig();
 
+    conf.AssertionConsumerServiceManager.registerFactory(SAML1_ASSERTION_CONSUMER_SERVICE, SAML1ConsumerFactory);
     conf.AssertionConsumerServiceManager.registerFactory(SAML1_PROFILE_BROWSER_ARTIFACT, SAML1ConsumerFactory);
     conf.AssertionConsumerServiceManager.registerFactory(SAML1_PROFILE_BROWSER_POST, SAML1ConsumerFactory);
+    conf.AssertionConsumerServiceManager.registerFactory(SAML20_ASSERTION_CONSUMER_SERVICE, SAML2ConsumerFactory);
     conf.AssertionConsumerServiceManager.registerFactory(SAML20_BINDING_HTTP_POST, SAML2ConsumerFactory);
     conf.AssertionConsumerServiceManager.registerFactory(SAML20_BINDING_HTTP_POST_SIMPLESIGN, SAML2ConsumerFactory);
     conf.AssertionConsumerServiceManager.registerFactory(SAML20_BINDING_HTTP_ARTIFACT, SAML2ConsumerFactory);
     conf.AssertionConsumerServiceManager.registerFactory(SAML20_BINDING_PAOS, SAML2ConsumerFactory);
 
+    conf.ArtifactResolutionServiceManager.registerFactory(SAML20_ARTIFACT_RESOLUTION_SERVICE, SAML2ArtifactResolutionFactory);
     conf.ArtifactResolutionServiceManager.registerFactory(SAML20_BINDING_SOAP, SAML2ArtifactResolutionFactory);
 
     conf.HandlerManager.registerFactory(SAML20_BINDING_URI, AssertionLookupFactory);
+    conf.HandlerManager.registerFactory(DISCOVERY_FEED_HANDLER, DiscoveryFeedFactory);
     conf.HandlerManager.registerFactory(METADATA_GENERATOR_HANDLER, MetadataGeneratorFactory);
     conf.HandlerManager.registerFactory(STATUS_HANDLER, StatusHandlerFactory);
     conf.HandlerManager.registerFactory(SESSION_HANDLER, SessionHandlerFactory);
 
-    conf.LogoutInitiatorManager.registerFactory(CHAINING_LOGOUT_INITIATOR, ChainingLogoutInitiatorFactory);
-    conf.LogoutInitiatorManager.registerFactory(LOCAL_LOGOUT_INITIATOR, LocalLogoutInitiatorFactory);
-    conf.LogoutInitiatorManager.registerFactory(SAML2_LOGOUT_INITIATOR, SAML2LogoutInitiatorFactory);
+    conf.SingleLogoutServiceManager.registerFactory(SAML20_LOGOUT_HANDLER, SAML2LogoutFactory);
     conf.SingleLogoutServiceManager.registerFactory(SAML20_BINDING_SOAP, SAML2LogoutFactory);
     conf.SingleLogoutServiceManager.registerFactory(SAML20_BINDING_HTTP_REDIRECT, SAML2LogoutFactory);
     conf.SingleLogoutServiceManager.registerFactory(SAML20_BINDING_HTTP_POST, SAML2LogoutFactory);
     conf.SingleLogoutServiceManager.registerFactory(SAML20_BINDING_HTTP_POST_SIMPLESIGN, SAML2LogoutFactory);
     conf.SingleLogoutServiceManager.registerFactory(SAML20_BINDING_HTTP_ARTIFACT, SAML2LogoutFactory);
 
+    conf.ManageNameIDServiceManager.registerFactory(SAML20_NAMEID_MGMT_SERVICE, SAML2NameIDMgmtFactory);
     conf.ManageNameIDServiceManager.registerFactory(SAML20_BINDING_SOAP, SAML2NameIDMgmtFactory);
     conf.ManageNameIDServiceManager.registerFactory(SAML20_BINDING_HTTP_REDIRECT, SAML2NameIDMgmtFactory);
     conf.ManageNameIDServiceManager.registerFactory(SAML20_BINDING_HTTP_POST, SAML2NameIDMgmtFactory);
@@ -139,6 +160,11 @@ Handler::~Handler()
 {
 }
 
+const XMLCh* Handler::getProtocolFamily() const
+{
+    return nullptr;
+}
+
 void Handler::log(SPRequest::SPLogLevel level, const string& msg) const
 {
     Category::getInstance(SHIBSP_LOGCAT".Handler").log(
@@ -152,11 +178,18 @@ void Handler::log(SPRequest::SPLogLevel level, const string& msg) const
 
 void Handler::preserveRelayState(const Application& application, HTTPResponse& response, string& relayState) const
 {
+    // The empty string implies no state to deal with.
     if (relayState.empty())
         return;
 
-    // No setting means just pass it by value.
-    pair<bool,const char*> mech=getString("relayState");
+    // No setting means just pass state by value.
+    pair<bool,const char*> mech = getString("relayState");
+    if (!mech.first) {
+        // Check for setting on Sessions element.
+        const PropertySet* sessionprop = application.getPropertySet("Sessions");
+        if (sessionprop)
+            mech = sessionprop->getString("relayState");
+    }
     if (!mech.first || !mech.second || !*mech.second)
         return;
 
@@ -185,7 +218,7 @@ void Handler::preserveRelayState(const Application& application, HTTPResponse& r
                     if (storage) {
                         string rsKey;
                         generateRandomHex(rsKey,5);
-                        if (!storage->createString("RelayState", rsKey.c_str(), relayState.c_str(), time(NULL) + 600))
+                        if (!storage->createString("RelayState", rsKey.c_str(), relayState.c_str(), time(nullptr) + 600))
                             throw IOException("Attempted to insert duplicate storage key.");
                         relayState = string(mech.second-3) + ':' + rsKey;
                     }
@@ -234,9 +267,10 @@ void Handler::recoverRelayState(
                     StorageService* storage = conf.getServiceProvider()->getStorageService(ssid.c_str());
                     if (storage) {
                         ssid = key;
-                        if (storage->readString("RelayState",ssid.c_str(),&relayState)>0) {
+                        if (storage->readString("RelayState",ssid.c_str(),&relayState) > 0) {
                             if (clear)
                                 storage->deleteString("RelayState",ssid.c_str());
+                            absolutize(request, relayState);
                             return;
                         }
                         else
@@ -263,6 +297,7 @@ void Handler::recoverRelayState(
                     }
                     else {
                         relayState = out.string();
+                        absolutize(request, relayState);
                         return;
                     }
                 }
@@ -290,6 +325,7 @@ void Handler::recoverRelayState(
                     exp += "; expires=Mon, 01 Jan 2001 00:00:00 GMT";
                     response.setCookie(relay_cookie.first.c_str(), exp.c_str());
                 }
+                absolutize(request, relayState);
                 return;
             }
         }
@@ -301,26 +337,18 @@ void Handler::recoverRelayState(
     if (relayState.empty() || relayState == "default" || relayState == "cookie") {
         pair<bool,const char*> homeURL=application.getString("homeURL");
         if (homeURL.first)
-            relayState=homeURL.second;
-        else {
-            // Compute a URL to the root of the site.
-            int port = request.getPort();
-            const char* scheme = request.getScheme();
-            relayState = string(scheme) + "://" + request.getHostname();
-            if ((!strcmp(scheme,"http") && port!=80) || (!strcmp(scheme,"https") && port!=443)) {
-                ostringstream portstr;
-                portstr << port;
-                relayState += ":" + portstr.str();
-            }
-            relayState += '/';
-        }
+            relayState = homeURL.second;
+        else
+            relayState = '/';
     }
+
+    absolutize(request, relayState);
 }
 
 AbstractHandler::AbstractHandler(
     const DOMElement* e, Category& log, DOMNodeFilter* filter, const map<string,string>* remapper
     ) : m_log(log), m_configNS(shibspconstants::SHIB2SPCONFIG_NS) {
-    load(e,NULL,filter,remapper);
+    load(e,nullptr,filter,remapper);
 }
 
 AbstractHandler::~AbstractHandler()
@@ -352,7 +380,7 @@ void AbstractHandler::checkError(const XMLObject* response, const saml2md::RoleD
         const saml2p::Status* status = r2->getStatus();
         if (status) {
             const saml2p::StatusCode* sc = status->getStatusCode();
-            const XMLCh* code = sc ? sc->getValue() : NULL;
+            const XMLCh* code = sc ? sc->getValue() : nullptr;
             if (code && !XMLString::equals(code,saml2p::StatusCode::SUCCESS)) {
                 FatalProfileException ex("SAML response contained an error.");
                 annotateException(&ex, role, status);   // throws it
@@ -365,7 +393,7 @@ void AbstractHandler::checkError(const XMLObject* response, const saml2md::RoleD
         const saml1p::Status* status = r1->getStatus();
         if (status) {
             const saml1p::StatusCode* sc = status->getStatusCode();
-            const xmltooling::QName* code = sc ? sc->getValue() : NULL;
+            const xmltooling::QName* code = sc ? sc->getValue() : nullptr;
             if (code && *code != saml1p::StatusCode::SUCCESS) {
                 FatalProfileException ex("SAML response contained an error.");
                 ex.addProperty("statusCode", code->toString().c_str());
@@ -417,7 +445,7 @@ long AbstractHandler::sendMessage(
     bool signIfPossible
     ) const
 {
-    const EntityDescriptor* entity = role ? dynamic_cast<const EntityDescriptor*>(role->getParent()) : NULL;
+    const EntityDescriptor* entity = role ? dynamic_cast<const EntityDescriptor*>(role->getParent()) : nullptr;
     const PropertySet* relyingParty = application.getRelyingParty(entity);
     pair<bool,const char*> flag = signIfPossible ? make_pair(true,(const char*)"true") : relyingParty->getString("signing");
     if (role && flag.first &&
@@ -427,7 +455,7 @@ long AbstractHandler::sendMessage(
         CredentialResolver* credResolver=application.getCredentialResolver();
         if (credResolver) {
             Locker credLocker(credResolver);
-            const Credential* cred = NULL;
+            const Credential* cred = nullptr;
             pair<bool,const char*> keyName = relyingParty->getString("keyName");
             pair<bool,const XMLCh*> sigalg = relyingParty->getXMLString("signingAlg");
             if (role) {
@@ -435,9 +463,19 @@ long AbstractHandler::sendMessage(
                 mcc.setUsage(Credential::SIGNING_CREDENTIAL);
                 if (keyName.first)
                     mcc.getKeyNames().insert(keyName.second);
-                if (sigalg.first)
+                if (sigalg.first) {
+                    // Using an explicit algorithm, so resolve a credential directly.
                     mcc.setXMLAlgorithm(sigalg.second);
-                cred = credResolver->resolve(&mcc);
+                    cred = credResolver->resolve(&mcc);
+                }
+                else {
+                    // Prefer credential based on peer's requirements.
+                    pair<const SigningMethod*,const Credential*> p = role->getSigningMethod(*credResolver, mcc);
+                    if (p.first)
+                        sigalg = make_pair(true, p.first->getAlgorithm());
+                    if (p.second)
+                        cred = p.second;
+                }
             }
             else {
                 CredentialCriteria cc;
@@ -450,6 +488,12 @@ long AbstractHandler::sendMessage(
             }
             if (cred) {
                 // Signed request.
+                pair<bool,const XMLCh*> digalg = relyingParty->getXMLString("digestAlg");
+                if (!digalg.first && role) {
+                    const DigestMethod* dm = role->getDigestMethod();
+                    if (dm)
+                        digalg = make_pair(true, dm->getAlgorithm());
+                }
                 return encoder.encode(
                     httpResponse,
                     msg,
@@ -459,7 +503,7 @@ long AbstractHandler::sendMessage(
                     &application,
                     cred,
                     sigalg.second,
-                    relyingParty->getXMLString("digestAlg").second
+                    (digalg.first ? digalg.second : nullptr)
                     );
             }
             else {
@@ -489,7 +533,7 @@ void AbstractHandler::preservePostData(
 
     // No specs mean no save.
     const PropertySet* props=application.getPropertySet("Sessions");
-    pair<bool,const char*> mech = props->getString("postData");
+    pair<bool,const char*> mech = props ? props->getString("postData") : pair<bool,const char*>(false,nullptr);
     if (!mech.first) {
         m_log.info("postData property not supplied, form data will not be preserved across SSO");
         return;
@@ -518,7 +562,7 @@ void AbstractHandler::preservePostData(
                 rsKey = SAMLArtifact::toHex(rsKey);
                 ostringstream out;
                 out << postData;
-                if (!storage->createString("PostData", rsKey.c_str(), out.str().c_str(), time(NULL) + 600))
+                if (!storage->createString("PostData", rsKey.c_str(), out.str().c_str(), time(nullptr) + 600))
                     throw IOException("Attempted to insert duplicate storage key.");
                 postkey = string(mech.second-3) + ':' + rsKey;
             }
@@ -618,9 +662,9 @@ long AbstractHandler::sendPostResponse(
     HTTPResponse::sanitizeURL(url);
 
     const PropertySet* props=application.getPropertySet("Sessions");
-    pair<bool,const char*> postTemplate = props->getString("postTemplate");
+    pair<bool,const char*> postTemplate = props ? props->getString("postTemplate") : pair<bool,const char*>(true,nullptr);
     if (!postTemplate.first)
-        throw ConfigurationException("Missing postTemplate property, unable to recreate form post.");
+        postTemplate.second = "postTemplate.html";
 
     string fname(postTemplate.second);
     ifstream infile(XMLToolingConfig::getConfig().getPathResolver()->resolve(fname, PathResolver::XMLTOOLING_CFG_FILE).c_str());
@@ -640,7 +684,7 @@ long AbstractHandler::sendPostResponse(
     stringstream str;
     XMLToolingConfig::getConfig().getTemplateEngine()->run(infile, str, respParam);
 
-    pair<bool,bool> postExpire = props->getBool("postExpire");
+    pair<bool,bool> postExpire = props ? props->getBool("postExpire") : make_pair(false,false);
 
     httpResponse.setContentType("text/html");
     if (!postExpire.first || postExpire.second) {
@@ -673,12 +717,12 @@ DDF AbstractHandler::getPostData(const Application& application, const HTTPReque
     string contentType = request.getContentType();
     if (contentType.compare("application/x-www-form-urlencoded") == 0) {
         const PropertySet* props=application.getPropertySet("Sessions");
-        pair<bool,unsigned int> plimit = props->getUnsignedInt("postLimit");
+        pair<bool,unsigned int> plimit = props ? props->getUnsignedInt("postLimit") : pair<bool,unsigned int>(false,0);
         if (!plimit.first)
             plimit.second = 1024 * 1024;
         if (plimit.second == 0 || request.getContentLength() <= plimit.second) {
             CGIParser cgi(request);
-            pair<CGIParser::walker,CGIParser::walker> params = cgi.getParameters(NULL);
+            pair<CGIParser::walker,CGIParser::walker> params = cgi.getParameters(nullptr);
             if (params.first == params.second)
                 return DDF("parameters").list();
             DDF child;
@@ -740,7 +784,7 @@ pair<bool,const char*> AbstractHandler::getString(const char* name, const SPRequ
         return getString(name);
     }
 
-    return pair<bool,const char*>(false,NULL);
+    return pair<bool,const char*>(false,nullptr);
 }
 
 pair<bool,unsigned int> AbstractHandler::getUnsignedInt(const char* name, const SPRequest& request, unsigned int type) const
@@ -748,7 +792,7 @@ pair<bool,unsigned int> AbstractHandler::getUnsignedInt(const char* name, const
     if (type & HANDLER_PROPERTY_REQUEST) {
         const char* param = request.getParameter(name);
         if (param && *param)
-            return pair<bool,unsigned int>(true, strtol(param,NULL,10));
+            return pair<bool,unsigned int>(true, strtol(param,nullptr,10));
     }
     
     if (type & HANDLER_PROPERTY_MAP) {