069641e2dbd4a1e532a381da589285dd3d61ccd5
[shibboleth/cpp-sp.git] / shib-target / XMLRequestMapper.cpp
1 /*
2  *  Copyright 2001-2005 Internet2
3  * 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /* XMLRequestMapper.cpp - an XML-based map of URLs to application names and settings
18
19    Scott Cantor
20    1/6/04
21
22    $History:$
23 */
24
25 #include "internal.h"
26
27 #include <xercesc/util/XMLUniDefs.hpp>
28 #include <xercesc/util/regx/RegularExpression.hpp>
29
30 using namespace shibtarget::logging;
31 using namespace shibtarget;
32 using namespace shibboleth;
33 using namespace saml;
34 using namespace std;
35
36 namespace shibtarget {
37
38     class Override : public XMLPropertySet, public DOMNodeFilter
39     {
40     public:
41         Override() : m_base(NULL), m_acl(NULL) {}
42         Override(const DOMElement* e, Category& log, const Override* base=NULL);
43         ~Override();
44
45         // IPropertySet
46         pair<bool,bool> getBool(const char* name, const char* ns=NULL) const;
47         pair<bool,const char*> getString(const char* name, const char* ns=NULL) const;
48         pair<bool,const XMLCh*> getXMLString(const char* name, const char* ns=NULL) const;
49         pair<bool,unsigned int> getUnsignedInt(const char* name, const char* ns=NULL) const;
50         pair<bool,int> getInt(const char* name, const char* ns=NULL) const;
51         const IPropertySet* getPropertySet(const char* name, const char* ns="urn:mace:shibboleth:target:config:1.0") const;
52         
53         // Provides filter to exclude special config elements.
54         short acceptNode(const DOMNode* node) const {
55             return FILTER_REJECT;
56         }
57
58         const Override* locate(const char* path) const;
59         IAccessControl* getAC() const { return (m_acl ? m_acl : (m_base ? m_base->getAC() : NULL)); }
60         
61     protected:
62         void loadACL(const DOMElement* e, Category& log);
63         
64         map<string,Override*> m_map;
65         vector< pair<RegularExpression*,Override*> > m_regexps;
66     
67     private:
68         const Override* m_base;
69         IAccessControl* m_acl;
70     };
71
72     class XMLRequestMapperImpl : public ReloadableXMLFileImpl, public Override
73     {
74     public:
75         XMLRequestMapperImpl(const char* pathname) : ReloadableXMLFileImpl(pathname) { init(); }
76         XMLRequestMapperImpl(const DOMElement* e) : ReloadableXMLFileImpl(e) { init(); }
77         void init();
78         ~XMLRequestMapperImpl() {}
79     
80         const Override* findOverride(const char* vhost, const char* path) const;
81         Category* log;
82
83     private:    
84         map<string,Override*> m_extras;
85     };
86
87     // An implementation of the URL->application mapping API using an XML file
88     class XMLRequestMapper : public IRequestMapper, public ReloadableXMLFile
89     {
90     public:
91         XMLRequestMapper(const DOMElement* e) : ReloadableXMLFile(e) {}
92         ~XMLRequestMapper() {}
93
94         virtual Settings getSettings(ShibTarget* st) const;
95
96     protected:
97         virtual ReloadableXMLFileImpl* newImplementation(const char* pathname, bool first=true) const;
98         virtual ReloadableXMLFileImpl* newImplementation(const DOMElement* e, bool first=true) const;
99     };
100
101     static const XMLCh HostRegex[] =    { chLatin_H, chLatin_o, chLatin_s, chLatin_t, chLatin_R, chLatin_e, chLatin_g, chLatin_e, chLatin_x, chNull };
102     static const XMLCh ignoreCase[] =   { chLatin_i, chLatin_g, chLatin_n, chLatin_o, chLatin_r, chLatin_e, chLatin_C, chLatin_a, chLatin_s, chLatin_e, chNull };
103     static const XMLCh ignoreOption[] = { chLatin_i, chNull };
104     static const XMLCh PathRegex[] =    { chLatin_P, chLatin_a, chLatin_t, chLatin_h, chLatin_R, chLatin_e, chLatin_g, chLatin_e, chLatin_x, chNull };
105     static const XMLCh regex[] =        { chLatin_r, chLatin_e, chLatin_g, chLatin_e, chLatin_x, chNull };
106 }
107
108 IPlugIn* XMLRequestMapFactory(const DOMElement* e)
109 {
110     auto_ptr<XMLRequestMapper> m(new XMLRequestMapper(e));
111     m->getImplementation();
112     return m.release();
113 }
114
115 void Override::loadACL(const DOMElement* e, Category& log)
116 {
117     IPlugIn* plugin=NULL;
118     const DOMElement* acl=saml::XML::getFirstChildElement(e,shibtarget::XML::SHIBTARGET_NS,SHIBT_L(htaccess));
119     if (acl) {
120         log.info("building Apache htaccess provider...");
121         plugin=SAMLConfig::getConfig().getPlugMgr().newPlugin(shibtarget::XML::htAccessControlType,acl);
122     }
123     else {
124         acl=saml::XML::getFirstChildElement(e,shibtarget::XML::SHIBTARGET_NS,SHIBT_L(AccessControl));
125         if (acl) {
126             log.info("building XML-based AccessControl provider...");
127             plugin=SAMLConfig::getConfig().getPlugMgr().newPlugin(shibtarget::XML::XMLAccessControlType,acl);
128         }
129         else {
130             acl=saml::XML::getFirstChildElement(e,shibtarget::XML::SHIBTARGET_NS,SHIBT_L(AccessControlProvider));
131             if (acl) {
132                 auto_ptr_char type(acl->getAttributeNS(NULL,SHIBT_L(type)));
133                 log.info("building AccessControl provider of type %s...",type.get());
134                 plugin=SAMLConfig::getConfig().getPlugMgr().newPlugin(type.get(),acl);
135             }
136         }
137     }
138     if (plugin) {
139         IAccessControl* acl=dynamic_cast<IAccessControl*>(plugin);
140         if (acl)
141             m_acl=acl;
142         else {
143             delete plugin;
144             log.fatal("plugin was not an AccessControl provider");
145             throw UnsupportedExtensionException("plugin was not an Access Control provider");
146         }
147     }
148 }
149
150 Override::Override(const DOMElement* e, Category& log, const Override* base) : m_base(base), m_acl(NULL)
151 {
152     try {
153         // Load the property set.
154         load(e,log,this);
155         
156         // Load any AccessControl provider.
157         loadACL(e,log);
158     
159         // Handle nested Paths.
160         DOMElement* path=saml::XML::getFirstChildElement(e,shibtarget::XML::SHIBTARGET_NS,SHIBT_L(Path));
161         while (path) {
162             const XMLCh* n=path->getAttributeNS(NULL,SHIBT_L(name));
163             
164             // Skip any leading slashes.
165             while (n && *n==chForwardSlash)
166                 n++;
167             
168             // Check for empty name.
169             if (!n || !*n) {
170                 log.warn("skipping Path element with empty name attribute");
171                 path=saml::XML::getNextSiblingElement(path,shibtarget::XML::SHIBTARGET_NS,SHIBT_L(Path));
172                 continue;
173             }
174
175             // Check for an embedded slash.
176             int slash=XMLString::indexOf(n,chForwardSlash);
177             if (slash>0) {
178                 // Copy the first path segment.
179                 XMLCh* namebuf=new XMLCh[slash + 1];
180                 for (int pos=0; pos < slash; pos++)
181                     namebuf[pos]=n[pos];
182                 namebuf[slash]=chNull;
183                 
184                 // Move past the slash in the original pathname.
185                 n=n+slash+1;
186                 
187                 // Skip any leading slashes again.
188                 while (*n==chForwardSlash)
189                     n++;
190                 
191                 if (*n) {
192                     // Create a placeholder Path element for the first path segment and replant under it.
193                     DOMElement* newpath=path->getOwnerDocument()->createElementNS(shibtarget::XML::SHIBTARGET_NS,SHIBT_L(Path));
194                     newpath->setAttributeNS(NULL,SHIBT_L(name),namebuf);
195                     path->setAttributeNS(NULL,SHIBT_L(name),n);
196                     path->getParentNode()->replaceChild(newpath,path);
197                     newpath->appendChild(path);
198                     
199                     // Repoint our locals at the new parent.
200                     path=newpath;
201                     n=path->getAttributeNS(NULL,SHIBT_L(name));
202                 }
203                 else {
204                     // All we had was a pathname with trailing slash(es), so just reset it without them.
205                     path->setAttributeNS(NULL,SHIBT_L(name),namebuf);
206                     n=path->getAttributeNS(NULL,SHIBT_L(name));
207                 }
208                 delete[] namebuf;
209             }
210             
211             Override* o=new Override(path,log,this);
212             pair<bool,const char*> name=o->getString("name");
213             char* dup=strdup(name.second);
214             for (char* pch=dup; *pch; pch++)
215                 *pch=tolower(*pch);
216             if (m_map.count(dup)) {
217                 log.warn("Skipping duplicate Path element (%s)",dup);
218                 free(dup);
219                 delete o;
220                 path=saml::XML::getNextSiblingElement(path,shibtarget::XML::SHIBTARGET_NS,SHIBT_L(Path));
221                 continue;
222             }
223             m_map[dup]=o;
224             free(dup);
225             
226             path=saml::XML::getNextSiblingElement(path,shibtarget::XML::SHIBTARGET_NS,SHIBT_L(Path));
227         }
228
229         if (!XMLString::equals(e->getLocalName(), PathRegex)) {
230             // Handle nested PathRegexs.
231             path = saml::XML::getFirstChildElement(e,shibtarget::XML::SHIBTARGET_NS,PathRegex);
232             for (int i=1; path; ++i, path=saml::XML::getNextSiblingElement(path,shibtarget::XML::SHIBTARGET_NS,PathRegex)) {
233                 const XMLCh* n=path->getAttributeNS(NULL,regex);
234                 if (!n || !*n) {
235                     log.warn("skipping PathRegex element (%d) with empty regex attribute",i);
236                     continue;
237                 }
238
239                 auto_ptr<Override> o(new Override(path,log,this));
240
241                 const XMLCh* flag=path->getAttributeNS(NULL,ignoreCase);
242                 try {
243                     auto_ptr<RegularExpression> re(
244                         new RegularExpression(n, (flag && (*flag==chLatin_f || *flag==chDigit_0)) ? &chNull : ignoreOption)
245                         );
246                     m_regexps.push_back(make_pair(re.release(), o.release()));
247                 }
248                 catch (XMLException& ex) {
249                     auto_ptr_char tmp(ex.getMessage());
250                     log.error("caught exception while parsing PathRegex regular expression (%d): %s", i, tmp.get());
251                     throw ConfigurationException("Invalid regular expression in PathRegex element.");
252                 }
253
254                 if (log.isDebugEnabled())
255                     log.debug("added <PathRegex> mapping (%s)", m_regexps.back().second->getString("regex").second);
256             }
257         }
258     }
259     catch (...) {
260         delete m_acl;
261
262         for (map<string,Override*>::iterator m=m_map.begin(); m!=m_map.end(); m++)
263             delete m->second;
264         for (vector< pair<RegularExpression*,Override*> >::iterator i = m_regexps.begin(); i != m_regexps.end(); ++i) {
265             delete i->first;
266             delete i->second;
267         }
268         throw;
269     }
270 }
271
272 Override::~Override()
273 {
274     delete m_acl;
275     for (map<string,Override*>::iterator m=m_map.begin(); m!=m_map.end(); m++)
276         delete m->second;
277     for (vector< pair<RegularExpression*,Override*> >::iterator i = m_regexps.begin(); i != m_regexps.end(); ++i) {
278         delete i->first;
279         delete i->second;
280     }
281 }
282
283 pair<bool,bool> Override::getBool(const char* name, const char* ns) const
284 {
285     pair<bool,bool> ret=XMLPropertySet::getBool(name,ns);
286     if (ret.first)
287         return ret;
288     return m_base ? m_base->getBool(name,ns) : ret;
289 }
290
291 pair<bool,const char*> Override::getString(const char* name, const char* ns) const
292 {
293     pair<bool,const char*> ret=XMLPropertySet::getString(name,ns);
294     if (ret.first)
295         return ret;
296     return m_base ? m_base->getString(name,ns) : ret;
297 }
298
299 pair<bool,const XMLCh*> Override::getXMLString(const char* name, const char* ns) const
300 {
301     pair<bool,const XMLCh*> ret=XMLPropertySet::getXMLString(name,ns);
302     if (ret.first)
303         return ret;
304     return m_base ? m_base->getXMLString(name,ns) : ret;
305 }
306
307 pair<bool,unsigned int> Override::getUnsignedInt(const char* name, const char* ns) const
308 {
309     pair<bool,unsigned int> ret=XMLPropertySet::getUnsignedInt(name,ns);
310     if (ret.first)
311         return ret;
312     return m_base ? m_base->getUnsignedInt(name,ns) : ret;
313 }
314
315 pair<bool,int> Override::getInt(const char* name, const char* ns) const
316 {
317     pair<bool,int> ret=XMLPropertySet::getInt(name,ns);
318     if (ret.first)
319         return ret;
320     return m_base ? m_base->getInt(name,ns) : ret;
321 }
322
323 const IPropertySet* Override::getPropertySet(const char* name, const char* ns) const
324 {
325     const IPropertySet* ret=XMLPropertySet::getPropertySet(name,ns);
326     if (ret || !m_base)
327         return ret;
328     return m_base->getPropertySet(name,ns);
329 }
330
331 const Override* Override::locate(const char* path) const
332 {
333     // This function is confusing because it's *not* recursive.
334     // The whole path is tokenized and mapped in a loop, so the
335     // path parameter starts with the entire request path and
336     // we can skip the leading slash as irrelevant.
337     if (*path == '/')
338         path++;
339
340     // Now we copy the path, chop the query string, and lower case it.
341     char* dup=strdup(path);
342     char* sep=strchr(dup,'?');
343     if (sep)
344         *sep=0;
345     for (char* pch=dup; *pch; pch++)
346         *pch=tolower(*pch);
347         
348     // Default is for the current object to provide settings.
349     const Override* o=this;
350     
351     // Tokenize the path by segment and try and map each segment.
352 #ifdef HAVE_STRTOK_R
353     char* pos=NULL;
354     const char* token=strtok_r(dup,"/",&pos);
355 #else
356     const char* token=strtok(dup,"/");
357 #endif
358     while (token)
359     {
360         map<string,Override*>::const_iterator i=o->m_map.find(token);
361         if (i==o->m_map.end())
362             break;  // Once there's no match, we've consumed as much of the path as possible here.
363         // We found a match, so reset the settings pointer.
364         o=i->second;
365
366         // We descended a step down the path, so we need to advance the original
367         // parameter for the regex step later.
368         path += strlen(token);
369         if (*path == '/')
370             path++;
371         
372 #ifdef HAVE_STRTOK_R
373         token=strtok_r(NULL,"/",&pos);
374 #else
375         token=strtok(NULL,"/");
376 #endif
377     }
378
379     free(dup);
380
381     // If there's anything left, we try for a regex match on the rest of the path minus the query string.
382     if (*path) {
383         string path2(path);
384         path2 = path2.substr(0,path2.find('?'));
385
386         for (vector< pair<RegularExpression*,Override*> >::const_iterator re = o->m_regexps.begin(); re != o->m_regexps.end(); ++re) {
387             if (re->first->matches(path2.c_str())) {
388                 o = re->second;
389                 break;
390             }
391         }
392     }
393     
394     return o;
395 }
396
397 void XMLRequestMapperImpl::init()
398 {
399 #ifdef _DEBUG
400     NDC ndc("init");
401 #endif
402     log=&Category::getInstance("shibtarget.RequestMapper");
403
404     try {
405         if (!saml::XML::isElementNamed(ReloadableXMLFileImpl::m_root,shibtarget::XML::SHIBTARGET_NS,SHIBT_L(RequestMap))) {
406             log->error("Construction requires a valid request mapping file: (conf:RequestMap as root element)");
407             throw MalformedException("Construction requires a valid request mapping file: (conf:RequestMap as root element)");
408         }
409
410         // Load the property set.
411         load(ReloadableXMLFileImpl::m_root,*log,this);
412         
413         // Load any AccessControl provider.
414         loadACL(ReloadableXMLFileImpl::m_root,*log);
415     
416         // Loop over the Host elements.
417         DOMNodeList* nlist = ReloadableXMLFileImpl::m_root->getElementsByTagNameNS(shibtarget::XML::SHIBTARGET_NS,SHIBT_L(Host));
418         for (XMLSize_t i=0; nlist && i<nlist->getLength(); i++) {
419             DOMElement* host=static_cast<DOMElement*>(nlist->item(i));
420             const XMLCh* n=host->getAttributeNS(NULL,SHIBT_L(name));
421             if (!n || !*n) {
422                 log->warn("Skipping Host element (%d) with empty name attribute",i);
423                 continue;
424             }
425             
426             Override* o=new Override(host,*log,this);
427             pair<bool,const char*> name=o->getString("name");
428             pair<bool,const char*> scheme=o->getString("scheme");
429             pair<bool,const char*> port=o->getString("port");
430             
431             char* dup=strdup(name.second);
432             for (char* pch=dup; *pch; pch++)
433                 *pch=tolower(*pch);
434             auto_ptr<char> dupwrap(dup);
435
436             if (!scheme.first && port.first) {
437                 // No scheme, but a port, so assume http.
438                 scheme = pair<bool,const char*>(true,"http");
439             }
440             else if (scheme.first && !port.first) {
441                 // Scheme, no port, so default it.
442                 // XXX Use getservbyname instead?
443                 port.first = true;
444                 if (!strcmp(scheme.second,"http"))
445                     port.second = "80";
446                 else if (!strcmp(scheme.second,"https"))
447                     port.second = "443";
448                 else if (!strcmp(scheme.second,"ftp"))
449                     port.second = "21";
450                 else if (!strcmp(scheme.second,"ldap"))
451                     port.second = "389";
452                 else if (!strcmp(scheme.second,"ldaps"))
453                     port.second = "636";
454             }
455
456             if (scheme.first) {
457                 string url(scheme.second);
458                 url=url + "://" + dup;
459                 
460                 // Is this the default port?
461                 if ((!strcmp(scheme.second,"http") && !strcmp(port.second,"80")) ||
462                     (!strcmp(scheme.second,"https") && !strcmp(port.second,"443")) ||
463                     (!strcmp(scheme.second,"ftp") && !strcmp(port.second,"21")) ||
464                     (!strcmp(scheme.second,"ldap") && !strcmp(port.second,"389")) ||
465                     (!strcmp(scheme.second,"ldaps") && !strcmp(port.second,"636"))) {
466                     // First store a port-less version.
467                     if (m_map.count(url) || m_extras.count(url)) {
468                         log->warn("Skipping duplicate Host element (%s)",url.c_str());
469                         delete o;
470                         continue;
471                     }
472                     m_map[url]=o;
473                     log->debug("Added <Host> mapping for %s",url.c_str());
474                     
475                     // Now append the port. We use the extras vector, to avoid double freeing the object later.
476                     url=url + ':' + port.second;
477                     m_extras[url]=o;
478                     log->debug("Added <Host> mapping for %s",url.c_str());
479                 }
480                 else {
481                     url=url + ':' + port.second;
482                     if (m_map.count(url) || m_extras.count(url)) {
483                         log->warn("Skipping duplicate Host element (%s)",url.c_str());
484                         delete o;
485                         continue;
486                     }
487                     m_map[url]=o;
488                     log->debug("Added <Host> mapping for %s",url.c_str());
489                 }
490             }
491             else {
492                 // No scheme or port, so we enter dual hosts on http:80 and https:443
493                 string url("http://");
494                 url = url + dup;
495                 if (m_map.count(url) || m_extras.count(url)) {
496                     log->warn("Skipping duplicate Host element (%s)",url.c_str());
497                     delete o;
498                     continue;
499                 }
500                 m_map[url]=o;
501                 log->debug("Added <Host> mapping for %s",url.c_str());
502                 
503                 url = url + ":80";
504                 if (m_map.count(url) || m_extras.count(url)) {
505                     log->warn("Skipping duplicate Host element (%s)",url.c_str());
506                     continue;
507                 }
508                 m_extras[url]=o;
509                 log->debug("Added <Host> mapping for %s",url.c_str());
510                 
511                 url = "https://";
512                 url = url + dup;
513                 if (m_map.count(url) || m_extras.count(url)) {
514                     log->warn("Skipping duplicate Host element (%s)",url.c_str());
515                     continue;
516                 }
517                 m_extras[url]=o;
518                 log->debug("Added <Host> mapping for %s",url.c_str());
519                 
520                 url = url + ":443";
521                 if (m_map.count(url) || m_extras.count(url)) {
522                     log->warn("Skipping duplicate Host element (%s)",url.c_str());
523                     continue;
524                 }
525                 m_extras[url]=o;
526                 log->debug("Added <Host> mapping for %s",url.c_str());
527             }
528         }
529     }
530     catch (SAMLException& e) {
531         log->errorStream() << "Error while parsing request mapping configuration: " << e.what() << CategoryStream::ENDLINE;
532         throw;
533     }
534 #ifndef _DEBUG
535     catch (...)
536     {
537         log->error("Unexpected error while parsing request mapping configuration");
538         throw;
539     }
540 #endif
541 }
542
543 const Override* XMLRequestMapperImpl::findOverride(const char* vhost, const char* path) const
544 {
545     const Override* o=NULL;
546     map<string,Override*>::const_iterator i=m_map.find(vhost);
547     if (i!=m_map.end())
548         o=i->second;
549     else {
550         i=m_extras.find(vhost);
551         if (i!=m_extras.end())
552             o=i->second;
553     }
554     
555     return o ? o->locate(path) : this;
556 }
557
558 ReloadableXMLFileImpl* XMLRequestMapper::newImplementation(const char* pathname, bool first) const
559 {
560     return new XMLRequestMapperImpl(pathname);
561 }
562
563 ReloadableXMLFileImpl* XMLRequestMapper::newImplementation(const DOMElement* e, bool first) const
564 {
565     return new XMLRequestMapperImpl(e);
566 }
567
568 IRequestMapper::Settings XMLRequestMapper::getSettings(ShibTarget* st) const
569 {
570     ostringstream vhost;
571     vhost << st->getProtocol() << "://" << st->getHostname() << ':' << st->getPort();
572
573     XMLRequestMapperImpl* impl=static_cast<XMLRequestMapperImpl*>(getImplementation());
574     const Override* o=impl->findOverride(vhost.str().c_str(), st->getRequestURI());
575
576     if (impl->log->isDebugEnabled()) {
577 #ifdef _DEBUG
578         saml::NDC ndc("getSettings");
579 #endif
580         pair<bool,const char*> ret=o->getString("applicationId");
581         impl->log->debug("mapped %s%s to %s", vhost.str().c_str(), st->getRequestURI() ? st->getRequestURI() : "", ret.second);
582     }
583
584     return Settings(o,o->getAC());
585 }