Old config class ported, all config files now loading with new parser.
[shibboleth/cpp-sp.git] / odbc_ccache / odbc-ccache.cpp
1 /*
2  *  Copyright 2001-2005 Internet2
3  * 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /*
18  * odbc-ccache.cpp - Shibboleth Credential Cache using ODBC
19  *
20  * Scott Cantor <cantor.2@osu.edu>
21  *
22  * $Id$
23  */
24
25 // eventually we might be able to support autoconf via cygwin...
26 #if defined (_MSC_VER) || defined(__BORLANDC__)
27 # include "config_win32.h"
28 #else
29 # include "config.h"
30 #endif
31
32 #ifdef WIN32
33 # define _CRT_NONSTDC_NO_DEPRECATE 1
34 # define _CRT_SECURE_NO_DEPRECATE 1
35 # define NOMINMAX
36 # define SHIBODBC_EXPORTS __declspec(dllexport)
37 #else
38 # define SHIBODBC_EXPORTS
39 #endif
40
41 #include <shib-target/shib-target.h>
42 #include <shibsp/exceptions.h>
43 #include <log4cpp/Category.hh>
44 #include <xmltooling/util/NDC.h>
45 #include <xmltooling/util/Threads.h>
46
47 #include <ctime>
48 #include <algorithm>
49 #include <sstream>
50
51 #include <sql.h>
52 #include <sqlext.h>
53
54 #ifdef HAVE_LIBDMALLOCXX
55 #include <dmalloc.h>
56 #endif
57
58 using namespace shibsp;
59 using namespace shibtarget;
60 using namespace opensaml::saml2md;
61 using namespace saml;
62 using namespace xmltooling;
63 using namespace log4cpp;
64 using namespace std;
65
66 #define PLUGIN_VER_MAJOR 3
67 #define PLUGIN_VER_MINOR 0
68
69 #define COLSIZE_KEY 64
70 #define COLSIZE_APPLICATION_ID 256
71 #define COLSIZE_ADDRESS 128
72 #define COLSIZE_PROVIDER_ID 256
73 #define LONGDATA_BUFLEN 32768
74
75 /*
76   CREATE TABLE state (
77       cookie VARCHAR(64) PRIMARY KEY,
78       application_id VARCHAR(256),
79       ctime TIMESTAMP,
80       atime TIMESTAMP,
81       addr VARCHAR(128),
82       major INT,
83       minor INT,
84       provider VARCHAR(256),
85       subject TEXT,
86       authn_context TEXT,
87       tokens TEXT
88       )
89 */
90
91 #define REPLAY_TABLE \
92   "CREATE TABLE replay (id VARCHAR(255) PRIMARY KEY, " \
93   "expires TIMESTAMP, " \
94   "INDEX (expires))"
95
96 static const XMLCh ConnectionString[] =
97 { chLatin_C, chLatin_o, chLatin_n, chLatin_n, chLatin_e, chLatin_c, chLatin_t, chLatin_i, chLatin_o, chLatin_n,
98   chLatin_S, chLatin_t, chLatin_r, chLatin_i, chLatin_n, chLatin_g, chNull
99 };
100 static const XMLCh cleanupInterval[] =
101 { chLatin_c, chLatin_l, chLatin_e, chLatin_a, chLatin_n, chLatin_u, chLatin_p,
102   chLatin_I, chLatin_n, chLatin_t, chLatin_e, chLatin_r, chLatin_v, chLatin_a, chLatin_l, chNull
103 };
104 static const XMLCh cacheTimeout[] =
105 { chLatin_c, chLatin_a, chLatin_c, chLatin_h, chLatin_e, chLatin_T, chLatin_i, chLatin_m, chLatin_e, chLatin_o, chLatin_u, chLatin_t, chNull };
106 static const XMLCh odbcTimeout[] =
107 { chLatin_o, chLatin_d, chLatin_b, chLatin_c, chLatin_T, chLatin_i, chLatin_m, chLatin_e, chLatin_o, chLatin_u, chLatin_t, chNull };
108 static const XMLCh storeAttributes[] =
109 { chLatin_s, chLatin_t, chLatin_o, chLatin_r, chLatin_e, chLatin_A, chLatin_t, chLatin_t, chLatin_r, chLatin_i, chLatin_b, chLatin_u, chLatin_t, chLatin_e, chLatin_s, chNull };
110
111 struct ODBCConn {
112     ODBCConn(SQLHDBC conn) : handle(conn) {}
113     ~ODBCConn() {SQLFreeHandle(SQL_HANDLE_DBC,handle);}
114     operator SQLHDBC() {return handle;}
115     SQLHDBC handle;
116 };
117
118 class ODBCBase : public virtual saml::IPlugIn
119 {
120 public:
121     ODBCBase(const DOMElement* e);
122     virtual ~ODBCBase();
123
124     SQLHDBC getHDBC();
125
126     Category* log;
127
128 protected:
129     const DOMElement* m_root; // can only use this during initialization
130     string m_connstring;
131
132     static SQLHENV m_henv;          // single handle for both plugins
133     bool m_bInitializedODBC;        // tracks which class handled the process
134     static const char* p_connstring;
135
136     pair<int,int> getVersion(SQLHDBC);
137     void log_error(SQLHANDLE handle, SQLSMALLINT htype);
138 };
139
140 SQLHENV ODBCBase::m_henv = SQL_NULL_HANDLE;
141 const char* ODBCBase::p_connstring = NULL;
142
143 ODBCBase::ODBCBase(const DOMElement* e) : m_root(e), m_bInitializedODBC(false)
144 {
145 #ifdef _DEBUG
146     xmltooling::NDC ndc("ODBCBase");
147 #endif
148     log = &(Category::getInstance("shibtarget.ODBC"));
149
150     if (m_henv == SQL_NULL_HANDLE) {
151         // Enable connection pooling.
152         SQLSetEnvAttr(SQL_NULL_HANDLE, SQL_ATTR_CONNECTION_POOLING, (void*)SQL_CP_ONE_PER_HENV, 0);
153
154         // Allocate the environment.
155         if (!SQL_SUCCEEDED(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &m_henv)))
156             throw ConfigurationException("ODBC failed to initialize.");
157
158         // Specify ODBC 3.x
159         SQLSetEnvAttr(m_henv, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0);
160
161         log->info("ODBC initialized");
162         m_bInitializedODBC = true;
163     }
164
165     // Grab connection string from the configuration.
166     e=XMLHelper::getFirstChildElement(e,ConnectionString);
167     if (!e || !e->hasChildNodes()) {
168         if (!p_connstring) {
169             this->~ODBCBase();
170             throw ConfigurationException("ODBC cache requires ConnectionString element in configuration.");
171         }
172         m_connstring=p_connstring;
173     }
174     else {
175         xmltooling::auto_ptr_char arg(e->getFirstChild()->getNodeValue());
176         m_connstring=arg.get();
177         p_connstring=m_connstring.c_str();
178     }
179
180     // Connect and check version.
181     SQLHDBC conn=getHDBC();
182     pair<int,int> v=getVersion(conn);
183     SQLFreeHandle(SQL_HANDLE_DBC,conn);
184
185     // Make sure we've got the right version.
186     if (v.first != PLUGIN_VER_MAJOR) {
187         this->~ODBCBase();
188         log->crit("unknown database version: %d.%d", v.first, v.second);
189         throw SAMLException("Unknown cache database version.");
190     }
191 }
192
193 ODBCBase::~ODBCBase()
194 {
195     //delete m_mysql;
196     if (m_bInitializedODBC)
197         SQLFreeHandle(SQL_HANDLE_ENV,m_henv);
198     m_bInitializedODBC=false;
199     m_henv = SQL_NULL_HANDLE;
200     p_connstring=NULL;
201 }
202
203 void ODBCBase::log_error(SQLHANDLE handle, SQLSMALLINT htype)
204 {
205     SQLSMALLINT  i = 0;
206     SQLINTEGER   native;
207     SQLCHAR      state[7];
208     SQLCHAR      text[256];
209     SQLSMALLINT  len;
210     SQLRETURN    ret;
211
212     do {
213         ret = SQLGetDiagRec(htype, handle, ++i, state, &native, text, sizeof(text), &len);
214         if (SQL_SUCCEEDED(ret))
215             log->error("ODBC Error: %s:%ld:%ld:%s", state, i, native, text);
216     } while(SQL_SUCCEEDED(ret));
217 }
218
219 SQLHDBC ODBCBase::getHDBC()
220 {
221 #ifdef _DEBUG
222     xmltooling::NDC ndc("getMYSQL");
223 #endif
224
225     // Get a handle.
226     SQLHDBC handle;
227     SQLRETURN sr=SQLAllocHandle(SQL_HANDLE_DBC, m_henv, &handle);
228     if (!SQL_SUCCEEDED(sr)) {
229         log->error("failed to allocate connection handle");
230         log_error(m_henv, SQL_HANDLE_ENV);
231         throw SAMLException("ODBCBase::getHDBC failed to allocate connection handle");
232     }
233
234     sr=SQLDriverConnect(handle,NULL,(SQLCHAR*)m_connstring.c_str(),m_connstring.length(),NULL,0,NULL,SQL_DRIVER_NOPROMPT);
235     if (!SQL_SUCCEEDED(sr)) {
236         log->error("failed to connect to database");
237         log_error(handle, SQL_HANDLE_DBC);
238         throw SAMLException("ODBCBase::getHDBC failed to connect to database");
239     }
240
241     return handle;
242 }
243
244 pair<int,int> ODBCBase::getVersion(SQLHDBC conn)
245 {
246     // Grab the version number from the database.
247     SQLHSTMT hstmt;
248     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
249     
250     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)"SELECT major,minor FROM version", SQL_NTS);
251     if (!SQL_SUCCEEDED(sr)) {
252         log->error("failed to read version from database");
253         log_error(hstmt, SQL_HANDLE_STMT);
254         throw SAMLException("ODBCBase::getVersion failed to read version from database");
255     }
256
257     SQLINTEGER major;
258     SQLINTEGER minor;
259     SQLBindCol(hstmt,1,SQL_C_SLONG,&major,0,NULL);
260     SQLBindCol(hstmt,2,SQL_C_SLONG,&minor,0,NULL);
261
262     if ((sr=SQLFetch(hstmt)) != SQL_NO_DATA) {
263         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
264         return pair<int,int>(major,minor);
265     }
266
267     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
268     log->error("no rows returned in version query");
269     throw SAMLException("ODBCBase::getVersion failed to read version from database");
270 }
271
272 class ODBCCCache : public ODBCBase, virtual public ISessionCache, virtual public ISessionCacheStore
273 {
274 public:
275     ODBCCCache(const DOMElement* e);
276     virtual ~ODBCCCache();
277
278     // Delegate all the ISessionCache methods.
279     string insert(
280         const IApplication* application,
281         const RoleDescriptor* role,
282         const char* client_addr,
283         const SAMLSubject* subject,
284         const char* authnContext,
285         const SAMLResponse* tokens
286         )
287     { return m_cache->insert(application,role,client_addr,subject,authnContext,tokens); }
288     ISessionCacheEntry* find(const char* key, const IApplication* application, const char* client_addr)
289     { return m_cache->find(key,application,client_addr); }
290     void remove(const char* key, const IApplication* application, const char* client_addr)
291     { m_cache->remove(key,application,client_addr); }
292
293     bool setBackingStore(ISessionCacheStore*) {return false;}
294
295     // Store methods handle the database work
296     HRESULT onCreate(
297         const char* key,
298         const IApplication* application,
299         const ISessionCacheEntry* entry,
300         int majorVersion,
301         int minorVersion,
302         time_t created
303         );
304     HRESULT onRead(
305         const char* key,
306         string& applicationId,
307         string& clientAddress,
308         string& providerId,
309         string& subject,
310         string& authnContext,
311         string& tokens,
312         int& majorVersion,
313         int& minorVersion,
314         time_t& created,
315         time_t& accessed
316         );
317     HRESULT onRead(const char* key, time_t& accessed);
318     HRESULT onRead(const char* key, string& tokens);
319     HRESULT onUpdate(const char* key, const char* tokens=NULL, time_t accessed=0);
320     HRESULT onDelete(const char* key);
321
322     void cleanup();
323
324 private:
325     bool m_storeAttributes;
326     ISessionCache* m_cache;
327     xmltooling::CondWait* shutdown_wait;
328     bool shutdown;
329     xmltooling::Thread* cleanup_thread;
330
331     static void* cleanup_fcn(void*); // XXX Assumed an ODBCCCache
332 };
333
334 ODBCCCache::ODBCCCache(const DOMElement* e) : ODBCBase(e), m_storeAttributes(false)
335 {
336 #ifdef _DEBUG
337     xmltooling::NDC ndc("ODBCCCache");
338 #endif
339     log = &(Category::getInstance("shibtarget.SessionCache.ODBC"));
340
341     m_cache = dynamic_cast<ISessionCache*>(
342         SAMLConfig::getConfig().getPlugMgr().newPlugin(MEMORY_SESSIONCACHE, m_root)
343     );
344     if (!m_cache->setBackingStore(this)) {
345         delete m_cache;
346         throw SAMLException("Unable to register ODBC cache plugin as a cache store.");
347     }
348     
349     shutdown_wait = CondWait::create();
350     shutdown = false;
351
352     // Load our configuration details...
353     const XMLCh* tag=m_root->getAttributeNS(NULL,storeAttributes);
354     if (tag && *tag && (*tag==chLatin_t || *tag==chDigit_1))
355         m_storeAttributes=true;
356
357     // Initialize the cleanup thread
358     cleanup_thread = Thread::create(&cleanup_fcn, (void*)this);
359 }
360
361 ODBCCCache::~ODBCCCache()
362 {
363     shutdown = true;
364     shutdown_wait->signal();
365     cleanup_thread->join(NULL);
366     delete m_cache;
367 }
368
369 void appendXML(ostream& os, const char* str)
370 {
371     const char* pos=strchr(str,'\'');
372     while (pos) {
373         os.write(str,pos-str);
374         os << "''";
375         str=pos+1;
376         pos=strchr(str,'\'');
377     }
378     os << str;
379 }
380
381 HRESULT ODBCCCache::onCreate(
382     const char* key,
383     const IApplication* application,
384     const ISessionCacheEntry* entry,
385     int majorVersion,
386     int minorVersion,
387     time_t created
388     )
389 {
390 #ifdef _DEBUG
391     xmltooling::NDC ndc("onCreate");
392 #endif
393
394     // Get XML data from entry. Default is not to return SAML objects.
395     const char* context=entry->getAuthnContext();
396     pair<const char*,const SAMLSubject*> subject=entry->getSubject();
397     pair<const char*,const SAMLResponse*> tokens=entry->getTokens();
398
399     // Stringify timestamp.
400     if (created==0)
401         created=time(NULL);
402 #ifndef HAVE_GMTIME_R
403     struct tm* ptime=gmtime(&created);
404 #else
405     struct tm res;
406     struct tm* ptime=gmtime_r(&created,&res);
407 #endif
408     char timebuf[32];
409     strftime(timebuf,32,"{ts '%Y-%m-%d %H:%M:%S'}",ptime);
410
411     // Prepare insert statement.
412     ostringstream q;
413     q << "INSERT state VALUES ('" << key << "','" << application->getId() << "'," << timebuf << "," << timebuf
414         << ",'" << entry->getClientAddress() << "'," << majorVersion << "," << minorVersion << ",'" << entry->getProviderId()
415         << "','";
416     appendXML(q,subject.first);
417     q << "','";
418     appendXML(q,context);
419     q << "',";
420     if (m_storeAttributes && tokens.first) {
421         q << "'";
422         appendXML(q,tokens.first);
423         q << "')";
424     }
425     else
426         q << "null)";
427     if (log->isDebugEnabled())
428         log->debug("SQL insert: %s", q.str().c_str());
429
430     // Get statement handle.
431     SQLHSTMT hstmt;
432     ODBCConn conn(getHDBC());
433     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
434
435     // Execute statement.
436     HRESULT hr=NOERROR;
437     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.str().c_str(), SQL_NTS);
438     if (!SQL_SUCCEEDED(sr)) {
439         log->error("failed to insert record into database");
440         log_error(hstmt, SQL_HANDLE_STMT);
441         hr=E_FAIL;
442     }
443
444     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
445     return hr;
446 }
447
448 HRESULT ODBCCCache::onRead(
449     const char* key,
450     string& applicationId,
451     string& clientAddress,
452     string& providerId,
453     string& subject,
454     string& authnContext,
455     string& tokens,
456     int& majorVersion,
457     int& minorVersion,
458     time_t& created,
459     time_t& accessed
460     )
461 {
462 #ifdef _DEBUG
463     xmltooling::NDC ndc("onRead");
464 #endif
465
466     log->debug("searching database...");
467
468     SQLHSTMT hstmt;
469     ODBCConn conn(getHDBC());
470     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
471
472     string q = string("SELECT application_id,ctime,atime,addr,major,minor,provider,subject,authn_context,tokens FROM state WHERE cookie='") + key + "'";
473     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
474     if (!SQL_SUCCEEDED(sr)) {
475         log->error("error searching for (%s)",key);
476         log_error(hstmt, SQL_HANDLE_STMT);
477         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
478         return E_FAIL;
479     }
480
481     SQLINTEGER major,minor;
482     SQL_TIMESTAMP_STRUCT atime,ctime;
483     SQLCHAR application_id[COLSIZE_APPLICATION_ID+1];
484     SQLCHAR addr[COLSIZE_ADDRESS+1];
485     SQLCHAR provider_id[COLSIZE_PROVIDER_ID+1];
486
487     // Bind simple output columns.
488     SQLBindCol(hstmt,1,SQL_C_CHAR,application_id,sizeof(application_id),NULL);
489     SQLBindCol(hstmt,2,SQL_C_TYPE_TIMESTAMP,&ctime,0,NULL);
490     SQLBindCol(hstmt,3,SQL_C_TYPE_TIMESTAMP,&atime,0,NULL);
491     SQLBindCol(hstmt,4,SQL_C_CHAR,addr,sizeof(addr),NULL);
492     SQLBindCol(hstmt,5,SQL_C_SLONG,&major,0,NULL);
493     SQLBindCol(hstmt,6,SQL_C_SLONG,&minor,0,NULL);
494     SQLBindCol(hstmt,7,SQL_C_CHAR,provider_id,sizeof(provider_id),NULL);
495
496     if ((sr=SQLFetch(hstmt)) == SQL_NO_DATA) {
497         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
498         return S_FALSE;
499     }
500
501     log->debug("session found, tranfering data back into memory");
502
503     // Copy back simple data.
504     applicationId = (char*)application_id;
505     clientAddress = (char*)addr;
506     majorVersion = major;
507     minorVersion = minor;
508     providerId = (char*)provider_id;
509
510     struct tm t;
511     t.tm_sec=ctime.second;
512     t.tm_min=ctime.minute;
513     t.tm_hour=ctime.hour;
514     t.tm_mday=ctime.day;
515     t.tm_mon=ctime.month-1;
516     t.tm_year=ctime.year-1900;
517     t.tm_isdst=0;
518 #if defined(HAVE_TIMEGM)
519     created=timegm(&t);
520 #else
521     // Windows, and hopefully most others...?
522     created = mktime(&t) - timezone;
523 #endif
524     t.tm_sec=atime.second;
525     t.tm_min=atime.minute;
526     t.tm_hour=atime.hour;
527     t.tm_mday=atime.day;
528     t.tm_mon=atime.month-1;
529     t.tm_year=atime.year-1900;
530     t.tm_isdst=0;
531 #if defined(HAVE_TIMEGM)
532     accessed=timegm(&t);
533 #else
534     // Windows, and hopefully most others...?
535     accessed = mktime(&t) - timezone;
536 #endif
537
538     // Extract text data.
539     string* ptrs[] = {&subject, &authnContext, &tokens};
540     HRESULT hr=NOERROR;
541     SQLINTEGER len;
542     SQLCHAR buf[LONGDATA_BUFLEN];
543     for (int i=0; i<3; i++) {
544         while ((sr=SQLGetData(hstmt,i+8,SQL_C_CHAR,buf,sizeof(buf),&len)) != SQL_NO_DATA) {
545             if (!SUCCEEDED(sr)) {
546                 log->error("error while reading text field from result set");
547                 log_error(hstmt, SQL_HANDLE_STMT);
548                 hr=E_FAIL;
549                 break;
550             }
551             ptrs[i]->append((char*)buf);
552         }
553     }
554
555     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
556     return hr;
557 }
558
559 HRESULT ODBCCCache::onRead(const char* key, time_t& accessed)
560 {
561 #ifdef _DEBUG
562     xmltooling::NDC ndc("onRead");
563 #endif
564
565     log->debug("reading last access time from database");
566
567     SQLHSTMT hstmt;
568     ODBCConn conn(getHDBC());
569     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
570     
571     string q = string("SELECT atime FROM state WHERE cookie='") + key + "'";
572     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
573     if (!SQL_SUCCEEDED(sr)) {
574         log->error("error searching for (%s)",key);
575         log_error(hstmt, SQL_HANDLE_STMT);
576         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
577         return E_FAIL;
578     }
579
580     SQL_TIMESTAMP_STRUCT atime;
581     SQLBindCol(hstmt,1,SQL_C_TYPE_TIMESTAMP,&atime,0,NULL);
582
583     if ((sr=SQLFetch(hstmt)) == SQL_NO_DATA) {
584         log->warn("session expected, but not found in database");
585         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
586         return S_FALSE;
587     }
588
589     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
590
591     struct tm t;
592     t.tm_sec=atime.second;
593     t.tm_min=atime.minute;
594     t.tm_hour=atime.hour;
595     t.tm_mday=atime.day;
596     t.tm_mon=atime.month-1;
597     t.tm_year=atime.year-1900;
598     t.tm_isdst=0;
599 #if defined(HAVE_TIMEGM)
600     accessed=timegm(&t);
601 #else
602     // Windows, and hopefully most others...?
603     accessed = mktime(&t) - timezone;
604 #endif
605     return NOERROR;
606 }
607
608 HRESULT ODBCCCache::onRead(const char* key, string& tokens)
609 {
610 #ifdef _DEBUG
611     xmltooling::NDC ndc("onRead");
612 #endif
613
614     if (!m_storeAttributes)
615         return S_FALSE;
616
617     log->debug("reading cached tokens from database");
618
619     SQLHSTMT hstmt;
620     ODBCConn conn(getHDBC());
621     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
622     
623     string q = string("SELECT tokens FROM state WHERE cookie='") + key + "'";
624     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
625     if (!SQL_SUCCEEDED(sr)) {
626         log->error("error searching for (%s)",key);
627         log_error(hstmt, SQL_HANDLE_STMT);
628         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
629         return E_FAIL;
630     }
631
632     if ((sr=SQLFetch(hstmt)) == SQL_NO_DATA) {
633         log->warn("session expected, but not found in database");
634         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
635         return S_FALSE;
636     }
637
638     HRESULT hr=NOERROR;
639     SQLINTEGER len;
640     SQLCHAR buf[LONGDATA_BUFLEN];
641     while ((sr=SQLGetData(hstmt,1,SQL_C_CHAR,buf,sizeof(buf),&len)) != SQL_NO_DATA) {
642         if (!SUCCEEDED(sr)) {
643             log->error("error while reading text field from result set");
644             log_error(hstmt, SQL_HANDLE_STMT);
645             hr=E_FAIL;
646             break;
647         }
648         tokens += (char*)buf;
649     }
650
651     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
652     return hr;
653 }
654
655 HRESULT ODBCCCache::onUpdate(const char* key, const char* tokens, time_t lastAccess)
656 {
657 #ifdef _DEBUG
658     xmltooling::NDC ndc("onUpdate");
659 #endif
660
661     ostringstream q;
662
663     if (lastAccess>0) {
664 #ifndef HAVE_GMTIME_R
665         struct tm* ptime=gmtime(&lastAccess);
666 #else
667         struct tm res;
668         struct tm* ptime=gmtime_r(&lastAccess,&res);
669 #endif
670         char timebuf[32];
671         strftime(timebuf,32,"{ts '%Y-%m-%d %H:%M:%S'}",ptime);
672
673         q << "UPDATE state SET atime=" << timebuf << " WHERE cookie='" << key << "'";
674     }
675     else if (tokens) {
676         if (!m_storeAttributes)
677             return S_FALSE;
678         q << "UPDATE state SET tokens=";
679         if (*tokens) {
680             q << "'";
681             appendXML(q,tokens);
682             q << "' ";
683         }
684         else
685             q << "null ";
686         q << "WHERE cookie='" << key << "'";
687     }
688     else {
689         log->warn("onUpdate called with nothing to do!");
690         return S_FALSE;
691     }
692  
693     HRESULT hr=NOERROR;
694     SQLHSTMT hstmt;
695     ODBCConn conn(getHDBC());
696     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
697     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.str().c_str(), SQL_NTS);
698     if (sr==SQL_NO_DATA)
699         hr=S_FALSE;
700     else if (!SQL_SUCCEEDED(sr)) {
701         log->error("error updating record (key=%s)", key);
702         log_error(hstmt, SQL_HANDLE_STMT);
703         hr=E_FAIL;
704     }
705
706     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
707     return hr;
708 }
709
710 HRESULT ODBCCCache::onDelete(const char* key)
711 {
712 #ifdef _DEBUG
713     xmltooling::NDC ndc("onDelete");
714 #endif
715
716     SQLHSTMT hstmt;
717     ODBCConn conn(getHDBC());
718     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
719     string q = string("DELETE FROM state WHERE cookie='") + key + "'";
720     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
721  
722     HRESULT hr=NOERROR;
723     if (sr==SQL_NO_DATA)
724         hr=S_FALSE;
725     else if (!SQL_SUCCEEDED(sr)) {
726         log->error("error deleting record (key=%s)", key);
727         log_error(hstmt, SQL_HANDLE_STMT);
728         hr=E_FAIL;
729     }
730
731     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
732     return hr;
733 }
734
735 void ODBCCCache::cleanup()
736 {
737 #ifdef _DEBUG
738     xmltooling::NDC ndc("cleanup");
739 #endif
740
741     Mutex* mutex = xmltooling::Mutex::create();
742
743     int rerun_timer = 0;
744     int timeout_life = 0;
745
746     // Load our configuration details...
747     const XMLCh* tag=m_root->getAttributeNS(NULL,cleanupInterval);
748     if (tag && *tag)
749         rerun_timer = XMLString::parseInt(tag);
750
751     // search for 'mysql-cache-timeout' and then the regular cache timeout
752     tag=m_root->getAttributeNS(NULL,odbcTimeout);
753     if (tag && *tag)
754         timeout_life = XMLString::parseInt(tag);
755     else {
756         tag=m_root->getAttributeNS(NULL,cacheTimeout);
757         if (tag && *tag)
758             timeout_life = XMLString::parseInt(tag);
759     }
760   
761     if (rerun_timer <= 0)
762         rerun_timer = 300;              // rerun every 5 minutes
763
764     if (timeout_life <= 0)
765         timeout_life = 28800;   // timeout after 8 hours
766     
767     mutex->lock();
768
769     log->info("cleanup thread started...Run every %d secs; timeout after %d secs", rerun_timer, timeout_life);
770
771     while (shutdown == false) {
772         shutdown_wait->timedwait(mutex, rerun_timer);
773
774         if (shutdown == true)
775             break;
776
777         // Find all the entries in the database that haven't been used
778         // recently In particular, find all entries that have not been
779         // accessed in 'timeout_life' seconds.
780
781         time_t stale=time(NULL)-timeout_life;
782 #ifndef HAVE_GMTIME_R
783         struct tm* ptime=gmtime(&stale);
784 #else
785         struct tm res;
786         struct tm* ptime=gmtime_r(&stale,&res);
787 #endif
788         char timebuf[32];
789         strftime(timebuf,32,"{ts '%Y-%m-%d %H:%M:%S'}",ptime);
790
791         string q = string("DELETE state WHERE atime < ") +  timebuf;
792
793         SQLHSTMT hstmt;
794         ODBCConn conn(getHDBC());
795         SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
796         SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
797         if (sr!=SQL_NO_DATA && !SQL_SUCCEEDED(sr)) {
798             log->error("error purging old records");
799             log_error(hstmt, SQL_HANDLE_STMT);
800         }
801
802         SQLINTEGER rowcount=0;
803         sr=SQLRowCount(hstmt,&rowcount);
804         if (SQL_SUCCEEDED(sr) && rowcount > 0)
805             log->info("purging %d old sessions",rowcount);
806
807         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
808      }
809
810     log->info("cleanup thread exiting...");
811
812     mutex->unlock();
813     delete mutex;
814     xmltooling::Thread::exit(NULL);
815 }
816
817 void* ODBCCCache::cleanup_fcn(void* cache_p)
818 {
819   ODBCCCache* cache = (ODBCCCache*)cache_p;
820
821 #ifndef WIN32
822   // First, let's block all signals
823   Thread::mask_all_signals();
824 #endif
825
826   // Now run the cleanup process.
827   cache->cleanup();
828   return NULL;
829 }
830
831
832 class ODBCReplayCache : public ODBCBase, virtual public IReplayCache
833 {
834 public:
835   ODBCReplayCache(const DOMElement* e);
836   virtual ~ODBCReplayCache() {}
837
838   bool check(const XMLCh* str, time_t expires) {xmltooling::auto_ptr_XMLCh temp(str); return check(temp.get(),expires);}
839   bool check(const char* str, time_t expires);
840 };
841
842 ODBCReplayCache::ODBCReplayCache(const DOMElement* e) : ODBCBase(e)
843 {
844 #ifdef _DEBUG
845     saml::NDC ndc("ODBCReplayCache");
846 #endif
847     log = &(Category::getInstance("shibtarget.ReplayCache.ODBC"));
848 }
849
850 bool ODBCReplayCache::check(const char* str, time_t expires)
851 {
852 #ifdef _DEBUG
853     saml::NDC ndc("check");
854 #endif
855   
856     time_t now=time(NULL);
857 #ifndef HAVE_GMTIME_R
858     struct tm* ptime=gmtime(&now);
859 #else
860     struct tm res;
861     struct tm* ptime=gmtime_r(&now,&res);
862 #endif
863     char timebuf[32];
864     strftime(timebuf,32,"{ts '%Y-%m-%d %H:%M:%S'}",ptime);
865
866     // Remove expired entries.
867     SQLHSTMT hstmt;
868     ODBCConn conn(getHDBC());
869     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
870     string q = string("DELETE FROM replay WHERE expires < ") + timebuf;
871     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
872     if (sr!=SQL_NO_DATA && !SQL_SUCCEEDED(sr)) {
873         log->error("error purging old replay cache entries");
874         log_error(hstmt, SQL_HANDLE_STMT);
875     }
876     SQLCloseCursor(hstmt);
877   
878     // Look for a replay.
879     q = string("SELECT id FROM replay WHERE id='") + str + "'";
880     sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
881     if (!SQL_SUCCEEDED(sr)) {
882         log->error("error searching replay cache");
883         log_error(hstmt, SQL_HANDLE_STMT);
884         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
885         throw SAMLException("Replay cache failed, please inform application support staff.");
886     }
887
888     // If we got a record, we return false.
889     if ((sr=SQLFetch(hstmt)) != SQL_NO_DATA) {
890         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
891         return false;
892     }
893     SQLCloseCursor(hstmt);
894     
895 #ifndef HAVE_GMTIME_R
896     ptime=gmtime(&expires);
897 #else
898     ptime=gmtime_r(&expires,&res);
899 #endif
900     strftime(timebuf,32,"{ts '%Y-%m-%d %H:%M:%S'}",ptime);
901
902     // Add it to the database.
903     q = string("INSERT replay VALUES('") + str + "'," + timebuf + ")";
904     sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
905     if (!SQL_SUCCEEDED(sr)) {
906         log->error("error inserting replay cache entry", str);
907         log_error(hstmt, SQL_HANDLE_STMT);
908         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
909         throw SAMLException("Replay cache failed, please inform application support staff.");
910     }
911
912     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
913     return true;
914 }
915
916
917 // Factories
918
919 IPlugIn* new_odbc_ccache(const DOMElement* e)
920 {
921     return new ODBCCCache(e);
922 }
923
924 IPlugIn* new_odbc_replay(const DOMElement* e)
925 {
926     return new ODBCReplayCache(e);
927 }
928
929
930 extern "C" int SHIBODBC_EXPORTS saml_extension_init(void*)
931 {
932     // register this ccache type
933     SAMLConfig::getConfig().getPlugMgr().regFactory(ODBC_REPLAYCACHE, &new_odbc_replay);
934     SAMLConfig::getConfig().getPlugMgr().regFactory(ODBC_SESSIONCACHE, &new_odbc_ccache);
935     return 0;
936 }
937
938 extern "C" void SHIBODBC_EXPORTS saml_extension_term()
939 {
940     SAMLConfig::getConfig().getPlugMgr().unregFactory(ODBC_SESSIONCACHE);
941     SAMLConfig::getConfig().getPlugMgr().unregFactory(ODBC_REPLAYCACHE);
942 }