Merge branch '1.x' of ssh://authdev.it.ohio-state.edu/~scantor/git/cpp-xmltooling...
[shibboleth/cpp-xmltooling.git] / xmltooling / util / CurlURLInputStream.cpp
index f3b6dcd..4d1e4b3 100644 (file)
@@ -1,18 +1,21 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
+/**
+ * Licensed to the University Corporation for Advanced Internet
+ * Development, Inc. (UCAID) under one or more contributor license
+ * agreements. See the NOTICE file distributed with this work for
+ * additional information regarding copyright ownership.
+ *
+ * UCAID licenses this file to you under the Apache License,
+ * Version 2.0 (the "License"); you may not use this file except
+ * in compliance with the License. You may obtain a copy of the
+ * License at
  *
- *      http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
  *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
+ * either express or implied. See the License for the specific
+ * language governing permissions and limitations under the License.
  */
 
 /**
@@ -24,6 +27,7 @@
 #include "internal.h"
 
 #include <xmltooling/util/CurlURLInputStream.h>
+#include <xmltooling/util/ParserPool.h>
 #include <xmltooling/util/XMLHelper.h>
 
 #include <openssl/ssl.h>
 
 using namespace xmltooling;
 using namespace xercesc;
+using namespace std;
 
 namespace {
-    static const XMLCh  _CURL[] =           UNICODE_LITERAL_4(C,U,R,L);
+    static const XMLCh _CURL[] =            UNICODE_LITERAL_4(C,U,R,L);
+    static const XMLCh _OpenSSL[] =         UNICODE_LITERAL_7(O,p,e,n,S,S,L);
     static const XMLCh _option[] =          UNICODE_LITERAL_6(o,p,t,i,o,n);
     static const XMLCh _provider[] =        UNICODE_LITERAL_8(p,r,o,v,i,d,e,r);
     static const XMLCh TransportOption[] =  UNICODE_LITERAL_15(T,r,a,n,s,p,o,r,t,O,p,t,i,o,n);
@@ -52,70 +58,158 @@ namespace {
     // callback to invoke a caller-defined SSL callback
     CURLcode ssl_ctx_callback(CURL* curl, SSL_CTX* ssl_ctx, void* userptr)
     {
-        // Manually disable SSLv2 so we're not dependent on libcurl to do it.
+        CurlURLInputStream* str = reinterpret_cast<CurlURLInputStream*>(userptr);
+
+        // Default flags manually disable SSLv2 so we're not dependent on libcurl to do it.
         // Also disable the ticket option where implemented, since this breaks a variety
         // of servers. Newer libcurl also does this for us.
 #ifdef SSL_OP_NO_TICKET
-        SSL_CTX_set_options(ssl_ctx, SSL_OP_ALL|SSL_OP_NO_SSLv2|SSL_OP_NO_TICKET);
+        SSL_CTX_set_options(ssl_ctx, str->getOpenSSLOps()|SSL_OP_NO_TICKET);
 #else
-        SSL_CTX_set_options(ssl_ctx, SSL_OP_ALL|SSL_OP_NO_SSLv2);
+        SSL_CTX_set_options(ssl_ctx, str->getOpenSSLOps());
 #endif
 
         return CURLE_OK;
     }
+
+    size_t curl_header_hook(void* ptr, size_t size, size_t nmemb, void* stream)
+    {
+        // only handle single-byte data
+        if (size!=1 || nmemb<5 || !stream)
+            return nmemb;
+        string* cacheTag = reinterpret_cast<string*>(stream);
+        const char* hdr = reinterpret_cast<char*>(ptr);
+        if (strncmp(hdr, "ETag:", 5) == 0) {
+            hdr += 5;
+            size_t remaining = nmemb - 5;
+            // skip leading spaces
+            while (remaining > 0) {
+                if (*hdr == ' ') {
+                    ++hdr;
+                    --remaining;
+                    continue;
+                }
+                break;
+            }
+            // append until whitespace
+            cacheTag->erase();
+            while (remaining > 0) {
+                if (!isspace(*hdr)) {
+                    (*cacheTag) += *hdr++;
+                    --remaining;
+                    continue;
+                }
+                break;
+            }
+
+            if (!cacheTag->empty())
+                *cacheTag = "If-None-Match: " + *cacheTag;
+        }
+        else if (cacheTag->empty() && strncmp(hdr, "Last-Modified:", 14) == 0) {
+            hdr += 14;
+            size_t remaining = nmemb - 14;
+            // skip leading spaces
+            while (remaining > 0) {
+                if (*hdr == ' ') {
+                    ++hdr;
+                    --remaining;
+                    continue;
+                }
+                break;
+            }
+            // append until whitespace
+            while (remaining > 0) {
+                if (!isspace(*hdr)) {
+                    (*cacheTag) += *hdr++;
+                    --remaining;
+                    continue;
+                }
+                break;
+            }
+
+            if (!cacheTag->empty())
+                *cacheTag = "If-Modified-Since: " + *cacheTag;
+        }
+
+        return nmemb;
+    }
 }
 
-CurlURLInputStream::CurlURLInputStream(const char* url)
+CurlURLInputStream::CurlURLInputStream(const char* url, string* cacheTag)
     : fLog(logging::Category::getInstance(XMLTOOLING_LOGCAT".libcurl.InputStream"))
-    , fURL(url)
+    , fCacheTag(cacheTag)
+    , fURL(url ? url : "")
+    , fOpenSSLOps(SSL_OP_ALL|SSL_OP_NO_SSLv2)
     , fMulti(0)
     , fEasy(0)
+    , fHeaders(0)
     , fTotalBytesRead(0)
     , fWritePtr(0)
     , fBytesRead(0)
     , fBytesToRead(0)
     , fDataAvailable(false)
-    , fBufferHeadPtr(fBuffer)
-    , fBufferTailPtr(fBuffer)
+    , fBuffer(0)
+    , fBufferHeadPtr(0)
+    , fBufferTailPtr(0)
+    , fBufferSize(0)
     , fContentType(0)
+    , fStatusCode(200)
 {
+    if (fURL.empty())
+        throw IOException("No URL supplied to CurlURLInputStream constructor.");
     init();
 }
 
-CurlURLInputStream::CurlURLInputStream(const XMLCh* url)
+CurlURLInputStream::CurlURLInputStream(const XMLCh* url, string* cacheTag)
     : fLog(logging::Category::getInstance(XMLTOOLING_LOGCAT".libcurl.InputStream"))
+    , fCacheTag(cacheTag)
+    , fOpenSSLOps(SSL_OP_ALL|SSL_OP_NO_SSLv2)
     , fMulti(0)
     , fEasy(0)
+    , fHeaders(0)
     , fTotalBytesRead(0)
     , fWritePtr(0)
     , fBytesRead(0)
     , fBytesToRead(0)
     , fDataAvailable(false)
-    , fBufferHeadPtr(fBuffer)
-    , fBufferTailPtr(fBuffer)
+    , fBuffer(0)
+    , fBufferHeadPtr(0)
+    , fBufferTailPtr(0)
+    , fBufferSize(0)
     , fContentType(0)
+    , fStatusCode(200)
 {
-    auto_ptr_char temp(url);
-    fURL = temp.get();
+    if (url) {
+        auto_ptr_char temp(url);
+        fURL = temp.get();
+    }
+    if (fURL.empty())
+        throw IOException("No URL supplied to CurlURLInputStream constructor.");
     init();
 }
 
-CurlURLInputStream::CurlURLInputStream(const DOMElement* e)
+CurlURLInputStream::CurlURLInputStream(const DOMElement* e, string* cacheTag)
     : fLog(logging::Category::getInstance(XMLTOOLING_LOGCAT".libcurl.InputStream"))
+    , fCacheTag(cacheTag)
+    , fOpenSSLOps(SSL_OP_ALL|SSL_OP_NO_SSLv2)
     , fMulti(0)
     , fEasy(0)
+    , fHeaders(0)
     , fTotalBytesRead(0)
     , fWritePtr(0)
     , fBytesRead(0)
     , fBytesToRead(0)
     , fDataAvailable(false)
-    , fBufferHeadPtr(fBuffer)
-    , fBufferTailPtr(fBuffer)
+    , fBuffer(0)
+    , fBufferHeadPtr(0)
+    , fBufferTailPtr(0)
+    , fBufferSize(0)
     , fContentType(0)
+    , fStatusCode(200)
 {
-    const XMLCh* attr = e->getAttributeNS(NULL, url);
+    const XMLCh* attr = e->getAttributeNS(nullptr, url);
     if (!attr || !*attr) {
-        attr = e->getAttributeNS(NULL, uri);
+        attr = e->getAttributeNS(nullptr, uri);
         if (!attr || !*attr)
             throw IOException("No URL supplied via DOM to CurlURLInputStream constructor.");
     }
@@ -140,7 +234,12 @@ CurlURLInputStream::~CurlURLInputStream()
         curl_multi_cleanup(fMulti);
     }
 
+    if (fHeaders) {
+        curl_slist_free_all(fHeaders);
+    }
+
     XMLString::release(&fContentType);
+    free(fBuffer);
 }
 
 void CurlURLInputStream::init(const DOMElement* e)
@@ -165,25 +264,48 @@ void CurlURLInputStream::init(const DOMElement* e)
     curl_easy_setopt(fEasy, CURLOPT_MAXREDIRS, 6);
 
     // Default settings.
-    curl_easy_setopt(fEasy, CURLOPT_CONNECTTIMEOUT,30);
-    curl_easy_setopt(fEasy, CURLOPT_TIMEOUT,60);
-    curl_easy_setopt(fEasy, CURLOPT_HTTPAUTH,0);
-    curl_easy_setopt(fEasy, CURLOPT_USERPWD,NULL);
+    curl_easy_setopt(fEasy, CURLOPT_CONNECTTIMEOUT, 10);
+    curl_easy_setopt(fEasy, CURLOPT_TIMEOUT, 60);
+    curl_easy_setopt(fEasy, CURLOPT_HTTPAUTH, 0);
+    curl_easy_setopt(fEasy, CURLOPT_USERPWD, nullptr);
     curl_easy_setopt(fEasy, CURLOPT_SSL_VERIFYHOST, 2);
     curl_easy_setopt(fEasy, CURLOPT_SSL_VERIFYPEER, 0);
+    curl_easy_setopt(fEasy, CURLOPT_CAINFO, nullptr);
     curl_easy_setopt(fEasy, CURLOPT_SSL_CIPHER_LIST, "ALL:!aNULL:!LOW:!EXPORT:!SSLv2");
     curl_easy_setopt(fEasy, CURLOPT_NOPROGRESS, 1);
     curl_easy_setopt(fEasy, CURLOPT_NOSIGNAL, 1);
     curl_easy_setopt(fEasy, CURLOPT_FAILONERROR, 1);
+    curl_easy_setopt(fEasy, CURLOPT_ENCODING, "");
 
     // Install SSL callback.
     curl_easy_setopt(fEasy, CURLOPT_SSL_CTX_FUNCTION, ssl_ctx_callback);
+    curl_easy_setopt(fEasy, CURLOPT_SSL_CTX_DATA, this);
 
     fError[0] = 0;
     curl_easy_setopt(fEasy, CURLOPT_ERRORBUFFER, fError);
 
+    // Check for cache tag.
+    if (fCacheTag) {
+        // Outgoing tag.
+        if (!fCacheTag->empty()) {
+            fHeaders = curl_slist_append(fHeaders, fCacheTag->c_str());
+        }
+        // Incoming tag.
+        curl_easy_setopt(fEasy, CURLOPT_HEADERFUNCTION, curl_header_hook);
+        curl_easy_setopt(fEasy, CURLOPT_HEADERDATA, fCacheTag);
+    }
+
+    // Add User-Agent as a header for now. TODO: Add private member to hold the
+    // value for the standard UA option.
+    string ua = string("User-Agent: ") + XMLToolingConfig::getConfig().user_agent +
+        " libcurl/" + LIBCURL_VERSION + ' ' + OPENSSL_VERSION_TEXT;
+    fHeaders = curl_slist_append(fHeaders, ua.c_str());
+
+    // Add User-Agent and cache headers.
+    curl_easy_setopt(fEasy, CURLOPT_HTTPHEADER, fHeaders);
+
     if (e) {
-        const XMLCh* flag = e->getAttributeNS(NULL, verifyHost);
+        const XMLCh* flag = e->getAttributeNS(nullptr, verifyHost);
         if (flag && (*flag == chLatin_f || *flag == chDigit_0))
             curl_easy_setopt(fEasy, CURLOPT_SSL_VERIFYHOST, 0);
 
@@ -191,32 +313,61 @@ void CurlURLInputStream::init(const DOMElement* e)
         bool success;
         DOMElement* child = XMLHelper::getLastChildElement(e, TransportOption);
         while (child) {
-            if (child->hasChildNodes() && XMLString::equals(child->getAttributeNS(NULL,_provider), _CURL)) {
-                auto_ptr_char option(child->getAttributeNS(NULL,_option));
+            if (child->hasChildNodes() && XMLString::equals(child->getAttributeNS(nullptr,_provider), _OpenSSL)) {
+                auto_ptr_char option(child->getAttributeNS(nullptr,_option));
+                auto_ptr_char value(child->getFirstChild()->getNodeValue());
+                if (option.get() && value.get() && !strcmp(option.get(), "SSL_OP_ALLOW_UNSAFE_LEGACY_RENEGOTIATION") &&
+                    (*value.get()=='1' || *value.get()=='t')) {
+                    // If the new option to enable buggy rengotiation is available, set it.
+                    // Otherwise, signal false if this is newer than 0.9.8k, because that
+                    // means it's 0.9.8l, which blocks renegotiation, and therefore will
+                    // not honor this request. Older versions are buggy, so behave as though
+                    // the flag was set anyway, so we signal true.
+#if defined(SSL_OP_ALLOW_UNSAFE_LEGACY_RENEGOTIATION)
+                    fOpenSSLOps |= SSL_OP_ALLOW_UNSAFE_LEGACY_RENEGOTIATION;
+                    success = true;
+#elif (OPENSSL_VERSION_NUMBER > 0x009080bfL)
+                    success = false;
+#else
+                    success = true;
+#endif
+                }
+                else {
+                    success = false;
+                }
+                if (!success)
+                    fLog.error("failed to set OpenSSL transport option (%s)", option.get());
+            }
+            else if (child->hasChildNodes() && XMLString::equals(child->getAttributeNS(nullptr,_provider), _CURL)) {
+                auto_ptr_char option(child->getAttributeNS(nullptr,_option));
                 auto_ptr_char value(child->getFirstChild()->getNodeValue());
                 if (option.get() && *option.get() && value.get() && *value.get()) {
                     // For libcurl, the option is an enum and the value type depends on the option.
-                    CURLoption opt = static_cast<CURLoption>(strtol(option.get(), NULL, 10));
+                    CURLoption opt = static_cast<CURLoption>(strtol(option.get(), nullptr, 10));
                     if (opt < CURLOPTTYPE_OBJECTPOINT)
-                        success = (curl_easy_setopt(fEasy, opt, strtol(value.get(), NULL, 10)) == CURLE_OK);
+                        success = (curl_easy_setopt(fEasy, opt, strtol(value.get(), nullptr, 10)) == CURLE_OK);
 #ifdef CURLOPTTYPE_OFF_T
-                    else if (opt < CURLOPTTYPE_OFF_T)
-                        success = (curl_easy_setopt(fEasy, opt, value.get()) == CURLE_OK);
+                    else if (opt < CURLOPTTYPE_OFF_T) {
+                        fSavedOptions.push_back(value.get());
+                        success = (curl_easy_setopt(fEasy, opt, fSavedOptions.back().c_str()) == CURLE_OK);
+                    }
 # ifdef HAVE_CURL_OFF_T
                     else if (sizeof(curl_off_t) == sizeof(long))
-                        success = (curl_easy_setopt(fEasy, opt, strtol(value.get(), NULL, 10)) == CURLE_OK);
+                        success = (curl_easy_setopt(fEasy, opt, strtol(value.get(), nullptr, 10)) == CURLE_OK);
 # else
                     else if (sizeof(off_t) == sizeof(long))
-                        success = (curl_easy_setopt(fEasy, opt, strtol(value.get(), NULL, 10)) == CURLE_OK);
+                        success = (curl_easy_setopt(fEasy, opt, strtol(value.get(), nullptr, 10)) == CURLE_OK);
 # endif
                     else
                         success = false;
 #else
-                    else
-                        success = (curl_easy_setopt(fEasy, opt, value.get()) == CURLE_OK);
+                    else {
+                        fSavedOptions.push_back(value.get());
+                        success = (curl_easy_setopt(fEasy, opt, fSavedOptions.back().c_str()) == CURLE_OK);
+                    }
 #endif
                     if (!success)
-                        fLog.error("failed to set transport option (%s)", option.get());
+                        fLog.error("failed to set CURL transport option (%s)", option.get());
                 }
             }
             child = XMLHelper::getPreviousSiblingElement(child, TransportOption);
@@ -234,22 +385,46 @@ void CurlURLInputStream::init(const DOMElement* e)
         try {
             readMore(&runningHandles);
         }
-        catch (XMLException& ex) {
+        catch (XMLException&) {
             curl_multi_remove_handle(fMulti, fEasy);
             curl_easy_cleanup(fEasy);
-            fEasy = NULL;
+            fEasy = nullptr;
             curl_multi_cleanup(fMulti);
-            fMulti = NULL;
-            auto_ptr_char msg(ex.getMessage());
-            throw IOException(msg.get());
+            fMulti = nullptr;
+            throw;
         }
         if(runningHandles == 0) break;
     }
 
+    // Check for a response code.
+    if (curl_easy_getinfo(fEasy, CURLINFO_RESPONSE_CODE, &fStatusCode) == CURLE_OK) {
+        if (fStatusCode >= 300 ) {
+            // Short-circuit usual processing by storing a special XML document in the buffer.
+            ostringstream specialdoc;
+            specialdoc << '<' << URLInputSource::asciiStatusCodeElementName << " xmlns=\"http://www.opensaml.org/xmltooling\">"
+                << fStatusCode
+                << "</" << URLInputSource::asciiStatusCodeElementName << '>';
+            string specialxml = specialdoc.str();
+            fBufferTailPtr = fBuffer = reinterpret_cast<XMLByte*>(malloc(specialxml.length()));
+            if (!fBuffer) {
+                curl_multi_remove_handle(fMulti, fEasy);
+                curl_easy_cleanup(fEasy);
+                fEasy = nullptr;
+                curl_multi_cleanup(fMulti);
+                fMulti = nullptr;
+                throw bad_alloc();
+            }
+            memcpy(fBuffer, specialxml.c_str(), specialxml.length());
+            fBufferHeadPtr = fBuffer + specialxml.length();
+        }
+    }
+    else {
+        fStatusCode = 200;  // reset to 200 to ensure no special processing occurs
+    }
+
     // Find the content type
-    char* contentType8 = NULL;
-    curl_easy_getinfo(fEasy, CURLINFO_CONTENT_TYPE, &contentType8);
-    if(contentType8)
+    char* contentType8 = nullptr;
+    if(curl_easy_getinfo(fEasy, CURLINFO_CONTENT_TYPE, &contentType8) == CURLE_OK && contentType8)
         fContentType = XMLString::transcode(contentType8);
 }
 
@@ -272,7 +447,7 @@ size_t CurlURLInputStream::writeCallback(char* buffer, size_t size, size_t nitem
     fTotalBytesRead += consume;
     fBytesToRead    -= consume;
 
-    //fLog.debug("write callback consuming %d bytes", consume);
+    fLog.debug("write callback consuming %u bytes", consume);
 
     // If bytes remain, rebuffer as many as possible into our holding buffer
     buffer          += consume;
@@ -280,13 +455,22 @@ size_t CurlURLInputStream::writeCallback(char* buffer, size_t size, size_t nitem
     cnt             -= consume;
     if (cnt > 0)
     {
-        size_t bufAvail = sizeof(fBuffer) - (fBufferHeadPtr - fBuffer);
-        consume = (cnt > bufAvail) ? bufAvail : cnt;
-        memcpy(fBufferHeadPtr, buffer, consume);
-        fBufferHeadPtr  += consume;
-        buffer          += consume;
-        totalConsumed   += consume;
-        //fLog.debug("write callback rebuffering %d bytes", consume);
+        size_t bufAvail = fBufferSize - (fBufferHeadPtr - fBuffer);
+        if (bufAvail < cnt) {
+            // Enlarge the buffer. TODO: limit max size
+            XMLByte* newbuf = reinterpret_cast<XMLByte*>(realloc(fBuffer, fBufferSize + (cnt - bufAvail)));
+            if (newbuf) {
+                fBufferSize = fBufferSize + (cnt - bufAvail);
+                fLog.debug("enlarged buffer to %u bytes", fBufferSize);
+                fBufferHeadPtr = newbuf + (fBufferHeadPtr - fBuffer);
+                fBuffer = fBufferTailPtr = newbuf;
+            }
+        }
+        memcpy(fBufferHeadPtr, buffer, cnt);
+        fBufferHeadPtr  += cnt;
+        buffer          += cnt;
+        totalConsumed   += cnt;
+        fLog.debug("write callback rebuffering %u bytes", cnt);
     }
 
     // Return the total amount we've consumed. If we don't consume all the bytes
@@ -303,9 +487,9 @@ bool CurlURLInputStream::readMore(int* runningHandles)
 
     // Process messages from curl
     int msgsInQueue = 0;
-    for (CURLMsg* msg = NULL; (msg = curl_multi_info_read(fMulti, &msgsInQueue)) != NULL; )
+    for (CURLMsg* msg = nullptr; (msg = curl_multi_info_read(fMulti, &msgsInQueue)) != nullptr; )
     {
-        //fLog.debug("msg %d, %d from curl", msg->msg, msg->data.result);
+        fLog.debug("msg %d, %d from curl", msg->msg, msg->data.result);
 
         if (msg->msg != CURLMSG_DONE)
             return true;
@@ -329,6 +513,10 @@ bool CurlURLInputStream::readMore(int* runningHandles)
             ThrowXML1(NetAccessorException, XMLExcepts::NetAcc_ConnSocket, fURL.c_str());
             break;
 
+        case CURLE_OPERATION_TIMEDOUT:
+            ThrowXML1(NetAccessorException, XMLExcepts::NetAcc_ConnSocket, fURL.c_str());
+            break;
+
         case CURLE_RECV_ERROR:
             ThrowXML1(NetAccessorException, XMLExcepts::NetAcc_ReadSocket, fURL.c_str());
             break;
@@ -393,12 +581,16 @@ xsecsize_t CurlURLInputStream::readBytes(XMLByte* const toFill, const xsecsize_t
             if (fBufferTailPtr == fBufferHeadPtr)
                 fBufferHeadPtr = fBufferTailPtr = fBuffer;
 
-            //fLog.debug("consuming %d buffered bytes", bufCnt);
+            fLog.debug("consuming %d buffered bytes", bufCnt);
 
             tryAgain = true;
             continue;
         }
 
+        // Check for a non-2xx status that means to ignore the curl response.
+        if (fStatusCode >= 300)
+            break;
+
         // Ask the curl to do some work
         int runningHandles = 0;
         tryAgain = readMore(&runningHandles);