Refactor channel binding code
[cyrus-sasl.git] / lib / server.c
index f4a50bd..a34bccb 100644 (file)
@@ -1454,17 +1454,38 @@ int sasl_server_step(sasl_conn_t *conn,
        }
 
         /* Validate channel bindings */
-        if (conn->oparams.chanbindingflag == SASL_CB_FLAG_NONE &&
-            s_conn->sparams->chanbindingcrit) {
-           sasl_seterror(conn, 0,
-                         "server requires channel binding but client provided none");
-            ret = SASL_BADAUTH;
-        } else if (conn->oparams.chanbindingflag == SASL_CB_FLAG_WANT &&
-            SASL_CB_PRESENT(s_conn->sparams)) {
-            sasl_seterror(conn, 0,
-                          "client incorrectly assumed server had no channel binding");
-            ret = SASL_BADAUTH;
-        } else if (conn->oparams.user == NULL || conn->oparams.authid == NULL) {
+       switch (conn->oparams.cbindingdisp) {
+       case SASL_CB_DISP_NONE:
+           if (SASL_CB_CRITICAL(s_conn->sparams)) {
+               sasl_seterror(conn, 0,
+                             "server requires channel binding but client provided none");
+               ret = SASL_BADAUTH;
+           }
+           break;
+       case SASL_CB_DISP_WANT:
+           if (SASL_CB_PRESENT(s_conn->sparams)) {
+               sasl_seterror(conn, 0,
+                             "client incorrectly assumed server had no channel binding");
+               ret = SASL_BADAUTH;
+           }
+           break;
+       case SASL_CB_DISP_USED:
+           if (!SASL_CB_PRESENT(s_conn->sparams)) {
+               sasl_seterror(conn, 0,
+                             "client provided channel binding but server had none");
+               ret = SASL_BADAUTH;
+           } else if (strcmp(conn->oparams.cbindingname,
+                      s_conn->sparams->cbinding->name) != 0) {
+               sasl_seterror(conn, 0,
+                             "client channel binding %s does not match server %s",
+                             conn->oparams.cbindingname, s_conn->sparams->cbinding->name);
+               ret = SASL_BADAUTH;
+           }
+           break;
+       }
+
+        if (ret == SASL_OK &&
+           (conn->oparams.user == NULL || conn->oparams.authid == NULL)) {
            sasl_seterror(conn, 0,
                          "mech did not call canon_user for both authzid " \
                          "and authid");
@@ -1572,7 +1593,8 @@ int _sasl_server_listmech(sasl_conn_t *conn,
            * he non-PLUS-variant due to policy reasons, it MUST advertise
            * only the PLUS-variant.
            */
-          if (!s_conn->sparams->chanbindingcrit) {
+          if (!SASL_CB_PRESENT(s_conn->sparams) ||
+              !SASL_CB_CRITICAL(s_conn->sparams)) {
             if (pcount != NULL)
              (*pcount)++;
             if (flag)