Fix line feeds again, VS is also broken.
[shibboleth/cpp-sp.git] / odbc-store / odbc-store.cpp
index 6460068..1ad49a2 100644 (file)
@@ -54,7 +54,6 @@
 #include <sqlext.h>
 
 #include <boost/lexical_cast.hpp>
-#include <boost/scoped_ptr.hpp>
 #include <boost/algorithm/string.hpp>
 
 using namespace xmltooling::logging;
@@ -64,7 +63,7 @@ using namespace boost;
 using namespace std;
 
 #define PLUGIN_VER_MAJOR 1
-#define PLUGIN_VER_MINOR 0
+#define PLUGIN_VER_MINOR 1
 
 #define LONGDATA_BUFLEN 16384
 
@@ -85,7 +84,7 @@ CREATE TABLE strings (
     context varchar(255) not null,
     id varchar(255) not null,
     expires datetime not null,
-    version smallint not null,
+    version int not null,
     value varchar(255) not null,
     PRIMARY KEY (context, id)
     )
@@ -94,7 +93,7 @@ CREATE TABLE texts (
     context varchar(255) not null,
     id varchar(255) not null,
     expires datetime not null,
-    version smallint not null,
+    version int not null,
     value text not null,
     PRIMARY KEY (context, id)
     )
@@ -113,13 +112,15 @@ namespace {
     struct ODBCConn {
         ODBCConn(SQLHDBC conn) : handle(conn), autoCommit(true) {}
         ~ODBCConn() {
-            SQLRETURN sr = SQL_SUCCESS;
-            if (!autoCommit)
-                sr = SQLSetConnectAttr(handle, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)SQL_AUTOCOMMIT_ON, 0);
-            SQLDisconnect(handle);
-            SQLFreeHandle(SQL_HANDLE_DBC,handle);
-            if (!SQL_SUCCEEDED(sr))
-                throw IOException("Failed to commit connection and return to auto-commit mode.");
+            if (handle != SQL_NULL_HDBC) {
+                SQLRETURN sr = SQL_SUCCESS;
+                if (!autoCommit)
+                    sr = SQLSetConnectAttr(handle, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)SQL_AUTOCOMMIT_ON, 0);
+                SQLDisconnect(handle);
+                SQLFreeHandle(SQL_HANDLE_DBC, handle);
+                if (!SQL_SUCCEEDED(sr))
+                    throw IOException("Failed to commit connection and return to auto-commit mode.");
+            }
         }
         operator SQLHDBC() {return handle;}
         SQLHDBC handle;
@@ -140,7 +141,7 @@ namespace {
             return createRow(STRING_TABLE, context, key, value, expiration);
         }
         int readString(const char* context, const char* key, string* pvalue=nullptr, time_t* pexpiration=nullptr, int version=0) {
-            return readRow(STRING_TABLE, context, key, pvalue, pexpiration, version, false);
+            return readRow(STRING_TABLE, context, key, pvalue, pexpiration, version);
         }
         int updateString(const char* context, const char* key, const char* value=nullptr, time_t expiration=0, int version=0) {
             return updateRow(STRING_TABLE, context, key, value, expiration, version);
@@ -153,7 +154,7 @@ namespace {
             return createRow(TEXT_TABLE, context, key, value, expiration);
         }
         int readText(const char* context, const char* key, string* pvalue=nullptr, time_t* pexpiration=nullptr, int version=0) {
-            return readRow(TEXT_TABLE, context, key, pvalue, pexpiration, version, true);
+            return readRow(TEXT_TABLE, context, key, pvalue, pexpiration, version);
         }
         int updateText(const char* context, const char* key, const char* value=nullptr, time_t expiration=0, int version=0) {
             return updateRow(TEXT_TABLE, context, key, value, expiration, version);
@@ -180,7 +181,7 @@ namespace {
 
     private:
         bool createRow(const char *table, const char* context, const char* key, const char* value, time_t expiration);
-        int readRow(const char *table, const char* context, const char* key, string* pvalue, time_t* pexpiration, int version, bool text);
+        int readRow(const char *table, const char* context, const char* key, string* pvalue, time_t* pexpiration, int version);
         int updateRow(const char *table, const char* context, const char* key, const char* value, time_t expiration, int version);
         bool deleteRow(const char *table, const char* context, const char* key);
 
@@ -206,6 +207,7 @@ namespace {
         SQLHENV m_henv;
         string m_connstring;
         long m_isolation;
+        bool m_wideVersion;
         vector<SQLINTEGER> m_retries;
     };
 
@@ -270,7 +272,7 @@ namespace {
 ODBCStorageService::ODBCStorageService(const DOMElement* e) : m_log(Category::getInstance("XMLTooling.StorageService")),
     m_caps(XMLHelper::getAttrInt(e, 255, contextSize), XMLHelper::getAttrInt(e, 255, keySize), XMLHelper::getAttrInt(e, 255, stringSize)),
     m_cleanupInterval(XMLHelper::getAttrInt(e, 900, cleanupInterval)),
-    cleanup_thread(nullptr), shutdown(false), m_henv(SQL_NULL_HANDLE), m_isolation(SQL_TXN_SERIALIZABLE)
+    cleanup_thread(nullptr), shutdown(false), m_henv(SQL_NULL_HENV), m_isolation(SQL_TXN_SERIALIZABLE), m_wideVersion(false)
 {
 #ifdef _DEBUG
     xmltooling::NDC ndc("ODBCStorageService");
@@ -287,7 +289,7 @@ ODBCStorageService::ODBCStorageService(const DOMElement* e) : m_log(Category::ge
     else
         throw XMLToolingException("Unknown transaction isolationLevel property.");
 
-    if (m_henv == SQL_NULL_HANDLE) {
+    if (m_henv == SQL_NULL_HENV) {
         // Enable connection pooling.
         SQLSetEnvAttr(SQL_NULL_HANDLE, SQL_ATTR_CONNECTION_POOLING, (void*)SQL_CP_ONE_PER_HENV, 0);
 
@@ -320,29 +322,50 @@ ODBCStorageService::ODBCStorageService(const DOMElement* e) : m_log(Category::ge
         m_log.crit("unknown database version: %d.%d", v.first, v.second);
         throw XMLToolingException("Unknown database version for ODBC StorageService.");
     }
+    
+    if (v.first > 1 || v.second > 0) {
+        m_log.info("using 32-bit int type for version fields in tables");
+        m_wideVersion = true;
+    }
 
     // Load any retry errors to check.
     e = XMLHelper::getNextSiblingElement(e, RetryOnError);
     while (e) {
         if (e->hasChildNodes()) {
-            m_retries.push_back(XMLString::parseInt(e->getTextContent()));
-            m_log.info("will retry operations when native ODBC error (%ld) is returned", m_retries.back());
+            try {
+                int code = XMLString::parseInt(e->getTextContent());
+                m_retries.push_back(code);
+                m_log.info("will retry operations when native ODBC error (%d) is returned", code);
+            }
+            catch (XMLException&) {
+                m_log.error("skipping non-numeric ODBC retry code");
+            }
         }
         e = XMLHelper::getNextSiblingElement(e, RetryOnError);
     }
 
-    // Initialize the cleanup thread
-    shutdown_wait.reset(CondWait::create());
-    cleanup_thread = Thread::create(&cleanup_fn, (void*)this);
+    if (m_cleanupInterval > 0) {
+        // Initialize the cleanup thread
+        shutdown_wait.reset(CondWait::create());
+        cleanup_thread = Thread::create(&cleanup_fn, (void*)this);
+    }
+    else {
+        m_log.info("no cleanup interval configured, no cleanup thread will be started");
+    }
 }
 
 ODBCStorageService::~ODBCStorageService()
 {
     shutdown = true;
-    shutdown_wait->signal();
-    cleanup_thread->join(nullptr);
-    if (m_henv != SQL_NULL_HANDLE)
+    if (shutdown_wait.get()) {
+        shutdown_wait->signal();
+    }
+    if (cleanup_thread) {
+        cleanup_thread->join(nullptr);
+    }
+    if (m_henv != SQL_NULL_HANDLE) {
         SQLFreeHandle(SQL_HANDLE_ENV, m_henv);
+    }
 }
 
 pair<bool,bool> ODBCStorageService::log_error(SQLHANDLE handle, SQLSMALLINT htype, const char* checkfor)
@@ -375,33 +398,37 @@ SQLHDBC ODBCStorageService::getHDBC()
 #endif
 
     // Get a handle.
-    SQLHDBC handle;
-    SQLRETURN sr=SQLAllocHandle(SQL_HANDLE_DBC, m_henv, &handle);
-    if (!SQL_SUCCEEDED(sr)) {
+    SQLHDBC handle = SQL_NULL_HDBC;
+    SQLRETURN sr = SQLAllocHandle(SQL_HANDLE_DBC, m_henv, &handle);
+    if (!SQL_SUCCEEDED(sr) || handle == SQL_NULL_HDBC) {
         m_log.error("failed to allocate connection handle");
         log_error(m_henv, SQL_HANDLE_ENV);
         throw IOException("ODBC StorageService failed to allocate a connection handle.");
     }
 
-    sr=SQLDriverConnect(handle,nullptr,(SQLCHAR*)m_connstring.c_str(),m_connstring.length(),nullptr,0,nullptr,SQL_DRIVER_NOPROMPT);
+    sr = SQLDriverConnect(handle,nullptr,(SQLCHAR*)m_connstring.c_str(),m_connstring.length(),nullptr,0,nullptr,SQL_DRIVER_NOPROMPT);
     if (!SQL_SUCCEEDED(sr)) {
         m_log.error("failed to connect to database");
         log_error(handle, SQL_HANDLE_DBC);
+        SQLFreeHandle(SQL_HANDLE_DBC, handle);
         throw IOException("ODBC StorageService failed to connect to database.");
     }
 
     sr = SQLSetConnectAttr(handle, SQL_ATTR_TXN_ISOLATION, (SQLPOINTER)m_isolation, 0);
-    if (!SQL_SUCCEEDED(sr))
+    if (!SQL_SUCCEEDED(sr)) {
+        SQLDisconnect(handle);
+        SQLFreeHandle(SQL_HANDLE_DBC, handle);
         throw IOException("ODBC StorageService failed to set transaction isolation level.");
+    }
 
     return handle;
 }
 
 SQLHSTMT ODBCStorageService::getHSTMT(SQLHDBC conn)
 {
-    SQLHSTMT hstmt;
+    SQLHSTMT hstmt = SQL_NULL_HSTMT;
     SQLRETURN sr = SQLAllocHandle(SQL_HANDLE_STMT, conn, &hstmt);
-    if (!SQL_SUCCEEDED(sr)) {
+    if (!SQL_SUCCEEDED(sr) || hstmt == SQL_NULL_HSTMT) {
         m_log.error("failed to allocate statement handle");
         log_error(conn, SQL_HANDLE_DBC);
         throw IOException("ODBC StorageService failed to allocate a statement handle.");
@@ -496,16 +523,22 @@ bool ODBCStorageService::createRow(const char* table, const char* context, const
         }
         m_log.error("insert record failed (t=%s, c=%s, k=%s)", table, context, key);
         logres = log_error(stmt, SQL_HANDLE_STMT, "23000");
-        if (logres.second)
-            return false;   // supposedly integrity violation?
+        if (logres.second) {
+            // Supposedly integrity violation.
+            // Try and delete any expired record still hanging around until the final attempt.
+            if (attempts > 0) {
+                reap(table, context);
+                logres.first = true;    // force it to treat as a retryable error
+                continue;
+            }
+            return false;
+        }
     } while (attempts && logres.first);
 
     throw IOException("ODBC StorageService failed to insert record.");
 }
 
-int ODBCStorageService::readRow(
-    const char *table, const char* context, const char* key, string* pvalue, time_t* pexpiration, int version, bool text
-    )
+int ODBCStorageService::readRow(const char *table, const char* context, const char* key, string* pvalue, time_t* pexpiration, int version)
 {
 #ifdef _DEBUG
     xmltooling::NDC ndc("readRow");
@@ -523,8 +556,10 @@ int ODBCStorageService::readRow(
     string q("SELECT version");
     if (pexpiration)
         q += ",expires";
-    if (pvalue)
+    if (pvalue) {
+        pvalue->erase();
         q = q + ",CASE version WHEN " + lexical_cast<string>(version) + " THEN null ELSE value END";
+    }
     q = q + " FROM " + table + " WHERE context='" + scontext.tostr() + "' AND id='" + skey.tostr() + "' AND expires > " + timebuf;
     if (m_log.isDebugEnabled())
         m_log.debug("SQL: %s", q.c_str());
@@ -537,20 +572,30 @@ int ODBCStorageService::readRow(
     }
 
     SQLSMALLINT ver;
+    SQLINTEGER widever;
     SQL_TIMESTAMP_STRUCT expiration;
 
-    SQLBindCol(stmt, 1, SQL_C_SSHORT, &ver, 0, nullptr);
+    if (m_wideVersion)
+        SQLBindCol(stmt, 1, SQL_C_SLONG, &widever, 0, nullptr);
+    else
+        SQLBindCol(stmt, 1, SQL_C_SSHORT, &ver, 0, nullptr);
     if (pexpiration)
         SQLBindCol(stmt, 2, SQL_C_TYPE_TIMESTAMP, &expiration, 0, nullptr);
 
-    if ((sr = SQLFetch(stmt)) == SQL_NO_DATA)
+    if ((sr = SQLFetch(stmt)) == SQL_NO_DATA) {
+        if (m_log.isDebugEnabled())
+            m_log.debug("search returned no data (t=%s, c=%s, k=%s)", table, context, key);
         return 0;
+    }
 
     if (pexpiration)
         *pexpiration = timeFromTimestamp(expiration);
 
-    if (version == ver)
+    if (version == (m_wideVersion ? widever : ver)) {
+        if (m_log.isDebugEnabled())
+            m_log.debug("versioned search detected no change (t=%s, c=%s, k=%s)", table, context, key);
         return version; // nothing's changed, so just echo back the version
+    }
 
     if (pvalue) {
         SQLLEN len;
@@ -565,7 +610,7 @@ int ODBCStorageService::readRow(
         }
     }
     
-    return ver;
+    return (m_wideVersion ? widever : ver);
 }
 
 int ODBCStorageService::updateRow(const char *table, const char* context, const char* key, const char* value, time_t expiration, int version)
@@ -603,15 +648,23 @@ int ODBCStorageService::updateRow(const char *table, const char* context, const
     }
 
     SQLSMALLINT ver;
-    SQLBindCol(stmt, 1, SQL_C_SSHORT, &ver, 0, nullptr);
+    SQLINTEGER widever;
+    if (m_wideVersion)
+        SQLBindCol(stmt, 1, SQL_C_SLONG, &widever, 0, nullptr);
+    else
+        SQLBindCol(stmt, 1, SQL_C_SSHORT, &ver, 0, nullptr);
     if ((sr = SQLFetch(stmt)) == SQL_NO_DATA) {
         return 0;
     }
 
     // Check version?
-    if (version > 0 && version != ver) {
+    if (version > 0 && version != (m_wideVersion ? widever : ver)) {
         return -1;
     }
+    else if ((m_wideVersion && widever == INT_MAX) || (!m_wideVersion && ver == 32767)) {
+        m_log.error("record version overflow (t=%s, c=%s, k=%s)", table, context, key);
+        throw IOException("Version overflow, record in ODBC StorageService could not be updated.");
+    }
 
     SQLFreeHandle(SQL_HANDLE_STMT, stmt);
     stmt = getHSTMT(conn);
@@ -646,11 +699,11 @@ int ODBCStorageService::updateRow(const char *table, const char* context, const
         else
             sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(value), &b_ind);
         if (!SQL_SUCCEEDED(sr)) {
-            m_log.error("SQLBindParam failed (context = %s)", context);
+            m_log.error("SQLBindParam failed (value = %s)", value);
             log_error(stmt, SQL_HANDLE_STMT);
             throw IOException("ODBC StorageService failed to update record.");
         }
-        m_log.debug("SQLBindParam succeeded (context = %s)", context);
+        m_log.debug("SQLBindParam succeeded (value = %s)", value);
     }
 
     int attempts = 3;
@@ -663,10 +716,10 @@ int ODBCStorageService::updateRow(const char *table, const char* context, const
             return 0;   // went missing?
         else if (SQL_SUCCEEDED(sr)) {
             m_log.debug("SQLExecute of update succeeded");
-            return ver + 1;
+            return (m_wideVersion ? widever : ver) + 1;
         }
 
-        m_log.error("update of record failed (t=%s, c=%s, k=%s", table, context, key);
+        m_log.error("update of record failed (t=%s, c=%s, k=%s)", table, context, key);
         logres = log_error(stmt, SQL_HANDLE_STMT);
     } while (attempts && logres.first);