Added methods to return tokens or last access.
[shibboleth/cpp-sp.git] / shib-mysql-ccache / shib-mysql-ccache.cpp
1 /*
2  *  Copyright 2001-2005 Internet2
3  * 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /*
18  * shib-mysql-ccache.cpp: Shibboleth Credential Cache using MySQL.
19  *
20  * Created by:  Derek Atkins <derek@ihtfp.com>
21  *
22  * $Id$
23  */
24
25 // eventually we might be able to support autoconf via cygwin...
26 #if defined (_MSC_VER) || defined(__BORLANDC__)
27 # include "config_win32.h"
28 #else
29 # include "config.h"
30 #endif
31
32 #ifdef WIN32
33 # define SHIBMYSQL_EXPORTS __declspec(dllexport)
34 #else
35 # define SHIBMYSQL_EXPORTS
36 #endif
37
38 #ifdef HAVE_UNISTD_H
39 # include <unistd.h>
40 #endif
41
42 #include <shib/shib-threads.h>
43 #include <shib-target/shib-target.h>
44 #include <log4cpp/Category.hh>
45
46 #include <sstream>
47
48 #ifdef WIN32
49 # include <winsock.h>
50 #endif
51 #include <mysql.h>
52
53 // wanted to use MySQL codes for this, but can't seem to get back a 145
54 #define isCorrupt(s) strstr(s,"(errno: 145)")
55
56 #ifdef HAVE_LIBDMALLOCXX
57 #include <dmalloc.h>
58 #endif
59
60 using namespace std;
61 using namespace saml;
62 using namespace shibboleth;
63 using namespace shibtarget;
64 using namespace log4cpp;
65
66 #define PLUGIN_VER_MAJOR 3
67 #define PLUGIN_VER_MINOR 0
68
69 #define STATE_TABLE \
70   "CREATE TABLE state (" \
71   "cookie VARCHAR(64) PRIMARY KEY, " \
72   "application_id VARCHAR(255)," \
73   "ctime TIMESTAMP," \
74   "atime TIMESTAMP," \
75   "addr VARCHAR(128)," \
76   "major INT," \
77   "minor INT," \
78   "provider VARCHAR(256)," \
79   "subject TEXT," \
80   "authn_context TEXT," \
81   "tokens TEXT)"
82
83 #define REPLAY_TABLE \
84   "CREATE TABLE replay (id VARCHAR(255) PRIMARY KEY, " \
85   "expires TIMESTAMP, " \
86   "INDEX (expires))"
87
88 static const XMLCh Argument[] =
89 { chLatin_A, chLatin_r, chLatin_g, chLatin_u, chLatin_m, chLatin_e, chLatin_n, chLatin_t, chNull };
90 static const XMLCh cleanupInterval[] =
91 { chLatin_c, chLatin_l, chLatin_e, chLatin_a, chLatin_n, chLatin_u, chLatin_p,
92   chLatin_I, chLatin_n, chLatin_t, chLatin_e, chLatin_r, chLatin_v, chLatin_a, chLatin_l, chNull
93 };
94 static const XMLCh cacheTimeout[] =
95 { chLatin_c, chLatin_a, chLatin_c, chLatin_h, chLatin_e, chLatin_T, chLatin_i, chLatin_m, chLatin_e, chLatin_o, chLatin_u, chLatin_t, chNull };
96 static const XMLCh mysqlTimeout[] =
97 { chLatin_m, chLatin_y, chLatin_s, chLatin_q, chLatin_l, chLatin_T, chLatin_i, chLatin_m, chLatin_e, chLatin_o, chLatin_u, chLatin_t, chNull };
98 static const XMLCh storeAttributes[] =
99 { chLatin_s, chLatin_t, chLatin_o, chLatin_r, chLatin_e, chLatin_A, chLatin_t, chLatin_t, chLatin_r, chLatin_i, chLatin_b, chLatin_u, chLatin_t, chLatin_e, chLatin_s, chNull };
100
101 class MySQLBase : public virtual saml::IPlugIn
102 {
103 public:
104   MySQLBase(const DOMElement* e);
105   virtual ~MySQLBase();
106
107   MYSQL* getMYSQL();
108   bool repairTable(MYSQL*&, const char* table);
109
110   log4cpp::Category* log;
111
112 protected:
113   ThreadKey* m_mysql;
114   const DOMElement* m_root; // can only use this during initialization
115
116   bool initialized;
117
118   void createDatabase(MYSQL*, int major, int minor);
119   void upgradeDatabase(MYSQL*);
120   pair<int,int> getVersion(MYSQL*);
121 };
122
123 // Forward declarations
124 static void mysqlInit(const DOMElement* e, Category& log);
125
126 extern "C" void shib_mysql_destroy_handle(void* data)
127 {
128   MYSQL* mysql = (MYSQL*) data;
129   if (mysql) mysql_close(mysql);
130 }
131
132 MySQLBase::MySQLBase(const DOMElement* e) : m_root(e)
133 {
134 #ifdef _DEBUG
135   saml::NDC ndc("MySQLBase");
136 #endif
137   log = &(Category::getInstance("shibmysql.MySQLBase"));
138
139   m_mysql = ThreadKey::create(&shib_mysql_destroy_handle);
140
141   initialized = false;
142   mysqlInit(e,*log);
143   getMYSQL();
144   initialized = true;
145 }
146
147 MySQLBase::~MySQLBase()
148 {
149   delete m_mysql;
150 }
151
152 MYSQL* MySQLBase::getMYSQL()
153 {
154 #ifdef _DEBUG
155     saml::NDC ndc("getMYSQL");
156 #endif
157
158     // Do we already have a handle?
159     MYSQL* mysql=reinterpret_cast<MYSQL*>(m_mysql->getData());
160     if (mysql)
161         return mysql;
162
163     // Connect to the database
164     mysql = mysql_init(NULL);
165     if (!mysql) {
166         log->error("mysql_init failed");
167         mysql_close(mysql);
168         throw SAMLException("MySQLBase::getMYSQL(): mysql_init() failed");
169     }
170
171     if (!mysql_real_connect(mysql, NULL, NULL, NULL, "shibd", 0, NULL, 0)) {
172         if (initialized) {
173             log->crit("mysql_real_connect failed: %s", mysql_error(mysql));
174             mysql_close(mysql);
175             throw SAMLException("MySQLBase::getMYSQL(): mysql_real_connect() failed");
176         }
177         else {
178             log->info("mysql_real_connect failed: %s.  Trying to create", mysql_error(mysql));
179
180             // This will throw an exception if it fails.
181             createDatabase(mysql, PLUGIN_VER_MAJOR, PLUGIN_VER_MINOR);
182         }
183     }
184
185     pair<int,int> v=getVersion (mysql);
186
187     // Make sure we've got the right version
188     if (v.first != PLUGIN_VER_MAJOR || v.second != PLUGIN_VER_MINOR) {
189    
190         // If we're capable, try upgrading on the fly...
191         if (v.first == 0  || v.first == 1 || v.first == 2) {
192             if (mysql_query(mysql, "DROP TABLE state")) {
193                 log->error("error dropping old session state table: %s", mysql_error(mysql));
194             }
195             if (v.first==2 && mysql_query(mysql, "DROP TABLE replay")) {
196                 log->error("error dropping old session state table: %s", mysql_error(mysql));
197             }
198             upgradeDatabase(mysql);
199         }
200         else {
201             mysql_close(mysql);
202             log->crit("Unknown database version: %d.%d", v.first, v.second);
203             throw SAMLException("MySQLBase::getMYSQL(): Unknown database version");
204         }
205     }
206
207     // We're all set.. Save off the handle for this thread.
208     m_mysql->setData(mysql);
209     return mysql;
210 }
211
212 bool MySQLBase::repairTable(MYSQL*& mysql, const char* table)
213 {
214     string q = string("REPAIR TABLE ") + table;
215     if (mysql_query(mysql, q.c_str())) {
216         log->error("Error repairing table %s: %s", table, mysql_error(mysql));
217         return false;
218     }
219
220     // seems we have to recycle the connection to get the thread to keep working
221     // other threads seem to be ok, but we should monitor that
222     mysql_close(mysql);
223     m_mysql->setData(NULL);
224     mysql=getMYSQL();
225     return true;
226 }
227
228 void MySQLBase::createDatabase(MYSQL* mysql, int major, int minor)
229 {
230   log->info("creating database");
231
232   MYSQL* ms = NULL;
233   try {
234     ms = mysql_init(NULL);
235     if (!ms) {
236       log->crit("mysql_init failed");
237       throw SAMLException("ShibMySQLCCache::createDatabase(): mysql_init failed");
238     }
239
240     if (!mysql_real_connect(ms, NULL, NULL, NULL, NULL, 0, NULL, 0)) {
241       log->crit("cannot open DB file to create DB: %s", mysql_error(ms));
242       throw SAMLException("ShibMySQLCCache::createDatabase(): mysql_real_connect failed");
243     }
244
245     if (mysql_query(ms, "CREATE DATABASE shibd")) {
246       log->crit("cannot create shibd database: %s", mysql_error(ms));
247       throw SAMLException("ShibMySQLCCache::createDatabase(): create db cmd failed");
248     }
249
250     if (!mysql_real_connect(mysql, NULL, NULL, NULL, "shibd", 0, NULL, 0)) {
251       log->crit("cannot open shibd database");
252       throw SAMLException("ShibMySQLCCache::createDatabase(): mysql_real_connect to plugin db failed");
253     }
254
255     mysql_close(ms);
256     
257   }
258   catch (SAMLException&) {
259     if (ms)
260       mysql_close(ms);
261     mysql_close(mysql);
262     throw;
263   }
264
265   // Now create the tables if they don't exist
266   log->info("Creating database tables");
267
268   if (mysql_query(mysql, "CREATE TABLE version (major INT, minor INT)")) {
269     log->error ("error creating version: %s", mysql_error(mysql));
270     throw SAMLException("ShibMySQLCCache::createDatabase(): create table cmd failed");
271   }
272
273   if (mysql_query(mysql,STATE_TABLE)) {
274     log->error ("error creating state table: %s", mysql_error(mysql));
275     throw SAMLException("ShibMySQLCCache::createDatabase(): create table cmd failed");
276   }
277
278   if (mysql_query(mysql,REPLAY_TABLE)) {
279     log->error ("error creating replay table: %s", mysql_error(mysql));
280     throw SAMLException("ShibMySQLCCache::createDatabase(): create table cmd failed");
281   }
282
283   ostringstream q;
284   q << "INSERT INTO version VALUES(" << major << "," << minor << ")";
285   if (mysql_query(mysql, q.str().c_str())) {
286     log->error ("error setting version: %s", mysql_error(mysql));
287     throw SAMLException("ShibMySQLCCache::createDatabase(): version insert failed");
288   }
289 }
290
291 void MySQLBase::upgradeDatabase(MYSQL* mysql)
292 {
293     if (mysql_query(mysql,STATE_TABLE)) {
294         log->error ("error creating state table: %s", mysql_error(mysql));
295         throw SAMLException("ShibMySQLCCache::upgradeDatabase(): error creating state table");
296     }
297
298     if (mysql_query(mysql,REPLAY_TABLE)) {
299         log->error ("error creating replay table: %s", mysql_error(mysql));
300         throw SAMLException("ShibMySQLCCache::upgradeDatabase(): error creating replay table");
301     }
302
303     ostringstream q;
304     q << "UPDATE version SET major = " << PLUGIN_VER_MAJOR;
305     if (mysql_query(mysql, q.str().c_str())) {
306         log->error ("error updating version: %s", mysql_error(mysql));
307         throw SAMLException("ShibMySQLCCache::upgradeDatabase(): error updating version");
308     }
309 }
310
311 pair<int,int> MySQLBase::getVersion(MYSQL* mysql)
312 {
313     // grab the version number from the database
314     if (mysql_query(mysql, "SELECT * FROM version"))
315         log->error ("Error reading version: %s", mysql_error(mysql));
316
317     MYSQL_RES* rows = mysql_store_result(mysql);
318     if (rows) {
319         if (mysql_num_rows(rows) == 1 && mysql_num_fields(rows) == 2)  {
320           MYSQL_ROW row = mysql_fetch_row(rows);
321           int major = row[0] ? atoi(row[0]) : -1;
322           int minor = row[1] ? atoi(row[1]) : -1;
323           log->debug("opening database version %d.%d", major, minor);
324           mysql_free_result(rows);
325           return make_pair(major,minor);
326         } else {
327             // Wrong number of rows or wrong number of fields...
328             log->crit("Houston, we've got a problem with the database...");
329             mysql_free_result (rows);
330             throw SAMLException("ShibMySQLCCache::getVersion(): version verification failed");
331         }
332     }
333     log->crit("MySQL Read Failed in version verification");
334     throw SAMLException("ShibMySQLCCache::getVersion(): error reading version");
335 }
336
337 static void mysqlInit(const DOMElement* e, Category& log)
338 {
339     static bool done = false;
340     if (done) {
341         log.info("MySQL embedded server already initialized");
342         return;
343     }
344     log.info("initializing MySQL embedded server");
345
346     // Setup the argument array
347     vector<string> arg_array;
348     arg_array.push_back("shibboleth");
349
350     // grab any MySQL parameters from the config file
351     e=saml::XML::getFirstChildElement(e,shibtarget::XML::SHIBTARGET_NS,Argument);
352     while (e) {
353         auto_ptr_char arg(e->getFirstChild()->getNodeValue());
354         if (arg.get())
355             arg_array.push_back(arg.get());
356         e=saml::XML::getNextSiblingElement(e,shibtarget::XML::SHIBTARGET_NS,Argument);
357     }
358
359     // Compute the argument array
360     vector<string>::size_type arg_count = arg_array.size();
361     const char** args=new const char*[arg_count];
362     for (vector<string>::size_type i = 0; i < arg_count; i++)
363         args[i] = arg_array[i].c_str();
364
365     // Initialize MySQL with the arguments
366     mysql_server_init(arg_count, (char **)args, NULL);
367
368     delete[] args;
369     done = true;
370 }  
371
372 class ShibMySQLCCache : public MySQLBase, virtual public ISessionCache, virtual public ISessionCacheStore
373 {
374 public:
375     ShibMySQLCCache(const DOMElement* e);
376     virtual ~ShibMySQLCCache();
377
378     // Delegate all the ISessionCache methods.
379     string insert(
380         const IApplication* application,
381         const IEntityDescriptor* source,
382         const char* client_addr,
383         const SAMLSubject* subject,
384         const char* authnContext,
385         const SAMLResponse* tokens
386         )
387     { return m_cache->insert(application,source,client_addr,subject,authnContext,tokens); }
388     ISessionCacheEntry* find(const char* key, const IApplication* application, const char* client_addr)
389     { return m_cache->find(key,application,client_addr); }
390     void remove(const char* key, const IApplication* application, const char* client_addr)
391     { m_cache->remove(key,application,client_addr); }
392
393     bool setBackingStore(ISessionCacheStore*) {return false;}
394
395     // Store methods handle the database work
396     HRESULT onCreate(
397         const char* key,
398         const IApplication* application,
399         const ISessionCacheEntry* entry,
400         int majorVersion,
401         int minorVersion,
402         time_t created
403         );
404     HRESULT onRead(
405         const char* key,
406         string& applicationId,
407         string& clientAddress,
408         string& providerId,
409         string& subject,
410         string& authnContext,
411         string& tokens,
412         int& majorVersion,
413         int& minorVersion,
414         time_t& created,
415         time_t& accessed
416         );
417     HRESULT onRead(const char* key, time_t& accessed);
418     HRESULT onRead(const char* key, string& tokens);
419     HRESULT onUpdate(const char* key, const char* tokens=NULL, time_t accessed=0);
420     HRESULT onDelete(const char* key);
421
422     void cleanup();
423
424 private:
425     bool m_storeAttributes;
426     ISessionCache* m_cache;
427     CondWait* shutdown_wait;
428     bool shutdown;
429     Thread* cleanup_thread;
430
431     static void* cleanup_fcn(void*); // XXX Assumed an ShibMySQLCCache
432 };
433
434 ShibMySQLCCache::ShibMySQLCCache(const DOMElement* e) : MySQLBase(e), m_storeAttributes(false)
435 {
436 #ifdef _DEBUG
437     saml::NDC ndc("ShibMySQLCCache");
438 #endif
439     log = &(Category::getInstance("shibmysql.SessionCache"));
440
441     m_cache = dynamic_cast<ISessionCache*>(
442         SAMLConfig::getConfig().getPlugMgr().newPlugin(shibtarget::XML::MemorySessionCacheType, e)
443     );
444     if (!m_cache->setBackingStore(this)) {
445         delete m_cache;
446         throw SAMLException("Unable to register MySQL cache plugin as a cache store.");
447     }
448     
449     shutdown_wait = CondWait::create();
450     shutdown = false;
451
452     // Load our configuration details...
453     const XMLCh* tag=m_root->getAttributeNS(NULL,storeAttributes);
454     if (tag && *tag && (*tag==chLatin_t || *tag==chDigit_1))
455         m_storeAttributes=true;
456
457     // Initialize the cleanup thread
458     cleanup_thread = Thread::create(&cleanup_fcn, (void*)this);
459 }
460
461 ShibMySQLCCache::~ShibMySQLCCache()
462 {
463     shutdown = true;
464     shutdown_wait->signal();
465     cleanup_thread->join(NULL);
466     delete m_cache;
467 }
468
469 HRESULT ShibMySQLCCache::onCreate(
470     const char* key,
471     const IApplication* application,
472     const ISessionCacheEntry* entry,
473     int majorVersion,
474     int minorVersion,
475     time_t created
476     )
477 {
478 #ifdef _DEBUG
479     saml::NDC ndc("onCreate");
480 #endif
481
482     // Get XML data from entry. Default is not to return SAML objects.
483     const char* context=entry->getAuthnContext();
484     pair<const char*,const SAMLSubject*> subject=entry->getSubject();
485     pair<const char*,const SAMLResponse*> tokens=entry->getTokens();
486
487     ostringstream q;
488     q << "INSERT INTO state VALUES('" << key << "','" << application->getId() << "',";
489     if (created==0)
490         q << "NOW(),NOW(),'";
491     else
492         q << "FROM_UNIXTIME(" << created << "),NOW(),'";
493     q << entry->getClientAddress() << "'," << majorVersion << "," << minorVersion << ",'" << entry->getProviderId() << "','"
494         << subject.first << "','" << context << "',";
495
496     if (m_storeAttributes && tokens.first)
497         q << "'" << tokens.first << "')";
498     else
499         q << "null)";
500
501     if (log->isDebugEnabled())
502         log->debug("SQL insert: %s", q.str().c_str());
503
504     // then add it to the database
505     MYSQL* mysql = getMYSQL();
506     if (mysql_query(mysql, q.str().c_str())) {
507         const char* err=mysql_error(mysql);
508         log->error("error inserting %s: %s", key, err);
509         if (isCorrupt(err) && repairTable(mysql,"state")) {
510             // Try again...
511             if (mysql_query(mysql, q.str().c_str())) {
512                 log->error("error inserting %s: %s", key, mysql_error(mysql));
513                 return E_FAIL;
514             }
515         }
516         else
517             throw E_FAIL;
518     }
519
520     return NOERROR;
521 }
522
523 HRESULT ShibMySQLCCache::onRead(
524     const char* key,
525     string& applicationId,
526     string& clientAddress,
527     string& providerId,
528     string& subject,
529     string& authnContext,
530     string& tokens,
531     int& majorVersion,
532     int& minorVersion,
533     time_t& created,
534     time_t& accessed
535     )
536 {
537 #ifdef _DEBUG
538     saml::NDC ndc("onRead");
539 #endif
540
541     log->debug("searching MySQL database...");
542
543     string q = string("SELECT application_id,UNIX_TIMESTAMP(ctime),UNIX_TIMESTAMP(atime),addr,major,minor,provider,subject,authn_context,tokens FROM state WHERE cookie='") + key + "' LIMIT 1";
544
545     MYSQL* mysql = getMYSQL();
546     if (mysql_query(mysql, q.c_str())) {
547         const char* err=mysql_error(mysql);
548         log->error("error searching for %s: %s", key, err);
549         if (isCorrupt(err) && repairTable(mysql,"state")) {
550             if (mysql_query(mysql, q.c_str()))
551                 log->error("error retrying search for %s: %s", key, mysql_error(mysql));
552         }
553     }
554
555     MYSQL_RES* rows = mysql_store_result(mysql);
556
557     // Nope, doesn't exist.
558     if (!rows)
559         return S_FALSE;
560
561     // Make sure we got 1 and only 1 rows.
562     if (mysql_num_rows(rows) != 1) {
563         log->error("Database select returned wrong number of rows: %d", mysql_num_rows(rows));
564         mysql_free_result(rows);
565         return S_FALSE;
566     }
567
568     log->debug("match found, tranfering data back into memory");
569     
570     /* Columns in query:
571         0: application_id
572         1: ctime
573         2: atime
574         3: address
575         4: major
576         5: minor
577         6: provider
578         7: subject
579         8: authncontext
580         9: tokens
581      */
582
583     MYSQL_ROW row = mysql_fetch_row(rows);
584     applicationId=row[0];
585     created=atoi(row[1]);
586     accessed=atoi(row[2]);
587     clientAddress=row[3];
588     majorVersion=atoi(row[4]);
589     minorVersion=atoi(row[5]);
590     providerId=row[6];
591     subject=row[7];
592     authnContext=row[8];
593     if (row[9])
594         tokens=row[9];
595
596     // Free the results.
597     mysql_free_result(rows);
598
599     return NOERROR;
600 }
601
602 HRESULT ShibMySQLCCache::onRead(const char* key, time_t& accessed)
603 {
604 #ifdef _DEBUG
605     saml::NDC ndc("onRead");
606 #endif
607
608     log->debug("reading last access time from MySQL database");
609
610     string q = string("SELECT UNIX_TIMESTAMP(atime) FROM state WHERE cookie='") + key + "' LIMIT 1";
611
612     MYSQL* mysql = getMYSQL();
613     if (mysql_query(mysql, q.c_str())) {
614         const char* err=mysql_error(mysql);
615         log->error("error searching for %s: %s", key, err);
616         if (isCorrupt(err) && repairTable(mysql,"state")) {
617             if (mysql_query(mysql, q.c_str()))
618                 log->error("error retrying search for %s: %s", key, mysql_error(mysql));
619         }
620     }
621
622     MYSQL_RES* rows = mysql_store_result(mysql);
623
624     // Nope, doesn't exist.
625     if (!rows)
626         return S_FALSE;
627
628     // Make sure we got 1 and only 1 rows.
629     if (mysql_num_rows(rows) != 1) {
630         log->error("database select returned wrong number of rows: %d", mysql_num_rows(rows));
631         mysql_free_result(rows);
632         return S_FALSE;
633     }
634
635     MYSQL_ROW row = mysql_fetch_row(rows);
636     accessed=atoi(row[0]);
637
638     // Free the results.
639     mysql_free_result(rows);
640
641     return NOERROR;
642 }
643
644 HRESULT ShibMySQLCCache::onRead(const char* key, string& tokens)
645 {
646 #ifdef _DEBUG
647     saml::NDC ndc("onRead");
648 #endif
649
650     if (!m_storeAttributes)
651         return S_FALSE;
652
653     log->debug("reading cached tokens from MySQL database");
654
655     string q = string("SELECT tokens FROM state WHERE cookie='") + key + "' LIMIT 1";
656
657     MYSQL* mysql = getMYSQL();
658     if (mysql_query(mysql, q.c_str())) {
659         const char* err=mysql_error(mysql);
660         log->error("error searching for %s: %s", key, err);
661         if (isCorrupt(err) && repairTable(mysql,"state")) {
662             if (mysql_query(mysql, q.c_str()))
663                 log->error("error retrying search for %s: %s", key, mysql_error(mysql));
664         }
665     }
666
667     MYSQL_RES* rows = mysql_store_result(mysql);
668
669     // Nope, doesn't exist.
670     if (!rows)
671         return S_FALSE;
672
673     // Make sure we got 1 and only 1 rows.
674     if (mysql_num_rows(rows) != 1) {
675         log->error("database select returned wrong number of rows: %d", mysql_num_rows(rows));
676         mysql_free_result(rows);
677         return S_FALSE;
678     }
679
680     MYSQL_ROW row = mysql_fetch_row(rows);
681     if (row[0])
682         tokens=row[0];
683
684     // Free the results.
685     mysql_free_result(rows);
686
687     return NOERROR;
688 }
689
690 HRESULT ShibMySQLCCache::onUpdate(const char* key, const char* tokens, time_t lastAccess)
691 {
692 #ifdef _DEBUG
693     saml::NDC ndc("onUpdate");
694 #endif
695
696     ostringstream q;
697     if (lastAccess>0)
698         q << "UPDATE state SET atime=FROM_UNIXTIME(" << lastAccess << ")";
699     else if (tokens) {
700         if (!m_storeAttributes)
701             return S_FALSE;
702         q << "UPDATE state SET tokens=";
703         if (*tokens)
704             q << "'" << tokens << "'";
705         else
706             q << "null";
707     }
708     else {
709         log->warn("onUpdate called with nothing to do!");
710         return S_FALSE;
711     }
712  
713     q << " WHERE cookie='" << key << "'";
714
715     MYSQL* mysql = getMYSQL();
716     if (mysql_query(mysql, q.str().c_str())) {
717         const char* err=mysql_error(mysql);
718         log->error("error updating %s: %s", key, err);
719         if (isCorrupt(err) && repairTable(mysql,"state")) {
720             // Try again...
721             if (mysql_query(mysql, q.str().c_str())) {
722                 log->error("error updating %s: %s", key, mysql_error(mysql));
723                 return E_FAIL;
724             }
725         }
726         else
727             return E_FAIL;
728     }
729
730     return NOERROR;
731 }
732
733 HRESULT ShibMySQLCCache::onDelete(const char* key)
734 {
735 #ifdef _DEBUG
736     saml::NDC ndc("onDelete");
737 #endif
738
739     // Remove from the database
740     string q = string("DELETE FROM state WHERE cookie='") + key + "'";
741     MYSQL* mysql = getMYSQL();
742     if (mysql_query(mysql, q.c_str())) {
743         const char* err=mysql_error(mysql);
744         log->error("error deleting entry %s: %s", key, err);
745         if (isCorrupt(err) && repairTable(mysql,"state")) {
746             // Try again...
747             if (mysql_query(mysql, q.c_str())) {
748                 log->error("error deleting entry %s: %s", key, mysql_error(mysql));
749                 return E_FAIL;
750             }
751         }
752         else
753             return E_FAIL;
754     }
755
756     return NOERROR;
757 }
758
759 void ShibMySQLCCache::cleanup()
760 {
761 #ifdef _DEBUG
762   saml::NDC ndc("cleanup");
763 #endif
764
765   Mutex* mutex = Mutex::create();
766
767   int rerun_timer = 0;
768   int timeout_life = 0;
769
770   // Load our configuration details...
771   const XMLCh* tag=m_root->getAttributeNS(NULL,cleanupInterval);
772   if (tag && *tag)
773     rerun_timer = XMLString::parseInt(tag);
774
775   // search for 'mysql-cache-timeout' and then the regular cache timeout
776   tag=m_root->getAttributeNS(NULL,mysqlTimeout);
777   if (tag && *tag)
778     timeout_life = XMLString::parseInt(tag);
779   else {
780       tag=m_root->getAttributeNS(NULL,cacheTimeout);
781       if (tag && *tag)
782         timeout_life = XMLString::parseInt(tag);
783   }
784   
785   if (rerun_timer <= 0)
786     rerun_timer = 300;          // rerun every 5 minutes
787
788   if (timeout_life <= 0)
789     timeout_life = 28800;       // timeout after 8 hours
790
791   mutex->lock();
792
793   MYSQL* mysql = getMYSQL();
794
795   log->info("cleanup thread started...Run every %d secs; timeout after %d secs", rerun_timer, timeout_life);
796
797   while (shutdown == false) {
798     shutdown_wait->timedwait(mutex, rerun_timer);
799
800     if (shutdown == true)
801       break;
802
803     // Find all the entries in the database that haven't been used
804     // recently In particular, find all entries that have not been
805     // accessed in 'timeout_life' seconds.
806     ostringstream q;
807     q << "DELETE FROM state WHERE " << "UNIX_TIMESTAMP(NOW()) - UNIX_TIMESTAMP(atime) >= " << timeout_life;
808
809     if (mysql_query(mysql, q.str().c_str())) {
810       const char* err=mysql_error(mysql);
811       log->error("error purging old records: %s", err);
812         if (isCorrupt(err) && repairTable(mysql,"state")) {
813           if (mysql_query(mysql, q.str().c_str()))
814             log->error("error re-purging old records: %s", mysql_error(mysql));
815         }
816     }
817   }
818
819   log->info("cleanup thread exiting...");
820
821   mutex->unlock();
822   delete mutex;
823   Thread::exit(NULL);
824 }
825
826 void* ShibMySQLCCache::cleanup_fcn(void* cache_p)
827 {
828   ShibMySQLCCache* cache = (ShibMySQLCCache*)cache_p;
829
830   // First, let's block all signals
831   Thread::mask_all_signals();
832
833   // Now run the cleanup process.
834   cache->cleanup();
835   return NULL;
836 }
837
838 class MySQLReplayCache : public MySQLBase, virtual public IReplayCache
839 {
840 public:
841   MySQLReplayCache(const DOMElement* e);
842   virtual ~MySQLReplayCache() {}
843
844   bool check(const XMLCh* str, time_t expires) {auto_ptr_XMLCh temp(str); return check(temp.get(),expires);}
845   bool check(const char* str, time_t expires);
846 };
847
848 MySQLReplayCache::MySQLReplayCache(const DOMElement* e) : MySQLBase(e)
849 {
850 #ifdef _DEBUG
851   saml::NDC ndc("MySQLReplayCache");
852 #endif
853
854   log = &(Category::getInstance("shibmysql.ReplayCache"));
855 }
856
857 bool MySQLReplayCache::check(const char* str, time_t expires)
858 {
859 #ifdef _DEBUG
860     saml::NDC ndc("check");
861 #endif
862   
863     // Remove expired entries
864     string q = string("DELETE FROM replay WHERE expires < NOW()");
865     MYSQL* mysql = getMYSQL();
866     if (mysql_query(mysql, q.c_str())) {
867         const char* err=mysql_error(mysql);
868         log->error("Error deleting expired entries: %s", err);
869         if (isCorrupt(err) && repairTable(mysql,"replay")) {
870             // Try again...
871             if (mysql_query(mysql, q.c_str()))
872                 log->error("Error deleting expired entries: %s", mysql_error(mysql));
873         }
874     }
875   
876     string q2 = string("SELECT id FROM replay WHERE id='") + str + "'";
877     if (mysql_query(mysql, q2.c_str())) {
878         const char* err=mysql_error(mysql);
879         log->error("Error searching for %s: %s", str, err);
880         if (isCorrupt(err) && repairTable(mysql,"replay")) {
881             if (mysql_query(mysql, q2.c_str())) {
882                 log->error("Error retrying search for %s: %s", str, mysql_error(mysql));
883                 throw SAMLException("Replay cache failed, please inform application support staff.");
884             }
885         }
886         else
887             throw SAMLException("Replay cache failed, please inform application support staff.");
888     }
889
890     // Did we find it?
891     MYSQL_RES* rows = mysql_store_result(mysql);
892     if (rows && mysql_num_rows(rows)>0) {
893       mysql_free_result(rows);
894       return false;
895     }
896
897     ostringstream q3;
898     q3 << "INSERT INTO replay VALUES('" << str << "'," << "FROM_UNIXTIME(" << expires << "))";
899
900     // then add it to the database
901     if (mysql_query(mysql, q3.str().c_str())) {
902         const char* err=mysql_error(mysql);
903         log->error("Error inserting %s: %s", str, err);
904         if (isCorrupt(err) && repairTable(mysql,"state")) {
905             // Try again...
906             if (mysql_query(mysql, q3.str().c_str())) {
907                 log->error("Error inserting %s: %s", str, mysql_error(mysql));
908                 throw SAMLException("Replay cache failed, please inform application support staff.");
909             }
910         }
911         else
912             throw SAMLException("Replay cache failed, please inform application support staff.");
913     }
914     
915     return true;
916 }
917
918 /*************************************************************************
919  * The registration functions here...
920  */
921
922 IPlugIn* new_mysql_ccache(const DOMElement* e)
923 {
924     return new ShibMySQLCCache(e);
925 }
926
927 IPlugIn* new_mysql_replay(const DOMElement* e)
928 {
929     return new MySQLReplayCache(e);
930 }
931
932 extern "C" int SHIBMYSQL_EXPORTS saml_extension_init(void*)
933 {
934     // register this ccache type
935     SAMLConfig::getConfig().getPlugMgr().regFactory(shibtarget::XML::MySQLReplayCacheType, &new_mysql_replay);
936     SAMLConfig::getConfig().getPlugMgr().regFactory(shibtarget::XML::MySQLSessionCacheType, &new_mysql_ccache);
937     return 0;
938 }
939
940 extern "C" void SHIBMYSQL_EXPORTS saml_extension_term()
941 {
942     // Shutdown MySQL
943     mysql_server_end();
944     SAMLConfig::getConfig().getPlugMgr().unregFactory(shibtarget::XML::MySQLReplayCacheType);
945     SAMLConfig::getConfig().getPlugMgr().unregFactory(shibtarget::XML::MySQLSessionCacheType);
946 }