2baeb6f4ae87a3ead59a4719d950cbbc144ed3a5
[shibboleth/cpp-sp.git] / plugins / TemplateAttributeResolver.cpp
1 /**
2  * Licensed to the University Corporation for Advanced Internet
3  * Development, Inc. (UCAID) under one or more contributor license
4  * agreements. See the NOTICE file distributed with this work for
5  * additional information regarding copyright ownership.
6  *
7  * UCAID licenses this file to you under the Apache License,
8  * Version 2.0 (the "License"); you may not use this file except
9  * in compliance with the License. You may obtain a copy of the
10  * License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing,
15  * software distributed under the License is distributed on an
16  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
17  * either express or implied. See the License for the specific
18  * language governing permissions and limitations under the License.
19  */
20
21 /**
22  * TemplateAttributeResolver.cpp
23  * 
24  * AttributeResolver plugin for composing input values.
25  */
26
27 #include "internal.h"
28
29 #include <boost/bind.hpp>
30 #include <boost/algorithm/string.hpp>
31 #include <shibsp/exceptions.h>
32 #include <shibsp/SessionCache.h>
33 #include <shibsp/attribute/SimpleAttribute.h>
34 #include <shibsp/attribute/resolver/AttributeResolver.h>
35 #include <shibsp/attribute/resolver/ResolutionContext.h>
36 #include <xmltooling/XMLToolingConfig.h>
37 #include <xmltooling/util/Predicates.h>
38 #include <xmltooling/util/XMLHelper.h>
39 #include <xercesc/util/XMLUniDefs.hpp>
40
41 using namespace shibsp;
42 using namespace xmltooling;
43 using namespace xercesc;
44 using namespace boost;
45 using namespace std;
46
47 namespace shibsp {
48
49     class SHIBSP_DLLLOCAL TemplateContext : public ResolutionContext
50     {
51     public:
52         TemplateContext(const vector<Attribute*>* attributes) : m_inputAttributes(attributes) {
53         }
54
55         ~TemplateContext() {
56             for_each(m_attributes.begin(), m_attributes.end(), xmltooling::cleanup<Attribute>());
57         }
58
59         const vector<Attribute*>* getInputAttributes() const {
60             return m_inputAttributes;
61         }
62         vector<Attribute*>& getResolvedAttributes() {
63             return m_attributes;
64         }
65         vector<opensaml::Assertion*>& getResolvedAssertions() {
66             return m_assertions;
67         }
68
69     private:
70         const vector<Attribute*>* m_inputAttributes;
71         vector<Attribute*> m_attributes;
72         static vector<opensaml::Assertion*> m_assertions;   // empty dummy
73     };
74
75
76     class SHIBSP_DLLLOCAL TemplateAttributeResolver : public AttributeResolver
77     {
78     public:
79         TemplateAttributeResolver(const DOMElement* e);
80         virtual ~TemplateAttributeResolver() {}
81
82         Lockable* lock() {
83             return this;
84         }
85         void unlock() {
86         }
87
88         ResolutionContext* createResolutionContext(
89             const Application& application,
90             const opensaml::saml2md::EntityDescriptor* issuer,
91             const XMLCh* protocol,
92             const opensaml::saml2::NameID* nameid=nullptr,
93             const XMLCh* authncontext_class=nullptr,
94             const XMLCh* authncontext_decl=nullptr,
95             const vector<const opensaml::Assertion*>* tokens=nullptr,
96             const vector<Attribute*>* attributes=nullptr
97             ) const {
98             // Make sure new method gets run.
99             return createResolutionContext(application, nullptr, issuer, protocol, nameid, authncontext_class, authncontext_decl, tokens, attributes);
100         }
101
102         ResolutionContext* createResolutionContext(
103             const Application& application,
104             const GenericRequest* request,
105             const opensaml::saml2md::EntityDescriptor* issuer,
106             const XMLCh* protocol,
107             const opensaml::saml2::NameID* nameid=nullptr,
108             const XMLCh* authncontext_class=nullptr,
109             const XMLCh* authncontext_decl=nullptr,
110             const vector<const opensaml::Assertion*>* tokens=nullptr,
111             const vector<Attribute*>* attributes=nullptr
112             ) const {
113             return new TemplateContext(attributes);
114         }
115
116         ResolutionContext* createResolutionContext(const Application& application, const Session& session) const {
117             return new TemplateContext(&session.getAttributes());
118         }
119
120         void resolveAttributes(ResolutionContext& ctx) const;
121
122         void getAttributeIds(vector<string>& attributes) const {
123             attributes.push_back(m_dest.front());
124         }
125
126     private:
127         Category& m_log;
128         string m_template;
129         vector<string> m_sources,m_dest;
130     };
131
132     static const XMLCh dest[] =     UNICODE_LITERAL_4(d,e,s,t);
133     static const XMLCh _sources[] = UNICODE_LITERAL_7(s,o,u,r,c,e,s);
134     static const XMLCh Template[] = UNICODE_LITERAL_8(T,e,m,p,l,a,t,e);
135
136     AttributeResolver* SHIBSP_DLLLOCAL TemplateAttributeResolverFactory(const DOMElement* const & e)
137     {
138         return new TemplateAttributeResolver(e);
139     }
140
141 };
142
143 vector<opensaml::Assertion*> TemplateContext::m_assertions;
144
145 TemplateAttributeResolver::TemplateAttributeResolver(const DOMElement* e)
146     : m_log(Category::getInstance(SHIBSP_LOGCAT".AttributeResolver.Template")),
147         m_dest(1, XMLHelper::getAttrString(e, nullptr, dest))
148 {
149     if (m_dest.front().empty())
150         throw ConfigurationException("Template AttributeResolver requires dest attribute.");
151
152     string s(XMLHelper::getAttrString(e, nullptr, _sources));
153     trim(s);
154     split(m_sources, s, is_space(), algorithm::token_compress_on);
155     if (m_sources.empty())
156         throw ConfigurationException("Template AttributeResolver requires sources attribute.");
157
158     e = e ? XMLHelper::getFirstChildElement(e, Template) : nullptr;
159     auto_ptr_char t(e ? e->getTextContent() : nullptr);
160     if (t.get()) {
161         m_template = t.get();
162         trim(m_template);
163     }
164     if (m_template.empty())
165         throw ConfigurationException("Template AttributeResolver requires <Template> child element.");
166 }
167
168
169 void TemplateAttributeResolver::resolveAttributes(ResolutionContext& ctx) const
170 {
171     TemplateContext& tctx = dynamic_cast<TemplateContext&>(ctx);
172     if (!tctx.getInputAttributes())
173         return;
174
175     map<string,const Attribute*> attrmap;
176     for (vector<string>::const_iterator a = m_sources.begin(); a != m_sources.end(); ++a) {
177         static bool (*eq)(const string&, const char*) = operator==;
178         const Attribute* attr = find_if(*tctx.getInputAttributes(), boost::bind(eq, boost::cref(*a), boost::bind(&Attribute::getId, _1)));
179         if (!attr) {
180             m_log.warn("source attribute (%s) missing, cannot resolve attribute (%s)", a->c_str(), m_dest.front().c_str());
181             return;
182         }
183         else if (!attrmap.empty() && attr->valueCount() != attrmap.begin()->second->valueCount()) {
184             m_log.warn("all source attributes must contain equal number of values, cannot resolve attribute (%s)", m_dest.front().c_str());
185             return;
186         }
187         attrmap[*a] = attr;
188     }
189
190     auto_ptr<SimpleAttribute> dest(new SimpleAttribute(m_dest));
191
192     for (size_t ix = 0; ix < attrmap.begin()->second->valueCount(); ++ix) {
193         static const char* legal="ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz1234567890_-.[]";
194
195         dest->getValues().push_back(string());
196         string& processed = dest->getValues().back();
197
198         string::size_type i=0, start=0;
199         while (start != string::npos && start < m_template.length() && (i = m_template.find("$", start)) != string::npos) {
200             if (i > start)
201                 processed += m_template.substr(start, i - start);   // append everything in between
202             start = i + 1;                                          // move start to the beginning of the token name
203             i = m_template.find_first_not_of(legal, start);         // find token delimiter
204             if (i == start) {                                       // append a non legal character
205                 processed += m_template[start++];
206                 continue;
207             }
208                     
209             map<string,const Attribute*>::const_iterator iter = attrmap.find(m_template.substr(start, (i==string::npos) ? i : i - start));
210             if (iter != attrmap.end())
211                 processed += iter->second->getSerializedValues()[ix];
212             start = i;
213         }
214         if (start != string::npos && start < m_template.length())
215             processed += m_template.substr(start, i);    // append rest of string
216
217         trim(processed);
218         if (processed.empty())
219             dest->getValues().pop_back();
220     }
221
222     // Save off new object.
223     if (dest.get() && dest->valueCount()) {
224         ctx.getResolvedAttributes().push_back(dest.get());
225         dest.release();
226     }
227 }