Move Shib constants to new lib, fixed symbol conflicts.
[shibboleth/cpp-sp.git] / odbc_ccache / odbc-ccache.cpp
index 88d84d5..350bb27 100644 (file)
 #ifdef WIN32
 # define _CRT_NONSTDC_NO_DEPRECATE 1
 # define _CRT_SECURE_NO_DEPRECATE 1
+# define NOMINMAX
 # define SHIBODBC_EXPORTS __declspec(dllexport)
 #else
 # define SHIBODBC_EXPORTS
 #endif
 
-#include <shib/shib-threads.h>
 #include <shib-target/shib-target.h>
 #include <log4cpp/Category.hh>
+#include <xmltooling/util/NDC.h>
 
+#include <ctime>
+#include <algorithm>
 #include <sstream>
 
 #include <sql.h>
 #include <dmalloc.h>
 #endif
 
-using namespace std;
-using namespace saml;
-using namespace shibboleth;
 using namespace shibtarget;
+using namespace shibboleth;
+using namespace saml;
+using namespace xmltooling;
 using namespace log4cpp;
+using namespace std;
 
 #define PLUGIN_VER_MAJOR 3
 #define PLUGIN_VER_MINOR 0
@@ -63,7 +67,7 @@ using namespace log4cpp;
 #define COLSIZE_APPLICATION_ID 256
 #define COLSIZE_ADDRESS 128
 #define COLSIZE_PROVIDER_ID 256
-#define LONGDATA_BUFLEN 2048
+#define LONGDATA_BUFLEN 32768
 
 /*
   CREATE TABLE state (
@@ -116,10 +120,9 @@ public:
 
     SQLHDBC getHDBC();
 
-    log4cpp::Category* log;
+    Category* log;
 
 protected:
-    //ThreadKey* m_mysql;
     const DOMElement* m_root; // can only use this during initialization
     string m_connstring;
 
@@ -137,7 +140,7 @@ const char* ODBCBase::p_connstring = NULL;
 ODBCBase::ODBCBase(const DOMElement* e) : m_root(e), m_bInitializedODBC(false)
 {
 #ifdef _DEBUG
-    saml::NDC ndc("ODBCBase");
+    xmltooling::NDC ndc("ODBCBase");
 #endif
     log = &(Category::getInstance("shibtarget.ODBC"));
 
@@ -166,7 +169,7 @@ ODBCBase::ODBCBase(const DOMElement* e) : m_root(e), m_bInitializedODBC(false)
         m_connstring=p_connstring;
     }
     else {
-        auto_ptr_char arg(e->getFirstChild()->getNodeValue());
+        xmltooling::auto_ptr_char arg(e->getFirstChild()->getNodeValue());
         m_connstring=arg.get();
         p_connstring=m_connstring.c_str();
     }
@@ -213,7 +216,7 @@ void ODBCBase::log_error(SQLHANDLE handle, SQLSMALLINT htype)
 SQLHDBC ODBCBase::getHDBC()
 {
 #ifdef _DEBUG
-    saml::NDC ndc("getMYSQL");
+    xmltooling::NDC ndc("getMYSQL");
 #endif
 
     // Get a handle.
@@ -318,9 +321,9 @@ public:
 private:
     bool m_storeAttributes;
     ISessionCache* m_cache;
-    CondWait* shutdown_wait;
+    xmltooling::CondWait* shutdown_wait;
     bool shutdown;
-    Thread* cleanup_thread;
+    xmltooling::Thread* cleanup_thread;
 
     static void* cleanup_fcn(void*); // XXX Assumed an ODBCCCache
 };
@@ -328,7 +331,7 @@ private:
 ODBCCCache::ODBCCCache(const DOMElement* e) : ODBCBase(e), m_storeAttributes(false)
 {
 #ifdef _DEBUG
-    saml::NDC ndc("ODBCCCache");
+    xmltooling::NDC ndc("ODBCCCache");
 #endif
     log = &(Category::getInstance("shibtarget.SessionCache.ODBC"));
 
@@ -360,6 +363,18 @@ ODBCCCache::~ODBCCCache()
     delete m_cache;
 }
 
+void appendXML(ostream& os, const char* str)
+{
+    const char* pos=strchr(str,'\'');
+    while (pos) {
+        os.write(str,pos-str);
+        os << "''";
+        str=pos+1;
+        pos=strchr(str,'\'');
+    }
+    os << str;
+}
+
 HRESULT ODBCCCache::onCreate(
     const char* key,
     const IApplication* application,
@@ -370,7 +385,7 @@ HRESULT ODBCCCache::onCreate(
     )
 {
 #ifdef _DEBUG
-    saml::NDC ndc("onCreate");
+    xmltooling::NDC ndc("onCreate");
 #endif
 
     // Get XML data from entry. Default is not to return SAML objects.
@@ -394,8 +409,18 @@ HRESULT ODBCCCache::onCreate(
     ostringstream q;
     q << "INSERT state VALUES ('" << key << "','" << application->getId() << "'," << timebuf << "," << timebuf
         << ",'" << entry->getClientAddress() << "'," << majorVersion << "," << minorVersion << ",'" << entry->getProviderId()
-        << "',?,?,?)";
-
+        << "','";
+    appendXML(q,subject.first);
+    q << "','";
+    appendXML(q,context);
+    q << "',";
+    if (m_storeAttributes && tokens.first) {
+        q << "'";
+        appendXML(q,tokens.first);
+       q << "')";
+    }
+    else
+        q << "null)";
     if (log->isDebugEnabled())
         log->debug("SQL insert: %s", q.str().c_str());
 
@@ -404,42 +429,9 @@ HRESULT ODBCCCache::onCreate(
     ODBCConn conn(getHDBC());
     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
 
-    // Bind text parameters to statement.
-    SQLINTEGER cbSubject=SQL_LEN_DATA_AT_EXEC(0),cbContext=SQL_LEN_DATA_AT_EXEC(0),cbTokens;
-    if (!m_storeAttributes || !tokens.first)
-        cbTokens=SQL_NULL_DATA;
-    else
-        cbTokens=SQL_LEN_DATA_AT_EXEC(0);
-    SQLRETURN sr=SQLBindParameter(hstmt,1,SQL_PARAM_INPUT,SQL_C_CHAR,SQL_LONGVARCHAR,LONGDATA_BUFLEN,0,(SQLPOINTER)subject.first,0,&cbSubject);
-    if (!SQL_SUCCEEDED(sr))
-        log_error(hstmt, SQL_HANDLE_STMT);
-    sr=SQLBindParameter(hstmt,2,SQL_PARAM_INPUT,SQL_C_CHAR,SQL_LONGVARCHAR,LONGDATA_BUFLEN,0,(SQLPOINTER)context,0,&cbContext);
-    if (!SQL_SUCCEEDED(sr))
-        log_error(hstmt, SQL_HANDLE_STMT);
-    sr=SQLBindParameter(hstmt,3,SQL_PARAM_INPUT,SQL_C_CHAR,SQL_LONGVARCHAR,LONGDATA_BUFLEN,0,(SQLPOINTER)tokens.first,0,&cbTokens);
-    if (!SQL_SUCCEEDED(sr))
-        log_error(hstmt, SQL_HANDLE_STMT);
-
     // Execute statement.
-    sr=SQLExecDirect(hstmt, (SQLCHAR*)q.str().c_str(), SQL_NTS);
-    if (sr==SQL_NEED_DATA) {
-        // Loop to send text data into driver.
-        // pData is set each round by the driver to the pointers we bound above.
-        char* pData;
-        sr=SQLParamData(hstmt,(SQLPOINTER*)&pData);
-        while (sr==SQL_NEED_DATA) {
-            size_t len=strlen(pData);
-            while (len>0) {
-                size_t amt = min(LONGDATA_BUFLEN,len);
-                SQLPutData(hstmt, pData, amt);
-                pData += amt;
-                len = len - amt;
-            }
-            sr=SQLParamData(hstmt,(SQLPOINTER*)&pData);
-       }
-    }
-
     HRESULT hr=NOERROR;
+    SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.str().c_str(), SQL_NTS);
     if (!SQL_SUCCEEDED(sr)) {
         log->error("failed to insert record into database");
         log_error(hstmt, SQL_HANDLE_STMT);
@@ -465,7 +457,7 @@ HRESULT ODBCCCache::onRead(
     )
 {
 #ifdef _DEBUG
-    saml::NDC ndc("onRead");
+    xmltooling::NDC ndc("onRead");
 #endif
 
     log->debug("searching database...");
@@ -564,7 +556,7 @@ HRESULT ODBCCCache::onRead(
 HRESULT ODBCCCache::onRead(const char* key, time_t& accessed)
 {
 #ifdef _DEBUG
-    saml::NDC ndc("onRead");
+    xmltooling::NDC ndc("onRead");
 #endif
 
     log->debug("reading last access time from database");
@@ -613,7 +605,7 @@ HRESULT ODBCCCache::onRead(const char* key, time_t& accessed)
 HRESULT ODBCCCache::onRead(const char* key, string& tokens)
 {
 #ifdef _DEBUG
-    saml::NDC ndc("onRead");
+    xmltooling::NDC ndc("onRead");
 #endif
 
     if (!m_storeAttributes)
@@ -660,12 +652,10 @@ HRESULT ODBCCCache::onRead(const char* key, string& tokens)
 HRESULT ODBCCCache::onUpdate(const char* key, const char* tokens, time_t lastAccess)
 {
 #ifdef _DEBUG
-    saml::NDC ndc("onUpdate");
+    xmltooling::NDC ndc("onUpdate");
 #endif
 
-    SQLRETURN sr;
-    SQLHSTMT hstmt;
-    ODBCConn conn(getHDBC());
+    ostringstream q;
 
     if (lastAccess>0) {
 #ifndef HAVE_GMTIME_R
@@ -677,48 +667,31 @@ HRESULT ODBCCCache::onUpdate(const char* key, const char* tokens, time_t lastAcc
         char timebuf[32];
         strftime(timebuf,32,"{ts '%Y-%m-%d %H:%M:%S'}",ptime);
 
-        ostringstream q;
         q << "UPDATE state SET atime=" << timebuf << " WHERE cookie='" << key << "'";
-
-        SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
-        sr=SQLExecDirect(hstmt, (SQLCHAR*)q.str().c_str(), SQL_NTS);
     }
     else if (tokens) {
         if (!m_storeAttributes)
             return S_FALSE;
-        string q = string("UPDATE state SET tokens=? WHERE cookie='") + key + "'";
-
-        SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
-
-        // Bind text parameters to statement.
-        SQLINTEGER cbTokens = tokens ? SQL_LEN_DATA_AT_EXEC(0) : SQL_NULL_DATA;
-        sr=SQLBindParameter(hstmt,1,SQL_PARAM_INPUT,SQL_C_CHAR,SQL_LONGVARCHAR,LONGDATA_BUFLEN,0,(SQLPOINTER)tokens,0,&cbTokens);
-
-        // Execute statement.
-        sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
-        if (sr==SQL_NEED_DATA) {
-            // Loop to send text data into driver.
-            // pData is set each round by the driver to the pointers we bound above.
-            char* pData;
-            sr=SQLParamData(hstmt,(SQLPOINTER*)&pData);
-            while (sr==SQL_NEED_DATA) {
-                size_t len=strlen(pData);
-                while (len>0) {
-                    size_t amt=min(LONGDATA_BUFLEN,len);
-                    SQLPutData(hstmt, pData, amt);
-                    pData += amt;
-                    len = len - amt;
-                }
-                sr=SQLParamData(hstmt,(SQLPOINTER*)&pData);
-           }
-        }
+        q << "UPDATE state SET tokens=";
+       if (*tokens) {
+           q << "'";
+           appendXML(q,tokens);
+           q << "' ";
+       }
+       else
+           q << "null ";
+       q << "WHERE cookie='" << key << "'";
     }
     else {
         log->warn("onUpdate called with nothing to do!");
         return S_FALSE;
     }
  
-    HRESULT hr;
+    HRESULT hr=NOERROR;
+    SQLHSTMT hstmt;
+    ODBCConn conn(getHDBC());
+    SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
+    SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.str().c_str(), SQL_NTS);
     if (sr==SQL_NO_DATA)
         hr=S_FALSE;
     else if (!SQL_SUCCEEDED(sr)) {
@@ -726,8 +699,6 @@ HRESULT ODBCCCache::onUpdate(const char* key, const char* tokens, time_t lastAcc
         log_error(hstmt, SQL_HANDLE_STMT);
         hr=E_FAIL;
     }
-    else
-        hr=NOERROR;
 
     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
     return hr;
@@ -736,7 +707,7 @@ HRESULT ODBCCCache::onUpdate(const char* key, const char* tokens, time_t lastAcc
 HRESULT ODBCCCache::onDelete(const char* key)
 {
 #ifdef _DEBUG
-    saml::NDC ndc("onDelete");
+    xmltooling::NDC ndc("onDelete");
 #endif
 
     SQLHSTMT hstmt;
@@ -745,7 +716,7 @@ HRESULT ODBCCCache::onDelete(const char* key)
     string q = string("DELETE FROM state WHERE cookie='") + key + "'";
     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
  
-    HRESULT hr;
+    HRESULT hr=NOERROR;
     if (sr==SQL_NO_DATA)
         hr=S_FALSE;
     else if (!SQL_SUCCEEDED(sr)) {
@@ -753,8 +724,6 @@ HRESULT ODBCCCache::onDelete(const char* key)
         log_error(hstmt, SQL_HANDLE_STMT);
         hr=E_FAIL;
     }
-    else
-        hr=NOERROR;
 
     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
     return hr;
@@ -763,10 +732,10 @@ HRESULT ODBCCCache::onDelete(const char* key)
 void ODBCCCache::cleanup()
 {
 #ifdef _DEBUG
-    saml::NDC ndc("cleanup");
+    xmltooling::NDC ndc("cleanup");
 #endif
 
-    Mutex* mutex = Mutex::create();
+    Mutex* mutex = xmltooling::Mutex::create();
 
     int rerun_timer = 0;
     int timeout_life = 0;
@@ -839,15 +808,17 @@ void ODBCCCache::cleanup()
 
     mutex->unlock();
     delete mutex;
-    Thread::exit(NULL);
+    xmltooling::Thread::exit(NULL);
 }
 
 void* ODBCCCache::cleanup_fcn(void* cache_p)
 {
   ODBCCCache* cache = (ODBCCCache*)cache_p;
 
+#ifndef WIN32
   // First, let's block all signals
   Thread::mask_all_signals();
+#endif
 
   // Now run the cleanup process.
   cache->cleanup();
@@ -861,7 +832,7 @@ public:
   ODBCReplayCache(const DOMElement* e);
   virtual ~ODBCReplayCache() {}
 
-  bool check(const XMLCh* str, time_t expires) {auto_ptr_XMLCh temp(str); return check(temp.get(),expires);}
+  bool check(const XMLCh* str, time_t expires) {xmltooling::auto_ptr_XMLCh temp(str); return check(temp.get(),expires);}
   bool check(const char* str, time_t expires);
 };