Merge branch '1.x' of ssh://authdev.it.ohio-state.edu/~scantor/git/cpp-xmltooling...
[shibboleth/cpp-xmltooling.git] / xmltooling / encryption / impl / Encrypter.cpp
index 4f5b354..2e7cbe7 100644 (file)
@@ -1,17 +1,21 @@
-/*
- *  Copyright 2001-2006 Internet2
- * 
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
+/**
+ * Licensed to the University Corporation for Advanced Internet
+ * Development, Inc. (UCAID) under one or more contributor license
+ * agreements. See the NOTICE file distributed with this work for
+ * additional information regarding copyright ownership.
+ *
+ * UCAID licenses this file to you under the Apache License,
+ * Version 2.0 (the "License"); you may not use this file except
+ * in compliance with the License. You may obtain a copy of the
+ * License at
  *
- *     http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
  *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
+ * either express or implied. See the License for the specific
+ * language governing permissions and limitations under the License.
  */
 
 /**
 
 #include "internal.h"
 #include "encryption/Encrypter.h"
+#include "encryption/Encryption.h"
+#include "security/Credential.h"
+#include "signature/KeyInfo.h"
 
-#include <xsec/enc/openssl/OpenSSLCryptoSymmetricKey.hpp>
 #include <xsec/enc/XSECCryptoException.hpp>
 #include <xsec/framework/XSECException.hpp>
+#include <xsec/framework/XSECAlgorithmMapper.hpp>
+#include <xsec/framework/XSECAlgorithmHandler.hpp>
+#include <xsec/xenc/XENCCipher.hpp>
 #include <xsec/xenc/XENCEncryptedData.hpp>
 #include <xsec/xenc/XENCEncryptedKey.hpp>
 
 using namespace xmlencryption;
 using namespace xmlsignature;
 using namespace xmltooling;
+using namespace xercesc;
 using namespace std;
 
+Encrypter::EncryptionParams::EncryptionParams(
+    const XMLCh* algorithm, const unsigned char* keyBuffer, unsigned int keyBufferSize, const Credential* credential, bool compact
+    ) : m_algorithm(algorithm), m_keyBuffer(keyBuffer), m_keyBufferSize(keyBufferSize), m_credential(credential), m_compact(compact)
+{
+}
+
+Encrypter::EncryptionParams::~EncryptionParams()
+{
+}
+
+Encrypter::KeyEncryptionParams::KeyEncryptionParams(const Credential& credential, const XMLCh* algorithm, const XMLCh* recipient)
+    : m_credential(credential), m_algorithm(algorithm), m_recipient(recipient)
+{
+}
+
+Encrypter::KeyEncryptionParams::~KeyEncryptionParams()
+{
+}
+
+Encrypter::Encrypter() : m_cipher(nullptr)
+{
+}
+
 Encrypter::~Encrypter()
 {
     XMLToolingInternalConfig::getInternalConfig().m_xsecProvider->releaseCipher(m_cipher);
@@ -43,11 +76,11 @@ Encrypter::~Encrypter()
 void Encrypter::checkParams(EncryptionParams& encParams, KeyEncryptionParams* kencParams)
 {
     if (encParams.m_keyBufferSize==0) {
-        if (encParams.m_key) {
+        if (encParams.m_credential) {
             if (kencParams)
                 throw EncryptionException("Generating EncryptedKey inline requires the encryption key in raw form.");
         }
-        else if (!encParams.m_key) {
+        else {
             if (!kencParams)
                 throw EncryptionException("Using a generated encryption key requires a KeyEncryptionParams object.");
 
@@ -59,28 +92,42 @@ void Encrypter::checkParams(EncryptionParams& encParams, KeyEncryptionParams* ke
         }
     }
     
-    if (!encParams.m_key) {
+    XSECCryptoKey* key=nullptr;
+    if (encParams.m_credential) {
+        key = encParams.m_credential->getPrivateKey();
+        if (!key)
+            throw EncryptionException("Credential in EncryptionParams structure did not supply a private/secret key.");
+        // Set the encryption key.
+        m_cipher->setKey(key->clone());
+    }
+    else {
         // We have to have a raw key now, so we need to build a wrapper around it.
-        if (XMLString::equals(encParams.m_algorithm,DSIGConstants::s_unicodeStrURI3DES_CBC)) {
-            encParams.m_key=new OpenSSLCryptoSymmetricKey(XSECCryptoSymmetricKey::KEY_3DES_192);
-        }
-        else if (XMLString::equals(encParams.m_algorithm,DSIGConstants::s_unicodeStrURIAES128_CBC)) {
-            encParams.m_key=new OpenSSLCryptoSymmetricKey(XSECCryptoSymmetricKey::KEY_AES_128);
-        }
-        else if (XMLString::equals(encParams.m_algorithm,DSIGConstants::s_unicodeStrURIAES192_CBC)) {
-            encParams.m_key=new OpenSSLCryptoSymmetricKey(XSECCryptoSymmetricKey::KEY_AES_192);
-        }
-        else if (XMLString::equals(encParams.m_algorithm,DSIGConstants::s_unicodeStrURIAES256_CBC)) {
-            encParams.m_key=new OpenSSLCryptoSymmetricKey(XSECCryptoSymmetricKey::KEY_AES_256);
-        }
-        else {
-            throw EncryptionException("Unrecognized encryption algorithm, unable to build key wrapper.");
+        XSECAlgorithmHandler* handler =XSECPlatformUtils::g_algorithmMapper->mapURIToHandler(encParams.m_algorithm);
+        if (handler != nullptr)
+            key = handler->createKeyForURI(
+                encParams.m_algorithm,const_cast<unsigned char*>(encParams.m_keyBuffer),encParams.m_keyBufferSize
+                );
+
+        if (!key)
+            throw EncryptionException("Unable to build wrapper for key, unknown algorithm?");
+        // Overwrite the length if known.
+        switch (static_cast<XSECCryptoSymmetricKey*>(key)->getSymmetricKeyType()) {
+            case XSECCryptoSymmetricKey::KEY_3DES_192:
+                encParams.m_keyBufferSize = 192/8;
+                break;
+            case XSECCryptoSymmetricKey::KEY_AES_128:
+                encParams.m_keyBufferSize = 128/8;
+                break;
+            case XSECCryptoSymmetricKey::KEY_AES_192:
+                encParams.m_keyBufferSize = 192/8;
+                break;
+            case XSECCryptoSymmetricKey::KEY_AES_256:
+                encParams.m_keyBufferSize = 256/8;
+                break;
         }
-        static_cast<OpenSSLCryptoSymmetricKey*>(encParams.m_key)->setKey(encParams.m_keyBuffer, encParams.m_keyBufferSize);
+        // Set the encryption key.
+        m_cipher->setKey(key);
     }
-    
-    // Set the encryption key.
-    m_cipher->setKey(encParams.m_key->clone());
 }
 
 EncryptedData* Encrypter::encryptElement(DOMElement* element, EncryptionParams& encParams, KeyEncryptionParams* kencParams)
@@ -89,11 +136,13 @@ EncryptedData* Encrypter::encryptElement(DOMElement* element, EncryptionParams&
     
     if (m_cipher && m_cipher->getDocument()!=element->getOwnerDocument()) {
         XMLToolingInternalConfig::getInternalConfig().m_xsecProvider->releaseCipher(m_cipher);
-        m_cipher=NULL;
+        m_cipher=nullptr;
     }
     
-    if (!m_cipher)
+    if (!m_cipher) {
         m_cipher=XMLToolingInternalConfig::getInternalConfig().m_xsecProvider->newCipher(element->getOwnerDocument());
+        m_cipher->setExclusiveC14nSerialisation(false);
+    }
     
     try {
         checkParams(encParams,kencParams);
@@ -115,11 +164,13 @@ EncryptedData* Encrypter::encryptElementContent(DOMElement* element, EncryptionP
 
     if (m_cipher && m_cipher->getDocument()!=element->getOwnerDocument()) {
         XMLToolingInternalConfig::getInternalConfig().m_xsecProvider->releaseCipher(m_cipher);
-        m_cipher=NULL;
+        m_cipher=nullptr;
     }
     
-    if (!m_cipher)
+    if (!m_cipher) {
         m_cipher=XMLToolingInternalConfig::getInternalConfig().m_xsecProvider->newCipher(element->getOwnerDocument());
+        m_cipher->setExclusiveC14nSerialisation(false);
+    }
     
     try {
         checkParams(encParams,kencParams);
@@ -141,34 +192,28 @@ EncryptedData* Encrypter::encryptStream(istream& input, EncryptionParams& encPar
 
     if (m_cipher) {
         XMLToolingInternalConfig::getInternalConfig().m_xsecProvider->releaseCipher(m_cipher);
-        m_cipher=NULL;
+        m_cipher=nullptr;
     }
     
-    DOMDocument* doc=NULL;
+    DOMDocument* doc=nullptr;
     try {
         doc=XMLToolingConfig::getConfig().getParser().newDocument();
+        XercesJanitor<DOMDocument> janitor(doc);
         m_cipher=XMLToolingInternalConfig::getInternalConfig().m_xsecProvider->newCipher(doc);
+        m_cipher->setExclusiveC14nSerialisation(false);
         
         checkParams(encParams,kencParams);
         StreamInputSource::StreamBinInputStream xstream(input);
         m_cipher->encryptBinInputStream(&xstream, ENCRYPT_NONE, encParams.m_algorithm);
-        EncryptedData* xmlEncData = decorateAndUnmarshall(encParams, kencParams);
-        doc->release();
-        return xmlEncData;
+        return decorateAndUnmarshall(encParams, kencParams);
     }
     catch(XSECException& e) {
-        doc->release();
         auto_ptr_char temp(e.getMsg());
         throw EncryptionException(string("XMLSecurity exception while encrypting: ") + temp.get());
     }
     catch(XSECCryptoException& e) {
-        doc->release();
         throw EncryptionException(string("XMLSecurity exception while encrypting: ") + e.getMsg());
     }
-    catch (...) {
-        doc->release();
-        throw;
-    }
 }
 
 EncryptedData* Encrypter::decorateAndUnmarshall(EncryptionParams& encParams, KeyEncryptionParams* kencParams)
@@ -178,7 +223,7 @@ EncryptedData* Encrypter::decorateAndUnmarshall(EncryptionParams& encParams, Key
         throw EncryptionException("No EncryptedData element found?");
 
     // Unmarshall a tooling version of EncryptedData around the DOM.
-    EncryptedData* xmlEncData=NULL;
+    EncryptedData* xmlEncData=nullptr;
     auto_ptr<XMLObject> xmlObject(XMLObjectBuilder::buildOneFromElement(encData->getElement()));
     if (!(xmlObject.get()) || !(xmlEncData=dynamic_cast<EncryptedData*>(xmlObject.get())))
         throw EncryptionException("Unable to unmarshall into EncryptedData object.");
@@ -187,38 +232,144 @@ EncryptedData* Encrypter::decorateAndUnmarshall(EncryptionParams& encParams, Key
     xmlEncData->releaseThisAndChildrenDOM();
     
     // KeyInfo?
-    if (encParams.m_keyInfo) {
-        xmlEncData->setKeyInfo(encParams.m_keyInfo);
-        encParams.m_keyInfo=NULL;   // transfer ownership
-    }
+    KeyInfo* kinfo = encParams.m_credential ? encParams.m_credential->getKeyInfo(encParams.m_compact) : nullptr;
+    if (kinfo)
+        xmlEncData->setKeyInfo(kinfo);
     
     // Are we doing a key encryption?
     if (kencParams) {
-        m_cipher->setKEK(kencParams->m_key->clone());
+        XSECCryptoKey* kek = kencParams->m_credential.getPublicKey();
+        if (!kek)
+            throw EncryptionException("Credential in KeyEncryptionParams structure did not supply a public key.");
+        if (!kencParams->m_algorithm)
+            kencParams->m_algorithm = getKeyTransportAlgorithm(kencParams->m_credential, encParams.m_algorithm);
+        if (!kencParams->m_algorithm)
+            throw EncryptionException("Unable to derive a supported key encryption algorithm.");
+
+        m_cipher->setKEK(kek->clone());
         // ownership of this belongs to us, for some reason...
         auto_ptr<XENCEncryptedKey> encKey(
             m_cipher->encryptKey(encParams.m_keyBuffer, encParams.m_keyBufferSize, ENCRYPT_NONE, kencParams->m_algorithm)
             );
-        EncryptedKey* xmlEncKey=NULL;
+        EncryptedKey* xmlEncKey=nullptr;
         auto_ptr<XMLObject> xmlObjectKey(XMLObjectBuilder::buildOneFromElement(encKey->getElement()));
         if (!(xmlObjectKey.get()) || !(xmlEncKey=dynamic_cast<EncryptedKey*>(xmlObjectKey.get())))
             throw EncryptionException("Unable to unmarshall into EncryptedKey object.");
         
         xmlEncKey->releaseThisAndChildrenDOM();
         
+        // Recipient?
+        if (kencParams->m_recipient)
+            xmlEncKey->setRecipient(kencParams->m_recipient);
+        
         // KeyInfo?
-        if (kencParams->m_keyInfo) {
-            xmlEncKey->setKeyInfo(kencParams->m_keyInfo);
-            kencParams->m_keyInfo=NULL;   // transfer ownership
-        }
+        kinfo = kencParams->m_credential.getKeyInfo(encParams.m_compact);
+        if (kinfo)
+            xmlEncKey->setKeyInfo(kinfo);
         
-        // Add the EncryptedKey.
+        // Add the EncryptedKey inline.
         if (!xmlEncData->getKeyInfo())
             xmlEncData->setKeyInfo(KeyInfoBuilder::buildKeyInfo());
-        xmlEncData->getKeyInfo()->getOthers().push_back(xmlEncKey);
+        xmlEncData->getKeyInfo()->getUnknownXMLObjects().push_back(xmlEncKey);
         xmlObjectKey.release();
     }
     
     xmlObject.release();
     return xmlEncData;
 }
+
+EncryptedKey* Encrypter::encryptKey(
+    const unsigned char* keyBuffer, unsigned int keyBufferSize, KeyEncryptionParams& kencParams, bool compact
+    )
+{
+    if (!kencParams.m_algorithm)
+        throw EncryptionException("KeyEncryptionParams structure did not include a key encryption algorithm.");
+
+    // Get a fresh cipher object and document.
+
+    if (m_cipher) {
+        XMLToolingInternalConfig::getInternalConfig().m_xsecProvider->releaseCipher(m_cipher);
+        m_cipher=nullptr;
+    }
+
+    XSECCryptoKey* kek = kencParams.m_credential.getPublicKey();
+    if (!kek)
+        throw EncryptionException("Credential in KeyEncryptionParams structure did not supply a public key.");
+
+    DOMDocument* doc=nullptr;
+    try {
+        doc=XMLToolingConfig::getConfig().getParser().newDocument();
+        XercesJanitor<DOMDocument> janitor(doc);
+        m_cipher=XMLToolingInternalConfig::getInternalConfig().m_xsecProvider->newCipher(doc);
+        m_cipher->setExclusiveC14nSerialisation(false);
+        m_cipher->setKEK(kek->clone());
+        auto_ptr<XENCEncryptedKey> encKey(m_cipher->encryptKey(keyBuffer, keyBufferSize, ENCRYPT_NONE, kencParams.m_algorithm));
+        
+        EncryptedKey* xmlEncKey=nullptr;
+        auto_ptr<XMLObject> xmlObjectKey(XMLObjectBuilder::buildOneFromElement(encKey->getElement()));
+        if (!(xmlObjectKey.get()) || !(xmlEncKey=dynamic_cast<EncryptedKey*>(xmlObjectKey.get())))
+            throw EncryptionException("Unable to unmarshall into EncryptedKey object.");
+        
+        xmlEncKey->releaseThisAndChildrenDOM();
+        
+        // Recipient?
+        if (kencParams.m_recipient)
+            xmlEncKey->setRecipient(kencParams.m_recipient);
+
+        // KeyInfo?
+        KeyInfo* kinfo = kencParams.m_credential.getKeyInfo(compact);
+        if (kinfo)
+            xmlEncKey->setKeyInfo(kinfo);
+
+        xmlObjectKey.release();
+        return xmlEncKey;
+    }
+    catch(XSECException& e) {
+        auto_ptr_char temp(e.getMsg());
+        throw EncryptionException(string("XMLSecurity exception while encrypting: ") + temp.get());
+    }
+    catch(XSECCryptoException& e) {
+        throw EncryptionException(string("XMLSecurity exception while encrypting: ") + e.getMsg());
+    }
+}
+
+const XMLCh* Encrypter::getKeyTransportAlgorithm(const Credential& credential, const XMLCh* encryptionAlg)
+{
+    XMLToolingConfig& conf = XMLToolingConfig::getConfig();
+    const char* alg = credential.getAlgorithm();
+    if (!alg || !strcmp(alg, "RSA")) {
+        if (XMLString::equals(encryptionAlg,DSIGConstants::s_unicodeStrURI3DES_CBC)) {
+            if (conf.isXMLAlgorithmSupported(DSIGConstants::s_unicodeStrURIRSA_1_5, XMLToolingConfig::ALGTYPE_KEYENCRYPT))
+                return DSIGConstants::s_unicodeStrURIRSA_1_5;
+            else if (conf.isXMLAlgorithmSupported(DSIGConstants::s_unicodeStrURIRSA_OAEP_MGFP1, XMLToolingConfig::ALGTYPE_KEYENCRYPT))
+                return DSIGConstants::s_unicodeStrURIRSA_OAEP_MGFP1;
+        }
+        else {
+            if (conf.isXMLAlgorithmSupported(DSIGConstants::s_unicodeStrURIRSA_OAEP_MGFP1, XMLToolingConfig::ALGTYPE_KEYENCRYPT))
+                return DSIGConstants::s_unicodeStrURIRSA_OAEP_MGFP1;
+            else if (conf.isXMLAlgorithmSupported(DSIGConstants::s_unicodeStrURIRSA_1_5, XMLToolingConfig::ALGTYPE_KEYENCRYPT))
+                return DSIGConstants::s_unicodeStrURIRSA_1_5;
+        }
+    }
+    else if (!strcmp(alg, "AES")) {
+        const XMLCh* ret = nullptr;
+        switch (credential.getKeySize()) {
+            case 128:
+                ret = DSIGConstants::s_unicodeStrURIKW_AES128;
+            case 192:
+                ret = DSIGConstants::s_unicodeStrURIKW_AES192;
+            case 256:
+                ret = DSIGConstants::s_unicodeStrURIKW_AES256;
+            default:
+                return nullptr;
+        }
+        if (conf.isXMLAlgorithmSupported(ret, XMLToolingConfig::ALGTYPE_KEYENCRYPT))
+            return ret;
+    }
+    else if (!strcmp(alg, "DESede")) {
+        if (conf.isXMLAlgorithmSupported(DSIGConstants::s_unicodeStrURIKW_3DES, XMLToolingConfig::ALGTYPE_KEYENCRYPT))
+            return DSIGConstants::s_unicodeStrURIKW_3DES;
+    }
+
+    return nullptr;
+}