Corrected a variety of ODBC bugs.
[shibboleth/cpp-sp.git] / odbc_ccache / odbc-ccache.cpp
1 /*
2  *  Copyright 2001-2005 Internet2
3  * 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /*
18  * odbc-ccache.cpp - Shibboleth Credential Cache using ODBC
19  *
20  * Scott Cantor <cantor.2@osu.edu>
21  *
22  * $Id$
23  */
24
25 // eventually we might be able to support autoconf via cygwin...
26 #if defined (_MSC_VER) || defined(__BORLANDC__)
27 # include "config_win32.h"
28 #else
29 # include "config.h"
30 #endif
31
32 #ifdef WIN32
33 # define _CRT_NONSTDC_NO_DEPRECATE 1
34 # define _CRT_SECURE_NO_DEPRECATE 1
35 # define SHIBODBC_EXPORTS __declspec(dllexport)
36 #else
37 # define SHIBODBC_EXPORTS
38 #endif
39
40 #include <shib/shib-threads.h>
41 #include <shib-target/shib-target.h>
42 #include <log4cpp/Category.hh>
43
44 #include <sstream>
45
46 #include <sql.h>
47 #include <sqlext.h>
48
49 #ifdef HAVE_LIBDMALLOCXX
50 #include <dmalloc.h>
51 #endif
52
53 using namespace std;
54 using namespace saml;
55 using namespace shibboleth;
56 using namespace shibtarget;
57 using namespace log4cpp;
58
59 #define PLUGIN_VER_MAJOR 3
60 #define PLUGIN_VER_MINOR 0
61
62 #define COLSIZE_KEY 64
63 #define COLSIZE_APPLICATION_ID 256
64 #define COLSIZE_ADDRESS 128
65 #define COLSIZE_PROVIDER_ID 256
66 #define LONGDATA_BUFLEN 2048
67
68 /*
69   CREATE TABLE state (
70       cookie VARCHAR(64) PRIMARY KEY,
71       application_id VARCHAR(256),
72       ctime TIMESTAMP,
73       atime TIMESTAMP,
74       addr VARCHAR(128),
75       major INT,
76       minor INT,
77       provider VARCHAR(256),
78       subject TEXT,
79       authn_context TEXT,
80       tokens TEXT
81       )
82 */
83
84 #define REPLAY_TABLE \
85   "CREATE TABLE replay (id VARCHAR(255) PRIMARY KEY, " \
86   "expires TIMESTAMP, " \
87   "INDEX (expires))"
88
89 static const XMLCh ConnectionString[] =
90 { chLatin_C, chLatin_o, chLatin_n, chLatin_n, chLatin_e, chLatin_c, chLatin_t, chLatin_i, chLatin_o, chLatin_n,
91   chLatin_S, chLatin_t, chLatin_r, chLatin_i, chLatin_n, chLatin_g, chNull
92 };
93 static const XMLCh cleanupInterval[] =
94 { chLatin_c, chLatin_l, chLatin_e, chLatin_a, chLatin_n, chLatin_u, chLatin_p,
95   chLatin_I, chLatin_n, chLatin_t, chLatin_e, chLatin_r, chLatin_v, chLatin_a, chLatin_l, chNull
96 };
97 static const XMLCh cacheTimeout[] =
98 { chLatin_c, chLatin_a, chLatin_c, chLatin_h, chLatin_e, chLatin_T, chLatin_i, chLatin_m, chLatin_e, chLatin_o, chLatin_u, chLatin_t, chNull };
99 static const XMLCh odbcTimeout[] =
100 { chLatin_o, chLatin_d, chLatin_b, chLatin_c, chLatin_T, chLatin_i, chLatin_m, chLatin_e, chLatin_o, chLatin_u, chLatin_t, chNull };
101 static const XMLCh storeAttributes[] =
102 { chLatin_s, chLatin_t, chLatin_o, chLatin_r, chLatin_e, chLatin_A, chLatin_t, chLatin_t, chLatin_r, chLatin_i, chLatin_b, chLatin_u, chLatin_t, chLatin_e, chLatin_s, chNull };
103
104 struct ODBCConn {
105     ODBCConn(SQLHDBC conn) : handle(conn) {}
106     ~ODBCConn() {SQLFreeHandle(SQL_HANDLE_DBC,handle);}
107     operator SQLHDBC() {return handle;}
108     SQLHDBC handle;
109 };
110
111 class ODBCBase : public virtual saml::IPlugIn
112 {
113 public:
114     ODBCBase(const DOMElement* e);
115     virtual ~ODBCBase();
116
117     SQLHDBC getHDBC();
118
119     log4cpp::Category* log;
120
121 protected:
122     //ThreadKey* m_mysql;
123     const DOMElement* m_root; // can only use this during initialization
124     string m_connstring;
125
126     static SQLHENV m_henv;          // single handle for both plugins
127     bool m_bInitializedODBC;        // tracks which class handled the process
128
129     pair<int,int> getVersion(SQLHDBC);
130     void log_error(SQLHANDLE handle, SQLSMALLINT htype);
131 };
132
133 SQLHENV ODBCBase::m_henv = SQL_NULL_HANDLE;
134
135 /*
136 extern "C" void shib_mysql_destroy_handle(void* data)
137 {
138   MYSQL* mysql = (MYSQL*) data;
139   if (mysql) mysql_close(mysql);
140 }
141 */
142
143 ODBCBase::ODBCBase(const DOMElement* e) : m_root(e), m_bInitializedODBC(false)
144 {
145 #ifdef _DEBUG
146     saml::NDC ndc("ODBCBase");
147 #endif
148     log = &(Category::getInstance("shibtarget.SessionCache.ODBC"));
149     //m_mysql = ThreadKey::create(&shib_mysql_destroy_handle);
150
151     if (m_henv != SQL_NULL_HANDLE) {
152         log->info("ODBC already initialized");
153         return;
154     }
155
156     // Enable connection pooling.
157     SQLSetEnvAttr(SQL_NULL_HANDLE, SQL_ATTR_CONNECTION_POOLING, (void*)SQL_CP_ONE_PER_HENV, 0);
158
159     // Allocate the environment.
160     if (!SQL_SUCCEEDED(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &m_henv)))
161         throw ConfigurationException("ODBC failed to initialize.");
162
163     // Specify ODBC 3.x
164     SQLSetEnvAttr(m_henv, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0);
165
166     log->info("ODBC initialized");
167     m_bInitializedODBC = true;
168
169     // Grab connection string from the configuration.
170     e=saml::XML::getFirstChildElement(e,shibtarget::XML::SHIBTARGET_NS,ConnectionString);
171     if (!e || !e->hasChildNodes()) {
172         this->~ODBCBase();
173         throw ConfigurationException("ODBC cache requires ConnectionString element in configuration.");
174     }
175     auto_ptr_char arg(e->getFirstChild()->getNodeValue());
176     m_connstring=arg.get();
177
178     // Connect and check version.
179     SQLHDBC conn=getHDBC();
180     pair<int,int> v=getVersion(conn);
181     SQLFreeHandle(SQL_HANDLE_DBC,conn);
182
183     // Make sure we've got the right version.
184     if (v.first != PLUGIN_VER_MAJOR) {
185         this->~ODBCBase();
186         log->crit("unknown database version: %d.%d", v.first, v.second);
187         throw SAMLException("Unknown cache database version.");
188     }
189 }
190
191 ODBCBase::~ODBCBase()
192 {
193     //delete m_mysql;
194     if (m_bInitializedODBC)
195         SQLFreeHandle(SQL_HANDLE_ENV,m_henv);
196     m_henv = SQL_NULL_HANDLE;
197 }
198
199 void ODBCBase::log_error(SQLHANDLE handle, SQLSMALLINT htype)
200 {
201     SQLSMALLINT  i = 0;
202     SQLINTEGER   native;
203     SQLCHAR      state[7];
204     SQLCHAR      text[256];
205     SQLSMALLINT  len;
206     SQLRETURN    ret;
207
208     do {
209         ret = SQLGetDiagRec(htype, handle, ++i, state, &native, text, sizeof(text), &len);
210         if (SQL_SUCCEEDED(ret))
211             log->error("ODBC Error: %s:%ld:%ld:%s", state, i, native, text);
212     } while(SQL_SUCCEEDED(ret));
213 }
214
215 SQLHDBC ODBCBase::getHDBC()
216 {
217 #ifdef _DEBUG
218     saml::NDC ndc("getMYSQL");
219 #endif
220
221     // Get a handle.
222     SQLHDBC handle;
223     SQLRETURN sr=SQLAllocHandle(SQL_HANDLE_DBC, m_henv, &handle);
224     if (!SQL_SUCCEEDED(sr)) {
225         log->error("failed to allocate connection handle");
226         log_error(m_henv, SQL_HANDLE_ENV);
227         throw SAMLException("ODBCBase::getHDBC failed to allocate connection handle");
228     }
229
230     sr=SQLDriverConnect(handle,NULL,(SQLCHAR*)m_connstring.c_str(),m_connstring.length(),NULL,0,NULL,SQL_DRIVER_NOPROMPT);
231     if (!SQL_SUCCEEDED(sr)) {
232         log->error("failed to connect to database");
233         log_error(handle, SQL_HANDLE_DBC);
234         throw SAMLException("ODBCBase::getHDBC failed to connect to database");
235     }
236
237     return handle;
238 }
239
240 pair<int,int> ODBCBase::getVersion(SQLHDBC conn)
241 {
242     // Grab the version number from the database.
243     SQLHSTMT hstmt;
244     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
245     
246     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)"SELECT major,minor FROM version", SQL_NTS);
247     if (!SQL_SUCCEEDED(sr)) {
248         log->error("failed to read version from database");
249         log_error(hstmt, SQL_HANDLE_STMT);
250         throw SAMLException("ODBCBase::getVersion failed to read version from database");
251     }
252
253     SQLINTEGER major;
254     SQLINTEGER minor;
255     SQLBindCol(hstmt,1,SQL_C_SLONG,&major,0,NULL);
256     SQLBindCol(hstmt,2,SQL_C_SLONG,&minor,0,NULL);
257
258     if ((sr=SQLFetch(hstmt)) != SQL_NO_DATA) {
259         SQLCloseCursor(hstmt);
260         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
261         return pair<int,int>(major,minor);
262     }
263
264     SQLCloseCursor(hstmt);
265     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
266     log->error("no rows returned in version query");
267     throw SAMLException("ODBCBase::getVersion failed to read version from database");
268 }
269
270 class ODBCCCache : public ODBCBase, virtual public ISessionCache, virtual public ISessionCacheStore
271 {
272 public:
273     ODBCCCache(const DOMElement* e);
274     virtual ~ODBCCCache();
275
276     // Delegate all the ISessionCache methods.
277     string insert(
278         const IApplication* application,
279         const IEntityDescriptor* source,
280         const char* client_addr,
281         const SAMLSubject* subject,
282         const char* authnContext,
283         const SAMLResponse* tokens
284         )
285     { return m_cache->insert(application,source,client_addr,subject,authnContext,tokens); }
286     ISessionCacheEntry* find(const char* key, const IApplication* application, const char* client_addr)
287     { return m_cache->find(key,application,client_addr); }
288     void remove(const char* key, const IApplication* application, const char* client_addr)
289     { m_cache->remove(key,application,client_addr); }
290
291     bool setBackingStore(ISessionCacheStore*) {return false;}
292
293     // Store methods handle the database work
294     HRESULT onCreate(
295         const char* key,
296         const IApplication* application,
297         const ISessionCacheEntry* entry,
298         int majorVersion,
299         int minorVersion,
300         time_t created
301         );
302     HRESULT onRead(
303         const char* key,
304         string& applicationId,
305         string& clientAddress,
306         string& providerId,
307         string& subject,
308         string& authnContext,
309         string& tokens,
310         int& majorVersion,
311         int& minorVersion,
312         time_t& created,
313         time_t& accessed
314         );
315     HRESULT onRead(const char* key, time_t& accessed);
316     HRESULT onRead(const char* key, string& tokens);
317     HRESULT onUpdate(const char* key, const char* tokens=NULL, time_t accessed=0);
318     HRESULT onDelete(const char* key);
319
320     void cleanup();
321
322 private:
323     bool m_storeAttributes;
324     ISessionCache* m_cache;
325     CondWait* shutdown_wait;
326     bool shutdown;
327     Thread* cleanup_thread;
328
329     static void* cleanup_fcn(void*); // XXX Assumed an ODBCCCache
330 };
331
332 ODBCCCache::ODBCCCache(const DOMElement* e) : ODBCBase(e), m_storeAttributes(false)
333 {
334 #ifdef _DEBUG
335     saml::NDC ndc("ODBCCCache");
336 #endif
337
338     m_cache = dynamic_cast<ISessionCache*>(
339         SAMLConfig::getConfig().getPlugMgr().newPlugin(shibtarget::XML::MemorySessionCacheType, m_root)
340     );
341     if (!m_cache->setBackingStore(this)) {
342         delete m_cache;
343         throw SAMLException("Unable to register ODBC cache plugin as a cache store.");
344     }
345     
346     shutdown_wait = CondWait::create();
347     shutdown = false;
348
349     // Load our configuration details...
350     const XMLCh* tag=m_root->getAttributeNS(NULL,storeAttributes);
351     if (tag && *tag && (*tag==chLatin_t || *tag==chDigit_1))
352         m_storeAttributes=true;
353
354     // Initialize the cleanup thread
355     cleanup_thread = Thread::create(&cleanup_fcn, (void*)this);
356 }
357
358 ODBCCCache::~ODBCCCache()
359 {
360     shutdown = true;
361     shutdown_wait->signal();
362     cleanup_thread->join(NULL);
363     delete m_cache;
364 }
365
366 HRESULT ODBCCCache::onCreate(
367     const char* key,
368     const IApplication* application,
369     const ISessionCacheEntry* entry,
370     int majorVersion,
371     int minorVersion,
372     time_t created
373     )
374 {
375 #ifdef _DEBUG
376     saml::NDC ndc("onCreate");
377 #endif
378
379     // Get XML data from entry. Default is not to return SAML objects.
380     const char* context=entry->getAuthnContext();
381     pair<const char*,const SAMLSubject*> subject=entry->getSubject();
382     pair<const char*,const SAMLResponse*> tokens=entry->getTokens();
383
384     // Stringify timestamp.
385     if (created==0)
386         created=time(NULL);
387 #ifndef HAVE_GMTIME_R
388     struct tm* ptime=gmtime(&created);
389 #else
390     struct tm res;
391     struct tm* ptime=gmtime_r(&created,&res);
392 #endif
393     char timebuf[32];
394     strftime(timebuf,32,"{ts '%Y-%m-%d %H:%M:%S'}",ptime);
395
396     // Prepare insert statement.
397     ostringstream q;
398     q << "INSERT state VALUES ('" << key << "','" << application->getId() << "'," << timebuf << "," << timebuf
399         << ",'" << entry->getClientAddress() << "'," << majorVersion << "," << minorVersion << ",'" << entry->getProviderId()
400         << "',?,?,?)";
401
402     if (log->isDebugEnabled())
403         log->debug("SQL insert: %s", q.str().c_str());
404
405     // Get statement handle.
406     SQLHSTMT hstmt;
407     ODBCConn conn(getHDBC());
408     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
409
410     // Bind text parameters to statement.
411     SQLINTEGER cbSubject=SQL_LEN_DATA_AT_EXEC(0),cbContext=SQL_LEN_DATA_AT_EXEC(0),cbTokens;
412     if (!m_storeAttributes || !tokens.first)
413         cbTokens=SQL_NULL_DATA;
414     else
415         cbTokens=SQL_LEN_DATA_AT_EXEC(0);
416     SQLRETURN sr=SQLBindParameter(hstmt,1,SQL_PARAM_INPUT,SQL_C_CHAR,SQL_LONGVARCHAR,LONGDATA_BUFLEN,0,(SQLPOINTER)subject.first,0,&cbSubject);
417     if (!SQL_SUCCEEDED(sr))
418         log_error(hstmt, SQL_HANDLE_STMT);
419     sr=SQLBindParameter(hstmt,2,SQL_PARAM_INPUT,SQL_C_CHAR,SQL_LONGVARCHAR,LONGDATA_BUFLEN,0,(SQLPOINTER)context,0,&cbContext);
420     if (!SQL_SUCCEEDED(sr))
421         log_error(hstmt, SQL_HANDLE_STMT);
422     sr=SQLBindParameter(hstmt,3,SQL_PARAM_INPUT,SQL_C_CHAR,SQL_LONGVARCHAR,LONGDATA_BUFLEN,0,(SQLPOINTER)tokens.first,0,&cbTokens);
423     if (!SQL_SUCCEEDED(sr))
424         log_error(hstmt, SQL_HANDLE_STMT);
425
426     // Execute statement.
427     sr=SQLExecDirect(hstmt, (SQLCHAR*)q.str().c_str(), SQL_NTS);
428     if (sr==SQL_NEED_DATA) {
429         // Loop to send text data into driver.
430         // pData is set each round by the driver to the pointers we bound above.
431         char* pData;
432         sr=SQLParamData(hstmt,(SQLPOINTER*)&pData);
433         while (sr==SQL_NEED_DATA) {
434             size_t len=strlen(pData);
435             while (len>0) {
436                 size_t amt = min(LONGDATA_BUFLEN,len);
437                 SQLPutData(hstmt, pData, amt);
438                 pData += amt;
439                 len = len - amt;
440             }
441             sr=SQLParamData(hstmt,(SQLPOINTER*)&pData);
442        }
443     }
444
445     HRESULT hr=NOERROR;
446     if (!SQL_SUCCEEDED(sr)) {
447         log->error("failed to insert record into database");
448         log_error(hstmt, SQL_HANDLE_STMT);
449         hr=E_FAIL;
450     }
451
452     SQLCloseCursor(hstmt);
453     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
454     return hr;
455 }
456
457 HRESULT ODBCCCache::onRead(
458     const char* key,
459     string& applicationId,
460     string& clientAddress,
461     string& providerId,
462     string& subject,
463     string& authnContext,
464     string& tokens,
465     int& majorVersion,
466     int& minorVersion,
467     time_t& created,
468     time_t& accessed
469     )
470 {
471 #ifdef _DEBUG
472     saml::NDC ndc("onRead");
473 #endif
474
475     log->debug("searching database...");
476
477     SQLHSTMT hstmt;
478     ODBCConn conn(getHDBC());
479     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
480
481     string q = string("SELECT application_id,ctime,atime,addr,major,minor,provider,subject,authn_context,tokens FROM state WHERE cookie='") + key + "'";
482     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
483     if (!SQL_SUCCEEDED(sr)) {
484         log->error("error searching for (%s)",key);
485         log_error(hstmt, SQL_HANDLE_STMT);
486         SQLCloseCursor(hstmt);
487         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
488         return E_FAIL;
489     }
490
491     SQLINTEGER major,minor;
492     SQL_TIMESTAMP_STRUCT atime,ctime;
493     SQLCHAR application_id[COLSIZE_APPLICATION_ID+1];
494     SQLCHAR addr[COLSIZE_ADDRESS+1];
495     SQLCHAR provider_id[COLSIZE_PROVIDER_ID+1];
496
497     // Bind simple output columns.
498     SQLBindCol(hstmt,1,SQL_C_CHAR,application_id,sizeof(application_id),NULL);
499     SQLBindCol(hstmt,2,SQL_C_TYPE_TIMESTAMP,&ctime,0,NULL);
500     SQLBindCol(hstmt,3,SQL_C_TYPE_TIMESTAMP,&atime,0,NULL);
501     SQLBindCol(hstmt,4,SQL_C_CHAR,addr,sizeof(addr),NULL);
502     SQLBindCol(hstmt,5,SQL_C_SLONG,&major,0,NULL);
503     SQLBindCol(hstmt,6,SQL_C_SLONG,&minor,0,NULL);
504     SQLBindCol(hstmt,7,SQL_C_CHAR,provider_id,sizeof(provider_id),NULL);
505
506     if ((sr=SQLFetch(hstmt)) == SQL_NO_DATA) {
507         SQLCloseCursor(hstmt);
508         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
509         return S_FALSE;
510     }
511
512     log->debug("session found, tranfering data back into memory");
513
514     // Copy back simple data.
515     applicationId = (char*)application_id;
516     clientAddress = (char*)addr;
517     majorVersion = major;
518     minorVersion = minor;
519     providerId = (char*)provider_id;
520
521     struct tm t;
522     t.tm_sec=ctime.second;
523     t.tm_min=ctime.minute;
524     t.tm_hour=ctime.hour;
525     t.tm_mday=ctime.day;
526     t.tm_mon=ctime.month-1;
527     t.tm_year=ctime.year-1900;
528     t.tm_isdst=0;
529 #if defined(HAVE_TIMEGM)
530     created=timegm(&t);
531 #else
532     // Windows, and hopefully most others...?
533     created = mktime(&t) - timezone;
534 #endif
535     t.tm_sec=atime.second;
536     t.tm_min=atime.minute;
537     t.tm_hour=atime.hour;
538     t.tm_mday=atime.day;
539     t.tm_mon=atime.month-1;
540     t.tm_year=atime.year-1900;
541     t.tm_isdst=0;
542 #if defined(HAVE_TIMEGM)
543     accessed=timegm(&t);
544 #else
545     // Windows, and hopefully most others...?
546     accessed = mktime(&t) - timezone;
547 #endif
548
549     // Extract text data.
550     string* ptrs[] = {&subject, &authnContext, &tokens};
551     HRESULT hr=NOERROR;
552     SQLINTEGER len;
553     SQLCHAR buf[LONGDATA_BUFLEN];
554     for (int i=0; i<3; i++) {
555         while ((sr=SQLGetData(hstmt,i+8,SQL_C_CHAR,buf,sizeof(buf),&len)) != SQL_NO_DATA) {
556             if (!SUCCEEDED(sr)) {
557                 log->error("error while reading text field from result set");
558                 log_error(hstmt, SQL_HANDLE_STMT);
559                 hr=E_FAIL;
560                 break;
561             }
562             ptrs[i]->append((char*)buf);
563         }
564     }
565
566     SQLCloseCursor(hstmt);
567     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
568     return hr;
569 }
570
571 HRESULT ODBCCCache::onRead(const char* key, time_t& accessed)
572 {
573 #ifdef _DEBUG
574     saml::NDC ndc("onRead");
575 #endif
576
577     log->debug("reading last access time from database");
578
579     SQLHSTMT hstmt;
580     ODBCConn conn(getHDBC());
581     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
582     
583     string q = string("SELECT atime FROM state WHERE cookie='") + key + "'";
584     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
585     if (!SQL_SUCCEEDED(sr)) {
586         log->error("error searching for (%s)",key);
587         log_error(hstmt, SQL_HANDLE_STMT);
588         SQLCloseCursor(hstmt);
589         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
590         return E_FAIL;
591     }
592
593     SQL_TIMESTAMP_STRUCT atime;
594     SQLBindCol(hstmt,1,SQL_C_TYPE_TIMESTAMP,&atime,0,NULL);
595
596     if ((sr=SQLFetch(hstmt)) == SQL_NO_DATA) {
597         log->warn("session expected, but not found in database");
598         SQLCloseCursor(hstmt);
599         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
600         return S_FALSE;
601     }
602
603     SQLCloseCursor(hstmt);
604     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
605
606     struct tm t;
607     t.tm_sec=atime.second;
608     t.tm_min=atime.minute;
609     t.tm_hour=atime.hour;
610     t.tm_mday=atime.day;
611     t.tm_mon=atime.month-1;
612     t.tm_year=atime.year-1900;
613     t.tm_isdst=0;
614 #if defined(HAVE_TIMEGM)
615     accessed=timegm(&t);
616 #else
617     // Windows, and hopefully most others...?
618     accessed = mktime(&t) - timezone;
619 #endif
620     return NOERROR;
621 }
622
623 HRESULT ODBCCCache::onRead(const char* key, string& tokens)
624 {
625 #ifdef _DEBUG
626     saml::NDC ndc("onRead");
627 #endif
628
629     if (!m_storeAttributes)
630         return S_FALSE;
631
632     log->debug("reading cached tokens from database");
633
634     SQLHSTMT hstmt;
635     ODBCConn conn(getHDBC());
636     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
637     
638     string q = string("SELECT tokens FROM state WHERE cookie='") + key + "'";
639     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
640     if (!SQL_SUCCEEDED(sr)) {
641         log->error("error searching for (%s)",key);
642         log_error(hstmt, SQL_HANDLE_STMT);
643         SQLCloseCursor(hstmt);
644         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
645         return E_FAIL;
646     }
647
648     if ((sr=SQLFetch(hstmt)) == SQL_NO_DATA) {
649         log->warn("session expected, but not found in database");
650         SQLCloseCursor(hstmt);
651         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
652         return S_FALSE;
653     }
654
655     HRESULT hr=NOERROR;
656     SQLINTEGER len;
657     SQLCHAR buf[LONGDATA_BUFLEN];
658     while ((sr=SQLGetData(hstmt,1,SQL_C_CHAR,buf,sizeof(buf),&len)) != SQL_NO_DATA) {
659         if (!SUCCEEDED(sr)) {
660             log->error("error while reading text field from result set");
661             log_error(hstmt, SQL_HANDLE_STMT);
662             hr=E_FAIL;
663             break;
664         }
665         tokens += (char*)buf;
666     }
667
668     SQLCloseCursor(hstmt);
669     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
670     return hr;
671 }
672
673 HRESULT ODBCCCache::onUpdate(const char* key, const char* tokens, time_t lastAccess)
674 {
675 #ifdef _DEBUG
676     saml::NDC ndc("onUpdate");
677 #endif
678
679     SQLRETURN sr;
680     SQLHSTMT hstmt;
681     ODBCConn conn(getHDBC());
682
683     if (lastAccess>0) {
684 #ifndef HAVE_GMTIME_R
685         struct tm* ptime=gmtime(&lastAccess);
686 #else
687         struct tm res;
688         struct tm* ptime=gmtime_r(&lastAccess,&res);
689 #endif
690         char timebuf[32];
691         strftime(timebuf,32,"{ts '%Y-%m-%d %H:%M:%S'}",ptime);
692
693         ostringstream q;
694         q << "UPDATE state SET atime=" << timebuf << " WHERE cookie='" << key << "'";
695
696         SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
697         sr=SQLExecDirect(hstmt, (SQLCHAR*)q.str().c_str(), SQL_NTS);
698     }
699     else if (tokens) {
700         if (!m_storeAttributes)
701             return S_FALSE;
702         string q = string("UPDATE state SET tokens=? WHERE cookie='") + key + "'";
703
704         SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
705
706         // Bind text parameters to statement.
707         SQLINTEGER cbTokens = tokens ? SQL_LEN_DATA_AT_EXEC(0) : SQL_NULL_DATA;
708         sr=SQLBindParameter(hstmt,1,SQL_PARAM_INPUT,SQL_C_CHAR,SQL_LONGVARCHAR,LONGDATA_BUFLEN,0,(SQLPOINTER)tokens,0,&cbTokens);
709
710         // Execute statement.
711         sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
712         if (sr==SQL_NEED_DATA) {
713             // Loop to send text data into driver.
714             // pData is set each round by the driver to the pointers we bound above.
715             char* pData;
716             sr=SQLParamData(hstmt,(SQLPOINTER*)&pData);
717             while (sr==SQL_NEED_DATA) {
718                 size_t len=strlen(pData);
719                 while (len>0) {
720                     size_t amt=min(LONGDATA_BUFLEN,len);
721                     SQLPutData(hstmt, pData, amt);
722                     pData += amt;
723                     len = len - amt;
724                 }
725                 sr=SQLParamData(hstmt,(SQLPOINTER*)&pData);
726            }
727         }
728     }
729     else {
730         log->warn("onUpdate called with nothing to do!");
731         return S_FALSE;
732     }
733  
734     HRESULT hr;
735     if (sr==SQL_NO_DATA)
736         hr=S_FALSE;
737     else if (!SQL_SUCCEEDED(sr)) {
738         log->error("error updating record (key=%s)", key);
739         log_error(hstmt, SQL_HANDLE_STMT);
740         hr=E_FAIL;
741     }
742     else
743         hr=NOERROR;
744
745     SQLCloseCursor(hstmt);
746     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
747     return hr;
748 }
749
750 HRESULT ODBCCCache::onDelete(const char* key)
751 {
752 #ifdef _DEBUG
753     saml::NDC ndc("onDelete");
754 #endif
755
756     SQLHSTMT hstmt;
757     ODBCConn conn(getHDBC());
758     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
759     string q = string("DELETE FROM state WHERE cookie='") + key + "'";
760     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
761  
762     HRESULT hr;
763     if (sr==SQL_NO_DATA)
764         hr=S_FALSE;
765     else if (!SQL_SUCCEEDED(sr)) {
766         log->error("error deleting record (key=%s)", key);
767         log_error(hstmt, SQL_HANDLE_STMT);
768         hr=E_FAIL;
769     }
770     else
771         hr=NOERROR;
772
773     SQLCloseCursor(hstmt);
774     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
775     return hr;
776 }
777
778 void ODBCCCache::cleanup()
779 {
780 #ifdef _DEBUG
781     saml::NDC ndc("cleanup");
782 #endif
783
784     Mutex* mutex = Mutex::create();
785
786     int rerun_timer = 0;
787     int timeout_life = 0;
788
789     // Load our configuration details...
790     const XMLCh* tag=m_root->getAttributeNS(NULL,cleanupInterval);
791     if (tag && *tag)
792         rerun_timer = XMLString::parseInt(tag);
793
794     // search for 'mysql-cache-timeout' and then the regular cache timeout
795     tag=m_root->getAttributeNS(NULL,odbcTimeout);
796     if (tag && *tag)
797         timeout_life = XMLString::parseInt(tag);
798     else {
799         tag=m_root->getAttributeNS(NULL,cacheTimeout);
800         if (tag && *tag)
801             timeout_life = XMLString::parseInt(tag);
802     }
803   
804     if (rerun_timer <= 0)
805         rerun_timer = 300;              // rerun every 5 minutes
806
807     if (timeout_life <= 0)
808         timeout_life = 28800;   // timeout after 8 hours
809     
810     mutex->lock();
811
812     log->info("cleanup thread started...Run every %d secs; timeout after %d secs", rerun_timer, timeout_life);
813
814     while (shutdown == false) {
815         shutdown_wait->timedwait(mutex, rerun_timer);
816
817         if (shutdown == true)
818             break;
819
820         // Find all the entries in the database that haven't been used
821         // recently In particular, find all entries that have not been
822         // accessed in 'timeout_life' seconds.
823
824         time_t stale=time(NULL)-timeout_life;
825 #ifndef HAVE_GMTIME_R
826         struct tm* ptime=gmtime(&stale);
827 #else
828         struct tm res;
829         struct tm* ptime=gmtime_r(&stale,&res);
830 #endif
831         char timebuf[32];
832         strftime(timebuf,32,"{ts '%Y-%m-%d %H:%M:%S'}",ptime);
833
834         string q = string("DELETE state WHERE atime < ") +  timebuf;
835
836         SQLHSTMT hstmt;
837         ODBCConn conn(getHDBC());
838         SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
839         SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
840         if (sr!=SQL_NO_DATA && !SQL_SUCCEEDED(sr)) {
841             log->error("error purging old records");
842             log_error(hstmt, SQL_HANDLE_STMT);
843         }
844
845         SQLINTEGER rowcount=0;
846         sr=SQLRowCount(hstmt,&rowcount);
847         if (SQL_SUCCEEDED(sr) && rowcount > 0)
848             log->info("purging %d old sessions",rowcount);
849
850         SQLCloseCursor(hstmt);
851         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
852      }
853
854     log->info("cleanup thread exiting...");
855
856     mutex->unlock();
857     delete mutex;
858     Thread::exit(NULL);
859 }
860
861 void* ODBCCCache::cleanup_fcn(void* cache_p)
862 {
863   ODBCCCache* cache = (ODBCCCache*)cache_p;
864
865   // First, let's block all signals
866   Thread::mask_all_signals();
867
868   // Now run the cleanup process.
869   cache->cleanup();
870   return NULL;
871 }
872
873 /*
874 class MySQLReplayCache : public MySQLBase, virtual public IReplayCache
875 {
876 public:
877   MySQLReplayCache(const DOMElement* e);
878   virtual ~MySQLReplayCache() {}
879
880   bool check(const XMLCh* str, time_t expires) {auto_ptr_XMLCh temp(str); return check(temp.get(),expires);}
881   bool check(const char* str, time_t expires);
882 };
883
884 MySQLReplayCache::MySQLReplayCache(const DOMElement* e) : MySQLBase(e)
885 {
886 #ifdef _DEBUG
887   saml::NDC ndc("MySQLReplayCache");
888 #endif
889
890   log = &(Category::getInstance("shibmysql.ReplayCache"));
891 }
892
893 bool MySQLReplayCache::check(const char* str, time_t expires)
894 {
895 #ifdef _DEBUG
896     saml::NDC ndc("check");
897 #endif
898   
899     // Remove expired entries
900     string q = string("DELETE FROM replay WHERE expires < NOW()");
901     MYSQL* mysql = getMYSQL();
902     if (mysql_query(mysql, q.c_str())) {
903         const char* err=mysql_error(mysql);
904         log->error("Error deleting expired entries: %s", err);
905         if (isCorrupt(err) && repairTable(mysql,"replay")) {
906             // Try again...
907             if (mysql_query(mysql, q.c_str()))
908                 log->error("Error deleting expired entries: %s", mysql_error(mysql));
909         }
910     }
911   
912     string q2 = string("SELECT id FROM replay WHERE id='") + str + "'";
913     if (mysql_query(mysql, q2.c_str())) {
914         const char* err=mysql_error(mysql);
915         log->error("Error searching for %s: %s", str, err);
916         if (isCorrupt(err) && repairTable(mysql,"replay")) {
917             if (mysql_query(mysql, q2.c_str())) {
918                 log->error("Error retrying search for %s: %s", str, mysql_error(mysql));
919                 throw SAMLException("Replay cache failed, please inform application support staff.");
920             }
921         }
922         else
923             throw SAMLException("Replay cache failed, please inform application support staff.");
924     }
925
926     // Did we find it?
927     MYSQL_RES* rows = mysql_store_result(mysql);
928     if (rows && mysql_num_rows(rows)>0) {
929       mysql_free_result(rows);
930       return false;
931     }
932
933     ostringstream q3;
934     q3 << "INSERT INTO replay VALUES('" << str << "'," << "FROM_UNIXTIME(" << expires << "))";
935
936     // then add it to the database
937     if (mysql_query(mysql, q3.str().c_str())) {
938         const char* err=mysql_error(mysql);
939         log->error("Error inserting %s: %s", str, err);
940         if (isCorrupt(err) && repairTable(mysql,"state")) {
941             // Try again...
942             if (mysql_query(mysql, q3.str().c_str())) {
943                 log->error("Error inserting %s: %s", str, mysql_error(mysql));
944                 throw SAMLException("Replay cache failed, please inform application support staff.");
945             }
946         }
947         else
948             throw SAMLException("Replay cache failed, please inform application support staff.");
949     }
950     
951     return true;
952 }
953
954 IPlugIn* new_mysql_replay(const DOMElement* e)
955 {
956     return new MySQLReplayCache(e);
957 }
958 */
959
960 /*************************************************************************
961  * The registration functions here...
962  */
963
964 IPlugIn* new_odbc_ccache(const DOMElement* e)
965 {
966     return new ODBCCCache(e);
967 }
968
969
970 extern "C" int SHIBODBC_EXPORTS saml_extension_init(void*)
971 {
972     // register this ccache type
973 //    SAMLConfig::getConfig().getPlugMgr().regFactory(shibtarget::XML::ODBCReplayCacheType, &new_odbc_replay);
974     SAMLConfig::getConfig().getPlugMgr().regFactory(shibtarget::XML::ODBCSessionCacheType, &new_odbc_ccache);
975     return 0;
976 }
977
978 extern "C" void SHIBODBC_EXPORTS saml_extension_term()
979 {
980     SAMLConfig::getConfig().getPlugMgr().unregFactory(shibtarget::XML::ODBCSessionCacheType);
981 //    SAMLConfig::getConfig().getPlugMgr().unregFactory(shibtarget::XML::ODBCReplayCacheType);
982 }