SSPCPP-356 - Better support message-level security on the back channel
[shibboleth/cpp-sp.git] / shibsp / attribute / resolver / impl / SimpleAggregationAttributeResolver.cpp
index 7e2d1d0..e4068d7 100644 (file)
@@ -75,6 +75,7 @@ namespace shibsp {
     public:
         SimpleAggregationContext(const Application& application, const Session& session)
             : m_app(application),
+              m_request(nullptr),
               m_session(&session),
               m_nameid(nullptr),
               m_class(session.getAuthnContextClassRef()),
@@ -85,6 +86,7 @@ namespace shibsp {
 
         SimpleAggregationContext(
             const Application& application,
+            const GenericRequest* request=nullptr,
             const NameID* nameid=nullptr,
             const XMLCh* entityID=nullptr,
             const XMLCh* authncontext_class=nullptr,
@@ -92,6 +94,7 @@ namespace shibsp {
             const vector<const opensaml::Assertion*>* tokens=nullptr,
             const vector<shibsp::Attribute*>* attributes=nullptr
             ) : m_app(application),
+                m_request(request),
                 m_session(nullptr),
                 m_nameid(nameid),
                 m_entityid(entityID),
@@ -109,6 +112,9 @@ namespace shibsp {
         const Application& getApplication() const {
             return m_app;
         }
+        const GenericRequest* getRequest() const {
+            return m_request;
+        }
         const char* getEntityID() const {
             return m_session ? m_session->getEntityID() : m_entityid.get();
         }
@@ -139,6 +145,7 @@ namespace shibsp {
 
     private:
         const Application& m_app;
+        const GenericRequest* m_request;
         const Session* m_session;
         const NameID* m_nameid;
         auto_ptr_char m_entityid;
@@ -159,8 +166,23 @@ namespace shibsp {
         Lockable* lock() {return this;}
         void unlock() {}
 
+        // deprecated method
+        ResolutionContext* createResolutionContext(
+            const Application& application,
+            const EntityDescriptor* issuer,
+            const XMLCh* protocol,
+            const NameID* nameid=nullptr,
+            const XMLCh* authncontext_class=nullptr,
+            const XMLCh* authncontext_decl=nullptr,
+            const vector<const opensaml::Assertion*>* tokens=nullptr,
+            const vector<shibsp::Attribute*>* attributes=nullptr
+            ) const {
+            return createResolutionContext(application, nullptr, issuer, protocol, nameid, authncontext_class, authncontext_decl, tokens, attributes);
+        }
+
         ResolutionContext* createResolutionContext(
             const Application& application,
+            const GenericRequest* request,
             const EntityDescriptor* issuer,
             const XMLCh* protocol,
             const NameID* nameid=nullptr,
@@ -170,7 +192,7 @@ namespace shibsp {
             const vector<shibsp::Attribute*>* attributes=nullptr
             ) const {
             return new SimpleAggregationContext(
-                application, nameid, (issuer ? issuer->getEntityID() : nullptr), authncontext_class, authncontext_decl, tokens, attributes
+                application, request, nameid, (issuer ? issuer->getEntityID() : nullptr), authncontext_class, authncontext_decl, tokens, attributes
                 );
         }
 
@@ -181,7 +203,8 @@ namespace shibsp {
         void resolveAttributes(ResolutionContext& ctx) const;
 
         void getAttributeIds(vector<string>& attributes) const {
-            // Nothing to do, only the extractor would actually generate them.
+            if (m_extractor)
+                m_extractor->getAttributeIds(attributes);
         }
 
     private:
@@ -194,6 +217,8 @@ namespace shibsp {
         xstring m_format;
         scoped_ptr<MetadataProvider> m_metadata;
         scoped_ptr<TrustEngine> m_trust;
+        scoped_ptr<AttributeExtractor> m_extractor;
+        scoped_ptr<AttributeFilter> m_filter;
         ptr_vector<saml2::Attribute> m_designators;
         vector< pair<string,bool> > m_sources;
         vector<string> m_exceptionId;
@@ -204,6 +229,8 @@ namespace shibsp {
         return new SimpleAggregationResolver(e);
     }
 
+    static const XMLCh _AttributeExtractor[] =  UNICODE_LITERAL_18(A,t,t,r,i,b,u,t,e,E,x,t,r,a,c,t,o,r);
+    static const XMLCh _AttributeFilter[] =     UNICODE_LITERAL_15(A,t,t,r,i,b,u,t,e,F,i,l,t,e,r);
     static const XMLCh attributeId[] =          UNICODE_LITERAL_11(a,t,t,r,i,b,u,t,e,I,d);
     static const XMLCh Entity[] =               UNICODE_LITERAL_6(E,n,t,i,t,y);
     static const XMLCh EntityReference[] =      UNICODE_LITERAL_15(E,n,t,i,t,y,R,e,f,e,r,e,n,c,e);
@@ -217,7 +244,7 @@ namespace shibsp {
 };
 
 SimpleAggregationResolver::SimpleAggregationResolver(const DOMElement* e)
-    : m_log(Category::getInstance(SHIBSP_LOGCAT".AttributeResolver.SimpleAggregation")),
+    : m_log(Category::getInstance(SHIBSP_LOGCAT ".AttributeResolver.SimpleAggregation")),
         m_policyId(XMLHelper::getAttrString(e, nullptr, policyId)),
         m_subjectMatch(XMLHelper::getAttrBool(e, false, subjectMatch))
 {
@@ -229,6 +256,7 @@ SimpleAggregationResolver::SimpleAggregationResolver(const DOMElement* e)
     if (aid && *aid) {
         auto_ptr_char dup(aid);
         string sdup(dup.get());
+        trim(sdup);
         split(m_attributeIds, sdup, is_space(), algorithm::token_compress_on);
 
         aid = e->getAttributeNS(nullptr, format);
@@ -259,6 +287,24 @@ SimpleAggregationResolver::SimpleAggregationResolver(const DOMElement* e)
         m_trust.reset(XMLToolingConfig::getConfig().TrustEngineManager.newPlugin(t.c_str(), child));
     }
 
+    child = XMLHelper::getFirstChildElement(e,  _AttributeExtractor);
+    if (child) {
+        string t(XMLHelper::getAttrString(child, nullptr, _type));
+        if (t.empty())
+            throw ConfigurationException("AttributeExtractor element missing type attribute.");
+        m_log.info("building AttributeExtractor of type %s...", t.c_str());
+        m_extractor.reset(SPConfig::getConfig().AttributeExtractorManager.newPlugin(t.c_str(), child));
+    }
+
+    child = XMLHelper::getFirstChildElement(e,  _AttributeFilter);
+    if (child) {
+        string t(XMLHelper::getAttrString(child, nullptr, _type));
+        if (t.empty())
+            throw ConfigurationException("AttributeFilter element missing type attribute.");
+        m_log.info("building AttributeFilter of type %s...", t.c_str());
+        m_filter.reset(SPConfig::getConfig().AttributeFilterManager.newPlugin(t.c_str(), child));
+    }
+
     child = XMLHelper::getFirstChildElement(e);
     while (child) {
         if (child->hasChildNodes() && XMLString::equals(child->getLocalName(), Entity)) {
@@ -344,20 +390,35 @@ void SimpleAggregationResolver::doQuery(SimpleAggregationContext& ctx, const cha
             auto_ptr<saml2::Subject> subject(saml2::SubjectBuilder::buildSubject());
 
             // Encrypt the NameID?
-            if (encryption.first && (!strcmp(encryption.second, "true") || !strcmp(encryption.second, "back"))) {
-                auto_ptr<EncryptedID> encrypted(EncryptedIDBuilder::buildEncryptedID());
-                encrypted->encrypt(
-                    *name,
-                    *(policy->getMetadataProvider()),
-                    mcc,
-                    false,
-                    relyingParty->getXMLString("encryptionAlg").second
+            if (SPConfig::shouldSignOrEncrypt(encryption.first ? encryption.second : "conditional", loc.get(), false)) {
+                try {
+                    auto_ptr<EncryptedID> encrypted(EncryptedIDBuilder::buildEncryptedID());
+                    encrypted->encrypt(
+                        *name,
+                        *(policy->getMetadataProvider()),
+                        mcc,
+                        false,
+                        relyingParty->getXMLString("encryptionAlg").second
                     );
-                subject->setEncryptedID(encrypted.get());
-                encrypted.release();
+                    subject->setEncryptedID(encrypted.get());
+                    encrypted.release();
+                }
+                catch (std::exception& ex) {
+                    // If we're encrypting deliberately, failure should be fatal.
+                    if (encryption.first && strcmp(encryption.second, "conditional")) {
+                        throw;
+                    }
+                    // If opportunistically, just log and move on.
+                    m_log.info("Conditional encryption of NameID in AttributeQuery failed: %s", ex.what());
+                    auto_ptr<NameID> namewrapper(name->cloneNameID());
+                    subject->setNameID(namewrapper.get());
+                    namewrapper.release();
+                }
             }
             else {
-                subject->setNameID(name->cloneNameID());
+                auto_ptr<NameID> namewrapper(name->cloneNameID());
+                subject->setNameID(namewrapper.get());
+                namewrapper.release();
             }
 
             saml2p::AttributeQuery* query = saml2p::AttributeQueryBuilder::buildAttributeQuery();
@@ -420,7 +481,7 @@ void SimpleAggregationResolver::doQuery(SimpleAggregationContext& ctx, const cha
 
         // With this flag on, we block unauthenticated ciphertext when decrypting,
         // unless the protocol was authenticated.
-        pair<bool,bool> authenticatedCipher = application.getBool("requireAuthenticatedCipher");
+        pair<bool,bool> authenticatedCipher = application.getBool("requireAuthenticatedEncryption");
         if (policy->isAuthenticated())
             authenticatedCipher.second = false;
 
@@ -437,11 +498,11 @@ void SimpleAggregationResolver::doQuery(SimpleAggregationContext& ctx, const cha
                 tokenwrapper.release();
                 newtokenwrapper.reset(newtoken);
                 if (m_log.isDebugEnabled())
-                    m_log.debugStream() << "decrypted Assertion: " << *newtoken << logging::eol;
+                    m_log.debugStream() << "decrypted assertion: " << *newtoken << logging::eol;
             }
         }
         catch (std::exception& ex) {
-            m_log.error(ex.what());
+            m_log.error("failed to decrypt assertion: %s", ex.what());
             throw;
         }
     }
@@ -523,13 +584,13 @@ void SimpleAggregationResolver::doQuery(SimpleAggregationContext& ctx, const cha
 
     // Finally, extract and filter the result.
     try {
-        AttributeExtractor* extractor = application.getAttributeExtractor();
+        AttributeExtractor* extractor = m_extractor ? m_extractor.get() : application.getAttributeExtractor();
         if (extractor) {
             Locker extlocker(extractor);
-            extractor->extractAttributes(application, AA, *newtoken, ctx.getResolvedAttributes());
+            extractor->extractAttributes(application, ctx.getRequest(), AA, *newtoken, ctx.getResolvedAttributes());
         }
 
-        AttributeFilter* filter = application.getAttributeFilter();
+        AttributeFilter* filter = m_filter ? m_filter.get() : application.getAttributeFilter();
         if (filter) {
             BasicFilteringContext fc(application, ctx.getResolvedAttributes(), AA, ctx.getClassRef(), ctx.getDeclRef());
             Locker filtlocker(filter);