Handle BinaryAttribute (include the header, put it in correct folder)
[shibboleth/cpp-sp.git] / odbc-store / odbc-store.cpp
1 /**
2  * Licensed to the University Corporation for Advanced Internet
3  * Development, Inc. (UCAID) under one or more contributor license
4  * agreements. See the NOTICE file distributed with this work for
5  * additional information regarding copyright ownership.
6  *
7  * UCAID licenses this file to you under the Apache License,
8  * Version 2.0 (the "License"); you may not use this file except
9  * in compliance with the License. You may obtain a copy of the
10  * License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing,
15  * software distributed under the License is distributed on an
16  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
17  * either express or implied. See the License for the specific
18  * language governing permissions and limitations under the License.
19  */
20
21 /**
22  * odbc-store.cpp
23  *
24  * Storage Service using ODBC.
25  */
26
27 #if defined (_MSC_VER) || defined(__BORLANDC__)
28 # include "config_win32.h"
29 #else
30 # include "config.h"
31 #endif
32
33 #ifdef WIN32
34 # define _CRT_NONSTDC_NO_DEPRECATE 1
35 # define _CRT_SECURE_NO_DEPRECATE 1
36 #endif
37
38 #ifdef WIN32
39 # define ODBCSTORE_EXPORTS __declspec(dllexport)
40 #else
41 # define ODBCSTORE_EXPORTS
42 #endif
43
44 #include <xercesc/util/XMLUniDefs.hpp>
45 #include <xmltooling/logging.h>
46 #include <xmltooling/unicode.h>
47 #include <xmltooling/XMLToolingConfig.h>
48 #include <xmltooling/util/NDC.h>
49 #include <xmltooling/util/StorageService.h>
50 #include <xmltooling/util/Threads.h>
51 #include <xmltooling/util/XMLHelper.h>
52
53 #include <sql.h>
54 #include <sqlext.h>
55
56 using namespace xmltooling::logging;
57 using namespace xmltooling;
58 using namespace xercesc;
59 using namespace std;
60
61 #define PLUGIN_VER_MAJOR 1
62 #define PLUGIN_VER_MINOR 0
63
64 #define LONGDATA_BUFLEN 16384
65
66 #define COLSIZE_CONTEXT 255
67 #define COLSIZE_ID 255
68 #define COLSIZE_STRING_VALUE 255
69
70 #define STRING_TABLE "strings"
71 #define TEXT_TABLE "texts"
72
73 /* table definitions
74 CREATE TABLE version (
75     major int NOT nullptr,
76     minor int NOT nullptr
77     )
78
79 CREATE TABLE strings (
80     context varchar(255) not null,
81     id varchar(255) not null,
82     expires datetime not null,
83     version smallint not null,
84     value varchar(255) not null,
85     PRIMARY KEY (context, id)
86     )
87
88 CREATE TABLE texts (
89     context varchar(255) not null,
90     id varchar(255) not null,
91     expires datetime not null,
92     version smallint not null,
93     value text not null,
94     PRIMARY KEY (context, id)
95     )
96 */
97
98 namespace {
99     static const XMLCh cleanupInterval[] =  UNICODE_LITERAL_15(c,l,e,a,n,u,p,I,n,t,e,r,v,a,l);
100     static const XMLCh isolationLevel[] =   UNICODE_LITERAL_14(i,s,o,l,a,t,i,o,n,L,e,v,e,l);
101     static const XMLCh ConnectionString[] = UNICODE_LITERAL_16(C,o,n,n,e,c,t,i,o,n,S,t,r,i,n,g);
102     static const XMLCh RetryOnError[] =     UNICODE_LITERAL_12(R,e,t,r,y,O,n,E,r,r,o,r);
103
104     // RAII for ODBC handles
105     struct ODBCConn {
106         ODBCConn(SQLHDBC conn) : handle(conn), autoCommit(true) {}
107         ~ODBCConn() {
108             SQLRETURN sr = SQL_SUCCESS;
109             if (!autoCommit)
110                 sr = SQLSetConnectAttr(handle, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER)SQL_AUTOCOMMIT_ON, 0);
111             SQLDisconnect(handle);
112             SQLFreeHandle(SQL_HANDLE_DBC,handle);
113             if (!SQL_SUCCEEDED(sr))
114                 throw IOException("Failed to commit connection and return to auto-commit mode.");
115         }
116         operator SQLHDBC() {return handle;}
117         SQLHDBC handle;
118         bool autoCommit;
119     };
120
121     class ODBCStorageService : public StorageService
122     {
123     public:
124         ODBCStorageService(const DOMElement* e);
125         virtual ~ODBCStorageService();
126
127         bool createString(const char* context, const char* key, const char* value, time_t expiration) {
128             return createRow(STRING_TABLE, context, key, value, expiration);
129         }
130         int readString(const char* context, const char* key, string* pvalue=nullptr, time_t* pexpiration=nullptr, int version=0) {
131             return readRow(STRING_TABLE, context, key, pvalue, pexpiration, version, false);
132         }
133         int updateString(const char* context, const char* key, const char* value=nullptr, time_t expiration=0, int version=0) {
134             return updateRow(STRING_TABLE, context, key, value, expiration, version);
135         }
136         bool deleteString(const char* context, const char* key) {
137             return deleteRow(STRING_TABLE, context, key);
138         }
139
140         bool createText(const char* context, const char* key, const char* value, time_t expiration) {
141             return createRow(TEXT_TABLE, context, key, value, expiration);
142         }
143         int readText(const char* context, const char* key, string* pvalue=nullptr, time_t* pexpiration=nullptr, int version=0) {
144             return readRow(TEXT_TABLE, context, key, pvalue, pexpiration, version, true);
145         }
146         int updateText(const char* context, const char* key, const char* value=nullptr, time_t expiration=0, int version=0) {
147             return updateRow(TEXT_TABLE, context, key, value, expiration, version);
148         }
149         bool deleteText(const char* context, const char* key) {
150             return deleteRow(TEXT_TABLE, context, key);
151         }
152
153         void reap(const char* context) {
154             reap(STRING_TABLE, context);
155             reap(TEXT_TABLE, context);
156         }
157
158         void updateContext(const char* context, time_t expiration) {
159             updateContext(STRING_TABLE, context, expiration);
160             updateContext(TEXT_TABLE, context, expiration);
161         }
162
163         void deleteContext(const char* context) {
164             deleteContext(STRING_TABLE, context);
165             deleteContext(TEXT_TABLE, context);
166         }
167          
168
169     private:
170         bool createRow(const char *table, const char* context, const char* key, const char* value, time_t expiration);
171         int readRow(const char *table, const char* context, const char* key, string* pvalue, time_t* pexpiration, int version, bool text);
172         int updateRow(const char *table, const char* context, const char* key, const char* value, time_t expiration, int version);
173         bool deleteRow(const char *table, const char* context, const char* key);
174
175         void reap(const char* table, const char* context);
176         void updateContext(const char* table, const char* context, time_t expiration);
177         void deleteContext(const char* table, const char* context);
178
179         SQLHDBC getHDBC();
180         SQLHSTMT getHSTMT(SQLHDBC);
181         pair<int,int> getVersion(SQLHDBC);
182         pair<bool,bool> log_error(SQLHANDLE handle, SQLSMALLINT htype, const char* checkfor=nullptr);
183
184         static void* cleanup_fn(void*); 
185         void cleanup();
186
187         Category& m_log;
188         int m_cleanupInterval;
189         CondWait* shutdown_wait;
190         Thread* cleanup_thread;
191         bool shutdown;
192
193         SQLHENV m_henv;
194         string m_connstring;
195         long m_isolation;
196         vector<SQLINTEGER> m_retries;
197     };
198
199     StorageService* ODBCStorageServiceFactory(const DOMElement* const & e)
200     {
201         return new ODBCStorageService(e);
202     }
203
204     // convert SQL timestamp to time_t 
205     time_t timeFromTimestamp(SQL_TIMESTAMP_STRUCT expires)
206     {
207         time_t ret;
208         struct tm t;
209         t.tm_sec=expires.second;
210         t.tm_min=expires.minute;
211         t.tm_hour=expires.hour;
212         t.tm_mday=expires.day;
213         t.tm_mon=expires.month-1;
214         t.tm_year=expires.year-1900;
215         t.tm_isdst=0;
216 #if defined(HAVE_TIMEGM)
217         ret = timegm(&t);
218 #else
219         ret = mktime(&t) - timezone;
220 #endif
221         return (ret);
222     }
223
224     // conver time_t to SQL string
225     void timestampFromTime(time_t t, char* ret)
226     {
227 #ifdef HAVE_GMTIME_R
228         struct tm res;
229         struct tm* ptime=gmtime_r(&t,&res);
230 #else
231         struct tm* ptime=gmtime(&t);
232 #endif
233         strftime(ret,32,"{ts '%Y-%m-%d %H:%M:%S'}",ptime);
234     }
235
236     // make a string safe for SQL command
237     // result to be free'd only if it isn't the input
238     static char *makeSafeSQL(const char *src)
239     {
240        int ns = 0;
241        int nc = 0;
242        char *s;
243     
244        // see if any conversion needed
245        for (s=(char*)src; *s; nc++,s++) if (*s=='\'') ns++;
246        if (ns==0) return ((char*)src);
247     
248        char *safe = new char[(nc+2*ns+1)];
249        for (s=safe; *src; src++) {
250            if (*src=='\'') *s++ = '\'';
251            *s++ = (char)*src;
252        }
253        *s = '\0';
254        return (safe);
255     }
256
257     void freeSafeSQL(char *safe, const char *src)
258     {
259         if (safe!=src)
260             delete[](safe);
261     }
262 };
263
264 ODBCStorageService::ODBCStorageService(const DOMElement* e) : m_log(Category::getInstance("XMLTooling.StorageService")),
265    m_cleanupInterval(900), shutdown_wait(nullptr), cleanup_thread(nullptr), shutdown(false), m_henv(SQL_NULL_HANDLE), m_isolation(SQL_TXN_SERIALIZABLE)
266 {
267 #ifdef _DEBUG
268     xmltooling::NDC ndc("ODBCStorageService");
269 #endif
270
271     const XMLCh* tag=e ? e->getAttributeNS(nullptr,cleanupInterval) : nullptr;
272     if (tag && *tag)
273         m_cleanupInterval = XMLString::parseInt(tag);
274     if (!m_cleanupInterval)
275         m_cleanupInterval = 900;
276
277     auto_ptr_char iso(e ? e->getAttributeNS(nullptr,isolationLevel) : nullptr);
278     if (iso.get() && *iso.get()) {
279         if (!strcmp(iso.get(),"SERIALIZABLE"))
280             m_isolation = SQL_TXN_SERIALIZABLE;
281         else if (!strcmp(iso.get(),"REPEATABLE_READ"))
282             m_isolation = SQL_TXN_REPEATABLE_READ;
283         else if (!strcmp(iso.get(),"READ_COMMITTED"))
284             m_isolation = SQL_TXN_READ_COMMITTED;
285         else if (!strcmp(iso.get(),"READ_UNCOMMITTED"))
286             m_isolation = SQL_TXN_READ_UNCOMMITTED;
287         else
288             throw XMLToolingException("Unknown transaction isolationLevel property.");
289     }
290
291     if (m_henv == SQL_NULL_HANDLE) {
292         // Enable connection pooling.
293         SQLSetEnvAttr(SQL_NULL_HANDLE, SQL_ATTR_CONNECTION_POOLING, (void*)SQL_CP_ONE_PER_HENV, 0);
294
295         // Allocate the environment.
296         if (!SQL_SUCCEEDED(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &m_henv)))
297             throw XMLToolingException("ODBC failed to initialize.");
298
299         // Specify ODBC 3.x
300         SQLSetEnvAttr(m_henv, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0);
301
302         m_log.info("ODBC initialized");
303     }
304
305     // Grab connection string from the configuration.
306     e = e ? XMLHelper::getFirstChildElement(e,ConnectionString) : nullptr;
307     if (!e || !e->hasChildNodes()) {
308         SQLFreeHandle(SQL_HANDLE_ENV, m_henv);
309         throw XMLToolingException("ODBC StorageService requires ConnectionString element in configuration.");
310     }
311     auto_ptr_char arg(e->getFirstChild()->getNodeValue());
312     m_connstring=arg.get();
313
314     // Connect and check version.
315     ODBCConn conn(getHDBC());
316     pair<int,int> v=getVersion(conn);
317
318     // Make sure we've got the right version.
319     if (v.first != PLUGIN_VER_MAJOR) {
320         SQLFreeHandle(SQL_HANDLE_ENV, m_henv);
321         m_log.crit("unknown database version: %d.%d", v.first, v.second);
322         throw XMLToolingException("Unknown database version for ODBC StorageService.");
323     }
324
325     // Load any retry errors to check.
326     e = XMLHelper::getNextSiblingElement(e,RetryOnError);
327     while (e) {
328         if (e->hasChildNodes()) {
329             m_retries.push_back(XMLString::parseInt(e->getFirstChild()->getNodeValue()));
330             m_log.info("will retry operations when native ODBC error (%ld) is returned", m_retries.back());
331         }
332         e = XMLHelper::getNextSiblingElement(e,RetryOnError);
333     }
334
335     // Initialize the cleanup thread
336     shutdown_wait = CondWait::create();
337     cleanup_thread = Thread::create(&cleanup_fn, (void*)this);
338 }
339
340 ODBCStorageService::~ODBCStorageService()
341 {
342     shutdown = true;
343     shutdown_wait->signal();
344     cleanup_thread->join(nullptr);
345     delete shutdown_wait;
346     if (m_henv != SQL_NULL_HANDLE)
347         SQLFreeHandle(SQL_HANDLE_ENV, m_henv);
348 }
349
350 pair<bool,bool> ODBCStorageService::log_error(SQLHANDLE handle, SQLSMALLINT htype, const char* checkfor)
351 {
352     SQLSMALLINT  i = 0;
353     SQLINTEGER   native;
354     SQLCHAR      state[7];
355     SQLCHAR      text[256];
356     SQLSMALLINT  len;
357     SQLRETURN    ret;
358
359     pair<bool,bool> res = make_pair(false,false);
360     do {
361         ret = SQLGetDiagRec(htype, handle, ++i, state, &native, text, sizeof(text), &len);
362         if (SQL_SUCCEEDED(ret)) {
363             m_log.error("ODBC Error: %s:%ld:%ld:%s", state, i, native, text);
364             for (vector<SQLINTEGER>::const_iterator n = m_retries.begin(); !res.first && n != m_retries.end(); ++n)
365                 res.first = (*n == native);
366             if (checkfor && !strcmp(checkfor, (const char*)state))
367                 res.second = true;
368         }
369     } while(SQL_SUCCEEDED(ret));
370     return res;
371 }
372
373 SQLHDBC ODBCStorageService::getHDBC()
374 {
375 #ifdef _DEBUG
376     xmltooling::NDC ndc("getHDBC");
377 #endif
378
379     // Get a handle.
380     SQLHDBC handle;
381     SQLRETURN sr=SQLAllocHandle(SQL_HANDLE_DBC, m_henv, &handle);
382     if (!SQL_SUCCEEDED(sr)) {
383         m_log.error("failed to allocate connection handle");
384         log_error(m_henv, SQL_HANDLE_ENV);
385         throw IOException("ODBC StorageService failed to allocate a connection handle.");
386     }
387
388     sr=SQLDriverConnect(handle,nullptr,(SQLCHAR*)m_connstring.c_str(),m_connstring.length(),nullptr,0,nullptr,SQL_DRIVER_NOPROMPT);
389     if (!SQL_SUCCEEDED(sr)) {
390         m_log.error("failed to connect to database");
391         log_error(handle, SQL_HANDLE_DBC);
392         throw IOException("ODBC StorageService failed to connect to database.");
393     }
394
395     sr = SQLSetConnectAttr(handle, SQL_ATTR_TXN_ISOLATION, (SQLPOINTER)m_isolation, 0);
396     if (!SQL_SUCCEEDED(sr))
397         throw IOException("ODBC StorageService failed to set transaction isolation level.");
398
399     return handle;
400 }
401
402 SQLHSTMT ODBCStorageService::getHSTMT(SQLHDBC conn)
403 {
404     SQLHSTMT hstmt;
405     SQLRETURN sr=SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
406     if (!SQL_SUCCEEDED(sr)) {
407         m_log.error("failed to allocate statement handle");
408         log_error(conn, SQL_HANDLE_DBC);
409         throw IOException("ODBC StorageService failed to allocate a statement handle.");
410     }
411     return hstmt;
412 }
413
414 pair<int,int> ODBCStorageService::getVersion(SQLHDBC conn)
415 {
416     // Grab the version number from the database.
417     SQLHSTMT stmt = getHSTMT(conn);
418     
419     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)"SELECT major,minor FROM version", SQL_NTS);
420     if (!SQL_SUCCEEDED(sr)) {
421         m_log.error("failed to read version from database");
422         log_error(stmt, SQL_HANDLE_STMT);
423         throw IOException("ODBC StorageService failed to read version from database.");
424     }
425
426     SQLINTEGER major;
427     SQLINTEGER minor;
428     SQLBindCol(stmt,1,SQL_C_SLONG,&major,0,nullptr);
429     SQLBindCol(stmt,2,SQL_C_SLONG,&minor,0,nullptr);
430
431     if ((sr=SQLFetch(stmt)) != SQL_NO_DATA)
432         return pair<int,int>(major,minor);
433
434     m_log.error("no rows returned in version query");
435     throw IOException("ODBC StorageService failed to read version from database.");
436 }
437
438 bool ODBCStorageService::createRow(const char* table, const char* context, const char* key, const char* value, time_t expiration)
439 {
440 #ifdef _DEBUG
441     xmltooling::NDC ndc("createRow");
442 #endif
443
444     char timebuf[32];
445     timestampFromTime(expiration, timebuf);
446
447     // Get statement handle.
448     ODBCConn conn(getHDBC());
449     SQLHSTMT stmt = getHSTMT(conn);
450
451     // Prepare and exectute insert statement.
452     //char *scontext = makeSafeSQL(context);
453     //char *skey = makeSafeSQL(key);
454     //char *svalue = makeSafeSQL(value);
455     string q  = string("INSERT INTO ") + table + " VALUES (?,?," + timebuf + ",1,?)";
456
457     SQLRETURN sr = SQLPrepare(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
458     if (!SQL_SUCCEEDED(sr)) {
459         m_log.error("SQLPrepare failed (t=%s, c=%s, k=%s)", table, context, key);
460         log_error(stmt, SQL_HANDLE_STMT);
461         throw IOException("ODBC StorageService failed to insert record.");
462     }
463     m_log.debug("SQLPrepare succeeded. SQL: %s", q.c_str());
464
465     SQLLEN b_ind = SQL_NTS;
466     sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(context), &b_ind);
467     if (!SQL_SUCCEEDED(sr)) {
468         m_log.error("SQLBindParam failed (context = %s)", context);
469         log_error(stmt, SQL_HANDLE_STMT);
470         throw IOException("ODBC StorageService failed to insert record.");
471     }
472     m_log.debug("SQLBindParam succeeded (context = %s)", context);
473
474     sr = SQLBindParam(stmt, 2, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(key), &b_ind);
475     if (!SQL_SUCCEEDED(sr)) {
476         m_log.error("SQLBindParam failed (key = %s)", key);
477         log_error(stmt, SQL_HANDLE_STMT);
478         throw IOException("ODBC StorageService failed to insert record.");
479     }
480     m_log.debug("SQLBindParam succeeded (key = %s)", key);
481
482     if (strcmp(table, TEXT_TABLE)==0)
483         sr = SQLBindParam(stmt, 3, SQL_C_CHAR, SQL_LONGVARCHAR, strlen(value), 0, const_cast<char*>(value), &b_ind);
484     else
485         sr = SQLBindParam(stmt, 3, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(value), &b_ind);
486     if (!SQL_SUCCEEDED(sr)) {
487         m_log.error("SQLBindParam failed (value = %s)", value);
488         log_error(stmt, SQL_HANDLE_STMT);
489         throw IOException("ODBC StorageService failed to insert record.");
490     }
491     m_log.debug("SQLBindParam succeeded (value = %s)", value);
492     
493     //freeSafeSQL(scontext, context);
494     //freeSafeSQL(skey, key);
495     //freeSafeSQL(svalue, value);
496     //m_log.debug("SQL: %s", q.c_str());
497
498     int attempts = 3;
499     pair<bool,bool> logres;
500     do {
501         logres = make_pair(false,false);
502         attempts--;
503         sr=SQLExecute(stmt);
504         if (SQL_SUCCEEDED(sr)) {
505             m_log.debug("SQLExecute of insert succeeded");
506             return true;
507         }
508         m_log.error("insert record failed (t=%s, c=%s, k=%s)", table, context, key);
509         logres = log_error(stmt, SQL_HANDLE_STMT, "23000");
510         if (logres.second)
511             return false;   // supposedly integrity violation?
512     } while (attempts && logres.first);
513
514     throw IOException("ODBC StorageService failed to insert record.");
515 }
516
517 int ODBCStorageService::readRow(
518     const char *table, const char* context, const char* key, string* pvalue, time_t* pexpiration, int version, bool text
519     )
520 {
521 #ifdef _DEBUG
522     xmltooling::NDC ndc("readRow");
523 #endif
524
525     // Get statement handle.
526     ODBCConn conn(getHDBC());
527     SQLHSTMT stmt = getHSTMT(conn);
528
529     // Prepare and exectute select statement.
530     char timebuf[32];
531     timestampFromTime(time(nullptr), timebuf);
532     char *scontext = makeSafeSQL(context);
533     char *skey = makeSafeSQL(key);
534     ostringstream q;
535     q << "SELECT version";
536     if (pexpiration)
537         q << ",expires";
538     if (pvalue)
539         q << ",CASE version WHEN " << version << " THEN null ELSE value END";
540     q << " FROM " << table << " WHERE context='" << scontext << "' AND id='" << skey << "' AND expires > " << timebuf;
541     freeSafeSQL(scontext, context);
542     freeSafeSQL(skey, key);
543     if (m_log.isDebugEnabled())
544         m_log.debug("SQL: %s", q.str().c_str());
545
546     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.str().c_str(), SQL_NTS);
547     if (!SQL_SUCCEEDED(sr)) {
548         m_log.error("error searching for (t=%s, c=%s, k=%s)", table, context, key);
549         log_error(stmt, SQL_HANDLE_STMT);
550         throw IOException("ODBC StorageService search failed.");
551     }
552
553     SQLSMALLINT ver;
554     SQL_TIMESTAMP_STRUCT expiration;
555
556     SQLBindCol(stmt,1,SQL_C_SSHORT,&ver,0,nullptr);
557     if (pexpiration)
558         SQLBindCol(stmt,2,SQL_C_TYPE_TIMESTAMP,&expiration,0,nullptr);
559
560     if ((sr=SQLFetch(stmt)) == SQL_NO_DATA)
561         return 0;
562
563     if (pexpiration)
564         *pexpiration = timeFromTimestamp(expiration);
565
566     if (version == ver)
567         return version; // nothing's changed, so just echo back the version
568
569     if (pvalue) {
570         SQLLEN len;
571         SQLCHAR buf[LONGDATA_BUFLEN];
572         while ((sr=SQLGetData(stmt,pexpiration ? 3 : 2,SQL_C_CHAR,buf,sizeof(buf),&len)) != SQL_NO_DATA) {
573             if (!SQL_SUCCEEDED(sr)) {
574                 m_log.error("error while reading text field from result set");
575                 log_error(stmt, SQL_HANDLE_STMT);
576                 throw IOException("ODBC StorageService search failed to read data from result set.");
577             }
578             pvalue->append((char*)buf);
579         }
580     }
581     
582     return ver;
583 }
584
585 int ODBCStorageService::updateRow(const char *table, const char* context, const char* key, const char* value, time_t expiration, int version)
586 {
587 #ifdef _DEBUG
588     xmltooling::NDC ndc("updateRow");
589 #endif
590
591     if (!value && !expiration)
592         throw IOException("ODBC StorageService given invalid update instructions.");
593
594     // Get statement handle. Disable auto-commit mode to wrap select + update.
595     ODBCConn conn(getHDBC());
596     SQLRETURN sr = SQLSetConnectAttr(conn, SQL_ATTR_AUTOCOMMIT, SQL_AUTOCOMMIT_OFF, 0);
597     if (!SQL_SUCCEEDED(sr))
598         throw IOException("ODBC StorageService failed to disable auto-commit mode.");
599     conn.autoCommit = false;
600     SQLHSTMT stmt = getHSTMT(conn);
601
602     // First, fetch the current version for later, which also ensures the record still exists.
603     char timebuf[32];
604     timestampFromTime(time(nullptr), timebuf);
605     char *scontext = makeSafeSQL(context);
606     char *skey = makeSafeSQL(key);
607     string q("SELECT version FROM ");
608     q = q + table + " WHERE context='" + scontext + "' AND id='" + skey + "' AND expires > " + timebuf;
609
610     m_log.debug("SQL: %s", q.c_str());
611
612     sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
613     if (!SQL_SUCCEEDED(sr)) {
614         freeSafeSQL(scontext, context);
615         freeSafeSQL(skey, key);
616         m_log.error("error searching for (t=%s, c=%s, k=%s)", table, context, key);
617         log_error(stmt, SQL_HANDLE_STMT);
618         throw IOException("ODBC StorageService search failed.");
619     }
620
621     SQLSMALLINT ver;
622     SQLBindCol(stmt,1,SQL_C_SSHORT,&ver,0,nullptr);
623     if ((sr=SQLFetch(stmt)) == SQL_NO_DATA) {
624         freeSafeSQL(scontext, context);
625         freeSafeSQL(skey, key);
626         return 0;
627     }
628
629     // Check version?
630     if (version > 0 && version != ver) {
631         freeSafeSQL(scontext, context);
632         freeSafeSQL(skey, key);
633         return -1;
634     }
635
636     SQLFreeHandle(SQL_HANDLE_STMT, stmt);
637     stmt = getHSTMT(conn);
638
639     // Prepare and exectute update statement.
640     q = string("UPDATE ") + table + " SET ";
641
642     if (value)
643         q = q + "value=?, version=version+1";
644
645     if (expiration) {
646         timestampFromTime(expiration, timebuf);
647         if (value)
648             q += ',';
649         q = q + "expires = " + timebuf;
650     }
651
652     q = q + " WHERE context='" + scontext + "' AND id='" + skey + "'";
653     freeSafeSQL(scontext, context);
654     freeSafeSQL(skey, key);
655
656     sr = SQLPrepare(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
657     if (!SQL_SUCCEEDED(sr)) {
658         m_log.error("update of record failed (t=%s, c=%s, k=%s", table, context, key);
659         log_error(stmt, SQL_HANDLE_STMT);
660         throw IOException("ODBC StorageService failed to update record.");
661     }
662     m_log.debug("SQLPrepare succeeded. SQL: %s", q.c_str());
663
664     SQLLEN b_ind = SQL_NTS;
665     if (value) {
666         if (strcmp(table, TEXT_TABLE)==0)
667             sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_LONGVARCHAR, strlen(value), 0, const_cast<char*>(value), &b_ind);
668         else
669             sr = SQLBindParam(stmt, 1, SQL_C_CHAR, SQL_VARCHAR, 255, 0, const_cast<char*>(value), &b_ind);
670         if (!SQL_SUCCEEDED(sr)) {
671             m_log.error("SQLBindParam failed (context = %s)", context);
672             log_error(stmt, SQL_HANDLE_STMT);
673             throw IOException("ODBC StorageService failed to update record.");
674         }
675         m_log.debug("SQLBindParam succeeded (context = %s)", context);
676     }
677
678     int attempts = 3;
679     pair<bool,bool> logres;
680     do {
681         logres = make_pair(false,false);
682         attempts--;
683         sr=SQLExecute(stmt);
684         if (sr==SQL_NO_DATA)
685             return 0;   // went missing?
686         else if (SQL_SUCCEEDED(sr)) {
687             m_log.debug("SQLExecute of update succeeded");
688             return ver + 1;
689         }
690
691         m_log.error("update of record failed (t=%s, c=%s, k=%s", table, context, key);
692         logres = log_error(stmt, SQL_HANDLE_STMT);
693     } while (attempts && logres.first);
694
695     throw IOException("ODBC StorageService failed to update record.");
696 }
697
698 bool ODBCStorageService::deleteRow(const char *table, const char *context, const char* key)
699 {
700 #ifdef _DEBUG
701     xmltooling::NDC ndc("deleteRow");
702 #endif
703
704     // Get statement handle.
705     ODBCConn conn(getHDBC());
706     SQLHSTMT stmt = getHSTMT(conn);
707
708     // Prepare and execute delete statement.
709     char *scontext = makeSafeSQL(context);
710     char *skey = makeSafeSQL(key);
711     string q = string("DELETE FROM ") + table + " WHERE context='" + scontext + "' AND id='" + skey + "'";
712     freeSafeSQL(scontext, context);
713     freeSafeSQL(skey, key);
714     m_log.debug("SQL: %s", q.c_str());
715
716     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
717      if (sr==SQL_NO_DATA)
718         return false;
719     else if (!SQL_SUCCEEDED(sr)) {
720         m_log.error("error deleting record (t=%s, c=%s, k=%s)", table, context, key);
721         log_error(stmt, SQL_HANDLE_STMT);
722         throw IOException("ODBC StorageService failed to delete record.");
723     }
724
725     return true;
726 }
727
728
729 void ODBCStorageService::cleanup()
730 {
731 #ifdef _DEBUG
732     xmltooling::NDC ndc("cleanup");
733 #endif
734
735     Mutex* mutex = Mutex::create();
736
737     mutex->lock();
738
739     m_log.info("cleanup thread started... running every %d secs", m_cleanupInterval);
740
741     while (!shutdown) {
742         shutdown_wait->timedwait(mutex, m_cleanupInterval);
743         if (shutdown)
744             break;
745         try {
746             reap(nullptr);
747         }
748         catch (exception& ex) {
749             m_log.error("cleanup thread swallowed exception: %s", ex.what());
750         }
751     }
752
753     m_log.info("cleanup thread exiting...");
754
755     mutex->unlock();
756     delete mutex;
757     Thread::exit(nullptr);
758 }
759
760 void* ODBCStorageService::cleanup_fn(void* cache_p)
761 {
762   ODBCStorageService* cache = (ODBCStorageService*)cache_p;
763
764 #ifndef WIN32
765   // First, let's block all signals
766   Thread::mask_all_signals();
767 #endif
768
769   // Now run the cleanup process.
770   cache->cleanup();
771   return nullptr;
772 }
773
774 void ODBCStorageService::updateContext(const char *table, const char* context, time_t expiration)
775 {
776 #ifdef _DEBUG
777     xmltooling::NDC ndc("updateContext");
778 #endif
779
780     // Get statement handle.
781     ODBCConn conn(getHDBC());
782     SQLHSTMT stmt = getHSTMT(conn);
783
784     char timebuf[32];
785     timestampFromTime(expiration, timebuf);
786
787     char nowbuf[32];
788     timestampFromTime(time(nullptr), nowbuf);
789
790     char *scontext = makeSafeSQL(context);
791     string q("UPDATE ");
792     q = q + table + " SET expires = " + timebuf + " WHERE context='" + scontext + "' AND expires > " + nowbuf;
793     freeSafeSQL(scontext, context);
794
795     m_log.debug("SQL: %s", q.c_str());
796
797     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
798     if ((sr!=SQL_NO_DATA) && !SQL_SUCCEEDED(sr)) {
799         m_log.error("error updating records (t=%s, c=%s)", table, context ? context : "all");
800         log_error(stmt, SQL_HANDLE_STMT);
801         throw IOException("ODBC StorageService failed to update context expiration.");
802     }
803 }
804
805 void ODBCStorageService::reap(const char *table, const char* context)
806 {
807 #ifdef _DEBUG
808     xmltooling::NDC ndc("reap");
809 #endif
810
811     // Get statement handle.
812     ODBCConn conn(getHDBC());
813     SQLHSTMT stmt = getHSTMT(conn);
814
815     // Prepare and execute delete statement.
816     char nowbuf[32];
817     timestampFromTime(time(nullptr), nowbuf);
818     string q;
819     if (context) {
820         char *scontext = makeSafeSQL(context);
821         q = string("DELETE FROM ") + table + " WHERE context='" + scontext + "' AND expires <= " + nowbuf;
822         freeSafeSQL(scontext, context);
823     }
824     else {
825         q = string("DELETE FROM ") + table + " WHERE expires <= " + nowbuf;
826     }
827     m_log.debug("SQL: %s", q.c_str());
828
829     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
830     if ((sr!=SQL_NO_DATA) && !SQL_SUCCEEDED(sr)) {
831         m_log.error("error expiring records (t=%s, c=%s)", table, context ? context : "all");
832         log_error(stmt, SQL_HANDLE_STMT);
833         throw IOException("ODBC StorageService failed to purge expired records.");
834     }
835 }
836
837 void ODBCStorageService::deleteContext(const char *table, const char* context)
838 {
839 #ifdef _DEBUG
840     xmltooling::NDC ndc("deleteContext");
841 #endif
842
843     // Get statement handle.
844     ODBCConn conn(getHDBC());
845     SQLHSTMT stmt = getHSTMT(conn);
846
847     // Prepare and execute delete statement.
848     char *scontext = makeSafeSQL(context);
849     string q = string("DELETE FROM ") + table + " WHERE context='" + scontext + "'";
850     freeSafeSQL(scontext, context);
851     m_log.debug("SQL: %s", q.c_str());
852
853     SQLRETURN sr=SQLExecDirect(stmt, (SQLCHAR*)q.c_str(), SQL_NTS);
854     if ((sr!=SQL_NO_DATA) && !SQL_SUCCEEDED(sr)) {
855         m_log.error("error deleting context (t=%s, c=%s)", table, context);
856         log_error(stmt, SQL_HANDLE_STMT);
857         throw IOException("ODBC StorageService failed to delete context.");
858     }
859 }
860
861 extern "C" int ODBCSTORE_EXPORTS xmltooling_extension_init(void*)
862 {
863     // Register this SS type
864     XMLToolingConfig::getConfig().StorageServiceManager.registerFactory("ODBC", ODBCStorageServiceFactory);
865     return 0;
866 }
867
868 extern "C" void ODBCSTORE_EXPORTS xmltooling_extension_term()
869 {
870     XMLToolingConfig::getConfig().StorageServiceManager.deregisterFactory("ODBC");
871 }