21a0246224dd991b1cf0d470b6f78d7343cca299
[shibboleth/cpp-sp.git] / plugins / TransformAttributeResolver.cpp
1 /**
2  * Licensed to the University Corporation for Advanced Internet
3  * Development, Inc. (UCAID) under one or more contributor license
4  * agreements. See the NOTICE file distributed with this work for
5  * additional information regarding copyright ownership.
6  *
7  * UCAID licenses this file to you under the Apache License,
8  * Version 2.0 (the "License"); you may not use this file except
9  * in compliance with the License. You may obtain a copy of the
10  * License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing,
15  * software distributed under the License is distributed on an
16  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
17  * either express or implied. See the License for the specific
18  * language governing permissions and limitations under the License.
19  */
20
21 /**
22  * TransformAttributeResolver.cpp
23  * 
24  * Attribute Resolver plugin for transforming input values.
25  */
26
27 #include "internal.h"
28
29 #include <algorithm>
30 #include <boost/shared_ptr.hpp>
31 #include <boost/algorithm/string/trim.hpp>
32 #include <boost/tuple/tuple.hpp>
33 #include <shibsp/exceptions.h>
34 #include <shibsp/SessionCache.h>
35 #include <shibsp/attribute/SimpleAttribute.h>
36 #include <shibsp/attribute/resolver/AttributeResolver.h>
37 #include <shibsp/attribute/resolver/ResolutionContext.h>
38 #include <xmltooling/XMLToolingConfig.h>
39 #include <xmltooling/util/XMLHelper.h>
40 #include <xercesc/util/XMLUniDefs.hpp>
41 #include <xercesc/util/regx/RegularExpression.hpp>
42
43 using namespace shibsp;
44 using namespace xmltooling;
45 using namespace xercesc;
46 using namespace boost;
47 using namespace std;
48
49 namespace shibsp {
50
51     class SHIBSP_DLLLOCAL TransformContext : public ResolutionContext
52     {
53     public:
54         TransformContext(const vector<Attribute*>* attributes) : m_inputAttributes(attributes) {
55         }
56
57         ~TransformContext() {
58             for_each(m_attributes.begin(), m_attributes.end(), xmltooling::cleanup<Attribute>());
59         }
60
61         const vector<Attribute*>* getInputAttributes() const {
62             return m_inputAttributes;
63         }
64         vector<Attribute*>& getResolvedAttributes() {
65             return m_attributes;
66         }
67         vector<opensaml::Assertion*>& getResolvedAssertions() {
68             return m_assertions;
69         }
70
71     private:
72         const vector<Attribute*>* m_inputAttributes;
73         vector<Attribute*> m_attributes;
74         static vector<opensaml::Assertion*> m_assertions;   // empty dummy
75     };
76
77
78     class SHIBSP_DLLLOCAL TransformAttributeResolver : public AttributeResolver
79     {
80     public:
81         TransformAttributeResolver(const DOMElement* e);
82         virtual ~TransformAttributeResolver() {}
83
84         Lockable* lock() {
85             return this;
86         }
87         void unlock() {
88         }
89
90         ResolutionContext* createResolutionContext(
91             const Application& application,
92             const opensaml::saml2md::EntityDescriptor* issuer,
93             const XMLCh* protocol,
94             const opensaml::saml2::NameID* nameid=nullptr,
95             const XMLCh* authncontext_class=nullptr,
96             const XMLCh* authncontext_decl=nullptr,
97             const vector<const opensaml::Assertion*>* tokens=nullptr,
98             const vector<Attribute*>* attributes=nullptr
99             ) const {
100             // Make sure new method gets run.
101             return createResolutionContext(application, nullptr, issuer, protocol, nameid, authncontext_class, authncontext_decl, tokens, attributes);
102         }
103
104         ResolutionContext* createResolutionContext(
105             const Application& application,
106             const GenericRequest* request,
107             const opensaml::saml2md::EntityDescriptor* issuer,
108             const XMLCh* protocol,
109             const opensaml::saml2::NameID* nameid=nullptr,
110             const XMLCh* authncontext_class=nullptr,
111             const XMLCh* authncontext_decl=nullptr,
112             const vector<const opensaml::Assertion*>* tokens=nullptr,
113             const vector<Attribute*>* attributes=nullptr
114             ) const {
115             return new TransformContext(attributes);
116         }
117
118         ResolutionContext* createResolutionContext(const Application& application, const Session& session) const {
119             return new TransformContext(&session.getAttributes());
120         }
121
122         void resolveAttributes(ResolutionContext& ctx) const;
123
124         void getAttributeIds(vector<string>& attributes) const {
125             for (vector<regex_t>::const_iterator r = m_regex.begin(); r != m_regex.end(); ++r) {
126                 if (!r->get<0>().empty())
127                     attributes.push_back(r->get<0>());
128             }
129         }
130
131     private:
132         Category& m_log;
133         string m_source;
134         // dest id, regex to apply, replacement string
135         typedef tuple<string,boost::shared_ptr<RegularExpression>,const XMLCh*> regex_t;
136         vector<regex_t> m_regex;
137     };
138
139     static const XMLCh dest[] =             UNICODE_LITERAL_4(d,e,s,t);
140     static const XMLCh match[] =            UNICODE_LITERAL_5(m,a,t,c,h);
141     static const XMLCh caseSensitive[] =    UNICODE_LITERAL_13(c,a,s,e,S,e,n,s,i,t,i,v,e);
142     static const XMLCh source[] =           UNICODE_LITERAL_6(s,o,u,r,c,e);
143     static const XMLCh Regex[] =            UNICODE_LITERAL_5(R,e,g,e,x);
144
145     AttributeResolver* SHIBSP_DLLLOCAL TransformAttributeResolverFactory(const DOMElement* const & e)
146     {
147         return new TransformAttributeResolver(e);
148     }
149
150 };
151
152 vector<opensaml::Assertion*> TransformContext::m_assertions;
153
154 TransformAttributeResolver::TransformAttributeResolver(const DOMElement* e)
155     : m_log(Category::getInstance(SHIBSP_LOGCAT ".AttributeResolver.Transform")),
156         m_source(XMLHelper::getAttrString(e, nullptr, source))
157 {
158     if (m_source.empty())
159         throw ConfigurationException("Transform AttributeResolver requires source attribute.");
160
161     e = XMLHelper::getFirstChildElement(e, Regex);
162     while (e) {
163         if (e->hasChildNodes() && e->hasAttributeNS(nullptr, match)) {
164             const XMLCh* repl(e->getTextContent());
165             string destId(XMLHelper::getAttrString(e, nullptr, dest));
166             bool caseflag(XMLHelper::getAttrBool(e, true, caseSensitive));
167             if (repl && *repl) {
168                 try {
169                     static XMLCh options[] = { chLatin_i, chNull };
170                     boost::shared_ptr<RegularExpression> re(new RegularExpression(e->getAttributeNS(nullptr, match), (caseflag ? &chNull : options)));
171                     m_regex.push_back(make_tuple(destId, re, repl));
172                 }
173                 catch (XMLException& ex) {
174                     auto_ptr_char msg(ex.getMessage());
175                     auto_ptr_char m(e->getAttributeNS(nullptr, match));
176                     m_log.error("exception parsing regular expression (%s): %s", m.get(), msg.get());
177                 }
178             }
179         }
180         e = XMLHelper::getNextSiblingElement(e, Regex);
181     }
182
183     if (m_regex.empty())
184         throw ConfigurationException("Transform AttributeResolver requires at least one Regex element.");
185 }
186
187
188 void TransformAttributeResolver::resolveAttributes(ResolutionContext& ctx) const
189 {
190     TransformContext& tctx = dynamic_cast<TransformContext&>(ctx);
191     if (!tctx.getInputAttributes())
192         return;
193
194     for (vector<Attribute*>::const_iterator a = tctx.getInputAttributes()->begin(); a != tctx.getInputAttributes()->end(); ++a) {
195         if (m_source != (*a)->getId() || (*a)->valueCount() == 0) {
196             continue;
197         }
198
199         // We run each transform expression against each value of the input. Each transform either generates
200         // a new attribute from its dest property, or overwrites a SimpleAttribute's values in place.
201
202         for (vector<regex_t>::const_iterator r = m_regex.begin(); r != m_regex.end(); ++r) {
203             SimpleAttribute* dest = nullptr;
204             auto_ptr<SimpleAttribute> destwrapper;
205
206             // First tuple element is the destination attribute ID, if any.
207             if (r->get<0>().empty()) {
208                 // Can we transform in-place?
209                 dest = dynamic_cast<SimpleAttribute*>(*a);
210                 if (!dest) {
211                     m_log.warn("can't transform non-simple attribute (%s) 'in place'", m_source.c_str());
212                     continue;
213                 }
214             }
215             else {
216                 // Create a destination attribute.
217                 vector<string> ids(1, r->get<0>());
218                 destwrapper.reset(new SimpleAttribute(ids));
219             }
220
221             if (dest)
222                 m_log.debug("applying in-place transform to source attribute (%s)", m_source.c_str());
223             else
224                 m_log.debug("applying transform from source attribute (%s) to dest attribute (%s)", m_source.c_str(), r->get<0>().c_str());
225
226             for (size_t i = 0; i < (*a)->valueCount(); ++i) {
227                 try {
228                     auto_arrayptr<XMLCh> srcval(fromUTF8((*a)->getSerializedValues()[i].c_str()));
229                     XMLCh* destval = r->get<1>()->replace(srcval.get(), r->get<2>());
230                     if (!destval)
231                         continue;
232                     // For some reason, it returns the source string if the match doesn't succeed.
233                     if (!XMLString::equals(destval, srcval.get())) {
234                         auto_arrayptr<char> narrow(toUTF8(destval));
235                         XMLString::release(&destval);
236                         if (dest) {
237                             // Modify in place.
238                             dest->getValues()[i] = narrow.get();
239                             trim(dest->getValues()[i]);
240                         }
241                         else {
242                             // Add to new object.
243                             destwrapper->getValues().push_back(narrow.get());
244                             trim(destwrapper->getValues().back());
245                         }
246                     }
247                     else {
248                         XMLString::release(&destval);
249                     }
250                 }
251                 catch (XMLException& ex) {
252                     auto_ptr_char msg(ex.getMessage());
253                     m_log.error("caught error applying regular expression: %s", msg.get());
254                 }
255             }
256
257             // Save off new object.
258             if (destwrapper.get()) {
259                 ctx.getResolvedAttributes().push_back(destwrapper.get());
260                 destwrapper.release();
261             }
262         }
263     }
264 }