Check for null env handle.
[shibboleth/sp.git] / odbc-store / odbc-store.cpp
1 /*
2  *  Copyright 2001-2007 Internet2
3  * 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /**
18  * odbc-store.cpp
19  *
20  * Storage Service using ODBC
21  */
22
23 #if defined (_MSC_VER) || defined(__BORLANDC__)
24 # include "config_win32.h"
25 #else
26 # include "config.h"
27 #endif
28
29 #ifdef WIN32
30 # define _CRT_NONSTDC_NO_DEPRECATE 1
31 # define _CRT_SECURE_NO_DEPRECATE 1
32 #endif
33
34 #ifdef WIN32
35 # define ODBCSTORE_EXPORTS __declspec(dllexport)
36 #else
37 # define ODBCSTORE_EXPORTS
38 #endif
39
40 #include <xercesc/util/XMLUniDefs.hpp>
41 #include <xmltooling/logging.h>
42 #include <xmltooling/XMLToolingConfig.h>
43 #include <xmltooling/util/NDC.h>
44 #include <xmltooling/util/StorageService.h>
45 #include <xmltooling/util/Threads.h>
46 #include <xmltooling/util/XMLHelper.h>
47
48 #include <sql.h>
49 #include <sqlext.h>
50
51 using namespace xmltooling::logging;
52 using namespace xmltooling;
53 using namespace xercesc;
54 using namespace std;
55
56 #define PLUGIN_VER_MAJOR 1
57 #define PLUGIN_VER_MINOR 0
58
59 #define LONGDATA_BUFLEN 16384
60
61 #define COLSIZE_CONTEXT 255
62 #define COLSIZE_ID 255
63 #define COLSIZE_STRING_VALUE 255
64
65 #define STRING_TABLE "strings"
66 #define TEXT_TABLE "texts"
67
68 /* table definitions
69 CREATE TABLE version (
70     major tinyint NOT NULL,
71     minor tinyint NOT NULL
72     )
73
74 CREATE TABLE strings (
75     context varchar(255) not null,
76     id varchar(255) not null,
77     expires datetime not null,
78     version smallint not null,
79     value varchar(255) not null,
80     PRIMARY KEY (context, id)
81     )
82
83 CREATE TABLE texts (
84     context varchar(255) not null,
85     id varchar(255) not null,
86     expires datetime not null,
87     version smallint not null,
88     value text not null,
89     PRIMARY KEY (context, id)
90     )
91 */
92
93 namespace {
94     static const XMLCh cleanupInterval[] =  UNICODE_LITERAL_15(c,l,e,a,n,u,p,I,n,t,e,r,v,a,l);
95     static const XMLCh ConnectionString[] = UNICODE_LITERAL_16(C,o,n,n,e,c,t,i,o,n,S,t,r,i,n,g);
96
97     // RAII for ODBC handles
98     struct ODBCConn {
99         ODBCConn(SQLHDBC conn) : handle(conn) {}
100         ~ODBCConn() {
101             SQLRETURN sr = SQLEndTran(SQL_HANDLE_DBC, handle, SQL_COMMIT);
102             SQLDisconnect(handle);
103             SQLFreeHandle(SQL_HANDLE_DBC,handle);
104             if (!SQL_SUCCEEDED(sr))
105                 throw IOException("Failed to commit connection.");
106         }
107         operator SQLHDBC() {return handle;}
108         SQLHDBC handle;
109     };
110
111     struct ODBCStatement {
112         ODBCStatement(SQLHSTMT statement) : handle(statement) {}
113         ~ODBCStatement() {SQLFreeHandle(SQL_HANDLE_STMT,handle);}
114         operator SQLHSTMT() {return handle;}
115         SQLHSTMT handle;
116     };
117
118     class ODBCStorageService : public StorageService
119     {
120     public:
121         ODBCStorageService(const DOMElement* e);
122         virtual ~ODBCStorageService();
123
124         bool createString(const char* context, const char* key, const char* value, time_t expiration) {
125             return createRow(STRING_TABLE, context, key, value, expiration);
126         }
127         int readString(const char* context, const char* key, string* pvalue=NULL, time_t* pexpiration=NULL, int version=0) {
128             return readRow(STRING_TABLE, context, key, pvalue, pexpiration, version, false);
129         }
130         int updateString(const char* context, const char* key, const char* value=NULL, time_t expiration=0, int version=0) {
131             return updateRow(STRING_TABLE, context, key, value, expiration, version);
132         }
133         bool deleteString(const char* context, const char* key) {
134             return deleteRow(STRING_TABLE, context, key);
135         }
136
137         bool createText(const char* context, const char* key, const char* value, time_t expiration) {
138             return createRow(TEXT_TABLE, context, key, value, expiration);
139         }
140         int readText(const char* context, const char* key, string* pvalue=NULL, time_t* pexpiration=NULL, int version=0) {
141             return readRow(TEXT_TABLE, context, key, pvalue, pexpiration, version, true);
142         }
143         int updateText(const char* context, const char* key, const char* value=NULL, time_t expiration=0, int version=0) {
144             return updateRow(TEXT_TABLE, context, key, value, expiration, version);
145         }
146         bool deleteText(const char* context, const char* key) {
147             return deleteRow(TEXT_TABLE, context, key);
148         }
149
150         void reap(const char* context) {
151             reap(STRING_TABLE, context);
152             reap(TEXT_TABLE, context);
153         }
154
155         void updateContext(const char* context, time_t expiration) {
156             updateContext(STRING_TABLE, context, expiration);
157             updateContext(TEXT_TABLE, context, expiration);
158         }
159
160         void deleteContext(const char* context) {
161             deleteContext(STRING_TABLE, context);
162             deleteContext(TEXT_TABLE, context);
163         }
164          
165
166     private:
167         bool createRow(const char *table, const char* context, const char* key, const char* value, time_t expiration);
168         int readRow(const char *table, const char* context, const char* key, string* pvalue, time_t* pexpiration, int version, bool text);
169         int updateRow(const char *table, const char* context, const char* key, const char* value, time_t expiration, int version);
170         bool deleteRow(const char *table, const char* context, const char* key);
171
172         void reap(const char* table, const char* context);
173         void updateContext(const char* table, const char* context, time_t expiration);
174         void deleteContext(const char* table, const char* context);
175
176         SQLHDBC getHDBC();
177         SQLHSTMT getHSTMT(SQLHDBC);
178         pair<int,int> getVersion(SQLHDBC);
179         bool log_error(SQLHANDLE handle, SQLSMALLINT htype, const char* checkfor=NULL);
180
181         static void* cleanup_fn(void*); 
182         void cleanup();
183
184         Category& m_log;
185         int m_cleanupInterval;
186         CondWait* shutdown_wait;
187         Thread* cleanup_thread;
188         bool shutdown;
189
190         SQLHENV m_henv;
191         string m_connstring;
192     };
193
194     StorageService* ODBCStorageServiceFactory(const DOMElement* const & e)
195     {
196         return new ODBCStorageService(e);
197     }
198
199     // convert SQL timestamp to time_t 
200     time_t timeFromTimestamp(SQL_TIMESTAMP_STRUCT expires)
201     {
202         time_t ret;
203         struct tm t;
204         t.tm_sec=expires.second;
205         t.tm_min=expires.minute;
206         t.tm_hour=expires.hour;
207         t.tm_mday=expires.day;
208         t.tm_mon=expires.month-1;
209         t.tm_year=expires.year-1900;
210         t.tm_isdst=0;
211 #if defined(HAVE_TIMEGM)
212         ret = timegm(&t);
213 #else
214         ret = mktime(&t) - timezone;
215 #endif
216         return (ret);
217     }
218
219     // conver time_t to SQL string
220     void timestampFromTime(time_t t, char* ret)
221     {
222 #ifdef HAVE_GMTIME_R
223         struct tm res;
224         struct tm* ptime=gmtime_r(&t,&res);
225 #else
226         struct tm* ptime=gmtime(&t);
227 #endif
228         strftime(ret,32,"{ts '%Y-%m-%d %H:%M:%S'}",ptime);
229     }
230
231     // make a string safe for SQL command
232     // result to be free'd only if it isn't the input
233     static char *makeSafeSQL(const char *src)
234     {
235        int ns = 0;
236        int nc = 0;
237        char *s;
238     
239        // see if any conversion needed
240        for (s=(char*)src; *s; nc++,s++) if (*s=='\'') ns++;
241        if (ns==0) return ((char*)src);
242     
243        char *safe = new char[(nc+2*ns+1)];
244        for (s=safe; *src; src++) {
245            if (*src=='\'') *s++ = '\'';
246            *s++ = (char)*src;
247        }
248        *s = '\0';
249        return (safe);
250     }
251
252     void freeSafeSQL(char *safe, const char *src)
253     {
254         if (safe!=src)
255             delete[](safe);
256     }
257 };
258
259 ODBCStorageService::ODBCStorageService(const DOMElement* e) : m_log(Category::getInstance("XMLTooling.StorageService")),
260    m_cleanupInterval(900), shutdown_wait(NULL), cleanup_thread(NULL), shutdown(false), m_henv(SQL_NULL_HANDLE)
261 {
262 #ifdef _DEBUG
263     xmltooling::NDC ndc("ODBCStorageService");
264 #endif
265
266     const XMLCh* tag=e ? e->getAttributeNS(NULL,cleanupInterval) : NULL;
267     if (tag && *tag)
268         m_cleanupInterval = XMLString::parseInt(tag);
269     if (!m_cleanupInterval)
270         m_cleanupInterval = 900;
271
272     if (m_henv == SQL_NULL_HANDLE) {
273         // Enable connection pooling.
274         SQLSetEnvAttr(SQL_NULL_HANDLE, SQL_ATTR_CONNECTION_POOLING, (void*)SQL_CP_ONE_PER_HENV, 0);
275
276         // Allocate the environment.
277         if (!SQL_SUCCEEDED(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &m_henv)))
278             throw XMLToolingException("ODBC failed to initialize.");
279
280         // Specify ODBC 3.x
281         SQLSetEnvAttr(m_henv, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0);
282
283         m_log.info("ODBC initialized");
284     }
285
286     // Grab connection string from the configuration.
287     e = e ? XMLHelper::getFirstChildElement(e,ConnectionString) : NULL;
288     if (!e || !e->hasChildNodes()) {
289         SQLFreeHandle(SQL_HANDLE_ENV, m_henv);
290         throw XMLToolingException("ODBC StorageService requires ConnectionString element in configuration.");
291     }
292     auto_ptr_char arg(e->getFirstChild()->getNodeValue());
293     m_connstring=arg.get();
294
295     // Connect and check version.
296     ODBCConn conn(getHDBC());
297     pair<int,int> v=getVersion(conn);
298
299     // Make sure we've got the right version.
300     if (v.first != PLUGIN_VER_MAJOR) {
301         SQLFreeHandle(SQL_HANDLE_ENV, m_henv);
302         m_log.crit("unknown database version: %d.%d", v.first, v.second);
303         throw XMLToolingException("Unknown database version for ODBC StorageService.");
304     }
305
306     // Initialize the cleanup thread
307     shutdown_wait = CondWait::create();
308     cleanup_thread = Thread::create(&cleanup_fn, (void*)this);
309 }
310
311 ODBCStorageService::~ODBCStorageService()
312 {
313     shutdown = true;
314     shutdown_wait->signal();
315     cleanup_thread->join(NULL);
316     delete shutdown_wait;
317     if (m_henv != SQL_NULL_HANDLE)
318         SQLFreeHandle(SQL_HANDLE_ENV, m_henv);
319 }
320
321 bool ODBCStorageService::log_error(SQLHANDLE handle, SQLSMALLINT htype, const char* checkfor)
322 {
323     SQLSMALLINT  i = 0;
324     SQLINTEGER   native;
325     SQLCHAR      state[7];
326     SQLCHAR      text[256];
327     SQLSMALLINT  len;
328     SQLRETURN    ret;
329
330     bool res = false;
331     do {
332         ret = SQLGetDiagRec(htype, handle, ++i, state, &native, text, sizeof(text), &len);
333         if (SQL_SUCCEEDED(ret)) {
334             m_log.error("ODBC Error: %s:%ld:%ld:%s", state, i, native, text);
335             if (checkfor && !strcmp(checkfor, (const char*)state))
336                 res = true;
337         }
338     } while(SQL_SUCCEEDED(ret));
339     return res;
340 }
341
342 SQLHDBC ODBCStorageService::getHDBC()
343 {
344 #ifdef _DEBUG
345     xmltooling::NDC ndc("getHDBC");
346 #endif
347
348     // Get a handle.
349     SQLHDBC handle;
350     SQLRETURN sr=SQLAllocHandle(SQL_HANDLE_DBC, m_henv, &handle);
351     if (!SQL_SUCCEEDED(sr)) {
352         m_log.error("failed to allocate connection handle");
353         log_error(m_henv, SQL_HANDLE_ENV);
354         throw IOException("ODBC StorageService failed to allocate a connection handle.");
355     }
356
357     sr=SQLDriverConnect(handle,NULL,(SQLCHAR*)m_connstring.c_str(),m_connstring.length(),NULL,0,NULL,SQL_DRIVER_NOPROMPT);
358     if (!SQL_SUCCEEDED(sr)) {
359         m_log.error("failed to connect to database");
360         log_error(handle, SQL_HANDLE_DBC);
361         throw IOException("ODBC StorageService failed to connect to database.");
362     }
363
364     sr = SQLSetConnectAttr(handle, SQL_ATTR_AUTOCOMMIT, SQL_AUTOCOMMIT_OFF, NULL);
365     if (!SQL_SUCCEEDED(sr))
366         throw IOException("ODBC StorageService failed to disable auto-commit mode.");
367     sr = SQLSetConnectAttr(handle, SQL_ATTR_TXN_ISOLATION, (SQLPOINTER)SQL_TXN_SERIALIZABLE, NULL);
368     if (!SQL_SUCCEEDED(sr))
369         throw IOException("ODBC StorageService failed to enable transaction isolation.");
370
371     return handle;
372 }
373
374 SQLHSTMT ODBCStorageService::getHSTMT(SQLHDBC conn)
375 {
376     SQLHSTMT hstmt;
377     SQLRETURN sr=SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
378     if (!SQL_SUCCEEDED(sr)) {
379         m_log.error("failed to allocate statement handle");
380         log_error(conn, SQL_HANDLE_DBC);
381         throw IOException("ODBC StorageService failed to allocate a statement handle.");
382     }
383     return hstmt;
384 }
385
386 pair<int,int> ODBCStorageService::getVersion(SQLHDBC conn)
387 {
388     // Grab the version number from the database.
389     ODBCStatement stmt(getHSTMT(conn));
390     
391     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)"SELECT major,minor FROM version", SQL_NTS);
392     if (!SQL_SUCCEEDED(sr)) {
393         m_log.error("failed to read version from database");
394         log_error(stmt, SQL_HANDLE_STMT);
395         throw IOException("ODBC StorageService failed to read version from database.");
396     }
397
398     SQLINTEGER major;
399     SQLINTEGER minor;
400     SQLBindCol(stmt,1,SQL_C_SLONG,&major,0,NULL);
401     SQLBindCol(stmt,2,SQL_C_SLONG,&minor,0,NULL);
402
403     if ((sr=SQLFetch(stmt)) != SQL_NO_DATA)
404         return pair<int,int>(major,minor);
405
406     m_log.error("no rows returned in version query");
407     throw IOException("ODBC StorageService failed to read version from database.");
408 }
409
410 bool ODBCStorageService::createRow(const char* table, const char* context, const char* key, const char* value, time_t expiration)
411 {
412 #ifdef _DEBUG
413     xmltooling::NDC ndc("createRow");
414 #endif
415
416     char timebuf[32];
417     timestampFromTime(expiration, timebuf);
418
419     // Get statement handle.
420     ODBCConn conn(getHDBC());
421     ODBCStatement stmt(getHSTMT(conn));
422
423     // Prepare and exectute insert statement.
424     //char *scontext = makeSafeSQL(context);
425     //char *skey = makeSafeSQL(key);
426     //char *svalue = makeSafeSQL(value);
427     string q  = string("INSERT INTO ") + table + " VALUES (?,?," + timebuf + ",1,?)";
428
429     SQLRETURN sr = SQLPrepare(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
430     if (!SQL_SUCCEEDED(sr)) {
431         m_log.error("SQLPrepare failed (t=%s, c=%s, k=%s)", table, context, key);
432         log_error(stmt, SQL_HANDLE_STMT);
433         throw IOException("ODBC StorageService failed to insert record.");
434     }
435     m_log.debug("SQLPrepare() succeded. SQL: %s", q.c_str());
436
437     SQLINTEGER b_ind = SQL_NTS;
438     sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_VARCHAR, 0, 0, const_cast<char*>(context), &b_ind);
439     if (!SQL_SUCCEEDED(sr)) {
440         m_log.error("SQLBindParam failed (context = %s)", context);
441         log_error(stmt, SQL_HANDLE_STMT);
442         throw IOException("ODBC StorageService failed to insert record.");
443     }
444     m_log.debug("SQLBindParam succeded (context = %s)", context);
445
446     sr = SQLBindParam(stmt, 2, SQL_C_CHAR, SQL_VARCHAR, 0, 0, const_cast<char*>(key), &b_ind);
447     if (!SQL_SUCCEEDED(sr)) {
448         m_log.error("SQLBindParam failed (key = %s)", key);
449         log_error(stmt, SQL_HANDLE_STMT);
450         throw IOException("ODBC StorageService failed to insert record.");
451     }
452     m_log.debug("SQLBindParam succeded (key = %s)", key);
453
454     sr = SQLBindParam(stmt, 3, SQL_C_CHAR, (strcmp(table, TEXT_TABLE)==0 ? SQL_LONGVARCHAR : SQL_VARCHAR), 0, 0, const_cast<char*>(value), &b_ind);
455     if (!SQL_SUCCEEDED(sr)) {
456         m_log.error("SQLBindParam failed (value = %s)", value);
457         log_error(stmt, SQL_HANDLE_STMT);
458         throw IOException("ODBC StorageService failed to insert record.");
459     }
460     m_log.debug("SQLBindParam succeded (value = %s)", value);
461     
462     //freeSafeSQL(scontext, context);
463     //freeSafeSQL(skey, key);
464     //freeSafeSQL(svalue, value);
465     //m_log.debug("SQL: %s", q.c_str());
466
467     sr=SQLExecute(stmt);
468     if (!SQL_SUCCEEDED(sr)) {
469         m_log.error("insert record failed (t=%s, c=%s, k=%s)", table, context, key);
470         if (log_error(stmt, SQL_HANDLE_STMT, "23000"))
471             return false;   // supposedly integrity violation?
472         throw IOException("ODBC StorageService failed to insert record.");
473     }
474     return true;
475 }
476
477 int ODBCStorageService::readRow(
478     const char *table, const char* context, const char* key, string* pvalue, time_t* pexpiration, int version, bool text
479     )
480 {
481 #ifdef _DEBUG
482     xmltooling::NDC ndc("readRow");
483 #endif
484
485     // Get statement handle.
486     ODBCConn conn(getHDBC());
487     ODBCStatement stmt(getHSTMT(conn));
488
489     // Prepare and exectute select statement.
490     char timebuf[32];
491     timestampFromTime(time(NULL), timebuf);
492     char *scontext = makeSafeSQL(context);
493     char *skey = makeSafeSQL(key);
494     ostringstream q;
495     q << "SELECT version";
496     if (pexpiration)
497         q << ",expires";
498     if (pvalue)
499         q << ",CASE version WHEN " << version << " THEN NULL ELSE value END";
500     q << " FROM " << table << " WHERE context='" << scontext << "' AND id='" << skey << "' AND expires > " << timebuf;
501     freeSafeSQL(scontext, context);
502     freeSafeSQL(skey, key);
503     if (m_log.isDebugEnabled())
504         m_log.debug("SQL: %s", q.str().c_str());
505
506     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.str().c_str(), SQL_NTS);
507     if (!SQL_SUCCEEDED(sr)) {
508         m_log.error("error searching for (t=%s, c=%s, k=%s)", table, context, key);
509         log_error(stmt, SQL_HANDLE_STMT);
510         throw IOException("ODBC StorageService search failed.");
511     }
512
513     SQLSMALLINT ver;
514     SQL_TIMESTAMP_STRUCT expiration;
515
516     SQLBindCol(stmt,1,SQL_C_SSHORT,&ver,0,NULL);
517     if (pexpiration)
518         SQLBindCol(stmt,2,SQL_C_TYPE_TIMESTAMP,&expiration,0,NULL);
519
520     if ((sr=SQLFetch(stmt)) == SQL_NO_DATA)
521         return 0;
522
523     if (pexpiration)
524         *pexpiration = timeFromTimestamp(expiration);
525
526     if (version == ver)
527         return version; // nothing's changed, so just echo back the version
528
529     if (pvalue) {
530         SQLINTEGER len;
531         SQLCHAR buf[LONGDATA_BUFLEN];
532         while ((sr=SQLGetData(stmt,pexpiration ? 3 : 2,SQL_C_CHAR,buf,sizeof(buf),&len)) != SQL_NO_DATA) {
533             if (!SQL_SUCCEEDED(sr)) {
534                 m_log.error("error while reading text field from result set");
535                 log_error(stmt, SQL_HANDLE_STMT);
536                 throw IOException("ODBC StorageService search failed to read data from result set.");
537             }
538             pvalue->append((char*)buf);
539         }
540     }
541     
542     return ver;
543 }
544
545 int ODBCStorageService::updateRow(const char *table, const char* context, const char* key, const char* value, time_t expiration, int version)
546 {
547 #ifdef _DEBUG
548     xmltooling::NDC ndc("updateRow");
549 #endif
550
551     if (!value && !expiration)
552         throw IOException("ODBC StorageService given invalid update instructions.");
553
554     // Get statement handle.
555     ODBCConn conn(getHDBC());
556     ODBCStatement stmt(getHSTMT(conn));
557
558     // First, fetch the current version for later, which also ensures the record still exists.
559     char timebuf[32];
560     timestampFromTime(time(NULL), timebuf);
561     char *scontext = makeSafeSQL(context);
562     char *skey = makeSafeSQL(key);
563     string q("SELECT version FROM ");
564     q = q + table + " WHERE context='" + scontext + "' AND id='" + key + "' AND expires > " + timebuf;
565
566     m_log.debug("SQL: %s", q.c_str());
567
568     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
569     if (!SQL_SUCCEEDED(sr)) {
570         freeSafeSQL(scontext, context);
571         freeSafeSQL(skey, key);
572         m_log.error("error searching for (t=%s, c=%s, k=%s)", table, context, key);
573         log_error(stmt, SQL_HANDLE_STMT);
574         throw IOException("ODBC StorageService search failed.");
575     }
576
577     SQLSMALLINT ver;
578     SQLBindCol(stmt,1,SQL_C_SSHORT,&ver,0,NULL);
579     if ((sr=SQLFetch(stmt)) == SQL_NO_DATA) {
580         freeSafeSQL(scontext, context);
581         freeSafeSQL(skey, key);
582         return 0;
583     }
584
585     // Check version?
586     if (version > 0 && version != ver) {
587         freeSafeSQL(scontext, context);
588         freeSafeSQL(skey, key);
589         return -1;
590     }
591
592     // Prepare and exectute update statement.
593     q = string("UPDATE ") + table + " SET ";
594
595     if (value)
596         q = q + "value=?, version=version+1";
597
598     if (expiration) {
599         timestampFromTime(expiration, timebuf);
600         if (value)
601             q += ',';
602         q = q + "expires = " + timebuf;
603     }
604
605     q = q + " WHERE context='" + scontext + "' AND id='" + key + "'";
606     freeSafeSQL(scontext, context);
607     freeSafeSQL(skey, key);
608
609     sr = SQLPrepare(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
610     if (!SQL_SUCCEEDED(sr)) {
611         m_log.error("update of record failed (t=%s, c=%s, k=%s", table, context, key);
612         log_error(stmt, SQL_HANDLE_STMT);
613         throw IOException("ODBC StorageService failed to update record.");
614     }
615     m_log.debug("SQLPrepare() succeded. SQL: %s", q.c_str());
616
617     SQLINTEGER b_ind = SQL_NTS;
618     if (value) {
619         sr = SQLBindParam(stmt, 1, SQL_C_CHAR, (strcmp(table, TEXT_TABLE)==0 ? SQL_LONGVARCHAR : SQL_VARCHAR), 0, 0, const_cast<char*>(value), &b_ind);
620         if (!SQL_SUCCEEDED(sr)) {
621             m_log.error("SQLBindParam failed (context = %s)", context);
622             log_error(stmt, SQL_HANDLE_STMT);
623             throw IOException("ODBC StorageService failed to update record.");
624         }
625         m_log.debug("SQLBindParam succeded (context = %s)", context);
626     }
627
628     sr=SQLExecute(stmt);
629     if (sr==SQL_NO_DATA)
630         return 0;   // went missing?
631     else if (!SQL_SUCCEEDED(sr)) {
632         m_log.error("update of record failed (t=%s, c=%s, k=%s", table, context, key);
633         log_error(stmt, SQL_HANDLE_STMT);
634         throw IOException("ODBC StorageService failed to update record.");
635     }
636
637     return ver + 1;
638 }
639
640 bool ODBCStorageService::deleteRow(const char *table, const char *context, const char* key)
641 {
642 #ifdef _DEBUG
643     xmltooling::NDC ndc("deleteRow");
644 #endif
645
646     // Get statement handle.
647     ODBCConn conn(getHDBC());
648     ODBCStatement stmt(getHSTMT(conn));
649
650     // Prepare and execute delete statement.
651     char *scontext = makeSafeSQL(context);
652     char *skey = makeSafeSQL(key);
653     string q = string("DELETE FROM ") + table + " WHERE context='" + scontext + "' AND id='" + skey + "'";
654     freeSafeSQL(scontext, context);
655     freeSafeSQL(skey, key);
656     m_log.debug("SQL: %s", q.c_str());
657
658     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
659      if (sr==SQL_NO_DATA)
660         return false;
661     else if (!SQL_SUCCEEDED(sr)) {
662         m_log.error("error deleting record (t=%s, c=%s, k=%s)", table, context, key);
663         log_error(stmt, SQL_HANDLE_STMT);
664         throw IOException("ODBC StorageService failed to delete record.");
665     }
666
667     return true;
668 }
669
670
671 void ODBCStorageService::cleanup()
672 {
673 #ifdef _DEBUG
674     xmltooling::NDC ndc("cleanup");
675 #endif
676
677     Mutex* mutex = Mutex::create();
678
679     mutex->lock();
680
681     m_log.info("cleanup thread started... running every %d secs", m_cleanupInterval);
682
683     while (!shutdown) {
684         shutdown_wait->timedwait(mutex, m_cleanupInterval);
685         if (shutdown)
686             break;
687         try {
688             reap(NULL);
689         }
690         catch (exception& ex) {
691             m_log.error("cleanup thread swallowed exception: %s", ex.what());
692         }
693     }
694
695     m_log.info("cleanup thread exiting...");
696
697     mutex->unlock();
698     delete mutex;
699     Thread::exit(NULL);
700 }
701
702 void* ODBCStorageService::cleanup_fn(void* cache_p)
703 {
704   ODBCStorageService* cache = (ODBCStorageService*)cache_p;
705
706 #ifndef WIN32
707   // First, let's block all signals
708   Thread::mask_all_signals();
709 #endif
710
711   // Now run the cleanup process.
712   cache->cleanup();
713   return NULL;
714 }
715
716 void ODBCStorageService::updateContext(const char *table, const char* context, time_t expiration)
717 {
718 #ifdef _DEBUG
719     xmltooling::NDC ndc("updateContext");
720 #endif
721
722     // Get statement handle.
723     ODBCConn conn(getHDBC());
724     ODBCStatement stmt(getHSTMT(conn));
725
726     char timebuf[32];
727     timestampFromTime(expiration, timebuf);
728
729     char nowbuf[32];
730     timestampFromTime(time(NULL), nowbuf);
731
732     char *scontext = makeSafeSQL(context);
733     string q("UPDATE ");
734     q = q + table + " SET expires = " + timebuf + " WHERE context='" + scontext + "' AND expires > " + nowbuf;
735     freeSafeSQL(scontext, context);
736
737     m_log.debug("SQL: %s", q.c_str());
738
739     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
740     if ((sr!=SQL_NO_DATA) && !SQL_SUCCEEDED(sr)) {
741         m_log.error("error updating records (t=%s, c=%s)", table, context ? context : "all");
742         log_error(stmt, SQL_HANDLE_STMT);
743         throw IOException("ODBC StorageService failed to update context expiration.");
744     }
745 }
746
747 void ODBCStorageService::reap(const char *table, const char* context)
748 {
749 #ifdef _DEBUG
750     xmltooling::NDC ndc("reap");
751 #endif
752
753     // Get statement handle.
754     ODBCConn conn(getHDBC());
755     ODBCStatement stmt(getHSTMT(conn));
756
757     // Prepare and execute delete statement.
758     char nowbuf[32];
759     timestampFromTime(time(NULL), nowbuf);
760     string q;
761     if (context) {
762         char *scontext = makeSafeSQL(context);
763         q = string("DELETE FROM ") + table + " WHERE context='" + scontext + "' AND expires <= " + nowbuf;
764         freeSafeSQL(scontext, context);
765     }
766     else {
767         q = string("DELETE FROM ") + table + " WHERE expires <= " + nowbuf;
768     }
769     m_log.debug("SQL: %s", q.c_str());
770
771     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
772     if ((sr!=SQL_NO_DATA) && !SQL_SUCCEEDED(sr)) {
773         m_log.error("error expiring records (t=%s, c=%s)", table, context ? context : "all");
774         log_error(stmt, SQL_HANDLE_STMT);
775         throw IOException("ODBC StorageService failed to purge expired records.");
776     }
777 }
778
779 void ODBCStorageService::deleteContext(const char *table, const char* context)
780 {
781 #ifdef _DEBUG
782     xmltooling::NDC ndc("deleteContext");
783 #endif
784
785     // Get statement handle.
786     ODBCConn conn(getHDBC());
787     ODBCStatement stmt(getHSTMT(conn));
788
789     // Prepare and execute delete statement.
790     char *scontext = makeSafeSQL(context);
791     string q = string("DELETE FROM ") + table + " WHERE context='" + scontext + "'";
792     freeSafeSQL(scontext, context);
793     m_log.debug("SQL: %s", q.c_str());
794
795     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
796     if ((sr!=SQL_NO_DATA) && !SQL_SUCCEEDED(sr)) {
797         m_log.error("error deleting context (t=%s, c=%s)", table, context);
798         log_error(stmt, SQL_HANDLE_STMT);
799         throw IOException("ODBC StorageService failed to delete context.");
800     }
801 }
802
803 extern "C" int ODBCSTORE_EXPORTS xmltooling_extension_init(void*)
804 {
805     // Register this SS type
806     XMLToolingConfig::getConfig().StorageServiceManager.registerFactory("ODBC", ODBCStorageServiceFactory);
807     return 0;
808 }
809
810 extern "C" void ODBCSTORE_EXPORTS xmltooling_extension_term()
811 {
812     XMLToolingConfig::getConfig().StorageServiceManager.deregisterFactory("ODBC");
813 }