Initial draft of protocol bootstrapper, reworked ACS lookup (again).
[shibboleth/sp.git] / shibsp / handler / impl / SAML2SessionInitiator.cpp
index 9707700..3d903f3 100644 (file)
@@ -310,34 +310,24 @@ pair<bool,long> SAML2SessionInitiator::run(SPRequest& request, string& entityID,
 
     if (!ACS) {
         if (ECP) {
-            ACS = app.getAssertionConsumerServiceByBinding(samlconstants::SAML20_BINDING_PAOS);
+            ACS = app.getAssertionConsumerServiceByProtocol(getProtocolFamily(), samlconstants::SAML20_BINDING_PAOS);
             if (!ACS)
                 throw ConfigurationException("Unable to locate PAOS response endpoint.");
         }
         else {
-            // Try fixed index property, or incoming binding set, or default, in order.
+            // Try fixed index property.
             pair<bool,unsigned int> index = getUnsignedInt("acsIndex", request, HANDLER_PROPERTY_MAP|HANDLER_PROPERTY_FIXED);
-            if (index.first) {
+            if (index.first)
                 ACS = app.getAssertionConsumerServiceByIndex(index.second);
-                if (!ACS)
-                    request.log(SPRequest::SPWarn, "invalid acsIndex property, using default ACS location");
-            }
-            /*
-            for (vector<string>::const_iterator b = m_incomingBindings.begin(); !ACS && b != m_incomingBindings.end(); ++b) {
-                ACS = app.getAssertionConsumerServiceByBinding(b->c_str());
-                if (ACS && !XMLString::equals(getProtocolFamily(), ACS->getProtocolFamily()))
-                    ACS = nullptr;
-            }
-            */
-            if (!ACS)
-                ACS = app.getDefaultAssertionConsumerService();
         }
     }
 
-    // Validate the ACS for use with this protocol.
-    if (!ECP && ACS && !XMLString::equals(getProtocolFamily(), ACS->getProtocolFamily())) {
-        m_log.error("configured or requested ACS has non-SAML 2.0 binding");
-        throw ConfigurationException("Configured or requested ACS has non-SAML 2.0 binding ($1).", params(1, ACS->getString("Binding").second));
+    // If we picked by index, validate the ACS for use with this protocol.
+    if (!ECP && (!ACS || !XMLString::equals(getProtocolFamily(), ACS->getProtocolFamily()))) {
+        request.log(SPRequest::SPWarn, "invalid acsIndex property, or non-SAML 2.0 ACS, using default SAML 2.0 ACS");
+        ACS = app.getAssertionConsumerServiceByProtocol(getProtocolFamily());
+        if (!ACS)
+            throw ConfigurationException("Unable to locate a SAML 2.0 ACS endpoint to use for response.");
     }
 
     // To invoke the request builder, the key requirement is to figure out how
@@ -361,21 +351,19 @@ pair<bool,long> SAML2SessionInitiator::run(SPRequest& request, string& entityID,
 
             // Determine index to use.
             pair<bool,const XMLCh*> ix = pair<bool,const XMLCh*>(false,nullptr);
-            if (ACS) {
-               if (!strncmp(ACSloc.c_str(), "https", 5)) {
-                       ix = ACS->getXMLString("sslIndex", shibspconstants::ASCII_SHIB2SPCONFIG_NS);
-                       if (!ix.first)
-                               ix = ACS->getXMLString("index");
-               }
-               else {
+            if (!strncmp(ACSloc.c_str(), "https", 5)) {
+               ix = ACS->getXMLString("sslIndex", shibspconstants::ASCII_SHIB2SPCONFIG_NS);
+               if (!ix.first)
                        ix = ACS->getXMLString("index");
-               }
+            }
+            else {
+               ix = ACS->getXMLString("index");
             }
 
             return doRequest(
                 app, &request, request, entityID.c_str(),
                 ix.second,
-                ACS ? XMLString::equals(ACS->getString("Binding").second, samlconstants::SAML20_BINDING_HTTP_ARTIFACT) : false,
+                XMLString::equals(ACS->getString("Binding").second, samlconstants::SAML20_BINDING_HTTP_ARTIFACT),
                 nullptr, nullptr,
                 isPassive, forceAuthn,
                 acClass.first ? acClass.second : nullptr,
@@ -388,7 +376,7 @@ pair<bool,long> SAML2SessionInitiator::run(SPRequest& request, string& entityID,
 
         // Since we're not passing by index, we need to fully compute the return URL and binding.
         // Compute the ACS URL. We add the ACS location to the base handlerURL.
-        prop = ACS ? ACS->getString("Location") : pair<bool,const char*>(false,nullptr);
+        prop = ACS->getString("Location");
         if (prop.first)
             ACSloc += prop.second;
 
@@ -404,8 +392,8 @@ pair<bool,long> SAML2SessionInitiator::run(SPRequest& request, string& entityID,
         return doRequest(
             app, &request, request, entityID.c_str(),
             nullptr,
-            ACS ? XMLString::equals(ACS->getString("Binding").second, samlconstants::SAML20_BINDING_HTTP_ARTIFACT) : false,
-            ACSloc.c_str(), ACS ? ACS->getXMLString("Binding").second : nullptr,
+            XMLString::equals(ACS->getString("Binding").second, samlconstants::SAML20_BINDING_HTTP_ARTIFACT),
+            ACSloc.c_str(), ACS->getXMLString("Binding").second,
             isPassive, forceAuthn,
             acClass.first ? acClass.second : nullptr,
             acComp.first ? acComp.second : nullptr,
@@ -434,35 +422,31 @@ pair<bool,long> SAML2SessionInitiator::run(SPRequest& request, string& entityID,
     if (spQual.first)
         in.addmember("SPNameQualifier").string(spQual.second);
     if (acsByIndex.first && acsByIndex.second) {
-        if (ACS) {
-            // Determine index to use.
-            pair<bool,const char*> ix = pair<bool,const char*>(false,nullptr);
-               if (!strncmp(ACSloc.c_str(), "https", 5)) {
-                       ix = ACS->getString("sslIndex", shibspconstants::ASCII_SHIB2SPCONFIG_NS);
-                       if (!ix.first)
-                               ix = ACS->getString("index");
-               }
-               else {
+        // Determine index to use.
+        pair<bool,const char*> ix = pair<bool,const char*>(false,nullptr);
+        if (!strncmp(ACSloc.c_str(), "https", 5)) {
+               ix = ACS->getString("sslIndex", shibspconstants::ASCII_SHIB2SPCONFIG_NS);
+               if (!ix.first)
                        ix = ACS->getString("index");
-               }
-            in.addmember("acsIndex").string(ix.second);
-            if (XMLString::equals(ACS->getString("Binding").second, samlconstants::SAML20_BINDING_HTTP_ARTIFACT))
-                in.addmember("artifact").integer(1);
         }
+        else {
+               ix = ACS->getString("index");
+        }
+        in.addmember("acsIndex").string(ix.second);
+        if (XMLString::equals(ACS->getString("Binding").second, samlconstants::SAML20_BINDING_HTTP_ARTIFACT))
+            in.addmember("artifact").integer(1);
     }
     else {
         // Since we're not passing by index, we need to fully compute the return URL and binding.
         // Compute the ACS URL. We add the ACS location to the base handlerURL.
-        prop = ACS ? ACS->getString("Location") : pair<bool,const char*>(false,nullptr);
+        prop = ACS->getString("Location");
         if (prop.first)
             ACSloc += prop.second;
         in.addmember("acsLocation").string(ACSloc.c_str());
-        if (ACS) {
-            prop = ACS->getString("Binding");
-            in.addmember("acsBinding").string(prop.second);
-            if (XMLString::equals(prop.second, samlconstants::SAML20_BINDING_HTTP_ARTIFACT))
-                in.addmember("artifact").integer(1);
-        }
+        prop = ACS->getString("Binding");
+        in.addmember("acsBinding").string(prop.second);
+        if (XMLString::equals(prop.second, samlconstants::SAML20_BINDING_HTTP_ARTIFACT))
+            in.addmember("artifact").integer(1);
     }
 
     if (isHandler) {