New base class for XMLRequestMap.
[shibboleth/sp.git] / shib-target / XMLRequestMapper.cpp
1 /*
2  *  Copyright 2001-2005 Internet2
3  * 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /* XMLRequestMapper.cpp - an XML-based map of URLs to application names and settings
18
19    Scott Cantor
20    1/6/04
21
22    $History:$
23 */
24
25 #include "internal.h"
26
27 #include <algorithm>
28 #include <shibsp/DOMPropertySet.h>
29 #include <xmltooling/util/ReloadableXMLFile.h>
30 #include <xmltooling/util/XMLHelper.h>
31
32 using namespace shibsp;
33 using namespace shibtarget;
34 using namespace xmltooling;
35 using namespace log4cpp;
36 using namespace std;
37
38 namespace shibtarget {
39
40     // Blocks access when an ACL plugin fails to load. 
41     class AccessControlDummy : public IAccessControl
42     {
43     public:
44         Lockable* lock() {
45             return this;
46         }
47         
48         void unlock() {}
49     
50         bool authorized(ShibTarget* st, ISessionCacheEntry* entry) const {
51             return false;
52         }
53     };
54
55     class Override : public DOMPropertySet, public DOMNodeFilter
56     {
57     public:
58         Override() : m_base(NULL), m_acl(NULL) {}
59         Override(const DOMElement* e, Category& log, const Override* base=NULL);
60         ~Override();
61
62         // PropertySet
63         pair<bool,bool> getBool(const char* name, const char* ns=NULL) const;
64         pair<bool,const char*> getString(const char* name, const char* ns=NULL) const;
65         pair<bool,const XMLCh*> getXMLString(const char* name, const char* ns=NULL) const;
66         pair<bool,unsigned int> getUnsignedInt(const char* name, const char* ns=NULL) const;
67         pair<bool,int> getInt(const char* name, const char* ns=NULL) const;
68         const PropertySet* getPropertySet(const char* name, const char* ns="urn:mace:shibboleth:target:config:1.0") const;
69         
70         // Provides filter to exclude special config elements.
71         short acceptNode(const DOMNode* node) const;
72
73         const Override* locate(const char* path) const;
74         IAccessControl* getAC() const { return (m_acl ? m_acl : (m_base ? m_base->getAC() : NULL)); }
75         
76     protected:
77         void loadACL(const DOMElement* e, Category& log);
78         
79         map<string,Override*> m_map;
80     
81     private:
82         const Override* m_base;
83         IAccessControl* m_acl;
84     };
85
86     class XMLRequestMapperImpl : public Override
87     {
88     public:
89         XMLRequestMapperImpl(const DOMElement* e, Category& log);
90
91         ~XMLRequestMapperImpl() {
92             if (m_document)
93                 m_document->release();
94         }
95
96         void setDocument(DOMDocument* doc) {
97             m_document = doc;
98         }
99     
100         const Override* findOverride(const char* vhost, const char* path) const;
101
102     private:    
103         map<string,Override*> m_extras;
104         DOMDocument* m_document;
105     };
106
107 #if defined (_MSC_VER)
108     #pragma warning( push )
109     #pragma warning( disable : 4250 )
110 #endif
111
112     class XMLRequestMapper : public IRequestMapper, public xmltooling::ReloadableXMLFile
113     {
114     public:
115         XMLRequestMapper(const DOMElement* e)
116                 : xmltooling::ReloadableXMLFile(e), m_impl(NULL), m_log(Category::getInstance(SHIBT_LOGCAT".RequestMapper")) {
117             load();
118         }
119
120         ~XMLRequestMapper() {}
121
122         virtual Settings getSettings(ShibTarget* st) const;
123
124     protected:
125         pair<bool,DOMElement*> load();
126
127     private:
128         XMLRequestMapperImpl* m_impl;
129         Category& m_log;
130     };
131
132 #if defined (_MSC_VER)
133     #pragma warning( pop )
134 #endif
135
136     static const XMLCh AccessControl[] =            UNICODE_LITERAL_13(A,c,c,e,s,s,C,o,n,t,r,o,l);
137     static const XMLCh AccessControlProvider[] =    UNICODE_LITERAL_21(A,c,c,e,s,s,C,o,n,t,r,o,l,P,r,o,v,i,d,e,r);
138     static const XMLCh htaccess[] =                 UNICODE_LITERAL_8(h,t,a,c,c,e,s,s);
139     static const XMLCh Host[] =                     UNICODE_LITERAL_4(H,o,s,t);
140     static const XMLCh Path[] =                     UNICODE_LITERAL_4(P,a,t,h);
141     static const XMLCh name[] =                     UNICODE_LITERAL_4(n,a,m,e);
142     static const XMLCh type[] =                     UNICODE_LITERAL_4(t,y,p,e);
143 }
144
145 saml::IPlugIn* XMLRequestMapFactory(const DOMElement* e)
146 {
147     return new XMLRequestMapper(e);
148 }
149
150 short Override::acceptNode(const DOMNode* node) const
151 {
152     if (!XMLString::equals(node->getNamespaceURI(),shibtarget::XML::SHIBTARGET_NS))
153         return FILTER_ACCEPT;
154     const XMLCh* name=node->getLocalName();
155     if (XMLString::equals(name,Host) ||
156         XMLString::equals(name,Path) ||
157         XMLString::equals(name,AccessControl) ||
158         XMLString::equals(name,htaccess) ||
159         XMLString::equals(name,AccessControlProvider))
160         return FILTER_REJECT;
161
162     return FILTER_ACCEPT;
163 }
164
165 void Override::loadACL(const DOMElement* e, Category& log)
166 {
167     try {
168         saml::IPlugIn* plugin=NULL;
169         const DOMElement* acl=XMLHelper::getFirstChildElement(e,htaccess);
170         if (acl) {
171             log.info("building Apache htaccess provider...");
172             plugin=saml::SAMLConfig::getConfig().getPlugMgr().newPlugin(shibtarget::XML::htAccessControlType,acl);
173         }
174         else {
175             acl=XMLHelper::getFirstChildElement(e,AccessControl);
176             if (acl) {
177                 log.info("building XML-based Access Control provider...");
178                 plugin=saml::SAMLConfig::getConfig().getPlugMgr().newPlugin(shibtarget::XML::XMLAccessControlType,acl);
179             }
180             else {
181                 acl=XMLHelper::getFirstChildElement(e,AccessControlProvider);
182                 if (acl) {
183                     xmltooling::auto_ptr_char type(acl->getAttributeNS(NULL,type));
184                     log.info("building Access Control provider of type %s...",type.get());
185                     plugin=saml::SAMLConfig::getConfig().getPlugMgr().newPlugin(type.get(),acl);
186                 }
187             }
188         }
189         if (plugin) {
190             IAccessControl* acl=dynamic_cast<IAccessControl*>(plugin);
191             if (acl)
192                 m_acl=acl;
193             else {
194                 delete plugin;
195                 throw UnknownExtensionException("plugin was not an Access Control provider");
196             }
197         }
198     }
199     catch (exception& ex) {
200         log.crit("exception building AccessControl provider: %s", ex.what());
201         m_acl = new AccessControlDummy();
202     }
203 }
204
205 Override::Override(const DOMElement* e, Category& log, const Override* base) : m_base(base), m_acl(NULL)
206 {
207     try {
208         // Load the property set.
209         load(e,log,this);
210         
211         // Load any AccessControl provider.
212         loadACL(e,log);
213     
214         // Handle nested Paths.
215         DOMElement* path = XMLHelper::getFirstChildElement(e,Path);
216         for (int i=1; path; ++i, path=XMLHelper::getNextSiblingElement(path,Path)) {
217             const XMLCh* n=path->getAttributeNS(NULL,name);
218             
219             // Skip any leading slashes.
220             while (n && *n==chForwardSlash)
221                 n++;
222             
223             // Check for empty name.
224             if (!n || !*n) {
225                 log.warn("skipping Path element (%d) with empty name attribute", i);
226                 continue;
227             }
228
229             // Check for an embedded slash.
230             int slash=XMLString::indexOf(n,chForwardSlash);
231             if (slash>0) {
232                 // Copy the first path segment.
233                 XMLCh* namebuf=new XMLCh[slash + 1];
234                 for (int pos=0; pos < slash; pos++)
235                     namebuf[pos]=n[pos];
236                 namebuf[slash]=chNull;
237                 
238                 // Move past the slash in the original pathname.
239                 n=n+slash+1;
240                 
241                 // Skip any leading slashes again.
242                 while (*n==chForwardSlash)
243                     n++;
244                 
245                 if (*n) {
246                     // Create a placeholder Path element for the first path segment and replant under it.
247                     DOMElement* newpath=path->getOwnerDocument()->createElementNS(shibtarget::XML::SHIBTARGET_NS,Path);
248                     newpath->setAttributeNS(NULL,name,namebuf);
249                     path->setAttributeNS(NULL,name,n);
250                     path->getParentNode()->replaceChild(newpath,path);
251                     newpath->appendChild(path);
252                     
253                     // Repoint our locals at the new parent.
254                     path=newpath;
255                     n=path->getAttributeNS(NULL,name);
256                 }
257                 else {
258                     // All we had was a pathname with trailing slash(es), so just reset it without them.
259                     path->setAttributeNS(NULL,name,namebuf);
260                     n=path->getAttributeNS(NULL,name);
261                 }
262                 delete[] namebuf;
263             }
264             
265             Override* o=new Override(path,log,this);
266             pair<bool,const char*> name=o->getString("name");
267             char* dup=strdup(name.second);
268             for (char* pch=dup; *pch; pch++)
269                 *pch=tolower(*pch);
270             if (m_map.count(dup)) {
271                 log.warn("Skipping duplicate Path element (%s)",dup);
272                 free(dup);
273                 delete o;
274                 continue;
275             }
276             m_map[dup]=o;
277             free(dup);
278         }
279     }
280     catch (exception&) {
281         delete m_acl;
282         for_each(m_map.begin(),m_map.end(),xmltooling::cleanup_pair<string,Override>());
283         throw;
284     }
285 }
286
287 Override::~Override()
288 {
289     delete m_acl;
290     for_each(m_map.begin(),m_map.end(),xmltooling::cleanup_pair<string,Override>());
291 }
292
293 pair<bool,bool> Override::getBool(const char* name, const char* ns) const
294 {
295     pair<bool,bool> ret=DOMPropertySet::getBool(name,ns);
296     if (ret.first)
297         return ret;
298     return m_base ? m_base->getBool(name,ns) : ret;
299 }
300
301 pair<bool,const char*> Override::getString(const char* name, const char* ns) const
302 {
303     pair<bool,const char*> ret=DOMPropertySet::getString(name,ns);
304     if (ret.first)
305         return ret;
306     return m_base ? m_base->getString(name,ns) : ret;
307 }
308
309 pair<bool,const XMLCh*> Override::getXMLString(const char* name, const char* ns) const
310 {
311     pair<bool,const XMLCh*> ret=DOMPropertySet::getXMLString(name,ns);
312     if (ret.first)
313         return ret;
314     return m_base ? m_base->getXMLString(name,ns) : ret;
315 }
316
317 pair<bool,unsigned int> Override::getUnsignedInt(const char* name, const char* ns) const
318 {
319     pair<bool,unsigned int> ret=DOMPropertySet::getUnsignedInt(name,ns);
320     if (ret.first)
321         return ret;
322     return m_base ? m_base->getUnsignedInt(name,ns) : ret;
323 }
324
325 pair<bool,int> Override::getInt(const char* name, const char* ns) const
326 {
327     pair<bool,int> ret=DOMPropertySet::getInt(name,ns);
328     if (ret.first)
329         return ret;
330     return m_base ? m_base->getInt(name,ns) : ret;
331 }
332
333 const PropertySet* Override::getPropertySet(const char* name, const char* ns) const
334 {
335     const PropertySet* ret=DOMPropertySet::getPropertySet(name,ns);
336     if (ret || !m_base)
337         return ret;
338     return m_base->getPropertySet(name,ns);
339 }
340
341 const Override* Override::locate(const char* path) const
342 {
343     char* dup=strdup(path);
344     char* sep=strchr(dup,'?');
345     if (sep)
346         *sep=0;
347     for (char* pch=dup; *pch; pch++)
348         *pch=tolower(*pch);
349         
350     const Override* o=this;
351     
352 #ifdef HAVE_STRTOK_R
353     char* pos=NULL;
354     const char* token=strtok_r(dup,"/",&pos);
355 #else
356     const char* token=strtok(dup,"/");
357 #endif
358     while (token)
359     {
360         map<string,Override*>::const_iterator i=o->m_map.find(token);
361         if (i==o->m_map.end())
362             break;
363         o=i->second;
364 #ifdef HAVE_STRTOK_R
365         token=strtok_r(NULL,"/",&pos);
366 #else
367         token=strtok(NULL,"/");
368 #endif
369     }
370
371     free(dup);
372     return o;
373 }
374
375 XMLRequestMapperImpl::XMLRequestMapperImpl(const DOMElement* e, Category& log) : m_document(NULL)
376 {
377 #ifdef _DEBUG
378     xmltooling::NDC ndc("XMLRequestMapperImpl");
379 #endif
380
381     // Load the property set.
382     load(e,log,this);
383     
384     // Load any AccessControl provider.
385     loadACL(e,log);
386
387     // Loop over the Host elements.
388     const DOMElement* host = XMLHelper::getFirstChildElement(e,Host);
389     for (int i=1; host; ++i, host=XMLHelper::getNextSiblingElement(host,Host)) {
390         const XMLCh* n=host->getAttributeNS(NULL,name);
391         if (!n || !*n) {
392             log.warn("Skipping Host element (%d) with empty name attribute",i);
393             continue;
394         }
395         
396         Override* o=new Override(host,log,this);
397         pair<bool,const char*> name=o->getString("name");
398         pair<bool,const char*> scheme=o->getString("scheme");
399         pair<bool,const char*> port=o->getString("port");
400         
401         char* dup=strdup(name.second);
402         for (char* pch=dup; *pch; pch++)
403             *pch=tolower(*pch);
404         auto_ptr<char> dupwrap(dup);
405
406         if (!scheme.first && port.first) {
407             // No scheme, but a port, so assume http.
408             scheme = pair<bool,const char*>(true,"http");
409         }
410         else if (scheme.first && !port.first) {
411             // Scheme, no port, so default it.
412             // XXX Use getservbyname instead?
413             port.first = true;
414             if (!strcmp(scheme.second,"http"))
415                 port.second = "80";
416             else if (!strcmp(scheme.second,"https"))
417                 port.second = "443";
418             else if (!strcmp(scheme.second,"ftp"))
419                 port.second = "21";
420             else if (!strcmp(scheme.second,"ldap"))
421                 port.second = "389";
422             else if (!strcmp(scheme.second,"ldaps"))
423                 port.second = "636";
424         }
425
426         if (scheme.first) {
427             string url(scheme.second);
428             url=url + "://" + dup;
429             
430             // Is this the default port?
431             if ((!strcmp(scheme.second,"http") && !strcmp(port.second,"80")) ||
432                 (!strcmp(scheme.second,"https") && !strcmp(port.second,"443")) ||
433                 (!strcmp(scheme.second,"ftp") && !strcmp(port.second,"21")) ||
434                 (!strcmp(scheme.second,"ldap") && !strcmp(port.second,"389")) ||
435                 (!strcmp(scheme.second,"ldaps") && !strcmp(port.second,"636"))) {
436                 // First store a port-less version.
437                 if (m_map.count(url) || m_extras.count(url)) {
438                     log.warn("Skipping duplicate Host element (%s)",url.c_str());
439                     delete o;
440                     continue;
441                 }
442                 m_map[url]=o;
443                 log.debug("Added <Host> mapping for %s",url.c_str());
444                 
445                 // Now append the port. We use the extras vector, to avoid double freeing the object later.
446                 url=url + ':' + port.second;
447                 m_extras[url]=o;
448                 log.debug("Added <Host> mapping for %s",url.c_str());
449             }
450             else {
451                 url=url + ':' + port.second;
452                 if (m_map.count(url) || m_extras.count(url)) {
453                     log.warn("Skipping duplicate Host element (%s)",url.c_str());
454                     delete o;
455                     continue;
456                 }
457                 m_map[url]=o;
458                 log.debug("Added <Host> mapping for %s",url.c_str());
459             }
460         }
461         else {
462             // No scheme or port, so we enter dual hosts on http:80 and https:443
463             string url("http://");
464             url = url + dup;
465             if (m_map.count(url) || m_extras.count(url)) {
466                 log.warn("Skipping duplicate Host element (%s)",url.c_str());
467                 delete o;
468                 continue;
469             }
470             m_map[url]=o;
471             log.debug("Added <Host> mapping for %s",url.c_str());
472             
473             url = url + ":80";
474             if (m_map.count(url) || m_extras.count(url)) {
475                 log.warn("Skipping duplicate Host element (%s)",url.c_str());
476                 continue;
477             }
478             m_extras[url]=o;
479             log.debug("Added <Host> mapping for %s",url.c_str());
480             
481             url = "https://";
482             url = url + dup;
483             if (m_map.count(url) || m_extras.count(url)) {
484                 log.warn("Skipping duplicate Host element (%s)",url.c_str());
485                 continue;
486             }
487             m_extras[url]=o;
488             log.debug("Added <Host> mapping for %s",url.c_str());
489             
490             url = url + ":443";
491             if (m_map.count(url) || m_extras.count(url)) {
492                 log.warn("Skipping duplicate Host element (%s)",url.c_str());
493                 continue;
494             }
495             m_extras[url]=o;
496             log.debug("Added <Host> mapping for %s",url.c_str());
497         }
498     }
499 }
500
501 const Override* XMLRequestMapperImpl::findOverride(const char* vhost, const char* path) const
502 {
503     const Override* o=NULL;
504     map<string,Override*>::const_iterator i=m_map.find(vhost);
505     if (i!=m_map.end())
506         o=i->second;
507     else {
508         i=m_extras.find(vhost);
509         if (i!=m_extras.end())
510             o=i->second;
511     }
512     
513     return o ? o->locate(path) : this;
514 }
515
516 pair<bool,DOMElement*> XMLRequestMapper::load()
517 {
518     // Load from source using base class.
519     pair<bool,DOMElement*> raw = ReloadableXMLFile::load();
520     
521     // If we own it, wrap it.
522     XercesJanitor<DOMDocument> docjanitor(raw.first ? raw.second->getOwnerDocument() : NULL);
523
524     XMLRequestMapperImpl* impl = new XMLRequestMapperImpl(raw.second,m_log);
525     
526     // If we held the document, transfer it to the impl. If we didn't, it's a no-op.
527     impl->setDocument(docjanitor.release());
528
529     delete m_impl;
530     m_impl = impl;
531
532     return make_pair(false,(DOMElement*)NULL);
533 }
534
535 IRequestMapper::Settings XMLRequestMapper::getSettings(ShibTarget* st) const
536 {
537     ostringstream vhost;
538     vhost << st->getProtocol() << "://" << st->getHostname() << ':' << st->getPort();
539
540     const Override* o=m_impl->findOverride(vhost.str().c_str(), st->getRequestURI());
541
542     if (m_log.isDebugEnabled()) {
543 #ifdef _DEBUG
544         xmltooling::NDC ndc("getSettings");
545 #endif
546         pair<bool,const char*> ret=o->getString("applicationId");
547         m_log.debug("mapped %s%s to %s", vhost.str().c_str(), st->getRequestURI() ? st->getRequestURI() : "", ret.second);
548     }
549
550     return Settings(o,o->getAC());
551 }