Old config class ported, all config files now loading with new parser.
[shibboleth/sp.git] / shib-target / XMLRequestMapper.cpp
1 /*
2  *  Copyright 2001-2005 Internet2
3  * 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /* XMLRequestMapper.cpp - an XML-based map of URLs to application names and settings
18
19    Scott Cantor
20    1/6/04
21
22    $History:$
23 */
24
25 #include "internal.h"
26
27 #include <algorithm>
28 #include <shibsp/DOMPropertySet.h>
29 #include <xmltooling/util/ReloadableXMLFile.h>
30 #include <xmltooling/util/XMLHelper.h>
31
32 using namespace shibsp;
33 using namespace shibtarget;
34 using namespace xmltooling;
35 using namespace log4cpp;
36 using namespace std;
37
38 namespace shibtarget {
39
40     // Blocks access when an ACL plugin fails to load. 
41     class AccessControlDummy : public IAccessControl
42     {
43     public:
44         Lockable* lock() {
45             return this;
46         }
47         
48         void unlock() {}
49     
50         bool authorized(ShibTarget* st, ISessionCacheEntry* entry) const {
51             return false;
52         }
53     };
54
55     class Override : public DOMPropertySet, public DOMNodeFilter
56     {
57     public:
58         Override() : m_base(NULL), m_acl(NULL) {}
59         Override(const DOMElement* e, Category& log, const Override* base=NULL);
60         ~Override();
61
62         // PropertySet
63         pair<bool,bool> getBool(const char* name, const char* ns=NULL) const;
64         pair<bool,const char*> getString(const char* name, const char* ns=NULL) const;
65         pair<bool,const XMLCh*> getXMLString(const char* name, const char* ns=NULL) const;
66         pair<bool,unsigned int> getUnsignedInt(const char* name, const char* ns=NULL) const;
67         pair<bool,int> getInt(const char* name, const char* ns=NULL) const;
68         const PropertySet* getPropertySet(const char* name, const char* ns="urn:mace:shibboleth:target:config:1.0") const;
69         
70         // Provides filter to exclude special config elements.
71         short acceptNode(const DOMNode* node) const;
72
73         const Override* locate(const char* path) const;
74         IAccessControl* getAC() const { return (m_acl ? m_acl : (m_base ? m_base->getAC() : NULL)); }
75         
76     protected:
77         void loadACL(const DOMElement* e, Category& log);
78         
79         map<string,Override*> m_map;
80     
81     private:
82         const Override* m_base;
83         IAccessControl* m_acl;
84     };
85
86     class XMLRequestMapperImpl : public Override
87     {
88     public:
89         XMLRequestMapperImpl(const DOMElement* e, Category& log);
90
91         ~XMLRequestMapperImpl() {
92             if (m_document)
93                 m_document->release();
94         }
95
96         void setDocument(DOMDocument* doc) {
97             m_document = doc;
98         }
99     
100         const Override* findOverride(const char* vhost, const char* path) const;
101
102     private:    
103         map<string,Override*> m_extras;
104         DOMDocument* m_document;
105     };
106
107 #if defined (_MSC_VER)
108     #pragma warning( push )
109     #pragma warning( disable : 4250 )
110 #endif
111
112     class XMLRequestMapper : public IRequestMapper, public ReloadableXMLFile
113     {
114     public:
115         XMLRequestMapper(const DOMElement* e)
116                 : ReloadableXMLFile(e), m_impl(NULL), m_log(Category::getInstance(SHIBT_LOGCAT".RequestMapper")) {
117             load();
118         }
119
120         ~XMLRequestMapper() {
121             delete m_impl;
122         }
123
124         virtual Settings getSettings(ShibTarget* st) const;
125
126     protected:
127         pair<bool,DOMElement*> load();
128
129     private:
130         XMLRequestMapperImpl* m_impl;
131         Category& m_log;
132     };
133
134 #if defined (_MSC_VER)
135     #pragma warning( pop )
136 #endif
137
138     static const XMLCh AccessControl[] =            UNICODE_LITERAL_13(A,c,c,e,s,s,C,o,n,t,r,o,l);
139     static const XMLCh AccessControlProvider[] =    UNICODE_LITERAL_21(A,c,c,e,s,s,C,o,n,t,r,o,l,P,r,o,v,i,d,e,r);
140     static const XMLCh htaccess[] =                 UNICODE_LITERAL_8(h,t,a,c,c,e,s,s);
141     static const XMLCh Host[] =                     UNICODE_LITERAL_4(H,o,s,t);
142     static const XMLCh Path[] =                     UNICODE_LITERAL_4(P,a,t,h);
143     static const XMLCh name[] =                     UNICODE_LITERAL_4(n,a,m,e);
144     static const XMLCh type[] =                     UNICODE_LITERAL_4(t,y,p,e);
145 }
146
147 saml::IPlugIn* XMLRequestMapFactory(const DOMElement* e)
148 {
149     return new XMLRequestMapper(e);
150 }
151
152 short Override::acceptNode(const DOMNode* node) const
153 {
154     if (!XMLString::equals(node->getNamespaceURI(),shibspconstants::SHIB1SPCONFIG_NS))
155         return FILTER_ACCEPT;
156     const XMLCh* name=node->getLocalName();
157     if (XMLString::equals(name,Host) ||
158         XMLString::equals(name,Path) ||
159         XMLString::equals(name,AccessControl) ||
160         XMLString::equals(name,htaccess) ||
161         XMLString::equals(name,AccessControlProvider))
162         return FILTER_REJECT;
163
164     return FILTER_ACCEPT;
165 }
166
167 void Override::loadACL(const DOMElement* e, Category& log)
168 {
169     try {
170         saml::IPlugIn* plugin=NULL;
171         const DOMElement* acl=XMLHelper::getFirstChildElement(e,htaccess);
172         if (acl) {
173             log.info("building Apache htaccess provider...");
174             plugin=saml::SAMLConfig::getConfig().getPlugMgr().newPlugin(HTACCESS_ACCESSCONTROL,acl);
175         }
176         else {
177             acl=XMLHelper::getFirstChildElement(e,AccessControl);
178             if (acl) {
179                 log.info("building XML-based Access Control provider...");
180                 plugin=saml::SAMLConfig::getConfig().getPlugMgr().newPlugin(XML_ACCESSCONTROL,acl);
181             }
182             else {
183                 acl=XMLHelper::getFirstChildElement(e,AccessControlProvider);
184                 if (acl) {
185                     xmltooling::auto_ptr_char type(acl->getAttributeNS(NULL,type));
186                     log.info("building Access Control provider of type %s...",type.get());
187                     plugin=saml::SAMLConfig::getConfig().getPlugMgr().newPlugin(type.get(),acl);
188                 }
189             }
190         }
191         if (plugin) {
192             IAccessControl* acl=dynamic_cast<IAccessControl*>(plugin);
193             if (acl)
194                 m_acl=acl;
195             else {
196                 delete plugin;
197                 throw UnknownExtensionException("plugin was not an Access Control provider");
198             }
199         }
200     }
201     catch (exception& ex) {
202         log.crit("exception building AccessControl provider: %s", ex.what());
203         m_acl = new AccessControlDummy();
204     }
205 }
206
207 Override::Override(const DOMElement* e, Category& log, const Override* base) : m_base(base), m_acl(NULL)
208 {
209     try {
210         // Load the property set.
211         load(e,log,this);
212         
213         // Load any AccessControl provider.
214         loadACL(e,log);
215     
216         // Handle nested Paths.
217         DOMElement* path = XMLHelper::getFirstChildElement(e,Path);
218         for (int i=1; path; ++i, path=XMLHelper::getNextSiblingElement(path,Path)) {
219             const XMLCh* n=path->getAttributeNS(NULL,name);
220             
221             // Skip any leading slashes.
222             while (n && *n==chForwardSlash)
223                 n++;
224             
225             // Check for empty name.
226             if (!n || !*n) {
227                 log.warn("skipping Path element (%d) with empty name attribute", i);
228                 continue;
229             }
230
231             // Check for an embedded slash.
232             int slash=XMLString::indexOf(n,chForwardSlash);
233             if (slash>0) {
234                 // Copy the first path segment.
235                 XMLCh* namebuf=new XMLCh[slash + 1];
236                 for (int pos=0; pos < slash; pos++)
237                     namebuf[pos]=n[pos];
238                 namebuf[slash]=chNull;
239                 
240                 // Move past the slash in the original pathname.
241                 n=n+slash+1;
242                 
243                 // Skip any leading slashes again.
244                 while (*n==chForwardSlash)
245                     n++;
246                 
247                 if (*n) {
248                     // Create a placeholder Path element for the first path segment and replant under it.
249                     DOMElement* newpath=path->getOwnerDocument()->createElementNS(shibspconstants::SHIB1SPCONFIG_NS,Path);
250                     newpath->setAttributeNS(NULL,name,namebuf);
251                     path->setAttributeNS(NULL,name,n);
252                     path->getParentNode()->replaceChild(newpath,path);
253                     newpath->appendChild(path);
254                     
255                     // Repoint our locals at the new parent.
256                     path=newpath;
257                     n=path->getAttributeNS(NULL,name);
258                 }
259                 else {
260                     // All we had was a pathname with trailing slash(es), so just reset it without them.
261                     path->setAttributeNS(NULL,name,namebuf);
262                     n=path->getAttributeNS(NULL,name);
263                 }
264                 delete[] namebuf;
265             }
266             
267             Override* o=new Override(path,log,this);
268             pair<bool,const char*> name=o->getString("name");
269             char* dup=strdup(name.second);
270             for (char* pch=dup; *pch; pch++)
271                 *pch=tolower(*pch);
272             if (m_map.count(dup)) {
273                 log.warn("Skipping duplicate Path element (%s)",dup);
274                 free(dup);
275                 delete o;
276                 continue;
277             }
278             m_map[dup]=o;
279             free(dup);
280         }
281     }
282     catch (exception&) {
283         delete m_acl;
284         for_each(m_map.begin(),m_map.end(),xmltooling::cleanup_pair<string,Override>());
285         throw;
286     }
287 }
288
289 Override::~Override()
290 {
291     delete m_acl;
292     for_each(m_map.begin(),m_map.end(),xmltooling::cleanup_pair<string,Override>());
293 }
294
295 pair<bool,bool> Override::getBool(const char* name, const char* ns) const
296 {
297     pair<bool,bool> ret=DOMPropertySet::getBool(name,ns);
298     if (ret.first)
299         return ret;
300     return m_base ? m_base->getBool(name,ns) : ret;
301 }
302
303 pair<bool,const char*> Override::getString(const char* name, const char* ns) const
304 {
305     pair<bool,const char*> ret=DOMPropertySet::getString(name,ns);
306     if (ret.first)
307         return ret;
308     return m_base ? m_base->getString(name,ns) : ret;
309 }
310
311 pair<bool,const XMLCh*> Override::getXMLString(const char* name, const char* ns) const
312 {
313     pair<bool,const XMLCh*> ret=DOMPropertySet::getXMLString(name,ns);
314     if (ret.first)
315         return ret;
316     return m_base ? m_base->getXMLString(name,ns) : ret;
317 }
318
319 pair<bool,unsigned int> Override::getUnsignedInt(const char* name, const char* ns) const
320 {
321     pair<bool,unsigned int> ret=DOMPropertySet::getUnsignedInt(name,ns);
322     if (ret.first)
323         return ret;
324     return m_base ? m_base->getUnsignedInt(name,ns) : ret;
325 }
326
327 pair<bool,int> Override::getInt(const char* name, const char* ns) const
328 {
329     pair<bool,int> ret=DOMPropertySet::getInt(name,ns);
330     if (ret.first)
331         return ret;
332     return m_base ? m_base->getInt(name,ns) : ret;
333 }
334
335 const PropertySet* Override::getPropertySet(const char* name, const char* ns) const
336 {
337     const PropertySet* ret=DOMPropertySet::getPropertySet(name,ns);
338     if (ret || !m_base)
339         return ret;
340     return m_base->getPropertySet(name,ns);
341 }
342
343 const Override* Override::locate(const char* path) const
344 {
345     char* dup=strdup(path);
346     char* sep=strchr(dup,'?');
347     if (sep)
348         *sep=0;
349     for (char* pch=dup; *pch; pch++)
350         *pch=tolower(*pch);
351         
352     const Override* o=this;
353     
354 #ifdef HAVE_STRTOK_R
355     char* pos=NULL;
356     const char* token=strtok_r(dup,"/",&pos);
357 #else
358     const char* token=strtok(dup,"/");
359 #endif
360     while (token)
361     {
362         map<string,Override*>::const_iterator i=o->m_map.find(token);
363         if (i==o->m_map.end())
364             break;
365         o=i->second;
366 #ifdef HAVE_STRTOK_R
367         token=strtok_r(NULL,"/",&pos);
368 #else
369         token=strtok(NULL,"/");
370 #endif
371     }
372
373     free(dup);
374     return o;
375 }
376
377 XMLRequestMapperImpl::XMLRequestMapperImpl(const DOMElement* e, Category& log) : m_document(NULL)
378 {
379 #ifdef _DEBUG
380     xmltooling::NDC ndc("XMLRequestMapperImpl");
381 #endif
382
383     // Load the property set.
384     load(e,log,this);
385     
386     // Load any AccessControl provider.
387     loadACL(e,log);
388
389     // Loop over the Host elements.
390     const DOMElement* host = XMLHelper::getFirstChildElement(e,Host);
391     for (int i=1; host; ++i, host=XMLHelper::getNextSiblingElement(host,Host)) {
392         const XMLCh* n=host->getAttributeNS(NULL,name);
393         if (!n || !*n) {
394             log.warn("Skipping Host element (%d) with empty name attribute",i);
395             continue;
396         }
397         
398         Override* o=new Override(host,log,this);
399         pair<bool,const char*> name=o->getString("name");
400         pair<bool,const char*> scheme=o->getString("scheme");
401         pair<bool,const char*> port=o->getString("port");
402         
403         char* dup=strdup(name.second);
404         for (char* pch=dup; *pch; pch++)
405             *pch=tolower(*pch);
406         auto_ptr<char> dupwrap(dup);
407
408         if (!scheme.first && port.first) {
409             // No scheme, but a port, so assume http.
410             scheme = pair<bool,const char*>(true,"http");
411         }
412         else if (scheme.first && !port.first) {
413             // Scheme, no port, so default it.
414             // XXX Use getservbyname instead?
415             port.first = true;
416             if (!strcmp(scheme.second,"http"))
417                 port.second = "80";
418             else if (!strcmp(scheme.second,"https"))
419                 port.second = "443";
420             else if (!strcmp(scheme.second,"ftp"))
421                 port.second = "21";
422             else if (!strcmp(scheme.second,"ldap"))
423                 port.second = "389";
424             else if (!strcmp(scheme.second,"ldaps"))
425                 port.second = "636";
426         }
427
428         if (scheme.first) {
429             string url(scheme.second);
430             url=url + "://" + dup;
431             
432             // Is this the default port?
433             if ((!strcmp(scheme.second,"http") && !strcmp(port.second,"80")) ||
434                 (!strcmp(scheme.second,"https") && !strcmp(port.second,"443")) ||
435                 (!strcmp(scheme.second,"ftp") && !strcmp(port.second,"21")) ||
436                 (!strcmp(scheme.second,"ldap") && !strcmp(port.second,"389")) ||
437                 (!strcmp(scheme.second,"ldaps") && !strcmp(port.second,"636"))) {
438                 // First store a port-less version.
439                 if (m_map.count(url) || m_extras.count(url)) {
440                     log.warn("Skipping duplicate Host element (%s)",url.c_str());
441                     delete o;
442                     continue;
443                 }
444                 m_map[url]=o;
445                 log.debug("Added <Host> mapping for %s",url.c_str());
446                 
447                 // Now append the port. We use the extras vector, to avoid double freeing the object later.
448                 url=url + ':' + port.second;
449                 m_extras[url]=o;
450                 log.debug("Added <Host> mapping for %s",url.c_str());
451             }
452             else {
453                 url=url + ':' + port.second;
454                 if (m_map.count(url) || m_extras.count(url)) {
455                     log.warn("Skipping duplicate Host element (%s)",url.c_str());
456                     delete o;
457                     continue;
458                 }
459                 m_map[url]=o;
460                 log.debug("Added <Host> mapping for %s",url.c_str());
461             }
462         }
463         else {
464             // No scheme or port, so we enter dual hosts on http:80 and https:443
465             string url("http://");
466             url = url + dup;
467             if (m_map.count(url) || m_extras.count(url)) {
468                 log.warn("Skipping duplicate Host element (%s)",url.c_str());
469                 delete o;
470                 continue;
471             }
472             m_map[url]=o;
473             log.debug("Added <Host> mapping for %s",url.c_str());
474             
475             url = url + ":80";
476             if (m_map.count(url) || m_extras.count(url)) {
477                 log.warn("Skipping duplicate Host element (%s)",url.c_str());
478                 continue;
479             }
480             m_extras[url]=o;
481             log.debug("Added <Host> mapping for %s",url.c_str());
482             
483             url = "https://";
484             url = url + dup;
485             if (m_map.count(url) || m_extras.count(url)) {
486                 log.warn("Skipping duplicate Host element (%s)",url.c_str());
487                 continue;
488             }
489             m_extras[url]=o;
490             log.debug("Added <Host> mapping for %s",url.c_str());
491             
492             url = url + ":443";
493             if (m_map.count(url) || m_extras.count(url)) {
494                 log.warn("Skipping duplicate Host element (%s)",url.c_str());
495                 continue;
496             }
497             m_extras[url]=o;
498             log.debug("Added <Host> mapping for %s",url.c_str());
499         }
500     }
501 }
502
503 const Override* XMLRequestMapperImpl::findOverride(const char* vhost, const char* path) const
504 {
505     const Override* o=NULL;
506     map<string,Override*>::const_iterator i=m_map.find(vhost);
507     if (i!=m_map.end())
508         o=i->second;
509     else {
510         i=m_extras.find(vhost);
511         if (i!=m_extras.end())
512             o=i->second;
513     }
514     
515     return o ? o->locate(path) : this;
516 }
517
518 pair<bool,DOMElement*> XMLRequestMapper::load()
519 {
520     // Load from source using base class.
521     pair<bool,DOMElement*> raw = ReloadableXMLFile::load();
522     
523     // If we own it, wrap it.
524     XercesJanitor<DOMDocument> docjanitor(raw.first ? raw.second->getOwnerDocument() : NULL);
525
526     XMLRequestMapperImpl* impl = new XMLRequestMapperImpl(raw.second,m_log);
527     
528     // If we held the document, transfer it to the impl. If we didn't, it's a no-op.
529     impl->setDocument(docjanitor.release());
530
531     delete m_impl;
532     m_impl = impl;
533
534     return make_pair(false,(DOMElement*)NULL);
535 }
536
537 IRequestMapper::Settings XMLRequestMapper::getSettings(ShibTarget* st) const
538 {
539     ostringstream vhost;
540     vhost << st->getProtocol() << "://" << st->getHostname() << ':' << st->getPort();
541
542     const Override* o=m_impl->findOverride(vhost.str().c_str(), st->getRequestURI());
543
544     if (m_log.isDebugEnabled()) {
545 #ifdef _DEBUG
546         xmltooling::NDC ndc("getSettings");
547 #endif
548         pair<bool,const char*> ret=o->getString("applicationId");
549         m_log.debug("mapped %s%s to %s", vhost.str().c_str(), st->getRequestURI() ? st->getRequestURI() : "", ret.second);
550     }
551
552     return Settings(o,o->getAC());
553 }