Implement TLS-PSK.
[radsecproxy.git] / lib / tls.c
index 6fcf5a0..9427f78 100644 (file)
--- a/lib/tls.c
+++ b/lib/tls.c
@@ -8,6 +8,7 @@
 #include <assert.h>
 #include <openssl/ssl.h>
 #include <openssl/err.h>
+#include <openssl/bn.h>
 #include <radsec/radsec.h>
 #include <radsec/radsec-impl.h>
 
@@ -41,6 +42,72 @@ _get_tlsconf (struct rs_connection *conn, const struct rs_realm *realm)
   return c;
 }
 
+static unsigned int
+psk_client_cb (SSL *ssl,
+               const char *hint,
+               char *identity,
+               unsigned int max_identity_len,
+               unsigned char *psk,
+               unsigned int max_psk_len)
+{
+  struct rs_connection *conn = NULL;
+  struct rs_credentials *cred = NULL;
+
+  conn = SSL_get_ex_data (ssl, 0);
+  assert (conn != NULL);
+  cred = conn->active_peer->realm->transport_cred;
+  assert (cred != NULL);
+  /* NOTE: Ignoring identity hint from server.  */
+
+  if (strlen (cred->identity) + 1 > max_identity_len)
+    {
+      rs_err_conn_push (conn, RSE_CRED, "PSK identity longer than max %d",
+                        max_identity_len - 1);
+      return 0;
+    }
+  strcpy (identity, cred->identity);
+
+  switch (cred->secret_encoding)
+    {
+    case RS_KEY_ENCODING_UTF8:
+      cred->secret_len = strlen (cred->secret);
+      if (cred->secret_len > max_psk_len)
+        {
+          rs_err_conn_push (conn, RSE_CRED, "PSK secret longer than max %d",
+                            max_psk_len);
+          return 0;
+        }
+      memcpy (psk, cred->secret, cred->secret_len);
+      break;
+    case RS_KEY_ENCODING_ASCII_HEX:
+      {
+        BIGNUM *bn = NULL;
+
+        if (BN_hex2bn (&bn, cred->secret) == 0)
+          {
+            rs_err_conn_push (conn, RSE_CRED, "Unable to convert pskhexstr");
+            if (bn != NULL)
+              BN_clear_free (bn);
+            return 0;
+          }
+        if ((unsigned int) BN_num_bytes (bn) > max_psk_len)
+          {
+            rs_err_conn_push (conn, RSE_CRED, "PSK secret longer than max %d",
+                             max_psk_len);
+            BN_clear_free (bn);
+            return 0;
+          }
+        cred->secret_len = BN_bn2bin (bn, psk);
+        BN_clear_free (bn);
+      }
+      break;
+    default:
+      assert (!"unknown psk encoding");
+    }
+
+  return cred->secret_len;
+}
+
 int
 rs_tls_init (struct rs_connection *conn)
 {
@@ -73,6 +140,11 @@ rs_tls_init (struct rs_connection *conn)
       return -1;
     }
 
+  if (conn->active_peer->realm->transport_cred != NULL)
+    {
+      SSL_set_psk_client_callback (ssl, psk_client_cb);
+      SSL_set_ex_data (ssl, 0, conn);
+    }
   conn->tls_ctx = ssl_ctx;
   conn->tls_ssl = ssl;
   rs_free (ctx, tlsconf);