Switch to auto-commit for all the non-update transactions.
[shibboleth/sp.git] / odbc-store / odbc-store.cpp
index 11405e4..2fd13c4 100644 (file)
@@ -96,23 +96,19 @@ namespace {
 
     // RAII for ODBC handles
     struct ODBCConn {
-        ODBCConn(SQLHDBC conn) : handle(conn) {}
+        ODBCConn(SQLHDBC conn) : handle(conn), autoCommit(true) {}
         ~ODBCConn() {
-            SQLRETURN sr = SQLEndTran(SQL_HANDLE_DBC, handle, SQL_COMMIT);
+            SQLRETURN sr = SQL_SUCCESS;
+            if (!autoCommit)
+                sr = SQLSetConnectAttr(handle, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)SQL_AUTOCOMMIT_ON, NULL);
             SQLDisconnect(handle);
             SQLFreeHandle(SQL_HANDLE_DBC,handle);
             if (!SQL_SUCCEEDED(sr))
-                throw IOException("Failed to commit connection.");
+                throw IOException("Failed to commit connection and return to auto-commit mode.");
         }
         operator SQLHDBC() {return handle;}
         SQLHDBC handle;
-    };
-
-    struct ODBCStatement {
-        ODBCStatement(SQLHSTMT statement) : handle(statement) {}
-        ~ODBCStatement() {SQLFreeHandle(SQL_HANDLE_STMT,handle);}
-        operator SQLHSTMT() {return handle;}
-        SQLHSTMT handle;
+        bool autoCommit;
     };
 
     class ODBCStorageService : public StorageService
@@ -314,7 +310,8 @@ ODBCStorageService::~ODBCStorageService()
     shutdown_wait->signal();
     cleanup_thread->join(NULL);
     delete shutdown_wait;
-    SQLFreeHandle(SQL_HANDLE_ENV, m_henv);
+    if (m_henv != SQL_NULL_HANDLE)
+        SQLFreeHandle(SQL_HANDLE_ENV, m_henv);
 }
 
 bool ODBCStorageService::log_error(SQLHANDLE handle, SQLSMALLINT htype, const char* checkfor)
@@ -360,9 +357,6 @@ SQLHDBC ODBCStorageService::getHDBC()
         throw IOException("ODBC StorageService failed to connect to database.");
     }
 
-    sr = SQLSetConnectAttr(handle, SQL_ATTR_AUTOCOMMIT, SQL_AUTOCOMMIT_OFF, NULL);
-    if (!SQL_SUCCEEDED(sr))
-        throw IOException("ODBC StorageService failed to disable auto-commit mode.");
     sr = SQLSetConnectAttr(handle, SQL_ATTR_TXN_ISOLATION, (SQLPOINTER)SQL_TXN_SERIALIZABLE, NULL);
     if (!SQL_SUCCEEDED(sr))
         throw IOException("ODBC StorageService failed to enable transaction isolation.");
@@ -385,7 +379,7 @@ SQLHSTMT ODBCStorageService::getHSTMT(SQLHDBC conn)
 pair<int,int> ODBCStorageService::getVersion(SQLHDBC conn)
 {
     // Grab the version number from the database.
-    ODBCStatement stmt(getHSTMT(conn));
+    SQLHSTMT stmt = getHSTMT(conn);
     
     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)"SELECT major,minor FROM version", SQL_NTS);
     if (!SQL_SUCCEEDED(sr)) {
@@ -417,7 +411,7 @@ bool ODBCStorageService::createRow(const char* table, const char* context, const
 
     // Get statement handle.
     ODBCConn conn(getHDBC());
-    ODBCStatement stmt(getHSTMT(conn));
+    SQLHSTMT stmt = getHSTMT(conn);
 
     // Prepare and exectute insert statement.
     //char *scontext = makeSafeSQL(context);
@@ -431,10 +425,10 @@ bool ODBCStorageService::createRow(const char* table, const char* context, const
         log_error(stmt, SQL_HANDLE_STMT);
         throw IOException("ODBC StorageService failed to insert record.");
     }
-    m_log.debug("SQLPrepare() succeded. SQL: %s", q.c_str());
+    m_log.debug("SQLPrepare succeded. SQL: %s", q.c_str());
 
     SQLINTEGER b_ind = SQL_NTS;
-    sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_VARCHAR, 0, 0, const_cast<char*>(context), &b_ind);
+    sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(context), &b_ind);
     if (!SQL_SUCCEEDED(sr)) {
         m_log.error("SQLBindParam failed (context = %s)", context);
         log_error(stmt, SQL_HANDLE_STMT);
@@ -442,7 +436,7 @@ bool ODBCStorageService::createRow(const char* table, const char* context, const
     }
     m_log.debug("SQLBindParam succeded (context = %s)", context);
 
-    sr = SQLBindParam(stmt, 2, SQL_C_CHAR, SQL_VARCHAR, 0, 0, const_cast<char*>(key), &b_ind);
+    sr = SQLBindParam(stmt, 2, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(key), &b_ind);
     if (!SQL_SUCCEEDED(sr)) {
         m_log.error("SQLBindParam failed (key = %s)", key);
         log_error(stmt, SQL_HANDLE_STMT);
@@ -450,7 +444,10 @@ bool ODBCStorageService::createRow(const char* table, const char* context, const
     }
     m_log.debug("SQLBindParam succeded (key = %s)", key);
 
-    sr = SQLBindParam(stmt, 3, SQL_C_CHAR, (strcmp(table, TEXT_TABLE)==0 ? SQL_LONGVARCHAR : SQL_VARCHAR), 0, 0, const_cast<char*>(value), &b_ind);
+    if (strcmp(table, TEXT_TABLE)==0)
+        sr = SQLBindParam(stmt, 3, SQL_C_CHAR, SQL_LONGVARCHAR, strlen(value), 0, const_cast<char*>(value), &b_ind);
+    else
+        sr = SQLBindParam(stmt, 3, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(value), &b_ind);
     if (!SQL_SUCCEEDED(sr)) {
         m_log.error("SQLBindParam failed (value = %s)", value);
         log_error(stmt, SQL_HANDLE_STMT);
@@ -470,6 +467,8 @@ bool ODBCStorageService::createRow(const char* table, const char* context, const
             return false;   // supposedly integrity violation?
         throw IOException("ODBC StorageService failed to insert record.");
     }
+
+    m_log.debug("SQLExecute of insert succeeded");
     return true;
 }
 
@@ -483,7 +482,7 @@ int ODBCStorageService::readRow(
 
     // Get statement handle.
     ODBCConn conn(getHDBC());
-    ODBCStatement stmt(getHSTMT(conn));
+    SQLHSTMT stmt = getHSTMT(conn);
 
     // Prepare and exectute select statement.
     char timebuf[32];
@@ -550,9 +549,13 @@ int ODBCStorageService::updateRow(const char *table, const char* context, const
     if (!value && !expiration)
         throw IOException("ODBC StorageService given invalid update instructions.");
 
-    // Get statement handle.
+    // Get statement handle. Disable auto-commit mode to wrap select + update.
     ODBCConn conn(getHDBC());
-    ODBCStatement stmt(getHSTMT(conn));
+    SQLRETURN sr = SQLSetConnectAttr(conn, SQL_ATTR_AUTOCOMMIT, SQL_AUTOCOMMIT_OFF, NULL);
+    if (!SQL_SUCCEEDED(sr))
+        throw IOException("ODBC StorageService failed to disable auto-commit mode.");
+    conn.autoCommit = false;
+    SQLHSTMT stmt = getHSTMT(conn);
 
     // First, fetch the current version for later, which also ensures the record still exists.
     char timebuf[32];
@@ -564,7 +567,7 @@ int ODBCStorageService::updateRow(const char *table, const char* context, const
 
     m_log.debug("SQL: %s", q.c_str());
 
-    SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
+    sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
     if (!SQL_SUCCEEDED(sr)) {
         freeSafeSQL(scontext, context);
         freeSafeSQL(skey, key);
@@ -588,6 +591,9 @@ int ODBCStorageService::updateRow(const char *table, const char* context, const
         return -1;
     }
 
+    SQLFreeHandle(SQL_HANDLE_STMT, stmt);
+    stmt = getHSTMT(conn);
+
     // Prepare and exectute update statement.
     q = string("UPDATE ") + table + " SET ";
 
@@ -611,11 +617,14 @@ int ODBCStorageService::updateRow(const char *table, const char* context, const
         log_error(stmt, SQL_HANDLE_STMT);
         throw IOException("ODBC StorageService failed to update record.");
     }
-    m_log.debug("SQLPrepare() succeded. SQL: %s", q.c_str());
+    m_log.debug("SQLPrepare succeded. SQL: %s", q.c_str());
 
     SQLINTEGER b_ind = SQL_NTS;
     if (value) {
-        sr = SQLBindParam(stmt, 1, SQL_C_CHAR, (strcmp(table, TEXT_TABLE)==0 ? SQL_LONGVARCHAR : SQL_VARCHAR), 0, 0, const_cast<char*>(value), &b_ind);
+        if (strcmp(table, TEXT_TABLE)==0)
+            sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_LONGVARCHAR, strlen(value), 0, const_cast<char*>(value), &b_ind);
+        else
+            sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(value), &b_ind);
         if (!SQL_SUCCEEDED(sr)) {
             m_log.error("SQLBindParam failed (context = %s)", context);
             log_error(stmt, SQL_HANDLE_STMT);
@@ -633,6 +642,7 @@ int ODBCStorageService::updateRow(const char *table, const char* context, const
         throw IOException("ODBC StorageService failed to update record.");
     }
 
+    m_log.debug("SQLExecute of update succeeded");
     return ver + 1;
 }
 
@@ -644,7 +654,7 @@ bool ODBCStorageService::deleteRow(const char *table, const char *context, const
 
     // Get statement handle.
     ODBCConn conn(getHDBC());
-    ODBCStatement stmt(getHSTMT(conn));
+    SQLHSTMT stmt = getHSTMT(conn);
 
     // Prepare and execute delete statement.
     char *scontext = makeSafeSQL(context);
@@ -720,7 +730,7 @@ void ODBCStorageService::updateContext(const char *table, const char* context, t
 
     // Get statement handle.
     ODBCConn conn(getHDBC());
-    ODBCStatement stmt(getHSTMT(conn));
+    SQLHSTMT stmt = getHSTMT(conn);
 
     char timebuf[32];
     timestampFromTime(expiration, timebuf);
@@ -751,7 +761,7 @@ void ODBCStorageService::reap(const char *table, const char* context)
 
     // Get statement handle.
     ODBCConn conn(getHDBC());
-    ODBCStatement stmt(getHSTMT(conn));
+    SQLHSTMT stmt = getHSTMT(conn);
 
     // Prepare and execute delete statement.
     char nowbuf[32];
@@ -783,7 +793,7 @@ void ODBCStorageService::deleteContext(const char *table, const char* context)
 
     // Get statement handle.
     ODBCConn conn(getHDBC());
-    ODBCStatement stmt(getHSTMT(conn));
+    SQLHSTMT stmt = getHSTMT(conn);
 
     // Prepare and execute delete statement.
     char *scontext = makeSafeSQL(context);