Cleaned up some logging and error codes.
[shibboleth/sp.git] / shib-mysql-ccache / shib-mysql-ccache.cpp
1 /*
2  *  Copyright 2001-2005 Internet2
3  * 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /*
18  * shib-mysql-ccache.cpp: Shibboleth Credential Cache using MySQL.
19  *
20  * Created by:  Derek Atkins <derek@ihtfp.com>
21  *
22  * $Id$
23  */
24
25 // eventually we might be able to support autoconf via cygwin...
26 #if defined (_MSC_VER) || defined(__BORLANDC__)
27 # include "config_win32.h"
28 #else
29 # include "config.h"
30 #endif
31
32 #ifdef WIN32
33 # define SHIBMYSQL_EXPORTS __declspec(dllexport)
34 #else
35 # define SHIBMYSQL_EXPORTS
36 #endif
37
38 #ifdef HAVE_UNISTD_H
39 # include <unistd.h>
40 #endif
41
42 #include <shib/shib-threads.h>
43 #include <shib-target/shib-target.h>
44 #include <log4cpp/Category.hh>
45
46 #include <sstream>
47
48 #ifdef WIN32
49 # include <winsock.h>
50 #endif
51 #include <mysql.h>
52
53 // wanted to use MySQL codes for this, but can't seem to get back a 145
54 #define isCorrupt(s) strstr(s,"(errno: 145)")
55
56 #ifdef HAVE_LIBDMALLOCXX
57 #include <dmalloc.h>
58 #endif
59
60 using namespace std;
61 using namespace saml;
62 using namespace shibboleth;
63 using namespace shibtarget;
64 using namespace log4cpp;
65
66 #define PLUGIN_VER_MAJOR 3
67 #define PLUGIN_VER_MINOR 0
68
69 #define STATE_TABLE \
70   "CREATE TABLE state (" \
71   "cookie VARCHAR(64) PRIMARY KEY, " \
72   "application_id VARCHAR(255)," \
73   "ctime TIMESTAMP," \
74   "atime TIMESTAMP," \
75   "addr VARCHAR(128)," \
76   "major INT," \
77   "minor INT," \
78   "provider VARCHAR(256)," \
79   "subject TEXT," \
80   "authn_context TEXT," \
81   "tokens TEXT)"
82
83 #define REPLAY_TABLE \
84   "CREATE TABLE replay (id VARCHAR(255) PRIMARY KEY, " \
85   "expires TIMESTAMP, " \
86   "INDEX (expires))"
87
88 static const XMLCh Argument[] =
89 { chLatin_A, chLatin_r, chLatin_g, chLatin_u, chLatin_m, chLatin_e, chLatin_n, chLatin_t, chNull };
90 static const XMLCh cleanupInterval[] =
91 { chLatin_c, chLatin_l, chLatin_e, chLatin_a, chLatin_n, chLatin_u, chLatin_p,
92   chLatin_I, chLatin_n, chLatin_t, chLatin_e, chLatin_r, chLatin_v, chLatin_a, chLatin_l, chNull
93 };
94 static const XMLCh cacheTimeout[] =
95 { chLatin_c, chLatin_a, chLatin_c, chLatin_h, chLatin_e, chLatin_T, chLatin_i, chLatin_m, chLatin_e, chLatin_o, chLatin_u, chLatin_t, chNull };
96 static const XMLCh mysqlTimeout[] =
97 { chLatin_m, chLatin_y, chLatin_s, chLatin_q, chLatin_l, chLatin_T, chLatin_i, chLatin_m, chLatin_e, chLatin_o, chLatin_u, chLatin_t, chNull };
98 static const XMLCh storeAttributes[] =
99 { chLatin_s, chLatin_t, chLatin_o, chLatin_r, chLatin_e, chLatin_A, chLatin_t, chLatin_t, chLatin_r, chLatin_i, chLatin_b, chLatin_u, chLatin_t, chLatin_e, chLatin_s, chNull };
100
101 class MySQLBase : public virtual saml::IPlugIn
102 {
103 public:
104   MySQLBase(const DOMElement* e);
105   virtual ~MySQLBase();
106
107   MYSQL* getMYSQL();
108   bool repairTable(MYSQL*&, const char* table);
109
110   log4cpp::Category* log;
111
112 protected:
113   ThreadKey* m_mysql;
114   const DOMElement* m_root; // can only use this during initialization
115
116   bool initialized;
117
118   void createDatabase(MYSQL*, int major, int minor);
119   void upgradeDatabase(MYSQL*);
120   pair<int,int> getVersion(MYSQL*);
121 };
122
123 // Forward declarations
124 static void mysqlInit(const DOMElement* e, Category& log);
125
126 extern "C" void shib_mysql_destroy_handle(void* data)
127 {
128   MYSQL* mysql = (MYSQL*) data;
129   if (mysql) mysql_close(mysql);
130 }
131
132 MySQLBase::MySQLBase(const DOMElement* e) : m_root(e)
133 {
134 #ifdef _DEBUG
135   saml::NDC ndc("MySQLBase");
136 #endif
137   log = &(Category::getInstance("shibmysql.MySQLBase"));
138
139   m_mysql = ThreadKey::create(&shib_mysql_destroy_handle);
140
141   initialized = false;
142   mysqlInit(e,*log);
143   getMYSQL();
144   initialized = true;
145 }
146
147 MySQLBase::~MySQLBase()
148 {
149   delete m_mysql;
150 }
151
152 MYSQL* MySQLBase::getMYSQL()
153 {
154 #ifdef _DEBUG
155     saml::NDC ndc("getMYSQL");
156 #endif
157
158     // Do we already have a handle?
159     MYSQL* mysql=reinterpret_cast<MYSQL*>(m_mysql->getData());
160     if (mysql)
161         return mysql;
162
163     // Connect to the database
164     mysql = mysql_init(NULL);
165     if (!mysql) {
166         log->error("mysql_init failed");
167         mysql_close(mysql);
168         throw SAMLException("MySQLBase::getMYSQL(): mysql_init() failed");
169     }
170
171     if (!mysql_real_connect(mysql, NULL, NULL, NULL, "shibd", 0, NULL, 0)) {
172         if (initialized) {
173             log->crit("mysql_real_connect failed: %s", mysql_error(mysql));
174             mysql_close(mysql);
175             throw SAMLException("MySQLBase::getMYSQL(): mysql_real_connect() failed");
176         }
177         else {
178             log->info("mysql_real_connect failed: %s.  Trying to create", mysql_error(mysql));
179
180             // This will throw an exception if it fails.
181             createDatabase(mysql, PLUGIN_VER_MAJOR, PLUGIN_VER_MINOR);
182         }
183     }
184
185     pair<int,int> v=getVersion (mysql);
186
187     // Make sure we've got the right version
188     if (v.first != PLUGIN_VER_MAJOR || v.second != PLUGIN_VER_MINOR) {
189    
190         // If we're capable, try upgrading on the fly...
191         if (v.first == 0  || v.first == 1 || v.first == 2) {
192             if (mysql_query(mysql, "DROP TABLE state")) {
193                 log->error("error dropping old session state table: %s", mysql_error(mysql));
194             }
195             if (v.first==2 && mysql_query(mysql, "DROP TABLE replay")) {
196                 log->error("error dropping old session state table: %s", mysql_error(mysql));
197             }
198             upgradeDatabase(mysql);
199         }
200         else {
201             mysql_close(mysql);
202             log->crit("Unknown database version: %d.%d", v.first, v.second);
203             throw SAMLException("MySQLBase::getMYSQL(): Unknown database version");
204         }
205     }
206
207     // We're all set.. Save off the handle for this thread.
208     m_mysql->setData(mysql);
209     return mysql;
210 }
211
212 bool MySQLBase::repairTable(MYSQL*& mysql, const char* table)
213 {
214     string q = string("REPAIR TABLE ") + table;
215     if (mysql_query(mysql, q.c_str())) {
216         log->error("Error repairing table %s: %s", table, mysql_error(mysql));
217         return false;
218     }
219
220     // seems we have to recycle the connection to get the thread to keep working
221     // other threads seem to be ok, but we should monitor that
222     mysql_close(mysql);
223     m_mysql->setData(NULL);
224     mysql=getMYSQL();
225     return true;
226 }
227
228 void MySQLBase::createDatabase(MYSQL* mysql, int major, int minor)
229 {
230   log->info("creating database");
231
232   MYSQL* ms = NULL;
233   try {
234     ms = mysql_init(NULL);
235     if (!ms) {
236       log->crit("mysql_init failed");
237       throw SAMLException("ShibMySQLCCache::createDatabase(): mysql_init failed");
238     }
239
240     if (!mysql_real_connect(ms, NULL, NULL, NULL, NULL, 0, NULL, 0)) {
241       log->crit("cannot open DB file to create DB: %s", mysql_error(ms));
242       throw SAMLException("ShibMySQLCCache::createDatabase(): mysql_real_connect failed");
243     }
244
245     if (mysql_query(ms, "CREATE DATABASE shibd")) {
246       log->crit("cannot create shibd database: %s", mysql_error(ms));
247       throw SAMLException("ShibMySQLCCache::createDatabase(): create db cmd failed");
248     }
249
250     if (!mysql_real_connect(mysql, NULL, NULL, NULL, "shibd", 0, NULL, 0)) {
251       log->crit("cannot open shibd database");
252       throw SAMLException("ShibMySQLCCache::createDatabase(): mysql_real_connect to plugin db failed");
253     }
254
255     mysql_close(ms);
256     
257   }
258   catch (SAMLException&) {
259     if (ms)
260       mysql_close(ms);
261     mysql_close(mysql);
262     throw;
263   }
264
265   // Now create the tables if they don't exist
266   log->info("Creating database tables");
267
268   if (mysql_query(mysql, "CREATE TABLE version (major INT, minor INT)")) {
269     log->error ("error creating version: %s", mysql_error(mysql));
270     throw SAMLException("ShibMySQLCCache::createDatabase(): create table cmd failed");
271   }
272
273   if (mysql_query(mysql,STATE_TABLE)) {
274     log->error ("error creating state table: %s", mysql_error(mysql));
275     throw SAMLException("ShibMySQLCCache::createDatabase(): create table cmd failed");
276   }
277
278   if (mysql_query(mysql,REPLAY_TABLE)) {
279     log->error ("error creating replay table: %s", mysql_error(mysql));
280     throw SAMLException("ShibMySQLCCache::createDatabase(): create table cmd failed");
281   }
282
283   ostringstream q;
284   q << "INSERT INTO version VALUES(" << major << "," << minor << ")";
285   if (mysql_query(mysql, q.str().c_str())) {
286     log->error ("error setting version: %s", mysql_error(mysql));
287     throw SAMLException("ShibMySQLCCache::createDatabase(): version insert failed");
288   }
289 }
290
291 void MySQLBase::upgradeDatabase(MYSQL* mysql)
292 {
293     if (mysql_query(mysql,STATE_TABLE)) {
294         log->error ("error creating state table: %s", mysql_error(mysql));
295         throw SAMLException("ShibMySQLCCache::upgradeDatabase(): error creating state table");
296     }
297
298     if (mysql_query(mysql,REPLAY_TABLE)) {
299         log->error ("error creating replay table: %s", mysql_error(mysql));
300         throw SAMLException("ShibMySQLCCache::upgradeDatabase(): error creating replay table");
301     }
302
303     ostringstream q;
304     q << "UPDATE version SET major = " << PLUGIN_VER_MAJOR;
305     if (mysql_query(mysql, q.str().c_str())) {
306         log->error ("error updating version: %s", mysql_error(mysql));
307         throw SAMLException("ShibMySQLCCache::upgradeDatabase(): error updating version");
308     }
309 }
310
311 pair<int,int> MySQLBase::getVersion(MYSQL* mysql)
312 {
313     // grab the version number from the database
314     if (mysql_query(mysql, "SELECT * FROM version"))
315         log->error ("Error reading version: %s", mysql_error(mysql));
316
317     MYSQL_RES* rows = mysql_store_result(mysql);
318     if (rows) {
319         if (mysql_num_rows(rows) == 1 && mysql_num_fields(rows) == 2)  {
320           MYSQL_ROW row = mysql_fetch_row(rows);
321           int major = row[0] ? atoi(row[0]) : -1;
322           int minor = row[1] ? atoi(row[1]) : -1;
323           log->debug("opening database version %d.%d", major, minor);
324           mysql_free_result(rows);
325           return make_pair(major,minor);
326         } else {
327             // Wrong number of rows or wrong number of fields...
328             log->crit("Houston, we've got a problem with the database...");
329             mysql_free_result (rows);
330             throw SAMLException("ShibMySQLCCache::getVersion(): version verification failed");
331         }
332     }
333     log->crit("MySQL Read Failed in version verification");
334     throw SAMLException("ShibMySQLCCache::getVersion(): error reading version");
335 }
336
337 static void mysqlInit(const DOMElement* e, Category& log)
338 {
339     static bool done = false;
340     if (done) {
341         log.info("MySQL embedded server already initialized");
342         return;
343     }
344     log.info("initializing MySQL embedded server");
345
346     // Setup the argument array
347     vector<string> arg_array;
348     arg_array.push_back("shibboleth");
349
350     // grab any MySQL parameters from the config file
351     e=saml::XML::getFirstChildElement(e,shibtarget::XML::SHIBTARGET_NS,Argument);
352     while (e) {
353         auto_ptr_char arg(e->getFirstChild()->getNodeValue());
354         if (arg.get())
355             arg_array.push_back(arg.get());
356         e=saml::XML::getNextSiblingElement(e,shibtarget::XML::SHIBTARGET_NS,Argument);
357     }
358
359     // Compute the argument array
360     vector<string>::size_type arg_count = arg_array.size();
361     const char** args=new const char*[arg_count];
362     for (vector<string>::size_type i = 0; i < arg_count; i++)
363         args[i] = arg_array[i].c_str();
364
365     // Initialize MySQL with the arguments
366     mysql_server_init(arg_count, (char **)args, NULL);
367
368     delete[] args;
369     done = true;
370 }  
371
372 class ShibMySQLCCache : public MySQLBase, virtual public ISessionCache, virtual public ISessionCacheStore
373 {
374 public:
375     ShibMySQLCCache(const DOMElement* e);
376     virtual ~ShibMySQLCCache();
377
378     // Delegate all the ISessionCache methods.
379     string insert(
380         const IApplication* application,
381         const IEntityDescriptor* source,
382         const char* client_addr,
383         const SAMLSubject* subject,
384         const char* authnContext,
385         const SAMLResponse* tokens
386         )
387     { return m_cache->insert(application,source,client_addr,subject,authnContext,tokens); }
388     ISessionCacheEntry* find(const char* key, const IApplication* application, const char* client_addr)
389     { return m_cache->find(key,application,client_addr); }
390     void remove(const char* key, const IApplication* application, const char* client_addr)
391     { m_cache->remove(key,application,client_addr); }
392
393     bool setBackingStore(ISessionCacheStore*) {return false;}
394
395     // Store methods handle the database work
396     HRESULT onCreate(
397         const char* key,
398         const IApplication* application,
399         const ISessionCacheEntry* entry,
400         int majorVersion,
401         int minorVersion,
402         time_t created
403         );
404     HRESULT onRead(
405         const char* key,
406         string& applicationId,
407         string& clientAddress,
408         string& providerId,
409         string& subject,
410         string& authnContext,
411         string& tokens,
412         int& majorVersion,
413         int& minorVersion,
414         time_t& created,
415         time_t& accessed
416         );
417     HRESULT onRead(const char* key, time_t& accessed);
418     HRESULT onRead(const char* key, string& tokens);
419     HRESULT onUpdate(const char* key, const char* tokens=NULL, time_t accessed=0);
420     HRESULT onDelete(const char* key);
421
422     void cleanup();
423
424 private:
425     bool m_storeAttributes;
426     ISessionCache* m_cache;
427     CondWait* shutdown_wait;
428     bool shutdown;
429     Thread* cleanup_thread;
430
431     static void* cleanup_fcn(void*); // XXX Assumed an ShibMySQLCCache
432 };
433
434 ShibMySQLCCache::ShibMySQLCCache(const DOMElement* e) : MySQLBase(e), m_storeAttributes(false)
435 {
436 #ifdef _DEBUG
437     saml::NDC ndc("ShibMySQLCCache");
438 #endif
439     log = &(Category::getInstance("shibmysql.SessionCache"));
440
441     m_cache = dynamic_cast<ISessionCache*>(
442         SAMLConfig::getConfig().getPlugMgr().newPlugin(shibtarget::XML::MemorySessionCacheType, e)
443     );
444     if (!m_cache->setBackingStore(this)) {
445         delete m_cache;
446         throw SAMLException("Unable to register MySQL cache plugin as a cache store.");
447     }
448     
449     shutdown_wait = CondWait::create();
450     shutdown = false;
451
452     // Load our configuration details...
453     const XMLCh* tag=m_root->getAttributeNS(NULL,storeAttributes);
454     if (tag && *tag && (*tag==chLatin_t || *tag==chDigit_1))
455         m_storeAttributes=true;
456
457     // Initialize the cleanup thread
458     cleanup_thread = Thread::create(&cleanup_fcn, (void*)this);
459 }
460
461 ShibMySQLCCache::~ShibMySQLCCache()
462 {
463     shutdown = true;
464     shutdown_wait->signal();
465     cleanup_thread->join(NULL);
466     delete m_cache;
467 }
468
469 HRESULT ShibMySQLCCache::onCreate(
470     const char* key,
471     const IApplication* application,
472     const ISessionCacheEntry* entry,
473     int majorVersion,
474     int minorVersion,
475     time_t created
476     )
477 {
478 #ifdef _DEBUG
479     saml::NDC ndc("onCreate");
480 #endif
481
482     // Get XML data from entry. Default is not to return SAML objects.
483     const char* context=entry->getAuthnContext();
484     pair<const char*,const SAMLSubject*> subject=entry->getSubject();
485     pair<const char*,const SAMLResponse*> tokens=entry->getTokens();
486
487     ostringstream q;
488     q << "INSERT INTO state VALUES('" << key << "','" << application->getId() << "',";
489     if (created==0)
490         q << "NOW(),NOW(),'";
491     else
492         q << "FROM_UNIXTIME(" << created << "),NOW(),'";
493     q << entry->getClientAddress() << "'," << majorVersion << "," << minorVersion << ",'" << entry->getProviderId() << "','"
494         << subject.first << "','" << context << "',";
495
496     if (m_storeAttributes && tokens.first)
497         q << "'" << tokens.first << "')";
498     else
499         q << "null)";
500
501     if (log->isDebugEnabled())
502         log->debug("SQL insert: %s", q.str().c_str());
503
504     // then add it to the database
505     MYSQL* mysql = getMYSQL();
506     if (mysql_query(mysql, q.str().c_str())) {
507         const char* err=mysql_error(mysql);
508         log->error("error inserting %s: %s", key, err);
509         if (isCorrupt(err) && repairTable(mysql,"state")) {
510             // Try again...
511             if (mysql_query(mysql, q.str().c_str())) {
512                 log->error("error inserting %s: %s", key, mysql_error(mysql));
513                 return E_FAIL;
514             }
515         }
516         else
517             throw E_FAIL;
518     }
519
520     return NOERROR;
521 }
522
523 HRESULT ShibMySQLCCache::onRead(
524     const char* key,
525     string& applicationId,
526     string& clientAddress,
527     string& providerId,
528     string& subject,
529     string& authnContext,
530     string& tokens,
531     int& majorVersion,
532     int& minorVersion,
533     time_t& created,
534     time_t& accessed
535     )
536 {
537 #ifdef _DEBUG
538     saml::NDC ndc("onRead");
539 #endif
540
541     log->debug("searching MySQL database...");
542
543     string q = string("SELECT application_id,UNIX_TIMESTAMP(ctime),UNIX_TIMESTAMP(atime),addr,major,minor,provider,subject,authn_context,tokens FROM state WHERE cookie='") + key + "' LIMIT 1";
544
545     MYSQL* mysql = getMYSQL();
546     if (mysql_query(mysql, q.c_str())) {
547         const char* err=mysql_error(mysql);
548         log->error("error searching for %s: %s", key, err);
549         if (isCorrupt(err) && repairTable(mysql,"state")) {
550             if (mysql_query(mysql, q.c_str()))
551                 log->error("error retrying search for %s: %s", key, mysql_error(mysql));
552         }
553     }
554
555     MYSQL_RES* rows = mysql_store_result(mysql);
556
557     // Nope, doesn't exist.
558     if (!rows || mysql_num_rows(rows)==0) {
559         log->debug("not found in database");
560         if (rows)
561             mysql_free_result(rows);
562         return S_FALSE;
563     }
564
565     // Make sure we got 1 and only 1 row.
566     if (mysql_num_rows(rows) > 1) {
567         log->error("database select returned %d rows!", mysql_num_rows(rows));
568         mysql_free_result(rows);
569         return E_FAIL;
570     }
571
572     log->debug("session found, tranfering data back into memory");
573     
574     /* Columns in query:
575         0: application_id
576         1: ctime
577         2: atime
578         3: address
579         4: major
580         5: minor
581         6: provider
582         7: subject
583         8: authncontext
584         9: tokens
585      */
586
587     MYSQL_ROW row = mysql_fetch_row(rows);
588     applicationId=row[0];
589     created=atoi(row[1]);
590     accessed=atoi(row[2]);
591     clientAddress=row[3];
592     majorVersion=atoi(row[4]);
593     minorVersion=atoi(row[5]);
594     providerId=row[6];
595     subject=row[7];
596     authnContext=row[8];
597     if (row[9])
598         tokens=row[9];
599
600     // Free the results.
601     mysql_free_result(rows);
602
603     return NOERROR;
604 }
605
606 HRESULT ShibMySQLCCache::onRead(const char* key, time_t& accessed)
607 {
608 #ifdef _DEBUG
609     saml::NDC ndc("onRead");
610 #endif
611
612     log->debug("reading last access time from MySQL database");
613
614     string q = string("SELECT UNIX_TIMESTAMP(atime) FROM state WHERE cookie='") + key + "' LIMIT 1";
615
616     MYSQL* mysql = getMYSQL();
617     if (mysql_query(mysql, q.c_str())) {
618         const char* err=mysql_error(mysql);
619         log->error("error searching for %s: %s", key, err);
620         if (isCorrupt(err) && repairTable(mysql,"state")) {
621             if (mysql_query(mysql, q.c_str()))
622                 log->error("error retrying search for %s: %s", key, mysql_error(mysql));
623         }
624     }
625
626     MYSQL_RES* rows = mysql_store_result(mysql);
627
628     // Nope, doesn't exist.
629     if (!rows || mysql_num_rows(rows)==0) {
630         log->warn("session expected, but not found in database");
631         if (rows)
632             mysql_free_result(rows);
633         return S_FALSE;
634     }
635
636     // Make sure we got 1 and only 1 row.
637     if (mysql_num_rows(rows) != 1) {
638         log->error("database select returned %d rows!", mysql_num_rows(rows));
639         mysql_free_result(rows);
640         return E_FAIL;
641     }
642
643     MYSQL_ROW row = mysql_fetch_row(rows);
644     accessed=atoi(row[0]);
645
646     // Free the results.
647     mysql_free_result(rows);
648
649     return NOERROR;
650 }
651
652 HRESULT ShibMySQLCCache::onRead(const char* key, string& tokens)
653 {
654 #ifdef _DEBUG
655     saml::NDC ndc("onRead");
656 #endif
657
658     if (!m_storeAttributes)
659         return S_FALSE;
660
661     log->debug("reading cached tokens from MySQL database");
662
663     string q = string("SELECT tokens FROM state WHERE cookie='") + key + "' LIMIT 1";
664
665     MYSQL* mysql = getMYSQL();
666     if (mysql_query(mysql, q.c_str())) {
667         const char* err=mysql_error(mysql);
668         log->error("error searching for %s: %s", key, err);
669         if (isCorrupt(err) && repairTable(mysql,"state")) {
670             if (mysql_query(mysql, q.c_str()))
671                 log->error("error retrying search for %s: %s", key, mysql_error(mysql));
672         }
673     }
674
675     MYSQL_RES* rows = mysql_store_result(mysql);
676
677     // Nope, doesn't exist.
678     if (!rows || mysql_num_rows(rows)==0) {
679         log->warn("session expected, but not found in database");
680         if (rows)
681             mysql_free_result(rows);
682         return S_FALSE;
683     }
684
685     // Make sure we got 1 and only 1 row.
686     if (mysql_num_rows(rows) != 1) {
687         log->error("database select returned %d rows!", mysql_num_rows(rows));
688         mysql_free_result(rows);
689         return E_FAIL;
690     }
691
692     MYSQL_ROW row = mysql_fetch_row(rows);
693     if (row[0])
694         tokens=row[0];
695
696     // Free the results.
697     mysql_free_result(rows);
698
699     return NOERROR;
700 }
701
702 HRESULT ShibMySQLCCache::onUpdate(const char* key, const char* tokens, time_t lastAccess)
703 {
704 #ifdef _DEBUG
705     saml::NDC ndc("onUpdate");
706 #endif
707
708     ostringstream q;
709     if (lastAccess>0)
710         q << "UPDATE state SET atime=FROM_UNIXTIME(" << lastAccess << ")";
711     else if (tokens) {
712         if (!m_storeAttributes)
713             return S_FALSE;
714         q << "UPDATE state SET tokens=";
715         if (*tokens)
716             q << "'" << tokens << "'";
717         else
718             q << "null";
719     }
720     else {
721         log->warn("onUpdate called with nothing to do!");
722         return S_FALSE;
723     }
724  
725     q << " WHERE cookie='" << key << "'";
726
727     MYSQL* mysql = getMYSQL();
728     if (mysql_query(mysql, q.str().c_str())) {
729         const char* err=mysql_error(mysql);
730         log->error("error updating %s: %s", key, err);
731         if (isCorrupt(err) && repairTable(mysql,"state")) {
732             // Try again...
733             if (mysql_query(mysql, q.str().c_str())) {
734                 log->error("error updating %s: %s", key, mysql_error(mysql));
735                 return E_FAIL;
736             }
737         }
738         else
739             return E_FAIL;
740     }
741
742     return NOERROR;
743 }
744
745 HRESULT ShibMySQLCCache::onDelete(const char* key)
746 {
747 #ifdef _DEBUG
748     saml::NDC ndc("onDelete");
749 #endif
750
751     // Remove from the database
752     string q = string("DELETE FROM state WHERE cookie='") + key + "'";
753     MYSQL* mysql = getMYSQL();
754     if (mysql_query(mysql, q.c_str())) {
755         const char* err=mysql_error(mysql);
756         log->error("error deleting entry %s: %s", key, err);
757         if (isCorrupt(err) && repairTable(mysql,"state")) {
758             // Try again...
759             if (mysql_query(mysql, q.c_str())) {
760                 log->error("error deleting entry %s: %s", key, mysql_error(mysql));
761                 return E_FAIL;
762             }
763         }
764         else
765             return E_FAIL;
766     }
767
768     return NOERROR;
769 }
770
771 void ShibMySQLCCache::cleanup()
772 {
773 #ifdef _DEBUG
774   saml::NDC ndc("cleanup");
775 #endif
776
777   Mutex* mutex = Mutex::create();
778
779   int rerun_timer = 0;
780   int timeout_life = 0;
781
782   // Load our configuration details...
783   const XMLCh* tag=m_root->getAttributeNS(NULL,cleanupInterval);
784   if (tag && *tag)
785     rerun_timer = XMLString::parseInt(tag);
786
787   // search for 'mysql-cache-timeout' and then the regular cache timeout
788   tag=m_root->getAttributeNS(NULL,mysqlTimeout);
789   if (tag && *tag)
790     timeout_life = XMLString::parseInt(tag);
791   else {
792       tag=m_root->getAttributeNS(NULL,cacheTimeout);
793       if (tag && *tag)
794         timeout_life = XMLString::parseInt(tag);
795   }
796   
797   if (rerun_timer <= 0)
798     rerun_timer = 300;          // rerun every 5 minutes
799
800   if (timeout_life <= 0)
801     timeout_life = 28800;       // timeout after 8 hours
802
803   mutex->lock();
804
805   MYSQL* mysql = getMYSQL();
806
807   log->info("cleanup thread started...Run every %d secs; timeout after %d secs", rerun_timer, timeout_life);
808
809   while (shutdown == false) {
810     shutdown_wait->timedwait(mutex, rerun_timer);
811
812     if (shutdown == true)
813       break;
814
815     // Find all the entries in the database that haven't been used
816     // recently In particular, find all entries that have not been
817     // accessed in 'timeout_life' seconds.
818     ostringstream q;
819     q << "DELETE FROM state WHERE " << "UNIX_TIMESTAMP(NOW()) - UNIX_TIMESTAMP(atime) >= " << timeout_life;
820
821     if (mysql_query(mysql, q.str().c_str())) {
822       const char* err=mysql_error(mysql);
823       log->error("error purging old records: %s", err);
824         if (isCorrupt(err) && repairTable(mysql,"state")) {
825           if (mysql_query(mysql, q.str().c_str()))
826             log->error("error re-purging old records: %s", mysql_error(mysql));
827         }
828     }
829   }
830
831   log->info("cleanup thread exiting...");
832
833   mutex->unlock();
834   delete mutex;
835   Thread::exit(NULL);
836 }
837
838 void* ShibMySQLCCache::cleanup_fcn(void* cache_p)
839 {
840   ShibMySQLCCache* cache = (ShibMySQLCCache*)cache_p;
841
842   // First, let's block all signals
843   Thread::mask_all_signals();
844
845   // Now run the cleanup process.
846   cache->cleanup();
847   return NULL;
848 }
849
850 class MySQLReplayCache : public MySQLBase, virtual public IReplayCache
851 {
852 public:
853   MySQLReplayCache(const DOMElement* e);
854   virtual ~MySQLReplayCache() {}
855
856   bool check(const XMLCh* str, time_t expires) {auto_ptr_XMLCh temp(str); return check(temp.get(),expires);}
857   bool check(const char* str, time_t expires);
858 };
859
860 MySQLReplayCache::MySQLReplayCache(const DOMElement* e) : MySQLBase(e)
861 {
862 #ifdef _DEBUG
863   saml::NDC ndc("MySQLReplayCache");
864 #endif
865
866   log = &(Category::getInstance("shibmysql.ReplayCache"));
867 }
868
869 bool MySQLReplayCache::check(const char* str, time_t expires)
870 {
871 #ifdef _DEBUG
872     saml::NDC ndc("check");
873 #endif
874   
875     // Remove expired entries
876     string q = string("DELETE FROM replay WHERE expires < NOW()");
877     MYSQL* mysql = getMYSQL();
878     if (mysql_query(mysql, q.c_str())) {
879         const char* err=mysql_error(mysql);
880         log->error("Error deleting expired entries: %s", err);
881         if (isCorrupt(err) && repairTable(mysql,"replay")) {
882             // Try again...
883             if (mysql_query(mysql, q.c_str()))
884                 log->error("Error deleting expired entries: %s", mysql_error(mysql));
885         }
886     }
887   
888     string q2 = string("SELECT id FROM replay WHERE id='") + str + "'";
889     if (mysql_query(mysql, q2.c_str())) {
890         const char* err=mysql_error(mysql);
891         log->error("Error searching for %s: %s", str, err);
892         if (isCorrupt(err) && repairTable(mysql,"replay")) {
893             if (mysql_query(mysql, q2.c_str())) {
894                 log->error("Error retrying search for %s: %s", str, mysql_error(mysql));
895                 throw SAMLException("Replay cache failed, please inform application support staff.");
896             }
897         }
898         else
899             throw SAMLException("Replay cache failed, please inform application support staff.");
900     }
901
902     // Did we find it?
903     MYSQL_RES* rows = mysql_store_result(mysql);
904     if (rows && mysql_num_rows(rows)>0) {
905       mysql_free_result(rows);
906       return false;
907     }
908
909     ostringstream q3;
910     q3 << "INSERT INTO replay VALUES('" << str << "'," << "FROM_UNIXTIME(" << expires << "))";
911
912     // then add it to the database
913     if (mysql_query(mysql, q3.str().c_str())) {
914         const char* err=mysql_error(mysql);
915         log->error("Error inserting %s: %s", str, err);
916         if (isCorrupt(err) && repairTable(mysql,"state")) {
917             // Try again...
918             if (mysql_query(mysql, q3.str().c_str())) {
919                 log->error("Error inserting %s: %s", str, mysql_error(mysql));
920                 throw SAMLException("Replay cache failed, please inform application support staff.");
921             }
922         }
923         else
924             throw SAMLException("Replay cache failed, please inform application support staff.");
925     }
926     
927     return true;
928 }
929
930 /*************************************************************************
931  * The registration functions here...
932  */
933
934 IPlugIn* new_mysql_ccache(const DOMElement* e)
935 {
936     return new ShibMySQLCCache(e);
937 }
938
939 IPlugIn* new_mysql_replay(const DOMElement* e)
940 {
941     return new MySQLReplayCache(e);
942 }
943
944 extern "C" int SHIBMYSQL_EXPORTS saml_extension_init(void*)
945 {
946     // register this ccache type
947     SAMLConfig::getConfig().getPlugMgr().regFactory(shibtarget::XML::MySQLReplayCacheType, &new_mysql_replay);
948     SAMLConfig::getConfig().getPlugMgr().regFactory(shibtarget::XML::MySQLSessionCacheType, &new_mysql_ccache);
949     return 0;
950 }
951
952 extern "C" void SHIBMYSQL_EXPORTS saml_extension_term()
953 {
954     // Shutdown MySQL
955     mysql_server_end();
956     SAMLConfig::getConfig().getPlugMgr().unregFactory(shibtarget::XML::MySQLReplayCacheType);
957     SAMLConfig::getConfig().getPlugMgr().unregFactory(shibtarget::XML::MySQLSessionCacheType);
958 }