cleanups to minimise merging hassle
[cyrus-sasl.git] / lib / server.c
index 59c61ff..928447d 100644 (file)
@@ -1214,6 +1214,7 @@ int sasl_server_start(sasl_conn_t *conn,
     int result;
     context_list_t *cur, **prev;
     mechanism_t *m;
+    int plus = 0;
 
     if (_sasl_server_active==0) return SASL_NOTINIT;
 
@@ -1230,13 +1231,11 @@ int sasl_server_start(sasl_conn_t *conn,
     if(serverout) *serverout = NULL;
     if(serveroutlen) *serveroutlen = 0;
 
-    while (m!=NULL)
-    {
-       if ( strcasecmp(mech, m->m.plug->mech_name)==0)
-       {
+    while (m != NULL) {
+       if (_sasl_is_equal_mech(mech, m->m.plug->mech_name, &plus))
            break;
-       }
-       m=m->next;
+
+       m = m->next;
     }
   
     if (m==NULL) {
@@ -1427,7 +1426,6 @@ int sasl_server_step(sasl_conn_t *conn,
     if(serverout) *serverout = NULL;
     if(serveroutlen) *serveroutlen = 0;
 
-    s_conn->sparams->plug = s_conn->mech->m.plug;
     ret = s_conn->mech->m.plug->mech_step(conn->context,
                                        s_conn->sparams,
                                        clientin,
@@ -1436,12 +1434,10 @@ int sasl_server_step(sasl_conn_t *conn,
                                        serveroutlen,
                                        &conn->oparams);
 
-    s_conn->sparams->plug = NULL;
     if (ret == SASL_OK) {
        ret = do_authorization(s_conn);
     }
 
-
     if (ret == SASL_OK) {
        /* if we're done, we need to watch out for the following:
         * 1. the mech does server-send-last
@@ -1457,7 +1453,39 @@ int sasl_server_step(sasl_conn_t *conn,
            conn->oparams.maxoutbuf = conn->props.maxbufsize;
        }
 
-       if(conn->oparams.user == NULL || conn->oparams.authid == NULL) {
+        /* Validate channel bindings */
+       switch (conn->oparams.cbindingdisp) {
+       case SASL_CB_DISP_NONE:
+           if (SASL_CB_CRITICAL(s_conn->sparams)) {
+               sasl_seterror(conn, 0,
+                             "server requires channel binding but client provided none");
+               ret = SASL_BADAUTH;
+           }
+           break;
+       case SASL_CB_DISP_WANT:
+           if (SASL_CB_PRESENT(s_conn->sparams)) {
+               sasl_seterror(conn, 0,
+                             "client incorrectly assumed server had no channel binding");
+               ret = SASL_BADAUTH;
+           }
+           break;
+       case SASL_CB_DISP_USED:
+           if (!SASL_CB_PRESENT(s_conn->sparams)) {
+               sasl_seterror(conn, 0,
+                             "client provided channel binding but server had none");
+               ret = SASL_BADAUTH;
+           } else if (strcmp(conn->oparams.cbindingname,
+                      s_conn->sparams->cbinding->name) != 0) {
+               sasl_seterror(conn, 0,
+                             "client channel binding %s does not match server %s",
+                             conn->oparams.cbindingname, s_conn->sparams->cbinding->name);
+               ret = SASL_BADAUTH;
+           }
+           break;
+       }
+
+        if (ret == SASL_OK &&
+           (conn->oparams.user == NULL || conn->oparams.authid == NULL)) {
            sasl_seterror(conn, 0,
                          "mech did not call canon_user for both authzid " \
                          "and authid");
@@ -1514,6 +1542,7 @@ int _sasl_server_listmech(sasl_conn_t *conn,
   size_t resultlen;
   int flag;
   const char *mysep;
+  sasl_server_conn_t *s_conn = (sasl_server_conn_t *) conn;  /* cast */
 
   /* if there hasn't been a sasl_sever_init() fail */
   if (_sasl_server_active==0) return SASL_NOTINIT;
@@ -1538,8 +1567,9 @@ int _sasl_server_listmech(sasl_conn_t *conn,
       INTERROR(conn, SASL_NOMECH);
 
   resultlen = (prefix ? strlen(prefix) : 0)
-            + (strlen(mysep) * (mechlist->mech_length - 1))
-           + mech_names_len()
+            + (strlen(mysep) * (mechlist->mech_length - 1) * 2)
+           + (mech_names_len() * 2) /* including -PLUS variant */
+           + (mechlist->mech_length * (sizeof("-PLUS") - 1))
             + (suffix ? strlen(suffix) : 0)
            + 1;
   ret = _buf_alloc(&conn->mechlist_buf,
@@ -1558,18 +1588,38 @@ int _sasl_server_listmech(sasl_conn_t *conn,
   for (lup = 0; lup < mechlist->mech_length; lup++) {
       /* currently, we don't use the "user" parameter for anything */
       if (mech_permitted(conn, listptr) == SASL_OK) {
-         if (pcount != NULL)
+          /*
+           * If the server would never succeed in the authentication of
+           * he non-PLUS-variant due to policy reasons, it MUST advertise
+           * only the PLUS-variant.
+           */
+          if (!SASL_CB_PRESENT(s_conn->sparams) ||
+              !SASL_CB_CRITICAL(s_conn->sparams)) {
+            if (pcount != NULL)
              (*pcount)++;
-
-         /* print separator */
-         if (flag) {
-             strcat(conn->mechlist_buf, mysep);
-         } else {
-             flag = 1;
+            if (flag)
+              strcat(conn->mechlist_buf, mysep);
+            else
+              flag = 1;
+           strcat(conn->mechlist_buf, listptr->m.plug->mech_name);
+          }
+          /*
+           * If the server cannot support channel binding, it SHOULD
+           * advertise only the non-PLUS-variant. Here, supporting channel
+           * binding means the underlying SASL mechanism supports it and
+           * the application has set some channel binding data.
+           */
+         if ((listptr->m.plug->features & SASL_FEAT_CHANNEL_BINDING) &&
+             SASL_CB_PRESENT(s_conn->sparams)) {
+           if (pcount != NULL)
+               (*pcount)++;
+            if (flag)
+              strcat(conn->mechlist_buf, mysep);
+            else
+              flag = 1;
+           strcat(conn->mechlist_buf, listptr->m.plug->mech_name);
+           strcat(conn->mechlist_buf, "-PLUS");
          }
-
-         /* now print the mechanism name */
-         strcat(conn->mechlist_buf, listptr->m.plug->mech_name);
       }
 
       listptr = listptr->next;
@@ -2049,6 +2099,16 @@ _sasl_print_mechanism (
            printf ("%cNEED_GETSECRET", delimiter);
            delimiter = '|';
        }
+
+        if (m->plug->features & SASL_FEAT_GSS_FRAMING) {
+           printf ("%cGSS_FRAMING", delimiter);
+           delimiter = '|';
+       }
+
+        if (m->plug->features & SASL_FEAT_CHANNEL_BINDING) {
+           printf ("%cCHANNEL_BINDING", delimiter);
+           delimiter = '|';
+       }
     }
 
     if (m->f) {