Starting to refactor session cache, eliminated IConfig class.
[shibboleth/cpp-sp.git] / shib-mysql-ccache / shib-mysql-ccache.cpp
1 /*
2  *  Copyright 2001-2005 Internet2
3  * 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /*
18  * shib-mysql-ccache.cpp: Shibboleth Credential Cache using MySQL.
19  *
20  * Created by:  Derek Atkins <derek@ihtfp.com>
21  *
22  * $Id$
23  */
24
25 // eventually we might be able to support autoconf via cygwin...
26 #if defined (_MSC_VER) || defined(__BORLANDC__)
27 # include "config_win32.h"
28 #else
29 # include "config.h"
30 #endif
31
32 #ifdef WIN32
33 # define SHIBMYSQL_EXPORTS __declspec(dllexport)
34 #else
35 # define SHIBMYSQL_EXPORTS
36 #endif
37
38 #ifdef HAVE_UNISTD_H
39 # include <unistd.h>
40 #endif
41
42 #include <shib-target/shib-target.h>
43
44 #include <log4cpp/Category.hh>
45 #include <xmltooling/util/NDC.h>
46 #include <xmltooling/util/Threads.h>
47 #include <xmltooling/util/XMLHelper.h>
48 #include <shibsp/SPConfig.h>
49 using xmltooling::XMLHelper;
50
51 #include <sstream>
52
53 #ifdef WIN32
54 # include <winsock.h>
55 #endif
56 #include <mysql.h>
57
58 // wanted to use MySQL codes for this, but can't seem to get back a 145
59 #define isCorrupt(s) strstr(s,"(errno: 145)")
60
61 #ifdef HAVE_LIBDMALLOCXX
62 #include <dmalloc.h>
63 #endif
64
65 using namespace shibsp;
66 using namespace shibtarget;
67 using namespace opensaml::saml2md;
68 using namespace saml;
69 using namespace log4cpp;
70 using namespace std;
71
72 #define PLUGIN_VER_MAJOR 3
73 #define PLUGIN_VER_MINOR 0
74
75 #define STATE_TABLE \
76   "CREATE TABLE state (" \
77   "cookie VARCHAR(64) PRIMARY KEY, " \
78   "application_id VARCHAR(255)," \
79   "ctime TIMESTAMP," \
80   "atime TIMESTAMP," \
81   "addr VARCHAR(128)," \
82   "major INT," \
83   "minor INT," \
84   "provider VARCHAR(256)," \
85   "subject TEXT," \
86   "authn_context TEXT," \
87   "tokens TEXT)"
88
89 #define REPLAY_TABLE \
90   "CREATE TABLE replay (id VARCHAR(255) PRIMARY KEY, " \
91   "expires TIMESTAMP, " \
92   "INDEX (expires))"
93
94 static const XMLCh Argument[] =
95 { chLatin_A, chLatin_r, chLatin_g, chLatin_u, chLatin_m, chLatin_e, chLatin_n, chLatin_t, chNull };
96 static const XMLCh cleanupInterval[] =
97 { chLatin_c, chLatin_l, chLatin_e, chLatin_a, chLatin_n, chLatin_u, chLatin_p,
98   chLatin_I, chLatin_n, chLatin_t, chLatin_e, chLatin_r, chLatin_v, chLatin_a, chLatin_l, chNull
99 };
100 static const XMLCh cacheTimeout[] =
101 { chLatin_c, chLatin_a, chLatin_c, chLatin_h, chLatin_e, chLatin_T, chLatin_i, chLatin_m, chLatin_e, chLatin_o, chLatin_u, chLatin_t, chNull };
102 static const XMLCh mysqlTimeout[] =
103 { chLatin_m, chLatin_y, chLatin_s, chLatin_q, chLatin_l, chLatin_T, chLatin_i, chLatin_m, chLatin_e, chLatin_o, chLatin_u, chLatin_t, chNull };
104 static const XMLCh storeAttributes[] =
105 { chLatin_s, chLatin_t, chLatin_o, chLatin_r, chLatin_e, chLatin_A, chLatin_t, chLatin_t, chLatin_r, chLatin_i, chLatin_b, chLatin_u, chLatin_t, chLatin_e, chLatin_s, chNull };
106
107 static bool g_MySQLInitialized = false;
108
109 class MySQLBase : public virtual saml::IPlugIn
110 {
111 public:
112   MySQLBase(const DOMElement* e);
113   virtual ~MySQLBase();
114
115   MYSQL* getMYSQL();
116   bool repairTable(MYSQL*&, const char* table);
117
118   log4cpp::Category* log;
119
120 protected:
121     xmltooling::ThreadKey* m_mysql;
122   const DOMElement* m_root; // can only use this during initialization
123
124   bool initialized;
125   bool handleShutdown;
126
127   void createDatabase(MYSQL*, int major, int minor);
128   void upgradeDatabase(MYSQL*);
129   pair<int,int> getVersion(MYSQL*);
130 };
131
132 // Forward declarations
133 static void mysqlInit(const DOMElement* e, Category& log);
134
135 extern "C" void shib_mysql_destroy_handle(void* data)
136 {
137   MYSQL* mysql = (MYSQL*) data;
138   if (mysql) mysql_close(mysql);
139 }
140
141 MySQLBase::MySQLBase(const DOMElement* e) : m_root(e)
142 {
143 #ifdef _DEBUG
144   xmltooling::NDC ndc("MySQLBase");
145 #endif
146   log = &(Category::getInstance("shibtarget.SessionCache.MySQL"));
147
148   m_mysql = xmltooling::ThreadKey::create(&shib_mysql_destroy_handle);
149
150   initialized = false;
151   mysqlInit(e,*log);
152   getMYSQL();
153   initialized = true;
154 }
155
156 MySQLBase::~MySQLBase()
157 {
158   delete m_mysql;
159 }
160
161 MYSQL* MySQLBase::getMYSQL()
162 {
163 #ifdef _DEBUG
164     xmltooling::NDC ndc("getMYSQL");
165 #endif
166
167     // Do we already have a handle?
168     MYSQL* mysql=reinterpret_cast<MYSQL*>(m_mysql->getData());
169     if (mysql)
170         return mysql;
171
172     // Connect to the database
173     mysql = mysql_init(NULL);
174     if (!mysql) {
175         log->error("mysql_init failed");
176         mysql_close(mysql);
177         throw SAMLException("MySQLBase::getMYSQL(): mysql_init() failed");
178     }
179
180     if (!mysql_real_connect(mysql, NULL, NULL, NULL, "shibd", 0, NULL, 0)) {
181         if (initialized) {
182             log->crit("mysql_real_connect failed: %s", mysql_error(mysql));
183             mysql_close(mysql);
184             throw SAMLException("MySQLBase::getMYSQL(): mysql_real_connect() failed");
185         }
186         else {
187             log->info("mysql_real_connect failed: %s.  Trying to create", mysql_error(mysql));
188
189             // This will throw an exception if it fails.
190             createDatabase(mysql, PLUGIN_VER_MAJOR, PLUGIN_VER_MINOR);
191         }
192     }
193
194     pair<int,int> v=getVersion (mysql);
195
196     // Make sure we've got the right version
197     if (v.first != PLUGIN_VER_MAJOR || v.second != PLUGIN_VER_MINOR) {
198    
199         // If we're capable, try upgrading on the fly...
200         if (v.first == 0  || v.first == 1 || v.first == 2) {
201             if (mysql_query(mysql, "DROP TABLE state")) {
202                 log->error("error dropping old session state table: %s", mysql_error(mysql));
203             }
204             if (v.first==2 && mysql_query(mysql, "DROP TABLE replay")) {
205                 log->error("error dropping old session state table: %s", mysql_error(mysql));
206             }
207             upgradeDatabase(mysql);
208         }
209         else {
210             mysql_close(mysql);
211             log->crit("Unknown database version: %d.%d", v.first, v.second);
212             throw SAMLException("MySQLBase::getMYSQL(): Unknown database version");
213         }
214     }
215
216     // We're all set.. Save off the handle for this thread.
217     m_mysql->setData(mysql);
218     return mysql;
219 }
220
221 bool MySQLBase::repairTable(MYSQL*& mysql, const char* table)
222 {
223     string q = string("REPAIR TABLE ") + table;
224     if (mysql_query(mysql, q.c_str())) {
225         log->error("Error repairing table %s: %s", table, mysql_error(mysql));
226         return false;
227     }
228
229     // seems we have to recycle the connection to get the thread to keep working
230     // other threads seem to be ok, but we should monitor that
231     mysql_close(mysql);
232     m_mysql->setData(NULL);
233     mysql=getMYSQL();
234     return true;
235 }
236
237 void MySQLBase::createDatabase(MYSQL* mysql, int major, int minor)
238 {
239   log->info("creating database");
240
241   MYSQL* ms = NULL;
242   try {
243     ms = mysql_init(NULL);
244     if (!ms) {
245       log->crit("mysql_init failed");
246       throw SAMLException("ShibMySQLCCache::createDatabase(): mysql_init failed");
247     }
248
249     if (!mysql_real_connect(ms, NULL, NULL, NULL, NULL, 0, NULL, 0)) {
250       log->crit("cannot open DB file to create DB: %s", mysql_error(ms));
251       throw SAMLException("ShibMySQLCCache::createDatabase(): mysql_real_connect failed");
252     }
253
254     if (mysql_query(ms, "CREATE DATABASE shibd")) {
255       log->crit("cannot create shibd database: %s", mysql_error(ms));
256       throw SAMLException("ShibMySQLCCache::createDatabase(): create db cmd failed");
257     }
258
259     if (!mysql_real_connect(mysql, NULL, NULL, NULL, "shibd", 0, NULL, 0)) {
260       log->crit("cannot open shibd database");
261       throw SAMLException("ShibMySQLCCache::createDatabase(): mysql_real_connect to plugin db failed");
262     }
263
264     mysql_close(ms);
265     
266   }
267   catch (SAMLException&) {
268     if (ms)
269       mysql_close(ms);
270     mysql_close(mysql);
271     throw;
272   }
273
274   // Now create the tables if they don't exist
275   log->info("Creating database tables");
276
277   if (mysql_query(mysql, "CREATE TABLE version (major INT, minor INT)")) {
278     log->error ("error creating version: %s", mysql_error(mysql));
279     throw SAMLException("ShibMySQLCCache::createDatabase(): create table cmd failed");
280   }
281
282   if (mysql_query(mysql,STATE_TABLE)) {
283     log->error ("error creating state table: %s", mysql_error(mysql));
284     throw SAMLException("ShibMySQLCCache::createDatabase(): create table cmd failed");
285   }
286
287   if (mysql_query(mysql,REPLAY_TABLE)) {
288     log->error ("error creating replay table: %s", mysql_error(mysql));
289     throw SAMLException("ShibMySQLCCache::createDatabase(): create table cmd failed");
290   }
291
292   ostringstream q;
293   q << "INSERT INTO version VALUES(" << major << "," << minor << ")";
294   if (mysql_query(mysql, q.str().c_str())) {
295     log->error ("error setting version: %s", mysql_error(mysql));
296     throw SAMLException("ShibMySQLCCache::createDatabase(): version insert failed");
297   }
298 }
299
300 void MySQLBase::upgradeDatabase(MYSQL* mysql)
301 {
302     if (mysql_query(mysql,STATE_TABLE)) {
303         log->error ("error creating state table: %s", mysql_error(mysql));
304         throw SAMLException("ShibMySQLCCache::upgradeDatabase(): error creating state table");
305     }
306
307     if (mysql_query(mysql,REPLAY_TABLE)) {
308         log->error ("error creating replay table: %s", mysql_error(mysql));
309         throw SAMLException("ShibMySQLCCache::upgradeDatabase(): error creating replay table");
310     }
311
312     ostringstream q;
313     q << "UPDATE version SET major = " << PLUGIN_VER_MAJOR;
314     if (mysql_query(mysql, q.str().c_str())) {
315         log->error ("error updating version: %s", mysql_error(mysql));
316         throw SAMLException("ShibMySQLCCache::upgradeDatabase(): error updating version");
317     }
318 }
319
320 pair<int,int> MySQLBase::getVersion(MYSQL* mysql)
321 {
322     // grab the version number from the database
323     if (mysql_query(mysql, "SELECT * FROM version")) {
324         log->error("error reading version: %s", mysql_error(mysql));
325         throw SAMLException("MySQLBase::getVersion(): error reading version");
326     }
327
328     MYSQL_RES* rows = mysql_store_result(mysql);
329     if (rows) {
330         if (mysql_num_rows(rows) == 1 && mysql_num_fields(rows) == 2)  {
331           MYSQL_ROW row = mysql_fetch_row(rows);
332           int major = row[0] ? atoi(row[0]) : -1;
333           int minor = row[1] ? atoi(row[1]) : -1;
334           log->debug("opening database version %d.%d", major, minor);
335           mysql_free_result(rows);
336           return make_pair(major,minor);
337         }
338         else {
339             // Wrong number of rows or wrong number of fields...
340             log->crit("Houston, we've got a problem with the database...");
341             mysql_free_result(rows);
342             throw SAMLException("MySQLBase::getVersion(): version verification failed");
343         }
344     }
345     log->crit("MySQL Read Failed in version verification");
346     throw SAMLException("MySQLBase::getVersion(): error reading version");
347 }
348
349 static void mysqlInit(const DOMElement* e, Category& log)
350 {
351     if (g_MySQLInitialized) {
352         log.info("MySQL embedded server already initialized");
353         return;
354     }
355     log.info("initializing MySQL embedded server");
356
357     // Setup the argument array
358     vector<string> arg_array;
359     arg_array.push_back("shibboleth");
360
361     // grab any MySQL parameters from the config file
362     e=XMLHelper::getFirstChildElement(e,Argument);
363     while (e) {
364         auto_ptr_char arg(e->getFirstChild()->getNodeValue());
365         if (arg.get())
366             arg_array.push_back(arg.get());
367         e=XMLHelper::getNextSiblingElement(e,Argument);
368     }
369
370     // Compute the argument array
371     vector<string>::size_type arg_count = arg_array.size();
372     const char** args=new const char*[arg_count];
373     for (vector<string>::size_type i = 0; i < arg_count; i++)
374         args[i] = arg_array[i].c_str();
375
376     // Initialize MySQL with the arguments
377     mysql_server_init(arg_count, (char **)args, NULL);
378
379     delete[] args;
380     g_MySQLInitialized = true;
381 }  
382
383 class ShibMySQLCCache : public MySQLBase, virtual public ISessionCache, virtual public ISessionCacheStore
384 {
385 public:
386     ShibMySQLCCache(const DOMElement* e);
387     virtual ~ShibMySQLCCache();
388
389     // Delegate all the ISessionCache methods.
390     string insert(
391         const IApplication* application,
392         const RoleDescriptor* role,
393         const char* client_addr,
394         const SAMLSubject* subject,
395         const char* authnContext,
396         const SAMLResponse* tokens
397         )
398     { return m_cache->insert(application,role,client_addr,subject,authnContext,tokens); }
399     ISessionCacheEntry* find(const char* key, const IApplication* application, const char* client_addr)
400     { return m_cache->find(key,application,client_addr); }
401     void remove(const char* key, const IApplication* application, const char* client_addr)
402     { m_cache->remove(key,application,client_addr); }
403
404     bool setBackingStore(ISessionCacheStore*) {return false;}
405
406     // Store methods handle the database work
407     HRESULT onCreate(
408         const char* key,
409         const IApplication* application,
410         const ISessionCacheEntry* entry,
411         int majorVersion,
412         int minorVersion,
413         time_t created
414         );
415     HRESULT onRead(
416         const char* key,
417         string& applicationId,
418         string& clientAddress,
419         string& providerId,
420         string& subject,
421         string& authnContext,
422         string& tokens,
423         int& majorVersion,
424         int& minorVersion,
425         time_t& created,
426         time_t& accessed
427         );
428     HRESULT onRead(const char* key, time_t& accessed);
429     HRESULT onRead(const char* key, string& tokens);
430     HRESULT onUpdate(const char* key, const char* tokens=NULL, time_t accessed=0);
431     HRESULT onDelete(const char* key);
432
433     void cleanup();
434
435 private:
436     bool m_storeAttributes;
437     ISessionCache* m_cache;
438     xmltooling::CondWait* shutdown_wait;
439     bool shutdown;
440     xmltooling::Thread* cleanup_thread;
441
442     static void* cleanup_fcn(void*); // XXX Assumed an ShibMySQLCCache
443 };
444
445 ShibMySQLCCache::ShibMySQLCCache(const DOMElement* e) : MySQLBase(e), m_storeAttributes(false)
446 {
447 #ifdef _DEBUG
448     xmltooling::NDC ndc("ShibMySQLCCache");
449 #endif
450
451     m_cache = dynamic_cast<ISessionCache*>(
452         SAMLConfig::getConfig().getPlugMgr().newPlugin(MEMORY_SESSIONCACHE, e)
453     );
454     if (!m_cache->setBackingStore(this)) {
455         delete m_cache;
456         throw SAMLException("Unable to register MySQL cache plugin as a cache store.");
457     }
458     
459     shutdown_wait = xmltooling::CondWait::create();
460     shutdown = false;
461
462     // Load our configuration details...
463     const XMLCh* tag=m_root->getAttributeNS(NULL,storeAttributes);
464     if (tag && *tag && (*tag==chLatin_t || *tag==chDigit_1))
465         m_storeAttributes=true;
466
467     // Initialize the cleanup thread
468     cleanup_thread = xmltooling::Thread::create(&cleanup_fcn, (void*)this);
469 }
470
471 ShibMySQLCCache::~ShibMySQLCCache()
472 {
473     shutdown = true;
474     shutdown_wait->signal();
475     cleanup_thread->join(NULL);
476     delete m_cache;
477 }
478
479 HRESULT ShibMySQLCCache::onCreate(
480     const char* key,
481     const IApplication* application,
482     const ISessionCacheEntry* entry,
483     int majorVersion,
484     int minorVersion,
485     time_t created
486     )
487 {
488 #ifdef _DEBUG
489     xmltooling::NDC ndc("onCreate");
490 #endif
491
492     // Get XML data from entry. Default is not to return SAML objects.
493     const char* context=entry->getAuthnContext();
494     pair<const char*,const SAMLSubject*> subject=entry->getSubject();
495     pair<const char*,const SAMLResponse*> tokens=entry->getTokens();
496
497     ostringstream q;
498     q << "INSERT INTO state VALUES('" << key << "','" << application->getId() << "',";
499     if (created==0)
500         q << "NOW(),NOW(),'";
501     else
502         q << "FROM_UNIXTIME(" << created << "),NOW(),'";
503     q << entry->getClientAddress() << "'," << majorVersion << "," << minorVersion << ",'" << entry->getProviderId() << "','"
504         << subject.first << "','" << context << "',";
505
506     if (m_storeAttributes && tokens.first)
507         q << "'" << tokens.first << "')";
508     else
509         q << "null)";
510
511     if (log->isDebugEnabled())
512         log->debug("SQL insert: %s", q.str().c_str());
513
514     // then add it to the database
515     MYSQL* mysql = getMYSQL();
516     if (mysql_query(mysql, q.str().c_str())) {
517         const char* err=mysql_error(mysql);
518         log->error("error inserting %s: %s", key, err);
519         if (isCorrupt(err) && repairTable(mysql,"state")) {
520             // Try again...
521             if (mysql_query(mysql, q.str().c_str())) {
522                 log->error("error inserting %s: %s", key, mysql_error(mysql));
523                 return E_FAIL;
524             }
525         }
526         else
527             throw E_FAIL;
528     }
529
530     return NOERROR;
531 }
532
533 HRESULT ShibMySQLCCache::onRead(
534     const char* key,
535     string& applicationId,
536     string& clientAddress,
537     string& providerId,
538     string& subject,
539     string& authnContext,
540     string& tokens,
541     int& majorVersion,
542     int& minorVersion,
543     time_t& created,
544     time_t& accessed
545     )
546 {
547 #ifdef _DEBUG
548     xmltooling::NDC ndc("onRead");
549 #endif
550
551     log->debug("searching MySQL database...");
552
553     string q = string("SELECT application_id,UNIX_TIMESTAMP(ctime),UNIX_TIMESTAMP(atime),addr,major,minor,provider,subject,authn_context,tokens FROM state WHERE cookie='") + key + "' LIMIT 1";
554
555     MYSQL* mysql = getMYSQL();
556     if (mysql_query(mysql, q.c_str())) {
557         const char* err=mysql_error(mysql);
558         log->error("error searching for %s: %s", key, err);
559         if (isCorrupt(err) && repairTable(mysql,"state")) {
560             if (mysql_query(mysql, q.c_str()))
561                 log->error("error retrying search for %s: %s", key, mysql_error(mysql));
562         }
563     }
564
565     MYSQL_RES* rows = mysql_store_result(mysql);
566
567     // Nope, doesn't exist.
568     if (!rows || mysql_num_rows(rows)==0) {
569         log->debug("not found in database");
570         if (rows)
571             mysql_free_result(rows);
572         return S_FALSE;
573     }
574
575     // Make sure we got 1 and only 1 row.
576     if (mysql_num_rows(rows) > 1) {
577         log->error("database select returned %d rows!", mysql_num_rows(rows));
578         mysql_free_result(rows);
579         return E_FAIL;
580     }
581
582     log->debug("session found, tranfering data back into memory");
583     
584     /* Columns in query:
585         0: application_id
586         1: ctime
587         2: atime
588         3: address
589         4: major
590         5: minor
591         6: provider
592         7: subject
593         8: authncontext
594         9: tokens
595      */
596
597     MYSQL_ROW row = mysql_fetch_row(rows);
598     applicationId=row[0];
599     created=atoi(row[1]);
600     accessed=atoi(row[2]);
601     clientAddress=row[3];
602     majorVersion=atoi(row[4]);
603     minorVersion=atoi(row[5]);
604     providerId=row[6];
605     subject=row[7];
606     authnContext=row[8];
607     if (row[9])
608         tokens=row[9];
609
610     // Free the results.
611     mysql_free_result(rows);
612
613     return NOERROR;
614 }
615
616 HRESULT ShibMySQLCCache::onRead(const char* key, time_t& accessed)
617 {
618 #ifdef _DEBUG
619     xmltooling::NDC ndc("onRead");
620 #endif
621
622     log->debug("reading last access time from MySQL database");
623
624     string q = string("SELECT UNIX_TIMESTAMP(atime) FROM state WHERE cookie='") + key + "' LIMIT 1";
625
626     MYSQL* mysql = getMYSQL();
627     if (mysql_query(mysql, q.c_str())) {
628         const char* err=mysql_error(mysql);
629         log->error("error searching for %s: %s", key, err);
630         if (isCorrupt(err) && repairTable(mysql,"state")) {
631             if (mysql_query(mysql, q.c_str()))
632                 log->error("error retrying search for %s: %s", key, mysql_error(mysql));
633         }
634     }
635
636     MYSQL_RES* rows = mysql_store_result(mysql);
637
638     // Nope, doesn't exist.
639     if (!rows || mysql_num_rows(rows)==0) {
640         log->warn("session expected, but not found in database");
641         if (rows)
642             mysql_free_result(rows);
643         return S_FALSE;
644     }
645
646     // Make sure we got 1 and only 1 row.
647     if (mysql_num_rows(rows) != 1) {
648         log->error("database select returned %d rows!", mysql_num_rows(rows));
649         mysql_free_result(rows);
650         return E_FAIL;
651     }
652
653     MYSQL_ROW row = mysql_fetch_row(rows);
654     accessed=atoi(row[0]);
655
656     // Free the results.
657     mysql_free_result(rows);
658
659     return NOERROR;
660 }
661
662 HRESULT ShibMySQLCCache::onRead(const char* key, string& tokens)
663 {
664 #ifdef _DEBUG
665     xmltooling::NDC ndc("onRead");
666 #endif
667
668     if (!m_storeAttributes)
669         return S_FALSE;
670
671     log->debug("reading cached tokens from MySQL database");
672
673     string q = string("SELECT tokens FROM state WHERE cookie='") + key + "' LIMIT 1";
674
675     MYSQL* mysql = getMYSQL();
676     if (mysql_query(mysql, q.c_str())) {
677         const char* err=mysql_error(mysql);
678         log->error("error searching for %s: %s", key, err);
679         if (isCorrupt(err) && repairTable(mysql,"state")) {
680             if (mysql_query(mysql, q.c_str()))
681                 log->error("error retrying search for %s: %s", key, mysql_error(mysql));
682         }
683     }
684
685     MYSQL_RES* rows = mysql_store_result(mysql);
686
687     // Nope, doesn't exist.
688     if (!rows || mysql_num_rows(rows)==0) {
689         log->warn("session expected, but not found in database");
690         if (rows)
691             mysql_free_result(rows);
692         return S_FALSE;
693     }
694
695     // Make sure we got 1 and only 1 row.
696     if (mysql_num_rows(rows) != 1) {
697         log->error("database select returned %d rows!", mysql_num_rows(rows));
698         mysql_free_result(rows);
699         return E_FAIL;
700     }
701
702     MYSQL_ROW row = mysql_fetch_row(rows);
703     if (row[0])
704         tokens=row[0];
705
706     // Free the results.
707     mysql_free_result(rows);
708
709     return NOERROR;
710 }
711
712 HRESULT ShibMySQLCCache::onUpdate(const char* key, const char* tokens, time_t lastAccess)
713 {
714 #ifdef _DEBUG
715     xmltooling::NDC ndc("onUpdate");
716 #endif
717
718     ostringstream q;
719     if (lastAccess>0)
720         q << "UPDATE state SET atime=FROM_UNIXTIME(" << lastAccess << ")";
721     else if (tokens) {
722         if (!m_storeAttributes)
723             return S_FALSE;
724         q << "UPDATE state SET tokens=";
725         if (*tokens)
726             q << "'" << tokens << "'";
727         else
728             q << "null";
729     }
730     else {
731         log->warn("onUpdate called with nothing to do!");
732         return S_FALSE;
733     }
734  
735     q << " WHERE cookie='" << key << "'";
736
737     MYSQL* mysql = getMYSQL();
738     if (mysql_query(mysql, q.str().c_str())) {
739         const char* err=mysql_error(mysql);
740         log->error("error updating %s: %s", key, err);
741         if (isCorrupt(err) && repairTable(mysql,"state")) {
742             // Try again...
743             if (mysql_query(mysql, q.str().c_str())) {
744                 log->error("error updating %s: %s", key, mysql_error(mysql));
745                 return E_FAIL;
746             }
747         }
748         else
749             return E_FAIL;
750     }
751
752     return NOERROR;
753 }
754
755 HRESULT ShibMySQLCCache::onDelete(const char* key)
756 {
757 #ifdef _DEBUG
758     xmltooling::NDC ndc("onDelete");
759 #endif
760
761     // Remove from the database
762     string q = string("DELETE FROM state WHERE cookie='") + key + "'";
763     MYSQL* mysql = getMYSQL();
764     if (mysql_query(mysql, q.c_str())) {
765         const char* err=mysql_error(mysql);
766         log->error("error deleting entry %s: %s", key, err);
767         if (isCorrupt(err) && repairTable(mysql,"state")) {
768             // Try again...
769             if (mysql_query(mysql, q.c_str())) {
770                 log->error("error deleting entry %s: %s", key, mysql_error(mysql));
771                 return E_FAIL;
772             }
773         }
774         else
775             return E_FAIL;
776     }
777
778     return NOERROR;
779 }
780
781 void ShibMySQLCCache::cleanup()
782 {
783 #ifdef _DEBUG
784   xmltooling::NDC ndc("cleanup");
785 #endif
786
787   xmltooling::Mutex* mutex = xmltooling::Mutex::create();
788
789   int rerun_timer = 0;
790   int timeout_life = 0;
791
792   // Load our configuration details...
793   const XMLCh* tag=m_root->getAttributeNS(NULL,cleanupInterval);
794   if (tag && *tag)
795     rerun_timer = XMLString::parseInt(tag);
796
797   // search for 'mysql-cache-timeout' and then the regular cache timeout
798   tag=m_root->getAttributeNS(NULL,mysqlTimeout);
799   if (tag && *tag)
800     timeout_life = XMLString::parseInt(tag);
801   else {
802       tag=m_root->getAttributeNS(NULL,cacheTimeout);
803       if (tag && *tag)
804         timeout_life = XMLString::parseInt(tag);
805   }
806   
807   if (rerun_timer <= 0)
808     rerun_timer = 300;          // rerun every 5 minutes
809
810   if (timeout_life <= 0)
811     timeout_life = 28800;       // timeout after 8 hours
812
813   mutex->lock();
814
815   MYSQL* mysql = getMYSQL();
816
817   log->info("cleanup thread started...Run every %d secs; timeout after %d secs", rerun_timer, timeout_life);
818
819   while (shutdown == false) {
820     shutdown_wait->timedwait(mutex, rerun_timer);
821
822     if (shutdown == true)
823       break;
824
825     // Find all the entries in the database that haven't been used
826     // recently In particular, find all entries that have not been
827     // accessed in 'timeout_life' seconds.
828     ostringstream q;
829     q << "DELETE FROM state WHERE " << "UNIX_TIMESTAMP(NOW()) - UNIX_TIMESTAMP(atime) >= " << timeout_life;
830
831     if (mysql_query(mysql, q.str().c_str())) {
832       const char* err=mysql_error(mysql);
833       log->error("error purging old records: %s", err);
834         if (isCorrupt(err) && repairTable(mysql,"state")) {
835           if (mysql_query(mysql, q.str().c_str()))
836             log->error("error re-purging old records: %s", mysql_error(mysql));
837         }
838     }
839   }
840
841   log->info("cleanup thread exiting...");
842
843   mutex->unlock();
844   delete mutex;
845   xmltooling::Thread::exit(NULL);
846 }
847
848 void* ShibMySQLCCache::cleanup_fcn(void* cache_p)
849 {
850   ShibMySQLCCache* cache = (ShibMySQLCCache*)cache_p;
851
852 #ifndef WIN32
853   // First, let'block all signals
854   xmltooling::Thread::mask_all_signals();
855 #endif
856
857   // Now run the cleanup process.
858   cache->cleanup();
859   return NULL;
860 }
861
862 class MySQLReplayCache : public MySQLBase, virtual public IReplayCache
863 {
864 public:
865   MySQLReplayCache(const DOMElement* e);
866   virtual ~MySQLReplayCache() {}
867
868   bool check(const XMLCh* str, time_t expires) {auto_ptr_XMLCh temp(str); return check(temp.get(),expires);}
869   bool check(const char* str, time_t expires);
870 };
871
872 MySQLReplayCache::MySQLReplayCache(const DOMElement* e) : MySQLBase(e) {}
873
874 bool MySQLReplayCache::check(const char* str, time_t expires)
875 {
876 #ifdef _DEBUG
877     xmltooling::NDC ndc("check");
878 #endif
879   
880     // Remove expired entries
881     string q = string("DELETE FROM replay WHERE expires < NOW()");
882     MYSQL* mysql = getMYSQL();
883     if (mysql_query(mysql, q.c_str())) {
884         const char* err=mysql_error(mysql);
885         log->error("Error deleting expired entries: %s", err);
886         if (isCorrupt(err) && repairTable(mysql,"replay")) {
887             // Try again...
888             if (mysql_query(mysql, q.c_str()))
889                 log->error("Error deleting expired entries: %s", mysql_error(mysql));
890         }
891     }
892   
893     string q2 = string("SELECT id FROM replay WHERE id='") + str + "'";
894     if (mysql_query(mysql, q2.c_str())) {
895         const char* err=mysql_error(mysql);
896         log->error("Error searching for %s: %s", str, err);
897         if (isCorrupt(err) && repairTable(mysql,"replay")) {
898             if (mysql_query(mysql, q2.c_str())) {
899                 log->error("Error retrying search for %s: %s", str, mysql_error(mysql));
900                 throw SAMLException("Replay cache failed, please inform application support staff.");
901             }
902         }
903         else
904             throw SAMLException("Replay cache failed, please inform application support staff.");
905     }
906
907     // Did we find it?
908     MYSQL_RES* rows = mysql_store_result(mysql);
909     if (rows && mysql_num_rows(rows)>0) {
910       mysql_free_result(rows);
911       return false;
912     }
913
914     ostringstream q3;
915     q3 << "INSERT INTO replay VALUES('" << str << "'," << "FROM_UNIXTIME(" << expires << "))";
916
917     // then add it to the database
918     if (mysql_query(mysql, q3.str().c_str())) {
919         const char* err=mysql_error(mysql);
920         log->error("Error inserting %s: %s", str, err);
921         if (isCorrupt(err) && repairTable(mysql,"state")) {
922             // Try again...
923             if (mysql_query(mysql, q3.str().c_str())) {
924                 log->error("Error inserting %s: %s", str, mysql_error(mysql));
925                 throw SAMLException("Replay cache failed, please inform application support staff.");
926             }
927         }
928         else
929             throw SAMLException("Replay cache failed, please inform application support staff.");
930     }
931     
932     return true;
933 }
934
935 /*************************************************************************
936  * The registration functions here...
937  */
938
939 SessionCache* new_mysql_ccache(const DOMElement* const & e)
940 {
941     return new ShibMySQLCCache(e);
942 }
943
944 IPlugIn* new_mysql_replay(const DOMElement* e)
945 {
946     return new MySQLReplayCache(e);
947 }
948
949 extern "C" int SHIBMYSQL_EXPORTS saml_extension_init(void*)
950 {
951     // register this ccache type
952     SAMLConfig::getConfig().getPlugMgr().regFactory(MYSQL_REPLAYCACHE, &new_mysql_replay);
953     SPConfig::getConfig().SessionCacheManager.registerFactory(MYSQL_SESSIONCACHE, &new_mysql_ccache);
954     return 0;
955 }
956
957 extern "C" void SHIBMYSQL_EXPORTS saml_extension_term()
958 {
959     // Shutdown MySQL
960     if (g_MySQLInitialized)
961         mysql_server_end();
962     SAMLConfig::getConfig().getPlugMgr().unregFactory(MYSQL_REPLAYCACHE);
963 }