Revised metadata API.
[shibboleth/cpp-sp.git] / shib-mysql-ccache / shib-mysql-ccache.cpp
1 /*
2  * shib-mysql-ccache.cpp: Shibboleth Credential Cache using MySQL.
3  *
4  * Created by:  Derek Atkins <derek@ihtfp.com>
5  *
6  * $Id$
7  */
8
9 /* This file is loosely based off the Shibboleth Credential Cache.
10  * This plug-in is designed as a two-layer cache.  Layer 1, the
11  * long-term cache, stores data in a MySQL embedded database.  The
12  * data stored in layer 1 is only the session id (cookie), the
13  * "posted" SAML statement (expanded into an XML string), and usage
14  * timestamps.
15  *
16  * Short-term data is cached in memory as SAML objects in the layer 2
17  * cache.  Data like Attribute Authority assertions are stored in
18  * the layer 2 cache.
19  */
20
21 // eventually we might be able to support autoconf via cygwin...
22 #if defined (_MSC_VER) || defined(__BORLANDC__)
23 # include "config_win32.h"
24 #else
25 # include "config.h"
26 #endif
27
28 #ifdef WIN32
29 # define SHIBMYSQL_EXPORTS __declspec(dllexport)
30 #else
31 # define SHIBMYSQL_EXPORTS
32 #endif
33
34 #ifdef HAVE_UNISTD_H
35 # include <unistd.h>
36 #endif
37
38 #include <shib-target/shib-target.h>
39 #include <shib/shib-threads.h>
40 #include <log4cpp/Category.hh>
41
42 #include <sstream>
43 #include <stdexcept>
44
45 #include <mysql.h>
46
47 #ifdef HAVE_LIBDMALLOCXX
48 #include <dmalloc.h>
49 #endif
50
51 using namespace std;
52 using namespace saml;
53 using namespace shibboleth;
54 using namespace shibtarget;
55
56 #define PLUGIN_VER_MAJOR 1
57 #define PLUGIN_VER_MINOR 0
58
59 static const XMLCh Argument[] =
60 { chLatin_A, chLatin_r, chLatin_g, chLatin_u, chLatin_m, chLatin_e, chLatin_n, chLatin_t, chNull };
61 static const XMLCh cleanupInterval[] =
62 { chLatin_c, chLatin_l, chLatin_e, chLatin_a, chLatin_n, chLatin_u, chLatin_p,
63   chLatin_I, chLatin_n, chLatin_t, chLatin_e, chLatin_r, chLatin_v, chLatin_a, chLatin_l, chNull
64 };
65 static const XMLCh cacheTimeout[] =
66 { chLatin_c, chLatin_a, chLatin_c, chLatin_h, chLatin_e, chLatin_T, chLatin_i, chLatin_m, chLatin_e, chLatin_o, chLatin_u, chLatin_t, chNull };
67 static const XMLCh mysqlTimeout[] =
68 { chLatin_m, chLatin_y, chLatin_s, chLatin_q, chLatin_l, chLatin_T, chLatin_i, chLatin_m, chLatin_e, chLatin_o, chLatin_u, chLatin_t, chNull };
69
70 class ShibMySQLCCache;
71 class ShibMySQLCCacheEntry : public ISessionCacheEntry
72 {
73 public:
74   ShibMySQLCCacheEntry(const char*, ISessionCacheEntry*, ShibMySQLCCache*);
75   ~ShibMySQLCCacheEntry() {}
76
77   virtual void lock() {}
78   virtual void unlock() { m_cacheEntry->unlock(); delete this; }
79   virtual bool isValid(time_t lifetime, time_t timeout) const;
80   virtual const char* getClientAddress() const { return m_cacheEntry->getClientAddress(); }
81   virtual const char* getSerializedStatement() const { return m_cacheEntry->getSerializedStatement(); }
82   virtual const SAMLAuthenticationStatement* getStatement() const { return m_cacheEntry->getStatement(); }
83   virtual Iterator<SAMLAssertion*> getAssertions() { return m_cacheEntry->getAssertions(); }
84   virtual void preFetch(int prefetch_window) { m_cacheEntry->preFetch(prefetch_window); }
85
86 private:
87   bool touch() const;
88
89   ShibMySQLCCache* m_cache;
90   ISessionCacheEntry* m_cacheEntry;
91   string m_key;
92 };
93
94 class ShibMySQLCCache : public ISessionCache
95 {
96 public:
97   ShibMySQLCCache(const DOMElement* e);
98   virtual ~ShibMySQLCCache();
99
100   virtual void thread_init();
101   virtual void thread_end() {}
102
103   virtual string generateKey() const {return m_cache->generateKey();}
104   virtual ISessionCacheEntry* find(const char* key, const IApplication* application);
105   virtual void insert(
106         const char* key,
107         const IApplication* application,
108         SAMLAuthenticationStatement *s,
109         const char *client_addr,
110         SAMLResponse* r=NULL,
111         const IRoleDescriptor* source=NULL);
112   virtual void remove(const char* key);
113
114   void  cleanup();
115   MYSQL* getMYSQL() const;
116
117   log4cpp::Category* log;
118
119 private:
120   ISessionCache* m_cache;
121   ThreadKey* m_mysql;
122   const DOMElement* m_root; // can only use this during initialization
123
124   static void*  cleanup_fcn(void*); // XXX Assumed an ShibMySQLCCache
125   CondWait* shutdown_wait;
126   bool shutdown;
127   Thread* cleanup_thread;
128
129   bool initialized;
130
131   void createDatabase(MYSQL*, int major, int minor);
132   void upgradeDatabase(MYSQL*);
133   void getVersion(MYSQL*, int* major_p, int* minor_p);
134   void mysqlInit(void);
135 };
136
137 // Forward declarations
138 extern "C" void shib_mysql_destroy_handle(void* data);
139
140 /*************************************************************************
141  * The CCache here talks to a MySQL database.  The database stores
142  * three items: the cookie (session key index), the lastAccess time, and
143  * the SAMLAuthenticationStatement.  All other access is performed
144  * through the memory cache provided by shibboleth.
145  */
146
147 MYSQL* ShibMySQLCCache::getMYSQL() const
148 {
149   void* data = m_mysql->getData();
150   return (MYSQL*)data;
151 }
152
153 void ShibMySQLCCache::thread_init()
154 {
155   saml::NDC ndc("thread_init");
156
157   // Connect to the database
158   MYSQL* mysql = mysql_init(NULL);
159   if (!mysql) {
160     log->error("mysql_init failed");
161     mysql_close(mysql);
162     throw runtime_error("mysql_init()");
163   }
164
165   if (!mysql_real_connect(mysql, NULL, NULL, NULL, "shar", 0, NULL, 0)) {
166     if (initialized) {
167       log->crit("mysql_real_connect failed: %s", mysql_error(mysql));
168       throw runtime_error("mysql_real_connect");
169
170     } else {
171       log->info("mysql_real_connect failed: %s.  Trying to create",
172                 mysql_error(mysql));
173
174       // This will throw a runtime error if it fails.
175       createDatabase(mysql, PLUGIN_VER_MAJOR, PLUGIN_VER_MINOR);
176     }
177   }
178
179   int major = -1, minor = -1;
180   getVersion (mysql, &major, &minor);
181
182   // Make sure we've got the right version
183   if (major != PLUGIN_VER_MAJOR || minor != PLUGIN_VER_MINOR) {
184    
185     // If we're capable, try upgrading on the fly...
186     if (major == 0 && minor == 0) {
187        upgradeDatabase(mysql);
188     }
189     else {
190         log->crit("Invalid database version: %d.%d", major, minor);
191         throw runtime_error("Invalid Database version");
192     }
193   }
194
195   // We're all set.. Save off the handle for this thread.
196   m_mysql->setData((void*)mysql);
197 }
198
199 ShibMySQLCCache::ShibMySQLCCache(const DOMElement* e)
200 {
201   saml::NDC ndc("shibmysql::ShibMySQLCCache");
202
203   m_mysql = ThreadKey::create(&shib_mysql_destroy_handle);
204   log = &(log4cpp::Category::getInstance("shibmysql::ShibMySQLCCache"));
205
206   m_root=e;
207   initialized = false;
208   mysqlInit();
209   thread_init();
210   initialized = true;
211
212   m_cache = dynamic_cast<ISessionCache*>(
213       ShibConfig::getConfig().m_plugMgr.newPlugin(
214         "edu.internet2.middleware.shibboleth.target.provider.MemorySessionCache", e
215         )
216     );
217
218   // Initialize the cleanup thread
219   shutdown_wait = CondWait::create();
220   shutdown = false;
221   cleanup_thread = Thread::create(&cleanup_fcn, (void*)this);
222 }
223
224 ShibMySQLCCache::~ShibMySQLCCache()
225 {
226   shutdown = true;
227   shutdown_wait->signal();
228   cleanup_thread->join(NULL);
229
230   delete m_cache;
231   delete m_mysql;
232
233   // Shutdown MySQL
234   mysql_server_end();
235 }
236
237 ISessionCacheEntry* ShibMySQLCCache::find(const char* key, const IApplication* application)
238 {
239   saml::NDC ndc("mysql::find");
240   ISessionCacheEntry* res = m_cache->find(key,application);
241   if (!res) {
242
243     log->debug("Looking in database...");
244
245     // nothing cached; see if this exists in the database
246     string q = string("SELECT application_id,addr,statement FROM state WHERE cookie='") + key + "' LIMIT 1";
247
248     MYSQL_RES* rows;
249     MYSQL* mysql = getMYSQL();
250     if (mysql_query(mysql, q.c_str()))
251       log->error("Error searching for %s: %s", key, mysql_error(mysql));
252
253     rows = mysql_store_result(mysql);
254
255     // Nope, doesn't exist.
256     if (!rows)
257       return NULL;
258
259     // Make sure we got 1 and only 1 rows.
260     if (mysql_num_rows(rows) != 1) {
261       log->error("Select returned wrong number of rows: %d", mysql_num_rows(rows));
262       mysql_free_result(rows);
263       return NULL;
264     }
265
266     log->debug("Match found.  Parsing...");
267
268     // Pull apart the row and process the results
269     MYSQL_ROW row = mysql_fetch_row(rows);
270     IConfig* conf=ShibTargetConfig::getConfig().getINI();
271     Locker locker(conf);
272     const IApplication* application=conf->getApplication(row[0]);
273     if (!application) {
274         mysql_free_result(rows);
275         throw ShibTargetException(SHIBRPC_INTERNAL_ERROR,"unable to locate application for session, deleted?");
276     }
277     else if (strcmp(row[0],application->getId())) {
278         log->crit("An application (%s) attempted to access another application's (%s) session!", application->getId(), row[0]);
279         mysql_free_result(rows);
280         return NULL;
281     }
282
283     istringstream str(row[2]);
284     SAMLAuthenticationStatement *s = NULL;
285
286     // Try to parse the AuthStatement
287     try {
288       s = new SAMLAuthenticationStatement(str);
289     } catch (...) {
290       mysql_free_result(rows);
291       throw;
292     }
293
294     // Insert it into the memory cache
295     if (s)
296       m_cache->insert(key, application, s, row[1]);
297
298     // Free the results, and then re-run the 'find' query
299     mysql_free_result(rows);
300     res = m_cache->find(key,application);
301     if (!res)
302       return NULL;
303   }
304
305   return new ShibMySQLCCacheEntry(key, res, this);
306 }
307
308 void ShibMySQLCCache::insert(
309     const char* key,
310     const IApplication* application,
311     saml::SAMLAuthenticationStatement *s,
312     const char *client_addr,
313     saml::SAMLResponse* r,
314     const IRoleDescriptor* source)
315 {
316   saml::NDC ndc("mysql::insert");
317   ostringstream os;
318   os << *s;
319
320   string q = string("INSERT INTO state VALUES('") + key + "','" + application->getId() + "',NOW(),'" + client_addr + "','" + os.str() + "')";
321
322   log->debug("Query: %s", q.c_str());
323
324   // Add it to the memory cache
325   m_cache->insert(key, application, s, client_addr, r, source);
326
327   // then add it to the database
328   MYSQL* mysql = getMYSQL();
329   if (mysql_query(mysql, q.c_str()))
330     log->error("Error inserting %s: %s", key, mysql_error(mysql));
331 }
332
333 void ShibMySQLCCache::remove(const char* key)
334 {
335   saml::NDC ndc("mysql::remove");
336
337   // Remove the cached version
338   m_cache->remove(key);
339
340   // Remove from the database
341   string q = string("DELETE FROM state WHERE cookie='") + key + "'";
342   MYSQL* mysql = getMYSQL();
343   if (mysql_query(mysql, q.c_str()))
344     log->info("Error deleting entry %s: %s", key, mysql_error(mysql));
345 }
346
347 void ShibMySQLCCache::cleanup()
348 {
349   Mutex* mutex = Mutex::create();
350   saml::NDC ndc("mysql::cleanup");
351
352   thread_init();
353
354   int rerun_timer = 0;
355   int timeout_life = 0;
356
357   // Load our configuration details...
358   const XMLCh* tag=m_root->getAttributeNS(NULL,cleanupInterval);
359   if (tag && *tag)
360     rerun_timer = XMLString::parseInt(tag);
361
362   // search for 'mysql-cache-timeout' and then the regular cache timeout
363   tag=m_root->getAttributeNS(NULL,mysqlTimeout);
364   if (tag && *tag)
365     timeout_life = XMLString::parseInt(tag);
366   else {
367       tag=m_root->getAttributeNS(NULL,cacheTimeout);
368       if (tag && *tag)
369         timeout_life = XMLString::parseInt(tag);
370   }
371   
372   if (rerun_timer <= 0)
373     rerun_timer = 300;          // rerun every 5 minutes
374
375   if (timeout_life <= 0)
376     timeout_life = 28800;       // timeout after 8 hours
377
378   mutex->lock();
379
380   MYSQL* mysql = getMYSQL();
381
382   while (shutdown == false) {
383     shutdown_wait->timedwait(mutex, rerun_timer);
384
385     if (shutdown == true)
386       break;
387
388     // Find all the entries in the database that haven't been used
389     // recently In particular, find all entries that have not been
390     // accessed in 'timeout_life' seconds.
391     ostringstream q;
392     q << "SELECT cookie FROM state WHERE " <<
393       "UNIX_TIMESTAMP(NOW()) - UNIX_TIMESTAMP(atime) >= " << timeout_life;
394
395     MYSQL_RES *rows;
396     if (mysql_query(mysql, q.str().c_str()))
397       log->error("Error searching for old items: %s", mysql_error(mysql));
398
399     rows = mysql_store_result(mysql);
400     if (!rows)
401       continue;
402
403     if (mysql_num_fields(rows) != 1) {
404       log->error("Wrong number of rows, 1 != %d", mysql_num_fields(rows));
405       mysql_free_result(rows);
406       continue;
407     }
408
409     // For each row, remove the entry from the database.
410     MYSQL_ROW row;
411     while ((row = mysql_fetch_row(rows)) != NULL)
412       remove(row[0]);
413
414     mysql_free_result(rows);
415   }
416
417   log->debug("cleanup thread exiting...");
418
419   mutex->unlock();
420   delete mutex;
421   thread_end();
422   Thread::exit(NULL);
423 }
424
425 void* ShibMySQLCCache::cleanup_fcn(void* cache_p)
426 {
427   ShibMySQLCCache* cache = (ShibMySQLCCache*)cache_p;
428
429   // First, let's block all signals
430   Thread::mask_all_signals();
431
432   // Now run the cleanup process.
433   cache->cleanup();
434   return NULL;
435 }
436
437 void ShibMySQLCCache::createDatabase(MYSQL* mysql, int major, int minor)
438 {
439   log->info("Creating database.");
440
441   MYSQL* ms = NULL;
442   try {
443     ms = mysql_init(NULL);
444     if (!ms) {
445       log->crit("mysql_init failed");
446       throw ShibTargetException();
447     }
448
449     if (!mysql_real_connect(ms, NULL, NULL, NULL, NULL, 0, NULL, 0)) {
450       log->crit("cannot open DB file to create DB: %s", mysql_error(ms));
451       throw ShibTargetException();
452     }
453
454     if (mysql_query(ms, "CREATE DATABASE shar")) {
455       log->crit("cannot create shar database: %s", mysql_error(ms));
456       throw ShibTargetException();
457     }
458
459     if (!mysql_real_connect(mysql, NULL, NULL, NULL, "shar", 0, NULL, 0)) {
460       log->crit("cannot open SHAR database");
461       throw ShibTargetException();
462     }
463
464     mysql_close(ms);
465     
466   } catch (ShibTargetException&) {
467     if (ms)
468       mysql_close(ms);
469     mysql_close(mysql);
470     throw runtime_error("mysql_real_connect");
471   }
472
473   // Now create the tables if they don't exist
474   log->info("Creating database tables.");
475
476   if (mysql_query(mysql, "CREATE TABLE version (major INT, minor INT)"))
477     log->error ("Error creating version: %s", mysql_error(mysql));
478
479   if (mysql_query(mysql,
480                   "CREATE TABLE state (cookie VARCHAR(64) PRIMARY KEY, application_id VARCHAR(255),"
481                   "atime DATETIME, addr VARCHAR(128), statement TEXT)"))
482     log->error ("Error creating state: %s", mysql_error(mysql));
483
484   ostringstream q;
485   q << "INSERT INTO version VALUES(" << major << "," << minor << ")";
486   if (mysql_query(mysql, q.str().c_str()))
487     log->error ("Error setting version: %s", mysql_error(mysql));
488 }
489
490 void ShibMySQLCCache::upgradeDatabase(MYSQL* mysql)
491 {
492     if (mysql_query(mysql, "DROP TABLE state")) {
493         log->error("Error dropping old session state table: %s", mysql_error(mysql));
494     }
495
496     if (mysql_query(mysql,
497         "CREATE TABLE state (cookie VARCHAR(64) PRIMARY KEY, application_id VARCHAR(255),"
498        "atime DATETIME, addr VARCHAR(128), statement TEXT)")) {
499         log->error ("Error creating state table: %s", mysql_error(mysql));
500         throw runtime_error("error creating table");
501     }
502
503     ostringstream q;
504     q << "UPDATE version SET major = " << PLUGIN_VER_MAJOR;
505     if (mysql_query(mysql, q.str().c_str())) {
506         log->error ("Error updating version: %s", mysql_error(mysql));
507         throw runtime_error("error updating table");
508     }
509 }
510
511 void ShibMySQLCCache::getVersion(MYSQL* mysql, int* major_p, int* minor_p)
512 {
513   // grab the version number from the database
514   if (mysql_query(mysql, "SELECT * FROM version"))
515     log->error ("Error reading version: %s", mysql_error(mysql));
516
517   MYSQL_RES* rows = mysql_store_result(mysql);
518   if (rows) {
519     if (mysql_num_rows(rows) == 1 && mysql_num_fields(rows) == 2)  {
520       MYSQL_ROW row = mysql_fetch_row(rows);
521
522       int major = row[0] ? atoi(row[0]) : -1;
523       int minor = row[1] ? atoi(row[1]) : -1;
524       log->debug("opening database version %d.%d", major, minor);
525       
526       mysql_free_result (rows);
527
528       *major_p = major;
529       *minor_p = minor;
530       return;
531
532     } else {
533       // Wrong number of rows or wrong number of fields...
534
535       log->crit("Houston, we've got a problem with the database..");
536       mysql_free_result (rows);
537       throw runtime_error("Database version verification failed");
538     }
539   }
540   log->crit("MySQL Read Failed in version verificatoin");
541   throw runtime_error("MySQL Read Failed");
542 }
543
544 void ShibMySQLCCache::mysqlInit(void)
545 {
546   log->info ("Opening MySQL Database");
547
548   // Setup the argument array
549   vector<string> arg_array;
550   arg_array.push_back("shar");
551
552   // grab any MySQL parameters from the config file
553   const DOMElement* e=saml::XML::getFirstChildElement(m_root,ShibTargetConfig::SHIBTARGET_NS,Argument);
554   while (e) {
555       auto_ptr_char arg(e->getFirstChild()->getNodeValue());
556       if (arg.get())
557           arg_array.push_back(arg.get());
558       e=saml::XML::getNextSiblingElement(e,ShibTargetConfig::SHIBTARGET_NS,Argument);
559   }
560
561   // Compute the argument array
562   int arg_count = arg_array.size();
563   const char** args=new const char*[arg_count];
564   for (int i = 0; i < arg_count; i++)
565     args[i] = arg_array[i].c_str();
566
567   // Initialize MySQL with the arguments
568   mysql_server_init(arg_count, (char **)args, NULL);
569
570   delete[] args;
571 }  
572
573 /*************************************************************************
574  * The CCacheEntry here is mostly a wrapper around the "memory"
575  * cacheentry provided by shibboleth.  The only difference is that we
576  * intercept the isSessionValid() so that we can "touch()" the
577  * database if the session is still valid.
578  */
579
580 ShibMySQLCCacheEntry::ShibMySQLCCacheEntry(const char* key, ISessionCacheEntry* entry, ShibMySQLCCache* cache)
581 {
582   m_cacheEntry = entry;
583   m_key = key;
584   m_cache = cache;
585 }
586
587 bool ShibMySQLCCacheEntry::isValid(time_t lifetime, time_t timeout) const
588 {
589   bool res = m_cacheEntry->isValid(lifetime, timeout);
590   if (res == true)
591     res = touch();
592   return res;
593 }
594
595 bool ShibMySQLCCacheEntry::touch() const
596 {
597   string q=string("UPDATE state SET atime=NOW() WHERE cookie='") + m_key + "'";
598
599   MYSQL* mysql = m_cache->getMYSQL();
600   if (mysql_query(mysql, q.c_str())) {
601     m_cache->log->info("Error updating timestamp on %s: %s",
602                         m_key.c_str(), mysql_error(mysql));
603     return false;
604   }
605   return true;
606 }
607
608 /*************************************************************************
609  * The registration functions here...
610  */
611
612 IPlugIn* new_mysql_ccache(const DOMElement* e)
613 {
614   return new ShibMySQLCCache(e);
615 }
616
617 #define PLUGINTYPE "edu.internet2.middleware.shibboleth.target.provider.MySQLSessionCache"
618
619 extern "C" int SHIBMYSQL_EXPORTS saml_extension_init(void*)
620 {
621   // register this ccache type
622   ShibConfig::getConfig().m_plugMgr.regFactory(PLUGINTYPE, &new_mysql_ccache);
623   return 0;
624 }
625
626 extern "C" void SHIBMYSQL_EXPORTS saml_extension_term()
627 {
628   ShibConfig::getConfig().m_plugMgr.unregFactory(PLUGINTYPE);
629 }
630
631 /*************************************************************************
632  * Local Functions
633  */
634
635 extern "C" void shib_mysql_destroy_handle(void* data)
636 {
637   MYSQL* mysql = (MYSQL*) data;
638   mysql_close(mysql);
639 }