Use shibboleth-sp as package name for compatibility.
[shibboleth/cpp-sp.git] / memcache-store / memcache-store.cpp
1 /**
2  * Licensed to the University Corporation for Advanced Internet
3  * Development, Inc. (UCAID) under one or more contributor license
4  * agreements. See the NOTICE file distributed with this work for
5  * additional information regarding copyright ownership.
6  *
7  * UCAID licenses this file to you under the Apache License,
8  * Version 2.0 (the "License"); you may not use this file except
9  * in compliance with the License. You may obtain a copy of the
10  * License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing,
15  * software distributed under the License is distributed on an
16  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
17  * either express or implied. See the License for the specific
18  * language governing permissions and limitations under the License.
19  */
20
21 /**
22  * memcache-store.cpp
23  *
24  * Storage Service using memcache (pre memcache tags).
25  */
26
27 #if defined (_MSC_VER) || defined(__BORLANDC__)
28 # include "config_win32.h"
29 #else
30 # include "config.h"
31 #endif
32
33 #ifdef WIN32
34 # define _CRT_NONSTDC_NO_DEPRECATE 1
35 # define _CRT_SECURE_NO_DEPRECATE 1
36 # define MCEXT_EXPORTS __declspec(dllexport)
37 #else
38 # define MCEXT_EXPORTS
39 #endif
40
41 #include <xmltooling/base.h>
42
43 #include <list>
44 #include <iostream> 
45 #include <libmemcached/memcached.h>
46 #include <xercesc/util/XMLUniDefs.hpp>
47
48 #include <xmltooling/logging.h>
49 #include <xmltooling/unicode.h>
50 #include <xmltooling/XMLToolingConfig.h>
51 #include <xmltooling/util/NDC.h>
52 #include <xmltooling/util/StorageService.h>
53 #include <xmltooling/util/Threads.h>
54 #include <xmltooling/util/XMLHelper.h>
55
56 using namespace xmltooling::logging;
57 using namespace xmltooling;
58 using namespace xercesc;
59 using namespace boost;
60 using namespace std;
61
62 namespace {
63     static const XMLCh Hosts[] = UNICODE_LITERAL_5(H,o,s,t,s);
64     static const XMLCh prefix[] = UNICODE_LITERAL_6(p,r,e,f,i,x);
65     static const XMLCh buildMap[] = UNICODE_LITERAL_8(b,u,i,l,d,M,a,p);
66     static const XMLCh sendTimeout[] = UNICODE_LITERAL_11(s,e,n,d,T,i,m,e,o,u,t);
67     static const XMLCh recvTimeout[] = UNICODE_LITERAL_11(r,e,c,v,T,i,m,e,o,u,t);
68     static const XMLCh pollTimeout[] = UNICODE_LITERAL_11(p,o,l,l,T,i,m,e,o,u,t);
69     static const XMLCh failLimit[] = UNICODE_LITERAL_9(f,a,i,l,L,i,m,i,t);
70     static const XMLCh retryTimeout[] = UNICODE_LITERAL_12(r,e,t,r,y,T,i,m,e,o,u,t);
71     static const XMLCh nonBlocking[] = UNICODE_LITERAL_11(n,o,n,B,l,o,c,k,i,n,g);
72   
73     class mc_record {
74     public:
75         string value;
76         time_t expiration;
77         mc_record() {};
78         mc_record(string _v, time_t _e) : value(_v), expiration(_e) {}
79     };
80
81     class MemcacheBase {
82     public:
83         MemcacheBase(const DOMElement* e);
84         ~MemcacheBase();
85         
86         bool addMemcache(const char* key, string &value, time_t timeout, uint32_t flags, bool use_prefix = true);
87         bool setMemcache(const char* key, string &value, time_t timeout, uint32_t flags, bool use_prefix = true);
88         bool replaceMemcache(const char* key, string &value, time_t timeout, uint32_t flags, bool use_prefix = true);
89         bool getMemcache(const char* key, string &dest, uint32_t *flags, bool use_prefix = true);
90         bool deleteMemcache(const char* key, time_t timeout, bool use_prefix = true);
91
92         void serialize(mc_record &source, string &dest);
93         void serialize(list<string> &source, string &dest);
94         void deserialize(string &source, mc_record &dest);
95         void deserialize(string &source, list<string> &dest);
96
97         bool addLock(string what, bool use_prefix = true);
98         void deleteLock(string what, bool use_prefix = true);
99
100     protected:
101         Category& m_log;
102         memcached_st* memc;
103         string m_prefix;
104         scoped_ptr<Mutex> m_lock;
105
106     private:
107         bool handleError(const char*, memcached_return) const;
108     };
109   
110     class MemcacheStorageService : public StorageService, public MemcacheBase {
111
112     public:
113         MemcacheStorageService(const DOMElement* e);
114         ~MemcacheStorageService() {}
115     
116         const Capabilities& getCapabilities() const {
117             return m_caps;
118         }
119
120         bool createString(const char* context, const char* key, const char* value, time_t expiration);
121         int readString(const char* context, const char* key, string* pvalue=nullptr, time_t* pexpiration=nullptr, int version=0);
122         int updateString(const char* context, const char* key, const char* value=nullptr, time_t expiration=0, int version=0);
123         bool deleteString(const char* context, const char* key);
124     
125         bool createText(const char* context, const char* key, const char* value, time_t expiration) {
126             return createString(context, key, value, expiration);
127         }
128         int readText(const char* context, const char* key, string* pvalue=nullptr, time_t* pexpiration=nullptr, int version=0) {
129             return readString(context, key, pvalue, pexpiration, version);
130         }
131         int updateText(const char* context, const char* key, const char* value=nullptr, time_t expiration=0, int version=0) {
132             return updateString(context, key, value, expiration, version);
133         }
134         bool deleteText(const char* context, const char* key) {
135             return deleteString(context, key);
136         }
137     
138         void reap(const char* context) {}
139
140         void updateContext(const char* context, time_t expiration);
141         void deleteContext(const char* context);
142
143     private:
144         Capabilities m_caps;
145         bool m_buildMap;
146     };
147
148     StorageService* MemcacheStorageServiceFactory(const DOMElement* const & e) {
149         return new MemcacheStorageService(e);
150     }
151 };
152
153 MemcacheBase::MemcacheBase(const DOMElement* e)
154     : m_log(Category::getInstance("XMLTooling.StorageService.MEMCACHE")), memc(nullptr),
155         m_prefix(XMLHelper::getAttrString(e, nullptr, prefix)), m_lock(Mutex::create())
156 {
157     memc = memcached_create(nullptr);
158     if (!memc)
159         throw XMLToolingException("MemcacheBase::Memcache(): memcached_create() failed");
160     m_log.debug("Memcache created");
161
162     memcached_behavior_set(memc, MEMCACHED_BEHAVIOR_HASH, MEMCACHED_HASH_CRC);
163     m_log.debug("CRC hash set");
164
165     int prop = XMLHelper::getAttrInt(e, 999999, sendTimeout);
166     m_log.debug("MEMCACHED_BEHAVIOR_SND_TIMEOUT will be set to %d", prop);
167     memcached_behavior_set(memc, MEMCACHED_BEHAVIOR_SND_TIMEOUT, prop);
168
169     prop = XMLHelper::getAttrInt(e, 999999, recvTimeout);
170     m_log.debug("MEMCACHED_BEHAVIOR_RCV_TIMEOUT will be set to %d", prop);
171     memcached_behavior_set(memc, MEMCACHED_BEHAVIOR_RCV_TIMEOUT, prop);
172
173     prop = XMLHelper::getAttrInt(e, 1000, pollTimeout);
174     m_log.debug("MEMCACHED_BEHAVIOR_POLL_TIMEOUT will be set to %d", prop);
175     memcached_behavior_set(memc, MEMCACHED_BEHAVIOR_POLL_TIMEOUT, prop);
176
177     prop = XMLHelper::getAttrInt(e, 5, failLimit);
178     m_log.debug("MEMCACHED_BEHAVIOR_SERVER_FAILURE_LIMIT will be set to %d", prop);
179     memcached_behavior_set(memc, MEMCACHED_BEHAVIOR_SERVER_FAILURE_LIMIT, prop);
180
181     prop = XMLHelper::getAttrInt(e, 30, retryTimeout);
182     m_log.debug("MEMCACHED_BEHAVIOR_RETRY_TIMEOUT will be set to %d", prop);
183     memcached_behavior_set(memc, MEMCACHED_BEHAVIOR_RETRY_TIMEOUT, prop);
184
185     prop = XMLHelper::getAttrInt(e, 1, nonBlocking);
186     m_log.debug("MEMCACHED_BEHAVIOR_NO_BLOCK will be set to %d", prop);
187     memcached_behavior_set(memc, MEMCACHED_BEHAVIOR_NO_BLOCK, prop);
188
189     // Grab hosts from the configuration.
190     e = e ? XMLHelper::getFirstChildElement(e, Hosts) : nullptr;
191     if (!e || !e->hasChildNodes()) {
192         memcached_free(memc);
193         throw XMLToolingException("Memcache StorageService requires Hosts element in configuration.");
194     }
195     auto_ptr_char h(e->getTextContent());
196     m_log.debug("INIT: GOT Hosts: %s", h.get());
197     memcached_server_st* servers;
198     servers = memcached_servers_parse(const_cast<char*>(h.get()));
199     m_log.debug("Got %u hosts.",  memcached_server_list_count(servers));
200     if (memcached_server_push(memc, servers) != MEMCACHED_SUCCESS) {
201         memcached_server_list_free(servers);
202         memcached_free(memc);
203         throw IOException("MemcacheBase: memcached_server_push() failed");
204     }
205     memcached_server_list_free(servers);
206
207     m_log.debug("Memcache object initialized");
208 }
209
210 MemcacheBase::~MemcacheBase()
211 {
212     memcached_free(memc);
213     m_log.debug("Base object destroyed");
214 }
215
216
217 bool MemcacheBase::handleError(const char* fn, memcached_return rv) const
218 {
219 #ifdef HAVE_MEMCACHED_LAST_ERROR_MESSAGE
220     string error = string("Memcache::") + fn + ": " + memcached_last_error_message(memc);
221 #else
222     string error;
223     if (rv == MEMCACHED_ERRNO) {
224         // System error
225         error = string("Memcache::") + fn + "SYSTEM ERROR: " + strerror(memc->cached_errno);
226     }
227     else {
228         error = string("Memcache::") + fn + " Problems: " + memcached_strerror(memc, rv);
229     }
230 #endif
231     m_log.error(error);
232     throw IOException(error);
233 }
234
235 bool MemcacheBase::addLock(string what, bool use_prefix)
236 {
237     string lock_name = what + ":LOCK";
238     string set_val = "1";
239     unsigned tries = 5;
240     while (!addMemcache(lock_name.c_str(), set_val, 5, 0, use_prefix)) {
241         if (tries-- == 0) {
242             m_log.debug("Unable to get lock %s... FAILED.", lock_name.c_str());
243             return false;
244         }
245         m_log.debug("Unable to get lock %s... Retrying.", lock_name.c_str());
246     
247         // sleep 100ms
248 #ifdef WIN32
249         Sleep(100);
250 #else
251         struct timeval tv = { 0, 100000 };
252         select(0, 0, 0, 0, &tv);
253 #endif
254     }
255     return true;
256 }
257
258 void MemcacheBase::deleteLock(string what, bool use_prefix)
259 {
260     string lock_name = what + ":LOCK";
261     deleteMemcache(lock_name.c_str(), 0, use_prefix);
262     return;
263
264 }  
265
266 void MemcacheBase::deserialize(string& source, mc_record& dest)
267 {
268     istringstream is(source, stringstream::in | stringstream::out);
269     is >> dest.expiration;
270     is.ignore(1); // ignore delimiter
271     dest.value = is.str().c_str() + is.tellg();
272 }
273
274 void MemcacheBase::deserialize(string& source, list<string>& dest)
275 {
276     istringstream is(source, stringstream::in | stringstream::out);
277     while (!is.eof()) {
278         string s;
279         is >> s;
280         dest.push_back(s);
281     }
282 }
283
284 void MemcacheBase::serialize(mc_record& source, string& dest)
285 {
286     ostringstream os(stringstream::in | stringstream::out);
287     os << source.expiration;
288     os << "-"; // delimiter
289     os << source.value;
290     dest = os.str();
291 }
292
293 void MemcacheBase::serialize(list<string>& source, string& dest)
294 {
295     ostringstream os(stringstream::in | stringstream::out);
296     for(list<string>::iterator iter = source.begin(); iter != source.end(); iter++) {
297         if (iter != source.begin()) {
298             os << endl;
299         }
300         os << *iter;
301     }
302     dest = os.str();
303 }
304
305 bool MemcacheBase::deleteMemcache(const char* key, time_t timeout, bool use_prefix)
306 {
307     string final_key;
308     if (use_prefix)
309         final_key = m_prefix + key;
310     else
311         final_key = key;
312
313     Lock lock(m_lock);
314     memcached_return rv = memcached_delete(memc, const_cast<char*>(final_key.c_str()), final_key.length(), timeout);
315
316     switch (rv) {
317         case MEMCACHED_SUCCESS:
318             return true;
319         case MEMCACHED_NOTFOUND:
320             // Key wasn't there... No biggie.
321             return false;
322         default:
323             return handleError("deleteMemcache", rv);
324     }
325 }
326
327 bool MemcacheBase::getMemcache(const char* key, string& dest, uint32_t* flags, bool use_prefix)
328 {
329     string final_key;
330     if (use_prefix)
331         final_key = m_prefix + key;
332     else
333         final_key = key;
334
335     Lock lock(m_lock);
336     size_t len;
337     memcached_return rv;
338     char* result = memcached_get(memc, const_cast<char*>(final_key.c_str()), final_key.length(), &len, flags, &rv);
339
340     switch (rv) {
341         case MEMCACHED_SUCCESS:
342             dest = result;
343             free(result);
344             return true;
345         case MEMCACHED_NOTFOUND:
346             m_log.debug("Key %s not found in memcache...", key);
347             return false;
348         default:
349             return handleError("getMemcache", rv);
350     }
351 }
352
353 bool MemcacheBase::addMemcache(const char* key, string& value, time_t timeout, uint32_t flags, bool use_prefix)
354 {
355     string final_key;
356     if (use_prefix)
357         final_key = m_prefix + key;
358     else
359         final_key = key;
360
361     Lock lock(m_lock);
362     memcached_return rv = memcached_add(
363         memc, const_cast<char*>(final_key.c_str()), final_key.length(), const_cast<char*>(value.c_str()), value.length(), timeout, flags
364         );
365
366     switch (rv) {
367         case MEMCACHED_SUCCESS:
368             return true;
369         case MEMCACHED_NOTSTORED:
370             return false;
371         default:
372             return handleError("addMemcache", rv);
373     }
374 }
375
376 bool MemcacheBase::setMemcache(const char* key, string& value, time_t timeout, uint32_t flags, bool use_prefix)
377 {
378     string final_key;
379     if (use_prefix)
380         final_key = m_prefix + key;
381     else
382         final_key = key;
383
384     Lock lock(m_lock);
385     memcached_return rv = memcached_set(
386         memc, const_cast<char*>(final_key.c_str()), final_key.length(), const_cast<char*>(value.c_str()), value.length(), timeout, flags
387         );
388
389     if (rv == MEMCACHED_SUCCESS)
390         return true;
391     return handleError("setMemcache", rv);
392 }
393
394 bool MemcacheBase::replaceMemcache(const char* key, string& value, time_t timeout, uint32_t flags, bool use_prefix)
395 {
396   
397     string final_key;
398     if (use_prefix)
399         final_key = m_prefix + key;
400     else
401         final_key = key;
402
403     Lock lock(m_lock);
404     memcached_return rv = memcached_replace(
405         memc, const_cast<char*>(final_key.c_str()), final_key.length(), const_cast<char*>(value.c_str()), value.length(), timeout, flags
406         );
407
408     switch (rv) {
409         case MEMCACHED_SUCCESS:
410             return true;
411         case MEMCACHED_NOTSTORED:
412             // not there
413             return false;
414         default:
415             return handleError("replaceMemcache", rv);
416     }
417 }
418
419
420 MemcacheStorageService::MemcacheStorageService(const DOMElement* e)
421     : MemcacheBase(e), m_caps(80, 250 - m_prefix.length() - 1 - 80, 255),
422         m_buildMap(XMLHelper::getAttrBool(e, false, buildMap))
423 {
424     if (m_buildMap)
425         m_log.debug("Cache built with buildMap ON");
426 }
427
428 bool MemcacheStorageService::createString(const char* context, const char* key, const char* value, time_t expiration)
429 {
430     m_log.debug("createString ctx: %s - key: %s", context, key);
431
432     string final_key = string(context) + ":" + string(key);
433
434     mc_record rec(value, expiration);
435     string final_value;
436     serialize(rec, final_value);
437
438     bool result = addMemcache(final_key.c_str(), final_value, expiration, 1); // the flag will be the version
439
440     if (result && m_buildMap) {
441         m_log.debug("Got result, updating map");
442
443         string map_name = context;
444         // we need to update the context map
445         if (!addLock(map_name)) {
446             m_log.error("Unable to get lock for context %s!", context);
447             deleteMemcache(final_key.c_str(), 0);
448             return false;
449         }
450
451         string ser_arr;
452         uint32_t flags;
453         bool result = getMemcache(map_name.c_str(), ser_arr, &flags);
454     
455         list<string> contents;
456         if (result) {
457             m_log.debug("Match found. Parsing...");
458             deserialize(ser_arr, contents);
459             if (m_log.isDebugEnabled()) {
460                 m_log.debug("Iterating retrieved session map...");
461                 for(list<string>::const_iterator iter = contents.begin(); iter != contents.end(); ++iter)
462                     m_log.debug("value = %s", iter->c_str());
463             }
464         }
465         else {
466             m_log.debug("New context: %s", map_name.c_str());
467         }
468
469         contents.push_back(key);
470         serialize(contents, ser_arr);
471         setMemcache(map_name.c_str(), ser_arr, expiration, 0);
472         deleteLock(map_name);
473     }
474     return result;
475 }
476
477 int MemcacheStorageService::readString(const char* context, const char* key, string* pvalue, time_t* pexpiration, int version)
478 {
479     m_log.debug("readString ctx: %s - key: %s", context, key);
480
481     string final_key = string(context) + ":" + string(key);
482     uint32_t rec_version;
483     string value;
484
485     if (m_buildMap) {
486         m_log.debug("Checking context");
487         string map_name = context;
488         string ser_arr;
489         uint32_t flags;
490         bool ctx_found = getMemcache(map_name.c_str(), ser_arr, &flags);
491         if (!ctx_found)
492             return 0;
493     }
494
495     bool found = getMemcache(final_key.c_str(), value, &rec_version);
496     if (!found)
497         return 0;
498
499     if (version && rec_version <= (uint32_t)version)
500         return version;
501
502     if (pexpiration || pvalue) {
503         mc_record rec;
504         deserialize(value, rec);
505     
506         if (pexpiration)
507             *pexpiration = rec.expiration;
508     
509         if (pvalue)
510             *pvalue = rec.value;
511     }
512   
513     return rec_version;
514 }
515
516 int MemcacheStorageService::updateString(const char* context, const char* key, const char* value, time_t expiration, int version)
517 {
518     m_log.debug("updateString ctx: %s - key: %s", context, key);
519
520     time_t final_exp = expiration;
521     time_t* want_expiration = nullptr;
522     if (!final_exp)
523         want_expiration = &final_exp;
524
525     int read_res = readString(context, key, nullptr, want_expiration, version);
526
527     if (!read_res) {
528         // not found
529         return read_res;
530     }
531
532     if (version && version != read_res) {
533         // version incorrect
534         return -1;
535     }
536
537     // Proceding with update
538     string final_key = string(context) + ":" + string(key);
539     mc_record rec(value, final_exp);
540     string final_value;
541     serialize(rec, final_value);
542
543     replaceMemcache(final_key.c_str(), final_value, final_exp, ++version);
544     return version;
545 }
546
547 bool MemcacheStorageService::deleteString(const char* context, const char* key)
548 {
549     m_log.debug("deleteString ctx: %s - key: %s", context, key);
550   
551     string final_key = string(context) + ":" + string(key);
552
553     // Not updating context map, if there is one. There is no need.
554     return deleteMemcache(final_key.c_str(), 0);
555 }
556
557 void MemcacheStorageService::updateContext(const char* context, time_t expiration)
558 {
559
560     m_log.debug("updateContext ctx: %s", context);
561
562     if (!m_buildMap) {
563         m_log.error("updateContext invoked on a Storage with no context map built!");
564         return;
565     }
566
567     string map_name = context;
568     string ser_arr;
569     uint32_t flags;
570     bool result = getMemcache(map_name.c_str(), ser_arr, &flags);
571   
572     list<string> contents;
573     if (result) {
574         m_log.debug("Match found. Parsing...");
575         deserialize(ser_arr, contents);
576     
577         m_log.debug("Iterating retrieved session map...");
578         for(list<string>::const_iterator iter = contents.begin(); iter != contents.end(); ++iter) {
579             // Update expiration times
580             string value;
581             int read_res = readString(context, iter->c_str(), &value, nullptr, 0);
582             if (!read_res) {
583                 // not found
584                 continue;
585             }
586
587             updateString(context, iter->c_str(), value.c_str(), expiration, read_res);
588         }
589         replaceMemcache(map_name.c_str(), ser_arr, expiration, flags);
590     }
591 }
592
593 void MemcacheStorageService::deleteContext(const char* context)
594 {
595
596     m_log.debug("deleteContext ctx: %s", context);
597
598     if (!m_buildMap) {
599         m_log.error("deleteContext invoked on a Storage with no context map built!");
600         return;
601     }
602
603     string map_name = context;
604     string ser_arr;
605     uint32_t flags;
606     bool result = getMemcache(map_name.c_str(), ser_arr, &flags);
607   
608     list<string> contents;
609     if (result) {
610         m_log.debug("Match found. Parsing...");
611         deserialize(ser_arr, contents);
612     
613         m_log.debug("Iterating retrieved session map...");
614         for (list<string>::const_iterator iter = contents.begin(); iter != contents.end(); ++iter) {
615             string final_key = map_name + *iter;
616             deleteMemcache(final_key.c_str(), 0);
617         }
618     
619         deleteMemcache(map_name.c_str(), 0);
620     }
621 }
622
623 extern "C" int MCEXT_EXPORTS xmltooling_extension_init(void*) {
624     // Register this SS type
625     XMLToolingConfig::getConfig().StorageServiceManager.registerFactory("MEMCACHE", MemcacheStorageServiceFactory);
626     return 0;
627 }
628
629 extern "C" void MCEXT_EXPORTS xmltooling_extension_term() {
630     XMLToolingConfig::getConfig().StorageServiceManager.deregisterFactory("MEMCACHE");
631 }