Add implicit support for a "null" protocol.
[shibboleth/cpp-opensaml.git] / saml / saml2 / metadata / impl / MetadataImpl.cpp
index 43162ec..91615d9 100644 (file)
@@ -34,6 +34,7 @@
 #include <ctime>
 #include <xercesc/util/XMLUniDefs.hpp>
 
+using namespace samlconstants;
 using namespace opensaml::saml2md;
 using namespace opensaml::saml2;
 using namespace opensaml;
@@ -43,8 +44,6 @@ using namespace xmltooling;
 using namespace std;
 using xmlconstants::XMLSIG_NS;
 using xmlconstants::XML_BOOL_NULL;
-using samlconstants::SAML20_NS;
-using samlconstants::SAML20MD_NS;
 
 #if defined (_MSC_VER)
     #pragma warning( push )
@@ -966,6 +965,8 @@ namespace opensaml {
             IMPL_TYPED_CHILDREN(ContactPerson,m_pos_ContactPerson);
 
             bool hasSupport(const XMLCh* protocol) const {
+                if (!protocol || !*protocol)
+                    return true;
                 if (m_ProtocolSupportEnumeration) {
                     // Look for first character.
                     unsigned int len=XMLString::stringLen(protocol);
@@ -996,6 +997,28 @@ namespace opensaml {
                 }
                 return false;
             }
+
+            void addSupport(const XMLCh* protocol) {
+                if (hasSupport(protocol))
+                    return;
+                if (m_ProtocolSupportEnumeration && *m_ProtocolSupportEnumeration) {
+#ifdef HAVE_GOOD_STL
+                    xstring pse(m_ProtocolSupportEnumeration);
+                    pse = pse + chSpace + protocol;
+                    setProtocolSupportEnumeration(pse.c_str());
+#else
+                    auto_ptr_char temp(m_ProtocolSupportEnumeration);
+                    auto_ptr_char temp2(protocol);
+                    string pse(temp.get());
+                    pse = pse + ' ' + temp2.get();
+                    auto_ptr_XMLCh temp3(pse.c_str());
+                    setProtocolSupportEnumeration(temp3.get());
+#endif
+                }
+                else {
+                    setProtocolSupportEnumeration(protocol);
+                }
+            }
     
             void setAttribute(const QName& qualifiedName, const XMLCh* value, bool ID=false) {
                 if (!qualifiedName.hasNamespaceURI()) {
@@ -2176,105 +2199,28 @@ namespace opensaml {
                 AbstractAttributeExtensibleXMLObject::setAttribute(qualifiedName, value, ID);
             }
 
-            const IDPSSODescriptor* getIDPSSODescriptor(const XMLCh* protocol) const {
-                for (vector<IDPSSODescriptor*>::const_iterator i=m_IDPSSODescriptors.begin(); i!=m_IDPSSODescriptors.end(); i++) {
-                    if ((*i)->hasSupport(protocol) && (*i)->isValid())
-                        return (*i);
-                }
-                return NULL;
-            }
-            
-            const SPSSODescriptor* getSPSSODescriptor(const XMLCh* protocol) const {
-                for (vector<SPSSODescriptor*>::const_iterator i=m_SPSSODescriptors.begin(); i!=m_SPSSODescriptors.end(); i++) {
-                    if ((*i)->hasSupport(protocol) && (*i)->isValid())
-                        return (*i);
-                }
-                return NULL;
-            }
-            
-            const AuthnAuthorityDescriptor* getAuthnAuthorityDescriptor(const XMLCh* protocol) const {
-                for (vector<AuthnAuthorityDescriptor*>::const_iterator i=m_AuthnAuthorityDescriptors.begin(); i!=m_AuthnAuthorityDescriptors.end(); i++) {
-                    if ((*i)->hasSupport(protocol) && (*i)->isValid())
-                        return (*i);
-                }
-                return NULL;
-            }
-            
-            const AttributeAuthorityDescriptor* getAttributeAuthorityDescriptor(const XMLCh* protocol) const {
-                for (vector<AttributeAuthorityDescriptor*>::const_iterator i=m_AttributeAuthorityDescriptors.begin(); i!=m_AttributeAuthorityDescriptors.end(); i++) {
-                    if ((*i)->hasSupport(protocol) && (*i)->isValid())
-                        return (*i);
-                }
-                return NULL;
-            }
-            
-            const PDPDescriptor* getPDPDescriptor(const XMLCh* protocol) const {
-                for (vector<PDPDescriptor*>::const_iterator i=m_PDPDescriptors.begin(); i!=m_PDPDescriptors.end(); i++) {
-                    if ((*i)->hasSupport(protocol) && (*i)->isValid())
-                        return (*i);
-                }
-                return NULL;
-            }
-
-            const AuthnQueryDescriptorType* getAuthnQueryDescriptorType(const XMLCh* protocol) const {
-                for (vector<AuthnQueryDescriptorType*>::const_iterator i=m_AuthnQueryDescriptorTypes.begin(); i!=m_AuthnQueryDescriptorTypes.end(); i++) {
-                    if ((*i)->hasSupport(protocol) && (*i)->isValid())
-                        return (*i);
-                }
-                return NULL;
-            }
-
-            const AttributeQueryDescriptorType* getAttributeQueryDescriptorType(const XMLCh* protocol) const {
-                for (vector<AttributeQueryDescriptorType*>::const_iterator i=m_AttributeQueryDescriptorTypes.begin(); i!=m_AttributeQueryDescriptorTypes.end(); i++) {
-                    if ((*i)->hasSupport(protocol) && (*i)->isValid())
-                        return (*i);
-                }
-                return NULL;
-            }
-            
-            const AuthzDecisionQueryDescriptorType* getAuthzDecisionQueryDescriptorType(const XMLCh* protocol) const {
-                for (vector<AuthzDecisionQueryDescriptorType*>::const_iterator i=m_AuthzDecisionQueryDescriptorTypes.begin(); i!=m_AuthzDecisionQueryDescriptorTypes.end(); i++) {
-                    if ((*i)->hasSupport(protocol) && (*i)->isValid())
-                        return (*i);
-                }
-                return NULL;
-            }
-
             const RoleDescriptor* getRoleDescriptor(const xmltooling::QName& qname, const XMLCh* protocol) const {
                 // Check for "known" elements/types.
-                QName q;
-                q.setNamespaceURI(SAML20MD_NS);
-                q.setLocalPart(IDPSSODescriptor::LOCAL_NAME);
-                if (q == qname)
-                    return getIDPSSODescriptor(protocol);
-                q.setLocalPart(SPSSODescriptor::LOCAL_NAME);
-                if (q == qname)
-                    return getSPSSODescriptor(protocol);
-                q.setLocalPart(AuthnAuthorityDescriptor::LOCAL_NAME);
-                if (q == qname)
-                    return getAuthnAuthorityDescriptor(protocol);
-                q.setLocalPart(AttributeAuthorityDescriptor::LOCAL_NAME);
-                if (q == qname)
-                    return getAttributeAuthorityDescriptor(protocol);
-                q.setLocalPart(PDPDescriptor::LOCAL_NAME);
-                if (q == qname)
-                    return getPDPDescriptor(protocol);
-                q.setNamespaceURI(samlconstants::SAML20MD_QUERY_EXT_NS);
-                q.setLocalPart(AuthnQueryDescriptorType::TYPE_NAME);
-                if (q == qname)
-                    return getAuthnQueryDescriptorType(protocol);
-                q.setLocalPart(AttributeQueryDescriptorType::TYPE_NAME);
-                if (q == qname)
-                    return getAttributeQueryDescriptorType(protocol);
-                q.setLocalPart(AuthzDecisionQueryDescriptorType::TYPE_NAME);
-                if (q == qname)
-                    return getAuthzDecisionQueryDescriptorType(protocol);
+                if (qname == IDPSSODescriptor::ELEMENT_QNAME)
+                    return find_if(m_IDPSSODescriptors, isValidForProtocol(protocol));
+                if (qname == SPSSODescriptor::ELEMENT_QNAME)
+                    return find_if(m_SPSSODescriptors, isValidForProtocol(protocol));
+                if (qname == AuthnAuthorityDescriptor::ELEMENT_QNAME)
+                    return find_if(m_AuthnAuthorityDescriptors, isValidForProtocol(protocol));
+                if (qname == AttributeAuthorityDescriptor::ELEMENT_QNAME)
+                    return find_if(m_AttributeAuthorityDescriptors, isValidForProtocol(protocol));
+                if (qname == PDPDescriptor::ELEMENT_QNAME)
+                    return find_if(m_PDPDescriptors, isValidForProtocol(protocol));
+                if (qname == AuthnQueryDescriptorType::TYPE_QNAME)
+                    return find_if(m_AuthnQueryDescriptorTypes, isValidForProtocol(protocol));
+                if (qname == AttributeQueryDescriptorType::TYPE_QNAME)
+                    return find_if(m_AttributeQueryDescriptorTypes, isValidForProtocol(protocol));
+                if (qname == AuthzDecisionQueryDescriptorType::TYPE_QNAME)
+                    return find_if(m_AuthzDecisionQueryDescriptorTypes, isValidForProtocol(protocol));
                 
-                for (vector<RoleDescriptor*>::const_iterator i=m_RoleDescriptors.begin(); i!=m_RoleDescriptors.end(); i++) {
-                    if ((*i)->getSchemaType() && qname==(*((*i)->getSchemaType())) && (*i)->hasSupport(protocol) && (*i)->isValid())
-                        return (*i);
-                }
-                return NULL;
+                vector<RoleDescriptor*>::const_iterator i =
+                    find_if(m_RoleDescriptors.begin(), m_RoleDescriptors.end(), ofTypeValidForProtocol(qname,protocol));
+                return (i!=m_RoleDescriptors.end()) ? *i : NULL;
             }
 
         protected:
@@ -2431,6 +2377,15 @@ namespace opensaml {
     #pragma warning( pop )
 #endif
 
+IMPL_ELEMENT_QNAME(IDPSSODescriptor, SAML20MD_NS, SAML20MD_PREFIX);
+IMPL_ELEMENT_QNAME(SPSSODescriptor, SAML20MD_NS, SAML20MD_PREFIX);
+IMPL_ELEMENT_QNAME(AuthnAuthorityDescriptor, SAML20MD_NS, SAML20MD_PREFIX);
+IMPL_ELEMENT_QNAME(AttributeAuthorityDescriptor, SAML20MD_NS, SAML20MD_PREFIX);
+IMPL_ELEMENT_QNAME(PDPDescriptor, SAML20MD_NS, SAML20MD_PREFIX);
+IMPL_TYPE_QNAME(AuthnQueryDescriptorType, SAML20MD_QUERY_EXT_NS, SAML20MD_QUERY_EXT_PREFIX);
+IMPL_TYPE_QNAME(AttributeQueryDescriptorType, SAML20MD_QUERY_EXT_NS, SAML20MD_QUERY_EXT_PREFIX);
+IMPL_TYPE_QNAME(AuthzDecisionQueryDescriptorType, SAML20MD_QUERY_EXT_NS, SAML20MD_QUERY_EXT_PREFIX);
+
 // Builder Implementations
 
 IMPL_XMLOBJECTBUILDER(AdditionalMetadataLocation);