X-Git-Url: http://www.project-moonshot.org/gitweb/?a=blobdiff_plain;f=odbc-store%2Fodbc-store.cpp;h=9af4d0f4f66ec6c946aae02fa1f42a62d6a0dd7c;hb=17dab4f583f7d7293ce994ad562d88e7fe08386f;hp=2693a9dd24ace4fcfeca99e7d6def7375e01ff9a;hpb=a50127c9f35d03a676ed325787674b305a7498aa;p=shibboleth%2Fsp.git diff --git a/odbc-store/odbc-store.cpp b/odbc-store/odbc-store.cpp index 2693a9d..9af4d0f 100644 --- a/odbc-store/odbc-store.cpp +++ b/odbc-store/odbc-store.cpp @@ -37,8 +37,8 @@ # define ODBCSTORE_EXPORTS #endif -#include #include +#include #include #include #include @@ -48,9 +48,9 @@ #include #include +using namespace xmltooling::logging; using namespace xmltooling; using namespace xercesc; -using namespace log4cpp; using namespace std; #define PLUGIN_VER_MAJOR 1 @@ -65,7 +65,7 @@ using namespace std; #define STRING_TABLE "strings" #define TEXT_TABLE "texts" -/* tables definitions +/* table definitions CREATE TABLE version ( major tinyint NOT NULL, minor tinyint NOT NULL @@ -99,6 +99,7 @@ namespace { ODBCConn(SQLHDBC conn) : handle(conn) {} ~ODBCConn() { SQLRETURN sr = SQLEndTran(SQL_HANDLE_DBC, handle, SQL_COMMIT); + SQLDisconnect(handle); SQLFreeHandle(SQL_HANDLE_DBC,handle); if (!SQL_SUCCEEDED(sr)) throw IOException("Failed to commit connection."); @@ -107,20 +108,13 @@ namespace { SQLHDBC handle; }; - struct ODBCStatement { - ODBCStatement(SQLHSTMT statement) : handle(statement) {} - ~ODBCStatement() {SQLFreeHandle(SQL_HANDLE_STMT,handle);} - operator SQLHSTMT() {return handle;} - SQLHSTMT handle; - }; - class ODBCStorageService : public StorageService { public: ODBCStorageService(const DOMElement* e); virtual ~ODBCStorageService(); - void createString(const char* context, const char* key, const char* value, time_t expiration) { + bool createString(const char* context, const char* key, const char* value, time_t expiration) { return createRow(STRING_TABLE, context, key, value, expiration); } int readString(const char* context, const char* key, string* pvalue=NULL, time_t* pexpiration=NULL, int version=0) { @@ -133,7 +127,7 @@ namespace { return deleteRow(STRING_TABLE, context, key); } - void createText(const char* context, const char* key, const char* value, time_t expiration) { + bool createText(const char* context, const char* key, const char* value, time_t expiration) { return createRow(TEXT_TABLE, context, key, value, expiration); } int readText(const char* context, const char* key, string* pvalue=NULL, time_t* pexpiration=NULL, int version=0) { @@ -163,7 +157,7 @@ namespace { private: - void createRow(const char *table, const char* context, const char* key, const char* value, time_t expiration); + bool createRow(const char *table, const char* context, const char* key, const char* value, time_t expiration); int readRow(const char *table, const char* context, const char* key, string* pvalue, time_t* pexpiration, int version, bool text); int updateRow(const char *table, const char* context, const char* key, const char* value, time_t expiration, int version); bool deleteRow(const char *table, const char* context, const char* key); @@ -175,7 +169,7 @@ namespace { SQLHDBC getHDBC(); SQLHSTMT getHSTMT(SQLHDBC); pair getVersion(SQLHDBC); - void log_error(SQLHANDLE handle, SQLSMALLINT htype); + bool log_error(SQLHANDLE handle, SQLSMALLINT htype, const char* checkfor=NULL); static void* cleanup_fn(void*); void cleanup(); @@ -313,10 +307,11 @@ ODBCStorageService::~ODBCStorageService() shutdown_wait->signal(); cleanup_thread->join(NULL); delete shutdown_wait; - SQLFreeHandle(SQL_HANDLE_ENV, m_henv); + if (m_henv != SQL_NULL_HANDLE) + SQLFreeHandle(SQL_HANDLE_ENV, m_henv); } -void ODBCStorageService::log_error(SQLHANDLE handle, SQLSMALLINT htype) +bool ODBCStorageService::log_error(SQLHANDLE handle, SQLSMALLINT htype, const char* checkfor) { SQLSMALLINT i = 0; SQLINTEGER native; @@ -325,11 +320,16 @@ void ODBCStorageService::log_error(SQLHANDLE handle, SQLSMALLINT htype) SQLSMALLINT len; SQLRETURN ret; + bool res = false; do { ret = SQLGetDiagRec(htype, handle, ++i, state, &native, text, sizeof(text), &len); - if (SQL_SUCCEEDED(ret)) + if (SQL_SUCCEEDED(ret)) { m_log.error("ODBC Error: %s:%ld:%ld:%s", state, i, native, text); + if (checkfor && !strcmp(checkfor, (const char*)state)) + res = true; + } } while(SQL_SUCCEEDED(ret)); + return res; } SQLHDBC ODBCStorageService::getHDBC() @@ -379,7 +379,7 @@ SQLHSTMT ODBCStorageService::getHSTMT(SQLHDBC conn) pair ODBCStorageService::getVersion(SQLHDBC conn) { // Grab the version number from the database. - ODBCStatement stmt(getHSTMT(conn)); + SQLHSTMT stmt = getHSTMT(conn); SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)"SELECT major,minor FROM version", SQL_NTS); if (!SQL_SUCCEEDED(sr)) { @@ -400,7 +400,7 @@ pair ODBCStorageService::getVersion(SQLHDBC conn) throw IOException("ODBC StorageService failed to read version from database."); } -void ODBCStorageService::createRow(const char *table, const char* context, const char* key, const char* value, time_t expiration) +bool ODBCStorageService::createRow(const char* table, const char* context, const char* key, const char* value, time_t expiration) { #ifdef _DEBUG xmltooling::NDC ndc("createRow"); @@ -411,24 +411,65 @@ void ODBCStorageService::createRow(const char *table, const char* context, const // Get statement handle. ODBCConn conn(getHDBC()); - ODBCStatement stmt(getHSTMT(conn)); + SQLHSTMT stmt = getHSTMT(conn); // Prepare and exectute insert statement. - char *scontext = makeSafeSQL(context); - char *skey = makeSafeSQL(key); - char *svalue = makeSafeSQL(value); - string q = string("INSERT ") + table + " VALUES ('" + scontext + "','" + skey + "'," + timebuf + ",1,'" + svalue + "')"; - freeSafeSQL(scontext, context); - freeSafeSQL(skey, key); - freeSafeSQL(svalue, value); - m_log.debug("SQL: %s", q.c_str()); + //char *scontext = makeSafeSQL(context); + //char *skey = makeSafeSQL(key); + //char *svalue = makeSafeSQL(value); + string q = string("INSERT INTO ") + table + " VALUES (?,?," + timebuf + ",1,?)"; - SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS); + SQLRETURN sr = SQLPrepare(stmt, (SQLCHAR*)q.c_str(), SQL_NTS); if (!SQL_SUCCEEDED(sr)) { - m_log.error("insert record failed (t=%s, c=%s, k=%s)", table, context, key); + m_log.error("SQLPrepare failed (t=%s, c=%s, k=%s)", table, context, key); + log_error(stmt, SQL_HANDLE_STMT); + throw IOException("ODBC StorageService failed to insert record."); + } + m_log.debug("SQLPrepare succeded. SQL: %s", q.c_str()); + + SQLINTEGER b_ind = SQL_NTS; + sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast(context), &b_ind); + if (!SQL_SUCCEEDED(sr)) { + m_log.error("SQLBindParam failed (context = %s)", context); + log_error(stmt, SQL_HANDLE_STMT); + throw IOException("ODBC StorageService failed to insert record."); + } + m_log.debug("SQLBindParam succeded (context = %s)", context); + + sr = SQLBindParam(stmt, 2, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast(key), &b_ind); + if (!SQL_SUCCEEDED(sr)) { + m_log.error("SQLBindParam failed (key = %s)", key); + log_error(stmt, SQL_HANDLE_STMT); + throw IOException("ODBC StorageService failed to insert record."); + } + m_log.debug("SQLBindParam succeded (key = %s)", key); + + if (strcmp(table, TEXT_TABLE)==0) + sr = SQLBindParam(stmt, 3, SQL_C_CHAR, SQL_LONGVARCHAR, strlen(value), 0, const_cast(value), &b_ind); + else + sr = SQLBindParam(stmt, 3, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast(value), &b_ind); + if (!SQL_SUCCEEDED(sr)) { + m_log.error("SQLBindParam failed (value = %s)", value); log_error(stmt, SQL_HANDLE_STMT); throw IOException("ODBC StorageService failed to insert record."); } + m_log.debug("SQLBindParam succeded (value = %s)", value); + + //freeSafeSQL(scontext, context); + //freeSafeSQL(skey, key); + //freeSafeSQL(svalue, value); + //m_log.debug("SQL: %s", q.c_str()); + + sr=SQLExecute(stmt); + if (!SQL_SUCCEEDED(sr)) { + m_log.error("insert record failed (t=%s, c=%s, k=%s)", table, context, key); + if (log_error(stmt, SQL_HANDLE_STMT, "23000")) + return false; // supposedly integrity violation? + throw IOException("ODBC StorageService failed to insert record."); + } + + m_log.debug("SQLExecute of insert succeeded"); + return true; } int ODBCStorageService::readRow( @@ -441,24 +482,26 @@ int ODBCStorageService::readRow( // Get statement handle. ODBCConn conn(getHDBC()); - ODBCStatement stmt(getHSTMT(conn)); + SQLHSTMT stmt = getHSTMT(conn); // Prepare and exectute select statement. char timebuf[32]; timestampFromTime(time(NULL), timebuf); char *scontext = makeSafeSQL(context); char *skey = makeSafeSQL(key); - string q("SELECT version"); + ostringstream q; + q << "SELECT version"; if (pexpiration) - q += ",expires"; + q << ",expires"; if (pvalue) - q += ",value"; - q = q + " FROM " + table + " WHERE context='" + scontext + "' AND id='" + skey + "' AND expires > " + timebuf; + q << ",CASE version WHEN " << version << " THEN NULL ELSE value END"; + q << " FROM " << table << " WHERE context='" << scontext << "' AND id='" << skey << "' AND expires > " << timebuf; freeSafeSQL(scontext, context); freeSafeSQL(skey, key); - m_log.debug("SQL: %s", q.c_str()); + if (m_log.isDebugEnabled()) + m_log.debug("SQL: %s", q.str().c_str()); - SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS); + SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.str().c_str(), SQL_NTS); if (!SQL_SUCCEEDED(sr)) { m_log.error("error searching for (t=%s, c=%s, k=%s)", table, context, key); log_error(stmt, SQL_HANDLE_STMT); @@ -508,7 +551,7 @@ int ODBCStorageService::updateRow(const char *table, const char* context, const // Get statement handle. ODBCConn conn(getHDBC()); - ODBCStatement stmt(getHSTMT(conn)); + SQLHSTMT stmt = getHSTMT(conn); // First, fetch the current version for later, which also ensures the record still exists. char timebuf[32]; @@ -544,14 +587,14 @@ int ODBCStorageService::updateRow(const char *table, const char* context, const return -1; } + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + stmt = getHSTMT(conn); + // Prepare and exectute update statement. q = string("UPDATE ") + table + " SET "; - if (value) { - char *svalue = makeSafeSQL(value); - q = q + "value='" + svalue + "'" + ",version=version+1"; - freeSafeSQL(svalue, value); - } + if (value) + q = q + "value=?, version=version+1"; if (expiration) { timestampFromTime(expiration, timebuf); @@ -564,8 +607,29 @@ int ODBCStorageService::updateRow(const char *table, const char* context, const freeSafeSQL(scontext, context); freeSafeSQL(skey, key); - m_log.debug("SQL: %s", q.c_str()); - sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS); + sr = SQLPrepare(stmt, (SQLCHAR*)q.c_str(), SQL_NTS); + if (!SQL_SUCCEEDED(sr)) { + m_log.error("update of record failed (t=%s, c=%s, k=%s", table, context, key); + log_error(stmt, SQL_HANDLE_STMT); + throw IOException("ODBC StorageService failed to update record."); + } + m_log.debug("SQLPrepare succeded. SQL: %s", q.c_str()); + + SQLINTEGER b_ind = SQL_NTS; + if (value) { + if (strcmp(table, TEXT_TABLE)==0) + sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_LONGVARCHAR, strlen(value), 0, const_cast(value), &b_ind); + else + sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast(value), &b_ind); + if (!SQL_SUCCEEDED(sr)) { + m_log.error("SQLBindParam failed (context = %s)", context); + log_error(stmt, SQL_HANDLE_STMT); + throw IOException("ODBC StorageService failed to update record."); + } + m_log.debug("SQLBindParam succeded (context = %s)", context); + } + + sr=SQLExecute(stmt); if (sr==SQL_NO_DATA) return 0; // went missing? else if (!SQL_SUCCEEDED(sr)) { @@ -574,6 +638,7 @@ int ODBCStorageService::updateRow(const char *table, const char* context, const throw IOException("ODBC StorageService failed to update record."); } + m_log.debug("SQLExecute of update succeeded"); return ver + 1; } @@ -585,7 +650,7 @@ bool ODBCStorageService::deleteRow(const char *table, const char *context, const // Get statement handle. ODBCConn conn(getHDBC()); - ODBCStatement stmt(getHSTMT(conn)); + SQLHSTMT stmt = getHSTMT(conn); // Prepare and execute delete statement. char *scontext = makeSafeSQL(context); @@ -661,7 +726,7 @@ void ODBCStorageService::updateContext(const char *table, const char* context, t // Get statement handle. ODBCConn conn(getHDBC()); - ODBCStatement stmt(getHSTMT(conn)); + SQLHSTMT stmt = getHSTMT(conn); char timebuf[32]; timestampFromTime(expiration, timebuf); @@ -692,7 +757,7 @@ void ODBCStorageService::reap(const char *table, const char* context) // Get statement handle. ODBCConn conn(getHDBC()); - ODBCStatement stmt(getHSTMT(conn)); + SQLHSTMT stmt = getHSTMT(conn); // Prepare and execute delete statement. char nowbuf[32]; @@ -724,7 +789,7 @@ void ODBCStorageService::deleteContext(const char *table, const char* context) // Get statement handle. ODBCConn conn(getHDBC()); - ODBCStatement stmt(getHSTMT(conn)); + SQLHSTMT stmt = getHSTMT(conn); // Prepare and execute delete statement. char *scontext = makeSafeSQL(context);