4357276144d80516488ba0fe362d2686c08380b8
[shibboleth/cpp-sp.git] / odbc-store / odbc-store.cpp
1 /**
2  * Licensed to the University Corporation for Advanced Internet
3  * Development, Inc. (UCAID) under one or more contributor license
4  * agreements. See the NOTICE file distributed with this work for
5  * additional information regarding copyright ownership.
6  *
7  * UCAID licenses this file to you under the Apache License,
8  * Version 2.0 (the "License"); you may not use this file except
9  * in compliance with the License. You may obtain a copy of the
10  * License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing,
15  * software distributed under the License is distributed on an
16  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
17  * either express or implied. See the License for the specific
18  * language governing permissions and limitations under the License.
19  */
20
21 /**
22  * odbc-store.cpp
23  *
24  * Storage Service using ODBC.
25  */
26
27 #if defined (_MSC_VER) || defined(__BORLANDC__)
28 # include "config_win32.h"
29 #else
30 # include "config.h"
31 #endif
32
33 #ifdef WIN32
34 # define _CRT_NONSTDC_NO_DEPRECATE 1
35 # define _CRT_SECURE_NO_DEPRECATE 1
36 #endif
37
38 #ifdef WIN32
39 # define ODBCSTORE_EXPORTS __declspec(dllexport)
40 #else
41 # define ODBCSTORE_EXPORTS
42 #endif
43
44 #include <xmltooling/logging.h>
45 #include <xmltooling/unicode.h>
46 #include <xmltooling/XMLToolingConfig.h>
47 #include <xmltooling/util/NDC.h>
48 #include <xmltooling/util/StorageService.h>
49 #include <xmltooling/util/Threads.h>
50 #include <xmltooling/util/XMLHelper.h>
51 #include <xercesc/util/XMLUniDefs.hpp>
52
53 #include <sql.h>
54 #include <sqlext.h>
55
56 #include <boost/lexical_cast.hpp>
57 #include <boost/algorithm/string.hpp>
58
59 using namespace xmltooling::logging;
60 using namespace xmltooling;
61 using namespace xercesc;
62 using namespace boost;
63 using namespace std;
64
65 #define PLUGIN_VER_MAJOR 1
66 #define PLUGIN_VER_MINOR 0
67
68 #define LONGDATA_BUFLEN 16384
69
70 #define COLSIZE_CONTEXT 255
71 #define COLSIZE_ID 255
72 #define COLSIZE_STRING_VALUE 255
73
74 #define STRING_TABLE "strings"
75 #define TEXT_TABLE "texts"
76
77 /* table definitions
78 CREATE TABLE version (
79     major int NOT nullptr,
80     minor int NOT nullptr
81     )
82
83 CREATE TABLE strings (
84     context varchar(255) not null,
85     id varchar(255) not null,
86     expires datetime not null,
87     version smallint not null,
88     value varchar(255) not null,
89     PRIMARY KEY (context, id)
90     )
91
92 CREATE TABLE texts (
93     context varchar(255) not null,
94     id varchar(255) not null,
95     expires datetime not null,
96     version smallint not null,
97     value text not null,
98     PRIMARY KEY (context, id)
99     )
100 */
101
102 namespace {
103     static const XMLCh cleanupInterval[] =  UNICODE_LITERAL_15(c,l,e,a,n,u,p,I,n,t,e,r,v,a,l);
104     static const XMLCh isolationLevel[] =   UNICODE_LITERAL_14(i,s,o,l,a,t,i,o,n,L,e,v,e,l);
105     static const XMLCh ConnectionString[] = UNICODE_LITERAL_16(C,o,n,n,e,c,t,i,o,n,S,t,r,i,n,g);
106     static const XMLCh RetryOnError[] =     UNICODE_LITERAL_12(R,e,t,r,y,O,n,E,r,r,o,r);
107     static const XMLCh contextSize[] =      UNICODE_LITERAL_11(c,o,n,t,e,x,t,S,i,z,e);
108     static const XMLCh keySize[] =          UNICODE_LITERAL_7(k,e,y,S,i,z,e);
109     static const XMLCh stringSize[] =       UNICODE_LITERAL_10(s,t,r,i,n,g,S,i,z,e);
110
111     // RAII for ODBC handles
112     struct ODBCConn {
113         ODBCConn(SQLHDBC conn) : handle(conn), autoCommit(true) {}
114         ~ODBCConn() {
115             if (handle != SQL_NULL_HDBC) {
116                 SQLRETURN sr = SQL_SUCCESS;
117                 if (!autoCommit)
118                     sr = SQLSetConnectAttr(handle, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)SQL_AUTOCOMMIT_ON, 0);
119                 SQLDisconnect(handle);
120                 SQLFreeHandle(SQL_HANDLE_DBC, handle);
121                 if (!SQL_SUCCEEDED(sr))
122                     throw IOException("Failed to commit connection and return to auto-commit mode.");
123             }
124         }
125         operator SQLHDBC() {return handle;}
126         SQLHDBC handle;
127         bool autoCommit;
128     };
129
130     class ODBCStorageService : public StorageService
131     {
132     public:
133         ODBCStorageService(const DOMElement* e);
134         virtual ~ODBCStorageService();
135
136         const Capabilities& getCapabilities() const {
137             return m_caps;
138         }
139
140         bool createString(const char* context, const char* key, const char* value, time_t expiration) {
141             return createRow(STRING_TABLE, context, key, value, expiration);
142         }
143         int readString(const char* context, const char* key, string* pvalue=nullptr, time_t* pexpiration=nullptr, int version=0) {
144             return readRow(STRING_TABLE, context, key, pvalue, pexpiration, version);
145         }
146         int updateString(const char* context, const char* key, const char* value=nullptr, time_t expiration=0, int version=0) {
147             return updateRow(STRING_TABLE, context, key, value, expiration, version);
148         }
149         bool deleteString(const char* context, const char* key) {
150             return deleteRow(STRING_TABLE, context, key);
151         }
152
153         bool createText(const char* context, const char* key, const char* value, time_t expiration) {
154             return createRow(TEXT_TABLE, context, key, value, expiration);
155         }
156         int readText(const char* context, const char* key, string* pvalue=nullptr, time_t* pexpiration=nullptr, int version=0) {
157             return readRow(TEXT_TABLE, context, key, pvalue, pexpiration, version);
158         }
159         int updateText(const char* context, const char* key, const char* value=nullptr, time_t expiration=0, int version=0) {
160             return updateRow(TEXT_TABLE, context, key, value, expiration, version);
161         }
162         bool deleteText(const char* context, const char* key) {
163             return deleteRow(TEXT_TABLE, context, key);
164         }
165
166         void reap(const char* context) {
167             reap(STRING_TABLE, context);
168             reap(TEXT_TABLE, context);
169         }
170
171         void updateContext(const char* context, time_t expiration) {
172             updateContext(STRING_TABLE, context, expiration);
173             updateContext(TEXT_TABLE, context, expiration);
174         }
175
176         void deleteContext(const char* context) {
177             deleteContext(STRING_TABLE, context);
178             deleteContext(TEXT_TABLE, context);
179         }
180          
181
182     private:
183         bool createRow(const char *table, const char* context, const char* key, const char* value, time_t expiration);
184         int readRow(const char *table, const char* context, const char* key, string* pvalue, time_t* pexpiration, int version);
185         int updateRow(const char *table, const char* context, const char* key, const char* value, time_t expiration, int version);
186         bool deleteRow(const char *table, const char* context, const char* key);
187
188         void reap(const char* table, const char* context);
189         void updateContext(const char* table, const char* context, time_t expiration);
190         void deleteContext(const char* table, const char* context);
191
192         SQLHDBC getHDBC();
193         SQLHSTMT getHSTMT(SQLHDBC);
194         pair<SQLINTEGER,SQLINTEGER> getVersion(SQLHDBC);
195         pair<bool,bool> log_error(SQLHANDLE handle, SQLSMALLINT htype, const char* checkfor=nullptr);
196
197         static void* cleanup_fn(void*); 
198         void cleanup();
199
200         Category& m_log;
201         Capabilities m_caps;
202         int m_cleanupInterval;
203         scoped_ptr<CondWait> shutdown_wait;
204         Thread* cleanup_thread;
205         bool shutdown;
206
207         SQLHENV m_henv;
208         string m_connstring;
209         long m_isolation;
210         vector<SQLINTEGER> m_retries;
211     };
212
213     StorageService* ODBCStorageServiceFactory(const DOMElement* const & e)
214     {
215         return new ODBCStorageService(e);
216     }
217
218     // convert SQL timestamp to time_t 
219     time_t timeFromTimestamp(SQL_TIMESTAMP_STRUCT expires)
220     {
221         time_t ret;
222         struct tm t;
223         t.tm_sec=expires.second;
224         t.tm_min=expires.minute;
225         t.tm_hour=expires.hour;
226         t.tm_mday=expires.day;
227         t.tm_mon=expires.month-1;
228         t.tm_year=expires.year-1900;
229         t.tm_isdst=0;
230 #if defined(HAVE_TIMEGM)
231         ret = timegm(&t);
232 #else
233         ret = mktime(&t) - timezone;
234 #endif
235         return (ret);
236     }
237
238     // conver time_t to SQL string
239     void timestampFromTime(time_t t, char* ret)
240     {
241 #ifdef HAVE_GMTIME_R
242         struct tm res;
243         struct tm* ptime=gmtime_r(&t,&res);
244 #else
245         struct tm* ptime=gmtime(&t);
246 #endif
247         strftime(ret,32,"{ts '%Y-%m-%d %H:%M:%S'}",ptime);
248     }
249
250     class SQLString {
251         const char* m_src;
252         string m_copy;
253     public:
254         SQLString(const char* src) : m_src(src) {
255             if (strchr(src, '\'')) {
256                 m_copy = src;
257                 replace_all(m_copy, "'", "''");
258             }
259         }
260
261         operator const char*() const {
262             return tostr();
263         }
264
265         const char* tostr() const {
266             return m_copy.empty() ? m_src : m_copy.c_str();
267         }
268     };
269 };
270
271 ODBCStorageService::ODBCStorageService(const DOMElement* e) : m_log(Category::getInstance("XMLTooling.StorageService")),
272     m_caps(XMLHelper::getAttrInt(e, 255, contextSize), XMLHelper::getAttrInt(e, 255, keySize), XMLHelper::getAttrInt(e, 255, stringSize)),
273     m_cleanupInterval(XMLHelper::getAttrInt(e, 900, cleanupInterval)),
274     cleanup_thread(nullptr), shutdown(false), m_henv(SQL_NULL_HENV), m_isolation(SQL_TXN_SERIALIZABLE)
275 {
276 #ifdef _DEBUG
277     xmltooling::NDC ndc("ODBCStorageService");
278 #endif
279     string iso(XMLHelper::getAttrString(e, "SERIALIZABLE", isolationLevel));
280     if (iso == "SERIALIZABLE")
281         m_isolation = SQL_TXN_SERIALIZABLE;
282     else if (iso == "REPEATABLE_READ")
283         m_isolation = SQL_TXN_REPEATABLE_READ;
284     else if (iso == "READ_COMMITTED")
285         m_isolation = SQL_TXN_READ_COMMITTED;
286     else if (iso == "READ_UNCOMMITTED")
287         m_isolation = SQL_TXN_READ_UNCOMMITTED;
288     else
289         throw XMLToolingException("Unknown transaction isolationLevel property.");
290
291     if (m_henv == SQL_NULL_HENV) {
292         // Enable connection pooling.
293         SQLSetEnvAttr(SQL_NULL_HANDLE, SQL_ATTR_CONNECTION_POOLING, (void*)SQL_CP_ONE_PER_HENV, 0);
294
295         // Allocate the environment.
296         if (!SQL_SUCCEEDED(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &m_henv)))
297             throw XMLToolingException("ODBC failed to initialize.");
298
299         // Specify ODBC 3.x
300         SQLSetEnvAttr(m_henv, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0);
301
302         m_log.info("ODBC initialized");
303     }
304
305     // Grab connection string from the configuration.
306     e = e ? XMLHelper::getFirstChildElement(e, ConnectionString) : nullptr;
307     auto_ptr_char arg(e ? e->getTextContent() : nullptr);
308     if (!arg.get() || !*arg.get()) {
309         SQLFreeHandle(SQL_HANDLE_ENV, m_henv);
310         throw XMLToolingException("ODBC StorageService requires ConnectionString element in configuration.");
311     }
312     m_connstring = arg.get();
313
314     // Connect and check version.
315     ODBCConn conn(getHDBC());
316     pair<SQLINTEGER,SQLINTEGER> v = getVersion(conn);
317
318     // Make sure we've got the right version.
319     if (v.first != PLUGIN_VER_MAJOR) {
320         SQLFreeHandle(SQL_HANDLE_ENV, m_henv);
321         m_log.crit("unknown database version: %d.%d", v.first, v.second);
322         throw XMLToolingException("Unknown database version for ODBC StorageService.");
323     }
324
325     // Load any retry errors to check.
326     e = XMLHelper::getNextSiblingElement(e, RetryOnError);
327     while (e) {
328         if (e->hasChildNodes()) {
329             m_retries.push_back(XMLString::parseInt(e->getTextContent()));
330             m_log.info("will retry operations when native ODBC error (%ld) is returned", m_retries.back());
331         }
332         e = XMLHelper::getNextSiblingElement(e, RetryOnError);
333     }
334
335     // Initialize the cleanup thread
336     shutdown_wait.reset(CondWait::create());
337     cleanup_thread = Thread::create(&cleanup_fn, (void*)this);
338 }
339
340 ODBCStorageService::~ODBCStorageService()
341 {
342     shutdown = true;
343     shutdown_wait->signal();
344     cleanup_thread->join(nullptr);
345     if (m_henv != SQL_NULL_HANDLE)
346         SQLFreeHandle(SQL_HANDLE_ENV, m_henv);
347 }
348
349 pair<bool,bool> ODBCStorageService::log_error(SQLHANDLE handle, SQLSMALLINT htype, const char* checkfor)
350 {
351     SQLSMALLINT  i = 0;
352     SQLINTEGER   native;
353     SQLCHAR      state[7];
354     SQLCHAR      text[256];
355     SQLSMALLINT  len;
356     SQLRETURN    ret;
357
358     pair<bool,bool> res = make_pair(false,false);
359     do {
360         ret = SQLGetDiagRec(htype, handle, ++i, state, &native, text, sizeof(text), &len);
361         if (SQL_SUCCEEDED(ret)) {
362             m_log.error("ODBC Error: %s:%ld:%ld:%s", state, i, native, text);
363             for (vector<SQLINTEGER>::const_iterator n = m_retries.begin(); !res.first && n != m_retries.end(); ++n)
364                 res.first = (*n == native);
365             if (checkfor && !strcmp(checkfor, (const char*)state))
366                 res.second = true;
367         }
368     } while(SQL_SUCCEEDED(ret));
369     return res;
370 }
371
372 SQLHDBC ODBCStorageService::getHDBC()
373 {
374 #ifdef _DEBUG
375     xmltooling::NDC ndc("getHDBC");
376 #endif
377
378     // Get a handle.
379     SQLHDBC handle = SQL_NULL_HDBC;
380     SQLRETURN sr = SQLAllocHandle(SQL_HANDLE_DBC, m_henv, &handle);
381     if (!SQL_SUCCEEDED(sr) || handle == SQL_NULL_HDBC) {
382         m_log.error("failed to allocate connection handle");
383         log_error(m_henv, SQL_HANDLE_ENV);
384         throw IOException("ODBC StorageService failed to allocate a connection handle.");
385     }
386
387     sr = SQLDriverConnect(handle,nullptr,(SQLCHAR*)m_connstring.c_str(),m_connstring.length(),nullptr,0,nullptr,SQL_DRIVER_NOPROMPT);
388     if (!SQL_SUCCEEDED(sr)) {
389         m_log.error("failed to connect to database");
390         log_error(handle, SQL_HANDLE_DBC);
391         SQLFreeHandle(SQL_HANDLE_DBC, handle);
392         throw IOException("ODBC StorageService failed to connect to database.");
393     }
394
395     sr = SQLSetConnectAttr(handle, SQL_ATTR_TXN_ISOLATION, (SQLPOINTER)m_isolation, 0);
396     if (!SQL_SUCCEEDED(sr)) {
397         SQLDisconnect(handle);
398         SQLFreeHandle(SQL_HANDLE_DBC, handle);
399         throw IOException("ODBC StorageService failed to set transaction isolation level.");
400     }
401
402     return handle;
403 }
404
405 SQLHSTMT ODBCStorageService::getHSTMT(SQLHDBC conn)
406 {
407     SQLHSTMT hstmt = SQL_NULL_HSTMT;
408     SQLRETURN sr = SQLAllocHandle(SQL_HANDLE_STMT, conn, &hstmt);
409     if (!SQL_SUCCEEDED(sr) || hstmt == SQL_NULL_HSTMT) {
410         m_log.error("failed to allocate statement handle");
411         log_error(conn, SQL_HANDLE_DBC);
412         throw IOException("ODBC StorageService failed to allocate a statement handle.");
413     }
414     return hstmt;
415 }
416
417 pair<SQLINTEGER,SQLINTEGER> ODBCStorageService::getVersion(SQLHDBC conn)
418 {
419     // Grab the version number from the database.
420     SQLHSTMT stmt = getHSTMT(conn);
421     
422     SQLRETURN sr = SQLExecDirect(stmt, (SQLCHAR*)"SELECT major,minor FROM version", SQL_NTS);
423     if (!SQL_SUCCEEDED(sr)) {
424         m_log.error("failed to read version from database");
425         log_error(stmt, SQL_HANDLE_STMT);
426         throw IOException("ODBC StorageService failed to read version from database.");
427     }
428
429     SQLINTEGER major;
430     SQLINTEGER minor;
431     SQLBindCol(stmt, 1, SQL_C_SLONG, &major, 0, nullptr);
432     SQLBindCol(stmt, 2, SQL_C_SLONG, &minor, 0, nullptr);
433
434     if ((sr = SQLFetch(stmt)) != SQL_NO_DATA)
435         return make_pair(major,minor);
436
437     m_log.error("no rows returned in version query");
438     throw IOException("ODBC StorageService failed to read version from database.");
439 }
440
441 bool ODBCStorageService::createRow(const char* table, const char* context, const char* key, const char* value, time_t expiration)
442 {
443 #ifdef _DEBUG
444     xmltooling::NDC ndc("createRow");
445 #endif
446
447     char timebuf[32];
448     timestampFromTime(expiration, timebuf);
449
450     // Get statement handle.
451     ODBCConn conn(getHDBC());
452     SQLHSTMT stmt = getHSTMT(conn);
453
454     string q  = string("INSERT INTO ") + table + " VALUES (?,?," + timebuf + ",1,?)";
455
456     SQLRETURN sr = SQLPrepare(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
457     if (!SQL_SUCCEEDED(sr)) {
458         m_log.error("SQLPrepare failed (t=%s, c=%s, k=%s)", table, context, key);
459         log_error(stmt, SQL_HANDLE_STMT);
460         throw IOException("ODBC StorageService failed to insert record.");
461     }
462     m_log.debug("SQLPrepare succeeded. SQL: %s", q.c_str());
463
464     SQLLEN b_ind = SQL_NTS;
465     sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(context), &b_ind);
466     if (!SQL_SUCCEEDED(sr)) {
467         m_log.error("SQLBindParam failed (context = %s)", context);
468         log_error(stmt, SQL_HANDLE_STMT);
469         throw IOException("ODBC StorageService failed to insert record.");
470     }
471     m_log.debug("SQLBindParam succeeded (context = %s)", context);
472
473     sr = SQLBindParam(stmt, 2, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(key), &b_ind);
474     if (!SQL_SUCCEEDED(sr)) {
475         m_log.error("SQLBindParam failed (key = %s)", key);
476         log_error(stmt, SQL_HANDLE_STMT);
477         throw IOException("ODBC StorageService failed to insert record.");
478     }
479     m_log.debug("SQLBindParam succeeded (key = %s)", key);
480
481     if (strcmp(table, TEXT_TABLE)==0)
482         sr = SQLBindParam(stmt, 3, SQL_C_CHAR, SQL_LONGVARCHAR, strlen(value), 0, const_cast<char*>(value), &b_ind);
483     else
484         sr = SQLBindParam(stmt, 3, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(value), &b_ind);
485     if (!SQL_SUCCEEDED(sr)) {
486         m_log.error("SQLBindParam failed (value = %s)", value);
487         log_error(stmt, SQL_HANDLE_STMT);
488         throw IOException("ODBC StorageService failed to insert record.");
489     }
490     m_log.debug("SQLBindParam succeeded (value = %s)", value);
491     
492     int attempts = 3;
493     pair<bool,bool> logres;
494     do {
495         logres = make_pair(false,false);
496         attempts--;
497         sr = SQLExecute(stmt);
498         if (SQL_SUCCEEDED(sr)) {
499             m_log.debug("SQLExecute of insert succeeded");
500             return true;
501         }
502         m_log.error("insert record failed (t=%s, c=%s, k=%s)", table, context, key);
503         logres = log_error(stmt, SQL_HANDLE_STMT, "23000");
504         if (logres.second)
505             return false;   // supposedly integrity violation?
506     } while (attempts && logres.first);
507
508     throw IOException("ODBC StorageService failed to insert record.");
509 }
510
511 int ODBCStorageService::readRow(const char *table, const char* context, const char* key, string* pvalue, time_t* pexpiration, int version)
512 {
513 #ifdef _DEBUG
514     xmltooling::NDC ndc("readRow");
515 #endif
516
517     // Get statement handle.
518     ODBCConn conn(getHDBC());
519     SQLHSTMT stmt = getHSTMT(conn);
520
521     // Prepare and exectute select statement.
522     char timebuf[32];
523     timestampFromTime(time(nullptr), timebuf);
524     SQLString scontext(context);
525     SQLString skey(key);
526     string q("SELECT version");
527     if (pexpiration)
528         q += ",expires";
529     if (pvalue) {
530         pvalue->erase();
531         q = q + ",CASE version WHEN " + lexical_cast<string>(version) + " THEN null ELSE value END";
532     }
533     q = q + " FROM " + table + " WHERE context='" + scontext.tostr() + "' AND id='" + skey.tostr() + "' AND expires > " + timebuf;
534     if (m_log.isDebugEnabled())
535         m_log.debug("SQL: %s", q.c_str());
536
537     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
538     if (!SQL_SUCCEEDED(sr)) {
539         m_log.error("error searching for (t=%s, c=%s, k=%s)", table, context, key);
540         log_error(stmt, SQL_HANDLE_STMT);
541         throw IOException("ODBC StorageService search failed.");
542     }
543
544     SQLSMALLINT ver;
545     SQL_TIMESTAMP_STRUCT expiration;
546
547     SQLBindCol(stmt, 1, SQL_C_SSHORT, &ver, 0, nullptr);
548     if (pexpiration)
549         SQLBindCol(stmt, 2, SQL_C_TYPE_TIMESTAMP, &expiration, 0, nullptr);
550
551     if ((sr = SQLFetch(stmt)) == SQL_NO_DATA)
552         return 0;
553
554     if (pexpiration)
555         *pexpiration = timeFromTimestamp(expiration);
556
557     if (version == ver)
558         return version; // nothing's changed, so just echo back the version
559
560     if (pvalue) {
561         SQLLEN len;
562         SQLCHAR buf[LONGDATA_BUFLEN];
563         while ((sr = SQLGetData(stmt, (pexpiration ? 3 : 2), SQL_C_CHAR, buf, sizeof(buf), &len)) != SQL_NO_DATA) {
564             if (!SQL_SUCCEEDED(sr)) {
565                 m_log.error("error while reading text field from result set");
566                 log_error(stmt, SQL_HANDLE_STMT);
567                 throw IOException("ODBC StorageService search failed to read data from result set.");
568             }
569             pvalue->append((char*)buf);
570         }
571     }
572     
573     return ver;
574 }
575
576 int ODBCStorageService::updateRow(const char *table, const char* context, const char* key, const char* value, time_t expiration, int version)
577 {
578 #ifdef _DEBUG
579     xmltooling::NDC ndc("updateRow");
580 #endif
581
582     if (!value && !expiration)
583         throw IOException("ODBC StorageService given invalid update instructions.");
584
585     // Get statement handle. Disable auto-commit mode to wrap select + update.
586     ODBCConn conn(getHDBC());
587     SQLRETURN sr = SQLSetConnectAttr(conn, SQL_ATTR_AUTOCOMMIT, SQL_AUTOCOMMIT_OFF, 0);
588     if (!SQL_SUCCEEDED(sr))
589         throw IOException("ODBC StorageService failed to disable auto-commit mode.");
590     conn.autoCommit = false;
591     SQLHSTMT stmt = getHSTMT(conn);
592
593     // First, fetch the current version for later, which also ensures the record still exists.
594     char timebuf[32];
595     timestampFromTime(time(nullptr), timebuf);
596     SQLString scontext(context);
597     SQLString skey(key);
598     string q("SELECT version FROM ");
599     q = q + table + " WHERE context='" + scontext.tostr() + "' AND id='" + skey.tostr() + "' AND expires > " + timebuf;
600
601     m_log.debug("SQL: %s", q.c_str());
602
603     sr = SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
604     if (!SQL_SUCCEEDED(sr)) {
605         m_log.error("error searching for (t=%s, c=%s, k=%s)", table, context, key);
606         log_error(stmt, SQL_HANDLE_STMT);
607         throw IOException("ODBC StorageService search failed.");
608     }
609
610     SQLSMALLINT ver;
611     SQLBindCol(stmt, 1, SQL_C_SSHORT, &ver, 0, nullptr);
612     if ((sr = SQLFetch(stmt)) == SQL_NO_DATA) {
613         return 0;
614     }
615
616     // Check version?
617     if (version > 0 && version != ver) {
618         return -1;
619     }
620
621     SQLFreeHandle(SQL_HANDLE_STMT, stmt);
622     stmt = getHSTMT(conn);
623
624     // Prepare and exectute update statement.
625     q = string("UPDATE ") + table + " SET ";
626
627     if (value)
628         q = q + "value=?, version=version+1";
629
630     if (expiration) {
631         timestampFromTime(expiration, timebuf);
632         if (value)
633             q += ',';
634         q = q + "expires = " + timebuf;
635     }
636
637     q = q + " WHERE context='" + scontext.tostr() + "' AND id='" + skey.tostr() + "'";
638
639     sr = SQLPrepare(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
640     if (!SQL_SUCCEEDED(sr)) {
641         m_log.error("update of record failed (t=%s, c=%s, k=%s", table, context, key);
642         log_error(stmt, SQL_HANDLE_STMT);
643         throw IOException("ODBC StorageService failed to update record.");
644     }
645     m_log.debug("SQLPrepare succeeded. SQL: %s", q.c_str());
646
647     SQLLEN b_ind = SQL_NTS;
648     if (value) {
649         if (strcmp(table, TEXT_TABLE)==0)
650             sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_LONGVARCHAR, strlen(value), 0, const_cast<char*>(value), &b_ind);
651         else
652             sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(value), &b_ind);
653         if (!SQL_SUCCEEDED(sr)) {
654             m_log.error("SQLBindParam failed (context = %s)", context);
655             log_error(stmt, SQL_HANDLE_STMT);
656             throw IOException("ODBC StorageService failed to update record.");
657         }
658         m_log.debug("SQLBindParam succeeded (context = %s)", context);
659     }
660
661     int attempts = 3;
662     pair<bool,bool> logres;
663     do {
664         logres = make_pair(false,false);
665         attempts--;
666         sr = SQLExecute(stmt);
667         if (sr == SQL_NO_DATA)
668             return 0;   // went missing?
669         else if (SQL_SUCCEEDED(sr)) {
670             m_log.debug("SQLExecute of update succeeded");
671             return ver + 1;
672         }
673
674         m_log.error("update of record failed (t=%s, c=%s, k=%s", table, context, key);
675         logres = log_error(stmt, SQL_HANDLE_STMT);
676     } while (attempts && logres.first);
677
678     throw IOException("ODBC StorageService failed to update record.");
679 }
680
681 bool ODBCStorageService::deleteRow(const char *table, const char *context, const char* key)
682 {
683 #ifdef _DEBUG
684     xmltooling::NDC ndc("deleteRow");
685 #endif
686
687     // Get statement handle.
688     ODBCConn conn(getHDBC());
689     SQLHSTMT stmt = getHSTMT(conn);
690
691     // Prepare and execute delete statement.
692     SQLString scontext(context);
693     SQLString skey(key);
694     string q = string("DELETE FROM ") + table + " WHERE context='" + scontext.tostr() + "' AND id='" + skey.tostr() + "'";
695     m_log.debug("SQL: %s", q.c_str());
696
697     SQLRETURN sr = SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
698      if (sr == SQL_NO_DATA)
699         return false;
700     else if (!SQL_SUCCEEDED(sr)) {
701         m_log.error("error deleting record (t=%s, c=%s, k=%s)", table, context, key);
702         log_error(stmt, SQL_HANDLE_STMT);
703         throw IOException("ODBC StorageService failed to delete record.");
704     }
705
706     return true;
707 }
708
709
710 void ODBCStorageService::cleanup()
711 {
712 #ifdef _DEBUG
713     xmltooling::NDC ndc("cleanup");
714 #endif
715
716     scoped_ptr<Mutex> mutex(Mutex::create());
717
718     mutex->lock();
719
720     m_log.info("cleanup thread started... running every %d secs", m_cleanupInterval);
721
722     while (!shutdown) {
723         shutdown_wait->timedwait(mutex.get(), m_cleanupInterval);
724         if (shutdown)
725             break;
726         try {
727             reap(nullptr);
728         }
729         catch (std::exception& ex) {
730             m_log.error("cleanup thread swallowed exception: %s", ex.what());
731         }
732     }
733
734     m_log.info("cleanup thread exiting...");
735
736     mutex->unlock();
737     Thread::exit(nullptr);
738 }
739
740 void* ODBCStorageService::cleanup_fn(void* cache_p)
741 {
742   ODBCStorageService* cache = (ODBCStorageService*)cache_p;
743
744 #ifndef WIN32
745   // First, let's block all signals
746   Thread::mask_all_signals();
747 #endif
748
749   // Now run the cleanup process.
750   cache->cleanup();
751   return nullptr;
752 }
753
754 void ODBCStorageService::updateContext(const char *table, const char* context, time_t expiration)
755 {
756 #ifdef _DEBUG
757     xmltooling::NDC ndc("updateContext");
758 #endif
759
760     // Get statement handle.
761     ODBCConn conn(getHDBC());
762     SQLHSTMT stmt = getHSTMT(conn);
763
764     char timebuf[32];
765     timestampFromTime(expiration, timebuf);
766
767     char nowbuf[32];
768     timestampFromTime(time(nullptr), nowbuf);
769
770     SQLString scontext(context);
771     string q = string("UPDATE ") + table + " SET expires = " + timebuf + " WHERE context='" + scontext.tostr() + "' AND expires > " + nowbuf;
772
773     m_log.debug("SQL: %s", q.c_str());
774
775     SQLRETURN sr = SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
776     if ((sr != SQL_NO_DATA) && !SQL_SUCCEEDED(sr)) {
777         m_log.error("error updating records (t=%s, c=%s)", table, context ? context : "all");
778         log_error(stmt, SQL_HANDLE_STMT);
779         throw IOException("ODBC StorageService failed to update context expiration.");
780     }
781 }
782
783 void ODBCStorageService::reap(const char *table, const char* context)
784 {
785 #ifdef _DEBUG
786     xmltooling::NDC ndc("reap");
787 #endif
788
789     // Get statement handle.
790     ODBCConn conn(getHDBC());
791     SQLHSTMT stmt = getHSTMT(conn);
792
793     // Prepare and execute delete statement.
794     char nowbuf[32];
795     timestampFromTime(time(nullptr), nowbuf);
796     string q;
797     if (context) {
798         SQLString scontext(context);
799         q = string("DELETE FROM ") + table + " WHERE context='" + scontext.tostr() + "' AND expires <= " + nowbuf;
800     }
801     else {
802         q = string("DELETE FROM ") + table + " WHERE expires <= " + nowbuf;
803     }
804     m_log.debug("SQL: %s", q.c_str());
805
806     SQLRETURN sr = SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
807     if ((sr != SQL_NO_DATA) && !SQL_SUCCEEDED(sr)) {
808         m_log.error("error expiring records (t=%s, c=%s)", table, context ? context : "all");
809         log_error(stmt, SQL_HANDLE_STMT);
810         throw IOException("ODBC StorageService failed to purge expired records.");
811     }
812 }
813
814 void ODBCStorageService::deleteContext(const char *table, const char* context)
815 {
816 #ifdef _DEBUG
817     xmltooling::NDC ndc("deleteContext");
818 #endif
819
820     // Get statement handle.
821     ODBCConn conn(getHDBC());
822     SQLHSTMT stmt = getHSTMT(conn);
823
824     // Prepare and execute delete statement.
825     SQLString scontext(context);
826     string q = string("DELETE FROM ") + table + " WHERE context='" + scontext.tostr() + "'";
827     m_log.debug("SQL: %s", q.c_str());
828
829     SQLRETURN sr = SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
830     if ((sr != SQL_NO_DATA) && !SQL_SUCCEEDED(sr)) {
831         m_log.error("error deleting context (t=%s, c=%s)", table, context);
832         log_error(stmt, SQL_HANDLE_STMT);
833         throw IOException("ODBC StorageService failed to delete context.");
834     }
835 }
836
837 extern "C" int ODBCSTORE_EXPORTS xmltooling_extension_init(void*)
838 {
839     // Register this SS type
840     XMLToolingConfig::getConfig().StorageServiceManager.registerFactory("ODBC", ODBCStorageServiceFactory);
841     return 0;
842 }
843
844 extern "C" void ODBCSTORE_EXPORTS xmltooling_extension_term()
845 {
846     XMLToolingConfig::getConfig().StorageServiceManager.deregisterFactory("ODBC");
847 }