Change license header, remove stale pkg files.
[shibboleth/cpp-opensaml.git] / saml / saml2 / metadata / impl / AbstractMetadataProvider.cpp
1 /**
2  * Licensed to the University Corporation for Advanced Internet
3  * Development, Inc. (UCAID) under one or more contributor license
4  * agreements. See the NOTICE file distributed with this work for
5  * additional information regarding copyright ownership.
6  *
7  * UCAID licenses this file to you under the Apache License,
8  * Version 2.0 (the "License"); you may not use this file except
9  * in compliance with the License. You may obtain a copy of the
10  * License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing,
15  * software distributed under the License is distributed on an
16  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
17  * either express or implied. See the License for the specific
18  * language governing permissions and limitations under the License.
19  */
20
21 /**
22  * AbstractMetadataProvider.cpp
23  * 
24  * Base class for caching metadata providers.
25  */
26
27 #include "internal.h"
28 #include "binding/SAMLArtifact.h"
29 #include "saml2/metadata/Metadata.h"
30 #include "saml2/metadata/AbstractMetadataProvider.h"
31 #include "saml2/metadata/MetadataCredentialContext.h"
32 #include "saml2/metadata/MetadataCredentialCriteria.h"
33
34 #include <xercesc/util/XMLUniDefs.hpp>
35 #include <xmltooling/logging.h>
36 #include <xmltooling/XMLToolingConfig.h>
37 #include <xmltooling/security/Credential.h>
38 #include <xmltooling/security/KeyInfoResolver.h>
39 #include <xmltooling/security/SecurityHelper.h>
40 #include <xmltooling/util/Threads.h>
41 #include <xmltooling/util/XMLHelper.h>
42
43 using namespace opensaml::saml2md;
44 using namespace xmltooling::logging;
45 using namespace xmltooling;
46 using namespace std;
47 using opensaml::SAMLArtifact;
48
49 static const XMLCh _KeyInfoResolver[] = UNICODE_LITERAL_15(K,e,y,I,n,f,o,R,e,s,o,l,v,e,r);
50 static const XMLCh type[] =             UNICODE_LITERAL_4(t,y,p,e);
51
52 AbstractMetadataProvider::AbstractMetadataProvider(const DOMElement* e)
53     : ObservableMetadataProvider(e), m_resolver(nullptr), m_credentialLock(nullptr)
54 {
55     e = XMLHelper::getFirstChildElement(e, _KeyInfoResolver);
56     if (e) {
57         string t = XMLHelper::getAttrString(e, nullptr, type);
58         if (!t.empty())
59             m_resolver = XMLToolingConfig::getConfig().KeyInfoResolverManager.newPlugin(t.c_str(), e);
60         else
61             throw UnknownExtensionException("<KeyInfoResolver> element found with no type attribute");
62     }
63     m_credentialLock = Mutex::create();
64 }
65
66 AbstractMetadataProvider::~AbstractMetadataProvider()
67 {
68     for (credmap_t::iterator c = m_credentialMap.begin(); c!=m_credentialMap.end(); ++c)
69         for_each(c->second.begin(), c->second.end(), xmltooling::cleanup<Credential>());
70     delete m_credentialLock;
71     delete m_resolver;
72 }
73
74 void AbstractMetadataProvider::emitChangeEvent() const
75 {
76     for (credmap_t::iterator c = m_credentialMap.begin(); c!=m_credentialMap.end(); ++c)
77         for_each(c->second.begin(), c->second.end(), xmltooling::cleanup<Credential>());
78     m_credentialMap.clear();
79     ObservableMetadataProvider::emitChangeEvent();
80 }
81
82 void AbstractMetadataProvider::indexEntity(EntityDescriptor* site, time_t& validUntil, bool replace) const
83 {
84     // If child expires later than input, reset child, otherwise lower input to match.
85     if (validUntil < site->getValidUntilEpoch())
86         site->setValidUntil(validUntil);
87     else
88         validUntil = site->getValidUntilEpoch();
89
90     auto_ptr_char id(site->getEntityID());
91     if (id.get()) {
92         if (replace) {
93             m_sites.erase(id.get());
94             for (sitemap_t::iterator s = m_sources.begin(); s != m_sources.end();) {
95                 if (s->second == site) {
96                     sitemap_t::iterator temp = s;
97                     ++s;
98                     m_sources.erase(temp);
99                 }
100                 else {
101                     ++s;
102                 }
103             }
104         }
105         m_sites.insert(sitemap_t::value_type(id.get(),site));
106     }
107     
108     // Process each IdP role.
109     const vector<IDPSSODescriptor*>& roles = const_cast<const EntityDescriptor*>(site)->getIDPSSODescriptors();
110     for (vector<IDPSSODescriptor*>::const_iterator i = roles.begin(); i != roles.end(); i++) {
111         // SAML 1.x?
112         if ((*i)->hasSupport(samlconstants::SAML10_PROTOCOL_ENUM) || (*i)->hasSupport(samlconstants::SAML11_PROTOCOL_ENUM)) {
113             // Check for SourceID extension element.
114             const Extensions* exts = (*i)->getExtensions();
115             if (exts && exts->hasChildren()) {
116                 const vector<XMLObject*>& children = exts->getUnknownXMLObjects();
117                 for (vector<XMLObject*>::const_iterator ext = children.begin(); ext != children.end(); ++ext) {
118                     SourceID* sid = dynamic_cast<SourceID*>(*ext);
119                     if (sid) {
120                         auto_ptr_char sourceid(sid->getID());
121                         if (sourceid.get()) {
122                             m_sources.insert(sitemap_t::value_type(sourceid.get(),site));
123                             break;
124                         }
125                     }
126                 }
127             }
128             
129             // Hash the ID.
130             m_sources.insert(sitemap_t::value_type(SecurityHelper::doHash("SHA1", id.get(), strlen(id.get())),site));
131                 
132             // Load endpoints for type 0x0002 artifacts.
133             const vector<ArtifactResolutionService*>& locs = const_cast<const IDPSSODescriptor*>(*i)->getArtifactResolutionServices();
134             for (vector<ArtifactResolutionService*>::const_iterator loc = locs.begin(); loc != locs.end(); loc++) {
135                 auto_ptr_char location((*loc)->getLocation());
136                 if (location.get())
137                     m_sources.insert(sitemap_t::value_type(location.get(),site));
138             }
139         }
140         
141         // SAML 2.0?
142         if ((*i)->hasSupport(samlconstants::SAML20P_NS)) {
143             // Hash the ID.
144             m_sources.insert(sitemap_t::value_type(SecurityHelper::doHash("SHA1", id.get(), strlen(id.get())),site));
145         }
146     }
147 }
148
149 void AbstractMetadataProvider::indexGroup(EntitiesDescriptor* group, time_t& validUntil) const
150 {
151     // If child expires later than input, reset child, otherwise lower input to match.
152     if (validUntil < group->getValidUntilEpoch())
153         group->setValidUntil(validUntil);
154     else
155         validUntil = group->getValidUntilEpoch();
156
157     auto_ptr_char name(group->getName());
158     if (name.get()) {
159         m_groups.insert(groupmap_t::value_type(name.get(),group));
160     }
161     
162     // Track the smallest validUntil amongst the children.
163     time_t minValidUntil = validUntil;
164
165     const vector<EntitiesDescriptor*>& groups = const_cast<const EntitiesDescriptor*>(group)->getEntitiesDescriptors();
166     for (vector<EntitiesDescriptor*>::const_iterator i = groups.begin(); i != groups.end(); i++) {
167         // Use the current validUntil fence for each child, but track the smallest we find.
168         time_t subValidUntil = validUntil;
169         indexGroup(*i, subValidUntil);
170         if (subValidUntil < minValidUntil)
171             minValidUntil = subValidUntil;
172     }
173
174     const vector<EntityDescriptor*>& sites = const_cast<const EntitiesDescriptor*>(group)->getEntityDescriptors();
175     for (vector<EntityDescriptor*>::const_iterator j = sites.begin(); j != sites.end(); j++) {
176         // Use the current validUntil fence for each child, but track the smallest we find.
177         time_t subValidUntil = validUntil;
178         indexEntity(*j, subValidUntil);
179         if (subValidUntil < minValidUntil)
180             minValidUntil = subValidUntil;
181     }
182
183     // Pass back up the smallest child we found.
184     if (minValidUntil < validUntil)
185         validUntil = minValidUntil;
186 }
187
188 void AbstractMetadataProvider::index(EntityDescriptor* site, time_t validUntil, bool replace) const
189 {
190     indexEntity(site, validUntil, replace);
191 }
192
193 void AbstractMetadataProvider::index(EntitiesDescriptor* group, time_t validUntil) const
194 {
195     indexGroup(group, validUntil);
196 }
197
198 void AbstractMetadataProvider::clearDescriptorIndex(bool freeSites)
199 {
200     if (freeSites)
201         for_each(m_sites.begin(), m_sites.end(), cleanup_const_pair<string,EntityDescriptor>());
202     m_sites.clear();
203     m_groups.clear();
204     m_sources.clear();
205 }
206
207 const EntitiesDescriptor* AbstractMetadataProvider::getEntitiesDescriptor(const char* name, bool strict) const
208 {
209     pair<groupmap_t::const_iterator,groupmap_t::const_iterator> range=const_cast<const groupmap_t&>(m_groups).equal_range(name);
210
211     time_t now=time(nullptr);
212     for (groupmap_t::const_iterator i=range.first; i!=range.second; i++)
213         if (now < i->second->getValidUntilEpoch())
214             return i->second;
215     
216     if (range.first != range.second) {
217         Category& log = Category::getInstance(SAML_LOGCAT".MetadataProvider");
218         if (strict) {
219             log.warn("ignored expired metadata group (%s)", range.first->first.c_str());
220         }
221         else {
222             log.info("no valid metadata found, returning expired metadata group (%s)", range.first->first.c_str());
223             return range.first->second;
224         }
225     }
226
227     return nullptr;
228 }
229
230 pair<const EntityDescriptor*,const RoleDescriptor*> AbstractMetadataProvider::getEntityDescriptor(const Criteria& criteria) const
231 {
232     pair<sitemap_t::const_iterator,sitemap_t::const_iterator> range;
233     if (criteria.entityID_ascii)
234         range = const_cast<const sitemap_t&>(m_sites).equal_range(criteria.entityID_ascii);
235     else if (criteria.entityID_unicode) {
236         auto_ptr_char id(criteria.entityID_unicode);
237         range = const_cast<const sitemap_t&>(m_sites).equal_range(id.get());
238     }
239     else if (criteria.artifact)
240         range = const_cast<const sitemap_t&>(m_sources).equal_range(criteria.artifact->getSource());
241     else
242         return pair<const EntityDescriptor*,const RoleDescriptor*>(nullptr,nullptr);
243     
244     pair<const EntityDescriptor*,const RoleDescriptor*> result;
245     result.first = nullptr;
246     result.second = nullptr;
247     
248     time_t now=time(nullptr);
249     for (sitemap_t::const_iterator i=range.first; i!=range.second; i++) {
250         if (now < i->second->getValidUntilEpoch()) {
251             result.first = i->second;
252             break;
253         }
254     }
255     
256     if (!result.first && range.first!=range.second) {
257         Category& log = Category::getInstance(SAML_LOGCAT".MetadataProvider");
258         if (criteria.validOnly) {
259             log.warn("ignored expired metadata instance for (%s)", range.first->first.c_str());
260         }
261         else {
262             log.info("no valid metadata found, returning expired instance for (%s)", range.first->first.c_str());
263             result.first = range.first->second;
264         }
265     }
266
267     if (result.first && criteria.role) {
268         result.second = result.first->getRoleDescriptor(*criteria.role, criteria.protocol);
269         if (!result.second && criteria.protocol2)
270             result.second = result.first->getRoleDescriptor(*criteria.role, criteria.protocol2);
271     }
272     
273     return result;
274 }
275
276 const Credential* AbstractMetadataProvider::resolve(const CredentialCriteria* criteria) const
277 {
278     const MetadataCredentialCriteria* metacrit = dynamic_cast<const MetadataCredentialCriteria*>(criteria);
279     if (!metacrit)
280         throw MetadataException("Cannot resolve credentials without a MetadataCredentialCriteria object.");
281
282     Lock lock(m_credentialLock);
283     const credmap_t::mapped_type& creds = resolveCredentials(metacrit->getRole());
284
285     for (credmap_t::mapped_type::const_iterator c = creds.begin(); c!=creds.end(); ++c)
286         if (metacrit->matches(*(*c)))
287             return *c;
288     return nullptr;
289 }
290
291 vector<const Credential*>::size_type AbstractMetadataProvider::resolve(
292     vector<const Credential*>& results, const CredentialCriteria* criteria
293     ) const
294 {
295     const MetadataCredentialCriteria* metacrit = dynamic_cast<const MetadataCredentialCriteria*>(criteria);
296     if (!metacrit)
297         throw MetadataException("Cannot resolve credentials without a MetadataCredentialCriteria object.");
298
299     Lock lock(m_credentialLock);
300     const credmap_t::mapped_type& creds = resolveCredentials(metacrit->getRole());
301
302     for (credmap_t::mapped_type::const_iterator c = creds.begin(); c!=creds.end(); ++c)
303         if (metacrit->matches(*(*c)))
304             results.push_back(*c);
305     return results.size();
306 }
307
308 const AbstractMetadataProvider::credmap_t::mapped_type& AbstractMetadataProvider::resolveCredentials(const RoleDescriptor& role) const
309 {
310     credmap_t::const_iterator i = m_credentialMap.find(&role);
311     if (i!=m_credentialMap.end())
312         return i->second;
313
314     const KeyInfoResolver* resolver = m_resolver ? m_resolver : XMLToolingConfig::getConfig().getKeyInfoResolver();
315     const vector<KeyDescriptor*>& keys = role.getKeyDescriptors();
316     AbstractMetadataProvider::credmap_t::mapped_type& resolved = m_credentialMap[&role];
317     for (vector<KeyDescriptor*>::const_iterator k = keys.begin(); k!=keys.end(); ++k) {
318         if ((*k)->getKeyInfo()) {
319             auto_ptr<MetadataCredentialContext> mcc(new MetadataCredentialContext(*(*k)));
320             Credential* c = resolver->resolve(mcc.get());
321             mcc.release();
322             resolved.push_back(c);
323         }
324     }
325     return resolved;
326 }