X-Git-Url: http://www.project-moonshot.org/gitweb/?a=blobdiff_plain;f=odbc-store%2Fodbc-store.cpp;h=9af4d0f4f66ec6c946aae02fa1f42a62d6a0dd7c;hb=17dab4f583f7d7293ce994ad562d88e7fe08386f;hp=7efd78ff5054ef448064a27375434d4a15016e70;hpb=bca93b0c1f6e0b4bf4169cfb8506b30acee3a2e1;p=shibboleth%2Fsp.git diff --git a/odbc-store/odbc-store.cpp b/odbc-store/odbc-store.cpp index 7efd78f..9af4d0f 100644 --- a/odbc-store/odbc-store.cpp +++ b/odbc-store/odbc-store.cpp @@ -20,16 +20,16 @@ * Storage Service using ODBC */ -#if defined (_MSC_VER) || defined(__BORLANDC__) -# include "config_win32.h" -#else -# include "config.h" -#endif - -#ifdef WIN32 -# define _CRT_NONSTDC_NO_DEPRECATE 1 -# define _CRT_SECURE_NO_DEPRECATE 1 -#endif +#if defined (_MSC_VER) || defined(__BORLANDC__) +# include "config_win32.h" +#else +# include "config.h" +#endif + +#ifdef WIN32 +# define _CRT_NONSTDC_NO_DEPRECATE 1 +# define _CRT_SECURE_NO_DEPRECATE 1 +#endif #ifdef WIN32 # define ODBCSTORE_EXPORTS __declspec(dllexport) @@ -37,8 +37,8 @@ # define ODBCSTORE_EXPORTS #endif -#include #include +#include #include #include #include @@ -48,9 +48,9 @@ #include #include +using namespace xmltooling::logging; using namespace xmltooling; using namespace xercesc; -using namespace log4cpp; using namespace std; #define PLUGIN_VER_MAJOR 1 @@ -58,35 +58,36 @@ using namespace std; #define LONGDATA_BUFLEN 16384 -#define COLSIZE_KEY 255 #define COLSIZE_CONTEXT 255 +#define COLSIZE_ID 255 #define COLSIZE_STRING_VALUE 255 -#define STRING_TABLE "STRING_TABLE" -#define TEXT_TABLE "TEXT_TABLE" - -/* tables definitions - not used here - -#define STRING_TABLE \ - "CREATE TABLE STRING_TABLE ( " \ - "context varchar(255), " \ - "key varchar(255), " \ - "value varchar(255), " \ - "expires datetime, " \ - "version smallint, " \ - "PRIMARY KEY (context, key)" \ - ")" - - -#define TEXT_TABLE \ - "CREATE TABLE TEXT_TABLE ( "\ - "context varchar(255), " \ - "key varchar(255), " \ - "value text, " \ - "expires datetime, " \ - "version smallint, " \ - "PRIMARY KEY (context, key)" \ - ")" +#define STRING_TABLE "strings" +#define TEXT_TABLE "texts" + +/* table definitions +CREATE TABLE version ( + major tinyint NOT NULL, + minor tinyint NOT NULL + ) + +CREATE TABLE strings ( + context varchar(255) not null, + id varchar(255) not null, + expires datetime not null, + version smallint not null, + value varchar(255) not null, + PRIMARY KEY (context, id) + ) + +CREATE TABLE texts ( + context varchar(255) not null, + id varchar(255) not null, + expires datetime not null, + version smallint not null, + value text not null, + PRIMARY KEY (context, id) + ) */ namespace { @@ -98,6 +99,7 @@ namespace { ODBCConn(SQLHDBC conn) : handle(conn) {} ~ODBCConn() { SQLRETURN sr = SQLEndTran(SQL_HANDLE_DBC, handle, SQL_COMMIT); + SQLDisconnect(handle); SQLFreeHandle(SQL_HANDLE_DBC,handle); if (!SQL_SUCCEEDED(sr)) throw IOException("Failed to commit connection."); @@ -106,20 +108,13 @@ namespace { SQLHDBC handle; }; - struct ODBCStatement { - ODBCStatement(SQLHSTMT statement) : handle(statement) {} - ~ODBCStatement() {SQLFreeHandle(SQL_HANDLE_STMT,handle);} - operator SQLHSTMT() {return handle;} - SQLHSTMT handle; - }; - class ODBCStorageService : public StorageService { public: ODBCStorageService(const DOMElement* e); virtual ~ODBCStorageService(); - void createString(const char* context, const char* key, const char* value, time_t expiration) { + bool createString(const char* context, const char* key, const char* value, time_t expiration) { return createRow(STRING_TABLE, context, key, value, expiration); } int readString(const char* context, const char* key, string* pvalue=NULL, time_t* pexpiration=NULL, int version=0) { @@ -132,7 +127,7 @@ namespace { return deleteRow(STRING_TABLE, context, key); } - void createText(const char* context, const char* key, const char* value, time_t expiration) { + bool createText(const char* context, const char* key, const char* value, time_t expiration) { return createRow(TEXT_TABLE, context, key, value, expiration); } int readText(const char* context, const char* key, string* pvalue=NULL, time_t* pexpiration=NULL, int version=0) { @@ -162,7 +157,7 @@ namespace { private: - void createRow(const char *table, const char* context, const char* key, const char* value, time_t expiration); + bool createRow(const char *table, const char* context, const char* key, const char* value, time_t expiration); int readRow(const char *table, const char* context, const char* key, string* pvalue, time_t* pexpiration, int version, bool text); int updateRow(const char *table, const char* context, const char* key, const char* value, time_t expiration, int version); bool deleteRow(const char *table, const char* context, const char* key); @@ -174,7 +169,7 @@ namespace { SQLHDBC getHDBC(); SQLHSTMT getHSTMT(SQLHDBC); pair getVersion(SQLHDBC); - void log_error(SQLHANDLE handle, SQLSMALLINT htype); + bool log_error(SQLHANDLE handle, SQLSMALLINT htype, const char* checkfor=NULL); static void* cleanup_fn(void*); void cleanup(); @@ -235,12 +230,12 @@ namespace { char *s; // see if any conversion needed - for (s=(char*)src; *s; nc++,s++) if (*s=='\''||*s=='\\') ns++; + for (s=(char*)src; *s; nc++,s++) if (*s=='\'') ns++; if (ns==0) return ((char*)src); char *safe = new char[(nc+2*ns+1)]; for (s=safe; *src; src++) { - if (*src=='\''||*src=='\\') *s++ = '\\'; + if (*src=='\'') *s++ = '\''; *s++ = (char)*src; } *s = '\0'; @@ -312,10 +307,11 @@ ODBCStorageService::~ODBCStorageService() shutdown_wait->signal(); cleanup_thread->join(NULL); delete shutdown_wait; - SQLFreeHandle(SQL_HANDLE_ENV, m_henv); + if (m_henv != SQL_NULL_HANDLE) + SQLFreeHandle(SQL_HANDLE_ENV, m_henv); } -void ODBCStorageService::log_error(SQLHANDLE handle, SQLSMALLINT htype) +bool ODBCStorageService::log_error(SQLHANDLE handle, SQLSMALLINT htype, const char* checkfor) { SQLSMALLINT i = 0; SQLINTEGER native; @@ -324,11 +320,16 @@ void ODBCStorageService::log_error(SQLHANDLE handle, SQLSMALLINT htype) SQLSMALLINT len; SQLRETURN ret; + bool res = false; do { ret = SQLGetDiagRec(htype, handle, ++i, state, &native, text, sizeof(text), &len); - if (SQL_SUCCEEDED(ret)) + if (SQL_SUCCEEDED(ret)) { m_log.error("ODBC Error: %s:%ld:%ld:%s", state, i, native, text); + if (checkfor && !strcmp(checkfor, (const char*)state)) + res = true; + } } while(SQL_SUCCEEDED(ret)); + return res; } SQLHDBC ODBCStorageService::getHDBC() @@ -378,7 +379,7 @@ SQLHSTMT ODBCStorageService::getHSTMT(SQLHDBC conn) pair ODBCStorageService::getVersion(SQLHDBC conn) { // Grab the version number from the database. - ODBCStatement stmt(getHSTMT(conn)); + SQLHSTMT stmt = getHSTMT(conn); SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)"SELECT major,minor FROM version", SQL_NTS); if (!SQL_SUCCEEDED(sr)) { @@ -399,7 +400,7 @@ pair ODBCStorageService::getVersion(SQLHDBC conn) throw IOException("ODBC StorageService failed to read version from database."); } -void ODBCStorageService::createRow(const char *table, const char* context, const char* key, const char* value, time_t expiration) +bool ODBCStorageService::createRow(const char* table, const char* context, const char* key, const char* value, time_t expiration) { #ifdef _DEBUG xmltooling::NDC ndc("createRow"); @@ -410,24 +411,65 @@ void ODBCStorageService::createRow(const char *table, const char* context, const // Get statement handle. ODBCConn conn(getHDBC()); - ODBCStatement stmt(getHSTMT(conn)); + SQLHSTMT stmt = getHSTMT(conn); // Prepare and exectute insert statement. - char *scontext = makeSafeSQL(context); - char *skey = makeSafeSQL(key); - char *svalue = makeSafeSQL(value); - string q = string("INSERT ") + table + " VALUES ('" + scontext + "','" + skey + "','" + svalue + "'," + timebuf + "', 1)"; - freeSafeSQL(scontext, context); - freeSafeSQL(skey, key); - freeSafeSQL(svalue, value); - m_log.debug("SQL: %s", q.c_str()); + //char *scontext = makeSafeSQL(context); + //char *skey = makeSafeSQL(key); + //char *svalue = makeSafeSQL(value); + string q = string("INSERT INTO ") + table + " VALUES (?,?," + timebuf + ",1,?)"; - SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS); + SQLRETURN sr = SQLPrepare(stmt, (SQLCHAR*)q.c_str(), SQL_NTS); if (!SQL_SUCCEEDED(sr)) { - m_log.error("insert record failed (t=%s, c=%s, k=%s)", table, context, key); + m_log.error("SQLPrepare failed (t=%s, c=%s, k=%s)", table, context, key); + log_error(stmt, SQL_HANDLE_STMT); + throw IOException("ODBC StorageService failed to insert record."); + } + m_log.debug("SQLPrepare succeded. SQL: %s", q.c_str()); + + SQLINTEGER b_ind = SQL_NTS; + sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast(context), &b_ind); + if (!SQL_SUCCEEDED(sr)) { + m_log.error("SQLBindParam failed (context = %s)", context); log_error(stmt, SQL_HANDLE_STMT); throw IOException("ODBC StorageService failed to insert record."); } + m_log.debug("SQLBindParam succeded (context = %s)", context); + + sr = SQLBindParam(stmt, 2, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast(key), &b_ind); + if (!SQL_SUCCEEDED(sr)) { + m_log.error("SQLBindParam failed (key = %s)", key); + log_error(stmt, SQL_HANDLE_STMT); + throw IOException("ODBC StorageService failed to insert record."); + } + m_log.debug("SQLBindParam succeded (key = %s)", key); + + if (strcmp(table, TEXT_TABLE)==0) + sr = SQLBindParam(stmt, 3, SQL_C_CHAR, SQL_LONGVARCHAR, strlen(value), 0, const_cast(value), &b_ind); + else + sr = SQLBindParam(stmt, 3, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast(value), &b_ind); + if (!SQL_SUCCEEDED(sr)) { + m_log.error("SQLBindParam failed (value = %s)", value); + log_error(stmt, SQL_HANDLE_STMT); + throw IOException("ODBC StorageService failed to insert record."); + } + m_log.debug("SQLBindParam succeded (value = %s)", value); + + //freeSafeSQL(scontext, context); + //freeSafeSQL(skey, key); + //freeSafeSQL(svalue, value); + //m_log.debug("SQL: %s", q.c_str()); + + sr=SQLExecute(stmt); + if (!SQL_SUCCEEDED(sr)) { + m_log.error("insert record failed (t=%s, c=%s, k=%s)", table, context, key); + if (log_error(stmt, SQL_HANDLE_STMT, "23000")) + return false; // supposedly integrity violation? + throw IOException("ODBC StorageService failed to insert record."); + } + + m_log.debug("SQLExecute of insert succeeded"); + return true; } int ODBCStorageService::readRow( @@ -438,26 +480,28 @@ int ODBCStorageService::readRow( xmltooling::NDC ndc("readRow"); #endif - SQLCHAR *tvalue = NULL; - // Get statement handle. ODBCConn conn(getHDBC()); - ODBCStatement stmt(getHSTMT(conn)); + SQLHSTMT stmt = getHSTMT(conn); // Prepare and exectute select statement. + char timebuf[32]; + timestampFromTime(time(NULL), timebuf); char *scontext = makeSafeSQL(context); char *skey = makeSafeSQL(key); - string q("SELECT version"); + ostringstream q; + q << "SELECT version"; if (pexpiration) - q += ",expires"; + q << ",expires"; if (pvalue) - q += ",value"; - q = q + " FROM " + table + " WHERE context='" + scontext + "' AND key='" + skey + "' AND expires > NOW()"; + q << ",CASE version WHEN " << version << " THEN NULL ELSE value END"; + q << " FROM " << table << " WHERE context='" << scontext << "' AND id='" << skey << "' AND expires > " << timebuf; freeSafeSQL(scontext, context); freeSafeSQL(skey, key); - m_log.debug("SQL: %s", q.c_str()); + if (m_log.isDebugEnabled()) + m_log.debug("SQL: %s", q.str().c_str()); - SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS); + SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.str().c_str(), SQL_NTS); if (!SQL_SUCCEEDED(sr)) { m_log.error("error searching for (t=%s, c=%s, k=%s)", table, context, key); log_error(stmt, SQL_HANDLE_STMT); @@ -481,16 +525,16 @@ int ODBCStorageService::readRow( return version; // nothing's changed, so just echo back the version if (pvalue) { - SQLINTEGER len; - SQLCHAR buf[LONGDATA_BUFLEN]; - while ((sr=SQLGetData(stmt,pexpiration ? 3 : 2,SQL_C_CHAR,buf,sizeof(buf),&len)) != SQL_NO_DATA) { - if (!SQL_SUCCEEDED(sr)) { - m_log.error("error while reading text field from result set"); - log_error(stmt, SQL_HANDLE_STMT); - throw IOException("ODBC StorageService search failed to read data from result set."); - } - pvalue->append((char*)buf); - } + SQLINTEGER len; + SQLCHAR buf[LONGDATA_BUFLEN]; + while ((sr=SQLGetData(stmt,pexpiration ? 3 : 2,SQL_C_CHAR,buf,sizeof(buf),&len)) != SQL_NO_DATA) { + if (!SQL_SUCCEEDED(sr)) { + m_log.error("error while reading text field from result set"); + log_error(stmt, SQL_HANDLE_STMT); + throw IOException("ODBC StorageService search failed to read data from result set."); + } + pvalue->append((char*)buf); + } } return ver; @@ -507,13 +551,15 @@ int ODBCStorageService::updateRow(const char *table, const char* context, const // Get statement handle. ODBCConn conn(getHDBC()); - ODBCStatement stmt(getHSTMT(conn)); + SQLHSTMT stmt = getHSTMT(conn); // First, fetch the current version for later, which also ensures the record still exists. + char timebuf[32]; + timestampFromTime(time(NULL), timebuf); char *scontext = makeSafeSQL(context); char *skey = makeSafeSQL(key); string q("SELECT version FROM "); - q = q + table + " WHERE context='" + scontext + "' AND key='" + key + "' AND expires > NOW()"; + q = q + table + " WHERE context='" + scontext + "' AND id='" + key + "' AND expires > " + timebuf; m_log.debug("SQL: %s", q.c_str()); @@ -541,29 +587,49 @@ int ODBCStorageService::updateRow(const char *table, const char* context, const return -1; } + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + stmt = getHSTMT(conn); + // Prepare and exectute update statement. q = string("UPDATE ") + table + " SET "; - if (value) { - char *svalue = makeSafeSQL(value); - q = q + "value='" + svalue + "'" + ",version=version+1"; - freeSafeSQL(svalue, value); - } + if (value) + q = q + "value=?, version=version+1"; if (expiration) { - char timebuf[32]; timestampFromTime(expiration, timebuf); if (value) q += ','; - q = q + "expires = '" + timebuf + "' "; + q = q + "expires = " + timebuf; } - q = q + " WHERE context='" + scontext + "' AND key='" + key + "'"; + q = q + " WHERE context='" + scontext + "' AND id='" + key + "'"; freeSafeSQL(scontext, context); freeSafeSQL(skey, key); - m_log.debug("SQL: %s", q.c_str()); - sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS); + sr = SQLPrepare(stmt, (SQLCHAR*)q.c_str(), SQL_NTS); + if (!SQL_SUCCEEDED(sr)) { + m_log.error("update of record failed (t=%s, c=%s, k=%s", table, context, key); + log_error(stmt, SQL_HANDLE_STMT); + throw IOException("ODBC StorageService failed to update record."); + } + m_log.debug("SQLPrepare succeded. SQL: %s", q.c_str()); + + SQLINTEGER b_ind = SQL_NTS; + if (value) { + if (strcmp(table, TEXT_TABLE)==0) + sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_LONGVARCHAR, strlen(value), 0, const_cast(value), &b_ind); + else + sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast(value), &b_ind); + if (!SQL_SUCCEEDED(sr)) { + m_log.error("SQLBindParam failed (context = %s)", context); + log_error(stmt, SQL_HANDLE_STMT); + throw IOException("ODBC StorageService failed to update record."); + } + m_log.debug("SQLBindParam succeded (context = %s)", context); + } + + sr=SQLExecute(stmt); if (sr==SQL_NO_DATA) return 0; // went missing? else if (!SQL_SUCCEEDED(sr)) { @@ -572,6 +638,7 @@ int ODBCStorageService::updateRow(const char *table, const char* context, const throw IOException("ODBC StorageService failed to update record."); } + m_log.debug("SQLExecute of update succeeded"); return ver + 1; } @@ -583,12 +650,12 @@ bool ODBCStorageService::deleteRow(const char *table, const char *context, const // Get statement handle. ODBCConn conn(getHDBC()); - ODBCStatement stmt(getHSTMT(conn)); + SQLHSTMT stmt = getHSTMT(conn); // Prepare and execute delete statement. char *scontext = makeSafeSQL(context); char *skey = makeSafeSQL(key); - string q = string("DELETE FROM ") + table + " WHERE context='" + scontext + "' AND key='" + skey + "'"; + string q = string("DELETE FROM ") + table + " WHERE context='" + scontext + "' AND id='" + skey + "'"; freeSafeSQL(scontext, context); freeSafeSQL(skey, key); m_log.debug("SQL: %s", q.c_str()); @@ -659,14 +726,17 @@ void ODBCStorageService::updateContext(const char *table, const char* context, t // Get statement handle. ODBCConn conn(getHDBC()); - ODBCStatement stmt(getHSTMT(conn)); + SQLHSTMT stmt = getHSTMT(conn); char timebuf[32]; timestampFromTime(expiration, timebuf); + char nowbuf[32]; + timestampFromTime(time(NULL), nowbuf); + char *scontext = makeSafeSQL(context); string q("UPDATE "); - q = q + table + " SET expires = '" + timebuf + "' WHERE context='" + scontext + "' AND expires > NOW()"; + q = q + table + " SET expires = " + timebuf + " WHERE context='" + scontext + "' AND expires > " + nowbuf; freeSafeSQL(scontext, context); m_log.debug("SQL: %s", q.c_str()); @@ -687,17 +757,19 @@ void ODBCStorageService::reap(const char *table, const char* context) // Get statement handle. ODBCConn conn(getHDBC()); - ODBCStatement stmt(getHSTMT(conn)); + SQLHSTMT stmt = getHSTMT(conn); // Prepare and execute delete statement. + char nowbuf[32]; + timestampFromTime(time(NULL), nowbuf); string q; if (context) { char *scontext = makeSafeSQL(context); - q = string("DELETE FROM ") + table + " WHERE context='" + scontext + "' AND expires <= NOW()"; + q = string("DELETE FROM ") + table + " WHERE context='" + scontext + "' AND expires <= " + nowbuf; freeSafeSQL(scontext, context); } else { - q = string("DELETE FROM ") + table + " WHERE expires <= NOW()"; + q = string("DELETE FROM ") + table + " WHERE expires <= " + nowbuf; } m_log.debug("SQL: %s", q.c_str()); @@ -717,7 +789,7 @@ void ODBCStorageService::deleteContext(const char *table, const char* context) // Get statement handle. ODBCConn conn(getHDBC()); - ODBCStatement stmt(getHSTMT(conn)); + SQLHSTMT stmt = getHSTMT(conn); // Prepare and execute delete statement. char *scontext = makeSafeSQL(context); @@ -733,14 +805,14 @@ void ODBCStorageService::deleteContext(const char *table, const char* context) } } -extern "C" int ODBCSTORE_EXPORTS xmltooling_extension_init(void*) -{ - // Register this SS type - XMLToolingConfig::getConfig().StorageServiceManager.registerFactory("ODBC", ODBCStorageServiceFactory); - return 0; -} - -extern "C" void ODBCSTORE_EXPORTS xmltooling_extension_term() -{ - XMLToolingConfig::getConfig().StorageServiceManager.deregisterFactory("ODBC"); -} +extern "C" int ODBCSTORE_EXPORTS xmltooling_extension_init(void*) +{ + // Register this SS type + XMLToolingConfig::getConfig().StorageServiceManager.registerFactory("ODBC", ODBCStorageServiceFactory); + return 0; +} + +extern "C" void ODBCSTORE_EXPORTS xmltooling_extension_term() +{ + XMLToolingConfig::getConfig().StorageServiceManager.deregisterFactory("ODBC"); +}