Use for_each alg for cleanup.
[shibboleth/sp.git] / shib-target / XMLRequestMapper.cpp
1 /*
2  *  Copyright 2001-2005 Internet2
3  * 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /* XMLRequestMapper.cpp - an XML-based map of URLs to application names and settings
18
19    Scott Cantor
20    1/6/04
21
22    $History:$
23 */
24
25 #include "internal.h"
26
27 #include <algorithm>
28 #include <log4cpp/Category.hh>
29
30 using namespace std;
31 using namespace log4cpp;
32 using namespace saml;
33 using namespace shibboleth;
34 using namespace shibtarget;
35
36 namespace shibtarget {
37
38     class Override : public XMLPropertySet, public DOMNodeFilter
39     {
40     public:
41         Override() : m_base(NULL), m_acl(NULL) {}
42         Override(const DOMElement* e, Category& log, const Override* base=NULL);
43         ~Override();
44
45         // IPropertySet
46         pair<bool,bool> getBool(const char* name, const char* ns=NULL) const;
47         pair<bool,const char*> getString(const char* name, const char* ns=NULL) const;
48         pair<bool,const XMLCh*> getXMLString(const char* name, const char* ns=NULL) const;
49         pair<bool,unsigned int> getUnsignedInt(const char* name, const char* ns=NULL) const;
50         pair<bool,int> getInt(const char* name, const char* ns=NULL) const;
51         const IPropertySet* getPropertySet(const char* name, const char* ns="urn:mace:shibboleth:target:config:1.0") const;
52         
53         // Provides filter to exclude special config elements.
54         short acceptNode(const DOMNode* node) const;
55
56         const Override* locate(const char* path) const;
57         IAccessControl* getAC() const { return (m_acl ? m_acl : (m_base ? m_base->getAC() : NULL)); }
58         
59     protected:
60         void loadACL(const DOMElement* e, Category& log);
61         
62         map<string,Override*> m_map;
63     
64     private:
65         const Override* m_base;
66         IAccessControl* m_acl;
67     };
68
69     class XMLRequestMapperImpl : public ReloadableXMLFileImpl, public Override
70     {
71     public:
72         XMLRequestMapperImpl(const char* pathname) : ReloadableXMLFileImpl(pathname) { init(); }
73         XMLRequestMapperImpl(const DOMElement* e) : ReloadableXMLFileImpl(e) { init(); }
74         void init();
75         ~XMLRequestMapperImpl() {}
76     
77         const Override* findOverride(const char* vhost, const char* path) const;
78         Category* log;
79
80     private:    
81         map<string,Override*> m_extras;
82     };
83
84     // An implementation of the URL->application mapping API using an XML file
85     class XMLRequestMapper : public IRequestMapper, public ReloadableXMLFile
86     {
87     public:
88         XMLRequestMapper(const DOMElement* e) : ReloadableXMLFile(e) {}
89         ~XMLRequestMapper() {}
90
91         virtual Settings getSettings(ShibTarget* st) const;
92
93     protected:
94         virtual ReloadableXMLFileImpl* newImplementation(const char* pathname, bool first=true) const;
95         virtual ReloadableXMLFileImpl* newImplementation(const DOMElement* e, bool first=true) const;
96     };
97 }
98
99 IPlugIn* XMLRequestMapFactory(const DOMElement* e)
100 {
101     auto_ptr<XMLRequestMapper> m(new XMLRequestMapper(e));
102     m->getImplementation();
103     return m.release();
104 }
105
106 short Override::acceptNode(const DOMNode* node) const
107 {
108     if (XMLString::compareString(node->getNamespaceURI(),shibtarget::XML::SHIBTARGET_NS))
109         return FILTER_ACCEPT;
110     const XMLCh* name=node->getLocalName();
111     if (!XMLString::compareString(name,SHIBT_L(Host)) ||
112         !XMLString::compareString(name,SHIBT_L(Path)) ||
113         !XMLString::compareString(name,SHIBT_L(AccessControl)) ||
114         !XMLString::compareString(name,SHIBT_L(htaccess)) ||
115         !XMLString::compareString(name,SHIBT_L(AccessControlProvider)))
116         return FILTER_REJECT;
117
118     return FILTER_ACCEPT;
119 }
120
121 void Override::loadACL(const DOMElement* e, Category& log)
122 {
123     IPlugIn* plugin=NULL;
124     const DOMElement* acl=saml::XML::getFirstChildElement(e,shibtarget::XML::SHIBTARGET_NS,SHIBT_L(htaccess));
125     if (acl) {
126         log.info("building Apache htaccess provider...");
127         plugin=SAMLConfig::getConfig().getPlugMgr().newPlugin(shibtarget::XML::htAccessControlType,acl);
128     }
129     else {
130         acl=saml::XML::getFirstChildElement(e,shibtarget::XML::SHIBTARGET_NS,SHIBT_L(AccessControl));
131         if (acl) {
132             log.info("building XML-based Access Control provider...");
133             plugin=SAMLConfig::getConfig().getPlugMgr().newPlugin(shibtarget::XML::XMLAccessControlType,acl);
134         }
135         else {
136             acl=saml::XML::getFirstChildElement(e,shibtarget::XML::SHIBTARGET_NS,SHIBT_L(AccessControlProvider));
137             if (acl) {
138                 auto_ptr_char type(acl->getAttributeNS(NULL,SHIBT_L(type)));
139                 log.info("building Access Control provider of type %s...",type.get());
140                 plugin=SAMLConfig::getConfig().getPlugMgr().newPlugin(type.get(),acl);
141             }
142         }
143     }
144     if (plugin) {
145         IAccessControl* acl=dynamic_cast<IAccessControl*>(plugin);
146         if (acl)
147             m_acl=acl;
148         else {
149             delete plugin;
150             log.fatal("plugin was not an Access Control provider");
151             throw UnsupportedExtensionException("plugin was not an Access Control provider");
152         }
153     }
154 }
155
156 Override::Override(const DOMElement* e, Category& log, const Override* base) : m_base(base), m_acl(NULL)
157 {
158     try {
159         // Load the property set.
160         load(e,log,this);
161         
162         // Load any AccessControl provider.
163         loadACL(e,log);
164     
165         // Handle nested Paths.
166         unsigned int count=0;
167         DOMElement* path=saml::XML::getFirstChildElement(e,shibtarget::XML::SHIBTARGET_NS,SHIBT_L(Path));
168         while (path) {
169             const XMLCh* n=path->getAttributeNS(NULL,SHIBT_L(name));
170             
171             // Skip any leading slashes.
172             while (n && *n==chForwardSlash)
173                 n++;
174             
175             // Check for empty name.
176             if (!n || !*n) {
177                 log.warn("skipping Path element with empty name attribute");
178                 path=saml::XML::getNextSiblingElement(path,shibtarget::XML::SHIBTARGET_NS,SHIBT_L(Path));
179                 continue;
180             }
181
182             // Check for an embedded slash.
183             int slash=XMLString::indexOf(n,chForwardSlash);
184             if (slash>0) {
185                 // Copy the first path segment.
186                 XMLCh* namebuf=new XMLCh[slash + 1];
187                 for (int pos=0; pos < slash; pos++)
188                     namebuf[pos]=n[pos];
189                 namebuf[slash]=chNull;
190                 
191                 // Move past the slash in the original pathname.
192                 n=n+slash+1;
193                 
194                 // Skip any leading slashes again.
195                 while (*n==chForwardSlash)
196                     n++;
197                 
198                 if (*n) {
199                     // Create a placeholder Path element for the first path segment and replant under it.
200                     DOMElement* newpath=path->getOwnerDocument()->createElementNS(shibtarget::XML::SHIBTARGET_NS,SHIBT_L(Path));
201                     newpath->setAttributeNS(NULL,SHIBT_L(name),namebuf);
202                     path->setAttributeNS(NULL,SHIBT_L(name),n);
203                     path->getParentNode()->replaceChild(newpath,path);
204                     newpath->appendChild(path);
205                     
206                     // Repoint our locals at the new parent.
207                     path=newpath;
208                     n=path->getAttributeNS(NULL,SHIBT_L(name));
209                 }
210                 else {
211                     // All we had was a pathname with trailing slash(es), so just reset it without them.
212                     path->setAttributeNS(NULL,SHIBT_L(name),namebuf);
213                     n=path->getAttributeNS(NULL,SHIBT_L(name));
214                 }
215                 delete[] namebuf;
216             }
217             
218             Override* o=new Override(path,log,this);
219             pair<bool,const char*> name=o->getString("name");
220             char* dup=strdup(name.second);
221             for (char* pch=dup; *pch; pch++)
222                 *pch=tolower(*pch);
223             if (m_map.count(dup)) {
224                 log.warn("Skipping duplicate Path element (%s)",dup);
225                 free(dup);
226                 delete o;
227                 path=saml::XML::getNextSiblingElement(path,shibtarget::XML::SHIBTARGET_NS,SHIBT_L(Path));
228                 continue;
229             }
230             m_map[dup]=o;
231             free(dup);
232             
233             path=saml::XML::getNextSiblingElement(path,shibtarget::XML::SHIBTARGET_NS,SHIBT_L(Path));
234         }
235     }
236     catch (...) {
237         this->~Override();
238         throw;
239     }
240 }
241
242 Override::~Override()
243 {
244     delete m_acl;
245     for_each(m_map.begin(),m_map.end(),cleanup<string,Override>);
246 }
247
248 pair<bool,bool> Override::getBool(const char* name, const char* ns) const
249 {
250     pair<bool,bool> ret=XMLPropertySet::getBool(name,ns);
251     if (ret.first)
252         return ret;
253     return m_base ? m_base->getBool(name,ns) : ret;
254 }
255
256 pair<bool,const char*> Override::getString(const char* name, const char* ns) const
257 {
258     pair<bool,const char*> ret=XMLPropertySet::getString(name,ns);
259     if (ret.first)
260         return ret;
261     return m_base ? m_base->getString(name,ns) : ret;
262 }
263
264 pair<bool,const XMLCh*> Override::getXMLString(const char* name, const char* ns) const
265 {
266     pair<bool,const XMLCh*> ret=XMLPropertySet::getXMLString(name,ns);
267     if (ret.first)
268         return ret;
269     return m_base ? m_base->getXMLString(name,ns) : ret;
270 }
271
272 pair<bool,unsigned int> Override::getUnsignedInt(const char* name, const char* ns) const
273 {
274     pair<bool,unsigned int> ret=XMLPropertySet::getUnsignedInt(name,ns);
275     if (ret.first)
276         return ret;
277     return m_base ? m_base->getUnsignedInt(name,ns) : ret;
278 }
279
280 pair<bool,int> Override::getInt(const char* name, const char* ns) const
281 {
282     pair<bool,int> ret=XMLPropertySet::getInt(name,ns);
283     if (ret.first)
284         return ret;
285     return m_base ? m_base->getInt(name,ns) : ret;
286 }
287
288 const IPropertySet* Override::getPropertySet(const char* name, const char* ns) const
289 {
290     const IPropertySet* ret=XMLPropertySet::getPropertySet(name,ns);
291     if (ret || !m_base)
292         return ret;
293     return m_base->getPropertySet(name,ns);
294 }
295
296 const Override* Override::locate(const char* path) const
297 {
298     char* dup=strdup(path);
299     char* sep=strchr(dup,'?');
300     if (sep)
301         *sep=0;
302     for (char* pch=dup; *pch; pch++)
303         *pch=tolower(*pch);
304         
305     const Override* o=this;
306     
307 #ifdef HAVE_STRTOK_R
308     char* pos=NULL;
309     const char* token=strtok_r(dup,"/",&pos);
310 #else
311     const char* token=strtok(dup,"/");
312 #endif
313     while (token)
314     {
315         map<string,Override*>::const_iterator i=o->m_map.find(token);
316         if (i==o->m_map.end())
317             break;
318         o=i->second;
319 #ifdef HAVE_STRTOK_R
320         token=strtok_r(NULL,"/",&pos);
321 #else
322         token=strtok(NULL,"/");
323 #endif
324     }
325
326     free(dup);
327     return o;
328 }
329
330 void XMLRequestMapperImpl::init()
331 {
332 #ifdef _DEBUG
333     NDC ndc("init");
334 #endif
335     log=&Category::getInstance("shibtarget.RequestMapper");
336
337     try {
338         if (!saml::XML::isElementNamed(ReloadableXMLFileImpl::m_root,shibtarget::XML::SHIBTARGET_NS,SHIBT_L(RequestMap))) {
339             log->error("Construction requires a valid request mapping file: (conf:RequestMap as root element)");
340             throw MalformedException("Construction requires a valid request mapping file: (conf:RequestMap as root element)");
341         }
342
343         // Load the property set.
344         load(ReloadableXMLFileImpl::m_root,*log,this);
345         
346         // Load any AccessControl provider.
347         loadACL(ReloadableXMLFileImpl::m_root,*log);
348     
349         // Loop over the Host elements.
350         DOMNodeList* nlist = ReloadableXMLFileImpl::m_root->getElementsByTagNameNS(shibtarget::XML::SHIBTARGET_NS,SHIBT_L(Host));
351         for (unsigned int i=0; nlist && i<nlist->getLength(); i++) {
352             DOMElement* host=static_cast<DOMElement*>(nlist->item(i));
353             const XMLCh* n=host->getAttributeNS(NULL,SHIBT_L(name));
354             if (!n || !*n) {
355                 log->warn("Skipping Host element (%d) with empty name attribute",i);
356                 continue;
357             }
358             
359             Override* o=new Override(host,*log,this);
360             pair<bool,const char*> name=o->getString("name");
361             pair<bool,const char*> scheme=o->getString("scheme");
362             pair<bool,const char*> port=o->getString("port");
363             
364             char* dup=strdup(name.second);
365             for (char* pch=dup; *pch; pch++)
366                 *pch=tolower(*pch);
367             auto_ptr<char> dupwrap(dup);
368
369             if (!scheme.first && port.first) {
370                 // No scheme, but a port, so assume http.
371                 scheme = pair<bool,const char*>(true,"http");
372             }
373             else if (scheme.first && !port.first) {
374                 // Scheme, no port, so default it.
375                 // XXX Use getservbyname instead?
376                 port.first = true;
377                 if (!strcmp(scheme.second,"http"))
378                     port.second = "80";
379                 else if (!strcmp(scheme.second,"https"))
380                     port.second = "443";
381                 else if (!strcmp(scheme.second,"ftp"))
382                     port.second = "21";
383                 else if (!strcmp(scheme.second,"ldap"))
384                     port.second = "389";
385                 else if (!strcmp(scheme.second,"ldaps"))
386                     port.second = "636";
387             }
388
389             if (scheme.first) {
390                 string url(scheme.second);
391                 url=url + "://" + dup;
392                 
393                 // Is this the default port?
394                 if ((!strcmp(scheme.second,"http") && !strcmp(port.second,"80")) ||
395                     (!strcmp(scheme.second,"https") && !strcmp(port.second,"443")) ||
396                     (!strcmp(scheme.second,"ftp") && !strcmp(port.second,"21")) ||
397                     (!strcmp(scheme.second,"ldap") && !strcmp(port.second,"389")) ||
398                     (!strcmp(scheme.second,"ldaps") && !strcmp(port.second,"636"))) {
399                     // First store a port-less version.
400                     if (m_map.count(url) || m_extras.count(url)) {
401                         log->warn("Skipping duplicate Host element (%s)",url.c_str());
402                         delete o;
403                         continue;
404                     }
405                     m_map[url]=o;
406                     log->debug("Added <Host> mapping for %s",url.c_str());
407                     
408                     // Now append the port. We use the extras vector, to avoid double freeing the object later.
409                     url=url + ':' + port.second;
410                     m_extras[url]=o;
411                     log->debug("Added <Host> mapping for %s",url.c_str());
412                 }
413                 else {
414                     url=url + ':' + port.second;
415                     if (m_map.count(url) || m_extras.count(url)) {
416                         log->warn("Skipping duplicate Host element (%s)",url.c_str());
417                         delete o;
418                         continue;
419                     }
420                     m_map[url]=o;
421                     log->debug("Added <Host> mapping for %s",url.c_str());
422                 }
423             }
424             else {
425                 // No scheme or port, so we enter dual hosts on http:80 and https:443
426                 string url("http://");
427                 url = url + dup;
428                 if (m_map.count(url) || m_extras.count(url)) {
429                     log->warn("Skipping duplicate Host element (%s)",url.c_str());
430                     delete o;
431                     continue;
432                 }
433                 m_map[url]=o;
434                 log->debug("Added <Host> mapping for %s",url.c_str());
435                 
436                 url = url + ":80";
437                 if (m_map.count(url) || m_extras.count(url)) {
438                     log->warn("Skipping duplicate Host element (%s)",url.c_str());
439                     continue;
440                 }
441                 m_extras[url]=o;
442                 log->debug("Added <Host> mapping for %s",url.c_str());
443                 
444                 url = "https://";
445                 url = url + dup;
446                 if (m_map.count(url) || m_extras.count(url)) {
447                     log->warn("Skipping duplicate Host element (%s)",url.c_str());
448                     continue;
449                 }
450                 m_extras[url]=o;
451                 log->debug("Added <Host> mapping for %s",url.c_str());
452                 
453                 url = url + ":443";
454                 if (m_map.count(url) || m_extras.count(url)) {
455                     log->warn("Skipping duplicate Host element (%s)",url.c_str());
456                     continue;
457                 }
458                 m_extras[url]=o;
459                 log->debug("Added <Host> mapping for %s",url.c_str());
460             }
461         }
462     }
463     catch (SAMLException& e) {
464         log->errorStream() << "Error while parsing request mapping configuration: " << e.what() << CategoryStream::ENDLINE;
465         throw;
466     }
467 #ifndef _DEBUG
468     catch (...)
469     {
470         log->error("Unexpected error while parsing request mapping configuration");
471         throw;
472     }
473 #endif
474 }
475
476 const Override* XMLRequestMapperImpl::findOverride(const char* vhost, const char* path) const
477 {
478     const Override* o=NULL;
479     map<string,Override*>::const_iterator i=m_map.find(vhost);
480     if (i!=m_map.end())
481         o=i->second;
482     else {
483         i=m_extras.find(vhost);
484         if (i!=m_extras.end())
485             o=i->second;
486     }
487     
488     return o ? o->locate(path) : this;
489 }
490
491 ReloadableXMLFileImpl* XMLRequestMapper::newImplementation(const char* pathname, bool first) const
492 {
493     return new XMLRequestMapperImpl(pathname);
494 }
495
496 ReloadableXMLFileImpl* XMLRequestMapper::newImplementation(const DOMElement* e, bool first) const
497 {
498     return new XMLRequestMapperImpl(e);
499 }
500
501 IRequestMapper::Settings XMLRequestMapper::getSettings(ShibTarget* st) const
502 {
503     ostringstream vhost;
504     vhost << st->getProtocol() << "://" << st->getHostname() << ':' << st->getPort();
505
506     XMLRequestMapperImpl* impl=static_cast<XMLRequestMapperImpl*>(getImplementation());
507     const Override* o=impl->findOverride(vhost.str().c_str(), st->getRequestURI());
508
509     if (impl->log->isDebugEnabled()) {
510 #ifdef _DEBUG
511         saml::NDC ndc("getSettings");
512 #endif
513         pair<bool,const char*> ret=o->getString("applicationId");
514         impl->log->debug("mapped %s%s to %s", vhost.str().c_str(), st->getRequestURI() ? st->getRequestURI() : "", ret.second);
515     }
516
517     return Settings(o,o->getAC());
518 }