SAE: Fix PMKID calculation for PMKSA cache
[mech_eap.git] / src / rsn_supp / pmksa_cache.c
index b5a87fc..3d8d122 100644 (file)
@@ -1,6 +1,6 @@
 /*
  * WPA Supplicant - RSN PMKSA cache
- * Copyright (c) 2004-2009, 2011-2012, Jouni Malinen <j@w1.fi>
+ * Copyright (c) 2004-2009, 2011-2015, Jouni Malinen <j@w1.fi>
  *
  * This software may be distributed under the terms of the BSD license.
  * See README for more details.
@@ -15,7 +15,7 @@
 #include "wpa_i.h"
 #include "pmksa_cache.h"
 
-#ifdef IEEE8021X_EAPOL
+#if defined(IEEE8021X_EAPOL) && !defined(CONFIG_NO_WPA)
 
 static const int pmksa_cache_max_entries = 32;
 
@@ -35,7 +35,7 @@ static void pmksa_cache_set_expiration(struct rsn_pmksa_cache *pmksa);
 
 static void _pmksa_cache_free_entry(struct rsn_pmksa_cache_entry *entry)
 {
-       os_free(entry);
+       bin_clear_free(entry, sizeof(*entry));
 }
 
 
@@ -109,6 +109,9 @@ static void pmksa_cache_set_expiration(struct rsn_pmksa_cache *pmksa)
  * @pmksa: Pointer to PMKSA cache data from pmksa_cache_init()
  * @pmk: The new pairwise master key
  * @pmk_len: PMK length in bytes, usually PMK_LEN (32)
+ * @pmkid: Calculated PMKID
+ * @kck: Key confirmation key or %NULL if not yet derived
+ * @kck_len: KCK length in bytes
  * @aa: Authenticator address
  * @spa: Supplicant address
  * @network_ctx: Network configuration context for this PMK
@@ -122,12 +125,16 @@ static void pmksa_cache_set_expiration(struct rsn_pmksa_cache *pmksa)
  */
 struct rsn_pmksa_cache_entry *
 pmksa_cache_add(struct rsn_pmksa_cache *pmksa, const u8 *pmk, size_t pmk_len,
+               const u8 *pmkid, const u8 *kck, size_t kck_len,
                const u8 *aa, const u8 *spa, void *network_ctx, int akmp)
 {
        struct rsn_pmksa_cache_entry *entry, *pos, *prev;
        struct os_reltime now;
 
-       if (pmk_len > PMK_LEN)
+       if (pmk_len > PMK_LEN_MAX)
+               return NULL;
+
+       if (wpa_key_mgmt_suite_b(akmp) && !kck)
                return NULL;
 
        entry = os_zalloc(sizeof(*entry));
@@ -135,8 +142,15 @@ pmksa_cache_add(struct rsn_pmksa_cache *pmksa, const u8 *pmk, size_t pmk_len,
                return NULL;
        os_memcpy(entry->pmk, pmk, pmk_len);
        entry->pmk_len = pmk_len;
-       rsn_pmkid(pmk, pmk_len, aa, spa, entry->pmkid,
-                 wpa_key_mgmt_sha256(akmp));
+       if (pmkid)
+               os_memcpy(entry->pmkid, pmkid, PMKID_LEN);
+       else if (akmp == WPA_KEY_MGMT_IEEE8021X_SUITE_B_192)
+               rsn_pmkid_suite_b_192(kck, kck_len, aa, spa, entry->pmkid);
+       else if (wpa_key_mgmt_suite_b(akmp))
+               rsn_pmkid_suite_b(kck, kck_len, aa, spa, entry->pmkid);
+       else
+               rsn_pmkid(pmk, pmk_len, aa, spa, entry->pmkid,
+                         wpa_key_mgmt_sha256(akmp));
        os_get_reltime(&now);
        entry->expiration = now.sec + pmksa->sm->dot11RSNAConfigPMKLifetime;
        entry->reauth_time = now.sec + pmksa->sm->dot11RSNAConfigPMKLifetime *
@@ -333,6 +347,7 @@ pmksa_cache_clone_entry(struct rsn_pmksa_cache *pmksa,
        struct rsn_pmksa_cache_entry *new_entry;
 
        new_entry = pmksa_cache_add(pmksa, old_entry->pmk, old_entry->pmk_len,
+                                   NULL, NULL, 0,
                                    aa, pmksa->sm->own_addr,
                                    old_entry->network_ctx, old_entry->akmp);
        if (new_entry == NULL)
@@ -472,7 +487,7 @@ int pmksa_cache_list(struct rsn_pmksa_cache *pmksa, char *buf, size_t len)
        ret = os_snprintf(pos, buf + len - pos,
                          "Index / AA / PMKID / expiration (in seconds) / "
                          "opportunistic\n");
-       if (ret < 0 || ret >= buf + len - pos)
+       if (os_snprintf_error(buf + len - pos, ret))
                return pos - buf;
        pos += ret;
        i = 0;
@@ -481,7 +496,7 @@ int pmksa_cache_list(struct rsn_pmksa_cache *pmksa, char *buf, size_t len)
                i++;
                ret = os_snprintf(pos, buf + len - pos, "%d " MACSTR " ",
                                  i, MAC2STR(entry->aa));
-               if (ret < 0 || ret >= buf + len - pos)
+               if (os_snprintf_error(buf + len - pos, ret))
                        return pos - buf;
                pos += ret;
                pos += wpa_snprintf_hex(pos, buf + len - pos, entry->pmkid,
@@ -489,7 +504,7 @@ int pmksa_cache_list(struct rsn_pmksa_cache *pmksa, char *buf, size_t len)
                ret = os_snprintf(pos, buf + len - pos, " %d %d\n",
                                  (int) (entry->expiration - now.sec),
                                  entry->opportunistic);
-               if (ret < 0 || ret >= buf + len - pos)
+               if (os_snprintf_error(buf + len - pos, ret))
                        return pos - buf;
                pos += ret;
                entry = entry->next;