Checked in test by accident.
[shibboleth/cpp-opensaml.git] / samltest / binding.h
1 /*\r
2  *  Copyright 2001-2007 Internet2\r
3  * \r
4  * Licensed under the Apache License, Version 2.0 (the "License");\r
5  * you may not use this file except in compliance with the License.\r
6  * You may obtain a copy of the License at\r
7  *\r
8  *     http://www.apache.org/licenses/LICENSE-2.0\r
9  *\r
10  * Unless required by applicable law or agreed to in writing, software\r
11  * distributed under the License is distributed on an "AS IS" BASIS,\r
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r
13  * See the License for the specific language governing permissions and\r
14  * limitations under the License.\r
15  */\r
16 \r
17 #include "internal.h"\r
18 \r
19 #include <saml/SAMLConfig.h>\r
20 #include <saml/binding/HTTPRequest.h>\r
21 #include <saml/binding/HTTPResponse.h>\r
22 #include <saml/binding/MessageDecoder.h>\r
23 #include <saml/binding/MessageEncoder.h>\r
24 #include <saml/binding/SecurityPolicyRule.h>\r
25 #include <saml/saml2/metadata/Metadata.h>\r
26 #include <saml/saml2/metadata/MetadataProvider.h>\r
27 #include <xmltooling/security/TrustEngine.h>\r
28 #include <xmltooling/util/URLEncoder.h>\r
29 \r
30 using namespace opensaml::saml2md;\r
31 using namespace opensaml;\r
32 using namespace xmlsignature;\r
33 \r
34 class SAMLBindingBaseTestCase : public HTTPRequest, public HTTPResponse\r
35 {\r
36 protected:\r
37     CredentialResolver* m_creds; \r
38     MetadataProvider* m_metadata;\r
39     TrustEngine* m_trust;\r
40     map<string,string> m_fields;\r
41     map<string,string> m_headers;\r
42     string m_method,m_url,m_query;\r
43     vector<XSECCryptoX509*> m_clientCerts;\r
44     vector<const SecurityPolicyRule*> m_rules1;\r
45     vector<const SecurityPolicyRule*> m_rules2;\r
46 \r
47 public:\r
48     void setUp() {\r
49         m_creds=NULL;\r
50         m_metadata=NULL;\r
51         m_trust=NULL;\r
52         m_fields.clear();\r
53         m_headers.clear();\r
54         m_method.erase();\r
55         m_url.erase();\r
56         m_query.erase();\r
57 \r
58         try {\r
59             string config = data_path + "binding/ExampleMetadataProvider.xml";\r
60             ifstream in(config.c_str());\r
61             DOMDocument* doc=XMLToolingConfig::getConfig().getParser().parse(in);\r
62             XercesJanitor<DOMDocument> janitor(doc);\r
63     \r
64             auto_ptr_XMLCh path("path");\r
65             string s = data_path + "binding/example-metadata.xml";\r
66             auto_ptr_XMLCh file(s.c_str());\r
67             doc->getDocumentElement()->setAttributeNS(NULL,path.get(),file.get());\r
68     \r
69             m_metadata = SAMLConfig::getConfig().MetadataProviderManager.newPlugin(\r
70                 XML_METADATA_PROVIDER,doc->getDocumentElement()\r
71                 );\r
72             m_metadata->init();\r
73 \r
74             config = data_path + "FilesystemCredentialResolver.xml";\r
75             ifstream in2(config.c_str());\r
76             DOMDocument* doc2=XMLToolingConfig::getConfig().getParser().parse(in2);\r
77             XercesJanitor<DOMDocument> janitor2(doc2);\r
78             m_creds = XMLToolingConfig::getConfig().CredentialResolverManager.newPlugin(\r
79                 FILESYSTEM_CREDENTIAL_RESOLVER,doc2->getDocumentElement()\r
80                 );\r
81                 \r
82             m_trust = XMLToolingConfig::getConfig().TrustEngineManager.newPlugin(EXPLICIT_KEY_TRUSTENGINE, NULL);\r
83 \r
84             m_rules1.push_back(SAMLConfig::getConfig().SecurityPolicyRuleManager.newPlugin(SAML1MESSAGE_POLICY_RULE,NULL));\r
85             m_rules1.push_back(SAMLConfig::getConfig().SecurityPolicyRuleManager.newPlugin(MESSAGEFLOW_POLICY_RULE,NULL));\r
86             m_rules1.push_back(SAMLConfig::getConfig().SecurityPolicyRuleManager.newPlugin(SIMPLESIGNING_POLICY_RULE,NULL));\r
87             m_rules1.push_back(SAMLConfig::getConfig().SecurityPolicyRuleManager.newPlugin(XMLSIGNING_POLICY_RULE,NULL));\r
88 \r
89             m_rules2.push_back(SAMLConfig::getConfig().SecurityPolicyRuleManager.newPlugin(SAML2MESSAGE_POLICY_RULE,NULL));\r
90             m_rules2.push_back(SAMLConfig::getConfig().SecurityPolicyRuleManager.newPlugin(MESSAGEFLOW_POLICY_RULE,NULL));\r
91             m_rules2.push_back(SAMLConfig::getConfig().SecurityPolicyRuleManager.newPlugin(SIMPLESIGNING_POLICY_RULE,NULL));\r
92             m_rules2.push_back(SAMLConfig::getConfig().SecurityPolicyRuleManager.newPlugin(XMLSIGNING_POLICY_RULE,NULL));\r
93         }\r
94         catch (XMLToolingException& ex) {\r
95             TS_TRACE(ex.what());\r
96             tearDown();\r
97             throw;\r
98         }\r
99 \r
100     }\r
101     \r
102     void tearDown() {\r
103         for_each(m_rules1.begin(), m_rules1.end(), xmltooling::cleanup<SecurityPolicyRule>());\r
104         m_rules1.clear();\r
105         for_each(m_rules2.begin(), m_rules2.end(), xmltooling::cleanup<SecurityPolicyRule>());\r
106         m_rules2.clear();\r
107         delete m_creds;\r
108         delete m_metadata;\r
109         delete m_trust;\r
110         m_creds=NULL;\r
111         m_metadata=NULL;\r
112         m_trust=NULL;\r
113         m_fields.clear();\r
114         m_headers.clear();\r
115         m_method.erase();\r
116         m_url.erase();\r
117         m_query.erase();\r
118     }\r
119 \r
120     // HTTPRequest methods\r
121 \r
122     const char* getMethod() const {\r
123         return m_method.c_str();\r
124     }\r
125 \r
126     const char* getScheme() const {\r
127         return "https";\r
128     }\r
129 \r
130     const char* getHostname() const {\r
131         return "localhost";\r
132     }\r
133 \r
134     int getPort() const {\r
135         return 443;\r
136     }\r
137 \r
138     string getContentType() const {\r
139         return "application/x-www-form-urlencoded";\r
140     }\r
141 \r
142     long getContentLength() const {\r
143         return -1;\r
144     }\r
145 \r
146     const char* getRequestURI() const {\r
147         return "/";\r
148     }\r
149 \r
150     const char* getRequestURL() const {\r
151         return m_url.c_str();\r
152     }\r
153     \r
154     const char* getRequestBody() const {\r
155         return NULL;\r
156     }\r
157     \r
158     const char* getQueryString() const {\r
159         return m_query.c_str();\r
160     }\r
161     \r
162     string getRemoteUser() const {\r
163         return "";\r
164     }\r
165 \r
166     string getRemoteAddr() const {\r
167         return "127.0.0.1";\r
168     }\r
169 \r
170     const std::vector<XSECCryptoX509*>& getClientCertificates() const {\r
171         return m_clientCerts;\r
172     }\r
173 \r
174     string getHeader(const char* name) const {\r
175         map<string,string>::const_iterator i=m_headers.find(name);\r
176         return i==m_headers.end() ? "" : i->second;\r
177     }\r
178     \r
179     const char* getParameter(const char* name) const {\r
180         map<string,string>::const_iterator i=m_fields.find(name);\r
181         return i==m_fields.end() ? NULL : i->second.c_str();\r
182     }\r
183 \r
184     vector<const char*>::size_type getParameters(const char* name, vector<const char*>& values) const {\r
185         values.clear();\r
186         map<string,string>::const_iterator i=m_fields.find(name);\r
187         if (i!=m_fields.end())\r
188             values.push_back(i->second.c_str());\r
189         return values.size();\r
190     }\r
191     \r
192     // HTTPResponse methods\r
193     \r
194     void setResponseHeader(const char* name, const char* value) {\r
195         m_headers[name] = value ? value : "";\r
196     }\r
197 \r
198     // The amount of error checking missing from this is incredible, but as long\r
199     // as the test data isn't unexpected or malformed, it should work.\r
200     \r
201     long sendRedirect(const char* url) {\r
202         m_method = "GET";\r
203         char* dup = strdup(url);\r
204         char* pch = strchr(dup,'?');\r
205         if (pch) {\r
206             *pch++=0;\r
207             m_query = pch;\r
208             char* name=pch;\r
209             while (name && *name) {\r
210                 pch=strchr(pch,'=');\r
211                 *pch++=0;\r
212                 char* value=pch;\r
213                 pch=strchr(pch,'&');\r
214                 if (pch)\r
215                     *pch++=0;\r
216                 XMLToolingConfig::getConfig().getURLEncoder()->decode(value);\r
217                 m_fields[name] = value;\r
218                 name = pch; \r
219             }\r
220         }\r
221         m_url = dup;\r
222         free(dup);\r
223         return m_fields.size();\r
224     }\r
225     \r
226     string html_decode(const string& s) const {\r
227         string decoded;\r
228         const char* ch=s.c_str();\r
229         while (*ch) {\r
230             if (*ch=='&') {\r
231                 if (!strncmp(ch,"&lt;",4)) {\r
232                     decoded+='<'; ch+=4;\r
233                 }\r
234                 else if (!strncmp(ch,"&gt;",4)) {\r
235                     decoded+='>'; ch+=4;\r
236                 }\r
237                 else if (!strncmp(ch,"&quot;",6)) {\r
238                     decoded+='"'; ch+=6;\r
239                 }\r
240                 else if (*++ch=='#') {\r
241                     decoded+=(char)atoi(++ch);\r
242                     ch=strchr(ch,';')+1;\r
243                 }\r
244             }\r
245             else {\r
246                 decoded+=*ch++;\r
247             }\r
248         }\r
249         return decoded;\r
250     }\r
251     \r
252     long sendResponse(std::istream& inputStream, long status) {\r
253         m_method="POST";\r
254         string page,line;\r
255         while (getline(inputStream,line))\r
256             page += line + '\n';\r
257             \r
258         const char* pch=strstr(page.c_str(),"action=\"");\r
259         pch+=strlen("action=\"");\r
260         m_url = html_decode(page.substr(pch-page.c_str(),strchr(pch,'"')-pch));\r
261 \r
262         while (pch=strstr(pch,"<input type=\"hidden\" name=\"")) {\r
263             pch+=strlen("<input type=\"hidden\" name=\"");\r
264             string name = page.substr(pch-page.c_str(),strchr(pch,'"')-pch);\r
265             pch=strstr(pch,"value=\"");\r
266             pch+=strlen("value=\"");\r
267             m_fields[name] = html_decode(page.substr(pch-page.c_str(),strchr(pch,'"')-pch));\r
268         }\r
269         return m_fields.size();\r
270     }\r
271 };\r