Starting to refactor session cache, eliminated IConfig class.
[shibboleth/cpp-sp.git] / odbc_ccache / odbc-ccache.cpp
1 /*
2  *  Copyright 2001-2005 Internet2
3  * 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /*
18  * odbc-ccache.cpp - Shibboleth Credential Cache using ODBC
19  *
20  * Scott Cantor <cantor.2@osu.edu>
21  *
22  * $Id$
23  */
24
25 // eventually we might be able to support autoconf via cygwin...
26 #if defined (_MSC_VER) || defined(__BORLANDC__)
27 # include "config_win32.h"
28 #else
29 # include "config.h"
30 #endif
31
32 #ifdef WIN32
33 # define _CRT_NONSTDC_NO_DEPRECATE 1
34 # define _CRT_SECURE_NO_DEPRECATE 1
35 # define NOMINMAX
36 # define SHIBODBC_EXPORTS __declspec(dllexport)
37 #else
38 # define SHIBODBC_EXPORTS
39 #endif
40
41 #include <shib-target/shib-target.h>
42 #include <shibsp/exceptions.h>
43 #include <shibsp/SPConfig.h>
44 #include <log4cpp/Category.hh>
45 #include <xmltooling/util/NDC.h>
46 #include <xmltooling/util/Threads.h>
47
48 #include <ctime>
49 #include <algorithm>
50 #include <sstream>
51
52 #include <sql.h>
53 #include <sqlext.h>
54
55 #ifdef HAVE_LIBDMALLOCXX
56 #include <dmalloc.h>
57 #endif
58
59 using namespace shibsp;
60 using namespace shibtarget;
61 using namespace opensaml::saml2md;
62 using namespace saml;
63 using namespace xmltooling;
64 using namespace log4cpp;
65 using namespace std;
66
67 #define PLUGIN_VER_MAJOR 3
68 #define PLUGIN_VER_MINOR 0
69
70 #define COLSIZE_KEY 64
71 #define COLSIZE_APPLICATION_ID 256
72 #define COLSIZE_ADDRESS 128
73 #define COLSIZE_PROVIDER_ID 256
74 #define LONGDATA_BUFLEN 32768
75
76 /*
77   CREATE TABLE state (
78       cookie VARCHAR(64) PRIMARY KEY,
79       application_id VARCHAR(256),
80       ctime TIMESTAMP,
81       atime TIMESTAMP,
82       addr VARCHAR(128),
83       major INT,
84       minor INT,
85       provider VARCHAR(256),
86       subject TEXT,
87       authn_context TEXT,
88       tokens TEXT
89       )
90 */
91
92 #define REPLAY_TABLE \
93   "CREATE TABLE replay (id VARCHAR(255) PRIMARY KEY, " \
94   "expires TIMESTAMP, " \
95   "INDEX (expires))"
96
97 static const XMLCh ConnectionString[] =
98 { chLatin_C, chLatin_o, chLatin_n, chLatin_n, chLatin_e, chLatin_c, chLatin_t, chLatin_i, chLatin_o, chLatin_n,
99   chLatin_S, chLatin_t, chLatin_r, chLatin_i, chLatin_n, chLatin_g, chNull
100 };
101 static const XMLCh cleanupInterval[] =
102 { chLatin_c, chLatin_l, chLatin_e, chLatin_a, chLatin_n, chLatin_u, chLatin_p,
103   chLatin_I, chLatin_n, chLatin_t, chLatin_e, chLatin_r, chLatin_v, chLatin_a, chLatin_l, chNull
104 };
105 static const XMLCh cacheTimeout[] =
106 { chLatin_c, chLatin_a, chLatin_c, chLatin_h, chLatin_e, chLatin_T, chLatin_i, chLatin_m, chLatin_e, chLatin_o, chLatin_u, chLatin_t, chNull };
107 static const XMLCh odbcTimeout[] =
108 { chLatin_o, chLatin_d, chLatin_b, chLatin_c, chLatin_T, chLatin_i, chLatin_m, chLatin_e, chLatin_o, chLatin_u, chLatin_t, chNull };
109 static const XMLCh storeAttributes[] =
110 { chLatin_s, chLatin_t, chLatin_o, chLatin_r, chLatin_e, chLatin_A, chLatin_t, chLatin_t, chLatin_r, chLatin_i, chLatin_b, chLatin_u, chLatin_t, chLatin_e, chLatin_s, chNull };
111
112 struct ODBCConn {
113     ODBCConn(SQLHDBC conn) : handle(conn) {}
114     ~ODBCConn() {SQLFreeHandle(SQL_HANDLE_DBC,handle);}
115     operator SQLHDBC() {return handle;}
116     SQLHDBC handle;
117 };
118
119 class ODBCBase : public virtual saml::IPlugIn
120 {
121 public:
122     ODBCBase(const DOMElement* e);
123     virtual ~ODBCBase();
124
125     SQLHDBC getHDBC();
126
127     Category* log;
128
129 protected:
130     const DOMElement* m_root; // can only use this during initialization
131     string m_connstring;
132
133     static SQLHENV m_henv;          // single handle for both plugins
134     bool m_bInitializedODBC;        // tracks which class handled the process
135     static const char* p_connstring;
136
137     pair<int,int> getVersion(SQLHDBC);
138     void log_error(SQLHANDLE handle, SQLSMALLINT htype);
139 };
140
141 SQLHENV ODBCBase::m_henv = SQL_NULL_HANDLE;
142 const char* ODBCBase::p_connstring = NULL;
143
144 ODBCBase::ODBCBase(const DOMElement* e) : m_root(e), m_bInitializedODBC(false)
145 {
146 #ifdef _DEBUG
147     xmltooling::NDC ndc("ODBCBase");
148 #endif
149     log = &(Category::getInstance("shibtarget.ODBC"));
150
151     if (m_henv == SQL_NULL_HANDLE) {
152         // Enable connection pooling.
153         SQLSetEnvAttr(SQL_NULL_HANDLE, SQL_ATTR_CONNECTION_POOLING, (void*)SQL_CP_ONE_PER_HENV, 0);
154
155         // Allocate the environment.
156         if (!SQL_SUCCEEDED(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &m_henv)))
157             throw ConfigurationException("ODBC failed to initialize.");
158
159         // Specify ODBC 3.x
160         SQLSetEnvAttr(m_henv, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0);
161
162         log->info("ODBC initialized");
163         m_bInitializedODBC = true;
164     }
165
166     // Grab connection string from the configuration.
167     e=XMLHelper::getFirstChildElement(e,ConnectionString);
168     if (!e || !e->hasChildNodes()) {
169         if (!p_connstring) {
170             this->~ODBCBase();
171             throw ConfigurationException("ODBC cache requires ConnectionString element in configuration.");
172         }
173         m_connstring=p_connstring;
174     }
175     else {
176         xmltooling::auto_ptr_char arg(e->getFirstChild()->getNodeValue());
177         m_connstring=arg.get();
178         p_connstring=m_connstring.c_str();
179     }
180
181     // Connect and check version.
182     SQLHDBC conn=getHDBC();
183     pair<int,int> v=getVersion(conn);
184     SQLFreeHandle(SQL_HANDLE_DBC,conn);
185
186     // Make sure we've got the right version.
187     if (v.first != PLUGIN_VER_MAJOR) {
188         this->~ODBCBase();
189         log->crit("unknown database version: %d.%d", v.first, v.second);
190         throw SAMLException("Unknown cache database version.");
191     }
192 }
193
194 ODBCBase::~ODBCBase()
195 {
196     //delete m_mysql;
197     if (m_bInitializedODBC)
198         SQLFreeHandle(SQL_HANDLE_ENV,m_henv);
199     m_bInitializedODBC=false;
200     m_henv = SQL_NULL_HANDLE;
201     p_connstring=NULL;
202 }
203
204 void ODBCBase::log_error(SQLHANDLE handle, SQLSMALLINT htype)
205 {
206     SQLSMALLINT  i = 0;
207     SQLINTEGER   native;
208     SQLCHAR      state[7];
209     SQLCHAR      text[256];
210     SQLSMALLINT  len;
211     SQLRETURN    ret;
212
213     do {
214         ret = SQLGetDiagRec(htype, handle, ++i, state, &native, text, sizeof(text), &len);
215         if (SQL_SUCCEEDED(ret))
216             log->error("ODBC Error: %s:%ld:%ld:%s", state, i, native, text);
217     } while(SQL_SUCCEEDED(ret));
218 }
219
220 SQLHDBC ODBCBase::getHDBC()
221 {
222 #ifdef _DEBUG
223     xmltooling::NDC ndc("getMYSQL");
224 #endif
225
226     // Get a handle.
227     SQLHDBC handle;
228     SQLRETURN sr=SQLAllocHandle(SQL_HANDLE_DBC, m_henv, &handle);
229     if (!SQL_SUCCEEDED(sr)) {
230         log->error("failed to allocate connection handle");
231         log_error(m_henv, SQL_HANDLE_ENV);
232         throw SAMLException("ODBCBase::getHDBC failed to allocate connection handle");
233     }
234
235     sr=SQLDriverConnect(handle,NULL,(SQLCHAR*)m_connstring.c_str(),m_connstring.length(),NULL,0,NULL,SQL_DRIVER_NOPROMPT);
236     if (!SQL_SUCCEEDED(sr)) {
237         log->error("failed to connect to database");
238         log_error(handle, SQL_HANDLE_DBC);
239         throw SAMLException("ODBCBase::getHDBC failed to connect to database");
240     }
241
242     return handle;
243 }
244
245 pair<int,int> ODBCBase::getVersion(SQLHDBC conn)
246 {
247     // Grab the version number from the database.
248     SQLHSTMT hstmt;
249     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
250     
251     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)"SELECT major,minor FROM version", SQL_NTS);
252     if (!SQL_SUCCEEDED(sr)) {
253         log->error("failed to read version from database");
254         log_error(hstmt, SQL_HANDLE_STMT);
255         throw SAMLException("ODBCBase::getVersion failed to read version from database");
256     }
257
258     SQLINTEGER major;
259     SQLINTEGER minor;
260     SQLBindCol(hstmt,1,SQL_C_SLONG,&major,0,NULL);
261     SQLBindCol(hstmt,2,SQL_C_SLONG,&minor,0,NULL);
262
263     if ((sr=SQLFetch(hstmt)) != SQL_NO_DATA) {
264         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
265         return pair<int,int>(major,minor);
266     }
267
268     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
269     log->error("no rows returned in version query");
270     throw SAMLException("ODBCBase::getVersion failed to read version from database");
271 }
272
273 class ODBCCCache : public ODBCBase, virtual public ISessionCache, virtual public ISessionCacheStore
274 {
275 public:
276     ODBCCCache(const DOMElement* e);
277     virtual ~ODBCCCache();
278
279     // Delegate all the ISessionCache methods.
280     string insert(
281         const IApplication* application,
282         const RoleDescriptor* role,
283         const char* client_addr,
284         const SAMLSubject* subject,
285         const char* authnContext,
286         const SAMLResponse* tokens
287         )
288     { return m_cache->insert(application,role,client_addr,subject,authnContext,tokens); }
289     ISessionCacheEntry* find(const char* key, const IApplication* application, const char* client_addr)
290     { return m_cache->find(key,application,client_addr); }
291     void remove(const char* key, const IApplication* application, const char* client_addr)
292     { m_cache->remove(key,application,client_addr); }
293
294     bool setBackingStore(ISessionCacheStore*) {return false;}
295
296     // Store methods handle the database work
297     HRESULT onCreate(
298         const char* key,
299         const IApplication* application,
300         const ISessionCacheEntry* entry,
301         int majorVersion,
302         int minorVersion,
303         time_t created
304         );
305     HRESULT onRead(
306         const char* key,
307         string& applicationId,
308         string& clientAddress,
309         string& providerId,
310         string& subject,
311         string& authnContext,
312         string& tokens,
313         int& majorVersion,
314         int& minorVersion,
315         time_t& created,
316         time_t& accessed
317         );
318     HRESULT onRead(const char* key, time_t& accessed);
319     HRESULT onRead(const char* key, string& tokens);
320     HRESULT onUpdate(const char* key, const char* tokens=NULL, time_t accessed=0);
321     HRESULT onDelete(const char* key);
322
323     void cleanup();
324
325 private:
326     bool m_storeAttributes;
327     ISessionCache* m_cache;
328     xmltooling::CondWait* shutdown_wait;
329     bool shutdown;
330     xmltooling::Thread* cleanup_thread;
331
332     static void* cleanup_fcn(void*); // XXX Assumed an ODBCCCache
333 };
334
335 ODBCCCache::ODBCCCache(const DOMElement* e) : ODBCBase(e), m_storeAttributes(false)
336 {
337 #ifdef _DEBUG
338     xmltooling::NDC ndc("ODBCCCache");
339 #endif
340     log = &(Category::getInstance("shibtarget.SessionCache.ODBC"));
341
342     m_cache = dynamic_cast<ISessionCache*>(
343         SAMLConfig::getConfig().getPlugMgr().newPlugin(MEMORY_SESSIONCACHE, m_root)
344     );
345     if (!m_cache->setBackingStore(this)) {
346         delete m_cache;
347         throw SAMLException("Unable to register ODBC cache plugin as a cache store.");
348     }
349     
350     shutdown_wait = CondWait::create();
351     shutdown = false;
352
353     // Load our configuration details...
354     const XMLCh* tag=m_root->getAttributeNS(NULL,storeAttributes);
355     if (tag && *tag && (*tag==chLatin_t || *tag==chDigit_1))
356         m_storeAttributes=true;
357
358     // Initialize the cleanup thread
359     cleanup_thread = Thread::create(&cleanup_fcn, (void*)this);
360 }
361
362 ODBCCCache::~ODBCCCache()
363 {
364     shutdown = true;
365     shutdown_wait->signal();
366     cleanup_thread->join(NULL);
367     delete m_cache;
368 }
369
370 void appendXML(ostream& os, const char* str)
371 {
372     const char* pos=strchr(str,'\'');
373     while (pos) {
374         os.write(str,pos-str);
375         os << "''";
376         str=pos+1;
377         pos=strchr(str,'\'');
378     }
379     os << str;
380 }
381
382 HRESULT ODBCCCache::onCreate(
383     const char* key,
384     const IApplication* application,
385     const ISessionCacheEntry* entry,
386     int majorVersion,
387     int minorVersion,
388     time_t created
389     )
390 {
391 #ifdef _DEBUG
392     xmltooling::NDC ndc("onCreate");
393 #endif
394
395     // Get XML data from entry. Default is not to return SAML objects.
396     const char* context=entry->getAuthnContext();
397     pair<const char*,const SAMLSubject*> subject=entry->getSubject();
398     pair<const char*,const SAMLResponse*> tokens=entry->getTokens();
399
400     // Stringify timestamp.
401     if (created==0)
402         created=time(NULL);
403 #ifndef HAVE_GMTIME_R
404     struct tm* ptime=gmtime(&created);
405 #else
406     struct tm res;
407     struct tm* ptime=gmtime_r(&created,&res);
408 #endif
409     char timebuf[32];
410     strftime(timebuf,32,"{ts '%Y-%m-%d %H:%M:%S'}",ptime);
411
412     // Prepare insert statement.
413     ostringstream q;
414     q << "INSERT state VALUES ('" << key << "','" << application->getId() << "'," << timebuf << "," << timebuf
415         << ",'" << entry->getClientAddress() << "'," << majorVersion << "," << minorVersion << ",'" << entry->getProviderId()
416         << "','";
417     appendXML(q,subject.first);
418     q << "','";
419     appendXML(q,context);
420     q << "',";
421     if (m_storeAttributes && tokens.first) {
422         q << "'";
423         appendXML(q,tokens.first);
424         q << "')";
425     }
426     else
427         q << "null)";
428     if (log->isDebugEnabled())
429         log->debug("SQL insert: %s", q.str().c_str());
430
431     // Get statement handle.
432     SQLHSTMT hstmt;
433     ODBCConn conn(getHDBC());
434     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
435
436     // Execute statement.
437     HRESULT hr=NOERROR;
438     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.str().c_str(), SQL_NTS);
439     if (!SQL_SUCCEEDED(sr)) {
440         log->error("failed to insert record into database");
441         log_error(hstmt, SQL_HANDLE_STMT);
442         hr=E_FAIL;
443     }
444
445     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
446     return hr;
447 }
448
449 HRESULT ODBCCCache::onRead(
450     const char* key,
451     string& applicationId,
452     string& clientAddress,
453     string& providerId,
454     string& subject,
455     string& authnContext,
456     string& tokens,
457     int& majorVersion,
458     int& minorVersion,
459     time_t& created,
460     time_t& accessed
461     )
462 {
463 #ifdef _DEBUG
464     xmltooling::NDC ndc("onRead");
465 #endif
466
467     log->debug("searching database...");
468
469     SQLHSTMT hstmt;
470     ODBCConn conn(getHDBC());
471     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
472
473     string q = string("SELECT application_id,ctime,atime,addr,major,minor,provider,subject,authn_context,tokens FROM state WHERE cookie='") + key + "'";
474     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
475     if (!SQL_SUCCEEDED(sr)) {
476         log->error("error searching for (%s)",key);
477         log_error(hstmt, SQL_HANDLE_STMT);
478         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
479         return E_FAIL;
480     }
481
482     SQLINTEGER major,minor;
483     SQL_TIMESTAMP_STRUCT atime,ctime;
484     SQLCHAR application_id[COLSIZE_APPLICATION_ID+1];
485     SQLCHAR addr[COLSIZE_ADDRESS+1];
486     SQLCHAR provider_id[COLSIZE_PROVIDER_ID+1];
487
488     // Bind simple output columns.
489     SQLBindCol(hstmt,1,SQL_C_CHAR,application_id,sizeof(application_id),NULL);
490     SQLBindCol(hstmt,2,SQL_C_TYPE_TIMESTAMP,&ctime,0,NULL);
491     SQLBindCol(hstmt,3,SQL_C_TYPE_TIMESTAMP,&atime,0,NULL);
492     SQLBindCol(hstmt,4,SQL_C_CHAR,addr,sizeof(addr),NULL);
493     SQLBindCol(hstmt,5,SQL_C_SLONG,&major,0,NULL);
494     SQLBindCol(hstmt,6,SQL_C_SLONG,&minor,0,NULL);
495     SQLBindCol(hstmt,7,SQL_C_CHAR,provider_id,sizeof(provider_id),NULL);
496
497     if ((sr=SQLFetch(hstmt)) == SQL_NO_DATA) {
498         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
499         return S_FALSE;
500     }
501
502     log->debug("session found, tranfering data back into memory");
503
504     // Copy back simple data.
505     applicationId = (char*)application_id;
506     clientAddress = (char*)addr;
507     majorVersion = major;
508     minorVersion = minor;
509     providerId = (char*)provider_id;
510
511     struct tm t;
512     t.tm_sec=ctime.second;
513     t.tm_min=ctime.minute;
514     t.tm_hour=ctime.hour;
515     t.tm_mday=ctime.day;
516     t.tm_mon=ctime.month-1;
517     t.tm_year=ctime.year-1900;
518     t.tm_isdst=0;
519 #if defined(HAVE_TIMEGM)
520     created=timegm(&t);
521 #else
522     // Windows, and hopefully most others...?
523     created = mktime(&t) - timezone;
524 #endif
525     t.tm_sec=atime.second;
526     t.tm_min=atime.minute;
527     t.tm_hour=atime.hour;
528     t.tm_mday=atime.day;
529     t.tm_mon=atime.month-1;
530     t.tm_year=atime.year-1900;
531     t.tm_isdst=0;
532 #if defined(HAVE_TIMEGM)
533     accessed=timegm(&t);
534 #else
535     // Windows, and hopefully most others...?
536     accessed = mktime(&t) - timezone;
537 #endif
538
539     // Extract text data.
540     string* ptrs[] = {&subject, &authnContext, &tokens};
541     HRESULT hr=NOERROR;
542     SQLINTEGER len;
543     SQLCHAR buf[LONGDATA_BUFLEN];
544     for (int i=0; i<3; i++) {
545         while ((sr=SQLGetData(hstmt,i+8,SQL_C_CHAR,buf,sizeof(buf),&len)) != SQL_NO_DATA) {
546             if (!SUCCEEDED(sr)) {
547                 log->error("error while reading text field from result set");
548                 log_error(hstmt, SQL_HANDLE_STMT);
549                 hr=E_FAIL;
550                 break;
551             }
552             ptrs[i]->append((char*)buf);
553         }
554     }
555
556     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
557     return hr;
558 }
559
560 HRESULT ODBCCCache::onRead(const char* key, time_t& accessed)
561 {
562 #ifdef _DEBUG
563     xmltooling::NDC ndc("onRead");
564 #endif
565
566     log->debug("reading last access time from database");
567
568     SQLHSTMT hstmt;
569     ODBCConn conn(getHDBC());
570     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
571     
572     string q = string("SELECT atime FROM state WHERE cookie='") + key + "'";
573     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
574     if (!SQL_SUCCEEDED(sr)) {
575         log->error("error searching for (%s)",key);
576         log_error(hstmt, SQL_HANDLE_STMT);
577         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
578         return E_FAIL;
579     }
580
581     SQL_TIMESTAMP_STRUCT atime;
582     SQLBindCol(hstmt,1,SQL_C_TYPE_TIMESTAMP,&atime,0,NULL);
583
584     if ((sr=SQLFetch(hstmt)) == SQL_NO_DATA) {
585         log->warn("session expected, but not found in database");
586         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
587         return S_FALSE;
588     }
589
590     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
591
592     struct tm t;
593     t.tm_sec=atime.second;
594     t.tm_min=atime.minute;
595     t.tm_hour=atime.hour;
596     t.tm_mday=atime.day;
597     t.tm_mon=atime.month-1;
598     t.tm_year=atime.year-1900;
599     t.tm_isdst=0;
600 #if defined(HAVE_TIMEGM)
601     accessed=timegm(&t);
602 #else
603     // Windows, and hopefully most others...?
604     accessed = mktime(&t) - timezone;
605 #endif
606     return NOERROR;
607 }
608
609 HRESULT ODBCCCache::onRead(const char* key, string& tokens)
610 {
611 #ifdef _DEBUG
612     xmltooling::NDC ndc("onRead");
613 #endif
614
615     if (!m_storeAttributes)
616         return S_FALSE;
617
618     log->debug("reading cached tokens from database");
619
620     SQLHSTMT hstmt;
621     ODBCConn conn(getHDBC());
622     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
623     
624     string q = string("SELECT tokens FROM state WHERE cookie='") + key + "'";
625     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
626     if (!SQL_SUCCEEDED(sr)) {
627         log->error("error searching for (%s)",key);
628         log_error(hstmt, SQL_HANDLE_STMT);
629         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
630         return E_FAIL;
631     }
632
633     if ((sr=SQLFetch(hstmt)) == SQL_NO_DATA) {
634         log->warn("session expected, but not found in database");
635         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
636         return S_FALSE;
637     }
638
639     HRESULT hr=NOERROR;
640     SQLINTEGER len;
641     SQLCHAR buf[LONGDATA_BUFLEN];
642     while ((sr=SQLGetData(hstmt,1,SQL_C_CHAR,buf,sizeof(buf),&len)) != SQL_NO_DATA) {
643         if (!SUCCEEDED(sr)) {
644             log->error("error while reading text field from result set");
645             log_error(hstmt, SQL_HANDLE_STMT);
646             hr=E_FAIL;
647             break;
648         }
649         tokens += (char*)buf;
650     }
651
652     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
653     return hr;
654 }
655
656 HRESULT ODBCCCache::onUpdate(const char* key, const char* tokens, time_t lastAccess)
657 {
658 #ifdef _DEBUG
659     xmltooling::NDC ndc("onUpdate");
660 #endif
661
662     ostringstream q;
663
664     if (lastAccess>0) {
665 #ifndef HAVE_GMTIME_R
666         struct tm* ptime=gmtime(&lastAccess);
667 #else
668         struct tm res;
669         struct tm* ptime=gmtime_r(&lastAccess,&res);
670 #endif
671         char timebuf[32];
672         strftime(timebuf,32,"{ts '%Y-%m-%d %H:%M:%S'}",ptime);
673
674         q << "UPDATE state SET atime=" << timebuf << " WHERE cookie='" << key << "'";
675     }
676     else if (tokens) {
677         if (!m_storeAttributes)
678             return S_FALSE;
679         q << "UPDATE state SET tokens=";
680         if (*tokens) {
681             q << "'";
682             appendXML(q,tokens);
683             q << "' ";
684         }
685         else
686             q << "null ";
687         q << "WHERE cookie='" << key << "'";
688     }
689     else {
690         log->warn("onUpdate called with nothing to do!");
691         return S_FALSE;
692     }
693  
694     HRESULT hr=NOERROR;
695     SQLHSTMT hstmt;
696     ODBCConn conn(getHDBC());
697     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
698     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.str().c_str(), SQL_NTS);
699     if (sr==SQL_NO_DATA)
700         hr=S_FALSE;
701     else if (!SQL_SUCCEEDED(sr)) {
702         log->error("error updating record (key=%s)", key);
703         log_error(hstmt, SQL_HANDLE_STMT);
704         hr=E_FAIL;
705     }
706
707     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
708     return hr;
709 }
710
711 HRESULT ODBCCCache::onDelete(const char* key)
712 {
713 #ifdef _DEBUG
714     xmltooling::NDC ndc("onDelete");
715 #endif
716
717     SQLHSTMT hstmt;
718     ODBCConn conn(getHDBC());
719     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
720     string q = string("DELETE FROM state WHERE cookie='") + key + "'";
721     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
722  
723     HRESULT hr=NOERROR;
724     if (sr==SQL_NO_DATA)
725         hr=S_FALSE;
726     else if (!SQL_SUCCEEDED(sr)) {
727         log->error("error deleting record (key=%s)", key);
728         log_error(hstmt, SQL_HANDLE_STMT);
729         hr=E_FAIL;
730     }
731
732     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
733     return hr;
734 }
735
736 void ODBCCCache::cleanup()
737 {
738 #ifdef _DEBUG
739     xmltooling::NDC ndc("cleanup");
740 #endif
741
742     Mutex* mutex = xmltooling::Mutex::create();
743
744     int rerun_timer = 0;
745     int timeout_life = 0;
746
747     // Load our configuration details...
748     const XMLCh* tag=m_root->getAttributeNS(NULL,cleanupInterval);
749     if (tag && *tag)
750         rerun_timer = XMLString::parseInt(tag);
751
752     // search for 'mysql-cache-timeout' and then the regular cache timeout
753     tag=m_root->getAttributeNS(NULL,odbcTimeout);
754     if (tag && *tag)
755         timeout_life = XMLString::parseInt(tag);
756     else {
757         tag=m_root->getAttributeNS(NULL,cacheTimeout);
758         if (tag && *tag)
759             timeout_life = XMLString::parseInt(tag);
760     }
761   
762     if (rerun_timer <= 0)
763         rerun_timer = 300;              // rerun every 5 minutes
764
765     if (timeout_life <= 0)
766         timeout_life = 28800;   // timeout after 8 hours
767     
768     mutex->lock();
769
770     log->info("cleanup thread started...Run every %d secs; timeout after %d secs", rerun_timer, timeout_life);
771
772     while (shutdown == false) {
773         shutdown_wait->timedwait(mutex, rerun_timer);
774
775         if (shutdown == true)
776             break;
777
778         // Find all the entries in the database that haven't been used
779         // recently In particular, find all entries that have not been
780         // accessed in 'timeout_life' seconds.
781
782         time_t stale=time(NULL)-timeout_life;
783 #ifndef HAVE_GMTIME_R
784         struct tm* ptime=gmtime(&stale);
785 #else
786         struct tm res;
787         struct tm* ptime=gmtime_r(&stale,&res);
788 #endif
789         char timebuf[32];
790         strftime(timebuf,32,"{ts '%Y-%m-%d %H:%M:%S'}",ptime);
791
792         string q = string("DELETE state WHERE atime < ") +  timebuf;
793
794         SQLHSTMT hstmt;
795         ODBCConn conn(getHDBC());
796         SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
797         SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
798         if (sr!=SQL_NO_DATA && !SQL_SUCCEEDED(sr)) {
799             log->error("error purging old records");
800             log_error(hstmt, SQL_HANDLE_STMT);
801         }
802
803         SQLINTEGER rowcount=0;
804         sr=SQLRowCount(hstmt,&rowcount);
805         if (SQL_SUCCEEDED(sr) && rowcount > 0)
806             log->info("purging %d old sessions",rowcount);
807
808         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
809      }
810
811     log->info("cleanup thread exiting...");
812
813     mutex->unlock();
814     delete mutex;
815     xmltooling::Thread::exit(NULL);
816 }
817
818 void* ODBCCCache::cleanup_fcn(void* cache_p)
819 {
820   ODBCCCache* cache = (ODBCCCache*)cache_p;
821
822 #ifndef WIN32
823   // First, let's block all signals
824   Thread::mask_all_signals();
825 #endif
826
827   // Now run the cleanup process.
828   cache->cleanup();
829   return NULL;
830 }
831
832
833 class ODBCReplayCache : public ODBCBase, virtual public IReplayCache
834 {
835 public:
836   ODBCReplayCache(const DOMElement* e);
837   virtual ~ODBCReplayCache() {}
838
839   bool check(const XMLCh* str, time_t expires) {xmltooling::auto_ptr_XMLCh temp(str); return check(temp.get(),expires);}
840   bool check(const char* str, time_t expires);
841 };
842
843 ODBCReplayCache::ODBCReplayCache(const DOMElement* e) : ODBCBase(e)
844 {
845 #ifdef _DEBUG
846     saml::NDC ndc("ODBCReplayCache");
847 #endif
848     log = &(Category::getInstance("shibtarget.ReplayCache.ODBC"));
849 }
850
851 bool ODBCReplayCache::check(const char* str, time_t expires)
852 {
853 #ifdef _DEBUG
854     saml::NDC ndc("check");
855 #endif
856   
857     time_t now=time(NULL);
858 #ifndef HAVE_GMTIME_R
859     struct tm* ptime=gmtime(&now);
860 #else
861     struct tm res;
862     struct tm* ptime=gmtime_r(&now,&res);
863 #endif
864     char timebuf[32];
865     strftime(timebuf,32,"{ts '%Y-%m-%d %H:%M:%S'}",ptime);
866
867     // Remove expired entries.
868     SQLHSTMT hstmt;
869     ODBCConn conn(getHDBC());
870     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
871     string q = string("DELETE FROM replay WHERE expires < ") + timebuf;
872     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
873     if (sr!=SQL_NO_DATA && !SQL_SUCCEEDED(sr)) {
874         log->error("error purging old replay cache entries");
875         log_error(hstmt, SQL_HANDLE_STMT);
876     }
877     SQLCloseCursor(hstmt);
878   
879     // Look for a replay.
880     q = string("SELECT id FROM replay WHERE id='") + str + "'";
881     sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
882     if (!SQL_SUCCEEDED(sr)) {
883         log->error("error searching replay cache");
884         log_error(hstmt, SQL_HANDLE_STMT);
885         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
886         throw SAMLException("Replay cache failed, please inform application support staff.");
887     }
888
889     // If we got a record, we return false.
890     if ((sr=SQLFetch(hstmt)) != SQL_NO_DATA) {
891         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
892         return false;
893     }
894     SQLCloseCursor(hstmt);
895     
896 #ifndef HAVE_GMTIME_R
897     ptime=gmtime(&expires);
898 #else
899     ptime=gmtime_r(&expires,&res);
900 #endif
901     strftime(timebuf,32,"{ts '%Y-%m-%d %H:%M:%S'}",ptime);
902
903     // Add it to the database.
904     q = string("INSERT replay VALUES('") + str + "'," + timebuf + ")";
905     sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
906     if (!SQL_SUCCEEDED(sr)) {
907         log->error("error inserting replay cache entry", str);
908         log_error(hstmt, SQL_HANDLE_STMT);
909         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
910         throw SAMLException("Replay cache failed, please inform application support staff.");
911     }
912
913     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
914     return true;
915 }
916
917
918 // Factories
919
920 SessionCache* new_odbc_ccache(const DOMElement* const & e)
921 {
922     return new ODBCCCache(e);
923 }
924
925 IPlugIn* new_odbc_replay(const DOMElement* e)
926 {
927     return new ODBCReplayCache(e);
928 }
929
930
931 extern "C" int SHIBODBC_EXPORTS saml_extension_init(void*)
932 {
933     // register this ccache type
934     SAMLConfig::getConfig().getPlugMgr().regFactory(ODBC_REPLAYCACHE, &new_odbc_replay);
935     SPConfig::getConfig().SessionCacheManager.registerFactory(ODBC_SESSIONCACHE, &new_odbc_ccache);
936     return 0;
937 }
938
939 extern "C" void SHIBODBC_EXPORTS saml_extension_term()
940 {
941     SAMLConfig::getConfig().getPlugMgr().unregFactory(ODBC_REPLAYCACHE);
942 }