Allow for dynamically expanded PSK.
[freeradius.git] / src / main / tls.c
index debec61..50b1edc 100644 (file)
@@ -40,6 +40,7 @@ USES_APPLE_DEPRECATED_API     /* OpenSSL API has been deprecated by Apple */
 #ifdef HAVE_UTIME_H
 #include <utime.h>
 #endif
+#include <ctype.h>
 
 #ifdef WITH_TLS
 #ifdef HAVE_OPENSSL_RAND_H
@@ -80,28 +81,99 @@ static unsigned int         record_minus(record_t *buf, void *ptr,
                                     unsigned int size);
 
 #ifdef PSK_MAX_IDENTITY_LEN
-static unsigned int psk_server_callback(SSL *ssl, char const *identity,
+static bool identity_is_safe(const char *identity)
+{
+       char c;
+
+       if (!identity) return true;
+
+       while ((c = *(identity++)) != '\0') {
+               if (isalpha((int) c) || isdigit((int) c) || isspace((int) c) ||
+                   (c == '@') || (c == '-') || (c == '_') || (c == '.')) {
+                       continue;
+               }
+
+               return false;
+       }
+
+       return true;
+}
+
+
+/*
+ *     When a client uses TLS-PSK to talk to a server, this callback
+ *     is used by the server to determine the PSK to use.
+ */
+static unsigned int psk_server_callback(SSL *ssl, const char *identity,
                                        unsigned char *psk,
                                        unsigned int max_psk_len)
 {
-       unsigned int psk_len;
+       unsigned int psk_len = 0;
        fr_tls_server_conf_t *conf;
+       REQUEST *request;
 
        conf = (fr_tls_server_conf_t *)SSL_get_ex_data(ssl,
                                                       FR_TLS_EX_INDEX_CONF);
        if (!conf) return 0;
 
+       request = (REQUEST *)SSL_get_ex_data(ssl,
+                                            FR_TLS_EX_INDEX_REQUEST);
+       if (request && conf->psk_query) {
+               size_t hex_len;
+               VALUE_PAIR *vp;
+               char buffer[2 * PSK_MAX_PSK_LEN + 4]; /* allow for too-long keys */
+
+               /*
+                *      The passed identity is weird.  Deny it.
+                */
+               if (!identity_is_safe(identity)) {
+                       RWDEBUG("Invalid characters in PSK identity %s", identity);
+                       return 0;
+               }
+
+               vp = pairmake_packet("TLS-PSK-Identity", identity, T_OP_SET);
+               if (!vp) return 0;
+
+               hex_len = radius_xlat(buffer, sizeof(buffer), request, conf->psk_query,
+                                     NULL, NULL);
+               if (!hex_len) {
+                       RWDEBUG("PSK expansion returned an empty string.");
+                       return 0;
+               }
+
+               /*
+                *      The returned key is truncated at MORE than
+                *      OpenSSL can handle.  That way we can detect
+                *      the truncation, and complain about it.
+                */
+               if (hex_len > (2 * max_psk_len)) {
+                       RWDEBUG("Returned PSK is too long (%u > %u)",
+                               (unsigned int) hex_len, 2 * max_psk_len);
+                       return 0;
+               }
+
+               /*
+                *      Leave the TLS-PSK-Identity in the request, and
+                *      convert the expansion from printable string
+                *      back to hex.
+                */
+               return fr_hex2bin(psk, max_psk_len, buffer, hex_len);
+       }
+
        /*
-        *      FIXME: Look up the PSK password based on the identity!
+        *      No REQUEST, or no dynamic query.  Just look for a
+        *      static identity.
         */
        if (strcmp(identity, conf->psk_identity) != 0) {
+               ERROR("Supplied PSK identity %s does not match configuration.  Rejecting.",
+                     identity);
                return 0;
        }
 
        psk_len = strlen(conf->psk_password);
        if (psk_len > (2 * max_psk_len)) return 0;
 
-       return fr_hex2bin(psk, conf->psk_password, psk_len);
+       return fr_hex2bin(psk, max_psk_len, conf->psk_password, psk_len);
 }
 
 static unsigned int psk_client_callback(SSL *ssl, UNUSED char const *hint,
@@ -120,7 +192,7 @@ static unsigned int psk_client_callback(SSL *ssl, UNUSED char const *hint,
 
        strlcpy(identity, conf->psk_identity, max_identity_len);
 
-       return fr_hex2bin(psk, conf->psk_password, psk_len);
+       return fr_hex2bin(psk, max_psk_len, conf->psk_password, psk_len);
 }
 
 #endif
@@ -129,6 +201,7 @@ tls_session_t *tls_new_client_session(fr_tls_server_conf_t *conf, int fd)
 {
        int verify_mode;
        tls_session_t *ssn = NULL;
+       REQUEST *request;
 
        ssn = talloc_zero(conf, tls_session_t);
        if (!ssn) return NULL;
@@ -138,7 +211,13 @@ tls_session_t *tls_new_client_session(fr_tls_server_conf_t *conf, int fd)
        SSL_CTX_set_mode(ssn->ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | SSL_MODE_AUTO_RETRY);
 
        ssn->ssl = SSL_new(ssn->ctx);
-       rad_assert(ssn->ssl != NULL);
+       if (!ssn->ssl) {
+               talloc_free(ssn);
+               return NULL;
+       }
+
+       request = request_alloc(ssn);
+       SSL_set_ex_data(ssn->ssl, FR_TLS_EX_INDEX_REQUEST, (void *)request);
 
        /*
         *      Add the message callback to identify what type of
@@ -166,6 +245,7 @@ tls_session_t *tls_new_client_session(fr_tls_server_conf_t *conf, int fd)
                }
                SSL_free(ssn->ssl);
                talloc_free(ssn);
+
                return NULL;
        }
 
@@ -174,8 +254,22 @@ tls_session_t *tls_new_client_session(fr_tls_server_conf_t *conf, int fd)
        return ssn;
 }
 
-tls_session_t *tls_new_session(fr_tls_server_conf_t *conf, REQUEST *request,
-                              int client_cert)
+static int _tls_session_free(tls_session_t *ssn)
+{
+       /*
+        *      Free any opaque TTLS or PEAP data.
+        */
+       if ((ssn->opaque) && (ssn->free_opaque)) {
+               ssn->free_opaque(ssn->opaque);
+               ssn->opaque = NULL;
+       }
+
+       session_close(ssn);
+
+       return 0;
+}
+
+tls_session_t *tls_new_session(TALLOC_CTX *ctx, fr_tls_server_conf_t *conf, REQUEST *request, bool client_cert)
 {
        tls_session_t *state = NULL;
        SSL *new_tls = NULL;
@@ -209,8 +303,9 @@ tls_session_t *tls_new_session(fr_tls_server_conf_t *conf, REQUEST *request,
        /* We use the SSL's "app_data" to indicate a call-back */
        SSL_set_app_data(new_tls, NULL);
 
-       state = talloc_zero(conf, tls_session_t);
+       state = talloc_zero(ctx, tls_session_t);
        session_init(state);
+       talloc_set_destructor(state, _tls_session_free);
 
        state->ctx = conf->ctx;
        state->ssl = new_tls;
@@ -512,25 +607,6 @@ void session_close(tls_session_t *ssn)
        session_init(ssn);
 }
 
-void session_free(void *ssn)
-{
-       tls_session_t *sess = (tls_session_t *)ssn;
-
-       if (!ssn) return;
-
-       /*
-        *      Free any opaque TTLS or PEAP data.
-        */
-       if ((sess->opaque) && (sess->free_opaque)) {
-               sess->free_opaque(sess->opaque);
-               sess->opaque = NULL;
-       }
-
-       session_close(sess);
-
-       talloc_free(sess);
-}
-
 static void record_init(record_t *rec)
 {
        rec->used = 0;
@@ -771,8 +847,11 @@ void tls_session_information(tls_session_t *tls_session)
                 str_details1, str_details2);
 
        request = SSL_get_ex_data(tls_session->ssl, FR_TLS_EX_INDEX_REQUEST);
-
-       RDEBUG2("%s\n", tls_session->info.info_description);
+       if (request) {
+               RDEBUG2("%s", tls_session->info.info_description);
+       } else {
+               DEBUG2("%s", tls_session->info.info_description);
+       }
 }
 
 static CONF_PARSER cache_config[] = {
@@ -819,6 +898,7 @@ static CONF_PARSER tls_server_config[] = {
 #ifdef PSK_MAX_IDENTITY_LEN
        { "psk_identity", FR_CONF_OFFSET(PW_TYPE_STRING, fr_tls_server_conf_t, psk_identity), NULL },
        { "psk_hexphrase", FR_CONF_OFFSET(PW_TYPE_STRING | PW_TYPE_SECRET, fr_tls_server_conf_t, psk_password), NULL },
+       { "psk_query", FR_CONF_OFFSET(PW_TYPE_STRING, fr_tls_server_conf_t, psk_query), NULL },
 #endif
        { "dh_file", FR_CONF_OFFSET(PW_TYPE_STRING, fr_tls_server_conf_t, dh_file), NULL },
        { "random_file", FR_CONF_OFFSET(PW_TYPE_STRING, fr_tls_server_conf_t, random_file), NULL },
@@ -1008,8 +1088,10 @@ static int cbtls_new_session(SSL *ssl, SSL_SESSION *sess)
                        return 0;
                }
 
+
+               /* Do not convert to TALLOC - Thread safety */
                /* alloc and convert to ASN.1 */
-               sess_blob = talloc_array(conf, unsigned char, blob_len);
+               sess_blob = malloc(blob_len);
                if (!sess_blob) {
                        DEBUG2("  SSL: could not allocate buffer len=%d to persist session", blob_len);
                        return 0;
@@ -1048,7 +1130,7 @@ static int cbtls_new_session(SSL *ssl, SSL_SESSION *sess)
        }
 
 error:
-       if (sess_blob) talloc_free(sess_blob);
+       free(sess_blob);
 
        return 0;
 }
@@ -1491,12 +1573,9 @@ int cbtls_verify(int ok, X509_STORE_CTX *ctx)
        if (!conf) return 1;
 
        request = (REQUEST *)SSL_get_ex_data(ssl, FR_TLS_EX_INDEX_REQUEST);
-
-       if (!request) return 1; /* FIXME: outbound TLS */
-
        rad_assert(request != NULL);
        certs = (VALUE_PAIR **)SSL_get_ex_data(ssl, FR_TLS_EX_INDEX_CERTS);
-       rad_assert(certs != NULL);
+
        identity = (char **)SSL_get_ex_data(ssl, FR_TLS_EX_INDEX_IDENTITY);
 #ifdef HAVE_OPENSSL_OCSP_H
        ocsp_store = (X509_STORE *)SSL_get_ex_data(ssl, FR_TLS_EX_INDEX_STORE);
@@ -1516,7 +1595,7 @@ int cbtls_verify(int ok, X509_STORE_CTX *ctx)
         *      have a user identity.  i.e. we don't create the
         *      attributes for RadSec connections.
         */
-       if (identity &&
+       if (certs && identity &&
            (lookup <= 1) && sn && ((size_t) sn->length < (sizeof(buf) / 2))) {
                char *p = buf;
                int i;
@@ -1534,7 +1613,7 @@ int cbtls_verify(int ok, X509_STORE_CTX *ctx)
         */
        buf[0] = '\0';
        asn_time = X509_get_notAfter(client_cert);
-       if (identity && (lookup <= 1) && asn_time &&
+       if (certs && identity && (lookup <= 1) && asn_time &&
            (asn_time->length < (int) sizeof(buf))) {
                memcpy(buf, (char*) asn_time->data, asn_time->length);
                buf[asn_time->length] = '\0';
@@ -1548,14 +1627,14 @@ int cbtls_verify(int ok, X509_STORE_CTX *ctx)
        X509_NAME_oneline(X509_get_subject_name(client_cert), subject,
                          sizeof(subject));
        subject[sizeof(subject) - 1] = '\0';
-       if (identity && (lookup <= 1) && subject[0]) {
+       if (certs && identity && (lookup <= 1) && subject[0]) {
                pairmake(talloc_ctx, certs, cert_attr_names[FR_TLS_SUBJECT][lookup], subject, T_OP_SET);
        }
 
        X509_NAME_oneline(X509_get_issuer_name(ctx->current_cert), issuer,
                          sizeof(issuer));
        issuer[sizeof(issuer) - 1] = '\0';
-       if (identity && (lookup <= 1) && issuer[0]) {
+       if (certs && identity && (lookup <= 1) && issuer[0]) {
                pairmake(talloc_ctx, certs, cert_attr_names[FR_TLS_ISSUER][lookup], issuer, T_OP_SET);
        }
 
@@ -1565,7 +1644,7 @@ int cbtls_verify(int ok, X509_STORE_CTX *ctx)
        X509_NAME_get_text_by_NID(X509_get_subject_name(client_cert),
                                  NID_commonName, common_name, sizeof(common_name));
        common_name[sizeof(common_name) - 1] = '\0';
-       if (identity && (lookup <= 1) && common_name[0] && subject[0]) {
+       if (certs && identity && (lookup <= 1) && common_name[0] && subject[0]) {
                pairmake(talloc_ctx, certs, cert_attr_names[FR_TLS_CN][lookup], common_name, T_OP_SET);
        }
 
@@ -1573,7 +1652,7 @@ int cbtls_verify(int ok, X509_STORE_CTX *ctx)
         *      Get the RFC822 Subject Alternative Name
         */
        loc = X509_get_ext_by_NID(client_cert, NID_subject_alt_name, 0);
-       if (lookup <= 1 && loc >= 0) {
+       if (certs && (lookup <= 1) && (loc >= 0)) {
                X509_EXTENSION *ext = NULL;
                GENERAL_NAMES *names = NULL;
                int i;
@@ -1991,7 +2070,7 @@ void tls_global_cleanup(void)
  *     - Load the Private key & the certificate
  *     - Set the Context options & Verify options
  */
-static SSL_CTX *init_tls_ctx(fr_tls_server_conf_t *conf, int client)
+static SSL_CTX *tls_init_ctx(fr_tls_server_conf_t *conf, int client)
 {
        SSL_CTX *ctx;
        X509_STORE *certstore;
@@ -2090,6 +2169,26 @@ static SSL_CTX *init_tls_ctx(fr_tls_server_conf_t *conf, int client)
        }
 
 #ifdef PSK_MAX_IDENTITY_LEN
+       if (!client) {
+               /*
+                *      No dynamic query exists.  There MUST be a
+                *      statically configured identity and password.
+                */
+               if (conf->psk_query && !*conf->psk_query) {
+                       ERROR("Invalid PSK Configuration: psk_query cannot be empty");
+                       return NULL;
+               }
+
+               SSL_CTX_set_psk_server_callback(ctx, psk_server_callback);
+
+       } else if (conf->psk_query) {
+               ERROR("Invalid PSK Configuration: psk_query cannot be used for outgoing connections");
+               return NULL;
+       }
+
+       /*
+        *      Now check that if PSK is being used, the config is valid.
+        */
        if ((conf->psk_identity && !conf->psk_password) ||
            (!conf->psk_identity && conf->psk_password) ||
            (conf->psk_identity && !*conf->psk_identity) ||
@@ -2100,7 +2199,7 @@ static SSL_CTX *init_tls_ctx(fr_tls_server_conf_t *conf, int client)
 
        if (conf->psk_identity) {
                size_t psk_len, hex_len;
-               char buffer[PSK_MAX_PSK_LEN];
+               uint8_t buffer[PSK_MAX_PSK_LEN];
 
                if (conf->certificate_file ||
                    conf->private_key_password || conf->private_key_file ||
@@ -2112,9 +2211,6 @@ static SSL_CTX *init_tls_ctx(fr_tls_server_conf_t *conf, int client)
                if (client) {
                        SSL_CTX_set_psk_client_callback(ctx,
                                                        psk_client_callback);
-               } else {
-                       SSL_CTX_set_psk_server_callback(ctx,
-                                                       psk_server_callback);
                }
 
                psk_len = strlen(conf->psk_password);
@@ -2124,7 +2220,11 @@ static SSL_CTX *init_tls_ctx(fr_tls_server_conf_t *conf, int client)
                        return NULL;
                }
 
-               hex_len = fr_hex2bin((uint8_t *) buffer, conf->psk_password, psk_len);
+               /*
+                *      Check the password now, so that we don't have
+                *      errors at run-time.
+                */
+               hex_len = fr_hex2bin(buffer, sizeof(buffer), conf->psk_password, psk_len);
                if (psk_len != (2 * hex_len)) {
                        ERROR("psk_hexphrase is not all hex");
                        return NULL;
@@ -2272,13 +2372,13 @@ post_ca:
         */
 #ifdef X509_V_FLAG_CRL_CHECK
        if (conf->check_crl) {
-         certstore = SSL_CTX_get_cert_store(ctx);
-         if (certstore == NULL) {
-           ERROR("tls: SSL error %s", ERR_error_string(ERR_get_error(), NULL));
-           ERROR("tls: Error reading Certificate Store");
-           return NULL;
-         }
-         X509_STORE_set_flags(certstore, X509_V_FLAG_CRL_CHECK);
+               certstore = SSL_CTX_get_cert_store(ctx);
+               if (certstore == NULL) {
+                       ERROR("tls: SSL error %s", ERR_error_string(ERR_get_error(), NULL));
+                       ERROR("tls: Error reading Certificate Store");
+                       return NULL;
+               }
+               X509_STORE_set_flags(certstore, X509_V_FLAG_CRL_CHECK);
        }
 #endif
 
@@ -2366,7 +2466,7 @@ post_ca:
  *     added to automatically free the data when the CONF_SECTION
  *     is freed.
  */
-static int tls_server_conf_free(fr_tls_server_conf_t *conf)
+static int _tls_server_conf_free(fr_tls_server_conf_t *conf)
 {
        if (conf->ctx) SSL_CTX_free(conf->ctx);
 
@@ -2381,6 +2481,21 @@ static int tls_server_conf_free(fr_tls_server_conf_t *conf)
        return 0;
 }
 
+static fr_tls_server_conf_t *tls_server_conf_alloc(TALLOC_CTX *ctx)
+{
+       fr_tls_server_conf_t *conf;
+
+       conf = talloc_zero(ctx, fr_tls_server_conf_t);
+       if (!conf) {
+               ERROR("Out of memory");
+               return NULL;
+       }
+
+       talloc_set_destructor(conf, _tls_server_conf_free);
+
+       return conf;
+}
+
 
 fr_tls_server_conf_t *tls_server_conf_parse(CONF_SECTION *cs)
 {
@@ -2396,13 +2511,7 @@ fr_tls_server_conf_t *tls_server_conf_parse(CONF_SECTION *cs)
                return conf;
        }
 
-       conf = talloc_zero(cs, fr_tls_server_conf_t);
-       if (!conf) {
-               ERROR("Out of memory");
-               return NULL;
-       }
-
-       talloc_set_destructor(conf, tls_server_conf_free);
+       conf = tls_server_conf_alloc(cs);
 
        if (cf_section_parse(cs, conf, tls_server_config) < 0) {
        error:
@@ -2428,7 +2537,7 @@ fr_tls_server_conf_t *tls_server_conf_parse(CONF_SECTION *cs)
        /*
         *      Initialize TLS
         */
-       conf->ctx = init_tls_ctx(conf, 0);
+       conf->ctx = tls_init_ctx(conf, 0);
        if (conf->ctx == NULL) {
                goto error;
        }
@@ -2485,13 +2594,7 @@ fr_tls_server_conf_t *tls_client_conf_parse(CONF_SECTION *cs)
                return conf;
        }
 
-       conf = talloc_zero(cs, fr_tls_server_conf_t);
-       if (!conf) {
-               ERROR("Out of memory");
-               return NULL;
-       }
-
-       talloc_set_destructor(conf, tls_server_conf_free);
+       conf = tls_server_conf_alloc(cs);
 
        if (cf_section_parse(cs, conf, tls_client_config) < 0) {
        error:
@@ -2507,7 +2610,7 @@ fr_tls_server_conf_t *tls_client_conf_parse(CONF_SECTION *cs)
        /*
         *      Initialize TLS
         */
-       conf->ctx = init_tls_ctx(conf, 1);
+       conf->ctx = tls_init_ctx(conf, 1);
        if (conf->ctx == NULL) {
                goto error;
        }
@@ -2679,7 +2782,7 @@ int tls_success(tls_session_t *ssn, REQUEST *request)
                                                paircopyvp(request->packet, vp));
                                } else {
                                        pairadd(&request->reply->vps,
-                                               paircopyvp(request->packet, vp));
+                                               paircopyvp(request->reply, vp));
                                }
                        }