https://issues.shibboleth.net/jira/browse/SSPCPP-339
[shibboleth/cpp-sp.git] / shibsp / handler / impl / AssertionConsumerService.cpp
1 /**
2  * Licensed to the University Corporation for Advanced Internet
3  * Development, Inc. (UCAID) under one or more contributor license
4  * agreements. See the NOTICE file distributed with this work for
5  * additional information regarding copyright ownership.
6  *
7  * UCAID licenses this file to you under the Apache License,
8  * Version 2.0 (the "License"); you may not use this file except
9  * in compliance with the License. You may obtain a copy of the
10  * License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing,
15  * software distributed under the License is distributed on an
16  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
17  * either express or implied. See the License for the specific
18  * language governing permissions and limitations under the License.
19  */
20
21 /**
22  * AssertionConsumerService.cpp
23  *
24  * Base class for handlers that create sessions by consuming SSO protocol responses.
25  */
26
27 #include "internal.h"
28 #include "exceptions.h"
29 #include "Application.h"
30 #include "ServiceProvider.h"
31 #include "SPRequest.h"
32 #include "handler/AssertionConsumerService.h"
33 #include "util/SPConstants.h"
34
35 # include <ctime>
36 #ifndef SHIBSP_LITE
37 # include "attribute/Attribute.h"
38 # include "attribute/filtering/AttributeFilter.h"
39 # include "attribute/filtering/BasicFilteringContext.h"
40 # include "attribute/resolver/AttributeExtractor.h"
41 # include "attribute/resolver/AttributeResolver.h"
42 # include "attribute/resolver/ResolutionContext.h"
43 # include "metadata/MetadataProviderCriteria.h"
44 # include "security/SecurityPolicy.h"
45 # include "security/SecurityPolicyProvider.h"
46 # include <boost/iterator/indirect_iterator.hpp>
47 # include <saml/exceptions.h>
48 # include <saml/SAMLConfig.h>
49 # include <saml/saml1/core/Assertions.h>
50 # include <saml/saml1/core/Protocols.h>
51 # include <saml/saml2/core/Protocols.h>
52 # include <saml/saml2/metadata/Metadata.h>
53 # include <saml/util/CommonDomainCookie.h>
54 using namespace samlconstants;
55 using opensaml::saml2md::MetadataProvider;
56 using opensaml::saml2md::RoleDescriptor;
57 using opensaml::saml2md::EntityDescriptor;
58 using opensaml::saml2md::IDPSSODescriptor;
59 using opensaml::saml2md::SPSSODescriptor;
60 #else
61 # include "lite/CommonDomainCookie.h"
62 #endif
63
64 using namespace shibspconstants;
65 using namespace shibsp;
66 using namespace opensaml;
67 using namespace xmltooling;
68 using namespace boost;
69 using namespace std;
70
71 AssertionConsumerService::AssertionConsumerService(
72     const DOMElement* e, const char* appId, Category& log, DOMNodeFilter* filter, const map<string,string>* remapper
73     ) : AbstractHandler(e, log, filter, remapper)
74 {
75     if (!e)
76         return;
77     string address(appId);
78     address += getString("Location").second;
79     setAddress(address.c_str());
80 #ifndef SHIBSP_LITE
81     if (SPConfig::getConfig().isEnabled(SPConfig::OutOfProcess)) {
82         m_decoder.reset(
83             SAMLConfig::getConfig().MessageDecoderManager.newPlugin(
84                 getString("Binding").second, pair<const DOMElement*,const XMLCh*>(e,shibspconstants::SHIB2SPCONFIG_NS)
85                 )
86             );
87         m_decoder->setArtifactResolver(SPConfig::getConfig().getArtifactResolver());
88     }
89 #endif
90 }
91
92 AssertionConsumerService::~AssertionConsumerService()
93 {
94 }
95
96 pair<bool,long> AssertionConsumerService::run(SPRequest& request, bool isHandler) const
97 {
98     string relayState;
99     SPConfig& conf = SPConfig::getConfig();
100
101     if (conf.isEnabled(SPConfig::OutOfProcess)) {
102         // When out of process, we run natively and directly process the message.
103         return processMessage(request.getApplication(), request, request);
104     }
105     else {
106         // When not out of process, we remote all the message processing.
107         vector<string> headers(1, "Cookie");
108         headers.push_back("User-Agent");
109         headers.push_back("Accept-Language");
110         DDF out,in = wrap(request, &headers);
111         DDFJanitor jin(in), jout(out);
112         out=request.getServiceProvider().getListenerService()->send(in);
113         return unwrap(request, out);
114     }
115 }
116
117 void AssertionConsumerService::receive(DDF& in, ostream& out)
118 {
119     // Find application.
120     const char* aid=in["application_id"].string();
121     const Application* app=aid ? SPConfig::getConfig().getServiceProvider()->getApplication(aid) : nullptr;
122     if (!app) {
123         // Something's horribly wrong.
124         m_log.error("couldn't find application (%s) for new session", aid ? aid : "(missing)");
125         throw ConfigurationException("Unable to locate application for new session, deleted?");
126     }
127
128     // Unpack the request.
129     scoped_ptr<HTTPRequest> req(getRequest(in));
130
131     // Wrap a response shim.
132     DDF ret(nullptr);
133     DDFJanitor jout(ret);
134     scoped_ptr<HTTPResponse> resp(getResponse(ret));
135
136     // Since we're remoted, the result should either be a throw, a false/0 return,
137     // which we just return as an empty structure, or a response/redirect,
138     // which we capture in the facade and send back.
139     processMessage(*app, *req, *resp);
140     out << ret;
141 }
142
143 pair<bool,long> AssertionConsumerService::processMessage(
144     const Application& application, const HTTPRequest& httpRequest, HTTPResponse& httpResponse
145     ) const
146 {
147 #ifndef SHIBSP_LITE
148     // Locate policy key.
149     pair<bool,const char*> policyId = getString("policyId", m_configNS.get());  // namespace-qualified if inside handler element
150     if (!policyId.first)
151         policyId = application.getString("policyId");   // unqualified in Application(s) element
152
153     // Lock metadata for use by policy.
154     Locker metadataLocker(application.getMetadataProvider());
155
156     // Create the policy.
157     scoped_ptr<opensaml::SecurityPolicy> policy(
158         application.getServiceProvider().getSecurityPolicyProvider()->createSecurityPolicy(application, &IDPSSODescriptor::ELEMENT_QNAME, policyId.second)
159         );
160
161     string relayState;
162     bool relayStateOK = true;
163     scoped_ptr<XMLObject> msg;
164     try {
165         // Decode the message and process it in a protocol-specific way.
166         msg.reset(m_decoder->decode(relayState, httpRequest, *(policy.get())));
167         if (!msg)
168             throw BindingException("Failed to decode an SSO protocol response.");
169         DDF postData = recoverPostData(application, httpRequest, httpResponse, relayState.c_str());
170         DDFJanitor postjan(postData);
171         recoverRelayState(application, httpRequest, httpResponse, relayState);
172         limitRelayState(m_log, application, httpRequest, relayState.c_str());
173         implementProtocol(application, httpRequest, httpResponse, *policy, nullptr, *msg);
174
175         auto_ptr_char issuer(policy->getIssuer() ? policy->getIssuer()->getName() : nullptr);
176
177         // History cookie.
178         if (issuer.get() && *issuer.get())
179             maintainHistory(application, httpRequest, httpResponse, issuer.get());
180
181         // Now redirect to the state value. By now, it should be set to *something* usable.
182         // First check for POST data.
183         if (!postData.islist()) {
184             m_log.debug("ACS returning via redirect to: %s", relayState.c_str());
185             return make_pair(true, httpResponse.sendRedirect(relayState.c_str()));
186         }
187         else {
188             m_log.debug("ACS returning via POST to: %s", relayState.c_str());
189             return make_pair(true, sendPostResponse(application, httpResponse, relayState.c_str(), postData));
190         }
191     }
192     catch (XMLToolingException& ex) {
193         if (relayStateOK) {
194             // Check for isPassive error condition.
195             const char* sc2 = ex.getProperty("statusCode2");
196             if (sc2 && !strcmp(sc2, "urn:oasis:names:tc:SAML:2.0:status:NoPassive")) {
197                 pair<bool,bool> ignore = getBool("ignoreNoPassive", m_configNS.get());  // namespace-qualified if inside handler element
198                 if (ignore.first && ignore.second && !relayState.empty()) {
199                     m_log.debug("ignoring SAML status of NoPassive and redirecting to resource...");
200                     return make_pair(true, httpResponse.sendRedirect(relayState.c_str()));
201                 }
202             }
203         }
204         if (!relayState.empty())
205             ex.addProperty("RelayState", relayState.c_str());
206
207         // Log the error.
208         try {
209             scoped_ptr<TransactionLog::Event> event(SPConfig::getConfig().EventManager.newPlugin(LOGIN_EVENT, nullptr));
210             LoginEvent* error_event = dynamic_cast<LoginEvent*>(event.get());
211             if (error_event) {
212                 error_event->m_exception = &ex;
213                 error_event->m_request = &httpRequest;
214                 error_event->m_app = &application;
215                 if (policy->getIssuerMetadata())
216                     error_event->m_peer = dynamic_cast<const EntityDescriptor*>(policy->getIssuerMetadata()->getParent());
217                 auto_ptr_char prot(getProtocolFamily());
218                 error_event->m_protocol = prot.get();
219                 error_event->m_binding = getString("Binding").second;
220                 error_event->m_saml2Response = dynamic_cast<const saml2p::StatusResponseType*>(msg.get());
221                 if (!error_event->m_saml2Response)
222                     error_event->m_saml1Response = dynamic_cast<const saml1p::Response*>(msg.get());
223                 application.getServiceProvider().getTransactionLog()->write(*error_event);
224             }
225             else {
226                 m_log.warn("unable to audit event, log event object was of an incorrect type");
227             }
228         }
229         catch (std::exception& ex) {
230             m_log.warn("exception auditing event: %s", ex.what());
231         }
232
233         throw;
234     }
235 #else
236     throw ConfigurationException("Cannot process message using lite version of shibsp library.");
237 #endif
238 }
239
240 void AssertionConsumerService::checkAddress(const Application& application, const HTTPRequest& httpRequest, const char* issuedTo) const
241 {
242     if (!issuedTo || !*issuedTo)
243         return;
244
245     const PropertySet* props = application.getPropertySet("Sessions");
246     pair<bool,bool> checkAddress = props ? props->getBool("checkAddress") : make_pair(false,true);
247     if (!checkAddress.first)
248         checkAddress.second = true;
249
250     if (checkAddress.second) {
251         m_log.debug("checking client address");
252         if (httpRequest.getRemoteAddr() != issuedTo) {
253             throw FatalProfileException(
254                "Your client's current address ($client_addr) differs from the one used when you authenticated "
255                 "to your identity provider. To correct this problem, you may need to bypass a proxy server. "
256                 "Please contact your local support staff or help desk for assistance.",
257                 namedparams(1, "client_addr", httpRequest.getRemoteAddr().c_str())
258                 );
259         }
260     }
261 }
262
263 #ifndef SHIBSP_LITE
264
265 const XMLCh* AssertionConsumerService::getProtocolFamily() const
266 {
267     return m_decoder ? m_decoder->getProtocolFamily() : nullptr;
268 }
269
270 const char* AssertionConsumerService::getType() const
271 {
272     return "AssertionConsumerService";
273 }
274
275 void AssertionConsumerService::generateMetadata(SPSSODescriptor& role, const char* handlerURL) const
276 {
277     // Initial guess at index to use.
278     pair<bool,unsigned int> ix = pair<bool,unsigned int>(false,0);
279     if (!strncmp(handlerURL, "https", 5))
280         ix = getUnsignedInt("sslIndex", shibspconstants::ASCII_SHIB2SPCONFIG_NS);
281     if (!ix.first)
282         ix = getUnsignedInt("index");
283     if (!ix.first)
284         ix.second = 1;
285
286     // Find maximum index in use and go one higher.
287     const vector<saml2md::AssertionConsumerService*>& services = const_cast<const SPSSODescriptor&>(role).getAssertionConsumerServices();
288     if (!services.empty() && ix.second <= services.back()->getIndex().second)
289         ix.second = services.back()->getIndex().second + 1;
290
291     const char* loc = getString("Location").second;
292     string hurl(handlerURL);
293     if (*loc != '/')
294         hurl += '/';
295     hurl += loc;
296     auto_ptr_XMLCh widen(hurl.c_str());
297
298     saml2md::AssertionConsumerService* ep = saml2md::AssertionConsumerServiceBuilder::buildAssertionConsumerService();
299     ep->setLocation(widen.get());
300     ep->setBinding(getXMLString("Binding").second);
301     ep->setIndex(ix.second);
302     role.getAssertionConsumerServices().push_back(ep);
303 }
304
305 opensaml::SecurityPolicy* AssertionConsumerService::createSecurityPolicy(
306     const Application& application, const xmltooling::QName* role, bool validate, const char* policyId
307     ) const
308 {
309     return new SecurityPolicy(application, role, validate, policyId);
310 }
311
312 class SHIBSP_DLLLOCAL DummyContext : public ResolutionContext
313 {
314 public:
315     DummyContext(const vector<Attribute*>& attributes) : m_attributes(attributes) {
316     }
317
318     virtual ~DummyContext() {
319         for_each(m_attributes.begin(), m_attributes.end(), xmltooling::cleanup<Attribute>());
320     }
321
322     vector<Attribute*>& getResolvedAttributes() {
323         return m_attributes;
324     }
325     vector<Assertion*>& getResolvedAssertions() {
326         return m_tokens;
327     }
328
329 private:
330     vector<Attribute*> m_attributes;
331     static vector<Assertion*> m_tokens; // never any tokens, so just share an empty vector
332 };
333
334 vector<Assertion*> DummyContext::m_tokens;
335
336 ResolutionContext* AssertionConsumerService::resolveAttributes(
337     const Application& application,
338     const saml2md::RoleDescriptor* issuer,
339     const XMLCh* protocol,
340     const saml1::NameIdentifier* v1nameid,
341     const saml2::NameID* nameid,
342     const XMLCh* authncontext_class,
343     const XMLCh* authncontext_decl,
344     const vector<const Assertion*>* tokens
345     ) const
346 {
347     return resolveAttributes(
348         application,
349         nullptr,
350         issuer,
351         protocol,
352         v1nameid,
353         nullptr,
354         nameid,
355         nullptr,
356         authncontext_class,
357         authncontext_decl,
358         tokens
359         );
360 }
361
362 ResolutionContext* AssertionConsumerService::resolveAttributes(
363     const Application& application,
364     const GenericRequest* request,
365     const saml2md::RoleDescriptor* issuer,
366     const XMLCh* protocol,
367     const saml1::NameIdentifier* v1nameid,
368     const saml1::AuthenticationStatement* v1statement,
369     const saml2::NameID* nameid,
370     const saml2::AuthnStatement* statement,
371     const XMLCh* authncontext_class,
372     const XMLCh* authncontext_decl,
373     const vector<const Assertion*>* tokens
374     ) const
375 {
376     // First we do the extraction of any pushed information, including from metadata.
377     vector<Attribute*> resolvedAttributes;
378     AttributeExtractor* extractor = application.getAttributeExtractor();
379     if (extractor) {
380         Locker extlocker(extractor);
381         if (issuer) {
382             pair<bool,const char*> mprefix = application.getString("metadataAttributePrefix");
383             if (mprefix.first) {
384                 m_log.debug("extracting metadata-derived attributes...");
385                 try {
386                     // We pass nullptr for "issuer" because the IdP isn't the one asserting metadata-based attributes.
387                     extractor->extractAttributes(application, request, nullptr, *issuer, resolvedAttributes);
388                     for (indirect_iterator<vector<Attribute*>::iterator> a = make_indirect_iterator(resolvedAttributes.begin());
389                             a != make_indirect_iterator(resolvedAttributes.end()); ++a) {
390                         vector<string>& ids = a->getAliases();
391                         for (vector<string>::iterator id = ids.begin(); id != ids.end(); ++id)
392                             *id = mprefix.second + *id;
393                     }
394                 }
395                 catch (std::exception& ex) {
396                     m_log.error("caught exception extracting attributes: %s", ex.what());
397                 }
398             }
399         }
400
401         m_log.debug("extracting pushed attributes...");
402
403         if (v1nameid || nameid) {
404             try {
405                 if (v1nameid)
406                     extractor->extractAttributes(application, request, issuer, *v1nameid, resolvedAttributes);
407                 else
408                     extractor->extractAttributes(application, request, issuer, *nameid, resolvedAttributes);
409             }
410             catch (std::exception& ex) {
411                 m_log.error("caught exception extracting attributes: %s", ex.what());
412             }
413         }
414
415         if (v1statement || statement) {
416             try {
417                 if (v1statement)
418                     extractor->extractAttributes(application, request, issuer, *v1statement, resolvedAttributes);
419                 else
420                     extractor->extractAttributes(application, request, issuer, *statement, resolvedAttributes);
421             }
422             catch (std::exception& ex) {
423                 m_log.error("caught exception extracting attributes: %s", ex.what());
424             }
425         }
426
427         if (tokens) {
428             for (indirect_iterator<vector<const Assertion*>::const_iterator> t = make_indirect_iterator(tokens->begin());
429                     t != make_indirect_iterator(tokens->end()); ++t) {
430                 try {
431                     extractor->extractAttributes(application, request, issuer, *t, resolvedAttributes);
432                 }
433                 catch (std::exception& ex) {
434                     m_log.error("caught exception extracting attributes: %s", ex.what());
435                 }
436             }
437         }
438
439         AttributeFilter* filter = application.getAttributeFilter();
440         if (filter && !resolvedAttributes.empty()) {
441             BasicFilteringContext fc(application, resolvedAttributes, issuer, authncontext_class);
442             Locker filtlocker(filter);
443             try {
444                 filter->filterAttributes(fc, resolvedAttributes);
445             }
446             catch (std::exception& ex) {
447                 m_log.error("caught exception filtering attributes: %s", ex.what());
448                 m_log.error("dumping extracted attributes due to filtering exception");
449                 for_each(resolvedAttributes.begin(), resolvedAttributes.end(), xmltooling::cleanup<shibsp::Attribute>());
450                 resolvedAttributes.clear();
451             }
452         }
453     }
454     else {
455         m_log.warn("no AttributeExtractor plugin installed, check log during startup");
456     }
457
458     try {
459         AttributeResolver* resolver = application.getAttributeResolver();
460         if (resolver) {
461             m_log.debug("resolving attributes...");
462
463             Locker locker(resolver);
464             auto_ptr<ResolutionContext> ctx(
465                 resolver->createResolutionContext(
466                     application,
467                     issuer ? dynamic_cast<const saml2md::EntityDescriptor*>(issuer->getParent()) : nullptr,
468                     protocol,
469                     nameid,
470                     authncontext_class,
471                     authncontext_decl,
472                     tokens,
473                     &resolvedAttributes
474                     )
475                 );
476             resolver->resolveAttributes(*ctx);
477             // Copy over any pushed attributes.
478             while (!resolvedAttributes.empty()) {
479                 ctx->getResolvedAttributes().push_back(resolvedAttributes.back());
480                 resolvedAttributes.pop_back();
481             }
482             return ctx.release();
483         }
484     }
485     catch (std::exception& ex) {
486         m_log.error("attribute resolution failed: %s", ex.what());
487     }
488
489     if (!resolvedAttributes.empty()) {
490         try {
491             return new DummyContext(resolvedAttributes);
492         }
493         catch (bad_alloc&) {
494             for_each(resolvedAttributes.begin(), resolvedAttributes.end(), xmltooling::cleanup<shibsp::Attribute>());
495         }
496     }
497     return nullptr;
498 }
499
500 void AssertionConsumerService::extractMessageDetails(const Assertion& assertion, const XMLCh* protocol, opensaml::SecurityPolicy& policy) const
501 {
502     policy.setMessageID(assertion.getID());
503     policy.setIssueInstant(assertion.getIssueInstantEpoch());
504
505     if (XMLString::equals(assertion.getElementQName().getNamespaceURI(), samlconstants::SAML20_NS)) {
506         const saml2::Assertion* a2 = dynamic_cast<const saml2::Assertion*>(&assertion);
507         if (a2) {
508             m_log.debug("extracting issuer from SAML 2.0 assertion");
509             policy.setIssuer(a2->getIssuer());
510         }
511     }
512     else {
513         const saml1::Assertion* a1 = dynamic_cast<const saml1::Assertion*>(&assertion);
514         if (a1) {
515             m_log.debug("extracting issuer from SAML 1.x assertion");
516             policy.setIssuer(a1->getIssuer());
517         }
518     }
519
520     if (policy.getIssuer() && !policy.getIssuerMetadata() && policy.getMetadataProvider()) {
521         if (policy.getIssuer()->getFormat() && !XMLString::equals(policy.getIssuer()->getFormat(), saml2::NameIDType::ENTITY)) {
522             m_log.warn("non-system entity issuer, skipping metadata lookup");
523             return;
524         }
525         m_log.debug("searching metadata for assertion issuer...");
526         pair<const EntityDescriptor*,const RoleDescriptor*> entity;
527         MetadataProvider::Criteria& mc = policy.getMetadataProviderCriteria();
528         mc.entityID_unicode = policy.getIssuer()->getName();
529         mc.role = &IDPSSODescriptor::ELEMENT_QNAME;
530         mc.protocol = protocol;
531         entity = policy.getMetadataProvider()->getEntityDescriptor(mc);
532         if (!entity.first) {
533             auto_ptr_char iname(policy.getIssuer()->getName());
534             m_log.warn("no metadata found, can't establish identity of issuer (%s)", iname.get());
535         }
536         else if (!entity.second) {
537             m_log.warn("unable to find compatible IdP role in metadata");
538         }
539         else {
540             policy.setIssuerMetadata(entity.second);
541         }
542     }
543 }
544
545 LoginEvent* AssertionConsumerService::newLoginEvent(const Application& application, const xmltooling::HTTPRequest& request) const
546 {
547     if (!SPConfig::getConfig().isEnabled(SPConfig::Logging))
548         return nullptr;
549     try {
550         auto_ptr<TransactionLog::Event> event(SPConfig::getConfig().EventManager.newPlugin(LOGIN_EVENT, nullptr));
551         LoginEvent* login_event = dynamic_cast<LoginEvent*>(event.get());
552         if (login_event) {
553             login_event->m_request = &request;
554             login_event->m_app = &application;
555             login_event->m_binding = getString("Binding").second;
556             event.release();
557             return login_event;
558         }
559         else {
560             m_log.warn("unable to audit event, log event object was of an incorrect type");
561         }
562     }
563     catch (std::exception& ex) {
564         m_log.warn("exception auditing event: %s", ex.what());
565     }
566     return nullptr;
567 }
568
569 #endif
570
571 void AssertionConsumerService::maintainHistory(
572     const Application& application, const HTTPRequest& request, HTTPResponse& response, const char* entityID
573     ) const
574 {
575     static const char* defProps="; path=/";
576
577     const PropertySet* sessionProps=application.getPropertySet("Sessions");
578     pair<bool,bool> idpHistory=sessionProps->getBool("idpHistory");
579
580     if (idpHistory.first && idpHistory.second) {
581         pair<bool,const char*> cookieProps=sessionProps->getString("idpHistoryProps");
582         if (!cookieProps.first)
583             cookieProps=sessionProps->getString("cookieProps");
584         if (!cookieProps.first)
585             cookieProps.second=defProps;
586
587         // Set an IdP history cookie locally (essentially just a CDC).
588         CommonDomainCookie cdc(request.getCookie(CommonDomainCookie::CDCName));
589
590         // Either leave in memory or set an expiration.
591         pair<bool,unsigned int> days=sessionProps->getUnsignedInt("idpHistoryDays");
592         if (!days.first || days.second==0) {
593             string c = string(cdc.set(entityID)) + cookieProps.second;
594             response.setCookie(CommonDomainCookie::CDCName, c.c_str());
595         }
596         else {
597             time_t now=time(nullptr) + (days.second * 24 * 60 * 60);
598 #ifdef HAVE_GMTIME_R
599             struct tm res;
600             struct tm* ptime=gmtime_r(&now,&res);
601 #else
602             struct tm* ptime=gmtime(&now);
603 #endif
604             char timebuf[64];
605             strftime(timebuf,64,"%a, %d %b %Y %H:%M:%S GMT",ptime);
606             string c = string(cdc.set(entityID)) + cookieProps.second + "; expires=" + timebuf;
607             response.setCookie(CommonDomainCookie::CDCName, c.c_str());
608         }
609     }
610 }