Only shutdown MySQL if it's been started.
[shibboleth/sp.git] / shib-mysql-ccache / shib-mysql-ccache.cpp
1 /*
2  *  Copyright 2001-2005 Internet2
3  * 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /*
18  * shib-mysql-ccache.cpp: Shibboleth Credential Cache using MySQL.
19  *
20  * Created by:  Derek Atkins <derek@ihtfp.com>
21  *
22  * $Id$
23  */
24
25 // eventually we might be able to support autoconf via cygwin...
26 #if defined (_MSC_VER) || defined(__BORLANDC__)
27 # include "config_win32.h"
28 #else
29 # include "config.h"
30 #endif
31
32 #ifdef WIN32
33 # define SHIBMYSQL_EXPORTS __declspec(dllexport)
34 #else
35 # define SHIBMYSQL_EXPORTS
36 #endif
37
38 #ifdef HAVE_UNISTD_H
39 # include <unistd.h>
40 #endif
41
42 #include <shib/shib-threads.h>
43 #include <shib-target/shib-target.h>
44 #include <log4cpp/Category.hh>
45
46 #include <sstream>
47
48 #ifdef WIN32
49 # include <winsock.h>
50 #endif
51 #include <mysql.h>
52
53 // wanted to use MySQL codes for this, but can't seem to get back a 145
54 #define isCorrupt(s) strstr(s,"(errno: 145)")
55
56 #ifdef HAVE_LIBDMALLOCXX
57 #include <dmalloc.h>
58 #endif
59
60 using namespace std;
61 using namespace saml;
62 using namespace shibboleth;
63 using namespace shibtarget;
64 using namespace log4cpp;
65
66 #define PLUGIN_VER_MAJOR 3
67 #define PLUGIN_VER_MINOR 0
68
69 #define STATE_TABLE \
70   "CREATE TABLE state (" \
71   "cookie VARCHAR(64) PRIMARY KEY, " \
72   "application_id VARCHAR(255)," \
73   "ctime TIMESTAMP," \
74   "atime TIMESTAMP," \
75   "addr VARCHAR(128)," \
76   "major INT," \
77   "minor INT," \
78   "provider VARCHAR(256)," \
79   "subject TEXT," \
80   "authn_context TEXT," \
81   "tokens TEXT)"
82
83 #define REPLAY_TABLE \
84   "CREATE TABLE replay (id VARCHAR(255) PRIMARY KEY, " \
85   "expires TIMESTAMP, " \
86   "INDEX (expires))"
87
88 static const XMLCh Argument[] =
89 { chLatin_A, chLatin_r, chLatin_g, chLatin_u, chLatin_m, chLatin_e, chLatin_n, chLatin_t, chNull };
90 static const XMLCh cleanupInterval[] =
91 { chLatin_c, chLatin_l, chLatin_e, chLatin_a, chLatin_n, chLatin_u, chLatin_p,
92   chLatin_I, chLatin_n, chLatin_t, chLatin_e, chLatin_r, chLatin_v, chLatin_a, chLatin_l, chNull
93 };
94 static const XMLCh cacheTimeout[] =
95 { chLatin_c, chLatin_a, chLatin_c, chLatin_h, chLatin_e, chLatin_T, chLatin_i, chLatin_m, chLatin_e, chLatin_o, chLatin_u, chLatin_t, chNull };
96 static const XMLCh mysqlTimeout[] =
97 { chLatin_m, chLatin_y, chLatin_s, chLatin_q, chLatin_l, chLatin_T, chLatin_i, chLatin_m, chLatin_e, chLatin_o, chLatin_u, chLatin_t, chNull };
98 static const XMLCh storeAttributes[] =
99 { chLatin_s, chLatin_t, chLatin_o, chLatin_r, chLatin_e, chLatin_A, chLatin_t, chLatin_t, chLatin_r, chLatin_i, chLatin_b, chLatin_u, chLatin_t, chLatin_e, chLatin_s, chNull };
100
101 static bool g_MySQLInitialized = false;
102
103 class MySQLBase : public virtual saml::IPlugIn
104 {
105 public:
106   MySQLBase(const DOMElement* e);
107   virtual ~MySQLBase();
108
109   MYSQL* getMYSQL();
110   bool repairTable(MYSQL*&, const char* table);
111
112   log4cpp::Category* log;
113
114 protected:
115   ThreadKey* m_mysql;
116   const DOMElement* m_root; // can only use this during initialization
117
118   bool initialized;
119   bool handleShutdown;
120
121   void createDatabase(MYSQL*, int major, int minor);
122   void upgradeDatabase(MYSQL*);
123   pair<int,int> getVersion(MYSQL*);
124 };
125
126 // Forward declarations
127 static void mysqlInit(const DOMElement* e, Category& log);
128
129 extern "C" void shib_mysql_destroy_handle(void* data)
130 {
131   MYSQL* mysql = (MYSQL*) data;
132   if (mysql) mysql_close(mysql);
133 }
134
135 MySQLBase::MySQLBase(const DOMElement* e) : m_root(e)
136 {
137 #ifdef _DEBUG
138   saml::NDC ndc("MySQLBase");
139 #endif
140   log = &(Category::getInstance("shibtarget.SessionCache.MySQL"));
141
142   m_mysql = ThreadKey::create(&shib_mysql_destroy_handle);
143
144   initialized = false;
145   mysqlInit(e,*log);
146   getMYSQL();
147   initialized = true;
148 }
149
150 MySQLBase::~MySQLBase()
151 {
152   delete m_mysql;
153 }
154
155 MYSQL* MySQLBase::getMYSQL()
156 {
157 #ifdef _DEBUG
158     saml::NDC ndc("getMYSQL");
159 #endif
160
161     // Do we already have a handle?
162     MYSQL* mysql=reinterpret_cast<MYSQL*>(m_mysql->getData());
163     if (mysql)
164         return mysql;
165
166     // Connect to the database
167     mysql = mysql_init(NULL);
168     if (!mysql) {
169         log->error("mysql_init failed");
170         mysql_close(mysql);
171         throw SAMLException("MySQLBase::getMYSQL(): mysql_init() failed");
172     }
173
174     if (!mysql_real_connect(mysql, NULL, NULL, NULL, "shibd", 0, NULL, 0)) {
175         if (initialized) {
176             log->crit("mysql_real_connect failed: %s", mysql_error(mysql));
177             mysql_close(mysql);
178             throw SAMLException("MySQLBase::getMYSQL(): mysql_real_connect() failed");
179         }
180         else {
181             log->info("mysql_real_connect failed: %s.  Trying to create", mysql_error(mysql));
182
183             // This will throw an exception if it fails.
184             createDatabase(mysql, PLUGIN_VER_MAJOR, PLUGIN_VER_MINOR);
185         }
186     }
187
188     pair<int,int> v=getVersion (mysql);
189
190     // Make sure we've got the right version
191     if (v.first != PLUGIN_VER_MAJOR || v.second != PLUGIN_VER_MINOR) {
192    
193         // If we're capable, try upgrading on the fly...
194         if (v.first == 0  || v.first == 1 || v.first == 2) {
195             if (mysql_query(mysql, "DROP TABLE state")) {
196                 log->error("error dropping old session state table: %s", mysql_error(mysql));
197             }
198             if (v.first==2 && mysql_query(mysql, "DROP TABLE replay")) {
199                 log->error("error dropping old session state table: %s", mysql_error(mysql));
200             }
201             upgradeDatabase(mysql);
202         }
203         else {
204             mysql_close(mysql);
205             log->crit("Unknown database version: %d.%d", v.first, v.second);
206             throw SAMLException("MySQLBase::getMYSQL(): Unknown database version");
207         }
208     }
209
210     // We're all set.. Save off the handle for this thread.
211     m_mysql->setData(mysql);
212     return mysql;
213 }
214
215 bool MySQLBase::repairTable(MYSQL*& mysql, const char* table)
216 {
217     string q = string("REPAIR TABLE ") + table;
218     if (mysql_query(mysql, q.c_str())) {
219         log->error("Error repairing table %s: %s", table, mysql_error(mysql));
220         return false;
221     }
222
223     // seems we have to recycle the connection to get the thread to keep working
224     // other threads seem to be ok, but we should monitor that
225     mysql_close(mysql);
226     m_mysql->setData(NULL);
227     mysql=getMYSQL();
228     return true;
229 }
230
231 void MySQLBase::createDatabase(MYSQL* mysql, int major, int minor)
232 {
233   log->info("creating database");
234
235   MYSQL* ms = NULL;
236   try {
237     ms = mysql_init(NULL);
238     if (!ms) {
239       log->crit("mysql_init failed");
240       throw SAMLException("ShibMySQLCCache::createDatabase(): mysql_init failed");
241     }
242
243     if (!mysql_real_connect(ms, NULL, NULL, NULL, NULL, 0, NULL, 0)) {
244       log->crit("cannot open DB file to create DB: %s", mysql_error(ms));
245       throw SAMLException("ShibMySQLCCache::createDatabase(): mysql_real_connect failed");
246     }
247
248     if (mysql_query(ms, "CREATE DATABASE shibd")) {
249       log->crit("cannot create shibd database: %s", mysql_error(ms));
250       throw SAMLException("ShibMySQLCCache::createDatabase(): create db cmd failed");
251     }
252
253     if (!mysql_real_connect(mysql, NULL, NULL, NULL, "shibd", 0, NULL, 0)) {
254       log->crit("cannot open shibd database");
255       throw SAMLException("ShibMySQLCCache::createDatabase(): mysql_real_connect to plugin db failed");
256     }
257
258     mysql_close(ms);
259     
260   }
261   catch (SAMLException&) {
262     if (ms)
263       mysql_close(ms);
264     mysql_close(mysql);
265     throw;
266   }
267
268   // Now create the tables if they don't exist
269   log->info("Creating database tables");
270
271   if (mysql_query(mysql, "CREATE TABLE version (major INT, minor INT)")) {
272     log->error ("error creating version: %s", mysql_error(mysql));
273     throw SAMLException("ShibMySQLCCache::createDatabase(): create table cmd failed");
274   }
275
276   if (mysql_query(mysql,STATE_TABLE)) {
277     log->error ("error creating state table: %s", mysql_error(mysql));
278     throw SAMLException("ShibMySQLCCache::createDatabase(): create table cmd failed");
279   }
280
281   if (mysql_query(mysql,REPLAY_TABLE)) {
282     log->error ("error creating replay table: %s", mysql_error(mysql));
283     throw SAMLException("ShibMySQLCCache::createDatabase(): create table cmd failed");
284   }
285
286   ostringstream q;
287   q << "INSERT INTO version VALUES(" << major << "," << minor << ")";
288   if (mysql_query(mysql, q.str().c_str())) {
289     log->error ("error setting version: %s", mysql_error(mysql));
290     throw SAMLException("ShibMySQLCCache::createDatabase(): version insert failed");
291   }
292 }
293
294 void MySQLBase::upgradeDatabase(MYSQL* mysql)
295 {
296     if (mysql_query(mysql,STATE_TABLE)) {
297         log->error ("error creating state table: %s", mysql_error(mysql));
298         throw SAMLException("ShibMySQLCCache::upgradeDatabase(): error creating state table");
299     }
300
301     if (mysql_query(mysql,REPLAY_TABLE)) {
302         log->error ("error creating replay table: %s", mysql_error(mysql));
303         throw SAMLException("ShibMySQLCCache::upgradeDatabase(): error creating replay table");
304     }
305
306     ostringstream q;
307     q << "UPDATE version SET major = " << PLUGIN_VER_MAJOR;
308     if (mysql_query(mysql, q.str().c_str())) {
309         log->error ("error updating version: %s", mysql_error(mysql));
310         throw SAMLException("ShibMySQLCCache::upgradeDatabase(): error updating version");
311     }
312 }
313
314 pair<int,int> MySQLBase::getVersion(MYSQL* mysql)
315 {
316     // grab the version number from the database
317     if (mysql_query(mysql, "SELECT * FROM version")) {
318         log->error("error reading version: %s", mysql_error(mysql));
319         throw SAMLException("MySQLBase::getVersion(): error reading version");
320     }
321
322     MYSQL_RES* rows = mysql_store_result(mysql);
323     if (rows) {
324         if (mysql_num_rows(rows) == 1 && mysql_num_fields(rows) == 2)  {
325           MYSQL_ROW row = mysql_fetch_row(rows);
326           int major = row[0] ? atoi(row[0]) : -1;
327           int minor = row[1] ? atoi(row[1]) : -1;
328           log->debug("opening database version %d.%d", major, minor);
329           mysql_free_result(rows);
330           return make_pair(major,minor);
331         }
332         else {
333             // Wrong number of rows or wrong number of fields...
334             log->crit("Houston, we've got a problem with the database...");
335             mysql_free_result(rows);
336             throw SAMLException("MySQLBase::getVersion(): version verification failed");
337         }
338     }
339     log->crit("MySQL Read Failed in version verification");
340     throw SAMLException("MySQLBase::getVersion(): error reading version");
341 }
342
343 static void mysqlInit(const DOMElement* e, Category& log)
344 {
345     if (g_MySQLInitialized) {
346         log.info("MySQL embedded server already initialized");
347         return;
348     }
349     log.info("initializing MySQL embedded server");
350
351     // Setup the argument array
352     vector<string> arg_array;
353     arg_array.push_back("shibboleth");
354
355     // grab any MySQL parameters from the config file
356     e=saml::XML::getFirstChildElement(e,shibtarget::XML::SHIBTARGET_NS,Argument);
357     while (e) {
358         auto_ptr_char arg(e->getFirstChild()->getNodeValue());
359         if (arg.get())
360             arg_array.push_back(arg.get());
361         e=saml::XML::getNextSiblingElement(e,shibtarget::XML::SHIBTARGET_NS,Argument);
362     }
363
364     // Compute the argument array
365     vector<string>::size_type arg_count = arg_array.size();
366     const char** args=new const char*[arg_count];
367     for (vector<string>::size_type i = 0; i < arg_count; i++)
368         args[i] = arg_array[i].c_str();
369
370     // Initialize MySQL with the arguments
371     mysql_server_init(arg_count, (char **)args, NULL);
372
373     delete[] args;
374     g_MySQLInitialized = true;
375 }  
376
377 class ShibMySQLCCache : public MySQLBase, virtual public ISessionCache, virtual public ISessionCacheStore
378 {
379 public:
380     ShibMySQLCCache(const DOMElement* e);
381     virtual ~ShibMySQLCCache();
382
383     // Delegate all the ISessionCache methods.
384     string insert(
385         const IApplication* application,
386         const IEntityDescriptor* source,
387         const char* client_addr,
388         const SAMLSubject* subject,
389         const char* authnContext,
390         const SAMLResponse* tokens
391         )
392     { return m_cache->insert(application,source,client_addr,subject,authnContext,tokens); }
393     ISessionCacheEntry* find(const char* key, const IApplication* application, const char* client_addr)
394     { return m_cache->find(key,application,client_addr); }
395     void remove(const char* key, const IApplication* application, const char* client_addr)
396     { m_cache->remove(key,application,client_addr); }
397
398     bool setBackingStore(ISessionCacheStore*) {return false;}
399
400     // Store methods handle the database work
401     HRESULT onCreate(
402         const char* key,
403         const IApplication* application,
404         const ISessionCacheEntry* entry,
405         int majorVersion,
406         int minorVersion,
407         time_t created
408         );
409     HRESULT onRead(
410         const char* key,
411         string& applicationId,
412         string& clientAddress,
413         string& providerId,
414         string& subject,
415         string& authnContext,
416         string& tokens,
417         int& majorVersion,
418         int& minorVersion,
419         time_t& created,
420         time_t& accessed
421         );
422     HRESULT onRead(const char* key, time_t& accessed);
423     HRESULT onRead(const char* key, string& tokens);
424     HRESULT onUpdate(const char* key, const char* tokens=NULL, time_t accessed=0);
425     HRESULT onDelete(const char* key);
426
427     void cleanup();
428
429 private:
430     bool m_storeAttributes;
431     ISessionCache* m_cache;
432     CondWait* shutdown_wait;
433     bool shutdown;
434     Thread* cleanup_thread;
435
436     static void* cleanup_fcn(void*); // XXX Assumed an ShibMySQLCCache
437 };
438
439 ShibMySQLCCache::ShibMySQLCCache(const DOMElement* e) : MySQLBase(e), m_storeAttributes(false)
440 {
441 #ifdef _DEBUG
442     saml::NDC ndc("ShibMySQLCCache");
443 #endif
444
445     m_cache = dynamic_cast<ISessionCache*>(
446         SAMLConfig::getConfig().getPlugMgr().newPlugin(shibtarget::XML::MemorySessionCacheType, e)
447     );
448     if (!m_cache->setBackingStore(this)) {
449         delete m_cache;
450         throw SAMLException("Unable to register MySQL cache plugin as a cache store.");
451     }
452     
453     shutdown_wait = CondWait::create();
454     shutdown = false;
455
456     // Load our configuration details...
457     const XMLCh* tag=m_root->getAttributeNS(NULL,storeAttributes);
458     if (tag && *tag && (*tag==chLatin_t || *tag==chDigit_1))
459         m_storeAttributes=true;
460
461     // Initialize the cleanup thread
462     cleanup_thread = Thread::create(&cleanup_fcn, (void*)this);
463 }
464
465 ShibMySQLCCache::~ShibMySQLCCache()
466 {
467     shutdown = true;
468     shutdown_wait->signal();
469     cleanup_thread->join(NULL);
470     delete m_cache;
471 }
472
473 HRESULT ShibMySQLCCache::onCreate(
474     const char* key,
475     const IApplication* application,
476     const ISessionCacheEntry* entry,
477     int majorVersion,
478     int minorVersion,
479     time_t created
480     )
481 {
482 #ifdef _DEBUG
483     saml::NDC ndc("onCreate");
484 #endif
485
486     // Get XML data from entry. Default is not to return SAML objects.
487     const char* context=entry->getAuthnContext();
488     pair<const char*,const SAMLSubject*> subject=entry->getSubject();
489     pair<const char*,const SAMLResponse*> tokens=entry->getTokens();
490
491     ostringstream q;
492     q << "INSERT INTO state VALUES('" << key << "','" << application->getId() << "',";
493     if (created==0)
494         q << "NOW(),NOW(),'";
495     else
496         q << "FROM_UNIXTIME(" << created << "),NOW(),'";
497     q << entry->getClientAddress() << "'," << majorVersion << "," << minorVersion << ",'" << entry->getProviderId() << "','"
498         << subject.first << "','" << context << "',";
499
500     if (m_storeAttributes && tokens.first)
501         q << "'" << tokens.first << "')";
502     else
503         q << "null)";
504
505     if (log->isDebugEnabled())
506         log->debug("SQL insert: %s", q.str().c_str());
507
508     // then add it to the database
509     MYSQL* mysql = getMYSQL();
510     if (mysql_query(mysql, q.str().c_str())) {
511         const char* err=mysql_error(mysql);
512         log->error("error inserting %s: %s", key, err);
513         if (isCorrupt(err) && repairTable(mysql,"state")) {
514             // Try again...
515             if (mysql_query(mysql, q.str().c_str())) {
516                 log->error("error inserting %s: %s", key, mysql_error(mysql));
517                 return E_FAIL;
518             }
519         }
520         else
521             throw E_FAIL;
522     }
523
524     return NOERROR;
525 }
526
527 HRESULT ShibMySQLCCache::onRead(
528     const char* key,
529     string& applicationId,
530     string& clientAddress,
531     string& providerId,
532     string& subject,
533     string& authnContext,
534     string& tokens,
535     int& majorVersion,
536     int& minorVersion,
537     time_t& created,
538     time_t& accessed
539     )
540 {
541 #ifdef _DEBUG
542     saml::NDC ndc("onRead");
543 #endif
544
545     log->debug("searching MySQL database...");
546
547     string q = string("SELECT application_id,UNIX_TIMESTAMP(ctime),UNIX_TIMESTAMP(atime),addr,major,minor,provider,subject,authn_context,tokens FROM state WHERE cookie='") + key + "' LIMIT 1";
548
549     MYSQL* mysql = getMYSQL();
550     if (mysql_query(mysql, q.c_str())) {
551         const char* err=mysql_error(mysql);
552         log->error("error searching for %s: %s", key, err);
553         if (isCorrupt(err) && repairTable(mysql,"state")) {
554             if (mysql_query(mysql, q.c_str()))
555                 log->error("error retrying search for %s: %s", key, mysql_error(mysql));
556         }
557     }
558
559     MYSQL_RES* rows = mysql_store_result(mysql);
560
561     // Nope, doesn't exist.
562     if (!rows || mysql_num_rows(rows)==0) {
563         log->debug("not found in database");
564         if (rows)
565             mysql_free_result(rows);
566         return S_FALSE;
567     }
568
569     // Make sure we got 1 and only 1 row.
570     if (mysql_num_rows(rows) > 1) {
571         log->error("database select returned %d rows!", mysql_num_rows(rows));
572         mysql_free_result(rows);
573         return E_FAIL;
574     }
575
576     log->debug("session found, tranfering data back into memory");
577     
578     /* Columns in query:
579         0: application_id
580         1: ctime
581         2: atime
582         3: address
583         4: major
584         5: minor
585         6: provider
586         7: subject
587         8: authncontext
588         9: tokens
589      */
590
591     MYSQL_ROW row = mysql_fetch_row(rows);
592     applicationId=row[0];
593     created=atoi(row[1]);
594     accessed=atoi(row[2]);
595     clientAddress=row[3];
596     majorVersion=atoi(row[4]);
597     minorVersion=atoi(row[5]);
598     providerId=row[6];
599     subject=row[7];
600     authnContext=row[8];
601     if (row[9])
602         tokens=row[9];
603
604     // Free the results.
605     mysql_free_result(rows);
606
607     return NOERROR;
608 }
609
610 HRESULT ShibMySQLCCache::onRead(const char* key, time_t& accessed)
611 {
612 #ifdef _DEBUG
613     saml::NDC ndc("onRead");
614 #endif
615
616     log->debug("reading last access time from MySQL database");
617
618     string q = string("SELECT UNIX_TIMESTAMP(atime) FROM state WHERE cookie='") + key + "' LIMIT 1";
619
620     MYSQL* mysql = getMYSQL();
621     if (mysql_query(mysql, q.c_str())) {
622         const char* err=mysql_error(mysql);
623         log->error("error searching for %s: %s", key, err);
624         if (isCorrupt(err) && repairTable(mysql,"state")) {
625             if (mysql_query(mysql, q.c_str()))
626                 log->error("error retrying search for %s: %s", key, mysql_error(mysql));
627         }
628     }
629
630     MYSQL_RES* rows = mysql_store_result(mysql);
631
632     // Nope, doesn't exist.
633     if (!rows || mysql_num_rows(rows)==0) {
634         log->warn("session expected, but not found in database");
635         if (rows)
636             mysql_free_result(rows);
637         return S_FALSE;
638     }
639
640     // Make sure we got 1 and only 1 row.
641     if (mysql_num_rows(rows) != 1) {
642         log->error("database select returned %d rows!", mysql_num_rows(rows));
643         mysql_free_result(rows);
644         return E_FAIL;
645     }
646
647     MYSQL_ROW row = mysql_fetch_row(rows);
648     accessed=atoi(row[0]);
649
650     // Free the results.
651     mysql_free_result(rows);
652
653     return NOERROR;
654 }
655
656 HRESULT ShibMySQLCCache::onRead(const char* key, string& tokens)
657 {
658 #ifdef _DEBUG
659     saml::NDC ndc("onRead");
660 #endif
661
662     if (!m_storeAttributes)
663         return S_FALSE;
664
665     log->debug("reading cached tokens from MySQL database");
666
667     string q = string("SELECT tokens FROM state WHERE cookie='") + key + "' LIMIT 1";
668
669     MYSQL* mysql = getMYSQL();
670     if (mysql_query(mysql, q.c_str())) {
671         const char* err=mysql_error(mysql);
672         log->error("error searching for %s: %s", key, err);
673         if (isCorrupt(err) && repairTable(mysql,"state")) {
674             if (mysql_query(mysql, q.c_str()))
675                 log->error("error retrying search for %s: %s", key, mysql_error(mysql));
676         }
677     }
678
679     MYSQL_RES* rows = mysql_store_result(mysql);
680
681     // Nope, doesn't exist.
682     if (!rows || mysql_num_rows(rows)==0) {
683         log->warn("session expected, but not found in database");
684         if (rows)
685             mysql_free_result(rows);
686         return S_FALSE;
687     }
688
689     // Make sure we got 1 and only 1 row.
690     if (mysql_num_rows(rows) != 1) {
691         log->error("database select returned %d rows!", mysql_num_rows(rows));
692         mysql_free_result(rows);
693         return E_FAIL;
694     }
695
696     MYSQL_ROW row = mysql_fetch_row(rows);
697     if (row[0])
698         tokens=row[0];
699
700     // Free the results.
701     mysql_free_result(rows);
702
703     return NOERROR;
704 }
705
706 HRESULT ShibMySQLCCache::onUpdate(const char* key, const char* tokens, time_t lastAccess)
707 {
708 #ifdef _DEBUG
709     saml::NDC ndc("onUpdate");
710 #endif
711
712     ostringstream q;
713     if (lastAccess>0)
714         q << "UPDATE state SET atime=FROM_UNIXTIME(" << lastAccess << ")";
715     else if (tokens) {
716         if (!m_storeAttributes)
717             return S_FALSE;
718         q << "UPDATE state SET tokens=";
719         if (*tokens)
720             q << "'" << tokens << "'";
721         else
722             q << "null";
723     }
724     else {
725         log->warn("onUpdate called with nothing to do!");
726         return S_FALSE;
727     }
728  
729     q << " WHERE cookie='" << key << "'";
730
731     MYSQL* mysql = getMYSQL();
732     if (mysql_query(mysql, q.str().c_str())) {
733         const char* err=mysql_error(mysql);
734         log->error("error updating %s: %s", key, err);
735         if (isCorrupt(err) && repairTable(mysql,"state")) {
736             // Try again...
737             if (mysql_query(mysql, q.str().c_str())) {
738                 log->error("error updating %s: %s", key, mysql_error(mysql));
739                 return E_FAIL;
740             }
741         }
742         else
743             return E_FAIL;
744     }
745
746     return NOERROR;
747 }
748
749 HRESULT ShibMySQLCCache::onDelete(const char* key)
750 {
751 #ifdef _DEBUG
752     saml::NDC ndc("onDelete");
753 #endif
754
755     // Remove from the database
756     string q = string("DELETE FROM state WHERE cookie='") + key + "'";
757     MYSQL* mysql = getMYSQL();
758     if (mysql_query(mysql, q.c_str())) {
759         const char* err=mysql_error(mysql);
760         log->error("error deleting entry %s: %s", key, err);
761         if (isCorrupt(err) && repairTable(mysql,"state")) {
762             // Try again...
763             if (mysql_query(mysql, q.c_str())) {
764                 log->error("error deleting entry %s: %s", key, mysql_error(mysql));
765                 return E_FAIL;
766             }
767         }
768         else
769             return E_FAIL;
770     }
771
772     return NOERROR;
773 }
774
775 void ShibMySQLCCache::cleanup()
776 {
777 #ifdef _DEBUG
778   saml::NDC ndc("cleanup");
779 #endif
780
781   Mutex* mutex = Mutex::create();
782
783   int rerun_timer = 0;
784   int timeout_life = 0;
785
786   // Load our configuration details...
787   const XMLCh* tag=m_root->getAttributeNS(NULL,cleanupInterval);
788   if (tag && *tag)
789     rerun_timer = XMLString::parseInt(tag);
790
791   // search for 'mysql-cache-timeout' and then the regular cache timeout
792   tag=m_root->getAttributeNS(NULL,mysqlTimeout);
793   if (tag && *tag)
794     timeout_life = XMLString::parseInt(tag);
795   else {
796       tag=m_root->getAttributeNS(NULL,cacheTimeout);
797       if (tag && *tag)
798         timeout_life = XMLString::parseInt(tag);
799   }
800   
801   if (rerun_timer <= 0)
802     rerun_timer = 300;          // rerun every 5 minutes
803
804   if (timeout_life <= 0)
805     timeout_life = 28800;       // timeout after 8 hours
806
807   mutex->lock();
808
809   MYSQL* mysql = getMYSQL();
810
811   log->info("cleanup thread started...Run every %d secs; timeout after %d secs", rerun_timer, timeout_life);
812
813   while (shutdown == false) {
814     shutdown_wait->timedwait(mutex, rerun_timer);
815
816     if (shutdown == true)
817       break;
818
819     // Find all the entries in the database that haven't been used
820     // recently In particular, find all entries that have not been
821     // accessed in 'timeout_life' seconds.
822     ostringstream q;
823     q << "DELETE FROM state WHERE " << "UNIX_TIMESTAMP(NOW()) - UNIX_TIMESTAMP(atime) >= " << timeout_life;
824
825     if (mysql_query(mysql, q.str().c_str())) {
826       const char* err=mysql_error(mysql);
827       log->error("error purging old records: %s", err);
828         if (isCorrupt(err) && repairTable(mysql,"state")) {
829           if (mysql_query(mysql, q.str().c_str()))
830             log->error("error re-purging old records: %s", mysql_error(mysql));
831         }
832     }
833   }
834
835   log->info("cleanup thread exiting...");
836
837   mutex->unlock();
838   delete mutex;
839   Thread::exit(NULL);
840 }
841
842 void* ShibMySQLCCache::cleanup_fcn(void* cache_p)
843 {
844   ShibMySQLCCache* cache = (ShibMySQLCCache*)cache_p;
845
846   // First, let's block all signals
847   Thread::mask_all_signals();
848
849   // Now run the cleanup process.
850   cache->cleanup();
851   return NULL;
852 }
853
854 class MySQLReplayCache : public MySQLBase, virtual public IReplayCache
855 {
856 public:
857   MySQLReplayCache(const DOMElement* e);
858   virtual ~MySQLReplayCache() {}
859
860   bool check(const XMLCh* str, time_t expires) {auto_ptr_XMLCh temp(str); return check(temp.get(),expires);}
861   bool check(const char* str, time_t expires);
862 };
863
864 MySQLReplayCache::MySQLReplayCache(const DOMElement* e) : MySQLBase(e) {}
865
866 bool MySQLReplayCache::check(const char* str, time_t expires)
867 {
868 #ifdef _DEBUG
869     saml::NDC ndc("check");
870 #endif
871   
872     // Remove expired entries
873     string q = string("DELETE FROM replay WHERE expires < NOW()");
874     MYSQL* mysql = getMYSQL();
875     if (mysql_query(mysql, q.c_str())) {
876         const char* err=mysql_error(mysql);
877         log->error("Error deleting expired entries: %s", err);
878         if (isCorrupt(err) && repairTable(mysql,"replay")) {
879             // Try again...
880             if (mysql_query(mysql, q.c_str()))
881                 log->error("Error deleting expired entries: %s", mysql_error(mysql));
882         }
883     }
884   
885     string q2 = string("SELECT id FROM replay WHERE id='") + str + "'";
886     if (mysql_query(mysql, q2.c_str())) {
887         const char* err=mysql_error(mysql);
888         log->error("Error searching for %s: %s", str, err);
889         if (isCorrupt(err) && repairTable(mysql,"replay")) {
890             if (mysql_query(mysql, q2.c_str())) {
891                 log->error("Error retrying search for %s: %s", str, mysql_error(mysql));
892                 throw SAMLException("Replay cache failed, please inform application support staff.");
893             }
894         }
895         else
896             throw SAMLException("Replay cache failed, please inform application support staff.");
897     }
898
899     // Did we find it?
900     MYSQL_RES* rows = mysql_store_result(mysql);
901     if (rows && mysql_num_rows(rows)>0) {
902       mysql_free_result(rows);
903       return false;
904     }
905
906     ostringstream q3;
907     q3 << "INSERT INTO replay VALUES('" << str << "'," << "FROM_UNIXTIME(" << expires << "))";
908
909     // then add it to the database
910     if (mysql_query(mysql, q3.str().c_str())) {
911         const char* err=mysql_error(mysql);
912         log->error("Error inserting %s: %s", str, err);
913         if (isCorrupt(err) && repairTable(mysql,"state")) {
914             // Try again...
915             if (mysql_query(mysql, q3.str().c_str())) {
916                 log->error("Error inserting %s: %s", str, mysql_error(mysql));
917                 throw SAMLException("Replay cache failed, please inform application support staff.");
918             }
919         }
920         else
921             throw SAMLException("Replay cache failed, please inform application support staff.");
922     }
923     
924     return true;
925 }
926
927 /*************************************************************************
928  * The registration functions here...
929  */
930
931 IPlugIn* new_mysql_ccache(const DOMElement* e)
932 {
933     return new ShibMySQLCCache(e);
934 }
935
936 IPlugIn* new_mysql_replay(const DOMElement* e)
937 {
938     return new MySQLReplayCache(e);
939 }
940
941 extern "C" int SHIBMYSQL_EXPORTS saml_extension_init(void*)
942 {
943     // register this ccache type
944     SAMLConfig::getConfig().getPlugMgr().regFactory(shibtarget::XML::MySQLReplayCacheType, &new_mysql_replay);
945     SAMLConfig::getConfig().getPlugMgr().regFactory(shibtarget::XML::MySQLSessionCacheType, &new_mysql_ccache);
946     return 0;
947 }
948
949 extern "C" void SHIBMYSQL_EXPORTS saml_extension_term()
950 {
951     // Shutdown MySQL
952     if (g_MySQLInitialized)
953         mysql_server_end();
954     SAMLConfig::getConfig().getPlugMgr().unregFactory(shibtarget::XML::MySQLReplayCacheType);
955     SAMLConfig::getConfig().getPlugMgr().unregFactory(shibtarget::XML::MySQLSessionCacheType);
956 }