Backport key compare approach to certificate validation.
[shibboleth/cpp-sp.git] / shib / BasicTrust.cpp
index 3de204e..c588a42 100644 (file)
@@ -25,6 +25,8 @@
 #include "internal.h"
 
 #include <openssl/x509.h>
+#include <xsec/enc/OpenSSL/OpenSSLCryptoKeyDSA.hpp>
+#include <xsec/enc/OpenSSL/OpenSSLCryptoKeyRSA.hpp>
 #include <xsec/enc/OpenSSL/OpenSSLCryptoX509.hpp>
 
 using namespace shibboleth::logging;
@@ -104,7 +106,7 @@ bool BasicTrust::validate(void* certEE, const Iterator<void*>& certChain, const
     // The new "basic" trust implementation relies solely on certificates living within the
     // role interface to verify the EE certificate.
 
-    log.debug("comparing certificate to KeyDescriptors");
+    log.debug("comparing key inside certificate to KeyDescriptors");
     Iterator<const IKeyDescriptor*> kd_i=role->getKeyDescriptors();
     while (kd_i.hasNext()) {
         const IKeyDescriptor* kd=kd_i.next();
@@ -115,24 +117,56 @@ bool BasicTrust::validate(void* certEE, const Iterator<void*>& certChain, const
             continue;
         Iterator<KeyInfoResolver*> resolvers(m_resolvers);
         while (resolvers.hasNext()) {
-            XSECCryptoX509* cert=resolvers.next()->resolveCert(KIL);
-            if (cert) {
-                log.debug("KeyDescriptor resolved into a certificate, comparing it...");
-                if (cert->getProviderName()!=DSIGConstants::s_unicodeStrPROVOpenSSL) {
-                    log.warn("only the OpenSSL XSEC provider is supported");
+            XSECCryptoKey* key=((XSECKeyInfoResolver*)*resolvers.next())->resolveKey(KIL);
+            if (key) {
+                log.debug("KeyDescriptor resolved into a key, comparing it...");
+                if (key->getProviderName()!=DSIGConstants::s_unicodeStrPROVOpenSSL) {
+                    log.error("only the OpenSSL XSEC provider is supported");
                     continue;
                 }
-                else if (!X509_cmp(reinterpret_cast<X509*>(certEE),static_cast<OpenSSLCryptoX509*>(cert)->getOpenSSLX509())) {
-                    log.info("certificate match found in KeyDescriptor");
-                    return true;
+
+                switch (key->getKeyType()) {
+                    case XSECCryptoKey::KEY_RSA_PUBLIC:
+                    case XSECCryptoKey::KEY_RSA_PAIR:
+                    {
+                        RSA* rsa = static_cast<OpenSSLCryptoKeyRSA*>(key)->getOpenSSLRSA();
+                        EVP_PKEY* evp = X509_PUBKEY_get(X509_get_X509_PUBKEY(reinterpret_cast<X509*>(certEE)));
+                        if (rsa && evp && evp->type == EVP_PKEY_RSA &&
+                                BN_cmp(rsa->n,evp->pkey.rsa->n) == 0 && BN_cmp(rsa->e,evp->pkey.rsa->e) == 0) {
+                            if (evp)
+                                EVP_PKEY_free(evp);
+                            log.debug("matching key found in KeyDescriptor");
+                            return true;
+                        }
+                        if (evp)
+                            EVP_PKEY_free(evp);
+                        break;
+                    }
+                
+                    case XSECCryptoKey::KEY_DSA_PUBLIC:
+                    case XSECCryptoKey::KEY_DSA_PAIR:
+                    {
+                        DSA* dsa = static_cast<OpenSSLCryptoKeyDSA*>(key)->getOpenSSLDSA();
+                        EVP_PKEY* evp = X509_PUBKEY_get(X509_get_X509_PUBKEY(reinterpret_cast<X509*>(certEE)));
+                        if (dsa && evp && evp->type == EVP_PKEY_DSA && BN_cmp(dsa->pub_key,evp->pkey.dsa->pub_key) == 0) {
+                            if (evp)
+                                EVP_PKEY_free(evp);
+                            log.debug("matching key found in KeyDescriptor");
+                            return true;
+                        }
+                        if (evp)
+                            EVP_PKEY_free(evp);
+                        break;
+                    }
+
+                    default:
+                        log.warn("unknown key type in KeyDescriptor, skipping...");
                 }
-                else
-                    log.debug("certificate did not match");
             }
         }
     }
     
-    log.debug("failed to find an exact match for certificate in KeyDescriptors");
+    log.debug("failed to find a matching key for certificate in KeyDescriptors");
     return false;
 }