Next integration phase, metadata and trust conversion.
[shibboleth/sp.git] / shib-mysql-ccache / shib-mysql-ccache.cpp
index 74ac433..e8f8321 100644 (file)
@@ -39,8 +39,9 @@
 # include <unistd.h>
 #endif
 
-#include <shib/shib-threads.h>
 #include <shib-target/shib-target.h>
+
+#include <xmltooling/util/NDC.h>
 #include <log4cpp/Category.hh>
 
 #include <sstream>
 #include <dmalloc.h>
 #endif
 
-using namespace std;
-using namespace saml;
-using namespace shibboleth;
 using namespace shibtarget;
+using namespace opensaml::saml2md;
+using namespace saml;
 using namespace log4cpp;
+using namespace std;
 
 #define PLUGIN_VER_MAJOR 3
 #define PLUGIN_VER_MINOR 0
@@ -98,6 +99,8 @@ static const XMLCh mysqlTimeout[] =
 static const XMLCh storeAttributes[] =
 { chLatin_s, chLatin_t, chLatin_o, chLatin_r, chLatin_e, chLatin_A, chLatin_t, chLatin_t, chLatin_r, chLatin_i, chLatin_b, chLatin_u, chLatin_t, chLatin_e, chLatin_s, chNull };
 
+static bool g_MySQLInitialized = false;
+
 class MySQLBase : public virtual saml::IPlugIn
 {
 public:
@@ -110,10 +113,11 @@ public:
   log4cpp::Category* log;
 
 protected:
-  ThreadKey* m_mysql;
+    xmltooling::ThreadKey* m_mysql;
   const DOMElement* m_root; // can only use this during initialization
 
   bool initialized;
+  bool handleShutdown;
 
   void createDatabase(MYSQL*, int major, int minor);
   void upgradeDatabase(MYSQL*);
@@ -132,11 +136,11 @@ extern "C" void shib_mysql_destroy_handle(void* data)
 MySQLBase::MySQLBase(const DOMElement* e) : m_root(e)
 {
 #ifdef _DEBUG
-  saml::NDC ndc("MySQLBase");
+  xmltooling::NDC ndc("MySQLBase");
 #endif
-  log = &(Category::getInstance("shibmysql.MySQLBase"));
+  log = &(Category::getInstance("shibtarget.SessionCache.MySQL"));
 
-  m_mysql = ThreadKey::create(&shib_mysql_destroy_handle);
+  m_mysql = xmltooling::ThreadKey::create(&shib_mysql_destroy_handle);
 
   initialized = false;
   mysqlInit(e,*log);
@@ -152,7 +156,7 @@ MySQLBase::~MySQLBase()
 MYSQL* MySQLBase::getMYSQL()
 {
 #ifdef _DEBUG
-    saml::NDC ndc("getMYSQL");
+    xmltooling::NDC ndc("getMYSQL");
 #endif
 
     // Do we already have a handle?
@@ -311,8 +315,10 @@ void MySQLBase::upgradeDatabase(MYSQL* mysql)
 pair<int,int> MySQLBase::getVersion(MYSQL* mysql)
 {
     // grab the version number from the database
-    if (mysql_query(mysql, "SELECT * FROM version"))
-        log->error ("Error reading version: %s", mysql_error(mysql));
+    if (mysql_query(mysql, "SELECT * FROM version")) {
+        log->error("error reading version: %s", mysql_error(mysql));
+        throw SAMLException("MySQLBase::getVersion(): error reading version");
+    }
 
     MYSQL_RES* rows = mysql_store_result(mysql);
     if (rows) {
@@ -323,21 +329,21 @@ pair<int,int> MySQLBase::getVersion(MYSQL* mysql)
           log->debug("opening database version %d.%d", major, minor);
           mysql_free_result(rows);
           return make_pair(major,minor);
-        } else {
+        }
+        else {
             // Wrong number of rows or wrong number of fields...
             log->crit("Houston, we've got a problem with the database...");
-            mysql_free_result (rows);
-            throw SAMLException("ShibMySQLCCache::getVersion(): version verification failed");
+            mysql_free_result(rows);
+            throw SAMLException("MySQLBase::getVersion(): version verification failed");
         }
     }
     log->crit("MySQL Read Failed in version verification");
-    throw SAMLException("ShibMySQLCCache::getVersion(): error reading version");
+    throw SAMLException("MySQLBase::getVersion(): error reading version");
 }
 
 static void mysqlInit(const DOMElement* e, Category& log)
 {
-    static bool done = false;
-    if (done) {
+    if (g_MySQLInitialized) {
         log.info("MySQL embedded server already initialized");
         return;
     }
@@ -366,7 +372,7 @@ static void mysqlInit(const DOMElement* e, Category& log)
     mysql_server_init(arg_count, (char **)args, NULL);
 
     delete[] args;
-    done = true;
+    g_MySQLInitialized = true;
 }  
 
 class ShibMySQLCCache : public MySQLBase, virtual public ISessionCache, virtual public ISessionCacheStore
@@ -378,13 +384,13 @@ public:
     // Delegate all the ISessionCache methods.
     string insert(
         const IApplication* application,
-        const IEntityDescriptor* source,
+        const RoleDescriptor* role,
         const char* client_addr,
         const SAMLSubject* subject,
         const char* authnContext,
         const SAMLResponse* tokens
         )
-    { return m_cache->insert(application,source,client_addr,subject,authnContext,tokens); }
+    { return m_cache->insert(application,role,client_addr,subject,authnContext,tokens); }
     ISessionCacheEntry* find(const char* key, const IApplication* application, const char* client_addr)
     { return m_cache->find(key,application,client_addr); }
     void remove(const char* key, const IApplication* application, const char* client_addr)
@@ -414,7 +420,9 @@ public:
         time_t& created,
         time_t& accessed
         );
-    HRESULT onUpdate(const char* key, const char* tokens=NULL, time_t lastAccess=0);
+    HRESULT onRead(const char* key, time_t& accessed);
+    HRESULT onRead(const char* key, string& tokens);
+    HRESULT onUpdate(const char* key, const char* tokens=NULL, time_t accessed=0);
     HRESULT onDelete(const char* key);
 
     void cleanup();
@@ -422,9 +430,9 @@ public:
 private:
     bool m_storeAttributes;
     ISessionCache* m_cache;
-    CondWait* shutdown_wait;
+    xmltooling::CondWait* shutdown_wait;
     bool shutdown;
-    Thread* cleanup_thread;
+    xmltooling::Thread* cleanup_thread;
 
     static void* cleanup_fcn(void*); // XXX Assumed an ShibMySQLCCache
 };
@@ -432,9 +440,8 @@ private:
 ShibMySQLCCache::ShibMySQLCCache(const DOMElement* e) : MySQLBase(e), m_storeAttributes(false)
 {
 #ifdef _DEBUG
-    saml::NDC ndc("ShibMySQLCCache");
+    xmltooling::NDC ndc("ShibMySQLCCache");
 #endif
-    log = &(Category::getInstance("shibmysql.SessionCache"));
 
     m_cache = dynamic_cast<ISessionCache*>(
         SAMLConfig::getConfig().getPlugMgr().newPlugin(shibtarget::XML::MemorySessionCacheType, e)
@@ -444,7 +451,7 @@ ShibMySQLCCache::ShibMySQLCCache(const DOMElement* e) : MySQLBase(e), m_storeAtt
         throw SAMLException("Unable to register MySQL cache plugin as a cache store.");
     }
     
-    shutdown_wait = CondWait::create();
+    shutdown_wait = xmltooling::CondWait::create();
     shutdown = false;
 
     // Load our configuration details...
@@ -453,7 +460,7 @@ ShibMySQLCCache::ShibMySQLCCache(const DOMElement* e) : MySQLBase(e), m_storeAtt
         m_storeAttributes=true;
 
     // Initialize the cleanup thread
-    cleanup_thread = Thread::create(&cleanup_fcn, (void*)this);
+    cleanup_thread = xmltooling::Thread::create(&cleanup_fcn, (void*)this);
 }
 
 ShibMySQLCCache::~ShibMySQLCCache()
@@ -474,7 +481,7 @@ HRESULT ShibMySQLCCache::onCreate(
     )
 {
 #ifdef _DEBUG
-    saml::NDC ndc("onCreate");
+    xmltooling::NDC ndc("onCreate");
 #endif
 
     // Get XML data from entry. Default is not to return SAML objects.
@@ -533,7 +540,7 @@ HRESULT ShibMySQLCCache::onRead(
     )
 {
 #ifdef _DEBUG
-    saml::NDC ndc("onRead");
+    xmltooling::NDC ndc("onRead");
 #endif
 
     log->debug("searching MySQL database...");
@@ -553,17 +560,21 @@ HRESULT ShibMySQLCCache::onRead(
     MYSQL_RES* rows = mysql_store_result(mysql);
 
     // Nope, doesn't exist.
-    if (!rows)
+    if (!rows || mysql_num_rows(rows)==0) {
+        log->debug("not found in database");
+        if (rows)
+            mysql_free_result(rows);
         return S_FALSE;
+    }
 
-    // Make sure we got 1 and only 1 rows.
-    if (mysql_num_rows(rows) != 1) {
-        log->error("Database select returned wrong number of rows: %d", mysql_num_rows(rows));
+    // Make sure we got 1 and only 1 row.
+    if (mysql_num_rows(rows) > 1) {
+        log->error("database select returned %d rows!", mysql_num_rows(rows));
         mysql_free_result(rows);
-        return S_FALSE;
+        return E_FAIL;
     }
 
-    log->debug("match found, tranfering data back into memory...");
+    log->debug("session found, tranfering data back into memory");
     
     /* Columns in query:
         0: application_id
@@ -597,22 +608,120 @@ HRESULT ShibMySQLCCache::onRead(
     return NOERROR;
 }
 
+HRESULT ShibMySQLCCache::onRead(const char* key, time_t& accessed)
+{
+#ifdef _DEBUG
+    xmltooling::NDC ndc("onRead");
+#endif
+
+    log->debug("reading last access time from MySQL database");
+
+    string q = string("SELECT UNIX_TIMESTAMP(atime) FROM state WHERE cookie='") + key + "' LIMIT 1";
+
+    MYSQL* mysql = getMYSQL();
+    if (mysql_query(mysql, q.c_str())) {
+        const char* err=mysql_error(mysql);
+        log->error("error searching for %s: %s", key, err);
+        if (isCorrupt(err) && repairTable(mysql,"state")) {
+            if (mysql_query(mysql, q.c_str()))
+                log->error("error retrying search for %s: %s", key, mysql_error(mysql));
+        }
+    }
+
+    MYSQL_RES* rows = mysql_store_result(mysql);
+
+    // Nope, doesn't exist.
+    if (!rows || mysql_num_rows(rows)==0) {
+        log->warn("session expected, but not found in database");
+        if (rows)
+            mysql_free_result(rows);
+        return S_FALSE;
+    }
+
+    // Make sure we got 1 and only 1 row.
+    if (mysql_num_rows(rows) != 1) {
+        log->error("database select returned %d rows!", mysql_num_rows(rows));
+        mysql_free_result(rows);
+        return E_FAIL;
+    }
+
+    MYSQL_ROW row = mysql_fetch_row(rows);
+    accessed=atoi(row[0]);
+
+    // Free the results.
+    mysql_free_result(rows);
+
+    return NOERROR;
+}
+
+HRESULT ShibMySQLCCache::onRead(const char* key, string& tokens)
+{
+#ifdef _DEBUG
+    xmltooling::NDC ndc("onRead");
+#endif
+
+    if (!m_storeAttributes)
+        return S_FALSE;
+
+    log->debug("reading cached tokens from MySQL database");
+
+    string q = string("SELECT tokens FROM state WHERE cookie='") + key + "' LIMIT 1";
+
+    MYSQL* mysql = getMYSQL();
+    if (mysql_query(mysql, q.c_str())) {
+        const char* err=mysql_error(mysql);
+        log->error("error searching for %s: %s", key, err);
+        if (isCorrupt(err) && repairTable(mysql,"state")) {
+            if (mysql_query(mysql, q.c_str()))
+                log->error("error retrying search for %s: %s", key, mysql_error(mysql));
+        }
+    }
+
+    MYSQL_RES* rows = mysql_store_result(mysql);
+
+    // Nope, doesn't exist.
+    if (!rows || mysql_num_rows(rows)==0) {
+        log->warn("session expected, but not found in database");
+        if (rows)
+            mysql_free_result(rows);
+        return S_FALSE;
+    }
+
+    // Make sure we got 1 and only 1 row.
+    if (mysql_num_rows(rows) != 1) {
+        log->error("database select returned %d rows!", mysql_num_rows(rows));
+        mysql_free_result(rows);
+        return E_FAIL;
+    }
+
+    MYSQL_ROW row = mysql_fetch_row(rows);
+    if (row[0])
+        tokens=row[0];
+
+    // Free the results.
+    mysql_free_result(rows);
+
+    return NOERROR;
+}
+
 HRESULT ShibMySQLCCache::onUpdate(const char* key, const char* tokens, time_t lastAccess)
 {
 #ifdef _DEBUG
-    saml::NDC ndc("onUpdate");
+    xmltooling::NDC ndc("onUpdate");
 #endif
 
     ostringstream q;
-    if (tokens) {
+    if (lastAccess>0)
+        q << "UPDATE state SET atime=FROM_UNIXTIME(" << lastAccess << ")";
+    else if (tokens) {
+        if (!m_storeAttributes)
+            return S_FALSE;
         q << "UPDATE state SET tokens=";
         if (*tokens)
             q << "'" << tokens << "'";
         else
             q << "null";
     }
-    else if (lastAccess>0)
-        q << "UPDATE state SET atime=FROM_UNIXTIME(" << lastAccess << ")";
     else {
         log->warn("onUpdate called with nothing to do!");
         return S_FALSE;
@@ -641,7 +750,7 @@ HRESULT ShibMySQLCCache::onUpdate(const char* key, const char* tokens, time_t la
 HRESULT ShibMySQLCCache::onDelete(const char* key)
 {
 #ifdef _DEBUG
-    saml::NDC ndc("onDelete");
+    xmltooling::NDC ndc("onDelete");
 #endif
 
     // Remove from the database
@@ -667,10 +776,10 @@ HRESULT ShibMySQLCCache::onDelete(const char* key)
 void ShibMySQLCCache::cleanup()
 {
 #ifdef _DEBUG
-  saml::NDC ndc("cleanup");
+  xmltooling::NDC ndc("cleanup");
 #endif
 
-  Mutex* mutex = Mutex::create();
+  xmltooling::Mutex* mutex = xmltooling::Mutex::create();
 
   int rerun_timer = 0;
   int timeout_life = 0;
@@ -728,15 +837,17 @@ void ShibMySQLCCache::cleanup()
 
   mutex->unlock();
   delete mutex;
-  Thread::exit(NULL);
+  xmltooling::Thread::exit(NULL);
 }
 
 void* ShibMySQLCCache::cleanup_fcn(void* cache_p)
 {
   ShibMySQLCCache* cache = (ShibMySQLCCache*)cache_p;
 
-  // First, let's block all signals
-  Thread::mask_all_signals();
+#ifndef WIN32
+  // First, let'block all signals
+  xmltooling::Thread::mask_all_signals();
+#endif
 
   // Now run the cleanup process.
   cache->cleanup();
@@ -753,19 +864,12 @@ public:
   bool check(const char* str, time_t expires);
 };
 
-MySQLReplayCache::MySQLReplayCache(const DOMElement* e) : MySQLBase(e)
-{
-#ifdef _DEBUG
-  saml::NDC ndc("MySQLReplayCache");
-#endif
-
-  log = &(Category::getInstance("shibmysql.ReplayCache"));
-}
+MySQLReplayCache::MySQLReplayCache(const DOMElement* e) : MySQLBase(e) {}
 
 bool MySQLReplayCache::check(const char* str, time_t expires)
 {
 #ifdef _DEBUG
-    saml::NDC ndc("check");
+    xmltooling::NDC ndc("check");
 #endif
   
     // Remove expired entries
@@ -848,7 +952,8 @@ extern "C" int SHIBMYSQL_EXPORTS saml_extension_init(void*)
 extern "C" void SHIBMYSQL_EXPORTS saml_extension_term()
 {
     // Shutdown MySQL
-    mysql_server_end();
+    if (g_MySQLInitialized)
+        mysql_server_end();
     SAMLConfig::getConfig().getPlugMgr().unregFactory(shibtarget::XML::MySQLReplayCacheType);
     SAMLConfig::getConfig().getPlugMgr().unregFactory(shibtarget::XML::MySQLSessionCacheType);
 }