214a770f6ff6f099a5c8e62ba51db7f018b6c95e
[shibboleth/cpp-opensaml.git] / samltest / binding.h
1 /*\r
2  *  Copyright 2001-2009 Internet2\r
3  * \r
4  * Licensed under the Apache License, Version 2.0 (the "License");\r
5  * you may not use this file except in compliance with the License.\r
6  * You may obtain a copy of the License at\r
7  *\r
8  *     http://www.apache.org/licenses/LICENSE-2.0\r
9  *\r
10  * Unless required by applicable law or agreed to in writing, software\r
11  * distributed under the License is distributed on an "AS IS" BASIS,\r
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r
13  * See the License for the specific language governing permissions and\r
14  * limitations under the License.\r
15  */\r
16 \r
17 #include "internal.h"\r
18 \r
19 #include <saml/SAMLConfig.h>\r
20 #include <saml/binding/MessageDecoder.h>\r
21 #include <saml/binding/MessageEncoder.h>\r
22 #include <saml/binding/SecurityPolicy.h>\r
23 #include <saml/binding/SecurityPolicyRule.h>\r
24 #include <saml/saml2/metadata/Metadata.h>\r
25 #include <saml/saml2/metadata/MetadataProvider.h>\r
26 #include <xmltooling/io/HTTPRequest.h>\r
27 #include <xmltooling/io/HTTPResponse.h>\r
28 #include <xmltooling/security/Credential.h>\r
29 #include <xmltooling/security/CredentialCriteria.h>\r
30 #include <xmltooling/security/TrustEngine.h>\r
31 #include <xmltooling/util/URLEncoder.h>\r
32 \r
33 using namespace opensaml::saml2md;\r
34 using namespace opensaml;\r
35 using namespace xmlsignature;\r
36 \r
37 class SAMLBindingBaseTestCase : public HTTPRequest, public HTTPResponse\r
38 {\r
39 protected:\r
40     CredentialResolver* m_creds; \r
41     MetadataProvider* m_metadata;\r
42     TrustEngine* m_trust;\r
43     map<string,string> m_fields;\r
44     map<string,string> m_headers;\r
45     string m_method,m_url,m_query;\r
46     vector<XSECCryptoX509*> m_clientCerts;\r
47     vector<const SecurityPolicyRule*> m_rules;\r
48 \r
49 public:\r
50     void setUp() {\r
51         m_creds=NULL;\r
52         m_metadata=NULL;\r
53         m_trust=NULL;\r
54         m_fields.clear();\r
55         m_headers.clear();\r
56         m_method.erase();\r
57         m_url.erase();\r
58         m_query.erase();\r
59 \r
60         try {\r
61             string config = data_path + "binding/ExampleMetadataProvider.xml";\r
62             ifstream in(config.c_str());\r
63             DOMDocument* doc=XMLToolingConfig::getConfig().getParser().parse(in);\r
64             XercesJanitor<DOMDocument> janitor(doc);\r
65     \r
66             auto_ptr_XMLCh path("path");\r
67             string s = data_path + "binding/example-metadata.xml";\r
68             auto_ptr_XMLCh file(s.c_str());\r
69             doc->getDocumentElement()->setAttributeNS(NULL,path.get(),file.get());\r
70     \r
71             m_metadata = SAMLConfig::getConfig().MetadataProviderManager.newPlugin(\r
72                 XML_METADATA_PROVIDER,doc->getDocumentElement()\r
73                 );\r
74             m_metadata->init();\r
75 \r
76             config = data_path + "FilesystemCredentialResolver.xml";\r
77             ifstream in2(config.c_str());\r
78             DOMDocument* doc2=XMLToolingConfig::getConfig().getParser().parse(in2);\r
79             XercesJanitor<DOMDocument> janitor2(doc2);\r
80             m_creds = XMLToolingConfig::getConfig().CredentialResolverManager.newPlugin(\r
81                 FILESYSTEM_CREDENTIAL_RESOLVER,doc2->getDocumentElement()\r
82                 );\r
83                 \r
84             m_trust = XMLToolingConfig::getConfig().TrustEngineManager.newPlugin(EXPLICIT_KEY_TRUSTENGINE, NULL);\r
85 \r
86             m_rules.push_back(SAMLConfig::getConfig().SecurityPolicyRuleManager.newPlugin(MESSAGEFLOW_POLICY_RULE,NULL));\r
87             m_rules.push_back(SAMLConfig::getConfig().SecurityPolicyRuleManager.newPlugin(SIMPLESIGNING_POLICY_RULE,NULL));\r
88             m_rules.push_back(SAMLConfig::getConfig().SecurityPolicyRuleManager.newPlugin(XMLSIGNING_POLICY_RULE,NULL));\r
89         }\r
90         catch (XMLToolingException& ex) {\r
91             TS_TRACE(ex.what());\r
92             tearDown();\r
93             throw;\r
94         }\r
95 \r
96     }\r
97     \r
98     void tearDown() {\r
99         for_each(m_rules.begin(), m_rules.end(), xmltooling::cleanup<SecurityPolicyRule>());\r
100         m_rules.clear();\r
101         delete m_creds;\r
102         delete m_metadata;\r
103         delete m_trust;\r
104         m_creds=NULL;\r
105         m_metadata=NULL;\r
106         m_trust=NULL;\r
107         m_fields.clear();\r
108         m_headers.clear();\r
109         m_method.erase();\r
110         m_url.erase();\r
111         m_query.erase();\r
112     }\r
113 \r
114     // HTTPRequest methods\r
115 \r
116     const char* getMethod() const {\r
117         return m_method.c_str();\r
118     }\r
119 \r
120     const char* getScheme() const {\r
121         return "https";\r
122     }\r
123 \r
124     const char* getHostname() const {\r
125         return "localhost";\r
126     }\r
127 \r
128     int getPort() const {\r
129         return 443;\r
130     }\r
131 \r
132     string getContentType() const {\r
133         return "application/x-www-form-urlencoded";\r
134     }\r
135 \r
136     long getContentLength() const {\r
137         return -1;\r
138     }\r
139 \r
140     const char* getRequestURI() const {\r
141         return "/";\r
142     }\r
143 \r
144     const char* getRequestURL() const {\r
145         return m_url.c_str();\r
146     }\r
147     \r
148     const char* getRequestBody() const {\r
149         return NULL;\r
150     }\r
151     \r
152     const char* getQueryString() const {\r
153         return m_query.c_str();\r
154     }\r
155     \r
156     string getRemoteUser() const {\r
157         return "";\r
158     }\r
159 \r
160     string getRemoteAddr() const {\r
161         return "127.0.0.1";\r
162     }\r
163 \r
164     const std::vector<XSECCryptoX509*>& getClientCertificates() const {\r
165         return m_clientCerts;\r
166     }\r
167 \r
168     string getHeader(const char* name) const {\r
169         map<string,string>::const_iterator i=m_headers.find(name);\r
170         return i==m_headers.end() ? "" : i->second;\r
171     }\r
172     \r
173     const char* getParameter(const char* name) const {\r
174         map<string,string>::const_iterator i=m_fields.find(name);\r
175         return i==m_fields.end() ? NULL : i->second.c_str();\r
176     }\r
177 \r
178     vector<const char*>::size_type getParameters(const char* name, vector<const char*>& values) const {\r
179         values.clear();\r
180         map<string,string>::const_iterator i=m_fields.find(name);\r
181         if (i!=m_fields.end())\r
182             values.push_back(i->second.c_str());\r
183         return values.size();\r
184     }\r
185     \r
186     // HTTPResponse methods\r
187     \r
188     void setResponseHeader(const char* name, const char* value) {\r
189         m_headers[name] = value ? value : "";\r
190     }\r
191 \r
192     // The amount of error checking missing from this is incredible, but as long\r
193     // as the test data isn't unexpected or malformed, it should work.\r
194     \r
195     long sendRedirect(const char* url) {\r
196         m_method = "GET";\r
197         char* dup = strdup(url);\r
198         char* pch = strchr(dup,'?');\r
199         if (pch) {\r
200             *pch++=0;\r
201             m_query = pch;\r
202             char* name=pch;\r
203             while (name && *name) {\r
204                 pch=strchr(pch,'=');\r
205                 *pch++=0;\r
206                 char* value=pch;\r
207                 pch=strchr(pch,'&');\r
208                 if (pch)\r
209                     *pch++=0;\r
210                 XMLToolingConfig::getConfig().getURLEncoder()->decode(value);\r
211                 m_fields[name] = value;\r
212                 name = pch; \r
213             }\r
214         }\r
215         m_url = dup;\r
216         free(dup);\r
217         return m_fields.size();\r
218     }\r
219     \r
220     string html_decode(const string& s) const {\r
221         string decoded;\r
222         const char* ch=s.c_str();\r
223         while (*ch) {\r
224             if (*ch=='&') {\r
225                 if (!strncmp(ch,"&lt;",4)) {\r
226                     decoded+='<'; ch+=4;\r
227                 }\r
228                 else if (!strncmp(ch,"&gt;",4)) {\r
229                     decoded+='>'; ch+=4;\r
230                 }\r
231                 else if (!strncmp(ch,"&quot;",6)) {\r
232                     decoded+='"'; ch+=6;\r
233                 }\r
234                 else if (*++ch=='#') {\r
235                     decoded+=(char)atoi(++ch);\r
236                     ch=strchr(ch,';')+1;\r
237                 }\r
238             }\r
239             else {\r
240                 decoded+=*ch++;\r
241             }\r
242         }\r
243         return decoded;\r
244     }\r
245     \r
246     using HTTPResponse::sendResponse;\r
247 \r
248     long sendResponse(std::istream& inputStream, long status) {\r
249         m_method="POST";\r
250         string page,line;\r
251         while (getline(inputStream,line))\r
252             page += line + '\n';\r
253             \r
254         const char* pch=strstr(page.c_str(),"action=\"");\r
255         pch+=strlen("action=\"");\r
256         m_url = html_decode(page.substr(pch-page.c_str(),strchr(pch,'"')-pch));\r
257 \r
258         while (pch=strstr(pch,"<input type=\"hidden\" name=\"")) {\r
259             pch+=strlen("<input type=\"hidden\" name=\"");\r
260             string name = page.substr(pch-page.c_str(),strchr(pch,'"')-pch);\r
261             pch=strstr(pch,"value=\"");\r
262             pch+=strlen("value=\"");\r
263             m_fields[name] = html_decode(page.substr(pch-page.c_str(),strchr(pch,'"')-pch));\r
264         }\r
265         return m_fields.size();\r
266     }\r
267 };\r