350bb27bdd469e7ecefb0e14e464af88f58c98cc
[shibboleth/cpp-sp.git] / odbc_ccache / odbc-ccache.cpp
1 /*
2  *  Copyright 2001-2005 Internet2
3  * 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /*
18  * odbc-ccache.cpp - Shibboleth Credential Cache using ODBC
19  *
20  * Scott Cantor <cantor.2@osu.edu>
21  *
22  * $Id$
23  */
24
25 // eventually we might be able to support autoconf via cygwin...
26 #if defined (_MSC_VER) || defined(__BORLANDC__)
27 # include "config_win32.h"
28 #else
29 # include "config.h"
30 #endif
31
32 #ifdef WIN32
33 # define _CRT_NONSTDC_NO_DEPRECATE 1
34 # define _CRT_SECURE_NO_DEPRECATE 1
35 # define NOMINMAX
36 # define SHIBODBC_EXPORTS __declspec(dllexport)
37 #else
38 # define SHIBODBC_EXPORTS
39 #endif
40
41 #include <shib-target/shib-target.h>
42 #include <log4cpp/Category.hh>
43 #include <xmltooling/util/NDC.h>
44
45 #include <ctime>
46 #include <algorithm>
47 #include <sstream>
48
49 #include <sql.h>
50 #include <sqlext.h>
51
52 #ifdef HAVE_LIBDMALLOCXX
53 #include <dmalloc.h>
54 #endif
55
56 using namespace shibtarget;
57 using namespace shibboleth;
58 using namespace saml;
59 using namespace xmltooling;
60 using namespace log4cpp;
61 using namespace std;
62
63 #define PLUGIN_VER_MAJOR 3
64 #define PLUGIN_VER_MINOR 0
65
66 #define COLSIZE_KEY 64
67 #define COLSIZE_APPLICATION_ID 256
68 #define COLSIZE_ADDRESS 128
69 #define COLSIZE_PROVIDER_ID 256
70 #define LONGDATA_BUFLEN 32768
71
72 /*
73   CREATE TABLE state (
74       cookie VARCHAR(64) PRIMARY KEY,
75       application_id VARCHAR(256),
76       ctime TIMESTAMP,
77       atime TIMESTAMP,
78       addr VARCHAR(128),
79       major INT,
80       minor INT,
81       provider VARCHAR(256),
82       subject TEXT,
83       authn_context TEXT,
84       tokens TEXT
85       )
86 */
87
88 #define REPLAY_TABLE \
89   "CREATE TABLE replay (id VARCHAR(255) PRIMARY KEY, " \
90   "expires TIMESTAMP, " \
91   "INDEX (expires))"
92
93 static const XMLCh ConnectionString[] =
94 { chLatin_C, chLatin_o, chLatin_n, chLatin_n, chLatin_e, chLatin_c, chLatin_t, chLatin_i, chLatin_o, chLatin_n,
95   chLatin_S, chLatin_t, chLatin_r, chLatin_i, chLatin_n, chLatin_g, chNull
96 };
97 static const XMLCh cleanupInterval[] =
98 { chLatin_c, chLatin_l, chLatin_e, chLatin_a, chLatin_n, chLatin_u, chLatin_p,
99   chLatin_I, chLatin_n, chLatin_t, chLatin_e, chLatin_r, chLatin_v, chLatin_a, chLatin_l, chNull
100 };
101 static const XMLCh cacheTimeout[] =
102 { chLatin_c, chLatin_a, chLatin_c, chLatin_h, chLatin_e, chLatin_T, chLatin_i, chLatin_m, chLatin_e, chLatin_o, chLatin_u, chLatin_t, chNull };
103 static const XMLCh odbcTimeout[] =
104 { chLatin_o, chLatin_d, chLatin_b, chLatin_c, chLatin_T, chLatin_i, chLatin_m, chLatin_e, chLatin_o, chLatin_u, chLatin_t, chNull };
105 static const XMLCh storeAttributes[] =
106 { chLatin_s, chLatin_t, chLatin_o, chLatin_r, chLatin_e, chLatin_A, chLatin_t, chLatin_t, chLatin_r, chLatin_i, chLatin_b, chLatin_u, chLatin_t, chLatin_e, chLatin_s, chNull };
107
108 struct ODBCConn {
109     ODBCConn(SQLHDBC conn) : handle(conn) {}
110     ~ODBCConn() {SQLFreeHandle(SQL_HANDLE_DBC,handle);}
111     operator SQLHDBC() {return handle;}
112     SQLHDBC handle;
113 };
114
115 class ODBCBase : public virtual saml::IPlugIn
116 {
117 public:
118     ODBCBase(const DOMElement* e);
119     virtual ~ODBCBase();
120
121     SQLHDBC getHDBC();
122
123     Category* log;
124
125 protected:
126     const DOMElement* m_root; // can only use this during initialization
127     string m_connstring;
128
129     static SQLHENV m_henv;          // single handle for both plugins
130     bool m_bInitializedODBC;        // tracks which class handled the process
131     static const char* p_connstring;
132
133     pair<int,int> getVersion(SQLHDBC);
134     void log_error(SQLHANDLE handle, SQLSMALLINT htype);
135 };
136
137 SQLHENV ODBCBase::m_henv = SQL_NULL_HANDLE;
138 const char* ODBCBase::p_connstring = NULL;
139
140 ODBCBase::ODBCBase(const DOMElement* e) : m_root(e), m_bInitializedODBC(false)
141 {
142 #ifdef _DEBUG
143     xmltooling::NDC ndc("ODBCBase");
144 #endif
145     log = &(Category::getInstance("shibtarget.ODBC"));
146
147     if (m_henv == SQL_NULL_HANDLE) {
148         // Enable connection pooling.
149         SQLSetEnvAttr(SQL_NULL_HANDLE, SQL_ATTR_CONNECTION_POOLING, (void*)SQL_CP_ONE_PER_HENV, 0);
150
151         // Allocate the environment.
152         if (!SQL_SUCCEEDED(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &m_henv)))
153             throw ConfigurationException("ODBC failed to initialize.");
154
155         // Specify ODBC 3.x
156         SQLSetEnvAttr(m_henv, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0);
157
158         log->info("ODBC initialized");
159         m_bInitializedODBC = true;
160     }
161
162     // Grab connection string from the configuration.
163     e=saml::XML::getFirstChildElement(e,shibtarget::XML::SHIBTARGET_NS,ConnectionString);
164     if (!e || !e->hasChildNodes()) {
165         if (!p_connstring) {
166             this->~ODBCBase();
167             throw ConfigurationException("ODBC cache requires ConnectionString element in configuration.");
168         }
169         m_connstring=p_connstring;
170     }
171     else {
172         xmltooling::auto_ptr_char arg(e->getFirstChild()->getNodeValue());
173         m_connstring=arg.get();
174         p_connstring=m_connstring.c_str();
175     }
176
177     // Connect and check version.
178     SQLHDBC conn=getHDBC();
179     pair<int,int> v=getVersion(conn);
180     SQLFreeHandle(SQL_HANDLE_DBC,conn);
181
182     // Make sure we've got the right version.
183     if (v.first != PLUGIN_VER_MAJOR) {
184         this->~ODBCBase();
185         log->crit("unknown database version: %d.%d", v.first, v.second);
186         throw SAMLException("Unknown cache database version.");
187     }
188 }
189
190 ODBCBase::~ODBCBase()
191 {
192     //delete m_mysql;
193     if (m_bInitializedODBC)
194         SQLFreeHandle(SQL_HANDLE_ENV,m_henv);
195     m_bInitializedODBC=false;
196     m_henv = SQL_NULL_HANDLE;
197     p_connstring=NULL;
198 }
199
200 void ODBCBase::log_error(SQLHANDLE handle, SQLSMALLINT htype)
201 {
202     SQLSMALLINT  i = 0;
203     SQLINTEGER   native;
204     SQLCHAR      state[7];
205     SQLCHAR      text[256];
206     SQLSMALLINT  len;
207     SQLRETURN    ret;
208
209     do {
210         ret = SQLGetDiagRec(htype, handle, ++i, state, &native, text, sizeof(text), &len);
211         if (SQL_SUCCEEDED(ret))
212             log->error("ODBC Error: %s:%ld:%ld:%s", state, i, native, text);
213     } while(SQL_SUCCEEDED(ret));
214 }
215
216 SQLHDBC ODBCBase::getHDBC()
217 {
218 #ifdef _DEBUG
219     xmltooling::NDC ndc("getMYSQL");
220 #endif
221
222     // Get a handle.
223     SQLHDBC handle;
224     SQLRETURN sr=SQLAllocHandle(SQL_HANDLE_DBC, m_henv, &handle);
225     if (!SQL_SUCCEEDED(sr)) {
226         log->error("failed to allocate connection handle");
227         log_error(m_henv, SQL_HANDLE_ENV);
228         throw SAMLException("ODBCBase::getHDBC failed to allocate connection handle");
229     }
230
231     sr=SQLDriverConnect(handle,NULL,(SQLCHAR*)m_connstring.c_str(),m_connstring.length(),NULL,0,NULL,SQL_DRIVER_NOPROMPT);
232     if (!SQL_SUCCEEDED(sr)) {
233         log->error("failed to connect to database");
234         log_error(handle, SQL_HANDLE_DBC);
235         throw SAMLException("ODBCBase::getHDBC failed to connect to database");
236     }
237
238     return handle;
239 }
240
241 pair<int,int> ODBCBase::getVersion(SQLHDBC conn)
242 {
243     // Grab the version number from the database.
244     SQLHSTMT hstmt;
245     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
246     
247     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)"SELECT major,minor FROM version", SQL_NTS);
248     if (!SQL_SUCCEEDED(sr)) {
249         log->error("failed to read version from database");
250         log_error(hstmt, SQL_HANDLE_STMT);
251         throw SAMLException("ODBCBase::getVersion failed to read version from database");
252     }
253
254     SQLINTEGER major;
255     SQLINTEGER minor;
256     SQLBindCol(hstmt,1,SQL_C_SLONG,&major,0,NULL);
257     SQLBindCol(hstmt,2,SQL_C_SLONG,&minor,0,NULL);
258
259     if ((sr=SQLFetch(hstmt)) != SQL_NO_DATA) {
260         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
261         return pair<int,int>(major,minor);
262     }
263
264     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
265     log->error("no rows returned in version query");
266     throw SAMLException("ODBCBase::getVersion failed to read version from database");
267 }
268
269 class ODBCCCache : public ODBCBase, virtual public ISessionCache, virtual public ISessionCacheStore
270 {
271 public:
272     ODBCCCache(const DOMElement* e);
273     virtual ~ODBCCCache();
274
275     // Delegate all the ISessionCache methods.
276     string insert(
277         const IApplication* application,
278         const IEntityDescriptor* source,
279         const char* client_addr,
280         const SAMLSubject* subject,
281         const char* authnContext,
282         const SAMLResponse* tokens
283         )
284     { return m_cache->insert(application,source,client_addr,subject,authnContext,tokens); }
285     ISessionCacheEntry* find(const char* key, const IApplication* application, const char* client_addr)
286     { return m_cache->find(key,application,client_addr); }
287     void remove(const char* key, const IApplication* application, const char* client_addr)
288     { m_cache->remove(key,application,client_addr); }
289
290     bool setBackingStore(ISessionCacheStore*) {return false;}
291
292     // Store methods handle the database work
293     HRESULT onCreate(
294         const char* key,
295         const IApplication* application,
296         const ISessionCacheEntry* entry,
297         int majorVersion,
298         int minorVersion,
299         time_t created
300         );
301     HRESULT onRead(
302         const char* key,
303         string& applicationId,
304         string& clientAddress,
305         string& providerId,
306         string& subject,
307         string& authnContext,
308         string& tokens,
309         int& majorVersion,
310         int& minorVersion,
311         time_t& created,
312         time_t& accessed
313         );
314     HRESULT onRead(const char* key, time_t& accessed);
315     HRESULT onRead(const char* key, string& tokens);
316     HRESULT onUpdate(const char* key, const char* tokens=NULL, time_t accessed=0);
317     HRESULT onDelete(const char* key);
318
319     void cleanup();
320
321 private:
322     bool m_storeAttributes;
323     ISessionCache* m_cache;
324     xmltooling::CondWait* shutdown_wait;
325     bool shutdown;
326     xmltooling::Thread* cleanup_thread;
327
328     static void* cleanup_fcn(void*); // XXX Assumed an ODBCCCache
329 };
330
331 ODBCCCache::ODBCCCache(const DOMElement* e) : ODBCBase(e), m_storeAttributes(false)
332 {
333 #ifdef _DEBUG
334     xmltooling::NDC ndc("ODBCCCache");
335 #endif
336     log = &(Category::getInstance("shibtarget.SessionCache.ODBC"));
337
338     m_cache = dynamic_cast<ISessionCache*>(
339         SAMLConfig::getConfig().getPlugMgr().newPlugin(shibtarget::XML::MemorySessionCacheType, m_root)
340     );
341     if (!m_cache->setBackingStore(this)) {
342         delete m_cache;
343         throw SAMLException("Unable to register ODBC cache plugin as a cache store.");
344     }
345     
346     shutdown_wait = CondWait::create();
347     shutdown = false;
348
349     // Load our configuration details...
350     const XMLCh* tag=m_root->getAttributeNS(NULL,storeAttributes);
351     if (tag && *tag && (*tag==chLatin_t || *tag==chDigit_1))
352         m_storeAttributes=true;
353
354     // Initialize the cleanup thread
355     cleanup_thread = Thread::create(&cleanup_fcn, (void*)this);
356 }
357
358 ODBCCCache::~ODBCCCache()
359 {
360     shutdown = true;
361     shutdown_wait->signal();
362     cleanup_thread->join(NULL);
363     delete m_cache;
364 }
365
366 void appendXML(ostream& os, const char* str)
367 {
368     const char* pos=strchr(str,'\'');
369     while (pos) {
370         os.write(str,pos-str);
371         os << "''";
372         str=pos+1;
373         pos=strchr(str,'\'');
374     }
375     os << str;
376 }
377
378 HRESULT ODBCCCache::onCreate(
379     const char* key,
380     const IApplication* application,
381     const ISessionCacheEntry* entry,
382     int majorVersion,
383     int minorVersion,
384     time_t created
385     )
386 {
387 #ifdef _DEBUG
388     xmltooling::NDC ndc("onCreate");
389 #endif
390
391     // Get XML data from entry. Default is not to return SAML objects.
392     const char* context=entry->getAuthnContext();
393     pair<const char*,const SAMLSubject*> subject=entry->getSubject();
394     pair<const char*,const SAMLResponse*> tokens=entry->getTokens();
395
396     // Stringify timestamp.
397     if (created==0)
398         created=time(NULL);
399 #ifndef HAVE_GMTIME_R
400     struct tm* ptime=gmtime(&created);
401 #else
402     struct tm res;
403     struct tm* ptime=gmtime_r(&created,&res);
404 #endif
405     char timebuf[32];
406     strftime(timebuf,32,"{ts '%Y-%m-%d %H:%M:%S'}",ptime);
407
408     // Prepare insert statement.
409     ostringstream q;
410     q << "INSERT state VALUES ('" << key << "','" << application->getId() << "'," << timebuf << "," << timebuf
411         << ",'" << entry->getClientAddress() << "'," << majorVersion << "," << minorVersion << ",'" << entry->getProviderId()
412         << "','";
413     appendXML(q,subject.first);
414     q << "','";
415     appendXML(q,context);
416     q << "',";
417     if (m_storeAttributes && tokens.first) {
418         q << "'";
419         appendXML(q,tokens.first);
420         q << "')";
421     }
422     else
423         q << "null)";
424     if (log->isDebugEnabled())
425         log->debug("SQL insert: %s", q.str().c_str());
426
427     // Get statement handle.
428     SQLHSTMT hstmt;
429     ODBCConn conn(getHDBC());
430     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
431
432     // Execute statement.
433     HRESULT hr=NOERROR;
434     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.str().c_str(), SQL_NTS);
435     if (!SQL_SUCCEEDED(sr)) {
436         log->error("failed to insert record into database");
437         log_error(hstmt, SQL_HANDLE_STMT);
438         hr=E_FAIL;
439     }
440
441     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
442     return hr;
443 }
444
445 HRESULT ODBCCCache::onRead(
446     const char* key,
447     string& applicationId,
448     string& clientAddress,
449     string& providerId,
450     string& subject,
451     string& authnContext,
452     string& tokens,
453     int& majorVersion,
454     int& minorVersion,
455     time_t& created,
456     time_t& accessed
457     )
458 {
459 #ifdef _DEBUG
460     xmltooling::NDC ndc("onRead");
461 #endif
462
463     log->debug("searching database...");
464
465     SQLHSTMT hstmt;
466     ODBCConn conn(getHDBC());
467     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
468
469     string q = string("SELECT application_id,ctime,atime,addr,major,minor,provider,subject,authn_context,tokens FROM state WHERE cookie='") + key + "'";
470     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
471     if (!SQL_SUCCEEDED(sr)) {
472         log->error("error searching for (%s)",key);
473         log_error(hstmt, SQL_HANDLE_STMT);
474         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
475         return E_FAIL;
476     }
477
478     SQLINTEGER major,minor;
479     SQL_TIMESTAMP_STRUCT atime,ctime;
480     SQLCHAR application_id[COLSIZE_APPLICATION_ID+1];
481     SQLCHAR addr[COLSIZE_ADDRESS+1];
482     SQLCHAR provider_id[COLSIZE_PROVIDER_ID+1];
483
484     // Bind simple output columns.
485     SQLBindCol(hstmt,1,SQL_C_CHAR,application_id,sizeof(application_id),NULL);
486     SQLBindCol(hstmt,2,SQL_C_TYPE_TIMESTAMP,&ctime,0,NULL);
487     SQLBindCol(hstmt,3,SQL_C_TYPE_TIMESTAMP,&atime,0,NULL);
488     SQLBindCol(hstmt,4,SQL_C_CHAR,addr,sizeof(addr),NULL);
489     SQLBindCol(hstmt,5,SQL_C_SLONG,&major,0,NULL);
490     SQLBindCol(hstmt,6,SQL_C_SLONG,&minor,0,NULL);
491     SQLBindCol(hstmt,7,SQL_C_CHAR,provider_id,sizeof(provider_id),NULL);
492
493     if ((sr=SQLFetch(hstmt)) == SQL_NO_DATA) {
494         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
495         return S_FALSE;
496     }
497
498     log->debug("session found, tranfering data back into memory");
499
500     // Copy back simple data.
501     applicationId = (char*)application_id;
502     clientAddress = (char*)addr;
503     majorVersion = major;
504     minorVersion = minor;
505     providerId = (char*)provider_id;
506
507     struct tm t;
508     t.tm_sec=ctime.second;
509     t.tm_min=ctime.minute;
510     t.tm_hour=ctime.hour;
511     t.tm_mday=ctime.day;
512     t.tm_mon=ctime.month-1;
513     t.tm_year=ctime.year-1900;
514     t.tm_isdst=0;
515 #if defined(HAVE_TIMEGM)
516     created=timegm(&t);
517 #else
518     // Windows, and hopefully most others...?
519     created = mktime(&t) - timezone;
520 #endif
521     t.tm_sec=atime.second;
522     t.tm_min=atime.minute;
523     t.tm_hour=atime.hour;
524     t.tm_mday=atime.day;
525     t.tm_mon=atime.month-1;
526     t.tm_year=atime.year-1900;
527     t.tm_isdst=0;
528 #if defined(HAVE_TIMEGM)
529     accessed=timegm(&t);
530 #else
531     // Windows, and hopefully most others...?
532     accessed = mktime(&t) - timezone;
533 #endif
534
535     // Extract text data.
536     string* ptrs[] = {&subject, &authnContext, &tokens};
537     HRESULT hr=NOERROR;
538     SQLINTEGER len;
539     SQLCHAR buf[LONGDATA_BUFLEN];
540     for (int i=0; i<3; i++) {
541         while ((sr=SQLGetData(hstmt,i+8,SQL_C_CHAR,buf,sizeof(buf),&len)) != SQL_NO_DATA) {
542             if (!SUCCEEDED(sr)) {
543                 log->error("error while reading text field from result set");
544                 log_error(hstmt, SQL_HANDLE_STMT);
545                 hr=E_FAIL;
546                 break;
547             }
548             ptrs[i]->append((char*)buf);
549         }
550     }
551
552     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
553     return hr;
554 }
555
556 HRESULT ODBCCCache::onRead(const char* key, time_t& accessed)
557 {
558 #ifdef _DEBUG
559     xmltooling::NDC ndc("onRead");
560 #endif
561
562     log->debug("reading last access time from database");
563
564     SQLHSTMT hstmt;
565     ODBCConn conn(getHDBC());
566     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
567     
568     string q = string("SELECT atime FROM state WHERE cookie='") + key + "'";
569     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
570     if (!SQL_SUCCEEDED(sr)) {
571         log->error("error searching for (%s)",key);
572         log_error(hstmt, SQL_HANDLE_STMT);
573         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
574         return E_FAIL;
575     }
576
577     SQL_TIMESTAMP_STRUCT atime;
578     SQLBindCol(hstmt,1,SQL_C_TYPE_TIMESTAMP,&atime,0,NULL);
579
580     if ((sr=SQLFetch(hstmt)) == SQL_NO_DATA) {
581         log->warn("session expected, but not found in database");
582         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
583         return S_FALSE;
584     }
585
586     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
587
588     struct tm t;
589     t.tm_sec=atime.second;
590     t.tm_min=atime.minute;
591     t.tm_hour=atime.hour;
592     t.tm_mday=atime.day;
593     t.tm_mon=atime.month-1;
594     t.tm_year=atime.year-1900;
595     t.tm_isdst=0;
596 #if defined(HAVE_TIMEGM)
597     accessed=timegm(&t);
598 #else
599     // Windows, and hopefully most others...?
600     accessed = mktime(&t) - timezone;
601 #endif
602     return NOERROR;
603 }
604
605 HRESULT ODBCCCache::onRead(const char* key, string& tokens)
606 {
607 #ifdef _DEBUG
608     xmltooling::NDC ndc("onRead");
609 #endif
610
611     if (!m_storeAttributes)
612         return S_FALSE;
613
614     log->debug("reading cached tokens from database");
615
616     SQLHSTMT hstmt;
617     ODBCConn conn(getHDBC());
618     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
619     
620     string q = string("SELECT tokens FROM state WHERE cookie='") + key + "'";
621     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
622     if (!SQL_SUCCEEDED(sr)) {
623         log->error("error searching for (%s)",key);
624         log_error(hstmt, SQL_HANDLE_STMT);
625         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
626         return E_FAIL;
627     }
628
629     if ((sr=SQLFetch(hstmt)) == SQL_NO_DATA) {
630         log->warn("session expected, but not found in database");
631         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
632         return S_FALSE;
633     }
634
635     HRESULT hr=NOERROR;
636     SQLINTEGER len;
637     SQLCHAR buf[LONGDATA_BUFLEN];
638     while ((sr=SQLGetData(hstmt,1,SQL_C_CHAR,buf,sizeof(buf),&len)) != SQL_NO_DATA) {
639         if (!SUCCEEDED(sr)) {
640             log->error("error while reading text field from result set");
641             log_error(hstmt, SQL_HANDLE_STMT);
642             hr=E_FAIL;
643             break;
644         }
645         tokens += (char*)buf;
646     }
647
648     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
649     return hr;
650 }
651
652 HRESULT ODBCCCache::onUpdate(const char* key, const char* tokens, time_t lastAccess)
653 {
654 #ifdef _DEBUG
655     xmltooling::NDC ndc("onUpdate");
656 #endif
657
658     ostringstream q;
659
660     if (lastAccess>0) {
661 #ifndef HAVE_GMTIME_R
662         struct tm* ptime=gmtime(&lastAccess);
663 #else
664         struct tm res;
665         struct tm* ptime=gmtime_r(&lastAccess,&res);
666 #endif
667         char timebuf[32];
668         strftime(timebuf,32,"{ts '%Y-%m-%d %H:%M:%S'}",ptime);
669
670         q << "UPDATE state SET atime=" << timebuf << " WHERE cookie='" << key << "'";
671     }
672     else if (tokens) {
673         if (!m_storeAttributes)
674             return S_FALSE;
675         q << "UPDATE state SET tokens=";
676         if (*tokens) {
677             q << "'";
678             appendXML(q,tokens);
679             q << "' ";
680         }
681         else
682             q << "null ";
683         q << "WHERE cookie='" << key << "'";
684     }
685     else {
686         log->warn("onUpdate called with nothing to do!");
687         return S_FALSE;
688     }
689  
690     HRESULT hr=NOERROR;
691     SQLHSTMT hstmt;
692     ODBCConn conn(getHDBC());
693     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
694     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.str().c_str(), SQL_NTS);
695     if (sr==SQL_NO_DATA)
696         hr=S_FALSE;
697     else if (!SQL_SUCCEEDED(sr)) {
698         log->error("error updating record (key=%s)", key);
699         log_error(hstmt, SQL_HANDLE_STMT);
700         hr=E_FAIL;
701     }
702
703     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
704     return hr;
705 }
706
707 HRESULT ODBCCCache::onDelete(const char* key)
708 {
709 #ifdef _DEBUG
710     xmltooling::NDC ndc("onDelete");
711 #endif
712
713     SQLHSTMT hstmt;
714     ODBCConn conn(getHDBC());
715     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
716     string q = string("DELETE FROM state WHERE cookie='") + key + "'";
717     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
718  
719     HRESULT hr=NOERROR;
720     if (sr==SQL_NO_DATA)
721         hr=S_FALSE;
722     else if (!SQL_SUCCEEDED(sr)) {
723         log->error("error deleting record (key=%s)", key);
724         log_error(hstmt, SQL_HANDLE_STMT);
725         hr=E_FAIL;
726     }
727
728     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
729     return hr;
730 }
731
732 void ODBCCCache::cleanup()
733 {
734 #ifdef _DEBUG
735     xmltooling::NDC ndc("cleanup");
736 #endif
737
738     Mutex* mutex = xmltooling::Mutex::create();
739
740     int rerun_timer = 0;
741     int timeout_life = 0;
742
743     // Load our configuration details...
744     const XMLCh* tag=m_root->getAttributeNS(NULL,cleanupInterval);
745     if (tag && *tag)
746         rerun_timer = XMLString::parseInt(tag);
747
748     // search for 'mysql-cache-timeout' and then the regular cache timeout
749     tag=m_root->getAttributeNS(NULL,odbcTimeout);
750     if (tag && *tag)
751         timeout_life = XMLString::parseInt(tag);
752     else {
753         tag=m_root->getAttributeNS(NULL,cacheTimeout);
754         if (tag && *tag)
755             timeout_life = XMLString::parseInt(tag);
756     }
757   
758     if (rerun_timer <= 0)
759         rerun_timer = 300;              // rerun every 5 minutes
760
761     if (timeout_life <= 0)
762         timeout_life = 28800;   // timeout after 8 hours
763     
764     mutex->lock();
765
766     log->info("cleanup thread started...Run every %d secs; timeout after %d secs", rerun_timer, timeout_life);
767
768     while (shutdown == false) {
769         shutdown_wait->timedwait(mutex, rerun_timer);
770
771         if (shutdown == true)
772             break;
773
774         // Find all the entries in the database that haven't been used
775         // recently In particular, find all entries that have not been
776         // accessed in 'timeout_life' seconds.
777
778         time_t stale=time(NULL)-timeout_life;
779 #ifndef HAVE_GMTIME_R
780         struct tm* ptime=gmtime(&stale);
781 #else
782         struct tm res;
783         struct tm* ptime=gmtime_r(&stale,&res);
784 #endif
785         char timebuf[32];
786         strftime(timebuf,32,"{ts '%Y-%m-%d %H:%M:%S'}",ptime);
787
788         string q = string("DELETE state WHERE atime < ") +  timebuf;
789
790         SQLHSTMT hstmt;
791         ODBCConn conn(getHDBC());
792         SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
793         SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
794         if (sr!=SQL_NO_DATA && !SQL_SUCCEEDED(sr)) {
795             log->error("error purging old records");
796             log_error(hstmt, SQL_HANDLE_STMT);
797         }
798
799         SQLINTEGER rowcount=0;
800         sr=SQLRowCount(hstmt,&rowcount);
801         if (SQL_SUCCEEDED(sr) && rowcount > 0)
802             log->info("purging %d old sessions",rowcount);
803
804         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
805      }
806
807     log->info("cleanup thread exiting...");
808
809     mutex->unlock();
810     delete mutex;
811     xmltooling::Thread::exit(NULL);
812 }
813
814 void* ODBCCCache::cleanup_fcn(void* cache_p)
815 {
816   ODBCCCache* cache = (ODBCCCache*)cache_p;
817
818 #ifndef WIN32
819   // First, let's block all signals
820   Thread::mask_all_signals();
821 #endif
822
823   // Now run the cleanup process.
824   cache->cleanup();
825   return NULL;
826 }
827
828
829 class ODBCReplayCache : public ODBCBase, virtual public IReplayCache
830 {
831 public:
832   ODBCReplayCache(const DOMElement* e);
833   virtual ~ODBCReplayCache() {}
834
835   bool check(const XMLCh* str, time_t expires) {xmltooling::auto_ptr_XMLCh temp(str); return check(temp.get(),expires);}
836   bool check(const char* str, time_t expires);
837 };
838
839 ODBCReplayCache::ODBCReplayCache(const DOMElement* e) : ODBCBase(e)
840 {
841 #ifdef _DEBUG
842     saml::NDC ndc("ODBCReplayCache");
843 #endif
844     log = &(Category::getInstance("shibtarget.ReplayCache.ODBC"));
845 }
846
847 bool ODBCReplayCache::check(const char* str, time_t expires)
848 {
849 #ifdef _DEBUG
850     saml::NDC ndc("check");
851 #endif
852   
853     time_t now=time(NULL);
854 #ifndef HAVE_GMTIME_R
855     struct tm* ptime=gmtime(&now);
856 #else
857     struct tm res;
858     struct tm* ptime=gmtime_r(&now,&res);
859 #endif
860     char timebuf[32];
861     strftime(timebuf,32,"{ts '%Y-%m-%d %H:%M:%S'}",ptime);
862
863     // Remove expired entries.
864     SQLHSTMT hstmt;
865     ODBCConn conn(getHDBC());
866     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
867     string q = string("DELETE FROM replay WHERE expires < ") + timebuf;
868     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
869     if (sr!=SQL_NO_DATA && !SQL_SUCCEEDED(sr)) {
870         log->error("error purging old replay cache entries");
871         log_error(hstmt, SQL_HANDLE_STMT);
872     }
873     SQLCloseCursor(hstmt);
874   
875     // Look for a replay.
876     q = string("SELECT id FROM replay WHERE id='") + str + "'";
877     sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
878     if (!SQL_SUCCEEDED(sr)) {
879         log->error("error searching replay cache");
880         log_error(hstmt, SQL_HANDLE_STMT);
881         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
882         throw SAMLException("Replay cache failed, please inform application support staff.");
883     }
884
885     // If we got a record, we return false.
886     if ((sr=SQLFetch(hstmt)) != SQL_NO_DATA) {
887         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
888         return false;
889     }
890     SQLCloseCursor(hstmt);
891     
892 #ifndef HAVE_GMTIME_R
893     ptime=gmtime(&expires);
894 #else
895     ptime=gmtime_r(&expires,&res);
896 #endif
897     strftime(timebuf,32,"{ts '%Y-%m-%d %H:%M:%S'}",ptime);
898
899     // Add it to the database.
900     q = string("INSERT replay VALUES('") + str + "'," + timebuf + ")";
901     sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
902     if (!SQL_SUCCEEDED(sr)) {
903         log->error("error inserting replay cache entry", str);
904         log_error(hstmt, SQL_HANDLE_STMT);
905         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
906         throw SAMLException("Replay cache failed, please inform application support staff.");
907     }
908
909     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
910     return true;
911 }
912
913
914 // Factories
915
916 IPlugIn* new_odbc_ccache(const DOMElement* e)
917 {
918     return new ODBCCCache(e);
919 }
920
921 IPlugIn* new_odbc_replay(const DOMElement* e)
922 {
923     return new ODBCReplayCache(e);
924 }
925
926
927 extern "C" int SHIBODBC_EXPORTS saml_extension_init(void*)
928 {
929     // register this ccache type
930     SAMLConfig::getConfig().getPlugMgr().regFactory(shibtarget::XML::ODBCReplayCacheType, &new_odbc_replay);
931     SAMLConfig::getConfig().getPlugMgr().regFactory(shibtarget::XML::ODBCSessionCacheType, &new_odbc_ccache);
932     return 0;
933 }
934
935 extern "C" void SHIBODBC_EXPORTS saml_extension_term()
936 {
937     SAMLConfig::getConfig().getPlugMgr().unregFactory(shibtarget::XML::ODBCSessionCacheType);
938     SAMLConfig::getConfig().getPlugMgr().unregFactory(shibtarget::XML::ODBCReplayCacheType);
939 }