2b2661839ae95a4ce79e2e640ceef54b315fc82e
[shibboleth/cpp-opensaml.git] / saml / saml2 / metadata / impl / DiscoverableMetadataProvider.cpp
1 /**
2  * Licensed to the University Corporation for Advanced Internet
3  * Development, Inc. (UCAID) under one or more contributor license
4  * agreements. See the NOTICE file distributed with this work for
5  * additional information regarding copyright ownership.
6  *
7  * UCAID licenses this file to you under the Apache License,
8  * Version 2.0 (the "License"); you may not use this file except
9  * in compliance with the License. You may obtain a copy of the
10  * License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing,
15  * software distributed under the License is distributed on an
16  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
17  * either express or implied. See the License for the specific
18  * language governing permissions and limitations under the License.
19  */
20
21 /**
22  * DiscoverableMetadataProvider.cpp
23  *
24  * A metadata provider that provides a JSON feed of IdP discovery information.
25  */
26
27 #include "internal.h"
28 #include "binding/SAMLArtifact.h"
29 #include "saml2/metadata/EntityMatcher.h"
30 #include "saml2/metadata/Metadata.h"
31 #include "saml2/metadata/DiscoverableMetadataProvider.h"
32
33 #include <fstream>
34 #include <sstream>
35 #include <boost/lambda/bind.hpp>
36 #include <boost/lambda/casts.hpp>
37 #include <boost/lambda/lambda.hpp>
38 #include <boost/iterator/indirect_iterator.hpp>
39 #include <xmltooling/logging.h>
40 #include <xmltooling/XMLToolingConfig.h>
41
42 using namespace opensaml::saml2;
43 using namespace opensaml::saml2md;
44 using namespace xmltooling::logging;
45 using namespace xmltooling;
46 using namespace boost::lambda;
47 using namespace boost;
48 using namespace std;
49
50 DiscoverableMetadataProvider::DiscoverableMetadataProvider(const DOMElement* e) : MetadataProvider(e), m_legacyOrgNames(false)
51 {
52     static const XMLCh legacyOrgNames[] =   UNICODE_LITERAL_14(l,e,g,a,c,y,O,r,g,N,a,m,e,s);
53     static const XMLCh matcher[] =          UNICODE_LITERAL_7(m,a,t,c,h,e,r);
54     static const XMLCh tagsInFeed[] =       UNICODE_LITERAL_10(t,a,g,s,I,n,F,e,e,d);
55     static const XMLCh _type[] =            UNICODE_LITERAL_4(t,y,p,e);
56     static const XMLCh DiscoveryFilter[] =  UNICODE_LITERAL_15(D,i,s,c,o,v,e,r,y,F,i,l,t,e,r);
57
58     m_legacyOrgNames = XMLHelper::getAttrBool(e, false, legacyOrgNames);
59     m_entityAttributes = XMLHelper::getAttrBool(e, false, tagsInFeed);
60
61     e = e ? XMLHelper::getFirstChildElement(e, DiscoveryFilter) : nullptr;
62     while (e) {
63         string t(XMLHelper::getAttrString(e, nullptr, _type));
64         if (t == "Whitelist" || t == "Blacklist") {
65             string m(XMLHelper::getAttrString(e, nullptr, matcher));
66             if (!m.empty()) {
67                 try {
68                     boost::shared_ptr<EntityMatcher> temp(SAMLConfig::getConfig().EntityMatcherManager.newPlugin(m, e));
69                     m_discoFilters.push_back(make_pair(t == "Whitelist", temp));
70                 }
71                 catch (std::exception& ex) {
72                     Category::getInstance(SAML_LOGCAT".MetadataProvider.Discoverable").error(
73                         "exception creating <DiscoveryFilter> EntityMatcher: %s", ex.what()
74                         );
75                 }
76             }
77             else {
78                 Category::getInstance(SAML_LOGCAT".MetadataProvider.Discoverable").error("<DiscoveryFilter> requires matcher attribute");
79             }
80         }
81         else {
82             Category::getInstance(SAML_LOGCAT".MetadataProvider.Discoverable").error(
83                 "unknown <DiscoveryFilter> type (%s)", t.empty() ? "none" : t.c_str()
84                 );
85         }
86         e = XMLHelper::getNextSiblingElement(e, DiscoveryFilter);
87     }
88 }
89
90 DiscoverableMetadataProvider::~DiscoverableMetadataProvider()
91 {
92 }
93
94 void DiscoverableMetadataProvider::generateFeed()
95 {
96     m_feed.erase();
97     bool first = true;
98     const XMLObject* object = getMetadata();
99     discoGroup(m_feed, dynamic_cast<const EntitiesDescriptor*>(object), first);
100     discoEntity(m_feed, dynamic_cast<const EntityDescriptor*>(object), first);
101
102     SAMLConfig::getConfig().generateRandomBytes(m_feedTag, 4);
103     m_feedTag = SAMLArtifact::toHex(m_feedTag);
104 }
105
106 string DiscoverableMetadataProvider::getCacheTag() const
107 {
108     return m_feedTag;
109 }
110
111 void DiscoverableMetadataProvider::outputFeed(ostream& os, bool& first, bool wrapArray) const
112 {
113     if (wrapArray)
114         os << '[';
115     if (!m_feed.empty()) {
116         if (first)
117             first = false;
118         else
119             os << ",\n";
120         os << m_feed;
121     }
122     if (wrapArray)
123         os << "\n]";
124 }
125
126 namespace {
127     static string& json_safe(string& s, const char* buf)
128     {
129         for (; *buf; ++buf) {
130             switch (*buf) {
131                 case '\\':
132                 case '"':
133                     s += '\\';
134                     s += *buf;
135                     break;
136                 case '\b':
137                     s += "\\b";
138                     break;
139                 case '\t':
140                     s += "\\t";
141                     break;
142                 case '\n':
143                     s += "\\n";
144                     break;
145                 case '\f':
146                     s += "\\f";
147                     break;
148                 case '\r':
149                     s += "\\r";
150                     break;
151                 default:
152                     s += *buf;
153             }
154         }
155         return s;
156     }
157 };
158
159 void DiscoverableMetadataProvider::discoEntity(string& s, const EntityDescriptor* entity, bool& first) const
160 {
161     time_t now = time(nullptr);
162     if (entity && entity->isValid(now)) {
163
164         // Check filter(s).
165         for (vector< pair < bool, boost::shared_ptr<EntityMatcher> > >::const_iterator f = m_discoFilters.begin(); f != m_discoFilters.end(); ++f) {
166             // The flag is true for a whitelist and false for a blacklist,
167             // so we omit the entity if the match outcome is the inverse.
168             if (f->first != f->second->matches(*entity))
169                 return;
170         }
171
172         const vector<IDPSSODescriptor*>& idps = entity->getIDPSSODescriptors();
173         if (!idps.empty()) {
174             auto_ptr_char entityid(entity->getEntityID());
175             // Open a struct and output id: entityID.
176             if (first)
177                 first = false;
178             else
179                 s += ',';
180             s += "\n{\n \"entityID\": \"";
181             json_safe(s, entityid.get());
182             s += '\"';
183             bool extFound = false;
184             bool displayNameFound = false;
185             for (indirect_iterator<vector<IDPSSODescriptor*>::const_iterator> idp = make_indirect_iterator(idps.begin());
186                     !extFound && idp != make_indirect_iterator(idps.end()); ++idp) {
187                 if (idp->isValid(now) && idp->getExtensions()) {
188                     const vector<XMLObject*>& exts =  const_cast<const Extensions*>(idp->getExtensions())->getUnknownXMLObjects();
189                     for (vector<XMLObject*>::const_iterator ext = exts.begin(); !extFound && ext != exts.end(); ++ext) {
190                         const UIInfo* info = dynamic_cast<UIInfo*>(*ext);
191                         if (info) {
192                             extFound = true;
193                             const vector<DisplayName*>& dispnames = info->getDisplayNames();
194                             if (!dispnames.empty()) {
195                                 displayNameFound = true;
196                                 s += ",\n \"DisplayNames\": [";
197                                 for (indirect_iterator<vector<DisplayName*>::const_iterator> dispname = make_indirect_iterator(dispnames.begin());
198                                         dispname != make_indirect_iterator(dispnames.end()); ++dispname) {
199                                     if (dispname.base() != dispnames.begin())
200                                         s += ',';
201                                     auto_arrayptr<char> val(toUTF8(dispname->getName()));
202                                     auto_ptr_char lang(dispname->getLang());
203                                     s += "\n  {\n  \"value\": \"";
204                                     json_safe(s, val.get());
205                                     s += "\",\n  \"lang\": \"";
206                                     s += lang.get();
207                                     s += "\"\n  }";
208                                 }
209                                 s += "\n ]";
210                             }
211
212                             const vector<Description*>& descs = info->getDescriptions();
213                             if (!descs.empty()) {
214                                 s += ",\n \"Descriptions\": [";
215                                 for (indirect_iterator<vector<Description*>::const_iterator> desc = make_indirect_iterator(descs.begin());
216                                         desc != make_indirect_iterator(descs.end()); ++desc) {
217                                     if (desc.base() != descs.begin())
218                                         s += ',';
219                                     auto_arrayptr<char> val(toUTF8(desc->getDescription()));
220                                     auto_ptr_char lang(desc->getLang());
221                                     s += "\n  {\n  \"value\": \"";
222                                     json_safe(s, val.get());
223                                     s += "\",\n  \"lang\": \"";
224                                     s += lang.get();
225                                     s += "\"\n  }";
226                                 }
227                                 s += "\n ]";
228                             }
229
230                             const vector<Keywords*>& keywords = info->getKeywordss();
231                             if (!keywords.empty()) {
232                                 s += ",\n \"Keywords\": [";
233                                 for (indirect_iterator<vector<Keywords*>::const_iterator> words = make_indirect_iterator(keywords.begin());
234                                         words != make_indirect_iterator(keywords.end()); ++words) {
235                                     if (words.base() != keywords.begin())
236                                         s += ',';
237                                     auto_arrayptr<char> val(toUTF8(words->getValues()));
238                                     auto_ptr_char lang(words->getLang());
239                                     s += "\n  {\n  \"value\": \"";
240                                     json_safe(s, val.get());
241                                     s += "\",\n  \"lang\": \"";
242                                     s += lang.get();
243                                     s += "\"\n  }";
244                                 }
245                                 s += "\n ]";
246                             }
247
248                             const vector<InformationURL*>& infurls = info->getInformationURLs();
249                             if (!infurls.empty()) {
250                                 s += ",\n \"InformationURLs\": [";
251                                 for (indirect_iterator<vector<InformationURL*>::const_iterator> infurl = make_indirect_iterator(infurls.begin());
252                                         infurl != make_indirect_iterator(infurls.end()); ++infurl) {
253                                     if (infurl.base() != infurls.begin())
254                                         s += ',';
255                                     auto_ptr_char val(infurl->getURL());
256                                     auto_ptr_char lang(infurl->getLang());
257                                     s += "\n  {\n  \"value\": \"";
258                                     json_safe(s, val.get());
259                                     s += "\",\n  \"lang\": \"";
260                                     s += lang.get();
261                                     s += "\"\n  }";
262                                 }
263                                 s += "\n ]";
264                             }
265
266                             const vector<PrivacyStatementURL*>& privs = info->getPrivacyStatementURLs();
267                             if (!privs.empty()) {
268                                 s += ",\n \"PrivacyStatementURLs\": [";
269                                 for (indirect_iterator<vector<PrivacyStatementURL*>::const_iterator> priv = make_indirect_iterator(privs.begin());
270                                         priv != make_indirect_iterator(privs.end()); ++priv) {
271                                     if (priv.base() != privs.begin())
272                                         s += ',';
273                                     auto_ptr_char val(priv->getURL());
274                                     auto_ptr_char lang(priv->getLang());
275                                     s += "\n  {\n  \"value\": \"";
276                                     json_safe(s, val.get());
277                                     s += "\",\n  \"lang\": \"";
278                                     s += lang.get();
279                                     s += "\"\n  }";
280                                 }
281                                 s += "\n ]";
282                             }
283
284                             const vector<Logo*>& logos = info->getLogos();
285                             if (!logos.empty()) {
286                                 s += ",\n \"Logos\": [";
287                                 for (indirect_iterator<vector<Logo*>::const_iterator> logo = make_indirect_iterator(logos.begin());
288                                         logo != make_indirect_iterator(logos.end()); ++logo) {
289                                     if (logo.base() != logos.begin())
290                                         s += ',';
291                                     s += "\n  {\n";
292                                     auto_ptr_char val(logo->getURL());
293                                     s += "  \"value\": \"";
294                                     json_safe(s, val.get());
295                                     ostringstream ht;
296                                     ht << logo->getHeight().second;
297                                     s += "\",\n  \"height\": \"";
298                                     s += ht.str();
299                                     ostringstream wt;
300                                     wt << logo->getWidth().second;
301                                     s += "\",\n  \"width\": \"";
302                                     s += wt.str();
303                                     s += '\"';
304                                     if (logo->getLang()) {
305                                         auto_ptr_char lang(logo->getLang());
306                                         s += ",\n  \"lang\": \"";
307                                         s += lang.get();
308                                         s += '\"';
309                                     }
310                                     s += "\n  }";
311                                 }
312                                 s += "\n ]";
313                             }
314                         }
315                     }
316                 }
317             }
318
319             if (m_legacyOrgNames && !displayNameFound) {
320                 const Organization* org = nullptr;
321                 for (indirect_iterator<vector<IDPSSODescriptor*>::const_iterator> idp = make_indirect_iterator(idps.begin());
322                         !org && idp != make_indirect_iterator(idps.end()); ++idp) {
323                     if (idp->isValid(now))
324                         org = idp->getOrganization();
325                 }
326                 if (!org)
327                     org = entity->getOrganization();
328                 if (org) {
329                     const vector<OrganizationDisplayName*>& odns = org->getOrganizationDisplayNames();
330                     if (!odns.empty()) {
331                         s += ",\n \"DisplayNames\": [";
332                         for (indirect_iterator<vector<OrganizationDisplayName*>::const_iterator> dispname = make_indirect_iterator(odns.begin());
333                                 dispname != make_indirect_iterator(odns.end()); ++dispname) {
334                             if (dispname.base() != odns.begin())
335                                 s += ',';
336                             auto_arrayptr<char> val(toUTF8(dispname->getName()));
337                             auto_ptr_char lang(dispname->getLang());
338                             s += "\n  {\n  \"value\": \"";
339                             json_safe(s, val.get());
340                             s += "\",\n  \"lang\": \"";
341                             s += lang.get();
342                             s += "\"\n  }";
343                         }
344                         s += "\n ]";
345                     }
346                 }
347             }
348
349             if (m_entityAttributes) {
350                 bool tagfirst = true;
351                 // Check for an EntityAttributes extension in the entity and its parent(s).
352                 const Extensions* exts = entity->getExtensions();
353                 if (exts) {
354                     const vector<XMLObject*>& children = exts->getUnknownXMLObjects();
355                     const XMLObject* xo = find_if(children, ll_dynamic_cast<EntityAttributes*>(_1) != ((EntityAttributes*)nullptr));
356                     if (xo)
357                         discoEntityAttributes(s, *dynamic_cast<const EntityAttributes*>(xo), tagfirst);
358                 }
359
360                 const EntitiesDescriptor* group = dynamic_cast<EntitiesDescriptor*>(entity->getParent());
361                 while (group) {
362                     exts = group->getExtensions();
363                     if (exts) {
364                         const vector<XMLObject*>& children = exts->getUnknownXMLObjects();
365                         const XMLObject* xo = find_if(children, ll_dynamic_cast<EntityAttributes*>(_1) != ((EntityAttributes*)nullptr));
366                         if (xo)
367                             discoEntityAttributes(s, *dynamic_cast<const EntityAttributes*>(xo), tagfirst);
368                     }
369                     group = dynamic_cast<EntitiesDescriptor*>(group->getParent());
370                 }
371                 if (!tagfirst)
372                     s += "\n ]";
373             }
374
375             // Close the struct;
376             s += "\n}";
377         }
378     }
379 }
380
381 void DiscoverableMetadataProvider::discoGroup(string& s, const EntitiesDescriptor* group, bool& first) const
382 {
383     if (group) {
384         for_each(
385             group->getEntitiesDescriptors().begin(), group->getEntitiesDescriptors().end(),
386             lambda::bind(&DiscoverableMetadataProvider::discoGroup, this, boost::ref(s), _1, boost::ref(first))
387             );
388         for_each(
389             group->getEntityDescriptors().begin(), group->getEntityDescriptors().end(),
390             lambda::bind(&DiscoverableMetadataProvider::discoEntity, this, boost::ref(s), _1, boost::ref(first))
391             );
392     }
393 }
394
395 void DiscoverableMetadataProvider::discoEntityAttributes(std::string& s, const EntityAttributes& ea, bool& first) const
396 {
397     discoAttributes(s, ea.getAttributes(), first);
398     const vector<saml2::Assertion*>& tokens = ea.getAssertions();
399     for (vector<saml2::Assertion*>::const_iterator t = tokens.begin(); t != tokens.end(); ++t) {
400         const vector<AttributeStatement*> statements = const_cast<const saml2::Assertion*>(*t)->getAttributeStatements();
401         for (vector<AttributeStatement*>::const_iterator st = statements.begin(); st != statements.end(); ++st) {
402             discoAttributes(s, const_cast<const AttributeStatement*>(*st)->getAttributes(), first);
403         }
404     }
405 }
406
407 void DiscoverableMetadataProvider::discoAttributes(std::string& s, const vector<Attribute*>& attrs, bool& first) const
408 {
409     for (indirect_iterator<vector<Attribute*>::const_iterator> a = make_indirect_iterator(attrs.begin());
410             a != make_indirect_iterator(attrs.end()); ++a) {
411
412         if (first) {
413             s += ",\n \"EntityAttributes\": [";
414             first = false;
415         }
416         else {
417             s += ',';
418         }
419
420         auto_ptr_char n(a->getName());
421         s += "\n  {\n  \"name\": \"";
422         json_safe(s, n.get());
423         s += "\",\n  \"values\": [";
424         const vector<XMLObject*>& vals = const_cast<const Attribute&>(*a).getAttributeValues();
425         for (indirect_iterator<vector<XMLObject*>::const_iterator> v = make_indirect_iterator(vals.begin());
426                 v != make_indirect_iterator(vals.end()); ++v) {
427             if (v.base() != vals.begin())
428                 s += ',';
429             auto_arrayptr<char> val(toUTF8(v->getTextContent()));
430             s += "\n     \"";
431             if (val.get())
432                 json_safe(s, val.get());
433             s += '\"';
434         }
435         s += "\n  ]\n  }";
436     }
437 }