complete moving logic to server
[cyrus-sasl.git] / lib / client.c
index 4c35519..2f627c1 100644 (file)
@@ -390,6 +390,72 @@ static int have_prompts(sasl_conn_t *conn,
   return 1; /* we have all the prompts */
 }
 
+static int
+_mech_plus_p(const char *mech, size_t len)
+{
+    return (len > 5 && strncasecmp(&mech[len - 5], "-PLUS", 5) == 0);
+}
+
+/*
+ * Order PLUS mechanisms first. Returns NUL separated list of
+ * *count items.
+ */
+static int
+_sasl_client_order_mechs(const sasl_utils_t *utils,
+                         const char *mechs,
+                         int has_cb_data,
+                         char **ordered_mechs,
+                         size_t *count,
+                         int *server_can_cb)
+{
+    char *list, *listp;
+    size_t i;
+    const char *p, *start  = NULL;
+
+    *count = 0;
+    *server_can_cb = 0;
+
+    listp = list = utils->malloc(strlen(mechs) + 1);
+    if (list == NULL)
+        return SASL_NOMEM;
+
+    if (has_cb_data) {
+        for (start = p = mechs, i = 0; *p != '\0'; p++) {
+            if (isspace(*p) || p[1] == '\0') {
+                size_t len = p - start + 1;
+
+                if (_mech_plus_p(start, len)) {
+                    memcpy(listp, start, len);
+                    listp[len] = '\0';
+                    listp += len + 1;
+                    (*count)++;
+                    *server_can_cb = 1;
+                }
+                start = p + 1;
+            }
+        }
+    }
+
+    for (start = p = mechs, i = 0; *p != '\0'; p++) {
+        if (isspace(*p) || p[1] == '\0') {
+            size_t len = p - start + 1;
+
+            if (!_mech_plus_p(start, len)) {
+                memcpy(listp, start, len);
+                listp[len] = '\0';
+                listp += len + 1;
+                (*count)++;
+            }
+            start = p + 1;
+        }
+    }
+
+    *listp = '\0';
+    *ordered_mechs = list;
+
+    return SASL_OK;
+}
+
 /* select a mechanism for a connection
  *  mechlist      -- mechanisms server has available (punctuation ignored)
  *  secret        -- optional secret from previous session
@@ -421,12 +487,11 @@ int sasl_client_start(sasl_conn_t *conn,
                      const char **mech)
 {
     sasl_client_conn_t *c_conn= (sasl_client_conn_t *) conn;
-    char name[SASL_MECHNAMEMAX + 1];
+    char *ordered_mechs = NULL, *name;
     cmechanism_t *m=NULL,*bestm=NULL;
-    size_t pos=0,place;
-    size_t list_len;
+    size_t i, list_len;
     sasl_ssf_t bestssf = 0, minssf = 0;
-    int result;
+    int result, server_can_cb = 0;
 
     if(_sasl_client_active==0) return SASL_NOTINIT;
 
@@ -451,33 +516,22 @@ int sasl_client_start(sasl_conn_t *conn,
        minssf = conn->props.min_ssf - conn->external.ssf;
     }
 
-    c_conn->cparams->chanbindingflag = SASL_CB_FLAG_NONE;
-
-    /* parse mechlist */
-    list_len = strlen(mechlist);
-
-    while (pos<list_len)
-    {
-       place=0;
-       while ((pos<list_len) && (isalnum((unsigned char)mechlist[pos])
-                                 || mechlist[pos] == '_'
-                                 || mechlist[pos] == '-')) {
-           name[place]=mechlist[pos];
-           pos++;
-           place++;
-           if (SASL_MECHNAMEMAX < place) {
-               place--;
-               while(pos<list_len && (isalnum((unsigned char)mechlist[pos])
-                                      || mechlist[pos] == '_'
-                                      || mechlist[pos] == '-'))
-                   pos++;
-           }
-       }
-       pos++;
-       name[place]=0;
-
-       if (! place) continue;
+    result = _sasl_client_order_mechs(c_conn->cparams->utils,
+                                      mechlist,
+                                      SASL_CB_PRESENT(c_conn->cparams),
+                                      &ordered_mechs,
+                                      &list_len,
+                                      &server_can_cb);
+    if (result != 0)
+        return result;
+
+    /* If we have CB and the server supports it, we should use it */
+    if (SASL_CB_PRESENT(c_conn->cparams) && server_can_cb)
+        c_conn->cparams->chanbindingflag = SASL_CB_FLAG_WANT;
+    else
+        c_conn->cparams->chanbindingflag = SASL_CB_FLAG_NONE;
 
+    for (i = 0, name = ordered_mechs; i < list_len; i++) {
        /* foreach in client list */
        for (m = cmechlist->mech_list; m != NULL; m = m->next) {
            int myflags, plus;
@@ -556,12 +610,9 @@ int sasl_client_start(sasl_conn_t *conn,
                break;
            }
 
-            if (SASL_CB_PRESENT(c_conn->cparams)) {
-                if (plus)
-                    c_conn->cparams->chanbindingflag = SASL_CB_FLAG_USED;
-                else
-                    c_conn->cparams->chanbindingflag = SASL_CB_FLAG_WANT;
-            }
+            /* Prefer server advertised CB mechanisms */
+            if (SASL_CB_PRESENT(c_conn->cparams) && plus)
+                c_conn->cparams->chanbindingflag = SASL_CB_FLAG_USED;
 
            if (mech) {
                *mech = m->m.plug->mech_name;
@@ -570,6 +621,7 @@ int sasl_client_start(sasl_conn_t *conn,
            bestm = m;
            break;
        }
+        name += strlen(name) + 1;
     }
 
     if (bestm == NULL) {
@@ -616,6 +668,8 @@ int sasl_client_start(sasl_conn_t *conn,
        result = SASL_CONTINUE;
 
  done:
+    if (ordered_mechs != NULL)
+        c_conn->cparams->utils->free(ordered_mechs);
     RETURN(conn, result);
 }