Merge branch '1.x' of ssh://authdev.it.ohio-state.edu/~scantor/git/cpp-xmltooling...
[shibboleth/cpp-xmltooling.git] / xmltooling / impl / MemoryStorageService.cpp
1 /**
2  * Licensed to the University Corporation for Advanced Internet
3  * Development, Inc. (UCAID) under one or more contributor license
4  * agreements. See the NOTICE file distributed with this work for
5  * additional information regarding copyright ownership.
6  *
7  * UCAID licenses this file to you under the Apache License,
8  * Version 2.0 (the "License"); you may not use this file except
9  * in compliance with the License. You may obtain a copy of the
10  * License at
11  *
12  * http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing,
15  * software distributed under the License is distributed on an
16  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
17  * either express or implied. See the License for the specific
18  * language governing permissions and limitations under the License.
19  */
20
21 /**
22  * MemoryStorageService.cpp
23  *
24  * In-memory "persistent" storage, suitable for simple applications.
25  */
26
27 #include "internal.h"
28 #include "logging.h"
29 #include "util/NDC.h"
30 #include "util/StorageService.h"
31 #include "util/Threads.h"
32 #include "util/XMLHelper.h"
33
34 #include <memory>
35 #include <xercesc/util/XMLUniDefs.hpp>
36
37 using namespace xmltooling::logging;
38 using namespace xmltooling;
39 using namespace std;
40
41 using xercesc::DOMElement;
42
43 namespace xmltooling {
44     class XMLTOOL_DLLLOCAL MemoryStorageService : public StorageService
45     {
46     public:
47         MemoryStorageService(const DOMElement* e);
48         virtual ~MemoryStorageService();
49
50         bool createString(const char* context, const char* key, const char* value, time_t expiration);
51         int readString(const char* context, const char* key, string* pvalue=nullptr, time_t* pexpiration=nullptr, int version=0);
52         int updateString(const char* context, const char* key, const char* value=nullptr, time_t expiration=0, int version=0);
53         bool deleteString(const char* context, const char* key);
54
55         bool createText(const char* context, const char* key, const char* value, time_t expiration) {
56             return createString(context, key, value, expiration);
57         }
58         int readText(const char* context, const char* key, string* pvalue=nullptr, time_t* pexpiration=nullptr, int version=0) {
59             return readString(context, key, pvalue, pexpiration, version);
60         }
61         int updateText(const char* context, const char* key, const char* value=nullptr, time_t expiration=0, int version=0) {
62             return updateString(context, key, value, expiration, version);
63         }
64         bool deleteText(const char* context, const char* key) {
65             return deleteString(context, key);
66         }
67
68         void reap(const char* context);
69         void updateContext(const char* context, time_t expiration);
70         void deleteContext(const char* context) {
71             m_lock->wrlock();
72             m_contextMap.erase(context);
73             m_lock->unlock();
74         }
75
76     private:
77         struct XMLTOOL_DLLLOCAL Record {
78             Record() : expiration(0), version(1) {}
79             Record(const string& s, time_t t) : data(s), expiration(t), version(1) {}
80             string data;
81             time_t expiration;
82             int version;
83         };
84
85         struct XMLTOOL_DLLLOCAL Context {
86             Context() {}
87             Context(const Context& src) {
88                 m_dataMap = src.m_dataMap;
89             }
90             map<string,Record> m_dataMap;
91             unsigned long reap(time_t exp);
92         };
93
94         Context& readContext(const char* context) {
95             m_lock->rdlock();
96             map<string,Context>::iterator i = m_contextMap.find(context);
97             if (i != m_contextMap.end())
98                 return i->second;
99             m_lock->unlock();
100             m_lock->wrlock();
101             return m_contextMap[context];
102         }
103
104         Context& writeContext(const char* context) {
105             m_lock->wrlock();
106             return m_contextMap[context];
107         }
108
109         map<string,Context> m_contextMap;
110         RWLock* m_lock;
111         CondWait* shutdown_wait;
112         Thread* cleanup_thread;
113         static void* cleanup_fn(void*);
114         bool shutdown;
115         int m_cleanupInterval;
116         Category& m_log;
117     };
118
119     StorageService* XMLTOOL_DLLLOCAL MemoryStorageServiceFactory(const DOMElement* const & e)
120     {
121         return new MemoryStorageService(e);
122     }
123 };
124
125 static const XMLCh cleanupInterval[] = UNICODE_LITERAL_15(c,l,e,a,n,u,p,I,n,t,e,r,v,a,l);
126
127 MemoryStorageService::MemoryStorageService(const DOMElement* e)
128     : m_lock(nullptr), shutdown_wait(nullptr), cleanup_thread(nullptr), shutdown(false),
129         m_cleanupInterval(XMLHelper::getAttrInt(e, 900, cleanupInterval)),
130         m_log(Category::getInstance(XMLTOOLING_LOGCAT".StorageService"))
131 {
132     m_lock = RWLock::create();
133     shutdown_wait = CondWait::create();
134     cleanup_thread = Thread::create(&cleanup_fn, (void*)this);
135 }
136
137 MemoryStorageService::~MemoryStorageService()
138 {
139     // Shut down the cleanup thread and let it know...
140     shutdown = true;
141     shutdown_wait->signal();
142     cleanup_thread->join(nullptr);
143
144     delete cleanup_thread;
145     delete shutdown_wait;
146     delete m_lock;
147 }
148
149 void* MemoryStorageService::cleanup_fn(void* pv)
150 {
151     MemoryStorageService* cache = reinterpret_cast<MemoryStorageService*>(pv);
152
153 #ifndef WIN32
154     // First, let's block all signals
155     Thread::mask_all_signals();
156 #endif
157
158 #ifdef _DEBUG
159     NDC ndc("cleanup");
160 #endif
161
162     auto_ptr<Mutex> mutex(Mutex::create());
163     mutex->lock();
164
165     cache->m_log.info("cleanup thread started...running every %d seconds", cache->m_cleanupInterval);
166
167     while (!cache->shutdown) {
168         cache->shutdown_wait->timedwait(mutex.get(), cache->m_cleanupInterval);
169         if (cache->shutdown)
170             break;
171
172         unsigned long count=0;
173         time_t now = time(nullptr);
174         cache->m_lock->wrlock();
175         SharedLock locker(cache->m_lock, false);
176         for (map<string,Context>::iterator i=cache->m_contextMap.begin(); i!=cache->m_contextMap.end(); ++i)
177             count += i->second.reap(now);
178
179         if (count)
180             cache->m_log.info("purged %d expired record(s) from storage", count);
181     }
182
183     cache->m_log.info("cleanup thread finished");
184
185     mutex->unlock();
186     return nullptr;
187 }
188
189 void MemoryStorageService::reap(const char* context)
190 {
191     Context& ctx = writeContext(context);
192     SharedLock locker(m_lock, false);
193     ctx.reap(time(nullptr));
194 }
195
196 unsigned long MemoryStorageService::Context::reap(time_t exp)
197 {
198     // Garbage collect any expired entries.
199     unsigned long count=0;
200     map<string,Record>::iterator cur = m_dataMap.begin();
201     map<string,Record>::iterator stop = m_dataMap.end();
202     while (cur != stop) {
203         if (cur->second.expiration <= exp) {
204             map<string,Record>::iterator tmp = cur++;
205             m_dataMap.erase(tmp);
206             ++count;
207         }
208         else {
209             cur++;
210         }
211     }
212     return count;
213 }
214
215 bool MemoryStorageService::createString(const char* context, const char* key, const char* value, time_t expiration)
216 {
217     Context& ctx = writeContext(context);
218     SharedLock locker(m_lock, false);
219
220     // Check for a duplicate.
221     map<string,Record>::iterator i=ctx.m_dataMap.find(key);
222     if (i!=ctx.m_dataMap.end()) {
223         // Not yet expired?
224         if (time(nullptr) < i->second.expiration)
225             return false;
226         // It's dead, so we can just remove it now and create the new record.
227         ctx.m_dataMap.erase(i);
228     }
229
230     ctx.m_dataMap[key]=Record(value,expiration);
231
232     m_log.debug("inserted record (%s) in context (%s) with expiration (%lu)", key, context, expiration);
233     return true;
234 }
235
236 int MemoryStorageService::readString(const char* context, const char* key, string* pvalue, time_t* pexpiration, int version)
237 {
238     Context& ctx = readContext(context);
239     SharedLock locker(m_lock, false);
240
241     map<string,Record>::iterator i=ctx.m_dataMap.find(key);
242     if (i==ctx.m_dataMap.end())
243         return 0;
244     else if (time(nullptr) >= i->second.expiration)
245         return 0;
246     if (pexpiration)
247         *pexpiration = i->second.expiration;
248     if (i->second.version == version)
249         return version; // nothing's changed, so just echo back the version
250     if (pvalue)
251         *pvalue = i->second.data;
252     return i->second.version;
253 }
254
255 int MemoryStorageService::updateString(const char* context, const char* key, const char* value, time_t expiration, int version)
256 {
257     Context& ctx = writeContext(context);
258     SharedLock locker(m_lock, false);
259
260     map<string,Record>::iterator i=ctx.m_dataMap.find(key);
261     if (i==ctx.m_dataMap.end())
262         return 0;
263     else if (time(nullptr) >= i->second.expiration)
264         return 0;
265
266     if (version > 0 && version != i->second.version)
267         return -1;  // caller's out of sync
268
269     if (value) {
270         i->second.data = value;
271         ++(i->second.version);
272     }
273
274     if (expiration && expiration != i->second.expiration)
275         i->second.expiration = expiration;
276
277     m_log.debug("updated record (%s) in context (%s) with expiration (%lu)", key, context, i->second.expiration);
278     return i->second.version;
279 }
280
281 bool MemoryStorageService::deleteString(const char* context, const char* key)
282 {
283     Context& ctx = writeContext(context);
284     SharedLock locker(m_lock, false);
285
286     // Find the record.
287     map<string,Record>::iterator i=ctx.m_dataMap.find(key);
288     if (i!=ctx.m_dataMap.end()) {
289         ctx.m_dataMap.erase(i);
290         m_log.debug("deleted record (%s) in context (%s)", key, context);
291         return true;
292     }
293
294     m_log.debug("deleting record (%s) in context (%s)....not found", key, context);
295     return false;
296 }
297
298 void MemoryStorageService::updateContext(const char* context, time_t expiration)
299 {
300     Context& ctx = writeContext(context);
301     SharedLock locker(m_lock, false);
302
303     time_t now = time(nullptr);
304     map<string,Record>::iterator stop=ctx.m_dataMap.end();
305     for (map<string,Record>::iterator i = ctx.m_dataMap.begin(); i!=stop; ++i) {
306         if (now < i->second.expiration)
307             i->second.expiration = expiration;
308     }
309
310     m_log.debug("updated expiration of valid records in context (%s) to (%lu)", context, expiration);
311 }