SSPCPP-616 - clean up concatenated string literals
[shibboleth/cpp-sp.git] / adfs / adfs.cpp
index 7a6172f..7e6d20d 100644 (file)
@@ -49,7 +49,6 @@
 #include <shibsp/handler/AssertionConsumerService.h>
 #include <shibsp/handler/LogoutInitiator.h>
 #include <shibsp/handler/SessionInitiator.h>
-#include <boost/scoped_ptr.hpp>
 #include <xmltooling/logging.h>
 #include <xmltooling/util/DateTime.h>
 #include <xmltooling/util/NDC.h>
@@ -125,7 +124,7 @@ namespace {
     {
     public:
         ADFSSessionInitiator(const DOMElement* e, const char* appId)
-            : AbstractHandler(e, Category::getInstance(SHIBSP_LOGCAT".SessionInitiator.ADFS"), nullptr, &m_remapper), m_appId(appId), m_binding(WSFED_NS) {
+            : AbstractHandler(e, Category::getInstance(SHIBSP_LOGCAT ".SessionInitiator.ADFS"), nullptr, &m_remapper), m_appId(appId), m_binding(WSFED_NS) {
             // If Location isn't set, defer address registration until the setParent call.
             pair<bool,const char*> loc = getString("Location");
             if (loc.first) {
@@ -155,6 +154,12 @@ namespace {
             return m_binding.get();
         }
 
+#ifndef SHIBSP_LITE
+        void generateMetadata(saml2md::SPSSODescriptor& role, const char* handlerURL) const {
+            doGenerateMetadata(role, handlerURL);
+        }
+#endif
+
     private:
         pair<bool,long> doRequest(
             const Application& application,
@@ -174,7 +179,7 @@ namespace {
         auto_ptr_XMLCh m_protocol;
     public:
         ADFSConsumer(const DOMElement* e, const char* appId)
-            : shibsp::AssertionConsumerService(e, appId, Category::getInstance(SHIBSP_LOGCAT".SSO.ADFS")), m_protocol(WSFED_NS) {}
+            : shibsp::AssertionConsumerService(e, appId, Category::getInstance(SHIBSP_LOGCAT ".SSO.ADFS")), m_protocol(WSFED_NS) {}
         virtual ~ADFSConsumer() {}
 
 #ifndef SHIBSP_LITE
@@ -203,7 +208,7 @@ namespace {
     {
     public:
         ADFSLogoutInitiator(const DOMElement* e, const char* appId)
-                : AbstractHandler(e, Category::getInstance(SHIBSP_LOGCAT".LogoutInitiator.ADFS")), m_appId(appId), m_binding(WSFED_NS) {
+                : AbstractHandler(e, Category::getInstance(SHIBSP_LOGCAT ".LogoutInitiator.ADFS")), m_appId(appId), m_binding(WSFED_NS) {
             // If Location isn't set, defer address registration until the setParent call.
             pair<bool,const char*> loc = getString("Location");
             if (loc.first) {
@@ -243,7 +248,7 @@ namespace {
     {
     public:
         ADFSLogout(const DOMElement* e, const char* appId)
-                : AbstractHandler(e, Category::getInstance(SHIBSP_LOGCAT".Logout.ADFS")), m_login(e, appId) {
+                : AbstractHandler(e, Category::getInstance(SHIBSP_LOGCAT ".Logout.ADFS")), m_login(e, appId) {
             m_initiator = false;
 #ifndef SHIBSP_LITE
             m_preserve.push_back("wreply");
@@ -361,6 +366,7 @@ pair<bool,long> ADFSSessionInitiator::run(SPRequest& request, string& entityID,
         // Since we're passing the ACS by value, we need to compute the return URL,
         // so we'll need the target resource for real.
         recoverRelayState(app, request, request, target, false);
+        app.limitRedirect(request, target.c_str());
 
         acClass = getString("authnContextClassRef", request);
     }
@@ -560,7 +566,7 @@ XMLObject* ADFSDecoder::decode(string& relayState, const GenericRequest& generic
 #ifdef _DEBUG
     xmltooling::NDC ndc("decode");
 #endif
-    Category& log = Category::getInstance(SHIBSP_LOGCAT".MessageDecoder.ADFS");
+    Category& log = Category::getInstance(SHIBSP_LOGCAT ".MessageDecoder.ADFS");
 
     log.debug("validating input");
     const HTTPRequest* httpRequest=dynamic_cast<const HTTPRequest*>(&genericRequest);
@@ -766,6 +772,7 @@ void ADFSConsumer::implementProtocol(
             &httpRequest,
             policy.getIssuerMetadata(),
             m_protocol.get(),
+            nullptr,
             saml1name,
             saml1statement,
             (saml1name ? nameid.get() : saml2name),
@@ -961,6 +968,10 @@ pair<bool,long> ADFSLogoutInitiator::doRequest(
                 );
         }
 
+        const char* returnloc = httpRequest.getParameter("return");
+        if (returnloc)
+            application.limitRedirect(httpRequest, returnloc);
+
         // Log the request.
         scoped_ptr<LogoutEvent> logout_event(newLogoutEvent(application, &httpRequest, session));
         if (logout_event) {
@@ -968,12 +979,19 @@ pair<bool,long> ADFSLogoutInitiator::doRequest(
             application.getServiceProvider().getTransactionLog()->write(*logout_event);
         }
 
-        const URLEncoder* urlenc = XMLToolingConfig::getConfig().getURLEncoder();
-        const char* returnloc = httpRequest.getParameter("return");
         auto_ptr_char dest(ep->getLocation());
         string req=string(dest.get()) + (strchr(dest.get(),'?') ? '&' : '?') + "wa=wsignout1.0";
-        if (returnloc)
-            req += "&wreply=" + urlenc->encode(returnloc);
+        if (returnloc) {
+            req += "&wreply=";
+            if (*returnloc == '/') {
+                string s(returnloc);
+                httpRequest.absolutize(s);
+                req += XMLToolingConfig::getConfig().getURLEncoder()->encode(s.c_str());
+            }
+            else {
+                req += XMLToolingConfig::getConfig().getURLEncoder()->encode(returnloc);
+            }
+        }
         ret.second = httpResponse.sendRedirect(req.c_str());
         ret.first = true;
 
@@ -1045,7 +1063,16 @@ pair<bool,long> ADFSLogout::run(SPRequest& request, bool isHandler) const
         }
     }
 
-    if (param)
-        return make_pair(true, request.sendRedirect(param));
+    if (param) {
+        if (*param == '/') {
+            string p(param);
+            request.absolutize(p);
+            return make_pair(true, request.sendRedirect(p.c_str()));
+        }
+        else {
+            app.limitRedirect(request, param);
+            return make_pair(true, request.sendRedirect(param));
+        }
+    }
     return sendLogoutPage(app, request, request, "global");
 }