Metadata filters, filter auto-registration, and unit tests.
[shibboleth/cpp-opensaml.git] / saml / saml2 / metadata / impl / BlacklistMetadataFilter.cpp
1 /*
2  *  Copyright 2001-2006 Internet2
3  * 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /**
18  * BlacklistMetadataFilter.cpp
19  * 
20  * Removes blacklisted entities from a metadata instance
21  */
22
23 #include "internal.h"
24 #include "saml2/metadata/MetadataFilter.h"
25
26 #include <log4cpp/Category.hh>
27 #include <xmltooling/util/NDC.h>
28
29 using namespace opensaml::saml2md;
30 using namespace xmltooling;
31 using namespace log4cpp;
32 using namespace std;
33
34 namespace opensaml {
35     namespace saml2md {
36                 
37         class SAML_DLLLOCAL BlacklistMetadataFilter : public MetadataFilter
38         {
39         public:
40             BlacklistMetadataFilter(const DOMElement* e);
41             ~BlacklistMetadataFilter() {}
42             
43             const char* getId() const { return BLACKLIST_METADATA_FILTER; }
44             void doFilter(XMLObject& xmlObject) const;
45
46         private:
47             void doFilter(EntitiesDescriptor& entities) const;
48             
49             bool found(const XMLCh* id) const {
50                 if (!id)
51                     return false;
52 #ifdef HAVE_GOOD_STL
53                 return m_set.count(id)==1;
54 #else
55                 auto_ptr_char id2(id);
56                 return m_set.count(id2.get())==1;
57 #endif
58             }
59
60 #ifdef HAVE_GOOD_STL
61             set<xstring> m_set;
62 #else
63             set<string> m_set;
64 #endif
65         }; 
66
67         MetadataFilter* SAML_DLLLOCAL BlacklistMetadataFilterFactory(const DOMElement* const & e)
68         {
69             return new BlacklistMetadataFilter(e);
70         }
71
72     };
73 };
74
75 static const XMLCh Exclude[] =  UNICODE_LITERAL_7(E,x,c,l,u,d,e);
76
77 BlacklistMetadataFilter::BlacklistMetadataFilter(const DOMElement* e)
78 {
79     e = XMLHelper::getFirstChildElement(e);
80     while (e) {
81         if (XMLString::equals(e->getLocalName(), Exclude) && e->hasChildNodes()) {
82 #ifdef HAVE_GOOD_STL
83             m_set.insert(e->getFirstChild()->getNodeValue());
84 #else
85             auto_ptr_char id(e->getFirstChild()->getNodeValue());
86             m_set.insert(id.get());
87 #endif
88         }
89         e = XMLHelper::getNextSiblingElement(e);
90     }
91 }
92
93 void BlacklistMetadataFilter::doFilter(XMLObject& xmlObject) const
94 {
95 #ifdef _DEBUG
96     NDC ndc("doFilter");
97 #endif
98     
99     try {
100         doFilter(dynamic_cast<EntitiesDescriptor&>(xmlObject));
101         return;
102     }
103     catch (bad_cast) {
104     }
105
106     try {
107         EntityDescriptor& entity = dynamic_cast<EntityDescriptor&>(xmlObject);
108         if (found(entity.getEntityID()))
109             throw MetadataFilterException("BlacklistMetadataFilter instructed to filter the root/only entity in the metadata.");
110         return;
111     }
112     catch (bad_cast) {
113     }
114      
115     throw MetadataFilterException("BlacklistMetadataFilter was given an improper metadata instance to filter.");
116 }
117
118 void BlacklistMetadataFilter::doFilter(EntitiesDescriptor& entities) const
119 {
120     Category& log=Category::getInstance(SAML_LOGCAT".Metadata");
121     
122     VectorOf(EntityDescriptor) v=entities.getEntityDescriptors();
123     for (VectorOf(EntityDescriptor)::size_type i=0; i<v.size(); ) {
124         const XMLCh* id=v[i]->getEntityID();
125         if (found(id)) {
126             auto_ptr_char id2(id);
127             log.info("filtering out blacklisted entity (%s)", id2.get());
128             v.erase(v.begin() + i);
129         }
130         else {
131             i++;
132         }
133     }
134     
135     const vector<EntitiesDescriptor*>& groups=const_cast<const EntitiesDescriptor&>(entities).getEntitiesDescriptors();
136     for (vector<EntitiesDescriptor*>::const_iterator j=groups.begin(); j!=groups.end(); j++)
137         doFilter(*(*j));
138 }