c69ce0954795bddaec25fa0ab2ac69ab920aaf78
[shibboleth/cpp-sp.git] / shib-mysql-ccache / shib-mysql-ccache.cpp
1 /*
2  *  Copyright 2001-2005 Internet2
3  * 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /*
18  * shib-mysql-ccache.cpp: Shibboleth Credential Cache using MySQL.
19  *
20  * Created by:  Derek Atkins <derek@ihtfp.com>
21  *
22  * $Id$
23  */
24
25 // eventually we might be able to support autoconf via cygwin...
26 #if defined (_MSC_VER) || defined(__BORLANDC__)
27 # include "config_win32.h"
28 #else
29 # include "config.h"
30 #endif
31
32 #ifdef WIN32
33 # define SHIBMYSQL_EXPORTS __declspec(dllexport)
34 #else
35 # define SHIBMYSQL_EXPORTS
36 #endif
37
38 #ifdef HAVE_UNISTD_H
39 # include <unistd.h>
40 #endif
41
42 #include <shib-target/shib-target.h>
43
44 #include <xmltooling/util/NDC.h>
45 #include <log4cpp/Category.hh>
46
47 #include <sstream>
48
49 #ifdef WIN32
50 # include <winsock.h>
51 #endif
52 #include <mysql.h>
53
54 // wanted to use MySQL codes for this, but can't seem to get back a 145
55 #define isCorrupt(s) strstr(s,"(errno: 145)")
56
57 #ifdef HAVE_LIBDMALLOCXX
58 #include <dmalloc.h>
59 #endif
60
61 using namespace shibtarget;
62 using namespace shibboleth;
63 using namespace saml;
64 using namespace log4cpp;
65 using namespace std;
66
67 #define PLUGIN_VER_MAJOR 3
68 #define PLUGIN_VER_MINOR 0
69
70 #define STATE_TABLE \
71   "CREATE TABLE state (" \
72   "cookie VARCHAR(64) PRIMARY KEY, " \
73   "application_id VARCHAR(255)," \
74   "ctime TIMESTAMP," \
75   "atime TIMESTAMP," \
76   "addr VARCHAR(128)," \
77   "major INT," \
78   "minor INT," \
79   "provider VARCHAR(256)," \
80   "subject TEXT," \
81   "authn_context TEXT," \
82   "tokens TEXT)"
83
84 #define REPLAY_TABLE \
85   "CREATE TABLE replay (id VARCHAR(255) PRIMARY KEY, " \
86   "expires TIMESTAMP, " \
87   "INDEX (expires))"
88
89 static const XMLCh Argument[] =
90 { chLatin_A, chLatin_r, chLatin_g, chLatin_u, chLatin_m, chLatin_e, chLatin_n, chLatin_t, chNull };
91 static const XMLCh cleanupInterval[] =
92 { chLatin_c, chLatin_l, chLatin_e, chLatin_a, chLatin_n, chLatin_u, chLatin_p,
93   chLatin_I, chLatin_n, chLatin_t, chLatin_e, chLatin_r, chLatin_v, chLatin_a, chLatin_l, chNull
94 };
95 static const XMLCh cacheTimeout[] =
96 { chLatin_c, chLatin_a, chLatin_c, chLatin_h, chLatin_e, chLatin_T, chLatin_i, chLatin_m, chLatin_e, chLatin_o, chLatin_u, chLatin_t, chNull };
97 static const XMLCh mysqlTimeout[] =
98 { chLatin_m, chLatin_y, chLatin_s, chLatin_q, chLatin_l, chLatin_T, chLatin_i, chLatin_m, chLatin_e, chLatin_o, chLatin_u, chLatin_t, chNull };
99 static const XMLCh storeAttributes[] =
100 { chLatin_s, chLatin_t, chLatin_o, chLatin_r, chLatin_e, chLatin_A, chLatin_t, chLatin_t, chLatin_r, chLatin_i, chLatin_b, chLatin_u, chLatin_t, chLatin_e, chLatin_s, chNull };
101
102 static bool g_MySQLInitialized = false;
103
104 class MySQLBase : public virtual saml::IPlugIn
105 {
106 public:
107   MySQLBase(const DOMElement* e);
108   virtual ~MySQLBase();
109
110   MYSQL* getMYSQL();
111   bool repairTable(MYSQL*&, const char* table);
112
113   log4cpp::Category* log;
114
115 protected:
116     xmltooling::ThreadKey* m_mysql;
117   const DOMElement* m_root; // can only use this during initialization
118
119   bool initialized;
120   bool handleShutdown;
121
122   void createDatabase(MYSQL*, int major, int minor);
123   void upgradeDatabase(MYSQL*);
124   pair<int,int> getVersion(MYSQL*);
125 };
126
127 // Forward declarations
128 static void mysqlInit(const DOMElement* e, Category& log);
129
130 extern "C" void shib_mysql_destroy_handle(void* data)
131 {
132   MYSQL* mysql = (MYSQL*) data;
133   if (mysql) mysql_close(mysql);
134 }
135
136 MySQLBase::MySQLBase(const DOMElement* e) : m_root(e)
137 {
138 #ifdef _DEBUG
139   xmltooling::NDC ndc("MySQLBase");
140 #endif
141   log = &(Category::getInstance("shibtarget.SessionCache.MySQL"));
142
143   m_mysql = xmltooling::ThreadKey::create(&shib_mysql_destroy_handle);
144
145   initialized = false;
146   mysqlInit(e,*log);
147   getMYSQL();
148   initialized = true;
149 }
150
151 MySQLBase::~MySQLBase()
152 {
153   delete m_mysql;
154 }
155
156 MYSQL* MySQLBase::getMYSQL()
157 {
158 #ifdef _DEBUG
159     xmltooling::NDC ndc("getMYSQL");
160 #endif
161
162     // Do we already have a handle?
163     MYSQL* mysql=reinterpret_cast<MYSQL*>(m_mysql->getData());
164     if (mysql)
165         return mysql;
166
167     // Connect to the database
168     mysql = mysql_init(NULL);
169     if (!mysql) {
170         log->error("mysql_init failed");
171         mysql_close(mysql);
172         throw SAMLException("MySQLBase::getMYSQL(): mysql_init() failed");
173     }
174
175     if (!mysql_real_connect(mysql, NULL, NULL, NULL, "shibd", 0, NULL, 0)) {
176         if (initialized) {
177             log->crit("mysql_real_connect failed: %s", mysql_error(mysql));
178             mysql_close(mysql);
179             throw SAMLException("MySQLBase::getMYSQL(): mysql_real_connect() failed");
180         }
181         else {
182             log->info("mysql_real_connect failed: %s.  Trying to create", mysql_error(mysql));
183
184             // This will throw an exception if it fails.
185             createDatabase(mysql, PLUGIN_VER_MAJOR, PLUGIN_VER_MINOR);
186         }
187     }
188
189     pair<int,int> v=getVersion (mysql);
190
191     // Make sure we've got the right version
192     if (v.first != PLUGIN_VER_MAJOR || v.second != PLUGIN_VER_MINOR) {
193    
194         // If we're capable, try upgrading on the fly...
195         if (v.first == 0  || v.first == 1 || v.first == 2) {
196             if (mysql_query(mysql, "DROP TABLE state")) {
197                 log->error("error dropping old session state table: %s", mysql_error(mysql));
198             }
199             if (v.first==2 && mysql_query(mysql, "DROP TABLE replay")) {
200                 log->error("error dropping old session state table: %s", mysql_error(mysql));
201             }
202             upgradeDatabase(mysql);
203         }
204         else {
205             mysql_close(mysql);
206             log->crit("Unknown database version: %d.%d", v.first, v.second);
207             throw SAMLException("MySQLBase::getMYSQL(): Unknown database version");
208         }
209     }
210
211     // We're all set.. Save off the handle for this thread.
212     m_mysql->setData(mysql);
213     return mysql;
214 }
215
216 bool MySQLBase::repairTable(MYSQL*& mysql, const char* table)
217 {
218     string q = string("REPAIR TABLE ") + table;
219     if (mysql_query(mysql, q.c_str())) {
220         log->error("Error repairing table %s: %s", table, mysql_error(mysql));
221         return false;
222     }
223
224     // seems we have to recycle the connection to get the thread to keep working
225     // other threads seem to be ok, but we should monitor that
226     mysql_close(mysql);
227     m_mysql->setData(NULL);
228     mysql=getMYSQL();
229     return true;
230 }
231
232 void MySQLBase::createDatabase(MYSQL* mysql, int major, int minor)
233 {
234   log->info("creating database");
235
236   MYSQL* ms = NULL;
237   try {
238     ms = mysql_init(NULL);
239     if (!ms) {
240       log->crit("mysql_init failed");
241       throw SAMLException("ShibMySQLCCache::createDatabase(): mysql_init failed");
242     }
243
244     if (!mysql_real_connect(ms, NULL, NULL, NULL, NULL, 0, NULL, 0)) {
245       log->crit("cannot open DB file to create DB: %s", mysql_error(ms));
246       throw SAMLException("ShibMySQLCCache::createDatabase(): mysql_real_connect failed");
247     }
248
249     if (mysql_query(ms, "CREATE DATABASE shibd")) {
250       log->crit("cannot create shibd database: %s", mysql_error(ms));
251       throw SAMLException("ShibMySQLCCache::createDatabase(): create db cmd failed");
252     }
253
254     if (!mysql_real_connect(mysql, NULL, NULL, NULL, "shibd", 0, NULL, 0)) {
255       log->crit("cannot open shibd database");
256       throw SAMLException("ShibMySQLCCache::createDatabase(): mysql_real_connect to plugin db failed");
257     }
258
259     mysql_close(ms);
260     
261   }
262   catch (SAMLException&) {
263     if (ms)
264       mysql_close(ms);
265     mysql_close(mysql);
266     throw;
267   }
268
269   // Now create the tables if they don't exist
270   log->info("Creating database tables");
271
272   if (mysql_query(mysql, "CREATE TABLE version (major INT, minor INT)")) {
273     log->error ("error creating version: %s", mysql_error(mysql));
274     throw SAMLException("ShibMySQLCCache::createDatabase(): create table cmd failed");
275   }
276
277   if (mysql_query(mysql,STATE_TABLE)) {
278     log->error ("error creating state table: %s", mysql_error(mysql));
279     throw SAMLException("ShibMySQLCCache::createDatabase(): create table cmd failed");
280   }
281
282   if (mysql_query(mysql,REPLAY_TABLE)) {
283     log->error ("error creating replay table: %s", mysql_error(mysql));
284     throw SAMLException("ShibMySQLCCache::createDatabase(): create table cmd failed");
285   }
286
287   ostringstream q;
288   q << "INSERT INTO version VALUES(" << major << "," << minor << ")";
289   if (mysql_query(mysql, q.str().c_str())) {
290     log->error ("error setting version: %s", mysql_error(mysql));
291     throw SAMLException("ShibMySQLCCache::createDatabase(): version insert failed");
292   }
293 }
294
295 void MySQLBase::upgradeDatabase(MYSQL* mysql)
296 {
297     if (mysql_query(mysql,STATE_TABLE)) {
298         log->error ("error creating state table: %s", mysql_error(mysql));
299         throw SAMLException("ShibMySQLCCache::upgradeDatabase(): error creating state table");
300     }
301
302     if (mysql_query(mysql,REPLAY_TABLE)) {
303         log->error ("error creating replay table: %s", mysql_error(mysql));
304         throw SAMLException("ShibMySQLCCache::upgradeDatabase(): error creating replay table");
305     }
306
307     ostringstream q;
308     q << "UPDATE version SET major = " << PLUGIN_VER_MAJOR;
309     if (mysql_query(mysql, q.str().c_str())) {
310         log->error ("error updating version: %s", mysql_error(mysql));
311         throw SAMLException("ShibMySQLCCache::upgradeDatabase(): error updating version");
312     }
313 }
314
315 pair<int,int> MySQLBase::getVersion(MYSQL* mysql)
316 {
317     // grab the version number from the database
318     if (mysql_query(mysql, "SELECT * FROM version")) {
319         log->error("error reading version: %s", mysql_error(mysql));
320         throw SAMLException("MySQLBase::getVersion(): error reading version");
321     }
322
323     MYSQL_RES* rows = mysql_store_result(mysql);
324     if (rows) {
325         if (mysql_num_rows(rows) == 1 && mysql_num_fields(rows) == 2)  {
326           MYSQL_ROW row = mysql_fetch_row(rows);
327           int major = row[0] ? atoi(row[0]) : -1;
328           int minor = row[1] ? atoi(row[1]) : -1;
329           log->debug("opening database version %d.%d", major, minor);
330           mysql_free_result(rows);
331           return make_pair(major,minor);
332         }
333         else {
334             // Wrong number of rows or wrong number of fields...
335             log->crit("Houston, we've got a problem with the database...");
336             mysql_free_result(rows);
337             throw SAMLException("MySQLBase::getVersion(): version verification failed");
338         }
339     }
340     log->crit("MySQL Read Failed in version verification");
341     throw SAMLException("MySQLBase::getVersion(): error reading version");
342 }
343
344 static void mysqlInit(const DOMElement* e, Category& log)
345 {
346     if (g_MySQLInitialized) {
347         log.info("MySQL embedded server already initialized");
348         return;
349     }
350     log.info("initializing MySQL embedded server");
351
352     // Setup the argument array
353     vector<string> arg_array;
354     arg_array.push_back("shibboleth");
355
356     // grab any MySQL parameters from the config file
357     e=saml::XML::getFirstChildElement(e,shibtarget::XML::SHIBTARGET_NS,Argument);
358     while (e) {
359         auto_ptr_char arg(e->getFirstChild()->getNodeValue());
360         if (arg.get())
361             arg_array.push_back(arg.get());
362         e=saml::XML::getNextSiblingElement(e,shibtarget::XML::SHIBTARGET_NS,Argument);
363     }
364
365     // Compute the argument array
366     vector<string>::size_type arg_count = arg_array.size();
367     const char** args=new const char*[arg_count];
368     for (vector<string>::size_type i = 0; i < arg_count; i++)
369         args[i] = arg_array[i].c_str();
370
371     // Initialize MySQL with the arguments
372     mysql_server_init(arg_count, (char **)args, NULL);
373
374     delete[] args;
375     g_MySQLInitialized = true;
376 }  
377
378 class ShibMySQLCCache : public MySQLBase, virtual public ISessionCache, virtual public ISessionCacheStore
379 {
380 public:
381     ShibMySQLCCache(const DOMElement* e);
382     virtual ~ShibMySQLCCache();
383
384     // Delegate all the ISessionCache methods.
385     string insert(
386         const IApplication* application,
387         const IEntityDescriptor* source,
388         const char* client_addr,
389         const SAMLSubject* subject,
390         const char* authnContext,
391         const SAMLResponse* tokens
392         )
393     { return m_cache->insert(application,source,client_addr,subject,authnContext,tokens); }
394     ISessionCacheEntry* find(const char* key, const IApplication* application, const char* client_addr)
395     { return m_cache->find(key,application,client_addr); }
396     void remove(const char* key, const IApplication* application, const char* client_addr)
397     { m_cache->remove(key,application,client_addr); }
398
399     bool setBackingStore(ISessionCacheStore*) {return false;}
400
401     // Store methods handle the database work
402     HRESULT onCreate(
403         const char* key,
404         const IApplication* application,
405         const ISessionCacheEntry* entry,
406         int majorVersion,
407         int minorVersion,
408         time_t created
409         );
410     HRESULT onRead(
411         const char* key,
412         string& applicationId,
413         string& clientAddress,
414         string& providerId,
415         string& subject,
416         string& authnContext,
417         string& tokens,
418         int& majorVersion,
419         int& minorVersion,
420         time_t& created,
421         time_t& accessed
422         );
423     HRESULT onRead(const char* key, time_t& accessed);
424     HRESULT onRead(const char* key, string& tokens);
425     HRESULT onUpdate(const char* key, const char* tokens=NULL, time_t accessed=0);
426     HRESULT onDelete(const char* key);
427
428     void cleanup();
429
430 private:
431     bool m_storeAttributes;
432     ISessionCache* m_cache;
433     xmltooling::CondWait* shutdown_wait;
434     bool shutdown;
435     xmltooling::Thread* cleanup_thread;
436
437     static void* cleanup_fcn(void*); // XXX Assumed an ShibMySQLCCache
438 };
439
440 ShibMySQLCCache::ShibMySQLCCache(const DOMElement* e) : MySQLBase(e), m_storeAttributes(false)
441 {
442 #ifdef _DEBUG
443     xmltooling::NDC ndc("ShibMySQLCCache");
444 #endif
445
446     m_cache = dynamic_cast<ISessionCache*>(
447         SAMLConfig::getConfig().getPlugMgr().newPlugin(shibtarget::XML::MemorySessionCacheType, e)
448     );
449     if (!m_cache->setBackingStore(this)) {
450         delete m_cache;
451         throw SAMLException("Unable to register MySQL cache plugin as a cache store.");
452     }
453     
454     shutdown_wait = xmltooling::CondWait::create();
455     shutdown = false;
456
457     // Load our configuration details...
458     const XMLCh* tag=m_root->getAttributeNS(NULL,storeAttributes);
459     if (tag && *tag && (*tag==chLatin_t || *tag==chDigit_1))
460         m_storeAttributes=true;
461
462     // Initialize the cleanup thread
463     cleanup_thread = xmltooling::Thread::create(&cleanup_fcn, (void*)this);
464 }
465
466 ShibMySQLCCache::~ShibMySQLCCache()
467 {
468     shutdown = true;
469     shutdown_wait->signal();
470     cleanup_thread->join(NULL);
471     delete m_cache;
472 }
473
474 HRESULT ShibMySQLCCache::onCreate(
475     const char* key,
476     const IApplication* application,
477     const ISessionCacheEntry* entry,
478     int majorVersion,
479     int minorVersion,
480     time_t created
481     )
482 {
483 #ifdef _DEBUG
484     xmltooling::NDC ndc("onCreate");
485 #endif
486
487     // Get XML data from entry. Default is not to return SAML objects.
488     const char* context=entry->getAuthnContext();
489     pair<const char*,const SAMLSubject*> subject=entry->getSubject();
490     pair<const char*,const SAMLResponse*> tokens=entry->getTokens();
491
492     ostringstream q;
493     q << "INSERT INTO state VALUES('" << key << "','" << application->getId() << "',";
494     if (created==0)
495         q << "NOW(),NOW(),'";
496     else
497         q << "FROM_UNIXTIME(" << created << "),NOW(),'";
498     q << entry->getClientAddress() << "'," << majorVersion << "," << minorVersion << ",'" << entry->getProviderId() << "','"
499         << subject.first << "','" << context << "',";
500
501     if (m_storeAttributes && tokens.first)
502         q << "'" << tokens.first << "')";
503     else
504         q << "null)";
505
506     if (log->isDebugEnabled())
507         log->debug("SQL insert: %s", q.str().c_str());
508
509     // then add it to the database
510     MYSQL* mysql = getMYSQL();
511     if (mysql_query(mysql, q.str().c_str())) {
512         const char* err=mysql_error(mysql);
513         log->error("error inserting %s: %s", key, err);
514         if (isCorrupt(err) && repairTable(mysql,"state")) {
515             // Try again...
516             if (mysql_query(mysql, q.str().c_str())) {
517                 log->error("error inserting %s: %s", key, mysql_error(mysql));
518                 return E_FAIL;
519             }
520         }
521         else
522             throw E_FAIL;
523     }
524
525     return NOERROR;
526 }
527
528 HRESULT ShibMySQLCCache::onRead(
529     const char* key,
530     string& applicationId,
531     string& clientAddress,
532     string& providerId,
533     string& subject,
534     string& authnContext,
535     string& tokens,
536     int& majorVersion,
537     int& minorVersion,
538     time_t& created,
539     time_t& accessed
540     )
541 {
542 #ifdef _DEBUG
543     xmltooling::NDC ndc("onRead");
544 #endif
545
546     log->debug("searching MySQL database...");
547
548     string q = string("SELECT application_id,UNIX_TIMESTAMP(ctime),UNIX_TIMESTAMP(atime),addr,major,minor,provider,subject,authn_context,tokens FROM state WHERE cookie='") + key + "' LIMIT 1";
549
550     MYSQL* mysql = getMYSQL();
551     if (mysql_query(mysql, q.c_str())) {
552         const char* err=mysql_error(mysql);
553         log->error("error searching for %s: %s", key, err);
554         if (isCorrupt(err) && repairTable(mysql,"state")) {
555             if (mysql_query(mysql, q.c_str()))
556                 log->error("error retrying search for %s: %s", key, mysql_error(mysql));
557         }
558     }
559
560     MYSQL_RES* rows = mysql_store_result(mysql);
561
562     // Nope, doesn't exist.
563     if (!rows || mysql_num_rows(rows)==0) {
564         log->debug("not found in database");
565         if (rows)
566             mysql_free_result(rows);
567         return S_FALSE;
568     }
569
570     // Make sure we got 1 and only 1 row.
571     if (mysql_num_rows(rows) > 1) {
572         log->error("database select returned %d rows!", mysql_num_rows(rows));
573         mysql_free_result(rows);
574         return E_FAIL;
575     }
576
577     log->debug("session found, tranfering data back into memory");
578     
579     /* Columns in query:
580         0: application_id
581         1: ctime
582         2: atime
583         3: address
584         4: major
585         5: minor
586         6: provider
587         7: subject
588         8: authncontext
589         9: tokens
590      */
591
592     MYSQL_ROW row = mysql_fetch_row(rows);
593     applicationId=row[0];
594     created=atoi(row[1]);
595     accessed=atoi(row[2]);
596     clientAddress=row[3];
597     majorVersion=atoi(row[4]);
598     minorVersion=atoi(row[5]);
599     providerId=row[6];
600     subject=row[7];
601     authnContext=row[8];
602     if (row[9])
603         tokens=row[9];
604
605     // Free the results.
606     mysql_free_result(rows);
607
608     return NOERROR;
609 }
610
611 HRESULT ShibMySQLCCache::onRead(const char* key, time_t& accessed)
612 {
613 #ifdef _DEBUG
614     xmltooling::NDC ndc("onRead");
615 #endif
616
617     log->debug("reading last access time from MySQL database");
618
619     string q = string("SELECT UNIX_TIMESTAMP(atime) FROM state WHERE cookie='") + key + "' LIMIT 1";
620
621     MYSQL* mysql = getMYSQL();
622     if (mysql_query(mysql, q.c_str())) {
623         const char* err=mysql_error(mysql);
624         log->error("error searching for %s: %s", key, err);
625         if (isCorrupt(err) && repairTable(mysql,"state")) {
626             if (mysql_query(mysql, q.c_str()))
627                 log->error("error retrying search for %s: %s", key, mysql_error(mysql));
628         }
629     }
630
631     MYSQL_RES* rows = mysql_store_result(mysql);
632
633     // Nope, doesn't exist.
634     if (!rows || mysql_num_rows(rows)==0) {
635         log->warn("session expected, but not found in database");
636         if (rows)
637             mysql_free_result(rows);
638         return S_FALSE;
639     }
640
641     // Make sure we got 1 and only 1 row.
642     if (mysql_num_rows(rows) != 1) {
643         log->error("database select returned %d rows!", mysql_num_rows(rows));
644         mysql_free_result(rows);
645         return E_FAIL;
646     }
647
648     MYSQL_ROW row = mysql_fetch_row(rows);
649     accessed=atoi(row[0]);
650
651     // Free the results.
652     mysql_free_result(rows);
653
654     return NOERROR;
655 }
656
657 HRESULT ShibMySQLCCache::onRead(const char* key, string& tokens)
658 {
659 #ifdef _DEBUG
660     xmltooling::NDC ndc("onRead");
661 #endif
662
663     if (!m_storeAttributes)
664         return S_FALSE;
665
666     log->debug("reading cached tokens from MySQL database");
667
668     string q = string("SELECT tokens FROM state WHERE cookie='") + key + "' LIMIT 1";
669
670     MYSQL* mysql = getMYSQL();
671     if (mysql_query(mysql, q.c_str())) {
672         const char* err=mysql_error(mysql);
673         log->error("error searching for %s: %s", key, err);
674         if (isCorrupt(err) && repairTable(mysql,"state")) {
675             if (mysql_query(mysql, q.c_str()))
676                 log->error("error retrying search for %s: %s", key, mysql_error(mysql));
677         }
678     }
679
680     MYSQL_RES* rows = mysql_store_result(mysql);
681
682     // Nope, doesn't exist.
683     if (!rows || mysql_num_rows(rows)==0) {
684         log->warn("session expected, but not found in database");
685         if (rows)
686             mysql_free_result(rows);
687         return S_FALSE;
688     }
689
690     // Make sure we got 1 and only 1 row.
691     if (mysql_num_rows(rows) != 1) {
692         log->error("database select returned %d rows!", mysql_num_rows(rows));
693         mysql_free_result(rows);
694         return E_FAIL;
695     }
696
697     MYSQL_ROW row = mysql_fetch_row(rows);
698     if (row[0])
699         tokens=row[0];
700
701     // Free the results.
702     mysql_free_result(rows);
703
704     return NOERROR;
705 }
706
707 HRESULT ShibMySQLCCache::onUpdate(const char* key, const char* tokens, time_t lastAccess)
708 {
709 #ifdef _DEBUG
710     xmltooling::NDC ndc("onUpdate");
711 #endif
712
713     ostringstream q;
714     if (lastAccess>0)
715         q << "UPDATE state SET atime=FROM_UNIXTIME(" << lastAccess << ")";
716     else if (tokens) {
717         if (!m_storeAttributes)
718             return S_FALSE;
719         q << "UPDATE state SET tokens=";
720         if (*tokens)
721             q << "'" << tokens << "'";
722         else
723             q << "null";
724     }
725     else {
726         log->warn("onUpdate called with nothing to do!");
727         return S_FALSE;
728     }
729  
730     q << " WHERE cookie='" << key << "'";
731
732     MYSQL* mysql = getMYSQL();
733     if (mysql_query(mysql, q.str().c_str())) {
734         const char* err=mysql_error(mysql);
735         log->error("error updating %s: %s", key, err);
736         if (isCorrupt(err) && repairTable(mysql,"state")) {
737             // Try again...
738             if (mysql_query(mysql, q.str().c_str())) {
739                 log->error("error updating %s: %s", key, mysql_error(mysql));
740                 return E_FAIL;
741             }
742         }
743         else
744             return E_FAIL;
745     }
746
747     return NOERROR;
748 }
749
750 HRESULT ShibMySQLCCache::onDelete(const char* key)
751 {
752 #ifdef _DEBUG
753     xmltooling::NDC ndc("onDelete");
754 #endif
755
756     // Remove from the database
757     string q = string("DELETE FROM state WHERE cookie='") + key + "'";
758     MYSQL* mysql = getMYSQL();
759     if (mysql_query(mysql, q.c_str())) {
760         const char* err=mysql_error(mysql);
761         log->error("error deleting entry %s: %s", key, err);
762         if (isCorrupt(err) && repairTable(mysql,"state")) {
763             // Try again...
764             if (mysql_query(mysql, q.c_str())) {
765                 log->error("error deleting entry %s: %s", key, mysql_error(mysql));
766                 return E_FAIL;
767             }
768         }
769         else
770             return E_FAIL;
771     }
772
773     return NOERROR;
774 }
775
776 void ShibMySQLCCache::cleanup()
777 {
778 #ifdef _DEBUG
779   xmltooling::NDC ndc("cleanup");
780 #endif
781
782   xmltooling::Mutex* mutex = xmltooling::Mutex::create();
783
784   int rerun_timer = 0;
785   int timeout_life = 0;
786
787   // Load our configuration details...
788   const XMLCh* tag=m_root->getAttributeNS(NULL,cleanupInterval);
789   if (tag && *tag)
790     rerun_timer = XMLString::parseInt(tag);
791
792   // search for 'mysql-cache-timeout' and then the regular cache timeout
793   tag=m_root->getAttributeNS(NULL,mysqlTimeout);
794   if (tag && *tag)
795     timeout_life = XMLString::parseInt(tag);
796   else {
797       tag=m_root->getAttributeNS(NULL,cacheTimeout);
798       if (tag && *tag)
799         timeout_life = XMLString::parseInt(tag);
800   }
801   
802   if (rerun_timer <= 0)
803     rerun_timer = 300;          // rerun every 5 minutes
804
805   if (timeout_life <= 0)
806     timeout_life = 28800;       // timeout after 8 hours
807
808   mutex->lock();
809
810   MYSQL* mysql = getMYSQL();
811
812   log->info("cleanup thread started...Run every %d secs; timeout after %d secs", rerun_timer, timeout_life);
813
814   while (shutdown == false) {
815     shutdown_wait->timedwait(mutex, rerun_timer);
816
817     if (shutdown == true)
818       break;
819
820     // Find all the entries in the database that haven't been used
821     // recently In particular, find all entries that have not been
822     // accessed in 'timeout_life' seconds.
823     ostringstream q;
824     q << "DELETE FROM state WHERE " << "UNIX_TIMESTAMP(NOW()) - UNIX_TIMESTAMP(atime) >= " << timeout_life;
825
826     if (mysql_query(mysql, q.str().c_str())) {
827       const char* err=mysql_error(mysql);
828       log->error("error purging old records: %s", err);
829         if (isCorrupt(err) && repairTable(mysql,"state")) {
830           if (mysql_query(mysql, q.str().c_str()))
831             log->error("error re-purging old records: %s", mysql_error(mysql));
832         }
833     }
834   }
835
836   log->info("cleanup thread exiting...");
837
838   mutex->unlock();
839   delete mutex;
840   xmltooling::Thread::exit(NULL);
841 }
842
843 void* ShibMySQLCCache::cleanup_fcn(void* cache_p)
844 {
845   ShibMySQLCCache* cache = (ShibMySQLCCache*)cache_p;
846
847 #ifndef WIN32
848   // First, let's block all signals
849   xmltooling::Thread::mask_all_signals();
850 #endif
851
852   // Now run the cleanup process.
853   cache->cleanup();
854   return NULL;
855 }
856
857 class MySQLReplayCache : public MySQLBase, virtual public IReplayCache
858 {
859 public:
860   MySQLReplayCache(const DOMElement* e);
861   virtual ~MySQLReplayCache() {}
862
863   bool check(const XMLCh* str, time_t expires) {auto_ptr_XMLCh temp(str); return check(temp.get(),expires);}
864   bool check(const char* str, time_t expires);
865 };
866
867 MySQLReplayCache::MySQLReplayCache(const DOMElement* e) : MySQLBase(e) {}
868
869 bool MySQLReplayCache::check(const char* str, time_t expires)
870 {
871 #ifdef _DEBUG
872     xmltooling::NDC ndc("check");
873 #endif
874   
875     // Remove expired entries
876     string q = string("DELETE FROM replay WHERE expires < NOW()");
877     MYSQL* mysql = getMYSQL();
878     if (mysql_query(mysql, q.c_str())) {
879         const char* err=mysql_error(mysql);
880         log->error("Error deleting expired entries: %s", err);
881         if (isCorrupt(err) && repairTable(mysql,"replay")) {
882             // Try again...
883             if (mysql_query(mysql, q.c_str()))
884                 log->error("Error deleting expired entries: %s", mysql_error(mysql));
885         }
886     }
887   
888     string q2 = string("SELECT id FROM replay WHERE id='") + str + "'";
889     if (mysql_query(mysql, q2.c_str())) {
890         const char* err=mysql_error(mysql);
891         log->error("Error searching for %s: %s", str, err);
892         if (isCorrupt(err) && repairTable(mysql,"replay")) {
893             if (mysql_query(mysql, q2.c_str())) {
894                 log->error("Error retrying search for %s: %s", str, mysql_error(mysql));
895                 throw SAMLException("Replay cache failed, please inform application support staff.");
896             }
897         }
898         else
899             throw SAMLException("Replay cache failed, please inform application support staff.");
900     }
901
902     // Did we find it?
903     MYSQL_RES* rows = mysql_store_result(mysql);
904     if (rows && mysql_num_rows(rows)>0) {
905       mysql_free_result(rows);
906       return false;
907     }
908
909     ostringstream q3;
910     q3 << "INSERT INTO replay VALUES('" << str << "'," << "FROM_UNIXTIME(" << expires << "))";
911
912     // then add it to the database
913     if (mysql_query(mysql, q3.str().c_str())) {
914         const char* err=mysql_error(mysql);
915         log->error("Error inserting %s: %s", str, err);
916         if (isCorrupt(err) && repairTable(mysql,"state")) {
917             // Try again...
918             if (mysql_query(mysql, q3.str().c_str())) {
919                 log->error("Error inserting %s: %s", str, mysql_error(mysql));
920                 throw SAMLException("Replay cache failed, please inform application support staff.");
921             }
922         }
923         else
924             throw SAMLException("Replay cache failed, please inform application support staff.");
925     }
926     
927     return true;
928 }
929
930 /*************************************************************************
931  * The registration functions here...
932  */
933
934 IPlugIn* new_mysql_ccache(const DOMElement* e)
935 {
936     return new ShibMySQLCCache(e);
937 }
938
939 IPlugIn* new_mysql_replay(const DOMElement* e)
940 {
941     return new MySQLReplayCache(e);
942 }
943
944 extern "C" int SHIBMYSQL_EXPORTS saml_extension_init(void*)
945 {
946     // register this ccache type
947     SAMLConfig::getConfig().getPlugMgr().regFactory(shibtarget::XML::MySQLReplayCacheType, &new_mysql_replay);
948     SAMLConfig::getConfig().getPlugMgr().regFactory(shibtarget::XML::MySQLSessionCacheType, &new_mysql_ccache);
949     return 0;
950 }
951
952 extern "C" void SHIBMYSQL_EXPORTS saml_extension_term()
953 {
954     // Shutdown MySQL
955     if (g_MySQLInitialized)
956         mysql_server_end();
957     SAMLConfig::getConfig().getPlugMgr().unregFactory(shibtarget::XML::MySQLReplayCacheType);
958     SAMLConfig::getConfig().getPlugMgr().unregFactory(shibtarget::XML::MySQLSessionCacheType);
959 }