Blacklist groups also.
[shibboleth/cpp-opensaml.git] / saml / saml2 / metadata / impl / BlacklistMetadataFilter.cpp
1 /*
2  *  Copyright 2001-2006 Internet2
3  * 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /**
18  * BlacklistMetadataFilter.cpp
19  * 
20  * Removes blacklisted entities from a metadata instance
21  */
22
23 #include "internal.h"
24 #include "saml2/metadata/MetadataFilter.h"
25
26 #include <log4cpp/Category.hh>
27 #include <xmltooling/util/NDC.h>
28
29 using namespace opensaml::saml2md;
30 using namespace xmltooling;
31 using namespace log4cpp;
32 using namespace std;
33
34 namespace opensaml {
35     namespace saml2md {
36                 
37         class SAML_DLLLOCAL BlacklistMetadataFilter : public MetadataFilter
38         {
39         public:
40             BlacklistMetadataFilter(const DOMElement* e);
41             ~BlacklistMetadataFilter() {}
42             
43             const char* getId() const { return BLACKLIST_METADATA_FILTER; }
44             void doFilter(XMLObject& xmlObject) const;
45
46         private:
47             void doFilter(EntitiesDescriptor& entities) const;
48             
49             bool found(const XMLCh* id) const {
50                 if (!id)
51                     return false;
52 #ifdef HAVE_GOOD_STL
53                 return m_set.count(id)==1;
54 #else
55                 auto_ptr_char id2(id);
56                 return m_set.count(id2.get())==1;
57 #endif
58             }
59
60 #ifdef HAVE_GOOD_STL
61             set<xstring> m_set;
62 #else
63             set<string> m_set;
64 #endif
65         }; 
66
67         MetadataFilter* SAML_DLLLOCAL BlacklistMetadataFilterFactory(const DOMElement* const & e)
68         {
69             return new BlacklistMetadataFilter(e);
70         }
71
72     };
73 };
74
75 static const XMLCh Exclude[] =  UNICODE_LITERAL_7(E,x,c,l,u,d,e);
76
77 BlacklistMetadataFilter::BlacklistMetadataFilter(const DOMElement* e)
78 {
79     e = XMLHelper::getFirstChildElement(e);
80     while (e) {
81         if (XMLString::equals(e->getLocalName(), Exclude) && e->hasChildNodes()) {
82 #ifdef HAVE_GOOD_STL
83             m_set.insert(e->getFirstChild()->getNodeValue());
84 #else
85             auto_ptr_char id(e->getFirstChild()->getNodeValue());
86             m_set.insert(id.get());
87 #endif
88         }
89         e = XMLHelper::getNextSiblingElement(e);
90     }
91 }
92
93 void BlacklistMetadataFilter::doFilter(XMLObject& xmlObject) const
94 {
95 #ifdef _DEBUG
96     NDC ndc("doFilter");
97 #endif
98     
99     try {
100         EntitiesDescriptor& entities = dynamic_cast<EntitiesDescriptor&>(xmlObject);
101         if (found(entities.getName()))
102             throw MetadataFilterException("BlacklistMetadataFilter instructed to filter the root/only group in the metadata.");
103         doFilter(entities);
104         return;
105     }
106     catch (bad_cast) {
107     }
108
109     try {
110         EntityDescriptor& entity = dynamic_cast<EntityDescriptor&>(xmlObject);
111         if (found(entity.getEntityID()))
112             throw MetadataFilterException("BlacklistMetadataFilter instructed to filter the root/only entity in the metadata.");
113         return;
114     }
115     catch (bad_cast) {
116     }
117      
118     throw MetadataFilterException("BlacklistMetadataFilter was given an improper metadata instance to filter.");
119 }
120
121 void BlacklistMetadataFilter::doFilter(EntitiesDescriptor& entities) const
122 {
123     Category& log=Category::getInstance(SAML_LOGCAT".Metadata");
124     
125     VectorOf(EntityDescriptor) v=entities.getEntityDescriptors();
126     for (VectorOf(EntityDescriptor)::size_type i=0; i<v.size(); ) {
127         const XMLCh* id=v[i]->getEntityID();
128         if (found(id)) {
129             auto_ptr_char id2(id);
130             log.info("filtering out blacklisted entity (%s)", id2.get());
131             v.erase(v.begin() + i);
132         }
133         else {
134             i++;
135         }
136     }
137     
138     VectorOf(EntitiesDescriptor) w=entities.getEntitiesDescriptors();
139     for (VectorOf(EntitiesDescriptor)::size_type j=0; j<w.size(); ) {
140         const XMLCh* name=w[j]->getName();
141         if (found(name)) {
142             auto_ptr_char name2(name);
143             log.info("filtering out blacklisted group (%s)", name2.get());
144             w.erase(w.begin() + j);
145         }
146         else {
147             doFilter(*(w[j]));
148             j++;
149         }
150     }
151 }