Add protocol family property to protocol handlers, and fix up error handling to recog...
[shibboleth/cpp-sp.git] / adfs / adfs.cpp
index c957c19..333249f 100644 (file)
@@ -90,6 +90,10 @@ namespace {
         ADFSDecoder() : m_ns(WSTRUST_NS) {}
         virtual ~ADFSDecoder() {}
 
+        const XMLCh* getProtocolFamily() const {
+            return m_ns.get();
+        }
+
         XMLObject* decode(string& relayState, const GenericRequest& genericRequest, SecurityPolicy& policy) const;
 
     protected:
@@ -140,6 +144,10 @@ namespace {
         pair<bool,long> unwrap(SPRequest& request, DDF& out) const;
         pair<bool,long> run(SPRequest& request, string& entityID, bool isHandler=true) const;
 
+        const XMLCh* getProtocolFamily() const {
+            return m_binding.get();
+        }
+
     private:
         pair<bool,long> doRequest(
             const Application& application,
@@ -156,13 +164,10 @@ namespace {
 
     class SHIBSP_DLLLOCAL ADFSConsumer : public shibsp::AssertionConsumerService
     {
+        auto_ptr_XMLCh m_protocol;
     public:
         ADFSConsumer(const DOMElement* e, const char* appId)
-            : shibsp::AssertionConsumerService(e, appId, Category::getInstance(SHIBSP_LOGCAT".SSO.ADFS"))
-#ifndef SHIBSP_LITE
-                ,m_protocol(WSFED_NS)
-#endif
-            {}
+            : shibsp::AssertionConsumerService(e, appId, Category::getInstance(SHIBSP_LOGCAT".SSO.ADFS")), m_protocol(WSFED_NS) {}
         virtual ~ADFSConsumer() {}
 
 #ifndef SHIBSP_LITE
@@ -171,8 +176,6 @@ namespace {
             role.addSupport(m_protocol.get());
         }
 
-        auto_ptr_XMLCh m_protocol;
-
     private:
         void implementProtocol(
             const Application& application,
@@ -182,6 +185,10 @@ namespace {
             const PropertySet*,
             const XMLObject& xmlObject
             ) const;
+#else
+        const XMLCh* getProtocolFamily() const {
+            return m_protocol.get();
+        }
 #endif
     };
 
@@ -219,6 +226,9 @@ namespace {
             return "LogoutInitiator";
         }
 #endif
+        const XMLCh* getProtocolFamily() const {
+            return m_binding.get();
+        }
 
     private:
         pair<bool,long> doRequest(const Application& application, const HTTPRequest& httpRequest, HTTPResponse& httpResponse, Session* session) const;
@@ -254,7 +264,7 @@ namespace {
             auto_ptr_XMLCh widen(hurl.c_str());
             SingleLogoutService* ep = SingleLogoutServiceBuilder::buildSingleLogoutService();
             ep->setLocation(widen.get());
-            ep->setBinding(m_login.m_protocol.get());
+            ep->setBinding(m_login.getProtocolFamily());
             role.getSingleLogoutServices().push_back(ep);
         }
 
@@ -262,6 +272,9 @@ namespace {
             return m_login.getType();
         }
 #endif
+        const XMLCh* getProtocolFamily() const {
+            return m_login.getProtocolFamily();
+        }
 
     private:
         ADFSConsumer m_login;
@@ -378,12 +391,9 @@ pair<bool,long> ADFSSessionInitiator::run(SPRequest& request, string& entityID,
     }
 
     // Validate the ACS for use with this protocol.
-    pair<bool,const XMLCh*> ACSbinding = ACS->getXMLString("Binding");
-    if (ACSbinding.first) {
-        if (!XMLString::equals(ACSbinding.second, m_binding.get())) {
-            m_log.error("configured or requested ACS has non-ADFS binding");
-            throw ConfigurationException("Configured or requested ACS has non-ADFS binding ($1).", params(1, ACSbinding.second));
-        }
+    if (!XMLString::equals(getProtocolFamily(), ACS->getProtocolFamily())) {
+        m_log.error("configured or requested ACS has non-ADFS binding");
+        throw ConfigurationException("Configured or requested ACS has non-ADFS binding ($1).", params(1, ACS->getString("Binding").second));
     }
 
     // Since we're not passing by index, we need to fully compute the return URL.