Add implicit support for a "null" protocol.
[shibboleth/cpp-opensaml.git] / saml / saml2 / metadata / impl / MetadataImpl.cpp
index d6f4ce9..91615d9 100644 (file)
@@ -34,6 +34,7 @@
 #include <ctime>
 #include <xercesc/util/XMLUniDefs.hpp>
 
+using namespace samlconstants;
 using namespace opensaml::saml2md;
 using namespace opensaml::saml2;
 using namespace opensaml;
@@ -43,8 +44,6 @@ using namespace xmltooling;
 using namespace std;
 using xmlconstants::XMLSIG_NS;
 using xmlconstants::XML_BOOL_NULL;
-using samlconstants::SAML20_NS;
-using samlconstants::SAML20MD_NS;
 
 #if defined (_MSC_VER)
     #pragma warning( push )
@@ -966,6 +965,8 @@ namespace opensaml {
             IMPL_TYPED_CHILDREN(ContactPerson,m_pos_ContactPerson);
 
             bool hasSupport(const XMLCh* protocol) const {
+                if (!protocol || !*protocol)
+                    return true;
                 if (m_ProtocolSupportEnumeration) {
                     // Look for first character.
                     unsigned int len=XMLString::stringLen(protocol);
@@ -1007,10 +1008,11 @@ namespace opensaml {
                     setProtocolSupportEnumeration(pse.c_str());
 #else
                     auto_ptr_char temp(m_ProtocolSupportEnumeration);
+                    auto_ptr_char temp2(protocol);
                     string pse(temp.get());
-                    pse = pse + ' ' + protocol;
-                    auto_ptr_XMLCh temp2(pse.c_str());
-                    setProtocolSupportEnumeration(temp2.get());
+                    pse = pse + ' ' + temp2.get();
+                    auto_ptr_XMLCh temp3(pse.c_str());
+                    setProtocolSupportEnumeration(temp3.get());
 #endif
                 }
                 else {
@@ -2197,105 +2199,28 @@ namespace opensaml {
                 AbstractAttributeExtensibleXMLObject::setAttribute(qualifiedName, value, ID);
             }
 
-            const IDPSSODescriptor* getIDPSSODescriptor(const XMLCh* protocol) const {
-                for (vector<IDPSSODescriptor*>::const_iterator i=m_IDPSSODescriptors.begin(); i!=m_IDPSSODescriptors.end(); i++) {
-                    if ((*i)->hasSupport(protocol) && (*i)->isValid())
-                        return (*i);
-                }
-                return NULL;
-            }
-            
-            const SPSSODescriptor* getSPSSODescriptor(const XMLCh* protocol) const {
-                for (vector<SPSSODescriptor*>::const_iterator i=m_SPSSODescriptors.begin(); i!=m_SPSSODescriptors.end(); i++) {
-                    if ((*i)->hasSupport(protocol) && (*i)->isValid())
-                        return (*i);
-                }
-                return NULL;
-            }
-            
-            const AuthnAuthorityDescriptor* getAuthnAuthorityDescriptor(const XMLCh* protocol) const {
-                for (vector<AuthnAuthorityDescriptor*>::const_iterator i=m_AuthnAuthorityDescriptors.begin(); i!=m_AuthnAuthorityDescriptors.end(); i++) {
-                    if ((*i)->hasSupport(protocol) && (*i)->isValid())
-                        return (*i);
-                }
-                return NULL;
-            }
-            
-            const AttributeAuthorityDescriptor* getAttributeAuthorityDescriptor(const XMLCh* protocol) const {
-                for (vector<AttributeAuthorityDescriptor*>::const_iterator i=m_AttributeAuthorityDescriptors.begin(); i!=m_AttributeAuthorityDescriptors.end(); i++) {
-                    if ((*i)->hasSupport(protocol) && (*i)->isValid())
-                        return (*i);
-                }
-                return NULL;
-            }
-            
-            const PDPDescriptor* getPDPDescriptor(const XMLCh* protocol) const {
-                for (vector<PDPDescriptor*>::const_iterator i=m_PDPDescriptors.begin(); i!=m_PDPDescriptors.end(); i++) {
-                    if ((*i)->hasSupport(protocol) && (*i)->isValid())
-                        return (*i);
-                }
-                return NULL;
-            }
-
-            const AuthnQueryDescriptorType* getAuthnQueryDescriptorType(const XMLCh* protocol) const {
-                for (vector<AuthnQueryDescriptorType*>::const_iterator i=m_AuthnQueryDescriptorTypes.begin(); i!=m_AuthnQueryDescriptorTypes.end(); i++) {
-                    if ((*i)->hasSupport(protocol) && (*i)->isValid())
-                        return (*i);
-                }
-                return NULL;
-            }
-
-            const AttributeQueryDescriptorType* getAttributeQueryDescriptorType(const XMLCh* protocol) const {
-                for (vector<AttributeQueryDescriptorType*>::const_iterator i=m_AttributeQueryDescriptorTypes.begin(); i!=m_AttributeQueryDescriptorTypes.end(); i++) {
-                    if ((*i)->hasSupport(protocol) && (*i)->isValid())
-                        return (*i);
-                }
-                return NULL;
-            }
-            
-            const AuthzDecisionQueryDescriptorType* getAuthzDecisionQueryDescriptorType(const XMLCh* protocol) const {
-                for (vector<AuthzDecisionQueryDescriptorType*>::const_iterator i=m_AuthzDecisionQueryDescriptorTypes.begin(); i!=m_AuthzDecisionQueryDescriptorTypes.end(); i++) {
-                    if ((*i)->hasSupport(protocol) && (*i)->isValid())
-                        return (*i);
-                }
-                return NULL;
-            }
-
             const RoleDescriptor* getRoleDescriptor(const xmltooling::QName& qname, const XMLCh* protocol) const {
                 // Check for "known" elements/types.
-                QName q;
-                q.setNamespaceURI(SAML20MD_NS);
-                q.setLocalPart(IDPSSODescriptor::LOCAL_NAME);
-                if (q == qname)
-                    return getIDPSSODescriptor(protocol);
-                q.setLocalPart(SPSSODescriptor::LOCAL_NAME);
-                if (q == qname)
-                    return getSPSSODescriptor(protocol);
-                q.setLocalPart(AuthnAuthorityDescriptor::LOCAL_NAME);
-                if (q == qname)
-                    return getAuthnAuthorityDescriptor(protocol);
-                q.setLocalPart(AttributeAuthorityDescriptor::LOCAL_NAME);
-                if (q == qname)
-                    return getAttributeAuthorityDescriptor(protocol);
-                q.setLocalPart(PDPDescriptor::LOCAL_NAME);
-                if (q == qname)
-                    return getPDPDescriptor(protocol);
-                q.setNamespaceURI(samlconstants::SAML20MD_QUERY_EXT_NS);
-                q.setLocalPart(AuthnQueryDescriptorType::TYPE_NAME);
-                if (q == qname)
-                    return getAuthnQueryDescriptorType(protocol);
-                q.setLocalPart(AttributeQueryDescriptorType::TYPE_NAME);
-                if (q == qname)
-                    return getAttributeQueryDescriptorType(protocol);
-                q.setLocalPart(AuthzDecisionQueryDescriptorType::TYPE_NAME);
-                if (q == qname)
-                    return getAuthzDecisionQueryDescriptorType(protocol);
+                if (qname == IDPSSODescriptor::ELEMENT_QNAME)
+                    return find_if(m_IDPSSODescriptors, isValidForProtocol(protocol));
+                if (qname == SPSSODescriptor::ELEMENT_QNAME)
+                    return find_if(m_SPSSODescriptors, isValidForProtocol(protocol));
+                if (qname == AuthnAuthorityDescriptor::ELEMENT_QNAME)
+                    return find_if(m_AuthnAuthorityDescriptors, isValidForProtocol(protocol));
+                if (qname == AttributeAuthorityDescriptor::ELEMENT_QNAME)
+                    return find_if(m_AttributeAuthorityDescriptors, isValidForProtocol(protocol));
+                if (qname == PDPDescriptor::ELEMENT_QNAME)
+                    return find_if(m_PDPDescriptors, isValidForProtocol(protocol));
+                if (qname == AuthnQueryDescriptorType::TYPE_QNAME)
+                    return find_if(m_AuthnQueryDescriptorTypes, isValidForProtocol(protocol));
+                if (qname == AttributeQueryDescriptorType::TYPE_QNAME)
+                    return find_if(m_AttributeQueryDescriptorTypes, isValidForProtocol(protocol));
+                if (qname == AuthzDecisionQueryDescriptorType::TYPE_QNAME)
+                    return find_if(m_AuthzDecisionQueryDescriptorTypes, isValidForProtocol(protocol));
                 
-                for (vector<RoleDescriptor*>::const_iterator i=m_RoleDescriptors.begin(); i!=m_RoleDescriptors.end(); i++) {
-                    if ((*i)->getSchemaType() && qname==(*((*i)->getSchemaType())) && (*i)->hasSupport(protocol) && (*i)->isValid())
-                        return (*i);
-                }
-                return NULL;
+                vector<RoleDescriptor*>::const_iterator i =
+                    find_if(m_RoleDescriptors.begin(), m_RoleDescriptors.end(), ofTypeValidForProtocol(qname,protocol));
+                return (i!=m_RoleDescriptors.end()) ? *i : NULL;
             }
 
         protected:
@@ -2452,6 +2377,15 @@ namespace opensaml {
     #pragma warning( pop )
 #endif
 
+IMPL_ELEMENT_QNAME(IDPSSODescriptor, SAML20MD_NS, SAML20MD_PREFIX);
+IMPL_ELEMENT_QNAME(SPSSODescriptor, SAML20MD_NS, SAML20MD_PREFIX);
+IMPL_ELEMENT_QNAME(AuthnAuthorityDescriptor, SAML20MD_NS, SAML20MD_PREFIX);
+IMPL_ELEMENT_QNAME(AttributeAuthorityDescriptor, SAML20MD_NS, SAML20MD_PREFIX);
+IMPL_ELEMENT_QNAME(PDPDescriptor, SAML20MD_NS, SAML20MD_PREFIX);
+IMPL_TYPE_QNAME(AuthnQueryDescriptorType, SAML20MD_QUERY_EXT_NS, SAML20MD_QUERY_EXT_PREFIX);
+IMPL_TYPE_QNAME(AttributeQueryDescriptorType, SAML20MD_QUERY_EXT_NS, SAML20MD_QUERY_EXT_PREFIX);
+IMPL_TYPE_QNAME(AuthzDecisionQueryDescriptorType, SAML20MD_QUERY_EXT_NS, SAML20MD_QUERY_EXT_PREFIX);
+
 // Builder Implementations
 
 IMPL_XMLOBJECTBUILDER(AdditionalMetadataLocation);