Merge branch '1.x' of ssh://authdev.it.ohio-state.edu/~scantor/git/cpp-xmltooling...
[shibboleth/cpp-xmltooling.git] / xmltooling / encryption / impl / Decrypter.cpp
index 7a8d553..49af2f3 100644 (file)
@@ -1,17 +1,21 @@
-/*
- *  Copyright 2001-2006 Internet2
- * 
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
+/**
+ * Licensed to the University Corporation for Advanced Internet
+ * Development, Inc. (UCAID) under one or more contributor license
+ * agreements. See the NOTICE file distributed with this work for
+ * additional information regarding copyright ownership.
+ *
+ * UCAID licenses this file to you under the Apache License,
+ * Version 2.0 (the "License"); you may not use this file except
+ * in compliance with the License. You may obtain a copy of the
+ * License at
  *
- *     http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
  *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
+ * either express or implied. See the License for the specific
+ * language governing permissions and limitations under the License.
  */
 
 /**
  */
 
 #include "internal.h"
+#include "logging.h"
 #include "encryption/Decrypter.h"
 #include "encryption/EncryptedKeyResolver.h"
+#include "encryption/Encryption.h"
+#include "security/Credential.h"
+#include "security/CredentialCriteria.h"
+#include "security/CredentialResolver.h"
 
-#include <log4cpp/Category.hh>
 #include <xsec/enc/XSECCryptoException.hpp>
 #include <xsec/framework/XSECException.hpp>
 #include <xsec/framework/XSECAlgorithmMapper.hpp>
 #include <xsec/framework/XSECAlgorithmHandler.hpp>
+#include <xsec/utils/XSECBinTXFMInputStream.hpp>
+#include <xsec/xenc/XENCCipher.hpp>
 #include <xsec/xenc/XENCEncryptedData.hpp>
 #include <xsec/xenc/XENCEncryptedKey.hpp>
 
 using namespace xmlencryption;
 using namespace xmlsignature;
 using namespace xmltooling;
+using namespace xercesc;
 using namespace std;
 
+Decrypter::Decrypter(const CredentialResolver* credResolver, CredentialCriteria* criteria, const EncryptedKeyResolver* EKResolver)
+    : m_cipher(nullptr), m_credResolver(credResolver), m_criteria(criteria), m_EKResolver(EKResolver)
+{
+}
+
 Decrypter::~Decrypter()
 {
     if (m_cipher)
         XMLToolingInternalConfig::getInternalConfig().m_xsecProvider->releaseCipher(m_cipher);
-    delete m_resolver;
-    delete m_KEKresolver;
 }
 
-DOMDocumentFragment* Decrypter::decryptData(EncryptedData* encryptedData)
+void Decrypter::setEncryptedKeyResolver(const EncryptedKeyResolver* EKResolver)
+{
+    m_EKResolver=EKResolver;
+}
+
+void Decrypter::setKEKResolver(const CredentialResolver* resolver, CredentialCriteria* criteria)
+{
+    m_credResolver=resolver;
+    m_criteria=criteria;
+}
+
+DOMDocumentFragment* Decrypter::decryptData(const EncryptedData& encryptedData, XSECCryptoKey* key)
 {
-    if (encryptedData->getDOM()==NULL)
+    if (encryptedData.getDOM()==nullptr)
         throw DecryptionException("The object must be marshalled before decryption.");
-    
+
     // We can reuse the cipher object if the document hasn't changed.
 
-    if (m_cipher && m_cipher->getDocument()!=encryptedData->getDOM()->getOwnerDocument()) {
+    if (m_cipher && m_cipher->getDocument()!=encryptedData.getDOM()->getOwnerDocument()) {
         XMLToolingInternalConfig::getInternalConfig().m_xsecProvider->releaseCipher(m_cipher);
-        m_cipher=NULL;
+        m_cipher=nullptr;
     }
     
     if (!m_cipher)
-        m_cipher=XMLToolingInternalConfig::getInternalConfig().m_xsecProvider->newCipher(encryptedData->getDOM()->getOwnerDocument());
-    
-    try {
-        // Resolve decryption key.
-        XSECCryptoKey* key=NULL;
-        if (m_resolver)
-            key=m_resolver->resolveKey(encryptedData->getKeyInfo());
-
-        if (!key && m_KEKresolver) {
-            // See if there's an encrypted key available. We'll need the algorithm...
-            const XMLCh* algorithm=
-                encryptedData->getEncryptionMethod() ? encryptedData->getEncryptionMethod()->getAlgorithm() : NULL;
-            if (!algorithm)
-                throw DecryptionException("No EncryptionMethod/@Algorithm set, key decryption cannot proceed.");
-            
-            if (encryptedData->getKeyInfo()) {
-                const vector<XMLObject*>& others=const_cast<const KeyInfo*>(encryptedData->getKeyInfo())->getUnknownXMLObjects();
-                for (vector<XMLObject*>::const_iterator i=others.begin(); i!=others.end(); i++) {
-                    EncryptedKey* encKey=dynamic_cast<EncryptedKey*>(*i);
-                    if (encKey) {
-                        try {
-                            key=decryptKey(encKey, algorithm);
-                        }
-                        catch (DecryptionException& e) {
-                            log4cpp::Category::getInstance(XMLTOOLING_LOGCAT".Decrypter").warn(e.what());
-                        }
-                    }
-                }
-            }
-            
-            if (!key) {
-                // Check for a non-trivial resolver.
-                EncryptedKeyResolver* ekr=dynamic_cast<EncryptedKeyResolver*>(m_resolver);
-                if (ekr) {
-                    EncryptedKey* encKey=ekr->resolveKey(encryptedData);
-                    if (encKey) {
-                        try {
-                            key=decryptKey(encKey, algorithm);
-                        }
-                        catch (DecryptionException& e) {
-                            log4cpp::Category::getInstance(XMLTOOLING_LOGCAT".Decrypter").warn(e.what());
-                        }
-                    }
-                }
-            }
-        }
+        m_cipher=XMLToolingInternalConfig::getInternalConfig().m_xsecProvider->newCipher(encryptedData.getDOM()->getOwnerDocument());
 
-        if (!key)
-            throw DecryptionException("Unable to resolve a decryption key.");
-        
-        m_cipher->setKey(key);
-        DOMNode* ret=m_cipher->decryptElementDetached(encryptedData->getDOM());
+    try {
+        m_cipher->setKey(key->clone());
+        DOMNode* ret=m_cipher->decryptElementDetached(encryptedData.getDOM());
         if (ret->getNodeType()!=DOMNode::DOCUMENT_FRAGMENT_NODE) {
             ret->release();
             throw DecryptionException("Decryption operation did not result in DocumentFragment.");
@@ -125,40 +103,93 @@ DOMDocumentFragment* Decrypter::decryptData(EncryptedData* encryptedData)
     }
 }
 
-XSECCryptoKey* Decrypter::decryptKey(EncryptedKey* encryptedKey, const XMLCh* algorithm)
+DOMDocumentFragment* Decrypter::decryptData(const EncryptedData& encryptedData, const XMLCh* recipient)
 {
-    if (encryptedKey->getDOM()==NULL)
-        throw DecryptionException("The object must be marshalled before decryption.");
+    if (!m_credResolver)
+        throw DecryptionException("No CredentialResolver supplied to provide decryption keys.");
+
+    // Resolve a decryption key directly.
+    vector<const Credential*> creds;
+    int types = CredentialCriteria::KEYINFO_EXTRACTION_KEY | CredentialCriteria::KEYINFO_EXTRACTION_KEYNAMES;
+    if (m_criteria) {
+        m_criteria->setUsage(Credential::ENCRYPTION_CREDENTIAL);
+        m_criteria->setKeyInfo(encryptedData.getKeyInfo(), types);
+        const EncryptionMethod* meth = encryptedData.getEncryptionMethod();
+        if (meth)
+            m_criteria->setXMLAlgorithm(meth->getAlgorithm());
+        m_credResolver->resolve(creds,m_criteria);
+    }
+    else {
+        CredentialCriteria criteria;
+        criteria.setUsage(Credential::ENCRYPTION_CREDENTIAL);
+        criteria.setKeyInfo(encryptedData.getKeyInfo(), types);
+        const EncryptionMethod* meth = encryptedData.getEncryptionMethod();
+        if (meth)
+            criteria.setXMLAlgorithm(meth->getAlgorithm());
+        m_credResolver->resolve(creds,&criteria);
+    }
+
+    // Loop over them and try each one.
+    XSECCryptoKey* key;
+    for (vector<const Credential*>::const_iterator cred = creds.begin(); cred!=creds.end(); ++cred) {
+        try {
+            key = (*cred)->getPrivateKey();
+            if (!key)
+                continue;
+            return decryptData(encryptedData, key);
+        }
+        catch(DecryptionException& ex) {
+            logging::Category::getInstance(XMLTOOLING_LOGCAT".Decrypter").warn(ex.what());
+        }
+    }
+
+    // We need to find an encrypted decryption key somewhere. We'll need the underlying algorithm...
+    const XMLCh* algorithm=
+        encryptedData.getEncryptionMethod() ? encryptedData.getEncryptionMethod()->getAlgorithm() : nullptr;
+    if (!algorithm)
+        throw DecryptionException("No EncryptionMethod/@Algorithm set, key decryption cannot proceed.");
     
+    // Check for external resolver.
+    const EncryptedKey* encKey=nullptr;
+    if (m_EKResolver)
+        encKey = m_EKResolver->resolveKey(encryptedData, recipient);
+    else {
+        EncryptedKeyResolver ekr;
+        encKey = ekr.resolveKey(encryptedData, recipient);
+    }
+
+    if (!encKey)
+        throw DecryptionException("Unable to locate an encrypted key.");
+
+    auto_ptr<XSECCryptoKey> keywrapper(decryptKey(*encKey, algorithm));
+    if (!keywrapper.get())
+        throw DecryptionException("Unable to decrypt the encrypted key.");
+    return decryptData(encryptedData, keywrapper.get());
+}
+
+void Decrypter::decryptData(ostream& out, const EncryptedData& encryptedData, XSECCryptoKey* key)
+{
+    if (encryptedData.getDOM()==nullptr)
+        throw DecryptionException("The object must be marshalled before decryption.");
+
     // We can reuse the cipher object if the document hasn't changed.
 
-    if (m_cipher && m_cipher->getDocument()!=encryptedKey->getDOM()->getOwnerDocument()) {
+    if (m_cipher && m_cipher->getDocument()!=encryptedData.getDOM()->getOwnerDocument()) {
         XMLToolingInternalConfig::getInternalConfig().m_xsecProvider->releaseCipher(m_cipher);
-        m_cipher=NULL;
+        m_cipher=nullptr;
     }
     
     if (!m_cipher)
-        m_cipher=XMLToolingInternalConfig::getInternalConfig().m_xsecProvider->newCipher(encryptedKey->getDOM()->getOwnerDocument());
-    
+        m_cipher=XMLToolingInternalConfig::getInternalConfig().m_xsecProvider->newCipher(encryptedData.getDOM()->getOwnerDocument());
+
     try {
-        // Resolve key decryption key.
-        XSECCryptoKey* key=NULL;
-        if (m_KEKresolver)
-            key=m_KEKresolver->resolveKey(encryptedKey->getKeyInfo());
-        if (!key)
-            throw DecryptionException("Unable to resolve a key decryption key.");
-        m_cipher->setKEK(key);
+        m_cipher->setKey(key->clone());
+        auto_ptr<XSECBinTXFMInputStream> in(m_cipher->decryptToBinInputStream(encryptedData.getDOM()));
         
-        XMLByte buffer[1024];
-        int keySize = m_cipher->decryptKey(encryptedKey->getDOM(), buffer, 1024);
-        if (keySize > 0) {
-            // Try to map the key.
-            XSECAlgorithmHandler* handler = XSECPlatformUtils::g_algorithmMapper->mapURIToHandler(algorithm);
-            if (handler != NULL)
-                return handler->createKeyForURI(algorithm, buffer, keySize);
-            throw DecryptionException("Unrecognized algorithm, could not build object around decrypted key.");
-        }
-        throw DecryptionException("Unable to decrypt key.");
+        XMLByte buf[8192];
+        xsecsize_t count = in->readBytes(buf, sizeof(buf));
+        while (count > 0)
+            out.write(reinterpret_cast<char*>(buf),count);
     }
     catch(XSECException& e) {
         auto_ptr_char temp(e.getMsg());
@@ -168,3 +199,154 @@ XSECCryptoKey* Decrypter::decryptKey(EncryptedKey* encryptedKey, const XMLCh* al
         throw DecryptionException(string("XMLSecurity exception while decrypting: ") + e.getMsg());
     }
 }
+
+void Decrypter::decryptData(ostream& out, const EncryptedData& encryptedData, const XMLCh* recipient)
+{
+    if (!m_credResolver)
+        throw DecryptionException("No CredentialResolver supplied to provide decryption keys.");
+
+    // Resolve a decryption key directly.
+    vector<const Credential*> creds;
+    int types = CredentialCriteria::KEYINFO_EXTRACTION_KEY | CredentialCriteria::KEYINFO_EXTRACTION_KEYNAMES;
+    if (m_criteria) {
+        m_criteria->setUsage(Credential::ENCRYPTION_CREDENTIAL);
+        m_criteria->setKeyInfo(encryptedData.getKeyInfo(), types);
+        const EncryptionMethod* meth = encryptedData.getEncryptionMethod();
+        if (meth)
+            m_criteria->setXMLAlgorithm(meth->getAlgorithm());
+        m_credResolver->resolve(creds,m_criteria);
+    }
+    else {
+        CredentialCriteria criteria;
+        criteria.setUsage(Credential::ENCRYPTION_CREDENTIAL);
+        criteria.setKeyInfo(encryptedData.getKeyInfo(), types);
+        const EncryptionMethod* meth = encryptedData.getEncryptionMethod();
+        if (meth)
+            criteria.setXMLAlgorithm(meth->getAlgorithm());
+        m_credResolver->resolve(creds,&criteria);
+    }
+
+    // Loop over them and try each one.
+    XSECCryptoKey* key;
+    for (vector<const Credential*>::const_iterator cred = creds.begin(); cred!=creds.end(); ++cred) {
+        try {
+            key = (*cred)->getPrivateKey();
+            if (!key)
+                continue;
+            return decryptData(out, encryptedData, key);
+        }
+        catch(DecryptionException& ex) {
+            logging::Category::getInstance(XMLTOOLING_LOGCAT".Decrypter").warn(ex.what());
+        }
+    }
+
+    // We need to find an encrypted decryption key somewhere. We'll need the underlying algorithm...
+    const XMLCh* algorithm=
+        encryptedData.getEncryptionMethod() ? encryptedData.getEncryptionMethod()->getAlgorithm() : nullptr;
+    if (!algorithm)
+        throw DecryptionException("No EncryptionMethod/@Algorithm set, key decryption cannot proceed.");
+    
+    // Check for external resolver.
+    const EncryptedKey* encKey=nullptr;
+    if (m_EKResolver)
+        encKey = m_EKResolver->resolveKey(encryptedData, recipient);
+    else {
+        EncryptedKeyResolver ekr;
+        encKey = ekr.resolveKey(encryptedData, recipient);
+    }
+
+    if (!encKey)
+        throw DecryptionException("Unable to locate an encrypted key.");
+
+    auto_ptr<XSECCryptoKey> keywrapper(decryptKey(*encKey, algorithm));
+    if (!keywrapper.get())
+        throw DecryptionException("Unable to decrypt the encrypted key.");
+    decryptData(out, encryptedData, keywrapper.get());
+}
+
+XSECCryptoKey* Decrypter::decryptKey(const EncryptedKey& encryptedKey, const XMLCh* algorithm)
+{
+    if (!m_credResolver)
+        throw DecryptionException("No CredentialResolver supplied to provide decryption keys.");
+
+    if (encryptedKey.getDOM()==nullptr)
+        throw DecryptionException("The object must be marshalled before decryption.");
+
+    XSECAlgorithmHandler* handler;
+    try {
+        handler = XSECPlatformUtils::g_algorithmMapper->mapURIToHandler(algorithm);
+        if (!handler)
+            throw DecryptionException("Unrecognized algorithm, no way to build object around decrypted key.");
+    }
+    catch(XSECException& e) {
+        auto_ptr_char temp(e.getMsg());
+        throw DecryptionException(string("XMLSecurity exception while decrypting key: ") + temp.get());
+    }
+    catch(XSECCryptoException& e) {
+        throw DecryptionException(string("XMLSecurity exception while decrypting key: ") + e.getMsg());
+    }
+    
+    // We can reuse the cipher object if the document hasn't changed.
+
+    if (m_cipher && m_cipher->getDocument()!=encryptedKey.getDOM()->getOwnerDocument()) {
+        XMLToolingInternalConfig::getInternalConfig().m_xsecProvider->releaseCipher(m_cipher);
+        m_cipher=nullptr;
+    }
+    
+    if (!m_cipher)
+        m_cipher=XMLToolingInternalConfig::getInternalConfig().m_xsecProvider->newCipher(encryptedKey.getDOM()->getOwnerDocument());
+    
+    // Resolve key decryption keys.
+    int types = CredentialCriteria::KEYINFO_EXTRACTION_KEY | CredentialCriteria::KEYINFO_EXTRACTION_KEYNAMES;
+    vector<const Credential*> creds;
+    if (m_criteria) {
+        m_criteria->setUsage(Credential::ENCRYPTION_CREDENTIAL);
+        m_criteria->setKeyInfo(encryptedKey.getKeyInfo(), types);
+        const EncryptionMethod* meth = encryptedKey.getEncryptionMethod();
+        if (meth)
+            m_criteria->setXMLAlgorithm(meth->getAlgorithm());
+        m_credResolver->resolve(creds, m_criteria);
+    }
+    else {
+        CredentialCriteria criteria;
+        criteria.setUsage(Credential::ENCRYPTION_CREDENTIAL);
+        criteria.setKeyInfo(encryptedKey.getKeyInfo(), types);
+        const EncryptionMethod* meth = encryptedKey.getEncryptionMethod();
+        if (meth)
+            criteria.setXMLAlgorithm(meth->getAlgorithm());
+        m_credResolver->resolve(creds, &criteria);
+    }
+    if (creds.empty())
+        throw DecryptionException("Unable to resolve any key decryption keys.");
+
+    XMLByte buffer[1024];
+    for (vector<const Credential*>::const_iterator cred = creds.begin(); cred!=creds.end(); ++cred) {
+        try {
+            if (!(*cred)->getPrivateKey())
+                throw DecryptionException("Credential did not contain a private key.");
+            memset(buffer,0,sizeof(buffer));
+            m_cipher->setKEK((*cred)->getPrivateKey()->clone());
+
+            try {
+                int keySize = m_cipher->decryptKey(encryptedKey.getDOM(), buffer, 1024);
+                if (keySize<=0)
+                    throw DecryptionException("Unable to decrypt key.");
+        
+                // Try to wrap the key.
+                return handler->createKeyForURI(algorithm, buffer, keySize);
+            }
+            catch(XSECException& e) {
+                auto_ptr_char temp(e.getMsg());
+                throw DecryptionException(string("XMLSecurity exception while decrypting key: ") + temp.get());
+            }
+            catch(XSECCryptoException& e) {
+                throw DecryptionException(string("XMLSecurity exception while decrypting key: ") + e.getMsg());
+            }
+        }
+        catch(DecryptionException& ex) {
+            logging::Category::getInstance(XMLTOOLING_LOGCAT".Decrypter").warn(ex.what());
+        }
+    }
+    
+    throw DecryptionException("Unable to decrypt key.");
+}