https://bugs.internet2.edu/jira/browse/SSPCPP-303
[shibboleth/sp.git] / shibsp / handler / impl / SAMLDSSessionInitiator.cpp
1 /*
2  *  Copyright 2001-2010 Internet2
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /**
18  * SAMLDSSessionInitiator.cpp
19  *
20  * SAML Discovery Service support.
21  */
22
23 #include "internal.h"
24 #include "Application.h"
25 #include "exceptions.h"
26 #include "handler/AbstractHandler.h"
27 #include "handler/SessionInitiator.h"
28
29 #include <xmltooling/XMLToolingConfig.h>
30 #include <xmltooling/impl/AnyElement.h>
31 #include <xmltooling/util/URLEncoder.h>
32
33 using namespace shibsp;
34 using namespace opensaml;
35 using namespace xmltooling;
36 using namespace std;
37
38 #ifndef SHIBSP_LITE
39 # include <saml/saml2/metadata/Metadata.h>
40 # include <saml/saml2/metadata/MetadataProvider.h>
41 using namespace opensaml::saml2md;
42 #endif
43
44 namespace shibsp {
45
46 #if defined (_MSC_VER)
47     #pragma warning( push )
48     #pragma warning( disable : 4250 )
49 #endif
50
51     class SHIBSP_DLLLOCAL SAMLDSSessionInitiator : public SessionInitiator, public AbstractHandler
52     {
53     public:
54         SAMLDSSessionInitiator(const DOMElement* e, const char* appId);
55         virtual ~SAMLDSSessionInitiator() {}
56
57         pair<bool,long> run(SPRequest& request, string& entityID, bool isHandler=true) const;
58
59 #ifndef SHIBSP_LITE
60         void generateMetadata(SPSSODescriptor& role, const char* handlerURL) const {
61             static const XMLCh LOCAL_NAME[] = UNICODE_LITERAL_17(D,i,s,c,o,v,e,r,y,R,e,s,p,o,n,s,e);
62
63             // Initial guess at index to use.
64             pair<bool,unsigned int> ix = getUnsignedInt("index");
65             if (!ix.first)
66                 ix.second = 1;
67
68             // Find maximum index in use and go one higher.
69             if (role.getExtensions()) {
70                 const vector<XMLObject*>& exts = const_cast<const Extensions*>(role.getExtensions())->getUnknownXMLObjects();
71                 for (vector<XMLObject*>::const_reverse_iterator i = exts.rbegin(); i != exts.rend(); ++i) {
72                     if (XMLString::equals((*i)->getElementQName().getLocalPart(), LOCAL_NAME) &&
73                         XMLString::equals((*i)->getElementQName().getNamespaceURI(), m_discoNS.get())) {
74                         const AttributeExtensibleXMLObject* sub = dynamic_cast<const AttributeExtensibleXMLObject*>(*i);
75                         if (sub) {
76                             const XMLCh* val = sub->getAttribute(xmltooling::QName(nullptr,IndexedEndpointType::INDEX_ATTRIB_NAME));
77                             if (val) {
78                                 int maxindex = XMLString::parseInt(val);
79                                 if (ix.second <= maxindex)
80                                     ix.second = maxindex + 1;
81                                 break;
82                             }
83                         }
84                     }
85                 }
86             }
87
88             const char* loc = getString("Location").second;
89             string hurl(handlerURL);
90             if (*loc != '/')
91                 hurl += '/';
92             hurl += loc;
93             auto_ptr_XMLCh widen(hurl.c_str());
94
95             ostringstream os;
96             os << ix.second;
97             auto_ptr_XMLCh widen2(os.str().c_str());
98
99             ElementProxy* ep = new AnyElementImpl(m_discoNS.get(), LOCAL_NAME);
100             ep->setAttribute(xmltooling::QName(nullptr,EndpointType::LOCATION_ATTRIB_NAME), widen.get());
101             ep->setAttribute(xmltooling::QName(nullptr,EndpointType::BINDING_ATTRIB_NAME), m_discoNS.get());
102             ep->setAttribute(xmltooling::QName(nullptr,IndexedEndpointType::INDEX_ATTRIB_NAME), widen2.get());
103             Extensions* ext = role.getExtensions();
104             if (!ext) {
105                 ext = ExtensionsBuilder::buildExtensions();
106                 role.setExtensions(ext);
107             }
108             ext->getUnknownXMLObjects().push_back(ep);
109         }
110 #endif
111
112     private:
113         const char* m_url;
114         const char* m_returnParam;
115         vector<string> m_preservedOptions;
116 #ifndef SHIBSP_LITE
117         auto_ptr_XMLCh m_discoNS;
118 #endif
119     };
120
121 #if defined (_MSC_VER)
122     #pragma warning( pop )
123 #endif
124
125     SessionInitiator* SHIBSP_DLLLOCAL SAMLDSSessionInitiatorFactory(const pair<const DOMElement*,const char*>& p)
126     {
127         return new SAMLDSSessionInitiator(p.first, p.second);
128     }
129
130 };
131
132 SAMLDSSessionInitiator::SAMLDSSessionInitiator(const DOMElement* e, const char* appId)
133         : AbstractHandler(e, Category::getInstance(SHIBSP_LOGCAT".SessionInitiator.SAMLDS")), m_url(nullptr), m_returnParam(nullptr)
134 #ifndef SHIBSP_LITE
135             ,m_discoNS("urn:oasis:names:tc:SAML:profiles:SSO:idp-discovery-protocol")
136 #endif
137 {
138     pair<bool,const char*> url = getString("URL");
139     if (!url.first)
140         throw ConfigurationException("SAMLDS SessionInitiator requires a URL property.");
141     m_url = url.second;
142     url = getString("entityIDParam");
143     if (url.first)
144         m_returnParam = url.second;
145
146     pair<bool,const char*> options = getString("preservedOptions");
147     if (options.first) {
148         int j = 0;
149         string opt = options.second;
150         for (unsigned int i = 0;  i < opt.length();  i++) {
151             if (opt.at(i) == ' ') {
152                 m_preservedOptions.push_back(opt.substr(j, i-j));
153                 j = i+1;
154             }
155         }
156         m_preservedOptions.push_back(opt.substr(j, opt.length()-j));
157     }
158     else {
159         m_preservedOptions.push_back("isPassive");
160         m_preservedOptions.push_back("forceAuthn");
161         m_preservedOptions.push_back("authnContextClassRef");
162         m_preservedOptions.push_back("authnContextComparison");
163         m_preservedOptions.push_back("NameIDFormat");
164         m_preservedOptions.push_back("SPNameQualifier");
165         m_preservedOptions.push_back("acsIndex");
166     }
167
168     m_supportedOptions.insert("isPassive");
169 }
170
171 pair<bool,long> SAMLDSSessionInitiator::run(SPRequest& request, string& entityID, bool isHandler) const
172 {
173     // The IdP CANNOT be specified for us to run. Otherwise, we'd be redirecting to a DS
174     // anytime the IdP's metadata was wrong.
175     if (!entityID.empty() || !checkCompatibility(request, isHandler))
176         return make_pair(false,0L);
177
178     string target;
179     pair<bool,const char*> prop;
180     bool isPassive=false;
181     const Application& app=request.getApplication();
182     pair<bool,const char*> discoveryURL = pair<bool,const char*>(true, m_url);
183
184     if (isHandler) {
185         prop.second = request.getParameter("SAMLDS");
186         if (prop.second && !strcmp(prop.second,"1")) {
187             saml2md::MetadataException ex("No identity provider was selected by user.");
188             ex.addProperty("statusCode", "urn:oasis:names:tc:SAML:2.0:status:Requester");
189             ex.addProperty("statusCode2", "urn:oasis:names:tc:SAML:2.0:status:NoAvailableIDP");
190             ex.raise();
191         }
192
193         prop = getString("target", request);
194         if (prop.first)
195             target = prop.second;
196
197         recoverRelayState(app, request, request, target, false);
198
199         pair<bool,bool> passopt = getBool("isPassive", request);
200         isPassive = passopt.first && passopt.second;
201
202         prop.second = request.getParameter("discoveryURL");
203         if (prop.second && *prop.second)
204             discoveryURL.second = prop.second;
205     }
206     else {
207         // Check for a hardwired target value in the map or handler.
208         prop = getString("target", request, HANDLER_PROPERTY_MAP|HANDLER_PROPERTY_FIXED);
209         if (prop.first)
210             target = prop.second;
211         else
212             target = request.getRequestURL();
213
214         pair<bool,bool> passopt = getBool("isPassive", request, HANDLER_PROPERTY_MAP|HANDLER_PROPERTY_FIXED);
215         isPassive = passopt.first && passopt.second;
216         discoveryURL = request.getRequestSettings().first->getString("discoveryURL");
217     }
218
219     if (!discoveryURL.first)
220         discoveryURL.second = m_url;
221     m_log.debug("sending request to SAMLDS (%s)", discoveryURL.second);
222
223     // Compute the return URL. We start with a self-referential link.
224     string returnURL = request.getHandlerURL(target.c_str());
225     prop = getString("Location");
226     if (prop.first)
227         returnURL += prop.second;
228     returnURL += "?SAMLDS=1"; // signals us not to loop if we get no answer back
229
230     if (isHandler) {
231         // We may already have RelayState set if we looped back here,
232         // but we've turned it back into a resource by this point, so if there's
233         // a target on the URL, reset to that value.
234         prop.second = request.getParameter("target");
235         if (prop.second && *prop.second)
236             target = prop.second;
237     }
238     preserveRelayState(app, request, target);
239     if (!isHandler)
240         preservePostData(app, request, request, target.c_str());
241
242     const URLEncoder* urlenc = XMLToolingConfig::getConfig().getURLEncoder();
243     if (isHandler) {
244         // Now the hard part. The base assumption is to append the entire query string, if any,
245         // to the self-link. But we want to replace target with the RelayState-preserved value
246         // to hide it from the DS.
247         const char* query = request.getQueryString();
248         if (query) {
249             // See if it starts with target.
250             if (!strncmp(query, "target=", 7)) {
251                 // We skip this altogether and advance the query past it to the first separator.
252                 query = strchr(query, '&');
253                 // If we still have more, just append it.
254                 if (query && *(++query))
255                     returnURL = returnURL + '&' + query;
256             }
257             else {
258                 // There's something in the query before target appears, so we have to find it.
259                 prop.second = strstr(query, "&target=");
260                 if (prop.second) {
261                     // We found it, so first append everything up to it.
262                     returnURL += '&';
263                     returnURL.append(query, prop.second - query);
264                     query = prop.second + 8; // move up just past the equals sign.
265                     prop.second = strchr(query, '&');
266                     if (prop.second)
267                         returnURL += prop.second;
268                 }
269                 else {
270                     // No target in the existing query, so just append it as is.
271                     returnURL = returnURL + '&' + query;
272                 }
273             }
274         }
275
276         // Now append the sanitized target as needed.
277         if (!target.empty())
278             returnURL = returnURL + "&target=" + urlenc->encode(target.c_str());
279     }
280     else {
281         // For a virtual handler, we append target to the return link.
282          if (!target.empty())
283             returnURL = returnURL + "&target=" + urlenc->encode(target.c_str());
284          // Preserve designated request settings on the URL.
285          for (vector<string>::const_iterator opt = m_preservedOptions.begin(); opt != m_preservedOptions.end(); ++ opt) {
286              prop = request.getRequestSettings().first->getString(opt->c_str());
287              if (prop.first)
288                  returnURL = returnURL + '&' + (*opt) + '=' + urlenc->encode(prop.second);
289          }
290     }
291
292     string req=string(discoveryURL.second) + (strchr(discoveryURL.second,'?') ? '&' : '?') + "entityID=" + urlenc->encode(app.getString("entityID").second) +
293         "&return=" + urlenc->encode(returnURL.c_str());
294     if (m_returnParam)
295         req = req + "&returnIDParam=" + m_returnParam;
296     if (isPassive)
297         req += "&isPassive=true";
298
299     return make_pair(true, request.sendRedirect(req.c_str()));
300 }