Redesigned target around URL->application mapping
[shibboleth/sp.git] / shib-mysql-ccache / shib-mysql-ccache.cpp
1 /*
2  * shib-mysql-ccache.cpp: Shibboleth Credential Cache using MySQL.
3  *
4  * Created by:  Derek Atkins <derek@ihtfp.com>
5  *
6  * $Id$
7  */
8
9 /* This file is loosely based off the Shibboleth Credential Cache.
10  * This plug-in is designed as a two-layer cache.  Layer 1, the
11  * long-term cache, stores data in a MySQL embedded database.  The
12  * data stored in layer 1 is only the session id (cookie), the
13  * "posted" SAML statement (expanded into an XML string), and usage
14  * timestamps.
15  *
16  * Short-term data is cached in memory as SAML objects in the layer 2
17  * cache.  Data like Attribute Authority assertions are stored in
18  * the layer 2 cache.
19  */
20
21 // eventually we might be able to support autoconf via cygwin...
22 #if defined (_MSC_VER) || defined(__BORLANDC__)
23 # include "config_win32.h"
24 #else
25 # include "config.h"
26 #endif
27
28 #ifdef WIN32
29 # define SHIBMYSQL_EXPORTS __declspec(dllexport)
30 #else
31 # define SHIBMYSQL_EXPORTS
32 #endif
33
34 #ifdef HAVE_UNISTD_H
35 # include <unistd.h>
36 #endif
37
38 #include <shib-target/shib-target.h>
39 #include <shib-target/ccache-utils.h>
40 #include <shib/shib-threads.h>
41 #include <log4cpp/Category.hh>
42
43 #include <sstream>
44 #include <stdexcept>
45
46 #include <mysql.h>
47
48 #ifdef HAVE_LIBDMALLOCXX
49 #include <dmalloc.h>
50 #endif
51
52 using namespace std;
53 using namespace saml;
54 using namespace shibboleth;
55 using namespace shibtarget;
56
57 class ShibMySQLCCache;
58 class ShibMySQLCCacheEntry : public CCacheEntry
59 {
60 public:
61   ShibMySQLCCacheEntry(const char *, CCacheEntry*, ShibMySQLCCache*);
62   ~ShibMySQLCCacheEntry() {}
63
64   virtual Iterator<SAMLAssertion*> getAssertions(const char* resource)
65         { return m_cacheEntry->getAssertions(resource); }
66   virtual void preFetch(const char* resource, int prefetch_window)
67         { m_cacheEntry->preFetch(resource, prefetch_window); }
68   virtual bool isSessionValid(time_t lifetime, time_t timeout);
69   virtual const char* getClientAddress()
70         { return m_cacheEntry->getClientAddress(); }
71   virtual const char* getSerializedStatement()
72         { return m_cacheEntry->getSerializedStatement(); }
73   virtual const SAMLAuthenticationStatement* getStatement()
74         { return m_cacheEntry->getStatement(); }
75   virtual void release()
76         { m_cacheEntry->release(); delete this; }
77
78 private:
79   bool touch();
80
81   ShibMySQLCCache* m_cache;
82   CCacheEntry *m_cacheEntry;
83   string m_key;
84 };
85
86 class ShibMySQLCCache : public CCache
87 {
88 public:
89   ShibMySQLCCache();
90   virtual ~ShibMySQLCCache();
91
92   virtual CCacheEntry* find(const char* key);
93   virtual void insert(const char* key, SAMLAuthenticationStatement *s,
94                       const char *client_addr);
95   virtual void remove(const char* key);
96   virtual void thread_init();
97
98   void  cleanup();
99   MYSQL* getMYSQL();
100
101   log4cpp::Category* log;
102
103 private:
104   CCache* m_cache;
105   ThreadKey* m_mysql;
106
107   static void*  cleanup_fcn(void*); // XXX Assumed an ShibMySQLCCache
108   CondWait* shutdown_wait;
109   bool shutdown;
110   Thread* cleanup_thread;
111
112   bool initialized;
113
114   void createDatabase(MYSQL*, int major, int minor);
115   void getVersion(MYSQL*, int* major_p, int* minor_p);
116   void mysqlInit(void);
117 };
118
119 // Forward declarations
120 extern "C" void shib_mysql_destroy_handle(void* data);
121
122 /*************************************************************************
123  * The CCache here talks to a MySQL database.  The database stores
124  * three items: the cookie (session key index), the lastAccess time, and
125  * the SAMLAuthenticationStatement.  All other access is performed
126  * through the memory cache provided by shibboleth.
127  */
128
129 MYSQL* ShibMySQLCCache::getMYSQL()
130 {
131   void* data = m_mysql->getData();
132   return (MYSQL*)data;
133 }
134
135 void ShibMySQLCCache::thread_init()
136 {
137   saml::NDC ndc("thread_init");
138
139   // Connect to the database
140   MYSQL* mysql = mysql_init(NULL);
141   if (!mysql) {
142     log->error("mysql_init failed");
143     mysql_close(mysql);
144     throw runtime_error("mysql_init()");
145   }
146
147   if (!mysql_real_connect(mysql, NULL, NULL, NULL, "shar", 0, NULL, 0)) {
148     if (initialized) {
149       log->crit("mysql_real_connect failed: %s", mysql_error(mysql));
150       throw runtime_error("mysql_real_connect");
151
152     } else {
153       log->info("mysql_real_connect failed: %s.  Trying to create",
154                 mysql_error(mysql));
155
156       // This will throw a runtime error if it fails.
157       createDatabase(mysql, 0, 0);
158     }
159   }
160
161   int major = -1, minor = -1;
162   getVersion (mysql, &major, &minor);
163
164   // Make sure we've got the right version
165   if (major != 0 && minor != 0) {
166     log->crit("Invalid database version: %d.%d", major, minor);
167     throw runtime_error("Invalid Database version");
168   }
169
170   // We're all set.. Save off the handle for this thread.
171   m_mysql->setData((void*)mysql);
172 }
173
174 ShibMySQLCCache::ShibMySQLCCache()
175 {
176   saml::NDC ndc("shibmysql::ShibMySQLCCache");
177
178   m_mysql = ThreadKey::create(&shib_mysql_destroy_handle);
179   string ctx = "shibmysql::ShibMySQLCCache";
180   log = &(log4cpp::Category::getInstance(ctx));
181
182   initialized = false;
183   mysqlInit();
184   thread_init();
185   initialized = true;
186
187   m_cache = CCache::getInstance("memory");
188
189   // Initialize the cleanup thread
190   shutdown_wait = CondWait::create();
191   shutdown = false;
192   cleanup_thread = Thread::create(&cleanup_fcn, (void*)this);
193 }
194
195 ShibMySQLCCache::~ShibMySQLCCache()
196 {
197   shutdown = true;
198   shutdown_wait->signal();
199   cleanup_thread->join(NULL);
200
201   delete m_cache;
202   delete m_mysql;
203
204   // Shutdown MySQL
205   mysql_server_end();
206 }
207
208 CCacheEntry* ShibMySQLCCache::find(const char* key)
209 {
210   saml::NDC ndc("mysql::find");
211   CCacheEntry* res = m_cache->find(key);
212   if (!res) {
213
214     log->debug("Looking in database...");
215
216     // nothing cached; see if this exists in the database
217     ostringstream q;
218     q << "SELECT addr,statement FROM state WHERE cookie='" << key << "' LIMIT 1";
219
220     MYSQL_RES* rows;
221     MYSQL* mysql = getMYSQL();
222     if (mysql_query(mysql, q.str().c_str()))
223       log->error("Error searching for %s: %s", key, mysql_error(mysql));
224
225     rows = mysql_store_result(mysql);
226
227     // Nope, doesn't exist.
228     if (!rows)
229       return NULL;
230
231     // Make sure we got 1 and only 1 rows.
232     if (mysql_num_rows(rows) != 1) {
233       log->error("Select returned wrong number of rows: %d", mysql_num_rows(rows));
234       mysql_free_result(rows);
235       return NULL;
236     }
237
238     log->debug("Match found.  Parsing...");
239     // Pull apart the row and process the results
240     MYSQL_ROW row = mysql_fetch_row(rows);
241     istringstream str(row[1]);
242     SAMLAuthenticationStatement *s = NULL;
243
244     // Try to parse the AuthStatement
245     try {
246       s = new SAMLAuthenticationStatement(str);
247     } catch (...) {
248       mysql_free_result(rows);
249       throw;
250     }
251
252     // Insert it into the memory cache
253     if (s)
254       m_cache->insert(key, s, row[0]);
255
256     // Free the results, and then re-run the 'find' query
257     mysql_free_result(rows);
258     res = m_cache->find(key);
259     if (!res)
260       return NULL;
261   }
262
263   return new ShibMySQLCCacheEntry(key, res, this);
264 }
265
266 void ShibMySQLCCache::insert(const char* key, SAMLAuthenticationStatement *s,
267                             const char *client_addr)
268 {
269   saml::NDC ndc("mysql::insert");
270   ostringstream os;
271   os << *s;
272
273   ostringstream q;
274   q << "INSERT INTO state VALUES('" << key << "', NOW(), '" << client_addr
275     << "', '" << os.str() << "')";
276
277   log->debug("Query: %s", q.str().c_str());
278
279   // Add it to the memory cache
280   m_cache->insert(key, s, client_addr);
281
282   // then add it to the database
283   MYSQL* mysql = getMYSQL();
284   if (mysql_query(mysql, q.str().c_str()))
285     log->error("Error inserting %s: %s", key, mysql_error(mysql));
286 }
287
288 void ShibMySQLCCache::remove(const char* key)
289 {
290   saml::NDC ndc("mysql::remove");
291
292   // Remove the cached version
293   m_cache->remove(key);
294
295   // Remove from the database
296   ostringstream q;
297   q << "DELETE FROM state WHERE cookie='" << key << "'";
298   MYSQL* mysql = getMYSQL();
299   if (mysql_query(mysql, q.str().c_str()))
300     log->info("Error deleting entry %s: %s", key, mysql_error(mysql));
301 }
302
303 void ShibMySQLCCache::cleanup()
304 {
305   Mutex* mutex = Mutex::create();
306   saml::NDC ndc("mysql::cleanup");
307
308   thread_init();
309
310   ShibTargetConfig& config = ShibTargetConfig::getConfig();
311   ShibINI& ini = config.getINI();
312
313   int rerun_timer = 0;
314   int timeout_life = 0;
315
316   string tag;
317   if (ini.get_tag (SHIBTARGET_SHAR, SHIBTARGET_TAG_CACHECLEAN, true, &tag))
318     rerun_timer = atoi(tag.c_str());
319
320   // search for 'mysql-cache-timeout' and then the regular cache timeout
321   if (ini.get_tag (SHIBTARGET_SHAR, "mysql-cache-timeout", true, &tag))
322     timeout_life = atoi(tag.c_str());
323   else if (ini.get_tag (SHIBTARGET_SHAR, SHIBTARGET_TAG_CACHETIMEOUT, true, &tag))
324     timeout_life = atoi(tag.c_str());
325
326   if (rerun_timer <= 0)
327     rerun_timer = 300;          // rerun every 5 minutes
328
329   if (timeout_life <= 0)
330     timeout_life = 28800;       // timeout after 8 hours
331
332   mutex->lock();
333
334   MYSQL* mysql = getMYSQL();
335
336   while (shutdown == false) {
337     shutdown_wait->timedwait(mutex, rerun_timer);
338
339     if (shutdown == true)
340       break;
341
342     // Find all the entries in the database that haven't been used
343     // recently In particular, find all entries that have not been
344     // accessed in 'timeout_life' seconds.
345     ostringstream q;
346     q << "SELECT cookie FROM state WHERE " <<
347       "UNIX_TIMESTAMP(NOW()) - UNIX_TIMESTAMP(atime) >= " << timeout_life;
348
349     MYSQL_RES *rows;
350     if (mysql_query(mysql, q.str().c_str()))
351       log->error("Error searching for old items: %s", mysql_error(mysql));
352
353     rows = mysql_store_result(mysql);
354     if (!rows)
355       continue;
356
357     if (mysql_num_fields(rows) != 1) {
358       log->error("Wrong number of rows, 1 != %d", mysql_num_fields(rows));
359       mysql_free_result(rows);
360       continue;
361     }
362
363     // For each row, remove the entry from the database.
364     MYSQL_ROW row;
365     while ((row = mysql_fetch_row(rows)) != NULL)
366       remove(row[0]);
367
368     mysql_free_result(rows);
369   }
370
371   log->debug("cleanup thread exiting...");
372
373   mutex->unlock();
374   delete mutex;
375   thread_end();
376   Thread::exit(NULL);
377 }
378
379 void* ShibMySQLCCache::cleanup_fcn(void* cache_p)
380 {
381   ShibMySQLCCache* cache = (ShibMySQLCCache*)cache_p;
382
383   // First, let's block all signals
384   Thread::mask_all_signals();
385
386   // Now run the cleanup process.
387   cache->cleanup();
388   return NULL;
389 }
390
391 void ShibMySQLCCache::createDatabase(MYSQL* mysql, int major, int minor)
392 {
393   log->info("Creating database.");
394
395   MYSQL* ms = NULL;
396   try {
397     ms = mysql_init(NULL);
398     if (!ms) {
399       log->crit("mysql_init failed");
400       throw new ShibTargetException();
401     }
402
403     if (!mysql_real_connect(ms, NULL, NULL, NULL, NULL, 0, NULL, 0)) {
404       log->crit("cannot open DB file to create DB: %s", mysql_error(ms));
405       throw new ShibTargetException();
406     }
407
408     if (mysql_query(ms, "CREATE DATABASE shar")) {
409       log->crit("cannot create shar database: %s", mysql_error(ms));
410       throw new ShibTargetException();
411     }
412
413     if (!mysql_real_connect(mysql, NULL, NULL, NULL, "shar", 0, NULL, 0)) {
414       log->crit("cannot open SHAR database");
415       throw new ShibTargetException();
416     }
417
418     mysql_close(ms);
419     
420   } catch (ShibTargetException&) {
421     if (ms)
422       mysql_close(ms);
423     mysql_close(mysql);
424     throw runtime_error("mysql_real_connect");
425   }
426
427   // Now create the tables if they don't exist
428   log->info("Creating database tables.");
429
430   if (mysql_query(mysql, "CREATE TABLE version (major INT, minor INT)"))
431     log->error ("Error creating version: %s", mysql_error(mysql));
432
433   if (mysql_query(mysql,
434                   "CREATE TABLE state (cookie VARCHAR(64) PRIMARY KEY, "
435                   "atime DATETIME, addr VARCHAR(128), statement TEXT)"))
436     log->error ("Error creating state: %s", mysql_error(mysql));
437
438   ostringstream q;
439   q << "INSERT INTO version VALUES(" << major << "," << minor << ")";
440   if (mysql_query(mysql, q.str().c_str()))
441     log->error ("Error setting version: %s", mysql_error(mysql));
442 }
443
444 void ShibMySQLCCache::getVersion(MYSQL* mysql, int* major_p, int* minor_p)
445 {
446   // grab the version number from the database
447   if (mysql_query(mysql, "SELECT * FROM version"))
448     log->error ("Error reading version: %s", mysql_error(mysql));
449
450   MYSQL_RES* rows = mysql_store_result(mysql);
451   if (rows) {
452     if (mysql_num_rows(rows) == 1 && mysql_num_fields(rows) == 2)  {
453       MYSQL_ROW row = mysql_fetch_row(rows);
454
455       int major = row[0] ? atoi(row[0]) : -1;
456       int minor = row[1] ? atoi(row[1]) : -1;
457       log->debug("opening database version %d.%d", major, minor);
458       
459       mysql_free_result (rows);
460
461       *major_p = major;
462       *minor_p = minor;
463       return;
464
465     } else {
466       // Wrong number of rows or wrong number of fields...
467
468       log->crit("Houston, we've got a problem with the database..");
469       mysql_free_result (rows);
470       throw runtime_error("Database version verification failed");
471     }
472   }
473   log->crit("MySQL Read Failed in version verificatoin");
474   throw runtime_error("MySQL Read Failed");
475 }
476
477 void ShibMySQLCCache::mysqlInit(void)
478 {
479   log->info ("Opening MySQL Database");
480
481   // Setup the argument array
482   vector<string> arg_array;
483   string tag;
484
485   tag = SHIBTARGET_SHAR;
486   arg_array.push_back(tag);
487
488   // grab any MySQL parameters from the config file
489   ShibTargetConfig& config = ShibTargetConfig::getConfig();
490   ShibINI& ini = config.getINI();
491
492   if (ini.exists("mysql")) {
493     ShibINI::Iterator* iter = ini.tag_iterator("mysql");
494
495     for (const string* str = iter->begin(); str; str = iter->next()) {
496       string arg = ini.get("mysql", *str);
497       arg_array.push_back(arg);
498     }
499     delete iter;
500   }
501
502   // Compute the argument array
503   int arg_count = arg_array.size();
504   const char** args=new const char*[arg_count];
505   for (int i = 0; i < arg_count; i++)
506     args[i] = arg_array[i].c_str();
507
508   // Initialize MySQL with the arguments
509   mysql_server_init(arg_count, (char **)args, NULL);
510
511   delete[] args;
512 }  
513
514 /*************************************************************************
515  * The CCacheEntry here is mostly a wrapper around the "memory"
516  * cacheentry provided by shibboleth.  The only difference is that we
517  * intercept the isSessionValid() so that we can "touch()" the
518  * database if the session is still valid.
519  */
520
521 ShibMySQLCCacheEntry::ShibMySQLCCacheEntry(const char* key, CCacheEntry *entry,
522                                          ShibMySQLCCache* cache)
523 {
524   m_cacheEntry = entry;
525   m_key = key;
526   m_cache = cache;
527 }
528
529 bool ShibMySQLCCacheEntry::isSessionValid(time_t lifetime, time_t timeout)
530 {
531   bool res = m_cacheEntry->isSessionValid(lifetime, timeout);
532   if (res == true)
533     res = touch();
534   return res;
535 }
536
537 bool ShibMySQLCCacheEntry::touch()
538 {
539   ostringstream q;
540   q << "UPDATE state SET atime=NOW() WHERE cookie='" << m_key << "'";
541
542   MYSQL* mysql = m_cache->getMYSQL();
543   if (mysql_query(mysql, q.str().c_str())) {
544     m_cache->log->info("Error updating timestamp on %s: %s",
545                         m_key.c_str(), mysql_error(mysql));
546     return false;
547   }
548   return true;
549 }
550
551 /*************************************************************************
552  * The registration functions here...
553  */
554
555 extern "C" CCache* new_mysql_ccache(void)
556 {
557   return new ShibMySQLCCache();
558 }
559
560 extern "C" int SHIBMYSQL_EXPORTS saml_extension_init(void* context)
561 {
562   // register this ccache type
563   CCache::registerFactory("mysql", &new_mysql_ccache);
564   return 0;
565 }
566
567 /*************************************************************************
568  * Local Functions
569  */
570
571 extern "C" void shib_mysql_destroy_handle(void* data)
572 {
573   MYSQL* mysql = (MYSQL*) data;
574   mysql_close(mysql);
575 }