OpenSSL: Add wrapper struct for tls_init() result
authorJouni Malinen <j@w1.fi>
Sun, 23 Aug 2015 16:22:13 +0000 (19:22 +0300)
committerJouni Malinen <j@w1.fi>
Sun, 23 Aug 2015 23:29:25 +0000 (02:29 +0300)
This new struct tls_data is needed to store per-tls_init() information
in the followup commits.

Signed-off-by: Jouni Malinen <j@w1.fi>
src/crypto/tls_openssl.c

index de1b2c7..8b84171 100644 (file)
@@ -86,6 +86,10 @@ struct tls_context {
 static struct tls_context *tls_global = NULL;
 
 
+struct tls_data {
+       SSL_CTX *ssl;
+};
+
 struct tls_connection {
        struct tls_context *context;
        SSL_CTX *ssl_ctx;
@@ -746,6 +750,7 @@ static int tls_engine_load_dynamic_opensc(const char *opensc_so_path)
 
 void * tls_init(const struct tls_config *conf)
 {
+       struct tls_data *data;
        SSL_CTX *ssl;
        struct tls_context *context;
        const char *ciphers;
@@ -810,7 +815,11 @@ void * tls_init(const struct tls_config *conf)
        }
        tls_openssl_ref_count++;
 
-       ssl = SSL_CTX_new(SSLv23_method());
+       data = os_zalloc(sizeof(*data));
+       if (data)
+               ssl = SSL_CTX_new(SSLv23_method());
+       else
+               ssl = NULL;
        if (ssl == NULL) {
                tls_openssl_ref_count--;
                if (context != tls_global)
@@ -821,6 +830,7 @@ void * tls_init(const struct tls_config *conf)
                }
                return NULL;
        }
+       data->ssl = ssl;
 
        SSL_CTX_set_options(ssl, SSL_OP_NO_SSLv2);
        SSL_CTX_set_options(ssl, SSL_OP_NO_SSLv3);
@@ -839,7 +849,7 @@ void * tls_init(const struct tls_config *conf)
                if (tls_engine_load_dynamic_opensc(conf->opensc_engine_path) ||
                    tls_engine_load_dynamic_pkcs11(conf->pkcs11_engine_path,
                                                   conf->pkcs11_module_path)) {
-                       tls_deinit(ssl);
+                       tls_deinit(data);
                        return NULL;
                }
        }
@@ -853,17 +863,18 @@ void * tls_init(const struct tls_config *conf)
                wpa_printf(MSG_ERROR,
                           "OpenSSL: Failed to set cipher string '%s'",
                           ciphers);
-               tls_deinit(ssl);
+               tls_deinit(data);
                return NULL;
        }
 
-       return ssl;
+       return data;
 }
 
 
 void tls_deinit(void *ssl_ctx)
 {
-       SSL_CTX *ssl = ssl_ctx;
+       struct tls_data *data = ssl_ctx;
+       SSL_CTX *ssl = data->ssl;
        struct tls_context *context = SSL_CTX_get_app_data(ssl);
        if (context != tls_global)
                os_free(context);
@@ -883,6 +894,8 @@ void tls_deinit(void *ssl_ctx)
                os_free(tls_global);
                tls_global = NULL;
        }
+
+       os_free(data);
 }
 
 
@@ -1058,7 +1071,8 @@ static void tls_msg_cb(int write_p, int version, int content_type,
 
 struct tls_connection * tls_connection_init(void *ssl_ctx)
 {
-       SSL_CTX *ssl = ssl_ctx;
+       struct tls_data *data = ssl_ctx;
+       SSL_CTX *ssl = data->ssl;
        struct tls_connection *conn;
        long options;
        struct tls_context *context = SSL_CTX_get_app_data(ssl);
@@ -1066,7 +1080,7 @@ struct tls_connection * tls_connection_init(void *ssl_ctx)
        conn = os_zalloc(sizeof(*conn));
        if (conn == NULL)
                return NULL;
-       conn->ssl_ctx = ssl_ctx;
+       conn->ssl_ctx = ssl;
        conn->ssl = SSL_new(ssl);
        if (conn->ssl == NULL) {
                tls_show_errors(MSG_INFO, __func__,
@@ -1641,9 +1655,9 @@ static int tls_verify_cb(int preverify_ok, X509_STORE_CTX *x509_ctx)
 
 
 #ifndef OPENSSL_NO_STDIO
-static int tls_load_ca_der(void *_ssl_ctx, const char *ca_cert)
+static int tls_load_ca_der(struct tls_data *data, const char *ca_cert)
 {
-       SSL_CTX *ssl_ctx = _ssl_ctx;
+       SSL_CTX *ssl_ctx = data->ssl;
        X509_LOOKUP *lookup;
        int ret = 0;
 
@@ -1673,11 +1687,12 @@ static int tls_load_ca_der(void *_ssl_ctx, const char *ca_cert)
 #endif /* OPENSSL_NO_STDIO */
 
 
-static int tls_connection_ca_cert(void *_ssl_ctx, struct tls_connection *conn,
+static int tls_connection_ca_cert(struct tls_data *data,
+                                 struct tls_connection *conn,
                                  const char *ca_cert, const u8 *ca_cert_blob,
                                  size_t ca_cert_blob_len, const char *ca_path)
 {
-       SSL_CTX *ssl_ctx = _ssl_ctx;
+       SSL_CTX *ssl_ctx = data->ssl;
        X509_STORE *store;
 
        /*
@@ -1812,7 +1827,7 @@ static int tls_connection_ca_cert(void *_ssl_ctx, struct tls_connection *conn,
                        tls_show_errors(MSG_WARNING, __func__,
                                        "Failed to load root certificates");
                        if (ca_cert &&
-                           tls_load_ca_der(ssl_ctx, ca_cert) == 0) {
+                           tls_load_ca_der(data, ca_cert) == 0) {
                                wpa_printf(MSG_DEBUG, "OpenSSL: %s - loaded "
                                           "DER format CA certificate",
                                           __func__);
@@ -1821,7 +1836,7 @@ static int tls_connection_ca_cert(void *_ssl_ctx, struct tls_connection *conn,
                } else {
                        wpa_printf(MSG_DEBUG, "TLS: Trusted root "
                                   "certificate(s) loaded");
-                       tls_get_errors(ssl_ctx);
+                       tls_get_errors(data);
                }
 #else /* OPENSSL_NO_STDIO */
                wpa_printf(MSG_DEBUG, "OpenSSL: %s - OPENSSL_NO_STDIO",
@@ -1838,8 +1853,10 @@ static int tls_connection_ca_cert(void *_ssl_ctx, struct tls_connection *conn,
 }
 
 
-static int tls_global_ca_cert(SSL_CTX *ssl_ctx, const char *ca_cert)
+static int tls_global_ca_cert(struct tls_data *data, const char *ca_cert)
 {
+       SSL_CTX *ssl_ctx = data->ssl;
+
        if (ca_cert) {
                if (SSL_CTX_load_verify_locations(ssl_ctx, ca_cert, NULL) != 1)
                {
@@ -1867,7 +1884,8 @@ int tls_global_set_verify(void *ssl_ctx, int check_crl)
        int flags;
 
        if (check_crl) {
-               X509_STORE *cs = SSL_CTX_get_cert_store(ssl_ctx);
+               struct tls_data *data = ssl_ctx;
+               X509_STORE *cs = SSL_CTX_get_cert_store(data->ssl);
                if (cs == NULL) {
                        tls_show_errors(MSG_INFO, __func__, "Failed to get "
                                        "certificate store when enabling "
@@ -2028,9 +2046,12 @@ static int tls_connection_client_cert(struct tls_connection *conn,
 }
 
 
-static int tls_global_client_cert(SSL_CTX *ssl_ctx, const char *client_cert)
+static int tls_global_client_cert(struct tls_data *data,
+                                 const char *client_cert)
 {
 #ifndef OPENSSL_NO_STDIO
+       SSL_CTX *ssl_ctx = data->ssl;
+
        if (client_cert == NULL)
                return 0;
 
@@ -2064,7 +2085,7 @@ static int tls_passwd_cb(char *buf, int size, int rwflag, void *password)
 
 
 #ifdef PKCS12_FUNCS
-static int tls_parse_pkcs12(SSL_CTX *ssl_ctx, SSL *ssl, PKCS12 *p12,
+static int tls_parse_pkcs12(struct tls_data *data, SSL *ssl, PKCS12 *p12,
                            const char *passwd)
 {
        EVP_PKEY *pkey;
@@ -2095,7 +2116,7 @@ static int tls_parse_pkcs12(SSL_CTX *ssl_ctx, SSL *ssl, PKCS12 *p12,
                        if (SSL_use_certificate(ssl, cert) != 1)
                                res = -1;
                } else {
-                       if (SSL_CTX_use_certificate(ssl_ctx, cert) != 1)
+                       if (SSL_CTX_use_certificate(data->ssl, cert) != 1)
                                res = -1;
                }
                X509_free(cert);
@@ -2107,7 +2128,7 @@ static int tls_parse_pkcs12(SSL_CTX *ssl_ctx, SSL *ssl, PKCS12 *p12,
                        if (SSL_use_PrivateKey(ssl, pkey) != 1)
                                res = -1;
                } else {
-                       if (SSL_CTX_use_PrivateKey(ssl_ctx, pkey) != 1)
+                       if (SSL_CTX_use_PrivateKey(data->ssl, pkey) != 1)
                                res = -1;
                }
                EVP_PKEY_free(pkey);
@@ -2146,7 +2167,7 @@ static int tls_parse_pkcs12(SSL_CTX *ssl_ctx, SSL *ssl, PKCS12 *p12,
                res = 0;
 #else /* OPENSSL_VERSION_NUMBER >= 0x10002000L */
 #if OPENSSL_VERSION_NUMBER >= 0x10001000L
-               SSL_CTX_clear_extra_chain_certs(ssl_ctx);
+               SSL_CTX_clear_extra_chain_certs(data->ssl);
 #endif /* OPENSSL_VERSION_NUMBER >= 0x10001000L */
                while ((cert = sk_X509_pop(certs)) != NULL) {
                        X509_NAME_oneline(X509_get_subject_name(cert), buf,
@@ -2157,7 +2178,8 @@ static int tls_parse_pkcs12(SSL_CTX *ssl_ctx, SSL *ssl, PKCS12 *p12,
                         * There is no SSL equivalent for the chain cert - so
                         * always add it to the context...
                         */
-                       if (SSL_CTX_add_extra_chain_cert(ssl_ctx, cert) != 1) {
+                       if (SSL_CTX_add_extra_chain_cert(data->ssl, cert) != 1)
+                       {
                                res = -1;
                                break;
                        }
@@ -2169,15 +2191,15 @@ static int tls_parse_pkcs12(SSL_CTX *ssl_ctx, SSL *ssl, PKCS12 *p12,
        PKCS12_free(p12);
 
        if (res < 0)
-               tls_get_errors(ssl_ctx);
+               tls_get_errors(data);
 
        return res;
 }
 #endif  /* PKCS12_FUNCS */
 
 
-static int tls_read_pkcs12(SSL_CTX *ssl_ctx, SSL *ssl, const char *private_key,
-                          const char *passwd)
+static int tls_read_pkcs12(struct tls_data *data, SSL *ssl,
+                          const char *private_key, const char *passwd)
 {
 #ifdef PKCS12_FUNCS
        FILE *f;
@@ -2196,7 +2218,7 @@ static int tls_read_pkcs12(SSL_CTX *ssl_ctx, SSL *ssl, const char *private_key,
                return -1;
        }
 
-       return tls_parse_pkcs12(ssl_ctx, ssl, p12, passwd);
+       return tls_parse_pkcs12(data, ssl, p12, passwd);
 
 #else /* PKCS12_FUNCS */
        wpa_printf(MSG_INFO, "TLS: PKCS12 support disabled - cannot read "
@@ -2206,7 +2228,7 @@ static int tls_read_pkcs12(SSL_CTX *ssl_ctx, SSL *ssl, const char *private_key,
 }
 
 
-static int tls_read_pkcs12_blob(SSL_CTX *ssl_ctx, SSL *ssl,
+static int tls_read_pkcs12_blob(struct tls_data *data, SSL *ssl,
                                const u8 *blob, size_t len, const char *passwd)
 {
 #ifdef PKCS12_FUNCS
@@ -2219,7 +2241,7 @@ static int tls_read_pkcs12_blob(SSL_CTX *ssl_ctx, SSL *ssl,
                return -1;
        }
 
-       return tls_parse_pkcs12(ssl_ctx, ssl, p12, passwd);
+       return tls_parse_pkcs12(data, ssl, p12, passwd);
 
 #else /* PKCS12_FUNCS */
        wpa_printf(MSG_INFO, "TLS: PKCS12 support disabled - cannot parse "
@@ -2290,13 +2312,13 @@ static int tls_connection_engine_client_cert(struct tls_connection *conn,
 }
 
 
-static int tls_connection_engine_ca_cert(void *_ssl_ctx,
+static int tls_connection_engine_ca_cert(struct tls_data *data,
                                         struct tls_connection *conn,
                                         const char *ca_cert_id)
 {
 #ifndef OPENSSL_NO_ENGINE
        X509 *cert;
-       SSL_CTX *ssl_ctx = _ssl_ctx;
+       SSL_CTX *ssl_ctx = data->ssl;
        X509_STORE *store;
 
        if (tls_engine_get_cert(conn, ca_cert_id, &cert))
@@ -2362,14 +2384,14 @@ static int tls_connection_engine_private_key(struct tls_connection *conn)
 }
 
 
-static int tls_connection_private_key(void *_ssl_ctx,
+static int tls_connection_private_key(struct tls_data *data,
                                      struct tls_connection *conn,
                                      const char *private_key,
                                      const char *private_key_passwd,
                                      const u8 *private_key_blob,
                                      size_t private_key_blob_len)
 {
-       SSL_CTX *ssl_ctx = _ssl_ctx;
+       SSL_CTX *ssl_ctx = data->ssl;
        char *passwd;
        int ok;
 
@@ -2415,7 +2437,7 @@ static int tls_connection_private_key(void *_ssl_ctx,
                        break;
                }
 
-               if (tls_read_pkcs12_blob(ssl_ctx, conn->ssl, private_key_blob,
+               if (tls_read_pkcs12_blob(data, conn->ssl, private_key_blob,
                                         private_key_blob_len, passwd) == 0) {
                        wpa_printf(MSG_DEBUG, "OpenSSL: PKCS#12 as blob --> "
                                   "OK");
@@ -2448,7 +2470,7 @@ static int tls_connection_private_key(void *_ssl_ctx,
                           __func__);
 #endif /* OPENSSL_NO_STDIO */
 
-               if (tls_read_pkcs12(ssl_ctx, conn->ssl, private_key, passwd)
+               if (tls_read_pkcs12(data, conn->ssl, private_key, passwd)
                    == 0) {
                        wpa_printf(MSG_DEBUG, "OpenSSL: Reading PKCS#12 file "
                                   "--> OK");
@@ -2487,9 +2509,11 @@ static int tls_connection_private_key(void *_ssl_ctx,
 }
 
 
-static int tls_global_private_key(SSL_CTX *ssl_ctx, const char *private_key,
+static int tls_global_private_key(struct tls_data *data,
+                                 const char *private_key,
                                  const char *private_key_passwd)
 {
+       SSL_CTX *ssl_ctx = data->ssl;
        char *passwd;
 
        if (private_key == NULL)
@@ -2511,7 +2535,7 @@ static int tls_global_private_key(SSL_CTX *ssl_ctx, const char *private_key,
            SSL_CTX_use_PrivateKey_file(ssl_ctx, private_key,
                                        SSL_FILETYPE_PEM) != 1 &&
 #endif /* OPENSSL_NO_STDIO */
-           tls_read_pkcs12(ssl_ctx, NULL, private_key, passwd)) {
+           tls_read_pkcs12(data, NULL, private_key, passwd)) {
                tls_show_errors(MSG_INFO, __func__,
                                "Failed to load private key");
                os_free(passwd);
@@ -2606,7 +2630,7 @@ static int tls_connection_dh(struct tls_connection *conn, const char *dh_file)
 }
 
 
-static int tls_global_dh(SSL_CTX *ssl_ctx, const char *dh_file)
+static int tls_global_dh(struct tls_data *data, const char *dh_file)
 {
 #ifdef OPENSSL_NO_DH
        if (dh_file == NULL)
@@ -2615,6 +2639,7 @@ static int tls_global_dh(SSL_CTX *ssl_ctx, const char *dh_file)
                   "dh_file specified");
        return -1;
 #else /* OPENSSL_NO_DH */
+       SSL_CTX *ssl_ctx = data->ssl;
        DH *dh;
        BIO *bio;
 
@@ -2778,7 +2803,7 @@ static int openssl_get_keyblock_size(SSL *ssl)
 #endif /* CONFIG_FIPS */
 
 
-static int openssl_tls_prf(void *tls_ctx, struct tls_connection *conn,
+static int openssl_tls_prf(struct tls_connection *conn,
                           const char *label, int server_random_first,
                           int skip_keyblock, u8 *out, size_t out_len)
 {
@@ -2946,7 +2971,7 @@ int tls_connection_prf(void *tls_ctx, struct tls_connection *conn,
        if (conn == NULL)
                return -1;
        if (server_random_first || skip_keyblock)
-               return openssl_tls_prf(tls_ctx, conn, label,
+               return openssl_tls_prf(conn, label,
                                       server_random_first, skip_keyblock,
                                       out, out_len);
        ssl = conn->ssl;
@@ -2956,7 +2981,7 @@ int tls_connection_prf(void *tls_ctx, struct tls_connection *conn,
                return 0;
        }
 #endif
-       return openssl_tls_prf(tls_ctx, conn, label, server_random_first,
+       return openssl_tls_prf(conn, label, server_random_first,
                               skip_keyblock, out, out_len);
 }
 
@@ -3633,6 +3658,7 @@ static int ocsp_status_cb(SSL *s, void *arg)
 int tls_connection_set_params(void *tls_ctx, struct tls_connection *conn,
                              const struct tls_connection_params *params)
 {
+       struct tls_data *data = tls_ctx;
        int ret;
        unsigned long err;
        int can_pkcs11 = 0;
@@ -3708,10 +3734,9 @@ int tls_connection_set_params(void *tls_ctx, struct tls_connection *conn,
                return -1;
 
        if (engine_id && ca_cert_id) {
-               if (tls_connection_engine_ca_cert(tls_ctx, conn,
-                                                 ca_cert_id))
+               if (tls_connection_engine_ca_cert(data, conn, ca_cert_id))
                        return TLS_SET_PARAMS_ENGINE_PRV_VERIFY_FAILED;
-       } else if (tls_connection_ca_cert(tls_ctx, conn, params->ca_cert,
+       } else if (tls_connection_ca_cert(data, conn, params->ca_cert,
                                          params->ca_cert_blob,
                                          params->ca_cert_blob_len,
                                          params->ca_path))
@@ -3729,7 +3754,7 @@ int tls_connection_set_params(void *tls_ctx, struct tls_connection *conn,
                wpa_printf(MSG_DEBUG, "TLS: Using private key from engine");
                if (tls_connection_engine_private_key(conn))
                        return TLS_SET_PARAMS_ENGINE_PRV_VERIFY_FAILED;
-       } else if (tls_connection_private_key(tls_ctx, conn,
+       } else if (tls_connection_private_key(data, conn,
                                              params->private_key,
                                              params->private_key_passwd,
                                              params->private_key_blob,
@@ -3783,7 +3808,7 @@ int tls_connection_set_params(void *tls_ctx, struct tls_connection *conn,
 
 #ifdef HAVE_OCSP
        if (params->flags & TLS_CONN_REQUEST_OCSP) {
-               SSL_CTX *ssl_ctx = tls_ctx;
+               SSL_CTX *ssl_ctx = data->ssl;
                SSL_set_tlsext_status_type(conn->ssl, TLSEXT_STATUSTYPE_ocsp);
                SSL_CTX_set_tlsext_status_cb(ssl_ctx, ocsp_resp_cb);
                SSL_CTX_set_tlsext_status_arg(ssl_ctx, conn);
@@ -3802,7 +3827,7 @@ int tls_connection_set_params(void *tls_ctx, struct tls_connection *conn,
 
        conn->flags = params->flags;
 
-       tls_get_errors(tls_ctx);
+       tls_get_errors(data);
 
        return 0;
 }
@@ -3811,7 +3836,8 @@ int tls_connection_set_params(void *tls_ctx, struct tls_connection *conn,
 int tls_global_set_params(void *tls_ctx,
                          const struct tls_connection_params *params)
 {
-       SSL_CTX *ssl_ctx = tls_ctx;
+       struct tls_data *data = tls_ctx;
+       SSL_CTX *ssl_ctx = data->ssl;
        unsigned long err;
 
        while ((err = ERR_get_error())) {
@@ -3819,11 +3845,11 @@ int tls_global_set_params(void *tls_ctx,
                           __func__, ERR_error_string(err, NULL));
        }
 
-       if (tls_global_ca_cert(ssl_ctx, params->ca_cert) ||
-           tls_global_client_cert(ssl_ctx, params->client_cert) ||
-           tls_global_private_key(ssl_ctx, params->private_key,
+       if (tls_global_ca_cert(data, params->ca_cert) ||
+           tls_global_client_cert(data, params->client_cert) ||
+           tls_global_private_key(data, params->private_key,
                                   params->private_key_passwd) ||
-           tls_global_dh(ssl_ctx, params->dh_file)) {
+           tls_global_dh(data, params->dh_file)) {
                wpa_printf(MSG_INFO, "TLS: Failed to set global parameters");
                return -1;
        }