Log adjustments.
[shibboleth/sp.git] / shib-mysql-ccache / shib-mysql-ccache.cpp
1 /*
2  * shib-mysql-ccache.cpp: Shibboleth Credential Cache using MySQL.
3  *
4  * Created by:  Derek Atkins <derek@ihtfp.com>
5  *
6  * $Id$
7  */
8
9 /* This file is loosely based off the Shibboleth Credential Cache.
10  * This plug-in is designed as a two-layer cache.  Layer 1, the
11  * long-term cache, stores data in a MySQL embedded database.
12  *
13  * Short-term data is cached in memory as SAML objects in the layer 2
14  * cache.
15  */
16
17 // eventually we might be able to support autoconf via cygwin...
18 #if defined (_MSC_VER) || defined(__BORLANDC__)
19 # include "config_win32.h"
20 #else
21 # include "config.h"
22 #endif
23
24 #ifdef WIN32
25 # define SHIBMYSQL_EXPORTS __declspec(dllexport)
26 #else
27 # define SHIBMYSQL_EXPORTS
28 #endif
29
30 #ifdef HAVE_UNISTD_H
31 # include <unistd.h>
32 #endif
33
34 #include <shib-target/shib-target.h>
35 #include <shib/shib-threads.h>
36 #include <log4cpp/Category.hh>
37
38 #include <sstream>
39 #include <stdexcept>
40
41 #include <mysql.h>
42
43 // wanted to use MySQL codes for this, but can't seem to get back a 145
44 #define isCorrupt(s) strstr(s,"(errno: 145)")
45
46 #ifdef HAVE_LIBDMALLOCXX
47 #include <dmalloc.h>
48 #endif
49
50 using namespace std;
51 using namespace saml;
52 using namespace shibboleth;
53 using namespace shibtarget;
54 using namespace log4cpp;
55
56 #define PLUGIN_VER_MAJOR 2
57 #define PLUGIN_VER_MINOR 0
58
59 #define STATE_TABLE \
60   "CREATE TABLE state (cookie VARCHAR(64) PRIMARY KEY, " \
61   "application_id VARCHAR(255)," \
62   "ctime TIMESTAMP," \
63   "atime TIMESTAMP," \
64   "addr VARCHAR(128)," \
65   "profile INT," \
66   "provider VARCHAR(256)," \
67   "response_id VARCHAR(128)," \
68   "response TEXT," \
69   "statement TEXT)"
70
71 #define REPLAY_TABLE \
72   "CREATE TABLE replay (id VARCHAR(255) PRIMARY KEY, " \
73   "expires TIMESTAMP, " \
74   "INDEX (expires))"
75
76 static const XMLCh Argument[] =
77 { chLatin_A, chLatin_r, chLatin_g, chLatin_u, chLatin_m, chLatin_e, chLatin_n, chLatin_t, chNull };
78 static const XMLCh cleanupInterval[] =
79 { chLatin_c, chLatin_l, chLatin_e, chLatin_a, chLatin_n, chLatin_u, chLatin_p,
80   chLatin_I, chLatin_n, chLatin_t, chLatin_e, chLatin_r, chLatin_v, chLatin_a, chLatin_l, chNull
81 };
82 static const XMLCh cacheTimeout[] =
83 { chLatin_c, chLatin_a, chLatin_c, chLatin_h, chLatin_e, chLatin_T, chLatin_i, chLatin_m, chLatin_e, chLatin_o, chLatin_u, chLatin_t, chNull };
84 static const XMLCh mysqlTimeout[] =
85 { chLatin_m, chLatin_y, chLatin_s, chLatin_q, chLatin_l, chLatin_T, chLatin_i, chLatin_m, chLatin_e, chLatin_o, chLatin_u, chLatin_t, chNull };
86 static const XMLCh storeAttributes[] =
87 { chLatin_s, chLatin_t, chLatin_o, chLatin_r, chLatin_e, chLatin_A, chLatin_t, chLatin_t, chLatin_r, chLatin_i, chLatin_b, chLatin_u, chLatin_t, chLatin_e, chLatin_s, chNull };
88
89 class MySQLBase
90 {
91 public:
92   MySQLBase(const DOMElement* e);
93   virtual ~MySQLBase();
94
95   void thread_init();
96   void thread_end() {}
97
98   MYSQL* getMYSQL() const;
99   bool repairTable(MYSQL*&, const char* table);
100
101   log4cpp::Category* log;
102
103 protected:
104   ThreadKey* m_mysql;
105   const DOMElement* m_root; // can only use this during initialization
106
107   bool initialized;
108
109   void createDatabase(MYSQL*, int major, int minor);
110   void upgradeDatabase(MYSQL*);
111   void getVersion(MYSQL*, int* major_p, int* minor_p);
112 };
113
114 // Forward declarations
115 static void mysqlInit(const DOMElement* e, Category& log);
116
117 extern "C" void shib_mysql_destroy_handle(void* data)
118 {
119   MYSQL* mysql = (MYSQL*) data;
120   if (mysql) mysql_close(mysql);
121 }
122
123 MySQLBase::MySQLBase(const DOMElement* e) : m_root(e)
124 {
125 #ifdef _DEBUG
126   saml::NDC ndc("MySQLBase");
127 #endif
128   log = &(Category::getInstance("shibmysql.MySQLBase"));
129
130   m_mysql = ThreadKey::create(&shib_mysql_destroy_handle);
131
132   initialized = false;
133   mysqlInit(e,*log);
134   thread_init();
135   initialized = true;
136 }
137
138 MySQLBase::~MySQLBase()
139 {
140   thread_end();
141
142   delete m_mysql;
143 }
144
145 MYSQL* MySQLBase::getMYSQL() const
146 {
147   return (MYSQL*)m_mysql->getData();
148 }
149
150 void MySQLBase::thread_init()
151 {
152 #ifdef _DEBUG
153   saml::NDC ndc("thread_init");
154 #endif
155
156   // Connect to the database
157   MYSQL* mysql = mysql_init(NULL);
158   if (!mysql) {
159     log->error("mysql_init failed");
160     mysql_close(mysql);
161     throw SAMLException("MySQLBase::thread_init(): mysql_init() failed");
162   }
163
164   if (!mysql_real_connect(mysql, NULL, NULL, NULL, "shar", 0, NULL, 0)) {
165     if (initialized) {
166       log->crit("mysql_real_connect failed: %s", mysql_error(mysql));
167       mysql_close(mysql);
168       throw SAMLException("MySQLBase::thread_init(): mysql_real_connect() failed");
169     } else {
170       log->info("mysql_real_connect failed: %s.  Trying to create", mysql_error(mysql));
171
172       // This will throw an exception if it fails.
173       createDatabase(mysql, PLUGIN_VER_MAJOR, PLUGIN_VER_MINOR);
174     }
175   }
176
177   int major = -1, minor = -1;
178   getVersion (mysql, &major, &minor);
179
180   // Make sure we've got the right version
181   if (major != PLUGIN_VER_MAJOR || minor != PLUGIN_VER_MINOR) {
182    
183     // If we're capable, try upgrading on the fly...
184     if (major == 0  || major == 1) {
185        upgradeDatabase(mysql);
186     }
187     else {
188         mysql_close(mysql);
189         log->crit("Unknown database version: %d.%d", major, minor);
190         throw SAMLException("MySQLBase::thread_init(): Unknown database version");
191     }
192   }
193
194   // We're all set.. Save off the handle for this thread.
195   m_mysql->setData(mysql);
196 }
197
198 bool MySQLBase::repairTable(MYSQL*& mysql, const char* table)
199 {
200   string q = string("REPAIR TABLE ") + table;
201   if (mysql_query(mysql, q.c_str())) {
202     log->error("Error repairing table %s: %s", table, mysql_error(mysql));
203     return false;
204   }
205
206   // seems we have to recycle the connection to get the thread to keep working
207   // other threads seem to be ok, but we should monitor that
208   mysql_close(mysql);
209   m_mysql->setData(NULL);
210   thread_init();
211   mysql=getMYSQL();
212   return true;
213 }
214
215 void MySQLBase::createDatabase(MYSQL* mysql, int major, int minor)
216 {
217   log->info("Creating database.");
218
219   MYSQL* ms = NULL;
220   try {
221     ms = mysql_init(NULL);
222     if (!ms) {
223       log->crit("mysql_init failed");
224       throw SAMLException("ShibMySQLCCache::createDatabase(): mysql_init failed");
225     }
226
227     if (!mysql_real_connect(ms, NULL, NULL, NULL, NULL, 0, NULL, 0)) {
228       log->crit("cannot open DB file to create DB: %s", mysql_error(ms));
229       throw SAMLException("ShibMySQLCCache::createDatabase(): mysql_real_connect failed");
230     }
231
232     if (mysql_query(ms, "CREATE DATABASE shar")) {
233       log->crit("cannot create shar database: %s", mysql_error(ms));
234       throw SAMLException("ShibMySQLCCache::createDatabase(): create db cmd failed");
235     }
236
237     if (!mysql_real_connect(mysql, NULL, NULL, NULL, "shar", 0, NULL, 0)) {
238       log->crit("cannot open SHAR database");
239       throw SAMLException("ShibMySQLCCache::createDatabase(): mysql_real_connect to plugin db failed");
240     }
241
242     mysql_close(ms);
243     
244   }
245   catch (SAMLException&) {
246     if (ms)
247       mysql_close(ms);
248     mysql_close(mysql);
249     throw;
250   }
251
252   // Now create the tables if they don't exist
253   log->info("Creating database tables");
254
255   if (mysql_query(mysql, "CREATE TABLE version (major INT, minor INT)")) {
256     log->error ("Error creating version: %s", mysql_error(mysql));
257     throw SAMLException("ShibMySQLCCache::createDatabase(): create table cmd failed");
258   }
259
260   if (mysql_query(mysql,STATE_TABLE)) {
261     log->error ("Error creating state table: %s", mysql_error(mysql));
262     throw SAMLException("ShibMySQLCCache::createDatabase(): create table cmd failed");
263   }
264
265   if (mysql_query(mysql,REPLAY_TABLE)) {
266     log->error ("Error creating replay table: %s", mysql_error(mysql));
267     throw SAMLException("ShibMySQLCCache::createDatabase(): create table cmd failed");
268   }
269
270   ostringstream q;
271   q << "INSERT INTO version VALUES(" << major << "," << minor << ")";
272   if (mysql_query(mysql, q.str().c_str())) {
273     log->error ("Error setting version: %s", mysql_error(mysql));
274     throw SAMLException("ShibMySQLCCache::createDatabase(): version insert failed");
275   }
276 }
277
278 void MySQLBase::upgradeDatabase(MYSQL* mysql)
279 {
280     if (mysql_query(mysql, "DROP TABLE state")) {
281         log->error("Error dropping old session state table: %s", mysql_error(mysql));
282     }
283
284     if (mysql_query(mysql,STATE_TABLE)) {
285         log->error ("Error creating state table: %s", mysql_error(mysql));
286         throw SAMLException("ShibMySQLCCache::upgradeDatabase(): error creating state table");
287     }
288
289     if (mysql_query(mysql,REPLAY_TABLE)) {
290         log->error ("Error creating replay table: %s", mysql_error(mysql));
291         throw SAMLException("ShibMySQLCCache::upgradeDatabase(): error creating replay table");
292     }
293
294     ostringstream q;
295     q << "UPDATE version SET major = " << PLUGIN_VER_MAJOR;
296     if (mysql_query(mysql, q.str().c_str())) {
297         log->error ("Error updating version: %s", mysql_error(mysql));
298         throw SAMLException("ShibMySQLCCache::upgradeDatabase(): error updating version");
299     }
300 }
301
302 void MySQLBase::getVersion(MYSQL* mysql, int* major_p, int* minor_p)
303 {
304   // grab the version number from the database
305   if (mysql_query(mysql, "SELECT * FROM version"))
306     log->error ("Error reading version: %s", mysql_error(mysql));
307
308   MYSQL_RES* rows = mysql_store_result(mysql);
309   if (rows) {
310     if (mysql_num_rows(rows) == 1 && mysql_num_fields(rows) == 2)  {
311       MYSQL_ROW row = mysql_fetch_row(rows);
312
313       int major = row[0] ? atoi(row[0]) : -1;
314       int minor = row[1] ? atoi(row[1]) : -1;
315       log->debug("opening database version %d.%d", major, minor);
316       
317       mysql_free_result (rows);
318
319       *major_p = major;
320       *minor_p = minor;
321       return;
322
323     } else {
324       // Wrong number of rows or wrong number of fields...
325
326       log->crit("Houston, we've got a problem with the database...");
327       mysql_free_result (rows);
328       throw SAMLException("ShibMySQLCCache::getVersion(): version verification failed");
329     }
330   }
331   log->crit("MySQL Read Failed in version verificatoin");
332   throw SAMLException("ShibMySQLCCache::getVersion(): error reading version");
333 }
334
335 static void mysqlInit(const DOMElement* e, Category& log)
336 {
337   static bool done = false;
338   if (done) {
339     log.info("MySQL embedded server already initialized");
340     return;
341   }
342   log.info("initializing MySQL embedded server");
343
344   // Setup the argument array
345   vector<string> arg_array;
346   arg_array.push_back("shibboleth");
347
348   // grab any MySQL parameters from the config file
349   e=saml::XML::getFirstChildElement(e,shibtarget::XML::SHIBTARGET_NS,Argument);
350   while (e) {
351       auto_ptr_char arg(e->getFirstChild()->getNodeValue());
352       if (arg.get())
353           arg_array.push_back(arg.get());
354       e=saml::XML::getNextSiblingElement(e,shibtarget::XML::SHIBTARGET_NS,Argument);
355   }
356
357   // Compute the argument array
358   int arg_count = arg_array.size();
359   const char** args=new const char*[arg_count];
360   for (int i = 0; i < arg_count; i++)
361     args[i] = arg_array[i].c_str();
362
363   // Initialize MySQL with the arguments
364   mysql_server_init(arg_count, (char **)args, NULL);
365
366   delete[] args;
367   done = true;
368 }  
369
370 class ShibMySQLCCache;
371 class ShibMySQLCCacheEntry : public ISessionCacheEntry
372 {
373 public:
374   ShibMySQLCCacheEntry(const char* key, ISessionCacheEntry* entry, ShibMySQLCCache* cache)
375     : m_cacheEntry(entry), m_key(key), m_cache(cache), m_responseId(NULL) {}
376   ~ShibMySQLCCacheEntry() {if (m_responseId) XMLString::release(&m_responseId);}
377
378   virtual void lock() {}
379   virtual void unlock() { m_cacheEntry->unlock(); delete this; }
380   virtual bool isValid(time_t lifetime, time_t timeout) const;
381   virtual const char* getClientAddress() const { return m_cacheEntry->getClientAddress(); }
382   virtual ShibProfile getProfile() const { return m_cacheEntry->getProfile(); }
383   virtual const char* getProviderId() const { return m_cacheEntry->getProviderId(); }
384   virtual const SAMLAuthenticationStatement* getAuthnStatement() const { return m_cacheEntry->getAuthnStatement(); }
385   virtual CachedResponse getResponse();
386
387 private:
388   bool touch() const;
389
390   ShibMySQLCCache* m_cache;
391   ISessionCacheEntry* m_cacheEntry;
392   string m_key;
393   XMLCh* m_responseId;
394 };
395
396 class ShibMySQLCCache : public MySQLBase, virtual public ISessionCache
397 {
398 public:
399   ShibMySQLCCache(const DOMElement* e);
400   virtual ~ShibMySQLCCache();
401
402   virtual void thread_init() {MySQLBase::thread_init();}
403   virtual void thread_end() {MySQLBase::thread_end();}
404
405   virtual string generateKey() const {return m_cache->generateKey();}
406   virtual ISessionCacheEntry* find(const char* key, const IApplication* application);
407   virtual void insert(
408     const char* key,
409     const IApplication* application,
410     const char* client_addr,
411     ShibProfile profile,
412     const char* providerId,
413     saml::SAMLAuthenticationStatement* s,
414     saml::SAMLResponse* r=NULL,
415     const shibboleth::IRoleDescriptor* source=NULL,
416     time_t created=0,
417     time_t accessed=0
418     );
419   virtual void remove(const char* key);
420
421   virtual void cleanup();
422
423   bool m_storeAttributes;
424
425 private:
426   ISessionCache* m_cache;
427   CondWait* shutdown_wait;
428   bool shutdown;
429   Thread* cleanup_thread;
430
431   static void* cleanup_fcn(void*); // XXX Assumed an ShibMySQLCCache
432 };
433
434 ShibMySQLCCache::ShibMySQLCCache(const DOMElement* e) : MySQLBase(e), m_storeAttributes(false)
435 {
436 #ifdef _DEBUG
437   saml::NDC ndc("ShibMySQLCCache");
438 #endif
439
440   log = &(Category::getInstance("shibmysql.SessionCache"));
441
442   shutdown_wait = CondWait::create();
443   shutdown = false;
444
445   m_cache = dynamic_cast<ISessionCache*>(
446       SAMLConfig::getConfig().getPlugMgr().newPlugin(
447         "edu.internet2.middleware.shibboleth.sp.provider.MemorySessionCacheProvider", e
448         )
449     );
450     
451   // Load our configuration details...
452   const XMLCh* tag=m_root->getAttributeNS(NULL,storeAttributes);
453   if (tag && *tag && (*tag==chLatin_t || *tag==chDigit_1))
454     m_storeAttributes=true;
455
456   // Initialize the cleanup thread
457   cleanup_thread = Thread::create(&cleanup_fcn, (void*)this);
458 }
459
460 ShibMySQLCCache::~ShibMySQLCCache()
461 {
462   shutdown = true;
463   shutdown_wait->signal();
464   cleanup_thread->join(NULL);
465
466   delete m_cache;
467 }
468
469 ISessionCacheEntry* ShibMySQLCCache::find(const char* key, const IApplication* application)
470 {
471 #ifdef _DEBUG
472   saml::NDC ndc("find");
473 #endif
474
475   ISessionCacheEntry* res = m_cache->find(key, application);
476   if (!res) {
477
478     log->debug("Looking in database...");
479
480     // nothing cached; see if this exists in the database
481     string q = string("SELECT application_id,UNIX_TIMESTAMP(ctime),UNIX_TIMESTAMP(atime),addr,profile,provider,statement,response FROM state WHERE cookie='") + key + "' LIMIT 1";
482
483     MYSQL* mysql = getMYSQL();
484     if (mysql_query(mysql, q.c_str())) {
485       const char* err=mysql_error(mysql);
486       log->error("Error searching for %s: %s", key, err);
487       if (isCorrupt(err) && repairTable(mysql,"state")) {
488         if (mysql_query(mysql, q.c_str()))
489           log->error("Error retrying search for %s: %s", key, mysql_error(mysql));
490       }
491     }
492
493     MYSQL_RES* rows = mysql_store_result(mysql);
494
495     // Nope, doesn't exist.
496     if (!rows)
497       return NULL;
498
499     // Make sure we got 1 and only 1 rows.
500     if (mysql_num_rows(rows) != 1) {
501       log->error("Select returned wrong number of rows: %d", mysql_num_rows(rows));
502       mysql_free_result(rows);
503       return NULL;
504     }
505
506     log->debug("Match found.  Parsing...");
507     
508     /* Columns in query:
509         0: application_id
510         1: ctime
511         2: atime
512         3: address
513         4: profile
514         5: provider
515         6: statement
516         7: response
517      */
518
519     // Pull apart the row and process the results
520     MYSQL_ROW row = mysql_fetch_row(rows);
521     if (strcmp(application->getId(),row[0])) {
522         log->crit("An application (%s) attempted to access another application's (%s) session!", application->getId(), row[0]);
523         mysql_free_result(rows);
524         return NULL;
525     }
526
527     Metadata m(application->getMetadataProviders());
528     const IEntityDescriptor* provider=m.lookup(row[5]);
529     if (!provider) {
530         log->crit("no metadata found for identity provider (%s) responsible for the session.", row[5]);
531         mysql_free_result(rows);
532         return NULL;
533     }
534
535     SAMLAuthenticationStatement* s=NULL;
536     SAMLResponse* r=NULL;
537     ShibProfile profile=static_cast<ShibProfile>(atoi(row[4]));
538     const IRoleDescriptor* role=NULL;
539     if (profile==SAML11_POST || profile==SAML11_ARTIFACT)
540         role=provider->getIDPSSODescriptor(saml::XML::SAML11_PROTOCOL_ENUM);
541     else if (profile==SAML10_POST || profile==SAML10_ARTIFACT)
542         role=provider->getIDPSSODescriptor(saml::XML::SAML10_PROTOCOL_ENUM);
543     if (!role) {
544         log->crit(
545             "no matching IdP role for profile (%s) found for identity provider (%s) responsible for the session.", row[4], row[5]
546             );
547         mysql_free_result(rows);
548         return NULL;
549     }
550
551     // Try to parse the SAML data
552     try {
553         istringstream istr(row[6]);
554         s = new SAMLAuthenticationStatement(istr);
555         if (row[7]) {
556             istringstream istr2(row[7]);
557             r = new SAMLResponse(istr2);
558         }
559     }
560     catch (SAMLException& e) {
561         log->error(string("caught SAML exception while loading objects from SQL record: ") + e.what());
562         delete s;
563         delete r;
564         mysql_free_result(rows);
565         return NULL;
566     }
567 #ifndef _DEBUG
568     catch (...) {
569         log->error("caught unknown exception while loading objects from SQL record");
570         delete s;
571         delete r;
572         mysql_free_result(rows);
573         return NULL;
574     }
575 #endif
576
577     // Insert it into the memory cache
578     m_cache->insert(
579         key,
580         application,
581         row[3],
582         profile,
583         row[5],
584         s,
585         r,
586         role,
587         atoi(row[1]),
588         atoi(row[2])
589         );
590
591     // Free the results, and then re-run the 'find' query
592     mysql_free_result(rows);
593     res = m_cache->find(key,application);
594     if (!res)
595       return NULL;
596   }
597
598   return new ShibMySQLCCacheEntry(key, res, this);
599 }
600
601 void ShibMySQLCCache::insert(
602     const char* key,
603     const IApplication* application,
604     const char* client_addr,
605     ShibProfile profile,
606     const char* providerId,
607     saml::SAMLAuthenticationStatement* s,
608     saml::SAMLResponse* r,
609     const shibboleth::IRoleDescriptor* source,
610     time_t created,
611     time_t accessed
612     )
613 {
614 #ifdef _DEBUG
615   saml::NDC ndc("insert");
616 #endif
617   
618   ostringstream q;
619   q << "INSERT INTO state VALUES('" << key << "','" << application->getId() << "',";
620   if (created==0)
621     q << "NOW(),";
622   else
623     q << "FROM_UNIXTIME(" << created << "),";
624   if (accessed==0)
625     q << "NOW(),'";
626   else
627     q << "FROM_UNIXTIME(" << accessed << "),'";
628   q << client_addr << "'," << profile << ",'" << providerId << "',";
629   if (m_storeAttributes && r) {
630     auto_ptr_char id(r->getId());
631     q << "'" << id.get() << "','" << *r << "','";
632   }
633   else
634     q << "null,null,'";
635   q << *s << "')";
636
637   log->debug("Query: %s", q.str().c_str());
638
639   // then add it to the database
640   MYSQL* mysql = getMYSQL();
641   if (mysql_query(mysql, q.str().c_str())) {
642     const char* err=mysql_error(mysql);
643     log->error("Error inserting %s: %s", key, err);
644     if (isCorrupt(err) && repairTable(mysql,"state")) {
645         // Try again...
646         if (mysql_query(mysql, q.str().c_str())) {
647           log->error("Error inserting %s: %s", key, mysql_error(mysql));
648           throw SAMLException("ShibMySQLCCache::insert(): insertion failed");
649         }
650     }
651     else
652         throw SAMLException("ShibMySQLCCache::insert(): insertion failed");
653   }
654
655   // Add it to the memory cache
656   m_cache->insert(key, application, client_addr, profile, providerId, s, r, source, created, accessed);
657 }
658
659 void ShibMySQLCCache::remove(const char* key)
660 {
661 #ifdef _DEBUG
662   saml::NDC ndc("remove");
663 #endif
664
665   // Remove the cached version
666   m_cache->remove(key);
667
668   // Remove from the database
669   string q = string("DELETE FROM state WHERE cookie='") + key + "'";
670   MYSQL* mysql = getMYSQL();
671   if (mysql_query(mysql, q.c_str())) {
672     const char* err=mysql_error(mysql);
673     log->error("Error deleting entry %s: %s", key, err);
674     if (isCorrupt(err) && repairTable(mysql,"state")) {
675         // Try again...
676         if (mysql_query(mysql, q.c_str()))
677           log->error("Error deleting entry %s: %s", key, mysql_error(mysql));
678     }
679   }
680 }
681
682 void ShibMySQLCCache::cleanup()
683 {
684 #ifdef _DEBUG
685   saml::NDC ndc("cleanup");
686 #endif
687
688   Mutex* mutex = Mutex::create();
689   MySQLBase::thread_init();
690
691   int rerun_timer = 0;
692   int timeout_life = 0;
693
694   // Load our configuration details...
695   const XMLCh* tag=m_root->getAttributeNS(NULL,cleanupInterval);
696   if (tag && *tag)
697     rerun_timer = XMLString::parseInt(tag);
698
699   // search for 'mysql-cache-timeout' and then the regular cache timeout
700   tag=m_root->getAttributeNS(NULL,mysqlTimeout);
701   if (tag && *tag)
702     timeout_life = XMLString::parseInt(tag);
703   else {
704       tag=m_root->getAttributeNS(NULL,cacheTimeout);
705       if (tag && *tag)
706         timeout_life = XMLString::parseInt(tag);
707   }
708   
709   if (rerun_timer <= 0)
710     rerun_timer = 300;          // rerun every 5 minutes
711
712   if (timeout_life <= 0)
713     timeout_life = 28800;       // timeout after 8 hours
714
715   mutex->lock();
716
717   MYSQL* mysql = getMYSQL();
718
719   while (shutdown == false) {
720     shutdown_wait->timedwait(mutex, rerun_timer);
721
722     if (shutdown == true)
723       break;
724
725     // Find all the entries in the database that haven't been used
726     // recently In particular, find all entries that have not been
727     // accessed in 'timeout_life' seconds.
728     ostringstream q;
729     q << "SELECT cookie FROM state WHERE " <<
730       "UNIX_TIMESTAMP(NOW()) - UNIX_TIMESTAMP(atime) >= " << timeout_life;
731
732     if (mysql_query(mysql, q.str().c_str())) {
733       const char* err=mysql_error(mysql);
734       log->error("Error searching for old items: %s", err);
735         if (isCorrupt(err) && repairTable(mysql,"state")) {
736           if (mysql_query(mysql, q.str().c_str()))
737             log->error("Error re-searching for old items: %s", mysql_error(mysql));
738         }
739     }
740
741     MYSQL_RES* rows = mysql_store_result(mysql);
742     if (!rows)
743       continue;
744
745     if (mysql_num_fields(rows) != 1) {
746       log->error("Wrong number of columns, 1 != %d", mysql_num_fields(rows));
747       mysql_free_result(rows);
748       continue;
749     }
750
751     // For each row, remove the entry from the database.
752     MYSQL_ROW row;
753     while ((row = mysql_fetch_row(rows)) != NULL)
754       remove(row[0]);
755
756     mysql_free_result(rows);
757   }
758
759   log->info("cleanup thread exiting...");
760
761   mutex->unlock();
762   delete mutex;
763   MySQLBase::thread_end();
764   Thread::exit(NULL);
765 }
766
767 void* ShibMySQLCCache::cleanup_fcn(void* cache_p)
768 {
769   ShibMySQLCCache* cache = (ShibMySQLCCache*)cache_p;
770
771   // First, let's block all signals
772   Thread::mask_all_signals();
773
774   // Now run the cleanup process.
775   cache->cleanup();
776   return NULL;
777 }
778
779 /*************************************************************************
780  * The CCacheEntry here is mostly a wrapper around the "memory"
781  * cacheentry provided by shibboleth.  The only difference is that we
782  * intercept isSessionValid() so that we can "touch()" the
783  * database if the session is still valid and getResponse() so we can
784  * store the data if we need to.
785  */
786
787 bool ShibMySQLCCacheEntry::isValid(time_t lifetime, time_t timeout) const
788 {
789   bool res = m_cacheEntry->isValid(lifetime, timeout);
790   if (res == true)
791     res = touch();
792   return res;
793 }
794
795 bool ShibMySQLCCacheEntry::touch() const
796 {
797   string q=string("UPDATE state SET atime=NOW() WHERE cookie='") + m_key + "'";
798
799   MYSQL* mysql = m_cache->getMYSQL();
800   if (mysql_query(mysql, q.c_str())) {
801     m_cache->log->info("Error updating timestamp on %s: %s", m_key.c_str(), mysql_error(mysql));
802     return false;
803   }
804   return true;
805 }
806
807 ISessionCacheEntry::CachedResponse ShibMySQLCCacheEntry::getResponse()
808 {
809     // Let the memory cache do the work first.
810     // If we're hands off, just pass it back.
811     if (!m_cache->m_storeAttributes)
812         return m_cacheEntry->getResponse();
813     
814     CachedResponse r=m_cacheEntry->getResponse();
815     if (r.empty()) return r;
816     
817     // Load the key from state if needed.
818     if (!m_responseId) {
819         string qselect=string("SELECT response_id from state WHERE cookie='") + m_key + "' LIMIT 1";
820         MYSQL* mysql = m_cache->getMYSQL();
821         if (mysql_query(mysql, qselect.c_str())) {
822             const char* err=mysql_error(mysql);
823             m_cache->log->error("error accessing response ID for %s: %s", m_key.c_str(), err);
824             if (isCorrupt(err) && m_cache->repairTable(mysql,"state")) {
825                 // Try again...
826                 if (mysql_query(mysql, qselect.c_str())) {
827                     m_cache->log->error("error accessing response ID for %s: %s", m_key.c_str(), mysql_error(mysql));
828                     return r;
829                 }
830             }
831         }
832         MYSQL_RES* rows = mysql_store_result(mysql);
833     
834         // Make sure we got 1 and only 1 row.
835         if (!rows || mysql_num_rows(rows) != 1) {
836             m_cache->log->error("select returned wrong number of rows");
837             if (rows) mysql_free_result(rows);
838             return r;
839         }
840         
841         MYSQL_ROW row=mysql_fetch_row(rows);
842         if (row)
843             m_responseId=XMLString::transcode(row[0]);
844         mysql_free_result(rows);
845     }
846     
847     // Compare it with what we have now.
848     if (m_responseId && !XMLString::compareString(m_responseId,r.unfiltered->getId()))
849         return r;
850     
851     // No match, so we need to update our copy.
852     if (m_responseId) XMLString::release(&m_responseId);
853     m_responseId = XMLString::replicate(r.unfiltered->getId());
854     auto_ptr_char id(m_responseId);
855
856     ostringstream q;
857     q << "UPDATE state SET response_id='" << id.get() << "',response='" << *r.unfiltered << "' WHERE cookie='" << m_key << "'";
858     m_cache->log->debug("Query: %s", q.str().c_str());
859
860     MYSQL* mysql = m_cache->getMYSQL();
861     if (mysql_query(mysql, q.str().c_str())) {
862         const char* err=mysql_error(mysql);
863         m_cache->log->error("Error updating response for %s: %s", m_key.c_str(), err);
864         if (isCorrupt(err) && m_cache->repairTable(mysql,"state")) {
865             // Try again...
866             if (mysql_query(mysql, q.str().c_str()))
867               m_cache->log->error("Error updating response for %s: %s", m_key.c_str(), mysql_error(mysql));
868         }
869     }
870     
871     return r;
872 }
873
874 class MySQLReplayCache : public MySQLBase, virtual public IReplayCache
875 {
876 public:
877   MySQLReplayCache(const DOMElement* e);
878   virtual ~MySQLReplayCache() {}
879
880   void thread_init() {MySQLBase::thread_init();}
881   void thread_end() {MySQLBase::thread_end();}
882
883   bool check(const XMLCh* str, time_t expires) {auto_ptr_XMLCh temp(str); return check(temp.get(),expires);}
884   bool check(const char* str, time_t expires);
885 };
886
887 MySQLReplayCache::MySQLReplayCache(const DOMElement* e) : MySQLBase(e)
888 {
889 #ifdef _DEBUG
890   saml::NDC ndc("MySQLReplayCache");
891 #endif
892
893   log = &(Category::getInstance("shibmysql.ReplayCache"));
894 }
895
896 bool MySQLReplayCache::check(const char* str, time_t expires)
897 {
898 #ifdef _DEBUG
899     saml::NDC ndc("check");
900 #endif
901   
902     // Remove expired entries
903     string q = string("DELETE FROM replay WHERE expires < NOW()");
904     MYSQL* mysql = getMYSQL();
905     if (mysql_query(mysql, q.c_str())) {
906         const char* err=mysql_error(mysql);
907         log->error("Error deleting expired entries: %s", err);
908         if (isCorrupt(err) && repairTable(mysql,"replay")) {
909             // Try again...
910             if (mysql_query(mysql, q.c_str()))
911                 log->error("Error deleting expired entries: %s", mysql_error(mysql));
912         }
913     }
914   
915     string q2 = string("SELECT id FROM replay WHERE id='") + str + "'";
916     if (mysql_query(mysql, q2.c_str())) {
917         const char* err=mysql_error(mysql);
918         log->error("Error searching for %s: %s", str, err);
919         if (isCorrupt(err) && repairTable(mysql,"replay")) {
920             if (mysql_query(mysql, q2.c_str())) {
921                 log->error("Error retrying search for %s: %s", str, mysql_error(mysql));
922                 throw SAMLException("Replay cache failed, please inform application support staff.");
923             }
924         }
925         else
926             throw SAMLException("Replay cache failed, please inform application support staff.");
927     }
928
929     // Did we find it?
930     MYSQL_RES* rows = mysql_store_result(mysql);
931     if (rows && mysql_num_rows(rows)>0) {
932       mysql_free_result(rows);
933       return false;
934     }
935
936     ostringstream q3;
937     q3 << "INSERT INTO replay VALUES('" << str << "'," << "FROM_UNIXTIME(" << expires << "))";
938
939     // then add it to the database
940     if (mysql_query(mysql, q3.str().c_str())) {
941         const char* err=mysql_error(mysql);
942         log->error("Error inserting %s: %s", str, err);
943         if (isCorrupt(err) && repairTable(mysql,"state")) {
944             // Try again...
945             if (mysql_query(mysql, q3.str().c_str())) {
946                 log->error("Error inserting %s: %s", str, mysql_error(mysql));
947                 throw SAMLException("Replay cache failed, please inform application support staff.");
948             }
949         }
950         else
951             throw SAMLException("Replay cache failed, please inform application support staff.");
952     }
953     
954     return true;
955 }
956
957 /*************************************************************************
958  * The registration functions here...
959  */
960
961 IPlugIn* new_mysql_ccache(const DOMElement* e)
962 {
963   return new ShibMySQLCCache(e);
964 }
965
966 IPlugIn* new_mysql_replay(const DOMElement* e)
967 {
968   return new MySQLReplayCache(e);
969 }
970
971 #define REPLAYPLUGINTYPE "edu.internet2.middleware.shibboleth.sp.provider.MySQLReplayCacheProvider"
972 #define SESSIONPLUGINTYPE "edu.internet2.middleware.shibboleth.sp.provider.MySQLSessionCacheProvider"
973
974 extern "C" int SHIBMYSQL_EXPORTS saml_extension_init(void*)
975 {
976   // register this ccache type
977   SAMLConfig::getConfig().getPlugMgr().regFactory(REPLAYPLUGINTYPE, &new_mysql_replay);
978   SAMLConfig::getConfig().getPlugMgr().regFactory(SESSIONPLUGINTYPE, &new_mysql_ccache);
979   return 0;
980 }
981
982 extern "C" void SHIBMYSQL_EXPORTS saml_extension_term()
983 {
984   // Shutdown MySQL
985   mysql_server_end();
986   SAMLConfig::getConfig().getPlugMgr().unregFactory(REPLAYPLUGINTYPE);
987   SAMLConfig::getConfig().getPlugMgr().unregFactory(SESSIONPLUGINTYPE);
988 }