0ba3c0e2ccedd21c29d4b293d7a2e6c96a62d92e
[shibboleth/cpp-opensaml.git] / saml / saml2 / metadata / impl / ChainingMetadataProvider.cpp
1 /*
2  *  Copyright 2001-2010 Internet2
3  * 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /**
18  * ChainingMetadataProvider.cpp
19  * 
20  * MetadataProvider that uses multiple providers in sequence.
21  */
22
23 #include "internal.h"
24 #include "exceptions.h"
25 #include "saml/binding/SAMLArtifact.h"
26 #include "saml2/metadata/Metadata.h"
27 #include "saml2/metadata/DiscoverableMetadataProvider.h"
28 #include "saml2/metadata/ObservableMetadataProvider.h"
29 #include "saml2/metadata/MetadataCredentialCriteria.h"
30
31 #include <memory>
32 #include <xercesc/util/XMLUniDefs.hpp>
33 #include <xmltooling/logging.h>
34 #include <xmltooling/util/Threads.h>
35 #include <xmltooling/util/XMLHelper.h>
36
37
38 using namespace opensaml::saml2md;
39 using namespace opensaml;
40 using namespace xmlsignature;
41 using namespace xmltooling::logging;
42 using namespace xmltooling;
43 using namespace std;
44
45 namespace opensaml {
46     namespace saml2md {
47
48         // per-thread structure allocated to track locks and role->provider mappings
49         struct SAML_DLLLOCAL tracker_t;
50         
51         class SAML_DLLLOCAL ChainingMetadataProvider
52             : public DiscoverableMetadataProvider, public ObservableMetadataProvider, public ObservableMetadataProvider::Observer {
53         public:
54             ChainingMetadataProvider(const xercesc::DOMElement* e=nullptr);
55             virtual ~ChainingMetadataProvider();
56     
57             using MetadataProvider::getEntityDescriptor;
58             using MetadataProvider::getEntitiesDescriptor;
59
60             Lockable* lock();
61             void unlock();
62             void init();
63             const XMLObject* getMetadata() const;
64             const EntitiesDescriptor* getEntitiesDescriptor(const char* name, bool requireValidMetadata=true) const;
65             pair<const EntityDescriptor*,const RoleDescriptor*> getEntityDescriptor(const Criteria& criteria) const;
66     
67             const Credential* resolve(const CredentialCriteria* criteria=nullptr) const;
68             vector<const Credential*>::size_type resolve(vector<const Credential*>& results, const CredentialCriteria* criteria=nullptr) const;
69
70             string getCacheTag() const {
71                 Lock lock(m_trackerLock);
72                 return m_feedTag;
73             }
74
75             void outputFeed(ostream& os, bool& first, bool wrapArray=true) const {
76                 if (wrapArray)
77                     os << '[';
78                 // Lock each provider in turn and suck in its feed.
79                 for (vector<MetadataProvider*>::const_iterator m = m_providers.begin(); m != m_providers.end(); ++m) {
80                     DiscoverableMetadataProvider* d = dynamic_cast<DiscoverableMetadataProvider*>(*m);
81                     if (d) {
82                         Locker locker(d);
83                         d->outputFeed(os, first, false);
84                     }
85                 }
86                 if (wrapArray)
87                     os << "\n]";
88             }
89
90             void onEvent(const ObservableMetadataProvider& provider) const {
91                 // Reset the cache tag for the feed.
92                 Lock lock(m_trackerLock);
93                 SAMLConfig::getConfig().generateRandomBytes(m_feedTag, 4);
94                 m_feedTag = SAMLArtifact::toHex(m_feedTag);
95                 emitChangeEvent();
96             }
97
98         protected:
99             void generateFeed() {
100                 // No-op.
101             }
102
103         private:
104             bool m_firstMatch;
105             mutable Mutex* m_trackerLock;
106             ThreadKey* m_tlsKey;
107             vector<MetadataProvider*> m_providers;
108             mutable set<tracker_t*> m_trackers;
109             static void tracker_cleanup(void*);
110             Category& m_log;
111             friend struct tracker_t;
112         };
113
114         struct SAML_DLLLOCAL tracker_t {
115             tracker_t(const ChainingMetadataProvider* m) : m_metadata(m) {
116                 Lock lock(m_metadata->m_trackerLock);
117                 m_metadata->m_trackers.insert(this);
118             }
119
120             void lock_if(MetadataProvider* m) {
121                 if (m_locked.count(m) == 0)
122                     m->lock();
123             }
124
125             void unlock_if(MetadataProvider* m) {
126                 if (m_locked.count(m) == 0)
127                     m->unlock();
128             }
129
130             void remember(MetadataProvider* m, const EntityDescriptor* entity=nullptr) {
131                 m_locked.insert(m);
132                 if (entity)
133                     m_objectMap.insert(pair<const XMLObject*,const MetadataProvider*>(entity,m));
134             }
135
136             const MetadataProvider* getProvider(const RoleDescriptor& role) {
137                 map<const XMLObject*,const MetadataProvider*>::const_iterator i = m_objectMap.find(role.getParent());
138                 return (i != m_objectMap.end()) ? i->second : nullptr;
139             }
140
141             const ChainingMetadataProvider* m_metadata;
142             set<MetadataProvider*> m_locked;
143             map<const XMLObject*,const MetadataProvider*> m_objectMap;
144         };
145
146         MetadataProvider* SAML_DLLLOCAL ChainingMetadataProviderFactory(const DOMElement* const & e)
147         {
148             return new ChainingMetadataProvider(e);
149         }
150
151         static const XMLCh _MetadataProvider[] =    UNICODE_LITERAL_16(M,e,t,a,d,a,t,a,P,r,o,v,i,d,e,r);
152         static const XMLCh precedence[] =           UNICODE_LITERAL_10(p,r,e,c,e,d,e,n,c,e);
153         static const XMLCh last[] =                 UNICODE_LITERAL_4(l,a,s,t);
154         static const XMLCh _type[] =                 UNICODE_LITERAL_4(t,y,p,e);
155     };
156 };
157
158 void ChainingMetadataProvider::tracker_cleanup(void* ptr)
159 {
160     if (ptr) {
161         // free the tracker after removing it from the parent plugin's tracker set
162         tracker_t* t = reinterpret_cast<tracker_t*>(ptr);
163         Lock lock(t->m_metadata->m_trackerLock);
164         t->m_metadata->m_trackers.erase(t);
165         delete t;
166     }
167 }
168
169 ChainingMetadataProvider::ChainingMetadataProvider(const DOMElement* e)
170     : ObservableMetadataProvider(e), m_firstMatch(true), m_trackerLock(nullptr), m_tlsKey(nullptr),
171         m_log(Category::getInstance(SAML_LOGCAT".Metadata.Chaining"))
172 {
173     if (XMLString::equals(e ? e->getAttributeNS(nullptr, precedence) : nullptr, last))
174         m_firstMatch = false;
175
176     e = XMLHelper::getFirstChildElement(e, _MetadataProvider);
177     while (e) {
178         string t = XMLHelper::getAttrString(e, nullptr, _type);
179         if (!t.empty()) {
180             try {
181                 m_log.info("building MetadataProvider of type %s", t.c_str());
182                 auto_ptr<MetadataProvider> provider(SAMLConfig::getConfig().MetadataProviderManager.newPlugin(t.c_str(), e));
183                 ObservableMetadataProvider* obs = dynamic_cast<ObservableMetadataProvider*>(provider.get());
184                 if (obs)
185                     obs->addObserver(this);
186                 m_providers.push_back(provider.get());
187                 provider.release();
188             }
189             catch (exception& ex) {
190                 m_log.error("error building MetadataProvider: %s", ex.what());
191             }
192         }
193         e = XMLHelper::getNextSiblingElement(e, _MetadataProvider);
194     }
195     m_trackerLock = Mutex::create();
196     m_tlsKey = ThreadKey::create(tracker_cleanup);
197 }
198
199 ChainingMetadataProvider::~ChainingMetadataProvider()
200 {
201     delete m_tlsKey;
202     delete m_trackerLock;
203     for_each(m_trackers.begin(), m_trackers.end(), xmltooling::cleanup<tracker_t>());
204     for_each(m_providers.begin(), m_providers.end(), xmltooling::cleanup<MetadataProvider>());
205 }
206
207 void ChainingMetadataProvider::init()
208 {
209     for (vector<MetadataProvider*>::const_iterator i=m_providers.begin(); i!=m_providers.end(); ++i) {
210         try {
211             (*i)->init();
212         }
213         catch (exception& ex) {
214             m_log.crit("failure initializing MetadataProvider: %s", ex.what());
215         }
216     }
217
218     // Set an initial cache tag for the state of the plugins.
219     SAMLConfig::getConfig().generateRandomBytes(m_feedTag, 4);
220     m_feedTag = SAMLArtifact::toHex(m_feedTag);
221 }
222
223 Lockable* ChainingMetadataProvider::lock()
224 {
225     return this;   // we're not lockable ourselves...
226 }
227
228 void ChainingMetadataProvider::unlock()
229 {
230     // Check for locked providers and remove role mappings.
231     void* ptr=m_tlsKey->getData();
232     if (ptr) {
233         tracker_t* t = reinterpret_cast<tracker_t*>(ptr);
234         for_each(t->m_locked.begin(), t->m_locked.end(), mem_fun<void,Lockable>(&Lockable::unlock));
235         t->m_locked.clear();
236         t->m_objectMap.clear();
237     }
238 }
239
240 const XMLObject* ChainingMetadataProvider::getMetadata() const
241 {
242     throw MetadataException("getMetadata operation not implemented on this provider.");
243 }
244
245 const EntitiesDescriptor* ChainingMetadataProvider::getEntitiesDescriptor(const char* name, bool requireValidMetadata) const
246 {
247     // Ensure we have a tracker to use.
248     tracker_t* tracker = nullptr;
249     void* ptr=m_tlsKey->getData();
250     if (ptr) {
251         tracker = reinterpret_cast<tracker_t*>(ptr);
252     }
253     else {
254         tracker = new tracker_t(this);
255         m_tlsKey->setData(tracker);
256     }
257
258     MetadataProvider* held = nullptr;
259     const EntitiesDescriptor* ret = nullptr;
260     const EntitiesDescriptor* cur = nullptr;
261     for (vector<MetadataProvider*>::const_iterator i=m_providers.begin(); i!=m_providers.end(); ++i) {
262         tracker->lock_if(*i);
263         if (cur=(*i)->getEntitiesDescriptor(name,requireValidMetadata)) {
264             // Are we using a first match policy?
265             if (m_firstMatch) {
266                 // Save locked provider.
267                 tracker->remember(*i);
268                 return cur;
269             }
270
271             // Using last match wins. Did we already have one?
272             if (held) {
273                 m_log.warn("found duplicate EntitiesDescriptor (%s), using last matching copy", name);
274                 tracker->unlock_if(held);
275             }
276
277             // Save off the latest match.
278             held = *i;
279             ret = cur;
280         }
281         else {
282             // No match, so just unlock this one and move on.
283             tracker->unlock_if(*i);
284         }
285     }
286
287     // Preserve any lock we're holding.
288     if (held)
289         tracker->remember(held);
290     return ret;
291 }
292
293 pair<const EntityDescriptor*,const RoleDescriptor*> ChainingMetadataProvider::getEntityDescriptor(const Criteria& criteria) const
294 {
295     // Ensure we have a tracker to use.
296     tracker_t* tracker = nullptr;
297     void* ptr=m_tlsKey->getData();
298     if (ptr) {
299         tracker = reinterpret_cast<tracker_t*>(ptr);
300     }
301     else {
302         tracker = new tracker_t(this);
303         m_tlsKey->setData(tracker);
304     }
305
306     // Do a search.
307     MetadataProvider* held = nullptr;
308     pair<const EntityDescriptor*,const RoleDescriptor*> ret = pair<const EntityDescriptor*,const RoleDescriptor*>(nullptr,nullptr);
309     pair<const EntityDescriptor*,const RoleDescriptor*> cur = ret;
310     for (vector<MetadataProvider*>::const_iterator i=m_providers.begin(); i!=m_providers.end(); ++i) {
311         tracker->lock_if(*i);
312         cur = (*i)->getEntityDescriptor(criteria);
313         if (cur.first) {
314             if (criteria.role) {
315                 // We want a role also. Did we find one?
316                 if (cur.second) {
317                     // Are we using a first match policy?
318                     if (m_firstMatch) {
319                         // We could have an entity-only match from earlier, so unlock it.
320                         if (held)
321                             tracker->unlock_if(held);
322                         // Save locked provider and role mapping.
323                         tracker->remember(*i, cur.first);
324                         return cur;
325                     }
326
327                     // Using last match wins. Did we already have one?
328                     if (held) {
329                         if (ret.second) {
330                             // We had a "complete" match, so log it.
331                             if (criteria.entityID_ascii) {
332                                 m_log.warn("found duplicate EntityDescriptor (%s) with role (%s), using last matching copy",
333                                     criteria.entityID_ascii, criteria.role->toString().c_str());
334                             }
335                             else if (criteria.entityID_unicode) {
336                                 auto_ptr_char temp(criteria.entityID_unicode);
337                                 m_log.warn("found duplicate EntityDescriptor (%s) with role (%s), using last matching copy",
338                                     temp.get(), criteria.role->toString().c_str());
339                             }
340                             else if (criteria.artifact) {
341                                 m_log.warn("found duplicate EntityDescriptor for artifact source (%s) with role (%s), using last matching copy",
342                                     criteria.artifact->getSource().c_str(), criteria.role->toString().c_str());
343                             }
344                         }
345                         tracker->unlock_if(held);
346                     }
347
348                     // Save off the latest match.
349                     held = *i;
350                     ret = cur;
351                 }
352                 else {
353                     // We didn't find the role, so we're going to keep looking,
354                     // but save this one if we didn't have the role yet.
355                     if (ret.second) {
356                         // We already had a role, so let's stick with that.
357                         tracker->unlock_if(*i);
358                     }
359                     else {
360                         // This is at least as good, so toss anything we had and keep it.
361                         if (held)
362                             tracker->unlock_if(held);
363                         held = *i;
364                         ret = cur;
365                     }
366                 }
367             }
368             else {
369                 // Are we using a first match policy?
370                 if (m_firstMatch) {
371                     // I don't think this can happen, but who cares, check anyway.
372                     if (held)
373                         tracker->unlock_if(held);
374                     
375                     // Save locked provider.
376                     tracker->remember(*i, cur.first);
377                     return cur;
378                 }
379
380                 // Using last match wins. Did we already have one?
381                 if (held) {
382                     if (criteria.entityID_ascii) {
383                         m_log.warn("found duplicate EntityDescriptor (%s), using last matching copy", criteria.entityID_ascii);
384                     }
385                     else if (criteria.entityID_unicode) {
386                         auto_ptr_char temp(criteria.entityID_unicode);
387                         m_log.warn("found duplicate EntityDescriptor (%s), using last matching copy", temp.get());
388                     }
389                     else if (criteria.artifact) {
390                         m_log.warn("found duplicate EntityDescriptor for artifact source (%s), using last matching copy",
391                             criteria.artifact->getSource().c_str());
392                     }
393                     tracker->unlock_if(held);
394                 }
395
396                 // Save off the latest match.
397                 held = *i;
398                 ret = cur;
399             }
400         }
401         else {
402             // No match, so just unlock this one and move on.
403             tracker->unlock_if(*i);
404         }
405     }
406
407     // Preserve any lock we're holding.
408     if (held)
409         tracker->remember(held, ret.first);
410     return ret;
411 }
412
413 const Credential* ChainingMetadataProvider::resolve(const CredentialCriteria* criteria) const
414 {
415     void* ptr=m_tlsKey->getData();
416     if (!ptr)
417         throw MetadataException("No locked MetadataProvider, where did the role object come from?");
418     tracker_t* tracker=reinterpret_cast<tracker_t*>(ptr);
419
420     const MetadataCredentialCriteria* mcc = dynamic_cast<const MetadataCredentialCriteria*>(criteria);
421     if (!mcc)
422         throw MetadataException("Cannot resolve credentials without a MetadataCredentialCriteria object.");
423     const MetadataProvider* m = tracker->getProvider(mcc->getRole());
424     if (!m)
425         throw MetadataException("No record of corresponding MetadataProvider, where did the role object come from?");
426     return m->resolve(mcc);
427 }
428
429 vector<const Credential*>::size_type ChainingMetadataProvider::resolve(
430     vector<const Credential*>& results, const CredentialCriteria* criteria
431     ) const
432 {
433     void* ptr=m_tlsKey->getData();
434     if (!ptr)
435         throw MetadataException("No locked MetadataProvider, where did the role object come from?");
436     tracker_t* tracker=reinterpret_cast<tracker_t*>(ptr);
437
438     const MetadataCredentialCriteria* mcc = dynamic_cast<const MetadataCredentialCriteria*>(criteria);
439     if (!mcc)
440         throw MetadataException("Cannot resolve credentials without a MetadataCredentialCriteria object.");
441     const MetadataProvider* m = tracker->getProvider(mcc->getRole());
442     if (!m)
443         throw MetadataException("No record of corresponding MetadataProvider, where did the role object come from?");
444     return m->resolve(results, mcc);
445 }