https://issues.shibboleth.net/jira/browse/SSPCPP-420
[shibboleth/cpp-sp.git] / memcache-store / memcache-store.cpp
1 /**
2  * Licensed to the University Corporation for Advanced Internet
3  * Development, Inc. (UCAID) under one or more contributor license
4  * agreements. See the NOTICE file distributed with this work for
5  * additional information regarding copyright ownership.
6  *
7  * UCAID licenses this file to you under the Apache License,
8  * Version 2.0 (the "License"); you may not use this file except
9  * in compliance with the License. You may obtain a copy of the
10  * License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing,
15  * software distributed under the License is distributed on an
16  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
17  * either express or implied. See the License for the specific
18  * language governing permissions and limitations under the License.
19  */
20
21 /**
22  * memcache-store.cpp
23  *
24  * Storage Service using memcache (pre memcache tags).
25  */
26
27 #if defined (_MSC_VER) || defined(__BORLANDC__)
28 # include "config_win32.h"
29 #else
30 # include "config.h"
31 #endif
32
33 #ifdef WIN32
34 # define _CRT_NONSTDC_NO_DEPRECATE 1
35 # define _CRT_SECURE_NO_DEPRECATE 1
36 # define MCEXT_EXPORTS __declspec(dllexport)
37 #else
38 # define MCEXT_EXPORTS
39 #endif
40
41 #include <xmltooling/base.h>
42
43 #include <list>
44 #include <iostream> 
45 #include <boost/scoped_ptr.hpp>
46 #include <libmemcached/memcached.h>
47 #include <xercesc/util/XMLUniDefs.hpp>
48
49 #include <xmltooling/logging.h>
50 #include <xmltooling/unicode.h>
51 #include <xmltooling/XMLToolingConfig.h>
52 #include <xmltooling/util/NDC.h>
53 #include <xmltooling/util/StorageService.h>
54 #include <xmltooling/util/Threads.h>
55 #include <xmltooling/util/XMLHelper.h>
56
57 using namespace xmltooling::logging;
58 using namespace xmltooling;
59 using namespace xercesc;
60 using namespace boost;
61 using namespace std;
62
63 namespace {
64     static const XMLCh Hosts[] = UNICODE_LITERAL_5(H,o,s,t,s);
65     static const XMLCh prefix[] = UNICODE_LITERAL_6(p,r,e,f,i,x);
66     static const XMLCh buildMap[] = UNICODE_LITERAL_8(b,u,i,l,d,M,a,p);
67     static const XMLCh sendTimeout[] = UNICODE_LITERAL_11(s,e,n,d,T,i,m,e,o,u,t);
68     static const XMLCh recvTimeout[] = UNICODE_LITERAL_11(r,e,c,v,T,i,m,e,o,u,t);
69     static const XMLCh pollTimeout[] = UNICODE_LITERAL_11(p,o,l,l,T,i,m,e,o,u,t);
70     static const XMLCh failLimit[] = UNICODE_LITERAL_9(f,a,i,l,L,i,m,i,t);
71     static const XMLCh retryTimeout[] = UNICODE_LITERAL_12(r,e,t,r,y,T,i,m,e,o,u,t);
72     static const XMLCh nonBlocking[] = UNICODE_LITERAL_11(n,o,n,B,l,o,c,k,i,n,g);
73   
74     class mc_record {
75     public:
76         string value;
77         time_t expiration;
78         mc_record() {};
79         mc_record(string _v, time_t _e) : value(_v), expiration(_e) {}
80     };
81
82     class MemcacheBase {
83     public:
84         MemcacheBase(const DOMElement* e);
85         ~MemcacheBase();
86         
87         bool addMemcache(const char* key, string &value, time_t timeout, uint32_t flags, bool use_prefix = true);
88         bool setMemcache(const char* key, string &value, time_t timeout, uint32_t flags, bool use_prefix = true);
89         bool replaceMemcache(const char* key, string &value, time_t timeout, uint32_t flags, bool use_prefix = true);
90         bool getMemcache(const char* key, string &dest, uint32_t *flags, bool use_prefix = true);
91         bool deleteMemcache(const char* key, time_t timeout, bool use_prefix = true);
92
93         void serialize(mc_record &source, string &dest);
94         void serialize(list<string> &source, string &dest);
95         void deserialize(string &source, mc_record &dest);
96         void deserialize(string &source, list<string> &dest);
97
98         bool addLock(string what, bool use_prefix = true);
99         void deleteLock(string what, bool use_prefix = true);
100
101     protected:
102         Category& m_log;
103         memcached_st* memc;
104         string m_prefix;
105         scoped_ptr<Mutex> m_lock;
106
107     private:
108         bool handleError(const char*, memcached_return) const;
109     };
110   
111     class MemcacheStorageService : public StorageService, public MemcacheBase {
112
113     public:
114         MemcacheStorageService(const DOMElement* e);
115         ~MemcacheStorageService() {}
116     
117         const Capabilities& getCapabilities() const {
118             return m_caps;
119         }
120
121         bool createString(const char* context, const char* key, const char* value, time_t expiration);
122         int readString(const char* context, const char* key, string* pvalue=nullptr, time_t* pexpiration=nullptr, int version=0);
123         int updateString(const char* context, const char* key, const char* value=nullptr, time_t expiration=0, int version=0);
124         bool deleteString(const char* context, const char* key);
125     
126         bool createText(const char* context, const char* key, const char* value, time_t expiration) {
127             return createString(context, key, value, expiration);
128         }
129         int readText(const char* context, const char* key, string* pvalue=nullptr, time_t* pexpiration=nullptr, int version=0) {
130             return readString(context, key, pvalue, pexpiration, version);
131         }
132         int updateText(const char* context, const char* key, const char* value=nullptr, time_t expiration=0, int version=0) {
133             return updateString(context, key, value, expiration, version);
134         }
135         bool deleteText(const char* context, const char* key) {
136             return deleteString(context, key);
137         }
138     
139         void reap(const char* context) {}
140
141         void updateContext(const char* context, time_t expiration);
142         void deleteContext(const char* context);
143
144     private:
145         Capabilities m_caps;
146         bool m_buildMap;
147     };
148
149     StorageService* MemcacheStorageServiceFactory(const DOMElement* const & e) {
150         return new MemcacheStorageService(e);
151     }
152 };
153
154 MemcacheBase::MemcacheBase(const DOMElement* e)
155     : m_log(Category::getInstance("XMLTooling.StorageService.MEMCACHE")), memc(nullptr),
156         m_prefix(XMLHelper::getAttrString(e, nullptr, prefix)), m_lock(Mutex::create())
157 {
158     memc = memcached_create(nullptr);
159     if (!memc)
160         throw XMLToolingException("MemcacheBase::Memcache(): memcached_create() failed");
161     m_log.debug("Memcache created");
162
163     memcached_behavior_set(memc, MEMCACHED_BEHAVIOR_HASH, MEMCACHED_HASH_CRC);
164     m_log.debug("CRC hash set");
165
166     int prop = XMLHelper::getAttrInt(e, 999999, sendTimeout);
167     m_log.debug("MEMCACHED_BEHAVIOR_SND_TIMEOUT will be set to %d", prop);
168     memcached_behavior_set(memc, MEMCACHED_BEHAVIOR_SND_TIMEOUT, prop);
169
170     prop = XMLHelper::getAttrInt(e, 999999, recvTimeout);
171     m_log.debug("MEMCACHED_BEHAVIOR_RCV_TIMEOUT will be set to %d", prop);
172     memcached_behavior_set(memc, MEMCACHED_BEHAVIOR_RCV_TIMEOUT, prop);
173
174     prop = XMLHelper::getAttrInt(e, 1000, pollTimeout);
175     m_log.debug("MEMCACHED_BEHAVIOR_POLL_TIMEOUT will be set to %d", prop);
176     memcached_behavior_set(memc, MEMCACHED_BEHAVIOR_POLL_TIMEOUT, prop);
177
178     prop = XMLHelper::getAttrInt(e, 5, failLimit);
179     m_log.debug("MEMCACHED_BEHAVIOR_SERVER_FAILURE_LIMIT will be set to %d", prop);
180     memcached_behavior_set(memc, MEMCACHED_BEHAVIOR_SERVER_FAILURE_LIMIT, prop);
181
182     prop = XMLHelper::getAttrInt(e, 30, retryTimeout);
183     m_log.debug("MEMCACHED_BEHAVIOR_RETRY_TIMEOUT will be set to %d", prop);
184     memcached_behavior_set(memc, MEMCACHED_BEHAVIOR_RETRY_TIMEOUT, prop);
185
186     prop = XMLHelper::getAttrInt(e, 1, nonBlocking);
187     m_log.debug("MEMCACHED_BEHAVIOR_NO_BLOCK will be set to %d", prop);
188     memcached_behavior_set(memc, MEMCACHED_BEHAVIOR_NO_BLOCK, prop);
189
190     // Grab hosts from the configuration.
191     e = e ? XMLHelper::getFirstChildElement(e, Hosts) : nullptr;
192     if (!e || !e->hasChildNodes()) {
193         memcached_free(memc);
194         throw XMLToolingException("Memcache StorageService requires Hosts element in configuration.");
195     }
196     auto_ptr_char h(e->getTextContent());
197     m_log.debug("INIT: GOT Hosts: %s", h.get());
198     memcached_server_st* servers;
199     servers = memcached_servers_parse(const_cast<char*>(h.get()));
200     m_log.debug("Got %u hosts.",  memcached_server_list_count(servers));
201     if (memcached_server_push(memc, servers) != MEMCACHED_SUCCESS) {
202         memcached_server_list_free(servers);
203         memcached_free(memc);
204         throw IOException("MemcacheBase: memcached_server_push() failed");
205     }
206     memcached_server_list_free(servers);
207
208     m_log.debug("Memcache object initialized");
209 }
210
211 MemcacheBase::~MemcacheBase()
212 {
213     memcached_free(memc);
214     m_log.debug("Base object destroyed");
215 }
216
217
218 bool MemcacheBase::handleError(const char* fn, memcached_return rv) const
219 {
220 #ifdef HAVE_MEMCACHED_LAST_ERROR_MESSAGE
221     string error = string("Memcache::") + fn + ": " + memcached_last_error_message(memc);
222 #else
223     string error;
224     if (rv == MEMCACHED_ERRNO) {
225         // System error
226         error = string("Memcache::") + fn + "SYSTEM ERROR: " + strerror(memc->cached_errno);
227     }
228     else {
229         error = string("Memcache::") + fn + " Problems: " + memcached_strerror(memc, rv);
230     }
231 #endif
232     m_log.error(error);
233     throw IOException(error);
234 }
235
236 bool MemcacheBase::addLock(string what, bool use_prefix)
237 {
238     string lock_name = what + ":LOCK";
239     string set_val = "1";
240     unsigned tries = 5;
241     while (!addMemcache(lock_name.c_str(), set_val, 5, 0, use_prefix)) {
242         if (tries-- == 0) {
243             m_log.debug("Unable to get lock %s... FAILED.", lock_name.c_str());
244             return false;
245         }
246         m_log.debug("Unable to get lock %s... Retrying.", lock_name.c_str());
247     
248         // sleep 100ms
249 #ifdef WIN32
250         Sleep(100);
251 #else
252         struct timeval tv = { 0, 100000 };
253         select(0, 0, 0, 0, &tv);
254 #endif
255     }
256     return true;
257 }
258
259 void MemcacheBase::deleteLock(string what, bool use_prefix)
260 {
261     string lock_name = what + ":LOCK";
262     deleteMemcache(lock_name.c_str(), 0, use_prefix);
263     return;
264
265 }  
266
267 void MemcacheBase::deserialize(string& source, mc_record& dest)
268 {
269     istringstream is(source, stringstream::in | stringstream::out);
270     is >> dest.expiration;
271     is.ignore(1); // ignore delimiter
272     dest.value = is.str().c_str() + is.tellg();
273 }
274
275 void MemcacheBase::deserialize(string& source, list<string>& dest)
276 {
277     istringstream is(source, stringstream::in | stringstream::out);
278     while (!is.eof()) {
279         string s;
280         is >> s;
281         dest.push_back(s);
282     }
283 }
284
285 void MemcacheBase::serialize(mc_record& source, string& dest)
286 {
287     ostringstream os(stringstream::in | stringstream::out);
288     os << source.expiration;
289     os << "-"; // delimiter
290     os << source.value;
291     dest = os.str();
292 }
293
294 void MemcacheBase::serialize(list<string>& source, string& dest)
295 {
296     ostringstream os(stringstream::in | stringstream::out);
297     for(list<string>::iterator iter = source.begin(); iter != source.end(); iter++) {
298         if (iter != source.begin()) {
299             os << endl;
300         }
301         os << *iter;
302     }
303     dest = os.str();
304 }
305
306 bool MemcacheBase::deleteMemcache(const char* key, time_t timeout, bool use_prefix)
307 {
308     string final_key;
309     if (use_prefix)
310         final_key = m_prefix + key;
311     else
312         final_key = key;
313
314     Lock lock(m_lock);
315     memcached_return rv = memcached_delete(memc, const_cast<char*>(final_key.c_str()), final_key.length(), timeout);
316
317     switch (rv) {
318         case MEMCACHED_SUCCESS:
319             return true;
320         case MEMCACHED_NOTFOUND:
321             // Key wasn't there... No biggie.
322             return false;
323         default:
324             return handleError("deleteMemcache", rv);
325     }
326 }
327
328 bool MemcacheBase::getMemcache(const char* key, string& dest, uint32_t* flags, bool use_prefix)
329 {
330     string final_key;
331     if (use_prefix)
332         final_key = m_prefix + key;
333     else
334         final_key = key;
335
336     Lock lock(m_lock);
337     size_t len;
338     memcached_return rv;
339     char* result = memcached_get(memc, const_cast<char*>(final_key.c_str()), final_key.length(), &len, flags, &rv);
340
341     switch (rv) {
342         case MEMCACHED_SUCCESS:
343             dest = result;
344             free(result);
345             return true;
346         case MEMCACHED_NOTFOUND:
347             m_log.debug("Key %s not found in memcache...", key);
348             return false;
349         default:
350             return handleError("getMemcache", rv);
351     }
352 }
353
354 bool MemcacheBase::addMemcache(const char* key, string& value, time_t timeout, uint32_t flags, bool use_prefix)
355 {
356     string final_key;
357     if (use_prefix)
358         final_key = m_prefix + key;
359     else
360         final_key = key;
361
362     Lock lock(m_lock);
363     memcached_return rv = memcached_add(
364         memc, const_cast<char*>(final_key.c_str()), final_key.length(), const_cast<char*>(value.c_str()), value.length(), timeout, flags
365         );
366
367     switch (rv) {
368         case MEMCACHED_SUCCESS:
369             return true;
370         case MEMCACHED_NOTSTORED:
371             return false;
372         default:
373             return handleError("addMemcache", rv);
374     }
375 }
376
377 bool MemcacheBase::setMemcache(const char* key, string& value, time_t timeout, uint32_t flags, bool use_prefix)
378 {
379     string final_key;
380     if (use_prefix)
381         final_key = m_prefix + key;
382     else
383         final_key = key;
384
385     Lock lock(m_lock);
386     memcached_return rv = memcached_set(
387         memc, const_cast<char*>(final_key.c_str()), final_key.length(), const_cast<char*>(value.c_str()), value.length(), timeout, flags
388         );
389
390     if (rv == MEMCACHED_SUCCESS)
391         return true;
392     return handleError("setMemcache", rv);
393 }
394
395 bool MemcacheBase::replaceMemcache(const char* key, string& value, time_t timeout, uint32_t flags, bool use_prefix)
396 {
397   
398     string final_key;
399     if (use_prefix)
400         final_key = m_prefix + key;
401     else
402         final_key = key;
403
404     Lock lock(m_lock);
405     memcached_return rv = memcached_replace(
406         memc, const_cast<char*>(final_key.c_str()), final_key.length(), const_cast<char*>(value.c_str()), value.length(), timeout, flags
407         );
408
409     switch (rv) {
410         case MEMCACHED_SUCCESS:
411             return true;
412         case MEMCACHED_NOTSTORED:
413             // not there
414             return false;
415         default:
416             return handleError("replaceMemcache", rv);
417     }
418 }
419
420
421 MemcacheStorageService::MemcacheStorageService(const DOMElement* e)
422     : MemcacheBase(e), m_caps(80, 250 - m_prefix.length() - 1 - 80, 255),
423         m_buildMap(XMLHelper::getAttrBool(e, false, buildMap))
424 {
425     if (m_buildMap)
426         m_log.debug("Cache built with buildMap ON");
427 }
428
429 bool MemcacheStorageService::createString(const char* context, const char* key, const char* value, time_t expiration)
430 {
431     m_log.debug("createString ctx: %s - key: %s", context, key);
432
433     string final_key = string(context) + ":" + string(key);
434
435     mc_record rec(value, expiration);
436     string final_value;
437     serialize(rec, final_value);
438
439     bool result = addMemcache(final_key.c_str(), final_value, expiration, 1); // the flag will be the version
440
441     if (result && m_buildMap) {
442         m_log.debug("Got result, updating map");
443
444         string map_name = context;
445         // we need to update the context map
446         if (!addLock(map_name)) {
447             m_log.error("Unable to get lock for context %s!", context);
448             deleteMemcache(final_key.c_str(), 0);
449             return false;
450         }
451
452         string ser_arr;
453         uint32_t flags;
454         bool result = getMemcache(map_name.c_str(), ser_arr, &flags);
455     
456         list<string> contents;
457         if (result) {
458             m_log.debug("Match found. Parsing...");
459             deserialize(ser_arr, contents);
460             if (m_log.isDebugEnabled()) {
461                 m_log.debug("Iterating retrieved session map...");
462                 for(list<string>::const_iterator iter = contents.begin(); iter != contents.end(); ++iter)
463                     m_log.debug("value = %s", iter->c_str());
464             }
465         }
466         else {
467             m_log.debug("New context: %s", map_name.c_str());
468         }
469
470         contents.push_back(key);
471         serialize(contents, ser_arr);
472         setMemcache(map_name.c_str(), ser_arr, expiration, 0);
473         deleteLock(map_name);
474     }
475     return result;
476 }
477
478 int MemcacheStorageService::readString(const char* context, const char* key, string* pvalue, time_t* pexpiration, int version)
479 {
480     m_log.debug("readString ctx: %s - key: %s", context, key);
481
482     string final_key = string(context) + ":" + string(key);
483     uint32_t rec_version;
484     string value;
485
486     if (m_buildMap) {
487         m_log.debug("Checking context");
488         string map_name = context;
489         string ser_arr;
490         uint32_t flags;
491         bool ctx_found = getMemcache(map_name.c_str(), ser_arr, &flags);
492         if (!ctx_found)
493             return 0;
494     }
495
496     bool found = getMemcache(final_key.c_str(), value, &rec_version);
497     if (!found)
498         return 0;
499
500     if (version && rec_version <= (uint32_t)version)
501         return version;
502
503     if (pexpiration || pvalue) {
504         mc_record rec;
505         deserialize(value, rec);
506     
507         if (pexpiration)
508             *pexpiration = rec.expiration;
509     
510         if (pvalue)
511             *pvalue = rec.value;
512     }
513   
514     return rec_version;
515 }
516
517 int MemcacheStorageService::updateString(const char* context, const char* key, const char* value, time_t expiration, int version)
518 {
519     m_log.debug("updateString ctx: %s - key: %s", context, key);
520
521     time_t final_exp = expiration;
522     time_t* want_expiration = nullptr;
523     if (!final_exp)
524         want_expiration = &final_exp;
525
526     int read_res = readString(context, key, nullptr, want_expiration, version);
527
528     if (!read_res) {
529         // not found
530         return read_res;
531     }
532
533     if (version && version != read_res) {
534         // version incorrect
535         return -1;
536     }
537
538     // Proceding with update
539     string final_key = string(context) + ":" + string(key);
540     mc_record rec(value, final_exp);
541     string final_value;
542     serialize(rec, final_value);
543
544     replaceMemcache(final_key.c_str(), final_value, final_exp, ++version);
545     return version;
546 }
547
548 bool MemcacheStorageService::deleteString(const char* context, const char* key)
549 {
550     m_log.debug("deleteString ctx: %s - key: %s", context, key);
551   
552     string final_key = string(context) + ":" + string(key);
553
554     // Not updating context map, if there is one. There is no need.
555     return deleteMemcache(final_key.c_str(), 0);
556 }
557
558 void MemcacheStorageService::updateContext(const char* context, time_t expiration)
559 {
560
561     m_log.debug("updateContext ctx: %s", context);
562
563     if (!m_buildMap) {
564         m_log.error("updateContext invoked on a Storage with no context map built!");
565         return;
566     }
567
568     string map_name = context;
569     string ser_arr;
570     uint32_t flags;
571     bool result = getMemcache(map_name.c_str(), ser_arr, &flags);
572   
573     list<string> contents;
574     if (result) {
575         m_log.debug("Match found. Parsing...");
576         deserialize(ser_arr, contents);
577     
578         m_log.debug("Iterating retrieved session map...");
579         for(list<string>::const_iterator iter = contents.begin(); iter != contents.end(); ++iter) {
580             // Update expiration times
581             string value;
582             int read_res = readString(context, iter->c_str(), &value, nullptr, 0);
583             if (!read_res) {
584                 // not found
585                 continue;
586             }
587
588             updateString(context, iter->c_str(), value.c_str(), expiration, read_res);
589         }
590         replaceMemcache(map_name.c_str(), ser_arr, expiration, flags);
591     }
592 }
593
594 void MemcacheStorageService::deleteContext(const char* context)
595 {
596
597     m_log.debug("deleteContext ctx: %s", context);
598
599     if (!m_buildMap) {
600         m_log.error("deleteContext invoked on a Storage with no context map built!");
601         return;
602     }
603
604     string map_name = context;
605     string ser_arr;
606     uint32_t flags;
607     bool result = getMemcache(map_name.c_str(), ser_arr, &flags);
608   
609     list<string> contents;
610     if (result) {
611         m_log.debug("Match found. Parsing...");
612         deserialize(ser_arr, contents);
613     
614         m_log.debug("Iterating retrieved session map...");
615         for (list<string>::const_iterator iter = contents.begin(); iter != contents.end(); ++iter) {
616             string final_key = map_name + *iter;
617             deleteMemcache(final_key.c_str(), 0);
618         }
619     
620         deleteMemcache(map_name.c_str(), 0);
621     }
622 }
623
624 extern "C" int MCEXT_EXPORTS xmltooling_extension_init(void*) {
625     // Register this SS type
626     XMLToolingConfig::getConfig().StorageServiceManager.registerFactory("MEMCACHE", MemcacheStorageServiceFactory);
627     return 0;
628 }
629
630 extern "C" void MCEXT_EXPORTS xmltooling_extension_term() {
631     XMLToolingConfig::getConfig().StorageServiceManager.deregisterFactory("MEMCACHE");
632 }