SQL Server wants column lengths filled in.
[shibboleth/sp.git] / odbc-store / odbc-store.cpp
index 5eb2a2b..9af4d0f 100644 (file)
@@ -37,8 +37,8 @@
 # define ODBCSTORE_EXPORTS
 #endif
 
-#include <log4cpp/Category.hh>
 #include <xercesc/util/XMLUniDefs.hpp>
+#include <xmltooling/logging.h>
 #include <xmltooling/XMLToolingConfig.h>
 #include <xmltooling/util/NDC.h>
 #include <xmltooling/util/StorageService.h>
@@ -48,9 +48,9 @@
 #include <sql.h>
 #include <sqlext.h>
 
+using namespace xmltooling::logging;
 using namespace xmltooling;
 using namespace xercesc;
-using namespace log4cpp;
 using namespace std;
 
 #define PLUGIN_VER_MAJOR 1
@@ -65,7 +65,7 @@ using namespace std;
 #define STRING_TABLE "strings"
 #define TEXT_TABLE "texts"
 
-/* tables definitions
+/* table definitions
 CREATE TABLE version (
     major tinyint NOT NULL,
     minor tinyint NOT NULL
@@ -108,13 +108,6 @@ namespace {
         SQLHDBC handle;
     };
 
-    struct ODBCStatement {
-        ODBCStatement(SQLHSTMT statement) : handle(statement) {}
-        ~ODBCStatement() {SQLFreeHandle(SQL_HANDLE_STMT,handle);}
-        operator SQLHSTMT() {return handle;}
-        SQLHSTMT handle;
-    };
-
     class ODBCStorageService : public StorageService
     {
     public:
@@ -314,7 +307,8 @@ ODBCStorageService::~ODBCStorageService()
     shutdown_wait->signal();
     cleanup_thread->join(NULL);
     delete shutdown_wait;
-    SQLFreeHandle(SQL_HANDLE_ENV, m_henv);
+    if (m_henv != SQL_NULL_HANDLE)
+        SQLFreeHandle(SQL_HANDLE_ENV, m_henv);
 }
 
 bool ODBCStorageService::log_error(SQLHANDLE handle, SQLSMALLINT htype, const char* checkfor)
@@ -385,7 +379,7 @@ SQLHSTMT ODBCStorageService::getHSTMT(SQLHDBC conn)
 pair<int,int> ODBCStorageService::getVersion(SQLHDBC conn)
 {
     // Grab the version number from the database.
-    ODBCStatement stmt(getHSTMT(conn));
+    SQLHSTMT stmt = getHSTMT(conn);
     
     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)"SELECT major,minor FROM version", SQL_NTS);
     if (!SQL_SUCCEEDED(sr)) {
@@ -406,7 +400,7 @@ pair<int,int> ODBCStorageService::getVersion(SQLHDBC conn)
     throw IOException("ODBC StorageService failed to read version from database.");
 }
 
-bool ODBCStorageService::createRow(const char *table, const char* context, const char* key, const char* value, time_t expiration)
+bool ODBCStorageService::createRow(const chartable, const char* context, const char* key, const char* value, time_t expiration)
 {
 #ifdef _DEBUG
     xmltooling::NDC ndc("createRow");
@@ -417,25 +411,64 @@ bool ODBCStorageService::createRow(const char *table, const char* context, const
 
     // Get statement handle.
     ODBCConn conn(getHDBC());
-    ODBCStatement stmt(getHSTMT(conn));
+    SQLHSTMT stmt = getHSTMT(conn);
 
     // Prepare and exectute insert statement.
-    char *scontext = makeSafeSQL(context);
-    char *skey = makeSafeSQL(key);
-    char *svalue = makeSafeSQL(value);
-    string q  = string("INSERT ") + table + " VALUES ('" + scontext + "','" + skey + "'," + timebuf + ",1,'" + svalue + "')";
-    freeSafeSQL(scontext, context);
-    freeSafeSQL(skey, key);
-    freeSafeSQL(svalue, value);
-    m_log.debug("SQL: %s", q.c_str());
+    //char *scontext = makeSafeSQL(context);
+    //char *skey = makeSafeSQL(key);
+    //char *svalue = makeSafeSQL(value);
+    string q  = string("INSERT INTO ") + table + " VALUES (?,?," + timebuf + ",1,?)";
 
-    SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
+    SQLRETURN sr = SQLPrepare(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
+    if (!SQL_SUCCEEDED(sr)) {
+        m_log.error("SQLPrepare failed (t=%s, c=%s, k=%s)", table, context, key);
+        log_error(stmt, SQL_HANDLE_STMT);
+        throw IOException("ODBC StorageService failed to insert record.");
+    }
+    m_log.debug("SQLPrepare succeded. SQL: %s", q.c_str());
+
+    SQLINTEGER b_ind = SQL_NTS;
+    sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(context), &b_ind);
+    if (!SQL_SUCCEEDED(sr)) {
+        m_log.error("SQLBindParam failed (context = %s)", context);
+        log_error(stmt, SQL_HANDLE_STMT);
+        throw IOException("ODBC StorageService failed to insert record.");
+    }
+    m_log.debug("SQLBindParam succeded (context = %s)", context);
+
+    sr = SQLBindParam(stmt, 2, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(key), &b_ind);
+    if (!SQL_SUCCEEDED(sr)) {
+        m_log.error("SQLBindParam failed (key = %s)", key);
+        log_error(stmt, SQL_HANDLE_STMT);
+        throw IOException("ODBC StorageService failed to insert record.");
+    }
+    m_log.debug("SQLBindParam succeded (key = %s)", key);
+
+    if (strcmp(table, TEXT_TABLE)==0)
+        sr = SQLBindParam(stmt, 3, SQL_C_CHAR, SQL_LONGVARCHAR, strlen(value), 0, const_cast<char*>(value), &b_ind);
+    else
+        sr = SQLBindParam(stmt, 3, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(value), &b_ind);
+    if (!SQL_SUCCEEDED(sr)) {
+        m_log.error("SQLBindParam failed (value = %s)", value);
+        log_error(stmt, SQL_HANDLE_STMT);
+        throw IOException("ODBC StorageService failed to insert record.");
+    }
+    m_log.debug("SQLBindParam succeded (value = %s)", value);
+    
+    //freeSafeSQL(scontext, context);
+    //freeSafeSQL(skey, key);
+    //freeSafeSQL(svalue, value);
+    //m_log.debug("SQL: %s", q.c_str());
+
+    sr=SQLExecute(stmt);
     if (!SQL_SUCCEEDED(sr)) {
         m_log.error("insert record failed (t=%s, c=%s, k=%s)", table, context, key);
         if (log_error(stmt, SQL_HANDLE_STMT, "23000"))
             return false;   // supposedly integrity violation?
         throw IOException("ODBC StorageService failed to insert record.");
     }
+
+    m_log.debug("SQLExecute of insert succeeded");
     return true;
 }
 
@@ -449,7 +482,7 @@ int ODBCStorageService::readRow(
 
     // Get statement handle.
     ODBCConn conn(getHDBC());
-    ODBCStatement stmt(getHSTMT(conn));
+    SQLHSTMT stmt = getHSTMT(conn);
 
     // Prepare and exectute select statement.
     char timebuf[32];
@@ -518,7 +551,7 @@ int ODBCStorageService::updateRow(const char *table, const char* context, const
 
     // Get statement handle.
     ODBCConn conn(getHDBC());
-    ODBCStatement stmt(getHSTMT(conn));
+    SQLHSTMT stmt = getHSTMT(conn);
 
     // First, fetch the current version for later, which also ensures the record still exists.
     char timebuf[32];
@@ -554,14 +587,14 @@ int ODBCStorageService::updateRow(const char *table, const char* context, const
         return -1;
     }
 
+    SQLFreeHandle(SQL_HANDLE_STMT, stmt);
+    stmt = getHSTMT(conn);
+
     // Prepare and exectute update statement.
     q = string("UPDATE ") + table + " SET ";
 
-    if (value) {
-        char *svalue = makeSafeSQL(value);
-        q = q + "value='" + svalue + "'" + ",version=version+1";
-        freeSafeSQL(svalue, value);
-    }
+    if (value)
+        q = q + "value=?, version=version+1";
 
     if (expiration) {
         timestampFromTime(expiration, timebuf);
@@ -574,8 +607,29 @@ int ODBCStorageService::updateRow(const char *table, const char* context, const
     freeSafeSQL(scontext, context);
     freeSafeSQL(skey, key);
 
-    m_log.debug("SQL: %s", q.c_str());
-    sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
+    sr = SQLPrepare(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
+    if (!SQL_SUCCEEDED(sr)) {
+        m_log.error("update of record failed (t=%s, c=%s, k=%s", table, context, key);
+        log_error(stmt, SQL_HANDLE_STMT);
+        throw IOException("ODBC StorageService failed to update record.");
+    }
+    m_log.debug("SQLPrepare succeded. SQL: %s", q.c_str());
+
+    SQLINTEGER b_ind = SQL_NTS;
+    if (value) {
+        if (strcmp(table, TEXT_TABLE)==0)
+            sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_LONGVARCHAR, strlen(value), 0, const_cast<char*>(value), &b_ind);
+        else
+            sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(value), &b_ind);
+        if (!SQL_SUCCEEDED(sr)) {
+            m_log.error("SQLBindParam failed (context = %s)", context);
+            log_error(stmt, SQL_HANDLE_STMT);
+            throw IOException("ODBC StorageService failed to update record.");
+        }
+        m_log.debug("SQLBindParam succeded (context = %s)", context);
+    }
+
+    sr=SQLExecute(stmt);
     if (sr==SQL_NO_DATA)
         return 0;   // went missing?
     else if (!SQL_SUCCEEDED(sr)) {
@@ -584,6 +638,7 @@ int ODBCStorageService::updateRow(const char *table, const char* context, const
         throw IOException("ODBC StorageService failed to update record.");
     }
 
+    m_log.debug("SQLExecute of update succeeded");
     return ver + 1;
 }
 
@@ -595,7 +650,7 @@ bool ODBCStorageService::deleteRow(const char *table, const char *context, const
 
     // Get statement handle.
     ODBCConn conn(getHDBC());
-    ODBCStatement stmt(getHSTMT(conn));
+    SQLHSTMT stmt = getHSTMT(conn);
 
     // Prepare and execute delete statement.
     char *scontext = makeSafeSQL(context);
@@ -671,7 +726,7 @@ void ODBCStorageService::updateContext(const char *table, const char* context, t
 
     // Get statement handle.
     ODBCConn conn(getHDBC());
-    ODBCStatement stmt(getHSTMT(conn));
+    SQLHSTMT stmt = getHSTMT(conn);
 
     char timebuf[32];
     timestampFromTime(expiration, timebuf);
@@ -702,7 +757,7 @@ void ODBCStorageService::reap(const char *table, const char* context)
 
     // Get statement handle.
     ODBCConn conn(getHDBC());
-    ODBCStatement stmt(getHSTMT(conn));
+    SQLHSTMT stmt = getHSTMT(conn);
 
     // Prepare and execute delete statement.
     char nowbuf[32];
@@ -734,7 +789,7 @@ void ODBCStorageService::deleteContext(const char *table, const char* context)
 
     // Get statement handle.
     ODBCConn conn(getHDBC());
-    ODBCStatement stmt(getHSTMT(conn));
+    SQLHSTMT stmt = getHSTMT(conn);
 
     // Prepare and execute delete statement.
     char *scontext = makeSafeSQL(context);