remove unnecessary assert
[freeradius.git] / src / main / state.c
index 4e288d8..b195a7c 100644 (file)
@@ -30,23 +30,6 @@ RCSID("$Id$")
 #include <freeradius-devel/state.h>
 #include <freeradius-devel/rad_assert.h>
 
-static rbtree_t *state_tree;
-
-#ifdef HAVE_PTHREAD_H
-static pthread_mutex_t state_mutex;
-
-#define PTHREAD_MUTEX_LOCK pthread_mutex_lock
-#define PTHREAD_MUTEX_UNLOCK pthread_mutex_unlock
-
-#else
-/*
- *     This is easier than ifdef's throughout the code.
- */
-#define PTHREAD_MUTEX_LOCK(_x)
-#define PTHREAD_MUTEX_UNLOCK(_x)
-
-#endif
-
 typedef struct state_entry_t {
        uint8_t         state[AUTH_VECTOR_LEN];
 
@@ -56,14 +39,38 @@ typedef struct state_entry_t {
 
        int             tries;
 
-       VALUE_PAIR      *vps;
+       TALLOC_CTX              *ctx;
+       VALUE_PAIR              *vps;
 
        void            *opaque;
        void            (*free_opaque)(void *opaque);
 } state_entry_t;
 
-static state_entry_t *state_head = NULL;
-static state_entry_t *state_tail = NULL;
+struct fr_state_t {
+       rbtree_t *tree;
+
+       state_entry_t *head, *tail;
+
+#ifdef HAVE_PTHREAD_H
+       pthread_mutex_t mutex;
+#endif
+};
+
+static fr_state_t global_state;
+
+#ifdef HAVE_PTHREAD_H
+
+#define PTHREAD_MUTEX_LOCK pthread_mutex_lock
+#define PTHREAD_MUTEX_UNLOCK pthread_mutex_unlock
+
+#else
+/*
+ *     This is easier than ifdef's throughout the code.
+ */
+#define PTHREAD_MUTEX_LOCK(_x)
+#define PTHREAD_MUTEX_UNLOCK(_x)
+
+#endif
 
 /*
  *     rbtree callback.
@@ -82,7 +89,7 @@ static int state_entry_cmp(void const *one, void const *two)
  *
  *     Note that
  */
-static void state_entry_free(state_entry_t *entry)
+static void state_entry_free(fr_state_t *state, state_entry_t *entry)
 {
        state_entry_t *prev, *next;
 
@@ -90,25 +97,25 @@ static void state_entry_free(state_entry_t *entry)
         *      If we're deleting the whole tree, don't bother doing
         *      all of the fixups.
         */
-       if (!state_tree) return;
+       if (!state || !state->tree) return;
 
        prev = entry->prev;
        next = entry->next;
 
        if (prev) {
-               rad_assert(state_head != entry);
+               rad_assert(state->head != entry);
                prev->next = next;
-       } else if (state_head) {
-               rad_assert(state_head == entry);
-               state_head = next;
+       } else if (state->head) {
+               rad_assert(state->head == entry);
+               state->head = next;
        }
 
        if (next) {
-               rad_assert(state_tail != entry);
+               rad_assert(state->tail != entry);
                next->prev = prev;
-       } else if (state_tail) {
-               rad_assert(state_tail == entry);
-               state_tail = prev;
+       } else if (state->tail) {
+               rad_assert(state->tail == entry);
+               state->tail = prev;
        }
 
        if (entry->opaque) {
@@ -118,47 +125,66 @@ static void state_entry_free(state_entry_t *entry)
 #ifdef WITH_VERIFY_PTR
        (void) talloc_get_type_abort(entry, state_entry_t);
 #endif
-       rbtree_deletebydata(state_tree, entry);
+       rbtree_deletebydata(state->tree, entry);
+
+       if (entry->ctx) talloc_free(entry->ctx);
+
        talloc_free(entry);
 }
 
-bool fr_state_init(void)
+fr_state_t *fr_state_init(TALLOC_CTX *ctx)
 {
+       fr_state_t *state;
+
+       if (!ctx) {
+               state = &global_state;
+               if (state->tree) return state;
+       } else {
+               state = talloc_zero(ctx, fr_state_t);
+               if (!state) return 0;
+       }
+
 #ifdef HAVE_PTHREAD_H
-       if (pthread_mutex_init(&state_mutex, NULL) != 0) {
-               return false;
+       if (pthread_mutex_init(&state->mutex, NULL) != 0) {
+               talloc_free(state);
+               return NULL;
        }
 #endif
 
-       state_tree = rbtree_create(NULL, state_entry_cmp, NULL, 0);
-       if (!state_tree) {
-               return false;
+       state->tree = rbtree_create(NULL, state_entry_cmp, NULL, 0);
+       if (!state->tree) {
+               talloc_free(state);
+               return NULL;
        }
 
-       return true;
+       return state;
 }
 
-void fr_state_delete(void)
+void fr_state_delete(fr_state_t *state)
 {
        rbtree_t *my_tree;
 
-       PTHREAD_MUTEX_LOCK(&state_mutex);
+       if (!state) return;
+
+       PTHREAD_MUTEX_LOCK(&state->mutex);
 
        /*
         *      Tell the talloc callback to NOT delete the entry from
         *      the tree.  We're deleting the entire tree.
         */
-       my_tree = state_tree;
-       state_tree = NULL;
+       my_tree = state->tree;
+       state->tree = NULL;
 
        rbtree_free(my_tree);
-       PTHREAD_MUTEX_UNLOCK(&state_mutex);
+       PTHREAD_MUTEX_UNLOCK(&state->mutex);
+
+       if (state != &global_state) talloc_free(state);
 }
 
 /*
  *     Create a new entry.  Called with the mutex held.
  */
-static state_entry_t *fr_state_create(RADIUS_PACKET *packet, state_entry_t *old)
+static state_entry_t *fr_state_create(fr_state_t *state, const char *server, RADIUS_PACKET *packet, state_entry_t *old)
 {
        size_t i;
        uint32_t x;
@@ -169,7 +195,7 @@ static state_entry_t *fr_state_create(RADIUS_PACKET *packet, state_entry_t *old)
        /*
         *      Clean up old entries.
         */
-       for (entry = state_head; entry != NULL; entry = next) {
+       for (entry = state->head; entry != NULL; entry = next) {
                next = entry->next;
 
                if (entry == old) continue;
@@ -178,7 +204,7 @@ static state_entry_t *fr_state_create(RADIUS_PACKET *packet, state_entry_t *old)
                 *      Too old, we can delete it.
                 */
                if (entry->cleanup < now) {
-                       state_entry_free(entry);
+                       state_entry_free(state, entry);
                        continue;
                }
 
@@ -186,8 +212,8 @@ static state_entry_t *fr_state_create(RADIUS_PACKET *packet, state_entry_t *old)
                 *      Unused.  We can delete it, even if now isn't
                 *      the time to clean it up.
                 */
-               if (!entry->vps && !entry->opaque) {
-                       state_entry_free(entry);
+               if (!entry->ctx && !entry->opaque) {
+                       state_entry_free(state, entry);
                        continue;
                }
 
@@ -198,14 +224,14 @@ static state_entry_t *fr_state_create(RADIUS_PACKET *packet, state_entry_t *old)
         *      Limit the size of the cache based on how many requests
         *      we can handle at the same time.
         */
-       if (rbtree_num_elements(state_tree) >= main_config.max_requests * 2) {
+       if (rbtree_num_elements(state->tree) >= main_config.max_requests * 2) {
                return NULL;
        }
 
        /*
         *      Allocate a new one.
         */
-       entry = talloc_zero(state_tree, state_entry_t);
+       entry = talloc_zero(state->tree, state_entry_t);
        if (!entry) return NULL;
 
        /*
@@ -222,7 +248,7 @@ static state_entry_t *fr_state_create(RADIUS_PACKET *packet, state_entry_t *old)
         *      The EAP module creates it's own State attribute, so we
         *      want to use that one in preference to one we create.
         */
-       vp = pairfind(packet->vps, PW_STATE, 0, TAG_ANY);
+       vp = fr_pair_find_by_num(packet->vps, PW_STATE, 0, TAG_ANY);
 
        /*
         *      If possible, base the new one off of the old one.
@@ -230,8 +256,6 @@ static state_entry_t *fr_state_create(RADIUS_PACKET *packet, state_entry_t *old)
        if (old) {
                entry->tries = old->tries + 1;
 
-               rad_assert(old->vps == NULL);
-
                /*
                 *      Track State
                 */
@@ -247,7 +271,7 @@ static state_entry_t *fr_state_create(RADIUS_PACKET *packet, state_entry_t *old)
                /*
                 *      The old one isn't used any more, so we can free it.
                 */
-               if (!old->opaque) state_entry_free(old);
+               if (!old->opaque) state_entry_free(state, old);
 
        } else if (!vp) {
                /*
@@ -265,19 +289,23 @@ static state_entry_t *fr_state_create(RADIUS_PACKET *packet, state_entry_t *old)
         *      one we created above.
         */
        if (vp) {
-               if (debug_flag && (vp->vp_length > sizeof(entry->state))) {
+               if (rad_debug_lvl && (vp->vp_length > sizeof(entry->state))) {
                        WARN("State should be %zd octets!",
                             sizeof(entry->state));
                }
                memcpy(entry->state, vp->vp_octets, sizeof(entry->state));
 
        } else {
-               vp = paircreate(packet, PW_STATE, 0);
-               pairmemcpy(vp, entry->state, sizeof(entry->state));
-               pairadd(&packet->vps, vp);
+               vp = fr_pair_afrom_num(packet, PW_STATE, 0);
+               fr_pair_value_memcpy(vp, entry->state, sizeof(entry->state));
+               fr_pair_add(&packet->vps, vp);
        }
 
-       if (!rbtree_insert(state_tree, entry)) {
+       /*      Make unique for different virtual servers handling same request
+        */
+       if (server) *((uint32_t *)(&entry->state[4])) ^= fr_hash_string(server);
+
+       if (!rbtree_insert(state->tree, entry)) {
                talloc_free(entry);
                return NULL;
        }
@@ -286,17 +314,17 @@ static state_entry_t *fr_state_create(RADIUS_PACKET *packet, state_entry_t *old)
         *      Link it to the end of the list, which is implicitely
         *      ordered by cleanup time.
         */
-       if (!state_head) {
+       if (!state->head) {
                entry->prev = entry->next = NULL;
-               state_head = state_tail = entry;
+               state->head = state->tail = entry;
        } else {
-               rad_assert(state_tail != NULL);
+               rad_assert(state->tail != NULL);
 
-               entry->prev = state_tail;
-               state_tail->next = entry;
+               entry->prev = state->tail;
+               state->tail->next = entry;
 
                entry->next = NULL;
-               state_tail = entry;
+               state->tail = entry;
        }
 
        return entry;
@@ -306,19 +334,23 @@ static state_entry_t *fr_state_create(RADIUS_PACKET *packet, state_entry_t *old)
 /*
  *     Find the entry, based on the State attribute.
  */
-static state_entry_t *fr_state_find(RADIUS_PACKET *packet)
+static state_entry_t *fr_state_find(fr_state_t *state, const char *server, RADIUS_PACKET *packet)
 {
        VALUE_PAIR *vp;
        state_entry_t *entry, my_entry;
 
-       vp = pairfind(packet->vps, PW_STATE, 0, TAG_ANY);
+       vp = fr_pair_find_by_num(packet->vps, PW_STATE, 0, TAG_ANY);
        if (!vp) return NULL;
 
        if (vp->vp_length != sizeof(my_entry.state)) return NULL;
 
        memcpy(my_entry.state, vp->vp_octets, sizeof(my_entry.state));
 
-       entry = rbtree_finddata(state_tree, &my_entry);
+       /*      Make unique for different virtual servers handling same request
+        */
+       if (server) *((uint32_t *)(&my_entry.state[4])) ^= fr_hash_string(server);
+
+       entry = rbtree_finddata(state->tree, &my_entry);
 
 #ifdef WITH_VERIFY_PTR
        if (entry)  (void) talloc_get_type_abort(entry, state_entry_t);
@@ -334,19 +366,20 @@ static state_entry_t *fr_state_find(RADIUS_PACKET *packet)
 void fr_state_discard(REQUEST *request, RADIUS_PACKET *original)
 {
        state_entry_t *entry;
+       fr_state_t *state = &global_state;
 
-       pairfree(&request->state);
+       fr_pair_list_free(&request->state);
        request->state = NULL;
 
-       PTHREAD_MUTEX_LOCK(&state_mutex);
-       entry = fr_state_find(original);
+       PTHREAD_MUTEX_LOCK(&state->mutex);
+       entry = fr_state_find(state, request->server, original);
        if (!entry) {
-               PTHREAD_MUTEX_UNLOCK(&state_mutex);
+               PTHREAD_MUTEX_UNLOCK(&state->mutex);
                return;
        }
 
-       state_entry_free(entry);
-       PTHREAD_MUTEX_UNLOCK(&state_mutex);
+       state_entry_free(state, entry);
+       PTHREAD_MUTEX_UNLOCK(&state->mutex);
        return;
 }
 
@@ -356,34 +389,49 @@ void fr_state_discard(REQUEST *request, RADIUS_PACKET *original)
 void fr_state_get_vps(REQUEST *request, RADIUS_PACKET *packet)
 {
        state_entry_t *entry;
+       fr_state_t *state = &global_state;
+       TALLOC_CTX *old_ctx = NULL;
 
        rad_assert(request->state == NULL);
 
        /*
         *      No State, don't do anything.
         */
-       if (!pairfind(request->packet->vps, PW_STATE, 0, TAG_ANY)) {
+       if (!fr_pair_find_by_num(request->packet->vps, PW_STATE, 0, TAG_ANY)) {
                RDEBUG3("session-state: No State attribute");
                return;
        }
 
-       PTHREAD_MUTEX_LOCK(&state_mutex);
-       entry = fr_state_find(packet);
+       PTHREAD_MUTEX_LOCK(&state->mutex);
+       entry = fr_state_find(state, request->server, packet);
 
        /*
         *      This has to be done in a mutex lock, because talloc
         *      isn't thread-safe.
         */
        if (entry) {
-               pairfilter(request, &request->state, &entry->vps, 0, 0, TAG_ANY);
-               RDEBUG2("session-state: Found cached attributes");
-               rdebug_pair_list(L_DBG_LVL_1, request, request->state, NULL);
+               RDEBUG2("Restoring &session-state");
+
+               if (request->state_ctx) old_ctx = request->state_ctx;
+
+               request->state_ctx = entry->ctx;
+               request->state = entry->vps;
+
+               entry->ctx = NULL;
+               entry->vps = NULL;
+
+               rdebug_pair_list(L_DBG_LVL_2, request, request->state, "&session-state:");
 
        } else {
                RDEBUG2("session-state: No cached attributes");
        }
 
-       PTHREAD_MUTEX_UNLOCK(&state_mutex);
+       PTHREAD_MUTEX_UNLOCK(&state->mutex);
+
+       /*
+        *      Free this outside of the mutex for less contention.
+        */
+       if (old_ctx) talloc_free(old_ctx);
 
        VERIFY_REQUEST(request);
        return;
@@ -398,6 +446,7 @@ void fr_state_get_vps(REQUEST *request, RADIUS_PACKET *packet)
 bool fr_state_put_vps(REQUEST *request, RADIUS_PACKET *original, RADIUS_PACKET *packet)
 {
        state_entry_t *entry, *old;
+       fr_state_t *state = &global_state;
 
        if (!request->state) {
                RDEBUG3("session-state: Nothing to cache");
@@ -407,28 +456,29 @@ bool fr_state_put_vps(REQUEST *request, RADIUS_PACKET *original, RADIUS_PACKET *
        RDEBUG2("session-state: Saving cached attributes");
        rdebug_pair_list(L_DBG_LVL_1, request, request->state, NULL);
 
-       PTHREAD_MUTEX_LOCK(&state_mutex);
+       PTHREAD_MUTEX_LOCK(&state->mutex);
 
        if (original) {
-               old = fr_state_find(original);
+               old = fr_state_find(state, request->server, original);
        } else {
                old = NULL;
        }
 
-       entry = fr_state_create(packet, old);
+       entry = fr_state_create(state, request->server, packet, old);
        if (!entry) {
-               PTHREAD_MUTEX_UNLOCK(&state_mutex);
+               PTHREAD_MUTEX_UNLOCK(&state->mutex);
                return false;
        }
 
-       /*
-        *      This has to be done in a mutex lock, because talloc
-        *      isn't thread-safe.
-        */
-       pairfilter(entry, &entry->vps, &request->state, 0, 0, TAG_ANY);
-       PTHREAD_MUTEX_UNLOCK(&state_mutex);
+       rad_assert(entry->ctx == NULL);
+       entry->ctx = request->state_ctx;
+       entry->vps = request->state;
+
+       request->state_ctx = NULL;
+       request->state = NULL;
+
+       PTHREAD_MUTEX_UNLOCK(&state->mutex);
 
-       rad_assert(request->state == NULL);
        VERIFY_REQUEST(request);
        return true;
 }
@@ -437,20 +487,22 @@ bool fr_state_put_vps(REQUEST *request, RADIUS_PACKET *original, RADIUS_PACKET *
  *     Find the opaque data associated with a State attribute.
  *     Leave the data in the entry.
  */
-void *fr_state_find_data(UNUSED REQUEST *request, RADIUS_PACKET *packet)
+void *fr_state_find_data(fr_state_t *state, REQUEST *request, RADIUS_PACKET *packet)
 {
        void *data;
        state_entry_t *entry;
 
-       PTHREAD_MUTEX_LOCK(&state_mutex);
-       entry = fr_state_find(packet);
+       if (!state) return false;
+
+       PTHREAD_MUTEX_LOCK(&state->mutex);
+       entry = fr_state_find(state, request->server, packet);
        if (!entry) {
-               PTHREAD_MUTEX_UNLOCK(&state_mutex);
+               PTHREAD_MUTEX_UNLOCK(&state->mutex);
                return NULL;
        }
 
        data = entry->opaque;
-       PTHREAD_MUTEX_UNLOCK(&state_mutex);
+       PTHREAD_MUTEX_UNLOCK(&state->mutex);
 
        return data;
 }
@@ -460,21 +512,23 @@ void *fr_state_find_data(UNUSED REQUEST *request, RADIUS_PACKET *packet)
  *     Get the opaque data associated with a State attribute.
  *     and remove the data from the entry.
  */
-void *fr_state_get_data(UNUSED REQUEST *request, RADIUS_PACKET *packet)
+void *fr_state_get_data(fr_state_t *state, REQUEST *request, RADIUS_PACKET *packet)
 {
        void *data;
        state_entry_t *entry;
 
-       PTHREAD_MUTEX_LOCK(&state_mutex);
-       entry = fr_state_find(packet);
+       if (!state) return NULL;
+
+       PTHREAD_MUTEX_LOCK(&state->mutex);
+       entry = fr_state_find(state, request->server, packet);
        if (!entry) {
-               PTHREAD_MUTEX_UNLOCK(&state_mutex);
+               PTHREAD_MUTEX_UNLOCK(&state->mutex);
                return NULL;
        }
 
        data = entry->opaque;
        entry->opaque = NULL;
-       PTHREAD_MUTEX_UNLOCK(&state_mutex);
+       PTHREAD_MUTEX_UNLOCK(&state->mutex);
 
        return data;
 }
@@ -484,22 +538,24 @@ void *fr_state_get_data(UNUSED REQUEST *request, RADIUS_PACKET *packet)
  *     Get the opaque data associated with a State attribute.
  *     and remove the data from the entry.
  */
-bool fr_state_put_data(UNUSED REQUEST *request, RADIUS_PACKET *original, RADIUS_PACKET *packet,
+bool fr_state_put_data(fr_state_t *state, REQUEST *request, RADIUS_PACKET *original, RADIUS_PACKET *packet,
                       void *data, void (*free_data)(void *))
 {
        state_entry_t *entry, *old;
 
-       PTHREAD_MUTEX_LOCK(&state_mutex);
+       if (!state) return false;
+
+       PTHREAD_MUTEX_LOCK(&state->mutex);
 
        if (original) {
-               old = fr_state_find(original);
+               old = fr_state_find(state, request->server, original);
        } else {
                old = NULL;
        }
 
-       entry = fr_state_create(packet, old);
+       entry = fr_state_create(state, request->server, packet, old);
        if (!entry) {
-               PTHREAD_MUTEX_UNLOCK(&state_mutex);
+               PTHREAD_MUTEX_UNLOCK(&state->mutex);
                return false;
        }
 
@@ -514,6 +570,6 @@ bool fr_state_put_data(UNUSED REQUEST *request, RADIUS_PACKET *original, RADIUS_
        entry->opaque = data;
        entry->free_opaque = free_data;
 
-       PTHREAD_MUTEX_UNLOCK(&state_mutex);
+       PTHREAD_MUTEX_UNLOCK(&state->mutex);
        return true;
 }