Reworked arg array (error'd on Windows)
[shibboleth/sp.git] / shib-mysql-ccache / shib-mysql-ccache.cpp
1 /*
2  * shib-mysql-ccache.cpp: Shibboleth Credential Cache using MySQL.
3  *
4  * Created by:  Derek Atkins <derek@ihtfp.com>
5  *
6  * $Id$
7  */
8
9 /* This file is loosely based off the Shibboleth Credential Cache.
10  * This plug-in is designed as a two-layer cache.  Layer 1, the
11  * long-term cache, stores data in a MySQL embedded database.  The
12  * data stored in layer 1 is only the session id (cookie), the
13  * "posted" SAML statement (expanded into an XML string), and usage
14  * timestamps.
15  *
16  * Short-term data is cached in memory as SAML objects in the layer 2
17  * cache.  Data like Attribute Authority assertions are stored in
18  * the layer 2 cache.
19  */
20
21 #ifndef WIN32
22 # include <unistd.h>
23 #endif
24
25 #include <shib-target/shib-target.h>
26 #include <shib-target/ccache-utils.h>
27 #include <shib/shib-threads.h>
28 #include <log4cpp/Category.hh>
29
30 #include <sstream>
31 #include <stdexcept>
32
33 #include <mysql.h>
34
35 #ifdef HAVE_LIBDMALLOCXX
36 #include <dmalloc.h>
37 #endif
38
39 using namespace std;
40 using namespace saml;
41 using namespace shibboleth;
42 using namespace shibtarget;
43
44 class ShibMySQLCCache;
45 class ShibMySQLCCacheEntry : public CCacheEntry
46 {
47 public:
48   ShibMySQLCCacheEntry(const char *, CCacheEntry*, ShibMySQLCCache*);
49   ~ShibMySQLCCacheEntry() {}
50
51   virtual Iterator<SAMLAssertion*> getAssertions(Resource& resource)
52         { return m_cacheEntry->getAssertions(resource); }
53   virtual void preFetch(Resource& resource, int prefetch_window)
54         { m_cacheEntry->preFetch(resource, prefetch_window); }
55   virtual bool isSessionValid(time_t lifetime, time_t timeout);
56   virtual const char* getClientAddress()
57         { return m_cacheEntry->getClientAddress(); }
58   virtual const char* getSerializedStatement()
59         { return m_cacheEntry->getSerializedStatement(); }
60   virtual const SAMLAuthenticationStatement* getStatement()
61         { return m_cacheEntry->getStatement(); }
62   virtual void release()
63         { m_cacheEntry->release(); delete this; }
64
65 private:
66   void touch();
67
68   ShibMySQLCCache* m_cache;
69   CCacheEntry *m_cacheEntry;
70   string m_key;
71 };
72
73 class ShibMySQLCCache : public CCache
74 {
75 public:
76   ShibMySQLCCache();
77   virtual ~ShibMySQLCCache();
78
79   virtual SAMLBinding* getBinding(const XMLCh* bindingProt)
80         { return m_cache->getBinding(bindingProt); }
81   virtual CCacheEntry* find(const char* key);
82   virtual void insert(const char* key, SAMLAuthenticationStatement *s,
83                       const char *client_addr);
84   virtual void remove(const char* key);
85   virtual void thread_init();
86
87   void  cleanup();
88   MYSQL* getMYSQL();
89
90   log4cpp::Category* log;
91
92 private:
93   CCache* m_cache;
94   ThreadKey* m_mysql;
95
96   static void*  cleanup_fcn(void*); // XXX Assumed an ShibMySQLCCache
97   CondWait* shutdown_wait;
98   bool shutdown;
99   Thread* cleanup_thread;
100
101   bool initialized;
102
103   void createDatabase(MYSQL*, int major, int minor);
104   void getVersion(MYSQL*, int* major_p, int* minor_p);
105   void mysqlInit(void);
106 };
107
108 // Forward declarations
109 extern "C" void shib_mysql_destroy_handle(void* data);
110
111 /*************************************************************************
112  * The CCache here talks to a MySQL database.  The database stores
113  * three items: the cookie (session key index), the lastAccess time, and
114  * the SAMLAuthenticationStatement.  All other access is performed
115  * through the memory cache provided by shibboleth.
116  */
117
118 MYSQL* ShibMySQLCCache::getMYSQL()
119 {
120   void* data = m_mysql->getData();
121   return (MYSQL*)data;
122 }
123
124 void ShibMySQLCCache::thread_init()
125 {
126   saml::NDC ndc("open_db");
127
128   // Connect to the database
129   MYSQL* mysql = mysql_init(NULL);
130   if (!mysql) {
131     log->error("mysql_init failed");
132     mysql_close(mysql);
133     throw runtime_error("mysql_init()");
134   }
135
136   if (!mysql_real_connect(mysql, NULL, NULL, NULL, "shar", 0, NULL, 0)) {
137     if (initialized) {
138       log->crit("mysql_real_connect failed: %s", mysql_error(mysql));
139       throw runtime_error("mysql_real_connect");
140
141     } else {
142       log->info("mysql_real_connect failed: %s.  Trying to create",
143                 mysql_error(mysql));
144
145       // This will throw a runtime error if it fails.
146       createDatabase(mysql, 0, 0);
147     }
148   }
149
150   int major = -1, minor = -1;
151   getVersion (mysql, &major, &minor);
152
153   // Make sure we've got the right version
154   if (major != 0 && minor != 0) {
155     log->crit("Invalid database version: %d.%d", major, minor);
156     throw runtime_error("Invalid Database version");
157   }
158
159   // We're all set.. Save off the handle for this thread.
160   m_mysql->setData((void*)mysql);
161 }
162
163 ShibMySQLCCache::ShibMySQLCCache()
164 {
165   saml::NDC ndc("shibmysql::ShibMySQLCCache");
166
167   m_mysql = ThreadKey::create(&shib_mysql_destroy_handle);
168   string ctx = "shibmysql::ShibMySQLCCache";
169   log = &(log4cpp::Category::getInstance(ctx));
170
171   initialized = false;
172   mysqlInit();
173   thread_init();
174   initialized = true;
175
176   m_cache = CCache::getInstance("memory");
177
178   // Initialize the cleanup thread
179   shutdown_wait = CondWait::create();
180   shutdown = false;
181   cleanup_thread = Thread::create(&cleanup_fcn, (void*)this);
182 }
183
184 ShibMySQLCCache::~ShibMySQLCCache()
185 {
186   shutdown = true;
187   shutdown_wait->signal();
188   cleanup_thread->join(NULL);
189
190   delete m_cache;
191   delete m_mysql;
192
193   // Shutdown MySQL
194   mysql_server_end();
195 }
196
197 CCacheEntry* ShibMySQLCCache::find(const char* key)
198 {
199   saml::NDC ndc("mysql::find");
200   CCacheEntry* res = m_cache->find(key);
201   if (!res) {
202
203     log->debug("Looking in database...");
204
205     // nothing cached; see if this exists in the database
206     ostringstream q;
207     q << "SELECT addr,statement FROM state WHERE cookie='" << key << "' LIMIT 1";
208
209     MYSQL_RES* rows;
210     MYSQL* mysql = getMYSQL();
211     if (mysql_query(mysql, q.str().c_str()))
212       log->error("Error searching for %s: %s", key, mysql_error(mysql));
213
214     rows = mysql_store_result(mysql);
215
216     // Nope, doesn't exist.
217     if (!rows)
218       return NULL;
219
220     // Make sure we got 1 and only 1 rows.
221     if (mysql_num_rows(rows) != 1) {
222       log->error("Select returned wrong number of rows: %d", mysql_num_rows(rows));
223       mysql_free_result(rows);
224       return NULL;
225     }
226
227     log->debug("Match found.  Parsing...");
228     // Pull apart the row and process the results
229     MYSQL_ROW row = mysql_fetch_row(rows);
230     istringstream str(row[1]);
231     SAMLAuthenticationStatement *s = NULL;
232
233     // Try to parse the AuthStatement
234     try {
235       s = new SAMLAuthenticationStatement(str);
236     } catch (...) {
237       mysql_free_result(rows);
238       throw;
239     }
240
241     // Insert it into the memory cache
242     if (s)
243       m_cache->insert(key, s, row[0]);
244
245     // Free the results, and then re-run the 'find' query
246     mysql_free_result(rows);
247     res = m_cache->find(key);
248     if (!res)
249       return NULL;
250   }
251
252   return new ShibMySQLCCacheEntry(key, res, this);
253 }
254
255 void ShibMySQLCCache::insert(const char* key, SAMLAuthenticationStatement *s,
256                             const char *client_addr)
257 {
258   saml::NDC ndc("mysql::insert");
259   ostringstream os;
260   os << *s;
261
262   ostringstream q;
263   q << "INSERT INTO state VALUES('" << key << "', NOW(), '" << client_addr
264     << "', '" << os.str() << "')";
265
266   log->debug("Query: %s", q.str().c_str());
267
268   // Add it to the memory cache
269   m_cache->insert(key, s, client_addr);
270
271   // then add it to the database
272   MYSQL* mysql = getMYSQL();
273   if (mysql_query(mysql, q.str().c_str()))
274     log->error("Error inserting %s: %s", key, mysql_error(mysql));
275 }
276
277 void ShibMySQLCCache::remove(const char* key)
278 {
279   saml::NDC ndc("mysql::remove");
280
281   // Remove the cached version
282   m_cache->remove(key);
283
284   // Remove from the database
285   ostringstream q;
286   q << "DELETE FROM state WHERE cookie='" << key << "'";
287   MYSQL* mysql = getMYSQL();
288   if (mysql_query(mysql, q.str().c_str()))
289     log->error("Error deleting entry %s: %s", key, mysql_error(mysql));
290 }
291
292 void ShibMySQLCCache::cleanup()
293 {
294   Mutex* mutex = Mutex::create();
295   saml::NDC ndc("mysql::cleanup()");
296
297   thread_init();
298
299   ShibTargetConfig& config = ShibTargetConfig::getConfig();
300   ShibINI& ini = config.getINI();
301
302   int rerun_timer = 0;
303   int timeout_life = 0;
304
305   string tag;
306   if (ini.get_tag (SHIBTARGET_SHAR, SHIBTARGET_TAG_CACHECLEAN, true, &tag))
307     rerun_timer = atoi(tag.c_str());
308
309   // search for 'mysql-cache-timeout' and then the regular cache timeout
310   if (ini.get_tag (SHIBTARGET_SHAR, "mysql-cache-timeout", true, &tag))
311     timeout_life = atoi(tag.c_str());
312   else if (ini.get_tag (SHIBTARGET_SHAR, SHIBTARGET_TAG_CACHETIMEOUT, true, &tag))
313     timeout_life = atoi(tag.c_str());
314
315   if (rerun_timer <= 0)
316     rerun_timer = 300;          // rerun every 5 minutes
317
318   if (timeout_life <= 0)
319     timeout_life = 28800;       // timeout after 8 hours
320
321   mutex->lock();
322
323   MYSQL* mysql = getMYSQL();
324
325   while (shutdown == false) {
326     shutdown_wait->timedwait(mutex, rerun_timer);
327
328     if (shutdown == true)
329       break;
330
331     // Find all the entries in the database that haven't been used
332     // recently In particular, find all entries that have not been
333     // accessed in 'timeout_life' seconds.
334     ostringstream q;
335     q << "SELECT cookie FROM state WHERE " <<
336       "UNIX_TIMESTAMP(NOW()) - UNIX_TIMESTAMP(atime) >= " << timeout_life;
337
338     MYSQL_RES *rows;
339     if (mysql_query(mysql, q.str().c_str()))
340       log->error("Error searching for old items: %s", mysql_error(mysql));
341
342     rows = mysql_store_result(mysql);
343     if (!rows)
344       continue;
345
346     if (mysql_num_fields(rows) != 1) {
347       log->error("Wrong number of rows, 1 != %d", mysql_num_fields(rows));
348       mysql_free_result(rows);
349       continue;
350     }
351
352     // For each row, remove the entry from the database.
353     MYSQL_ROW row;
354     while ((row = mysql_fetch_row(rows)) != NULL)
355       remove(row[0]);
356
357     mysql_free_result(rows);
358   }
359
360   mutex->unlock();
361   delete mutex;
362   thread_end();
363   Thread::exit(NULL);
364 }
365
366 void* ShibMySQLCCache::cleanup_fcn(void* cache_p)
367 {
368   ShibMySQLCCache* cache = (ShibMySQLCCache*)cache_p;
369
370   // First, let's block all signals
371   Thread::mask_all_signals();
372
373   // Now run the cleanup process.
374   cache->cleanup();
375   return NULL;
376 }
377
378 void ShibMySQLCCache::createDatabase(MYSQL* mysql, int major, int minor)
379 {
380   log->info("Creating database.");
381
382   MYSQL* ms = NULL;
383   try {
384     ms = mysql_init(NULL);
385     if (!ms) {
386       log->crit("mysql_init failed");
387       throw new ShibTargetException();
388     }
389
390     if (!mysql_real_connect(ms, NULL, NULL, NULL, NULL, 0, NULL, 0)) {
391       log->crit("cannot open DB file to create DB: %s", mysql_error(ms));
392       throw new ShibTargetException();
393     }
394
395     if (mysql_query(ms, "CREATE DATABASE shar")) {
396       log->crit("cannot create shar database: %s", mysql_error(ms));
397       throw new ShibTargetException();
398     }
399
400     if (!mysql_real_connect(mysql, NULL, NULL, NULL, "shar", 0, NULL, 0)) {
401       log->crit("cannot open SHAR database");
402       throw new ShibTargetException();
403     }
404
405     mysql_close(ms);
406     
407   } catch (ShibTargetException&) {
408     if (ms)
409       mysql_close(ms);
410     mysql_close(mysql);
411     throw runtime_error("mysql_real_connect");
412   }
413
414   // Now create the tables if they don't exist
415   log->info("Creating database tables.");
416
417   if (mysql_query(mysql, "CREATE TABLE version (major INT, minor INT)"))
418     log->error ("Error creating version: %s", mysql_error(mysql));
419
420   if (mysql_query(mysql,
421                   "CREATE TABLE state (cookie VARCHAR(64) PRIMARY KEY, "
422                   "atime DATETIME, addr VARCHAR(128), statement TEXT)"))
423     log->error ("Error creating state: %s", mysql_error(mysql));
424
425   ostringstream q;
426   q << "INSERT INTO version VALUES(" << major << "," << minor << ")";
427   if (mysql_query(mysql, q.str().c_str()))
428     log->error ("Error setting version: %s", mysql_error(mysql));
429 }
430
431 void ShibMySQLCCache::getVersion(MYSQL* mysql, int* major_p, int* minor_p)
432 {
433   // grab the version number from the database
434   if (mysql_query(mysql, "SELECT * FROM version"))
435     log->error ("Error reading version: %s", mysql_error(mysql));
436
437   MYSQL_RES* rows = mysql_store_result(mysql);
438   if (rows) {
439     if (mysql_num_rows(rows) == 1 && mysql_num_fields(rows) == 2)  {
440       MYSQL_ROW row = mysql_fetch_row(rows);
441
442       int major = row[0] ? atoi(row[0]) : -1;
443       int minor = row[1] ? atoi(row[1]) : -1;
444       log->debug("opening database version %d.%d", major, minor);
445       
446       mysql_free_result (rows);
447
448       *major_p = major;
449       *minor_p = minor;
450       return;
451
452     } else {
453       // Wrong number of rows or wrong number of fields...
454
455       log->crit("Houston, we've got a problem with the database..");
456       mysql_free_result (rows);
457       throw runtime_error("Database version verification failed");
458     }
459   }
460   log->crit("MySQL Read Failed in version verificatoin");
461   throw runtime_error("MySQL Read Failed");
462 }
463
464 void ShibMySQLCCache::mysqlInit(void)
465 {
466   log->info ("Opening MySQL Database");
467
468   // Setup the argument array
469   vector<string> arg_array;
470   string tag;
471
472   tag = SHIBTARGET_SHAR;
473   arg_array.push_back(tag);
474
475   // grab any MySQL parameters from the config file
476   ShibTargetConfig& config = ShibTargetConfig::getConfig();
477   ShibINI& ini = config.getINI();
478
479   if (ini.exists("mysql")) {
480     ShibINI::Iterator* iter = ini.tag_iterator("mysql");
481
482     for (const string* str = iter->begin(); str; str = iter->next()) {
483       string arg = ini.get("mysql", *str);
484       arg_array.push_back(arg);
485     }
486     delete iter;
487   }
488
489   // Compute the argument array
490   int arg_count = arg_array.size();
491   const char** args=new const char*[arg_count];
492   for (int i = 0; i < arg_count; i++)
493     args[i] = arg_array[i].c_str();
494
495   // Initialize MySQL with the arguments
496   mysql_server_init(arg_count, (char **)args, NULL);
497
498   delete[] args;
499 }  
500
501 /*************************************************************************
502  * The CCacheEntry here is mostly a wrapper around the "memory"
503  * cacheentry provided by shibboleth.  The only difference is that we
504  * intercept the isSessionValid() so that we can "touch()" the
505  * database if the session is still valid.
506  */
507
508 ShibMySQLCCacheEntry::ShibMySQLCCacheEntry(const char* key, CCacheEntry *entry,
509                                          ShibMySQLCCache* cache)
510 {
511   m_cacheEntry = entry;
512   m_key = key;
513   m_cache = cache;
514 }
515
516 bool ShibMySQLCCacheEntry::isSessionValid(time_t lifetime, time_t timeout)
517 {
518   bool res = m_cacheEntry->isSessionValid(lifetime, timeout);
519   if (res == true)
520     touch();
521   return res;
522 }
523
524 void ShibMySQLCCacheEntry::touch()
525 {
526   ostringstream q;
527   q << "UPDATE state SET atime=NOW() WHERE cookie='" << m_key << "'";
528
529   MYSQL* mysql = m_cache->getMYSQL();
530   if (mysql_query(mysql, q.str().c_str()))
531     m_cache->log->error("Error updating timestamp on %s: %s",
532                         m_key.c_str(), mysql_error(mysql));
533 }
534
535 /*************************************************************************
536  * The registration functions here...
537  */
538
539 extern "C" CCache* new_mysql_ccache(void)
540 {
541   return new ShibMySQLCCache();
542 }
543
544 extern "C" int saml_extension_init(void* context)
545 {
546   // register this ccache type
547   CCache::registerFactory("mysql", &new_mysql_ccache);
548   return 0;
549 }
550
551 /*************************************************************************
552  * Local Functions
553  */
554
555 extern "C" void shib_mysql_destroy_handle(void* data)
556 {
557   MYSQL* mysql = (MYSQL*) data;
558   mysql_close(mysql);
559 }