Tweaked some logging.
[shibboleth/cpp-sp.git] / shib-mysql-ccache / shib-mysql-ccache.cpp
1 /*
2  * shib-mysql-ccache.cpp: Shibboleth Credential Cache using MySQL.
3  *
4  * Created by:  Derek Atkins <derek@ihtfp.com>
5  *
6  * $Id$
7  */
8
9 /* This file is loosely based off the Shibboleth Credential Cache.
10  * This plug-in is designed as a two-layer cache.  Layer 1, the
11  * long-term cache, stores data in a MySQL embedded database.  The
12  * data stored in layer 1 is only the session id (cookie), the
13  * "posted" SAML statement (expanded into an XML string), and usage
14  * timestamps.
15  *
16  * Short-term data is cached in memory as SAML objects in the layer 2
17  * cache.  Data like Attribute Authority assertions are stored in
18  * the layer 2 cache.
19  */
20
21 // eventually we might be able to support autoconf via cygwin...
22 #if defined (_MSC_VER) || defined(__BORLANDC__)
23 # include "config_win32.h"
24 #else
25 # include "config.h"
26 #endif
27
28 #ifdef WIN32
29 # define SHIBMYSQL_EXPORTS __declspec(dllexport)
30 #else
31 # define SHIBMYSQL_EXPORTS
32 #endif
33
34 #ifdef HAVE_UNISTD_H
35 # include <unistd.h>
36 #endif
37
38 #include <shib-target/shib-target.h>
39 #include <shib-target/ccache-utils.h>
40 #include <shib/shib-threads.h>
41 #include <log4cpp/Category.hh>
42
43 #include <sstream>
44 #include <stdexcept>
45
46 #include <mysql.h>
47
48 #ifdef HAVE_LIBDMALLOCXX
49 #include <dmalloc.h>
50 #endif
51
52 using namespace std;
53 using namespace saml;
54 using namespace shibboleth;
55 using namespace shibtarget;
56
57 class ShibMySQLCCache;
58 class ShibMySQLCCacheEntry : public CCacheEntry
59 {
60 public:
61   ShibMySQLCCacheEntry(const char *, CCacheEntry*, ShibMySQLCCache*);
62   ~ShibMySQLCCacheEntry() {}
63
64   virtual Iterator<SAMLAssertion*> getAssertions(Resource& resource)
65         { return m_cacheEntry->getAssertions(resource); }
66   virtual void preFetch(Resource& resource, int prefetch_window)
67         { m_cacheEntry->preFetch(resource, prefetch_window); }
68   virtual bool isSessionValid(time_t lifetime, time_t timeout);
69   virtual const char* getClientAddress()
70         { return m_cacheEntry->getClientAddress(); }
71   virtual const char* getSerializedStatement()
72         { return m_cacheEntry->getSerializedStatement(); }
73   virtual const SAMLAuthenticationStatement* getStatement()
74         { return m_cacheEntry->getStatement(); }
75   virtual void release()
76         { m_cacheEntry->release(); delete this; }
77
78 private:
79   void touch();
80
81   ShibMySQLCCache* m_cache;
82   CCacheEntry *m_cacheEntry;
83   string m_key;
84 };
85
86 class ShibMySQLCCache : public CCache
87 {
88 public:
89   ShibMySQLCCache();
90   virtual ~ShibMySQLCCache();
91
92   virtual SAMLBinding* getBinding(const XMLCh* bindingProt)
93         { return m_cache->getBinding(bindingProt); }
94   virtual CCacheEntry* find(const char* key);
95   virtual void insert(const char* key, SAMLAuthenticationStatement *s,
96                       const char *client_addr);
97   virtual void remove(const char* key);
98   virtual void thread_init();
99
100   void  cleanup();
101   MYSQL* getMYSQL();
102
103   log4cpp::Category* log;
104
105 private:
106   CCache* m_cache;
107   ThreadKey* m_mysql;
108
109   static void*  cleanup_fcn(void*); // XXX Assumed an ShibMySQLCCache
110   CondWait* shutdown_wait;
111   bool shutdown;
112   Thread* cleanup_thread;
113
114   bool initialized;
115
116   void createDatabase(MYSQL*, int major, int minor);
117   void getVersion(MYSQL*, int* major_p, int* minor_p);
118   void mysqlInit(void);
119 };
120
121 // Forward declarations
122 extern "C" void shib_mysql_destroy_handle(void* data);
123
124 /*************************************************************************
125  * The CCache here talks to a MySQL database.  The database stores
126  * three items: the cookie (session key index), the lastAccess time, and
127  * the SAMLAuthenticationStatement.  All other access is performed
128  * through the memory cache provided by shibboleth.
129  */
130
131 MYSQL* ShibMySQLCCache::getMYSQL()
132 {
133   void* data = m_mysql->getData();
134   return (MYSQL*)data;
135 }
136
137 void ShibMySQLCCache::thread_init()
138 {
139   saml::NDC ndc("thread_init");
140
141   // Connect to the database
142   MYSQL* mysql = mysql_init(NULL);
143   if (!mysql) {
144     log->error("mysql_init failed");
145     mysql_close(mysql);
146     throw runtime_error("mysql_init()");
147   }
148
149   if (!mysql_real_connect(mysql, NULL, NULL, NULL, "shar", 0, NULL, 0)) {
150     if (initialized) {
151       log->crit("mysql_real_connect failed: %s", mysql_error(mysql));
152       throw runtime_error("mysql_real_connect");
153
154     } else {
155       log->info("mysql_real_connect failed: %s.  Trying to create",
156                 mysql_error(mysql));
157
158       // This will throw a runtime error if it fails.
159       createDatabase(mysql, 0, 0);
160     }
161   }
162
163   int major = -1, minor = -1;
164   getVersion (mysql, &major, &minor);
165
166   // Make sure we've got the right version
167   if (major != 0 && minor != 0) {
168     log->crit("Invalid database version: %d.%d", major, minor);
169     throw runtime_error("Invalid Database version");
170   }
171
172   // We're all set.. Save off the handle for this thread.
173   m_mysql->setData((void*)mysql);
174 }
175
176 ShibMySQLCCache::ShibMySQLCCache()
177 {
178   saml::NDC ndc("shibmysql::ShibMySQLCCache");
179
180   m_mysql = ThreadKey::create(&shib_mysql_destroy_handle);
181   string ctx = "shibmysql::ShibMySQLCCache";
182   log = &(log4cpp::Category::getInstance(ctx));
183
184   initialized = false;
185   mysqlInit();
186   thread_init();
187   initialized = true;
188
189   m_cache = CCache::getInstance("memory");
190
191   // Initialize the cleanup thread
192   shutdown_wait = CondWait::create();
193   shutdown = false;
194   cleanup_thread = Thread::create(&cleanup_fcn, (void*)this);
195 }
196
197 ShibMySQLCCache::~ShibMySQLCCache()
198 {
199   shutdown = true;
200   shutdown_wait->signal();
201   cleanup_thread->join(NULL);
202
203   delete m_cache;
204   delete m_mysql;
205
206   // Shutdown MySQL
207   mysql_server_end();
208 }
209
210 CCacheEntry* ShibMySQLCCache::find(const char* key)
211 {
212   saml::NDC ndc("mysql::find");
213   CCacheEntry* res = m_cache->find(key);
214   if (!res) {
215
216     log->debug("Looking in database...");
217
218     // nothing cached; see if this exists in the database
219     ostringstream q;
220     q << "SELECT addr,statement FROM state WHERE cookie='" << key << "' LIMIT 1";
221
222     MYSQL_RES* rows;
223     MYSQL* mysql = getMYSQL();
224     if (mysql_query(mysql, q.str().c_str()))
225       log->error("Error searching for %s: %s", key, mysql_error(mysql));
226
227     rows = mysql_store_result(mysql);
228
229     // Nope, doesn't exist.
230     if (!rows)
231       return NULL;
232
233     // Make sure we got 1 and only 1 rows.
234     if (mysql_num_rows(rows) != 1) {
235       log->error("Select returned wrong number of rows: %d", mysql_num_rows(rows));
236       mysql_free_result(rows);
237       return NULL;
238     }
239
240     log->debug("Match found.  Parsing...");
241     // Pull apart the row and process the results
242     MYSQL_ROW row = mysql_fetch_row(rows);
243     istringstream str(row[1]);
244     SAMLAuthenticationStatement *s = NULL;
245
246     // Try to parse the AuthStatement
247     try {
248       s = new SAMLAuthenticationStatement(str);
249     } catch (...) {
250       mysql_free_result(rows);
251       throw;
252     }
253
254     // Insert it into the memory cache
255     if (s)
256       m_cache->insert(key, s, row[0]);
257
258     // Free the results, and then re-run the 'find' query
259     mysql_free_result(rows);
260     res = m_cache->find(key);
261     if (!res)
262       return NULL;
263   }
264
265   return new ShibMySQLCCacheEntry(key, res, this);
266 }
267
268 void ShibMySQLCCache::insert(const char* key, SAMLAuthenticationStatement *s,
269                             const char *client_addr)
270 {
271   saml::NDC ndc("mysql::insert");
272   ostringstream os;
273   os << *s;
274
275   ostringstream q;
276   q << "INSERT INTO state VALUES('" << key << "', NOW(), '" << client_addr
277     << "', '" << os.str() << "')";
278
279   log->debug("Query: %s", q.str().c_str());
280
281   // Add it to the memory cache
282   m_cache->insert(key, s, client_addr);
283
284   // then add it to the database
285   MYSQL* mysql = getMYSQL();
286   if (mysql_query(mysql, q.str().c_str()))
287     log->error("Error inserting %s: %s", key, mysql_error(mysql));
288 }
289
290 void ShibMySQLCCache::remove(const char* key)
291 {
292   saml::NDC ndc("mysql::remove");
293
294   // Remove the cached version
295   m_cache->remove(key);
296
297   // Remove from the database
298   ostringstream q;
299   q << "DELETE FROM state WHERE cookie='" << key << "'";
300   MYSQL* mysql = getMYSQL();
301   if (mysql_query(mysql, q.str().c_str()))
302     log->error("Error deleting entry %s: %s", key, mysql_error(mysql));
303 }
304
305 void ShibMySQLCCache::cleanup()
306 {
307   Mutex* mutex = Mutex::create();
308   saml::NDC ndc("mysql::cleanup");
309
310   thread_init();
311
312   ShibTargetConfig& config = ShibTargetConfig::getConfig();
313   ShibINI& ini = config.getINI();
314
315   int rerun_timer = 0;
316   int timeout_life = 0;
317
318   string tag;
319   if (ini.get_tag (SHIBTARGET_SHAR, SHIBTARGET_TAG_CACHECLEAN, true, &tag))
320     rerun_timer = atoi(tag.c_str());
321
322   // search for 'mysql-cache-timeout' and then the regular cache timeout
323   if (ini.get_tag (SHIBTARGET_SHAR, "mysql-cache-timeout", true, &tag))
324     timeout_life = atoi(tag.c_str());
325   else if (ini.get_tag (SHIBTARGET_SHAR, SHIBTARGET_TAG_CACHETIMEOUT, true, &tag))
326     timeout_life = atoi(tag.c_str());
327
328   if (rerun_timer <= 0)
329     rerun_timer = 300;          // rerun every 5 minutes
330
331   if (timeout_life <= 0)
332     timeout_life = 28800;       // timeout after 8 hours
333
334   mutex->lock();
335
336   MYSQL* mysql = getMYSQL();
337
338   while (shutdown == false) {
339     shutdown_wait->timedwait(mutex, rerun_timer);
340
341     if (shutdown == true)
342       break;
343
344     // Find all the entries in the database that haven't been used
345     // recently In particular, find all entries that have not been
346     // accessed in 'timeout_life' seconds.
347     ostringstream q;
348     q << "SELECT cookie FROM state WHERE " <<
349       "UNIX_TIMESTAMP(NOW()) - UNIX_TIMESTAMP(atime) >= " << timeout_life;
350
351     MYSQL_RES *rows;
352     if (mysql_query(mysql, q.str().c_str()))
353       log->error("Error searching for old items: %s", mysql_error(mysql));
354
355     rows = mysql_store_result(mysql);
356     if (!rows)
357       continue;
358
359     if (mysql_num_fields(rows) != 1) {
360       log->error("Wrong number of rows, 1 != %d", mysql_num_fields(rows));
361       mysql_free_result(rows);
362       continue;
363     }
364
365     // For each row, remove the entry from the database.
366     MYSQL_ROW row;
367     while ((row = mysql_fetch_row(rows)) != NULL)
368       remove(row[0]);
369
370     mysql_free_result(rows);
371   }
372
373   log->debug("cleanup thread exiting...");
374
375   mutex->unlock();
376   delete mutex;
377   thread_end();
378   Thread::exit(NULL);
379 }
380
381 void* ShibMySQLCCache::cleanup_fcn(void* cache_p)
382 {
383   ShibMySQLCCache* cache = (ShibMySQLCCache*)cache_p;
384
385   // First, let's block all signals
386   Thread::mask_all_signals();
387
388   // Now run the cleanup process.
389   cache->cleanup();
390   return NULL;
391 }
392
393 void ShibMySQLCCache::createDatabase(MYSQL* mysql, int major, int minor)
394 {
395   log->info("Creating database.");
396
397   MYSQL* ms = NULL;
398   try {
399     ms = mysql_init(NULL);
400     if (!ms) {
401       log->crit("mysql_init failed");
402       throw new ShibTargetException();
403     }
404
405     if (!mysql_real_connect(ms, NULL, NULL, NULL, NULL, 0, NULL, 0)) {
406       log->crit("cannot open DB file to create DB: %s", mysql_error(ms));
407       throw new ShibTargetException();
408     }
409
410     if (mysql_query(ms, "CREATE DATABASE shar")) {
411       log->crit("cannot create shar database: %s", mysql_error(ms));
412       throw new ShibTargetException();
413     }
414
415     if (!mysql_real_connect(mysql, NULL, NULL, NULL, "shar", 0, NULL, 0)) {
416       log->crit("cannot open SHAR database");
417       throw new ShibTargetException();
418     }
419
420     mysql_close(ms);
421     
422   } catch (ShibTargetException&) {
423     if (ms)
424       mysql_close(ms);
425     mysql_close(mysql);
426     throw runtime_error("mysql_real_connect");
427   }
428
429   // Now create the tables if they don't exist
430   log->info("Creating database tables.");
431
432   if (mysql_query(mysql, "CREATE TABLE version (major INT, minor INT)"))
433     log->error ("Error creating version: %s", mysql_error(mysql));
434
435   if (mysql_query(mysql,
436                   "CREATE TABLE state (cookie VARCHAR(64) PRIMARY KEY, "
437                   "atime DATETIME, addr VARCHAR(128), statement TEXT)"))
438     log->error ("Error creating state: %s", mysql_error(mysql));
439
440   ostringstream q;
441   q << "INSERT INTO version VALUES(" << major << "," << minor << ")";
442   if (mysql_query(mysql, q.str().c_str()))
443     log->error ("Error setting version: %s", mysql_error(mysql));
444 }
445
446 void ShibMySQLCCache::getVersion(MYSQL* mysql, int* major_p, int* minor_p)
447 {
448   // grab the version number from the database
449   if (mysql_query(mysql, "SELECT * FROM version"))
450     log->error ("Error reading version: %s", mysql_error(mysql));
451
452   MYSQL_RES* rows = mysql_store_result(mysql);
453   if (rows) {
454     if (mysql_num_rows(rows) == 1 && mysql_num_fields(rows) == 2)  {
455       MYSQL_ROW row = mysql_fetch_row(rows);
456
457       int major = row[0] ? atoi(row[0]) : -1;
458       int minor = row[1] ? atoi(row[1]) : -1;
459       log->debug("opening database version %d.%d", major, minor);
460       
461       mysql_free_result (rows);
462
463       *major_p = major;
464       *minor_p = minor;
465       return;
466
467     } else {
468       // Wrong number of rows or wrong number of fields...
469
470       log->crit("Houston, we've got a problem with the database..");
471       mysql_free_result (rows);
472       throw runtime_error("Database version verification failed");
473     }
474   }
475   log->crit("MySQL Read Failed in version verificatoin");
476   throw runtime_error("MySQL Read Failed");
477 }
478
479 void ShibMySQLCCache::mysqlInit(void)
480 {
481   log->info ("Opening MySQL Database");
482
483   // Setup the argument array
484   vector<string> arg_array;
485   string tag;
486
487   tag = SHIBTARGET_SHAR;
488   arg_array.push_back(tag);
489
490   // grab any MySQL parameters from the config file
491   ShibTargetConfig& config = ShibTargetConfig::getConfig();
492   ShibINI& ini = config.getINI();
493
494   if (ini.exists("mysql")) {
495     ShibINI::Iterator* iter = ini.tag_iterator("mysql");
496
497     for (const string* str = iter->begin(); str; str = iter->next()) {
498       string arg = ini.get("mysql", *str);
499       arg_array.push_back(arg);
500     }
501     delete iter;
502   }
503
504   // Compute the argument array
505   int arg_count = arg_array.size();
506   const char** args=new const char*[arg_count];
507   for (int i = 0; i < arg_count; i++)
508     args[i] = arg_array[i].c_str();
509
510   // Initialize MySQL with the arguments
511   mysql_server_init(arg_count, (char **)args, NULL);
512
513   delete[] args;
514 }  
515
516 /*************************************************************************
517  * The CCacheEntry here is mostly a wrapper around the "memory"
518  * cacheentry provided by shibboleth.  The only difference is that we
519  * intercept the isSessionValid() so that we can "touch()" the
520  * database if the session is still valid.
521  */
522
523 ShibMySQLCCacheEntry::ShibMySQLCCacheEntry(const char* key, CCacheEntry *entry,
524                                          ShibMySQLCCache* cache)
525 {
526   m_cacheEntry = entry;
527   m_key = key;
528   m_cache = cache;
529 }
530
531 bool ShibMySQLCCacheEntry::isSessionValid(time_t lifetime, time_t timeout)
532 {
533   bool res = m_cacheEntry->isSessionValid(lifetime, timeout);
534   if (res == true)
535     touch();
536   return res;
537 }
538
539 void ShibMySQLCCacheEntry::touch()
540 {
541   ostringstream q;
542   q << "UPDATE state SET atime=NOW() WHERE cookie='" << m_key << "'";
543
544   MYSQL* mysql = m_cache->getMYSQL();
545   if (mysql_query(mysql, q.str().c_str()))
546     m_cache->log->error("Error updating timestamp on %s: %s",
547                         m_key.c_str(), mysql_error(mysql));
548 }
549
550 /*************************************************************************
551  * The registration functions here...
552  */
553
554 extern "C" CCache* new_mysql_ccache(void)
555 {
556   return new ShibMySQLCCache();
557 }
558
559 extern "C" int SHIBMYSQL_EXPORTS saml_extension_init(void* context)
560 {
561   // register this ccache type
562   CCache::registerFactory("mysql", &new_mysql_ccache);
563   return 0;
564 }
565
566 /*************************************************************************
567  * Local Functions
568  */
569
570 extern "C" void shib_mysql_destroy_handle(void* data)
571 {
572   MYSQL* mysql = (MYSQL*) data;
573   mysql_close(mysql);
574 }