Change audience handling and validators to separate out entityID.
[shibboleth/sp.git] / shibsp / handler / impl / SAML2LogoutInitiator.cpp
index b68c44f..400989c 100644 (file)
@@ -240,7 +240,7 @@ void SAML2LogoutInitiator::receive(DDF& in, ostream& out)
     
     Session* session = NULL;
     try {
-         session = app->getServiceProvider().getSessionCache()->find(*req.get(), *app, NULL, NULL);
+         session = app->getServiceProvider().getSessionCache()->find(*app, *req.get(), NULL, NULL);
     }
     catch (exception& ex) {
         m_log.error("error accessing current session: %s", ex.what());
@@ -257,7 +257,7 @@ void SAML2LogoutInitiator::receive(DDF& in, ostream& out)
         else {
              m_log.error("no NameID or issuing entityID found in session");
              session->unlock();
-             app->getServiceProvider().getSessionCache()->remove(*req.get(), resp.get(), *app);
+             app->getServiceProvider().getSessionCache()->remove(*app, *req.get(), resp.get());
         }
     }
     out << ret;
@@ -274,8 +274,8 @@ pair<bool,long> SAML2LogoutInitiator::doRequest(
     vector<string> sessions(1, session->getID());
     if (!notifyBackChannel(application, httpRequest.getRequestURL(), sessions, false)) {
         session->unlock();
-        application.getServiceProvider().getSessionCache()->remove(httpRequest, &httpResponse, application);
-        return sendLogoutPage(application, httpResponse, true, "Partial logout failure.");
+        application.getServiceProvider().getSessionCache()->remove(application, httpRequest, &httpResponse);
+        return sendLogoutPage(application, httpRequest, httpResponse, true, "Partial logout failure.");
     }
 
 #ifndef SHIBSP_LITE
@@ -340,30 +340,43 @@ pair<bool,long> SAML2LogoutInitiator::doRequest(
             }
 
             if (!logoutResponse)
-                ret = sendLogoutPage(application, httpResponse, false, "Identity provider did not respond to logout request.");
+                ret = sendLogoutPage(application, httpRequest, httpResponse, false, "Identity provider did not respond to logout request.");
             else if (!logoutResponse->getStatus() || !logoutResponse->getStatus()->getStatusCode() ||
                    !XMLString::equals(logoutResponse->getStatus()->getStatusCode()->getValue(), saml2p::StatusCode::SUCCESS)) {
                 delete logoutResponse;
-                ret = sendLogoutPage(application, httpResponse, false, "Identity provider returned a SAML error in response to logout request.");
+                ret = sendLogoutPage(application, httpRequest, httpResponse, false, "Identity provider returned a SAML error in response to logout request.");
             }
             else {
                 delete logoutResponse;
-                ret = sendLogoutPage(application, httpResponse, false, "Logout completed successfully.");
+                const char* returnloc = httpRequest.getParameter("return");
+                if (returnloc) {
+                    ret.second = httpResponse.sendRedirect(returnloc);
+                    ret.first = true;
+                }
+                ret = sendLogoutPage(application, httpRequest, httpResponse, false, "Logout completed successfully.");
             }
 
             if (session) {
                 session->unlock();
                 session = NULL;
-                application.getServiceProvider().getSessionCache()->remove(httpRequest, &httpResponse, application);
+                application.getServiceProvider().getSessionCache()->remove(application, httpRequest, &httpResponse);
             }
             return ret;
         }
 
+        // Save off return location as RelayState.
+        string relayState;
+        const char* returnloc = httpRequest.getParameter("return");
+        if (returnloc) {
+            relayState = returnloc;
+            preserveRelayState(application, httpResponse, relayState);
+        }
+
         auto_ptr<LogoutRequest> msg(buildRequest(application, *session, *role, encoder));
 
         msg->setDestination(ep->getLocation());
         auto_ptr_char dest(ep->getLocation());
-        ret.second = sendMessage(*encoder, msg.get(), NULL, dest.get(), role, application, httpResponse);
+        ret.second = sendMessage(*encoder, msg.get(), relayState.c_str(), dest.get(), role, application, httpResponse);
         ret.first = true;
         msg.release();  // freed by encoder
     }
@@ -374,13 +387,13 @@ pair<bool,long> SAML2LogoutInitiator::doRequest(
     if (session) {
         session->unlock();
         session = NULL;
-        application.getServiceProvider().getSessionCache()->remove(httpRequest, &httpResponse, application);
+        application.getServiceProvider().getSessionCache()->remove(application, httpRequest, &httpResponse);
     }
 
     return ret;
 #else
     session->unlock();
-    application.getServiceProvider().getSessionCache()->remove(httpRequest, &httpResponse, application);
+    application.getServiceProvider().getSessionCache()->remove(application, httpRequest, &httpResponse);
     throw ConfigurationException("Cannot perform logout using lite version of shibsp library.");
 #endif
 }
@@ -391,10 +404,12 @@ LogoutRequest* SAML2LogoutInitiator::buildRequest(
     const Application& application, const Session& session, const RoleDescriptor& role, const MessageEncoder* encoder
     ) const
 {
+    const PropertySet* relyingParty = application.getRelyingParty(dynamic_cast<EntityDescriptor*>(role.getParent()));
+
     auto_ptr<LogoutRequest> msg(LogoutRequestBuilder::buildLogoutRequest());
     Issuer* issuer = IssuerBuilder::buildIssuer();
     msg->setIssuer(issuer);
-    issuer->setName(application.getXMLString("entityID").second);
+    issuer->setName(relyingParty->getXMLString("entityID").second);
     auto_ptr_XMLCh index(session.getSessionIndex());
     if (index.get() && *index.get()) {
         SessionIndex* si = SessionIndexBuilder::buildSessionIndex();
@@ -403,7 +418,6 @@ LogoutRequest* SAML2LogoutInitiator::buildRequest(
     }
 
     const NameID* nameid = session.getNameID();
-    const PropertySet* relyingParty = application.getRelyingParty(dynamic_cast<EntityDescriptor*>(role.getParent()));
     pair<bool,const char*> flag = relyingParty->getString("encryption");
     if (flag.first &&
         (!strcmp(flag.second, "true") || (encoder && !strcmp(flag.second, "front")) || (!encoder && !strcmp(flag.second, "back")))) {