546eefc66218ed3a6eb5fa3f52f245f423627468
[shibboleth/sp.git] / odbc-store / odbc-store.cpp
1 /*
2  *  Copyright 2001-2009 Internet2
3  * 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /**
18  * odbc-store.cpp
19  *
20  * Storage Service using ODBC.
21  */
22
23 #if defined (_MSC_VER) || defined(__BORLANDC__)
24 # include "config_win32.h"
25 #else
26 # include "config.h"
27 #endif
28
29 #ifdef WIN32
30 # define _CRT_NONSTDC_NO_DEPRECATE 1
31 # define _CRT_SECURE_NO_DEPRECATE 1
32 #endif
33
34 #ifdef WIN32
35 # define ODBCSTORE_EXPORTS __declspec(dllexport)
36 #else
37 # define ODBCSTORE_EXPORTS
38 #endif
39
40 #include <xercesc/util/XMLUniDefs.hpp>
41 #include <xmltooling/logging.h>
42 #include <xmltooling/unicode.h>
43 #include <xmltooling/XMLToolingConfig.h>
44 #include <xmltooling/util/NDC.h>
45 #include <xmltooling/util/StorageService.h>
46 #include <xmltooling/util/Threads.h>
47 #include <xmltooling/util/XMLHelper.h>
48
49 #include <sql.h>
50 #include <sqlext.h>
51
52 using namespace xmltooling::logging;
53 using namespace xmltooling;
54 using namespace xercesc;
55 using namespace std;
56
57 #define PLUGIN_VER_MAJOR 1
58 #define PLUGIN_VER_MINOR 0
59
60 #define LONGDATA_BUFLEN 16384
61
62 #define COLSIZE_CONTEXT 255
63 #define COLSIZE_ID 255
64 #define COLSIZE_STRING_VALUE 255
65
66 #define STRING_TABLE "strings"
67 #define TEXT_TABLE "texts"
68
69 /* table definitions
70 CREATE TABLE version (
71     major int NOT NULL,
72     minor int NOT NULL
73     )
74
75 CREATE TABLE strings (
76     context varchar(255) not null,
77     id varchar(255) not null,
78     expires datetime not null,
79     version smallint not null,
80     value varchar(255) not null,
81     PRIMARY KEY (context, id)
82     )
83
84 CREATE TABLE texts (
85     context varchar(255) not null,
86     id varchar(255) not null,
87     expires datetime not null,
88     version smallint not null,
89     value text not null,
90     PRIMARY KEY (context, id)
91     )
92 */
93
94 namespace {
95     static const XMLCh cleanupInterval[] =  UNICODE_LITERAL_15(c,l,e,a,n,u,p,I,n,t,e,r,v,a,l);
96     static const XMLCh isolationLevel[] =   UNICODE_LITERAL_14(i,s,o,l,a,t,i,o,n,L,e,v,e,l);
97     static const XMLCh ConnectionString[] = UNICODE_LITERAL_16(C,o,n,n,e,c,t,i,o,n,S,t,r,i,n,g);
98     static const XMLCh RetryOnError[] =     UNICODE_LITERAL_12(R,e,t,r,y,O,n,E,r,r,o,r);
99
100     // RAII for ODBC handles
101     struct ODBCConn {
102         ODBCConn(SQLHDBC conn) : handle(conn), autoCommit(true) {}
103         ~ODBCConn() {
104             SQLRETURN sr = SQL_SUCCESS;
105             if (!autoCommit)
106                 sr = SQLSetConnectAttr(handle, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)SQL_AUTOCOMMIT_ON, NULL);
107             SQLDisconnect(handle);
108             SQLFreeHandle(SQL_HANDLE_DBC,handle);
109             if (!SQL_SUCCEEDED(sr))
110                 throw IOException("Failed to commit connection and return to auto-commit mode.");
111         }
112         operator SQLHDBC() {return handle;}
113         SQLHDBC handle;
114         bool autoCommit;
115     };
116
117     class ODBCStorageService : public StorageService
118     {
119     public:
120         ODBCStorageService(const DOMElement* e);
121         virtual ~ODBCStorageService();
122
123         bool createString(const char* context, const char* key, const char* value, time_t expiration) {
124             return createRow(STRING_TABLE, context, key, value, expiration);
125         }
126         int readString(const char* context, const char* key, string* pvalue=NULL, time_t* pexpiration=NULL, int version=0) {
127             return readRow(STRING_TABLE, context, key, pvalue, pexpiration, version, false);
128         }
129         int updateString(const char* context, const char* key, const char* value=NULL, time_t expiration=0, int version=0) {
130             return updateRow(STRING_TABLE, context, key, value, expiration, version);
131         }
132         bool deleteString(const char* context, const char* key) {
133             return deleteRow(STRING_TABLE, context, key);
134         }
135
136         bool createText(const char* context, const char* key, const char* value, time_t expiration) {
137             return createRow(TEXT_TABLE, context, key, value, expiration);
138         }
139         int readText(const char* context, const char* key, string* pvalue=NULL, time_t* pexpiration=NULL, int version=0) {
140             return readRow(TEXT_TABLE, context, key, pvalue, pexpiration, version, true);
141         }
142         int updateText(const char* context, const char* key, const char* value=NULL, time_t expiration=0, int version=0) {
143             return updateRow(TEXT_TABLE, context, key, value, expiration, version);
144         }
145         bool deleteText(const char* context, const char* key) {
146             return deleteRow(TEXT_TABLE, context, key);
147         }
148
149         void reap(const char* context) {
150             reap(STRING_TABLE, context);
151             reap(TEXT_TABLE, context);
152         }
153
154         void updateContext(const char* context, time_t expiration) {
155             updateContext(STRING_TABLE, context, expiration);
156             updateContext(TEXT_TABLE, context, expiration);
157         }
158
159         void deleteContext(const char* context) {
160             deleteContext(STRING_TABLE, context);
161             deleteContext(TEXT_TABLE, context);
162         }
163          
164
165     private:
166         bool createRow(const char *table, const char* context, const char* key, const char* value, time_t expiration);
167         int readRow(const char *table, const char* context, const char* key, string* pvalue, time_t* pexpiration, int version, bool text);
168         int updateRow(const char *table, const char* context, const char* key, const char* value, time_t expiration, int version);
169         bool deleteRow(const char *table, const char* context, const char* key);
170
171         void reap(const char* table, const char* context);
172         void updateContext(const char* table, const char* context, time_t expiration);
173         void deleteContext(const char* table, const char* context);
174
175         SQLHDBC getHDBC();
176         SQLHSTMT getHSTMT(SQLHDBC);
177         pair<int,int> getVersion(SQLHDBC);
178         pair<bool,bool> log_error(SQLHANDLE handle, SQLSMALLINT htype, const char* checkfor=NULL);
179
180         static void* cleanup_fn(void*); 
181         void cleanup();
182
183         Category& m_log;
184         int m_cleanupInterval;
185         CondWait* shutdown_wait;
186         Thread* cleanup_thread;
187         bool shutdown;
188
189         SQLHENV m_henv;
190         string m_connstring;
191         long m_isolation;
192         vector<SQLINTEGER> m_retries;
193     };
194
195     StorageService* ODBCStorageServiceFactory(const DOMElement* const & e)
196     {
197         return new ODBCStorageService(e);
198     }
199
200     // convert SQL timestamp to time_t 
201     time_t timeFromTimestamp(SQL_TIMESTAMP_STRUCT expires)
202     {
203         time_t ret;
204         struct tm t;
205         t.tm_sec=expires.second;
206         t.tm_min=expires.minute;
207         t.tm_hour=expires.hour;
208         t.tm_mday=expires.day;
209         t.tm_mon=expires.month-1;
210         t.tm_year=expires.year-1900;
211         t.tm_isdst=0;
212 #if defined(HAVE_TIMEGM)
213         ret = timegm(&t);
214 #else
215         ret = mktime(&t) - timezone;
216 #endif
217         return (ret);
218     }
219
220     // conver time_t to SQL string
221     void timestampFromTime(time_t t, char* ret)
222     {
223 #ifdef HAVE_GMTIME_R
224         struct tm res;
225         struct tm* ptime=gmtime_r(&t,&res);
226 #else
227         struct tm* ptime=gmtime(&t);
228 #endif
229         strftime(ret,32,"{ts '%Y-%m-%d %H:%M:%S'}",ptime);
230     }
231
232     // make a string safe for SQL command
233     // result to be free'd only if it isn't the input
234     static char *makeSafeSQL(const char *src)
235     {
236        int ns = 0;
237        int nc = 0;
238        char *s;
239     
240        // see if any conversion needed
241        for (s=(char*)src; *s; nc++,s++) if (*s=='\'') ns++;
242        if (ns==0) return ((char*)src);
243     
244        char *safe = new char[(nc+2*ns+1)];
245        for (s=safe; *src; src++) {
246            if (*src=='\'') *s++ = '\'';
247            *s++ = (char)*src;
248        }
249        *s = '\0';
250        return (safe);
251     }
252
253     void freeSafeSQL(char *safe, const char *src)
254     {
255         if (safe!=src)
256             delete[](safe);
257     }
258 };
259
260 ODBCStorageService::ODBCStorageService(const DOMElement* e) : m_log(Category::getInstance("XMLTooling.StorageService")),
261    m_cleanupInterval(900), shutdown_wait(NULL), cleanup_thread(NULL), shutdown(false), m_henv(SQL_NULL_HANDLE), m_isolation(SQL_TXN_SERIALIZABLE)
262 {
263 #ifdef _DEBUG
264     xmltooling::NDC ndc("ODBCStorageService");
265 #endif
266
267     const XMLCh* tag=e ? e->getAttributeNS(NULL,cleanupInterval) : NULL;
268     if (tag && *tag)
269         m_cleanupInterval = XMLString::parseInt(tag);
270     if (!m_cleanupInterval)
271         m_cleanupInterval = 900;
272
273     auto_ptr_char iso(e ? e->getAttributeNS(NULL,isolationLevel) : NULL);
274     if (iso.get() && *iso.get()) {
275         if (!strcmp(iso.get(),"SERIALIZABLE"))
276             m_isolation = SQL_TXN_SERIALIZABLE;
277         else if (!strcmp(iso.get(),"REPEATABLE_READ"))
278             m_isolation = SQL_TXN_REPEATABLE_READ;
279         else if (!strcmp(iso.get(),"READ_COMMITTED"))
280             m_isolation = SQL_TXN_READ_COMMITTED;
281         else if (!strcmp(iso.get(),"READ_UNCOMMITTED"))
282             m_isolation = SQL_TXN_READ_UNCOMMITTED;
283         else
284             throw XMLToolingException("Unknown transaction isolationLevel property.");
285     }
286
287     if (m_henv == SQL_NULL_HANDLE) {
288         // Enable connection pooling.
289         SQLSetEnvAttr(SQL_NULL_HANDLE, SQL_ATTR_CONNECTION_POOLING, (void*)SQL_CP_ONE_PER_HENV, 0);
290
291         // Allocate the environment.
292         if (!SQL_SUCCEEDED(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &m_henv)))
293             throw XMLToolingException("ODBC failed to initialize.");
294
295         // Specify ODBC 3.x
296         SQLSetEnvAttr(m_henv, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0);
297
298         m_log.info("ODBC initialized");
299     }
300
301     // Grab connection string from the configuration.
302     e = e ? XMLHelper::getFirstChildElement(e,ConnectionString) : NULL;
303     if (!e || !e->hasChildNodes()) {
304         SQLFreeHandle(SQL_HANDLE_ENV, m_henv);
305         throw XMLToolingException("ODBC StorageService requires ConnectionString element in configuration.");
306     }
307     auto_ptr_char arg(e->getFirstChild()->getNodeValue());
308     m_connstring=arg.get();
309
310     // Connect and check version.
311     ODBCConn conn(getHDBC());
312     pair<int,int> v=getVersion(conn);
313
314     // Make sure we've got the right version.
315     if (v.first != PLUGIN_VER_MAJOR) {
316         SQLFreeHandle(SQL_HANDLE_ENV, m_henv);
317         m_log.crit("unknown database version: %d.%d", v.first, v.second);
318         throw XMLToolingException("Unknown database version for ODBC StorageService.");
319     }
320
321     // Load any retry errors to check.
322     e = XMLHelper::getNextSiblingElement(e,RetryOnError);
323     while (e) {
324         if (e->hasChildNodes()) {
325             m_retries.push_back(XMLString::parseInt(e->getFirstChild()->getNodeValue()));
326             m_log.info("will retry operations when native ODBC error (%ld) is returned", m_retries.back());
327         }
328         e = XMLHelper::getNextSiblingElement(e,RetryOnError);
329     }
330
331     // Initialize the cleanup thread
332     shutdown_wait = CondWait::create();
333     cleanup_thread = Thread::create(&cleanup_fn, (void*)this);
334 }
335
336 ODBCStorageService::~ODBCStorageService()
337 {
338     shutdown = true;
339     shutdown_wait->signal();
340     cleanup_thread->join(NULL);
341     delete shutdown_wait;
342     if (m_henv != SQL_NULL_HANDLE)
343         SQLFreeHandle(SQL_HANDLE_ENV, m_henv);
344 }
345
346 pair<bool,bool> ODBCStorageService::log_error(SQLHANDLE handle, SQLSMALLINT htype, const char* checkfor)
347 {
348     SQLSMALLINT  i = 0;
349     SQLINTEGER   native;
350     SQLCHAR      state[7];
351     SQLCHAR      text[256];
352     SQLSMALLINT  len;
353     SQLRETURN    ret;
354
355     pair<bool,bool> res = make_pair(false,false);
356     do {
357         ret = SQLGetDiagRec(htype, handle, ++i, state, &native, text, sizeof(text), &len);
358         if (SQL_SUCCEEDED(ret)) {
359             m_log.error("ODBC Error: %s:%ld:%ld:%s", state, i, native, text);
360             for (vector<SQLINTEGER>::const_iterator n = m_retries.begin(); !res.first && n != m_retries.end(); ++n)
361                 res.first = (*n == native);
362             if (checkfor && !strcmp(checkfor, (const char*)state))
363                 res.second = true;
364         }
365     } while(SQL_SUCCEEDED(ret));
366     return res;
367 }
368
369 SQLHDBC ODBCStorageService::getHDBC()
370 {
371 #ifdef _DEBUG
372     xmltooling::NDC ndc("getHDBC");
373 #endif
374
375     // Get a handle.
376     SQLHDBC handle;
377     SQLRETURN sr=SQLAllocHandle(SQL_HANDLE_DBC, m_henv, &handle);
378     if (!SQL_SUCCEEDED(sr)) {
379         m_log.error("failed to allocate connection handle");
380         log_error(m_henv, SQL_HANDLE_ENV);
381         throw IOException("ODBC StorageService failed to allocate a connection handle.");
382     }
383
384     sr=SQLDriverConnect(handle,NULL,(SQLCHAR*)m_connstring.c_str(),m_connstring.length(),NULL,0,NULL,SQL_DRIVER_NOPROMPT);
385     if (!SQL_SUCCEEDED(sr)) {
386         m_log.error("failed to connect to database");
387         log_error(handle, SQL_HANDLE_DBC);
388         throw IOException("ODBC StorageService failed to connect to database.");
389     }
390
391     sr = SQLSetConnectAttr(handle, SQL_ATTR_TXN_ISOLATION, (SQLPOINTER)m_isolation, NULL);
392     if (!SQL_SUCCEEDED(sr))
393         throw IOException("ODBC StorageService failed to set transaction isolation level.");
394
395     return handle;
396 }
397
398 SQLHSTMT ODBCStorageService::getHSTMT(SQLHDBC conn)
399 {
400     SQLHSTMT hstmt;
401     SQLRETURN sr=SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
402     if (!SQL_SUCCEEDED(sr)) {
403         m_log.error("failed to allocate statement handle");
404         log_error(conn, SQL_HANDLE_DBC);
405         throw IOException("ODBC StorageService failed to allocate a statement handle.");
406     }
407     return hstmt;
408 }
409
410 pair<int,int> ODBCStorageService::getVersion(SQLHDBC conn)
411 {
412     // Grab the version number from the database.
413     SQLHSTMT stmt = getHSTMT(conn);
414     
415     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)"SELECT major,minor FROM version", SQL_NTS);
416     if (!SQL_SUCCEEDED(sr)) {
417         m_log.error("failed to read version from database");
418         log_error(stmt, SQL_HANDLE_STMT);
419         throw IOException("ODBC StorageService failed to read version from database.");
420     }
421
422     SQLINTEGER major;
423     SQLINTEGER minor;
424     SQLBindCol(stmt,1,SQL_C_SLONG,&major,0,NULL);
425     SQLBindCol(stmt,2,SQL_C_SLONG,&minor,0,NULL);
426
427     if ((sr=SQLFetch(stmt)) != SQL_NO_DATA)
428         return pair<int,int>(major,minor);
429
430     m_log.error("no rows returned in version query");
431     throw IOException("ODBC StorageService failed to read version from database.");
432 }
433
434 bool ODBCStorageService::createRow(const char* table, const char* context, const char* key, const char* value, time_t expiration)
435 {
436 #ifdef _DEBUG
437     xmltooling::NDC ndc("createRow");
438 #endif
439
440     char timebuf[32];
441     timestampFromTime(expiration, timebuf);
442
443     // Get statement handle.
444     ODBCConn conn(getHDBC());
445     SQLHSTMT stmt = getHSTMT(conn);
446
447     // Prepare and exectute insert statement.
448     //char *scontext = makeSafeSQL(context);
449     //char *skey = makeSafeSQL(key);
450     //char *svalue = makeSafeSQL(value);
451     string q  = string("INSERT INTO ") + table + " VALUES (?,?," + timebuf + ",1,?)";
452
453     SQLRETURN sr = SQLPrepare(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
454     if (!SQL_SUCCEEDED(sr)) {
455         m_log.error("SQLPrepare failed (t=%s, c=%s, k=%s)", table, context, key);
456         log_error(stmt, SQL_HANDLE_STMT);
457         throw IOException("ODBC StorageService failed to insert record.");
458     }
459     m_log.debug("SQLPrepare succeeded. SQL: %s", q.c_str());
460
461     SQLLEN b_ind = SQL_NTS;
462     sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(context), &b_ind);
463     if (!SQL_SUCCEEDED(sr)) {
464         m_log.error("SQLBindParam failed (context = %s)", context);
465         log_error(stmt, SQL_HANDLE_STMT);
466         throw IOException("ODBC StorageService failed to insert record.");
467     }
468     m_log.debug("SQLBindParam succeeded (context = %s)", context);
469
470     sr = SQLBindParam(stmt, 2, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(key), &b_ind);
471     if (!SQL_SUCCEEDED(sr)) {
472         m_log.error("SQLBindParam failed (key = %s)", key);
473         log_error(stmt, SQL_HANDLE_STMT);
474         throw IOException("ODBC StorageService failed to insert record.");
475     }
476     m_log.debug("SQLBindParam succeeded (key = %s)", key);
477
478     if (strcmp(table, TEXT_TABLE)==0)
479         sr = SQLBindParam(stmt, 3, SQL_C_CHAR, SQL_LONGVARCHAR, strlen(value), 0, const_cast<char*>(value), &b_ind);
480     else
481         sr = SQLBindParam(stmt, 3, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(value), &b_ind);
482     if (!SQL_SUCCEEDED(sr)) {
483         m_log.error("SQLBindParam failed (value = %s)", value);
484         log_error(stmt, SQL_HANDLE_STMT);
485         throw IOException("ODBC StorageService failed to insert record.");
486     }
487     m_log.debug("SQLBindParam succeeded (value = %s)", value);
488     
489     //freeSafeSQL(scontext, context);
490     //freeSafeSQL(skey, key);
491     //freeSafeSQL(svalue, value);
492     //m_log.debug("SQL: %s", q.c_str());
493
494     int attempts = 3;
495     pair<bool,bool> logres;
496     do {
497         logres = make_pair(false,false);
498         attempts--;
499         sr=SQLExecute(stmt);
500         if (SQL_SUCCEEDED(sr)) {
501             m_log.debug("SQLExecute of insert succeeded");
502             return true;
503         }
504         m_log.error("insert record failed (t=%s, c=%s, k=%s)", table, context, key);
505         logres = log_error(stmt, SQL_HANDLE_STMT, "23000");
506         if (logres.second)
507             return false;   // supposedly integrity violation?
508     } while (attempts && logres.first);
509
510     throw IOException("ODBC StorageService failed to insert record.");
511 }
512
513 int ODBCStorageService::readRow(
514     const char *table, const char* context, const char* key, string* pvalue, time_t* pexpiration, int version, bool text
515     )
516 {
517 #ifdef _DEBUG
518     xmltooling::NDC ndc("readRow");
519 #endif
520
521     // Get statement handle.
522     ODBCConn conn(getHDBC());
523     SQLHSTMT stmt = getHSTMT(conn);
524
525     // Prepare and exectute select statement.
526     char timebuf[32];
527     timestampFromTime(time(NULL), timebuf);
528     char *scontext = makeSafeSQL(context);
529     char *skey = makeSafeSQL(key);
530     ostringstream q;
531     q << "SELECT version";
532     if (pexpiration)
533         q << ",expires";
534     if (pvalue)
535         q << ",CASE version WHEN " << version << " THEN NULL ELSE value END";
536     q << " FROM " << table << " WHERE context='" << scontext << "' AND id='" << skey << "' AND expires > " << timebuf;
537     freeSafeSQL(scontext, context);
538     freeSafeSQL(skey, key);
539     if (m_log.isDebugEnabled())
540         m_log.debug("SQL: %s", q.str().c_str());
541
542     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.str().c_str(), SQL_NTS);
543     if (!SQL_SUCCEEDED(sr)) {
544         m_log.error("error searching for (t=%s, c=%s, k=%s)", table, context, key);
545         log_error(stmt, SQL_HANDLE_STMT);
546         throw IOException("ODBC StorageService search failed.");
547     }
548
549     SQLSMALLINT ver;
550     SQL_TIMESTAMP_STRUCT expiration;
551
552     SQLBindCol(stmt,1,SQL_C_SSHORT,&ver,0,NULL);
553     if (pexpiration)
554         SQLBindCol(stmt,2,SQL_C_TYPE_TIMESTAMP,&expiration,0,NULL);
555
556     if ((sr=SQLFetch(stmt)) == SQL_NO_DATA)
557         return 0;
558
559     if (pexpiration)
560         *pexpiration = timeFromTimestamp(expiration);
561
562     if (version == ver)
563         return version; // nothing's changed, so just echo back the version
564
565     if (pvalue) {
566         SQLLEN len;
567         SQLCHAR buf[LONGDATA_BUFLEN];
568         while ((sr=SQLGetData(stmt,pexpiration ? 3 : 2,SQL_C_CHAR,buf,sizeof(buf),&len)) != SQL_NO_DATA) {
569             if (!SQL_SUCCEEDED(sr)) {
570                 m_log.error("error while reading text field from result set");
571                 log_error(stmt, SQL_HANDLE_STMT);
572                 throw IOException("ODBC StorageService search failed to read data from result set.");
573             }
574             pvalue->append((char*)buf);
575         }
576     }
577     
578     return ver;
579 }
580
581 int ODBCStorageService::updateRow(const char *table, const char* context, const char* key, const char* value, time_t expiration, int version)
582 {
583 #ifdef _DEBUG
584     xmltooling::NDC ndc("updateRow");
585 #endif
586
587     if (!value && !expiration)
588         throw IOException("ODBC StorageService given invalid update instructions.");
589
590     // Get statement handle. Disable auto-commit mode to wrap select + update.
591     ODBCConn conn(getHDBC());
592     SQLRETURN sr = SQLSetConnectAttr(conn, SQL_ATTR_AUTOCOMMIT, SQL_AUTOCOMMIT_OFF, NULL);
593     if (!SQL_SUCCEEDED(sr))
594         throw IOException("ODBC StorageService failed to disable auto-commit mode.");
595     conn.autoCommit = false;
596     SQLHSTMT stmt = getHSTMT(conn);
597
598     // First, fetch the current version for later, which also ensures the record still exists.
599     char timebuf[32];
600     timestampFromTime(time(NULL), timebuf);
601     char *scontext = makeSafeSQL(context);
602     char *skey = makeSafeSQL(key);
603     string q("SELECT version FROM ");
604     q = q + table + " WHERE context='" + scontext + "' AND id='" + skey + "' AND expires > " + timebuf;
605
606     m_log.debug("SQL: %s", q.c_str());
607
608     sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
609     if (!SQL_SUCCEEDED(sr)) {
610         freeSafeSQL(scontext, context);
611         freeSafeSQL(skey, key);
612         m_log.error("error searching for (t=%s, c=%s, k=%s)", table, context, key);
613         log_error(stmt, SQL_HANDLE_STMT);
614         throw IOException("ODBC StorageService search failed.");
615     }
616
617     SQLSMALLINT ver;
618     SQLBindCol(stmt,1,SQL_C_SSHORT,&ver,0,NULL);
619     if ((sr=SQLFetch(stmt)) == SQL_NO_DATA) {
620         freeSafeSQL(scontext, context);
621         freeSafeSQL(skey, key);
622         return 0;
623     }
624
625     // Check version?
626     if (version > 0 && version != ver) {
627         freeSafeSQL(scontext, context);
628         freeSafeSQL(skey, key);
629         return -1;
630     }
631
632     SQLFreeHandle(SQL_HANDLE_STMT, stmt);
633     stmt = getHSTMT(conn);
634
635     // Prepare and exectute update statement.
636     q = string("UPDATE ") + table + " SET ";
637
638     if (value)
639         q = q + "value=?, version=version+1";
640
641     if (expiration) {
642         timestampFromTime(expiration, timebuf);
643         if (value)
644             q += ',';
645         q = q + "expires = " + timebuf;
646     }
647
648     q = q + " WHERE context='" + scontext + "' AND id='" + skey + "'";
649     freeSafeSQL(scontext, context);
650     freeSafeSQL(skey, key);
651
652     sr = SQLPrepare(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
653     if (!SQL_SUCCEEDED(sr)) {
654         m_log.error("update of record failed (t=%s, c=%s, k=%s", table, context, key);
655         log_error(stmt, SQL_HANDLE_STMT);
656         throw IOException("ODBC StorageService failed to update record.");
657     }
658     m_log.debug("SQLPrepare succeeded. SQL: %s", q.c_str());
659
660     SQLLEN b_ind = SQL_NTS;
661     if (value) {
662         if (strcmp(table, TEXT_TABLE)==0)
663             sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_LONGVARCHAR, strlen(value), 0, const_cast<char*>(value), &b_ind);
664         else
665             sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(value), &b_ind);
666         if (!SQL_SUCCEEDED(sr)) {
667             m_log.error("SQLBindParam failed (context = %s)", context);
668             log_error(stmt, SQL_HANDLE_STMT);
669             throw IOException("ODBC StorageService failed to update record.");
670         }
671         m_log.debug("SQLBindParam succeeded (context = %s)", context);
672     }
673
674     int attempts = 3;
675     pair<bool,bool> logres;
676     do {
677         logres = make_pair(false,false);
678         attempts--;
679         sr=SQLExecute(stmt);
680         if (sr==SQL_NO_DATA)
681             return 0;   // went missing?
682         else if (SQL_SUCCEEDED(sr)) {
683             m_log.debug("SQLExecute of update succeeded");
684             return ver + 1;
685         }
686
687         m_log.error("update of record failed (t=%s, c=%s, k=%s", table, context, key);
688         logres = log_error(stmt, SQL_HANDLE_STMT);
689     } while (attempts && logres.first);
690
691     throw IOException("ODBC StorageService failed to update record.");
692 }
693
694 bool ODBCStorageService::deleteRow(const char *table, const char *context, const char* key)
695 {
696 #ifdef _DEBUG
697     xmltooling::NDC ndc("deleteRow");
698 #endif
699
700     // Get statement handle.
701     ODBCConn conn(getHDBC());
702     SQLHSTMT stmt = getHSTMT(conn);
703
704     // Prepare and execute delete statement.
705     char *scontext = makeSafeSQL(context);
706     char *skey = makeSafeSQL(key);
707     string q = string("DELETE FROM ") + table + " WHERE context='" + scontext + "' AND id='" + skey + "'";
708     freeSafeSQL(scontext, context);
709     freeSafeSQL(skey, key);
710     m_log.debug("SQL: %s", q.c_str());
711
712     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
713      if (sr==SQL_NO_DATA)
714         return false;
715     else if (!SQL_SUCCEEDED(sr)) {
716         m_log.error("error deleting record (t=%s, c=%s, k=%s)", table, context, key);
717         log_error(stmt, SQL_HANDLE_STMT);
718         throw IOException("ODBC StorageService failed to delete record.");
719     }
720
721     return true;
722 }
723
724
725 void ODBCStorageService::cleanup()
726 {
727 #ifdef _DEBUG
728     xmltooling::NDC ndc("cleanup");
729 #endif
730
731     Mutex* mutex = Mutex::create();
732
733     mutex->lock();
734
735     m_log.info("cleanup thread started... running every %d secs", m_cleanupInterval);
736
737     while (!shutdown) {
738         shutdown_wait->timedwait(mutex, m_cleanupInterval);
739         if (shutdown)
740             break;
741         try {
742             reap(NULL);
743         }
744         catch (exception& ex) {
745             m_log.error("cleanup thread swallowed exception: %s", ex.what());
746         }
747     }
748
749     m_log.info("cleanup thread exiting...");
750
751     mutex->unlock();
752     delete mutex;
753     Thread::exit(NULL);
754 }
755
756 void* ODBCStorageService::cleanup_fn(void* cache_p)
757 {
758   ODBCStorageService* cache = (ODBCStorageService*)cache_p;
759
760 #ifndef WIN32
761   // First, let's block all signals
762   Thread::mask_all_signals();
763 #endif
764
765   // Now run the cleanup process.
766   cache->cleanup();
767   return NULL;
768 }
769
770 void ODBCStorageService::updateContext(const char *table, const char* context, time_t expiration)
771 {
772 #ifdef _DEBUG
773     xmltooling::NDC ndc("updateContext");
774 #endif
775
776     // Get statement handle.
777     ODBCConn conn(getHDBC());
778     SQLHSTMT stmt = getHSTMT(conn);
779
780     char timebuf[32];
781     timestampFromTime(expiration, timebuf);
782
783     char nowbuf[32];
784     timestampFromTime(time(NULL), nowbuf);
785
786     char *scontext = makeSafeSQL(context);
787     string q("UPDATE ");
788     q = q + table + " SET expires = " + timebuf + " WHERE context='" + scontext + "' AND expires > " + nowbuf;
789     freeSafeSQL(scontext, context);
790
791     m_log.debug("SQL: %s", q.c_str());
792
793     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
794     if ((sr!=SQL_NO_DATA) && !SQL_SUCCEEDED(sr)) {
795         m_log.error("error updating records (t=%s, c=%s)", table, context ? context : "all");
796         log_error(stmt, SQL_HANDLE_STMT);
797         throw IOException("ODBC StorageService failed to update context expiration.");
798     }
799 }
800
801 void ODBCStorageService::reap(const char *table, const char* context)
802 {
803 #ifdef _DEBUG
804     xmltooling::NDC ndc("reap");
805 #endif
806
807     // Get statement handle.
808     ODBCConn conn(getHDBC());
809     SQLHSTMT stmt = getHSTMT(conn);
810
811     // Prepare and execute delete statement.
812     char nowbuf[32];
813     timestampFromTime(time(NULL), nowbuf);
814     string q;
815     if (context) {
816         char *scontext = makeSafeSQL(context);
817         q = string("DELETE FROM ") + table + " WHERE context='" + scontext + "' AND expires <= " + nowbuf;
818         freeSafeSQL(scontext, context);
819     }
820     else {
821         q = string("DELETE FROM ") + table + " WHERE expires <= " + nowbuf;
822     }
823     m_log.debug("SQL: %s", q.c_str());
824
825     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
826     if ((sr!=SQL_NO_DATA) && !SQL_SUCCEEDED(sr)) {
827         m_log.error("error expiring records (t=%s, c=%s)", table, context ? context : "all");
828         log_error(stmt, SQL_HANDLE_STMT);
829         throw IOException("ODBC StorageService failed to purge expired records.");
830     }
831 }
832
833 void ODBCStorageService::deleteContext(const char *table, const char* context)
834 {
835 #ifdef _DEBUG
836     xmltooling::NDC ndc("deleteContext");
837 #endif
838
839     // Get statement handle.
840     ODBCConn conn(getHDBC());
841     SQLHSTMT stmt = getHSTMT(conn);
842
843     // Prepare and execute delete statement.
844     char *scontext = makeSafeSQL(context);
845     string q = string("DELETE FROM ") + table + " WHERE context='" + scontext + "'";
846     freeSafeSQL(scontext, context);
847     m_log.debug("SQL: %s", q.c_str());
848
849     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
850     if ((sr!=SQL_NO_DATA) && !SQL_SUCCEEDED(sr)) {
851         m_log.error("error deleting context (t=%s, c=%s)", table, context);
852         log_error(stmt, SQL_HANDLE_STMT);
853         throw IOException("ODBC StorageService failed to delete context.");
854     }
855 }
856
857 extern "C" int ODBCSTORE_EXPORTS xmltooling_extension_init(void*)
858 {
859     // Register this SS type
860     XMLToolingConfig::getConfig().StorageServiceManager.registerFactory("ODBC", ODBCStorageServiceFactory);
861     return 0;
862 }
863
864 extern "C" void ODBCSTORE_EXPORTS xmltooling_extension_term()
865 {
866     XMLToolingConfig::getConfig().StorageServiceManager.deregisterFactory("ODBC");
867 }