a455499a894eda4d78c7154869d8d51657c6e285
[shibboleth/cpp-opensaml.git] / samltest / binding.h
1 /**
2  * Licensed to the University Corporation for Advanced Internet
3  * Development, Inc. (UCAID) under one or more contributor license
4  * agreements. See the NOTICE file distributed with this work for
5  * additional information regarding copyright ownership.
6  *
7  * UCAID licenses this file to you under the Apache License,
8  * Version 2.0 (the "License"); you may not use this file except
9  * in compliance with the License. You may obtain a copy of the
10  * License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing,
15  * software distributed under the License is distributed on an
16  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
17  * either express or implied. See the License for the specific
18  * language governing permissions and limitations under the License.
19  */
20
21 #include "internal.h"
22
23 #include <saml/SAMLConfig.h>
24 #include <saml/binding/MessageDecoder.h>
25 #include <saml/binding/MessageEncoder.h>
26 #include <saml/binding/SecurityPolicy.h>
27 #include <saml/binding/SecurityPolicyRule.h>
28 #include <saml/saml2/metadata/Metadata.h>
29 #include <saml/saml2/metadata/MetadataProvider.h>
30 #include <xmltooling/io/HTTPRequest.h>
31 #include <xmltooling/io/HTTPResponse.h>
32 #include <xmltooling/security/Credential.h>
33 #include <xmltooling/security/CredentialCriteria.h>
34 #include <xmltooling/security/TrustEngine.h>
35 #include <xmltooling/util/URLEncoder.h>
36
37 using namespace opensaml::saml2md;
38 using namespace opensaml;
39 using namespace xmlsignature;
40
41 class SAMLBindingBaseTestCase : public HTTPRequest, public HTTPResponse
42 {
43 protected:
44     CredentialResolver* m_creds; 
45     MetadataProvider* m_metadata;
46     TrustEngine* m_trust;
47     map<string,string> m_fields;
48     map<string,string> m_headers;
49     string m_method,m_url,m_query;
50     vector<XSECCryptoX509*> m_clientCerts;
51     vector<const SecurityPolicyRule*> m_rules;
52
53 public:
54     void setUp() {
55         m_creds=nullptr;
56         m_metadata=nullptr;
57         m_trust=nullptr;
58         m_fields.clear();
59         m_headers.clear();
60         m_method.erase();
61         m_url.erase();
62         m_query.erase();
63
64         try {
65             string config = data_path + "binding/ExampleMetadataProvider.xml";
66             ifstream in(config.c_str());
67             DOMDocument* doc=XMLToolingConfig::getConfig().getParser().parse(in);
68             XercesJanitor<DOMDocument> janitor(doc);
69     
70             auto_ptr_XMLCh path("path");
71             string s = data_path + "binding/example-metadata.xml";
72             auto_ptr_XMLCh file(s.c_str());
73             doc->getDocumentElement()->setAttributeNS(nullptr,path.get(),file.get());
74     
75             m_metadata = SAMLConfig::getConfig().MetadataProviderManager.newPlugin(
76                 XML_METADATA_PROVIDER,doc->getDocumentElement()
77                 );
78             m_metadata->init();
79
80             config = data_path + "FilesystemCredentialResolver.xml";
81             ifstream in2(config.c_str());
82             DOMDocument* doc2=XMLToolingConfig::getConfig().getParser().parse(in2);
83             XercesJanitor<DOMDocument> janitor2(doc2);
84             m_creds = XMLToolingConfig::getConfig().CredentialResolverManager.newPlugin(
85                 FILESYSTEM_CREDENTIAL_RESOLVER,doc2->getDocumentElement()
86                 );
87                 
88             m_trust = XMLToolingConfig::getConfig().TrustEngineManager.newPlugin(EXPLICIT_KEY_TRUSTENGINE, nullptr);
89
90             m_rules.push_back(SAMLConfig::getConfig().SecurityPolicyRuleManager.newPlugin(MESSAGEFLOW_POLICY_RULE,nullptr));
91             m_rules.push_back(SAMLConfig::getConfig().SecurityPolicyRuleManager.newPlugin(SIMPLESIGNING_POLICY_RULE,nullptr));
92             m_rules.push_back(SAMLConfig::getConfig().SecurityPolicyRuleManager.newPlugin(XMLSIGNING_POLICY_RULE,nullptr));
93         }
94         catch (XMLToolingException& ex) {
95             TS_TRACE(ex.what());
96             tearDown();
97             throw;
98         }
99
100     }
101     
102     void tearDown() {
103         for_each(m_rules.begin(), m_rules.end(), xmltooling::cleanup<SecurityPolicyRule>());
104         m_rules.clear();
105         delete m_creds;
106         delete m_metadata;
107         delete m_trust;
108         m_creds=nullptr;
109         m_metadata=nullptr;
110         m_trust=nullptr;
111         m_fields.clear();
112         m_headers.clear();
113         m_method.erase();
114         m_url.erase();
115         m_query.erase();
116     }
117
118     // HTTPRequest methods
119
120     const char* getMethod() const {
121         return m_method.c_str();
122     }
123
124     const char* getScheme() const {
125         return "https";
126     }
127
128     const char* getHostname() const {
129         return "localhost";
130     }
131
132     int getPort() const {
133         return 443;
134     }
135
136     string getContentType() const {
137         return "application/x-www-form-urlencoded";
138     }
139
140     long getContentLength() const {
141         return -1;
142     }
143
144     const char* getRequestURI() const {
145         return "/";
146     }
147
148     const char* getRequestURL() const {
149         return m_url.c_str();
150     }
151     
152     const char* getRequestBody() const {
153         return nullptr;
154     }
155     
156     const char* getQueryString() const {
157         return m_query.c_str();
158     }
159     
160     string getRemoteUser() const {
161         return "";
162     }
163
164     string getRemoteAddr() const {
165         return "127.0.0.1";
166     }
167
168     const std::vector<XSECCryptoX509*>& getClientCertificates() const {
169         return m_clientCerts;
170     }
171
172     string getHeader(const char* name) const {
173         map<string,string>::const_iterator i=m_headers.find(name);
174         return i==m_headers.end() ? "" : i->second;
175     }
176     
177     const char* getParameter(const char* name) const {
178         map<string,string>::const_iterator i=m_fields.find(name);
179         return i==m_fields.end() ? nullptr : i->second.c_str();
180     }
181
182     vector<const char*>::size_type getParameters(const char* name, vector<const char*>& values) const {
183         values.clear();
184         map<string,string>::const_iterator i=m_fields.find(name);
185         if (i!=m_fields.end())
186             values.push_back(i->second.c_str());
187         return values.size();
188     }
189     
190     // HTTPResponse methods
191     
192     void setResponseHeader(const char* name, const char* value) {
193         m_headers[name] = value ? value : "";
194     }
195
196     // The amount of error checking missing from this is incredible, but as long
197     // as the test data isn't unexpected or malformed, it should work.
198     
199     long sendRedirect(const char* url) {
200         m_method = "GET";
201         char* dup = strdup(url);
202         char* pch = strchr(dup,'?');
203         if (pch) {
204             *pch++=0;
205             m_query = pch;
206             char* name=pch;
207             while (name && *name) {
208                 pch=strchr(pch,'=');
209                 *pch++=0;
210                 char* value=pch;
211                 pch=strchr(pch,'&');
212                 if (pch)
213                     *pch++=0;
214                 XMLToolingConfig::getConfig().getURLEncoder()->decode(value);
215                 m_fields[name] = value;
216                 name = pch; 
217             }
218         }
219         m_url = dup;
220         free(dup);
221         return m_fields.size();
222     }
223     
224     string html_decode(const string& s) const {
225         string decoded;
226         const char* ch=s.c_str();
227         while (*ch) {
228             if (*ch=='&') {
229                 if (!strncmp(ch,"&lt;",4)) {
230                     decoded+='<'; ch+=4;
231                 }
232                 else if (!strncmp(ch,"&gt;",4)) {
233                     decoded+='>'; ch+=4;
234                 }
235                 else if (!strncmp(ch,"&quot;",6)) {
236                     decoded+='"'; ch+=6;
237                 }
238                 else if (*++ch=='#') {
239                     decoded+=(char)atoi(++ch);
240                     ch=strchr(ch,';')+1;
241                 }
242             }
243             else {
244                 decoded+=*ch++;
245             }
246         }
247         return decoded;
248     }
249     
250     using HTTPResponse::sendResponse;
251
252     long sendResponse(std::istream& inputStream, long status) {
253         m_method="POST";
254         string page,line;
255         while (getline(inputStream,line))
256             page += line + '\n';
257             
258         const char* pch=strstr(page.c_str(),"action=\"");
259         pch+=strlen("action=\"");
260         m_url = html_decode(page.substr(pch-page.c_str(),strchr(pch,'"')-pch));
261
262         while (pch=strstr(pch,"<input type=\"hidden\" name=\"")) {
263             pch+=strlen("<input type=\"hidden\" name=\"");
264             string name = page.substr(pch-page.c_str(),strchr(pch,'"')-pch);
265             pch=strstr(pch,"value=\"");
266             pch+=strlen("value=\"");
267             m_fields[name] = html_decode(page.substr(pch-page.c_str(),strchr(pch,'"')-pch));
268         }
269         return m_fields.size();
270     }
271 };