fdd0ce77ff99d5a2505c209720d9cb82f78edbc5
[shibboleth/cpp-opensaml.git] / samltest / binding.h
1 /*\r
2  *  Copyright 2001-2006 Internet2\r
3  * \r
4  * Licensed under the Apache License, Version 2.0 (the "License");\r
5  * you may not use this file except in compliance with the License.\r
6  * You may obtain a copy of the License at\r
7  *\r
8  *     http://www.apache.org/licenses/LICENSE-2.0\r
9  *\r
10  * Unless required by applicable law or agreed to in writing, software\r
11  * distributed under the License is distributed on an "AS IS" BASIS,\r
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r
13  * See the License for the specific language governing permissions and\r
14  * limitations under the License.\r
15  */\r
16 \r
17 #include "internal.h"\r
18 \r
19 #include <saml/SAMLConfig.h>\r
20 #include <saml/binding/HTTPRequest.h>\r
21 #include <saml/binding/HTTPResponse.h>\r
22 #include <saml/binding/MessageDecoder.h>\r
23 #include <saml/binding/MessageEncoder.h>\r
24 #include <saml/binding/URLEncoder.h>\r
25 #include <saml/saml2/metadata/MetadataProvider.h>\r
26 #include <saml/security/TrustEngine.h>\r
27 \r
28 using namespace saml2md;\r
29 using namespace xmlsignature;\r
30 \r
31 class SAMLBindingBaseTestCase : public HTTPRequest, public HTTPResponse\r
32 {\r
33 protected:\r
34     CredentialResolver* m_creds; \r
35     MetadataProvider* m_metadata;\r
36     opensaml::TrustEngine* m_trust;\r
37     map<string,string> m_fields;\r
38     map<string,string> m_headers;\r
39     string m_method,m_url;\r
40     vector<XSECCryptoX509*> m_clientCerts;\r
41     vector<const SecurityPolicyRule*> m_rules;\r
42 \r
43 public:\r
44     void setUp() {\r
45         m_creds=NULL;\r
46         m_metadata=NULL;\r
47         m_trust=NULL;\r
48         m_fields.clear();\r
49         m_headers.clear();\r
50         m_method.erase();\r
51         m_url.erase();\r
52 \r
53         try {\r
54             string config = data_path + "binding/ExampleMetadataProvider.xml";\r
55             ifstream in(config.c_str());\r
56             DOMDocument* doc=XMLToolingConfig::getConfig().getParser().parse(in);\r
57             XercesJanitor<DOMDocument> janitor(doc);\r
58     \r
59             auto_ptr_XMLCh path("path");\r
60             string s = data_path + "binding/example-metadata.xml";\r
61             auto_ptr_XMLCh file(s.c_str());\r
62             doc->getDocumentElement()->setAttributeNS(NULL,path.get(),file.get());\r
63     \r
64             m_metadata = SAMLConfig::getConfig().MetadataProviderManager.newPlugin(\r
65                 FILESYSTEM_METADATA_PROVIDER,doc->getDocumentElement()\r
66                 );\r
67             m_metadata->init();\r
68 \r
69             config = data_path + "FilesystemCredentialResolver.xml";\r
70             ifstream in2(config.c_str());\r
71             DOMDocument* doc2=XMLToolingConfig::getConfig().getParser().parse(in2);\r
72             XercesJanitor<DOMDocument> janitor2(doc2);\r
73             m_creds = XMLToolingConfig::getConfig().CredentialResolverManager.newPlugin(\r
74                 FILESYSTEM_CREDENTIAL_RESOLVER,doc2->getDocumentElement()\r
75                 );\r
76                 \r
77             m_trust = SAMLConfig::getConfig().TrustEngineManager.newPlugin(EXPLICIT_KEY_SAMLTRUSTENGINE, NULL);\r
78 \r
79             m_rules.push_back(SAMLConfig::getConfig().SecurityPolicyRuleManager.newPlugin(MESSAGEFLOW_POLICY_RULE,NULL));\r
80             m_rules.push_back(SAMLConfig::getConfig().SecurityPolicyRuleManager.newPlugin(MESSAGESIGNING_POLICY_RULE,NULL));\r
81         }\r
82         catch (XMLToolingException& ex) {\r
83             TS_TRACE(ex.what());\r
84             tearDown();\r
85             throw;\r
86         }\r
87 \r
88     }\r
89     \r
90     void tearDown() {\r
91         for_each(m_rules.begin(), m_rules.end(), xmltooling::cleanup<SecurityPolicyRule>());\r
92         delete m_creds;\r
93         delete m_metadata;\r
94         delete m_trust;\r
95         m_creds=NULL;\r
96         m_metadata=NULL;\r
97         m_trust=NULL;\r
98         m_fields.clear();\r
99         m_headers.clear();\r
100         m_method.erase();\r
101         m_url.erase();\r
102     }\r
103 \r
104     // HTTPRequest methods\r
105 \r
106     const char* getMethod() const {\r
107         return m_method.c_str();\r
108     }\r
109 \r
110     const char* getScheme() const {\r
111         return "https";\r
112     }\r
113 \r
114     bool isSecure() const {\r
115         return true;\r
116     }\r
117 \r
118     string getContentType() const {\r
119         return "application/x-www-form-urlencoded";\r
120     }\r
121 \r
122     long getContentLength() const {\r
123         return -1;\r
124     }\r
125 \r
126     const char* getRequestURL() const {\r
127         return m_url.c_str();\r
128     }\r
129     \r
130     const char* getRequestBody() const {\r
131         return NULL;\r
132     }\r
133     \r
134     const char* getQueryString() const {\r
135         return NULL;\r
136     }\r
137     \r
138     string getRemoteUser() const {\r
139         return "";\r
140     }\r
141 \r
142     string getRemoteAddr() const {\r
143         return "127.0.0.1";\r
144     }\r
145 \r
146     const std::vector<XSECCryptoX509*>& getClientCertificates() const {\r
147         return m_clientCerts;\r
148     }\r
149 \r
150     string getHeader(const char* name) const {\r
151         map<string,string>::const_iterator i=m_headers.find(name);\r
152         return i==m_headers.end() ? "" : i->second;\r
153     }\r
154     \r
155     const char* getParameter(const char* name) const {\r
156         map<string,string>::const_iterator i=m_fields.find(name);\r
157         return i==m_fields.end() ? NULL : i->second.c_str();\r
158     }\r
159 \r
160     vector<const char*>::size_type getParameters(const char* name, vector<const char*>& values) const {\r
161         values.clear();\r
162         map<string,string>::const_iterator i=m_fields.find(name);\r
163         if (i!=m_fields.end())\r
164             values.push_back(i->second.c_str());\r
165         return values.size();\r
166     }\r
167     \r
168     // HTTPResponse methods\r
169     \r
170     void setHeader(const char* name, const char* value) {\r
171         m_headers[name] = value ? value : "";\r
172     }\r
173 \r
174     void setContentType(const char* type) {\r
175         setHeader("Content-Type", type);\r
176     }\r
177     \r
178     void setCookie(const char* name, const char* value) {\r
179         m_headers["Set-Cookie"] = string(name) + "=" + (value ? value : "");\r
180     }\r
181     \r
182     // The amount of error checking missing from this is incredible, but as long\r
183     // as the test data isn't unexpected or malformed, it should work.\r
184     \r
185     long sendRedirect(const char* url) {\r
186         m_method = "GET";\r
187         char* dup = strdup(url);\r
188         char* pch = strchr(dup,'?');\r
189         if (pch) {\r
190             *pch++=0;\r
191             char* name=pch;\r
192             while (name && *name) {\r
193                 pch=strchr(pch,'=');\r
194                 *pch++=0;\r
195                 char* value=pch;\r
196                 pch=strchr(pch,'&');\r
197                 if (pch)\r
198                     *pch++=0;\r
199                 SAMLConfig::getConfig().getURLEncoder()->decode(value);\r
200                 m_fields[name] = value;\r
201                 name = pch; \r
202             }\r
203         }\r
204         m_url = dup;\r
205         free(dup);\r
206         return m_fields.size();\r
207     }\r
208     \r
209     string html_decode(const string& s) const {\r
210         string decoded;\r
211         const char* ch=s.c_str();\r
212         while (*ch) {\r
213             if (*ch=='&') {\r
214                 if (!strncmp(ch,"&lt;",4)) {\r
215                     decoded+='<'; ch+=4;\r
216                 }\r
217                 else if (!strncmp(ch,"&gt;",4)) {\r
218                     decoded+='>'; ch+=4;\r
219                 }\r
220                 else if (!strncmp(ch,"&quot;",6)) {\r
221                     decoded+='"'; ch+=6;\r
222                 }\r
223                 else if (*++ch=='#') {\r
224                     decoded+=(char)atoi(++ch);\r
225                     ch=strchr(ch,';')+1;\r
226                 }\r
227             }\r
228             else {\r
229                 decoded+=*ch++;\r
230             }\r
231         }\r
232         return decoded;\r
233     }\r
234     \r
235     long sendResponse(std::istream& inputStream, long status) {\r
236         m_method="POST";\r
237         string page,line;\r
238         while (getline(inputStream,line))\r
239             page += line + '\n';\r
240             \r
241         const char* pch=strstr(page.c_str(),"action=\"");\r
242         pch+=strlen("action=\"");\r
243         m_url = html_decode(page.substr(pch-page.c_str(),strchr(pch,'"')-pch));\r
244 \r
245         while (pch=strstr(pch,"<input type=\"hidden\" name=\"")) {\r
246             pch+=strlen("<input type=\"hidden\" name=\"");\r
247             string name = page.substr(pch-page.c_str(),strchr(pch,'"')-pch);\r
248             pch=strstr(pch,"value=\"");\r
249             pch+=strlen("value=\"");\r
250             m_fields[name] = html_decode(page.substr(pch-page.c_str(),strchr(pch,'"')-pch));\r
251         }\r
252         return m_fields.size();\r
253     }\r
254 };\r