Sync up with policy API changes.
[shibboleth/cpp-sp.git] / adfs / adfs.cpp
index 05b20b8..529e013 100644 (file)
 #include <shibsp/handler/AssertionConsumerService.h>
 #include <shibsp/handler/LogoutHandler.h>
 #include <shibsp/handler/SessionInitiator.h>
+#include <xmltooling/logging.h>
 #include <xmltooling/util/NDC.h>
 #include <xmltooling/util/URLEncoder.h>
 #include <xmltooling/util/XMLHelper.h>
-#include <log4cpp/Category.hh>
 #include <xercesc/util/XMLUniDefs.hpp>
 
 #ifndef SHIBSP_LITE
@@ -67,9 +67,9 @@ using namespace opensaml::saml2md;
 #endif
 using namespace shibsp;
 using namespace opensaml;
+using namespace xmltooling::logging;
 using namespace xmltooling;
 using namespace xercesc;
-using namespace log4cpp;
 using namespace std;
 
 #define WSFED_NS "http://schemas.xmlsoap.org/ws/2003/07/secext"
@@ -77,6 +77,42 @@ using namespace std;
 
 namespace {
 
+#ifndef SHIBSP_LITE
+    class SHIBSP_DLLLOCAL ADFSDecoder : public MessageDecoder
+    {
+        auto_ptr_XMLCh m_ns;
+    public:
+        ADFSDecoder() : m_ns(WSTRUST_NS) {}
+        virtual ~ADFSDecoder() {}
+        
+        XMLObject* decode(string& relayState, const GenericRequest& genericRequest, SecurityPolicy& policy) const;
+    };
+
+    MessageDecoder* ADFSDecoderFactory(const pair<const DOMElement*,const XMLCh*>& p)
+    {
+        return new ADFSDecoder();
+    }
+
+    class SHIBSP_DLLLOCAL ADFSMessageRule : public SecurityPolicyRule
+    {
+    public:
+        ADFSMessageRule(const DOMElement* e) : m_protocol(WSFED_NS) {}
+        virtual ~ADFSMessageRule() {}
+        
+        const char* getType() const {
+            return "ADFSMessage";
+        }
+        void evaluate(const XMLObject& message, const GenericRequest* request, const XMLCh* protocol, SecurityPolicy& policy) const;
+
+        auto_ptr_XMLCh m_protocol;
+    };
+
+    SecurityPolicyRule* ADFSMessageRuleFactory(const DOMElement* const & e)
+    {
+        return new ADFSMessageRule(e);
+    }
+#endif
+
 #if defined (_MSC_VER)
     #pragma warning( push )
     #pragma warning( disable : 4250 )
@@ -129,7 +165,7 @@ namespace {
         ADFSConsumer(const DOMElement* e, const char* appId)
             : shibsp::AssertionConsumerService(e, appId, Category::getInstance(SHIBSP_LOGCAT".SSO.ADFS"))
 #ifndef SHIBSP_LITE
-                ,m_binding(WSFED_NS)
+                ,m_messageRule(e)
 #endif
             {}
         virtual ~ADFSConsumer() {}
@@ -143,7 +179,7 @@ namespace {
             const PropertySet* settings,
             const XMLObject& xmlObject
             ) const;
-        auto_ptr_XMLCh m_binding;
+        ADFSMessageRule m_messageRule;
 #endif
     };
 
@@ -209,23 +245,6 @@ namespace {
     #pragma warning( pop )
 #endif
 
-#ifndef SHIBSP_LITE
-    class ADFSDecoder : public MessageDecoder
-    {
-        auto_ptr_XMLCh m_ns;
-    public:
-        ADFSDecoder() : m_ns(WSTRUST_NS) {}
-        virtual ~ADFSDecoder() {}
-        
-        XMLObject* decode(string& relayState, const GenericRequest& genericRequest, SecurityPolicy& policy) const;
-    };
-
-    MessageDecoder* ADFSDecoderFactory(const pair<const DOMElement*,const XMLCh*>& p)
-    {
-        return new ADFSDecoder();
-    }
-#endif
-
     SessionInitiator* ADFSSessionInitiatorFactory(const pair<const DOMElement*,const char*>& p)
     {
         return new ADFSSessionInitiator(p.first, p.second);
@@ -254,8 +273,9 @@ extern "C" int ADFS_EXPORTS xmltooling_extension_init(void*)
     conf.AssertionConsumerServiceManager.registerFactory(WSFED_NS, ADFSLogoutFactory);
 #ifndef SHIBSP_LITE
     SAMLConfig::getConfig().MessageDecoderManager.registerFactory(WSFED_NS, ADFSDecoderFactory);
+    SAMLConfig::getConfig().SecurityPolicyRuleManager.registerFactory("ADFSMessage", ADFSMessageRuleFactory);
     XMLObjectBuilder::registerBuilder(QName(WSTRUST_NS,"RequestedSecurityToken"), new AnyElementBuilder());
-    XMLObjectBuilder::registerBuilder(QName(WSTRUST_NS,"RequestedSecurityTokenResponse"), new AnyElementBuilder());
+    XMLObjectBuilder::registerBuilder(QName(WSTRUST_NS,"RequestSecurityTokenResponse"), new AnyElementBuilder());
 #endif
     return 0;
 }
@@ -270,6 +290,7 @@ extern "C" void ADFS_EXPORTS xmltooling_extension_term()
     conf.AssertionConsumerServiceManager.deregisterFactory(WSFED_NS);
 #ifndef SHIBSP_LITE
     SAMLConfig::getConfig().MessageDecoderManager.deregisterFactory(WSFED_NS);
+    SAMLConfig::getConfig().SecurityPolicyRuleManager.deregisterFactory("ADFSMessage");
 #endif
     */
 }
@@ -489,12 +510,67 @@ XMLObject* ADFSDecoder::decode(string& relayState, const GenericRequest& generic
     if (!policy.getValidating())
         SchemaValidators.validate(xmlObject.get());
 
-    // Run through the policy.
-    policy.evaluate(*xmlObject.get(), &genericRequest);
+    // Skip policy step here, there's no security in the wrapper.
+    // policy.evaluate(*xmlObject.get(), &genericRequest);
     
     return xmlObject.release();
 }
 
+void ADFSMessageRule::evaluate(const XMLObject& message, const GenericRequest* request, const XMLCh* protocol, SecurityPolicy& policy) const
+{
+    Category& log=Category::getInstance(SHIBSP_LOGCAT".SecurityPolicyRule.ADFSMessage");
+
+    if (!XMLString::equals(protocol, m_protocol.get()))
+        return;
+
+    const QName& q = message.getElementQName();
+    if (!XMLString::equals(q.getNamespaceURI(), samlconstants::SAML1_NS) ||
+        !XMLString::equals(q.getLocalPart(), saml1::Assertion::LOCAL_NAME))
+        return;
+
+    try {
+        const saml1::Assertion& token = dynamic_cast<const saml1::Assertion&>(message);
+        policy.setMessageID(token.getAssertionID());
+        policy.setIssueInstant(token.getIssueInstantEpoch());
+
+        log.debug("extracting issuer from message");
+
+        policy.setIssuer(token.getIssuer());
+
+        if (log.isDebugEnabled()) {
+            auto_ptr_char iname(token.getIssuer());
+            log.debug("message from (%s)", iname.get());
+        }
+        
+        if (policy.getIssuerMetadata()) {
+            log.debug("metadata for issuer already set, leaving in place");
+            return;
+        }
+        
+        if (policy.getMetadataProvider() && policy.getRole()) {
+            log.debug("searching metadata for message issuer...");
+            const EntityDescriptor* entity = policy.getMetadataProvider()->getEntityDescriptor(token.getIssuer());
+            if (!entity) {
+                auto_ptr_char temp(token.getIssuer());
+                log.warn("no metadata found, can't establish identity of issuer (%s)", temp.get());
+                return;
+            }
+    
+            log.debug("matched message issuer against metadata, searching for applicable role...");
+            const RoleDescriptor* roledesc=entity->getRoleDescriptor(*policy.getRole(), m_protocol.get());
+            if (!roledesc) {
+                log.warn("unable to find compatible role (%s) in metadata", policy.getRole()->toString().c_str());
+                return;
+            }
+            policy.setIssuerMetadata(roledesc);
+        }
+    }
+    catch (bad_cast&) {
+        // Just trap it.
+        log.warn("caught a bad_cast while examining message");
+    }
+}
+
 string ADFSConsumer::implementProtocol(
     const Application& application,
     const HTTPRequest& httpRequest,
@@ -520,7 +596,8 @@ string ADFSConsumer::implementProtocol(
 
     // Run the policy over the assertion. Handles issuer consistency, replay, freshness,
     // and signature verification, assuming the relevant rules are configured.
-    policy.evaluate(*token);
+    policy.getRules().insert(policy.getRules().begin(), &m_messageRule);
+    policy.evaluate(*token, NULL, m_messageRule.m_protocol.get());
     
     // If no security is in place now, we kick it.
     if (!policy.isSecure())
@@ -561,7 +638,7 @@ string ADFSConsumer::implementProtocol(
 
     // We've successfully "accepted" the SSO token.
     // To complete processing, we need to extract and resolve attributes and then create the session.
-    multimap<string,Attribute*> resolvedAttributes;
+    vector<Attribute*> resolvedAttributes;
     AttributeExtractor* extractor = application.getAttributeExtractor();
     if (extractor) {
         m_log.debug("extracting pushed attributes...");
@@ -591,7 +668,7 @@ string ADFSConsumer::implementProtocol(
             catch (exception& ex) {
                 m_log.error("caught exception filtering attributes: %s", ex.what());
                 m_log.error("dumping extracted attributes due to filtering exception");
-                for_each(resolvedAttributes.begin(), resolvedAttributes.end(), cleanup_pair<string,shibsp::Attribute>());
+                for_each(resolvedAttributes.begin(), resolvedAttributes.end(), xmltooling::cleanup<shibsp::Attribute>());
                 resolvedAttributes.clear();
             }
         }
@@ -611,7 +688,7 @@ string ADFSConsumer::implementProtocol(
         resolveAttributes(
             application,
             issuerMetadata,
-            m_binding.get(),
+            m_messageRule.m_protocol.get(),
             nameid.get(),
             ssoStatement->getAuthenticationMethod(),
             NULL,
@@ -625,7 +702,7 @@ string ADFSConsumer::implementProtocol(
         tokens.insert(tokens.end(), ctx->getResolvedAssertions().begin(), ctx->getResolvedAssertions().end());
 
         // Copy over new attributes, and transfer ownership.
-        resolvedAttributes.insert(ctx->getResolvedAttributes().begin(), ctx->getResolvedAttributes().end());
+        resolvedAttributes.insert(resolvedAttributes.end(), ctx->getResolvedAttributes().begin(), ctx->getResolvedAttributes().end());
         ctx->getResolvedAttributes().clear();
     }
 
@@ -635,7 +712,7 @@ string ADFSConsumer::implementProtocol(
             application,
             httpRequest.getRemoteAddr().c_str(),
             issuerMetadata,
-            m_binding.get(),
+            m_messageRule.m_protocol.get(),
             nameid.get(),
             ssoStatement->getAuthenticationInstant() ? ssoStatement->getAuthenticationInstant()->getRawData() : NULL,
             NULL,
@@ -644,11 +721,11 @@ string ADFSConsumer::implementProtocol(
             &tokens,
             &resolvedAttributes
             );
-        for_each(resolvedAttributes.begin(), resolvedAttributes.end(), cleanup_pair<string,Attribute>());
+        for_each(resolvedAttributes.begin(), resolvedAttributes.end(), xmltooling::cleanup<shibsp::Attribute>());
         return key;
     }
     catch (exception&) {
-        for_each(resolvedAttributes.begin(), resolvedAttributes.end(), cleanup_pair<string,Attribute>());
+        for_each(resolvedAttributes.begin(), resolvedAttributes.end(), xmltooling::cleanup<shibsp::Attribute>());
         throw;
     }
 }