Add name-based API to GSSRequest
[shibboleth/cpp-sp-resolver.git] / src / shibresolver / resolver.cpp
index 79c3537..dd7bbbf 100644 (file)
@@ -1,9 +1,8 @@
-/*
- *  Copyright 2010-2011 JANET(UK)
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
+/**
+ * See the NOTICE file distributed with this work for information
+ * regarding copyright ownership. Licensed under the Apache License,
+ * Version 2.0 (the "License"); you may not use this file except in
+ * compliance with the License. You may obtain a copy of the License at
  *
  *     http://www.apache.org/licenses/LICENSE-2.0
  *
@@ -50,6 +49,7 @@
 #include <xmltooling/XMLToolingConfig.h>
 #include <xmltooling/impl/AnyElement.h>
 #include <xmltooling/util/ParserPool.h>
+#include <xmltooling/util/Threads.h>
 #include <xmltooling/util/XMLHelper.h>
 #include <xercesc/util/Base64.hpp>
 
@@ -84,13 +84,35 @@ namespace shibresolver {
         void resolve(
             const Application& app,
             const char* issuer,
+            const XMLCh* protocol,
             const vector<const XMLObject*>& tokens,
             const vector<Attribute*>& inputAttrs,
-            vector <Attribute*>& resolvedAttrs
+            vector<Attribute*>& resolvedAttrs
+            ) const;
+
+    private:
+#ifndef SHIBSP_LITE
+        void resolve(
+            AttributeExtractor* extractor,
+            const Application& app,
+            const RoleDescriptor* issuer,
+            const XMLObject& token,
+            vector<Attribute*>& resolvedAttrs
+            ) const;
+
+        const RoleDescriptor* lookup(
+            const Application& app,
+            MetadataProvider* m,
+            const char* entityID,
+            const XMLCh* protocol
             ) const;
+#endif
     };
 
     static RemotedResolver g_Remoted;
+
+    static int g_initCount = 0;
+    static auto_ptr<Mutex> g_lock(Mutex::create());
 };
 
 ShibbolethResolver* ShibbolethResolver::create()
@@ -122,9 +144,16 @@ void ShibbolethResolver::setRequest(const SPRequest* request)
     if (request) {
         const GSSRequest* gss = dynamic_cast<const GSSRequest*>(request);
         if (gss) {
-            // TODO: fix API to prevent destruction of contexts
+#ifdef SHIBRESOLVER_HAVE_GSSAPI_NAMINGEXTS
+            gss_name_t name = gss->getGSSName();
+            if (name != GSS_C_NO_NAME) {
+                addToken(name);
+                return;
+            }
+#endif
             gss_ctx_id_t ctx = gss->getGSSContext();
-            addToken(&ctx);
+            if (ctx != GSS_C_NO_CONTEXT)
+                addToken(&ctx);
         }
     }
 #endif
@@ -144,6 +173,13 @@ void ShibbolethResolver::setIssuer(const char* issuer)
         m_issuer = issuer;
 }
 
+void ShibbolethResolver::setProtocol(const XMLCh* protocol)
+{
+    m_protocol.erase();
+    if (protocol)
+        m_protocol = protocol;
+}
+
 void ShibbolethResolver::addToken(const XMLObject* token)
 {
     if (token)
@@ -281,6 +317,7 @@ void ShibbolethResolver::resolve()
         g_Remoted.resolve(
             *app,
             m_issuer.c_str(),
+            m_protocol.c_str(),
             m_tokens,
             m_inputAttributes,
             m_resolvedAttributes
@@ -293,6 +330,10 @@ void ShibbolethResolver::resolve()
         in.addmember("application_id").string(app->getId());
         if (!m_issuer.empty())
             in.addmember("issuer").string(m_issuer.c_str());
+        if (!m_protocol.empty()) {
+            auto_ptr_char prot(m_protocol.c_str());
+            in.addmember("protocol").string(prot.get());
+        }
         if (!m_tokens.empty()) {
             DDF& tokens = in.addmember("tokens").list();
             for (vector<const XMLObject*>::const_iterator t = m_tokens.begin(); t != m_tokens.end(); ++t) {
@@ -374,7 +415,9 @@ void RemotedResolver::receive(DDF& in, ostream& out)
         attr = alist.next();
     }
 
-    resolve(*app, in["issuer"].string(), t.tokens, t.inputAttrs, t.resolvedAttrs);
+    auto_ptr_XMLCh prot(in["protocol"].string());
+
+    resolve(*app, in["issuer"].string(), prot.get(), t.tokens, t.inputAttrs, t.resolvedAttrs);
 
     if (!t.resolvedAttrs.empty()) {
         ret.list();
@@ -390,67 +433,34 @@ void RemotedResolver::receive(DDF& in, ostream& out)
 void RemotedResolver::resolve(
     const Application& app,
     const char* issuer,
+    const XMLCh* protocol,
     const vector<const XMLObject*>& tokens,
     const vector<Attribute*>& inputAttrs,
-    vector <Attribute*>& resolvedAttrs
+    vector<Attribute*>& resolvedAttrs
     ) const
 {
 #ifndef SHIBSP_LITE
     Category& log = Category::getInstance(SHIBRESOLVER_LOGCAT);
-    string issuerstr(issuer ? issuer : "");
-    pair<const EntityDescriptor*,const RoleDescriptor*> entity = make_pair((EntityDescriptor*)NULL, (RoleDescriptor*)NULL);
     MetadataProvider* m = app.getMetadataProvider(false);
     Locker locker(m);
-    if (!m) {
-        log.warn("no metadata providers are configured");
-    }
-    else {
-        if (issuerstr.empty()) {
-            // Attempt to locate an issuer based on input token.
-            for (vector<const XMLObject*>::const_iterator t = tokens.begin(); t!=tokens.end(); ++t) {
-                const saml2::Assertion* assertion = dynamic_cast<const saml2::Assertion*>(*t);
-                if (assertion && assertion->getIssuer()) {
-                    auto_ptr_char iss(assertion->getIssuer()->getName());
-                    if (iss.get() && *iss.get()) {
-                        issuerstr = iss.get();
-                        break;
-                    }
-                }
-            }
-            if (!issuerstr.empty()) {
-                log.info("setting issuer based on input token (%s)", issuerstr.c_str());
-            }
-        }
 
-        if (!issuerstr.empty()) {
-            // Lookup metadata for the issuer.
-            MetadataProviderCriteria idpmc(app, issuerstr.c_str(), &IDPSSODescriptor::ELEMENT_QNAME, samlconstants::SAML20P_NS);
-            entity = m->getEntityDescriptor(idpmc);
-            if (!entity.first) {
-                log.warn("unable to locate metadata for provider (%s)", issuerstr.c_str());
-            }
-            else if (!entity.second) {
-                MetadataProviderCriteria aamc(app, issuerstr.c_str(), &AttributeAuthorityDescriptor::ELEMENT_QNAME, samlconstants::SAML20P_NS);
-                entity = m->getEntityDescriptor(aamc);
-                if (!entity.second) {
-                    log.warn("unable to locate SAML 2.0 IdP or AA role for provider (%s)", issuerstr.c_str());
-                }
-            }
-        }
-    }
+    const RoleDescriptor* role = NULL;
+    if (issuer && *issuer)
+        role = lookup(app, m, issuer, protocol);
 
     vector<const Assertion*> assertions;
 
     AttributeExtractor* extractor = app.getAttributeExtractor();
     if (extractor) {
         Locker extlocker(extractor);
-        if (entity.second) {
+        // Support metadata-based attributes for only the "top-level" issuer.
+        if (role) {
             pair<bool,const char*> mprefix = app.getString("metadataAttributePrefix");
             if (mprefix.first) {
                 log.debug("extracting metadata-derived attributes...");
                 try {
-                    // We pass NULL for "issuer" because the IdP isn't the one asserting metadata-based attributes.
-                    extractor->extractAttributes(app, NULL, *entity.second, resolvedAttrs);
+                    // We pass NULL for "issuer" because the issuer isn't the one asserting metadata-based attributes.
+                    extractor->extractAttributes(app, NULL, *role, resolvedAttrs);
                     for (vector<Attribute*>::iterator a = resolvedAttrs.begin(); a != resolvedAttrs.end(); ++a) {
                         vector<string>& ids = (*a)->getAliases();
                         for (vector<string>::iterator id = ids.begin(); id != ids.end(); ++id)
@@ -462,33 +472,23 @@ void RemotedResolver::resolve(
                 }
             }
         }
-        log.debug("extracting pushed attributes...");
-        for (vector<const XMLObject*>::const_iterator t = tokens.begin(); t!=tokens.end(); ++t) {
-            try {
-                // Save off any assertions for later use by resolver.
-                const Assertion* assertion = dynamic_cast<const Assertion*>(*t);
-                if (assertion)
-                    assertions.push_back(assertion);
-                extractor->extractAttributes(app, entity.second, *(*t), resolvedAttrs);
-            }
-            catch (exception& ex) {
-                log.error("caught exception extracting attributes: %s", ex.what());
-            }
-        }
 
-        AttributeFilter* filter = app.getAttributeFilter();
-        if (filter && !resolvedAttrs.empty()) {
-            BasicFilteringContext fc(app, resolvedAttrs, entity.second);
-            Locker filtlocker(filter);
-            try {
-                filter->filterAttributes(fc, resolvedAttrs);
-            }
-            catch (exception& ex) {
-                log.error("caught exception filtering attributes: %s", ex.what());
-                log.error("dumping extracted attributes due to filtering exception");
-                for_each(resolvedAttrs.begin(), resolvedAttrs.end(), xmltooling::cleanup<shibsp::Attribute>());
-                resolvedAttrs.clear();
+        log.debug("extracting pushed attributes...");
+        const RoleDescriptor* role2;
+        for (vector<const XMLObject*>::const_iterator t = tokens.begin(); t != tokens.end(); ++t) {
+            // Save off any assertions for later use by resolver.
+            role2 = NULL;
+            const Assertion* assertion = dynamic_cast<const Assertion*>(*t);
+            if (assertion) {
+                assertions.push_back(assertion);
+                const saml2::Assertion* saml2token = dynamic_cast<const saml2::Assertion*>(assertion);
+                if (saml2token && saml2token->getIssuer() && (saml2token->getIssuer()->getFormat() == NULL ||
+                        XMLString::equals(saml2token->getIssuer()->getFormat(), saml2::NameID::ENTITY))) {
+                    auto_ptr_char tokenissuer(saml2token->getIssuer()->getName());
+                    role2 = lookup(app, m, tokenissuer.get(), protocol);
+                }
             }
+            resolve(extractor, app, (role2 ? role2 : role), *(*t), resolvedAttrs);
         }
     }
     else {
@@ -498,7 +498,7 @@ void RemotedResolver::resolve(
     try {
         AttributeResolver* resolver = app.getAttributeResolver();
         if (resolver) {
-            log.debug("resolving attributes...");
+            log.debug("resolving additional attributes...");
 
             vector<Attribute*> inputs = inputAttrs;
             inputs.insert(inputs.end(), resolvedAttrs.begin(), resolvedAttrs.end());
@@ -507,8 +507,8 @@ void RemotedResolver::resolve(
             auto_ptr<ResolutionContext> ctx(
                 resolver->createResolutionContext(
                     app,
-                    entity.first,
-                    samlconstants::SAML20P_NS,
+                    role ? dynamic_cast<const EntityDescriptor*>(role->getParent()) : NULL,
+                    protocol ? protocol : samlconstants::SAML20P_NS,
                     NULL,
                     NULL,
                     NULL,
@@ -529,16 +529,95 @@ void RemotedResolver::resolve(
 #endif
 }
 
+#ifndef SHIBSP_LITE
+
+void RemotedResolver::resolve(
+    AttributeExtractor* extractor,
+    const Application& app,
+    const RoleDescriptor* issuer,
+    const XMLObject& token,
+    vector<Attribute*>& resolvedAttrs
+    ) const
+{
+    vector<Attribute*> extractedAttrs;
+    try {
+        extractor->extractAttributes(app, issuer, token, extractedAttrs);
+    }
+    catch (exception& ex) {
+        Category::getInstance(SHIBRESOLVER_LOGCAT).error("caught exception extracting attributes: %s", ex.what());
+    }
+
+    AttributeFilter* filter = app.getAttributeFilter();
+    if (filter && !extractedAttrs.empty()) {
+        BasicFilteringContext fc(app, extractedAttrs, issuer);
+        Locker filtlocker(filter);
+        try {
+            filter->filterAttributes(fc, extractedAttrs);
+        }
+        catch (exception& ex) {
+            Category::getInstance(SHIBRESOLVER_LOGCAT).error("caught exception filtering attributes: %s", ex.what());
+            Category::getInstance(SHIBRESOLVER_LOGCAT).error("dumping extracted attributes due to filtering exception");
+            for_each(extractedAttrs.begin(), extractedAttrs.end(), xmltooling::cleanup<shibsp::Attribute>());
+            extractedAttrs.clear();
+        }
+    }
+
+    resolvedAttrs.insert(resolvedAttrs.end(), extractedAttrs.begin(), extractedAttrs.end());
+}
+
+const RoleDescriptor* RemotedResolver::lookup(
+    const Application& app, MetadataProvider* m, const char* entityID, const XMLCh* protocol
+    ) const
+{
+    if (!m)
+        return NULL;
+
+    MetadataProviderCriteria idpmc(app, entityID, &IDPSSODescriptor::ELEMENT_QNAME, protocol ? protocol : samlconstants::SAML20P_NS);
+    if (protocol)
+        idpmc.protocol2 = samlconstants::SAML20P_NS;
+    pair<const EntityDescriptor*,const RoleDescriptor*> entity = m->getEntityDescriptor(idpmc);
+    if (!entity.first) {
+        Category::getInstance(SHIBRESOLVER_LOGCAT).warn("unable to locate metadata for provider (%s)", entityID);
+    }
+    else if (!entity.second) {
+        MetadataProviderCriteria aamc(
+            app, entityID, &AttributeAuthorityDescriptor::ELEMENT_QNAME, protocol ? protocol : samlconstants::SAML20P_NS
+            );
+        if (protocol)
+            aamc.protocol2 = samlconstants::SAML20P_NS;
+        entity = m->getEntityDescriptor(aamc);
+        if (!entity.second) {
+            Category::getInstance(SHIBRESOLVER_LOGCAT).warn("unable to locate compatible IdP or AA role for provider (%s)", entityID);
+        }
+    }
+
+    return entity.second;
+}
+
+#endif
+
 bool ShibbolethResolver::init(unsigned long features, const char* config, bool rethrow)
 {
-    if (features && SPConfig::OutOfProcess) {
+    Lock initLock(g_lock.get());
+
+    if (g_initCount == INT_MAX) {
+        Category::getInstance(SHIBRESOLVER_LOGCAT".Config").crit("library initialized too many times");
+        return false;
+    }
+
+    if (g_initCount >= 1) {
+        ++g_initCount;
+        return true;
+    }
+
+    if (features & SPConfig::OutOfProcess) {
 #ifndef SHIBSP_LITE
         features = features | SPConfig::AttributeResolution | SPConfig::Metadata | SPConfig::Trust | SPConfig::Credentials;
 #endif
-        if (!(features && SPConfig::InProcess))
+        if (!(features & SPConfig::InProcess))
             features |= SPConfig::Listener;
     }
-    else if (features && SPConfig::InProcess) {
+    else if (features & SPConfig::InProcess) {
         features |= SPConfig::Listener;
     }
     SPConfig::getConfig().setFeatures(features);
@@ -546,16 +625,22 @@ bool ShibbolethResolver::init(unsigned long features, const char* config, bool r
         return false;
     if (!SPConfig::getConfig().instantiate(config, rethrow))
         return false;
+
+    ++g_initCount;
     return true;
 }
 
-/**
-    * Shuts down runtime.
-    *
-    * Each process using the library SHOULD call this function exactly once before terminating itself.
-    */
 void ShibbolethResolver::term()
 {
+    Lock initLock(g_lock.get());
+    if (g_initCount == 0) {
+        Category::getInstance(SHIBRESOLVER_LOGCAT".Config").crit("term without corresponding init");
+        return;
+    }
+    else if (--g_initCount > 0) {
+        return;
+    }
+
     SPConfig::getConfig().term();
 }