Only ask for password if we can't get creds
[cyrus-sasl.git] / plugins / gs2.c
index 99298c7..917d9ed 100644 (file)
@@ -156,7 +156,8 @@ static int gs2_make_message(context_t *text,
 static int gs2_get_mech_attrs(const sasl_utils_t *utils,
                               const gss_OID mech,
                               unsigned int *security_flags,
-                              unsigned int *features);
+                              unsigned int *features,
+                              const unsigned long **prompts);
 
 static int gs2_indicate_mechs(const sasl_utils_t *utils);
 
@@ -245,11 +246,6 @@ sasl_gs2_free_context_contents(context_t *text)
         text->authzid = NULL;
     }
 
-    if (text->mechanism != NULL) {
-        gss_release_oid(&min_stat, &text->mechanism);
-        text->mechanism = GSS_C_NO_OID;
-    }
-
     gss_release_buffer(&min_stat, &text->gss_cbindings.application_data);
 
     if (text->out_buf != NULL) {
@@ -345,7 +341,6 @@ gs2_server_mech_step(void *conn_context,
     gss_buffer_desc short_name_buf = GSS_C_EMPTY_BUFFER;
     gss_name_t without = GSS_C_NO_NAME;
     gss_OID_set_desc mechs;
-    gss_OID actual_mech = GSS_C_NO_OID;
     OM_uint32 out_flags = 0;
     int ret = 0, equal = 0;
     int initialContextToken = (text->gss_ctx == GSS_C_NO_CONTEXT);
@@ -417,7 +412,7 @@ gs2_server_mech_step(void *conn_context,
                                       &input_token,
                                       &text->gss_cbindings,
                                       &text->client_name,
-                                      &actual_mech,
+                                      NULL,
                                       &output_token,
                                       &out_flags,
                                       &text->lifetime,
@@ -451,10 +446,6 @@ gs2_server_mech_step(void *conn_context,
 
     assert(maj_stat == GSS_S_COMPLETE);
 
-    if (!g_OID_equal(text->mechanism, actual_mech)) {
-        ret = SASL_WRONGMECH;
-        goto cleanup;
-    }
     if ((out_flags & GSS_C_SEQUENCE_FLAG) == 0)  {
         ret = SASL_BADAUTH;
         goto cleanup;
@@ -549,7 +540,6 @@ cleanup:
     gss_release_buffer(&min_stat, &short_name_buf);
     gss_release_buffer(&min_stat, &output_token);
     gss_release_name(&min_stat, &without);
-    gss_release_oid(&min_stat, &actual_mech);
 
     if (ret == SASL_OK && maj_stat != GSS_S_COMPLETE) {
         sasl_gs2_seterror(text->utils, maj_stat, min_stat);
@@ -634,7 +624,8 @@ gs2_server_plug_alloc(const sasl_utils_t *utils,
 
     ret = gs2_get_mech_attrs(utils, mech,
                              &splug->security_flags,
-                             &splug->features);
+                             &splug->features,
+                             NULL);
     if (ret != SASL_OK)
         return ret;
 
@@ -705,7 +696,6 @@ static int gs2_client_mech_step(void *conn_context,
     gss_buffer_desc name_buf = GSS_C_EMPTY_BUFFER;
     OM_uint32 maj_stat = GSS_S_FAILURE, min_stat = 0;
     OM_uint32 req_flags, ret_flags;
-    gss_OID actual_mech = GSS_C_NO_OID;
     int ret = SASL_FAIL;
     int initialContextToken;
 
@@ -717,40 +707,6 @@ static int gs2_client_mech_step(void *conn_context,
         if (ret != SASL_OK)
             goto cleanup;
 
-        if (params->gss_creds == GSS_C_NO_CREDENTIAL &&
-            text->password != NULL && text->password->len != 0) {
-            gss_buffer_desc password_buf;
-            gss_buffer_desc name_buf;
-            gss_OID_set_desc mechs;
-
-            name_buf.length = strlen(oparams->authid);
-            name_buf.value = (void *)oparams->authid;
-
-            password_buf.length = text->password->len;
-            password_buf.value = text->password->data;
-
-            mechs.count = 1;
-            mechs.elements = (gss_OID)text->mechanism;
-
-            maj_stat = gss_import_name(&min_stat,
-                                       &name_buf,
-                                       GSS_C_NT_USER_NAME,
-                                       &text->client_name);
-            if (GSS_ERROR(maj_stat))
-                goto cleanup;
-
-            maj_stat = gss_acquire_cred_with_password(&min_stat,
-                                                      text->client_name,
-                                                      &password_buf,
-                                                      GSS_C_INDEFINITE,
-                                                      &mechs,
-                                                      GSS_C_INITIATE,
-                                                      &text->client_creds,
-                                                      NULL,
-                                                      &text->lifetime);
-            if (GSS_ERROR(maj_stat))
-                goto cleanup;
-        }
 
         initialContextToken = 1;
     } else
@@ -858,17 +814,13 @@ static int gs2_client_mech_step(void *conn_context,
                                    &text->client_name,
                                    NULL,
                                    &text->lifetime,
-                                   &actual_mech,
+                                   NULL,
                                    &ret_flags, /* flags */
                                    NULL,
                                    NULL);
     if (GSS_ERROR(maj_stat))
         goto cleanup;
 
-    if (!g_OID_equal(text->mechanism, actual_mech)) {
-        ret = SASL_WRONGMECH;
-        goto cleanup;
-    }
     if ((ret_flags & req_flags) != req_flags) {
         maj_stat = SASL_BADAUTH;
         goto cleanup;
@@ -892,7 +844,6 @@ static int gs2_client_mech_step(void *conn_context,
 cleanup:
     gss_release_buffer(&min_stat, &output_token);
     gss_release_buffer(&min_stat, &name_buf);
-    gss_release_oid(&min_stat, &actual_mech);
 
     if (ret == SASL_OK && maj_stat != GSS_S_COMPLETE) {
         sasl_gs2_seterror(text->utils, maj_stat, min_stat);
@@ -935,10 +886,6 @@ static int gs2_client_mech_new(void *glob_context,
     return SASL_OK;
 }
 
-static const unsigned long gs2_required_prompts[] = {
-    SASL_CB_LIST_END
-};
-
 static int
 gs2_client_plug_alloc(const sasl_utils_t *utils,
                       void *plug,
@@ -953,7 +900,8 @@ gs2_client_plug_alloc(const sasl_utils_t *utils,
 
     ret = gs2_get_mech_attrs(utils, mech,
                              &cplug->security_flags,
-                             &cplug->features);
+                             &cplug->features,
+                             &cplug->required_prompts);
     if (ret != SASL_OK)
         return ret;
 
@@ -968,7 +916,6 @@ gs2_client_plug_alloc(const sasl_utils_t *utils,
     cplug->mech_step = gs2_client_mech_step;
     cplug->mech_dispose = gs2_common_mech_dispose;
     cplug->mech_free = gs2_common_mech_free;
-    cplug->required_prompts = gs2_required_prompts;
 
     return SASL_OK;
 }
@@ -1325,6 +1272,10 @@ gs2_make_message(context_t *text,
     return SASL_OK;
 }
 
+static const unsigned long gs2_required_prompts[] = {
+    SASL_CB_LIST_END
+};
+
 /*
  * Map GSS mechanism attributes to SASL ones
  */
@@ -1332,10 +1283,11 @@ static int
 gs2_get_mech_attrs(const sasl_utils_t *utils,
                    const gss_OID mech,
                    unsigned int *security_flags,
-                   unsigned int *features)
+                   unsigned int *features,
+                   const unsigned long **prompts)
 {
     OM_uint32 major, minor;
-    int present, ret;
+    int present;
     gss_OID_set attrs = GSS_C_NO_OID_SET;
 
     major = gss_inquire_attrs_for_mech(&minor, mech, &attrs, NULL);
@@ -1347,13 +1299,13 @@ gs2_get_mech_attrs(const sasl_utils_t *utils,
 
     *security_flags = SASL_SEC_NOPLAINTEXT | SASL_SEC_NOACTIVE;
     *features = SASL_FEAT_WANT_CLIENT_FIRST | SASL_FEAT_CHANNEL_BINDING;
+    if (prompts != NULL)
+        *prompts = gs2_required_prompts;
 
 #define MA_PRESENT(a)   (gss_test_oid_set_member(&minor, (gss_OID)(a), \
                                                  attrs, &present) == GSS_S_COMPLETE && \
                          present)
 
-    ret = SASL_OK;
-
     if (MA_PRESENT(GSS_C_MA_PFS))
         *security_flags |= SASL_SEC_FORWARD_SECRECY;
     if (!MA_PRESENT(GSS_C_MA_AUTH_INIT_ANON))
@@ -1362,11 +1314,14 @@ gs2_get_mech_attrs(const sasl_utils_t *utils,
         *security_flags |= SASL_SEC_PASS_CREDENTIALS;
     if (MA_PRESENT(GSS_C_MA_AUTH_TARG))
         *security_flags |= SASL_SEC_MUTUAL_AUTH;
+    if (MA_PRESENT(GSS_C_MA_AUTH_INIT_INIT) && prompts != NULL)
+        *prompts = NULL;
     if (MA_PRESENT(GSS_C_MA_ITOK_FRAMED))
         *features |= SASL_FEAT_GSS_FRAMING;
 
     gss_release_oid_set(&minor, &attrs);
-    return ret;
+
+    return SASL_OK;
 }
 
 /*
@@ -1548,31 +1503,84 @@ gs2_ask_user_info(context_t *text,
     int user_result = SASL_OK;
     int auth_result = SASL_OK;
     int pass_result = SASL_OK;
+    OM_uint32 maj_stat, min_stat;
+    gss_buffer_desc authid_buf = GSS_C_EMPTY_BUFFER;
+    gss_OID_set_desc mechs;
 
     /* try to get the authid */
     if (oparams->authid == NULL) {
         auth_result = _plug_get_authid(params->utils, &authid, prompt_need);
-
-        if (auth_result != SASL_OK && auth_result != SASL_INTERACT) {
+        if (auth_result != SASL_OK && auth_result != SASL_INTERACT) 
             return auth_result;
-        }
     }
 
     /* try to get the userid */
     if (oparams->user == NULL) {
         user_result = _plug_get_userid(params->utils, &userid, prompt_need);
 
-        if (user_result != SASL_OK && user_result != SASL_INTERACT) {
+        if (user_result != SASL_OK && user_result != SASL_INTERACT)
             return user_result;
+    }
+
+    mechs.count = 1;
+    mechs.elements = (gss_OID)text->mechanism;
+
+    if (authid != NULL) {
+        authid_buf.length = strlen(authid);
+        authid_buf.value = (void *)authid;
+    }
+
+    if (params->gss_creds == GSS_C_NO_CREDENTIAL && authid != NULL) {
+        maj_stat = gss_import_name(&min_stat, &authid_buf,
+                                   GSS_C_NT_USER_NAME, &text->client_name);
+        if (GSS_ERROR(maj_stat)) {
+            sasl_gs2_seterror(text->utils, maj_stat, min_stat);
+            return SASL_FAIL;
+        }
+
+        /* See if we have a default credential */
+        maj_stat = gss_acquire_cred(&min_stat,
+                                    text->client_name,
+                                    GSS_C_INDEFINITE,
+                                    &mechs,
+                                    GSS_C_INITIATE,
+                                    &text->client_creds,
+                                    NULL,
+                                    &text->lifetime);
+        if (maj_stat != GSS_S_COMPLETE && maj_stat != GSS_S_CRED_UNAVAIL) {
+            sasl_gs2_seterror(text->utils, maj_stat, min_stat);
+            return SASL_FAIL;
         }
     }
 
-    /* try to get the password */
-    if (text->password == NULL) {
+    /* try to get the password, only if necessary */
+    if (text->password == NULL &&
+        params->gss_creds == GSS_C_NO_CREDENTIAL &&
+        text->client_creds == GSS_C_NO_CREDENTIAL) {
         pass_result = _plug_get_password(params->utils, &text->password,
                                          &text->free_password, prompt_need);
-        if (pass_result != SASL_OK && pass_result != SASL_INTERACT) {
+        if (pass_result != SASL_OK && pass_result != SASL_INTERACT)
             return pass_result;
+
+        if (text->password != NULL && text->password->len != 0) {
+            gss_buffer_desc password_buf;
+
+            password_buf.length = text->password->len;
+            password_buf.value = text->password->data;
+
+            maj_stat = gss_acquire_cred_with_password(&min_stat,
+                                                      text->client_name,
+                                                      &password_buf,
+                                                      GSS_C_INDEFINITE,
+                                                      &mechs,
+                                                      GSS_C_INITIATE,
+                                                      &text->client_creds,
+                                                      NULL,
+                                                      &text->lifetime);
+            if (GSS_ERROR(maj_stat)) {
+                sasl_gs2_seterror(text->utils, maj_stat, min_stat);
+                return SASL_FAIL;
+            }
         }
     }
 
@@ -1600,7 +1608,8 @@ gs2_ask_user_info(context_t *text,
                                NULL, NULL, NULL,
                                NULL,
                                NULL, NULL);
-        if (result == SASL_OK) return SASL_INTERACT;
+        if (result == SASL_OK)
+            return SASL_INTERACT;
 
         return result;
     }
@@ -1613,7 +1622,8 @@ gs2_ask_user_info(context_t *text,
         } else {
             result = params->canon_user(params->utils->conn,
                                         authid, 0, SASL_CU_AUTHID, oparams);
-            if (result != SASL_OK) return result;
+            if (result != SASL_OK)
+                return result;
 
             result = params->canon_user(params->utils->conn,
                                         userid, 0, SASL_CU_AUTHZID, oparams);