Boost changes to rest of project
[shibboleth/cpp-sp.git] / odbc-store / odbc-store.cpp
1 /**
2  * Licensed to the University Corporation for Advanced Internet
3  * Development, Inc. (UCAID) under one or more contributor license
4  * agreements. See the NOTICE file distributed with this work for
5  * additional information regarding copyright ownership.
6  *
7  * UCAID licenses this file to you under the Apache License,
8  * Version 2.0 (the "License"); you may not use this file except
9  * in compliance with the License. You may obtain a copy of the
10  * License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing,
15  * software distributed under the License is distributed on an
16  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
17  * either express or implied. See the License for the specific
18  * language governing permissions and limitations under the License.
19  */
20
21 /**
22  * odbc-store.cpp
23  *
24  * Storage Service using ODBC.
25  */
26
27 #if defined (_MSC_VER) || defined(__BORLANDC__)
28 # include "config_win32.h"
29 #else
30 # include "config.h"
31 #endif
32
33 #ifdef WIN32
34 # define _CRT_NONSTDC_NO_DEPRECATE 1
35 # define _CRT_SECURE_NO_DEPRECATE 1
36 #endif
37
38 #ifdef WIN32
39 # define ODBCSTORE_EXPORTS __declspec(dllexport)
40 #else
41 # define ODBCSTORE_EXPORTS
42 #endif
43
44 #include <xmltooling/logging.h>
45 #include <xmltooling/unicode.h>
46 #include <xmltooling/XMLToolingConfig.h>
47 #include <xmltooling/util/NDC.h>
48 #include <xmltooling/util/StorageService.h>
49 #include <xmltooling/util/Threads.h>
50 #include <xmltooling/util/XMLHelper.h>
51 #include <xercesc/util/XMLUniDefs.hpp>
52
53 #include <sql.h>
54 #include <sqlext.h>
55
56 #include <boost/lexical_cast.hpp>
57 #include <boost/scoped_ptr.hpp>
58 #include <boost/algorithm/string.hpp>
59
60 using namespace xmltooling::logging;
61 using namespace xmltooling;
62 using namespace xercesc;
63 using namespace boost;
64 using namespace std;
65
66 #define PLUGIN_VER_MAJOR 1
67 #define PLUGIN_VER_MINOR 0
68
69 #define LONGDATA_BUFLEN 16384
70
71 #define COLSIZE_CONTEXT 255
72 #define COLSIZE_ID 255
73 #define COLSIZE_STRING_VALUE 255
74
75 #define STRING_TABLE "strings"
76 #define TEXT_TABLE "texts"
77
78 /* table definitions
79 CREATE TABLE version (
80     major int NOT nullptr,
81     minor int NOT nullptr
82     )
83
84 CREATE TABLE strings (
85     context varchar(255) not null,
86     id varchar(255) not null,
87     expires datetime not null,
88     version smallint not null,
89     value varchar(255) not null,
90     PRIMARY KEY (context, id)
91     )
92
93 CREATE TABLE texts (
94     context varchar(255) not null,
95     id varchar(255) not null,
96     expires datetime not null,
97     version smallint not null,
98     value text not null,
99     PRIMARY KEY (context, id)
100     )
101 */
102
103 namespace {
104     static const XMLCh cleanupInterval[] =  UNICODE_LITERAL_15(c,l,e,a,n,u,p,I,n,t,e,r,v,a,l);
105     static const XMLCh isolationLevel[] =   UNICODE_LITERAL_14(i,s,o,l,a,t,i,o,n,L,e,v,e,l);
106     static const XMLCh ConnectionString[] = UNICODE_LITERAL_16(C,o,n,n,e,c,t,i,o,n,S,t,r,i,n,g);
107     static const XMLCh RetryOnError[] =     UNICODE_LITERAL_12(R,e,t,r,y,O,n,E,r,r,o,r);
108     static const XMLCh contextSize[] =      UNICODE_LITERAL_11(c,o,n,t,e,x,t,S,i,z,e);
109     static const XMLCh keySize[] =          UNICODE_LITERAL_7(k,e,y,S,i,z,e);
110     static const XMLCh stringSize[] =       UNICODE_LITERAL_10(s,t,r,i,n,g,S,i,z,e);
111
112     // RAII for ODBC handles
113     struct ODBCConn {
114         ODBCConn(SQLHDBC conn) : handle(conn), autoCommit(true) {}
115         ~ODBCConn() {
116             SQLRETURN sr = SQL_SUCCESS;
117             if (!autoCommit)
118                 sr = SQLSetConnectAttr(handle, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)SQL_AUTOCOMMIT_ON, 0);
119             SQLDisconnect(handle);
120             SQLFreeHandle(SQL_HANDLE_DBC,handle);
121             if (!SQL_SUCCEEDED(sr))
122                 throw IOException("Failed to commit connection and return to auto-commit mode.");
123         }
124         operator SQLHDBC() {return handle;}
125         SQLHDBC handle;
126         bool autoCommit;
127     };
128
129     class ODBCStorageService : public StorageService
130     {
131     public:
132         ODBCStorageService(const DOMElement* e);
133         virtual ~ODBCStorageService();
134
135         const Capabilities& getCapabilities() const {
136             return m_caps;
137         }
138
139         bool createString(const char* context, const char* key, const char* value, time_t expiration) {
140             return createRow(STRING_TABLE, context, key, value, expiration);
141         }
142         int readString(const char* context, const char* key, string* pvalue=nullptr, time_t* pexpiration=nullptr, int version=0) {
143             return readRow(STRING_TABLE, context, key, pvalue, pexpiration, version, false);
144         }
145         int updateString(const char* context, const char* key, const char* value=nullptr, time_t expiration=0, int version=0) {
146             return updateRow(STRING_TABLE, context, key, value, expiration, version);
147         }
148         bool deleteString(const char* context, const char* key) {
149             return deleteRow(STRING_TABLE, context, key);
150         }
151
152         bool createText(const char* context, const char* key, const char* value, time_t expiration) {
153             return createRow(TEXT_TABLE, context, key, value, expiration);
154         }
155         int readText(const char* context, const char* key, string* pvalue=nullptr, time_t* pexpiration=nullptr, int version=0) {
156             return readRow(TEXT_TABLE, context, key, pvalue, pexpiration, version, true);
157         }
158         int updateText(const char* context, const char* key, const char* value=nullptr, time_t expiration=0, int version=0) {
159             return updateRow(TEXT_TABLE, context, key, value, expiration, version);
160         }
161         bool deleteText(const char* context, const char* key) {
162             return deleteRow(TEXT_TABLE, context, key);
163         }
164
165         void reap(const char* context) {
166             reap(STRING_TABLE, context);
167             reap(TEXT_TABLE, context);
168         }
169
170         void updateContext(const char* context, time_t expiration) {
171             updateContext(STRING_TABLE, context, expiration);
172             updateContext(TEXT_TABLE, context, expiration);
173         }
174
175         void deleteContext(const char* context) {
176             deleteContext(STRING_TABLE, context);
177             deleteContext(TEXT_TABLE, context);
178         }
179          
180
181     private:
182         bool createRow(const char *table, const char* context, const char* key, const char* value, time_t expiration);
183         int readRow(const char *table, const char* context, const char* key, string* pvalue, time_t* pexpiration, int version, bool text);
184         int updateRow(const char *table, const char* context, const char* key, const char* value, time_t expiration, int version);
185         bool deleteRow(const char *table, const char* context, const char* key);
186
187         void reap(const char* table, const char* context);
188         void updateContext(const char* table, const char* context, time_t expiration);
189         void deleteContext(const char* table, const char* context);
190
191         SQLHDBC getHDBC();
192         SQLHSTMT getHSTMT(SQLHDBC);
193         pair<SQLINTEGER,SQLINTEGER> getVersion(SQLHDBC);
194         pair<bool,bool> log_error(SQLHANDLE handle, SQLSMALLINT htype, const char* checkfor=nullptr);
195
196         static void* cleanup_fn(void*); 
197         void cleanup();
198
199         Category& m_log;
200         Capabilities m_caps;
201         int m_cleanupInterval;
202         scoped_ptr<CondWait> shutdown_wait;
203         Thread* cleanup_thread;
204         bool shutdown;
205
206         SQLHENV m_henv;
207         string m_connstring;
208         long m_isolation;
209         vector<SQLINTEGER> m_retries;
210     };
211
212     StorageService* ODBCStorageServiceFactory(const DOMElement* const & e)
213     {
214         return new ODBCStorageService(e);
215     }
216
217     // convert SQL timestamp to time_t 
218     time_t timeFromTimestamp(SQL_TIMESTAMP_STRUCT expires)
219     {
220         time_t ret;
221         struct tm t;
222         t.tm_sec=expires.second;
223         t.tm_min=expires.minute;
224         t.tm_hour=expires.hour;
225         t.tm_mday=expires.day;
226         t.tm_mon=expires.month-1;
227         t.tm_year=expires.year-1900;
228         t.tm_isdst=0;
229 #if defined(HAVE_TIMEGM)
230         ret = timegm(&t);
231 #else
232         ret = mktime(&t) - timezone;
233 #endif
234         return (ret);
235     }
236
237     // conver time_t to SQL string
238     void timestampFromTime(time_t t, char* ret)
239     {
240 #ifdef HAVE_GMTIME_R
241         struct tm res;
242         struct tm* ptime=gmtime_r(&t,&res);
243 #else
244         struct tm* ptime=gmtime(&t);
245 #endif
246         strftime(ret,32,"{ts '%Y-%m-%d %H:%M:%S'}",ptime);
247     }
248
249     class SQLString {
250         const char* m_src;
251         string m_copy;
252     public:
253         SQLString(const char* src) : m_src(src) {
254             if (strchr(src, '\'')) {
255                 m_copy = src;
256                 replace_all(m_copy, "'", "''");
257             }
258         }
259
260         operator const char*() const {
261             return tostr();
262         }
263
264         const char* tostr() const {
265             return m_copy.empty() ? m_src : m_copy.c_str();
266         }
267     };
268 };
269
270 ODBCStorageService::ODBCStorageService(const DOMElement* e) : m_log(Category::getInstance("XMLTooling.StorageService")),
271     m_caps(XMLHelper::getAttrInt(e, 255, contextSize), XMLHelper::getAttrInt(e, 255, keySize), XMLHelper::getAttrInt(e, 255, stringSize)),
272     m_cleanupInterval(XMLHelper::getAttrInt(e, 900, cleanupInterval)),
273     cleanup_thread(nullptr), shutdown(false), m_henv(SQL_NULL_HANDLE), m_isolation(SQL_TXN_SERIALIZABLE)
274 {
275 #ifdef _DEBUG
276     xmltooling::NDC ndc("ODBCStorageService");
277 #endif
278     string iso(XMLHelper::getAttrString(e, "SERIALIZABLE", isolationLevel));
279     if (iso == "SERIALIZABLE")
280         m_isolation = SQL_TXN_SERIALIZABLE;
281     else if (iso == "REPEATABLE_READ")
282         m_isolation = SQL_TXN_REPEATABLE_READ;
283     else if (iso == "READ_COMMITTED")
284         m_isolation = SQL_TXN_READ_COMMITTED;
285     else if (iso == "READ_UNCOMMITTED")
286         m_isolation = SQL_TXN_READ_UNCOMMITTED;
287     else
288         throw XMLToolingException("Unknown transaction isolationLevel property.");
289
290     if (m_henv == SQL_NULL_HANDLE) {
291         // Enable connection pooling.
292         SQLSetEnvAttr(SQL_NULL_HANDLE, SQL_ATTR_CONNECTION_POOLING, (void*)SQL_CP_ONE_PER_HENV, 0);
293
294         // Allocate the environment.
295         if (!SQL_SUCCEEDED(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &m_henv)))
296             throw XMLToolingException("ODBC failed to initialize.");
297
298         // Specify ODBC 3.x
299         SQLSetEnvAttr(m_henv, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0);
300
301         m_log.info("ODBC initialized");
302     }
303
304     // Grab connection string from the configuration.
305     e = e ? XMLHelper::getFirstChildElement(e, ConnectionString) : nullptr;
306     auto_ptr_char arg(e ? e->getTextContent() : nullptr);
307     if (!arg.get() || !*arg.get()) {
308         SQLFreeHandle(SQL_HANDLE_ENV, m_henv);
309         throw XMLToolingException("ODBC StorageService requires ConnectionString element in configuration.");
310     }
311     m_connstring = arg.get();
312
313     // Connect and check version.
314     ODBCConn conn(getHDBC());
315     pair<SQLINTEGER,SQLINTEGER> v = getVersion(conn);
316
317     // Make sure we've got the right version.
318     if (v.first != PLUGIN_VER_MAJOR) {
319         SQLFreeHandle(SQL_HANDLE_ENV, m_henv);
320         m_log.crit("unknown database version: %d.%d", v.first, v.second);
321         throw XMLToolingException("Unknown database version for ODBC StorageService.");
322     }
323
324     // Load any retry errors to check.
325     e = XMLHelper::getNextSiblingElement(e, RetryOnError);
326     while (e) {
327         if (e->hasChildNodes()) {
328             m_retries.push_back(XMLString::parseInt(e->getTextContent()));
329             m_log.info("will retry operations when native ODBC error (%ld) is returned", m_retries.back());
330         }
331         e = XMLHelper::getNextSiblingElement(e, RetryOnError);
332     }
333
334     // Initialize the cleanup thread
335     shutdown_wait.reset(CondWait::create());
336     cleanup_thread = Thread::create(&cleanup_fn, (void*)this);
337 }
338
339 ODBCStorageService::~ODBCStorageService()
340 {
341     shutdown = true;
342     shutdown_wait->signal();
343     cleanup_thread->join(nullptr);
344     if (m_henv != SQL_NULL_HANDLE)
345         SQLFreeHandle(SQL_HANDLE_ENV, m_henv);
346 }
347
348 pair<bool,bool> ODBCStorageService::log_error(SQLHANDLE handle, SQLSMALLINT htype, const char* checkfor)
349 {
350     SQLSMALLINT  i = 0;
351     SQLINTEGER   native;
352     SQLCHAR      state[7];
353     SQLCHAR      text[256];
354     SQLSMALLINT  len;
355     SQLRETURN    ret;
356
357     pair<bool,bool> res = make_pair(false,false);
358     do {
359         ret = SQLGetDiagRec(htype, handle, ++i, state, &native, text, sizeof(text), &len);
360         if (SQL_SUCCEEDED(ret)) {
361             m_log.error("ODBC Error: %s:%ld:%ld:%s", state, i, native, text);
362             for (vector<SQLINTEGER>::const_iterator n = m_retries.begin(); !res.first && n != m_retries.end(); ++n)
363                 res.first = (*n == native);
364             if (checkfor && !strcmp(checkfor, (const char*)state))
365                 res.second = true;
366         }
367     } while(SQL_SUCCEEDED(ret));
368     return res;
369 }
370
371 SQLHDBC ODBCStorageService::getHDBC()
372 {
373 #ifdef _DEBUG
374     xmltooling::NDC ndc("getHDBC");
375 #endif
376
377     // Get a handle.
378     SQLHDBC handle;
379     SQLRETURN sr=SQLAllocHandle(SQL_HANDLE_DBC, m_henv, &handle);
380     if (!SQL_SUCCEEDED(sr)) {
381         m_log.error("failed to allocate connection handle");
382         log_error(m_henv, SQL_HANDLE_ENV);
383         throw IOException("ODBC StorageService failed to allocate a connection handle.");
384     }
385
386     sr=SQLDriverConnect(handle,nullptr,(SQLCHAR*)m_connstring.c_str(),m_connstring.length(),nullptr,0,nullptr,SQL_DRIVER_NOPROMPT);
387     if (!SQL_SUCCEEDED(sr)) {
388         m_log.error("failed to connect to database");
389         log_error(handle, SQL_HANDLE_DBC);
390         throw IOException("ODBC StorageService failed to connect to database.");
391     }
392
393     sr = SQLSetConnectAttr(handle, SQL_ATTR_TXN_ISOLATION, (SQLPOINTER)m_isolation, 0);
394     if (!SQL_SUCCEEDED(sr))
395         throw IOException("ODBC StorageService failed to set transaction isolation level.");
396
397     return handle;
398 }
399
400 SQLHSTMT ODBCStorageService::getHSTMT(SQLHDBC conn)
401 {
402     SQLHSTMT hstmt;
403     SQLRETURN sr = SQLAllocHandle(SQL_HANDLE_STMT, conn, &hstmt);
404     if (!SQL_SUCCEEDED(sr)) {
405         m_log.error("failed to allocate statement handle");
406         log_error(conn, SQL_HANDLE_DBC);
407         throw IOException("ODBC StorageService failed to allocate a statement handle.");
408     }
409     return hstmt;
410 }
411
412 pair<SQLINTEGER,SQLINTEGER> ODBCStorageService::getVersion(SQLHDBC conn)
413 {
414     // Grab the version number from the database.
415     SQLHSTMT stmt = getHSTMT(conn);
416     
417     SQLRETURN sr = SQLExecDirect(stmt, (SQLCHAR*)"SELECT major,minor FROM version", SQL_NTS);
418     if (!SQL_SUCCEEDED(sr)) {
419         m_log.error("failed to read version from database");
420         log_error(stmt, SQL_HANDLE_STMT);
421         throw IOException("ODBC StorageService failed to read version from database.");
422     }
423
424     SQLINTEGER major;
425     SQLINTEGER minor;
426     SQLBindCol(stmt, 1, SQL_C_SLONG, &major, 0, nullptr);
427     SQLBindCol(stmt, 2, SQL_C_SLONG, &minor, 0, nullptr);
428
429     if ((sr = SQLFetch(stmt)) != SQL_NO_DATA)
430         return make_pair(major,minor);
431
432     m_log.error("no rows returned in version query");
433     throw IOException("ODBC StorageService failed to read version from database.");
434 }
435
436 bool ODBCStorageService::createRow(const char* table, const char* context, const char* key, const char* value, time_t expiration)
437 {
438 #ifdef _DEBUG
439     xmltooling::NDC ndc("createRow");
440 #endif
441
442     char timebuf[32];
443     timestampFromTime(expiration, timebuf);
444
445     // Get statement handle.
446     ODBCConn conn(getHDBC());
447     SQLHSTMT stmt = getHSTMT(conn);
448
449     string q  = string("INSERT INTO ") + table + " VALUES (?,?," + timebuf + ",1,?)";
450
451     SQLRETURN sr = SQLPrepare(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
452     if (!SQL_SUCCEEDED(sr)) {
453         m_log.error("SQLPrepare failed (t=%s, c=%s, k=%s)", table, context, key);
454         log_error(stmt, SQL_HANDLE_STMT);
455         throw IOException("ODBC StorageService failed to insert record.");
456     }
457     m_log.debug("SQLPrepare succeeded. SQL: %s", q.c_str());
458
459     SQLLEN b_ind = SQL_NTS;
460     sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(context), &b_ind);
461     if (!SQL_SUCCEEDED(sr)) {
462         m_log.error("SQLBindParam failed (context = %s)", context);
463         log_error(stmt, SQL_HANDLE_STMT);
464         throw IOException("ODBC StorageService failed to insert record.");
465     }
466     m_log.debug("SQLBindParam succeeded (context = %s)", context);
467
468     sr = SQLBindParam(stmt, 2, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(key), &b_ind);
469     if (!SQL_SUCCEEDED(sr)) {
470         m_log.error("SQLBindParam failed (key = %s)", key);
471         log_error(stmt, SQL_HANDLE_STMT);
472         throw IOException("ODBC StorageService failed to insert record.");
473     }
474     m_log.debug("SQLBindParam succeeded (key = %s)", key);
475
476     if (strcmp(table, TEXT_TABLE)==0)
477         sr = SQLBindParam(stmt, 3, SQL_C_CHAR, SQL_LONGVARCHAR, strlen(value), 0, const_cast<char*>(value), &b_ind);
478     else
479         sr = SQLBindParam(stmt, 3, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(value), &b_ind);
480     if (!SQL_SUCCEEDED(sr)) {
481         m_log.error("SQLBindParam failed (value = %s)", value);
482         log_error(stmt, SQL_HANDLE_STMT);
483         throw IOException("ODBC StorageService failed to insert record.");
484     }
485     m_log.debug("SQLBindParam succeeded (value = %s)", value);
486     
487     int attempts = 3;
488     pair<bool,bool> logres;
489     do {
490         logres = make_pair(false,false);
491         attempts--;
492         sr = SQLExecute(stmt);
493         if (SQL_SUCCEEDED(sr)) {
494             m_log.debug("SQLExecute of insert succeeded");
495             return true;
496         }
497         m_log.error("insert record failed (t=%s, c=%s, k=%s)", table, context, key);
498         logres = log_error(stmt, SQL_HANDLE_STMT, "23000");
499         if (logres.second)
500             return false;   // supposedly integrity violation?
501     } while (attempts && logres.first);
502
503     throw IOException("ODBC StorageService failed to insert record.");
504 }
505
506 int ODBCStorageService::readRow(
507     const char *table, const char* context, const char* key, string* pvalue, time_t* pexpiration, int version, bool text
508     )
509 {
510 #ifdef _DEBUG
511     xmltooling::NDC ndc("readRow");
512 #endif
513
514     // Get statement handle.
515     ODBCConn conn(getHDBC());
516     SQLHSTMT stmt = getHSTMT(conn);
517
518     // Prepare and exectute select statement.
519     char timebuf[32];
520     timestampFromTime(time(nullptr), timebuf);
521     SQLString scontext(context);
522     SQLString skey(key);
523     string q("SELECT version");
524     if (pexpiration)
525         q += ",expires";
526     if (pvalue)
527         q = q + ",CASE version WHEN " + lexical_cast<string>(version) + " THEN null ELSE value END";
528     q = q + " FROM " + table + " WHERE context='" + scontext.tostr() + "' AND id='" + skey.tostr() + "' AND expires > " + timebuf;
529     if (m_log.isDebugEnabled())
530         m_log.debug("SQL: %s", q.c_str());
531
532     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
533     if (!SQL_SUCCEEDED(sr)) {
534         m_log.error("error searching for (t=%s, c=%s, k=%s)", table, context, key);
535         log_error(stmt, SQL_HANDLE_STMT);
536         throw IOException("ODBC StorageService search failed.");
537     }
538
539     SQLSMALLINT ver;
540     SQL_TIMESTAMP_STRUCT expiration;
541
542     SQLBindCol(stmt, 1, SQL_C_SSHORT, &ver, 0, nullptr);
543     if (pexpiration)
544         SQLBindCol(stmt, 2, SQL_C_TYPE_TIMESTAMP, &expiration, 0, nullptr);
545
546     if ((sr = SQLFetch(stmt)) == SQL_NO_DATA)
547         return 0;
548
549     if (pexpiration)
550         *pexpiration = timeFromTimestamp(expiration);
551
552     if (version == ver)
553         return version; // nothing's changed, so just echo back the version
554
555     if (pvalue) {
556         SQLLEN len;
557         SQLCHAR buf[LONGDATA_BUFLEN];
558         while ((sr = SQLGetData(stmt, (pexpiration ? 3 : 2), SQL_C_CHAR, buf, sizeof(buf), &len)) != SQL_NO_DATA) {
559             if (!SQL_SUCCEEDED(sr)) {
560                 m_log.error("error while reading text field from result set");
561                 log_error(stmt, SQL_HANDLE_STMT);
562                 throw IOException("ODBC StorageService search failed to read data from result set.");
563             }
564             pvalue->append((char*)buf);
565         }
566     }
567     
568     return ver;
569 }
570
571 int ODBCStorageService::updateRow(const char *table, const char* context, const char* key, const char* value, time_t expiration, int version)
572 {
573 #ifdef _DEBUG
574     xmltooling::NDC ndc("updateRow");
575 #endif
576
577     if (!value && !expiration)
578         throw IOException("ODBC StorageService given invalid update instructions.");
579
580     // Get statement handle. Disable auto-commit mode to wrap select + update.
581     ODBCConn conn(getHDBC());
582     SQLRETURN sr = SQLSetConnectAttr(conn, SQL_ATTR_AUTOCOMMIT, SQL_AUTOCOMMIT_OFF, 0);
583     if (!SQL_SUCCEEDED(sr))
584         throw IOException("ODBC StorageService failed to disable auto-commit mode.");
585     conn.autoCommit = false;
586     SQLHSTMT stmt = getHSTMT(conn);
587
588     // First, fetch the current version for later, which also ensures the record still exists.
589     char timebuf[32];
590     timestampFromTime(time(nullptr), timebuf);
591     SQLString scontext(context);
592     SQLString skey(key);
593     string q("SELECT version FROM ");
594     q = q + table + " WHERE context='" + scontext.tostr() + "' AND id='" + skey.tostr() + "' AND expires > " + timebuf;
595
596     m_log.debug("SQL: %s", q.c_str());
597
598     sr = SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
599     if (!SQL_SUCCEEDED(sr)) {
600         m_log.error("error searching for (t=%s, c=%s, k=%s)", table, context, key);
601         log_error(stmt, SQL_HANDLE_STMT);
602         throw IOException("ODBC StorageService search failed.");
603     }
604
605     SQLSMALLINT ver;
606     SQLBindCol(stmt, 1, SQL_C_SSHORT, &ver, 0, nullptr);
607     if ((sr = SQLFetch(stmt)) == SQL_NO_DATA) {
608         return 0;
609     }
610
611     // Check version?
612     if (version > 0 && version != ver) {
613         return -1;
614     }
615
616     SQLFreeHandle(SQL_HANDLE_STMT, stmt);
617     stmt = getHSTMT(conn);
618
619     // Prepare and exectute update statement.
620     q = string("UPDATE ") + table + " SET ";
621
622     if (value)
623         q = q + "value=?, version=version+1";
624
625     if (expiration) {
626         timestampFromTime(expiration, timebuf);
627         if (value)
628             q += ',';
629         q = q + "expires = " + timebuf;
630     }
631
632     q = q + " WHERE context='" + scontext.tostr() + "' AND id='" + skey.tostr() + "'";
633
634     sr = SQLPrepare(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
635     if (!SQL_SUCCEEDED(sr)) {
636         m_log.error("update of record failed (t=%s, c=%s, k=%s", table, context, key);
637         log_error(stmt, SQL_HANDLE_STMT);
638         throw IOException("ODBC StorageService failed to update record.");
639     }
640     m_log.debug("SQLPrepare succeeded. SQL: %s", q.c_str());
641
642     SQLLEN b_ind = SQL_NTS;
643     if (value) {
644         if (strcmp(table, TEXT_TABLE)==0)
645             sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_LONGVARCHAR, strlen(value), 0, const_cast<char*>(value), &b_ind);
646         else
647             sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(value), &b_ind);
648         if (!SQL_SUCCEEDED(sr)) {
649             m_log.error("SQLBindParam failed (context = %s)", context);
650             log_error(stmt, SQL_HANDLE_STMT);
651             throw IOException("ODBC StorageService failed to update record.");
652         }
653         m_log.debug("SQLBindParam succeeded (context = %s)", context);
654     }
655
656     int attempts = 3;
657     pair<bool,bool> logres;
658     do {
659         logres = make_pair(false,false);
660         attempts--;
661         sr = SQLExecute(stmt);
662         if (sr == SQL_NO_DATA)
663             return 0;   // went missing?
664         else if (SQL_SUCCEEDED(sr)) {
665             m_log.debug("SQLExecute of update succeeded");
666             return ver + 1;
667         }
668
669         m_log.error("update of record failed (t=%s, c=%s, k=%s", table, context, key);
670         logres = log_error(stmt, SQL_HANDLE_STMT);
671     } while (attempts && logres.first);
672
673     throw IOException("ODBC StorageService failed to update record.");
674 }
675
676 bool ODBCStorageService::deleteRow(const char *table, const char *context, const char* key)
677 {
678 #ifdef _DEBUG
679     xmltooling::NDC ndc("deleteRow");
680 #endif
681
682     // Get statement handle.
683     ODBCConn conn(getHDBC());
684     SQLHSTMT stmt = getHSTMT(conn);
685
686     // Prepare and execute delete statement.
687     SQLString scontext(context);
688     SQLString skey(key);
689     string q = string("DELETE FROM ") + table + " WHERE context='" + scontext.tostr() + "' AND id='" + skey.tostr() + "'";
690     m_log.debug("SQL: %s", q.c_str());
691
692     SQLRETURN sr = SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
693      if (sr == SQL_NO_DATA)
694         return false;
695     else if (!SQL_SUCCEEDED(sr)) {
696         m_log.error("error deleting record (t=%s, c=%s, k=%s)", table, context, key);
697         log_error(stmt, SQL_HANDLE_STMT);
698         throw IOException("ODBC StorageService failed to delete record.");
699     }
700
701     return true;
702 }
703
704
705 void ODBCStorageService::cleanup()
706 {
707 #ifdef _DEBUG
708     xmltooling::NDC ndc("cleanup");
709 #endif
710
711     scoped_ptr<Mutex> mutex(Mutex::create());
712
713     mutex->lock();
714
715     m_log.info("cleanup thread started... running every %d secs", m_cleanupInterval);
716
717     while (!shutdown) {
718         shutdown_wait->timedwait(mutex.get(), m_cleanupInterval);
719         if (shutdown)
720             break;
721         try {
722             reap(nullptr);
723         }
724         catch (std::exception& ex) {
725             m_log.error("cleanup thread swallowed exception: %s", ex.what());
726         }
727     }
728
729     m_log.info("cleanup thread exiting...");
730
731     mutex->unlock();
732     Thread::exit(nullptr);
733 }
734
735 void* ODBCStorageService::cleanup_fn(void* cache_p)
736 {
737   ODBCStorageService* cache = (ODBCStorageService*)cache_p;
738
739 #ifndef WIN32
740   // First, let's block all signals
741   Thread::mask_all_signals();
742 #endif
743
744   // Now run the cleanup process.
745   cache->cleanup();
746   return nullptr;
747 }
748
749 void ODBCStorageService::updateContext(const char *table, const char* context, time_t expiration)
750 {
751 #ifdef _DEBUG
752     xmltooling::NDC ndc("updateContext");
753 #endif
754
755     // Get statement handle.
756     ODBCConn conn(getHDBC());
757     SQLHSTMT stmt = getHSTMT(conn);
758
759     char timebuf[32];
760     timestampFromTime(expiration, timebuf);
761
762     char nowbuf[32];
763     timestampFromTime(time(nullptr), nowbuf);
764
765     SQLString scontext(context);
766     string q = string("UPDATE ") + table + " SET expires = " + timebuf + " WHERE context='" + scontext.tostr() + "' AND expires > " + nowbuf;
767
768     m_log.debug("SQL: %s", q.c_str());
769
770     SQLRETURN sr = SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
771     if ((sr != SQL_NO_DATA) && !SQL_SUCCEEDED(sr)) {
772         m_log.error("error updating records (t=%s, c=%s)", table, context ? context : "all");
773         log_error(stmt, SQL_HANDLE_STMT);
774         throw IOException("ODBC StorageService failed to update context expiration.");
775     }
776 }
777
778 void ODBCStorageService::reap(const char *table, const char* context)
779 {
780 #ifdef _DEBUG
781     xmltooling::NDC ndc("reap");
782 #endif
783
784     // Get statement handle.
785     ODBCConn conn(getHDBC());
786     SQLHSTMT stmt = getHSTMT(conn);
787
788     // Prepare and execute delete statement.
789     char nowbuf[32];
790     timestampFromTime(time(nullptr), nowbuf);
791     string q;
792     if (context) {
793         SQLString scontext(context);
794         q = string("DELETE FROM ") + table + " WHERE context='" + scontext.tostr() + "' AND expires <= " + nowbuf;
795     }
796     else {
797         q = string("DELETE FROM ") + table + " WHERE expires <= " + nowbuf;
798     }
799     m_log.debug("SQL: %s", q.c_str());
800
801     SQLRETURN sr = SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
802     if ((sr != SQL_NO_DATA) && !SQL_SUCCEEDED(sr)) {
803         m_log.error("error expiring records (t=%s, c=%s)", table, context ? context : "all");
804         log_error(stmt, SQL_HANDLE_STMT);
805         throw IOException("ODBC StorageService failed to purge expired records.");
806     }
807 }
808
809 void ODBCStorageService::deleteContext(const char *table, const char* context)
810 {
811 #ifdef _DEBUG
812     xmltooling::NDC ndc("deleteContext");
813 #endif
814
815     // Get statement handle.
816     ODBCConn conn(getHDBC());
817     SQLHSTMT stmt = getHSTMT(conn);
818
819     // Prepare and execute delete statement.
820     SQLString scontext(context);
821     string q = string("DELETE FROM ") + table + " WHERE context='" + scontext.tostr() + "'";
822     m_log.debug("SQL: %s", q.c_str());
823
824     SQLRETURN sr = SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
825     if ((sr != SQL_NO_DATA) && !SQL_SUCCEEDED(sr)) {
826         m_log.error("error deleting context (t=%s, c=%s)", table, context);
827         log_error(stmt, SQL_HANDLE_STMT);
828         throw IOException("ODBC StorageService failed to delete context.");
829     }
830 }
831
832 extern "C" int ODBCSTORE_EXPORTS xmltooling_extension_init(void*)
833 {
834     // Register this SS type
835     XMLToolingConfig::getConfig().StorageServiceManager.registerFactory("ODBC", ODBCStorageServiceFactory);
836     return 0;
837 }
838
839 extern "C" void ODBCSTORE_EXPORTS xmltooling_extension_term()
840 {
841     XMLToolingConfig::getConfig().StorageServiceManager.deregisterFactory("ODBC");
842 }