Starting to refactor session cache, eliminated IConfig class.
[shibboleth/cpp-sp.git] / shib-mysql-ccache / shib-mysql-ccache.cpp
index d944ff8..c6bccf4 100644 (file)
 # include <unistd.h>
 #endif
 
-#include <shib/shib-threads.h>
 #include <shib-target/shib-target.h>
+
 #include <log4cpp/Category.hh>
+#include <xmltooling/util/NDC.h>
+#include <xmltooling/util/Threads.h>
+#include <xmltooling/util/XMLHelper.h>
+#include <shibsp/SPConfig.h>
+using xmltooling::XMLHelper;
 
 #include <sstream>
 
 #include <dmalloc.h>
 #endif
 
-using namespace std;
-using namespace saml;
-using namespace shibboleth;
+using namespace shibsp;
 using namespace shibtarget;
+using namespace opensaml::saml2md;
+using namespace saml;
 using namespace log4cpp;
+using namespace std;
 
 #define PLUGIN_VER_MAJOR 3
 #define PLUGIN_VER_MINOR 0
@@ -98,6 +104,8 @@ static const XMLCh mysqlTimeout[] =
 static const XMLCh storeAttributes[] =
 { chLatin_s, chLatin_t, chLatin_o, chLatin_r, chLatin_e, chLatin_A, chLatin_t, chLatin_t, chLatin_r, chLatin_i, chLatin_b, chLatin_u, chLatin_t, chLatin_e, chLatin_s, chNull };
 
+static bool g_MySQLInitialized = false;
+
 class MySQLBase : public virtual saml::IPlugIn
 {
 public:
@@ -110,10 +118,11 @@ public:
   log4cpp::Category* log;
 
 protected:
-  ThreadKey* m_mysql;
+    xmltooling::ThreadKey* m_mysql;
   const DOMElement* m_root; // can only use this during initialization
 
   bool initialized;
+  bool handleShutdown;
 
   void createDatabase(MYSQL*, int major, int minor);
   void upgradeDatabase(MYSQL*);
@@ -132,11 +141,11 @@ extern "C" void shib_mysql_destroy_handle(void* data)
 MySQLBase::MySQLBase(const DOMElement* e) : m_root(e)
 {
 #ifdef _DEBUG
-  saml::NDC ndc("MySQLBase");
+  xmltooling::NDC ndc("MySQLBase");
 #endif
-  log = &(Category::getInstance("shibmysql.MySQLBase"));
+  log = &(Category::getInstance("shibtarget.SessionCache.MySQL"));
 
-  m_mysql = ThreadKey::create(&shib_mysql_destroy_handle);
+  m_mysql = xmltooling::ThreadKey::create(&shib_mysql_destroy_handle);
 
   initialized = false;
   mysqlInit(e,*log);
@@ -152,7 +161,7 @@ MySQLBase::~MySQLBase()
 MYSQL* MySQLBase::getMYSQL()
 {
 #ifdef _DEBUG
-    saml::NDC ndc("getMYSQL");
+    xmltooling::NDC ndc("getMYSQL");
 #endif
 
     // Do we already have a handle?
@@ -311,8 +320,10 @@ void MySQLBase::upgradeDatabase(MYSQL* mysql)
 pair<int,int> MySQLBase::getVersion(MYSQL* mysql)
 {
     // grab the version number from the database
-    if (mysql_query(mysql, "SELECT * FROM version"))
-        log->error ("Error reading version: %s", mysql_error(mysql));
+    if (mysql_query(mysql, "SELECT * FROM version")) {
+        log->error("error reading version: %s", mysql_error(mysql));
+        throw SAMLException("MySQLBase::getVersion(): error reading version");
+    }
 
     MYSQL_RES* rows = mysql_store_result(mysql);
     if (rows) {
@@ -323,21 +334,21 @@ pair<int,int> MySQLBase::getVersion(MYSQL* mysql)
           log->debug("opening database version %d.%d", major, minor);
           mysql_free_result(rows);
           return make_pair(major,minor);
-        } else {
+        }
+        else {
             // Wrong number of rows or wrong number of fields...
             log->crit("Houston, we've got a problem with the database...");
-            mysql_free_result (rows);
-            throw SAMLException("ShibMySQLCCache::getVersion(): version verification failed");
+            mysql_free_result(rows);
+            throw SAMLException("MySQLBase::getVersion(): version verification failed");
         }
     }
     log->crit("MySQL Read Failed in version verification");
-    throw SAMLException("ShibMySQLCCache::getVersion(): error reading version");
+    throw SAMLException("MySQLBase::getVersion(): error reading version");
 }
 
 static void mysqlInit(const DOMElement* e, Category& log)
 {
-    static bool done = false;
-    if (done) {
+    if (g_MySQLInitialized) {
         log.info("MySQL embedded server already initialized");
         return;
     }
@@ -348,12 +359,12 @@ static void mysqlInit(const DOMElement* e, Category& log)
     arg_array.push_back("shibboleth");
 
     // grab any MySQL parameters from the config file
-    e=saml::XML::getFirstChildElement(e,shibtarget::XML::SHIBTARGET_NS,Argument);
+    e=XMLHelper::getFirstChildElement(e,Argument);
     while (e) {
         auto_ptr_char arg(e->getFirstChild()->getNodeValue());
         if (arg.get())
             arg_array.push_back(arg.get());
-        e=saml::XML::getNextSiblingElement(e,shibtarget::XML::SHIBTARGET_NS,Argument);
+        e=XMLHelper::getNextSiblingElement(e,Argument);
     }
 
     // Compute the argument array
@@ -366,7 +377,7 @@ static void mysqlInit(const DOMElement* e, Category& log)
     mysql_server_init(arg_count, (char **)args, NULL);
 
     delete[] args;
-    done = true;
+    g_MySQLInitialized = true;
 }  
 
 class ShibMySQLCCache : public MySQLBase, virtual public ISessionCache, virtual public ISessionCacheStore
@@ -378,13 +389,13 @@ public:
     // Delegate all the ISessionCache methods.
     string insert(
         const IApplication* application,
-        const IEntityDescriptor* source,
+        const RoleDescriptor* role,
         const char* client_addr,
         const SAMLSubject* subject,
         const char* authnContext,
         const SAMLResponse* tokens
         )
-    { return m_cache->insert(application,source,client_addr,subject,authnContext,tokens); }
+    { return m_cache->insert(application,role,client_addr,subject,authnContext,tokens); }
     ISessionCacheEntry* find(const char* key, const IApplication* application, const char* client_addr)
     { return m_cache->find(key,application,client_addr); }
     void remove(const char* key, const IApplication* application, const char* client_addr)
@@ -424,9 +435,9 @@ public:
 private:
     bool m_storeAttributes;
     ISessionCache* m_cache;
-    CondWait* shutdown_wait;
+    xmltooling::CondWait* shutdown_wait;
     bool shutdown;
-    Thread* cleanup_thread;
+    xmltooling::Thread* cleanup_thread;
 
     static void* cleanup_fcn(void*); // XXX Assumed an ShibMySQLCCache
 };
@@ -434,19 +445,18 @@ private:
 ShibMySQLCCache::ShibMySQLCCache(const DOMElement* e) : MySQLBase(e), m_storeAttributes(false)
 {
 #ifdef _DEBUG
-    saml::NDC ndc("ShibMySQLCCache");
+    xmltooling::NDC ndc("ShibMySQLCCache");
 #endif
-    log = &(Category::getInstance("shibmysql.SessionCache"));
 
     m_cache = dynamic_cast<ISessionCache*>(
-        SAMLConfig::getConfig().getPlugMgr().newPlugin(shibtarget::XML::MemorySessionCacheType, e)
+        SAMLConfig::getConfig().getPlugMgr().newPlugin(MEMORY_SESSIONCACHE, e)
     );
     if (!m_cache->setBackingStore(this)) {
         delete m_cache;
         throw SAMLException("Unable to register MySQL cache plugin as a cache store.");
     }
     
-    shutdown_wait = CondWait::create();
+    shutdown_wait = xmltooling::CondWait::create();
     shutdown = false;
 
     // Load our configuration details...
@@ -455,7 +465,7 @@ ShibMySQLCCache::ShibMySQLCCache(const DOMElement* e) : MySQLBase(e), m_storeAtt
         m_storeAttributes=true;
 
     // Initialize the cleanup thread
-    cleanup_thread = Thread::create(&cleanup_fcn, (void*)this);
+    cleanup_thread = xmltooling::Thread::create(&cleanup_fcn, (void*)this);
 }
 
 ShibMySQLCCache::~ShibMySQLCCache()
@@ -476,7 +486,7 @@ HRESULT ShibMySQLCCache::onCreate(
     )
 {
 #ifdef _DEBUG
-    saml::NDC ndc("onCreate");
+    xmltooling::NDC ndc("onCreate");
 #endif
 
     // Get XML data from entry. Default is not to return SAML objects.
@@ -535,7 +545,7 @@ HRESULT ShibMySQLCCache::onRead(
     )
 {
 #ifdef _DEBUG
-    saml::NDC ndc("onRead");
+    xmltooling::NDC ndc("onRead");
 #endif
 
     log->debug("searching MySQL database...");
@@ -555,17 +565,21 @@ HRESULT ShibMySQLCCache::onRead(
     MYSQL_RES* rows = mysql_store_result(mysql);
 
     // Nope, doesn't exist.
-    if (!rows)
+    if (!rows || mysql_num_rows(rows)==0) {
+        log->debug("not found in database");
+        if (rows)
+            mysql_free_result(rows);
         return S_FALSE;
+    }
 
-    // Make sure we got 1 and only 1 rows.
-    if (mysql_num_rows(rows) != 1) {
-        log->error("Database select returned wrong number of rows: %d", mysql_num_rows(rows));
+    // Make sure we got 1 and only 1 row.
+    if (mysql_num_rows(rows) > 1) {
+        log->error("database select returned %d rows!", mysql_num_rows(rows));
         mysql_free_result(rows);
-        return S_FALSE;
+        return E_FAIL;
     }
 
-    log->debug("match found, tranfering data back into memory");
+    log->debug("session found, tranfering data back into memory");
     
     /* Columns in query:
         0: application_id
@@ -602,7 +616,7 @@ HRESULT ShibMySQLCCache::onRead(
 HRESULT ShibMySQLCCache::onRead(const char* key, time_t& accessed)
 {
 #ifdef _DEBUG
-    saml::NDC ndc("onRead");
+    xmltooling::NDC ndc("onRead");
 #endif
 
     log->debug("reading last access time from MySQL database");
@@ -622,14 +636,18 @@ HRESULT ShibMySQLCCache::onRead(const char* key, time_t& accessed)
     MYSQL_RES* rows = mysql_store_result(mysql);
 
     // Nope, doesn't exist.
-    if (!rows)
+    if (!rows || mysql_num_rows(rows)==0) {
+        log->warn("session expected, but not found in database");
+        if (rows)
+            mysql_free_result(rows);
         return S_FALSE;
+    }
 
-    // Make sure we got 1 and only 1 rows.
+    // Make sure we got 1 and only 1 row.
     if (mysql_num_rows(rows) != 1) {
-        log->error("database select returned wrong number of rows: %d", mysql_num_rows(rows));
+        log->error("database select returned %d rows!", mysql_num_rows(rows));
         mysql_free_result(rows);
-        return S_FALSE;
+        return E_FAIL;
     }
 
     MYSQL_ROW row = mysql_fetch_row(rows);
@@ -644,7 +662,7 @@ HRESULT ShibMySQLCCache::onRead(const char* key, time_t& accessed)
 HRESULT ShibMySQLCCache::onRead(const char* key, string& tokens)
 {
 #ifdef _DEBUG
-    saml::NDC ndc("onRead");
+    xmltooling::NDC ndc("onRead");
 #endif
 
     if (!m_storeAttributes)
@@ -667,14 +685,18 @@ HRESULT ShibMySQLCCache::onRead(const char* key, string& tokens)
     MYSQL_RES* rows = mysql_store_result(mysql);
 
     // Nope, doesn't exist.
-    if (!rows)
+    if (!rows || mysql_num_rows(rows)==0) {
+        log->warn("session expected, but not found in database");
+        if (rows)
+            mysql_free_result(rows);
         return S_FALSE;
+    }
 
-    // Make sure we got 1 and only 1 rows.
+    // Make sure we got 1 and only 1 row.
     if (mysql_num_rows(rows) != 1) {
-        log->error("database select returned wrong number of rows: %d", mysql_num_rows(rows));
+        log->error("database select returned %d rows!", mysql_num_rows(rows));
         mysql_free_result(rows);
-        return S_FALSE;
+        return E_FAIL;
     }
 
     MYSQL_ROW row = mysql_fetch_row(rows);
@@ -690,7 +712,7 @@ HRESULT ShibMySQLCCache::onRead(const char* key, string& tokens)
 HRESULT ShibMySQLCCache::onUpdate(const char* key, const char* tokens, time_t lastAccess)
 {
 #ifdef _DEBUG
-    saml::NDC ndc("onUpdate");
+    xmltooling::NDC ndc("onUpdate");
 #endif
 
     ostringstream q;
@@ -733,7 +755,7 @@ HRESULT ShibMySQLCCache::onUpdate(const char* key, const char* tokens, time_t la
 HRESULT ShibMySQLCCache::onDelete(const char* key)
 {
 #ifdef _DEBUG
-    saml::NDC ndc("onDelete");
+    xmltooling::NDC ndc("onDelete");
 #endif
 
     // Remove from the database
@@ -759,10 +781,10 @@ HRESULT ShibMySQLCCache::onDelete(const char* key)
 void ShibMySQLCCache::cleanup()
 {
 #ifdef _DEBUG
-  saml::NDC ndc("cleanup");
+  xmltooling::NDC ndc("cleanup");
 #endif
 
-  Mutex* mutex = Mutex::create();
+  xmltooling::Mutex* mutex = xmltooling::Mutex::create();
 
   int rerun_timer = 0;
   int timeout_life = 0;
@@ -820,15 +842,17 @@ void ShibMySQLCCache::cleanup()
 
   mutex->unlock();
   delete mutex;
-  Thread::exit(NULL);
+  xmltooling::Thread::exit(NULL);
 }
 
 void* ShibMySQLCCache::cleanup_fcn(void* cache_p)
 {
   ShibMySQLCCache* cache = (ShibMySQLCCache*)cache_p;
 
-  // First, let's block all signals
-  Thread::mask_all_signals();
+#ifndef WIN32
+  // First, let'block all signals
+  xmltooling::Thread::mask_all_signals();
+#endif
 
   // Now run the cleanup process.
   cache->cleanup();
@@ -845,19 +869,12 @@ public:
   bool check(const char* str, time_t expires);
 };
 
-MySQLReplayCache::MySQLReplayCache(const DOMElement* e) : MySQLBase(e)
-{
-#ifdef _DEBUG
-  saml::NDC ndc("MySQLReplayCache");
-#endif
-
-  log = &(Category::getInstance("shibmysql.ReplayCache"));
-}
+MySQLReplayCache::MySQLReplayCache(const DOMElement* e) : MySQLBase(e) {}
 
 bool MySQLReplayCache::check(const char* str, time_t expires)
 {
 #ifdef _DEBUG
-    saml::NDC ndc("check");
+    xmltooling::NDC ndc("check");
 #endif
   
     // Remove expired entries
@@ -919,7 +936,7 @@ bool MySQLReplayCache::check(const char* str, time_t expires)
  * The registration functions here...
  */
 
-IPlugIn* new_mysql_ccache(const DOMElement* e)
+SessionCache* new_mysql_ccache(const DOMElement* const & e)
 {
     return new ShibMySQLCCache(e);
 }
@@ -932,15 +949,15 @@ IPlugIn* new_mysql_replay(const DOMElement* e)
 extern "C" int SHIBMYSQL_EXPORTS saml_extension_init(void*)
 {
     // register this ccache type
-    SAMLConfig::getConfig().getPlugMgr().regFactory(shibtarget::XML::MySQLReplayCacheType, &new_mysql_replay);
-    SAMLConfig::getConfig().getPlugMgr().regFactory(shibtarget::XML::MySQLSessionCacheType, &new_mysql_ccache);
+    SAMLConfig::getConfig().getPlugMgr().regFactory(MYSQL_REPLAYCACHE, &new_mysql_replay);
+    SPConfig::getConfig().SessionCacheManager.registerFactory(MYSQL_SESSIONCACHE, &new_mysql_ccache);
     return 0;
 }
 
 extern "C" void SHIBMYSQL_EXPORTS saml_extension_term()
 {
     // Shutdown MySQL
-    mysql_server_end();
-    SAMLConfig::getConfig().getPlugMgr().unregFactory(shibtarget::XML::MySQLReplayCacheType);
-    SAMLConfig::getConfig().getPlugMgr().unregFactory(shibtarget::XML::MySQLSessionCacheType);
+    if (g_MySQLInitialized)
+        mysql_server_end();
+    SAMLConfig::getConfig().getPlugMgr().unregFactory(MYSQL_REPLAYCACHE);
 }