Update ctors to use new attribute shortcuts.
[shibboleth/cpp-opensaml.git] / saml / saml2 / metadata / impl / AbstractMetadataProvider.cpp
1 /*
2  *  Copyright 2001-2010 Internet2
3  * 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /**
18  * AbstractMetadataProvider.cpp
19  * 
20  * Base class for caching metadata providers.
21  */
22
23 #include "internal.h"
24 #include "binding/SAMLArtifact.h"
25 #include "saml2/metadata/Metadata.h"
26 #include "saml2/metadata/AbstractMetadataProvider.h"
27 #include "saml2/metadata/MetadataCredentialContext.h"
28 #include "saml2/metadata/MetadataCredentialCriteria.h"
29
30 #include <xercesc/util/XMLUniDefs.hpp>
31 #include <xmltooling/logging.h>
32 #include <xmltooling/XMLToolingConfig.h>
33 #include <xmltooling/security/Credential.h>
34 #include <xmltooling/security/KeyInfoResolver.h>
35 #include <xmltooling/security/SecurityHelper.h>
36 #include <xmltooling/util/Threads.h>
37 #include <xmltooling/util/XMLHelper.h>
38
39 using namespace opensaml::saml2md;
40 using namespace xmltooling::logging;
41 using namespace xmltooling;
42 using namespace std;
43 using opensaml::SAMLArtifact;
44
45 static const XMLCh _KeyInfoResolver[] = UNICODE_LITERAL_15(K,e,y,I,n,f,o,R,e,s,o,l,v,e,r);
46 static const XMLCh type[] =             UNICODE_LITERAL_4(t,y,p,e);
47
48 AbstractMetadataProvider::AbstractMetadataProvider(const DOMElement* e)
49     : ObservableMetadataProvider(e), m_resolver(nullptr), m_credentialLock(nullptr)
50 {
51     e = XMLHelper::getFirstChildElement(e, _KeyInfoResolver);
52     if (e) {
53         string t = XMLHelper::getAttrString(e, nullptr, type);
54         if (!t.empty())
55             m_resolver = XMLToolingConfig::getConfig().KeyInfoResolverManager.newPlugin(t.c_str(), e);
56         else
57             throw UnknownExtensionException("<KeyInfoResolver> element found with no type attribute");
58     }
59     m_credentialLock = Mutex::create();
60 }
61
62 AbstractMetadataProvider::~AbstractMetadataProvider()
63 {
64     for (credmap_t::iterator c = m_credentialMap.begin(); c!=m_credentialMap.end(); ++c)
65         for_each(c->second.begin(), c->second.end(), xmltooling::cleanup<Credential>());
66     delete m_credentialLock;
67     delete m_resolver;
68 }
69
70 void AbstractMetadataProvider::emitChangeEvent() const
71 {
72     for (credmap_t::iterator c = m_credentialMap.begin(); c!=m_credentialMap.end(); ++c)
73         for_each(c->second.begin(), c->second.end(), xmltooling::cleanup<Credential>());
74     m_credentialMap.clear();
75     ObservableMetadataProvider::emitChangeEvent();
76 }
77
78 void AbstractMetadataProvider::indexEntity(EntityDescriptor* site, time_t& validUntil, bool replace) const
79 {
80     // If child expires later than input, reset child, otherwise lower input to match.
81     if (validUntil < site->getValidUntilEpoch())
82         site->setValidUntil(validUntil);
83     else
84         validUntil = site->getValidUntilEpoch();
85
86     auto_ptr_char id(site->getEntityID());
87     if (id.get()) {
88         if (replace) {
89             m_sites.erase(id.get());
90             for (sitemap_t::iterator s = m_sources.begin(); s != m_sources.end();) {
91                 if (s->second == site) {
92                     sitemap_t::iterator temp = s;
93                     ++s;
94                     m_sources.erase(temp);
95                 }
96                 else {
97                     ++s;
98                 }
99             }
100         }
101         m_sites.insert(sitemap_t::value_type(id.get(),site));
102     }
103     
104     // Process each IdP role.
105     const vector<IDPSSODescriptor*>& roles = const_cast<const EntityDescriptor*>(site)->getIDPSSODescriptors();
106     for (vector<IDPSSODescriptor*>::const_iterator i = roles.begin(); i != roles.end(); i++) {
107         // SAML 1.x?
108         if ((*i)->hasSupport(samlconstants::SAML10_PROTOCOL_ENUM) || (*i)->hasSupport(samlconstants::SAML11_PROTOCOL_ENUM)) {
109             // Check for SourceID extension element.
110             const Extensions* exts = (*i)->getExtensions();
111             if (exts && exts->hasChildren()) {
112                 const vector<XMLObject*>& children = exts->getUnknownXMLObjects();
113                 for (vector<XMLObject*>::const_iterator ext = children.begin(); ext != children.end(); ++ext) {
114                     SourceID* sid = dynamic_cast<SourceID*>(*ext);
115                     if (sid) {
116                         auto_ptr_char sourceid(sid->getID());
117                         if (sourceid.get()) {
118                             m_sources.insert(sitemap_t::value_type(sourceid.get(),site));
119                             break;
120                         }
121                     }
122                 }
123             }
124             
125             // Hash the ID.
126             m_sources.insert(sitemap_t::value_type(SecurityHelper::doHash("SHA1", id.get(), strlen(id.get())),site));
127                 
128             // Load endpoints for type 0x0002 artifacts.
129             const vector<ArtifactResolutionService*>& locs = const_cast<const IDPSSODescriptor*>(*i)->getArtifactResolutionServices();
130             for (vector<ArtifactResolutionService*>::const_iterator loc = locs.begin(); loc != locs.end(); loc++) {
131                 auto_ptr_char location((*loc)->getLocation());
132                 if (location.get())
133                     m_sources.insert(sitemap_t::value_type(location.get(),site));
134             }
135         }
136         
137         // SAML 2.0?
138         if ((*i)->hasSupport(samlconstants::SAML20P_NS)) {
139             // Hash the ID.
140             m_sources.insert(sitemap_t::value_type(SecurityHelper::doHash("SHA1", id.get(), strlen(id.get())),site));
141         }
142     }
143 }
144
145 void AbstractMetadataProvider::indexGroup(EntitiesDescriptor* group, time_t& validUntil) const
146 {
147     // If child expires later than input, reset child, otherwise lower input to match.
148     if (validUntil < group->getValidUntilEpoch())
149         group->setValidUntil(validUntil);
150     else
151         validUntil = group->getValidUntilEpoch();
152
153     auto_ptr_char name(group->getName());
154     if (name.get()) {
155         m_groups.insert(groupmap_t::value_type(name.get(),group));
156     }
157     
158     // Track the smallest validUntil amongst the children.
159     time_t minValidUntil = validUntil;
160
161     const vector<EntitiesDescriptor*>& groups = const_cast<const EntitiesDescriptor*>(group)->getEntitiesDescriptors();
162     for (vector<EntitiesDescriptor*>::const_iterator i = groups.begin(); i != groups.end(); i++) {
163         // Use the current validUntil fence for each child, but track the smallest we find.
164         time_t subValidUntil = validUntil;
165         indexGroup(*i, subValidUntil);
166         if (subValidUntil < minValidUntil)
167             minValidUntil = subValidUntil;
168     }
169
170     const vector<EntityDescriptor*>& sites = const_cast<const EntitiesDescriptor*>(group)->getEntityDescriptors();
171     for (vector<EntityDescriptor*>::const_iterator j = sites.begin(); j != sites.end(); j++) {
172         // Use the current validUntil fence for each child, but track the smallest we find.
173         time_t subValidUntil = validUntil;
174         indexEntity(*j, subValidUntil);
175         if (subValidUntil < minValidUntil)
176             minValidUntil = subValidUntil;
177     }
178
179     // Pass back up the smallest child we found.
180     if (minValidUntil < validUntil)
181         validUntil = minValidUntil;
182 }
183
184 void AbstractMetadataProvider::index(EntityDescriptor* site, time_t validUntil, bool replace) const
185 {
186     indexEntity(site, validUntil, replace);
187 }
188
189 void AbstractMetadataProvider::index(EntitiesDescriptor* group, time_t validUntil) const
190 {
191     indexGroup(group, validUntil);
192 }
193
194 void AbstractMetadataProvider::clearDescriptorIndex(bool freeSites)
195 {
196     if (freeSites)
197         for_each(m_sites.begin(), m_sites.end(), cleanup_const_pair<string,EntityDescriptor>());
198     m_sites.clear();
199     m_groups.clear();
200     m_sources.clear();
201 }
202
203 const EntitiesDescriptor* AbstractMetadataProvider::getEntitiesDescriptor(const char* name, bool strict) const
204 {
205     pair<groupmap_t::const_iterator,groupmap_t::const_iterator> range=const_cast<const groupmap_t&>(m_groups).equal_range(name);
206
207     time_t now=time(nullptr);
208     for (groupmap_t::const_iterator i=range.first; i!=range.second; i++)
209         if (now < i->second->getValidUntilEpoch())
210             return i->second;
211     
212     if (range.first != range.second) {
213         Category& log = Category::getInstance(SAML_LOGCAT".MetadataProvider");
214         if (strict) {
215             log.warn("ignored expired metadata group (%s)", range.first->first.c_str());
216         }
217         else {
218             log.info("no valid metadata found, returning expired metadata group (%s)", range.first->first.c_str());
219             return range.first->second;
220         }
221     }
222
223     return nullptr;
224 }
225
226 pair<const EntityDescriptor*,const RoleDescriptor*> AbstractMetadataProvider::getEntityDescriptor(const Criteria& criteria) const
227 {
228     pair<sitemap_t::const_iterator,sitemap_t::const_iterator> range;
229     if (criteria.entityID_ascii)
230         range = const_cast<const sitemap_t&>(m_sites).equal_range(criteria.entityID_ascii);
231     else if (criteria.entityID_unicode) {
232         auto_ptr_char id(criteria.entityID_unicode);
233         range = const_cast<const sitemap_t&>(m_sites).equal_range(id.get());
234     }
235     else if (criteria.artifact)
236         range = const_cast<const sitemap_t&>(m_sources).equal_range(criteria.artifact->getSource());
237     else
238         return pair<const EntityDescriptor*,const RoleDescriptor*>(nullptr,nullptr);
239     
240     pair<const EntityDescriptor*,const RoleDescriptor*> result;
241     result.first = nullptr;
242     result.second = nullptr;
243     
244     time_t now=time(nullptr);
245     for (sitemap_t::const_iterator i=range.first; i!=range.second; i++) {
246         if (now < i->second->getValidUntilEpoch()) {
247             result.first = i->second;
248             break;
249         }
250     }
251     
252     if (!result.first && range.first!=range.second) {
253         Category& log = Category::getInstance(SAML_LOGCAT".MetadataProvider");
254         if (criteria.validOnly) {
255             log.warn("ignored expired metadata instance for (%s)", range.first->first.c_str());
256         }
257         else {
258             log.info("no valid metadata found, returning expired instance for (%s)", range.first->first.c_str());
259             result.first = range.first->second;
260         }
261     }
262
263     if (result.first && criteria.role) {
264         result.second = result.first->getRoleDescriptor(*criteria.role, criteria.protocol);
265         if (!result.second && criteria.protocol2)
266             result.second = result.first->getRoleDescriptor(*criteria.role, criteria.protocol2);
267     }
268     
269     return result;
270 }
271
272 const Credential* AbstractMetadataProvider::resolve(const CredentialCriteria* criteria) const
273 {
274     const MetadataCredentialCriteria* metacrit = dynamic_cast<const MetadataCredentialCriteria*>(criteria);
275     if (!metacrit)
276         throw MetadataException("Cannot resolve credentials without a MetadataCredentialCriteria object.");
277
278     Lock lock(m_credentialLock);
279     const credmap_t::mapped_type& creds = resolveCredentials(metacrit->getRole());
280
281     for (credmap_t::mapped_type::const_iterator c = creds.begin(); c!=creds.end(); ++c)
282         if (metacrit->matches(*(*c)))
283             return *c;
284     return nullptr;
285 }
286
287 vector<const Credential*>::size_type AbstractMetadataProvider::resolve(
288     vector<const Credential*>& results, const CredentialCriteria* criteria
289     ) const
290 {
291     const MetadataCredentialCriteria* metacrit = dynamic_cast<const MetadataCredentialCriteria*>(criteria);
292     if (!metacrit)
293         throw MetadataException("Cannot resolve credentials without a MetadataCredentialCriteria object.");
294
295     Lock lock(m_credentialLock);
296     const credmap_t::mapped_type& creds = resolveCredentials(metacrit->getRole());
297
298     for (credmap_t::mapped_type::const_iterator c = creds.begin(); c!=creds.end(); ++c)
299         if (metacrit->matches(*(*c)))
300             results.push_back(*c);
301     return results.size();
302 }
303
304 const AbstractMetadataProvider::credmap_t::mapped_type& AbstractMetadataProvider::resolveCredentials(const RoleDescriptor& role) const
305 {
306     credmap_t::const_iterator i = m_credentialMap.find(&role);
307     if (i!=m_credentialMap.end())
308         return i->second;
309
310     const KeyInfoResolver* resolver = m_resolver ? m_resolver : XMLToolingConfig::getConfig().getKeyInfoResolver();
311     const vector<KeyDescriptor*>& keys = role.getKeyDescriptors();
312     AbstractMetadataProvider::credmap_t::mapped_type& resolved = m_credentialMap[&role];
313     for (vector<KeyDescriptor*>::const_iterator k = keys.begin(); k!=keys.end(); ++k) {
314         if ((*k)->getKeyInfo()) {
315             auto_ptr<MetadataCredentialContext> mcc(new MetadataCredentialContext(*(*k)));
316             Credential* c = resolver->resolve(mcc.get());
317             mcc.release();
318             resolved.push_back(c);
319         }
320     }
321     return resolved;
322 }