Add name-based API to GSSRequest
[shibboleth/cpp-sp-resolver.git] / src / shibresolver / resolver.cpp
index 287adda..dd7bbbf 100644 (file)
@@ -1,9 +1,8 @@
-/*
- *  Copyright 2010-2011 JANET(UK)
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
+/**
+ * See the NOTICE file distributed with this work for information
+ * regarding copyright ownership. Licensed under the Apache License,
+ * Version 2.0 (the "License"); you may not use this file except in
+ * compliance with the License. You may obtain a copy of the License at
  *
  *     http://www.apache.org/licenses/LICENSE-2.0
  *
 
 #include "internal.h"
 
+#ifdef SHIBRESOLVER_HAVE_GSSAPI_NAMINGEXTS
+# ifdef SHIBRESOLVER_HAVE_GSSMIT
+#  include <gssapi/gssapi_ext.h>
+# endif
+#endif
+
 #include <shibsp/exceptions.h>
 #include <shibsp/Application.h>
 #include <shibsp/GSSRequest.h>
@@ -44,6 +49,7 @@
 #include <xmltooling/XMLToolingConfig.h>
 #include <xmltooling/impl/AnyElement.h>
 #include <xmltooling/util/ParserPool.h>
+#include <xmltooling/util/Threads.h>
 #include <xmltooling/util/XMLHelper.h>
 #include <xercesc/util/Base64.hpp>
 
@@ -78,13 +84,35 @@ namespace shibresolver {
         void resolve(
             const Application& app,
             const char* issuer,
+            const XMLCh* protocol,
             const vector<const XMLObject*>& tokens,
             const vector<Attribute*>& inputAttrs,
-            vector <Attribute*>& resolvedAttrs
+            vector<Attribute*>& resolvedAttrs
+            ) const;
+
+    private:
+#ifndef SHIBSP_LITE
+        void resolve(
+            AttributeExtractor* extractor,
+            const Application& app,
+            const RoleDescriptor* issuer,
+            const XMLObject& token,
+            vector<Attribute*>& resolvedAttrs
+            ) const;
+
+        const RoleDescriptor* lookup(
+            const Application& app,
+            MetadataProvider* m,
+            const char* entityID,
+            const XMLCh* protocol
             ) const;
+#endif
     };
 
     static RemotedResolver g_Remoted;
+
+    static int g_initCount = 0;
+    static auto_ptr<Mutex> g_lock(Mutex::create());
 };
 
 ShibbolethResolver* ShibbolethResolver::create()
@@ -116,9 +144,16 @@ void ShibbolethResolver::setRequest(const SPRequest* request)
     if (request) {
         const GSSRequest* gss = dynamic_cast<const GSSRequest*>(request);
         if (gss) {
-            // TODO: fix API to prevent destruction of contexts
+#ifdef SHIBRESOLVER_HAVE_GSSAPI_NAMINGEXTS
+            gss_name_t name = gss->getGSSName();
+            if (name != GSS_C_NO_NAME) {
+                addToken(name);
+                return;
+            }
+#endif
             gss_ctx_id_t ctx = gss->getGSSContext();
-            addToken(&ctx);
+            if (ctx != GSS_C_NO_CONTEXT)
+                addToken(&ctx);
         }
     }
 #endif
@@ -138,6 +173,13 @@ void ShibbolethResolver::setIssuer(const char* issuer)
         m_issuer = issuer;
 }
 
+void ShibbolethResolver::setProtocol(const XMLCh* protocol)
+{
+    m_protocol.erase();
+    if (protocol)
+        m_protocol = protocol;
+}
+
 void ShibbolethResolver::addToken(const XMLObject* token)
 {
     if (token)
@@ -153,12 +195,28 @@ void ShibbolethResolver::addToken(gss_ctx_id_t* ctx)
     }
 
     if (ctx && *ctx != GSS_C_NO_CONTEXT) {
-        OM_uint32 major, minor;
+        OM_uint32 minor;
         gss_buffer_desc contextbuf = GSS_C_EMPTY_BUFFER;
-
-        major = gss_export_sec_context(&minor, ctx, &contextbuf);
+        OM_uint32 major = gss_export_sec_context(&minor, ctx, &contextbuf);
         if (major == GSS_S_COMPLETE) {
-            addToken(&contextbuf);
+            xsecsize_t len=0;
+            XMLByte* out=Base64::encode(reinterpret_cast<const XMLByte*>(contextbuf.value), contextbuf.length, &len);
+            if (out) {
+                string s;
+                s.append(reinterpret_cast<char*>(out), len);
+                auto_ptr_XMLCh temp(s.c_str());
+#ifdef SHIBSP_XERCESC_HAS_XMLBYTE_RELEASE
+                XMLString::release(&out);
+#else
+                XMLString::release((char**)&out);
+#endif
+                static const XMLCh _GSSAPI[] = UNICODE_LITERAL_13(G,S,S,A,P,I,C,o,n,t,e,x,t);
+                m_gsswrapper = new AnyElementImpl(shibspconstants::SHIB2ATTRIBUTEMAP_NS, _GSSAPI);
+                m_gsswrapper->setTextContent(temp.get());
+            }
+            else {
+                Category::getInstance(SHIBRESOLVER_LOGCAT).error("error while base64-encoding GSS context");
+            }
             gss_release_buffer(&minor, &contextbuf);
         }
         else {
@@ -167,6 +225,27 @@ void ShibbolethResolver::addToken(gss_ctx_id_t* ctx)
     }
 }
 
+#ifdef SHIBRESOLVER_HAVE_GSSAPI_NAMINGEXTS
+void ShibbolethResolver::addToken(gss_name_t name)
+{
+    if (m_gsswrapper) {
+        delete m_gsswrapper;
+        m_gsswrapper = NULL;
+    }
+
+    OM_uint32 minor;
+    gss_buffer_desc namebuf = GSS_C_EMPTY_BUFFER;
+    OM_uint32 major = gss_export_name_composite(&minor, name, &namebuf);
+    if (major == GSS_S_COMPLETE) {
+        addToken(&namebuf);
+        gss_release_buffer(&minor, &namebuf);
+    }
+    else {
+        Category::getInstance(SHIBRESOLVER_LOGCAT).error("error exporting GSS name");
+    }
+}
+#endif
+
 void ShibbolethResolver::addToken(const gss_buffer_t contextbuf)
 {
     if (m_gsswrapper) {
@@ -185,14 +264,15 @@ void ShibbolethResolver::addToken(const gss_buffer_t contextbuf)
 #else
         XMLString::release((char**)&out);
 #endif
-        static const XMLCh _GSSAPI[] = UNICODE_LITERAL_13(G,S,S,A,P,I,C,o,n,t,e,x,t);
+        static const XMLCh _GSSAPI[] = UNICODE_LITERAL_10(G,S,S,A,P,I,N,a,m,e);
         m_gsswrapper = new AnyElementImpl(shibspconstants::SHIB2ATTRIBUTEMAP_NS, _GSSAPI);
         m_gsswrapper->setTextContent(temp.get());
     }
     else {
-        Category::getInstance(SHIBRESOLVER_LOGCAT).error("error while base64-encoding GSS context");
+        Category::getInstance(SHIBRESOLVER_LOGCAT).error("error while base64-encoding GSS name");
     }
 }
+
 #endif
 
 void ShibbolethResolver::addAttribute(Attribute* attr)
@@ -237,6 +317,7 @@ void ShibbolethResolver::resolve()
         g_Remoted.resolve(
             *app,
             m_issuer.c_str(),
+            m_protocol.c_str(),
             m_tokens,
             m_inputAttributes,
             m_resolvedAttributes
@@ -249,6 +330,10 @@ void ShibbolethResolver::resolve()
         in.addmember("application_id").string(app->getId());
         if (!m_issuer.empty())
             in.addmember("issuer").string(m_issuer.c_str());
+        if (!m_protocol.empty()) {
+            auto_ptr_char prot(m_protocol.c_str());
+            in.addmember("protocol").string(prot.get());
+        }
         if (!m_tokens.empty()) {
             DDF& tokens = in.addmember("tokens").list();
             for (vector<const XMLObject*>::const_iterator t = m_tokens.begin(); t != m_tokens.end(); ++t) {
@@ -330,7 +415,9 @@ void RemotedResolver::receive(DDF& in, ostream& out)
         attr = alist.next();
     }
 
-    resolve(*app, in["issuer"].string(), t.tokens, t.inputAttrs, t.resolvedAttrs);
+    auto_ptr_XMLCh prot(in["protocol"].string());
+
+    resolve(*app, in["issuer"].string(), prot.get(), t.tokens, t.inputAttrs, t.resolvedAttrs);
 
     if (!t.resolvedAttrs.empty()) {
         ret.list();
@@ -346,43 +433,34 @@ void RemotedResolver::receive(DDF& in, ostream& out)
 void RemotedResolver::resolve(
     const Application& app,
     const char* issuer,
+    const XMLCh* protocol,
     const vector<const XMLObject*>& tokens,
     const vector<Attribute*>& inputAttrs,
-    vector <Attribute*>& resolvedAttrs
+    vector<Attribute*>& resolvedAttrs
     ) const
 {
 #ifndef SHIBSP_LITE
     Category& log = Category::getInstance(SHIBRESOLVER_LOGCAT);
-    pair<const EntityDescriptor*,const RoleDescriptor*> entity = make_pair((EntityDescriptor*)NULL, (RoleDescriptor*)NULL);
     MetadataProvider* m = app.getMetadataProvider(false);
     Locker locker(m);
-    if (!m) {
-        log.warn("no metadata providers are configured");
-    }
-    else if (issuer && *issuer) {
-        // Lookup metadata for the issuer.
-        MetadataProviderCriteria mc(app, issuer, &IDPSSODescriptor::ELEMENT_QNAME, samlconstants::SAML20P_NS);
-        entity = m->getEntityDescriptor(mc);
-        if (!entity.first) {
-            log.warn("unable to locate metadata for provider (%s)", issuer);
-        }
-        else if (!entity.second) {
-            log.warn("unable to locate SAML 2.0 identity provider role for provider (%s)", issuer);
-        }
-    }
+
+    const RoleDescriptor* role = NULL;
+    if (issuer && *issuer)
+        role = lookup(app, m, issuer, protocol);
 
     vector<const Assertion*> assertions;
 
     AttributeExtractor* extractor = app.getAttributeExtractor();
     if (extractor) {
         Locker extlocker(extractor);
-        if (entity.second) {
+        // Support metadata-based attributes for only the "top-level" issuer.
+        if (role) {
             pair<bool,const char*> mprefix = app.getString("metadataAttributePrefix");
             if (mprefix.first) {
                 log.debug("extracting metadata-derived attributes...");
                 try {
-                    // We pass NULL for "issuer" because the IdP isn't the one asserting metadata-based attributes.
-                    extractor->extractAttributes(app, NULL, *entity.second, resolvedAttrs);
+                    // We pass NULL for "issuer" because the issuer isn't the one asserting metadata-based attributes.
+                    extractor->extractAttributes(app, NULL, *role, resolvedAttrs);
                     for (vector<Attribute*>::iterator a = resolvedAttrs.begin(); a != resolvedAttrs.end(); ++a) {
                         vector<string>& ids = (*a)->getAliases();
                         for (vector<string>::iterator id = ids.begin(); id != ids.end(); ++id)
@@ -394,33 +472,23 @@ void RemotedResolver::resolve(
                 }
             }
         }
-        log.debug("extracting pushed attributes...");
-        for (vector<const XMLObject*>::const_iterator t = tokens.begin(); t!=tokens.end(); ++t) {
-            try {
-                // Save off any assertions for later use by resolver.
-                const Assertion* assertion = dynamic_cast<const Assertion*>(*t);
-                if (assertion)
-                    assertions.push_back(assertion);
-                extractor->extractAttributes(app, entity.second, *(*t), resolvedAttrs);
-            }
-            catch (exception& ex) {
-                log.error("caught exception extracting attributes: %s", ex.what());
-            }
-        }
 
-        AttributeFilter* filter = app.getAttributeFilter();
-        if (filter && !resolvedAttrs.empty()) {
-            BasicFilteringContext fc(app, resolvedAttrs, entity.second);
-            Locker filtlocker(filter);
-            try {
-                filter->filterAttributes(fc, resolvedAttrs);
-            }
-            catch (exception& ex) {
-                log.error("caught exception filtering attributes: %s", ex.what());
-                log.error("dumping extracted attributes due to filtering exception");
-                for_each(resolvedAttrs.begin(), resolvedAttrs.end(), xmltooling::cleanup<shibsp::Attribute>());
-                resolvedAttrs.clear();
+        log.debug("extracting pushed attributes...");
+        const RoleDescriptor* role2;
+        for (vector<const XMLObject*>::const_iterator t = tokens.begin(); t != tokens.end(); ++t) {
+            // Save off any assertions for later use by resolver.
+            role2 = NULL;
+            const Assertion* assertion = dynamic_cast<const Assertion*>(*t);
+            if (assertion) {
+                assertions.push_back(assertion);
+                const saml2::Assertion* saml2token = dynamic_cast<const saml2::Assertion*>(assertion);
+                if (saml2token && saml2token->getIssuer() && (saml2token->getIssuer()->getFormat() == NULL ||
+                        XMLString::equals(saml2token->getIssuer()->getFormat(), saml2::NameID::ENTITY))) {
+                    auto_ptr_char tokenissuer(saml2token->getIssuer()->getName());
+                    role2 = lookup(app, m, tokenissuer.get(), protocol);
+                }
             }
+            resolve(extractor, app, (role2 ? role2 : role), *(*t), resolvedAttrs);
         }
     }
     else {
@@ -430,7 +498,7 @@ void RemotedResolver::resolve(
     try {
         AttributeResolver* resolver = app.getAttributeResolver();
         if (resolver) {
-            log.debug("resolving attributes...");
+            log.debug("resolving additional attributes...");
 
             vector<Attribute*> inputs = inputAttrs;
             inputs.insert(inputs.end(), resolvedAttrs.begin(), resolvedAttrs.end());
@@ -439,8 +507,8 @@ void RemotedResolver::resolve(
             auto_ptr<ResolutionContext> ctx(
                 resolver->createResolutionContext(
                     app,
-                    entity.first,
-                    samlconstants::SAML20P_NS,
+                    role ? dynamic_cast<const EntityDescriptor*>(role->getParent()) : NULL,
+                    protocol ? protocol : samlconstants::SAML20P_NS,
                     NULL,
                     NULL,
                     NULL,
@@ -461,16 +529,95 @@ void RemotedResolver::resolve(
 #endif
 }
 
+#ifndef SHIBSP_LITE
+
+void RemotedResolver::resolve(
+    AttributeExtractor* extractor,
+    const Application& app,
+    const RoleDescriptor* issuer,
+    const XMLObject& token,
+    vector<Attribute*>& resolvedAttrs
+    ) const
+{
+    vector<Attribute*> extractedAttrs;
+    try {
+        extractor->extractAttributes(app, issuer, token, extractedAttrs);
+    }
+    catch (exception& ex) {
+        Category::getInstance(SHIBRESOLVER_LOGCAT).error("caught exception extracting attributes: %s", ex.what());
+    }
+
+    AttributeFilter* filter = app.getAttributeFilter();
+    if (filter && !extractedAttrs.empty()) {
+        BasicFilteringContext fc(app, extractedAttrs, issuer);
+        Locker filtlocker(filter);
+        try {
+            filter->filterAttributes(fc, extractedAttrs);
+        }
+        catch (exception& ex) {
+            Category::getInstance(SHIBRESOLVER_LOGCAT).error("caught exception filtering attributes: %s", ex.what());
+            Category::getInstance(SHIBRESOLVER_LOGCAT).error("dumping extracted attributes due to filtering exception");
+            for_each(extractedAttrs.begin(), extractedAttrs.end(), xmltooling::cleanup<shibsp::Attribute>());
+            extractedAttrs.clear();
+        }
+    }
+
+    resolvedAttrs.insert(resolvedAttrs.end(), extractedAttrs.begin(), extractedAttrs.end());
+}
+
+const RoleDescriptor* RemotedResolver::lookup(
+    const Application& app, MetadataProvider* m, const char* entityID, const XMLCh* protocol
+    ) const
+{
+    if (!m)
+        return NULL;
+
+    MetadataProviderCriteria idpmc(app, entityID, &IDPSSODescriptor::ELEMENT_QNAME, protocol ? protocol : samlconstants::SAML20P_NS);
+    if (protocol)
+        idpmc.protocol2 = samlconstants::SAML20P_NS;
+    pair<const EntityDescriptor*,const RoleDescriptor*> entity = m->getEntityDescriptor(idpmc);
+    if (!entity.first) {
+        Category::getInstance(SHIBRESOLVER_LOGCAT).warn("unable to locate metadata for provider (%s)", entityID);
+    }
+    else if (!entity.second) {
+        MetadataProviderCriteria aamc(
+            app, entityID, &AttributeAuthorityDescriptor::ELEMENT_QNAME, protocol ? protocol : samlconstants::SAML20P_NS
+            );
+        if (protocol)
+            aamc.protocol2 = samlconstants::SAML20P_NS;
+        entity = m->getEntityDescriptor(aamc);
+        if (!entity.second) {
+            Category::getInstance(SHIBRESOLVER_LOGCAT).warn("unable to locate compatible IdP or AA role for provider (%s)", entityID);
+        }
+    }
+
+    return entity.second;
+}
+
+#endif
+
 bool ShibbolethResolver::init(unsigned long features, const char* config, bool rethrow)
 {
-    if (features && SPConfig::OutOfProcess) {
+    Lock initLock(g_lock.get());
+
+    if (g_initCount == INT_MAX) {
+        Category::getInstance(SHIBRESOLVER_LOGCAT".Config").crit("library initialized too many times");
+        return false;
+    }
+
+    if (g_initCount >= 1) {
+        ++g_initCount;
+        return true;
+    }
+
+    if (features & SPConfig::OutOfProcess) {
 #ifndef SHIBSP_LITE
         features = features | SPConfig::AttributeResolution | SPConfig::Metadata | SPConfig::Trust | SPConfig::Credentials;
 #endif
-        if (!(features && SPConfig::InProcess))
+        if (!(features & SPConfig::InProcess))
             features |= SPConfig::Listener;
     }
-    else if (features && SPConfig::InProcess) {
+    else if (features & SPConfig::InProcess) {
         features |= SPConfig::Listener;
     }
     SPConfig::getConfig().setFeatures(features);
@@ -478,16 +625,22 @@ bool ShibbolethResolver::init(unsigned long features, const char* config, bool r
         return false;
     if (!SPConfig::getConfig().instantiate(config, rethrow))
         return false;
+
+    ++g_initCount;
     return true;
 }
 
-/**
-    * Shuts down runtime.
-    *
-    * Each process using the library SHOULD call this function exactly once before terminating itself.
-    */
 void ShibbolethResolver::term()
 {
+    Lock initLock(g_lock.get());
+    if (g_initCount == 0) {
+        Category::getInstance(SHIBRESOLVER_LOGCAT".Config").crit("term without corresponding init");
+        return;
+    }
+    else if (--g_initCount > 0) {
+        return;
+    }
+
     SPConfig::getConfig().term();
 }