Log message spelling.
[shibboleth/sp.git] / odbc-store / odbc-store.cpp
index 2693a9d..f64f091 100644 (file)
@@ -37,8 +37,8 @@
 # define ODBCSTORE_EXPORTS
 #endif
 
-#include <log4cpp/Category.hh>
 #include <xercesc/util/XMLUniDefs.hpp>
+#include <xmltooling/logging.h>
 #include <xmltooling/XMLToolingConfig.h>
 #include <xmltooling/util/NDC.h>
 #include <xmltooling/util/StorageService.h>
@@ -48,9 +48,9 @@
 #include <sql.h>
 #include <sqlext.h>
 
+using namespace xmltooling::logging;
 using namespace xmltooling;
 using namespace xercesc;
-using namespace log4cpp;
 using namespace std;
 
 #define PLUGIN_VER_MAJOR 1
@@ -65,7 +65,7 @@ using namespace std;
 #define STRING_TABLE "strings"
 #define TEXT_TABLE "texts"
 
-/* tables definitions
+/* table definitions
 CREATE TABLE version (
     major tinyint NOT NULL,
     minor tinyint NOT NULL
@@ -92,26 +92,25 @@ CREATE TABLE texts (
 
 namespace {
     static const XMLCh cleanupInterval[] =  UNICODE_LITERAL_15(c,l,e,a,n,u,p,I,n,t,e,r,v,a,l);
+    static const XMLCh isolationLevel[] =   UNICODE_LITERAL_14(i,s,o,l,a,t,i,o,n,L,e,v,e,l);
     static const XMLCh ConnectionString[] = UNICODE_LITERAL_16(C,o,n,n,e,c,t,i,o,n,S,t,r,i,n,g);
+    static const XMLCh RetryOnError[] =     UNICODE_LITERAL_12(R,e,t,r,y,O,n,E,r,r,o,r);
 
     // RAII for ODBC handles
     struct ODBCConn {
-        ODBCConn(SQLHDBC conn) : handle(conn) {}
+        ODBCConn(SQLHDBC conn) : handle(conn), autoCommit(true) {}
         ~ODBCConn() {
-            SQLRETURN sr = SQLEndTran(SQL_HANDLE_DBC, handle, SQL_COMMIT);
+            SQLRETURN sr = SQL_SUCCESS;
+            if (!autoCommit)
+                sr = SQLSetConnectAttr(handle, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)SQL_AUTOCOMMIT_ON, NULL);
+            SQLDisconnect(handle);
             SQLFreeHandle(SQL_HANDLE_DBC,handle);
             if (!SQL_SUCCEEDED(sr))
-                throw IOException("Failed to commit connection.");
+                throw IOException("Failed to commit connection and return to auto-commit mode.");
         }
         operator SQLHDBC() {return handle;}
         SQLHDBC handle;
-    };
-
-    struct ODBCStatement {
-        ODBCStatement(SQLHSTMT statement) : handle(statement) {}
-        ~ODBCStatement() {SQLFreeHandle(SQL_HANDLE_STMT,handle);}
-        operator SQLHSTMT() {return handle;}
-        SQLHSTMT handle;
+        bool autoCommit;
     };
 
     class ODBCStorageService : public StorageService
@@ -120,7 +119,7 @@ namespace {
         ODBCStorageService(const DOMElement* e);
         virtual ~ODBCStorageService();
 
-        void createString(const char* context, const char* key, const char* value, time_t expiration) {
+        bool createString(const char* context, const char* key, const char* value, time_t expiration) {
             return createRow(STRING_TABLE, context, key, value, expiration);
         }
         int readString(const char* context, const char* key, string* pvalue=NULL, time_t* pexpiration=NULL, int version=0) {
@@ -133,7 +132,7 @@ namespace {
             return deleteRow(STRING_TABLE, context, key);
         }
 
-        void createText(const char* context, const char* key, const char* value, time_t expiration) {
+        bool createText(const char* context, const char* key, const char* value, time_t expiration) {
             return createRow(TEXT_TABLE, context, key, value, expiration);
         }
         int readText(const char* context, const char* key, string* pvalue=NULL, time_t* pexpiration=NULL, int version=0) {
@@ -163,7 +162,7 @@ namespace {
          
 
     private:
-        void createRow(const char *table, const char* context, const char* key, const char* value, time_t expiration);
+        bool createRow(const char *table, const char* context, const char* key, const char* value, time_t expiration);
         int readRow(const char *table, const char* context, const char* key, string* pvalue, time_t* pexpiration, int version, bool text);
         int updateRow(const char *table, const char* context, const char* key, const char* value, time_t expiration, int version);
         bool deleteRow(const char *table, const char* context, const char* key);
@@ -175,7 +174,7 @@ namespace {
         SQLHDBC getHDBC();
         SQLHSTMT getHSTMT(SQLHDBC);
         pair<int,int> getVersion(SQLHDBC);
-        void log_error(SQLHANDLE handle, SQLSMALLINT htype);
+        pair<bool,bool> log_error(SQLHANDLE handle, SQLSMALLINT htype, const char* checkfor=NULL);
 
         static void* cleanup_fn(void*); 
         void cleanup();
@@ -188,6 +187,8 @@ namespace {
 
         SQLHENV m_henv;
         string m_connstring;
+        long m_isolation;
+        vector<SQLINTEGER> m_retries;
     };
 
     StorageService* ODBCStorageServiceFactory(const DOMElement* const & e)
@@ -256,7 +257,7 @@ namespace {
 };
 
 ODBCStorageService::ODBCStorageService(const DOMElement* e) : m_log(Category::getInstance("XMLTooling.StorageService")),
-   m_cleanupInterval(900), shutdown_wait(NULL), cleanup_thread(NULL), shutdown(false), m_henv(SQL_NULL_HANDLE)
+   m_cleanupInterval(900), shutdown_wait(NULL), cleanup_thread(NULL), shutdown(false), m_henv(SQL_NULL_HANDLE), m_isolation(SQL_TXN_SERIALIZABLE)
 {
 #ifdef _DEBUG
     xmltooling::NDC ndc("ODBCStorageService");
@@ -268,6 +269,20 @@ ODBCStorageService::ODBCStorageService(const DOMElement* e) : m_log(Category::ge
     if (!m_cleanupInterval)
         m_cleanupInterval = 900;
 
+    auto_ptr_char iso(e ? e->getAttributeNS(NULL,isolationLevel) : NULL);
+    if (iso.get() && *iso.get()) {
+        if (!strcmp(iso.get(),"SERIALIZABLE"))
+            m_isolation = SQL_TXN_SERIALIZABLE;
+        else if (!strcmp(iso.get(),"REPEATABLE_READ"))
+            m_isolation = SQL_TXN_REPEATABLE_READ;
+        else if (!strcmp(iso.get(),"READ_COMMITTED"))
+            m_isolation = SQL_TXN_READ_COMMITTED;
+        else if (!strcmp(iso.get(),"READ_UNCOMMITTED"))
+            m_isolation = SQL_TXN_READ_UNCOMMITTED;
+        else
+            throw XMLToolingException("Unknown transaction isolationLevel property.");
+    }
+
     if (m_henv == SQL_NULL_HANDLE) {
         // Enable connection pooling.
         SQLSetEnvAttr(SQL_NULL_HANDLE, SQL_ATTR_CONNECTION_POOLING, (void*)SQL_CP_ONE_PER_HENV, 0);
@@ -302,6 +317,16 @@ ODBCStorageService::ODBCStorageService(const DOMElement* e) : m_log(Category::ge
         throw XMLToolingException("Unknown database version for ODBC StorageService.");
     }
 
+    // Load any retry errors to check.
+    e = XMLHelper::getNextSiblingElement(e,RetryOnError);
+    while (e) {
+        if (e->hasChildNodes()) {
+            m_retries.push_back(XMLString::parseInt(e->getFirstChild()->getNodeValue()));
+            m_log.info("will retry operations when native ODBC error (%ld) is returned", m_retries.back());
+        }
+        e = XMLHelper::getNextSiblingElement(e,RetryOnError);
+    }
+
     // Initialize the cleanup thread
     shutdown_wait = CondWait::create();
     cleanup_thread = Thread::create(&cleanup_fn, (void*)this);
@@ -313,10 +338,11 @@ ODBCStorageService::~ODBCStorageService()
     shutdown_wait->signal();
     cleanup_thread->join(NULL);
     delete shutdown_wait;
-    SQLFreeHandle(SQL_HANDLE_ENV, m_henv);
+    if (m_henv != SQL_NULL_HANDLE)
+        SQLFreeHandle(SQL_HANDLE_ENV, m_henv);
 }
 
-void ODBCStorageService::log_error(SQLHANDLE handle, SQLSMALLINT htype)
+pair<bool,bool> ODBCStorageService::log_error(SQLHANDLE handle, SQLSMALLINT htype, const char* checkfor)
 {
     SQLSMALLINT         i = 0;
     SQLINTEGER  native;
@@ -325,11 +351,18 @@ void ODBCStorageService::log_error(SQLHANDLE handle, SQLSMALLINT htype)
     SQLSMALLINT         len;
     SQLRETURN   ret;
 
+    pair<bool,bool> res = make_pair(false,false);
     do {
         ret = SQLGetDiagRec(htype, handle, ++i, state, &native, text, sizeof(text), &len);
-        if (SQL_SUCCEEDED(ret))
+        if (SQL_SUCCEEDED(ret)) {
             m_log.error("ODBC Error: %s:%ld:%ld:%s", state, i, native, text);
+            for (vector<SQLINTEGER>::const_iterator n = m_retries.begin(); !res.first && n != m_retries.end(); ++n)
+                res.first = (*n == native);
+            if (checkfor && !strcmp(checkfor, (const char*)state))
+                res.second = true;
+        }
     } while(SQL_SUCCEEDED(ret));
+    return res;
 }
 
 SQLHDBC ODBCStorageService::getHDBC()
@@ -354,12 +387,9 @@ SQLHDBC ODBCStorageService::getHDBC()
         throw IOException("ODBC StorageService failed to connect to database.");
     }
 
-    sr = SQLSetConnectAttr(handle, SQL_ATTR_AUTOCOMMIT, SQL_AUTOCOMMIT_OFF, NULL);
-    if (!SQL_SUCCEEDED(sr))
-        throw IOException("ODBC StorageService failed to disable auto-commit mode.");
-    sr = SQLSetConnectAttr(handle, SQL_ATTR_TXN_ISOLATION, (SQLPOINTER)SQL_TXN_SERIALIZABLE, NULL);
+    sr = SQLSetConnectAttr(handle, SQL_ATTR_TXN_ISOLATION, (SQLPOINTER)m_isolation, NULL);
     if (!SQL_SUCCEEDED(sr))
-        throw IOException("ODBC StorageService failed to enable transaction isolation.");
+        throw IOException("ODBC StorageService failed to set transaction isolation level.");
 
     return handle;
 }
@@ -379,7 +409,7 @@ SQLHSTMT ODBCStorageService::getHSTMT(SQLHDBC conn)
 pair<int,int> ODBCStorageService::getVersion(SQLHDBC conn)
 {
     // Grab the version number from the database.
-    ODBCStatement stmt(getHSTMT(conn));
+    SQLHSTMT stmt = getHSTMT(conn);
     
     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)"SELECT major,minor FROM version", SQL_NTS);
     if (!SQL_SUCCEEDED(sr)) {
@@ -400,7 +430,7 @@ pair<int,int> ODBCStorageService::getVersion(SQLHDBC conn)
     throw IOException("ODBC StorageService failed to read version from database.");
 }
 
-void ODBCStorageService::createRow(const char *table, const char* context, const char* key, const char* value, time_t expiration)
+bool ODBCStorageService::createRow(const char* table, const char* context, const char* key, const char* value, time_t expiration)
 {
 #ifdef _DEBUG
     xmltooling::NDC ndc("createRow");
@@ -411,24 +441,72 @@ void ODBCStorageService::createRow(const char *table, const char* context, const
 
     // Get statement handle.
     ODBCConn conn(getHDBC());
-    ODBCStatement stmt(getHSTMT(conn));
+    SQLHSTMT stmt = getHSTMT(conn);
 
     // Prepare and exectute insert statement.
-    char *scontext = makeSafeSQL(context);
-    char *skey = makeSafeSQL(key);
-    char *svalue = makeSafeSQL(value);
-    string q  = string("INSERT ") + table + " VALUES ('" + scontext + "','" + skey + "'," + timebuf + ",1,'" + svalue + "')";
-    freeSafeSQL(scontext, context);
-    freeSafeSQL(skey, key);
-    freeSafeSQL(svalue, value);
-    m_log.debug("SQL: %s", q.c_str());
+    //char *scontext = makeSafeSQL(context);
+    //char *skey = makeSafeSQL(key);
+    //char *svalue = makeSafeSQL(value);
+    string q  = string("INSERT INTO ") + table + " VALUES (?,?," + timebuf + ",1,?)";
 
-    SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
+    SQLRETURN sr = SQLPrepare(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
     if (!SQL_SUCCEEDED(sr)) {
-        m_log.error("insert record failed (t=%s, c=%s, k=%s)", table, context, key);
+        m_log.error("SQLPrepare failed (t=%s, c=%s, k=%s)", table, context, key);
+        log_error(stmt, SQL_HANDLE_STMT);
+        throw IOException("ODBC StorageService failed to insert record.");
+    }
+    m_log.debug("SQLPrepare succeeded. SQL: %s", q.c_str());
+
+    SQLINTEGER b_ind = SQL_NTS;
+    sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(context), &b_ind);
+    if (!SQL_SUCCEEDED(sr)) {
+        m_log.error("SQLBindParam failed (context = %s)", context);
+        log_error(stmt, SQL_HANDLE_STMT);
+        throw IOException("ODBC StorageService failed to insert record.");
+    }
+    m_log.debug("SQLBindParam succeeded (context = %s)", context);
+
+    sr = SQLBindParam(stmt, 2, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(key), &b_ind);
+    if (!SQL_SUCCEEDED(sr)) {
+        m_log.error("SQLBindParam failed (key = %s)", key);
+        log_error(stmt, SQL_HANDLE_STMT);
+        throw IOException("ODBC StorageService failed to insert record.");
+    }
+    m_log.debug("SQLBindParam succeeded (key = %s)", key);
+
+    if (strcmp(table, TEXT_TABLE)==0)
+        sr = SQLBindParam(stmt, 3, SQL_C_CHAR, SQL_LONGVARCHAR, strlen(value), 0, const_cast<char*>(value), &b_ind);
+    else
+        sr = SQLBindParam(stmt, 3, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(value), &b_ind);
+    if (!SQL_SUCCEEDED(sr)) {
+        m_log.error("SQLBindParam failed (value = %s)", value);
         log_error(stmt, SQL_HANDLE_STMT);
         throw IOException("ODBC StorageService failed to insert record.");
     }
+    m_log.debug("SQLBindParam succeeded (value = %s)", value);
+    
+    //freeSafeSQL(scontext, context);
+    //freeSafeSQL(skey, key);
+    //freeSafeSQL(svalue, value);
+    //m_log.debug("SQL: %s", q.c_str());
+
+    int attempts = 3;
+    pair<bool,bool> logres;
+    do {
+        logres = make_pair(false,false);
+        attempts--;
+        sr=SQLExecute(stmt);
+        if (SQL_SUCCEEDED(sr)) {
+            m_log.debug("SQLExecute of insert succeeded");
+            return true;
+        }
+        m_log.error("insert record failed (t=%s, c=%s, k=%s)", table, context, key);
+        logres = log_error(stmt, SQL_HANDLE_STMT, "23000");
+        if (logres.second)
+            return false;   // supposedly integrity violation?
+    } while (attempts && logres.first);
+
+    throw IOException("ODBC StorageService failed to insert record.");
 }
 
 int ODBCStorageService::readRow(
@@ -441,24 +519,26 @@ int ODBCStorageService::readRow(
 
     // Get statement handle.
     ODBCConn conn(getHDBC());
-    ODBCStatement stmt(getHSTMT(conn));
+    SQLHSTMT stmt = getHSTMT(conn);
 
     // Prepare and exectute select statement.
     char timebuf[32];
     timestampFromTime(time(NULL), timebuf);
     char *scontext = makeSafeSQL(context);
     char *skey = makeSafeSQL(key);
-    string q("SELECT version");
+    ostringstream q;
+    q << "SELECT version";
     if (pexpiration)
-        q += ",expires";
+        q << ",expires";
     if (pvalue)
-        q += ",value";
-    q = q + " FROM " + table + " WHERE context='" + scontext + "' AND id='" + skey + "' AND expires > " + timebuf;
+        q << ",CASE version WHEN " << version << " THEN NULL ELSE value END";
+    q << " FROM " << table << " WHERE context='" << scontext << "' AND id='" << skey << "' AND expires > " << timebuf;
     freeSafeSQL(scontext, context);
     freeSafeSQL(skey, key);
-    m_log.debug("SQL: %s", q.c_str());
+    if (m_log.isDebugEnabled())
+        m_log.debug("SQL: %s", q.str().c_str());
 
-    SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
+    SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.str().c_str(), SQL_NTS);
     if (!SQL_SUCCEEDED(sr)) {
         m_log.error("error searching for (t=%s, c=%s, k=%s)", table, context, key);
         log_error(stmt, SQL_HANDLE_STMT);
@@ -506,9 +586,13 @@ int ODBCStorageService::updateRow(const char *table, const char* context, const
     if (!value && !expiration)
         throw IOException("ODBC StorageService given invalid update instructions.");
 
-    // Get statement handle.
+    // Get statement handle. Disable auto-commit mode to wrap select + update.
     ODBCConn conn(getHDBC());
-    ODBCStatement stmt(getHSTMT(conn));
+    SQLRETURN sr = SQLSetConnectAttr(conn, SQL_ATTR_AUTOCOMMIT, SQL_AUTOCOMMIT_OFF, NULL);
+    if (!SQL_SUCCEEDED(sr))
+        throw IOException("ODBC StorageService failed to disable auto-commit mode.");
+    conn.autoCommit = false;
+    SQLHSTMT stmt = getHSTMT(conn);
 
     // First, fetch the current version for later, which also ensures the record still exists.
     char timebuf[32];
@@ -516,11 +600,11 @@ int ODBCStorageService::updateRow(const char *table, const char* context, const
     char *scontext = makeSafeSQL(context);
     char *skey = makeSafeSQL(key);
     string q("SELECT version FROM ");
-    q = q + table + " WHERE context='" + scontext + "' AND id='" + key + "' AND expires > " + timebuf;
+    q = q + table + " WHERE context='" + scontext + "' AND id='" + skey + "' AND expires > " + timebuf;
 
     m_log.debug("SQL: %s", q.c_str());
 
-    SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
+    sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
     if (!SQL_SUCCEEDED(sr)) {
         freeSafeSQL(scontext, context);
         freeSafeSQL(skey, key);
@@ -544,14 +628,14 @@ int ODBCStorageService::updateRow(const char *table, const char* context, const
         return -1;
     }
 
+    SQLFreeHandle(SQL_HANDLE_STMT, stmt);
+    stmt = getHSTMT(conn);
+
     // Prepare and exectute update statement.
     q = string("UPDATE ") + table + " SET ";
 
-    if (value) {
-        char *svalue = makeSafeSQL(value);
-        q = q + "value='" + svalue + "'" + ",version=version+1";
-        freeSafeSQL(svalue, value);
-    }
+    if (value)
+        q = q + "value=?, version=version+1";
 
     if (expiration) {
         timestampFromTime(expiration, timebuf);
@@ -560,21 +644,50 @@ int ODBCStorageService::updateRow(const char *table, const char* context, const
         q = q + "expires = " + timebuf;
     }
 
-    q = q + " WHERE context='" + scontext + "' AND id='" + key + "'";
+    q = q + " WHERE context='" + scontext + "' AND id='" + skey + "'";
     freeSafeSQL(scontext, context);
     freeSafeSQL(skey, key);
 
-    m_log.debug("SQL: %s", q.c_str());
-    sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
-    if (sr==SQL_NO_DATA)
-        return 0;   // went missing?
-    else if (!SQL_SUCCEEDED(sr)) {
+    sr = SQLPrepare(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
+    if (!SQL_SUCCEEDED(sr)) {
         m_log.error("update of record failed (t=%s, c=%s, k=%s", table, context, key);
         log_error(stmt, SQL_HANDLE_STMT);
         throw IOException("ODBC StorageService failed to update record.");
     }
+    m_log.debug("SQLPrepare succeeded. SQL: %s", q.c_str());
+
+    SQLINTEGER b_ind = SQL_NTS;
+    if (value) {
+        if (strcmp(table, TEXT_TABLE)==0)
+            sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_LONGVARCHAR, strlen(value), 0, const_cast<char*>(value), &b_ind);
+        else
+            sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(value), &b_ind);
+        if (!SQL_SUCCEEDED(sr)) {
+            m_log.error("SQLBindParam failed (context = %s)", context);
+            log_error(stmt, SQL_HANDLE_STMT);
+            throw IOException("ODBC StorageService failed to update record.");
+        }
+        m_log.debug("SQLBindParam succeeded (context = %s)", context);
+    }
+
+    int attempts = 3;
+    pair<bool,bool> logres;
+    do {
+        logres = make_pair(false,false);
+        attempts--;
+        sr=SQLExecute(stmt);
+        if (sr==SQL_NO_DATA)
+            return 0;   // went missing?
+        else if (SQL_SUCCEEDED(sr)) {
+            m_log.debug("SQLExecute of update succeeded");
+            return ver + 1;
+        }
+
+        m_log.error("update of record failed (t=%s, c=%s, k=%s", table, context, key);
+        logres = log_error(stmt, SQL_HANDLE_STMT);
+    } while (attempts && logres.first);
 
-    return ver + 1;
+    throw IOException("ODBC StorageService failed to update record.");
 }
 
 bool ODBCStorageService::deleteRow(const char *table, const char *context, const char* key)
@@ -585,7 +698,7 @@ bool ODBCStorageService::deleteRow(const char *table, const char *context, const
 
     // Get statement handle.
     ODBCConn conn(getHDBC());
-    ODBCStatement stmt(getHSTMT(conn));
+    SQLHSTMT stmt = getHSTMT(conn);
 
     // Prepare and execute delete statement.
     char *scontext = makeSafeSQL(context);
@@ -661,7 +774,7 @@ void ODBCStorageService::updateContext(const char *table, const char* context, t
 
     // Get statement handle.
     ODBCConn conn(getHDBC());
-    ODBCStatement stmt(getHSTMT(conn));
+    SQLHSTMT stmt = getHSTMT(conn);
 
     char timebuf[32];
     timestampFromTime(expiration, timebuf);
@@ -692,7 +805,7 @@ void ODBCStorageService::reap(const char *table, const char* context)
 
     // Get statement handle.
     ODBCConn conn(getHDBC());
-    ODBCStatement stmt(getHSTMT(conn));
+    SQLHSTMT stmt = getHSTMT(conn);
 
     // Prepare and execute delete statement.
     char nowbuf[32];
@@ -724,7 +837,7 @@ void ODBCStorageService::deleteContext(const char *table, const char* context)
 
     // Get statement handle.
     ODBCConn conn(getHDBC());
-    ODBCStatement stmt(getHSTMT(conn));
+    SQLHSTMT stmt = getHSTMT(conn);
 
     // Prepare and execute delete statement.
     char *scontext = makeSafeSQL(context);