77a149b0a3889a54e529bf78b09a14d0b38ca185
[shibboleth/cpp-sp.git] / odbc_ccache / odbc-ccache.cpp
1 /*
2  *  Copyright 2001-2005 Internet2
3  * 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /*
18  * odbc-ccache.cpp - Shibboleth Credential Cache using ODBC
19  *
20  * Scott Cantor <cantor.2@osu.edu>
21  *
22  * $Id$
23  */
24
25 // eventually we might be able to support autoconf via cygwin...
26 #if defined (_MSC_VER) || defined(__BORLANDC__)
27 # include "config_win32.h"
28 #else
29 # include "config.h"
30 #endif
31
32 #ifdef WIN32
33 # define _CRT_NONSTDC_NO_DEPRECATE 1
34 # define _CRT_SECURE_NO_DEPRECATE 1
35 # define NOMINMAX
36 # define SHIBODBC_EXPORTS __declspec(dllexport)
37 #else
38 # define SHIBODBC_EXPORTS
39 #endif
40
41 #include <shib/shib-threads.h>
42 #include <shib-target/shib-target.h>
43 #include <log4cpp/Category.hh>
44
45 #include <algorithm>
46 #include <sstream>
47
48 #include <sql.h>
49 #include <sqlext.h>
50
51 #ifdef HAVE_LIBDMALLOCXX
52 #include <dmalloc.h>
53 #endif
54
55 using namespace std;
56 using namespace saml;
57 using namespace shibboleth;
58 using namespace shibtarget;
59 using namespace log4cpp;
60
61 #define PLUGIN_VER_MAJOR 3
62 #define PLUGIN_VER_MINOR 0
63
64 #define COLSIZE_KEY 64
65 #define COLSIZE_APPLICATION_ID 256
66 #define COLSIZE_ADDRESS 128
67 #define COLSIZE_PROVIDER_ID 256
68 #define LONGDATA_BUFLEN 32768
69
70 /*
71   CREATE TABLE state (
72       cookie VARCHAR(64) PRIMARY KEY,
73       application_id VARCHAR(256),
74       ctime TIMESTAMP,
75       atime TIMESTAMP,
76       addr VARCHAR(128),
77       major INT,
78       minor INT,
79       provider VARCHAR(256),
80       subject TEXT,
81       authn_context TEXT,
82       tokens TEXT
83       )
84 */
85
86 #define REPLAY_TABLE \
87   "CREATE TABLE replay (id VARCHAR(255) PRIMARY KEY, " \
88   "expires TIMESTAMP, " \
89   "INDEX (expires))"
90
91 static const XMLCh ConnectionString[] =
92 { chLatin_C, chLatin_o, chLatin_n, chLatin_n, chLatin_e, chLatin_c, chLatin_t, chLatin_i, chLatin_o, chLatin_n,
93   chLatin_S, chLatin_t, chLatin_r, chLatin_i, chLatin_n, chLatin_g, chNull
94 };
95 static const XMLCh cleanupInterval[] =
96 { chLatin_c, chLatin_l, chLatin_e, chLatin_a, chLatin_n, chLatin_u, chLatin_p,
97   chLatin_I, chLatin_n, chLatin_t, chLatin_e, chLatin_r, chLatin_v, chLatin_a, chLatin_l, chNull
98 };
99 static const XMLCh cacheTimeout[] =
100 { chLatin_c, chLatin_a, chLatin_c, chLatin_h, chLatin_e, chLatin_T, chLatin_i, chLatin_m, chLatin_e, chLatin_o, chLatin_u, chLatin_t, chNull };
101 static const XMLCh odbcTimeout[] =
102 { chLatin_o, chLatin_d, chLatin_b, chLatin_c, chLatin_T, chLatin_i, chLatin_m, chLatin_e, chLatin_o, chLatin_u, chLatin_t, chNull };
103 static const XMLCh storeAttributes[] =
104 { chLatin_s, chLatin_t, chLatin_o, chLatin_r, chLatin_e, chLatin_A, chLatin_t, chLatin_t, chLatin_r, chLatin_i, chLatin_b, chLatin_u, chLatin_t, chLatin_e, chLatin_s, chNull };
105
106 struct ODBCConn {
107     ODBCConn(SQLHDBC conn) : handle(conn) {}
108     ~ODBCConn() {SQLFreeHandle(SQL_HANDLE_DBC,handle);}
109     operator SQLHDBC() {return handle;}
110     SQLHDBC handle;
111 };
112
113 class ODBCBase : public virtual saml::IPlugIn
114 {
115 public:
116     ODBCBase(const DOMElement* e);
117     virtual ~ODBCBase();
118
119     SQLHDBC getHDBC();
120
121     log4cpp::Category* log;
122
123 protected:
124     //ThreadKey* m_mysql;
125     const DOMElement* m_root; // can only use this during initialization
126     string m_connstring;
127
128     static SQLHENV m_henv;          // single handle for both plugins
129     bool m_bInitializedODBC;        // tracks which class handled the process
130     static const char* p_connstring;
131
132     pair<int,int> getVersion(SQLHDBC);
133     void log_error(SQLHANDLE handle, SQLSMALLINT htype);
134 };
135
136 SQLHENV ODBCBase::m_henv = SQL_NULL_HANDLE;
137 const char* ODBCBase::p_connstring = NULL;
138
139 ODBCBase::ODBCBase(const DOMElement* e) : m_root(e), m_bInitializedODBC(false)
140 {
141 #ifdef _DEBUG
142     saml::NDC ndc("ODBCBase");
143 #endif
144     log = &(Category::getInstance("shibtarget.ODBC"));
145
146     if (m_henv == SQL_NULL_HANDLE) {
147         // Enable connection pooling.
148         SQLSetEnvAttr(SQL_NULL_HANDLE, SQL_ATTR_CONNECTION_POOLING, (void*)SQL_CP_ONE_PER_HENV, 0);
149
150         // Allocate the environment.
151         if (!SQL_SUCCEEDED(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &m_henv)))
152             throw ConfigurationException("ODBC failed to initialize.");
153
154         // Specify ODBC 3.x
155         SQLSetEnvAttr(m_henv, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0);
156
157         log->info("ODBC initialized");
158         m_bInitializedODBC = true;
159     }
160
161     // Grab connection string from the configuration.
162     e=saml::XML::getFirstChildElement(e,shibtarget::XML::SHIBTARGET_NS,ConnectionString);
163     if (!e || !e->hasChildNodes()) {
164         if (!p_connstring) {
165             this->~ODBCBase();
166             throw ConfigurationException("ODBC cache requires ConnectionString element in configuration.");
167         }
168         m_connstring=p_connstring;
169     }
170     else {
171         auto_ptr_char arg(e->getFirstChild()->getNodeValue());
172         m_connstring=arg.get();
173         p_connstring=m_connstring.c_str();
174     }
175
176     // Connect and check version.
177     SQLHDBC conn=getHDBC();
178     pair<int,int> v=getVersion(conn);
179     SQLFreeHandle(SQL_HANDLE_DBC,conn);
180
181     // Make sure we've got the right version.
182     if (v.first != PLUGIN_VER_MAJOR) {
183         this->~ODBCBase();
184         log->crit("unknown database version: %d.%d", v.first, v.second);
185         throw SAMLException("Unknown cache database version.");
186     }
187 }
188
189 ODBCBase::~ODBCBase()
190 {
191     //delete m_mysql;
192     if (m_bInitializedODBC)
193         SQLFreeHandle(SQL_HANDLE_ENV,m_henv);
194     m_bInitializedODBC=false;
195     m_henv = SQL_NULL_HANDLE;
196     p_connstring=NULL;
197 }
198
199 void ODBCBase::log_error(SQLHANDLE handle, SQLSMALLINT htype)
200 {
201     SQLSMALLINT  i = 0;
202     SQLINTEGER   native;
203     SQLCHAR      state[7];
204     SQLCHAR      text[256];
205     SQLSMALLINT  len;
206     SQLRETURN    ret;
207
208     do {
209         ret = SQLGetDiagRec(htype, handle, ++i, state, &native, text, sizeof(text), &len);
210         if (SQL_SUCCEEDED(ret))
211             log->error("ODBC Error: %s:%ld:%ld:%s", state, i, native, text);
212     } while(SQL_SUCCEEDED(ret));
213 }
214
215 SQLHDBC ODBCBase::getHDBC()
216 {
217 #ifdef _DEBUG
218     saml::NDC ndc("getMYSQL");
219 #endif
220
221     // Get a handle.
222     SQLHDBC handle;
223     SQLRETURN sr=SQLAllocHandle(SQL_HANDLE_DBC, m_henv, &handle);
224     if (!SQL_SUCCEEDED(sr)) {
225         log->error("failed to allocate connection handle");
226         log_error(m_henv, SQL_HANDLE_ENV);
227         throw SAMLException("ODBCBase::getHDBC failed to allocate connection handle");
228     }
229
230     sr=SQLDriverConnect(handle,NULL,(SQLCHAR*)m_connstring.c_str(),m_connstring.length(),NULL,0,NULL,SQL_DRIVER_NOPROMPT);
231     if (!SQL_SUCCEEDED(sr)) {
232         log->error("failed to connect to database");
233         log_error(handle, SQL_HANDLE_DBC);
234         throw SAMLException("ODBCBase::getHDBC failed to connect to database");
235     }
236
237     return handle;
238 }
239
240 pair<int,int> ODBCBase::getVersion(SQLHDBC conn)
241 {
242     // Grab the version number from the database.
243     SQLHSTMT hstmt;
244     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
245     
246     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)"SELECT major,minor FROM version", SQL_NTS);
247     if (!SQL_SUCCEEDED(sr)) {
248         log->error("failed to read version from database");
249         log_error(hstmt, SQL_HANDLE_STMT);
250         throw SAMLException("ODBCBase::getVersion failed to read version from database");
251     }
252
253     SQLINTEGER major;
254     SQLINTEGER minor;
255     SQLBindCol(hstmt,1,SQL_C_SLONG,&major,0,NULL);
256     SQLBindCol(hstmt,2,SQL_C_SLONG,&minor,0,NULL);
257
258     if ((sr=SQLFetch(hstmt)) != SQL_NO_DATA) {
259         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
260         return pair<int,int>(major,minor);
261     }
262
263     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
264     log->error("no rows returned in version query");
265     throw SAMLException("ODBCBase::getVersion failed to read version from database");
266 }
267
268 class ODBCCCache : public ODBCBase, virtual public ISessionCache, virtual public ISessionCacheStore
269 {
270 public:
271     ODBCCCache(const DOMElement* e);
272     virtual ~ODBCCCache();
273
274     // Delegate all the ISessionCache methods.
275     string insert(
276         const IApplication* application,
277         const IEntityDescriptor* source,
278         const char* client_addr,
279         const SAMLSubject* subject,
280         const char* authnContext,
281         const SAMLResponse* tokens
282         )
283     { return m_cache->insert(application,source,client_addr,subject,authnContext,tokens); }
284     ISessionCacheEntry* find(const char* key, const IApplication* application, const char* client_addr)
285     { return m_cache->find(key,application,client_addr); }
286     void remove(const char* key, const IApplication* application, const char* client_addr)
287     { m_cache->remove(key,application,client_addr); }
288
289     bool setBackingStore(ISessionCacheStore*) {return false;}
290
291     // Store methods handle the database work
292     HRESULT onCreate(
293         const char* key,
294         const IApplication* application,
295         const ISessionCacheEntry* entry,
296         int majorVersion,
297         int minorVersion,
298         time_t created
299         );
300     HRESULT onRead(
301         const char* key,
302         string& applicationId,
303         string& clientAddress,
304         string& providerId,
305         string& subject,
306         string& authnContext,
307         string& tokens,
308         int& majorVersion,
309         int& minorVersion,
310         time_t& created,
311         time_t& accessed
312         );
313     HRESULT onRead(const char* key, time_t& accessed);
314     HRESULT onRead(const char* key, string& tokens);
315     HRESULT onUpdate(const char* key, const char* tokens=NULL, time_t accessed=0);
316     HRESULT onDelete(const char* key);
317
318     void cleanup();
319
320 private:
321     bool m_storeAttributes;
322     ISessionCache* m_cache;
323     CondWait* shutdown_wait;
324     bool shutdown;
325     Thread* cleanup_thread;
326
327     static void* cleanup_fcn(void*); // XXX Assumed an ODBCCCache
328 };
329
330 ODBCCCache::ODBCCCache(const DOMElement* e) : ODBCBase(e), m_storeAttributes(false)
331 {
332 #ifdef _DEBUG
333     saml::NDC ndc("ODBCCCache");
334 #endif
335     log = &(Category::getInstance("shibtarget.SessionCache.ODBC"));
336
337     m_cache = dynamic_cast<ISessionCache*>(
338         SAMLConfig::getConfig().getPlugMgr().newPlugin(shibtarget::XML::MemorySessionCacheType, m_root)
339     );
340     if (!m_cache->setBackingStore(this)) {
341         delete m_cache;
342         throw SAMLException("Unable to register ODBC cache plugin as a cache store.");
343     }
344     
345     shutdown_wait = CondWait::create();
346     shutdown = false;
347
348     // Load our configuration details...
349     const XMLCh* tag=m_root->getAttributeNS(NULL,storeAttributes);
350     if (tag && *tag && (*tag==chLatin_t || *tag==chDigit_1))
351         m_storeAttributes=true;
352
353     // Initialize the cleanup thread
354     cleanup_thread = Thread::create(&cleanup_fcn, (void*)this);
355 }
356
357 ODBCCCache::~ODBCCCache()
358 {
359     shutdown = true;
360     shutdown_wait->signal();
361     cleanup_thread->join(NULL);
362     delete m_cache;
363 }
364
365 void appendXML(ostream& os, const char* str)
366 {
367     const char* pos=strchr(str,'\'');
368     while (pos) {
369         os.write(str,pos-str);
370         os << "''";
371         str=pos+1;
372         pos=strchr(str,'\'');
373     }
374     os << str;
375 }
376
377 HRESULT ODBCCCache::onCreate(
378     const char* key,
379     const IApplication* application,
380     const ISessionCacheEntry* entry,
381     int majorVersion,
382     int minorVersion,
383     time_t created
384     )
385 {
386 #ifdef _DEBUG
387     saml::NDC ndc("onCreate");
388 #endif
389
390     // Get XML data from entry. Default is not to return SAML objects.
391     const char* context=entry->getAuthnContext();
392     pair<const char*,const SAMLSubject*> subject=entry->getSubject();
393     pair<const char*,const SAMLResponse*> tokens=entry->getTokens();
394
395     // Stringify timestamp.
396     if (created==0)
397         created=time(NULL);
398 #ifndef HAVE_GMTIME_R
399     struct tm* ptime=gmtime(&created);
400 #else
401     struct tm res;
402     struct tm* ptime=gmtime_r(&created,&res);
403 #endif
404     char timebuf[32];
405     strftime(timebuf,32,"{ts '%Y-%m-%d %H:%M:%S'}",ptime);
406
407     // Prepare insert statement.
408     ostringstream q;
409     q << "INSERT state VALUES ('" << key << "','" << application->getId() << "'," << timebuf << "," << timebuf
410         << ",'" << entry->getClientAddress() << "'," << majorVersion << "," << minorVersion << ",'" << entry->getProviderId()
411         << "','";
412     appendXML(q,subject.first);
413     q << "','";
414     appendXML(q,context);
415     q << "',";
416     if (m_storeAttributes && tokens.first) {
417         q << "'";
418         appendXML(q,tokens.first);
419         q << "')";
420     }
421     else
422         q << "null)";
423     if (log->isDebugEnabled())
424         log->debug("SQL insert: %s", q.str().c_str());
425
426     // Get statement handle.
427     SQLHSTMT hstmt;
428     ODBCConn conn(getHDBC());
429     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
430
431     // Execute statement.
432     HRESULT hr=NOERROR;
433     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.str().c_str(), SQL_NTS);
434     if (!SQL_SUCCEEDED(sr)) {
435         log->error("failed to insert record into database");
436         log_error(hstmt, SQL_HANDLE_STMT);
437         hr=E_FAIL;
438     }
439
440     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
441     return hr;
442 }
443
444 HRESULT ODBCCCache::onRead(
445     const char* key,
446     string& applicationId,
447     string& clientAddress,
448     string& providerId,
449     string& subject,
450     string& authnContext,
451     string& tokens,
452     int& majorVersion,
453     int& minorVersion,
454     time_t& created,
455     time_t& accessed
456     )
457 {
458 #ifdef _DEBUG
459     saml::NDC ndc("onRead");
460 #endif
461
462     log->debug("searching database...");
463
464     SQLHSTMT hstmt;
465     ODBCConn conn(getHDBC());
466     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
467
468     string q = string("SELECT application_id,ctime,atime,addr,major,minor,provider,subject,authn_context,tokens FROM state WHERE cookie='") + key + "'";
469     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
470     if (!SQL_SUCCEEDED(sr)) {
471         log->error("error searching for (%s)",key);
472         log_error(hstmt, SQL_HANDLE_STMT);
473         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
474         return E_FAIL;
475     }
476
477     SQLINTEGER major,minor;
478     SQL_TIMESTAMP_STRUCT atime,ctime;
479     SQLCHAR application_id[COLSIZE_APPLICATION_ID+1];
480     SQLCHAR addr[COLSIZE_ADDRESS+1];
481     SQLCHAR provider_id[COLSIZE_PROVIDER_ID+1];
482
483     // Bind simple output columns.
484     SQLBindCol(hstmt,1,SQL_C_CHAR,application_id,sizeof(application_id),NULL);
485     SQLBindCol(hstmt,2,SQL_C_TYPE_TIMESTAMP,&ctime,0,NULL);
486     SQLBindCol(hstmt,3,SQL_C_TYPE_TIMESTAMP,&atime,0,NULL);
487     SQLBindCol(hstmt,4,SQL_C_CHAR,addr,sizeof(addr),NULL);
488     SQLBindCol(hstmt,5,SQL_C_SLONG,&major,0,NULL);
489     SQLBindCol(hstmt,6,SQL_C_SLONG,&minor,0,NULL);
490     SQLBindCol(hstmt,7,SQL_C_CHAR,provider_id,sizeof(provider_id),NULL);
491
492     if ((sr=SQLFetch(hstmt)) == SQL_NO_DATA) {
493         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
494         return S_FALSE;
495     }
496
497     log->debug("session found, tranfering data back into memory");
498
499     // Copy back simple data.
500     applicationId = (char*)application_id;
501     clientAddress = (char*)addr;
502     majorVersion = major;
503     minorVersion = minor;
504     providerId = (char*)provider_id;
505
506     struct tm t;
507     t.tm_sec=ctime.second;
508     t.tm_min=ctime.minute;
509     t.tm_hour=ctime.hour;
510     t.tm_mday=ctime.day;
511     t.tm_mon=ctime.month-1;
512     t.tm_year=ctime.year-1900;
513     t.tm_isdst=0;
514 #if defined(HAVE_TIMEGM)
515     created=timegm(&t);
516 #else
517     // Windows, and hopefully most others...?
518     created = mktime(&t) - timezone;
519 #endif
520     t.tm_sec=atime.second;
521     t.tm_min=atime.minute;
522     t.tm_hour=atime.hour;
523     t.tm_mday=atime.day;
524     t.tm_mon=atime.month-1;
525     t.tm_year=atime.year-1900;
526     t.tm_isdst=0;
527 #if defined(HAVE_TIMEGM)
528     accessed=timegm(&t);
529 #else
530     // Windows, and hopefully most others...?
531     accessed = mktime(&t) - timezone;
532 #endif
533
534     // Extract text data.
535     string* ptrs[] = {&subject, &authnContext, &tokens};
536     HRESULT hr=NOERROR;
537     SQLINTEGER len;
538     SQLCHAR buf[LONGDATA_BUFLEN];
539     for (int i=0; i<3; i++) {
540         while ((sr=SQLGetData(hstmt,i+8,SQL_C_CHAR,buf,sizeof(buf),&len)) != SQL_NO_DATA) {
541             if (!SUCCEEDED(sr)) {
542                 log->error("error while reading text field from result set");
543                 log_error(hstmt, SQL_HANDLE_STMT);
544                 hr=E_FAIL;
545                 break;
546             }
547             ptrs[i]->append((char*)buf);
548         }
549     }
550
551     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
552     return hr;
553 }
554
555 HRESULT ODBCCCache::onRead(const char* key, time_t& accessed)
556 {
557 #ifdef _DEBUG
558     saml::NDC ndc("onRead");
559 #endif
560
561     log->debug("reading last access time from database");
562
563     SQLHSTMT hstmt;
564     ODBCConn conn(getHDBC());
565     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
566     
567     string q = string("SELECT atime FROM state WHERE cookie='") + key + "'";
568     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
569     if (!SQL_SUCCEEDED(sr)) {
570         log->error("error searching for (%s)",key);
571         log_error(hstmt, SQL_HANDLE_STMT);
572         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
573         return E_FAIL;
574     }
575
576     SQL_TIMESTAMP_STRUCT atime;
577     SQLBindCol(hstmt,1,SQL_C_TYPE_TIMESTAMP,&atime,0,NULL);
578
579     if ((sr=SQLFetch(hstmt)) == SQL_NO_DATA) {
580         log->warn("session expected, but not found in database");
581         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
582         return S_FALSE;
583     }
584
585     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
586
587     struct tm t;
588     t.tm_sec=atime.second;
589     t.tm_min=atime.minute;
590     t.tm_hour=atime.hour;
591     t.tm_mday=atime.day;
592     t.tm_mon=atime.month-1;
593     t.tm_year=atime.year-1900;
594     t.tm_isdst=0;
595 #if defined(HAVE_TIMEGM)
596     accessed=timegm(&t);
597 #else
598     // Windows, and hopefully most others...?
599     accessed = mktime(&t) - timezone;
600 #endif
601     return NOERROR;
602 }
603
604 HRESULT ODBCCCache::onRead(const char* key, string& tokens)
605 {
606 #ifdef _DEBUG
607     saml::NDC ndc("onRead");
608 #endif
609
610     if (!m_storeAttributes)
611         return S_FALSE;
612
613     log->debug("reading cached tokens from database");
614
615     SQLHSTMT hstmt;
616     ODBCConn conn(getHDBC());
617     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
618     
619     string q = string("SELECT tokens FROM state WHERE cookie='") + key + "'";
620     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
621     if (!SQL_SUCCEEDED(sr)) {
622         log->error("error searching for (%s)",key);
623         log_error(hstmt, SQL_HANDLE_STMT);
624         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
625         return E_FAIL;
626     }
627
628     if ((sr=SQLFetch(hstmt)) == SQL_NO_DATA) {
629         log->warn("session expected, but not found in database");
630         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
631         return S_FALSE;
632     }
633
634     HRESULT hr=NOERROR;
635     SQLINTEGER len;
636     SQLCHAR buf[LONGDATA_BUFLEN];
637     while ((sr=SQLGetData(hstmt,1,SQL_C_CHAR,buf,sizeof(buf),&len)) != SQL_NO_DATA) {
638         if (!SUCCEEDED(sr)) {
639             log->error("error while reading text field from result set");
640             log_error(hstmt, SQL_HANDLE_STMT);
641             hr=E_FAIL;
642             break;
643         }
644         tokens += (char*)buf;
645     }
646
647     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
648     return hr;
649 }
650
651 HRESULT ODBCCCache::onUpdate(const char* key, const char* tokens, time_t lastAccess)
652 {
653 #ifdef _DEBUG
654     saml::NDC ndc("onUpdate");
655 #endif
656
657     ostringstream q;
658
659     if (lastAccess>0) {
660 #ifndef HAVE_GMTIME_R
661         struct tm* ptime=gmtime(&lastAccess);
662 #else
663         struct tm res;
664         struct tm* ptime=gmtime_r(&lastAccess,&res);
665 #endif
666         char timebuf[32];
667         strftime(timebuf,32,"{ts '%Y-%m-%d %H:%M:%S'}",ptime);
668
669         q << "UPDATE state SET atime=" << timebuf << " WHERE cookie='" << key << "'";
670     }
671     else if (tokens) {
672         if (!m_storeAttributes)
673             return S_FALSE;
674         q << "UPDATE state SET tokens=";
675         if (*tokens) {
676             q << "'";
677             appendXML(q,tokens);
678             q << "' ";
679         }
680         else
681             q << "null ";
682         q << "WHERE cookie='" << key << "'";
683     }
684     else {
685         log->warn("onUpdate called with nothing to do!");
686         return S_FALSE;
687     }
688  
689     HRESULT hr=NOERROR;
690     SQLHSTMT hstmt;
691     ODBCConn conn(getHDBC());
692     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
693     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.str().c_str(), SQL_NTS);
694     if (sr==SQL_NO_DATA)
695         hr=S_FALSE;
696     else if (!SQL_SUCCEEDED(sr)) {
697         log->error("error updating record (key=%s)", key);
698         log_error(hstmt, SQL_HANDLE_STMT);
699         hr=E_FAIL;
700     }
701
702     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
703     return hr;
704 }
705
706 HRESULT ODBCCCache::onDelete(const char* key)
707 {
708 #ifdef _DEBUG
709     saml::NDC ndc("onDelete");
710 #endif
711
712     SQLHSTMT hstmt;
713     ODBCConn conn(getHDBC());
714     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
715     string q = string("DELETE FROM state WHERE cookie='") + key + "'";
716     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
717  
718     HRESULT hr=NOERROR;
719     if (sr==SQL_NO_DATA)
720         hr=S_FALSE;
721     else if (!SQL_SUCCEEDED(sr)) {
722         log->error("error deleting record (key=%s)", key);
723         log_error(hstmt, SQL_HANDLE_STMT);
724         hr=E_FAIL;
725     }
726
727     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
728     return hr;
729 }
730
731 void ODBCCCache::cleanup()
732 {
733 #ifdef _DEBUG
734     saml::NDC ndc("cleanup");
735 #endif
736
737     Mutex* mutex = Mutex::create();
738
739     int rerun_timer = 0;
740     int timeout_life = 0;
741
742     // Load our configuration details...
743     const XMLCh* tag=m_root->getAttributeNS(NULL,cleanupInterval);
744     if (tag && *tag)
745         rerun_timer = XMLString::parseInt(tag);
746
747     // search for 'mysql-cache-timeout' and then the regular cache timeout
748     tag=m_root->getAttributeNS(NULL,odbcTimeout);
749     if (tag && *tag)
750         timeout_life = XMLString::parseInt(tag);
751     else {
752         tag=m_root->getAttributeNS(NULL,cacheTimeout);
753         if (tag && *tag)
754             timeout_life = XMLString::parseInt(tag);
755     }
756   
757     if (rerun_timer <= 0)
758         rerun_timer = 300;              // rerun every 5 minutes
759
760     if (timeout_life <= 0)
761         timeout_life = 28800;   // timeout after 8 hours
762     
763     mutex->lock();
764
765     log->info("cleanup thread started...Run every %d secs; timeout after %d secs", rerun_timer, timeout_life);
766
767     while (shutdown == false) {
768         shutdown_wait->timedwait(mutex, rerun_timer);
769
770         if (shutdown == true)
771             break;
772
773         // Find all the entries in the database that haven't been used
774         // recently In particular, find all entries that have not been
775         // accessed in 'timeout_life' seconds.
776
777         time_t stale=time(NULL)-timeout_life;
778 #ifndef HAVE_GMTIME_R
779         struct tm* ptime=gmtime(&stale);
780 #else
781         struct tm res;
782         struct tm* ptime=gmtime_r(&stale,&res);
783 #endif
784         char timebuf[32];
785         strftime(timebuf,32,"{ts '%Y-%m-%d %H:%M:%S'}",ptime);
786
787         string q = string("DELETE state WHERE atime < ") +  timebuf;
788
789         SQLHSTMT hstmt;
790         ODBCConn conn(getHDBC());
791         SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
792         SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
793         if (sr!=SQL_NO_DATA && !SQL_SUCCEEDED(sr)) {
794             log->error("error purging old records");
795             log_error(hstmt, SQL_HANDLE_STMT);
796         }
797
798         SQLINTEGER rowcount=0;
799         sr=SQLRowCount(hstmt,&rowcount);
800         if (SQL_SUCCEEDED(sr) && rowcount > 0)
801             log->info("purging %d old sessions",rowcount);
802
803         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
804      }
805
806     log->info("cleanup thread exiting...");
807
808     mutex->unlock();
809     delete mutex;
810     Thread::exit(NULL);
811 }
812
813 void* ODBCCCache::cleanup_fcn(void* cache_p)
814 {
815   ODBCCCache* cache = (ODBCCCache*)cache_p;
816
817   // First, let's block all signals
818   Thread::mask_all_signals();
819
820   // Now run the cleanup process.
821   cache->cleanup();
822   return NULL;
823 }
824
825
826 class ODBCReplayCache : public ODBCBase, virtual public IReplayCache
827 {
828 public:
829   ODBCReplayCache(const DOMElement* e);
830   virtual ~ODBCReplayCache() {}
831
832   bool check(const XMLCh* str, time_t expires) {auto_ptr_XMLCh temp(str); return check(temp.get(),expires);}
833   bool check(const char* str, time_t expires);
834 };
835
836 ODBCReplayCache::ODBCReplayCache(const DOMElement* e) : ODBCBase(e)
837 {
838 #ifdef _DEBUG
839     saml::NDC ndc("ODBCReplayCache");
840 #endif
841     log = &(Category::getInstance("shibtarget.ReplayCache.ODBC"));
842 }
843
844 bool ODBCReplayCache::check(const char* str, time_t expires)
845 {
846 #ifdef _DEBUG
847     saml::NDC ndc("check");
848 #endif
849   
850     time_t now=time(NULL);
851 #ifndef HAVE_GMTIME_R
852     struct tm* ptime=gmtime(&now);
853 #else
854     struct tm res;
855     struct tm* ptime=gmtime_r(&now,&res);
856 #endif
857     char timebuf[32];
858     strftime(timebuf,32,"{ts '%Y-%m-%d %H:%M:%S'}",ptime);
859
860     // Remove expired entries.
861     SQLHSTMT hstmt;
862     ODBCConn conn(getHDBC());
863     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
864     string q = string("DELETE FROM replay WHERE expires < ") + timebuf;
865     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
866     if (sr!=SQL_NO_DATA && !SQL_SUCCEEDED(sr)) {
867         log->error("error purging old replay cache entries");
868         log_error(hstmt, SQL_HANDLE_STMT);
869     }
870     SQLCloseCursor(hstmt);
871   
872     // Look for a replay.
873     q = string("SELECT id FROM replay WHERE id='") + str + "'";
874     sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
875     if (!SQL_SUCCEEDED(sr)) {
876         log->error("error searching replay cache");
877         log_error(hstmt, SQL_HANDLE_STMT);
878         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
879         throw SAMLException("Replay cache failed, please inform application support staff.");
880     }
881
882     // If we got a record, we return false.
883     if ((sr=SQLFetch(hstmt)) != SQL_NO_DATA) {
884         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
885         return false;
886     }
887     SQLCloseCursor(hstmt);
888     
889 #ifndef HAVE_GMTIME_R
890     ptime=gmtime(&expires);
891 #else
892     ptime=gmtime_r(&expires,&res);
893 #endif
894     strftime(timebuf,32,"{ts '%Y-%m-%d %H:%M:%S'}",ptime);
895
896     // Add it to the database.
897     q = string("INSERT replay VALUES('") + str + "'," + timebuf + ")";
898     sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
899     if (!SQL_SUCCEEDED(sr)) {
900         log->error("error inserting replay cache entry", str);
901         log_error(hstmt, SQL_HANDLE_STMT);
902         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
903         throw SAMLException("Replay cache failed, please inform application support staff.");
904     }
905
906     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
907     return true;
908 }
909
910
911 // Factories
912
913 IPlugIn* new_odbc_ccache(const DOMElement* e)
914 {
915     return new ODBCCCache(e);
916 }
917
918 IPlugIn* new_odbc_replay(const DOMElement* e)
919 {
920     return new ODBCReplayCache(e);
921 }
922
923
924 extern "C" int SHIBODBC_EXPORTS saml_extension_init(void*)
925 {
926     // register this ccache type
927     SAMLConfig::getConfig().getPlugMgr().regFactory(shibtarget::XML::ODBCReplayCacheType, &new_odbc_replay);
928     SAMLConfig::getConfig().getPlugMgr().regFactory(shibtarget::XML::ODBCSessionCacheType, &new_odbc_ccache);
929     return 0;
930 }
931
932 extern "C" void SHIBODBC_EXPORTS saml_extension_term()
933 {
934     SAMLConfig::getConfig().getPlugMgr().unregFactory(shibtarget::XML::ODBCSessionCacheType);
935     SAMLConfig::getConfig().getPlugMgr().unregFactory(shibtarget::XML::ODBCReplayCacheType);
936 }