Add cache method to find but not remove sessions by name.
[shibboleth/sp.git] / shibsp / handler / impl / SAML2SessionInitiator.cpp
index 600f45e..8b297ca 100644 (file)
 #include "handler/SessionInitiator.h"
 #include "util/SPConstants.h"
 
-#include <saml/SAMLConfig.h>
-#include <saml/binding/MessageEncoder.h>
-#include <saml/saml2/core/Protocols.h>
-#include <saml/saml2/metadata/EndpointManager.h>
-#include <saml/saml2/metadata/Metadata.h>
-#include <saml/saml2/metadata/MetadataCredentialCriteria.h>
-
-using namespace shibsp;
+#ifndef SHIBSP_LITE
+# include <saml/SAMLConfig.h>
+# include <saml/saml2/core/Protocols.h>
+# include <saml/saml2/metadata/EndpointManager.h>
+# include <saml/saml2/metadata/Metadata.h>
+# include <saml/saml2/metadata/MetadataCredentialCriteria.h>
 using namespace opensaml::saml2;
 using namespace opensaml::saml2p;
 using namespace opensaml::saml2md;
+#endif
+
+using namespace shibsp;
 using namespace opensaml;
 using namespace xmltooling;
 using namespace log4cpp;
@@ -58,10 +59,13 @@ namespace shibsp {
     public:
         SAML2SessionInitiator(const DOMElement* e, const char* appId);
         virtual ~SAML2SessionInitiator() {
+#ifndef SHIBSP_LITE
             if (SPConfig::getConfig().isEnabled(SPConfig::OutOfProcess)) {
                 XMLString::release(&m_outgoing);
                 for_each(m_encoders.begin(), m_encoders.end(), cleanup_pair<const XMLCh*,MessageEncoder>());
+                delete m_requestTemplate;
             }
+#endif
         }
         
         void setParent(const PropertySet* parent);
@@ -74,15 +78,22 @@ namespace shibsp {
             HTTPResponse& httpResponse,
             const char* entityID,
             const XMLCh* acsIndex,
-            const XMLCh* acsLocation,
+            const char* acsLocation,
             const XMLCh* acsBinding,
+            bool isPassive,
+            bool forceAuthn,
+            const char* authnContextClassRef,
+            const char* authnContextComparison,
             string& relayState
             ) const;
 
         string m_appId;
+#ifndef SHIBSP_LITE
         XMLCh* m_outgoing;
         vector<const XMLCh*> m_bindings;
         map<const XMLCh*,MessageEncoder*> m_encoders;
+        AuthnRequest* m_requestTemplate;
+#endif
     };
 
 #if defined (_MSC_VER)
@@ -97,26 +108,28 @@ namespace shibsp {
 };
 
 SAML2SessionInitiator::SAML2SessionInitiator(const DOMElement* e, const char* appId)
-    : AbstractHandler(e, Category::getInstance(SHIBSP_LOGCAT".SessionInitiator")), m_appId(appId), m_outgoing(NULL)
+    : AbstractHandler(e, Category::getInstance(SHIBSP_LOGCAT".SessionInitiator")), m_appId(appId)
 {
-    // If Location isn't set, defer address registration until the setParent call.
-    pair<bool,const char*> loc = getString("Location");
-    if (loc.first) {
-        string address = m_appId + loc.second + "::run::SAML2SI";
-        setAddress(address.c_str());
-    }
-
+#ifndef SHIBSP_LITE
+    m_outgoing=NULL;
+    m_requestTemplate=NULL;
     if (SPConfig::getConfig().isEnabled(SPConfig::OutOfProcess)) {
+        // Check for a template AuthnRequest to build from.
+        DOMElement* child = XMLHelper::getFirstChildElement(e, samlconstants::SAML20P_NS, AuthnRequest::LOCAL_NAME);
+        if (child)
+            m_requestTemplate = dynamic_cast<AuthnRequest*>(AuthnRequestBuilder::buildOneFromElement(child));
+
+        // Handle outgoing binding setup.
         pair<bool,const XMLCh*> outgoing = getXMLString("outgoingBindings");
         if (outgoing.first) {
             m_outgoing = XMLString::replicate(outgoing.second);
+            XMLString::trim(m_outgoing);
         }
         else {
             // No override, so we'll install a default binding precedence.
             string prec = string(samlconstants::SAML20_BINDING_HTTP_REDIRECT) + ' ' + samlconstants::SAML20_BINDING_HTTP_POST + ' ' +
                 samlconstants::SAML20_BINDING_HTTP_POST_SIMPLESIGN + ' ' + samlconstants::SAML20_BINDING_HTTP_ARTIFACT;
             m_outgoing = XMLString::transcode(prec.c_str());
-            XMLString::trim(m_outgoing);
         }
 
         int pos;
@@ -130,7 +143,7 @@ SAML2SessionInitiator::SAML2SessionInitiator(const DOMElement* e, const char* ap
                 auto_ptr_char b(start);
                 MessageEncoder * encoder = SAMLConfig::getConfig().MessageEncoderManager.newPlugin(b.get(),e);
                 m_encoders[start] = encoder;
-                m_log.info("supporting outgoing binding (%s)", b.get());
+                m_log.debug("supporting outgoing binding (%s)", b.get());
             }
             catch (exception& ex) {
                 m_log.error("error building MessageEncoder: %s", ex.what());
@@ -141,6 +154,14 @@ SAML2SessionInitiator::SAML2SessionInitiator(const DOMElement* e, const char* ap
                 break;
         }
     }
+#endif
+
+    // If Location isn't set, defer address registration until the setParent call.
+    pair<bool,const char*> loc = getString("Location");
+    if (loc.first) {
+        string address = m_appId + loc.second + "::run::SAML2SI";
+        setAddress(address.c_str());
+    }
 }
 
 void SAML2SessionInitiator::setParent(const PropertySet* parent)
@@ -165,38 +186,78 @@ pair<bool,long> SAML2SessionInitiator::run(SPRequest& request, const char* entit
     string target;
     const Handler* ACS=NULL;
     const char* option;
+    pair<bool,const char*> acClass;
+    pair<bool,const char*> acComp;
+    bool isPassive=false,forceAuthn=false;
     const Application& app=request.getApplication();
     pair<bool,bool> acsByIndex = getBool("acsByIndex");
 
     if (isHandler) {
         option=request.getParameter("acsIndex");
-        if (option)
+        if (option) {
             ACS = app.getAssertionConsumerServiceByIndex(atoi(option));
+            if (!ACS)
+                request.log(SPRequest::SPWarn, "invalid acsIndex specified in request, using default ACS location");
+        }
 
         option = request.getParameter("target");
         if (option)
             target = option;
-        if (!acsByIndex.first || !acsByIndex.second) {
+        if (acsByIndex.first && !acsByIndex.second) {
             // Since we're passing the ACS by value, we need to compute the return URL,
             // so we'll need the target resource for real.
             recoverRelayState(request.getApplication(), request, target, false);
         }
+
+        option = request.getParameter("isPassive");
+        isPassive = (option && (*option=='1' || *option=='t'));
+        if (!isPassive) {
+            option = request.getParameter("forceAuthn");
+            forceAuthn = (option && (*option=='1' || *option=='t'));
+        }
+
+        acClass.second = request.getParameter("authnContextClassRef");
+        acClass.first = (acClass.second!=NULL);
+        acComp.second = request.getParameter("authnContextComparison");
+        acComp.first = (acComp.second!=NULL);
     }
     else {
         // We're running as a "virtual handler" from within the filter.
         // The target resource is the current one and everything else is defaulted.
         target=request.getRequestURL();
+        const PropertySet* settings = request.getRequestSettings().first;
+
+        pair<bool,bool> flag = settings->getBool("isPassive");
+        isPassive = flag.first && flag.second;
+        if (!isPassive) {
+            flag = settings->getBool("forceAuthn");
+            forceAuthn = flag.first && flag.second;
+        }
+
+        acClass = settings->getString("authnContextClassRef");
+        acComp = settings->getString("authnContextComparison");
     }
 
     m_log.debug("attempting to initiate session using SAML 2.0 with provider (%s)", entityID);
 
-    // To invoke the request builder, the key requirement is to figure out how and whether
+    if (!ACS) {
+        pair<bool,unsigned int> index = getUnsignedInt("defaultACSIndex");
+        if (index.first) {
+            ACS = app.getAssertionConsumerServiceByIndex(index.second);
+            if (!ACS)
+                request.log(SPRequest::SPWarn, "invalid defaultACSIndex, using default ACS location");
+        }
+        if (!ACS)
+            ACS = app.getDefaultAssertionConsumerService();
+    }
+
+    // To invoke the request builder, the key requirement is to figure out how
     // to express the ACS, by index or value, and if by value, where.
 
     SPConfig& conf = SPConfig::getConfig();
     if (conf.isEnabled(SPConfig::OutOfProcess)) {
-        if (acsByIndex.first && acsByIndex.second) {
-            // Pass by Index. This also allows for defaulting it entirely and sending nothing.
+        if (!acsByIndex.first || acsByIndex.second) {
+            // Pass by Index.
             if (isHandler) {
                 // We may already have RelayState set if we looped back here,
                 // but just in case target is a resource, we reset it back.
@@ -205,13 +266,17 @@ pair<bool,long> SAML2SessionInitiator::run(SPRequest& request, const char* entit
                 if (option)
                     target = option;
             }
-            return doRequest(app, request, entityID, ACS ? ACS->getXMLString("index").second : NULL, NULL, NULL, target);
+            return doRequest(
+                app, request, entityID,
+                ACS ? ACS->getXMLString("index").second : NULL, NULL, NULL,
+                isPassive, forceAuthn,
+                acClass.first ? acClass.second : NULL,
+                acComp.first ? acComp.second : NULL,
+                target
+                );
         }
 
         // Since we're not passing by index, we need to fully compute the return URL and binding.
-        if (!ACS)
-            ACS = app.getDefaultAssertionConsumerService();
-
         // Compute the ACS URL. We add the ACS location to the base handlerURL.
         string ACSloc=request.getHandlerURL(target.c_str());
         pair<bool,const char*> loc=ACS ? ACS->getString("Location") : pair<bool,const char*>(false,NULL);
@@ -226,8 +291,14 @@ pair<bool,long> SAML2SessionInitiator::run(SPRequest& request, const char* entit
                 target = option;
         }
 
-        auto_ptr_XMLCh wideloc(ACSloc.c_str());
-        return doRequest(app, request, entityID, NULL, wideloc.get(), ACS ? ACS->getXMLString("Binding").second : NULL, target);
+        return doRequest(
+            app, request, entityID,
+            NULL, ACSloc.c_str(), ACS ? ACS->getXMLString("Binding").second : NULL,
+            isPassive, forceAuthn,
+            acClass.first ? acClass.second : NULL,
+            acComp.first ? acComp.second : NULL,
+            target
+            );
     }
 
     // Remote the call.
@@ -235,15 +306,20 @@ pair<bool,long> SAML2SessionInitiator::run(SPRequest& request, const char* entit
     DDFJanitor jin(in), jout(out);
     in.addmember("application_id").string(app.getId());
     in.addmember("entity_id").string(entityID);
-    if (acsByIndex.first && acsByIndex.second) {
+    if (isPassive)
+        in.addmember("isPassive").integer(1);
+    else if (forceAuthn)
+        in.addmember("forceAuthn").integer(1);
+    if (acClass.first)
+        in.addmember("authnContextClassRef").string(acClass.second);
+    if (acComp.first)
+        in.addmember("authnContextComparison").string(acComp.second);
+    if (!acsByIndex.first || acsByIndex.second) {
         if (ACS)
             in.addmember("acsIndex").string(ACS->getString("index").second);
     }
     else {
         // Since we're not passing by index, we need to fully compute the return URL and binding.
-        if (!ACS)
-            ACS = app.getDefaultAssertionConsumerService();
-
         // Compute the ACS URL. We add the ACS location to the base handlerURL.
         string ACSloc=request.getHandlerURL(target.c_str());
         pair<bool,const char*> loc=ACS ? ACS->getString("Location") : pair<bool,const char*>(false,NULL);
@@ -291,7 +367,6 @@ void SAML2SessionInitiator::receive(DDF& in, ostream& out)
     auto_ptr<HTTPResponse> http(getResponse(ret));
 
     auto_ptr_XMLCh index(in["acsIndex"].string());
-    auto_ptr_XMLCh loc(in["acsLocation"].string());
     auto_ptr_XMLCh bind(in["acsBinding"].string());
 
     string relayState(in["RelayState"].string() ? in["RelayState"].string() : "");
@@ -299,7 +374,13 @@ void SAML2SessionInitiator::receive(DDF& in, ostream& out)
     // Since we're remoted, the result should either be a throw, which we pass on,
     // a false/0 return, which we just return as an empty structure, or a response/redirect,
     // which we capture in the facade and send back.
-    doRequest(*app, *http.get(), entityID, index.get(), loc.get(), bind.get(), relayState);
+    doRequest(
+        *app, *http.get(), entityID,
+        index.get(), in["acsLocation"].string(), bind.get(),
+        in["isPassive"].integer()==1, in["forceAuthn"].integer()==1,
+        in["authnContextClassRef"].string(), in["authnContextComparison"].string(),
+        relayState
+        );
     out << ret;
 }
 
@@ -308,18 +389,24 @@ pair<bool,long> SAML2SessionInitiator::doRequest(
     HTTPResponse& httpResponse,
     const char* entityID,
     const XMLCh* acsIndex,
-    const XMLCh* acsLocation,
+    const char* acsLocation,
     const XMLCh* acsBinding,
+    bool isPassive,
+    bool forceAuthn,
+    const char* authnContextClassRef,
+    const char* authnContextComparison,
     string& relayState
     ) const
 {
+#ifndef SHIBSP_LITE
     // Use metadata to locate the IdP's SSO service.
     MetadataProvider* m=app.getMetadataProvider();
     Locker locker(m);
     const EntityDescriptor* entity=m->getEntityDescriptor(entityID);
     if (!entity) {
         m_log.error("unable to locate metadata for provider (%s)", entityID);
-        return make_pair(false,0);
+        throw MetadataException("Unable to locate metadata for identity provider ($entityID)",
+            namedparams(1, "entityID", entityID));
     }
     const IDPSSODescriptor* role=entity->getIDPSSODescriptor(samlconstants::SAML20P_NS);
     if (!role) {
@@ -346,31 +433,71 @@ pair<bool,long> SAML2SessionInitiator::doRequest(
 
     preserveRelayState(app, httpResponse, relayState);
 
-    // For now just build a dummy AuthnRequest.
-    auto_ptr<AuthnRequest> req(AuthnRequestBuilder::buildAuthnRequest());
+    auto_ptr<AuthnRequest> req(m_requestTemplate ? m_requestTemplate->cloneAuthnRequest() : AuthnRequestBuilder::buildAuthnRequest());
+    if (m_requestTemplate) {
+        // Freshen TS and ID.
+        req->setID(NULL);
+        req->setIssueInstant(time(NULL));
+    }
+
     req->setDestination(ep->getLocation());
-    if (acsIndex)
+    if (acsIndex && *acsIndex)
         req->setAssertionConsumerServiceIndex(acsIndex);
-    if (acsLocation)
-        req->setAssertionConsumerServiceURL(acsLocation);
-    if (acsBinding)
+    if (acsLocation) {
+        auto_ptr_XMLCh wideloc(acsLocation);
+        req->setAssertionConsumerServiceURL(wideloc.get());
+    }
+    if (acsBinding && *acsBinding)
         req->setProtocolBinding(acsBinding);
-    Issuer* issuer = IssuerBuilder::buildIssuer();
-    req->setIssuer(issuer);
-    issuer->setName(app.getXMLString("providerId").second);
+    if (isPassive)
+        req->IsPassive(isPassive);
+    else if (forceAuthn)
+        req->ForceAuthn(forceAuthn);
+    if (!req->getIssuer()) {
+        Issuer* issuer = IssuerBuilder::buildIssuer();
+        req->setIssuer(issuer);
+        issuer->setName(app.getXMLString("entityID").second);
+    }
+    if (!req->getNameIDPolicy()) {
+        NameIDPolicy* namepol = NameIDPolicyBuilder::buildNameIDPolicy();
+        req->setNameIDPolicy(namepol);
+        namepol->AllowCreate(true);
+    }
+    if (authnContextClassRef || authnContextComparison) {
+        RequestedAuthnContext* reqContext = req->getRequestedAuthnContext();
+        if (!reqContext) {
+            reqContext = RequestedAuthnContextBuilder::buildRequestedAuthnContext();
+            req->setRequestedAuthnContext(reqContext);
+        }
+        if (authnContextClassRef) {
+            reqContext->getAuthnContextDeclRefs().clear();
+            auto_ptr_XMLCh wideclass(authnContextClassRef);
+            AuthnContextClassRef* cref = AuthnContextClassRefBuilder::buildAuthnContextClassRef();
+            cref->setReference(wideclass.get());
+            reqContext->getAuthnContextClassRefs().push_back(cref);
+        }
+        if (authnContextComparison &&
+                (!reqContext->getAuthnContextClassRefs().empty() || !reqContext->getAuthnContextDeclRefs().empty())) {
+            auto_ptr_XMLCh widecomp(authnContextComparison);
+            reqContext->setComparison(widecomp.get());
+        }
+    }
 
     auto_ptr_char dest(ep->getLocation());
 
     // Check for signing.
     const PropertySet* relyingParty = app.getRelyingParty(entity);
-    pair<bool,bool> flag = relyingParty->getBool("signRequests");
-    if ((flag.first && flag.second) || role->WantAuthnRequestsSigned()) {
+    pair<bool,const char*> flag = relyingParty->getString("signRequests");
+    if (role->WantAuthnRequestsSigned() || (flag.first && (!strcmp(flag.second, "true") || !strcmp(flag.second, "front")))) {
         CredentialResolver* credResolver=app.getCredentialResolver();
         if (credResolver) {
             Locker credLocker(credResolver);
             // Fill in criteria to use.
             MetadataCredentialCriteria mcc(*role);
             mcc.setUsage(CredentialCriteria::SIGNING_CREDENTIAL);
+            pair<bool,const char*> keyName = relyingParty->getString("keyName");
+            if (keyName.first)
+                mcc.getKeyNames().insert(keyName.second);
             pair<bool,const XMLCh*> sigalg = relyingParty->getXMLString("signatureAlg");
             if (sigalg.first)
                 mcc.setXMLAlgorithm(sigalg.second);
@@ -381,8 +508,9 @@ pair<bool,long> SAML2SessionInitiator::doRequest(
                     httpResponse,
                     req.get(),
                     dest.get(),
-                    entityID,
+                    entity,
                     relayState.c_str(),
+                    &app,
                     cred,
                     sigalg.second,
                     relyingParty->getXMLString("digestAlg").second
@@ -397,7 +525,10 @@ pair<bool,long> SAML2SessionInitiator::doRequest(
     }
 
     // Unsigned request.
-    long ret = encoder->encode(httpResponse, req.get(), dest.get(), entityID, relayState.c_str());
+    long ret = encoder->encode(httpResponse, req.get(), dest.get(), entity, relayState.c_str(), &app);
     req.release();  // freed by encoder
     return make_pair(true,ret);
+#else
+    return make_pair(false,0);
+#endif
 }