Remove extra header
[shibboleth/cpp-sp.git] / odbc-store / odbc-store.cpp
index 6f7aede..7008e50 100644 (file)
@@ -41,7 +41,6 @@
 # define ODBCSTORE_EXPORTS
 #endif
 
-#include <xercesc/util/XMLUniDefs.hpp>
 #include <xmltooling/logging.h>
 #include <xmltooling/unicode.h>
 #include <xmltooling/XMLToolingConfig.h>
 #include <xmltooling/util/StorageService.h>
 #include <xmltooling/util/Threads.h>
 #include <xmltooling/util/XMLHelper.h>
+#include <xercesc/util/XMLUniDefs.hpp>
 
 #include <sql.h>
 #include <sqlext.h>
 
+#include <boost/lexical_cast.hpp>
+#include <boost/algorithm/string.hpp>
+
 using namespace xmltooling::logging;
 using namespace xmltooling;
 using namespace xercesc;
+using namespace boost;
 using namespace std;
 
 #define PLUGIN_VER_MAJOR 1
@@ -185,7 +189,7 @@ namespace {
 
         SQLHDBC getHDBC();
         SQLHSTMT getHSTMT(SQLHDBC);
-        pair<int,int> getVersion(SQLHDBC);
+        pair<SQLINTEGER,SQLINTEGER> getVersion(SQLHDBC);
         pair<bool,bool> log_error(SQLHANDLE handle, SQLSMALLINT htype, const char* checkfor=nullptr);
 
         static void* cleanup_fn(void*); 
@@ -194,7 +198,7 @@ namespace {
         Category& m_log;
         Capabilities m_caps;
         int m_cleanupInterval;
-        CondWait* shutdown_wait;
+        scoped_ptr<CondWait> shutdown_wait;
         Thread* cleanup_thread;
         bool shutdown;
 
@@ -241,61 +245,46 @@ namespace {
         strftime(ret,32,"{ts '%Y-%m-%d %H:%M:%S'}",ptime);
     }
 
-    // make a string safe for SQL command
-    // result to be free'd only if it isn't the input
-    static char *makeSafeSQL(const char *src)
-    {
-       int ns = 0;
-       int nc = 0;
-       char *s;
-    
-       // see if any conversion needed
-       for (s=(char*)src; *s; nc++,s++) if (*s=='\'') ns++;
-       if (ns==0) return ((char*)src);
-    
-       char *safe = new char[(nc+2*ns+1)];
-       for (s=safe; *src; src++) {
-           if (*src=='\'') *s++ = '\'';
-           *s++ = (char)*src;
-       }
-       *s = '\0';
-       return (safe);
-    }
+    class SQLString {
+        const char* m_src;
+        string m_copy;
+    public:
+        SQLString(const char* src) : m_src(src) {
+            if (strchr(src, '\'')) {
+                m_copy = src;
+                replace_all(m_copy, "'", "''");
+            }
+        }
 
-    void freeSafeSQL(char *safe, const char *src)
-    {
-        if (safe!=src)
-            delete[](safe);
-    }
+        operator const char*() const {
+            return tostr();
+        }
+
+        const char* tostr() const {
+            return m_copy.empty() ? m_src : m_copy.c_str();
+        }
+    };
 };
 
 ODBCStorageService::ODBCStorageService(const DOMElement* e) : m_log(Category::getInstance("XMLTooling.StorageService")),
     m_caps(XMLHelper::getAttrInt(e, 255, contextSize), XMLHelper::getAttrInt(e, 255, keySize), XMLHelper::getAttrInt(e, 255, stringSize)),
-    m_cleanupInterval(900), shutdown_wait(nullptr), cleanup_thread(nullptr), shutdown(false), m_henv(SQL_NULL_HANDLE), m_isolation(SQL_TXN_SERIALIZABLE)
+    m_cleanupInterval(XMLHelper::getAttrInt(e, 900, cleanupInterval)),
+    cleanup_thread(nullptr), shutdown(false), m_henv(SQL_NULL_HANDLE), m_isolation(SQL_TXN_SERIALIZABLE)
 {
 #ifdef _DEBUG
     xmltooling::NDC ndc("ODBCStorageService");
 #endif
-
-    const XMLCh* tag=e ? e->getAttributeNS(nullptr,cleanupInterval) : nullptr;
-    if (tag && *tag)
-        m_cleanupInterval = XMLString::parseInt(tag);
-    if (!m_cleanupInterval)
-        m_cleanupInterval = 900;
-
-    auto_ptr_char iso(e ? e->getAttributeNS(nullptr,isolationLevel) : nullptr);
-    if (iso.get() && *iso.get()) {
-        if (!strcmp(iso.get(),"SERIALIZABLE"))
-            m_isolation = SQL_TXN_SERIALIZABLE;
-        else if (!strcmp(iso.get(),"REPEATABLE_READ"))
-            m_isolation = SQL_TXN_REPEATABLE_READ;
-        else if (!strcmp(iso.get(),"READ_COMMITTED"))
-            m_isolation = SQL_TXN_READ_COMMITTED;
-        else if (!strcmp(iso.get(),"READ_UNCOMMITTED"))
-            m_isolation = SQL_TXN_READ_UNCOMMITTED;
-        else
-            throw XMLToolingException("Unknown transaction isolationLevel property.");
-    }
+    string iso(XMLHelper::getAttrString(e, "SERIALIZABLE", isolationLevel));
+    if (iso == "SERIALIZABLE")
+        m_isolation = SQL_TXN_SERIALIZABLE;
+    else if (iso == "REPEATABLE_READ")
+        m_isolation = SQL_TXN_REPEATABLE_READ;
+    else if (iso == "READ_COMMITTED")
+        m_isolation = SQL_TXN_READ_COMMITTED;
+    else if (iso == "READ_UNCOMMITTED")
+        m_isolation = SQL_TXN_READ_UNCOMMITTED;
+    else
+        throw XMLToolingException("Unknown transaction isolationLevel property.");
 
     if (m_henv == SQL_NULL_HANDLE) {
         // Enable connection pooling.
@@ -312,17 +301,17 @@ ODBCStorageService::ODBCStorageService(const DOMElement* e) : m_log(Category::ge
     }
 
     // Grab connection string from the configuration.
-    e = e ? XMLHelper::getFirstChildElement(e,ConnectionString) : nullptr;
-    if (!e || !e->hasChildNodes()) {
+    e = e ? XMLHelper::getFirstChildElement(e, ConnectionString) : nullptr;
+    auto_ptr_char arg(e ? e->getTextContent() : nullptr);
+    if (!arg.get() || !*arg.get()) {
         SQLFreeHandle(SQL_HANDLE_ENV, m_henv);
         throw XMLToolingException("ODBC StorageService requires ConnectionString element in configuration.");
     }
-    auto_ptr_char arg(e->getFirstChild()->getNodeValue());
-    m_connstring=arg.get();
+    m_connstring = arg.get();
 
     // Connect and check version.
     ODBCConn conn(getHDBC());
-    pair<int,int> v=getVersion(conn);
+    pair<SQLINTEGER,SQLINTEGER> v = getVersion(conn);
 
     // Make sure we've got the right version.
     if (v.first != PLUGIN_VER_MAJOR) {
@@ -332,17 +321,17 @@ ODBCStorageService::ODBCStorageService(const DOMElement* e) : m_log(Category::ge
     }
 
     // Load any retry errors to check.
-    e = XMLHelper::getNextSiblingElement(e,RetryOnError);
+    e = XMLHelper::getNextSiblingElement(e, RetryOnError);
     while (e) {
         if (e->hasChildNodes()) {
-            m_retries.push_back(XMLString::parseInt(e->getFirstChild()->getNodeValue()));
+            m_retries.push_back(XMLString::parseInt(e->getTextContent()));
             m_log.info("will retry operations when native ODBC error (%ld) is returned", m_retries.back());
         }
-        e = XMLHelper::getNextSiblingElement(e,RetryOnError);
+        e = XMLHelper::getNextSiblingElement(e, RetryOnError);
     }
 
     // Initialize the cleanup thread
-    shutdown_wait = CondWait::create();
+    shutdown_wait.reset(CondWait::create());
     cleanup_thread = Thread::create(&cleanup_fn, (void*)this);
 }
 
@@ -351,7 +340,6 @@ ODBCStorageService::~ODBCStorageService()
     shutdown = true;
     shutdown_wait->signal();
     cleanup_thread->join(nullptr);
-    delete shutdown_wait;
     if (m_henv != SQL_NULL_HANDLE)
         SQLFreeHandle(SQL_HANDLE_ENV, m_henv);
 }
@@ -411,7 +399,7 @@ SQLHDBC ODBCStorageService::getHDBC()
 SQLHSTMT ODBCStorageService::getHSTMT(SQLHDBC conn)
 {
     SQLHSTMT hstmt;
-    SQLRETURN sr=SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
+    SQLRETURN sr = SQLAllocHandle(SQL_HANDLE_STMT, conn, &hstmt);
     if (!SQL_SUCCEEDED(sr)) {
         m_log.error("failed to allocate statement handle");
         log_error(conn, SQL_HANDLE_DBC);
@@ -420,12 +408,12 @@ SQLHSTMT ODBCStorageService::getHSTMT(SQLHDBC conn)
     return hstmt;
 }
 
-pair<int,int> ODBCStorageService::getVersion(SQLHDBC conn)
+pair<SQLINTEGER,SQLINTEGER> ODBCStorageService::getVersion(SQLHDBC conn)
 {
     // Grab the version number from the database.
     SQLHSTMT stmt = getHSTMT(conn);
     
-    SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)"SELECT major,minor FROM version", SQL_NTS);
+    SQLRETURN sr = SQLExecDirect(stmt, (SQLCHAR*)"SELECT major,minor FROM version", SQL_NTS);
     if (!SQL_SUCCEEDED(sr)) {
         m_log.error("failed to read version from database");
         log_error(stmt, SQL_HANDLE_STMT);
@@ -434,11 +422,11 @@ pair<int,int> ODBCStorageService::getVersion(SQLHDBC conn)
 
     SQLINTEGER major;
     SQLINTEGER minor;
-    SQLBindCol(stmt,1,SQL_C_SLONG,&major,0,nullptr);
-    SQLBindCol(stmt,2,SQL_C_SLONG,&minor,0,nullptr);
+    SQLBindCol(stmt, 1, SQL_C_SLONG, &major, 0, nullptr);
+    SQLBindCol(stmt, 2, SQL_C_SLONG, &minor, 0, nullptr);
 
-    if ((sr=SQLFetch(stmt)) != SQL_NO_DATA)
-        return pair<int,int>(major,minor);
+    if ((sr = SQLFetch(stmt)) != SQL_NO_DATA)
+        return make_pair(major,minor);
 
     m_log.error("no rows returned in version query");
     throw IOException("ODBC StorageService failed to read version from database.");
@@ -457,10 +445,6 @@ bool ODBCStorageService::createRow(const char* table, const char* context, const
     ODBCConn conn(getHDBC());
     SQLHSTMT stmt = getHSTMT(conn);
 
-    // Prepare and exectute insert statement.
-    //char *scontext = makeSafeSQL(context);
-    //char *skey = makeSafeSQL(key);
-    //char *svalue = makeSafeSQL(value);
     string q  = string("INSERT INTO ") + table + " VALUES (?,?," + timebuf + ",1,?)";
 
     SQLRETURN sr = SQLPrepare(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
@@ -499,17 +483,12 @@ bool ODBCStorageService::createRow(const char* table, const char* context, const
     }
     m_log.debug("SQLBindParam succeeded (value = %s)", value);
     
-    //freeSafeSQL(scontext, context);
-    //freeSafeSQL(skey, key);
-    //freeSafeSQL(svalue, value);
-    //m_log.debug("SQL: %s", q.c_str());
-
     int attempts = 3;
     pair<bool,bool> logres;
     do {
         logres = make_pair(false,false);
         attempts--;
-        sr=SQLExecute(stmt);
+        sr = SQLExecute(stmt);
         if (SQL_SUCCEEDED(sr)) {
             m_log.debug("SQLExecute of insert succeeded");
             return true;
@@ -538,21 +517,18 @@ int ODBCStorageService::readRow(
     // Prepare and exectute select statement.
     char timebuf[32];
     timestampFromTime(time(nullptr), timebuf);
-    char *scontext = makeSafeSQL(context);
-    char *skey = makeSafeSQL(key);
-    ostringstream q;
-    q << "SELECT version";
+    SQLString scontext(context);
+    SQLString skey(key);
+    string q("SELECT version");
     if (pexpiration)
-        q << ",expires";
+        q += ",expires";
     if (pvalue)
-        q << ",CASE version WHEN " << version << " THEN null ELSE value END";
-    q << " FROM " << table << " WHERE context='" << scontext << "' AND id='" << skey << "' AND expires > " << timebuf;
-    freeSafeSQL(scontext, context);
-    freeSafeSQL(skey, key);
+        q = q + ",CASE version WHEN " + lexical_cast<string>(version) + " THEN null ELSE value END";
+    q = q + " FROM " + table + " WHERE context='" + scontext.tostr() + "' AND id='" + skey.tostr() + "' AND expires > " + timebuf;
     if (m_log.isDebugEnabled())
-        m_log.debug("SQL: %s", q.str().c_str());
+        m_log.debug("SQL: %s", q.c_str());
 
-    SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.str().c_str(), SQL_NTS);
+    SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
     if (!SQL_SUCCEEDED(sr)) {
         m_log.error("error searching for (t=%s, c=%s, k=%s)", table, context, key);
         log_error(stmt, SQL_HANDLE_STMT);
@@ -562,11 +538,11 @@ int ODBCStorageService::readRow(
     SQLSMALLINT ver;
     SQL_TIMESTAMP_STRUCT expiration;
 
-    SQLBindCol(stmt,1,SQL_C_SSHORT,&ver,0,nullptr);
+    SQLBindCol(stmt, 1, SQL_C_SSHORT, &ver, 0, nullptr);
     if (pexpiration)
-        SQLBindCol(stmt,2,SQL_C_TYPE_TIMESTAMP,&expiration,0,nullptr);
+        SQLBindCol(stmt, 2, SQL_C_TYPE_TIMESTAMP, &expiration, 0, nullptr);
 
-    if ((sr=SQLFetch(stmt)) == SQL_NO_DATA)
+    if ((sr = SQLFetch(stmt)) == SQL_NO_DATA)
         return 0;
 
     if (pexpiration)
@@ -578,7 +554,7 @@ int ODBCStorageService::readRow(
     if (pvalue) {
         SQLLEN len;
         SQLCHAR buf[LONGDATA_BUFLEN];
-        while ((sr=SQLGetData(stmt,pexpiration ? 3 : 2,SQL_C_CHAR,buf,sizeof(buf),&len)) != SQL_NO_DATA) {
+        while ((sr = SQLGetData(stmt, (pexpiration ? 3 : 2), SQL_C_CHAR, buf, sizeof(buf), &len)) != SQL_NO_DATA) {
             if (!SQL_SUCCEEDED(sr)) {
                 m_log.error("error while reading text field from result set");
                 log_error(stmt, SQL_HANDLE_STMT);
@@ -611,34 +587,28 @@ int ODBCStorageService::updateRow(const char *table, const char* context, const
     // First, fetch the current version for later, which also ensures the record still exists.
     char timebuf[32];
     timestampFromTime(time(nullptr), timebuf);
-    char *scontext = makeSafeSQL(context);
-    char *skey = makeSafeSQL(key);
+    SQLString scontext(context);
+    SQLString skey(key);
     string q("SELECT version FROM ");
-    q = q + table + " WHERE context='" + scontext + "' AND id='" + skey + "' AND expires > " + timebuf;
+    q = q + table + " WHERE context='" + scontext.tostr() + "' AND id='" + skey.tostr() + "' AND expires > " + timebuf;
 
     m_log.debug("SQL: %s", q.c_str());
 
-    sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
+    sr = SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
     if (!SQL_SUCCEEDED(sr)) {
-        freeSafeSQL(scontext, context);
-        freeSafeSQL(skey, key);
         m_log.error("error searching for (t=%s, c=%s, k=%s)", table, context, key);
         log_error(stmt, SQL_HANDLE_STMT);
         throw IOException("ODBC StorageService search failed.");
     }
 
     SQLSMALLINT ver;
-    SQLBindCol(stmt,1,SQL_C_SSHORT,&ver,0,nullptr);
-    if ((sr=SQLFetch(stmt)) == SQL_NO_DATA) {
-        freeSafeSQL(scontext, context);
-        freeSafeSQL(skey, key);
+    SQLBindCol(stmt, 1, SQL_C_SSHORT, &ver, 0, nullptr);
+    if ((sr = SQLFetch(stmt)) == SQL_NO_DATA) {
         return 0;
     }
 
     // Check version?
     if (version > 0 && version != ver) {
-        freeSafeSQL(scontext, context);
-        freeSafeSQL(skey, key);
         return -1;
     }
 
@@ -658,9 +628,7 @@ int ODBCStorageService::updateRow(const char *table, const char* context, const
         q = q + "expires = " + timebuf;
     }
 
-    q = q + " WHERE context='" + scontext + "' AND id='" + skey + "'";
-    freeSafeSQL(scontext, context);
-    freeSafeSQL(skey, key);
+    q = q + " WHERE context='" + scontext.tostr() + "' AND id='" + skey.tostr() + "'";
 
     sr = SQLPrepare(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
     if (!SQL_SUCCEEDED(sr)) {
@@ -689,8 +657,8 @@ int ODBCStorageService::updateRow(const char *table, const char* context, const
     do {
         logres = make_pair(false,false);
         attempts--;
-        sr=SQLExecute(stmt);
-        if (sr==SQL_NO_DATA)
+        sr = SQLExecute(stmt);
+        if (sr == SQL_NO_DATA)
             return 0;   // went missing?
         else if (SQL_SUCCEEDED(sr)) {
             m_log.debug("SQLExecute of update succeeded");
@@ -715,15 +683,13 @@ bool ODBCStorageService::deleteRow(const char *table, const char *context, const
     SQLHSTMT stmt = getHSTMT(conn);
 
     // Prepare and execute delete statement.
-    char *scontext = makeSafeSQL(context);
-    char *skey = makeSafeSQL(key);
-    string q = string("DELETE FROM ") + table + " WHERE context='" + scontext + "' AND id='" + skey + "'";
-    freeSafeSQL(scontext, context);
-    freeSafeSQL(skey, key);
+    SQLString scontext(context);
+    SQLString skey(key);
+    string q = string("DELETE FROM ") + table + " WHERE context='" + scontext.tostr() + "' AND id='" + skey.tostr() + "'";
     m_log.debug("SQL: %s", q.c_str());
 
-    SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
-     if (sr==SQL_NO_DATA)
+    SQLRETURN sr = SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
+     if (sr == SQL_NO_DATA)
         return false;
     else if (!SQL_SUCCEEDED(sr)) {
         m_log.error("error deleting record (t=%s, c=%s, k=%s)", table, context, key);
@@ -741,20 +707,20 @@ void ODBCStorageService::cleanup()
     xmltooling::NDC ndc("cleanup");
 #endif
 
-    Mutex* mutex = Mutex::create();
+    scoped_ptr<Mutex> mutex(Mutex::create());
 
     mutex->lock();
 
     m_log.info("cleanup thread started... running every %d secs", m_cleanupInterval);
 
     while (!shutdown) {
-        shutdown_wait->timedwait(mutex, m_cleanupInterval);
+        shutdown_wait->timedwait(mutex.get(), m_cleanupInterval);
         if (shutdown)
             break;
         try {
             reap(nullptr);
         }
-        catch (exception& ex) {
+        catch (std::exception& ex) {
             m_log.error("cleanup thread swallowed exception: %s", ex.what());
         }
     }
@@ -762,7 +728,6 @@ void ODBCStorageService::cleanup()
     m_log.info("cleanup thread exiting...");
 
     mutex->unlock();
-    delete mutex;
     Thread::exit(nullptr);
 }
 
@@ -796,15 +761,13 @@ void ODBCStorageService::updateContext(const char *table, const char* context, t
     char nowbuf[32];
     timestampFromTime(time(nullptr), nowbuf);
 
-    char *scontext = makeSafeSQL(context);
-    string q("UPDATE ");
-    q = q + table + " SET expires = " + timebuf + " WHERE context='" + scontext + "' AND expires > " + nowbuf;
-    freeSafeSQL(scontext, context);
+    SQLString scontext(context);
+    string q = string("UPDATE ") + table + " SET expires = " + timebuf + " WHERE context='" + scontext.tostr() + "' AND expires > " + nowbuf;
 
     m_log.debug("SQL: %s", q.c_str());
 
-    SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
-    if ((sr!=SQL_NO_DATA) && !SQL_SUCCEEDED(sr)) {
+    SQLRETURN sr = SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
+    if ((sr != SQL_NO_DATA) && !SQL_SUCCEEDED(sr)) {
         m_log.error("error updating records (t=%s, c=%s)", table, context ? context : "all");
         log_error(stmt, SQL_HANDLE_STMT);
         throw IOException("ODBC StorageService failed to update context expiration.");
@@ -826,17 +789,16 @@ void ODBCStorageService::reap(const char *table, const char* context)
     timestampFromTime(time(nullptr), nowbuf);
     string q;
     if (context) {
-        char *scontext = makeSafeSQL(context);
-        q = string("DELETE FROM ") + table + " WHERE context='" + scontext + "' AND expires <= " + nowbuf;
-        freeSafeSQL(scontext, context);
+        SQLString scontext(context);
+        q = string("DELETE FROM ") + table + " WHERE context='" + scontext.tostr() + "' AND expires <= " + nowbuf;
     }
     else {
         q = string("DELETE FROM ") + table + " WHERE expires <= " + nowbuf;
     }
     m_log.debug("SQL: %s", q.c_str());
 
-    SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
-    if ((sr!=SQL_NO_DATA) && !SQL_SUCCEEDED(sr)) {
+    SQLRETURN sr = SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
+    if ((sr != SQL_NO_DATA) && !SQL_SUCCEEDED(sr)) {
         m_log.error("error expiring records (t=%s, c=%s)", table, context ? context : "all");
         log_error(stmt, SQL_HANDLE_STMT);
         throw IOException("ODBC StorageService failed to purge expired records.");
@@ -854,13 +816,12 @@ void ODBCStorageService::deleteContext(const char *table, const char* context)
     SQLHSTMT stmt = getHSTMT(conn);
 
     // Prepare and execute delete statement.
-    char *scontext = makeSafeSQL(context);
-    string q = string("DELETE FROM ") + table + " WHERE context='" + scontext + "'";
-    freeSafeSQL(scontext, context);
+    SQLString scontext(context);
+    string q = string("DELETE FROM ") + table + " WHERE context='" + scontext.tostr() + "'";
     m_log.debug("SQL: %s", q.c_str());
 
-    SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
-    if ((sr!=SQL_NO_DATA) && !SQL_SUCCEEDED(sr)) {
+    SQLRETURN sr = SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
+    if ((sr != SQL_NO_DATA) && !SQL_SUCCEEDED(sr)) {
         m_log.error("error deleting context (t=%s, c=%s)", table, context);
         log_error(stmt, SQL_HANDLE_STMT);
         throw IOException("ODBC StorageService failed to delete context.");