Fix for shibd init script install.
[shibboleth/cpp-sp.git] / shib-mysql-ccache / shib-mysql-ccache.cpp
1 /*
2  *  Copyright 2001-2005 Internet2
3  * 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /*
18  * shib-mysql-ccache.cpp: Shibboleth Credential Cache using MySQL.
19  *
20  * Created by:  Derek Atkins <derek@ihtfp.com>
21  *
22  * $Id$
23  */
24
25 /* This file is loosely based off the Shibboleth Credential Cache.
26  * This plug-in is designed as a two-layer cache.  Layer 1, the
27  * long-term cache, stores data in a MySQL embedded database.
28  *
29  * Short-term data is cached in memory as SAML objects in the layer 2
30  * cache.
31  */
32
33 // eventually we might be able to support autoconf via cygwin...
34 #if defined (_MSC_VER) || defined(__BORLANDC__)
35 # include "config_win32.h"
36 #else
37 # include "config.h"
38 #endif
39
40 #ifdef WIN32
41 # define SHIBMYSQL_EXPORTS __declspec(dllexport)
42 #else
43 # define SHIBMYSQL_EXPORTS
44 #endif
45
46 #ifdef HAVE_UNISTD_H
47 # include <unistd.h>
48 #endif
49
50 #include <shib-target/shib-target.h>
51 #include <shib/shib-threads.h>
52
53 #if defined(HAVE_LOG4SHIB)
54 # include <log4shib/Category.hh>
55 namespace shibmysql {
56     namespace logging = log4shib;
57 };
58 #elif defined(HAVE_LOG4CPP)
59 # include <log4cpp/Category.hh>
60 namespace shibmysql {
61     namespace logging = log4cpp;
62 };
63 #else
64 # error "Supported logging library not available."
65 #endif
66
67 #include <sstream>
68 #include <stdexcept>
69
70 #include <mysql.h>
71
72 // wanted to use MySQL codes for this, but can't seem to get back a 145
73 #define isCorrupt(s) strstr(s,"(errno: 145)")
74
75 #ifdef HAVE_LIBDMALLOCXX
76 #include <dmalloc.h>
77 #endif
78
79 using namespace std;
80 using namespace saml;
81 using namespace shibboleth;
82 using namespace shibtarget;
83 using namespace shibmysql::logging;
84
85 #define PLUGIN_VER_MAJOR 2
86 #define PLUGIN_VER_MINOR 0
87
88 #define STATE_TABLE \
89   "CREATE TABLE state (cookie VARCHAR(64) PRIMARY KEY, " \
90   "application_id VARCHAR(255)," \
91   "ctime TIMESTAMP," \
92   "atime TIMESTAMP," \
93   "addr VARCHAR(128)," \
94   "profile INT," \
95   "provider VARCHAR(256)," \
96   "response_id VARCHAR(128)," \
97   "response TEXT," \
98   "statement TEXT)"
99
100 #define REPLAY_TABLE \
101   "CREATE TABLE replay (id VARCHAR(255) PRIMARY KEY, " \
102   "expires TIMESTAMP, " \
103   "INDEX (expires))"
104
105 static const XMLCh Argument[] =
106 { chLatin_A, chLatin_r, chLatin_g, chLatin_u, chLatin_m, chLatin_e, chLatin_n, chLatin_t, chNull };
107 static const XMLCh cleanupInterval[] =
108 { chLatin_c, chLatin_l, chLatin_e, chLatin_a, chLatin_n, chLatin_u, chLatin_p,
109   chLatin_I, chLatin_n, chLatin_t, chLatin_e, chLatin_r, chLatin_v, chLatin_a, chLatin_l, chNull
110 };
111 static const XMLCh cacheTimeout[] =
112 { chLatin_c, chLatin_a, chLatin_c, chLatin_h, chLatin_e, chLatin_T, chLatin_i, chLatin_m, chLatin_e, chLatin_o, chLatin_u, chLatin_t, chNull };
113 static const XMLCh mysqlTimeout[] =
114 { chLatin_m, chLatin_y, chLatin_s, chLatin_q, chLatin_l, chLatin_T, chLatin_i, chLatin_m, chLatin_e, chLatin_o, chLatin_u, chLatin_t, chNull };
115 static const XMLCh storeAttributes[] =
116 { chLatin_s, chLatin_t, chLatin_o, chLatin_r, chLatin_e, chLatin_A, chLatin_t, chLatin_t, chLatin_r, chLatin_i, chLatin_b, chLatin_u, chLatin_t, chLatin_e, chLatin_s, chNull };
117
118 class MySQLBase
119 {
120 public:
121   MySQLBase(const DOMElement* e);
122   virtual ~MySQLBase();
123
124   void thread_init();
125   void thread_end() {}
126
127   MYSQL* getMYSQL() const;
128   bool repairTable(MYSQL*&, const char* table);
129
130   Category* log;
131
132 protected:
133   ThreadKey* m_mysql;
134   const DOMElement* m_root; // can only use this during initialization
135
136   bool initialized;
137
138   void createDatabase(MYSQL*, int major, int minor);
139   void upgradeDatabase(MYSQL*);
140   void getVersion(MYSQL*, int* major_p, int* minor_p);
141 };
142
143 // Forward declarations
144 static void mysqlInit(const DOMElement* e, Category& log);
145
146 extern "C" void shib_mysql_destroy_handle(void* data)
147 {
148   MYSQL* mysql = (MYSQL*) data;
149   if (mysql) mysql_close(mysql);
150 }
151
152 MySQLBase::MySQLBase(const DOMElement* e) : m_root(e)
153 {
154 #ifdef _DEBUG
155   saml::NDC ndc("MySQLBase");
156 #endif
157   log = &(Category::getInstance("shibmysql.MySQLBase"));
158
159   m_mysql = ThreadKey::create(&shib_mysql_destroy_handle);
160
161   initialized = false;
162   mysqlInit(e,*log);
163   thread_init();
164   initialized = true;
165 }
166
167 MySQLBase::~MySQLBase()
168 {
169   thread_end();
170
171   delete m_mysql;
172 }
173
174 MYSQL* MySQLBase::getMYSQL() const
175 {
176   return (MYSQL*)m_mysql->getData();
177 }
178
179 void MySQLBase::thread_init()
180 {
181 #ifdef _DEBUG
182   saml::NDC ndc("thread_init");
183 #endif
184
185   // Connect to the database
186   MYSQL* mysql = mysql_init(NULL);
187   if (!mysql) {
188     log->error("mysql_init failed");
189     mysql_close(mysql);
190     throw SAMLException("MySQLBase::thread_init(): mysql_init() failed");
191   }
192
193   if (!mysql_real_connect(mysql, NULL, NULL, NULL, "shar", 0, NULL, 0)) {
194     if (initialized) {
195       log->crit("mysql_real_connect failed: %s", mysql_error(mysql));
196       mysql_close(mysql);
197       throw SAMLException("MySQLBase::thread_init(): mysql_real_connect() failed");
198     } else {
199       log->info("mysql_real_connect failed: %s.  Trying to create", mysql_error(mysql));
200
201       // This will throw an exception if it fails.
202       createDatabase(mysql, PLUGIN_VER_MAJOR, PLUGIN_VER_MINOR);
203     }
204   }
205
206   int major = -1, minor = -1;
207   getVersion (mysql, &major, &minor);
208
209   // Make sure we've got the right version
210   if (major != PLUGIN_VER_MAJOR || minor != PLUGIN_VER_MINOR) {
211    
212     // If we're capable, try upgrading on the fly...
213     if (major == 0  || major == 1) {
214        upgradeDatabase(mysql);
215     }
216     else {
217         mysql_close(mysql);
218         log->crit("Unknown database version: %d.%d", major, minor);
219         throw SAMLException("MySQLBase::thread_init(): Unknown database version");
220     }
221   }
222
223   // We're all set.. Save off the handle for this thread.
224   m_mysql->setData(mysql);
225 }
226
227 bool MySQLBase::repairTable(MYSQL*& mysql, const char* table)
228 {
229   string q = string("REPAIR TABLE ") + table;
230   if (mysql_query(mysql, q.c_str())) {
231     log->error("Error repairing table %s: %s", table, mysql_error(mysql));
232     return false;
233   }
234
235   // seems we have to recycle the connection to get the thread to keep working
236   // other threads seem to be ok, but we should monitor that
237   mysql_close(mysql);
238   m_mysql->setData(NULL);
239   thread_init();
240   mysql=getMYSQL();
241   return true;
242 }
243
244 void MySQLBase::createDatabase(MYSQL* mysql, int major, int minor)
245 {
246   log->info("Creating database.");
247
248   MYSQL* ms = NULL;
249   try {
250     ms = mysql_init(NULL);
251     if (!ms) {
252       log->crit("mysql_init failed");
253       throw SAMLException("ShibMySQLCCache::createDatabase(): mysql_init failed");
254     }
255
256     if (!mysql_real_connect(ms, NULL, NULL, NULL, NULL, 0, NULL, 0)) {
257       log->crit("cannot open DB file to create DB: %s", mysql_error(ms));
258       throw SAMLException("ShibMySQLCCache::createDatabase(): mysql_real_connect failed");
259     }
260
261     if (mysql_query(ms, "CREATE DATABASE shar")) {
262       log->crit("cannot create shar database: %s", mysql_error(ms));
263       throw SAMLException("ShibMySQLCCache::createDatabase(): create db cmd failed");
264     }
265
266     if (!mysql_real_connect(mysql, NULL, NULL, NULL, "shar", 0, NULL, 0)) {
267       log->crit("cannot open SHAR database");
268       throw SAMLException("ShibMySQLCCache::createDatabase(): mysql_real_connect to plugin db failed");
269     }
270
271     mysql_close(ms);
272     
273   }
274   catch (SAMLException&) {
275     if (ms)
276       mysql_close(ms);
277     mysql_close(mysql);
278     throw;
279   }
280
281   // Now create the tables if they don't exist
282   log->info("Creating database tables");
283
284   if (mysql_query(mysql, "CREATE TABLE version (major INT, minor INT)")) {
285     log->error ("Error creating version: %s", mysql_error(mysql));
286     throw SAMLException("ShibMySQLCCache::createDatabase(): create table cmd failed");
287   }
288
289   if (mysql_query(mysql,STATE_TABLE)) {
290     log->error ("Error creating state table: %s", mysql_error(mysql));
291     throw SAMLException("ShibMySQLCCache::createDatabase(): create table cmd failed");
292   }
293
294   if (mysql_query(mysql,REPLAY_TABLE)) {
295     log->error ("Error creating replay table: %s", mysql_error(mysql));
296     throw SAMLException("ShibMySQLCCache::createDatabase(): create table cmd failed");
297   }
298
299   ostringstream q;
300   q << "INSERT INTO version VALUES(" << major << "," << minor << ")";
301   if (mysql_query(mysql, q.str().c_str())) {
302     log->error ("Error setting version: %s", mysql_error(mysql));
303     throw SAMLException("ShibMySQLCCache::createDatabase(): version insert failed");
304   }
305 }
306
307 void MySQLBase::upgradeDatabase(MYSQL* mysql)
308 {
309     if (mysql_query(mysql, "DROP TABLE state")) {
310         log->error("Error dropping old session state table: %s", mysql_error(mysql));
311     }
312
313     if (mysql_query(mysql,STATE_TABLE)) {
314         log->error ("Error creating state table: %s", mysql_error(mysql));
315         throw SAMLException("ShibMySQLCCache::upgradeDatabase(): error creating state table");
316     }
317
318     if (mysql_query(mysql,REPLAY_TABLE)) {
319         log->error ("Error creating replay table: %s", mysql_error(mysql));
320         throw SAMLException("ShibMySQLCCache::upgradeDatabase(): error creating replay table");
321     }
322
323     ostringstream q;
324     q << "UPDATE version SET major = " << PLUGIN_VER_MAJOR;
325     if (mysql_query(mysql, q.str().c_str())) {
326         log->error ("Error updating version: %s", mysql_error(mysql));
327         throw SAMLException("ShibMySQLCCache::upgradeDatabase(): error updating version");
328     }
329 }
330
331 void MySQLBase::getVersion(MYSQL* mysql, int* major_p, int* minor_p)
332 {
333   // grab the version number from the database
334   if (mysql_query(mysql, "SELECT * FROM version"))
335     log->error ("Error reading version: %s", mysql_error(mysql));
336
337   MYSQL_RES* rows = mysql_store_result(mysql);
338   if (rows) {
339     if (mysql_num_rows(rows) == 1 && mysql_num_fields(rows) == 2)  {
340       MYSQL_ROW row = mysql_fetch_row(rows);
341
342       int major = row[0] ? atoi(row[0]) : -1;
343       int minor = row[1] ? atoi(row[1]) : -1;
344       log->debug("opening database version %d.%d", major, minor);
345       
346       mysql_free_result (rows);
347
348       *major_p = major;
349       *minor_p = minor;
350       return;
351
352     } else {
353       // Wrong number of rows or wrong number of fields...
354
355       log->crit("Houston, we've got a problem with the database...");
356       mysql_free_result (rows);
357       throw SAMLException("ShibMySQLCCache::getVersion(): version verification failed");
358     }
359   }
360   log->crit("MySQL Read Failed in version verificatoin");
361   throw SAMLException("ShibMySQLCCache::getVersion(): error reading version");
362 }
363
364 static void mysqlInit(const DOMElement* e, Category& log)
365 {
366   static bool done = false;
367   if (done) {
368     log.info("MySQL embedded server already initialized");
369     return;
370   }
371   log.info("initializing MySQL embedded server");
372
373   // Setup the argument array
374   vector<string> arg_array;
375   arg_array.push_back("shibboleth");
376
377   // grab any MySQL parameters from the config file
378   e=saml::XML::getFirstChildElement(e,shibtarget::XML::SHIBTARGET_NS,Argument);
379   while (e) {
380       auto_ptr_char arg(e->getFirstChild()->getNodeValue());
381       if (arg.get())
382           arg_array.push_back(arg.get());
383       e=saml::XML::getNextSiblingElement(e,shibtarget::XML::SHIBTARGET_NS,Argument);
384   }
385
386   // Compute the argument array
387   int arg_count = arg_array.size();
388   const char** args=new const char*[arg_count];
389   for (int i = 0; i < arg_count; i++)
390     args[i] = arg_array[i].c_str();
391
392   // Initialize MySQL with the arguments
393   mysql_server_init(arg_count, (char **)args, NULL);
394
395   delete[] args;
396   done = true;
397 }  
398
399 class ShibMySQLCCache;
400 class ShibMySQLCCacheEntry : public ISessionCacheEntry
401 {
402 public:
403   ShibMySQLCCacheEntry(const char* key, ISessionCacheEntry* entry, ShibMySQLCCache* cache)
404     : m_cacheEntry(entry), m_key(key), m_cache(cache), m_responseId(NULL) {}
405   ~ShibMySQLCCacheEntry() {if (m_responseId) XMLString::release(&m_responseId);}
406
407   virtual void lock() {}
408   virtual void unlock() { m_cacheEntry->unlock(); delete this; }
409   virtual bool isValid(time_t lifetime, time_t timeout) const;
410   virtual const char* getClientAddress() const { return m_cacheEntry->getClientAddress(); }
411   virtual ShibProfile getProfile() const { return m_cacheEntry->getProfile(); }
412   virtual const char* getProviderId() const { return m_cacheEntry->getProviderId(); }
413   virtual const SAMLAuthenticationStatement* getAuthnStatement() const { return m_cacheEntry->getAuthnStatement(); }
414   virtual CachedResponse getResponse();
415
416 private:
417   bool touch() const;
418
419   ShibMySQLCCache* m_cache;
420   ISessionCacheEntry* m_cacheEntry;
421   string m_key;
422   XMLCh* m_responseId;
423 };
424
425 class ShibMySQLCCache : public MySQLBase, virtual public ISessionCache
426 {
427 public:
428   ShibMySQLCCache(const DOMElement* e);
429   virtual ~ShibMySQLCCache();
430
431   virtual void thread_init() {MySQLBase::thread_init();}
432   virtual void thread_end() {MySQLBase::thread_end();}
433
434   virtual string generateKey() const {return m_cache->generateKey();}
435   virtual ISessionCacheEntry* find(const char* key, const IApplication* application);
436   virtual void insert(
437     const char* key,
438     const IApplication* application,
439     const char* client_addr,
440     ShibProfile profile,
441     const char* providerId,
442     saml::SAMLAuthenticationStatement* s,
443     saml::SAMLResponse* r=NULL,
444     const shibboleth::IRoleDescriptor* source=NULL,
445     time_t created=0,
446     time_t accessed=0
447     );
448   virtual void remove(const char* key);
449
450   virtual void cleanup();
451
452   bool m_storeAttributes;
453
454 private:
455   ISessionCache* m_cache;
456   CondWait* shutdown_wait;
457   bool shutdown;
458   Thread* cleanup_thread;
459
460   static void* cleanup_fcn(void*); // XXX Assumed an ShibMySQLCCache
461 };
462
463 ShibMySQLCCache::ShibMySQLCCache(const DOMElement* e) : MySQLBase(e), m_storeAttributes(false)
464 {
465 #ifdef _DEBUG
466   saml::NDC ndc("ShibMySQLCCache");
467 #endif
468
469   log = &(Category::getInstance("shibmysql.SessionCache"));
470
471   shutdown_wait = CondWait::create();
472   shutdown = false;
473
474   m_cache = dynamic_cast<ISessionCache*>(
475       SAMLConfig::getConfig().getPlugMgr().newPlugin(
476         "edu.internet2.middleware.shibboleth.sp.provider.MemorySessionCacheProvider", e
477         )
478     );
479     
480   // Load our configuration details...
481   const XMLCh* tag=m_root->getAttributeNS(NULL,storeAttributes);
482   if (tag && *tag && (*tag==chLatin_t || *tag==chDigit_1))
483     m_storeAttributes=true;
484
485   // Initialize the cleanup thread
486   cleanup_thread = Thread::create(&cleanup_fcn, (void*)this);
487 }
488
489 ShibMySQLCCache::~ShibMySQLCCache()
490 {
491   shutdown = true;
492   shutdown_wait->signal();
493   cleanup_thread->join(NULL);
494
495   delete m_cache;
496 }
497
498 ISessionCacheEntry* ShibMySQLCCache::find(const char* key, const IApplication* application)
499 {
500 #ifdef _DEBUG
501   saml::NDC ndc("find");
502 #endif
503
504   ISessionCacheEntry* res = m_cache->find(key, application);
505   if (!res) {
506
507     log->debug("Looking in database...");
508
509     // nothing cached; see if this exists in the database
510     string q = string("SELECT application_id,UNIX_TIMESTAMP(ctime),UNIX_TIMESTAMP(atime),addr,profile,provider,statement,response FROM state WHERE cookie='") + key + "' LIMIT 1";
511
512     MYSQL* mysql = getMYSQL();
513     if (mysql_query(mysql, q.c_str())) {
514       const char* err=mysql_error(mysql);
515       log->error("Error searching for %s: %s", key, err);
516       if (isCorrupt(err) && repairTable(mysql,"state")) {
517         if (mysql_query(mysql, q.c_str()))
518           log->error("Error retrying search for %s: %s", key, mysql_error(mysql));
519       }
520     }
521
522     MYSQL_RES* rows = mysql_store_result(mysql);
523
524     // Nope, doesn't exist.
525     if (!rows)
526       return NULL;
527
528     // Make sure we got 1 and only 1 rows.
529     if (mysql_num_rows(rows) != 1) {
530       log->error("Select returned wrong number of rows: %d", mysql_num_rows(rows));
531       mysql_free_result(rows);
532       return NULL;
533     }
534
535     log->debug("Match found.  Parsing...");
536     
537     /* Columns in query:
538         0: application_id
539         1: ctime
540         2: atime
541         3: address
542         4: profile
543         5: provider
544         6: statement
545         7: response
546      */
547
548     // Pull apart the row and process the results
549     MYSQL_ROW row = mysql_fetch_row(rows);
550     if (strcmp(application->getId(),row[0])) {
551         log->crit("An application (%s) attempted to access another application's (%s) session!", application->getId(), row[0]);
552         mysql_free_result(rows);
553         return NULL;
554     }
555
556     Metadata m(application->getMetadataProviders());
557     const IEntityDescriptor* provider=m.lookup(row[5]);
558     if (!provider) {
559         log->crit("no metadata found for identity provider (%s) responsible for the session.", row[5]);
560         mysql_free_result(rows);
561         return NULL;
562     }
563
564     SAMLAuthenticationStatement* s=NULL;
565     SAMLResponse* r=NULL;
566     ShibProfile profile=static_cast<ShibProfile>(atoi(row[4]));
567     const IRoleDescriptor* role=NULL;
568     if (profile==SAML11_POST || profile==SAML11_ARTIFACT)
569         role=provider->getIDPSSODescriptor(saml::XML::SAML11_PROTOCOL_ENUM);
570     else if (profile==SAML10_POST || profile==SAML10_ARTIFACT)
571         role=provider->getIDPSSODescriptor(saml::XML::SAML10_PROTOCOL_ENUM);
572     if (!role) {
573         log->crit(
574             "no matching IdP role for profile (%s) found for identity provider (%s) responsible for the session.", row[4], row[5]
575             );
576         mysql_free_result(rows);
577         return NULL;
578     }
579
580     // Try to parse the SAML data
581     try {
582         istringstream istr(row[6]);
583         s = new SAMLAuthenticationStatement(istr);
584         if (row[7]) {
585             istringstream istr2(row[7]);
586             r = new SAMLResponse(istr2);
587         }
588     }
589     catch (SAMLException& e) {
590         log->error(string("caught SAML exception while loading objects from SQL record: ") + e.what());
591         delete s;
592         delete r;
593         mysql_free_result(rows);
594         return NULL;
595     }
596 #ifndef _DEBUG
597     catch (...) {
598         log->error("caught unknown exception while loading objects from SQL record");
599         delete s;
600         delete r;
601         mysql_free_result(rows);
602         return NULL;
603     }
604 #endif
605
606     // Insert it into the memory cache
607     m_cache->insert(
608         key,
609         application,
610         row[3],
611         profile,
612         row[5],
613         s,
614         r,
615         role,
616         atoi(row[1]),
617         atoi(row[2])
618         );
619
620     // Free the results, and then re-run the 'find' query
621     mysql_free_result(rows);
622     res = m_cache->find(key,application);
623     if (!res)
624       return NULL;
625   }
626
627   return new ShibMySQLCCacheEntry(key, res, this);
628 }
629
630 void ShibMySQLCCache::insert(
631     const char* key,
632     const IApplication* application,
633     const char* client_addr,
634     ShibProfile profile,
635     const char* providerId,
636     saml::SAMLAuthenticationStatement* s,
637     saml::SAMLResponse* r,
638     const shibboleth::IRoleDescriptor* source,
639     time_t created,
640     time_t accessed
641     )
642 {
643 #ifdef _DEBUG
644   saml::NDC ndc("insert");
645 #endif
646   
647   ostringstream q;
648   q << "INSERT INTO state VALUES('" << key << "','" << application->getId() << "',";
649   if (created==0)
650     q << "NOW(),";
651   else
652     q << "FROM_UNIXTIME(" << created << "),";
653   if (accessed==0)
654     q << "NOW(),'";
655   else
656     q << "FROM_UNIXTIME(" << accessed << "),'";
657   q << client_addr << "'," << profile << ",'" << providerId << "',";
658   if (m_storeAttributes && r) {
659     auto_ptr_char id(r->getId());
660     q << "'" << id.get() << "','" << *r << "','";
661   }
662   else
663     q << "null,null,'";
664   q << *s << "')";
665
666   log->debug("Query: %s", q.str().c_str());
667
668   // then add it to the database
669   MYSQL* mysql = getMYSQL();
670   if (mysql_query(mysql, q.str().c_str())) {
671     const char* err=mysql_error(mysql);
672     log->error("Error inserting %s: %s", key, err);
673     if (isCorrupt(err) && repairTable(mysql,"state")) {
674         // Try again...
675         if (mysql_query(mysql, q.str().c_str())) {
676           log->error("Error inserting %s: %s", key, mysql_error(mysql));
677           throw SAMLException("ShibMySQLCCache::insert(): insertion failed");
678         }
679     }
680     else
681         throw SAMLException("ShibMySQLCCache::insert(): insertion failed");
682   }
683
684   // Add it to the memory cache
685   m_cache->insert(key, application, client_addr, profile, providerId, s, r, source, created, accessed);
686 }
687
688 void ShibMySQLCCache::remove(const char* key)
689 {
690 #ifdef _DEBUG
691   saml::NDC ndc("remove");
692 #endif
693
694   // Remove the cached version
695   m_cache->remove(key);
696
697   // Remove from the database
698   string q = string("DELETE FROM state WHERE cookie='") + key + "'";
699   MYSQL* mysql = getMYSQL();
700   if (mysql_query(mysql, q.c_str())) {
701     const char* err=mysql_error(mysql);
702     log->error("Error deleting entry %s: %s", key, err);
703     if (isCorrupt(err) && repairTable(mysql,"state")) {
704         // Try again...
705         if (mysql_query(mysql, q.c_str()))
706           log->error("Error deleting entry %s: %s", key, mysql_error(mysql));
707     }
708   }
709 }
710
711 void ShibMySQLCCache::cleanup()
712 {
713 #ifdef _DEBUG
714   saml::NDC ndc("cleanup");
715 #endif
716
717   Mutex* mutex = Mutex::create();
718   MySQLBase::thread_init();
719
720   int rerun_timer = 0;
721   int timeout_life = 0;
722
723   // Load our configuration details...
724   const XMLCh* tag=m_root->getAttributeNS(NULL,cleanupInterval);
725   if (tag && *tag)
726     rerun_timer = XMLString::parseInt(tag);
727
728   // search for 'mysql-cache-timeout' and then the regular cache timeout
729   tag=m_root->getAttributeNS(NULL,mysqlTimeout);
730   if (tag && *tag)
731     timeout_life = XMLString::parseInt(tag);
732   else {
733       tag=m_root->getAttributeNS(NULL,cacheTimeout);
734       if (tag && *tag)
735         timeout_life = XMLString::parseInt(tag);
736   }
737   
738   if (rerun_timer <= 0)
739     rerun_timer = 300;          // rerun every 5 minutes
740
741   if (timeout_life <= 0)
742     timeout_life = 28800;       // timeout after 8 hours
743
744   mutex->lock();
745
746   MYSQL* mysql = getMYSQL();
747
748   while (shutdown == false) {
749     shutdown_wait->timedwait(mutex, rerun_timer);
750
751     if (shutdown == true)
752       break;
753
754     // Find all the entries in the database that haven't been used
755     // recently In particular, find all entries that have not been
756     // accessed in 'timeout_life' seconds.
757     ostringstream q;
758     q << "SELECT cookie FROM state WHERE " <<
759       "UNIX_TIMESTAMP(NOW()) - UNIX_TIMESTAMP(atime) >= " << timeout_life;
760
761     if (mysql_query(mysql, q.str().c_str())) {
762       const char* err=mysql_error(mysql);
763       log->error("Error searching for old items: %s", err);
764         if (isCorrupt(err) && repairTable(mysql,"state")) {
765           if (mysql_query(mysql, q.str().c_str()))
766             log->error("Error re-searching for old items: %s", mysql_error(mysql));
767         }
768     }
769
770     MYSQL_RES* rows = mysql_store_result(mysql);
771     if (!rows)
772       continue;
773
774     if (mysql_num_fields(rows) != 1) {
775       log->error("Wrong number of columns, 1 != %d", mysql_num_fields(rows));
776       mysql_free_result(rows);
777       continue;
778     }
779
780     // For each row, remove the entry from the database.
781     MYSQL_ROW row;
782     while ((row = mysql_fetch_row(rows)) != NULL)
783       remove(row[0]);
784
785     mysql_free_result(rows);
786   }
787
788   log->info("cleanup thread exiting...");
789
790   mutex->unlock();
791   delete mutex;
792   MySQLBase::thread_end();
793   Thread::exit(NULL);
794 }
795
796 void* ShibMySQLCCache::cleanup_fcn(void* cache_p)
797 {
798   ShibMySQLCCache* cache = (ShibMySQLCCache*)cache_p;
799
800   // First, let's block all signals
801   Thread::mask_all_signals();
802
803   // Now run the cleanup process.
804   cache->cleanup();
805   return NULL;
806 }
807
808 /*************************************************************************
809  * The CCacheEntry here is mostly a wrapper around the "memory"
810  * cacheentry provided by shibboleth.  The only difference is that we
811  * intercept isSessionValid() so that we can "touch()" the
812  * database if the session is still valid and getResponse() so we can
813  * store the data if we need to.
814  */
815
816 bool ShibMySQLCCacheEntry::isValid(time_t lifetime, time_t timeout) const
817 {
818   bool res = m_cacheEntry->isValid(lifetime, timeout);
819   if (res == true)
820     res = touch();
821   return res;
822 }
823
824 bool ShibMySQLCCacheEntry::touch() const
825 {
826   string q=string("UPDATE state SET atime=NOW() WHERE cookie='") + m_key + "'";
827
828   MYSQL* mysql = m_cache->getMYSQL();
829   if (mysql_query(mysql, q.c_str())) {
830     m_cache->log->info("Error updating timestamp on %s: %s", m_key.c_str(), mysql_error(mysql));
831     return false;
832   }
833   return true;
834 }
835
836 ISessionCacheEntry::CachedResponse ShibMySQLCCacheEntry::getResponse()
837 {
838     // Let the memory cache do the work first.
839     // If we're hands off, just pass it back.
840     if (!m_cache->m_storeAttributes)
841         return m_cacheEntry->getResponse();
842     
843     CachedResponse r=m_cacheEntry->getResponse();
844     if (r.empty()) return r;
845     
846     // Load the key from state if needed.
847     if (!m_responseId) {
848         string qselect=string("SELECT response_id from state WHERE cookie='") + m_key + "' LIMIT 1";
849         MYSQL* mysql = m_cache->getMYSQL();
850         if (mysql_query(mysql, qselect.c_str())) {
851             const char* err=mysql_error(mysql);
852             m_cache->log->error("error accessing response ID for %s: %s", m_key.c_str(), err);
853             if (isCorrupt(err) && m_cache->repairTable(mysql,"state")) {
854                 // Try again...
855                 if (mysql_query(mysql, qselect.c_str())) {
856                     m_cache->log->error("error accessing response ID for %s: %s", m_key.c_str(), mysql_error(mysql));
857                     return r;
858                 }
859             }
860         }
861         MYSQL_RES* rows = mysql_store_result(mysql);
862     
863         // Make sure we got 1 and only 1 row.
864         if (!rows || mysql_num_rows(rows) != 1) {
865             m_cache->log->error("select returned wrong number of rows");
866             if (rows) mysql_free_result(rows);
867             return r;
868         }
869         
870         MYSQL_ROW row=mysql_fetch_row(rows);
871         if (row)
872             m_responseId=XMLString::transcode(row[0]);
873         mysql_free_result(rows);
874     }
875     
876     // Compare it with what we have now.
877     if (m_responseId && !XMLString::compareString(m_responseId,r.unfiltered->getId()))
878         return r;
879     
880     // No match, so we need to update our copy.
881     if (m_responseId) XMLString::release(&m_responseId);
882     m_responseId = XMLString::replicate(r.unfiltered->getId());
883     auto_ptr_char id(m_responseId);
884
885     ostringstream q;
886     q << "UPDATE state SET response_id='" << id.get() << "',response='" << *r.unfiltered << "' WHERE cookie='" << m_key << "'";
887     m_cache->log->debug("Query: %s", q.str().c_str());
888
889     MYSQL* mysql = m_cache->getMYSQL();
890     if (mysql_query(mysql, q.str().c_str())) {
891         const char* err=mysql_error(mysql);
892         m_cache->log->error("Error updating response for %s: %s", m_key.c_str(), err);
893         if (isCorrupt(err) && m_cache->repairTable(mysql,"state")) {
894             // Try again...
895             if (mysql_query(mysql, q.str().c_str()))
896               m_cache->log->error("Error updating response for %s: %s", m_key.c_str(), mysql_error(mysql));
897         }
898     }
899     
900     return r;
901 }
902
903 class MySQLReplayCache : public MySQLBase, virtual public IReplayCache
904 {
905 public:
906   MySQLReplayCache(const DOMElement* e);
907   virtual ~MySQLReplayCache() {}
908
909   void thread_init() {MySQLBase::thread_init();}
910   void thread_end() {MySQLBase::thread_end();}
911
912   bool check(const XMLCh* str, time_t expires) {auto_ptr_XMLCh temp(str); return check(temp.get(),expires);}
913   bool check(const char* str, time_t expires);
914 };
915
916 MySQLReplayCache::MySQLReplayCache(const DOMElement* e) : MySQLBase(e)
917 {
918 #ifdef _DEBUG
919   saml::NDC ndc("MySQLReplayCache");
920 #endif
921
922   log = &(Category::getInstance("shibmysql.ReplayCache"));
923 }
924
925 bool MySQLReplayCache::check(const char* str, time_t expires)
926 {
927 #ifdef _DEBUG
928     saml::NDC ndc("check");
929 #endif
930   
931     // Remove expired entries
932     string q = string("DELETE FROM replay WHERE expires < NOW()");
933     MYSQL* mysql = getMYSQL();
934     if (mysql_query(mysql, q.c_str())) {
935         const char* err=mysql_error(mysql);
936         log->error("Error deleting expired entries: %s", err);
937         if (isCorrupt(err) && repairTable(mysql,"replay")) {
938             // Try again...
939             if (mysql_query(mysql, q.c_str()))
940                 log->error("Error deleting expired entries: %s", mysql_error(mysql));
941         }
942     }
943   
944     string q2 = string("SELECT id FROM replay WHERE id='") + str + "'";
945     if (mysql_query(mysql, q2.c_str())) {
946         const char* err=mysql_error(mysql);
947         log->error("Error searching for %s: %s", str, err);
948         if (isCorrupt(err) && repairTable(mysql,"replay")) {
949             if (mysql_query(mysql, q2.c_str())) {
950                 log->error("Error retrying search for %s: %s", str, mysql_error(mysql));
951                 throw SAMLException("Replay cache failed, please inform application support staff.");
952             }
953         }
954         else
955             throw SAMLException("Replay cache failed, please inform application support staff.");
956     }
957
958     // Did we find it?
959     MYSQL_RES* rows = mysql_store_result(mysql);
960     if (rows && mysql_num_rows(rows)>0) {
961       mysql_free_result(rows);
962       return false;
963     }
964
965     ostringstream q3;
966     q3 << "INSERT INTO replay VALUES('" << str << "'," << "FROM_UNIXTIME(" << expires << "))";
967
968     // then add it to the database
969     if (mysql_query(mysql, q3.str().c_str())) {
970         const char* err=mysql_error(mysql);
971         log->error("Error inserting %s: %s", str, err);
972         if (isCorrupt(err) && repairTable(mysql,"state")) {
973             // Try again...
974             if (mysql_query(mysql, q3.str().c_str())) {
975                 log->error("Error inserting %s: %s", str, mysql_error(mysql));
976                 throw SAMLException("Replay cache failed, please inform application support staff.");
977             }
978         }
979         else
980             throw SAMLException("Replay cache failed, please inform application support staff.");
981     }
982     
983     return true;
984 }
985
986 /*************************************************************************
987  * The registration functions here...
988  */
989
990 IPlugIn* new_mysql_ccache(const DOMElement* e)
991 {
992   return new ShibMySQLCCache(e);
993 }
994
995 IPlugIn* new_mysql_replay(const DOMElement* e)
996 {
997   return new MySQLReplayCache(e);
998 }
999
1000 #define REPLAYPLUGINTYPE "edu.internet2.middleware.shibboleth.sp.provider.MySQLReplayCacheProvider"
1001 #define SESSIONPLUGINTYPE "edu.internet2.middleware.shibboleth.sp.provider.MySQLSessionCacheProvider"
1002
1003 extern "C" int SHIBMYSQL_EXPORTS saml_extension_init(void*)
1004 {
1005   // register this ccache type
1006   SAMLConfig::getConfig().getPlugMgr().regFactory(REPLAYPLUGINTYPE, &new_mysql_replay);
1007   SAMLConfig::getConfig().getPlugMgr().regFactory(SESSIONPLUGINTYPE, &new_mysql_ccache);
1008   return 0;
1009 }
1010
1011 extern "C" void SHIBMYSQL_EXPORTS saml_extension_term()
1012 {
1013   // Shutdown MySQL
1014   mysql_server_end();
1015   SAMLConfig::getConfig().getPlugMgr().unregFactory(REPLAYPLUGINTYPE);
1016   SAMLConfig::getConfig().getPlugMgr().unregFactory(SESSIONPLUGINTYPE);
1017 }