7547d079fef7b9f8125139eaee172b9d5000b579
[shibboleth/cpp-xmltooling.git] / xmltooling / impl / MemoryStorageService.cpp
1 /**
2  * Licensed to the University Corporation for Advanced Internet
3  * Development, Inc. (UCAID) under one or more contributor license
4  * agreements. See the NOTICE file distributed with this work for
5  * additional information regarding copyright ownership.
6  *
7  * UCAID licenses this file to you under the Apache License,
8  * Version 2.0 (the "License"); you may not use this file except
9  * in compliance with the License. You may obtain a copy of the
10  * License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing,
15  * software distributed under the License is distributed on an
16  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
17  * either express or implied. See the License for the specific
18  * language governing permissions and limitations under the License.
19  */
20
21 /**
22  * MemoryStorageService.cpp
23  *
24  * In-memory "persistent" storage, suitable for simple applications.
25  */
26
27 #include "internal.h"
28 #include "logging.h"
29 #include "util/NDC.h"
30 #include "util/StorageService.h"
31 #include "util/Threads.h"
32 #include "util/XMLHelper.h"
33
34 #include <memory>
35 #include <xercesc/util/XMLUniDefs.hpp>
36
37 using namespace xmltooling::logging;
38 using namespace xmltooling;
39 using namespace std;
40
41 using xercesc::DOMElement;
42
43 namespace {
44     // Reasonably extended sizes to avoid callers needing to shrink unduly.
45     static const XMLTOOL_DLLLOCAL StorageService::Capabilities g_memCaps(0x4000, 0x4000, 0x4000);
46 };
47
48 namespace xmltooling {
49     class XMLTOOL_DLLLOCAL MemoryStorageService : public StorageService
50     {
51     public:
52         MemoryStorageService(const DOMElement* e);
53         virtual ~MemoryStorageService();
54
55         const Capabilities& getCapabilities() const {
56             return g_memCaps;
57         }
58
59         bool createString(const char* context, const char* key, const char* value, time_t expiration);
60         int readString(const char* context, const char* key, string* pvalue=nullptr, time_t* pexpiration=nullptr, int version=0);
61         int updateString(const char* context, const char* key, const char* value=nullptr, time_t expiration=0, int version=0);
62         bool deleteString(const char* context, const char* key);
63
64         bool createText(const char* context, const char* key, const char* value, time_t expiration) {
65             return createString(context, key, value, expiration);
66         }
67         int readText(const char* context, const char* key, string* pvalue=nullptr, time_t* pexpiration=nullptr, int version=0) {
68             return readString(context, key, pvalue, pexpiration, version);
69         }
70         int updateText(const char* context, const char* key, const char* value=nullptr, time_t expiration=0, int version=0) {
71             return updateString(context, key, value, expiration, version);
72         }
73         bool deleteText(const char* context, const char* key) {
74             return deleteString(context, key);
75         }
76
77         void reap(const char* context);
78         void updateContext(const char* context, time_t expiration);
79         void deleteContext(const char* context) {
80             m_lock->wrlock();
81             m_contextMap.erase(context);
82             m_lock->unlock();
83         }
84
85     private:
86         struct XMLTOOL_DLLLOCAL Record {
87             Record() : expiration(0), version(1) {}
88             Record(const string& s, time_t t) : data(s), expiration(t), version(1) {}
89             string data;
90             time_t expiration;
91             int version;
92         };
93
94         struct XMLTOOL_DLLLOCAL Context {
95             Context() {}
96             Context(const Context& src) {
97                 m_dataMap = src.m_dataMap;
98             }
99             map<string,Record> m_dataMap;
100             unsigned long reap(time_t exp);
101         };
102
103         Context& readContext(const char* context) {
104             m_lock->rdlock();
105             map<string,Context>::iterator i = m_contextMap.find(context);
106             if (i != m_contextMap.end())
107                 return i->second;
108             m_lock->unlock();
109             m_lock->wrlock();
110             return m_contextMap[context];
111         }
112
113         Context& writeContext(const char* context) {
114             m_lock->wrlock();
115             return m_contextMap[context];
116         }
117
118         map<string,Context> m_contextMap;
119         RWLock* m_lock;
120         CondWait* shutdown_wait;
121         Thread* cleanup_thread;
122         static void* cleanup_fn(void*);
123         bool shutdown;
124         int m_cleanupInterval;
125         Category& m_log;
126     };
127
128     StorageService* XMLTOOL_DLLLOCAL MemoryStorageServiceFactory(const DOMElement* const & e)
129     {
130         return new MemoryStorageService(e);
131     }
132 };
133
134 static const XMLCh cleanupInterval[] = UNICODE_LITERAL_15(c,l,e,a,n,u,p,I,n,t,e,r,v,a,l);
135
136 MemoryStorageService::MemoryStorageService(const DOMElement* e)
137     : m_lock(nullptr), shutdown_wait(nullptr), cleanup_thread(nullptr), shutdown(false),
138         m_cleanupInterval(XMLHelper::getAttrInt(e, 900, cleanupInterval)),
139         m_log(Category::getInstance(XMLTOOLING_LOGCAT".StorageService"))
140 {
141     m_lock = RWLock::create();
142     shutdown_wait = CondWait::create();
143     cleanup_thread = Thread::create(&cleanup_fn, (void*)this);
144 }
145
146 MemoryStorageService::~MemoryStorageService()
147 {
148     // Shut down the cleanup thread and let it know...
149     shutdown = true;
150     shutdown_wait->signal();
151     cleanup_thread->join(nullptr);
152
153     delete cleanup_thread;
154     delete shutdown_wait;
155     delete m_lock;
156 }
157
158 void* MemoryStorageService::cleanup_fn(void* pv)
159 {
160     MemoryStorageService* cache = reinterpret_cast<MemoryStorageService*>(pv);
161
162 #ifndef WIN32
163     // First, let's block all signals
164     Thread::mask_all_signals();
165 #endif
166
167 #ifdef _DEBUG
168     NDC ndc("cleanup");
169 #endif
170
171     auto_ptr<Mutex> mutex(Mutex::create());
172     mutex->lock();
173
174     cache->m_log.info("cleanup thread started...running every %d seconds", cache->m_cleanupInterval);
175
176     while (!cache->shutdown) {
177         cache->shutdown_wait->timedwait(mutex.get(), cache->m_cleanupInterval);
178         if (cache->shutdown)
179             break;
180
181         unsigned long count=0;
182         time_t now = time(nullptr);
183         cache->m_lock->wrlock();
184         SharedLock locker(cache->m_lock, false);
185         for (map<string,Context>::iterator i=cache->m_contextMap.begin(); i!=cache->m_contextMap.end(); ++i)
186             count += i->second.reap(now);
187
188         if (count)
189             cache->m_log.info("purged %d expired record(s) from storage", count);
190     }
191
192     cache->m_log.info("cleanup thread finished");
193
194     mutex->unlock();
195     return nullptr;
196 }
197
198 void MemoryStorageService::reap(const char* context)
199 {
200     Context& ctx = writeContext(context);
201     SharedLock locker(m_lock, false);
202     ctx.reap(time(nullptr));
203 }
204
205 unsigned long MemoryStorageService::Context::reap(time_t exp)
206 {
207     // Garbage collect any expired entries.
208     unsigned long count=0;
209     map<string,Record>::iterator cur = m_dataMap.begin();
210     map<string,Record>::iterator stop = m_dataMap.end();
211     while (cur != stop) {
212         if (cur->second.expiration <= exp) {
213             map<string,Record>::iterator tmp = cur++;
214             m_dataMap.erase(tmp);
215             ++count;
216         }
217         else {
218             cur++;
219         }
220     }
221     return count;
222 }
223
224 bool MemoryStorageService::createString(const char* context, const char* key, const char* value, time_t expiration)
225 {
226     Context& ctx = writeContext(context);
227     SharedLock locker(m_lock, false);
228
229     // Check for a duplicate.
230     map<string,Record>::iterator i=ctx.m_dataMap.find(key);
231     if (i!=ctx.m_dataMap.end()) {
232         // Not yet expired?
233         if (time(nullptr) < i->second.expiration)
234             return false;
235         // It's dead, so we can just remove it now and create the new record.
236         ctx.m_dataMap.erase(i);
237     }
238
239     ctx.m_dataMap[key]=Record(value,expiration);
240
241     m_log.debug("inserted record (%s) in context (%s) with expiration (%lu)", key, context, expiration);
242     return true;
243 }
244
245 int MemoryStorageService::readString(const char* context, const char* key, string* pvalue, time_t* pexpiration, int version)
246 {
247     Context& ctx = readContext(context);
248     SharedLock locker(m_lock, false);
249
250     map<string,Record>::iterator i=ctx.m_dataMap.find(key);
251     if (i==ctx.m_dataMap.end())
252         return 0;
253     else if (time(nullptr) >= i->second.expiration)
254         return 0;
255     if (pexpiration)
256         *pexpiration = i->second.expiration;
257     if (i->second.version == version)
258         return version; // nothing's changed, so just echo back the version
259     if (pvalue)
260         *pvalue = i->second.data;
261     return i->second.version;
262 }
263
264 int MemoryStorageService::updateString(const char* context, const char* key, const char* value, time_t expiration, int version)
265 {
266     Context& ctx = writeContext(context);
267     SharedLock locker(m_lock, false);
268
269     map<string,Record>::iterator i=ctx.m_dataMap.find(key);
270     if (i==ctx.m_dataMap.end())
271         return 0;
272     else if (time(nullptr) >= i->second.expiration)
273         return 0;
274
275     if (version > 0 && version != i->second.version)
276         return -1;  // caller's out of sync
277
278     if (value) {
279         i->second.data = value;
280         ++(i->second.version);
281     }
282
283     if (expiration && expiration != i->second.expiration)
284         i->second.expiration = expiration;
285
286     m_log.debug("updated record (%s) in context (%s) with expiration (%lu)", key, context, i->second.expiration);
287     return i->second.version;
288 }
289
290 bool MemoryStorageService::deleteString(const char* context, const char* key)
291 {
292     Context& ctx = writeContext(context);
293     SharedLock locker(m_lock, false);
294
295     // Find the record.
296     map<string,Record>::iterator i=ctx.m_dataMap.find(key);
297     if (i!=ctx.m_dataMap.end()) {
298         ctx.m_dataMap.erase(i);
299         m_log.debug("deleted record (%s) in context (%s)", key, context);
300         return true;
301     }
302
303     m_log.debug("deleting record (%s) in context (%s)....not found", key, context);
304     return false;
305 }
306
307 void MemoryStorageService::updateContext(const char* context, time_t expiration)
308 {
309     Context& ctx = writeContext(context);
310     SharedLock locker(m_lock, false);
311
312     time_t now = time(nullptr);
313     map<string,Record>::iterator stop=ctx.m_dataMap.end();
314     for (map<string,Record>::iterator i = ctx.m_dataMap.begin(); i!=stop; ++i) {
315         if (now < i->second.expiration)
316             i->second.expiration = expiration;
317     }
318
319     m_log.debug("updated expiration of valid records in context (%s) to (%lu)", context, expiration);
320 }