Expose SAML objects from ICacheEntry
[shibboleth/cpp-sp.git] / shib-mysql-ccache / shib-mysql-ccache.cpp
1 /*
2  *  Copyright 2001-2005 Internet2
3  * 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /*
18  * shib-mysql-ccache.cpp: Shibboleth Credential Cache using MySQL.
19  *
20  * Created by:  Derek Atkins <derek@ihtfp.com>
21  *
22  * $Id$
23  */
24
25 /* This file is loosely based off the Shibboleth Credential Cache.
26  * This plug-in is designed as a two-layer cache.  Layer 1, the
27  * long-term cache, stores data in a MySQL embedded database.
28  *
29  * Short-term data is cached in memory as SAML objects in the layer 2
30  * cache.
31  */
32
33 // eventually we might be able to support autoconf via cygwin...
34 #if defined (_MSC_VER) || defined(__BORLANDC__)
35 # include "config_win32.h"
36 #else
37 # include "config.h"
38 #endif
39
40 #ifdef WIN32
41 # define SHIBMYSQL_EXPORTS __declspec(dllexport)
42 #else
43 # define SHIBMYSQL_EXPORTS
44 #endif
45
46 #ifdef HAVE_UNISTD_H
47 # include <unistd.h>
48 #endif
49
50 #include <shib-target/shib-target.h>
51 #include <shib/shib-threads.h>
52 #include <log4cpp/Category.hh>
53
54 #include <sstream>
55 #include <stdexcept>
56
57 #include <mysql.h>
58
59 // wanted to use MySQL codes for this, but can't seem to get back a 145
60 #define isCorrupt(s) strstr(s,"(errno: 145)")
61
62 #ifdef HAVE_LIBDMALLOCXX
63 #include <dmalloc.h>
64 #endif
65
66 using namespace std;
67 using namespace saml;
68 using namespace shibboleth;
69 using namespace shibtarget;
70 using namespace log4cpp;
71
72 #define PLUGIN_VER_MAJOR 2
73 #define PLUGIN_VER_MINOR 0
74
75 #define STATE_TABLE \
76   "CREATE TABLE state (cookie VARCHAR(64) PRIMARY KEY, " \
77   "application_id VARCHAR(255)," \
78   "ctime TIMESTAMP," \
79   "atime TIMESTAMP," \
80   "addr VARCHAR(128)," \
81   "profile INT," \
82   "provider VARCHAR(256)," \
83   "response_id VARCHAR(128)," \
84   "response TEXT," \
85   "statement TEXT)"
86
87 #define REPLAY_TABLE \
88   "CREATE TABLE replay (id VARCHAR(255) PRIMARY KEY, " \
89   "expires TIMESTAMP, " \
90   "INDEX (expires))"
91
92 static const XMLCh Argument[] =
93 { chLatin_A, chLatin_r, chLatin_g, chLatin_u, chLatin_m, chLatin_e, chLatin_n, chLatin_t, chNull };
94 static const XMLCh cleanupInterval[] =
95 { chLatin_c, chLatin_l, chLatin_e, chLatin_a, chLatin_n, chLatin_u, chLatin_p,
96   chLatin_I, chLatin_n, chLatin_t, chLatin_e, chLatin_r, chLatin_v, chLatin_a, chLatin_l, chNull
97 };
98 static const XMLCh cacheTimeout[] =
99 { chLatin_c, chLatin_a, chLatin_c, chLatin_h, chLatin_e, chLatin_T, chLatin_i, chLatin_m, chLatin_e, chLatin_o, chLatin_u, chLatin_t, chNull };
100 static const XMLCh mysqlTimeout[] =
101 { chLatin_m, chLatin_y, chLatin_s, chLatin_q, chLatin_l, chLatin_T, chLatin_i, chLatin_m, chLatin_e, chLatin_o, chLatin_u, chLatin_t, chNull };
102 static const XMLCh storeAttributes[] =
103 { chLatin_s, chLatin_t, chLatin_o, chLatin_r, chLatin_e, chLatin_A, chLatin_t, chLatin_t, chLatin_r, chLatin_i, chLatin_b, chLatin_u, chLatin_t, chLatin_e, chLatin_s, chNull };
104
105 class MySQLBase
106 {
107 public:
108   MySQLBase(const DOMElement* e);
109   virtual ~MySQLBase();
110
111   void thread_init();
112   void thread_end() {}
113
114   MYSQL* getMYSQL() const;
115   bool repairTable(MYSQL*&, const char* table);
116
117   log4cpp::Category* log;
118
119 protected:
120   ThreadKey* m_mysql;
121   const DOMElement* m_root; // can only use this during initialization
122
123   bool initialized;
124
125   void createDatabase(MYSQL*, int major, int minor);
126   void upgradeDatabase(MYSQL*);
127   void getVersion(MYSQL*, int* major_p, int* minor_p);
128 };
129
130 // Forward declarations
131 static void mysqlInit(const DOMElement* e, Category& log);
132
133 extern "C" void shib_mysql_destroy_handle(void* data)
134 {
135   MYSQL* mysql = (MYSQL*) data;
136   if (mysql) mysql_close(mysql);
137 }
138
139 MySQLBase::MySQLBase(const DOMElement* e) : m_root(e)
140 {
141 #ifdef _DEBUG
142   saml::NDC ndc("MySQLBase");
143 #endif
144   log = &(Category::getInstance("shibmysql.MySQLBase"));
145
146   m_mysql = ThreadKey::create(&shib_mysql_destroy_handle);
147
148   initialized = false;
149   mysqlInit(e,*log);
150   thread_init();
151   initialized = true;
152 }
153
154 MySQLBase::~MySQLBase()
155 {
156   thread_end();
157
158   delete m_mysql;
159 }
160
161 MYSQL* MySQLBase::getMYSQL() const
162 {
163   return (MYSQL*)m_mysql->getData();
164 }
165
166 void MySQLBase::thread_init()
167 {
168 #ifdef _DEBUG
169   saml::NDC ndc("thread_init");
170 #endif
171
172   // Connect to the database
173   MYSQL* mysql = mysql_init(NULL);
174   if (!mysql) {
175     log->error("mysql_init failed");
176     mysql_close(mysql);
177     throw SAMLException("MySQLBase::thread_init(): mysql_init() failed");
178   }
179
180   if (!mysql_real_connect(mysql, NULL, NULL, NULL, "shar", 0, NULL, 0)) {
181     if (initialized) {
182       log->crit("mysql_real_connect failed: %s", mysql_error(mysql));
183       mysql_close(mysql);
184       throw SAMLException("MySQLBase::thread_init(): mysql_real_connect() failed");
185     } else {
186       log->info("mysql_real_connect failed: %s.  Trying to create", mysql_error(mysql));
187
188       // This will throw an exception if it fails.
189       createDatabase(mysql, PLUGIN_VER_MAJOR, PLUGIN_VER_MINOR);
190     }
191   }
192
193   int major = -1, minor = -1;
194   getVersion (mysql, &major, &minor);
195
196   // Make sure we've got the right version
197   if (major != PLUGIN_VER_MAJOR || minor != PLUGIN_VER_MINOR) {
198    
199     // If we're capable, try upgrading on the fly...
200     if (major == 0  || major == 1) {
201        upgradeDatabase(mysql);
202     }
203     else {
204         mysql_close(mysql);
205         log->crit("Unknown database version: %d.%d", major, minor);
206         throw SAMLException("MySQLBase::thread_init(): Unknown database version");
207     }
208   }
209
210   // We're all set.. Save off the handle for this thread.
211   m_mysql->setData(mysql);
212 }
213
214 bool MySQLBase::repairTable(MYSQL*& mysql, const char* table)
215 {
216   string q = string("REPAIR TABLE ") + table;
217   if (mysql_query(mysql, q.c_str())) {
218     log->error("Error repairing table %s: %s", table, mysql_error(mysql));
219     return false;
220   }
221
222   // seems we have to recycle the connection to get the thread to keep working
223   // other threads seem to be ok, but we should monitor that
224   mysql_close(mysql);
225   m_mysql->setData(NULL);
226   thread_init();
227   mysql=getMYSQL();
228   return true;
229 }
230
231 void MySQLBase::createDatabase(MYSQL* mysql, int major, int minor)
232 {
233   log->info("Creating database.");
234
235   MYSQL* ms = NULL;
236   try {
237     ms = mysql_init(NULL);
238     if (!ms) {
239       log->crit("mysql_init failed");
240       throw SAMLException("ShibMySQLCCache::createDatabase(): mysql_init failed");
241     }
242
243     if (!mysql_real_connect(ms, NULL, NULL, NULL, NULL, 0, NULL, 0)) {
244       log->crit("cannot open DB file to create DB: %s", mysql_error(ms));
245       throw SAMLException("ShibMySQLCCache::createDatabase(): mysql_real_connect failed");
246     }
247
248     if (mysql_query(ms, "CREATE DATABASE shar")) {
249       log->crit("cannot create shar database: %s", mysql_error(ms));
250       throw SAMLException("ShibMySQLCCache::createDatabase(): create db cmd failed");
251     }
252
253     if (!mysql_real_connect(mysql, NULL, NULL, NULL, "shar", 0, NULL, 0)) {
254       log->crit("cannot open SHAR database");
255       throw SAMLException("ShibMySQLCCache::createDatabase(): mysql_real_connect to plugin db failed");
256     }
257
258     mysql_close(ms);
259     
260   }
261   catch (SAMLException&) {
262     if (ms)
263       mysql_close(ms);
264     mysql_close(mysql);
265     throw;
266   }
267
268   // Now create the tables if they don't exist
269   log->info("Creating database tables");
270
271   if (mysql_query(mysql, "CREATE TABLE version (major INT, minor INT)")) {
272     log->error ("Error creating version: %s", mysql_error(mysql));
273     throw SAMLException("ShibMySQLCCache::createDatabase(): create table cmd failed");
274   }
275
276   if (mysql_query(mysql,STATE_TABLE)) {
277     log->error ("Error creating state table: %s", mysql_error(mysql));
278     throw SAMLException("ShibMySQLCCache::createDatabase(): create table cmd failed");
279   }
280
281   if (mysql_query(mysql,REPLAY_TABLE)) {
282     log->error ("Error creating replay table: %s", mysql_error(mysql));
283     throw SAMLException("ShibMySQLCCache::createDatabase(): create table cmd failed");
284   }
285
286   ostringstream q;
287   q << "INSERT INTO version VALUES(" << major << "," << minor << ")";
288   if (mysql_query(mysql, q.str().c_str())) {
289     log->error ("Error setting version: %s", mysql_error(mysql));
290     throw SAMLException("ShibMySQLCCache::createDatabase(): version insert failed");
291   }
292 }
293
294 void MySQLBase::upgradeDatabase(MYSQL* mysql)
295 {
296     if (mysql_query(mysql, "DROP TABLE state")) {
297         log->error("Error dropping old session state table: %s", mysql_error(mysql));
298     }
299
300     if (mysql_query(mysql,STATE_TABLE)) {
301         log->error ("Error creating state table: %s", mysql_error(mysql));
302         throw SAMLException("ShibMySQLCCache::upgradeDatabase(): error creating state table");
303     }
304
305     if (mysql_query(mysql,REPLAY_TABLE)) {
306         log->error ("Error creating replay table: %s", mysql_error(mysql));
307         throw SAMLException("ShibMySQLCCache::upgradeDatabase(): error creating replay table");
308     }
309
310     ostringstream q;
311     q << "UPDATE version SET major = " << PLUGIN_VER_MAJOR;
312     if (mysql_query(mysql, q.str().c_str())) {
313         log->error ("Error updating version: %s", mysql_error(mysql));
314         throw SAMLException("ShibMySQLCCache::upgradeDatabase(): error updating version");
315     }
316 }
317
318 void MySQLBase::getVersion(MYSQL* mysql, int* major_p, int* minor_p)
319 {
320   // grab the version number from the database
321   if (mysql_query(mysql, "SELECT * FROM version"))
322     log->error ("Error reading version: %s", mysql_error(mysql));
323
324   MYSQL_RES* rows = mysql_store_result(mysql);
325   if (rows) {
326     if (mysql_num_rows(rows) == 1 && mysql_num_fields(rows) == 2)  {
327       MYSQL_ROW row = mysql_fetch_row(rows);
328
329       int major = row[0] ? atoi(row[0]) : -1;
330       int minor = row[1] ? atoi(row[1]) : -1;
331       log->debug("opening database version %d.%d", major, minor);
332       
333       mysql_free_result (rows);
334
335       *major_p = major;
336       *minor_p = minor;
337       return;
338
339     } else {
340       // Wrong number of rows or wrong number of fields...
341
342       log->crit("Houston, we've got a problem with the database...");
343       mysql_free_result (rows);
344       throw SAMLException("ShibMySQLCCache::getVersion(): version verification failed");
345     }
346   }
347   log->crit("MySQL Read Failed in version verificatoin");
348   throw SAMLException("ShibMySQLCCache::getVersion(): error reading version");
349 }
350
351 static void mysqlInit(const DOMElement* e, Category& log)
352 {
353   static bool done = false;
354   if (done) {
355     log.info("MySQL embedded server already initialized");
356     return;
357   }
358   log.info("initializing MySQL embedded server");
359
360   // Setup the argument array
361   vector<string> arg_array;
362   arg_array.push_back("shibboleth");
363
364   // grab any MySQL parameters from the config file
365   e=saml::XML::getFirstChildElement(e,shibtarget::XML::SHIBTARGET_NS,Argument);
366   while (e) {
367       auto_ptr_char arg(e->getFirstChild()->getNodeValue());
368       if (arg.get())
369           arg_array.push_back(arg.get());
370       e=saml::XML::getNextSiblingElement(e,shibtarget::XML::SHIBTARGET_NS,Argument);
371   }
372
373   // Compute the argument array
374   vector<string>::size_type arg_count = arg_array.size();
375   const char** args=new const char*[arg_count];
376   for (vector<string>::size_type i = 0; i < arg_count; i++)
377     args[i] = arg_array[i].c_str();
378
379   // Initialize MySQL with the arguments
380   mysql_server_init(arg_count, (char **)args, NULL);
381
382   delete[] args;
383   done = true;
384 }  
385
386 class ShibMySQLCCache;
387 class ShibMySQLCCacheEntry : public ISessionCacheEntry
388 {
389 public:
390   ShibMySQLCCacheEntry(const char* key, ISessionCacheEntry* entry, ShibMySQLCCache* cache)
391     : m_cacheEntry(entry), m_key(key), m_cache(cache), m_responseId(NULL) {}
392   ~ShibMySQLCCacheEntry() {if (m_responseId) XMLString::release(&m_responseId);}
393
394   virtual void lock() {}
395   virtual void unlock() { m_cacheEntry->unlock(); delete this; }
396   virtual bool isValid(time_t lifetime, time_t timeout) const;
397   virtual const char* getClientAddress() const { return m_cacheEntry->getClientAddress(); }
398   virtual ShibProfile getProfile() const { return m_cacheEntry->getProfile(); }
399   virtual const char* getProviderId() const { return m_cacheEntry->getProviderId(); }
400   virtual const char* getAuthnStatementXML() const { return m_cacheEntry->getAuthnStatementXML(); }
401   virtual const SAMLAuthenticationStatement* getAuthnStatementSAML() const { return m_cacheEntry->getAuthnStatementSAML(); }
402   virtual CachedResponseXML getResponseXML();
403   virtual CachedResponseSAML getResponseSAML() { return m_cacheEntry->getResponseSAML(); }
404
405 private:
406   bool touch() const;
407
408   ShibMySQLCCache* m_cache;
409   ISessionCacheEntry* m_cacheEntry;
410   string m_key;
411   char* m_responseId;
412 };
413
414 class ShibMySQLCCache : public MySQLBase, virtual public ISessionCache
415 {
416 public:
417   ShibMySQLCCache(const DOMElement* e);
418   virtual ~ShibMySQLCCache();
419
420   virtual void thread_init() {MySQLBase::thread_init();}
421   virtual void thread_end() {MySQLBase::thread_end();}
422
423   virtual string generateKey() const {return m_cache->generateKey();}
424   virtual ISessionCacheEntry* find(const char* key, const IApplication* application);
425   virtual void insert(
426     const char* key,
427     const IApplication* application,
428     const char* client_addr,
429     ShibProfile profile,
430     const char* providerId,
431     const saml::SAMLAuthenticationStatement* s,
432     saml::SAMLResponse* r=NULL,
433     const shibboleth::IRoleDescriptor* source=NULL,
434     time_t created=0,
435     time_t accessed=0
436     );
437   virtual void remove(const char* key);
438
439   virtual void cleanup();
440
441   bool m_storeAttributes;
442
443 private:
444   ISessionCache* m_cache;
445   CondWait* shutdown_wait;
446   bool shutdown;
447   Thread* cleanup_thread;
448
449   static void* cleanup_fcn(void*); // XXX Assumed an ShibMySQLCCache
450 };
451
452 ShibMySQLCCache::ShibMySQLCCache(const DOMElement* e) : MySQLBase(e), m_storeAttributes(false)
453 {
454 #ifdef _DEBUG
455   saml::NDC ndc("ShibMySQLCCache");
456 #endif
457
458   log = &(Category::getInstance("shibmysql.SessionCache"));
459
460   shutdown_wait = CondWait::create();
461   shutdown = false;
462
463   m_cache = dynamic_cast<ISessionCache*>(
464       SAMLConfig::getConfig().getPlugMgr().newPlugin(
465         "edu.internet2.middleware.shibboleth.sp.provider.MemorySessionCacheProvider", e
466         )
467     );
468     
469   // Load our configuration details...
470   const XMLCh* tag=m_root->getAttributeNS(NULL,storeAttributes);
471   if (tag && *tag && (*tag==chLatin_t || *tag==chDigit_1))
472     m_storeAttributes=true;
473
474   // Initialize the cleanup thread
475   cleanup_thread = Thread::create(&cleanup_fcn, (void*)this);
476 }
477
478 ShibMySQLCCache::~ShibMySQLCCache()
479 {
480   shutdown = true;
481   shutdown_wait->signal();
482   cleanup_thread->join(NULL);
483
484   delete m_cache;
485 }
486
487 ISessionCacheEntry* ShibMySQLCCache::find(const char* key, const IApplication* application)
488 {
489 #ifdef _DEBUG
490   saml::NDC ndc("find");
491 #endif
492
493   ISessionCacheEntry* res = m_cache->find(key, application);
494   if (!res) {
495
496     log->debug("Looking in database...");
497
498     // nothing cached; see if this exists in the database
499     string q = string("SELECT application_id,UNIX_TIMESTAMP(ctime),UNIX_TIMESTAMP(atime),addr,profile,provider,statement,response FROM state WHERE cookie='") + key + "' LIMIT 1";
500
501     MYSQL* mysql = getMYSQL();
502     if (mysql_query(mysql, q.c_str())) {
503       const char* err=mysql_error(mysql);
504       log->error("Error searching for %s: %s", key, err);
505       if (isCorrupt(err) && repairTable(mysql,"state")) {
506         if (mysql_query(mysql, q.c_str()))
507           log->error("Error retrying search for %s: %s", key, mysql_error(mysql));
508       }
509     }
510
511     MYSQL_RES* rows = mysql_store_result(mysql);
512
513     // Nope, doesn't exist.
514     if (!rows)
515       return NULL;
516
517     // Make sure we got 1 and only 1 rows.
518     if (mysql_num_rows(rows) != 1) {
519       log->error("Select returned wrong number of rows: %d", mysql_num_rows(rows));
520       mysql_free_result(rows);
521       return NULL;
522     }
523
524     log->debug("Match found.  Parsing...");
525     
526     /* Columns in query:
527         0: application_id
528         1: ctime
529         2: atime
530         3: address
531         4: profile
532         5: provider
533         6: statement
534         7: response
535      */
536
537     // Pull apart the row and process the results
538     MYSQL_ROW row = mysql_fetch_row(rows);
539     if (strcmp(application->getId(),row[0])) {
540         log->crit("An application (%s) attempted to access another application's (%s) session!", application->getId(), row[0]);
541         mysql_free_result(rows);
542         return NULL;
543     }
544
545     Metadata m(application->getMetadataProviders());
546     const IEntityDescriptor* provider=m.lookup(row[5]);
547     if (!provider) {
548         log->crit("no metadata found for identity provider (%s) responsible for the session.", row[5]);
549         mysql_free_result(rows);
550         return NULL;
551     }
552
553     ShibProfile profile=static_cast<ShibProfile>(atoi(row[4]));
554     const IRoleDescriptor* role=NULL;
555     if (profile==SAML11_POST || profile==SAML11_ARTIFACT)
556         role=provider->getIDPSSODescriptor(saml::XML::SAML11_PROTOCOL_ENUM);
557     else if (profile==SAML10_POST || profile==SAML10_ARTIFACT)
558         role=provider->getIDPSSODescriptor(saml::XML::SAML10_PROTOCOL_ENUM);
559     if (!role) {
560         log->crit(
561             "no matching IdP role for profile (%s) found for identity provider (%s) responsible for the session.", row[4], row[5]
562             );
563         mysql_free_result(rows);
564         return NULL;
565     }
566
567     // Try to parse the SAML data
568     try {
569         istringstream istr(row[6]);
570         auto_ptr<SAMLAuthenticationStatement> s(new SAMLAuthenticationStatement(istr));
571
572         SAMLResponse* r=NULL;
573         if (row[7]) {
574             istringstream istr2(row[7]);
575             r = new SAMLResponse(istr2);
576         }
577         auto_ptr<SAMLResponse> rwrap(r);
578
579         // Insert it into the memory cache
580         m_cache->insert(
581             key,
582             application,
583             row[3],
584             profile,
585             row[5],
586             s.get(),
587             r,
588             role,
589             atoi(row[1]),
590             atoi(row[2])
591             );
592     }
593     catch (SAMLException& e) {
594         log->error(string("caught SAML exception while loading objects from SQL record: ") + e.what());
595         mysql_free_result(rows);
596         return NULL;
597     }
598 #ifndef _DEBUG
599     catch (...) {
600         log->error("caught unknown exception while loading objects from SQL record");
601         mysql_free_result(rows);
602         return NULL;
603     }
604 #endif
605
606     // Free the results, and then re-run the 'find' query
607     mysql_free_result(rows);
608     res = m_cache->find(key,application);
609     if (!res)
610       return NULL;
611   }
612
613   return new ShibMySQLCCacheEntry(key, res, this);
614 }
615
616 void ShibMySQLCCache::insert(
617     const char* key,
618     const IApplication* application,
619     const char* client_addr,
620     ShibProfile profile,
621     const char* providerId,
622     const saml::SAMLAuthenticationStatement* s,
623     saml::SAMLResponse* r,
624     const shibboleth::IRoleDescriptor* source,
625     time_t created,
626     time_t accessed
627     )
628 {
629 #ifdef _DEBUG
630   saml::NDC ndc("insert");
631 #endif
632   
633   ostringstream q;
634   q << "INSERT INTO state VALUES('" << key << "','" << application->getId() << "',";
635   if (created==0)
636     q << "NOW(),";
637   else
638     q << "FROM_UNIXTIME(" << created << "),";
639   if (accessed==0)
640     q << "NOW(),'";
641   else
642     q << "FROM_UNIXTIME(" << accessed << "),'";
643   q << client_addr << "'," << profile << ",'" << providerId << "',";
644   if (m_storeAttributes && r) {
645     auto_ptr_char id(r->getId());
646     q << "'" << id.get() << "','" << *r << "','";
647   }
648   else
649     q << "null,null,'";
650   q << *s << "')";
651
652   log->debug("Query: %s", q.str().c_str());
653
654   // then add it to the database
655   MYSQL* mysql = getMYSQL();
656   if (mysql_query(mysql, q.str().c_str())) {
657     const char* err=mysql_error(mysql);
658     log->error("Error inserting %s: %s", key, err);
659     if (isCorrupt(err) && repairTable(mysql,"state")) {
660         // Try again...
661         if (mysql_query(mysql, q.str().c_str())) {
662           log->error("Error inserting %s: %s", key, mysql_error(mysql));
663           throw SAMLException("ShibMySQLCCache::insert(): insertion failed");
664         }
665     }
666     else
667         throw SAMLException("ShibMySQLCCache::insert(): insertion failed");
668   }
669
670   // Add it to the memory cache
671   m_cache->insert(key, application, client_addr, profile, providerId, s, r, source, created, accessed);
672 }
673
674 void ShibMySQLCCache::remove(const char* key)
675 {
676 #ifdef _DEBUG
677   saml::NDC ndc("remove");
678 #endif
679
680   // Remove the cached version
681   m_cache->remove(key);
682
683   // Remove from the database
684   string q = string("DELETE FROM state WHERE cookie='") + key + "'";
685   MYSQL* mysql = getMYSQL();
686   if (mysql_query(mysql, q.c_str())) {
687     const char* err=mysql_error(mysql);
688     log->error("Error deleting entry %s: %s", key, err);
689     if (isCorrupt(err) && repairTable(mysql,"state")) {
690         // Try again...
691         if (mysql_query(mysql, q.c_str()))
692           log->error("Error deleting entry %s: %s", key, mysql_error(mysql));
693     }
694   }
695 }
696
697 void ShibMySQLCCache::cleanup()
698 {
699 #ifdef _DEBUG
700   saml::NDC ndc("cleanup");
701 #endif
702
703   Mutex* mutex = Mutex::create();
704   MySQLBase::thread_init();
705
706   int rerun_timer = 0;
707   int timeout_life = 0;
708
709   // Load our configuration details...
710   const XMLCh* tag=m_root->getAttributeNS(NULL,cleanupInterval);
711   if (tag && *tag)
712     rerun_timer = XMLString::parseInt(tag);
713
714   // search for 'mysql-cache-timeout' and then the regular cache timeout
715   tag=m_root->getAttributeNS(NULL,mysqlTimeout);
716   if (tag && *tag)
717     timeout_life = XMLString::parseInt(tag);
718   else {
719       tag=m_root->getAttributeNS(NULL,cacheTimeout);
720       if (tag && *tag)
721         timeout_life = XMLString::parseInt(tag);
722   }
723   
724   if (rerun_timer <= 0)
725     rerun_timer = 300;          // rerun every 5 minutes
726
727   if (timeout_life <= 0)
728     timeout_life = 28800;       // timeout after 8 hours
729
730   mutex->lock();
731
732   MYSQL* mysql = getMYSQL();
733
734   while (shutdown == false) {
735     shutdown_wait->timedwait(mutex, rerun_timer);
736
737     if (shutdown == true)
738       break;
739
740     // Find all the entries in the database that haven't been used
741     // recently In particular, find all entries that have not been
742     // accessed in 'timeout_life' seconds.
743     ostringstream q;
744     q << "SELECT cookie FROM state WHERE " <<
745       "UNIX_TIMESTAMP(NOW()) - UNIX_TIMESTAMP(atime) >= " << timeout_life;
746
747     if (mysql_query(mysql, q.str().c_str())) {
748       const char* err=mysql_error(mysql);
749       log->error("Error searching for old items: %s", err);
750         if (isCorrupt(err) && repairTable(mysql,"state")) {
751           if (mysql_query(mysql, q.str().c_str()))
752             log->error("Error re-searching for old items: %s", mysql_error(mysql));
753         }
754     }
755
756     MYSQL_RES* rows = mysql_store_result(mysql);
757     if (!rows)
758       continue;
759
760     if (mysql_num_fields(rows) != 1) {
761       log->error("Wrong number of columns, 1 != %d", mysql_num_fields(rows));
762       mysql_free_result(rows);
763       continue;
764     }
765
766     // For each row, remove the entry from the database.
767     MYSQL_ROW row;
768     while ((row = mysql_fetch_row(rows)) != NULL)
769       remove(row[0]);
770
771     mysql_free_result(rows);
772   }
773
774   log->info("cleanup thread exiting...");
775
776   mutex->unlock();
777   delete mutex;
778   MySQLBase::thread_end();
779   Thread::exit(NULL);
780 }
781
782 void* ShibMySQLCCache::cleanup_fcn(void* cache_p)
783 {
784   ShibMySQLCCache* cache = (ShibMySQLCCache*)cache_p;
785
786   // First, let's block all signals
787   Thread::mask_all_signals();
788
789   // Now run the cleanup process.
790   cache->cleanup();
791   return NULL;
792 }
793
794 /*************************************************************************
795  * The CCacheEntry here is mostly a wrapper around the "memory"
796  * cacheentry provided by shibboleth.  The only difference is that we
797  * intercept isSessionValid() so that we can "touch()" the
798  * database if the session is still valid and getResponse() so we can
799  * store the data if we need to.
800  */
801
802 bool ShibMySQLCCacheEntry::isValid(time_t lifetime, time_t timeout) const
803 {
804   bool res = m_cacheEntry->isValid(lifetime, timeout);
805   if (res == true)
806     res = touch();
807   return res;
808 }
809
810 bool ShibMySQLCCacheEntry::touch() const
811 {
812   string q=string("UPDATE state SET atime=NOW() WHERE cookie='") + m_key + "'";
813
814   MYSQL* mysql = m_cache->getMYSQL();
815   if (mysql_query(mysql, q.c_str())) {
816     m_cache->log->info("Error updating timestamp on %s: %s", m_key.c_str(), mysql_error(mysql));
817     return false;
818   }
819   return true;
820 }
821
822 ISessionCacheEntry::CachedResponseXML ShibMySQLCCacheEntry::getResponseXML()
823 {
824     // Let the memory cache do the work first.
825     // If we're hands off, just pass it back.
826     if (!m_cache->m_storeAttributes)
827         return m_cacheEntry->getResponseXML();
828     
829     CachedResponseXML r=m_cacheEntry->getResponseXML();
830     if (!r.unfiltered || !*r.unfiltered) return r;
831     
832     // Load the key from state if needed.
833     if (!m_responseId) {
834         string qselect=string("SELECT response_id from state WHERE cookie='") + m_key + "' LIMIT 1";
835         MYSQL* mysql = m_cache->getMYSQL();
836         if (mysql_query(mysql, qselect.c_str())) {
837             const char* err=mysql_error(mysql);
838             m_cache->log->error("error accessing response ID for %s: %s", m_key.c_str(), err);
839             if (isCorrupt(err) && m_cache->repairTable(mysql,"state")) {
840                 // Try again...
841                 if (mysql_query(mysql, qselect.c_str())) {
842                     m_cache->log->error("error accessing response ID for %s: %s", m_key.c_str(), mysql_error(mysql));
843                     return r;
844                 }
845             }
846         }
847         MYSQL_RES* rows = mysql_store_result(mysql);
848     
849         // Make sure we got 1 and only 1 row.
850         if (!rows || mysql_num_rows(rows) != 1) {
851             m_cache->log->error("select returned wrong number of rows");
852             if (rows) mysql_free_result(rows);
853             return r;
854         }
855         
856         MYSQL_ROW row=mysql_fetch_row(rows);
857         if (row)
858             m_responseId=strdup(row[0]);
859         mysql_free_result(rows);
860     }
861     
862     // Compare it with what we have now.
863     char* start=NULL;
864     if (m_responseId) {
865         start=strstr(r.unfiltered,"ResponseID=\"");
866         if (start && !strncmp(start + 12,m_responseId,strlen(m_responseId)))
867             return r;
868         start+=12;
869         free(m_responseId);
870     }
871     
872     // No match, so we need to update our copy.
873     char* end=strchr(start,'"');
874     m_responseId = (char*)malloc(sizeof(char)*(end-start+1));
875     memset(m_responseId,0,sizeof(char)*(end-start+1));
876     strncpy(m_responseId,start,end-start);
877
878     ostringstream q;
879     q << "UPDATE state SET response_id='" << m_responseId << "',response='" << r.unfiltered << "' WHERE cookie='" << m_key << "'";
880     m_cache->log->debug("Query: %s", q.str().c_str());
881
882     MYSQL* mysql = m_cache->getMYSQL();
883     if (mysql_query(mysql, q.str().c_str())) {
884         const char* err=mysql_error(mysql);
885         m_cache->log->error("Error updating response for %s: %s", m_key.c_str(), err);
886         if (isCorrupt(err) && m_cache->repairTable(mysql,"state")) {
887             // Try again...
888             if (mysql_query(mysql, q.str().c_str()))
889               m_cache->log->error("Error updating response for %s: %s", m_key.c_str(), mysql_error(mysql));
890         }
891     }
892     
893     return r;
894 }
895
896 class MySQLReplayCache : public MySQLBase, virtual public IReplayCache
897 {
898 public:
899   MySQLReplayCache(const DOMElement* e);
900   virtual ~MySQLReplayCache() {}
901
902   void thread_init() {MySQLBase::thread_init();}
903   void thread_end() {MySQLBase::thread_end();}
904
905   bool check(const XMLCh* str, time_t expires) {auto_ptr_XMLCh temp(str); return check(temp.get(),expires);}
906   bool check(const char* str, time_t expires);
907 };
908
909 MySQLReplayCache::MySQLReplayCache(const DOMElement* e) : MySQLBase(e)
910 {
911 #ifdef _DEBUG
912   saml::NDC ndc("MySQLReplayCache");
913 #endif
914
915   log = &(Category::getInstance("shibmysql.ReplayCache"));
916 }
917
918 bool MySQLReplayCache::check(const char* str, time_t expires)
919 {
920 #ifdef _DEBUG
921     saml::NDC ndc("check");
922 #endif
923   
924     // Remove expired entries
925     string q = string("DELETE FROM replay WHERE expires < NOW()");
926     MYSQL* mysql = getMYSQL();
927     if (mysql_query(mysql, q.c_str())) {
928         const char* err=mysql_error(mysql);
929         log->error("Error deleting expired entries: %s", err);
930         if (isCorrupt(err) && repairTable(mysql,"replay")) {
931             // Try again...
932             if (mysql_query(mysql, q.c_str()))
933                 log->error("Error deleting expired entries: %s", mysql_error(mysql));
934         }
935     }
936   
937     string q2 = string("SELECT id FROM replay WHERE id='") + str + "'";
938     if (mysql_query(mysql, q2.c_str())) {
939         const char* err=mysql_error(mysql);
940         log->error("Error searching for %s: %s", str, err);
941         if (isCorrupt(err) && repairTable(mysql,"replay")) {
942             if (mysql_query(mysql, q2.c_str())) {
943                 log->error("Error retrying search for %s: %s", str, mysql_error(mysql));
944                 throw SAMLException("Replay cache failed, please inform application support staff.");
945             }
946         }
947         else
948             throw SAMLException("Replay cache failed, please inform application support staff.");
949     }
950
951     // Did we find it?
952     MYSQL_RES* rows = mysql_store_result(mysql);
953     if (rows && mysql_num_rows(rows)>0) {
954       mysql_free_result(rows);
955       return false;
956     }
957
958     ostringstream q3;
959     q3 << "INSERT INTO replay VALUES('" << str << "'," << "FROM_UNIXTIME(" << expires << "))";
960
961     // then add it to the database
962     if (mysql_query(mysql, q3.str().c_str())) {
963         const char* err=mysql_error(mysql);
964         log->error("Error inserting %s: %s", str, err);
965         if (isCorrupt(err) && repairTable(mysql,"state")) {
966             // Try again...
967             if (mysql_query(mysql, q3.str().c_str())) {
968                 log->error("Error inserting %s: %s", str, mysql_error(mysql));
969                 throw SAMLException("Replay cache failed, please inform application support staff.");
970             }
971         }
972         else
973             throw SAMLException("Replay cache failed, please inform application support staff.");
974     }
975     
976     return true;
977 }
978
979 /*************************************************************************
980  * The registration functions here...
981  */
982
983 IPlugIn* new_mysql_ccache(const DOMElement* e)
984 {
985   return new ShibMySQLCCache(e);
986 }
987
988 IPlugIn* new_mysql_replay(const DOMElement* e)
989 {
990   return new MySQLReplayCache(e);
991 }
992
993 #define REPLAYPLUGINTYPE "edu.internet2.middleware.shibboleth.sp.provider.MySQLReplayCacheProvider"
994 #define SESSIONPLUGINTYPE "edu.internet2.middleware.shibboleth.sp.provider.MySQLSessionCacheProvider"
995
996 extern "C" int SHIBMYSQL_EXPORTS saml_extension_init(void*)
997 {
998   // register this ccache type
999   SAMLConfig::getConfig().getPlugMgr().regFactory(REPLAYPLUGINTYPE, &new_mysql_replay);
1000   SAMLConfig::getConfig().getPlugMgr().regFactory(SESSIONPLUGINTYPE, &new_mysql_ccache);
1001   return 0;
1002 }
1003
1004 extern "C" void SHIBMYSQL_EXPORTS saml_extension_term()
1005 {
1006   // Shutdown MySQL
1007   mysql_server_end();
1008   SAMLConfig::getConfig().getPlugMgr().unregFactory(REPLAYPLUGINTYPE);
1009   SAMLConfig::getConfig().getPlugMgr().unregFactory(SESSIONPLUGINTYPE);
1010 }