Plugin now implements backing store API.
[shibboleth/sp.git] / shib-mysql-ccache / shib-mysql-ccache.cpp
1 /*
2  *  Copyright 2001-2005 Internet2
3  * 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /*
18  * shib-mysql-ccache.cpp: Shibboleth Credential Cache using MySQL.
19  *
20  * Created by:  Derek Atkins <derek@ihtfp.com>
21  *
22  * $Id$
23  */
24
25 // eventually we might be able to support autoconf via cygwin...
26 #if defined (_MSC_VER) || defined(__BORLANDC__)
27 # include "config_win32.h"
28 #else
29 # include "config.h"
30 #endif
31
32 #ifdef WIN32
33 # define SHIBMYSQL_EXPORTS __declspec(dllexport)
34 #else
35 # define SHIBMYSQL_EXPORTS
36 #endif
37
38 #ifdef HAVE_UNISTD_H
39 # include <unistd.h>
40 #endif
41
42 #include <shib/shib-threads.h>
43 #include <shib-target/shib-target.h>
44 #include <log4cpp/Category.hh>
45
46 #include <sstream>
47
48 #ifdef WIN32
49 # include <winsock.h>
50 #endif
51 #include <mysql.h>
52
53 // wanted to use MySQL codes for this, but can't seem to get back a 145
54 #define isCorrupt(s) strstr(s,"(errno: 145)")
55
56 #ifdef HAVE_LIBDMALLOCXX
57 #include <dmalloc.h>
58 #endif
59
60 using namespace std;
61 using namespace saml;
62 using namespace shibboleth;
63 using namespace shibtarget;
64 using namespace log4cpp;
65
66 #define PLUGIN_VER_MAJOR 3
67 #define PLUGIN_VER_MINOR 0
68
69 #define STATE_TABLE \
70   "CREATE TABLE state (" \
71   "cookie VARCHAR(64) PRIMARY KEY, " \
72   "application_id VARCHAR(255)," \
73   "ctime TIMESTAMP," \
74   "atime TIMESTAMP," \
75   "addr VARCHAR(128)," \
76   "major INT," \
77   "minor INT," \
78   "provider VARCHAR(256)," \
79   "subject TEXT," \
80   "authn_context TEXT," \
81   "tokens TEXT)"
82
83 #define REPLAY_TABLE \
84   "CREATE TABLE replay (id VARCHAR(255) PRIMARY KEY, " \
85   "expires TIMESTAMP, " \
86   "INDEX (expires))"
87
88 static const XMLCh Argument[] =
89 { chLatin_A, chLatin_r, chLatin_g, chLatin_u, chLatin_m, chLatin_e, chLatin_n, chLatin_t, chNull };
90 static const XMLCh cleanupInterval[] =
91 { chLatin_c, chLatin_l, chLatin_e, chLatin_a, chLatin_n, chLatin_u, chLatin_p,
92   chLatin_I, chLatin_n, chLatin_t, chLatin_e, chLatin_r, chLatin_v, chLatin_a, chLatin_l, chNull
93 };
94 static const XMLCh cacheTimeout[] =
95 { chLatin_c, chLatin_a, chLatin_c, chLatin_h, chLatin_e, chLatin_T, chLatin_i, chLatin_m, chLatin_e, chLatin_o, chLatin_u, chLatin_t, chNull };
96 static const XMLCh mysqlTimeout[] =
97 { chLatin_m, chLatin_y, chLatin_s, chLatin_q, chLatin_l, chLatin_T, chLatin_i, chLatin_m, chLatin_e, chLatin_o, chLatin_u, chLatin_t, chNull };
98 static const XMLCh storeAttributes[] =
99 { chLatin_s, chLatin_t, chLatin_o, chLatin_r, chLatin_e, chLatin_A, chLatin_t, chLatin_t, chLatin_r, chLatin_i, chLatin_b, chLatin_u, chLatin_t, chLatin_e, chLatin_s, chNull };
100
101 class MySQLBase : public virtual saml::IPlugIn
102 {
103 public:
104   MySQLBase(const DOMElement* e);
105   virtual ~MySQLBase();
106
107   MYSQL* getMYSQL();
108   bool repairTable(MYSQL*&, const char* table);
109
110   log4cpp::Category* log;
111
112 protected:
113   ThreadKey* m_mysql;
114   const DOMElement* m_root; // can only use this during initialization
115
116   bool initialized;
117
118   void createDatabase(MYSQL*, int major, int minor);
119   void upgradeDatabase(MYSQL*);
120   pair<int,int> getVersion(MYSQL*);
121 };
122
123 // Forward declarations
124 static void mysqlInit(const DOMElement* e, Category& log);
125
126 extern "C" void shib_mysql_destroy_handle(void* data)
127 {
128   MYSQL* mysql = (MYSQL*) data;
129   if (mysql) mysql_close(mysql);
130 }
131
132 MySQLBase::MySQLBase(const DOMElement* e) : m_root(e)
133 {
134 #ifdef _DEBUG
135   saml::NDC ndc("MySQLBase");
136 #endif
137   log = &(Category::getInstance("shibmysql.MySQLBase"));
138
139   m_mysql = ThreadKey::create(&shib_mysql_destroy_handle);
140
141   initialized = false;
142   mysqlInit(e,*log);
143   getMYSQL();
144   initialized = true;
145 }
146
147 MySQLBase::~MySQLBase()
148 {
149   delete m_mysql;
150 }
151
152 MYSQL* MySQLBase::getMYSQL()
153 {
154 #ifdef _DEBUG
155     saml::NDC ndc("getMYSQL");
156 #endif
157
158     // Do we already have a handle?
159     MYSQL* mysql=reinterpret_cast<MYSQL*>(m_mysql->getData());
160     if (mysql)
161         return mysql;
162
163     // Connect to the database
164     mysql = mysql_init(NULL);
165     if (!mysql) {
166         log->error("mysql_init failed");
167         mysql_close(mysql);
168         throw SAMLException("MySQLBase::getMYSQL(): mysql_init() failed");
169     }
170
171     if (!mysql_real_connect(mysql, NULL, NULL, NULL, "shibd", 0, NULL, 0)) {
172         if (initialized) {
173             log->crit("mysql_real_connect failed: %s", mysql_error(mysql));
174             mysql_close(mysql);
175             throw SAMLException("MySQLBase::getMYSQL(): mysql_real_connect() failed");
176         }
177         else {
178             log->info("mysql_real_connect failed: %s.  Trying to create", mysql_error(mysql));
179
180             // This will throw an exception if it fails.
181             createDatabase(mysql, PLUGIN_VER_MAJOR, PLUGIN_VER_MINOR);
182         }
183     }
184
185     pair<int,int> v=getVersion (mysql);
186
187     // Make sure we've got the right version
188     if (v.first != PLUGIN_VER_MAJOR || v.second != PLUGIN_VER_MINOR) {
189    
190         // If we're capable, try upgrading on the fly...
191         if (v.first == 0  || v.first == 1 || v.first == 2) {
192             if (mysql_query(mysql, "DROP TABLE state")) {
193                 log->error("error dropping old session state table: %s", mysql_error(mysql));
194             }
195             if (v.first==2 && mysql_query(mysql, "DROP TABLE replay")) {
196                 log->error("error dropping old session state table: %s", mysql_error(mysql));
197             }
198             upgradeDatabase(mysql);
199         }
200         else {
201             mysql_close(mysql);
202             log->crit("Unknown database version: %d.%d", v.first, v.second);
203             throw SAMLException("MySQLBase::getMYSQL(): Unknown database version");
204         }
205     }
206
207     // We're all set.. Save off the handle for this thread.
208     m_mysql->setData(mysql);
209     return mysql;
210 }
211
212 bool MySQLBase::repairTable(MYSQL*& mysql, const char* table)
213 {
214     string q = string("REPAIR TABLE ") + table;
215     if (mysql_query(mysql, q.c_str())) {
216         log->error("Error repairing table %s: %s", table, mysql_error(mysql));
217         return false;
218     }
219
220     // seems we have to recycle the connection to get the thread to keep working
221     // other threads seem to be ok, but we should monitor that
222     mysql_close(mysql);
223     m_mysql->setData(NULL);
224     mysql=getMYSQL();
225     return true;
226 }
227
228 void MySQLBase::createDatabase(MYSQL* mysql, int major, int minor)
229 {
230   log->info("creating database");
231
232   MYSQL* ms = NULL;
233   try {
234     ms = mysql_init(NULL);
235     if (!ms) {
236       log->crit("mysql_init failed");
237       throw SAMLException("ShibMySQLCCache::createDatabase(): mysql_init failed");
238     }
239
240     if (!mysql_real_connect(ms, NULL, NULL, NULL, NULL, 0, NULL, 0)) {
241       log->crit("cannot open DB file to create DB: %s", mysql_error(ms));
242       throw SAMLException("ShibMySQLCCache::createDatabase(): mysql_real_connect failed");
243     }
244
245     if (mysql_query(ms, "CREATE DATABASE shibd")) {
246       log->crit("cannot create shibd database: %s", mysql_error(ms));
247       throw SAMLException("ShibMySQLCCache::createDatabase(): create db cmd failed");
248     }
249
250     if (!mysql_real_connect(mysql, NULL, NULL, NULL, "shibd", 0, NULL, 0)) {
251       log->crit("cannot open shibd database");
252       throw SAMLException("ShibMySQLCCache::createDatabase(): mysql_real_connect to plugin db failed");
253     }
254
255     mysql_close(ms);
256     
257   }
258   catch (SAMLException&) {
259     if (ms)
260       mysql_close(ms);
261     mysql_close(mysql);
262     throw;
263   }
264
265   // Now create the tables if they don't exist
266   log->info("Creating database tables");
267
268   if (mysql_query(mysql, "CREATE TABLE version (major INT, minor INT)")) {
269     log->error ("error creating version: %s", mysql_error(mysql));
270     throw SAMLException("ShibMySQLCCache::createDatabase(): create table cmd failed");
271   }
272
273   if (mysql_query(mysql,STATE_TABLE)) {
274     log->error ("error creating state table: %s", mysql_error(mysql));
275     throw SAMLException("ShibMySQLCCache::createDatabase(): create table cmd failed");
276   }
277
278   if (mysql_query(mysql,REPLAY_TABLE)) {
279     log->error ("error creating replay table: %s", mysql_error(mysql));
280     throw SAMLException("ShibMySQLCCache::createDatabase(): create table cmd failed");
281   }
282
283   ostringstream q;
284   q << "INSERT INTO version VALUES(" << major << "," << minor << ")";
285   if (mysql_query(mysql, q.str().c_str())) {
286     log->error ("error setting version: %s", mysql_error(mysql));
287     throw SAMLException("ShibMySQLCCache::createDatabase(): version insert failed");
288   }
289 }
290
291 void MySQLBase::upgradeDatabase(MYSQL* mysql)
292 {
293     if (mysql_query(mysql,STATE_TABLE)) {
294         log->error ("error creating state table: %s", mysql_error(mysql));
295         throw SAMLException("ShibMySQLCCache::upgradeDatabase(): error creating state table");
296     }
297
298     if (mysql_query(mysql,REPLAY_TABLE)) {
299         log->error ("error creating replay table: %s", mysql_error(mysql));
300         throw SAMLException("ShibMySQLCCache::upgradeDatabase(): error creating replay table");
301     }
302
303     ostringstream q;
304     q << "UPDATE version SET major = " << PLUGIN_VER_MAJOR;
305     if (mysql_query(mysql, q.str().c_str())) {
306         log->error ("error updating version: %s", mysql_error(mysql));
307         throw SAMLException("ShibMySQLCCache::upgradeDatabase(): error updating version");
308     }
309 }
310
311 pair<int,int> MySQLBase::getVersion(MYSQL* mysql)
312 {
313     // grab the version number from the database
314     if (mysql_query(mysql, "SELECT * FROM version"))
315         log->error ("Error reading version: %s", mysql_error(mysql));
316
317     MYSQL_RES* rows = mysql_store_result(mysql);
318     if (rows) {
319         if (mysql_num_rows(rows) == 1 && mysql_num_fields(rows) == 2)  {
320           MYSQL_ROW row = mysql_fetch_row(rows);
321           int major = row[0] ? atoi(row[0]) : -1;
322           int minor = row[1] ? atoi(row[1]) : -1;
323           log->debug("opening database version %d.%d", major, minor);
324           mysql_free_result(rows);
325           return make_pair(major,minor);
326         } else {
327             // Wrong number of rows or wrong number of fields...
328             log->crit("Houston, we've got a problem with the database...");
329             mysql_free_result (rows);
330             throw SAMLException("ShibMySQLCCache::getVersion(): version verification failed");
331         }
332     }
333     log->crit("MySQL Read Failed in version verification");
334     throw SAMLException("ShibMySQLCCache::getVersion(): error reading version");
335 }
336
337 static void mysqlInit(const DOMElement* e, Category& log)
338 {
339     static bool done = false;
340     if (done) {
341         log.info("MySQL embedded server already initialized");
342         return;
343     }
344     log.info("initializing MySQL embedded server");
345
346     // Setup the argument array
347     vector<string> arg_array;
348     arg_array.push_back("shibboleth");
349
350     // grab any MySQL parameters from the config file
351     e=saml::XML::getFirstChildElement(e,shibtarget::XML::SHIBTARGET_NS,Argument);
352     while (e) {
353         auto_ptr_char arg(e->getFirstChild()->getNodeValue());
354         if (arg.get())
355             arg_array.push_back(arg.get());
356         e=saml::XML::getNextSiblingElement(e,shibtarget::XML::SHIBTARGET_NS,Argument);
357     }
358
359     // Compute the argument array
360     vector<string>::size_type arg_count = arg_array.size();
361     const char** args=new const char*[arg_count];
362     for (vector<string>::size_type i = 0; i < arg_count; i++)
363         args[i] = arg_array[i].c_str();
364
365     // Initialize MySQL with the arguments
366     mysql_server_init(arg_count, (char **)args, NULL);
367
368     delete[] args;
369     done = true;
370 }  
371
372 class ShibMySQLCCache : public MySQLBase, virtual public ISessionCache, virtual public ISessionCacheStore
373 {
374 public:
375     ShibMySQLCCache(const DOMElement* e);
376     virtual ~ShibMySQLCCache();
377
378     // Delegate all the ISessionCache methods.
379     string insert(
380         const IApplication* application,
381         const IEntityDescriptor* source,
382         const char* client_addr,
383         const SAMLSubject* subject,
384         const char* authnContext,
385         const SAMLResponse* tokens
386         )
387     { return m_cache->insert(application,source,client_addr,subject,authnContext,tokens); }
388     ISessionCacheEntry* find(const char* key, const IApplication* application, const char* client_addr)
389     { return m_cache->find(key,application,client_addr); }
390     void remove(const char* key, const IApplication* application, const char* client_addr)
391     { m_cache->remove(key,application,client_addr); }
392
393     bool setBackingStore(ISessionCacheStore*) {return false;}
394
395     // Store methods handle the database work
396     HRESULT onCreate(
397         const char* key,
398         const IApplication* application,
399         const ISessionCacheEntry* entry,
400         int majorVersion,
401         int minorVersion,
402         time_t created
403         );
404     HRESULT onRead(
405         const char* key,
406         string& applicationId,
407         string& clientAddress,
408         string& providerId,
409         string& subject,
410         string& authnContext,
411         string& tokens,
412         int& majorVersion,
413         int& minorVersion,
414         time_t& created,
415         time_t& accessed
416         );
417     HRESULT onUpdate(const char* key, const char* tokens=NULL, time_t lastAccess=0);
418     HRESULT onDelete(const char* key);
419
420     void cleanup();
421
422 private:
423     bool m_storeAttributes;
424     ISessionCache* m_cache;
425     CondWait* shutdown_wait;
426     bool shutdown;
427     Thread* cleanup_thread;
428
429     static void* cleanup_fcn(void*); // XXX Assumed an ShibMySQLCCache
430 };
431
432 ShibMySQLCCache::ShibMySQLCCache(const DOMElement* e) : MySQLBase(e), m_storeAttributes(false)
433 {
434 #ifdef _DEBUG
435     saml::NDC ndc("ShibMySQLCCache");
436 #endif
437     log = &(Category::getInstance("shibmysql.SessionCache"));
438
439     m_cache = dynamic_cast<ISessionCache*>(
440         SAMLConfig::getConfig().getPlugMgr().newPlugin(shibtarget::XML::MemorySessionCacheType, e)
441     );
442     if (!m_cache->setBackingStore(this)) {
443         delete m_cache;
444         throw SAMLException("Unable to register MySQL cache plugin as a cache store.");
445     }
446     
447     shutdown_wait = CondWait::create();
448     shutdown = false;
449
450     // Load our configuration details...
451     const XMLCh* tag=m_root->getAttributeNS(NULL,storeAttributes);
452     if (tag && *tag && (*tag==chLatin_t || *tag==chDigit_1))
453         m_storeAttributes=true;
454
455     // Initialize the cleanup thread
456     cleanup_thread = Thread::create(&cleanup_fcn, (void*)this);
457 }
458
459 ShibMySQLCCache::~ShibMySQLCCache()
460 {
461     shutdown = true;
462     shutdown_wait->signal();
463     cleanup_thread->join(NULL);
464     delete m_cache;
465 }
466
467 HRESULT ShibMySQLCCache::onCreate(
468     const char* key,
469     const IApplication* application,
470     const ISessionCacheEntry* entry,
471     int majorVersion,
472     int minorVersion,
473     time_t created
474     )
475 {
476 #ifdef _DEBUG
477     saml::NDC ndc("onCreate");
478 #endif
479
480     // Get XML data from entry. Default is not to return SAML objects.
481     const char* context=entry->getAuthnContext();
482     pair<const char*,const SAMLSubject*> subject=entry->getSubject();
483     pair<const char*,const SAMLResponse*> tokens=entry->getTokens();
484
485     ostringstream q;
486     q << "INSERT INTO state VALUES('" << key << "','" << application->getId() << "',";
487     if (created==0)
488         q << "NOW(),NOW(),'";
489     else
490         q << "FROM_UNIXTIME(" << created << "),NOW(),'";
491     q << entry->getClientAddress() << "'," << majorVersion << "," << minorVersion << ",'" << entry->getProviderId() << "','"
492         << subject.first << "','" << context << "',";
493
494     if (m_storeAttributes && tokens.first)
495         q << "'" << tokens.first << "')";
496     else
497         q << "null)";
498
499     if (log->isDebugEnabled())
500         log->debug("SQL insert: %s", q.str().c_str());
501
502     // then add it to the database
503     MYSQL* mysql = getMYSQL();
504     if (mysql_query(mysql, q.str().c_str())) {
505         const char* err=mysql_error(mysql);
506         log->error("error inserting %s: %s", key, err);
507         if (isCorrupt(err) && repairTable(mysql,"state")) {
508             // Try again...
509             if (mysql_query(mysql, q.str().c_str())) {
510                 log->error("error inserting %s: %s", key, mysql_error(mysql));
511                 return E_FAIL;
512             }
513         }
514         else
515             throw E_FAIL;
516     }
517
518     return NOERROR;
519 }
520
521 HRESULT ShibMySQLCCache::onRead(
522     const char* key,
523     string& applicationId,
524     string& clientAddress,
525     string& providerId,
526     string& subject,
527     string& authnContext,
528     string& tokens,
529     int& majorVersion,
530     int& minorVersion,
531     time_t& created,
532     time_t& accessed
533     )
534 {
535 #ifdef _DEBUG
536     saml::NDC ndc("onRead");
537 #endif
538
539     log->debug("searching MySQL database...");
540
541     string q = string("SELECT application_id,UNIX_TIMESTAMP(ctime),UNIX_TIMESTAMP(atime),addr,major,minor,provider,subject,authn_context,tokens FROM state WHERE cookie='") + key + "' LIMIT 1";
542
543     MYSQL* mysql = getMYSQL();
544     if (mysql_query(mysql, q.c_str())) {
545         const char* err=mysql_error(mysql);
546         log->error("error searching for %s: %s", key, err);
547         if (isCorrupt(err) && repairTable(mysql,"state")) {
548             if (mysql_query(mysql, q.c_str()))
549                 log->error("error retrying search for %s: %s", key, mysql_error(mysql));
550         }
551     }
552
553     MYSQL_RES* rows = mysql_store_result(mysql);
554
555     // Nope, doesn't exist.
556     if (!rows)
557         return S_FALSE;
558
559     // Make sure we got 1 and only 1 rows.
560     if (mysql_num_rows(rows) != 1) {
561         log->error("Database select returned wrong number of rows: %d", mysql_num_rows(rows));
562         mysql_free_result(rows);
563         return S_FALSE;
564     }
565
566     log->debug("match found, tranfering data back into memory...");
567     
568     /* Columns in query:
569         0: application_id
570         1: ctime
571         2: atime
572         3: address
573         4: major
574         5: minor
575         6: provider
576         7: subject
577         8: authncontext
578         9: tokens
579      */
580
581     MYSQL_ROW row = mysql_fetch_row(rows);
582     applicationId=row[0];
583     created=atoi(row[1]);
584     accessed=atoi(row[2]);
585     clientAddress=row[3];
586     majorVersion=atoi(row[4]);
587     minorVersion=atoi(row[5]);
588     providerId=row[6];
589     subject=row[7];
590     authnContext=row[8];
591     if (row[9])
592         tokens=row[9];
593
594     // Free the results.
595     mysql_free_result(rows);
596
597     return NOERROR;
598 }
599
600 HRESULT ShibMySQLCCache::onUpdate(const char* key, const char* tokens, time_t lastAccess)
601 {
602 #ifdef _DEBUG
603     saml::NDC ndc("onUpdate");
604 #endif
605
606     ostringstream q;
607     if (tokens) {
608         q << "UPDATE state SET tokens=";
609         if (*tokens)
610             q << "'" << tokens << "'";
611         else
612             q << "null";
613     }
614     else if (lastAccess>0)
615         q << "UPDATE state SET atime=FROM_UNIXTIME(" << lastAccess << ")";
616     else {
617         log->warn("onUpdate called with nothing to do!");
618         return S_FALSE;
619     }
620  
621     q << " WHERE cookie='" << key << "'";
622
623     MYSQL* mysql = getMYSQL();
624     if (mysql_query(mysql, q.str().c_str())) {
625         const char* err=mysql_error(mysql);
626         log->error("error updating %s: %s", key, err);
627         if (isCorrupt(err) && repairTable(mysql,"state")) {
628             // Try again...
629             if (mysql_query(mysql, q.str().c_str())) {
630                 log->error("error updating %s: %s", key, mysql_error(mysql));
631                 return E_FAIL;
632             }
633         }
634         else
635             return E_FAIL;
636     }
637
638     return NOERROR;
639 }
640
641 HRESULT ShibMySQLCCache::onDelete(const char* key)
642 {
643 #ifdef _DEBUG
644     saml::NDC ndc("onDelete");
645 #endif
646
647     // Remove from the database
648     string q = string("DELETE FROM state WHERE cookie='") + key + "'";
649     MYSQL* mysql = getMYSQL();
650     if (mysql_query(mysql, q.c_str())) {
651         const char* err=mysql_error(mysql);
652         log->error("error deleting entry %s: %s", key, err);
653         if (isCorrupt(err) && repairTable(mysql,"state")) {
654             // Try again...
655             if (mysql_query(mysql, q.c_str())) {
656                 log->error("error deleting entry %s: %s", key, mysql_error(mysql));
657                 return E_FAIL;
658             }
659         }
660         else
661             return E_FAIL;
662     }
663
664     return NOERROR;
665 }
666
667 void ShibMySQLCCache::cleanup()
668 {
669 #ifdef _DEBUG
670   saml::NDC ndc("cleanup");
671 #endif
672
673   Mutex* mutex = Mutex::create();
674
675   int rerun_timer = 0;
676   int timeout_life = 0;
677
678   // Load our configuration details...
679   const XMLCh* tag=m_root->getAttributeNS(NULL,cleanupInterval);
680   if (tag && *tag)
681     rerun_timer = XMLString::parseInt(tag);
682
683   // search for 'mysql-cache-timeout' and then the regular cache timeout
684   tag=m_root->getAttributeNS(NULL,mysqlTimeout);
685   if (tag && *tag)
686     timeout_life = XMLString::parseInt(tag);
687   else {
688       tag=m_root->getAttributeNS(NULL,cacheTimeout);
689       if (tag && *tag)
690         timeout_life = XMLString::parseInt(tag);
691   }
692   
693   if (rerun_timer <= 0)
694     rerun_timer = 300;          // rerun every 5 minutes
695
696   if (timeout_life <= 0)
697     timeout_life = 28800;       // timeout after 8 hours
698
699   mutex->lock();
700
701   MYSQL* mysql = getMYSQL();
702
703   log->info("cleanup thread started...Run every %d secs; timeout after %d secs", rerun_timer, timeout_life);
704
705   while (shutdown == false) {
706     shutdown_wait->timedwait(mutex, rerun_timer);
707
708     if (shutdown == true)
709       break;
710
711     // Find all the entries in the database that haven't been used
712     // recently In particular, find all entries that have not been
713     // accessed in 'timeout_life' seconds.
714     ostringstream q;
715     q << "DELETE FROM state WHERE " << "UNIX_TIMESTAMP(NOW()) - UNIX_TIMESTAMP(atime) >= " << timeout_life;
716
717     if (mysql_query(mysql, q.str().c_str())) {
718       const char* err=mysql_error(mysql);
719       log->error("error purging old records: %s", err);
720         if (isCorrupt(err) && repairTable(mysql,"state")) {
721           if (mysql_query(mysql, q.str().c_str()))
722             log->error("error re-purging old records: %s", mysql_error(mysql));
723         }
724     }
725   }
726
727   log->info("cleanup thread exiting...");
728
729   mutex->unlock();
730   delete mutex;
731   Thread::exit(NULL);
732 }
733
734 void* ShibMySQLCCache::cleanup_fcn(void* cache_p)
735 {
736   ShibMySQLCCache* cache = (ShibMySQLCCache*)cache_p;
737
738   // First, let's block all signals
739   Thread::mask_all_signals();
740
741   // Now run the cleanup process.
742   cache->cleanup();
743   return NULL;
744 }
745
746 class MySQLReplayCache : public MySQLBase, virtual public IReplayCache
747 {
748 public:
749   MySQLReplayCache(const DOMElement* e);
750   virtual ~MySQLReplayCache() {}
751
752   bool check(const XMLCh* str, time_t expires) {auto_ptr_XMLCh temp(str); return check(temp.get(),expires);}
753   bool check(const char* str, time_t expires);
754 };
755
756 MySQLReplayCache::MySQLReplayCache(const DOMElement* e) : MySQLBase(e)
757 {
758 #ifdef _DEBUG
759   saml::NDC ndc("MySQLReplayCache");
760 #endif
761
762   log = &(Category::getInstance("shibmysql.ReplayCache"));
763 }
764
765 bool MySQLReplayCache::check(const char* str, time_t expires)
766 {
767 #ifdef _DEBUG
768     saml::NDC ndc("check");
769 #endif
770   
771     // Remove expired entries
772     string q = string("DELETE FROM replay WHERE expires < NOW()");
773     MYSQL* mysql = getMYSQL();
774     if (mysql_query(mysql, q.c_str())) {
775         const char* err=mysql_error(mysql);
776         log->error("Error deleting expired entries: %s", err);
777         if (isCorrupt(err) && repairTable(mysql,"replay")) {
778             // Try again...
779             if (mysql_query(mysql, q.c_str()))
780                 log->error("Error deleting expired entries: %s", mysql_error(mysql));
781         }
782     }
783   
784     string q2 = string("SELECT id FROM replay WHERE id='") + str + "'";
785     if (mysql_query(mysql, q2.c_str())) {
786         const char* err=mysql_error(mysql);
787         log->error("Error searching for %s: %s", str, err);
788         if (isCorrupt(err) && repairTable(mysql,"replay")) {
789             if (mysql_query(mysql, q2.c_str())) {
790                 log->error("Error retrying search for %s: %s", str, mysql_error(mysql));
791                 throw SAMLException("Replay cache failed, please inform application support staff.");
792             }
793         }
794         else
795             throw SAMLException("Replay cache failed, please inform application support staff.");
796     }
797
798     // Did we find it?
799     MYSQL_RES* rows = mysql_store_result(mysql);
800     if (rows && mysql_num_rows(rows)>0) {
801       mysql_free_result(rows);
802       return false;
803     }
804
805     ostringstream q3;
806     q3 << "INSERT INTO replay VALUES('" << str << "'," << "FROM_UNIXTIME(" << expires << "))";
807
808     // then add it to the database
809     if (mysql_query(mysql, q3.str().c_str())) {
810         const char* err=mysql_error(mysql);
811         log->error("Error inserting %s: %s", str, err);
812         if (isCorrupt(err) && repairTable(mysql,"state")) {
813             // Try again...
814             if (mysql_query(mysql, q3.str().c_str())) {
815                 log->error("Error inserting %s: %s", str, mysql_error(mysql));
816                 throw SAMLException("Replay cache failed, please inform application support staff.");
817             }
818         }
819         else
820             throw SAMLException("Replay cache failed, please inform application support staff.");
821     }
822     
823     return true;
824 }
825
826 /*************************************************************************
827  * The registration functions here...
828  */
829
830 IPlugIn* new_mysql_ccache(const DOMElement* e)
831 {
832     return new ShibMySQLCCache(e);
833 }
834
835 IPlugIn* new_mysql_replay(const DOMElement* e)
836 {
837     return new MySQLReplayCache(e);
838 }
839
840 extern "C" int SHIBMYSQL_EXPORTS saml_extension_init(void*)
841 {
842     // register this ccache type
843     SAMLConfig::getConfig().getPlugMgr().regFactory(shibtarget::XML::MySQLReplayCacheType, &new_mysql_replay);
844     SAMLConfig::getConfig().getPlugMgr().regFactory(shibtarget::XML::MySQLSessionCacheType, &new_mysql_ccache);
845     return 0;
846 }
847
848 extern "C" void SHIBMYSQL_EXPORTS saml_extension_term()
849 {
850     // Shutdown MySQL
851     mysql_server_end();
852     SAMLConfig::getConfig().getPlugMgr().unregFactory(shibtarget::XML::MySQLReplayCacheType);
853     SAMLConfig::getConfig().getPlugMgr().unregFactory(shibtarget::XML::MySQLSessionCacheType);
854 }