Updated version, try and replace state table from older version.
[shibboleth/cpp-sp.git] / shib-mysql-ccache / shib-mysql-ccache.cpp
1 /*
2  * shib-mysql-ccache.cpp: Shibboleth Credential Cache using MySQL.
3  *
4  * Created by:  Derek Atkins <derek@ihtfp.com>
5  *
6  * $Id$
7  */
8
9 /* This file is loosely based off the Shibboleth Credential Cache.
10  * This plug-in is designed as a two-layer cache.  Layer 1, the
11  * long-term cache, stores data in a MySQL embedded database.  The
12  * data stored in layer 1 is only the session id (cookie), the
13  * "posted" SAML statement (expanded into an XML string), and usage
14  * timestamps.
15  *
16  * Short-term data is cached in memory as SAML objects in the layer 2
17  * cache.  Data like Attribute Authority assertions are stored in
18  * the layer 2 cache.
19  */
20
21 // eventually we might be able to support autoconf via cygwin...
22 #if defined (_MSC_VER) || defined(__BORLANDC__)
23 # include "config_win32.h"
24 #else
25 # include "config.h"
26 #endif
27
28 #ifdef WIN32
29 # define SHIBMYSQL_EXPORTS __declspec(dllexport)
30 #else
31 # define SHIBMYSQL_EXPORTS
32 #endif
33
34 #ifdef HAVE_UNISTD_H
35 # include <unistd.h>
36 #endif
37
38 #include <shib-target/shib-target.h>
39 #include <shib-target/ccache-utils.h>
40 #include <shib/shib-threads.h>
41 #include <log4cpp/Category.hh>
42
43 #include <sstream>
44 #include <stdexcept>
45
46 #include <mysql.h>
47
48 #ifdef HAVE_LIBDMALLOCXX
49 #include <dmalloc.h>
50 #endif
51
52 using namespace std;
53 using namespace saml;
54 using namespace shibboleth;
55 using namespace shibtarget;
56
57 #define PLUGIN_VER_MAJOR 1
58 #define PLUGIN_VER_MINOR 0
59
60 class ShibMySQLCCache;
61 class ShibMySQLCCacheEntry : public CCacheEntry
62 {
63 public:
64   ShibMySQLCCacheEntry(const char *, CCacheEntry*, ShibMySQLCCache*);
65   ~ShibMySQLCCacheEntry() {}
66
67   virtual Iterator<SAMLAssertion*> getAssertions() { return m_cacheEntry->getAssertions(); }
68   virtual void preFetch(int prefetch_window) { m_cacheEntry->preFetch(prefetch_window); }
69   virtual bool isSessionValid(time_t lifetime, time_t timeout);
70   virtual const char* getClientAddress() { return m_cacheEntry->getClientAddress(); }
71   virtual const char* getSerializedStatement() { return m_cacheEntry->getSerializedStatement(); }
72   virtual const SAMLAuthenticationStatement* getStatement() { return m_cacheEntry->getStatement(); }
73   virtual void release() { m_cacheEntry->release(); delete this; }
74
75 private:
76   bool touch();
77
78   ShibMySQLCCache* m_cache;
79   CCacheEntry *m_cacheEntry;
80   string m_key;
81 };
82
83 class ShibMySQLCCache : public CCache
84 {
85 public:
86   ShibMySQLCCache();
87   virtual ~ShibMySQLCCache();
88
89   virtual CCacheEntry* find(const char* key);
90   virtual void insert(
91         const char* key,
92         const char* application_id,
93         saml::SAMLAuthenticationStatement *s,
94         const char *client_addr,
95         saml::SAMLResponse* r=NULL);
96   virtual void remove(const char* key);
97   virtual void thread_init();
98
99   void  cleanup();
100   MYSQL* getMYSQL();
101
102   log4cpp::Category* log;
103
104 private:
105   CCache* m_cache;
106   ThreadKey* m_mysql;
107
108   static void*  cleanup_fcn(void*); // XXX Assumed an ShibMySQLCCache
109   CondWait* shutdown_wait;
110   bool shutdown;
111   Thread* cleanup_thread;
112
113   bool initialized;
114
115   void createDatabase(MYSQL*, int major, int minor);
116   void upgradeDatabase(MYSQL*);
117   void getVersion(MYSQL*, int* major_p, int* minor_p);
118   void mysqlInit(void);
119 };
120
121 // Forward declarations
122 extern "C" void shib_mysql_destroy_handle(void* data);
123
124 /*************************************************************************
125  * The CCache here talks to a MySQL database.  The database stores
126  * three items: the cookie (session key index), the lastAccess time, and
127  * the SAMLAuthenticationStatement.  All other access is performed
128  * through the memory cache provided by shibboleth.
129  */
130
131 MYSQL* ShibMySQLCCache::getMYSQL()
132 {
133   void* data = m_mysql->getData();
134   return (MYSQL*)data;
135 }
136
137 void ShibMySQLCCache::thread_init()
138 {
139   saml::NDC ndc("thread_init");
140
141   // Connect to the database
142   MYSQL* mysql = mysql_init(NULL);
143   if (!mysql) {
144     log->error("mysql_init failed");
145     mysql_close(mysql);
146     throw runtime_error("mysql_init()");
147   }
148
149   if (!mysql_real_connect(mysql, NULL, NULL, NULL, "shar", 0, NULL, 0)) {
150     if (initialized) {
151       log->crit("mysql_real_connect failed: %s", mysql_error(mysql));
152       throw runtime_error("mysql_real_connect");
153
154     } else {
155       log->info("mysql_real_connect failed: %s.  Trying to create",
156                 mysql_error(mysql));
157
158       // This will throw a runtime error if it fails.
159       createDatabase(mysql, PLUGIN_VER_MAJOR, PLUGIN_VER_MINOR);
160     }
161   }
162
163   int major = -1, minor = -1;
164   getVersion (mysql, &major, &minor);
165
166   // Make sure we've got the right version
167   if (major != PLUGIN_VER_MAJOR && minor != PLUGIN_VER_MINOR) {
168    
169     // If we're capable, try upgrading on the fly...
170     if (major == 0 && minor == 0) {
171        upgradeDatabase(mysql);
172     }
173     else {
174         log->crit("Invalid database version: %d.%d", major, minor);
175         throw runtime_error("Invalid Database version");
176     }
177   }
178
179   // We're all set.. Save off the handle for this thread.
180   m_mysql->setData((void*)mysql);
181 }
182
183 ShibMySQLCCache::ShibMySQLCCache()
184 {
185   saml::NDC ndc("shibmysql::ShibMySQLCCache");
186
187   m_mysql = ThreadKey::create(&shib_mysql_destroy_handle);
188   string ctx = "shibmysql::ShibMySQLCCache";
189   log = &(log4cpp::Category::getInstance(ctx));
190
191   initialized = false;
192   mysqlInit();
193   thread_init();
194   initialized = true;
195
196   m_cache = CCache::getInstance("memory");
197
198   // Initialize the cleanup thread
199   shutdown_wait = CondWait::create();
200   shutdown = false;
201   cleanup_thread = Thread::create(&cleanup_fcn, (void*)this);
202 }
203
204 ShibMySQLCCache::~ShibMySQLCCache()
205 {
206   shutdown = true;
207   shutdown_wait->signal();
208   cleanup_thread->join(NULL);
209
210   delete m_cache;
211   delete m_mysql;
212
213   // Shutdown MySQL
214   mysql_server_end();
215 }
216
217 CCacheEntry* ShibMySQLCCache::find(const char* key)
218 {
219   saml::NDC ndc("mysql::find");
220   CCacheEntry* res = m_cache->find(key);
221   if (!res) {
222
223     log->debug("Looking in database...");
224
225     // nothing cached; see if this exists in the database
226     string q = string("SELECT application_id,addr,statement FROM state WHERE cookie='") + key + "' LIMIT 1";
227
228     MYSQL_RES* rows;
229     MYSQL* mysql = getMYSQL();
230     if (mysql_query(mysql, q.c_str()))
231       log->error("Error searching for %s: %s", key, mysql_error(mysql));
232
233     rows = mysql_store_result(mysql);
234
235     // Nope, doesn't exist.
236     if (!rows)
237       return NULL;
238
239     // Make sure we got 1 and only 1 rows.
240     if (mysql_num_rows(rows) != 1) {
241       log->error("Select returned wrong number of rows: %d", mysql_num_rows(rows));
242       mysql_free_result(rows);
243       return NULL;
244     }
245
246     log->debug("Match found.  Parsing...");
247     // Pull apart the row and process the results
248     MYSQL_ROW row = mysql_fetch_row(rows);
249     istringstream str(row[2]);
250     SAMLAuthenticationStatement *s = NULL;
251
252     // Try to parse the AuthStatement
253     try {
254       s = new SAMLAuthenticationStatement(str);
255     } catch (...) {
256       mysql_free_result(rows);
257       throw;
258     }
259
260     // Insert it into the memory cache
261     if (s)
262       m_cache->insert(key, row[0], s, row[1]);
263
264     // Free the results, and then re-run the 'find' query
265     mysql_free_result(rows);
266     res = m_cache->find(key);
267     if (!res)
268       return NULL;
269   }
270
271   return new ShibMySQLCCacheEntry(key, res, this);
272 }
273
274 void ShibMySQLCCache::insert(
275     const char* key,
276     const char* application_id,
277     saml::SAMLAuthenticationStatement *s,
278     const char *client_addr,
279     saml::SAMLResponse* r)
280 {
281   saml::NDC ndc("mysql::insert");
282   ostringstream os;
283   os << *s;
284
285   string q = string("INSERT INTO state VALUES('") + key + "','" + application_id + "',NOW(),'" + client_addr + "','" + os.str() + "')";
286
287   log->debug("Query: %s", q.c_str());
288
289   // Add it to the memory cache
290   m_cache->insert(key, application_id, s, client_addr, r);
291
292   // then add it to the database
293   MYSQL* mysql = getMYSQL();
294   if (mysql_query(mysql, q.c_str()))
295     log->error("Error inserting %s: %s", key, mysql_error(mysql));
296 }
297
298 void ShibMySQLCCache::remove(const char* key)
299 {
300   saml::NDC ndc("mysql::remove");
301
302   // Remove the cached version
303   m_cache->remove(key);
304
305   // Remove from the database
306   string q = string("DELETE FROM state WHERE cookie='") + key + "'";
307   MYSQL* mysql = getMYSQL();
308   if (mysql_query(mysql, q.c_str()))
309     log->info("Error deleting entry %s: %s", key, mysql_error(mysql));
310 }
311
312 void ShibMySQLCCache::cleanup()
313 {
314   Mutex* mutex = Mutex::create();
315   saml::NDC ndc("mysql::cleanup");
316
317   thread_init();
318
319   ShibTargetConfig& config = ShibTargetConfig::getConfig();
320   ShibINI& ini = config.getINI();
321
322   int rerun_timer = 0;
323   int timeout_life = 0;
324
325   string tag;
326   if (ini.get_tag (SHIBTARGET_SHAR, SHIBTARGET_TAG_CACHECLEAN, true, &tag))
327     rerun_timer = atoi(tag.c_str());
328
329   // search for 'mysql-cache-timeout' and then the regular cache timeout
330   if (ini.get_tag (SHIBTARGET_SHAR, "mysql-cache-timeout", true, &tag))
331     timeout_life = atoi(tag.c_str());
332   else if (ini.get_tag (SHIBTARGET_SHAR, SHIBTARGET_TAG_CACHETIMEOUT, true, &tag))
333     timeout_life = atoi(tag.c_str());
334
335   if (rerun_timer <= 0)
336     rerun_timer = 300;          // rerun every 5 minutes
337
338   if (timeout_life <= 0)
339     timeout_life = 28800;       // timeout after 8 hours
340
341   mutex->lock();
342
343   MYSQL* mysql = getMYSQL();
344
345   while (shutdown == false) {
346     shutdown_wait->timedwait(mutex, rerun_timer);
347
348     if (shutdown == true)
349       break;
350
351     // Find all the entries in the database that haven't been used
352     // recently In particular, find all entries that have not been
353     // accessed in 'timeout_life' seconds.
354     ostringstream q;
355     q << "SELECT cookie FROM state WHERE " <<
356       "UNIX_TIMESTAMP(NOW()) - UNIX_TIMESTAMP(atime) >= " << timeout_life;
357
358     MYSQL_RES *rows;
359     if (mysql_query(mysql, q.str().c_str()))
360       log->error("Error searching for old items: %s", mysql_error(mysql));
361
362     rows = mysql_store_result(mysql);
363     if (!rows)
364       continue;
365
366     if (mysql_num_fields(rows) != 1) {
367       log->error("Wrong number of rows, 1 != %d", mysql_num_fields(rows));
368       mysql_free_result(rows);
369       continue;
370     }
371
372     // For each row, remove the entry from the database.
373     MYSQL_ROW row;
374     while ((row = mysql_fetch_row(rows)) != NULL)
375       remove(row[0]);
376
377     mysql_free_result(rows);
378   }
379
380   log->debug("cleanup thread exiting...");
381
382   mutex->unlock();
383   delete mutex;
384   thread_end();
385   Thread::exit(NULL);
386 }
387
388 void* ShibMySQLCCache::cleanup_fcn(void* cache_p)
389 {
390   ShibMySQLCCache* cache = (ShibMySQLCCache*)cache_p;
391
392   // First, let's block all signals
393   Thread::mask_all_signals();
394
395   // Now run the cleanup process.
396   cache->cleanup();
397   return NULL;
398 }
399
400 void ShibMySQLCCache::createDatabase(MYSQL* mysql, int major, int minor)
401 {
402   log->info("Creating database.");
403
404   MYSQL* ms = NULL;
405   try {
406     ms = mysql_init(NULL);
407     if (!ms) {
408       log->crit("mysql_init failed");
409       throw new ShibTargetException();
410     }
411
412     if (!mysql_real_connect(ms, NULL, NULL, NULL, NULL, 0, NULL, 0)) {
413       log->crit("cannot open DB file to create DB: %s", mysql_error(ms));
414       throw new ShibTargetException();
415     }
416
417     if (mysql_query(ms, "CREATE DATABASE shar")) {
418       log->crit("cannot create shar database: %s", mysql_error(ms));
419       throw new ShibTargetException();
420     }
421
422     if (!mysql_real_connect(mysql, NULL, NULL, NULL, "shar", 0, NULL, 0)) {
423       log->crit("cannot open SHAR database");
424       throw new ShibTargetException();
425     }
426
427     mysql_close(ms);
428     
429   } catch (ShibTargetException&) {
430     if (ms)
431       mysql_close(ms);
432     mysql_close(mysql);
433     throw runtime_error("mysql_real_connect");
434   }
435
436   // Now create the tables if they don't exist
437   log->info("Creating database tables.");
438
439   if (mysql_query(mysql, "CREATE TABLE version (major INT, minor INT)"))
440     log->error ("Error creating version: %s", mysql_error(mysql));
441
442   if (mysql_query(mysql,
443                   "CREATE TABLE state (cookie VARCHAR(64) PRIMARY KEY, application_id VARCHAR(1024),"
444                   "atime DATETIME, addr VARCHAR(128), statement TEXT)"))
445     log->error ("Error creating state: %s", mysql_error(mysql));
446
447   ostringstream q;
448   q << "INSERT INTO version VALUES(" << major << "," << minor << ")";
449   if (mysql_query(mysql, q.str().c_str()))
450     log->error ("Error setting version: %s", mysql_error(mysql));
451 }
452
453 void ShibMySQLCCache::upgradeDatabase(MYSQL* mysql)
454 {
455     if (mysql_query(mysql, "DROP TABLE state")) {
456         log->error("Error dropping old session state table: %s", mysql_error(mysql));
457         throw runtime_error("error dropping table");
458     }
459
460     if (mysql_query(mysql,
461         "CREATE TABLE state (cookie VARCHAR(64) PRIMARY KEY, application_id VARCHAR(1024),"
462        "atime DATETIME, addr VARCHAR(128), statement TEXT)")) {
463         log->error ("Error creating state table: %s", mysql_error(mysql));
464         throw runtime_error("error creating table");
465     }
466
467     ostringstream q;
468     q << "UPDATE version SET major = " << PLUGIN_VER_MAJOR;
469     if (mysql_query(mysql, q.str().c_str())) {
470         log->error ("Error updating version: %s", mysql_error(mysql));
471         throw runtime_error("error updating table");
472     }
473 }
474
475 void ShibMySQLCCache::getVersion(MYSQL* mysql, int* major_p, int* minor_p)
476 {
477   // grab the version number from the database
478   if (mysql_query(mysql, "SELECT * FROM version"))
479     log->error ("Error reading version: %s", mysql_error(mysql));
480
481   MYSQL_RES* rows = mysql_store_result(mysql);
482   if (rows) {
483     if (mysql_num_rows(rows) == 1 && mysql_num_fields(rows) == 2)  {
484       MYSQL_ROW row = mysql_fetch_row(rows);
485
486       int major = row[0] ? atoi(row[0]) : -1;
487       int minor = row[1] ? atoi(row[1]) : -1;
488       log->debug("opening database version %d.%d", major, minor);
489       
490       mysql_free_result (rows);
491
492       *major_p = major;
493       *minor_p = minor;
494       return;
495
496     } else {
497       // Wrong number of rows or wrong number of fields...
498
499       log->crit("Houston, we've got a problem with the database..");
500       mysql_free_result (rows);
501       throw runtime_error("Database version verification failed");
502     }
503   }
504   log->crit("MySQL Read Failed in version verificatoin");
505   throw runtime_error("MySQL Read Failed");
506 }
507
508 void ShibMySQLCCache::mysqlInit(void)
509 {
510   log->info ("Opening MySQL Database");
511
512   // Setup the argument array
513   vector<string> arg_array;
514   string tag;
515
516   tag = SHIBTARGET_SHAR;
517   arg_array.push_back(tag);
518
519   // grab any MySQL parameters from the config file
520   ShibTargetConfig& config = ShibTargetConfig::getConfig();
521   ShibINI& ini = config.getINI();
522
523   if (ini.exists("mysql")) {
524     ShibINI::Iterator* iter = ini.tag_iterator("mysql");
525
526     for (const string* str = iter->begin(); str; str = iter->next()) {
527       string arg = ini.get("mysql", *str);
528       arg_array.push_back(arg);
529     }
530     delete iter;
531   }
532
533   // Compute the argument array
534   int arg_count = arg_array.size();
535   const char** args=new const char*[arg_count];
536   for (int i = 0; i < arg_count; i++)
537     args[i] = arg_array[i].c_str();
538
539   // Initialize MySQL with the arguments
540   mysql_server_init(arg_count, (char **)args, NULL);
541
542   delete[] args;
543 }  
544
545 /*************************************************************************
546  * The CCacheEntry here is mostly a wrapper around the "memory"
547  * cacheentry provided by shibboleth.  The only difference is that we
548  * intercept the isSessionValid() so that we can "touch()" the
549  * database if the session is still valid.
550  */
551
552 ShibMySQLCCacheEntry::ShibMySQLCCacheEntry(const char* key, CCacheEntry *entry, ShibMySQLCCache* cache)
553 {
554   m_cacheEntry = entry;
555   m_key = key;
556   m_cache = cache;
557 }
558
559 bool ShibMySQLCCacheEntry::isSessionValid(time_t lifetime, time_t timeout)
560 {
561   bool res = m_cacheEntry->isSessionValid(lifetime, timeout);
562   if (res == true)
563     res = touch();
564   return res;
565 }
566
567 bool ShibMySQLCCacheEntry::touch()
568 {
569   string q=string("UPDATE state SET atime=NOW() WHERE cookie='") + m_key + "'";
570
571   MYSQL* mysql = m_cache->getMYSQL();
572   if (mysql_query(mysql, q.c_str())) {
573     m_cache->log->info("Error updating timestamp on %s: %s",
574                         m_key.c_str(), mysql_error(mysql));
575     return false;
576   }
577   return true;
578 }
579
580 /*************************************************************************
581  * The registration functions here...
582  */
583
584 extern "C" CCache* new_mysql_ccache(void)
585 {
586   return new ShibMySQLCCache();
587 }
588
589 extern "C" int SHIBMYSQL_EXPORTS saml_extension_init(void* context)
590 {
591   // register this ccache type
592   CCache::registerFactory("mysql", &new_mysql_ccache);
593   return 0;
594 }
595
596 /*************************************************************************
597  * Local Functions
598  */
599
600 extern "C" void shib_mysql_destroy_handle(void* data)
601 {
602   MYSQL* mysql = (MYSQL*) data;
603   mysql_close(mysql);
604 }