VS10 solution files, convert from NULL macro to nullptr.
[shibboleth/sp.git] / shibsp / attribute / resolver / impl / SimpleAggregationAttributeResolver.cpp
index 547a679..7f8cc26 100644 (file)
@@ -1,5 +1,5 @@
 /*
- *  Copyright 2009 Internet2
+ *  Copyright 2009-2010 Internet2
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -21,6 +21,7 @@
  */
 
 #include "internal.h"
+#include "exceptions.h"
 #include "Application.h"
 #include "ServiceProvider.h"
 #include "SessionCache.h"
 #include "attribute/resolver/ResolutionContext.h"
 #include "binding/SOAPClient.h"
 #include "metadata/MetadataProviderCriteria.h"
+#include "security/SecurityPolicy.h"
 #include "util/SPConstants.h"
 
 #include <saml/exceptions.h>
 #include <saml/SAMLConfig.h>
-#include <saml/binding/SecurityPolicy.h>
 #include <saml/saml2/binding/SAML2SOAPClient.h>
 #include <saml/saml2/core/Protocols.h>
+#include <saml/saml2/metadata/Metadata.h>
+#include <saml/saml2/metadata/MetadataCredentialCriteria.h>
 #include <saml/saml2/metadata/MetadataProvider.h>
+#include <xmltooling/XMLToolingConfig.h>
+#include <xmltooling/security/TrustEngine.h>
 #include <xmltooling/util/NDC.h>
 #include <xmltooling/util/XMLHelper.h>
 #include <xercesc/util/XMLUniDefs.hpp>
@@ -60,23 +65,26 @@ namespace shibsp {
         SimpleAggregationContext(const Application& application, const Session& session)
             : m_app(application),
               m_session(&session),
-              m_nameid(NULL),
+              m_nameid(nullptr),
+              m_entityid(nullptr),
               m_class(XMLString::transcode(session.getAuthnContextClassRef())),
               m_decl(XMLString::transcode(session.getAuthnContextDeclRef())),
-              m_inputTokens(NULL),
-              m_inputAttributes(NULL) {
+              m_inputTokens(nullptr),
+              m_inputAttributes(nullptr) {
         }
 
         SimpleAggregationContext(
             const Application& application,
-            const NameID* nameid=NULL,
-            const XMLCh* authncontext_class=NULL,
-            const XMLCh* authncontext_decl=NULL,
-            const vector<const opensaml::Assertion*>* tokens=NULL,
-            const vector<shibsp::Attribute*>* attributes=NULL
+            const NameID* nameid=nullptr,
+            const XMLCh* entityID=nullptr,
+            const XMLCh* authncontext_class=nullptr,
+            const XMLCh* authncontext_decl=nullptr,
+            const vector<const opensaml::Assertion*>* tokens=nullptr,
+            const vector<shibsp::Attribute*>* attributes=nullptr
             ) : m_app(application),
-                m_session(NULL),
+                m_session(nullptr),
                 m_nameid(nameid),
+                m_entityid(entityID ? XMLString::transcode(entityID) : nullptr),
                 m_class(const_cast<XMLCh*>(authncontext_class)),
                 m_decl(const_cast<XMLCh*>(authncontext_decl)),
                 m_inputTokens(tokens),
@@ -90,11 +98,15 @@ namespace shibsp {
                 XMLString::release(&m_class);
                 XMLString::release(&m_decl);
             }
+            XMLString::release(&m_entityid);
         }
 
         const Application& getApplication() const {
             return m_app;
         }
+        const char* getEntityID() const {
+            return m_session ? m_session->getEntityID() : m_entityid;
+        }
         const NameID* getNameID() const {
             return m_session ? m_session->getNameID() : m_nameid;
         }
@@ -124,6 +136,7 @@ namespace shibsp {
         const Application& m_app;
         const Session* m_session;
         const NameID* m_nameid;
+        char* m_entityid;
         XMLCh* m_class;
         XMLCh* m_decl;
         const vector<const opensaml::Assertion*>* m_inputTokens;
@@ -149,13 +162,15 @@ namespace shibsp {
             const Application& application,
             const EntityDescriptor* issuer,
             const XMLCh* protocol,
-            const NameID* nameid=NULL,
-            const XMLCh* authncontext_class=NULL,
-            const XMLCh* authncontext_decl=NULL,
-            const vector<const opensaml::Assertion*>* tokens=NULL,
-            const vector<shibsp::Attribute*>* attributes=NULL
+            const NameID* nameid=nullptr,
+            const XMLCh* authncontext_class=nullptr,
+            const XMLCh* authncontext_decl=nullptr,
+            const vector<const opensaml::Assertion*>* tokens=nullptr,
+            const vector<shibsp::Attribute*>* attributes=nullptr
             ) const {
-            return new SimpleAggregationContext(application,nameid,authncontext_class,authncontext_decl,tokens,attributes);
+            return new SimpleAggregationContext(
+                application, nameid, (issuer ? issuer->getEntityID() : nullptr), authncontext_class, authncontext_decl, tokens, attributes
+                );
         }
 
         ResolutionContext* createResolutionContext(const Application& application, const Session& session) const {
@@ -173,6 +188,7 @@ namespace shibsp {
 
         Category& m_log;
         string m_policyId;
+        bool m_subjectMatch;
         vector<string> m_attributeIds;
         xstring m_format;
         MetadataProvider* m_metadata;
@@ -192,24 +208,29 @@ namespace shibsp {
     static const XMLCh format[] =               UNICODE_LITERAL_6(f,o,r,m,a,t);
     static const XMLCh _MetadataProvider[] =    UNICODE_LITERAL_16(M,e,t,a,d,a,t,a,P,r,o,v,i,d,e,r);
     static const XMLCh policyId[] =             UNICODE_LITERAL_8(p,o,l,i,c,y,I,d);
+    static const XMLCh subjectMatch[] =         UNICODE_LITERAL_12(s,u,b,j,e,c,t,M,a,t,c,h);
     static const XMLCh _TrustEngine[] =         UNICODE_LITERAL_11(T,r,u,s,t,E,n,g,i,n,e);
     static const XMLCh _type[] =                UNICODE_LITERAL_4(t,y,p,e);
 };
 
 SimpleAggregationResolver::SimpleAggregationResolver(const DOMElement* e)
-    : m_log(Category::getInstance(SHIBSP_LOGCAT".AttributeResolver.SimpleAggregation")), m_metadata(NULL), m_trust(NULL)
+    : m_log(Category::getInstance(SHIBSP_LOGCAT".AttributeResolver.SimpleAggregation")), m_subjectMatch(false), m_metadata(nullptr), m_trust(nullptr)
 {
 #ifdef _DEBUG
     xmltooling::NDC ndc("SimpleAggregationResolver");
 #endif
 
-    const XMLCh* pid = e ? e->getAttributeNS(NULL, policyId) : NULL;
+    const XMLCh* pid = e ? e->getAttributeNS(nullptr, policyId) : nullptr;
     if (pid && *pid) {
         auto_ptr_char temp(pid);
         m_policyId = temp.get();
     }
 
-    pid = e ? e->getAttributeNS(NULL, attributeId) : NULL;
+    pid = e ? e->getAttributeNS(nullptr, subjectMatch) : nullptr;
+    if (pid && (*pid == chLatin_t || *pid == chDigit_1))
+        m_subjectMatch = true;
+
+    pid = e ? e->getAttributeNS(nullptr, attributeId) : nullptr;
     if (pid && *pid) {
         char* dup = XMLString::transcode(pid);
         char* pos;
@@ -223,19 +244,18 @@ SimpleAggregationResolver::SimpleAggregationResolver(const DOMElement* e)
             if (pos)
                 *pos=0;
             m_attributeIds.push_back(start);
-            start = pos ? pos+1 : NULL;
+            start = pos ? pos+1 : nullptr;
         }
         XMLString::release(&dup);
 
-        pid = e->getAttributeNS(NULL, format);
+        pid = e->getAttributeNS(nullptr, format);
         if (pid && *pid)
             m_format = pid;
-
     }
 
     DOMElement* child = XMLHelper::getFirstChildElement(e, _MetadataProvider);
     if (child) {
-        auto_ptr_char type(child->getAttributeNS(NULL, _type));
+        auto_ptr_char type(child->getAttributeNS(nullptr, _type));
         if (!type.get() || !*type.get())
             throw ConfigurationException("MetadataProvider element missing type attribute.");
         m_log.info("building MetadataProvider of type %s...", type.get());
@@ -247,7 +267,7 @@ SimpleAggregationResolver::SimpleAggregationResolver(const DOMElement* e)
     child = XMLHelper::getFirstChildElement(e,  _TrustEngine);
     if (child) {
         try {
-            auto_ptr_char type(child->getAttributeNS(NULL, _type));
+            auto_ptr_char type(child->getAttributeNS(nullptr, _type));
             if (!type.get() || !*type.get())
                 throw ConfigurationException("TrustEngine element missing type attribute.");
             m_log.info("building TrustEngine of type %s...", type.get());
@@ -290,7 +310,6 @@ SimpleAggregationResolver::SimpleAggregationResolver(const DOMElement* e)
         }
         child = XMLHelper::getNextSiblingElement(child);
     }
-
 }
 
 bool SimpleAggregationResolver::doQuery(SimpleAggregationContext& ctx, const char* entityID, const NameID* name) const
@@ -301,11 +320,11 @@ bool SimpleAggregationResolver::doQuery(SimpleAggregationContext& ctx, const cha
     const Application& application = ctx.getApplication();
     MetadataProviderCriteria mc(application, entityID, &AttributeAuthorityDescriptor::ELEMENT_QNAME, samlconstants::SAML20P_NS);
     Locker mlocker(m_metadata);
-    const AttributeAuthorityDescriptor* AA=NULL;
+    const AttributeAuthorityDescriptor* AA=nullptr;
     pair<const EntityDescriptor*,const RoleDescriptor*> mdresult =
         (m_metadata ? m_metadata : application.getMetadataProvider())->getEntityDescriptor(mc);
     if (!mdresult.first) {
-        m_log.warn("unable to locate metadata for provider (%s)", entityID);\r
+        m_log.warn("unable to locate metadata for provider (%s)", entityID);
         return false;
     }
     else if (!(AA=dynamic_cast<const AttributeAuthorityDescriptor*>(mdresult.second))) {
@@ -325,7 +344,7 @@ bool SimpleAggregationResolver::doQuery(SimpleAggregationContext& ctx, const cha
     pair<bool,bool> signedAssertions = relyingParty->getBool("requireSignedAssertions");
     pair<bool,const char*> encryption = relyingParty->getString("encryption");
 
-    shibsp::SecurityPolicy policy(application, NULL, validate.first && validate.second, policyId);
+    shibsp::SecurityPolicy policy(application, nullptr, validate.first && validate.second, policyId);
     if (m_metadata)
         policy.setMetadataProvider(m_metadata);
     if (m_trust)
@@ -336,7 +355,7 @@ bool SimpleAggregationResolver::doQuery(SimpleAggregationContext& ctx, const cha
     shibsp::SOAPClient soaper(policy);
 
     auto_ptr_XMLCh binding(samlconstants::SAML20_BINDING_SOAP);
-    saml2p::StatusResponseType* srt=NULL;
+    saml2p::StatusResponseType* srt=nullptr;
     const vector<AttributeService*>& endpoints=AA->getAttributeServices();
     for (vector<AttributeService*>::const_iterator ep=endpoints.begin(); !srt && ep!=endpoints.end(); ++ep) {
         if (!XMLString::equals((*ep)->getBinding(),binding.get())  || !(*ep)->getLocation())
@@ -348,7 +367,6 @@ bool SimpleAggregationResolver::doQuery(SimpleAggregationContext& ctx, const cha
             // Encrypt the NameID?
             if (encryption.first && (!strcmp(encryption.second, "true") || !strcmp(encryption.second, "back"))) {
                 auto_ptr<EncryptedID> encrypted(EncryptedIDBuilder::buildEncryptedID());
-                MetadataCredentialCriteria mcc(*AA);
                 encrypted->encrypt(
                     *name,
                     *(policy.getMetadataProvider()),
@@ -384,33 +402,73 @@ bool SimpleAggregationResolver::doQuery(SimpleAggregationContext& ctx, const cha
         m_log.error("unable to obtain a SAML response from attribute authority (%s)", entityID);
         return false;
     }
+
+    auto_ptr<saml2p::StatusResponseType> wrapper(srt);
+
     saml2p::Response* response = dynamic_cast<saml2p::Response*>(srt);
     if (!response) {
-        delete srt;
         m_log.error("message was not a samlp:Response");
         return true;
     }
     else if (!response->getStatus() || !response->getStatus()->getStatusCode() ||
             !XMLString::equals(response->getStatus()->getStatusCode()->getValue(), saml2p::StatusCode::SUCCESS)) {
-        delete srt;
         m_log.error("attribute authority (%s) returned a SAML error", entityID);
         return true;
     }
 
+    saml2::Assertion* newtoken = nullptr;
     const vector<saml2::Assertion*>& assertions = const_cast<const saml2p::Response*>(response)->getAssertions();
     if (assertions.empty()) {
-        delete srt;
-        m_log.warn("response from attribute authority (%s) was empty", entityID);
-        return true;
-    }
-    else if (assertions.size()>1)
-        m_log.warn("resolver only supports one assertion in the query response");
+        // Check for encryption.
+        const vector<saml2::EncryptedAssertion*>& encassertions =
+            const_cast<const saml2p::Response*>(response)->getEncryptedAssertions();
+        if (encassertions.empty()) {
+            m_log.warn("response from attribute authority was empty");
+            return true;
+        }
+        else if (encassertions.size() > 1) {
+            m_log.warn("simple resolver only supports one assertion in the query response");
+        }
 
-    auto_ptr<saml2p::StatusResponseType> wrapper(srt);
-    saml2::Assertion* newtoken = assertions.front();
+        CredentialResolver* cr=application.getCredentialResolver();
+        if (!cr) {
+            m_log.warn("found encrypted assertion, but no CredentialResolver was available");
+            return true;
+        }
+
+        // Attempt to decrypt it.
+        try {
+            Locker credlocker(cr);
+            auto_ptr<XMLObject> tokenwrapper(encassertions.front()->decrypt(*cr, relyingParty->getXMLString("entityID").second, &mcc));
+            newtoken = dynamic_cast<saml2::Assertion*>(tokenwrapper.get());
+            if (newtoken) {
+                tokenwrapper.release();
+                if (m_log.isDebugEnabled())
+                    m_log.debugStream() << "decrypted Assertion: " << *newtoken << logging::eol;
+            }
+        }
+        catch (exception& ex) {
+            m_log.error(ex.what());
+        }
+        if (newtoken) {
+            // Free the Response now, so we know this is a stand-alone token later.
+            delete wrapper.release();
+        }
+        else {
+            // Nothing decrypted, should already be logged.
+            return true;
+        }
+    }
+    else {
+        if (assertions.size() > 1)
+            m_log.warn("simple resolver only supports one assertion in the query response");
+        newtoken = assertions.front();
+    }
 
     if (!newtoken->getSignature() && signedAssertions.first && signedAssertions.second) {
         m_log.error("assertion unsigned, rejecting it based on signedAssertions policy");
+        if (!wrapper.get())
+            delete newtoken;
         return true;
     }
 
@@ -426,14 +484,60 @@ bool SimpleAggregationResolver::doQuery(SimpleAggregationContext& ctx, const cha
         // Now we can check the security status of the policy.
         if (!policy.isAuthenticated())
             throw SecurityPolicyException("Security of SAML 2.0 query result not established.");
+
+        if (m_subjectMatch) {
+            // Check for subject match.
+            bool ownedName = false;
+            NameID* respName = newtoken->getSubject() ? newtoken->getSubject()->getNameID() : nullptr;
+            if (!respName) {
+                // Check for encryption.
+                EncryptedID* encname = newtoken->getSubject() ? newtoken->getSubject()->getEncryptedID() : nullptr;
+                if (encname) {
+                    CredentialResolver* cr=application.getCredentialResolver();
+                    if (!cr)
+                        m_log.warn("found EncryptedID, but no CredentialResolver was available");
+                    else {
+                        Locker credlocker(cr);
+                        auto_ptr<XMLObject> decryptedID(encname->decrypt(*cr, relyingParty->getXMLString("entityID").second, &mcc));
+                        respName = dynamic_cast<NameID*>(decryptedID.get());
+                        if (respName) {
+                            ownedName = true;
+                            decryptedID.release();
+                            if (m_log.isDebugEnabled())
+                                m_log.debugStream() << "decrypted NameID: " << *respName << logging::eol;
+                        }
+                    }
+                }
+            }
+
+            auto_ptr<NameID> nameIDwrapper(ownedName ? respName : nullptr);
+
+            if (!respName || !XMLString::equals(respName->getName(), name->getName()) ||
+                !XMLString::equals(respName->getFormat(), name->getFormat()) ||
+                !XMLString::equals(respName->getNameQualifier(), name->getNameQualifier()) ||
+                !XMLString::equals(respName->getSPNameQualifier(), name->getSPNameQualifier())) {
+                if (respName)
+                    m_log.warnStream() << "ignoring Assertion without strongly matching NameID in Subject: " <<
+                        *respName << logging::eol;
+                else
+                    m_log.warn("ignoring Assertion without NameID in Subject");
+                if (!wrapper.get())
+                    delete newtoken;
+                return true;
+            }
+        }
     }
     catch (exception& ex) {
         m_log.error("assertion failed policy validation: %s", ex.what());
+        if (!wrapper.get())
+            delete newtoken;
         return true;
     }
 
-    newtoken->detach();
-    wrapper.release();
+    if (wrapper.get()) {
+        newtoken->detach();
+        wrapper.release();  // detach blows away the Response
+    }
     ctx.getResolvedAssertions().push_back(newtoken);
 
     // Finally, extract and filter the result.
@@ -469,9 +573,9 @@ void SimpleAggregationResolver::resolveAttributes(ResolutionContext& ctx) const
     SimpleAggregationContext& qctx = dynamic_cast<SimpleAggregationContext&>(ctx);
 
     // First we manufacture the appropriate NameID to use.
-    NameID* n=NULL;
+    NameID* n=nullptr;
     for (vector<string>::const_iterator a = m_attributeIds.begin(); !n && a != m_attributeIds.end(); ++a) {
-        const Attribute* attr=NULL;
+        const Attribute* attr=nullptr;
         if (qctx.getSession()) {
             // Input attributes should be available via multimap.
             pair<multimap<string,const Attribute*>::const_iterator, multimap<string,const Attribute*>::const_iterator> range =
@@ -544,11 +648,24 @@ void SimpleAggregationResolver::resolveAttributes(ResolutionContext& ctx) const
 
     auto_ptr<NameID> wrapper(n);
 
+    set<string> history;
+
+    // Put initial IdP into history to prevent extra query.
+    if (qctx.getEntityID())
+        history.insert(qctx.getEntityID());
+
     // We have a master loop over all the possible sources of material.
     for (vector< pair<string,bool> >::const_iterator source = m_sources.begin(); source != m_sources.end(); ++source) {
         if (source->second) {
             // A literal entityID to query.
-            doQuery(qctx, source->first.c_str(), n ? n : qctx.getNameID());
+            if (history.count(source->first) == 0) {
+                m_log.debug("issuing SAML query to (%s)", source->first.c_str());
+                doQuery(qctx, source->first.c_str(), n ? n : qctx.getNameID());
+                history.insert(source->first);
+            }
+            else {
+                m_log.debug("skipping previously queried attribute source (%s)", source->first.c_str());
+            }
         }
         else {
             m_log.debug("using attribute sources referenced in attribute (%s)", source->first.c_str());
@@ -558,8 +675,16 @@ void SimpleAggregationResolver::resolveAttributes(ResolutionContext& ctx) const
                     qctx.getSession()->getIndexedAttributes().equal_range(source->first);
                 for (; range.first != range.second; ++range.first) {
                     const vector<string>& links = range.first->second->getSerializedValues();
-                    for (vector<string>::const_iterator link = links.begin(); link != links.end(); ++link)
-                        doQuery(qctx, link->c_str(), n ? n : qctx.getNameID());
+                    for (vector<string>::const_iterator link = links.begin(); link != links.end(); ++link) {
+                        if (history.count(*link) == 0) {
+                            m_log.debug("issuing SAML query to (%s)", link->c_str());
+                            doQuery(qctx, link->c_str(), n ? n : qctx.getNameID());
+                            history.insert(*link);
+                        }
+                        else {
+                            m_log.debug("skipping previously queried attribute source (%s)", link->c_str());
+                        }
+                    }
                 }
             }
             else if (qctx.getInputAttributes()) {
@@ -568,8 +693,16 @@ void SimpleAggregationResolver::resolveAttributes(ResolutionContext& ctx) const
                 for (vector<Attribute*>::const_iterator match = matches->begin(); match != matches->end(); ++match) {
                     if (source->first == (*match)->getId()) {
                         const vector<string>& links = (*match)->getSerializedValues();
-                        for (vector<string>::const_iterator link = links.begin(); link != links.end(); ++link)
-                            doQuery(qctx, link->c_str(), n ? n : qctx.getNameID());
+                        for (vector<string>::const_iterator link = links.begin(); link != links.end(); ++link) {
+                            if (history.count(*link) == 0) {
+                                m_log.debug("issuing SAML query to (%s)", link->c_str());
+                                doQuery(qctx, link->c_str(), n ? n : qctx.getNameID());
+                                history.insert(*link);
+                            }
+                            else {
+                                m_log.debug("skipping previously queried attribute source (%s)", link->c_str());
+                            }
+                        }
                     }
                 }
             }