Switch to auto-commit for all the non-update transactions.
[shibboleth/sp.git] / odbc-store / odbc-store.cpp
1 /*
2  *  Copyright 2001-2007 Internet2
3  * 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /**
18  * odbc-store.cpp
19  *
20  * Storage Service using ODBC
21  */
22
23 #if defined (_MSC_VER) || defined(__BORLANDC__)
24 # include "config_win32.h"
25 #else
26 # include "config.h"
27 #endif
28
29 #ifdef WIN32
30 # define _CRT_NONSTDC_NO_DEPRECATE 1
31 # define _CRT_SECURE_NO_DEPRECATE 1
32 #endif
33
34 #ifdef WIN32
35 # define ODBCSTORE_EXPORTS __declspec(dllexport)
36 #else
37 # define ODBCSTORE_EXPORTS
38 #endif
39
40 #include <xercesc/util/XMLUniDefs.hpp>
41 #include <xmltooling/logging.h>
42 #include <xmltooling/XMLToolingConfig.h>
43 #include <xmltooling/util/NDC.h>
44 #include <xmltooling/util/StorageService.h>
45 #include <xmltooling/util/Threads.h>
46 #include <xmltooling/util/XMLHelper.h>
47
48 #include <sql.h>
49 #include <sqlext.h>
50
51 using namespace xmltooling::logging;
52 using namespace xmltooling;
53 using namespace xercesc;
54 using namespace std;
55
56 #define PLUGIN_VER_MAJOR 1
57 #define PLUGIN_VER_MINOR 0
58
59 #define LONGDATA_BUFLEN 16384
60
61 #define COLSIZE_CONTEXT 255
62 #define COLSIZE_ID 255
63 #define COLSIZE_STRING_VALUE 255
64
65 #define STRING_TABLE "strings"
66 #define TEXT_TABLE "texts"
67
68 /* table definitions
69 CREATE TABLE version (
70     major tinyint NOT NULL,
71     minor tinyint NOT NULL
72     )
73
74 CREATE TABLE strings (
75     context varchar(255) not null,
76     id varchar(255) not null,
77     expires datetime not null,
78     version smallint not null,
79     value varchar(255) not null,
80     PRIMARY KEY (context, id)
81     )
82
83 CREATE TABLE texts (
84     context varchar(255) not null,
85     id varchar(255) not null,
86     expires datetime not null,
87     version smallint not null,
88     value text not null,
89     PRIMARY KEY (context, id)
90     )
91 */
92
93 namespace {
94     static const XMLCh cleanupInterval[] =  UNICODE_LITERAL_15(c,l,e,a,n,u,p,I,n,t,e,r,v,a,l);
95     static const XMLCh ConnectionString[] = UNICODE_LITERAL_16(C,o,n,n,e,c,t,i,o,n,S,t,r,i,n,g);
96
97     // RAII for ODBC handles
98     struct ODBCConn {
99         ODBCConn(SQLHDBC conn) : handle(conn), autoCommit(true) {}
100         ~ODBCConn() {
101             SQLRETURN sr = SQL_SUCCESS;
102             if (!autoCommit)
103                 sr = SQLSetConnectAttr(handle, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)SQL_AUTOCOMMIT_ON, NULL);
104             SQLDisconnect(handle);
105             SQLFreeHandle(SQL_HANDLE_DBC,handle);
106             if (!SQL_SUCCEEDED(sr))
107                 throw IOException("Failed to commit connection and return to auto-commit mode.");
108         }
109         operator SQLHDBC() {return handle;}
110         SQLHDBC handle;
111         bool autoCommit;
112     };
113
114     class ODBCStorageService : public StorageService
115     {
116     public:
117         ODBCStorageService(const DOMElement* e);
118         virtual ~ODBCStorageService();
119
120         bool createString(const char* context, const char* key, const char* value, time_t expiration) {
121             return createRow(STRING_TABLE, context, key, value, expiration);
122         }
123         int readString(const char* context, const char* key, string* pvalue=NULL, time_t* pexpiration=NULL, int version=0) {
124             return readRow(STRING_TABLE, context, key, pvalue, pexpiration, version, false);
125         }
126         int updateString(const char* context, const char* key, const char* value=NULL, time_t expiration=0, int version=0) {
127             return updateRow(STRING_TABLE, context, key, value, expiration, version);
128         }
129         bool deleteString(const char* context, const char* key) {
130             return deleteRow(STRING_TABLE, context, key);
131         }
132
133         bool createText(const char* context, const char* key, const char* value, time_t expiration) {
134             return createRow(TEXT_TABLE, context, key, value, expiration);
135         }
136         int readText(const char* context, const char* key, string* pvalue=NULL, time_t* pexpiration=NULL, int version=0) {
137             return readRow(TEXT_TABLE, context, key, pvalue, pexpiration, version, true);
138         }
139         int updateText(const char* context, const char* key, const char* value=NULL, time_t expiration=0, int version=0) {
140             return updateRow(TEXT_TABLE, context, key, value, expiration, version);
141         }
142         bool deleteText(const char* context, const char* key) {
143             return deleteRow(TEXT_TABLE, context, key);
144         }
145
146         void reap(const char* context) {
147             reap(STRING_TABLE, context);
148             reap(TEXT_TABLE, context);
149         }
150
151         void updateContext(const char* context, time_t expiration) {
152             updateContext(STRING_TABLE, context, expiration);
153             updateContext(TEXT_TABLE, context, expiration);
154         }
155
156         void deleteContext(const char* context) {
157             deleteContext(STRING_TABLE, context);
158             deleteContext(TEXT_TABLE, context);
159         }
160          
161
162     private:
163         bool createRow(const char *table, const char* context, const char* key, const char* value, time_t expiration);
164         int readRow(const char *table, const char* context, const char* key, string* pvalue, time_t* pexpiration, int version, bool text);
165         int updateRow(const char *table, const char* context, const char* key, const char* value, time_t expiration, int version);
166         bool deleteRow(const char *table, const char* context, const char* key);
167
168         void reap(const char* table, const char* context);
169         void updateContext(const char* table, const char* context, time_t expiration);
170         void deleteContext(const char* table, const char* context);
171
172         SQLHDBC getHDBC();
173         SQLHSTMT getHSTMT(SQLHDBC);
174         pair<int,int> getVersion(SQLHDBC);
175         bool log_error(SQLHANDLE handle, SQLSMALLINT htype, const char* checkfor=NULL);
176
177         static void* cleanup_fn(void*); 
178         void cleanup();
179
180         Category& m_log;
181         int m_cleanupInterval;
182         CondWait* shutdown_wait;
183         Thread* cleanup_thread;
184         bool shutdown;
185
186         SQLHENV m_henv;
187         string m_connstring;
188     };
189
190     StorageService* ODBCStorageServiceFactory(const DOMElement* const & e)
191     {
192         return new ODBCStorageService(e);
193     }
194
195     // convert SQL timestamp to time_t 
196     time_t timeFromTimestamp(SQL_TIMESTAMP_STRUCT expires)
197     {
198         time_t ret;
199         struct tm t;
200         t.tm_sec=expires.second;
201         t.tm_min=expires.minute;
202         t.tm_hour=expires.hour;
203         t.tm_mday=expires.day;
204         t.tm_mon=expires.month-1;
205         t.tm_year=expires.year-1900;
206         t.tm_isdst=0;
207 #if defined(HAVE_TIMEGM)
208         ret = timegm(&t);
209 #else
210         ret = mktime(&t) - timezone;
211 #endif
212         return (ret);
213     }
214
215     // conver time_t to SQL string
216     void timestampFromTime(time_t t, char* ret)
217     {
218 #ifdef HAVE_GMTIME_R
219         struct tm res;
220         struct tm* ptime=gmtime_r(&t,&res);
221 #else
222         struct tm* ptime=gmtime(&t);
223 #endif
224         strftime(ret,32,"{ts '%Y-%m-%d %H:%M:%S'}",ptime);
225     }
226
227     // make a string safe for SQL command
228     // result to be free'd only if it isn't the input
229     static char *makeSafeSQL(const char *src)
230     {
231        int ns = 0;
232        int nc = 0;
233        char *s;
234     
235        // see if any conversion needed
236        for (s=(char*)src; *s; nc++,s++) if (*s=='\'') ns++;
237        if (ns==0) return ((char*)src);
238     
239        char *safe = new char[(nc+2*ns+1)];
240        for (s=safe; *src; src++) {
241            if (*src=='\'') *s++ = '\'';
242            *s++ = (char)*src;
243        }
244        *s = '\0';
245        return (safe);
246     }
247
248     void freeSafeSQL(char *safe, const char *src)
249     {
250         if (safe!=src)
251             delete[](safe);
252     }
253 };
254
255 ODBCStorageService::ODBCStorageService(const DOMElement* e) : m_log(Category::getInstance("XMLTooling.StorageService")),
256    m_cleanupInterval(900), shutdown_wait(NULL), cleanup_thread(NULL), shutdown(false), m_henv(SQL_NULL_HANDLE)
257 {
258 #ifdef _DEBUG
259     xmltooling::NDC ndc("ODBCStorageService");
260 #endif
261
262     const XMLCh* tag=e ? e->getAttributeNS(NULL,cleanupInterval) : NULL;
263     if (tag && *tag)
264         m_cleanupInterval = XMLString::parseInt(tag);
265     if (!m_cleanupInterval)
266         m_cleanupInterval = 900;
267
268     if (m_henv == SQL_NULL_HANDLE) {
269         // Enable connection pooling.
270         SQLSetEnvAttr(SQL_NULL_HANDLE, SQL_ATTR_CONNECTION_POOLING, (void*)SQL_CP_ONE_PER_HENV, 0);
271
272         // Allocate the environment.
273         if (!SQL_SUCCEEDED(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &m_henv)))
274             throw XMLToolingException("ODBC failed to initialize.");
275
276         // Specify ODBC 3.x
277         SQLSetEnvAttr(m_henv, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0);
278
279         m_log.info("ODBC initialized");
280     }
281
282     // Grab connection string from the configuration.
283     e = e ? XMLHelper::getFirstChildElement(e,ConnectionString) : NULL;
284     if (!e || !e->hasChildNodes()) {
285         SQLFreeHandle(SQL_HANDLE_ENV, m_henv);
286         throw XMLToolingException("ODBC StorageService requires ConnectionString element in configuration.");
287     }
288     auto_ptr_char arg(e->getFirstChild()->getNodeValue());
289     m_connstring=arg.get();
290
291     // Connect and check version.
292     ODBCConn conn(getHDBC());
293     pair<int,int> v=getVersion(conn);
294
295     // Make sure we've got the right version.
296     if (v.first != PLUGIN_VER_MAJOR) {
297         SQLFreeHandle(SQL_HANDLE_ENV, m_henv);
298         m_log.crit("unknown database version: %d.%d", v.first, v.second);
299         throw XMLToolingException("Unknown database version for ODBC StorageService.");
300     }
301
302     // Initialize the cleanup thread
303     shutdown_wait = CondWait::create();
304     cleanup_thread = Thread::create(&cleanup_fn, (void*)this);
305 }
306
307 ODBCStorageService::~ODBCStorageService()
308 {
309     shutdown = true;
310     shutdown_wait->signal();
311     cleanup_thread->join(NULL);
312     delete shutdown_wait;
313     if (m_henv != SQL_NULL_HANDLE)
314         SQLFreeHandle(SQL_HANDLE_ENV, m_henv);
315 }
316
317 bool ODBCStorageService::log_error(SQLHANDLE handle, SQLSMALLINT htype, const char* checkfor)
318 {
319     SQLSMALLINT  i = 0;
320     SQLINTEGER   native;
321     SQLCHAR      state[7];
322     SQLCHAR      text[256];
323     SQLSMALLINT  len;
324     SQLRETURN    ret;
325
326     bool res = false;
327     do {
328         ret = SQLGetDiagRec(htype, handle, ++i, state, &native, text, sizeof(text), &len);
329         if (SQL_SUCCEEDED(ret)) {
330             m_log.error("ODBC Error: %s:%ld:%ld:%s", state, i, native, text);
331             if (checkfor && !strcmp(checkfor, (const char*)state))
332                 res = true;
333         }
334     } while(SQL_SUCCEEDED(ret));
335     return res;
336 }
337
338 SQLHDBC ODBCStorageService::getHDBC()
339 {
340 #ifdef _DEBUG
341     xmltooling::NDC ndc("getHDBC");
342 #endif
343
344     // Get a handle.
345     SQLHDBC handle;
346     SQLRETURN sr=SQLAllocHandle(SQL_HANDLE_DBC, m_henv, &handle);
347     if (!SQL_SUCCEEDED(sr)) {
348         m_log.error("failed to allocate connection handle");
349         log_error(m_henv, SQL_HANDLE_ENV);
350         throw IOException("ODBC StorageService failed to allocate a connection handle.");
351     }
352
353     sr=SQLDriverConnect(handle,NULL,(SQLCHAR*)m_connstring.c_str(),m_connstring.length(),NULL,0,NULL,SQL_DRIVER_NOPROMPT);
354     if (!SQL_SUCCEEDED(sr)) {
355         m_log.error("failed to connect to database");
356         log_error(handle, SQL_HANDLE_DBC);
357         throw IOException("ODBC StorageService failed to connect to database.");
358     }
359
360     sr = SQLSetConnectAttr(handle, SQL_ATTR_TXN_ISOLATION, (SQLPOINTER)SQL_TXN_SERIALIZABLE, NULL);
361     if (!SQL_SUCCEEDED(sr))
362         throw IOException("ODBC StorageService failed to enable transaction isolation.");
363
364     return handle;
365 }
366
367 SQLHSTMT ODBCStorageService::getHSTMT(SQLHDBC conn)
368 {
369     SQLHSTMT hstmt;
370     SQLRETURN sr=SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
371     if (!SQL_SUCCEEDED(sr)) {
372         m_log.error("failed to allocate statement handle");
373         log_error(conn, SQL_HANDLE_DBC);
374         throw IOException("ODBC StorageService failed to allocate a statement handle.");
375     }
376     return hstmt;
377 }
378
379 pair<int,int> ODBCStorageService::getVersion(SQLHDBC conn)
380 {
381     // Grab the version number from the database.
382     SQLHSTMT stmt = getHSTMT(conn);
383     
384     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)"SELECT major,minor FROM version", SQL_NTS);
385     if (!SQL_SUCCEEDED(sr)) {
386         m_log.error("failed to read version from database");
387         log_error(stmt, SQL_HANDLE_STMT);
388         throw IOException("ODBC StorageService failed to read version from database.");
389     }
390
391     SQLINTEGER major;
392     SQLINTEGER minor;
393     SQLBindCol(stmt,1,SQL_C_SLONG,&major,0,NULL);
394     SQLBindCol(stmt,2,SQL_C_SLONG,&minor,0,NULL);
395
396     if ((sr=SQLFetch(stmt)) != SQL_NO_DATA)
397         return pair<int,int>(major,minor);
398
399     m_log.error("no rows returned in version query");
400     throw IOException("ODBC StorageService failed to read version from database.");
401 }
402
403 bool ODBCStorageService::createRow(const char* table, const char* context, const char* key, const char* value, time_t expiration)
404 {
405 #ifdef _DEBUG
406     xmltooling::NDC ndc("createRow");
407 #endif
408
409     char timebuf[32];
410     timestampFromTime(expiration, timebuf);
411
412     // Get statement handle.
413     ODBCConn conn(getHDBC());
414     SQLHSTMT stmt = getHSTMT(conn);
415
416     // Prepare and exectute insert statement.
417     //char *scontext = makeSafeSQL(context);
418     //char *skey = makeSafeSQL(key);
419     //char *svalue = makeSafeSQL(value);
420     string q  = string("INSERT INTO ") + table + " VALUES (?,?," + timebuf + ",1,?)";
421
422     SQLRETURN sr = SQLPrepare(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
423     if (!SQL_SUCCEEDED(sr)) {
424         m_log.error("SQLPrepare failed (t=%s, c=%s, k=%s)", table, context, key);
425         log_error(stmt, SQL_HANDLE_STMT);
426         throw IOException("ODBC StorageService failed to insert record.");
427     }
428     m_log.debug("SQLPrepare succeded. SQL: %s", q.c_str());
429
430     SQLINTEGER b_ind = SQL_NTS;
431     sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(context), &b_ind);
432     if (!SQL_SUCCEEDED(sr)) {
433         m_log.error("SQLBindParam failed (context = %s)", context);
434         log_error(stmt, SQL_HANDLE_STMT);
435         throw IOException("ODBC StorageService failed to insert record.");
436     }
437     m_log.debug("SQLBindParam succeded (context = %s)", context);
438
439     sr = SQLBindParam(stmt, 2, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(key), &b_ind);
440     if (!SQL_SUCCEEDED(sr)) {
441         m_log.error("SQLBindParam failed (key = %s)", key);
442         log_error(stmt, SQL_HANDLE_STMT);
443         throw IOException("ODBC StorageService failed to insert record.");
444     }
445     m_log.debug("SQLBindParam succeded (key = %s)", key);
446
447     if (strcmp(table, TEXT_TABLE)==0)
448         sr = SQLBindParam(stmt, 3, SQL_C_CHAR, SQL_LONGVARCHAR, strlen(value), 0, const_cast<char*>(value), &b_ind);
449     else
450         sr = SQLBindParam(stmt, 3, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(value), &b_ind);
451     if (!SQL_SUCCEEDED(sr)) {
452         m_log.error("SQLBindParam failed (value = %s)", value);
453         log_error(stmt, SQL_HANDLE_STMT);
454         throw IOException("ODBC StorageService failed to insert record.");
455     }
456     m_log.debug("SQLBindParam succeded (value = %s)", value);
457     
458     //freeSafeSQL(scontext, context);
459     //freeSafeSQL(skey, key);
460     //freeSafeSQL(svalue, value);
461     //m_log.debug("SQL: %s", q.c_str());
462
463     sr=SQLExecute(stmt);
464     if (!SQL_SUCCEEDED(sr)) {
465         m_log.error("insert record failed (t=%s, c=%s, k=%s)", table, context, key);
466         if (log_error(stmt, SQL_HANDLE_STMT, "23000"))
467             return false;   // supposedly integrity violation?
468         throw IOException("ODBC StorageService failed to insert record.");
469     }
470
471     m_log.debug("SQLExecute of insert succeeded");
472     return true;
473 }
474
475 int ODBCStorageService::readRow(
476     const char *table, const char* context, const char* key, string* pvalue, time_t* pexpiration, int version, bool text
477     )
478 {
479 #ifdef _DEBUG
480     xmltooling::NDC ndc("readRow");
481 #endif
482
483     // Get statement handle.
484     ODBCConn conn(getHDBC());
485     SQLHSTMT stmt = getHSTMT(conn);
486
487     // Prepare and exectute select statement.
488     char timebuf[32];
489     timestampFromTime(time(NULL), timebuf);
490     char *scontext = makeSafeSQL(context);
491     char *skey = makeSafeSQL(key);
492     ostringstream q;
493     q << "SELECT version";
494     if (pexpiration)
495         q << ",expires";
496     if (pvalue)
497         q << ",CASE version WHEN " << version << " THEN NULL ELSE value END";
498     q << " FROM " << table << " WHERE context='" << scontext << "' AND id='" << skey << "' AND expires > " << timebuf;
499     freeSafeSQL(scontext, context);
500     freeSafeSQL(skey, key);
501     if (m_log.isDebugEnabled())
502         m_log.debug("SQL: %s", q.str().c_str());
503
504     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.str().c_str(), SQL_NTS);
505     if (!SQL_SUCCEEDED(sr)) {
506         m_log.error("error searching for (t=%s, c=%s, k=%s)", table, context, key);
507         log_error(stmt, SQL_HANDLE_STMT);
508         throw IOException("ODBC StorageService search failed.");
509     }
510
511     SQLSMALLINT ver;
512     SQL_TIMESTAMP_STRUCT expiration;
513
514     SQLBindCol(stmt,1,SQL_C_SSHORT,&ver,0,NULL);
515     if (pexpiration)
516         SQLBindCol(stmt,2,SQL_C_TYPE_TIMESTAMP,&expiration,0,NULL);
517
518     if ((sr=SQLFetch(stmt)) == SQL_NO_DATA)
519         return 0;
520
521     if (pexpiration)
522         *pexpiration = timeFromTimestamp(expiration);
523
524     if (version == ver)
525         return version; // nothing's changed, so just echo back the version
526
527     if (pvalue) {
528         SQLINTEGER len;
529         SQLCHAR buf[LONGDATA_BUFLEN];
530         while ((sr=SQLGetData(stmt,pexpiration ? 3 : 2,SQL_C_CHAR,buf,sizeof(buf),&len)) != SQL_NO_DATA) {
531             if (!SQL_SUCCEEDED(sr)) {
532                 m_log.error("error while reading text field from result set");
533                 log_error(stmt, SQL_HANDLE_STMT);
534                 throw IOException("ODBC StorageService search failed to read data from result set.");
535             }
536             pvalue->append((char*)buf);
537         }
538     }
539     
540     return ver;
541 }
542
543 int ODBCStorageService::updateRow(const char *table, const char* context, const char* key, const char* value, time_t expiration, int version)
544 {
545 #ifdef _DEBUG
546     xmltooling::NDC ndc("updateRow");
547 #endif
548
549     if (!value && !expiration)
550         throw IOException("ODBC StorageService given invalid update instructions.");
551
552     // Get statement handle. Disable auto-commit mode to wrap select + update.
553     ODBCConn conn(getHDBC());
554     SQLRETURN sr = SQLSetConnectAttr(conn, SQL_ATTR_AUTOCOMMIT, SQL_AUTOCOMMIT_OFF, NULL);
555     if (!SQL_SUCCEEDED(sr))
556         throw IOException("ODBC StorageService failed to disable auto-commit mode.");
557     conn.autoCommit = false;
558     SQLHSTMT stmt = getHSTMT(conn);
559
560     // First, fetch the current version for later, which also ensures the record still exists.
561     char timebuf[32];
562     timestampFromTime(time(NULL), timebuf);
563     char *scontext = makeSafeSQL(context);
564     char *skey = makeSafeSQL(key);
565     string q("SELECT version FROM ");
566     q = q + table + " WHERE context='" + scontext + "' AND id='" + key + "' AND expires > " + timebuf;
567
568     m_log.debug("SQL: %s", q.c_str());
569
570     sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
571     if (!SQL_SUCCEEDED(sr)) {
572         freeSafeSQL(scontext, context);
573         freeSafeSQL(skey, key);
574         m_log.error("error searching for (t=%s, c=%s, k=%s)", table, context, key);
575         log_error(stmt, SQL_HANDLE_STMT);
576         throw IOException("ODBC StorageService search failed.");
577     }
578
579     SQLSMALLINT ver;
580     SQLBindCol(stmt,1,SQL_C_SSHORT,&ver,0,NULL);
581     if ((sr=SQLFetch(stmt)) == SQL_NO_DATA) {
582         freeSafeSQL(scontext, context);
583         freeSafeSQL(skey, key);
584         return 0;
585     }
586
587     // Check version?
588     if (version > 0 && version != ver) {
589         freeSafeSQL(scontext, context);
590         freeSafeSQL(skey, key);
591         return -1;
592     }
593
594     SQLFreeHandle(SQL_HANDLE_STMT, stmt);
595     stmt = getHSTMT(conn);
596
597     // Prepare and exectute update statement.
598     q = string("UPDATE ") + table + " SET ";
599
600     if (value)
601         q = q + "value=?, version=version+1";
602
603     if (expiration) {
604         timestampFromTime(expiration, timebuf);
605         if (value)
606             q += ',';
607         q = q + "expires = " + timebuf;
608     }
609
610     q = q + " WHERE context='" + scontext + "' AND id='" + key + "'";
611     freeSafeSQL(scontext, context);
612     freeSafeSQL(skey, key);
613
614     sr = SQLPrepare(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
615     if (!SQL_SUCCEEDED(sr)) {
616         m_log.error("update of record failed (t=%s, c=%s, k=%s", table, context, key);
617         log_error(stmt, SQL_HANDLE_STMT);
618         throw IOException("ODBC StorageService failed to update record.");
619     }
620     m_log.debug("SQLPrepare succeded. SQL: %s", q.c_str());
621
622     SQLINTEGER b_ind = SQL_NTS;
623     if (value) {
624         if (strcmp(table, TEXT_TABLE)==0)
625             sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_LONGVARCHAR, strlen(value), 0, const_cast<char*>(value), &b_ind);
626         else
627             sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(value), &b_ind);
628         if (!SQL_SUCCEEDED(sr)) {
629             m_log.error("SQLBindParam failed (context = %s)", context);
630             log_error(stmt, SQL_HANDLE_STMT);
631             throw IOException("ODBC StorageService failed to update record.");
632         }
633         m_log.debug("SQLBindParam succeded (context = %s)", context);
634     }
635
636     sr=SQLExecute(stmt);
637     if (sr==SQL_NO_DATA)
638         return 0;   // went missing?
639     else if (!SQL_SUCCEEDED(sr)) {
640         m_log.error("update of record failed (t=%s, c=%s, k=%s", table, context, key);
641         log_error(stmt, SQL_HANDLE_STMT);
642         throw IOException("ODBC StorageService failed to update record.");
643     }
644
645     m_log.debug("SQLExecute of update succeeded");
646     return ver + 1;
647 }
648
649 bool ODBCStorageService::deleteRow(const char *table, const char *context, const char* key)
650 {
651 #ifdef _DEBUG
652     xmltooling::NDC ndc("deleteRow");
653 #endif
654
655     // Get statement handle.
656     ODBCConn conn(getHDBC());
657     SQLHSTMT stmt = getHSTMT(conn);
658
659     // Prepare and execute delete statement.
660     char *scontext = makeSafeSQL(context);
661     char *skey = makeSafeSQL(key);
662     string q = string("DELETE FROM ") + table + " WHERE context='" + scontext + "' AND id='" + skey + "'";
663     freeSafeSQL(scontext, context);
664     freeSafeSQL(skey, key);
665     m_log.debug("SQL: %s", q.c_str());
666
667     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
668      if (sr==SQL_NO_DATA)
669         return false;
670     else if (!SQL_SUCCEEDED(sr)) {
671         m_log.error("error deleting record (t=%s, c=%s, k=%s)", table, context, key);
672         log_error(stmt, SQL_HANDLE_STMT);
673         throw IOException("ODBC StorageService failed to delete record.");
674     }
675
676     return true;
677 }
678
679
680 void ODBCStorageService::cleanup()
681 {
682 #ifdef _DEBUG
683     xmltooling::NDC ndc("cleanup");
684 #endif
685
686     Mutex* mutex = Mutex::create();
687
688     mutex->lock();
689
690     m_log.info("cleanup thread started... running every %d secs", m_cleanupInterval);
691
692     while (!shutdown) {
693         shutdown_wait->timedwait(mutex, m_cleanupInterval);
694         if (shutdown)
695             break;
696         try {
697             reap(NULL);
698         }
699         catch (exception& ex) {
700             m_log.error("cleanup thread swallowed exception: %s", ex.what());
701         }
702     }
703
704     m_log.info("cleanup thread exiting...");
705
706     mutex->unlock();
707     delete mutex;
708     Thread::exit(NULL);
709 }
710
711 void* ODBCStorageService::cleanup_fn(void* cache_p)
712 {
713   ODBCStorageService* cache = (ODBCStorageService*)cache_p;
714
715 #ifndef WIN32
716   // First, let's block all signals
717   Thread::mask_all_signals();
718 #endif
719
720   // Now run the cleanup process.
721   cache->cleanup();
722   return NULL;
723 }
724
725 void ODBCStorageService::updateContext(const char *table, const char* context, time_t expiration)
726 {
727 #ifdef _DEBUG
728     xmltooling::NDC ndc("updateContext");
729 #endif
730
731     // Get statement handle.
732     ODBCConn conn(getHDBC());
733     SQLHSTMT stmt = getHSTMT(conn);
734
735     char timebuf[32];
736     timestampFromTime(expiration, timebuf);
737
738     char nowbuf[32];
739     timestampFromTime(time(NULL), nowbuf);
740
741     char *scontext = makeSafeSQL(context);
742     string q("UPDATE ");
743     q = q + table + " SET expires = " + timebuf + " WHERE context='" + scontext + "' AND expires > " + nowbuf;
744     freeSafeSQL(scontext, context);
745
746     m_log.debug("SQL: %s", q.c_str());
747
748     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
749     if ((sr!=SQL_NO_DATA) && !SQL_SUCCEEDED(sr)) {
750         m_log.error("error updating records (t=%s, c=%s)", table, context ? context : "all");
751         log_error(stmt, SQL_HANDLE_STMT);
752         throw IOException("ODBC StorageService failed to update context expiration.");
753     }
754 }
755
756 void ODBCStorageService::reap(const char *table, const char* context)
757 {
758 #ifdef _DEBUG
759     xmltooling::NDC ndc("reap");
760 #endif
761
762     // Get statement handle.
763     ODBCConn conn(getHDBC());
764     SQLHSTMT stmt = getHSTMT(conn);
765
766     // Prepare and execute delete statement.
767     char nowbuf[32];
768     timestampFromTime(time(NULL), nowbuf);
769     string q;
770     if (context) {
771         char *scontext = makeSafeSQL(context);
772         q = string("DELETE FROM ") + table + " WHERE context='" + scontext + "' AND expires <= " + nowbuf;
773         freeSafeSQL(scontext, context);
774     }
775     else {
776         q = string("DELETE FROM ") + table + " WHERE expires <= " + nowbuf;
777     }
778     m_log.debug("SQL: %s", q.c_str());
779
780     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
781     if ((sr!=SQL_NO_DATA) && !SQL_SUCCEEDED(sr)) {
782         m_log.error("error expiring records (t=%s, c=%s)", table, context ? context : "all");
783         log_error(stmt, SQL_HANDLE_STMT);
784         throw IOException("ODBC StorageService failed to purge expired records.");
785     }
786 }
787
788 void ODBCStorageService::deleteContext(const char *table, const char* context)
789 {
790 #ifdef _DEBUG
791     xmltooling::NDC ndc("deleteContext");
792 #endif
793
794     // Get statement handle.
795     ODBCConn conn(getHDBC());
796     SQLHSTMT stmt = getHSTMT(conn);
797
798     // Prepare and execute delete statement.
799     char *scontext = makeSafeSQL(context);
800     string q = string("DELETE FROM ") + table + " WHERE context='" + scontext + "'";
801     freeSafeSQL(scontext, context);
802     m_log.debug("SQL: %s", q.c_str());
803
804     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
805     if ((sr!=SQL_NO_DATA) && !SQL_SUCCEEDED(sr)) {
806         m_log.error("error deleting context (t=%s, c=%s)", table, context);
807         log_error(stmt, SQL_HANDLE_STMT);
808         throw IOException("ODBC StorageService failed to delete context.");
809     }
810 }
811
812 extern "C" int ODBCSTORE_EXPORTS xmltooling_extension_init(void*)
813 {
814     // Register this SS type
815     XMLToolingConfig::getConfig().StorageServiceManager.registerFactory("ODBC", ODBCStorageServiceFactory);
816     return 0;
817 }
818
819 extern "C" void ODBCSTORE_EXPORTS xmltooling_extension_term()
820 {
821     XMLToolingConfig::getConfig().StorageServiceManager.deregisterFactory("ODBC");
822 }