https://issues.shibboleth.net/jira/browse/SSPCPP-97
authorScott Cantor <cantor.2@osu.edu>
Tue, 27 May 2008 00:58:37 +0000 (00:58 +0000)
committerScott Cantor <cantor.2@osu.edu>
Tue, 27 May 2008 00:58:37 +0000 (00:58 +0000)
shibsp/attribute/filtering/impl/XMLAttributeFilter.cpp

index 47857f2..55052e0 100644 (file)
@@ -33,6 +33,7 @@
 #include <xmltooling/util/XMLHelper.h>
 #include <xercesc/util/XMLUniDefs.hpp>
 
+using shibspconstants::SHIB2ATTRIBUTEFILTER_NS;
 using namespace shibsp;
 using namespace opensaml::saml2md;
 using namespace opensaml;
@@ -46,11 +47,13 @@ namespace shibsp {
     #pragma warning( disable : 4250 )
 #endif
 
+    // Each Policy has a functor for determining applicability and a map of
+    // attribute IDs to Accept/Deny functor pairs (which can include NULLs).
     struct SHIBSP_DLLLOCAL Policy
     {
         Policy() : m_applies(NULL) {}
         const MatchFunctor* m_applies;
-        typedef multimap<string,const MatchFunctor*> rules_t;
+        typedef multimap< string,pair<const MatchFunctor*,const MatchFunctor*> > rules_t;
         rules_t m_rules;
     };
 
@@ -63,6 +66,7 @@ namespace shibsp {
                 m_document->release();
             for_each(m_policyReqRules.begin(), m_policyReqRules.end(), cleanup_pair<string,MatchFunctor>());
             for_each(m_permitValRules.begin(), m_permitValRules.end(), cleanup_pair<string,MatchFunctor>());
+            for_each(m_denyValRules.begin(), m_denyValRules.end(), cleanup_pair<string,MatchFunctor>());
         }
 
         void setDocument(DOMDocument* doc) {
@@ -75,14 +79,17 @@ namespace shibsp {
         MatchFunctor* buildFunctor(
             const DOMElement* e, const FilterPolicyContext& functorMap, const char* logname, bool standalone
             );
-        pair<string,const MatchFunctor*> buildAttributeRule(const DOMElement* e, const FilterPolicyContext& functorMap, bool standalone);
+        pair< string,pair<const MatchFunctor*,const MatchFunctor*> > buildAttributeRule(
+            const DOMElement* e, const FilterPolicyContext& permMap, const FilterPolicyContext& denyMap, bool standalone
+            );
 
         Category& m_log;
         DOMDocument* m_document;
         vector<Policy> m_policies;
-        map< string,pair<string,const MatchFunctor*> > m_attrRules;
+        map< string,pair<string,pair<const MatchFunctor*,const MatchFunctor*> > > m_attrRules;
         multimap<string,MatchFunctor*> m_policyReqRules;
         multimap<string,MatchFunctor*> m_permitValRules;
+        multimap<string,MatchFunctor*> m_denyValRules;
     };
     
     class SHIBSP_DLLLOCAL XMLFilter : public AttributeFilter, public ReloadableXMLFile
@@ -119,6 +126,8 @@ namespace shibsp {
     static const XMLCh AttributeFilterPolicy[] =        UNICODE_LITERAL_21(A,t,t,r,i,b,u,t,e,F,i,l,t,e,r,P,o,l,i,c,y);
     static const XMLCh AttributeRule[] =                UNICODE_LITERAL_13(A,t,t,r,i,b,u,t,e,R,u,l,e);
     static const XMLCh AttributeRuleReference[] =       UNICODE_LITERAL_22(A,t,t,r,i,b,u,t,e,R,u,l,e,R,e,f,e,r,e,n,c,e);
+    static const XMLCh DenyValueRule[] =                UNICODE_LITERAL_13(D,e,n,y,V,a,l,u,e,R,u,l,e);
+    static const XMLCh DenyValueRuleReference[] =       UNICODE_LITERAL_22(D,e,n,y,V,a,l,u,e,R,u,l,e,R,e,f,e,r,e,n,c,e);
     static const XMLCh PermitValueRule[] =              UNICODE_LITERAL_15(P,e,r,m,i,t,V,a,l,u,e,R,u,l,e);
     static const XMLCh PermitValueRuleReference[] =     UNICODE_LITERAL_24(P,e,r,m,i,t,V,a,l,u,e,R,u,l,e,R,e,f,e,r,e,n,c,e);
     static const XMLCh PolicyRequirementRule[] =        UNICODE_LITERAL_21(P,o,l,i,c,y,R,e,q,u,i,r,e,m,e,n,t,R,u,l,e);
@@ -134,30 +143,34 @@ XMLFilterImpl::XMLFilterImpl(const DOMElement* e, Category& log) : m_log(log), m
     xmltooling::NDC ndc("XMLFilterImpl");
 #endif
     
-    if (!XMLHelper::isNodeNamed(e, shibspconstants::SHIB2ATTRIBUTEFILTER_NS, AttributeFilterPolicyGroup))
+    if (!XMLHelper::isNodeNamed(e, SHIB2ATTRIBUTEFILTER_NS, AttributeFilterPolicyGroup))
         throw ConfigurationException("XML AttributeFilter requires afp:AttributeFilterPolicyGroup at root of configuration.");
 
     FilterPolicyContext reqFunctors(m_policyReqRules);
-    FilterPolicyContext valFunctors(m_permitValRules);
+    FilterPolicyContext permFunctors(m_permitValRules);
+    FilterPolicyContext denyFunctors(m_denyValRules);
 
     DOMElement* child = XMLHelper::getFirstChildElement(e);
     while (child) {
-        if (XMLHelper::isNodeNamed(child, shibspconstants::SHIB2ATTRIBUTEFILTER_NS, PolicyRequirementRule)) {
+        if (XMLHelper::isNodeNamed(child, SHIB2ATTRIBUTEFILTER_NS, PolicyRequirementRule)) {
             buildFunctor(child, reqFunctors, "PolicyRequirementRule", true);
         }
-        else if (XMLHelper::isNodeNamed(child, shibspconstants::SHIB2ATTRIBUTEFILTER_NS, PermitValueRule)) {
-            buildFunctor(child, valFunctors, "PermitValueRule", true);
+        else if (XMLHelper::isNodeNamed(child, SHIB2ATTRIBUTEFILTER_NS, PermitValueRule)) {
+            buildFunctor(child, permFunctors, "PermitValueRule", true);
+        }
+        else if (XMLHelper::isNodeNamed(child, SHIB2ATTRIBUTEFILTER_NS, DenyValueRule)) {
+            buildFunctor(child, denyFunctors, "DenyValueRule", true);
         }
-        else if (XMLHelper::isNodeNamed(child, shibspconstants::SHIB2ATTRIBUTEFILTER_NS, AttributeRule)) {
-            buildAttributeRule(child, valFunctors, true);
+        else if (XMLHelper::isNodeNamed(child, SHIB2ATTRIBUTEFILTER_NS, AttributeRule)) {
+            buildAttributeRule(child, permFunctors, denyFunctors, true);
         }
-        else if (XMLHelper::isNodeNamed(child, shibspconstants::SHIB2ATTRIBUTEFILTER_NS, AttributeFilterPolicy)) {
+        else if (XMLHelper::isNodeNamed(child, SHIB2ATTRIBUTEFILTER_NS, AttributeFilterPolicy)) {
             e = XMLHelper::getFirstChildElement(child);
             MatchFunctor* func = NULL;
-            if (e && XMLHelper::isNodeNamed(e, shibspconstants::SHIB2ATTRIBUTEFILTER_NS, PolicyRequirementRule)) {
+            if (e && XMLHelper::isNodeNamed(e, SHIB2ATTRIBUTEFILTER_NS, PolicyRequirementRule)) {
                 func = buildFunctor(e, reqFunctors, "PolicyRequirementRule", false);
             }
-            else if (e && XMLHelper::isNodeNamed(e, shibspconstants::SHIB2ATTRIBUTEFILTER_NS, PolicyRequirementRuleReference)) {
+            else if (e && XMLHelper::isNodeNamed(e, SHIB2ATTRIBUTEFILTER_NS, PolicyRequirementRuleReference)) {
                 auto_ptr_char ref(e->getAttributeNS(NULL, _ref));
                 if (ref.get() && *ref.get()) {
                     multimap<string,MatchFunctor*>::const_iterator prr = m_policyReqRules.find(ref.get());
@@ -169,15 +182,15 @@ XMLFilterImpl::XMLFilterImpl(const DOMElement* e, Category& log) : m_log(log), m
                 m_policies.back().m_applies = func;
                 e = XMLHelper::getNextSiblingElement(e);
                 while (e) {
-                    if (e && XMLHelper::isNodeNamed(e, shibspconstants::SHIB2ATTRIBUTEFILTER_NS, AttributeRule)) {
-                        pair<string,const MatchFunctor*> rule = buildAttributeRule(e, valFunctors, false);
-                        if (rule.second)
+                    if (e && XMLHelper::isNodeNamed(e, SHIB2ATTRIBUTEFILTER_NS, AttributeRule)) {
+                        pair< string,pair<const MatchFunctor*,const MatchFunctor*> > rule = buildAttributeRule(e, permFunctors, denyFunctors, false);
+                        if (rule.second.first || rule.second.second)
                             m_policies.back().m_rules.insert(Policy::rules_t::value_type(rule.first, rule.second));
                     }
-                    else if (e && XMLHelper::isNodeNamed(e, shibspconstants::SHIB2ATTRIBUTEFILTER_NS, AttributeRuleReference)) {
+                    else if (e && XMLHelper::isNodeNamed(e, SHIB2ATTRIBUTEFILTER_NS, AttributeRuleReference)) {
                         auto_ptr_char ref(e->getAttributeNS(NULL, _ref));
                         if (ref.get() && *ref.get()) {
-                            map< string,pair<string,const MatchFunctor*> >::const_iterator ar = m_attrRules.find(ref.get());
+                            map< string,pair< string,pair< const MatchFunctor*,const MatchFunctor*> > >::const_iterator ar = m_attrRules.find(ref.get());
                             if (ar != m_attrRules.end())
                                 m_policies.back().m_rules.insert(Policy::rules_t::value_type(ar->second.first, ar->second.second));
                             else
@@ -234,19 +247,21 @@ MatchFunctor* XMLFilterImpl::buildFunctor(
     return NULL;
 }
 
-pair<string,const MatchFunctor*> XMLFilterImpl::buildAttributeRule(const DOMElement* e, const FilterPolicyContext& functorMap, bool standalone)
+pair< string,pair<const MatchFunctor*,const MatchFunctor*> > XMLFilterImpl::buildAttributeRule(
+    const DOMElement* e, const FilterPolicyContext& permMap, const FilterPolicyContext& denyMap, bool standalone
+    )
 {
     auto_ptr_char temp(e->getAttributeNS(NULL,_id));
     const char* id = (temp.get() && *temp.get()) ? temp.get() : "";
 
     if (standalone && !*id) {
         m_log.warn("skipping stand-alone AttributeRule with no id");
-        return make_pair(string(),(const MatchFunctor*)NULL);
+        return make_pair(string(),pair<const MatchFunctor*,const MatchFunctor*>(NULL,NULL));
     }
     else if (*id && m_attrRules.count(id)) {
         if (standalone) {
             m_log.warn("skipping duplicate stand-alone AttributeRule with id (%s)", id);
-            return make_pair(string(),(const MatchFunctor*)NULL);
+            return make_pair(string(),pair<const MatchFunctor*,const MatchFunctor*>(NULL,NULL));
         }
         else
             id = "";
@@ -256,28 +271,43 @@ pair<string,const MatchFunctor*> XMLFilterImpl::buildAttributeRule(const DOMElem
     if (!attrID.get() || !*attrID.get())
         m_log.warn("skipping AttributeRule with no attributeID");
 
+    MatchFunctor* perm=NULL;
+    MatchFunctor* deny=NULL;
+
     e = XMLHelper::getFirstChildElement(e);
-    MatchFunctor* func=NULL;
-    if (e && XMLHelper::isNodeNamed(e, shibspconstants::SHIB2ATTRIBUTEFILTER_NS, PermitValueRule)) {
-        func = buildFunctor(e, functorMap, "PermitValueRule", false);
+    if (e && XMLHelper::isNodeNamed(e, SHIB2ATTRIBUTEFILTER_NS, PermitValueRule)) {
+        perm = buildFunctor(e, permMap, "PermitValueRule", false);
+        e = XMLHelper::getNextSiblingElement(e);
     }
-    else if (e && XMLHelper::isNodeNamed(e, shibspconstants::SHIB2ATTRIBUTEFILTER_NS, PermitValueRuleReference)) {
+    else if (e && XMLHelper::isNodeNamed(e, SHIB2ATTRIBUTEFILTER_NS, PermitValueRuleReference)) {
         auto_ptr_char ref(e->getAttributeNS(NULL, _ref));
         if (ref.get() && *ref.get()) {
             multimap<string,MatchFunctor*>::const_iterator pvr = m_permitValRules.find(ref.get());
-            func = (pvr!=m_permitValRules.end()) ? pvr->second : NULL;
+            perm = (pvr!=m_permitValRules.end()) ? pvr->second : NULL;
+        }
+        e = XMLHelper::getNextSiblingElement(e);
+    }
+
+    if (e && XMLHelper::isNodeNamed(e, SHIB2ATTRIBUTEFILTER_NS, DenyValueRule)) {
+        deny = buildFunctor(e, denyMap, "DenyValueRule", false);
+    }
+    else if (e && XMLHelper::isNodeNamed(e, SHIB2ATTRIBUTEFILTER_NS, DenyValueRuleReference)) {
+        auto_ptr_char ref(e->getAttributeNS(NULL, _ref));
+        if (ref.get() && *ref.get()) {
+            multimap<string,MatchFunctor*>::const_iterator pvr = m_denyValRules.find(ref.get());
+            deny = (pvr!=m_denyValRules.end()) ? pvr->second : NULL;
         }
     }
 
-    if (func) {
+    if (perm || deny) {
         if (*id)
-            return m_attrRules[id] = pair<string,const MatchFunctor*>(attrID.get(), func);
+            return m_attrRules[id] = pair< string,pair<const MatchFunctor*,const MatchFunctor*> >(attrID.get(), make_pair(perm,deny));
         else
-            return pair<string,const MatchFunctor*>(attrID.get(), func);
+            return pair< string,pair<const MatchFunctor*,const MatchFunctor*> >(attrID.get(), make_pair(perm,deny));
     }
 
-    m_log.warn("skipping AttributeRule (%s), PermitValueRule invalid or missing", id);
-    return make_pair(string(),(const MatchFunctor*)NULL);
+    m_log.warn("skipping AttributeRule (%s), permit and denial rule(s) invalid or missing", id);
+    return make_pair(string(),pair<const MatchFunctor*,const MatchFunctor*>(NULL,NULL));
 }
 
 void XMLFilterImpl::filterAttributes(const FilteringContext& context, vector<Attribute*>& attributes) const
@@ -293,87 +323,126 @@ void XMLFilterImpl::filterAttributes(const FilteringContext& context, vector<Att
         return;
     }
 
-    size_t count,index;
+    // We have to evaluate every policy that applies against each attribute before deciding what to keep.
 
-    // Test each Policy.
+    // For efficiency, we build an array of the policies that apply in advance.
+    vector<const Policy*> applicablePolicies;
     for (vector<Policy>::const_iterator p=m_policies.begin(); p!=m_policies.end(); ++p) {
-        if (p->m_applies->evaluatePolicyRequirement(context)) {
-            // Loop over the attributes and look for possible rules to run.
-            for (vector<Attribute*>::size_type a=0; a<attributes.size();) {
-                bool ruleFound = false;
-                Attribute* attr = attributes[a];
-                pair<Policy::rules_t::const_iterator,Policy::rules_t::const_iterator> rules = p->m_rules.equal_range(attr->getId());
-                if (rules.first != rules.second) {
-                    ruleFound = true;
-                    // Run each rule in sequence.
-                    m_log.debug(
-                        "applying filtering rule(s) for attribute (%s) from (%s)",
-                        attr->getId(), issuer.get() ? issuer.get() : "unknown source"
-                        );
-                    for (; rules.first!=rules.second; ++rules.first) {
-                        count = attr->valueCount();
-                        for (index=0; index < count;) {
-                            // The return value tells us whether to index past the accepted value, or stay put and decrement the count.
-                            if (rules.first->second->evaluatePermitValue(context, *attr, index)) {
-                                index++;
-                            }
-                            else {
-                                m_log.warn(
-                                    "removed value at position (%lu) of attribute (%s) from (%s)",
-                                    index, attr->getId(), issuer.get() ? issuer.get() : "unknown source"
-                                    );
-                                attr->removeValue(index);
-                                count--;
-                            }
-                        }
-                    }
-                }
+        if (p->m_applies->evaluatePolicyRequirement(context))
+            applicablePolicies.push_back(&(*p));
+    }
 
-                rules = p->m_rules.equal_range("*");
-                if (rules.first != rules.second) {
-                    // Run each rule in sequence.
-                    if (!ruleFound) {
-                        m_log.debug(
-                            "applying wildcard rule(s) for attribute (%s) from (%s)",
-                            attr->getId(), issuer.get() ? issuer.get() : "unknown source"
-                            );
-                        ruleFound = true;
-                    }
-                    for (; rules.first!=rules.second; ++rules.first) {
-                        count = attr->valueCount();
-                        for (index=0; index < count;) {
-                            // The return value tells us whether to index past the accepted value, or stay put and decrement the count.
-                            if (rules.first->second->evaluatePermitValue(context, *attr, index)) {
-                                index++;
-                            }
-                            else {
-                                m_log.warn(
-                                    "removed value at position (%lu) of attribute (%s) from (%s)",
-                                    index, attr->getId(), issuer.get() ? issuer.get() : "unknown source"
-                                    );
-                                attr->removeValue(index);
-                                count--;
-                            }
-                        }
-                    }
-                }
+    // For further efficiency, we declare arrays to store the applicable rules for an Attribute.
+    vector< pair<const MatchFunctor*,const MatchFunctor*> > applicableRules;
+    vector< pair<const MatchFunctor*,const MatchFunctor*> > wildcardRules;
 
-                if (!ruleFound || attr->valueCount() == 0) {
-                    if (!ruleFound) {
-                        // No rule found, so we're filtering it out.
-                        m_log.warn(
-                            "no rule found, removing all values of attribute (%s) from (%s)",
-                            attr->getId(), issuer.get() ? issuer.get() : "unknown source"
-                            );
-                    }
-                    delete attr;
-                    attributes.erase(attributes.begin() + a);
-                }
-                else {
-                    ++a;
-                }
+    // Store off the wildcards ahead of time.
+    for (vector<const Policy*>::const_iterator pol=applicablePolicies.begin(); pol!=applicablePolicies.end(); ++pol) {
+        pair<Policy::rules_t::const_iterator,Policy::rules_t::const_iterator> rules = (*pol)->m_rules.equal_range("*");
+        for (; rules.first!=rules.second; ++rules.first)
+            wildcardRules.push_back(rules.first->second);
+    }
+
+    // To track what to keep without removing anything from the original set until the end, we maintain
+    // a map of each Attribute object to a boolean array with true flags indicating what to delete.
+    // A single dimension array tracks attributes being removed entirely.
+    vector<bool> deletedAttributes(attributes.size(), false);
+    map< Attribute*, vector<bool> > deletedPositions;
+
+    // Loop over each attribute to filter them.
+    for (vector<Attribute*>::size_type a=0; a<attributes.size(); ++a) {
+        Attribute* attr = attributes[a];
+
+        // Clear the rule store.
+        applicableRules.clear();
+
+        // Look for rules to run in each policy.
+        for (vector<const Policy*>::const_iterator pol=applicablePolicies.begin(); pol!=applicablePolicies.end(); ++pol) {
+            pair<Policy::rules_t::const_iterator,Policy::rules_t::const_iterator> rules = (*pol)->m_rules.equal_range(attr->getId());
+            for (; rules.first!=rules.second; ++rules.first)
+                applicableRules.push_back(rules.first->second);
+        }
+
+        // If no rules found, apply wildcards.
+        const vector< pair<const MatchFunctor*,const MatchFunctor*> >& rulesToRun = applicableRules.empty() ? wildcardRules : applicableRules;
+
+        // If no rules apply, remove the attribute entirely.
+        if (rulesToRun.empty()) {
+            m_log.warn(
+                "no rule found, removing attribute (%s) from (%s)",
+                attr->getId(), issuer.get() ? issuer.get() : "unknown source"
+                );
+            deletedAttributes[a] = true;
+            continue;
+        }
+
+        // Run each permit/deny rule.
+        m_log.debug(
+            "applying filtering rule(s) for attribute (%s) from (%s)",
+            attr->getId(), issuer.get() ? issuer.get() : "unknown source"
+            );
+
+        bool kickit;
+
+        // Examine each value.
+        for (size_t count = attr->valueCount(), index = 0; index < count; ++index) {
+
+            // Assume we're kicking it out.
+            kickit=true;
+
+            for (vector< pair<const MatchFunctor*,const MatchFunctor*> >::const_iterator r=rulesToRun.begin(); r!=rulesToRun.end(); ++r) {
+                // If there's a permit rule that passes, don't kick it.
+                if (r->first && r->first->evaluatePermitValue(context, *attr, index))
+                    kickit = false;
+                if (!kickit && r->second && r->second->evaluatePermitValue(context, *attr, index))
+                    kickit = true;
+            }
+
+            // If we're kicking it, record that in the tracker.
+            if (kickit) {
+                m_log.warn(
+                    "removed value at position (%lu) of attribute (%s) from (%s)",
+                    index, attr->getId(), issuer.get() ? issuer.get() : "unknown source"
+                    );
+                deletedPositions[attr].resize(index+1);
+                deletedPositions[attr][index] = true;
+            }
+        }
+    }
+
+    // Final step: go over the deletedPositions matrix and apply the actual changes. In order to delete
+    // any attributes that end up with no values, we have to do it by looping over the originals.
+    for (vector<Attribute*>::size_type a=0; a<attributes.size();) {
+        Attribute* attr = attributes[a];
+
+        if (deletedAttributes[a]) {
+            delete attr;
+            deletedAttributes.erase(deletedAttributes.begin() + a);
+            attributes.erase(attributes.begin() + a);
+            continue;
+        }
+        else if (deletedPositions.count(attr) > 0) {
+            // To do the removal, we loop over the bits backwards so that the
+            // underlying value sequence doesn't get distorted by any removals.
+            // Index has to be offset by one because size_type is unsigned.
+            const vector<bool>& row = deletedPositions[attr];
+            for (vector<bool>::size_type index = row.size(); index > 0; --index) {
+                if (row[index-1])
+                    attr->removeValue(index-1);
+            }
+
+            // Check for no values.
+            if (attr->valueCount() == 0) {
+                m_log.warn(
+                    "no values left, removing attribute (%s) from (%s)",
+                    attr->getId(), issuer.get() ? issuer.get() : "unknown source"
+                    );
+                delete attr;
+                attributes.erase(attributes.begin() + a);
+                continue;
             }
         }
+        ++a;
     }
 }