API consolidation around ShibTarget class
[shibboleth/sp.git] / shib-target / XMLRequestMapper.cpp
1 /* 
2  * The Shibboleth License, Version 1. 
3  * Copyright (c) 2002 
4  * University Corporation for Advanced Internet Development, Inc. 
5  * All rights reserved
6  * 
7  * 
8  * Redistribution and use in source and binary forms, with or without 
9  * modification, are permitted provided that the following conditions are met:
10  * 
11  * Redistributions of source code must retain the above copyright notice, this 
12  * list of conditions and the following disclaimer.
13  * 
14  * Redistributions in binary form must reproduce the above copyright notice, 
15  * this list of conditions and the following disclaimer in the documentation 
16  * and/or other materials provided with the distribution, if any, must include 
17  * the following acknowledgment: "This product includes software developed by 
18  * the University Corporation for Advanced Internet Development 
19  * <http://www.ucaid.edu>Internet2 Project. Alternately, this acknowledegement 
20  * may appear in the software itself, if and wherever such third-party 
21  * acknowledgments normally appear.
22  * 
23  * Neither the name of Shibboleth nor the names of its contributors, nor 
24  * Internet2, nor the University Corporation for Advanced Internet Development, 
25  * Inc., nor UCAID may be used to endorse or promote products derived from this 
26  * software without specific prior written permission. For written permission, 
27  * please contact shibboleth@shibboleth.org
28  * 
29  * Products derived from this software may not be called Shibboleth, Internet2, 
30  * UCAID, or the University Corporation for Advanced Internet Development, nor 
31  * may Shibboleth appear in their name, without prior written permission of the 
32  * University Corporation for Advanced Internet Development.
33  * 
34  * 
35  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 
36  * AND WITH ALL FAULTS. ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 
37  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 
38  * PARTICULAR PURPOSE, AND NON-INFRINGEMENT ARE DISCLAIMED AND THE ENTIRE RISK 
39  * OF SATISFACTORY QUALITY, PERFORMANCE, ACCURACY, AND EFFORT IS WITH LICENSEE. 
40  * IN NO EVENT SHALL THE COPYRIGHT OWNER, CONTRIBUTORS OR THE UNIVERSITY 
41  * CORPORATION FOR ADVANCED INTERNET DEVELOPMENT, INC. BE LIABLE FOR ANY DIRECT, 
42  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 
43  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 
44  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 
45  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 
46  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 
47  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
48  */
49
50 /* XMLRequestMapper.cpp - an XML-based map of URLs to application names and settings
51
52    Scott Cantor
53    1/6/04
54
55    $History:$
56 */
57
58 #include "internal.h"
59
60 #include <log4cpp/Category.hh>
61
62 using namespace std;
63 using namespace log4cpp;
64 using namespace saml;
65 using namespace shibboleth;
66 using namespace shibtarget;
67
68 namespace shibtarget {
69
70     class Override : public XMLPropertySet, public DOMNodeFilter
71     {
72     public:
73         Override() : m_base(NULL), m_acl(NULL) {}
74         Override(const DOMElement* e, Category& log, const Override* base=NULL);
75         ~Override();
76         IAccessControl* m_acl;
77
78         // IPropertySet
79         pair<bool,bool> getBool(const char* name, const char* ns=NULL) const;
80         pair<bool,const char*> getString(const char* name, const char* ns=NULL) const;
81         pair<bool,const XMLCh*> getXMLString(const char* name, const char* ns=NULL) const;
82         pair<bool,unsigned int> getUnsignedInt(const char* name, const char* ns=NULL) const;
83         pair<bool,int> getInt(const char* name, const char* ns=NULL) const;
84         const IPropertySet* getPropertySet(const char* name, const char* ns="urn:mace:shibboleth:target:config:1.0") const;
85         
86         // Provides filter to exclude special config elements.
87         short acceptNode(const DOMNode* node) const;
88
89         const Override* locate(const char* path) const;
90         
91     protected:
92         void loadACL(const DOMElement* e, Category& log);
93         
94         map<string,Override*> m_map;
95     
96     private:
97         const Override* m_base;
98     };
99
100     class XMLRequestMapperImpl : public ReloadableXMLFileImpl, public Override
101     {
102     public:
103         XMLRequestMapperImpl(const char* pathname) : ReloadableXMLFileImpl(pathname) { init(); }
104         XMLRequestMapperImpl(const DOMElement* e) : ReloadableXMLFileImpl(e) { init(); }
105         void init();
106         ~XMLRequestMapperImpl() {}
107     
108         const Override* findOverride(const char* vhost, const char* path) const;
109         Category* log;
110
111     private:    
112         map<string,Override*> m_extras;
113     };
114
115     // An implementation of the URL->application mapping API using an XML file
116     class XMLRequestMapper : public IRequestMapper, public ReloadableXMLFile
117     {
118     public:
119         XMLRequestMapper(const DOMElement* e) : ReloadableXMLFile(e) {}
120         ~XMLRequestMapper() {}
121
122         virtual Settings getSettings(ShibTarget* st) const;
123
124     protected:
125         virtual ReloadableXMLFileImpl* newImplementation(const char* pathname, bool first=true) const;
126         virtual ReloadableXMLFileImpl* newImplementation(const DOMElement* e, bool first=true) const;
127     };
128 }
129
130 IPlugIn* XMLRequestMapFactory(const DOMElement* e)
131 {
132     auto_ptr<XMLRequestMapper> m(new XMLRequestMapper(e));
133     m->getImplementation();
134     return m.release();
135 }
136
137 short Override::acceptNode(const DOMNode* node) const
138 {
139     if (XMLString::compareString(node->getNamespaceURI(),shibtarget::XML::SHIBTARGET_NS))
140         return FILTER_ACCEPT;
141     const XMLCh* name=node->getLocalName();
142     if (XMLString::compareString(name,SHIBT_L(AccessControlProvider)) ||
143         XMLString::compareString(name,SHIBT_L(Host)) ||
144         XMLString::compareString(name,SHIBT_L(Path)))
145         return FILTER_REJECT;
146
147     return FILTER_ACCEPT;
148 }
149
150 void Override::loadACL(const DOMElement* e, Category& log)
151 {
152     IPlugIn* plugin=NULL;
153     const DOMElement* acl=saml::XML::getFirstChildElement(e,shibtarget::XML::SHIBTARGET_NS,SHIBT_L(htaccess));
154     if (acl) {
155         log.info("building Apache htaccess provider...");
156         plugin=SAMLConfig::getConfig().getPlugMgr().newPlugin(shibtarget::XML::htAccessControlType,acl);
157     }
158     else {
159         acl=saml::XML::getFirstChildElement(e,shibtarget::XML::SHIBTARGET_NS,SHIBT_L(AccessControl));
160         if (acl) {
161             log.info("building XML-based Access Control provider...");
162             plugin=SAMLConfig::getConfig().getPlugMgr().newPlugin(shibtarget::XML::XMLAccessControlType,acl);
163         }
164         else {
165             acl=saml::XML::getFirstChildElement(e,shibtarget::XML::SHIBTARGET_NS,SHIBT_L(AccessControlProvider));
166             if (acl) {
167                 auto_ptr_char type(acl->getAttributeNS(NULL,SHIBT_L(type)));
168                 log.info("building Access Control provider of type %s...",type.get());
169                 plugin=SAMLConfig::getConfig().getPlugMgr().newPlugin(type.get(),acl);
170             }
171         }
172     }
173     if (plugin) {
174         IAccessControl* acl=dynamic_cast<IAccessControl*>(plugin);
175         if (acl)
176             m_acl=acl;
177         else {
178             delete plugin;
179             log.fatal("plugin was not an Access Control provider");
180             throw UnsupportedExtensionException("plugin was not an Access Control provider");
181         }
182     }
183 }
184
185 Override::Override(const DOMElement* e, Category& log, const Override* base) : m_base(base), m_acl(NULL)
186 {
187     try {
188         // Load the property set.
189         load(e,log,this);
190         
191         // Load any AccessControl provider.
192         loadACL(e,log);
193     
194         // Handle nested Paths.
195         DOMNodeList* nlist=e->getElementsByTagNameNS(shibtarget::XML::SHIBTARGET_NS,SHIBT_L(Path));
196         for (int i=0; nlist && i<nlist->getLength(); i++) {
197             DOMElement* path=static_cast<DOMElement*>(nlist->item(i));
198             const XMLCh* n=path->getAttributeNS(NULL,SHIBT_L(name));
199             if (!n || !*n) {
200                 log.warn("skipping Path element (%d) with empty name attribute",i);
201                 continue;
202             }
203             else if (*n==chForwardSlash && !n[1]) {
204                 log.warn("skipping Path element (%d) with a lone slash in the name attribute",i);
205                 continue;
206             }
207             Override* o=new Override(path,log,this);
208             pair<bool,const char*> name=o->getString("name");
209             char* dup=strdup(name.second);
210             for (char* pch=dup; *pch; pch++)
211                 *pch=tolower(*pch);
212             if (m_map.count(dup)) {
213                 log.warn("Skipping duplicate Path element (%s)",dup);
214                 free(dup);
215                 delete o;
216                 continue;
217             }
218             m_map[dup]=o;
219             free(dup);
220         }
221     }
222     catch (...) {
223         this->~Override();
224         throw;
225     }
226 }
227
228 Override::~Override()
229 {
230     delete m_acl;
231     for (map<string,Override*>::iterator i=m_map.begin(); i!=m_map.end(); i++)
232         delete i->second;
233 }
234
235 pair<bool,bool> Override::getBool(const char* name, const char* ns) const
236 {
237     pair<bool,bool> ret=XMLPropertySet::getBool(name,ns);
238     if (ret.first)
239         return ret;
240     return m_base ? m_base->getBool(name,ns) : ret;
241 }
242
243 pair<bool,const char*> Override::getString(const char* name, const char* ns) const
244 {
245     pair<bool,const char*> ret=XMLPropertySet::getString(name,ns);
246     if (ret.first)
247         return ret;
248     return m_base ? m_base->getString(name,ns) : ret;
249 }
250
251 pair<bool,const XMLCh*> Override::getXMLString(const char* name, const char* ns) const
252 {
253     pair<bool,const XMLCh*> ret=XMLPropertySet::getXMLString(name,ns);
254     if (ret.first)
255         return ret;
256     return m_base ? m_base->getXMLString(name,ns) : ret;
257 }
258
259 pair<bool,unsigned int> Override::getUnsignedInt(const char* name, const char* ns) const
260 {
261     pair<bool,unsigned int> ret=XMLPropertySet::getUnsignedInt(name,ns);
262     if (ret.first)
263         return ret;
264     return m_base ? m_base->getUnsignedInt(name,ns) : ret;
265 }
266
267 pair<bool,int> Override::getInt(const char* name, const char* ns) const
268 {
269     pair<bool,int> ret=XMLPropertySet::getInt(name,ns);
270     if (ret.first)
271         return ret;
272     return m_base ? m_base->getInt(name,ns) : ret;
273 }
274
275 const IPropertySet* Override::getPropertySet(const char* name, const char* ns) const
276 {
277     const IPropertySet* ret=XMLPropertySet::getPropertySet(name,ns);
278     if (ret || !m_base)
279         return ret;
280     return m_base->getPropertySet(name,ns);
281 }
282
283 const Override* Override::locate(const char* path) const
284 {
285     char* dup=strdup(path);
286     char* sep=strchr(dup,'?');
287     if (sep)
288         *sep=0;
289     for (char* pch=dup; *pch; pch++)
290         *pch=tolower(*pch);
291         
292     const Override* o=this;
293     
294 #ifdef HAVE_STRTOK_R
295     char* pos=NULL;
296     const char* token=strtok_r(dup,"/",&pos);
297 #else
298     const char* token=strtok(dup,"/");
299 #endif
300     while (token)
301     {
302         map<string,Override*>::const_iterator i=o->m_map.find(token);
303         if (i==o->m_map.end())
304             break;
305         o=i->second;
306 #ifdef HAVE_STRTOK_R
307         token=strtok_r(NULL,"/",&pos);
308 #else
309         token=strtok(NULL,"/");
310 #endif
311     }
312
313     free(dup);
314     return o;
315 }
316
317 void XMLRequestMapperImpl::init()
318 {
319 #ifdef _DEBUG
320     NDC ndc("init");
321 #endif
322     log=&Category::getInstance("shibtarget.RequestMapper");
323
324     try {
325         if (!saml::XML::isElementNamed(ReloadableXMLFileImpl::m_root,shibtarget::XML::SHIBTARGET_NS,SHIBT_L(RequestMap))) {
326             log->error("Construction requires a valid request mapping file: (conf:RequestMap as root element)");
327             throw MalformedException("Construction requires a valid request mapping file: (conf:RequestMap as root element)");
328         }
329
330         // Load the property set.
331         load(ReloadableXMLFileImpl::m_root,*log,this);
332         
333         // Load any AccessControl provider.
334         loadACL(ReloadableXMLFileImpl::m_root,*log);
335     
336         // Loop over the Host elements.
337         DOMNodeList* nlist = ReloadableXMLFileImpl::m_root->getElementsByTagNameNS(shibtarget::XML::SHIBTARGET_NS,SHIBT_L(Host));
338         for (int i=0; nlist && i<nlist->getLength(); i++) {
339             DOMElement* host=static_cast<DOMElement*>(nlist->item(i));
340             const XMLCh* n=host->getAttributeNS(NULL,SHIBT_L(name));
341             if (!n || !*n) {
342                 log->warn("Skipping Host element (%d) with empty name attribute",i);
343                 continue;
344             }
345             
346             Override* o=new Override(host,*log,this);
347             pair<bool,const char*> name=o->getString("name");
348             pair<bool,const char*> scheme=o->getString("scheme");
349             pair<bool,const char*> port=o->getString("port");
350             
351             char* dup=strdup(name.second);
352             for (char* pch=dup; *pch; pch++)
353                 *pch=tolower(*pch);
354             auto_ptr<char> dupwrap(dup);
355
356             if (!scheme.first && port.first) {
357                 // No scheme, but a port, so assume http.
358                 scheme = pair<bool,const char*>(true,"http");
359             }
360             else if (scheme.first && !port.first) {
361                 // Scheme, no port, so default it.
362                 // XXX Use getservbyname instead?
363                 port.first = true;
364                 if (!strcmp(scheme.second,"http"))
365                     port.second = "80";
366                 else if (!strcmp(scheme.second,"https"))
367                     port.second = "443";
368                 else if (!strcmp(scheme.second,"ftp"))
369                     port.second = "21";
370                 else if (!strcmp(scheme.second,"ldap"))
371                     port.second = "389";
372                 else if (!strcmp(scheme.second,"ldaps"))
373                     port.second = "636";
374             }
375
376             if (scheme.first) {
377                 string url(scheme.second);
378                 url=url + "://" + dup;
379                 
380                 // Is this the default port?
381                 if ((!strcmp(scheme.second,"http") && !strcmp(port.second,"80")) ||
382                     (!strcmp(scheme.second,"https") && !strcmp(port.second,"443")) ||
383                     (!strcmp(scheme.second,"ftp") && !strcmp(port.second,"21")) ||
384                     (!strcmp(scheme.second,"ldap") && !strcmp(port.second,"389")) ||
385                     (!strcmp(scheme.second,"ldaps") && !strcmp(port.second,"636"))) {
386                     // First store a port-less version.
387                     if (m_map.count(url) || m_extras.count(url)) {
388                         log->warn("Skipping duplicate Host element (%s)",url.c_str());
389                         delete o;
390                         continue;
391                     }
392                     m_map[url]=o;
393                     log->debug("Added <Host> mapping for %s",url.c_str());
394                     
395                     // Now append the port. We use the extras vector, to avoid double freeing the object later.
396                     url=url + ':' + port.second;
397                     m_extras[url]=o;
398                     log->debug("Added <Host> mapping for %s",url.c_str());
399                 }
400                 else {
401                     url=url + ':' + port.second;
402                     if (m_map.count(url) || m_extras.count(url)) {
403                         log->warn("Skipping duplicate Host element (%s)",url.c_str());
404                         delete o;
405                         continue;
406                     }
407                     m_map[url]=o;
408                     log->debug("Added <Host> mapping for %s",url.c_str());
409                 }
410             }
411             else {
412                 // No scheme or port, so we enter dual hosts on http:80 and https:443
413                 string url("http://");
414                 url = url + dup;
415                 if (m_map.count(url) || m_extras.count(url)) {
416                     log->warn("Skipping duplicate Host element (%s)",url.c_str());
417                     delete o;
418                     continue;
419                 }
420                 m_map[url]=o;
421                 log->debug("Added <Host> mapping for %s",url.c_str());
422                 
423                 url = url + ":80";
424                 if (m_map.count(url) || m_extras.count(url)) {
425                     log->warn("Skipping duplicate Host element (%s)",url.c_str());
426                     continue;
427                 }
428                 m_extras[url]=o;
429                 log->debug("Added <Host> mapping for %s",url.c_str());
430                 
431                 url = "https://";
432                 url = url + dup;
433                 if (m_map.count(url) || m_extras.count(url)) {
434                     log->warn("Skipping duplicate Host element (%s)",url.c_str());
435                     continue;
436                 }
437                 m_extras[url]=o;
438                 log->debug("Added <Host> mapping for %s",url.c_str());
439                 
440                 url = url + ":443";
441                 if (m_map.count(url) || m_extras.count(url)) {
442                     log->warn("Skipping duplicate Host element (%s)",url.c_str());
443                     continue;
444                 }
445                 m_extras[url]=o;
446                 log->debug("Added <Host> mapping for %s",url.c_str());
447             }
448         }
449     }
450     catch (SAMLException& e) {
451         log->errorStream() << "Error while parsing request mapping configuration: " << e.what() << CategoryStream::ENDLINE;
452         throw;
453     }
454 #ifndef _DEBUG
455     catch (...)
456     {
457         log->error("Unexpected error while parsing request mapping configuration");
458         throw;
459     }
460 #endif
461 }
462
463 const Override* XMLRequestMapperImpl::findOverride(const char* vhost, const char* path) const
464 {
465     const Override* o=NULL;
466     map<string,Override*>::const_iterator i=m_map.find(vhost);
467     if (i!=m_map.end())
468         o=i->second;
469     else {
470         i=m_extras.find(vhost);
471         if (i!=m_extras.end())
472             o=i->second;
473     }
474     
475     return o ? o->locate(path) : this;
476 }
477
478 ReloadableXMLFileImpl* XMLRequestMapper::newImplementation(const char* pathname, bool first) const
479 {
480     return new XMLRequestMapperImpl(pathname);
481 }
482
483 ReloadableXMLFileImpl* XMLRequestMapper::newImplementation(const DOMElement* e, bool first) const
484 {
485     return new XMLRequestMapperImpl(e);
486 }
487
488 IRequestMapper::Settings XMLRequestMapper::getSettings(ShibTarget* st) const
489 {
490     ostringstream vhost;
491     vhost << st->getProtocol() << "://" << st->getHostname() << ':' << st->getPort();
492
493     XMLRequestMapperImpl* impl=static_cast<XMLRequestMapperImpl*>(getImplementation());
494     const Override* o=impl->findOverride(vhost.str().c_str(), st->getRequestURI());
495
496     if (impl->log->isDebugEnabled()) {
497 #ifdef _DEBUG
498         saml::NDC ndc("getSettings");
499 #endif
500         pair<bool,const char*> ret=o->getString("applicationId");
501         impl->log->debug("mapped %s%s to %s", vhost.str().c_str(), st->getRequestURI() ? st->getRequestURI() : "", ret.second);
502     }
503
504     return Settings(o,o->m_acl);
505 }