Add retry feature for insert and update operations.
[shibboleth/sp.git] / odbc-store / odbc-store.cpp
1 /*
2  *  Copyright 2001-2007 Internet2
3  * 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /**
18  * odbc-store.cpp
19  *
20  * Storage Service using ODBC
21  */
22
23 #if defined (_MSC_VER) || defined(__BORLANDC__)
24 # include "config_win32.h"
25 #else
26 # include "config.h"
27 #endif
28
29 #ifdef WIN32
30 # define _CRT_NONSTDC_NO_DEPRECATE 1
31 # define _CRT_SECURE_NO_DEPRECATE 1
32 #endif
33
34 #ifdef WIN32
35 # define ODBCSTORE_EXPORTS __declspec(dllexport)
36 #else
37 # define ODBCSTORE_EXPORTS
38 #endif
39
40 #include <xercesc/util/XMLUniDefs.hpp>
41 #include <xmltooling/logging.h>
42 #include <xmltooling/XMLToolingConfig.h>
43 #include <xmltooling/util/NDC.h>
44 #include <xmltooling/util/StorageService.h>
45 #include <xmltooling/util/Threads.h>
46 #include <xmltooling/util/XMLHelper.h>
47
48 #include <sql.h>
49 #include <sqlext.h>
50
51 using namespace xmltooling::logging;
52 using namespace xmltooling;
53 using namespace xercesc;
54 using namespace std;
55
56 #define PLUGIN_VER_MAJOR 1
57 #define PLUGIN_VER_MINOR 0
58
59 #define LONGDATA_BUFLEN 16384
60
61 #define COLSIZE_CONTEXT 255
62 #define COLSIZE_ID 255
63 #define COLSIZE_STRING_VALUE 255
64
65 #define STRING_TABLE "strings"
66 #define TEXT_TABLE "texts"
67
68 /* table definitions
69 CREATE TABLE version (
70     major tinyint NOT NULL,
71     minor tinyint NOT NULL
72     )
73
74 CREATE TABLE strings (
75     context varchar(255) not null,
76     id varchar(255) not null,
77     expires datetime not null,
78     version smallint not null,
79     value varchar(255) not null,
80     PRIMARY KEY (context, id)
81     )
82
83 CREATE TABLE texts (
84     context varchar(255) not null,
85     id varchar(255) not null,
86     expires datetime not null,
87     version smallint not null,
88     value text not null,
89     PRIMARY KEY (context, id)
90     )
91 */
92
93 namespace {
94     static const XMLCh cleanupInterval[] =  UNICODE_LITERAL_15(c,l,e,a,n,u,p,I,n,t,e,r,v,a,l);
95     static const XMLCh isolationLevel[] =   UNICODE_LITERAL_14(i,s,o,l,a,t,i,o,n,L,e,v,e,l);
96     static const XMLCh ConnectionString[] = UNICODE_LITERAL_16(C,o,n,n,e,c,t,i,o,n,S,t,r,i,n,g);
97     static const XMLCh RetryOnError[] =     UNICODE_LITERAL_12(R,e,t,r,y,O,n,E,r,r,o,r);
98
99     // RAII for ODBC handles
100     struct ODBCConn {
101         ODBCConn(SQLHDBC conn) : handle(conn), autoCommit(true) {}
102         ~ODBCConn() {
103             SQLRETURN sr = SQL_SUCCESS;
104             if (!autoCommit)
105                 sr = SQLSetConnectAttr(handle, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)SQL_AUTOCOMMIT_ON, NULL);
106             SQLDisconnect(handle);
107             SQLFreeHandle(SQL_HANDLE_DBC,handle);
108             if (!SQL_SUCCEEDED(sr))
109                 throw IOException("Failed to commit connection and return to auto-commit mode.");
110         }
111         operator SQLHDBC() {return handle;}
112         SQLHDBC handle;
113         bool autoCommit;
114     };
115
116     class ODBCStorageService : public StorageService
117     {
118     public:
119         ODBCStorageService(const DOMElement* e);
120         virtual ~ODBCStorageService();
121
122         bool createString(const char* context, const char* key, const char* value, time_t expiration) {
123             return createRow(STRING_TABLE, context, key, value, expiration);
124         }
125         int readString(const char* context, const char* key, string* pvalue=NULL, time_t* pexpiration=NULL, int version=0) {
126             return readRow(STRING_TABLE, context, key, pvalue, pexpiration, version, false);
127         }
128         int updateString(const char* context, const char* key, const char* value=NULL, time_t expiration=0, int version=0) {
129             return updateRow(STRING_TABLE, context, key, value, expiration, version);
130         }
131         bool deleteString(const char* context, const char* key) {
132             return deleteRow(STRING_TABLE, context, key);
133         }
134
135         bool createText(const char* context, const char* key, const char* value, time_t expiration) {
136             return createRow(TEXT_TABLE, context, key, value, expiration);
137         }
138         int readText(const char* context, const char* key, string* pvalue=NULL, time_t* pexpiration=NULL, int version=0) {
139             return readRow(TEXT_TABLE, context, key, pvalue, pexpiration, version, true);
140         }
141         int updateText(const char* context, const char* key, const char* value=NULL, time_t expiration=0, int version=0) {
142             return updateRow(TEXT_TABLE, context, key, value, expiration, version);
143         }
144         bool deleteText(const char* context, const char* key) {
145             return deleteRow(TEXT_TABLE, context, key);
146         }
147
148         void reap(const char* context) {
149             reap(STRING_TABLE, context);
150             reap(TEXT_TABLE, context);
151         }
152
153         void updateContext(const char* context, time_t expiration) {
154             updateContext(STRING_TABLE, context, expiration);
155             updateContext(TEXT_TABLE, context, expiration);
156         }
157
158         void deleteContext(const char* context) {
159             deleteContext(STRING_TABLE, context);
160             deleteContext(TEXT_TABLE, context);
161         }
162          
163
164     private:
165         bool createRow(const char *table, const char* context, const char* key, const char* value, time_t expiration);
166         int readRow(const char *table, const char* context, const char* key, string* pvalue, time_t* pexpiration, int version, bool text);
167         int updateRow(const char *table, const char* context, const char* key, const char* value, time_t expiration, int version);
168         bool deleteRow(const char *table, const char* context, const char* key);
169
170         void reap(const char* table, const char* context);
171         void updateContext(const char* table, const char* context, time_t expiration);
172         void deleteContext(const char* table, const char* context);
173
174         SQLHDBC getHDBC();
175         SQLHSTMT getHSTMT(SQLHDBC);
176         pair<int,int> getVersion(SQLHDBC);
177         pair<bool,bool> log_error(SQLHANDLE handle, SQLSMALLINT htype, const char* checkfor=NULL);
178
179         static void* cleanup_fn(void*); 
180         void cleanup();
181
182         Category& m_log;
183         int m_cleanupInterval;
184         CondWait* shutdown_wait;
185         Thread* cleanup_thread;
186         bool shutdown;
187
188         SQLHENV m_henv;
189         string m_connstring;
190         long m_isolation;
191         vector<SQLINTEGER> m_retries;
192     };
193
194     StorageService* ODBCStorageServiceFactory(const DOMElement* const & e)
195     {
196         return new ODBCStorageService(e);
197     }
198
199     // convert SQL timestamp to time_t 
200     time_t timeFromTimestamp(SQL_TIMESTAMP_STRUCT expires)
201     {
202         time_t ret;
203         struct tm t;
204         t.tm_sec=expires.second;
205         t.tm_min=expires.minute;
206         t.tm_hour=expires.hour;
207         t.tm_mday=expires.day;
208         t.tm_mon=expires.month-1;
209         t.tm_year=expires.year-1900;
210         t.tm_isdst=0;
211 #if defined(HAVE_TIMEGM)
212         ret = timegm(&t);
213 #else
214         ret = mktime(&t) - timezone;
215 #endif
216         return (ret);
217     }
218
219     // conver time_t to SQL string
220     void timestampFromTime(time_t t, char* ret)
221     {
222 #ifdef HAVE_GMTIME_R
223         struct tm res;
224         struct tm* ptime=gmtime_r(&t,&res);
225 #else
226         struct tm* ptime=gmtime(&t);
227 #endif
228         strftime(ret,32,"{ts '%Y-%m-%d %H:%M:%S'}",ptime);
229     }
230
231     // make a string safe for SQL command
232     // result to be free'd only if it isn't the input
233     static char *makeSafeSQL(const char *src)
234     {
235        int ns = 0;
236        int nc = 0;
237        char *s;
238     
239        // see if any conversion needed
240        for (s=(char*)src; *s; nc++,s++) if (*s=='\'') ns++;
241        if (ns==0) return ((char*)src);
242     
243        char *safe = new char[(nc+2*ns+1)];
244        for (s=safe; *src; src++) {
245            if (*src=='\'') *s++ = '\'';
246            *s++ = (char)*src;
247        }
248        *s = '\0';
249        return (safe);
250     }
251
252     void freeSafeSQL(char *safe, const char *src)
253     {
254         if (safe!=src)
255             delete[](safe);
256     }
257 };
258
259 ODBCStorageService::ODBCStorageService(const DOMElement* e) : m_log(Category::getInstance("XMLTooling.StorageService")),
260    m_cleanupInterval(900), shutdown_wait(NULL), cleanup_thread(NULL), shutdown(false), m_henv(SQL_NULL_HANDLE), m_isolation(SQL_TXN_SERIALIZABLE)
261 {
262 #ifdef _DEBUG
263     xmltooling::NDC ndc("ODBCStorageService");
264 #endif
265
266     const XMLCh* tag=e ? e->getAttributeNS(NULL,cleanupInterval) : NULL;
267     if (tag && *tag)
268         m_cleanupInterval = XMLString::parseInt(tag);
269     if (!m_cleanupInterval)
270         m_cleanupInterval = 900;
271
272     auto_ptr_char iso(e ? e->getAttributeNS(NULL,isolationLevel) : NULL);
273     if (iso.get() && *iso.get()) {
274         if (!strcmp(iso.get(),"SERIALIZABLE"))
275             m_isolation = SQL_TXN_SERIALIZABLE;
276         else if (!strcmp(iso.get(),"REPEATABLE_READ"))
277             m_isolation = SQL_TXN_REPEATABLE_READ;
278         else if (!strcmp(iso.get(),"READ_COMMITTED"))
279             m_isolation = SQL_TXN_READ_COMMITTED;
280         else if (!strcmp(iso.get(),"READ_UNCOMMITTED"))
281             m_isolation = SQL_TXN_READ_UNCOMMITTED;
282         else
283             throw XMLToolingException("Unknown transaction isolationLevel property.");
284     }
285
286     if (m_henv == SQL_NULL_HANDLE) {
287         // Enable connection pooling.
288         SQLSetEnvAttr(SQL_NULL_HANDLE, SQL_ATTR_CONNECTION_POOLING, (void*)SQL_CP_ONE_PER_HENV, 0);
289
290         // Allocate the environment.
291         if (!SQL_SUCCEEDED(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &m_henv)))
292             throw XMLToolingException("ODBC failed to initialize.");
293
294         // Specify ODBC 3.x
295         SQLSetEnvAttr(m_henv, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0);
296
297         m_log.info("ODBC initialized");
298     }
299
300     // Grab connection string from the configuration.
301     e = e ? XMLHelper::getFirstChildElement(e,ConnectionString) : NULL;
302     if (!e || !e->hasChildNodes()) {
303         SQLFreeHandle(SQL_HANDLE_ENV, m_henv);
304         throw XMLToolingException("ODBC StorageService requires ConnectionString element in configuration.");
305     }
306     auto_ptr_char arg(e->getFirstChild()->getNodeValue());
307     m_connstring=arg.get();
308
309     // Connect and check version.
310     ODBCConn conn(getHDBC());
311     pair<int,int> v=getVersion(conn);
312
313     // Make sure we've got the right version.
314     if (v.first != PLUGIN_VER_MAJOR) {
315         SQLFreeHandle(SQL_HANDLE_ENV, m_henv);
316         m_log.crit("unknown database version: %d.%d", v.first, v.second);
317         throw XMLToolingException("Unknown database version for ODBC StorageService.");
318     }
319
320     // Load any retry errors to check.
321     e = XMLHelper::getNextSiblingElement(e,RetryOnError);
322     while (e) {
323         if (e->hasChildNodes()) {
324             m_retries.push_back(XMLString::parseInt(e->getFirstChild()->getNodeValue()));
325             m_log.info("will retry operations when native ODBC error (%ld) is returned", m_retries.back());
326         }
327         e = XMLHelper::getNextSiblingElement(e,RetryOnError);
328     }
329
330     // Initialize the cleanup thread
331     shutdown_wait = CondWait::create();
332     cleanup_thread = Thread::create(&cleanup_fn, (void*)this);
333 }
334
335 ODBCStorageService::~ODBCStorageService()
336 {
337     shutdown = true;
338     shutdown_wait->signal();
339     cleanup_thread->join(NULL);
340     delete shutdown_wait;
341     if (m_henv != SQL_NULL_HANDLE)
342         SQLFreeHandle(SQL_HANDLE_ENV, m_henv);
343 }
344
345 pair<bool,bool> ODBCStorageService::log_error(SQLHANDLE handle, SQLSMALLINT htype, const char* checkfor)
346 {
347     SQLSMALLINT  i = 0;
348     SQLINTEGER   native;
349     SQLCHAR      state[7];
350     SQLCHAR      text[256];
351     SQLSMALLINT  len;
352     SQLRETURN    ret;
353
354     pair<bool,bool> res = make_pair(false,false);
355     do {
356         ret = SQLGetDiagRec(htype, handle, ++i, state, &native, text, sizeof(text), &len);
357         if (SQL_SUCCEEDED(ret)) {
358             m_log.error("ODBC Error: %s:%ld:%ld:%s", state, i, native, text);
359             for (vector<SQLINTEGER>::const_iterator n = m_retries.begin(); !res.first && n != m_retries.end(); ++n)
360                 res.first = (*n == native);
361             if (checkfor && !strcmp(checkfor, (const char*)state))
362                 res.second = true;
363         }
364     } while(SQL_SUCCEEDED(ret));
365     return res;
366 }
367
368 SQLHDBC ODBCStorageService::getHDBC()
369 {
370 #ifdef _DEBUG
371     xmltooling::NDC ndc("getHDBC");
372 #endif
373
374     // Get a handle.
375     SQLHDBC handle;
376     SQLRETURN sr=SQLAllocHandle(SQL_HANDLE_DBC, m_henv, &handle);
377     if (!SQL_SUCCEEDED(sr)) {
378         m_log.error("failed to allocate connection handle");
379         log_error(m_henv, SQL_HANDLE_ENV);
380         throw IOException("ODBC StorageService failed to allocate a connection handle.");
381     }
382
383     sr=SQLDriverConnect(handle,NULL,(SQLCHAR*)m_connstring.c_str(),m_connstring.length(),NULL,0,NULL,SQL_DRIVER_NOPROMPT);
384     if (!SQL_SUCCEEDED(sr)) {
385         m_log.error("failed to connect to database");
386         log_error(handle, SQL_HANDLE_DBC);
387         throw IOException("ODBC StorageService failed to connect to database.");
388     }
389
390     sr = SQLSetConnectAttr(handle, SQL_ATTR_TXN_ISOLATION, (SQLPOINTER)m_isolation, NULL);
391     if (!SQL_SUCCEEDED(sr))
392         throw IOException("ODBC StorageService failed to set transaction isolation level.");
393
394     return handle;
395 }
396
397 SQLHSTMT ODBCStorageService::getHSTMT(SQLHDBC conn)
398 {
399     SQLHSTMT hstmt;
400     SQLRETURN sr=SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
401     if (!SQL_SUCCEEDED(sr)) {
402         m_log.error("failed to allocate statement handle");
403         log_error(conn, SQL_HANDLE_DBC);
404         throw IOException("ODBC StorageService failed to allocate a statement handle.");
405     }
406     return hstmt;
407 }
408
409 pair<int,int> ODBCStorageService::getVersion(SQLHDBC conn)
410 {
411     // Grab the version number from the database.
412     SQLHSTMT stmt = getHSTMT(conn);
413     
414     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)"SELECT major,minor FROM version", SQL_NTS);
415     if (!SQL_SUCCEEDED(sr)) {
416         m_log.error("failed to read version from database");
417         log_error(stmt, SQL_HANDLE_STMT);
418         throw IOException("ODBC StorageService failed to read version from database.");
419     }
420
421     SQLINTEGER major;
422     SQLINTEGER minor;
423     SQLBindCol(stmt,1,SQL_C_SLONG,&major,0,NULL);
424     SQLBindCol(stmt,2,SQL_C_SLONG,&minor,0,NULL);
425
426     if ((sr=SQLFetch(stmt)) != SQL_NO_DATA)
427         return pair<int,int>(major,minor);
428
429     m_log.error("no rows returned in version query");
430     throw IOException("ODBC StorageService failed to read version from database.");
431 }
432
433 bool ODBCStorageService::createRow(const char* table, const char* context, const char* key, const char* value, time_t expiration)
434 {
435 #ifdef _DEBUG
436     xmltooling::NDC ndc("createRow");
437 #endif
438
439     char timebuf[32];
440     timestampFromTime(expiration, timebuf);
441
442     // Get statement handle.
443     ODBCConn conn(getHDBC());
444     SQLHSTMT stmt = getHSTMT(conn);
445
446     // Prepare and exectute insert statement.
447     //char *scontext = makeSafeSQL(context);
448     //char *skey = makeSafeSQL(key);
449     //char *svalue = makeSafeSQL(value);
450     string q  = string("INSERT INTO ") + table + " VALUES (?,?," + timebuf + ",1,?)";
451
452     SQLRETURN sr = SQLPrepare(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
453     if (!SQL_SUCCEEDED(sr)) {
454         m_log.error("SQLPrepare failed (t=%s, c=%s, k=%s)", table, context, key);
455         log_error(stmt, SQL_HANDLE_STMT);
456         throw IOException("ODBC StorageService failed to insert record.");
457     }
458     m_log.debug("SQLPrepare succeded. SQL: %s", q.c_str());
459
460     SQLINTEGER b_ind = SQL_NTS;
461     sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(context), &b_ind);
462     if (!SQL_SUCCEEDED(sr)) {
463         m_log.error("SQLBindParam failed (context = %s)", context);
464         log_error(stmt, SQL_HANDLE_STMT);
465         throw IOException("ODBC StorageService failed to insert record.");
466     }
467     m_log.debug("SQLBindParam succeded (context = %s)", context);
468
469     sr = SQLBindParam(stmt, 2, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(key), &b_ind);
470     if (!SQL_SUCCEEDED(sr)) {
471         m_log.error("SQLBindParam failed (key = %s)", key);
472         log_error(stmt, SQL_HANDLE_STMT);
473         throw IOException("ODBC StorageService failed to insert record.");
474     }
475     m_log.debug("SQLBindParam succeded (key = %s)", key);
476
477     if (strcmp(table, TEXT_TABLE)==0)
478         sr = SQLBindParam(stmt, 3, SQL_C_CHAR, SQL_LONGVARCHAR, strlen(value), 0, const_cast<char*>(value), &b_ind);
479     else
480         sr = SQLBindParam(stmt, 3, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(value), &b_ind);
481     if (!SQL_SUCCEEDED(sr)) {
482         m_log.error("SQLBindParam failed (value = %s)", value);
483         log_error(stmt, SQL_HANDLE_STMT);
484         throw IOException("ODBC StorageService failed to insert record.");
485     }
486     m_log.debug("SQLBindParam succeded (value = %s)", value);
487     
488     //freeSafeSQL(scontext, context);
489     //freeSafeSQL(skey, key);
490     //freeSafeSQL(svalue, value);
491     //m_log.debug("SQL: %s", q.c_str());
492
493     int attempts = 3;
494     pair<bool,bool> logres;
495     do {
496         logres = make_pair(false,false);
497         attempts--;
498         sr=SQLExecute(stmt);
499         if (SQL_SUCCEEDED(sr)) {
500             m_log.debug("SQLExecute of insert succeeded");
501             return true;
502         }
503         m_log.error("insert record failed (t=%s, c=%s, k=%s)", table, context, key);
504         logres = log_error(stmt, SQL_HANDLE_STMT, "23000");
505         if (logres.second)
506             return false;   // supposedly integrity violation?
507     } while (attempts && logres.first);
508
509     throw IOException("ODBC StorageService failed to insert record.");
510 }
511
512 int ODBCStorageService::readRow(
513     const char *table, const char* context, const char* key, string* pvalue, time_t* pexpiration, int version, bool text
514     )
515 {
516 #ifdef _DEBUG
517     xmltooling::NDC ndc("readRow");
518 #endif
519
520     // Get statement handle.
521     ODBCConn conn(getHDBC());
522     SQLHSTMT stmt = getHSTMT(conn);
523
524     // Prepare and exectute select statement.
525     char timebuf[32];
526     timestampFromTime(time(NULL), timebuf);
527     char *scontext = makeSafeSQL(context);
528     char *skey = makeSafeSQL(key);
529     ostringstream q;
530     q << "SELECT version";
531     if (pexpiration)
532         q << ",expires";
533     if (pvalue)
534         q << ",CASE version WHEN " << version << " THEN NULL ELSE value END";
535     q << " FROM " << table << " WHERE context='" << scontext << "' AND id='" << skey << "' AND expires > " << timebuf;
536     freeSafeSQL(scontext, context);
537     freeSafeSQL(skey, key);
538     if (m_log.isDebugEnabled())
539         m_log.debug("SQL: %s", q.str().c_str());
540
541     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.str().c_str(), SQL_NTS);
542     if (!SQL_SUCCEEDED(sr)) {
543         m_log.error("error searching for (t=%s, c=%s, k=%s)", table, context, key);
544         log_error(stmt, SQL_HANDLE_STMT);
545         throw IOException("ODBC StorageService search failed.");
546     }
547
548     SQLSMALLINT ver;
549     SQL_TIMESTAMP_STRUCT expiration;
550
551     SQLBindCol(stmt,1,SQL_C_SSHORT,&ver,0,NULL);
552     if (pexpiration)
553         SQLBindCol(stmt,2,SQL_C_TYPE_TIMESTAMP,&expiration,0,NULL);
554
555     if ((sr=SQLFetch(stmt)) == SQL_NO_DATA)
556         return 0;
557
558     if (pexpiration)
559         *pexpiration = timeFromTimestamp(expiration);
560
561     if (version == ver)
562         return version; // nothing's changed, so just echo back the version
563
564     if (pvalue) {
565         SQLINTEGER len;
566         SQLCHAR buf[LONGDATA_BUFLEN];
567         while ((sr=SQLGetData(stmt,pexpiration ? 3 : 2,SQL_C_CHAR,buf,sizeof(buf),&len)) != SQL_NO_DATA) {
568             if (!SQL_SUCCEEDED(sr)) {
569                 m_log.error("error while reading text field from result set");
570                 log_error(stmt, SQL_HANDLE_STMT);
571                 throw IOException("ODBC StorageService search failed to read data from result set.");
572             }
573             pvalue->append((char*)buf);
574         }
575     }
576     
577     return ver;
578 }
579
580 int ODBCStorageService::updateRow(const char *table, const char* context, const char* key, const char* value, time_t expiration, int version)
581 {
582 #ifdef _DEBUG
583     xmltooling::NDC ndc("updateRow");
584 #endif
585
586     if (!value && !expiration)
587         throw IOException("ODBC StorageService given invalid update instructions.");
588
589     // Get statement handle. Disable auto-commit mode to wrap select + update.
590     ODBCConn conn(getHDBC());
591     SQLRETURN sr = SQLSetConnectAttr(conn, SQL_ATTR_AUTOCOMMIT, SQL_AUTOCOMMIT_OFF, NULL);
592     if (!SQL_SUCCEEDED(sr))
593         throw IOException("ODBC StorageService failed to disable auto-commit mode.");
594     conn.autoCommit = false;
595     SQLHSTMT stmt = getHSTMT(conn);
596
597     // First, fetch the current version for later, which also ensures the record still exists.
598     char timebuf[32];
599     timestampFromTime(time(NULL), timebuf);
600     char *scontext = makeSafeSQL(context);
601     char *skey = makeSafeSQL(key);
602     string q("SELECT version FROM ");
603     q = q + table + " WHERE context='" + scontext + "' AND id='" + skey + "' AND expires > " + timebuf;
604
605     m_log.debug("SQL: %s", q.c_str());
606
607     sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
608     if (!SQL_SUCCEEDED(sr)) {
609         freeSafeSQL(scontext, context);
610         freeSafeSQL(skey, key);
611         m_log.error("error searching for (t=%s, c=%s, k=%s)", table, context, key);
612         log_error(stmt, SQL_HANDLE_STMT);
613         throw IOException("ODBC StorageService search failed.");
614     }
615
616     SQLSMALLINT ver;
617     SQLBindCol(stmt,1,SQL_C_SSHORT,&ver,0,NULL);
618     if ((sr=SQLFetch(stmt)) == SQL_NO_DATA) {
619         freeSafeSQL(scontext, context);
620         freeSafeSQL(skey, key);
621         return 0;
622     }
623
624     // Check version?
625     if (version > 0 && version != ver) {
626         freeSafeSQL(scontext, context);
627         freeSafeSQL(skey, key);
628         return -1;
629     }
630
631     SQLFreeHandle(SQL_HANDLE_STMT, stmt);
632     stmt = getHSTMT(conn);
633
634     // Prepare and exectute update statement.
635     q = string("UPDATE ") + table + " SET ";
636
637     if (value)
638         q = q + "value=?, version=version+1";
639
640     if (expiration) {
641         timestampFromTime(expiration, timebuf);
642         if (value)
643             q += ',';
644         q = q + "expires = " + timebuf;
645     }
646
647     q = q + " WHERE context='" + scontext + "' AND id='" + skey + "'";
648     freeSafeSQL(scontext, context);
649     freeSafeSQL(skey, key);
650
651     sr = SQLPrepare(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
652     if (!SQL_SUCCEEDED(sr)) {
653         m_log.error("update of record failed (t=%s, c=%s, k=%s", table, context, key);
654         log_error(stmt, SQL_HANDLE_STMT);
655         throw IOException("ODBC StorageService failed to update record.");
656     }
657     m_log.debug("SQLPrepare succeded. SQL: %s", q.c_str());
658
659     SQLINTEGER b_ind = SQL_NTS;
660     if (value) {
661         if (strcmp(table, TEXT_TABLE)==0)
662             sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_LONGVARCHAR, strlen(value), 0, const_cast<char*>(value), &b_ind);
663         else
664             sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(value), &b_ind);
665         if (!SQL_SUCCEEDED(sr)) {
666             m_log.error("SQLBindParam failed (context = %s)", context);
667             log_error(stmt, SQL_HANDLE_STMT);
668             throw IOException("ODBC StorageService failed to update record.");
669         }
670         m_log.debug("SQLBindParam succeded (context = %s)", context);
671     }
672
673     int attempts = 3;
674     pair<bool,bool> logres;
675     do {
676         logres = make_pair(false,false);
677         attempts--;
678         sr=SQLExecute(stmt);
679         if (sr==SQL_NO_DATA)
680             return 0;   // went missing?
681         else if (SQL_SUCCEEDED(sr)) {
682             m_log.debug("SQLExecute of update succeeded");
683             return ver + 1;
684         }
685
686         m_log.error("update of record failed (t=%s, c=%s, k=%s", table, context, key);
687         logres = log_error(stmt, SQL_HANDLE_STMT);
688     } while (attempts && logres.first);
689
690     throw IOException("ODBC StorageService failed to update record.");
691 }
692
693 bool ODBCStorageService::deleteRow(const char *table, const char *context, const char* key)
694 {
695 #ifdef _DEBUG
696     xmltooling::NDC ndc("deleteRow");
697 #endif
698
699     // Get statement handle.
700     ODBCConn conn(getHDBC());
701     SQLHSTMT stmt = getHSTMT(conn);
702
703     // Prepare and execute delete statement.
704     char *scontext = makeSafeSQL(context);
705     char *skey = makeSafeSQL(key);
706     string q = string("DELETE FROM ") + table + " WHERE context='" + scontext + "' AND id='" + skey + "'";
707     freeSafeSQL(scontext, context);
708     freeSafeSQL(skey, key);
709     m_log.debug("SQL: %s", q.c_str());
710
711     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
712      if (sr==SQL_NO_DATA)
713         return false;
714     else if (!SQL_SUCCEEDED(sr)) {
715         m_log.error("error deleting record (t=%s, c=%s, k=%s)", table, context, key);
716         log_error(stmt, SQL_HANDLE_STMT);
717         throw IOException("ODBC StorageService failed to delete record.");
718     }
719
720     return true;
721 }
722
723
724 void ODBCStorageService::cleanup()
725 {
726 #ifdef _DEBUG
727     xmltooling::NDC ndc("cleanup");
728 #endif
729
730     Mutex* mutex = Mutex::create();
731
732     mutex->lock();
733
734     m_log.info("cleanup thread started... running every %d secs", m_cleanupInterval);
735
736     while (!shutdown) {
737         shutdown_wait->timedwait(mutex, m_cleanupInterval);
738         if (shutdown)
739             break;
740         try {
741             reap(NULL);
742         }
743         catch (exception& ex) {
744             m_log.error("cleanup thread swallowed exception: %s", ex.what());
745         }
746     }
747
748     m_log.info("cleanup thread exiting...");
749
750     mutex->unlock();
751     delete mutex;
752     Thread::exit(NULL);
753 }
754
755 void* ODBCStorageService::cleanup_fn(void* cache_p)
756 {
757   ODBCStorageService* cache = (ODBCStorageService*)cache_p;
758
759 #ifndef WIN32
760   // First, let's block all signals
761   Thread::mask_all_signals();
762 #endif
763
764   // Now run the cleanup process.
765   cache->cleanup();
766   return NULL;
767 }
768
769 void ODBCStorageService::updateContext(const char *table, const char* context, time_t expiration)
770 {
771 #ifdef _DEBUG
772     xmltooling::NDC ndc("updateContext");
773 #endif
774
775     // Get statement handle.
776     ODBCConn conn(getHDBC());
777     SQLHSTMT stmt = getHSTMT(conn);
778
779     char timebuf[32];
780     timestampFromTime(expiration, timebuf);
781
782     char nowbuf[32];
783     timestampFromTime(time(NULL), nowbuf);
784
785     char *scontext = makeSafeSQL(context);
786     string q("UPDATE ");
787     q = q + table + " SET expires = " + timebuf + " WHERE context='" + scontext + "' AND expires > " + nowbuf;
788     freeSafeSQL(scontext, context);
789
790     m_log.debug("SQL: %s", q.c_str());
791
792     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
793     if ((sr!=SQL_NO_DATA) && !SQL_SUCCEEDED(sr)) {
794         m_log.error("error updating records (t=%s, c=%s)", table, context ? context : "all");
795         log_error(stmt, SQL_HANDLE_STMT);
796         throw IOException("ODBC StorageService failed to update context expiration.");
797     }
798 }
799
800 void ODBCStorageService::reap(const char *table, const char* context)
801 {
802 #ifdef _DEBUG
803     xmltooling::NDC ndc("reap");
804 #endif
805
806     // Get statement handle.
807     ODBCConn conn(getHDBC());
808     SQLHSTMT stmt = getHSTMT(conn);
809
810     // Prepare and execute delete statement.
811     char nowbuf[32];
812     timestampFromTime(time(NULL), nowbuf);
813     string q;
814     if (context) {
815         char *scontext = makeSafeSQL(context);
816         q = string("DELETE FROM ") + table + " WHERE context='" + scontext + "' AND expires <= " + nowbuf;
817         freeSafeSQL(scontext, context);
818     }
819     else {
820         q = string("DELETE FROM ") + table + " WHERE expires <= " + nowbuf;
821     }
822     m_log.debug("SQL: %s", q.c_str());
823
824     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
825     if ((sr!=SQL_NO_DATA) && !SQL_SUCCEEDED(sr)) {
826         m_log.error("error expiring records (t=%s, c=%s)", table, context ? context : "all");
827         log_error(stmt, SQL_HANDLE_STMT);
828         throw IOException("ODBC StorageService failed to purge expired records.");
829     }
830 }
831
832 void ODBCStorageService::deleteContext(const char *table, const char* context)
833 {
834 #ifdef _DEBUG
835     xmltooling::NDC ndc("deleteContext");
836 #endif
837
838     // Get statement handle.
839     ODBCConn conn(getHDBC());
840     SQLHSTMT stmt = getHSTMT(conn);
841
842     // Prepare and execute delete statement.
843     char *scontext = makeSafeSQL(context);
844     string q = string("DELETE FROM ") + table + " WHERE context='" + scontext + "'";
845     freeSafeSQL(scontext, context);
846     m_log.debug("SQL: %s", q.c_str());
847
848     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
849     if ((sr!=SQL_NO_DATA) && !SQL_SUCCEEDED(sr)) {
850         m_log.error("error deleting context (t=%s, c=%s)", table, context);
851         log_error(stmt, SQL_HANDLE_STMT);
852         throw IOException("ODBC StorageService failed to delete context.");
853     }
854 }
855
856 extern "C" int ODBCSTORE_EXPORTS xmltooling_extension_init(void*)
857 {
858     // Register this SS type
859     XMLToolingConfig::getConfig().StorageServiceManager.registerFactory("ODBC", ODBCStorageServiceFactory);
860     return 0;
861 }
862
863 extern "C" void ODBCSTORE_EXPORTS xmltooling_extension_term()
864 {
865     XMLToolingConfig::getConfig().StorageServiceManager.deregisterFactory("ODBC");
866 }