2 * Copyright 2001-2007 Internet2
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
18 * odbc-store.cpp - Storage service using ODBC
23 // eventually we might be able to support autoconf via cygwin...
24 #if defined (_MSC_VER) || defined(__BORLANDC__)
25 # include "config_win32.h"
31 # define _CRT_NONSTDC_NO_DEPRECATE 1
32 # define _CRT_SECURE_NO_DEPRECATE 1
34 # define SHIBODBC_EXPORTS __declspec(dllexport)
36 # define SHIBODBC_EXPORTS
39 #include <shib-target/shib-target.h>
40 #include <shibsp/exceptions.h>
41 #include <log4cpp/Category.hh>
42 #include <xmltooling/util/NDC.h>
43 #include <xmltooling/util/Threads.h>
52 #ifdef HAVE_LIBDMALLOCXX
56 using namespace shibsp;
57 using namespace shibtarget;
58 using namespace opensaml::saml2md;
60 using namespace xmltooling;
61 using namespace log4cpp;
64 #define PLUGIN_VER_MAJOR 3
65 #define PLUGIN_VER_MINOR 0
67 #define COLSIZE_KEY 64
68 #define COLSIZE_CONTEXT 256
69 #define COLSIZE_STRING_VALUE 256
72 /* tables definitions - not used here */
74 #define STRING_TABLE "STRING_STORE"
76 #define STRING_TABLE \
77 "CREATE TABLE STRING_TABLE ( "\
78 "context VARCHAR( COLSIZE_CONTEXT ), " \
79 "key VARCHAR( COLSIZE_KEY ), " \
80 "value VARCHAR( COLSIZE_STRING_VALUE ), " \
82 "PRIMARY KEY (context, key), "
86 #define TEXT_TABLE "TEXT_STORE"
89 "CREATE TABLE TEXT_TABLE ( "\
90 "context VARCHAR( COLSIZE_CONTEXT ), " \
91 "key VARCHAR( COLSIZE_KEY ), " \
94 "PRIMARY KEY (context, key), "
100 static const XMLCh ConnectionString[] =
101 { chLatin_C, chLatin_o, chLatin_n, chLatin_n, chLatin_e, chLatin_c, chLatin_t, chLatin_i, chLatin_o, chLatin_n,
102 chLatin_S, chLatin_t, chLatin_r, chLatin_i, chLatin_n, chLatin_g, chNull
104 static const XMLCh cleanupInterval[] =
105 { chLatin_c, chLatin_l, chLatin_e, chLatin_a, chLatin_n, chLatin_u, chLatin_p,
106 chLatin_I, chLatin_n, chLatin_t, chLatin_e, chLatin_r, chLatin_v, chLatin_a, chLatin_l, chNull
108 static const XMLCh cacheTimeout[] =
109 { chLatin_c, chLatin_a, chLatin_c, chLatin_h, chLatin_e, chLatin_T, chLatin_i, chLatin_m, chLatin_e, chLatin_o, chLatin_u, chLatin_t, chNull };
110 static const XMLCh odbcTimeout[] =
111 { chLatin_o, chLatin_d, chLatin_b, chLatin_c, chLatin_T, chLatin_i, chLatin_m, chLatin_e, chLatin_o, chLatin_u, chLatin_t, chNull };
112 static const XMLCh storeAttributes[] =
113 { chLatin_s, chLatin_t, chLatin_o, chLatin_r, chLatin_e, chLatin_A, chLatin_t, chLatin_t, chLatin_r, chLatin_i, chLatin_b, chLatin_u, chLatin_t, chLatin_e, chLatin_s, chNull };
115 static const XMLCh cleanupInterval[] = UNICODE_LITERAL_15(c,l,e,a,n,u,p,I,n,t,e,r,v,a,l);
121 ODBCConn(SQLHDBC conn) : handle(conn) {}
122 ~ODBCConn() {SQLFreeHandle(SQL_HANDLE_DBC,handle);}
123 operator SQLHDBC() {return handle;}
127 class ODBCBase : public virtual saml::IPlugIn
130 ODBCBase(const DOMElement* e);
138 const DOMElement* m_root; // can only use this during initialization
141 static SQLHENV m_henv; // single handle for both plugins
142 bool m_bInitializedODBC; // tracks which class handled the process
143 static const char* p_connstring;
145 pair<int,int> getVersion(SQLHDBC);
146 void log_error(SQLHANDLE handle, SQLSMALLINT htype);
149 SQLHENV ODBCBase::m_henv = SQL_NULL_HANDLE;
150 const char* ODBCBase::p_connstring = NULL;
152 ODBCBase::ODBCBase(const DOMElement* e) : m_root(e), m_bInitializedODBC(false)
155 xmltooling::NDC ndc("ODBCBase");
157 log = &(Category::getInstance("shibtarget.ODBC"));
159 if (m_henv == SQL_NULL_HANDLE) {
160 // Enable connection pooling.
161 SQLSetEnvAttr(SQL_NULL_HANDLE, SQL_ATTR_CONNECTION_POOLING, (void*)SQL_CP_ONE_PER_HENV, 0);
163 // Allocate the environment.
164 if (!SQL_SUCCEEDED(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &m_henv)))
165 throw ConfigurationException("ODBC failed to initialize.");
168 SQLSetEnvAttr(m_henv, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0);
170 log->info("ODBC initialized");
171 m_bInitializedODBC = true;
174 // Grab connection string from the configuration.
175 e=XMLHelper::getFirstChildElement(e,ConnectionString);
176 if (!e || !e->hasChildNodes()) {
179 throw ConfigurationException("ODBC cache requires ConnectionString element in configuration.");
181 m_connstring=p_connstring;
184 xmltooling::auto_ptr_char arg(e->getFirstChild()->getNodeValue());
185 m_connstring=arg.get();
186 p_connstring=m_connstring.c_str();
189 // Connect and check version.
190 SQLHDBC conn=getHDBC();
191 pair<int,int> v=getVersion(conn);
192 SQLFreeHandle(SQL_HANDLE_DBC,conn);
194 // Make sure we've got the right version.
195 if (v.first != PLUGIN_VER_MAJOR) {
197 log->crit("unknown database version: %d.%d", v.first, v.second);
198 throw SAMLException("Unknown cache database version.");
202 ODBCBase::~ODBCBase()
205 if (m_bInitializedODBC)
206 SQLFreeHandle(SQL_HANDLE_ENV,m_henv);
207 m_bInitializedODBC=false;
208 m_henv = SQL_NULL_HANDLE;
212 void ODBCBase::log_error(SQLHANDLE handle, SQLSMALLINT htype)
222 ret = SQLGetDiagRec(htype, handle, ++i, state, &native, text, sizeof(text), &len);
223 if (SQL_SUCCEEDED(ret))
224 log->error("ODBC Error: %s:%ld:%ld:%s", state, i, native, text);
225 } while(SQL_SUCCEEDED(ret));
228 SQLHDBC ODBCBase::getHDBC()
231 xmltooling::NDC ndc("getMYSQL");
236 SQLRETURN sr=SQLAllocHandle(SQL_HANDLE_DBC, m_henv, &handle);
237 if (!SQL_SUCCEEDED(sr)) {
238 log->error("failed to allocate connection handle");
239 log_error(m_henv, SQL_HANDLE_ENV);
240 throw SAMLException("ODBCBase::getHDBC failed to allocate connection handle");
243 sr=SQLDriverConnect(handle,NULL,(SQLCHAR*)m_connstring.c_str(),m_connstring.length(),NULL,0,NULL,SQL_DRIVER_NOPROMPT);
244 if (!SQL_SUCCEEDED(sr)) {
245 log->error("failed to connect to database");
246 log_error(handle, SQL_HANDLE_DBC);
247 throw SAMLException("ODBCBase::getHDBC failed to connect to database");
253 pair<int,int> ODBCBase::getVersion(SQLHDBC conn)
255 // Grab the version number from the database.
257 SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
259 SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)"SELECT major,minor FROM version", SQL_NTS);
260 if (!SQL_SUCCEEDED(sr)) {
261 log->error("failed to read version from database");
262 log_error(hstmt, SQL_HANDLE_STMT);
263 throw SAMLException("ODBCBase::getVersion failed to read version from database");
268 SQLBindCol(hstmt,1,SQL_C_SLONG,&major,0,NULL);
269 SQLBindCol(hstmt,2,SQL_C_SLONG,&minor,0,NULL);
271 if ((sr=SQLFetch(hstmt)) != SQL_NO_DATA) {
272 SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
273 return pair<int,int>(major,minor);
276 SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
277 log->error("no rows returned in version query");
278 throw SAMLException("ODBCBase::getVersion failed to read version from database");
282 // ------------------------------------------------------------
284 // ODBC Storage Service class
286 class ODBCStorageService : public ODBCBase, public StorageService
288 string stringTable = STRING_TABLE;
289 string textTable = TEXT_TABLE;
292 ODBCStorageService(const DOMElement* e);
293 virtual ~ODBCStorageService();
295 void createString(const char* context, const char* key, const char* value, time_t expiration) {
296 return createRow(string_table, context, key, value, expiration);
298 bool readString(const char* context, const char* key, string* pvalue=NULL, time_t* pexpiration=NULL) {
299 return readRow(string_table, context, key, value, expiration, COLSIZE_STRING_VALUE);
301 bool updateString(const char* context, const char* key, const char* value=NULL, time_t expiration=0) {
302 return updateRow(string_table, context, key, value, expiration);
304 bool deleteString(const char* context, const char* key) {
305 return deleteRow(string_table, context, key, value, expiration);
308 void createText(const char* context, const char* key, const char* value, time_t expiration) {
309 return createRow(text_table, context, key, value, expiration);
311 bool readText(const char* context, const char* key, string* pvalue=NULL, time_t* pexpiration=NULL) {
312 return readRow(text_table, context, key, value, expiration, 0);
314 bool updateText(const char* context, const char* key, const char* value=NULL, time_t expiration=0) {
315 return updateRow(text_table, context, key, value, expiration);
317 bool deleteText(const char* context, const char* key) {
318 return deleteRow(text_table, context, key, value, expiration);
321 void reap(const char* context) {
322 reap(string_table, context);
323 reap(text_table, context);
325 void deleteContext(const char* context) {
326 deleteCtx(string_table, context);
327 deleteCtx(text_table, context);
333 void createRow(const char *table, const char* context, const char* key, const char* value, time_t expiration);
334 bool readRow(const char *table, const char* context, const char* key, string* pvalue, time_t* pexpiration, int maxsize);
335 bool updateRow(const char *table, const char* context, const char* key, const char* value, time_t expiration);
336 bool deleteRow(const char *table, const char* context, const char* key);
338 void reapRows(const char* table, const char* context);
339 void deleteCtx(const char* table, const char* context);
341 xmltooling::CondWait* shutdown_wait;
343 xmltooling::Thread* cleanup_thread;
345 static void* cleanup_fcn(void*);
348 CondWait* shutdown_wait;
349 Thread* cleanup_thread;
351 int m_cleanupInterval;
354 StorageService* ODBCStorageServiceFactory(const DOMElement* const & e)
356 return new ODBCStorageService(e);
359 // convert SQL timestamp to time_t
360 time_t timeFromTimestamp(SQL_TIMESTAMP_STRUCT expires)
364 t.tm_sec=expires.second;
365 t.tm_min=expires.minute;
366 t.tm_hour=expires.hour;
367 t.tm_mday=expires.day;
368 t.tm_mon=expires.month-1;
369 t.tm_year=expires.year-1900;
371 #if defined(HAVE_TIMEGM)
374 ret = mktime(&t) - timezone;
379 // conver time_t to SQL string
380 void timestampFromTime(time_t t, char &ret)
384 struct tm* ptime=gmtime_r(&created,&res);
386 struct tm* ptime=gmtime(&created);
388 strftime(ret,32,"{ts '%Y-%m-%d %H:%M:%S'}",ptime);
391 // make a string safe for SQL command
392 // result to be free'd only if it isn't the input
393 char *makeSafeSQL(const char *src)
399 // see if any conversion needed
400 for (s=(char*)src; *s; nc++,s++) if (*s=='\''||*s=='\\') ns++;
401 if (ns==0) return ((char*)src);
403 char *safe = (char*) malloc(nc+2*ns+1);
404 for (s=safe; *src; src++) {
405 if (*src=='\''||*src=='\\') *s++ = '\\';
412 void freeSafeSQL(char *safe, const char *src)
414 if (safe!=src) free(safe);
421 ODBCStorageService::ODBCStorageService(const DOMElement* e):
428 xmltooling::NDC ndc("ODBCStorageService");
430 log = &(Category::getInstance("shibtarget.StorageService.ODBC"));
432 const XMLCh* tag=e ? e->getAttributeNS(NULL,cleanupInterval) : NULL;
434 m_cleanupInterval = XMLString::parseInt(tag);
436 if (!m_cleanupInterval) m_cleanupInterval=300;
438 contextLock = Mutex::create();
439 shutdown_wait = CondWait::create();
441 // Initialize the cleanup thread
442 cleanup_thread = Thread::create(&cleanup_fcn, (void*)this);
445 ODBCStorageService::~ODBCStorageService()
448 shutdown_wait->signal();
449 cleanup_thread->join(NULL);
451 delete shutdown_wait;
457 void ODBCStorageService::createRow(const char *table, const char* context, const char* key, const char* value, time_t expiration)
460 xmltooling::NDC ndc("createRow");
464 timestampFromTime(expiration, timebuf);
466 // Get statement handle.
468 ODBCConn conn(getHDBC());
469 SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
471 // Prepare and exectute insert statement.
472 char *scontext = makeSafeSQL(context);
473 char *svalue = makeSafeSQL(value);
474 string q = string("INSERT ") + table + " VALUES ('" + scontext + "','" + key + "','" + svalue + "'," + timebuf + "')";
475 freeSafeSQL(scontext, context)
476 freeSafeSQL(svalue, value)
477 log->debug("SQL: %s", q.str());
479 SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.str().c_str(), SQL_NTS);
480 if (!SQL_SUCCEEDED(sr)) {
481 log->error("insert record failed (t=%s, c=%s, k=%s", table, context, key);
482 log_error(hstmt, SQL_HANDLE_STMT);
485 SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
490 bool ODBCStorageService::readRow(const char *table, const char* context, const char* key, string& pvalue, time_t& pexpiration, int maxsize)
493 xmltooling::NDC ndc("readRow");
496 SQLCHAR *tvalue = NULL;
497 SQL_TIMESTAMP_STRUCT expires;
500 // Get statement handle.
502 ODBCConn conn(getHDBC());
503 SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
505 // Prepare and exectute select statement.
506 char *scontext = makeSafeSQL(context);
507 string q = string("SELECT expires,value FROM ") + table +
508 " WHERE context='" + scontext + "' AND key='" + key + "'";
509 freeSafeSQL(scontext, context)
510 log->debug("SQL: %s", q.str());
512 SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
513 if (!SQL_SUCCEEDED(sr)) {
514 log->error("error searching for (t=%s, c=%s, k=%s)", table, context, key);
515 log_error(hstmt, SQL_HANDLE_STMT);
516 SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
521 SQLBindCol(hstmt,1,SQL_C_TYPE_TIMESTAMP,&expires,0,NULL);
523 if ((sr=SQLFetch(hstmt)) == SQL_NO_DATA) {
524 SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
528 // expire time from bound col
529 exp = timeFromTimestamp(expires);
530 if (time(NULL)>ezp) {
531 log->debug(".. expired");
532 SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
535 if (pexpiration) pexpiration = exp;
539 // see how much text there is
543 sr = SQLGetData(hstmt, 2, SQL_C_CHAR, tvp, BUFSIZE_TEXT_BLOCK, &nch);
544 if (sr==SQL_SUCCESS || sr==SQL_SUCCESS_WITH_INFO) {
549 tvalue = (SQLCHAR*) malloc(maxsize+1);
550 sr = SQLGetData(hstmt, 2, SQL_C_CHAR, tvalue, maxsize, &nch);
551 if (sr!=SQL_SUCCESS) {
552 log->error("error retriving value for (t=%s, c=%s, k=%s)", table, context, key);
553 log_error(hstmt, SQL_HANDLE_STMT);
554 SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
558 pvalue = string(tvalue);
561 log->debug(".. value found");
563 SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
570 bool ODBCStorageService::updateRow(const char *table, const char* context, const char* key, const char* value, time_t expiration)
573 xmltooling::NDC ndc("updateRow");
579 timestampFromTime(expiration, timebuf);
581 // Get statement handle.
583 ODBCConn conn(getHDBC());
584 SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
586 // Prepare and exectute update statement.
589 if (expiration) expstr = string(",expires = '") + timebuf + "' ";
591 char *scontext = makeSafeSQL(context);
592 char *svalue = makeSafeSQL(value);
593 string q = string("UPDATE ") + table + " SET value='" + svalue + "'" + expstr +
594 " WHERE context='" + scontext + "' AND key='" + key + "' AND expires > NOW()";
595 freeSafeSQL(scontext, context)
596 freeSafeSQL(svalue, value)
597 log->debug("SQL: %s", q.str());
599 SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.str().c_str(), SQL_NTS);
600 if (!SQL_SUCCEEDED(sr)) {
601 log->error("update record failed (t=%s, c=%s, k=%s", table, context, key);
602 log_error(hstmt, SQL_HANDLE_STMT);
606 SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
613 bool ODBCStorageService::deleteRow(const char *table, const char *context, const char* key)
616 xmltooling::NDC ndc("deleteRow");
621 // Get statement handle.
623 ODBCConn conn(getHDBC());
624 SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
626 // Prepare and execute delete statement.
627 char *scontext = makeSafeSQL(context);
628 string q = string("DELETE FROM ") + table + " WHERE context='" + scontext + "' AND key='" + key + "'";
629 freeSafeSQL(scontext, context)
630 log->debug("SQL: %s", q.str());
632 SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
634 if (sr==SQL_NO_DATA) {
636 } else if (!SQL_SUCCEEDED(sr)) {
637 log->error("error deleting record (t=%s, c=%s, k=%s)", table, context, key);
638 log_error(hstmt, SQL_HANDLE_STMT);
642 SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
647 // cleanup - delete expired entries
649 void ODBCStorageService::cleanup()
652 xmltooling::NDC ndc("cleanup");
655 Mutex* mutex = xmltooling::Mutex::create();
658 int timeout_life = 0;
662 log->info("cleanup thread started... running every %d secs", m_cleanupInterval);
665 shutdown_wait->timedwait(mutex, m_cleanupInterval);
672 log->info("cleanup thread exiting...");
676 xmltooling::Thread::exit(NULL);
679 void* ODBCStorageService::cleanup_fcn(void* cache_p)
681 ODBCStorageService* cache = (ODBCStorageService*)cache_p;
684 // First, let's block all signals
685 Thread::mask_all_signals();
688 // Now run the cleanup process.
694 // remove expired entries for a context
696 void ODBCStorageService::reapRows(const char *table, const char* context)
699 xmltooling::NDC ndc("reapRows");
702 // Get statement handle.
704 ODBCConn conn(getHDBC());
705 SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
707 // Prepare and execute delete statement.
710 char *scontext = makeSafeSQL(context);
711 q = string("DELETE FROM ") + table + " WHERE context='" + scontext + "' AND expires<NOW()";
712 freeSafeSQL(scontext, context)
714 q = string("DELETE FROM ") + table + " WHERE expires<NOW()";
716 log->debug("SQL: %s", q.str());
718 SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
720 if ((sr!=SQL_NO_DATA) && !SQL_SUCCEEDED(sr)) {
721 log->error("error expiring records (t=%s, c=%s)", table, context?context:"null");
722 log_error(hstmt, SQL_HANDLE_STMT);
725 SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
730 // remove all entries for a context
732 void ODBCStorageService::deleteCtx(const char *table, const char* context)
735 xmltooling::NDC ndc("deleteCtx");
738 // Get statement handle.
740 ODBCConn conn(getHDBC());
741 SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
743 // Prepare and execute delete statement.
744 char *scontext = makeSafeSQL(context);
745 string q = string("DELETE FROM ") + table + " WHERE context='" + scontext + "'";
746 freeSafeSQL(scontext, context)
747 log->debug("SQL: %s", q.str());
749 SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
751 if ((sr!=SQL_NO_DATA) && !SQL_SUCCEEDED(sr)) {
752 log->error("error deleting context (t=%s, c=%s)", table, context);
753 log_error(hstmt, SQL_HANDLE_STMT);
756 SQLFreeHandle(SQL_HANDLE_STMT,hstmt);