Support for application-specific attribute IDs.
[shibboleth/sp.git] / shibsp / attribute / resolver / impl / SimpleAttributeResolver.cpp
index 43a2c6c..b09b71b 100644 (file)
@@ -72,7 +72,7 @@ namespace shibsp {
             const Application& application,\r
             const char* client_addr,\r
             const EntityDescriptor* issuer,\r
-            const NameID& nameid,\r
+            const NameID* nameid,\r
             const vector<const opensaml::Assertion*>* tokens=NULL\r
             ) : m_app(application), m_session(NULL), m_client_addr(client_addr), m_metadata(NULL), m_entity(issuer),\r
                 m_nameid(nameid), m_tokens(tokens) {\r
@@ -103,7 +103,7 @@ namespace shibsp {
             }\r
             return NULL;\r
         }\r
-        const NameID& getNameID() const {\r
+        const NameID* getNameID() const {\r
             return m_nameid;\r
         }\r
         const vector<const opensaml::Assertion*>* getTokens() const {\r
@@ -125,7 +125,7 @@ namespace shibsp {
         const char* m_client_addr;\r
         mutable MetadataProvider* m_metadata;\r
         mutable const EntityDescriptor* m_entity;\r
-        const NameID& m_nameid;\r
+        const NameID* m_nameid;\r
         const vector<const opensaml::Assertion*>* m_tokens;\r
         vector<shibsp::Attribute*> m_attributes;\r
         vector<opensaml::Assertion*> m_assertions;\r
@@ -151,27 +151,32 @@ namespace shibsp {
         }\r
 \r
         void query(\r
-            ResolutionContext& ctx, const NameIdentifier& nameid, const vector<const char*>* attributes=NULL\r
+            ResolutionContext& ctx, const NameIdentifier& nameid, const set<string>* attributes=NULL\r
             ) const;\r
         void query(\r
-            ResolutionContext& ctx, const NameID& nameid, const vector<const char*>* attributes=NULL\r
+            ResolutionContext& ctx, const NameID& nameid, const set<string>* attributes=NULL\r
             ) const;\r
         void resolve(\r
-            ResolutionContext& ctx, const saml1::Assertion* token, const vector<const char*>* attributes=NULL\r
+            ResolutionContext& ctx, const saml1::Assertion* token, const set<string>* attributes=NULL\r
             ) const;\r
         void resolve(\r
-            ResolutionContext& ctx, const saml2::Assertion* token, const vector<const char*>* attributes=NULL\r
+            ResolutionContext& ctx, const saml2::Assertion* token, const set<string>* attributes=NULL\r
             ) const;\r
 \r
         bool m_allowQuery;\r
+\r
     private:\r
+        void populateQuery(saml1p::AttributeQuery& query, const string& id) const;\r
+        void populateQuery(saml2p::AttributeQuery& query, const string& id) const;\r
+\r
         DOMDocument* m_document;\r
         map<string,AttributeDecoder*> m_decoderMap;\r
 #ifdef HAVE_GOOD_STL\r
-        map< pair<xstring,xstring>,pair<const AttributeDecoder*,string> > m_attrMap;\r
+        typedef map< pair<xstring,xstring>,pair<const AttributeDecoder*,string> > attrmap_t;\r
 #else\r
-        map< pair<string,string>,pair<const AttributeDecoder*,string> > m_attrMap;\r
+        typedef map< pair<string,string>,pair<const AttributeDecoder*,string> > attrmap_t;\r
 #endif\r
+        attrmap_t m_attrMap;\r
     };\r
     \r
     class SimpleResolver : public AttributeResolver, public ReloadableXMLFile\r
@@ -188,7 +193,7 @@ namespace shibsp {
             const Application& application,\r
             const char* client_addr,\r
             const EntityDescriptor* issuer,\r
-            const NameID& nameid,\r
+            const NameID* nameid,\r
             const vector<const opensaml::Assertion*>* tokens=NULL\r
             ) const {\r
             return new SimpleContext(application,client_addr,issuer,nameid,tokens);\r
@@ -198,7 +203,7 @@ namespace shibsp {
             return new SimpleContext(application,session);\r
         }\r
         \r
-        void resolveAttributes(ResolutionContext& ctx, const vector<const char*>* attributes=NULL) const;\r
+        void resolveAttributes(ResolutionContext& ctx, const set<string>* attributes=NULL) const;\r
 \r
     protected:\r
         pair<bool,DOMElement*> load();\r
@@ -314,14 +319,9 @@ SimpleResolverImpl::SimpleResolverImpl(const DOMElement* e) : m_document(NULL),
 }\r
 \r
 void SimpleResolverImpl::resolve(\r
-    ResolutionContext& ctx, const saml1::Assertion* token, const vector<const char*>* attributes\r
+    ResolutionContext& ctx, const saml1::Assertion* token, const set<string>* attributes\r
     ) const\r
 {\r
-    set<string> aset;\r
-    if (attributes)\r
-        for(vector<const char*>::const_iterator i=attributes->begin(); i!=attributes->end(); ++i)\r
-            aset.insert(*i);\r
-\r
     vector<shibsp::Attribute*>& resolved = ctx.getResolvedAttributes();\r
 \r
     auto_ptr_char assertingParty(ctx.getEntityDescriptor() ? ctx.getEntityDescriptor()->getEntityID() : NULL);\r
@@ -333,21 +333,24 @@ void SimpleResolverImpl::resolve(
     map< pair<string,string>,pair<const AttributeDecoder*,string> >::const_iterator rule;\r
 #endif\r
 \r
-    // Check the NameID based on the format.\r
     const XMLCh* name;\r
-    const XMLCh* format = ctx.getNameID().getFormat();\r
-    if (!format) {\r
-        format = NameID::UNSPECIFIED;\r
+    const XMLCh* format;\r
+    \r
+    // Check the NameID based on the format.\r
+    if (ctx.getNameID()) {\r
+        format = ctx.getNameID()->getFormat();\r
+        if (!format || !*format)\r
+            format = NameID::UNSPECIFIED;\r
 #ifdef HAVE_GOOD_STL\r
         if ((rule=m_attrMap.find(make_pair(format,xstring()))) != m_attrMap.end()) {\r
 #else\r
         auto_ptr_char temp(format);\r
         if ((rule=m_attrMap.find(make_pair(temp.get(),string()))) != m_attrMap.end()) {\r
 #endif\r
-            if (aset.empty() || aset.count(rule->second.second)) {\r
+            if (!attributes || attributes->count(rule->second.second)) {\r
                 resolved.push_back(\r
                     rule->second.first->decode(\r
-                        rule->second.second.c_str(), &ctx.getNameID(), assertingParty.get(), relyingParty\r
+                        rule->second.second.c_str(), ctx.getNameID(), assertingParty.get(), relyingParty\r
                         )\r
                     );\r
             }\r
@@ -362,7 +365,7 @@ void SimpleResolverImpl::resolve(
             format = (*a)->getAttributeNamespace();\r
             if (!name || !*name)\r
                 continue;\r
-            if (!format)\r
+            if (!format || XMLString::equals(format, shibspconstants::SHIB1_ATTRIBUTE_NAMESPACE_URI))\r
                 format = &chNull;\r
 #ifdef HAVE_GOOD_STL\r
             if ((rule=m_attrMap.find(make_pair(name,format))) != m_attrMap.end()) {\r
@@ -371,7 +374,7 @@ void SimpleResolverImpl::resolve(
             auto_ptr_char temp2(format);\r
             if ((rule=m_attrMap.find(make_pair(temp1.get(),temp2.get()))) != m_attrMap.end()) {\r
 #endif\r
-                if (aset.empty() || aset.count(rule->second.second)) {\r
+                if (!attributes || attributes->count(rule->second.second)) {\r
                     resolved.push_back(\r
                         rule->second.first->decode(rule->second.second.c_str(), *a, assertingParty.get(), relyingParty)\r
                         );\r
@@ -382,14 +385,9 @@ void SimpleResolverImpl::resolve(
 }\r
 \r
 void SimpleResolverImpl::resolve(\r
-    ResolutionContext& ctx, const saml2::Assertion* token, const vector<const char*>* attributes\r
+    ResolutionContext& ctx, const saml2::Assertion* token, const set<string>* attributes\r
     ) const\r
 {\r
-    set<string> aset;\r
-    if (attributes)\r
-        for(vector<const char*>::const_iterator i=attributes->begin(); i!=attributes->end(); ++i)\r
-            aset.insert(*i);\r
-\r
     vector<shibsp::Attribute*>& resolved = ctx.getResolvedAttributes();\r
 \r
     auto_ptr_char assertingParty(ctx.getEntityDescriptor() ? ctx.getEntityDescriptor()->getEntityID() : NULL);\r
@@ -401,21 +399,24 @@ void SimpleResolverImpl::resolve(
     map< pair<string,string>,pair<const AttributeDecoder*,string> >::const_iterator rule;\r
 #endif\r
 \r
-    // Check the NameID based on the format.\r
     const XMLCh* name;\r
-    const XMLCh* format = ctx.getNameID().getFormat();\r
-    if (!format) {\r
-        format = NameID::UNSPECIFIED;\r
+    const XMLCh* format;\r
+    \r
+    // Check the NameID based on the format.\r
+    if (ctx.getNameID()) {\r
+        format = ctx.getNameID()->getFormat();\r
+        if (!format || !*format)\r
+            format = NameID::UNSPECIFIED;\r
 #ifdef HAVE_GOOD_STL\r
         if ((rule=m_attrMap.find(make_pair(format,xstring()))) != m_attrMap.end()) {\r
 #else\r
         auto_ptr_char temp(format);\r
         if ((rule=m_attrMap.find(make_pair(temp.get(),string()))) != m_attrMap.end()) {\r
 #endif\r
-            if (aset.empty() || aset.count(rule->second.second)) {\r
+            if (!attributes || attributes->count(rule->second.second)) {\r
                 resolved.push_back(\r
                     rule->second.first->decode(\r
-                        rule->second.second.c_str(), &ctx.getNameID(), assertingParty.get(), relyingParty\r
+                        rule->second.second.c_str(), ctx.getNameID(), assertingParty.get(), relyingParty\r
                         )\r
                     );\r
             }\r
@@ -430,7 +431,9 @@ void SimpleResolverImpl::resolve(
             format = (*a)->getNameFormat();\r
             if (!name || !*name)\r
                 continue;\r
-            if (!format)\r
+            if (!format || !*format)\r
+                format = saml2::Attribute::UNSPECIFIED;\r
+            else if (XMLString::equals(format, saml2::Attribute::URI_REFERENCE))\r
                 format = &chNull;\r
 #ifdef HAVE_GOOD_STL\r
             if ((rule=m_attrMap.find(make_pair(name,format))) != m_attrMap.end()) {\r
@@ -439,17 +442,60 @@ void SimpleResolverImpl::resolve(
             auto_ptr_char temp2(format);\r
             if ((rule=m_attrMap.find(make_pair(temp1.get(),temp2.get()))) != m_attrMap.end()) {\r
 #endif\r
-                if (aset.empty() || aset.count(rule->second.second)) {\r
+                if (!attributes || attributes->count(rule->second.second)) {\r
                     resolved.push_back(\r
                         rule->second.first->decode(rule->second.second.c_str(), *a, assertingParty.get(), relyingParty)\r
                         );\r
                 }\r
             }\r
         }\r
+\r
+        const vector<saml2::EncryptedAttribute*>& encattrs = const_cast<const saml2::AttributeStatement*>(*s)->getEncryptedAttributes();\r
+        if (!encattrs.empty()) {\r
+            const XMLCh* recipient = ctx.getApplication().getXMLString("providerId").second;\r
+            CredentialResolver* cr = ctx.getApplication().getCredentialResolver();\r
+            if (!cr) {\r
+                Category::getInstance(SHIBSP_LOGCAT".AttributeResolver").warn(\r
+                    "found encrypted attributes, but no CredentialResolver was available"\r
+                    );\r
+                return;\r
+            }\r
+\r
+            // We look up credentials based on the peer who did the encrypting.\r
+            CredentialCriteria cc;\r
+            cc.setPeerName(assertingParty.get());\r
+\r
+            Locker credlocker(cr);\r
+            for (vector<saml2::EncryptedAttribute*>::const_iterator ea = encattrs.begin(); ea!=encattrs.end(); ++ea) {\r
+                auto_ptr<XMLObject> decrypted((*ea)->decrypt(*cr, recipient, &cc));\r
+                const saml2::Attribute* decattr = dynamic_cast<const saml2::Attribute*>(decrypted.get());\r
+                name = decattr->getName();\r
+                format = decattr->getNameFormat();\r
+                if (!name || !*name)\r
+                    continue;\r
+                if (!format || !*format)\r
+                    format = saml2::Attribute::UNSPECIFIED;\r
+                else if (XMLString::equals(format, saml2::Attribute::URI_REFERENCE))\r
+                    format = &chNull;\r
+#ifdef HAVE_GOOD_STL\r
+                if ((rule=m_attrMap.find(make_pair(name,format))) != m_attrMap.end()) {\r
+#else\r
+                auto_ptr_char temp1(name);\r
+                auto_ptr_char temp2(format);\r
+                if ((rule=m_attrMap.find(make_pair(temp1.get(),temp2.get()))) != m_attrMap.end()) {\r
+#endif\r
+                    if (!attributes || attributes->count(rule->second.second)) {\r
+                        resolved.push_back(\r
+                            rule->second.first->decode(rule->second.second.c_str(), decattr, assertingParty.get(), relyingParty)\r
+                            );\r
+                    }\r
+                }\r
+            }\r
+        }\r
     }\r
 }\r
 \r
-void SimpleResolverImpl::query(ResolutionContext& ctx, const NameIdentifier& nameid, const vector<const char*>* attributes) const\r
+void SimpleResolverImpl::query(ResolutionContext& ctx, const NameIdentifier& nameid, const set<string>* attributes) const\r
 {\r
 #ifdef _DEBUG\r
     xmltooling::NDC ndc("query");\r
@@ -474,6 +520,7 @@ void SimpleResolverImpl::query(ResolutionContext& ctx, const NameIdentifier& nam
     }\r
 \r
     SecurityPolicy policy;\r
+    MetadataCredentialCriteria mcc(*AA);\r
     shibsp::SOAPClient soaper(ctx.getApplication(),policy);\r
     const PropertySet* policySettings = ctx.getApplication().getServiceProvider().getPolicySettings(ctx.getApplication().getString("policyId").second);\r
     pair<bool,bool> signedAssertions = policySettings->getBool("signedAssertions");\r
@@ -495,8 +542,13 @@ void SimpleResolverImpl::query(ResolutionContext& ctx, const NameIdentifier& nam
             request->setAttributeQuery(query);\r
             query->setResource(issuer.get());\r
             request->setMinorVersion(version);\r
+            if (attributes) {\r
+                for (set<string>::const_iterator a = attributes->begin(); a!=attributes->end(); ++a)\r
+                    populateQuery(*query, *a);\r
+            }\r
+\r
             SAML1SOAPClient client(soaper);\r
-            client.sendSAML(request, *AA, loc.get());\r
+            client.sendSAML(request, mcc, loc.get());\r
             response = client.receiveSAML();\r
         }\r
         catch (exception& ex) {\r
@@ -538,7 +590,30 @@ void SimpleResolverImpl::query(ResolutionContext& ctx, const NameIdentifier& nam
     resolve(ctx, newtoken, attributes);\r
 }\r
 \r
-void SimpleResolverImpl::query(ResolutionContext& ctx, const NameID& nameid, const vector<const char*>* attributes) const\r
+void SimpleResolverImpl::populateQuery(saml1p::AttributeQuery& query, const string& id) const\r
+{\r
+    for (attrmap_t::const_iterator i = m_attrMap.begin(); i!=m_attrMap.end(); ++i) {\r
+        if (i->second.second == id) {\r
+            AttributeDesignator* a = AttributeDesignatorBuilder::buildAttributeDesignator();\r
+#ifdef HAVE_GOOD_STL\r
+            a->setAttributeName(i->first.first.c_str());\r
+            a->setAttributeNamespace(i->first.second.empty() ? shibspconstants::SHIB1_ATTRIBUTE_NAMESPACE_URI : i->first.second.c_str());\r
+#else\r
+            auto_ptr_XMLCh n(i->first.first.c_str());\r
+            a->setAttributeName(n.get());\r
+            if (i->first.second.empty())\r
+                a->setAttributeNamespace(shibspconstants::SHIB1_ATTRIBUTE_NAMESPACE_URI);\r
+            else {\r
+                auto_ptr_XMLCh ns(i->first.second.c_str());\r
+                a->setAttributeNamespace(ns.get());\r
+            }\r
+#endif\r
+            query.getAttributeDesignators().push_back(a);\r
+        }\r
+    }\r
+}\r
+\r
+void SimpleResolverImpl::query(ResolutionContext& ctx, const NameID& nameid, const set<string>* attributes) const\r
 {\r
 #ifdef _DEBUG\r
     xmltooling::NDC ndc("query");\r
@@ -557,6 +632,7 @@ void SimpleResolverImpl::query(ResolutionContext& ctx, const NameID& nameid, con
     }\r
 \r
     SecurityPolicy policy;\r
+    MetadataCredentialCriteria mcc(*AA);\r
     shibsp::SOAPClient soaper(ctx.getApplication(),policy);\r
     const PropertySet* policySettings = ctx.getApplication().getServiceProvider().getPolicySettings(ctx.getApplication().getString("policyId").second);\r
     pair<bool,bool> signedAssertions = policySettings->getBool("signedAssertions");\r
@@ -577,8 +653,13 @@ void SimpleResolverImpl::query(ResolutionContext& ctx, const NameID& nameid, con
             Issuer* iss = IssuerBuilder::buildIssuer();\r
             query->setIssuer(iss);\r
             iss->setName(issuer.get());\r
+            if (attributes) {\r
+                for (set<string>::const_iterator a = attributes->begin(); a!=attributes->end(); ++a)\r
+                    populateQuery(*query, *a);\r
+            }\r
+\r
             SAML2SOAPClient client(soaper);\r
-            client.sendSAML(query, *AA, loc.get());\r
+            client.sendSAML(query, mcc, loc.get());\r
             srt = client.receiveSAML();\r
         }\r
         catch (exception& ex) {\r
@@ -626,7 +707,30 @@ void SimpleResolverImpl::query(ResolutionContext& ctx, const NameID& nameid, con
     resolve(ctx, newtoken, attributes);\r
 }\r
 \r
-void SimpleResolver::resolveAttributes(ResolutionContext& ctx, const vector<const char*>* attributes) const\r
+void SimpleResolverImpl::populateQuery(saml2p::AttributeQuery& query, const string& id) const\r
+{\r
+    for (attrmap_t::const_iterator i = m_attrMap.begin(); i!=m_attrMap.end(); ++i) {\r
+        if (i->second.second == id) {\r
+            saml2::Attribute* a = saml2::AttributeBuilder::buildAttribute();\r
+#ifdef HAVE_GOOD_STL\r
+            a->setName(i->first.first.c_str());\r
+            a->setNameFormat(i->first.second.empty() ? saml2::Attribute::URI_REFERENCE : i->first.second.c_str());\r
+#else\r
+            auto_ptr_XMLCh n(i->first.first.c_str());\r
+            a->setName(n.get());\r
+            if (i->first.second.empty())\r
+                a->setNameFormat(saml2::Attribute::URI_REFERENCE);\r
+            else {\r
+                auto_ptr_XMLCh ns(i->first.second.c_str());\r
+                a->setNameFormat(ns.get());\r
+            }\r
+#endif\r
+            query.getAttributes().push_back(a);\r
+        }\r
+    }\r
+}\r
+\r
+void SimpleResolver::resolveAttributes(ResolutionContext& ctx, const set<string>* attributes) const\r
 {\r
 #ifdef _DEBUG\r
     xmltooling::NDC ndc("resolveAttributes");\r
@@ -659,11 +763,17 @@ void SimpleResolver::resolveAttributes(ResolutionContext& ctx, const vector<cons
 \r
     if (query) {\r
         if (token1 && !token1->getAuthenticationStatements().empty()) {\r
-            log.debug("attempting SAML 1.x attribute query");\r
-            return m_impl->query(ctx, *(token1->getAuthenticationStatements().front()->getSubject()->getNameIdentifier()), attributes);\r
+            const AuthenticationStatement* statement = token1->getAuthenticationStatements().front();\r
+            if (statement && statement->getSubject() && statement->getSubject()->getNameIdentifier()) {\r
+                log.debug("attempting SAML 1.x attribute query");\r
+                return m_impl->query(ctx, *(statement->getSubject()->getNameIdentifier()), attributes);\r
+            }\r
+        }\r
+        else if (token2 && ctx.getNameID()) {\r
+            log.debug("attempting SAML 2.0 attribute query");\r
+            return m_impl->query(ctx, *ctx.getNameID(), attributes);\r
         }\r
-        log.debug("attempting SAML 2.0 attribute query");\r
-        m_impl->query(ctx, ctx.getNameID(), attributes);\r
+        log.warn("can't attempt attribute query, no identifier in assertion subject");\r
     }\r
 }\r
 \r