Merge commit '2.5.0' into moonshot-packaging-fixes
[shibboleth/sp.git] / shibsp / impl / XMLAccessControl.cpp
index fd6d1ed..dc96225 100644 (file)
@@ -32,6 +32,9 @@
 #include "attribute/Attribute.h"
 
 #include <algorithm>
+#include <boost/bind.hpp>
+#include <boost/algorithm/string.hpp>
+#include <boost/ptr_container/ptr_vector.hpp>
 #include <xmltooling/unicode.h>
 #include <xmltooling/util/ReloadableXMLFile.h>
 #include <xmltooling/util/Threads.h>
@@ -45,6 +48,7 @@
 
 using namespace shibsp;
 using namespace xmltooling;
+using namespace boost;
 using namespace std;
 
 namespace shibsp {
@@ -62,16 +66,14 @@ namespace shibsp {
 
     private:
         string m_alias;
-        vector <string> m_vals;
+        set <string> m_vals;
     };
 
     class RuleRegex : public AccessControl
     {
     public:
         RuleRegex(const DOMElement* e);
-        ~RuleRegex() {
-            delete m_re;
-        }
+        ~RuleRegex() {}
 
         Lockable* lock() {return this;}
         void unlock() {}
@@ -81,14 +83,14 @@ namespace shibsp {
     private:
         string m_alias;
         auto_arrayptr<char> m_exp;
-        RegularExpression* m_re;
+        scoped_ptr<RegularExpression> m_re;
     };
 
     class Operator : public AccessControl
     {
     public:
         Operator(const DOMElement* e);
-        ~Operator();
+        ~Operator() {}
 
         Lockable* lock() {return this;}
         void unlock() {}
@@ -97,7 +99,7 @@ namespace shibsp {
 
     private:
         enum operator_t { OP_NOT, OP_AND, OP_OR } m_op;
-        vector<AccessControl*> m_operands;
+        ptr_vector<AccessControl> m_operands;
     };
 
 #if defined (_MSC_VER)
@@ -109,13 +111,12 @@ namespace shibsp {
     {
     public:
         XMLAccessControl(const DOMElement* e)
-                : ReloadableXMLFile(e, Category::getInstance(SHIBSP_LOGCAT".AccessControl.XML")), m_rootAuthz(nullptr) {
+                : ReloadableXMLFile(e, Category::getInstance(SHIBSP_LOGCAT".AccessControl.XML")) {
             background_load(); // guarantees an exception or the policy is loaded
         }
 
         ~XMLAccessControl() {
             shutdown();
-            delete m_rootAuthz;
         }
 
         aclresult_t authorized(const SPRequest& request, const Session* session) const;
@@ -124,7 +125,7 @@ namespace shibsp {
         pair<bool,DOMElement*> background_load();
 
     private:
-        AccessControl* m_rootAuthz;
+        scoped_ptr<AccessControl> m_rootAuthz;
     };
 
 #if defined (_MSC_VER)
@@ -137,6 +138,7 @@ namespace shibsp {
     }
 
     static const XMLCh _AccessControl[] =   UNICODE_LITERAL_13(A,c,c,e,s,s,C,o,n,t,r,o,l);
+    static const XMLCh _Handler[] =         UNICODE_LITERAL_7(H,a,n,d,l,e,r);
     static const XMLCh ignoreCase[] =       UNICODE_LITERAL_10(i,g,n,o,r,e,C,a,s,e);
     static const XMLCh ignoreOption[] =     UNICODE_LITERAL_1(i);
     static const XMLCh _list[] =            UNICODE_LITERAL_4(l,i,s,t);
@@ -152,32 +154,23 @@ Rule::Rule(const DOMElement* e) : m_alias(XMLHelper::getAttrString(e, nullptr, r
 {
     if (m_alias.empty())
         throw ConfigurationException("Access control rule missing require attribute");
+    if (!e->hasChildNodes())
+        return; // empty rule
 
-    auto_arrayptr<char> vals(toUTF8(e->hasChildNodes() ? e->getFirstChild()->getNodeValue() : nullptr));
-    if (!vals.get())
-        return;
+    auto_arrayptr<char> vals(toUTF8(e->getTextContent()));
+    if (!vals.get() || !*vals.get())
+        throw ConfigurationException("Unable to convert Rule content into UTF-8.");
 
     bool listflag = XMLHelper::getAttrBool(e, true, _list);
     if (!listflag) {
-        if (*vals.get())
-            m_vals.push_back(vals.get());
+        m_vals.insert(vals.get());
         return;
     }
 
-#ifdef HAVE_STRTOK_R
-    char* pos=nullptr;
-    const char* token=strtok_r(const_cast<char*>(vals.get())," ",&pos);
-#else
-    const char* token=strtok(const_cast<char*>(vals.get())," ");
-#endif
-    while (token) {
-        m_vals.push_back(token);
-#ifdef HAVE_STRTOK_R
-        token=strtok_r(nullptr," ",&pos);
-#else
-        token=strtok(nullptr," ");
-#endif
-    }
+    string temp(vals.get());
+    split(m_vals, temp, boost::is_space(), algorithm::token_compress_on);
+    if (m_vals.empty())
+        throw ConfigurationException("Rule did not contain any usable values.");
 }
 
 AccessControl::aclresult_t Rule::authorized(const SPRequest& request, const Session* session) const
@@ -199,31 +192,25 @@ AccessControl::aclresult_t Rule::authorized(const SPRequest& request, const Sess
         return shib_acl_false;
     }
     if (m_alias == "user") {
-        for (vector<string>::const_iterator i=m_vals.begin(); i!=m_vals.end(); ++i) {
-            if (*i == request.getRemoteUser()) {
-                request.log(SPRequest::SPDebug, string("AccessControl plugin expecting REMOTE_USER (") + *i + "), authz granted");
-                return shib_acl_true;
-            }
+        if (m_vals.find(request.getRemoteUser()) != m_vals.end()) {
+            request.log(SPRequest::SPDebug, string("AccessControl plugin expecting REMOTE_USER (") + request.getRemoteUser() + "), authz granted");
+            return shib_acl_true;
         }
         return shib_acl_false;
     }
     else if (m_alias == "authnContextClassRef") {
         const char* ref = session->getAuthnContextClassRef();
-        for (vector<string>::const_iterator i=m_vals.begin(); ref && i!=m_vals.end(); ++i) {
-            if (!strcmp(i->c_str(),ref)) {
-                request.log(SPRequest::SPDebug, string("AccessControl plugin expecting authnContextClassRef (") + *i + "), authz granted");
-                return shib_acl_true;
-            }
+        if (ref && m_vals.find(ref) != m_vals.end()) {
+            request.log(SPRequest::SPDebug, string("AccessControl plugin expecting authnContextClassRef (") + ref + "), authz granted");
+            return shib_acl_true;
         }
         return shib_acl_false;
     }
     else if (m_alias == "authnContextDeclRef") {
         const char* ref = session->getAuthnContextDeclRef();
-        for (vector<string>::const_iterator i=m_vals.begin(); ref && i!=m_vals.end(); ++i) {
-            if (!strcmp(i->c_str(),ref)) {
-                request.log(SPRequest::SPDebug, string("AccessControl plugin expecting authnContextDeclRef (") + *i + "), authz granted");
-                return shib_acl_true;
-            }
+        if (ref && m_vals.find(ref) != m_vals.end()) {
+            request.log(SPRequest::SPDebug, string("AccessControl plugin expecting authnContextDeclRef (") + ref + "), authz granted");
+            return shib_acl_true;
         }
         return shib_acl_false;
     }
@@ -235,14 +222,18 @@ AccessControl::aclresult_t Rule::authorized(const SPRequest& request, const Sess
         request.log(SPRequest::SPWarn, string("rule requires attribute (") + m_alias + "), not found in session");
         return shib_acl_false;
     }
+    else if (m_vals.empty()) {
+        request.log(SPRequest::SPDebug, string("AccessControl plugin requires presence of attribute (") + m_alias + "), authz granted");
+        return shib_acl_true;
+    }
 
     for (; attrs.first != attrs.second; ++attrs.first) {
         bool caseSensitive = attrs.first->second->isCaseSensitive();
 
         // Now we have to intersect the attribute's values against the rule's list.
         const vector<string>& vals = attrs.first->second->getSerializedValues();
-        for (vector<string>::const_iterator i=m_vals.begin(); i!=m_vals.end(); ++i) {
-            for (vector<string>::const_iterator j=vals.begin(); j!=vals.end(); ++j) {
+        for (set<string>::const_iterator i = m_vals.begin(); i != m_vals.end(); ++i) {
+            for (vector<string>::const_iterator j = vals.begin(); j != vals.end(); ++j) {
                 if ((caseSensitive && *i == *j) || (!caseSensitive && !strcasecmp(i->c_str(),j->c_str()))) {
                     request.log(SPRequest::SPDebug, string("AccessControl plugin expecting (") + *j + "), authz granted");
                     return shib_acl_true;
@@ -263,7 +254,7 @@ RuleRegex::RuleRegex(const DOMElement* e)
 
     bool ignore = XMLHelper::getAttrBool(e, false, ignoreCase);
     try {
-        m_re = new RegularExpression(e->getFirstChild()->getNodeValue(), (ignore ? ignoreOption : &chNull));
+        m_re.reset(new RegularExpression(e->getFirstChild()->getNodeValue(), (ignore ? ignoreOption : &chNull)));
     }
     catch (XMLException& ex) {
         auto_ptr_char tmp(ex.getMessage());
@@ -321,7 +312,7 @@ AccessControl::aclresult_t RuleRegex::authorized(const SPRequest& request, const
         for (; attrs.first != attrs.second; ++attrs.first) {
             // Now we have to intersect the attribute's values against the regular expression.
             const vector<string>& vals = attrs.first->second->getSerializedValues();
-            for (vector<string>::const_iterator j=vals.begin(); j!=vals.end(); ++j) {
+            for (vector<string>::const_iterator j = vals.begin(); j != vals.end(); ++j) {
                 if (m_re->matches(j->c_str())) {
                     request.log(SPRequest::SPDebug, string("AccessControl plugin expecting (") + m_exp.get() + "), authz granted");
                     return shib_acl_true;
@@ -348,45 +339,34 @@ Operator::Operator(const DOMElement* e)
     else
         throw ConfigurationException("Unrecognized operator in access control rule");
 
-    try {
-        e=XMLHelper::getFirstChildElement(e);
+    e=XMLHelper::getFirstChildElement(e);
+    if (XMLString::equals(e->getLocalName(),_Rule))
+        m_operands.push_back(new Rule(e));
+    else if (XMLString::equals(e->getLocalName(),_RuleRegex))
+        m_operands.push_back(new RuleRegex(e));
+    else
+        m_operands.push_back(new Operator(e));
+
+    if (m_op==OP_NOT)
+        return;
+
+    e=XMLHelper::getNextSiblingElement(e);
+    while (e) {
         if (XMLString::equals(e->getLocalName(),_Rule))
             m_operands.push_back(new Rule(e));
         else if (XMLString::equals(e->getLocalName(),_RuleRegex))
             m_operands.push_back(new RuleRegex(e));
         else
             m_operands.push_back(new Operator(e));
-
-        if (m_op==OP_NOT)
-            return;
-
         e=XMLHelper::getNextSiblingElement(e);
-        while (e) {
-            if (XMLString::equals(e->getLocalName(),_Rule))
-                m_operands.push_back(new Rule(e));
-            else if (XMLString::equals(e->getLocalName(),_RuleRegex))
-                m_operands.push_back(new RuleRegex(e));
-            else
-                m_operands.push_back(new Operator(e));
-            e=XMLHelper::getNextSiblingElement(e);
-        }
-    }
-    catch (exception&) {
-        for_each(m_operands.begin(),m_operands.end(),xmltooling::cleanup<AccessControl>());
-        throw;
     }
 }
 
-Operator::~Operator()
-{
-    for_each(m_operands.begin(),m_operands.end(),xmltooling::cleanup<AccessControl>());
-}
-
 AccessControl::aclresult_t Operator::authorized(const SPRequest& request, const Session* session) const
 {
     switch (m_op) {
         case OP_NOT:
-            switch (m_operands.front()->authorized(request,session)) {
+            switch (m_operands.front().authorized(request,session)) {
                 case shib_acl_true:
                     return shib_acl_false;
                 case shib_acl_false:
@@ -397,20 +377,28 @@ AccessControl::aclresult_t Operator::authorized(const SPRequest& request, const
 
         case OP_AND:
         {
-            for (vector<AccessControl*>::const_iterator i=m_operands.begin(); i!=m_operands.end(); i++) {
-                if ((*i)->authorized(request,session) != shib_acl_true)
+            // Look for a rule that returns non-true.
+            for (ptr_vector<AccessControl>::const_iterator i = m_operands.begin(); i != m_operands.end(); ++i) {
+                if (i->authorized(request,session) != shib_acl_true)
                     return shib_acl_false;
             }
             return shib_acl_true;
+
+            ptr_vector<AccessControl>::const_iterator i = find_if(
+                m_operands.begin(), m_operands.end(),
+                boost::bind(&AccessControl::authorized, _1, boost::cref(request), session) != shib_acl_true
+                );
+            return (i != m_operands.end()) ? shib_acl_false : shib_acl_true;
         }
 
         case OP_OR:
         {
-            for (vector<AccessControl*>::const_iterator i=m_operands.begin(); i!=m_operands.end(); i++) {
-                if ((*i)->authorized(request,session) == shib_acl_true)
-                    return shib_acl_true;
-            }
-            return shib_acl_false;
+            // Look for a rule that returns true.
+            ptr_vector<AccessControl>::const_iterator i = find_if(
+                m_operands.begin(), m_operands.end(),
+                boost::bind(&AccessControl::authorized, _1, boost::cref(request), session) == shib_acl_true
+                );
+            return (i != m_operands.end()) ? shib_acl_true : shib_acl_false;
         }
     }
     request.log(SPRequest::SPWarn,"unknown operation in access control policy, denying access");
@@ -426,23 +414,30 @@ pair<bool,DOMElement*> XMLAccessControl::background_load()
     XercesJanitor<DOMDocument> docjanitor(raw.first ? raw.second->getOwnerDocument() : nullptr);
 
     // Check for AccessControl wrapper and drop a level.
-    if (XMLString::equals(raw.second->getLocalName(),_AccessControl))
+    if (XMLString::equals(raw.second->getLocalName(),_AccessControl)) {
+        raw.second = XMLHelper::getFirstChildElement(raw.second);
+        if (!raw.second)
+            throw ConfigurationException("No child element found in AccessControl parent element.");
+    }
+    else if (XMLString::equals(raw.second->getLocalName(),_Handler)) {
         raw.second = XMLHelper::getFirstChildElement(raw.second);
+        if (!raw.second)
+            throw ConfigurationException("No child element found in Handler parent element.");
+    }
 
-    AccessControl* authz;
+    scoped_ptr<AccessControl> authz;
     if (XMLString::equals(raw.second->getLocalName(),_Rule))
-        authz=new Rule(raw.second);
+        authz.reset(new Rule(raw.second));
     else if (XMLString::equals(raw.second->getLocalName(),_RuleRegex))
-        authz=new RuleRegex(raw.second);
+        authz.reset(new RuleRegex(raw.second));
     else
-        authz=new Operator(raw.second);
+        authz.reset(new Operator(raw.second));
 
     // Perform the swap inside a lock.
     if (m_lock)
         m_lock->wrlock();
     SharedLock locker(m_lock, false);
-    delete m_rootAuthz;
-    m_rootAuthz = authz;
+    m_rootAuthz.swap(authz);
 
     return make_pair(false,(DOMElement*)nullptr);
 }