SQL Server wants column lengths filled in.
[shibboleth/sp.git] / odbc-store / odbc-store.cpp
index 7efd78f..9af4d0f 100644 (file)
  * Storage Service using ODBC
  */
 
-#if defined (_MSC_VER) || defined(__BORLANDC__)\r
-# include "config_win32.h"\r
-#else\r
-# include "config.h"\r
-#endif\r
-\r
-#ifdef WIN32\r
-# define _CRT_NONSTDC_NO_DEPRECATE 1\r
-# define _CRT_SECURE_NO_DEPRECATE 1\r
-#endif\r
+#if defined (_MSC_VER) || defined(__BORLANDC__)
+# include "config_win32.h"
+#else
+# include "config.h"
+#endif
+
+#ifdef WIN32
+# define _CRT_NONSTDC_NO_DEPRECATE 1
+# define _CRT_SECURE_NO_DEPRECATE 1
+#endif
 
 #ifdef WIN32
 # define ODBCSTORE_EXPORTS __declspec(dllexport)
@@ -37,8 +37,8 @@
 # define ODBCSTORE_EXPORTS
 #endif
 
-#include <log4cpp/Category.hh>
 #include <xercesc/util/XMLUniDefs.hpp>
+#include <xmltooling/logging.h>
 #include <xmltooling/XMLToolingConfig.h>
 #include <xmltooling/util/NDC.h>
 #include <xmltooling/util/StorageService.h>
@@ -48,9 +48,9 @@
 #include <sql.h>
 #include <sqlext.h>
 
+using namespace xmltooling::logging;
 using namespace xmltooling;
 using namespace xercesc;
-using namespace log4cpp;
 using namespace std;
 
 #define PLUGIN_VER_MAJOR 1
@@ -58,35 +58,36 @@ using namespace std;
 
 #define LONGDATA_BUFLEN 16384
 
-#define COLSIZE_KEY 255
 #define COLSIZE_CONTEXT 255
+#define COLSIZE_ID 255
 #define COLSIZE_STRING_VALUE 255
 
-#define STRING_TABLE "STRING_TABLE"
-#define TEXT_TABLE "TEXT_TABLE"
-
-/* tables definitions - not used here
-
-#define STRING_TABLE \
-  "CREATE TABLE STRING_TABLE ( " \
-    "context varchar(255), " \
-    "key varchar(255), " \
-    "value varchar(255), " \
-    "expires datetime, " \
-    "version smallint, " \
-    "PRIMARY KEY (context, key)" \
-    ")"
-
-
-#define TEXT_TABLE \
-  "CREATE TABLE TEXT_TABLE ( "\
-    "context varchar(255), " \
-    "key varchar(255), " \
-    "value text, " \
-    "expires datetime, " \
-    "version smallint, " \
-    "PRIMARY KEY (context, key)" \
-    ")"
+#define STRING_TABLE "strings"
+#define TEXT_TABLE "texts"
+
+/* table definitions
+CREATE TABLE version (
+    major tinyint NOT NULL,
+    minor tinyint NOT NULL
+    )
+
+CREATE TABLE strings (
+    context varchar(255) not null,
+    id varchar(255) not null,
+    expires datetime not null,
+    version smallint not null,
+    value varchar(255) not null,
+    PRIMARY KEY (context, id)
+    )
+
+CREATE TABLE texts (
+    context varchar(255) not null,
+    id varchar(255) not null,
+    expires datetime not null,
+    version smallint not null,
+    value text not null,
+    PRIMARY KEY (context, id)
+    )
 */
 
 namespace {
@@ -98,6 +99,7 @@ namespace {
         ODBCConn(SQLHDBC conn) : handle(conn) {}
         ~ODBCConn() {
             SQLRETURN sr = SQLEndTran(SQL_HANDLE_DBC, handle, SQL_COMMIT);
+            SQLDisconnect(handle);
             SQLFreeHandle(SQL_HANDLE_DBC,handle);
             if (!SQL_SUCCEEDED(sr))
                 throw IOException("Failed to commit connection.");
@@ -106,20 +108,13 @@ namespace {
         SQLHDBC handle;
     };
 
-    struct ODBCStatement {
-        ODBCStatement(SQLHSTMT statement) : handle(statement) {}
-        ~ODBCStatement() {SQLFreeHandle(SQL_HANDLE_STMT,handle);}
-        operator SQLHSTMT() {return handle;}
-        SQLHSTMT handle;
-    };
-
     class ODBCStorageService : public StorageService
     {
     public:
         ODBCStorageService(const DOMElement* e);
         virtual ~ODBCStorageService();
 
-        void createString(const char* context, const char* key, const char* value, time_t expiration) {
+        bool createString(const char* context, const char* key, const char* value, time_t expiration) {
             return createRow(STRING_TABLE, context, key, value, expiration);
         }
         int readString(const char* context, const char* key, string* pvalue=NULL, time_t* pexpiration=NULL, int version=0) {
@@ -132,7 +127,7 @@ namespace {
             return deleteRow(STRING_TABLE, context, key);
         }
 
-        void createText(const char* context, const char* key, const char* value, time_t expiration) {
+        bool createText(const char* context, const char* key, const char* value, time_t expiration) {
             return createRow(TEXT_TABLE, context, key, value, expiration);
         }
         int readText(const char* context, const char* key, string* pvalue=NULL, time_t* pexpiration=NULL, int version=0) {
@@ -162,7 +157,7 @@ namespace {
          
 
     private:
-        void createRow(const char *table, const char* context, const char* key, const char* value, time_t expiration);
+        bool createRow(const char *table, const char* context, const char* key, const char* value, time_t expiration);
         int readRow(const char *table, const char* context, const char* key, string* pvalue, time_t* pexpiration, int version, bool text);
         int updateRow(const char *table, const char* context, const char* key, const char* value, time_t expiration, int version);
         bool deleteRow(const char *table, const char* context, const char* key);
@@ -174,7 +169,7 @@ namespace {
         SQLHDBC getHDBC();
         SQLHSTMT getHSTMT(SQLHDBC);
         pair<int,int> getVersion(SQLHDBC);
-        void log_error(SQLHANDLE handle, SQLSMALLINT htype);
+        bool log_error(SQLHANDLE handle, SQLSMALLINT htype, const char* checkfor=NULL);
 
         static void* cleanup_fn(void*); 
         void cleanup();
@@ -235,12 +230,12 @@ namespace {
        char *s;
     
        // see if any conversion needed
-       for (s=(char*)src; *s; nc++,s++) if (*s=='\''||*s=='\\') ns++;
+       for (s=(char*)src; *s; nc++,s++) if (*s=='\'') ns++;
        if (ns==0) return ((char*)src);
     
        char *safe = new char[(nc+2*ns+1)];
        for (s=safe; *src; src++) {
-           if (*src=='\''||*src=='\\') *s++ = '\\';
+           if (*src=='\'') *s++ = '\'';
            *s++ = (char)*src;
        }
        *s = '\0';
@@ -312,10 +307,11 @@ ODBCStorageService::~ODBCStorageService()
     shutdown_wait->signal();
     cleanup_thread->join(NULL);
     delete shutdown_wait;
-    SQLFreeHandle(SQL_HANDLE_ENV, m_henv);
+    if (m_henv != SQL_NULL_HANDLE)
+        SQLFreeHandle(SQL_HANDLE_ENV, m_henv);
 }
 
-void ODBCStorageService::log_error(SQLHANDLE handle, SQLSMALLINT htype)
+bool ODBCStorageService::log_error(SQLHANDLE handle, SQLSMALLINT htype, const char* checkfor)
 {
     SQLSMALLINT         i = 0;
     SQLINTEGER  native;
@@ -324,11 +320,16 @@ void ODBCStorageService::log_error(SQLHANDLE handle, SQLSMALLINT htype)
     SQLSMALLINT         len;
     SQLRETURN   ret;
 
+    bool res = false;
     do {
         ret = SQLGetDiagRec(htype, handle, ++i, state, &native, text, sizeof(text), &len);
-        if (SQL_SUCCEEDED(ret))
+        if (SQL_SUCCEEDED(ret)) {
             m_log.error("ODBC Error: %s:%ld:%ld:%s", state, i, native, text);
+            if (checkfor && !strcmp(checkfor, (const char*)state))
+                res = true;
+        }
     } while(SQL_SUCCEEDED(ret));
+    return res;
 }
 
 SQLHDBC ODBCStorageService::getHDBC()
@@ -378,7 +379,7 @@ SQLHSTMT ODBCStorageService::getHSTMT(SQLHDBC conn)
 pair<int,int> ODBCStorageService::getVersion(SQLHDBC conn)
 {
     // Grab the version number from the database.
-    ODBCStatement stmt(getHSTMT(conn));
+    SQLHSTMT stmt = getHSTMT(conn);
     
     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)"SELECT major,minor FROM version", SQL_NTS);
     if (!SQL_SUCCEEDED(sr)) {
@@ -399,7 +400,7 @@ pair<int,int> ODBCStorageService::getVersion(SQLHDBC conn)
     throw IOException("ODBC StorageService failed to read version from database.");
 }
 
-void ODBCStorageService::createRow(const char *table, const char* context, const char* key, const char* value, time_t expiration)
+bool ODBCStorageService::createRow(const char* table, const char* context, const char* key, const char* value, time_t expiration)
 {
 #ifdef _DEBUG
     xmltooling::NDC ndc("createRow");
@@ -410,24 +411,65 @@ void ODBCStorageService::createRow(const char *table, const char* context, const
 
     // Get statement handle.
     ODBCConn conn(getHDBC());
-    ODBCStatement stmt(getHSTMT(conn));
+    SQLHSTMT stmt = getHSTMT(conn);
 
     // Prepare and exectute insert statement.
-    char *scontext = makeSafeSQL(context);
-    char *skey = makeSafeSQL(key);
-    char *svalue = makeSafeSQL(value);
-    string q  = string("INSERT ") + table + " VALUES ('" + scontext + "','" + skey + "','" + svalue + "'," + timebuf + "', 1)";
-    freeSafeSQL(scontext, context);
-    freeSafeSQL(skey, key);
-    freeSafeSQL(svalue, value);
-    m_log.debug("SQL: %s", q.c_str());
+    //char *scontext = makeSafeSQL(context);
+    //char *skey = makeSafeSQL(key);
+    //char *svalue = makeSafeSQL(value);
+    string q  = string("INSERT INTO ") + table + " VALUES (?,?," + timebuf + ",1,?)";
 
-    SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
+    SQLRETURN sr = SQLPrepare(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
     if (!SQL_SUCCEEDED(sr)) {
-        m_log.error("insert record failed (t=%s, c=%s, k=%s)", table, context, key);
+        m_log.error("SQLPrepare failed (t=%s, c=%s, k=%s)", table, context, key);
+        log_error(stmt, SQL_HANDLE_STMT);
+        throw IOException("ODBC StorageService failed to insert record.");
+    }
+    m_log.debug("SQLPrepare succeded. SQL: %s", q.c_str());
+
+    SQLINTEGER b_ind = SQL_NTS;
+    sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(context), &b_ind);
+    if (!SQL_SUCCEEDED(sr)) {
+        m_log.error("SQLBindParam failed (context = %s)", context);
         log_error(stmt, SQL_HANDLE_STMT);
         throw IOException("ODBC StorageService failed to insert record.");
     }
+    m_log.debug("SQLBindParam succeded (context = %s)", context);
+
+    sr = SQLBindParam(stmt, 2, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(key), &b_ind);
+    if (!SQL_SUCCEEDED(sr)) {
+        m_log.error("SQLBindParam failed (key = %s)", key);
+        log_error(stmt, SQL_HANDLE_STMT);
+        throw IOException("ODBC StorageService failed to insert record.");
+    }
+    m_log.debug("SQLBindParam succeded (key = %s)", key);
+
+    if (strcmp(table, TEXT_TABLE)==0)
+        sr = SQLBindParam(stmt, 3, SQL_C_CHAR, SQL_LONGVARCHAR, strlen(value), 0, const_cast<char*>(value), &b_ind);
+    else
+        sr = SQLBindParam(stmt, 3, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(value), &b_ind);
+    if (!SQL_SUCCEEDED(sr)) {
+        m_log.error("SQLBindParam failed (value = %s)", value);
+        log_error(stmt, SQL_HANDLE_STMT);
+        throw IOException("ODBC StorageService failed to insert record.");
+    }
+    m_log.debug("SQLBindParam succeded (value = %s)", value);
+    
+    //freeSafeSQL(scontext, context);
+    //freeSafeSQL(skey, key);
+    //freeSafeSQL(svalue, value);
+    //m_log.debug("SQL: %s", q.c_str());
+
+    sr=SQLExecute(stmt);
+    if (!SQL_SUCCEEDED(sr)) {
+        m_log.error("insert record failed (t=%s, c=%s, k=%s)", table, context, key);
+        if (log_error(stmt, SQL_HANDLE_STMT, "23000"))
+            return false;   // supposedly integrity violation?
+        throw IOException("ODBC StorageService failed to insert record.");
+    }
+
+    m_log.debug("SQLExecute of insert succeeded");
+    return true;
 }
 
 int ODBCStorageService::readRow(
@@ -438,26 +480,28 @@ int ODBCStorageService::readRow(
     xmltooling::NDC ndc("readRow");
 #endif
 
-    SQLCHAR *tvalue = NULL;
-
     // Get statement handle.
     ODBCConn conn(getHDBC());
-    ODBCStatement stmt(getHSTMT(conn));
+    SQLHSTMT stmt = getHSTMT(conn);
 
     // Prepare and exectute select statement.
+    char timebuf[32];
+    timestampFromTime(time(NULL), timebuf);
     char *scontext = makeSafeSQL(context);
     char *skey = makeSafeSQL(key);
-    string q("SELECT version");
+    ostringstream q;
+    q << "SELECT version";
     if (pexpiration)
-        q += ",expires";
+        q << ",expires";
     if (pvalue)
-        q += ",value";
-    q = q + " FROM " + table + " WHERE context='" + scontext + "' AND key='" + skey + "' AND expires > NOW()";
+        q << ",CASE version WHEN " << version << " THEN NULL ELSE value END";
+    q << " FROM " << table << " WHERE context='" << scontext << "' AND id='" << skey << "' AND expires > " << timebuf;
     freeSafeSQL(scontext, context);
     freeSafeSQL(skey, key);
-    m_log.debug("SQL: %s", q.c_str());
+    if (m_log.isDebugEnabled())
+        m_log.debug("SQL: %s", q.str().c_str());
 
-    SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
+    SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.str().c_str(), SQL_NTS);
     if (!SQL_SUCCEEDED(sr)) {
         m_log.error("error searching for (t=%s, c=%s, k=%s)", table, context, key);
         log_error(stmt, SQL_HANDLE_STMT);
@@ -481,16 +525,16 @@ int ODBCStorageService::readRow(
         return version; // nothing's changed, so just echo back the version
 
     if (pvalue) {
-        SQLINTEGER len;\r
-        SQLCHAR buf[LONGDATA_BUFLEN];\r
-        while ((sr=SQLGetData(stmt,pexpiration ? 3 : 2,SQL_C_CHAR,buf,sizeof(buf),&len)) != SQL_NO_DATA) {\r
-            if (!SQL_SUCCEEDED(sr)) {\r
-                m_log.error("error while reading text field from result set");\r
-                log_error(stmt, SQL_HANDLE_STMT);\r
-                throw IOException("ODBC StorageService search failed to read data from result set.");\r
-            }\r
-            pvalue->append((char*)buf);\r
-        }\r
+        SQLINTEGER len;
+        SQLCHAR buf[LONGDATA_BUFLEN];
+        while ((sr=SQLGetData(stmt,pexpiration ? 3 : 2,SQL_C_CHAR,buf,sizeof(buf),&len)) != SQL_NO_DATA) {
+            if (!SQL_SUCCEEDED(sr)) {
+                m_log.error("error while reading text field from result set");
+                log_error(stmt, SQL_HANDLE_STMT);
+                throw IOException("ODBC StorageService search failed to read data from result set.");
+            }
+            pvalue->append((char*)buf);
+        }
     }
     
     return ver;
@@ -507,13 +551,15 @@ int ODBCStorageService::updateRow(const char *table, const char* context, const
 
     // Get statement handle.
     ODBCConn conn(getHDBC());
-    ODBCStatement stmt(getHSTMT(conn));
+    SQLHSTMT stmt = getHSTMT(conn);
 
     // First, fetch the current version for later, which also ensures the record still exists.
+    char timebuf[32];
+    timestampFromTime(time(NULL), timebuf);
     char *scontext = makeSafeSQL(context);
     char *skey = makeSafeSQL(key);
     string q("SELECT version FROM ");
-    q = q + table + " WHERE context='" + scontext + "' AND key='" + key + "' AND expires > NOW()";
+    q = q + table + " WHERE context='" + scontext + "' AND id='" + key + "' AND expires > " + timebuf;
 
     m_log.debug("SQL: %s", q.c_str());
 
@@ -541,29 +587,49 @@ int ODBCStorageService::updateRow(const char *table, const char* context, const
         return -1;
     }
 
+    SQLFreeHandle(SQL_HANDLE_STMT, stmt);
+    stmt = getHSTMT(conn);
+
     // Prepare and exectute update statement.
     q = string("UPDATE ") + table + " SET ";
 
-    if (value) {
-        char *svalue = makeSafeSQL(value);
-        q = q + "value='" + svalue + "'" + ",version=version+1";
-        freeSafeSQL(svalue, value);
-    }
+    if (value)
+        q = q + "value=?, version=version+1";
 
     if (expiration) {
-        char timebuf[32];
         timestampFromTime(expiration, timebuf);
         if (value)
             q += ',';
-        q = q + "expires = '" + timebuf + "' ";
+        q = q + "expires = " + timebuf;
     }
 
-    q = q + " WHERE context='" + scontext + "' AND key='" + key + "'";
+    q = q + " WHERE context='" + scontext + "' AND id='" + key + "'";
     freeSafeSQL(scontext, context);
     freeSafeSQL(skey, key);
 
-    m_log.debug("SQL: %s", q.c_str());
-    sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
+    sr = SQLPrepare(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
+    if (!SQL_SUCCEEDED(sr)) {
+        m_log.error("update of record failed (t=%s, c=%s, k=%s", table, context, key);
+        log_error(stmt, SQL_HANDLE_STMT);
+        throw IOException("ODBC StorageService failed to update record.");
+    }
+    m_log.debug("SQLPrepare succeded. SQL: %s", q.c_str());
+
+    SQLINTEGER b_ind = SQL_NTS;
+    if (value) {
+        if (strcmp(table, TEXT_TABLE)==0)
+            sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_LONGVARCHAR, strlen(value), 0, const_cast<char*>(value), &b_ind);
+        else
+            sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(value), &b_ind);
+        if (!SQL_SUCCEEDED(sr)) {
+            m_log.error("SQLBindParam failed (context = %s)", context);
+            log_error(stmt, SQL_HANDLE_STMT);
+            throw IOException("ODBC StorageService failed to update record.");
+        }
+        m_log.debug("SQLBindParam succeded (context = %s)", context);
+    }
+
+    sr=SQLExecute(stmt);
     if (sr==SQL_NO_DATA)
         return 0;   // went missing?
     else if (!SQL_SUCCEEDED(sr)) {
@@ -572,6 +638,7 @@ int ODBCStorageService::updateRow(const char *table, const char* context, const
         throw IOException("ODBC StorageService failed to update record.");
     }
 
+    m_log.debug("SQLExecute of update succeeded");
     return ver + 1;
 }
 
@@ -583,12 +650,12 @@ bool ODBCStorageService::deleteRow(const char *table, const char *context, const
 
     // Get statement handle.
     ODBCConn conn(getHDBC());
-    ODBCStatement stmt(getHSTMT(conn));
+    SQLHSTMT stmt = getHSTMT(conn);
 
     // Prepare and execute delete statement.
     char *scontext = makeSafeSQL(context);
     char *skey = makeSafeSQL(key);
-    string q = string("DELETE FROM ") + table + " WHERE context='" + scontext + "' AND key='" + skey + "'";
+    string q = string("DELETE FROM ") + table + " WHERE context='" + scontext + "' AND id='" + skey + "'";
     freeSafeSQL(scontext, context);
     freeSafeSQL(skey, key);
     m_log.debug("SQL: %s", q.c_str());
@@ -659,14 +726,17 @@ void ODBCStorageService::updateContext(const char *table, const char* context, t
 
     // Get statement handle.
     ODBCConn conn(getHDBC());
-    ODBCStatement stmt(getHSTMT(conn));
+    SQLHSTMT stmt = getHSTMT(conn);
 
     char timebuf[32];
     timestampFromTime(expiration, timebuf);
 
+    char nowbuf[32];
+    timestampFromTime(time(NULL), nowbuf);
+
     char *scontext = makeSafeSQL(context);
     string q("UPDATE ");
-    q = q + table + " SET expires = '" + timebuf + "' WHERE context='" + scontext + "' AND expires > NOW()";
+    q = q + table + " SET expires = " + timebuf + " WHERE context='" + scontext + "' AND expires > " + nowbuf;
     freeSafeSQL(scontext, context);
 
     m_log.debug("SQL: %s", q.c_str());
@@ -687,17 +757,19 @@ void ODBCStorageService::reap(const char *table, const char* context)
 
     // Get statement handle.
     ODBCConn conn(getHDBC());
-    ODBCStatement stmt(getHSTMT(conn));
+    SQLHSTMT stmt = getHSTMT(conn);
 
     // Prepare and execute delete statement.
+    char nowbuf[32];
+    timestampFromTime(time(NULL), nowbuf);
     string q;
     if (context) {
         char *scontext = makeSafeSQL(context);
-        q = string("DELETE FROM ") + table + " WHERE context='" + scontext + "' AND expires <= NOW()";
+        q = string("DELETE FROM ") + table + " WHERE context='" + scontext + "' AND expires <= " + nowbuf;
         freeSafeSQL(scontext, context);
     }
     else {
-        q = string("DELETE FROM ") + table + " WHERE expires <= NOW()";
+        q = string("DELETE FROM ") + table + " WHERE expires <= " + nowbuf;
     }
     m_log.debug("SQL: %s", q.c_str());
 
@@ -717,7 +789,7 @@ void ODBCStorageService::deleteContext(const char *table, const char* context)
 
     // Get statement handle.
     ODBCConn conn(getHDBC());
-    ODBCStatement stmt(getHSTMT(conn));
+    SQLHSTMT stmt = getHSTMT(conn);
 
     // Prepare and execute delete statement.
     char *scontext = makeSafeSQL(context);
@@ -733,14 +805,14 @@ void ODBCStorageService::deleteContext(const char *table, const char* context)
     }
 }
 
-extern "C" int ODBCSTORE_EXPORTS xmltooling_extension_init(void*)\r
-{\r
-    // Register this SS type\r
-    XMLToolingConfig::getConfig().StorageServiceManager.registerFactory("ODBC", ODBCStorageServiceFactory);\r
-    return 0;\r
-}\r
-\r
-extern "C" void ODBCSTORE_EXPORTS xmltooling_extension_term()\r
-{\r
-    XMLToolingConfig::getConfig().StorageServiceManager.deregisterFactory("ODBC");\r
-}\r
+extern "C" int ODBCSTORE_EXPORTS xmltooling_extension_init(void*)
+{
+    // Register this SS type
+    XMLToolingConfig::getConfig().StorageServiceManager.registerFactory("ODBC", ODBCStorageServiceFactory);
+    return 0;
+}
+
+extern "C" void ODBCSTORE_EXPORTS xmltooling_extension_term()
+{
+    XMLToolingConfig::getConfig().StorageServiceManager.deregisterFactory("ODBC");
+}