Added support for opportunistic key caching (OKC)
[libeap.git] / hostapd / wpa_auth_ie.c
index 1d2caa3..7af7531 100644 (file)
@@ -414,6 +414,25 @@ static int wpa_parse_wpa_ie_wpa(const u8 *wpa_ie, size_t wpa_ie_len,
 }
 
 
+struct wpa_auth_okc_iter_data {
+       struct rsn_pmksa_cache_entry *pmksa;
+       const u8 *aa;
+       const u8 *spa;
+       const u8 *pmkid;
+};
+
+
+static int wpa_auth_okc_iter(struct wpa_authenticator *a, void *ctx)
+{
+       struct wpa_auth_okc_iter_data *data = ctx;
+       data->pmksa = pmksa_cache_get_okc(a->pmksa, data->aa, data->spa,
+                                         data->pmkid);
+       if (data->pmksa)
+               return 1;
+       return 0;
+}
+
+
 int wpa_validate_wpa_ie(struct wpa_authenticator *wpa_auth,
                        struct wpa_state_machine *sm,
                        const u8 *wpa_ie, size_t wpa_ie_len,
@@ -423,6 +442,7 @@ int wpa_validate_wpa_ie(struct wpa_authenticator *wpa_auth,
        int ciphers, key_mgmt, res, version;
        u32 selector;
        size_t i;
+       const u8 *pmkid = NULL;
 
        if (wpa_auth == NULL || sm == NULL)
                return WPA_NOT_ENABLED;
@@ -615,22 +635,44 @@ int wpa_validate_wpa_ie(struct wpa_authenticator *wpa_auth,
        else
                sm->wpa = WPA_VERSION_WPA;
 
+       sm->pmksa = NULL;
        for (i = 0; i < data.num_pmkid; i++) {
                wpa_hexdump(MSG_DEBUG, "RSN IE: STA PMKID",
                            &data.pmkid[i * PMKID_LEN], PMKID_LEN);
                sm->pmksa = pmksa_cache_get(wpa_auth->pmksa, sm->addr,
                                            &data.pmkid[i * PMKID_LEN]);
                if (sm->pmksa) {
+                       pmkid = sm->pmksa->pmkid;
+                       break;
+               }
+       }
+       for (i = 0; sm->pmksa == NULL && wpa_auth->conf.okc &&
+                    i < data.num_pmkid; i++) {
+               struct wpa_auth_okc_iter_data idata;
+               idata.pmksa = NULL;
+               idata.aa = wpa_auth->addr;
+               idata.spa = sm->addr;
+               idata.pmkid = &data.pmkid[i * PMKID_LEN];
+               wpa_auth_for_each_auth(wpa_auth, wpa_auth_okc_iter, &idata);
+               if (idata.pmksa) {
                        wpa_auth_vlogger(wpa_auth, sm->addr, LOGGER_DEBUG,
-                                        "PMKID found from PMKSA cache "
-                                        "eap_type=%d vlan_id=%d",
-                                        sm->pmksa->eap_type_authsrv,
-                                        sm->pmksa->vlan_id);
-                       os_memcpy(wpa_auth->dot11RSNAPMKIDUsed,
-                                 sm->pmksa->pmkid, PMKID_LEN);
+                                        "OKC match for PMKID");
+                       sm->pmksa = pmksa_cache_add_okc(wpa_auth->pmksa,
+                                                       idata.pmksa,
+                                                       wpa_auth->addr,
+                                                       idata.pmkid);
+                       pmkid = idata.pmkid;
                        break;
                }
        }
+       if (sm->pmksa) {
+               wpa_auth_vlogger(wpa_auth, sm->addr, LOGGER_DEBUG,
+                                "PMKID found from PMKSA cache "
+                                "eap_type=%d vlan_id=%d",
+                                sm->pmksa->eap_type_authsrv,
+                                sm->pmksa->vlan_id);
+               os_memcpy(wpa_auth->dot11RSNAPMKIDUsed, pmkid, PMKID_LEN);
+       }
 
        if (sm->wpa_ie == NULL || sm->wpa_ie_len < wpa_ie_len) {
                os_free(sm->wpa_ie);