Revert to string-based return values, add linefeed option.
[shibboleth/cpp-xmltooling.git] / xmltooling / security / impl / SecurityHelper.cpp
index 0ee81ab..6370079 100644 (file)
@@ -24,6 +24,7 @@
 #include "logging.h"
 #include "security/OpenSSLCryptoX509CRL.h"
 #include "security/SecurityHelper.h"
+#include "security/X509Credential.h"
 #include "util/NDC.h"
 
 #include <fstream>
@@ -435,67 +436,72 @@ vector<XSECCryptoX509CRL*>::size_type SecurityHelper::loadCRLsFromURL(
     return loadCRLsFromFile(crls, backing, format);
 }
 
-bool SecurityHelper::matches(const XSECCryptoKey* key1, const XSECCryptoKey* key2)
+bool SecurityHelper::matches(const XSECCryptoKey& key1, const XSECCryptoKey& key2)
 {
-    if (key1->getProviderName()!=DSIGConstants::s_unicodeStrPROVOpenSSL ||
-        key2->getProviderName()!=DSIGConstants::s_unicodeStrPROVOpenSSL) {
+    if (key1.getProviderName()!=DSIGConstants::s_unicodeStrPROVOpenSSL ||
+        key2.getProviderName()!=DSIGConstants::s_unicodeStrPROVOpenSSL) {
         Category::getInstance(XMLTOOLING_LOGCAT".SecurityHelper").warn("comparison of non-OpenSSL keys not supported");
         return false;
     }
 
     // If one key is public or both, just compare the public key half.
-    if (key1->getKeyType()==XSECCryptoKey::KEY_RSA_PUBLIC || key1->getKeyType()==XSECCryptoKey::KEY_RSA_PAIR) {
-        if (key2->getKeyType()!=XSECCryptoKey::KEY_RSA_PUBLIC && key2->getKeyType()!=XSECCryptoKey::KEY_RSA_PAIR)
+    if (key1.getKeyType()==XSECCryptoKey::KEY_RSA_PUBLIC || key1.getKeyType()==XSECCryptoKey::KEY_RSA_PAIR) {
+        if (key2.getKeyType()!=XSECCryptoKey::KEY_RSA_PUBLIC && key2.getKeyType()!=XSECCryptoKey::KEY_RSA_PAIR)
             return false;
-        const RSA* rsa1 = static_cast<const OpenSSLCryptoKeyRSA*>(key1)->getOpenSSLRSA();
-        const RSA* rsa2 = static_cast<const OpenSSLCryptoKeyRSA*>(key2)->getOpenSSLRSA();
-        return (BN_cmp(rsa1->n,rsa2->n) == 0 && BN_cmp(rsa1->e,rsa2->e) == 0);
+        const RSA* rsa1 = static_cast<const OpenSSLCryptoKeyRSA&>(key1).getOpenSSLRSA();
+        const RSA* rsa2 = static_cast<const OpenSSLCryptoKeyRSA&>(key2).getOpenSSLRSA();
+        return (rsa1 && rsa2 && BN_cmp(rsa1->n,rsa2->n) == 0 && BN_cmp(rsa1->e,rsa2->e) == 0);
     }
 
     // For a private key, compare the private half.
-    if (key1->getKeyType()==XSECCryptoKey::KEY_RSA_PRIVATE) {
-        if (key2->getKeyType()!=XSECCryptoKey::KEY_RSA_PRIVATE && key2->getKeyType()!=XSECCryptoKey::KEY_RSA_PAIR)
+    if (key1.getKeyType()==XSECCryptoKey::KEY_RSA_PRIVATE) {
+        if (key2.getKeyType()!=XSECCryptoKey::KEY_RSA_PRIVATE && key2.getKeyType()!=XSECCryptoKey::KEY_RSA_PAIR)
             return false;
-        const RSA* rsa1 = static_cast<const OpenSSLCryptoKeyRSA*>(key1)->getOpenSSLRSA();
-        const RSA* rsa2 = static_cast<const OpenSSLCryptoKeyRSA*>(key2)->getOpenSSLRSA();
-        return (BN_cmp(rsa1->n,rsa2->n) == 0 && BN_cmp(rsa1->d,rsa2->d) == 0);
+        const RSA* rsa1 = static_cast<const OpenSSLCryptoKeyRSA&>(key1).getOpenSSLRSA();
+        const RSA* rsa2 = static_cast<const OpenSSLCryptoKeyRSA&>(key2).getOpenSSLRSA();
+        return (rsa1 && rsa2 && BN_cmp(rsa1->n,rsa2->n) == 0 && BN_cmp(rsa1->d,rsa2->d) == 0);
     }
 
     // If one key is public or both, just compare the public key half.
-    if (key1->getKeyType()==XSECCryptoKey::KEY_DSA_PUBLIC || key1->getKeyType()==XSECCryptoKey::KEY_DSA_PAIR) {
-        if (key2->getKeyType()!=XSECCryptoKey::KEY_DSA_PUBLIC && key2->getKeyType()!=XSECCryptoKey::KEY_DSA_PAIR)
+    if (key1.getKeyType()==XSECCryptoKey::KEY_DSA_PUBLIC || key1.getKeyType()==XSECCryptoKey::KEY_DSA_PAIR) {
+        if (key2.getKeyType()!=XSECCryptoKey::KEY_DSA_PUBLIC && key2.getKeyType()!=XSECCryptoKey::KEY_DSA_PAIR)
             return false;
-        const DSA* dsa1 = static_cast<const OpenSSLCryptoKeyDSA*>(key1)->getOpenSSLDSA();
-        const DSA* dsa2 = static_cast<const OpenSSLCryptoKeyDSA*>(key2)->getOpenSSLDSA();
-        return (BN_cmp(dsa1->pub_key,dsa2->pub_key) == 0);
+        const DSA* dsa1 = static_cast<const OpenSSLCryptoKeyDSA&>(key1).getOpenSSLDSA();
+        const DSA* dsa2 = static_cast<const OpenSSLCryptoKeyDSA&>(key2).getOpenSSLDSA();
+        return (dsa1 && dsa2 && BN_cmp(dsa1->pub_key,dsa2->pub_key) == 0);
     }
 
     // For a private key, compare the private half.
-    if (key1->getKeyType()==XSECCryptoKey::KEY_DSA_PRIVATE) {
-        if (key2->getKeyType()!=XSECCryptoKey::KEY_DSA_PRIVATE && key2->getKeyType()!=XSECCryptoKey::KEY_DSA_PAIR)
+    if (key1.getKeyType()==XSECCryptoKey::KEY_DSA_PRIVATE) {
+        if (key2.getKeyType()!=XSECCryptoKey::KEY_DSA_PRIVATE && key2.getKeyType()!=XSECCryptoKey::KEY_DSA_PAIR)
             return false;
-        const DSA* dsa1 = static_cast<const OpenSSLCryptoKeyDSA*>(key1)->getOpenSSLDSA();
-        const DSA* dsa2 = static_cast<const OpenSSLCryptoKeyDSA*>(key2)->getOpenSSLDSA();
-        return (BN_cmp(dsa1->priv_key,dsa2->priv_key) == 0);
+        const DSA* dsa1 = static_cast<const OpenSSLCryptoKeyDSA&>(key1).getOpenSSLDSA();
+        const DSA* dsa2 = static_cast<const OpenSSLCryptoKeyDSA&>(key2).getOpenSSLDSA();
+        return (dsa1 && dsa2 && BN_cmp(dsa1->priv_key,dsa2->priv_key) == 0);
     }
 
     Category::getInstance(XMLTOOLING_LOGCAT".SecurityHelper").warn("unsupported key type for comparison");
     return false;
 }
 
-string SecurityHelper::getDEREncoding(const XSECCryptoKey* key)
+string SecurityHelper::getDEREncoding(const XSECCryptoKey& key, bool nowrap)
 {
     string ret;
 
-    if (key->getProviderName()!=DSIGConstants::s_unicodeStrPROVOpenSSL) {
+    if (key.getProviderName()!=DSIGConstants::s_unicodeStrPROVOpenSSL) {
         Category::getInstance(XMLTOOLING_LOGCAT".SecurityHelper").warn("encoding of non-OpenSSL keys not supported");
         return ret;
     }
 
-    if (key->getKeyType() == XSECCryptoKey::KEY_RSA_PUBLIC || key->getKeyType() == XSECCryptoKey::KEY_RSA_PAIR) {
-        const RSA* rsa = static_cast<const OpenSSLCryptoKeyRSA*>(key)->getOpenSSLRSA();
+    if (key.getKeyType() == XSECCryptoKey::KEY_RSA_PUBLIC || key.getKeyType() == XSECCryptoKey::KEY_RSA_PAIR) {
+        const RSA* rsa = static_cast<const OpenSSLCryptoKeyRSA&>(key).getOpenSSLRSA();
+        if (!rsa) {
+            Category::getInstance(XMLTOOLING_LOGCAT".SecurityHelper").warn("key was not populated");
+            return ret;
+        }
         BIO* base64 = BIO_new(BIO_f_base64());
-        BIO_set_flags(base64, BIO_FLAGS_BASE64_NO_NL);
+        if (nowrap)
+            BIO_set_flags(base64, BIO_FLAGS_BASE64_NO_NL);
         BIO* mem = BIO_new(BIO_s_mem());
         BIO_push(base64, mem);
         i2d_RSA_PUBKEY_bio(base64, const_cast<RSA*>(rsa));
@@ -506,10 +512,15 @@ string SecurityHelper::getDEREncoding(const XSECCryptoKey* key)
             ret.append(bptr->data, bptr->length);
         BIO_free_all(base64);
     }
-    else if (key->getKeyType() == XSECCryptoKey::KEY_DSA_PUBLIC || key->getKeyType() == XSECCryptoKey::KEY_DSA_PAIR) {
-        const DSA* dsa = static_cast<const OpenSSLCryptoKeyDSA*>(key)->getOpenSSLDSA();
+    else if (key.getKeyType() == XSECCryptoKey::KEY_DSA_PUBLIC || key.getKeyType() == XSECCryptoKey::KEY_DSA_PAIR) {
+        const DSA* dsa = static_cast<const OpenSSLCryptoKeyDSA&>(key).getOpenSSLDSA();
+        if (!dsa) {
+            Category::getInstance(XMLTOOLING_LOGCAT".SecurityHelper").warn("key was not populated");
+            return ret;
+        }
         BIO* base64 = BIO_new(BIO_f_base64());
-        BIO_set_flags(base64, BIO_FLAGS_BASE64_NO_NL);
+        if (nowrap)
+            BIO_set_flags(base64, BIO_FLAGS_BASE64_NO_NL);
         BIO* mem = BIO_new(BIO_s_mem());
         BIO_push(base64, mem);
         i2d_DSA_PUBKEY_bio(base64, const_cast<DSA*>(dsa));
@@ -526,20 +537,21 @@ string SecurityHelper::getDEREncoding(const XSECCryptoKey* key)
     return ret;
 }
 
-string SecurityHelper::getDEREncoding(const XSECCryptoX509* cert)
+string SecurityHelper::getDEREncoding(const XSECCryptoX509& cert, bool nowrap)
 {
     string ret;
 
-    if (cert->getProviderName()!=DSIGConstants::s_unicodeStrPROVOpenSSL) {
+    if (cert.getProviderName()!=DSIGConstants::s_unicodeStrPROVOpenSSL) {
         Category::getInstance(XMLTOOLING_LOGCAT".SecurityHelper").warn("encoding of non-OpenSSL keys not supported");
         return ret;
     }
 
-    const X509* x = static_cast<const OpenSSLCryptoX509*>(cert)->getOpenSSLX509();
+    const X509* x = static_cast<const OpenSSLCryptoX509&>(cert).getOpenSSLX509();
     EVP_PKEY* key = X509_get_pubkey(const_cast<X509*>(x));
 
     BIO* base64 = BIO_new(BIO_f_base64());
-    BIO_set_flags(base64, BIO_FLAGS_BASE64_NO_NL);
+    if (nowrap)
+        BIO_set_flags(base64, BIO_FLAGS_BASE64_NO_NL);
     BIO* mem = BIO_new(BIO_s_mem());
     BIO_push(base64, mem);
     i2d_PUBKEY_bio(base64, key);
@@ -552,3 +564,13 @@ string SecurityHelper::getDEREncoding(const XSECCryptoX509* cert)
     BIO_free_all(base64);
     return ret;
 }
+
+string SecurityHelper::getDEREncoding(const Credential& cred, bool nowrap)
+{
+    const X509Credential* x509 = dynamic_cast<const X509Credential*>(&cred);
+    if (x509 && !x509->getEntityCertificateChain().empty())
+        return getDEREncoding(*(x509->getEntityCertificateChain().front()), nowrap);
+    else if (cred.getPublicKey())
+        return getDEREncoding(*(cred.getPublicKey()), nowrap);
+    return "";
+}