ff5aec8ef48fd7f88fe4cc484e2f4f6dc03c9ea5
[shibboleth/sp.git] / shib-mysql-ccache / shib-mysql-ccache.cpp
1 /*
2  *  Copyright 2001-2005 Internet2
3  * 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /*
18  * shib-mysql-ccache.cpp: Shibboleth Credential Cache using MySQL.
19  *
20  * Created by:  Derek Atkins <derek@ihtfp.com>
21  *
22  * $Id$
23  */
24
25 // eventually we might be able to support autoconf via cygwin...
26 #if defined (_MSC_VER) || defined(__BORLANDC__)
27 # include "config_win32.h"
28 #else
29 # include "config.h"
30 #endif
31
32 #ifdef WIN32
33 # define SHIBMYSQL_EXPORTS __declspec(dllexport)
34 #else
35 # define SHIBMYSQL_EXPORTS
36 #endif
37
38 #ifdef HAVE_UNISTD_H
39 # include <unistd.h>
40 #endif
41
42 #include <shib-target/shib-target.h>
43
44 #include <log4cpp/Category.hh>
45 #include <xmltooling/util/NDC.h>
46 #include <xmltooling/util/Threads.h>
47 #include <xmltooling/util/XMLHelper.h>
48 using xmltooling::XMLHelper;
49
50 #include <sstream>
51
52 #ifdef WIN32
53 # include <winsock.h>
54 #endif
55 #include <mysql.h>
56
57 // wanted to use MySQL codes for this, but can't seem to get back a 145
58 #define isCorrupt(s) strstr(s,"(errno: 145)")
59
60 #ifdef HAVE_LIBDMALLOCXX
61 #include <dmalloc.h>
62 #endif
63
64 using namespace shibtarget;
65 using namespace opensaml::saml2md;
66 using namespace saml;
67 using namespace log4cpp;
68 using namespace std;
69
70 #define PLUGIN_VER_MAJOR 3
71 #define PLUGIN_VER_MINOR 0
72
73 #define STATE_TABLE \
74   "CREATE TABLE state (" \
75   "cookie VARCHAR(64) PRIMARY KEY, " \
76   "application_id VARCHAR(255)," \
77   "ctime TIMESTAMP," \
78   "atime TIMESTAMP," \
79   "addr VARCHAR(128)," \
80   "major INT," \
81   "minor INT," \
82   "provider VARCHAR(256)," \
83   "subject TEXT," \
84   "authn_context TEXT," \
85   "tokens TEXT)"
86
87 #define REPLAY_TABLE \
88   "CREATE TABLE replay (id VARCHAR(255) PRIMARY KEY, " \
89   "expires TIMESTAMP, " \
90   "INDEX (expires))"
91
92 static const XMLCh Argument[] =
93 { chLatin_A, chLatin_r, chLatin_g, chLatin_u, chLatin_m, chLatin_e, chLatin_n, chLatin_t, chNull };
94 static const XMLCh cleanupInterval[] =
95 { chLatin_c, chLatin_l, chLatin_e, chLatin_a, chLatin_n, chLatin_u, chLatin_p,
96   chLatin_I, chLatin_n, chLatin_t, chLatin_e, chLatin_r, chLatin_v, chLatin_a, chLatin_l, chNull
97 };
98 static const XMLCh cacheTimeout[] =
99 { chLatin_c, chLatin_a, chLatin_c, chLatin_h, chLatin_e, chLatin_T, chLatin_i, chLatin_m, chLatin_e, chLatin_o, chLatin_u, chLatin_t, chNull };
100 static const XMLCh mysqlTimeout[] =
101 { chLatin_m, chLatin_y, chLatin_s, chLatin_q, chLatin_l, chLatin_T, chLatin_i, chLatin_m, chLatin_e, chLatin_o, chLatin_u, chLatin_t, chNull };
102 static const XMLCh storeAttributes[] =
103 { chLatin_s, chLatin_t, chLatin_o, chLatin_r, chLatin_e, chLatin_A, chLatin_t, chLatin_t, chLatin_r, chLatin_i, chLatin_b, chLatin_u, chLatin_t, chLatin_e, chLatin_s, chNull };
104
105 static bool g_MySQLInitialized = false;
106
107 class MySQLBase : public virtual saml::IPlugIn
108 {
109 public:
110   MySQLBase(const DOMElement* e);
111   virtual ~MySQLBase();
112
113   MYSQL* getMYSQL();
114   bool repairTable(MYSQL*&, const char* table);
115
116   log4cpp::Category* log;
117
118 protected:
119     xmltooling::ThreadKey* m_mysql;
120   const DOMElement* m_root; // can only use this during initialization
121
122   bool initialized;
123   bool handleShutdown;
124
125   void createDatabase(MYSQL*, int major, int minor);
126   void upgradeDatabase(MYSQL*);
127   pair<int,int> getVersion(MYSQL*);
128 };
129
130 // Forward declarations
131 static void mysqlInit(const DOMElement* e, Category& log);
132
133 extern "C" void shib_mysql_destroy_handle(void* data)
134 {
135   MYSQL* mysql = (MYSQL*) data;
136   if (mysql) mysql_close(mysql);
137 }
138
139 MySQLBase::MySQLBase(const DOMElement* e) : m_root(e)
140 {
141 #ifdef _DEBUG
142   xmltooling::NDC ndc("MySQLBase");
143 #endif
144   log = &(Category::getInstance("shibtarget.SessionCache.MySQL"));
145
146   m_mysql = xmltooling::ThreadKey::create(&shib_mysql_destroy_handle);
147
148   initialized = false;
149   mysqlInit(e,*log);
150   getMYSQL();
151   initialized = true;
152 }
153
154 MySQLBase::~MySQLBase()
155 {
156   delete m_mysql;
157 }
158
159 MYSQL* MySQLBase::getMYSQL()
160 {
161 #ifdef _DEBUG
162     xmltooling::NDC ndc("getMYSQL");
163 #endif
164
165     // Do we already have a handle?
166     MYSQL* mysql=reinterpret_cast<MYSQL*>(m_mysql->getData());
167     if (mysql)
168         return mysql;
169
170     // Connect to the database
171     mysql = mysql_init(NULL);
172     if (!mysql) {
173         log->error("mysql_init failed");
174         mysql_close(mysql);
175         throw SAMLException("MySQLBase::getMYSQL(): mysql_init() failed");
176     }
177
178     if (!mysql_real_connect(mysql, NULL, NULL, NULL, "shibd", 0, NULL, 0)) {
179         if (initialized) {
180             log->crit("mysql_real_connect failed: %s", mysql_error(mysql));
181             mysql_close(mysql);
182             throw SAMLException("MySQLBase::getMYSQL(): mysql_real_connect() failed");
183         }
184         else {
185             log->info("mysql_real_connect failed: %s.  Trying to create", mysql_error(mysql));
186
187             // This will throw an exception if it fails.
188             createDatabase(mysql, PLUGIN_VER_MAJOR, PLUGIN_VER_MINOR);
189         }
190     }
191
192     pair<int,int> v=getVersion (mysql);
193
194     // Make sure we've got the right version
195     if (v.first != PLUGIN_VER_MAJOR || v.second != PLUGIN_VER_MINOR) {
196    
197         // If we're capable, try upgrading on the fly...
198         if (v.first == 0  || v.first == 1 || v.first == 2) {
199             if (mysql_query(mysql, "DROP TABLE state")) {
200                 log->error("error dropping old session state table: %s", mysql_error(mysql));
201             }
202             if (v.first==2 && mysql_query(mysql, "DROP TABLE replay")) {
203                 log->error("error dropping old session state table: %s", mysql_error(mysql));
204             }
205             upgradeDatabase(mysql);
206         }
207         else {
208             mysql_close(mysql);
209             log->crit("Unknown database version: %d.%d", v.first, v.second);
210             throw SAMLException("MySQLBase::getMYSQL(): Unknown database version");
211         }
212     }
213
214     // We're all set.. Save off the handle for this thread.
215     m_mysql->setData(mysql);
216     return mysql;
217 }
218
219 bool MySQLBase::repairTable(MYSQL*& mysql, const char* table)
220 {
221     string q = string("REPAIR TABLE ") + table;
222     if (mysql_query(mysql, q.c_str())) {
223         log->error("Error repairing table %s: %s", table, mysql_error(mysql));
224         return false;
225     }
226
227     // seems we have to recycle the connection to get the thread to keep working
228     // other threads seem to be ok, but we should monitor that
229     mysql_close(mysql);
230     m_mysql->setData(NULL);
231     mysql=getMYSQL();
232     return true;
233 }
234
235 void MySQLBase::createDatabase(MYSQL* mysql, int major, int minor)
236 {
237   log->info("creating database");
238
239   MYSQL* ms = NULL;
240   try {
241     ms = mysql_init(NULL);
242     if (!ms) {
243       log->crit("mysql_init failed");
244       throw SAMLException("ShibMySQLCCache::createDatabase(): mysql_init failed");
245     }
246
247     if (!mysql_real_connect(ms, NULL, NULL, NULL, NULL, 0, NULL, 0)) {
248       log->crit("cannot open DB file to create DB: %s", mysql_error(ms));
249       throw SAMLException("ShibMySQLCCache::createDatabase(): mysql_real_connect failed");
250     }
251
252     if (mysql_query(ms, "CREATE DATABASE shibd")) {
253       log->crit("cannot create shibd database: %s", mysql_error(ms));
254       throw SAMLException("ShibMySQLCCache::createDatabase(): create db cmd failed");
255     }
256
257     if (!mysql_real_connect(mysql, NULL, NULL, NULL, "shibd", 0, NULL, 0)) {
258       log->crit("cannot open shibd database");
259       throw SAMLException("ShibMySQLCCache::createDatabase(): mysql_real_connect to plugin db failed");
260     }
261
262     mysql_close(ms);
263     
264   }
265   catch (SAMLException&) {
266     if (ms)
267       mysql_close(ms);
268     mysql_close(mysql);
269     throw;
270   }
271
272   // Now create the tables if they don't exist
273   log->info("Creating database tables");
274
275   if (mysql_query(mysql, "CREATE TABLE version (major INT, minor INT)")) {
276     log->error ("error creating version: %s", mysql_error(mysql));
277     throw SAMLException("ShibMySQLCCache::createDatabase(): create table cmd failed");
278   }
279
280   if (mysql_query(mysql,STATE_TABLE)) {
281     log->error ("error creating state table: %s", mysql_error(mysql));
282     throw SAMLException("ShibMySQLCCache::createDatabase(): create table cmd failed");
283   }
284
285   if (mysql_query(mysql,REPLAY_TABLE)) {
286     log->error ("error creating replay table: %s", mysql_error(mysql));
287     throw SAMLException("ShibMySQLCCache::createDatabase(): create table cmd failed");
288   }
289
290   ostringstream q;
291   q << "INSERT INTO version VALUES(" << major << "," << minor << ")";
292   if (mysql_query(mysql, q.str().c_str())) {
293     log->error ("error setting version: %s", mysql_error(mysql));
294     throw SAMLException("ShibMySQLCCache::createDatabase(): version insert failed");
295   }
296 }
297
298 void MySQLBase::upgradeDatabase(MYSQL* mysql)
299 {
300     if (mysql_query(mysql,STATE_TABLE)) {
301         log->error ("error creating state table: %s", mysql_error(mysql));
302         throw SAMLException("ShibMySQLCCache::upgradeDatabase(): error creating state table");
303     }
304
305     if (mysql_query(mysql,REPLAY_TABLE)) {
306         log->error ("error creating replay table: %s", mysql_error(mysql));
307         throw SAMLException("ShibMySQLCCache::upgradeDatabase(): error creating replay table");
308     }
309
310     ostringstream q;
311     q << "UPDATE version SET major = " << PLUGIN_VER_MAJOR;
312     if (mysql_query(mysql, q.str().c_str())) {
313         log->error ("error updating version: %s", mysql_error(mysql));
314         throw SAMLException("ShibMySQLCCache::upgradeDatabase(): error updating version");
315     }
316 }
317
318 pair<int,int> MySQLBase::getVersion(MYSQL* mysql)
319 {
320     // grab the version number from the database
321     if (mysql_query(mysql, "SELECT * FROM version")) {
322         log->error("error reading version: %s", mysql_error(mysql));
323         throw SAMLException("MySQLBase::getVersion(): error reading version");
324     }
325
326     MYSQL_RES* rows = mysql_store_result(mysql);
327     if (rows) {
328         if (mysql_num_rows(rows) == 1 && mysql_num_fields(rows) == 2)  {
329           MYSQL_ROW row = mysql_fetch_row(rows);
330           int major = row[0] ? atoi(row[0]) : -1;
331           int minor = row[1] ? atoi(row[1]) : -1;
332           log->debug("opening database version %d.%d", major, minor);
333           mysql_free_result(rows);
334           return make_pair(major,minor);
335         }
336         else {
337             // Wrong number of rows or wrong number of fields...
338             log->crit("Houston, we've got a problem with the database...");
339             mysql_free_result(rows);
340             throw SAMLException("MySQLBase::getVersion(): version verification failed");
341         }
342     }
343     log->crit("MySQL Read Failed in version verification");
344     throw SAMLException("MySQLBase::getVersion(): error reading version");
345 }
346
347 static void mysqlInit(const DOMElement* e, Category& log)
348 {
349     if (g_MySQLInitialized) {
350         log.info("MySQL embedded server already initialized");
351         return;
352     }
353     log.info("initializing MySQL embedded server");
354
355     // Setup the argument array
356     vector<string> arg_array;
357     arg_array.push_back("shibboleth");
358
359     // grab any MySQL parameters from the config file
360     e=XMLHelper::getFirstChildElement(e,Argument);
361     while (e) {
362         auto_ptr_char arg(e->getFirstChild()->getNodeValue());
363         if (arg.get())
364             arg_array.push_back(arg.get());
365         e=XMLHelper::getNextSiblingElement(e,Argument);
366     }
367
368     // Compute the argument array
369     vector<string>::size_type arg_count = arg_array.size();
370     const char** args=new const char*[arg_count];
371     for (vector<string>::size_type i = 0; i < arg_count; i++)
372         args[i] = arg_array[i].c_str();
373
374     // Initialize MySQL with the arguments
375     mysql_server_init(arg_count, (char **)args, NULL);
376
377     delete[] args;
378     g_MySQLInitialized = true;
379 }  
380
381 class ShibMySQLCCache : public MySQLBase, virtual public ISessionCache, virtual public ISessionCacheStore
382 {
383 public:
384     ShibMySQLCCache(const DOMElement* e);
385     virtual ~ShibMySQLCCache();
386
387     // Delegate all the ISessionCache methods.
388     string insert(
389         const IApplication* application,
390         const RoleDescriptor* role,
391         const char* client_addr,
392         const SAMLSubject* subject,
393         const char* authnContext,
394         const SAMLResponse* tokens
395         )
396     { return m_cache->insert(application,role,client_addr,subject,authnContext,tokens); }
397     ISessionCacheEntry* find(const char* key, const IApplication* application, const char* client_addr)
398     { return m_cache->find(key,application,client_addr); }
399     void remove(const char* key, const IApplication* application, const char* client_addr)
400     { m_cache->remove(key,application,client_addr); }
401
402     bool setBackingStore(ISessionCacheStore*) {return false;}
403
404     // Store methods handle the database work
405     HRESULT onCreate(
406         const char* key,
407         const IApplication* application,
408         const ISessionCacheEntry* entry,
409         int majorVersion,
410         int minorVersion,
411         time_t created
412         );
413     HRESULT onRead(
414         const char* key,
415         string& applicationId,
416         string& clientAddress,
417         string& providerId,
418         string& subject,
419         string& authnContext,
420         string& tokens,
421         int& majorVersion,
422         int& minorVersion,
423         time_t& created,
424         time_t& accessed
425         );
426     HRESULT onRead(const char* key, time_t& accessed);
427     HRESULT onRead(const char* key, string& tokens);
428     HRESULT onUpdate(const char* key, const char* tokens=NULL, time_t accessed=0);
429     HRESULT onDelete(const char* key);
430
431     void cleanup();
432
433 private:
434     bool m_storeAttributes;
435     ISessionCache* m_cache;
436     xmltooling::CondWait* shutdown_wait;
437     bool shutdown;
438     xmltooling::Thread* cleanup_thread;
439
440     static void* cleanup_fcn(void*); // XXX Assumed an ShibMySQLCCache
441 };
442
443 ShibMySQLCCache::ShibMySQLCCache(const DOMElement* e) : MySQLBase(e), m_storeAttributes(false)
444 {
445 #ifdef _DEBUG
446     xmltooling::NDC ndc("ShibMySQLCCache");
447 #endif
448
449     m_cache = dynamic_cast<ISessionCache*>(
450         SAMLConfig::getConfig().getPlugMgr().newPlugin(MEMORY_SESSIONCACHE, e)
451     );
452     if (!m_cache->setBackingStore(this)) {
453         delete m_cache;
454         throw SAMLException("Unable to register MySQL cache plugin as a cache store.");
455     }
456     
457     shutdown_wait = xmltooling::CondWait::create();
458     shutdown = false;
459
460     // Load our configuration details...
461     const XMLCh* tag=m_root->getAttributeNS(NULL,storeAttributes);
462     if (tag && *tag && (*tag==chLatin_t || *tag==chDigit_1))
463         m_storeAttributes=true;
464
465     // Initialize the cleanup thread
466     cleanup_thread = xmltooling::Thread::create(&cleanup_fcn, (void*)this);
467 }
468
469 ShibMySQLCCache::~ShibMySQLCCache()
470 {
471     shutdown = true;
472     shutdown_wait->signal();
473     cleanup_thread->join(NULL);
474     delete m_cache;
475 }
476
477 HRESULT ShibMySQLCCache::onCreate(
478     const char* key,
479     const IApplication* application,
480     const ISessionCacheEntry* entry,
481     int majorVersion,
482     int minorVersion,
483     time_t created
484     )
485 {
486 #ifdef _DEBUG
487     xmltooling::NDC ndc("onCreate");
488 #endif
489
490     // Get XML data from entry. Default is not to return SAML objects.
491     const char* context=entry->getAuthnContext();
492     pair<const char*,const SAMLSubject*> subject=entry->getSubject();
493     pair<const char*,const SAMLResponse*> tokens=entry->getTokens();
494
495     ostringstream q;
496     q << "INSERT INTO state VALUES('" << key << "','" << application->getId() << "',";
497     if (created==0)
498         q << "NOW(),NOW(),'";
499     else
500         q << "FROM_UNIXTIME(" << created << "),NOW(),'";
501     q << entry->getClientAddress() << "'," << majorVersion << "," << minorVersion << ",'" << entry->getProviderId() << "','"
502         << subject.first << "','" << context << "',";
503
504     if (m_storeAttributes && tokens.first)
505         q << "'" << tokens.first << "')";
506     else
507         q << "null)";
508
509     if (log->isDebugEnabled())
510         log->debug("SQL insert: %s", q.str().c_str());
511
512     // then add it to the database
513     MYSQL* mysql = getMYSQL();
514     if (mysql_query(mysql, q.str().c_str())) {
515         const char* err=mysql_error(mysql);
516         log->error("error inserting %s: %s", key, err);
517         if (isCorrupt(err) && repairTable(mysql,"state")) {
518             // Try again...
519             if (mysql_query(mysql, q.str().c_str())) {
520                 log->error("error inserting %s: %s", key, mysql_error(mysql));
521                 return E_FAIL;
522             }
523         }
524         else
525             throw E_FAIL;
526     }
527
528     return NOERROR;
529 }
530
531 HRESULT ShibMySQLCCache::onRead(
532     const char* key,
533     string& applicationId,
534     string& clientAddress,
535     string& providerId,
536     string& subject,
537     string& authnContext,
538     string& tokens,
539     int& majorVersion,
540     int& minorVersion,
541     time_t& created,
542     time_t& accessed
543     )
544 {
545 #ifdef _DEBUG
546     xmltooling::NDC ndc("onRead");
547 #endif
548
549     log->debug("searching MySQL database...");
550
551     string q = string("SELECT application_id,UNIX_TIMESTAMP(ctime),UNIX_TIMESTAMP(atime),addr,major,minor,provider,subject,authn_context,tokens FROM state WHERE cookie='") + key + "' LIMIT 1";
552
553     MYSQL* mysql = getMYSQL();
554     if (mysql_query(mysql, q.c_str())) {
555         const char* err=mysql_error(mysql);
556         log->error("error searching for %s: %s", key, err);
557         if (isCorrupt(err) && repairTable(mysql,"state")) {
558             if (mysql_query(mysql, q.c_str()))
559                 log->error("error retrying search for %s: %s", key, mysql_error(mysql));
560         }
561     }
562
563     MYSQL_RES* rows = mysql_store_result(mysql);
564
565     // Nope, doesn't exist.
566     if (!rows || mysql_num_rows(rows)==0) {
567         log->debug("not found in database");
568         if (rows)
569             mysql_free_result(rows);
570         return S_FALSE;
571     }
572
573     // Make sure we got 1 and only 1 row.
574     if (mysql_num_rows(rows) > 1) {
575         log->error("database select returned %d rows!", mysql_num_rows(rows));
576         mysql_free_result(rows);
577         return E_FAIL;
578     }
579
580     log->debug("session found, tranfering data back into memory");
581     
582     /* Columns in query:
583         0: application_id
584         1: ctime
585         2: atime
586         3: address
587         4: major
588         5: minor
589         6: provider
590         7: subject
591         8: authncontext
592         9: tokens
593      */
594
595     MYSQL_ROW row = mysql_fetch_row(rows);
596     applicationId=row[0];
597     created=atoi(row[1]);
598     accessed=atoi(row[2]);
599     clientAddress=row[3];
600     majorVersion=atoi(row[4]);
601     minorVersion=atoi(row[5]);
602     providerId=row[6];
603     subject=row[7];
604     authnContext=row[8];
605     if (row[9])
606         tokens=row[9];
607
608     // Free the results.
609     mysql_free_result(rows);
610
611     return NOERROR;
612 }
613
614 HRESULT ShibMySQLCCache::onRead(const char* key, time_t& accessed)
615 {
616 #ifdef _DEBUG
617     xmltooling::NDC ndc("onRead");
618 #endif
619
620     log->debug("reading last access time from MySQL database");
621
622     string q = string("SELECT UNIX_TIMESTAMP(atime) FROM state WHERE cookie='") + key + "' LIMIT 1";
623
624     MYSQL* mysql = getMYSQL();
625     if (mysql_query(mysql, q.c_str())) {
626         const char* err=mysql_error(mysql);
627         log->error("error searching for %s: %s", key, err);
628         if (isCorrupt(err) && repairTable(mysql,"state")) {
629             if (mysql_query(mysql, q.c_str()))
630                 log->error("error retrying search for %s: %s", key, mysql_error(mysql));
631         }
632     }
633
634     MYSQL_RES* rows = mysql_store_result(mysql);
635
636     // Nope, doesn't exist.
637     if (!rows || mysql_num_rows(rows)==0) {
638         log->warn("session expected, but not found in database");
639         if (rows)
640             mysql_free_result(rows);
641         return S_FALSE;
642     }
643
644     // Make sure we got 1 and only 1 row.
645     if (mysql_num_rows(rows) != 1) {
646         log->error("database select returned %d rows!", mysql_num_rows(rows));
647         mysql_free_result(rows);
648         return E_FAIL;
649     }
650
651     MYSQL_ROW row = mysql_fetch_row(rows);
652     accessed=atoi(row[0]);
653
654     // Free the results.
655     mysql_free_result(rows);
656
657     return NOERROR;
658 }
659
660 HRESULT ShibMySQLCCache::onRead(const char* key, string& tokens)
661 {
662 #ifdef _DEBUG
663     xmltooling::NDC ndc("onRead");
664 #endif
665
666     if (!m_storeAttributes)
667         return S_FALSE;
668
669     log->debug("reading cached tokens from MySQL database");
670
671     string q = string("SELECT tokens FROM state WHERE cookie='") + key + "' LIMIT 1";
672
673     MYSQL* mysql = getMYSQL();
674     if (mysql_query(mysql, q.c_str())) {
675         const char* err=mysql_error(mysql);
676         log->error("error searching for %s: %s", key, err);
677         if (isCorrupt(err) && repairTable(mysql,"state")) {
678             if (mysql_query(mysql, q.c_str()))
679                 log->error("error retrying search for %s: %s", key, mysql_error(mysql));
680         }
681     }
682
683     MYSQL_RES* rows = mysql_store_result(mysql);
684
685     // Nope, doesn't exist.
686     if (!rows || mysql_num_rows(rows)==0) {
687         log->warn("session expected, but not found in database");
688         if (rows)
689             mysql_free_result(rows);
690         return S_FALSE;
691     }
692
693     // Make sure we got 1 and only 1 row.
694     if (mysql_num_rows(rows) != 1) {
695         log->error("database select returned %d rows!", mysql_num_rows(rows));
696         mysql_free_result(rows);
697         return E_FAIL;
698     }
699
700     MYSQL_ROW row = mysql_fetch_row(rows);
701     if (row[0])
702         tokens=row[0];
703
704     // Free the results.
705     mysql_free_result(rows);
706
707     return NOERROR;
708 }
709
710 HRESULT ShibMySQLCCache::onUpdate(const char* key, const char* tokens, time_t lastAccess)
711 {
712 #ifdef _DEBUG
713     xmltooling::NDC ndc("onUpdate");
714 #endif
715
716     ostringstream q;
717     if (lastAccess>0)
718         q << "UPDATE state SET atime=FROM_UNIXTIME(" << lastAccess << ")";
719     else if (tokens) {
720         if (!m_storeAttributes)
721             return S_FALSE;
722         q << "UPDATE state SET tokens=";
723         if (*tokens)
724             q << "'" << tokens << "'";
725         else
726             q << "null";
727     }
728     else {
729         log->warn("onUpdate called with nothing to do!");
730         return S_FALSE;
731     }
732  
733     q << " WHERE cookie='" << key << "'";
734
735     MYSQL* mysql = getMYSQL();
736     if (mysql_query(mysql, q.str().c_str())) {
737         const char* err=mysql_error(mysql);
738         log->error("error updating %s: %s", key, err);
739         if (isCorrupt(err) && repairTable(mysql,"state")) {
740             // Try again...
741             if (mysql_query(mysql, q.str().c_str())) {
742                 log->error("error updating %s: %s", key, mysql_error(mysql));
743                 return E_FAIL;
744             }
745         }
746         else
747             return E_FAIL;
748     }
749
750     return NOERROR;
751 }
752
753 HRESULT ShibMySQLCCache::onDelete(const char* key)
754 {
755 #ifdef _DEBUG
756     xmltooling::NDC ndc("onDelete");
757 #endif
758
759     // Remove from the database
760     string q = string("DELETE FROM state WHERE cookie='") + key + "'";
761     MYSQL* mysql = getMYSQL();
762     if (mysql_query(mysql, q.c_str())) {
763         const char* err=mysql_error(mysql);
764         log->error("error deleting entry %s: %s", key, err);
765         if (isCorrupt(err) && repairTable(mysql,"state")) {
766             // Try again...
767             if (mysql_query(mysql, q.c_str())) {
768                 log->error("error deleting entry %s: %s", key, mysql_error(mysql));
769                 return E_FAIL;
770             }
771         }
772         else
773             return E_FAIL;
774     }
775
776     return NOERROR;
777 }
778
779 void ShibMySQLCCache::cleanup()
780 {
781 #ifdef _DEBUG
782   xmltooling::NDC ndc("cleanup");
783 #endif
784
785   xmltooling::Mutex* mutex = xmltooling::Mutex::create();
786
787   int rerun_timer = 0;
788   int timeout_life = 0;
789
790   // Load our configuration details...
791   const XMLCh* tag=m_root->getAttributeNS(NULL,cleanupInterval);
792   if (tag && *tag)
793     rerun_timer = XMLString::parseInt(tag);
794
795   // search for 'mysql-cache-timeout' and then the regular cache timeout
796   tag=m_root->getAttributeNS(NULL,mysqlTimeout);
797   if (tag && *tag)
798     timeout_life = XMLString::parseInt(tag);
799   else {
800       tag=m_root->getAttributeNS(NULL,cacheTimeout);
801       if (tag && *tag)
802         timeout_life = XMLString::parseInt(tag);
803   }
804   
805   if (rerun_timer <= 0)
806     rerun_timer = 300;          // rerun every 5 minutes
807
808   if (timeout_life <= 0)
809     timeout_life = 28800;       // timeout after 8 hours
810
811   mutex->lock();
812
813   MYSQL* mysql = getMYSQL();
814
815   log->info("cleanup thread started...Run every %d secs; timeout after %d secs", rerun_timer, timeout_life);
816
817   while (shutdown == false) {
818     shutdown_wait->timedwait(mutex, rerun_timer);
819
820     if (shutdown == true)
821       break;
822
823     // Find all the entries in the database that haven't been used
824     // recently In particular, find all entries that have not been
825     // accessed in 'timeout_life' seconds.
826     ostringstream q;
827     q << "DELETE FROM state WHERE " << "UNIX_TIMESTAMP(NOW()) - UNIX_TIMESTAMP(atime) >= " << timeout_life;
828
829     if (mysql_query(mysql, q.str().c_str())) {
830       const char* err=mysql_error(mysql);
831       log->error("error purging old records: %s", err);
832         if (isCorrupt(err) && repairTable(mysql,"state")) {
833           if (mysql_query(mysql, q.str().c_str()))
834             log->error("error re-purging old records: %s", mysql_error(mysql));
835         }
836     }
837   }
838
839   log->info("cleanup thread exiting...");
840
841   mutex->unlock();
842   delete mutex;
843   xmltooling::Thread::exit(NULL);
844 }
845
846 void* ShibMySQLCCache::cleanup_fcn(void* cache_p)
847 {
848   ShibMySQLCCache* cache = (ShibMySQLCCache*)cache_p;
849
850 #ifndef WIN32
851   // First, let'block all signals
852   xmltooling::Thread::mask_all_signals();
853 #endif
854
855   // Now run the cleanup process.
856   cache->cleanup();
857   return NULL;
858 }
859
860 class MySQLReplayCache : public MySQLBase, virtual public IReplayCache
861 {
862 public:
863   MySQLReplayCache(const DOMElement* e);
864   virtual ~MySQLReplayCache() {}
865
866   bool check(const XMLCh* str, time_t expires) {auto_ptr_XMLCh temp(str); return check(temp.get(),expires);}
867   bool check(const char* str, time_t expires);
868 };
869
870 MySQLReplayCache::MySQLReplayCache(const DOMElement* e) : MySQLBase(e) {}
871
872 bool MySQLReplayCache::check(const char* str, time_t expires)
873 {
874 #ifdef _DEBUG
875     xmltooling::NDC ndc("check");
876 #endif
877   
878     // Remove expired entries
879     string q = string("DELETE FROM replay WHERE expires < NOW()");
880     MYSQL* mysql = getMYSQL();
881     if (mysql_query(mysql, q.c_str())) {
882         const char* err=mysql_error(mysql);
883         log->error("Error deleting expired entries: %s", err);
884         if (isCorrupt(err) && repairTable(mysql,"replay")) {
885             // Try again...
886             if (mysql_query(mysql, q.c_str()))
887                 log->error("Error deleting expired entries: %s", mysql_error(mysql));
888         }
889     }
890   
891     string q2 = string("SELECT id FROM replay WHERE id='") + str + "'";
892     if (mysql_query(mysql, q2.c_str())) {
893         const char* err=mysql_error(mysql);
894         log->error("Error searching for %s: %s", str, err);
895         if (isCorrupt(err) && repairTable(mysql,"replay")) {
896             if (mysql_query(mysql, q2.c_str())) {
897                 log->error("Error retrying search for %s: %s", str, mysql_error(mysql));
898                 throw SAMLException("Replay cache failed, please inform application support staff.");
899             }
900         }
901         else
902             throw SAMLException("Replay cache failed, please inform application support staff.");
903     }
904
905     // Did we find it?
906     MYSQL_RES* rows = mysql_store_result(mysql);
907     if (rows && mysql_num_rows(rows)>0) {
908       mysql_free_result(rows);
909       return false;
910     }
911
912     ostringstream q3;
913     q3 << "INSERT INTO replay VALUES('" << str << "'," << "FROM_UNIXTIME(" << expires << "))";
914
915     // then add it to the database
916     if (mysql_query(mysql, q3.str().c_str())) {
917         const char* err=mysql_error(mysql);
918         log->error("Error inserting %s: %s", str, err);
919         if (isCorrupt(err) && repairTable(mysql,"state")) {
920             // Try again...
921             if (mysql_query(mysql, q3.str().c_str())) {
922                 log->error("Error inserting %s: %s", str, mysql_error(mysql));
923                 throw SAMLException("Replay cache failed, please inform application support staff.");
924             }
925         }
926         else
927             throw SAMLException("Replay cache failed, please inform application support staff.");
928     }
929     
930     return true;
931 }
932
933 /*************************************************************************
934  * The registration functions here...
935  */
936
937 IPlugIn* new_mysql_ccache(const DOMElement* e)
938 {
939     return new ShibMySQLCCache(e);
940 }
941
942 IPlugIn* new_mysql_replay(const DOMElement* e)
943 {
944     return new MySQLReplayCache(e);
945 }
946
947 extern "C" int SHIBMYSQL_EXPORTS saml_extension_init(void*)
948 {
949     // register this ccache type
950     SAMLConfig::getConfig().getPlugMgr().regFactory(MYSQL_REPLAYCACHE, &new_mysql_replay);
951     SAMLConfig::getConfig().getPlugMgr().regFactory(MYSQL_SESSIONCACHE, &new_mysql_ccache);
952     return 0;
953 }
954
955 extern "C" void SHIBMYSQL_EXPORTS saml_extension_term()
956 {
957     // Shutdown MySQL
958     if (g_MySQLInitialized)
959         mysql_server_end();
960     SAMLConfig::getConfig().getPlugMgr().unregFactory(MYSQL_REPLAYCACHE);
961     SAMLConfig::getConfig().getPlugMgr().unregFactory(MYSQL_SESSIONCACHE);
962 }