https://issues.shibboleth.net/jira/browse/CPPOST-87
[shibboleth/cpp-opensaml.git] / saml / saml2 / metadata / impl / DiscoverableMetadataProvider.cpp
index a2d89bb..2b26618 100644 (file)
 
 #include "internal.h"
 #include "binding/SAMLArtifact.h"
+#include "saml2/metadata/EntityMatcher.h"
 #include "saml2/metadata/Metadata.h"
 #include "saml2/metadata/DiscoverableMetadataProvider.h"
 
 #include <fstream>
 #include <sstream>
-#include <boost/bind.hpp>
+#include <boost/lambda/bind.hpp>
+#include <boost/lambda/casts.hpp>
+#include <boost/lambda/lambda.hpp>
 #include <boost/iterator/indirect_iterator.hpp>
 #include <xmltooling/logging.h>
 #include <xmltooling/XMLToolingConfig.h>
 
+using namespace opensaml::saml2;
 using namespace opensaml::saml2md;
+using namespace xmltooling::logging;
 using namespace xmltooling;
+using namespace boost::lambda;
 using namespace boost;
 using namespace std;
 
 DiscoverableMetadataProvider::DiscoverableMetadataProvider(const DOMElement* e) : MetadataProvider(e), m_legacyOrgNames(false)
 {
-    static const XMLCh legacyOrgNames[] = UNICODE_LITERAL_14(l,e,g,a,c,y,O,r,g,N,a,m,e,s);
+    static const XMLCh legacyOrgNames[] =   UNICODE_LITERAL_14(l,e,g,a,c,y,O,r,g,N,a,m,e,s);
+    static const XMLCh matcher[] =          UNICODE_LITERAL_7(m,a,t,c,h,e,r);
+    static const XMLCh tagsInFeed[] =       UNICODE_LITERAL_10(t,a,g,s,I,n,F,e,e,d);
+    static const XMLCh _type[] =            UNICODE_LITERAL_4(t,y,p,e);
+    static const XMLCh DiscoveryFilter[] =  UNICODE_LITERAL_15(D,i,s,c,o,v,e,r,y,F,i,l,t,e,r);
+
     m_legacyOrgNames = XMLHelper::getAttrBool(e, false, legacyOrgNames);
+    m_entityAttributes = XMLHelper::getAttrBool(e, false, tagsInFeed);
+
+    e = e ? XMLHelper::getFirstChildElement(e, DiscoveryFilter) : nullptr;
+    while (e) {
+        string t(XMLHelper::getAttrString(e, nullptr, _type));
+        if (t == "Whitelist" || t == "Blacklist") {
+            string m(XMLHelper::getAttrString(e, nullptr, matcher));
+            if (!m.empty()) {
+                try {
+                    boost::shared_ptr<EntityMatcher> temp(SAMLConfig::getConfig().EntityMatcherManager.newPlugin(m, e));
+                    m_discoFilters.push_back(make_pair(t == "Whitelist", temp));
+                }
+                catch (std::exception& ex) {
+                    Category::getInstance(SAML_LOGCAT".MetadataProvider.Discoverable").error(
+                        "exception creating <DiscoveryFilter> EntityMatcher: %s", ex.what()
+                        );
+                }
+            }
+            else {
+                Category::getInstance(SAML_LOGCAT".MetadataProvider.Discoverable").error("<DiscoveryFilter> requires matcher attribute");
+            }
+        }
+        else {
+            Category::getInstance(SAML_LOGCAT".MetadataProvider.Discoverable").error(
+                "unknown <DiscoveryFilter> type (%s)", t.empty() ? "none" : t.c_str()
+                );
+        }
+        e = XMLHelper::getNextSiblingElement(e, DiscoveryFilter);
+    }
 }
 
 DiscoverableMetadataProvider::~DiscoverableMetadataProvider()
@@ -83,41 +123,52 @@ void DiscoverableMetadataProvider::outputFeed(ostream& os, bool& first, bool wra
         os << "\n]";
 }
 
-static string& json_safe(string& s, const char* buf)
-{
-    for (; *buf; ++buf) {
-        switch (*buf) {
-            case '\\':
-            case '"':
-                s += '\\';
-                s += *buf;
-                break;
-            case '\b':
-                s += "\\b";
-                break;
-            case '\t':
-                s += "\\t";
-                break;
-            case '\n':
-                s += "\\n";
-                break;
-            case '\f':
-                s += "\\f";
-                break;
-            case '\r':
-                s += "\\r";
-                break;
-            default:
-                s += *buf;
+namespace {
+    static string& json_safe(string& s, const char* buf)
+    {
+        for (; *buf; ++buf) {
+            switch (*buf) {
+                case '\\':
+                case '"':
+                    s += '\\';
+                    s += *buf;
+                    break;
+                case '\b':
+                    s += "\\b";
+                    break;
+                case '\t':
+                    s += "\\t";
+                    break;
+                case '\n':
+                    s += "\\n";
+                    break;
+                case '\f':
+                    s += "\\f";
+                    break;
+                case '\r':
+                    s += "\\r";
+                    break;
+                default:
+                    s += *buf;
+            }
         }
+        return s;
     }
-    return s;
-}
+};
 
 void DiscoverableMetadataProvider::discoEntity(string& s, const EntityDescriptor* entity, bool& first) const
 {
     time_t now = time(nullptr);
     if (entity && entity->isValid(now)) {
+
+        // Check filter(s).
+        for (vector< pair < bool, boost::shared_ptr<EntityMatcher> > >::const_iterator f = m_discoFilters.begin(); f != m_discoFilters.end(); ++f) {
+            // The flag is true for a whitelist and false for a blacklist,
+            // so we omit the entity if the match outcome is the inverse.
+            if (f->first != f->second->matches(*entity))
+                return;
+        }
+
         const vector<IDPSSODescriptor*>& idps = entity->getIDPSSODescriptors();
         if (!idps.empty()) {
             auto_ptr_char entityid(entity->getEntityID());
@@ -130,6 +181,7 @@ void DiscoverableMetadataProvider::discoEntity(string& s, const EntityDescriptor
             json_safe(s, entityid.get());
             s += '\"';
             bool extFound = false;
+            bool displayNameFound = false;
             for (indirect_iterator<vector<IDPSSODescriptor*>::const_iterator> idp = make_indirect_iterator(idps.begin());
                     !extFound && idp != make_indirect_iterator(idps.end()); ++idp) {
                 if (idp->isValid(now) && idp->getExtensions()) {
@@ -140,6 +192,7 @@ void DiscoverableMetadataProvider::discoEntity(string& s, const EntityDescriptor
                             extFound = true;
                             const vector<DisplayName*>& dispnames = info->getDisplayNames();
                             if (!dispnames.empty()) {
+                                displayNameFound = true;
                                 s += ",\n \"DisplayNames\": [";
                                 for (indirect_iterator<vector<DisplayName*>::const_iterator> dispname = make_indirect_iterator(dispnames.begin());
                                         dispname != make_indirect_iterator(dispnames.end()); ++dispname) {
@@ -263,7 +316,7 @@ void DiscoverableMetadataProvider::discoEntity(string& s, const EntityDescriptor
                 }
             }
 
-            if (m_legacyOrgNames && !extFound) {
+            if (m_legacyOrgNames && !displayNameFound) {
                 const Organization* org = nullptr;
                 for (indirect_iterator<vector<IDPSSODescriptor*>::const_iterator> idp = make_indirect_iterator(idps.begin());
                         !org && idp != make_indirect_iterator(idps.end()); ++idp) {
@@ -293,6 +346,32 @@ void DiscoverableMetadataProvider::discoEntity(string& s, const EntityDescriptor
                 }
             }
 
+            if (m_entityAttributes) {
+                bool tagfirst = true;
+                // Check for an EntityAttributes extension in the entity and its parent(s).
+                const Extensions* exts = entity->getExtensions();
+                if (exts) {
+                    const vector<XMLObject*>& children = exts->getUnknownXMLObjects();
+                    const XMLObject* xo = find_if(children, ll_dynamic_cast<EntityAttributes*>(_1) != ((EntityAttributes*)nullptr));
+                    if (xo)
+                        discoEntityAttributes(s, *dynamic_cast<const EntityAttributes*>(xo), tagfirst);
+                }
+
+                const EntitiesDescriptor* group = dynamic_cast<EntitiesDescriptor*>(entity->getParent());
+                while (group) {
+                    exts = group->getExtensions();
+                    if (exts) {
+                        const vector<XMLObject*>& children = exts->getUnknownXMLObjects();
+                        const XMLObject* xo = find_if(children, ll_dynamic_cast<EntityAttributes*>(_1) != ((EntityAttributes*)nullptr));
+                        if (xo)
+                            discoEntityAttributes(s, *dynamic_cast<const EntityAttributes*>(xo), tagfirst);
+                    }
+                    group = dynamic_cast<EntitiesDescriptor*>(group->getParent());
+                }
+                if (!tagfirst)
+                    s += "\n ]";
+            }
+
             // Close the struct;
             s += "\n}";
         }
@@ -304,11 +383,55 @@ void DiscoverableMetadataProvider::discoGroup(string& s, const EntitiesDescripto
     if (group) {
         for_each(
             group->getEntitiesDescriptors().begin(), group->getEntitiesDescriptors().end(),
-            boost::bind(&DiscoverableMetadataProvider::discoGroup, this, boost::ref(s), _1, boost::ref(first))
+            lambda::bind(&DiscoverableMetadataProvider::discoGroup, this, boost::ref(s), _1, boost::ref(first))
             );
         for_each(
             group->getEntityDescriptors().begin(), group->getEntityDescriptors().end(),
-            boost::bind(&DiscoverableMetadataProvider::discoEntity, this, boost::ref(s), _1, boost::ref(first))
+            lambda::bind(&DiscoverableMetadataProvider::discoEntity, this, boost::ref(s), _1, boost::ref(first))
             );
     }
 }
+
+void DiscoverableMetadataProvider::discoEntityAttributes(std::string& s, const EntityAttributes& ea, bool& first) const
+{
+    discoAttributes(s, ea.getAttributes(), first);
+    const vector<saml2::Assertion*>& tokens = ea.getAssertions();
+    for (vector<saml2::Assertion*>::const_iterator t = tokens.begin(); t != tokens.end(); ++t) {
+        const vector<AttributeStatement*> statements = const_cast<const saml2::Assertion*>(*t)->getAttributeStatements();
+        for (vector<AttributeStatement*>::const_iterator st = statements.begin(); st != statements.end(); ++st) {
+            discoAttributes(s, const_cast<const AttributeStatement*>(*st)->getAttributes(), first);
+        }
+    }
+}
+
+void DiscoverableMetadataProvider::discoAttributes(std::string& s, const vector<Attribute*>& attrs, bool& first) const
+{
+    for (indirect_iterator<vector<Attribute*>::const_iterator> a = make_indirect_iterator(attrs.begin());
+            a != make_indirect_iterator(attrs.end()); ++a) {
+
+        if (first) {
+            s += ",\n \"EntityAttributes\": [";
+            first = false;
+        }
+        else {
+            s += ',';
+        }
+
+        auto_ptr_char n(a->getName());
+        s += "\n  {\n  \"name\": \"";
+        json_safe(s, n.get());
+        s += "\",\n  \"values\": [";
+        const vector<XMLObject*>& vals = const_cast<const Attribute&>(*a).getAttributeValues();
+        for (indirect_iterator<vector<XMLObject*>::const_iterator> v = make_indirect_iterator(vals.begin());
+                v != make_indirect_iterator(vals.end()); ++v) {
+            if (v.base() != vals.begin())
+                s += ',';
+            auto_arrayptr<char> val(toUTF8(v->getTextContent()));
+            s += "\n     \"";
+            if (val.get())
+                json_safe(s, val.get());
+            s += '\"';
+        }
+        s += "\n  ]\n  }";
+    }
+}