Fix race condition on exit of trpc threads
authorJennifer Richards <jennifer@painless-security.com>
Fri, 27 Apr 2018 01:46:36 +0000 (21:46 -0400)
committerJennifer Richards <jennifer@painless-security.com>
Fri, 27 Apr 2018 01:46:36 +0000 (21:46 -0400)
The messaging between the main thread and the trpc (outgoing connection)
threads allowed the trpc data to be cleaned up before the message queue
was empty, causing incorrect mutex behavior and seg faults.

This is (I hope!) solved adding an additional shutdown phase in which
the main thread indicates that it has recognized that the trpc thread
is done and that the trpc thread can safely exit.

So far, I have not seen a failure of the system to handle a peer
disconnecting. Prior to these changes, it failed every time with my
current setup.

include/tr_trp.h
include/trp_internal.h
tr/tr_trp.c
trp/trp_conn.c
trp/trpc.c
trp/trps.c

index dea20e4..42b586e 100644 (file)
@@ -68,6 +68,7 @@ struct tr_instance {
 /* messages between threads */
 #define TR_MQMSG_MSG_RECEIVED "msg received"
 #define TR_MQMSG_TRPC_DISCONNECTED "trpc disconnected"
+#define TR_MQMSG_TRPC_EXIT_OK "trpc exit ok"
 #define TR_MQMSG_TRPC_CONNECTED "trpc connected"
 #define TR_MQMSG_TRPS_DISCONNECTED "trps disconnected"
 #define TR_MQMSG_TRPS_CONNECTED "trps connected"
index a35a043..3fa4034 100644 (file)
@@ -116,8 +116,8 @@ struct trp_connection {
   TRP_CONNECTION *next;
   pthread_t *thread; /* thread servicing this connection */
   int fd;
-  TR_NAME *gssname;
-  TR_NAME *peer; /* TODO: why is there a peer and a gssname? jlr */
+  TR_NAME *gssname; /* the gss service name we presented for passive auth */
+  TR_NAME *peer; /* gssname of incoming peer */
   gss_ctx_id_t *gssctx;
   TRP_CONNECTION_STATUS status;
   void (*status_change_cb)(TRP_CONNECTION *conn, void *cookie);
@@ -142,6 +142,7 @@ struct trpc_instance {
   unsigned int port;
   TRP_CONNECTION *conn;
   TR_MQ *mq; /* msgs from master to trpc */
+  int shutting_down; /* 0 unless the TRPC is in the shutdown process */
 };
 
 /* TRP Server Instance Data */
@@ -187,7 +188,7 @@ TRP_CONNECTION *trp_connection_get_next(TRP_CONNECTION *conn);
 TRP_CONNECTION *trp_connection_remove(TRP_CONNECTION *conn, TRP_CONNECTION *remove);
 void trp_connection_append(TRP_CONNECTION *conn, TRP_CONNECTION *new);
 int trp_connection_auth(TRP_CONNECTION *conn, TRP_AUTH_FUNC auth_callback, void *callback_data);
-TRP_CONNECTION *trp_connection_accept(TALLOC_CTX *mem_ctx, int listen, TR_NAME *gssname);
+TRP_CONNECTION *trp_connection_accept(TALLOC_CTX *mem_ctx, int listen, TR_NAME *gss_servicename);
 TRP_RC trp_connection_initiate(TRP_CONNECTION *conn, char *server, unsigned int port);
 
 TRPC_INSTANCE *trpc_new (TALLOC_CTX *mem_ctx);
index e075d58..ad1084a 100644 (file)
@@ -200,7 +200,7 @@ static void tr_trps_event_cb(int listener, short event, void *arg)
       }
       thread_data->conn=conn;
       thread_data->trps=trps;
-      trps_add_connection(trps, conn); /* remember the connection */
+      trps_add_connection(trps, conn); /* remember the connection - this puts conn and the thread data in trps's talloc context */
       pthread_create(trp_connection_get_thread(conn), NULL, tr_trps_thread, thread_data);
     }
   }
@@ -222,6 +222,17 @@ static void tr_trps_cleanup_conn(TRPS_INSTANCE *trps, TRP_CONNECTION *conn)
 
 static void tr_trps_cleanup_trpc(TRPS_INSTANCE *trps, TRPC_INSTANCE *trpc)
 {
+  TR_MQ_MSG *msg;
+
+  /* tell the trpc thread to exit */
+  msg = tr_mq_msg_new(NULL, TR_MQMSG_TRPC_EXIT_OK, TR_MQ_PRIO_NORMAL);
+  if (msg) {
+    tr_debug("tr_trps_cleanup_trpc: Notifying thread that it may now exit");
+    trpc_mq_add(trpc, msg);
+  } else {
+    tr_crit("tr_trps_cleanup_trpc: Unable to acknowledge disconnection, thread will probably never terminate");
+  }
+
   pthread_join(*trp_connection_get_thread(trpc_get_conn(trpc)), NULL);
   trps_remove_trpc(trps, trpc);
   trpc_free(trpc);
@@ -268,55 +279,64 @@ static void tr_trps_process_mq(int socket, short event, void *arg)
   TRPS_INSTANCE *trps=talloc_get_type_abort(arg, TRPS_INSTANCE);
   TR_MQ_MSG *msg=NULL;
   const char *s=NULL;
+  TRP_PEER *peer = NULL;
+  char *tmp = NULL;
 
   msg=trps_mq_pop(trps);
   while (msg!=NULL) {
     s=tr_mq_msg_get_message(msg);
     if (0==strcmp(s, TR_MQMSG_TRPS_CONNECTED)) {
-      TR_NAME *gssname=(TR_NAME *)tr_mq_msg_get_payload(msg);
-      TRP_PEER *peer=trps_get_peer_by_gssname(trps, gssname);
+      TR_NAME *peer_gssname=(TR_NAME *)tr_mq_msg_get_payload(msg);
+      peer=trps_get_peer_by_gssname(trps, peer_gssname); /* get the peer record */
+      tmp = tr_name_strdup(peer_gssname); /* get the name as a null-terminated string */
       if (peer==NULL)
-        tr_err("tr_trps_process_mq: incoming connection from unknown peer (%s) reported.", gssname->buf);
+        tr_err("tr_trps_process_mq: incoming connection from unknown peer (%s) reported.", tmp);
       else {
         trp_peer_set_incoming_status(peer, PEER_CONNECTED);
-        tr_err("tr_trps_process_mq: incoming connection from %s established.", gssname->buf);
+        tr_err("tr_trps_process_mq: incoming connection from %s established.", tmp);
       }
+      free(tmp);
     }
     else if (0==strcmp(s, TR_MQMSG_TRPS_DISCONNECTED)) {
       TRP_CONNECTION *conn=talloc_get_type_abort(tr_mq_msg_get_payload(msg), TRP_CONNECTION);
-      TR_NAME *gssname=trp_connection_get_gssname(conn);
-      TRP_PEER *peer=trps_get_peer_by_gssname(trps, gssname);
+      TR_NAME *peer_gssname=trp_connection_get_peer(conn);
+      peer=trps_get_peer_by_gssname(trps, peer_gssname); /* get the peer record */
+      tmp = tr_name_strdup(peer_gssname); /* get the name as a null-terminated string */
       if (peer==NULL) {
-        tr_err("tr_trps_process_mq: incoming connection from unknown peer (%s) lost.",
-               trp_connection_get_gssname(conn)->buf);
+        tr_err("tr_trps_process_mq: incoming connection from unknown peer (%.*s) lost.", tmp);
       } else {
         trp_peer_set_incoming_status(peer, PEER_DISCONNECTED);
         tr_trps_cleanup_conn(trps, conn);
-        tr_err("tr_trps_process_mq: incoming connection from %s lost.", gssname->buf);
+        tr_err("tr_trps_process_mq: incoming connection from %s lost.", tmp);
       }
+      free(tmp);
     }
     else if (0==strcmp(s, TR_MQMSG_TRPC_CONNECTED)) {
       TR_NAME *svcname=(TR_NAME *)tr_mq_msg_get_payload(msg);
-      TRP_PEER *peer=trps_get_peer_by_servicename(trps, svcname);
+      peer=trps_get_peer_by_servicename(trps, svcname);
+      tmp = tr_name_strdup(svcname);
       if (peer==NULL)
-        tr_err("tr_trps_process_mq: outgoing connection to unknown peer (%s) reported.", svcname->buf);
+        tr_err("tr_trps_process_mq: outgoing connection to unknown peer (%s) reported.", tmp);
       else {
         trp_peer_set_outgoing_status(peer, PEER_CONNECTED);
-        tr_err("tr_trps_process_mq: outgoing connection to %s established.", svcname->buf);
+        tr_err("tr_trps_process_mq: outgoing connection to %s established.", tmp);
       }
+      free(tmp);
     }
     else if (0==strcmp(s, TR_MQMSG_TRPC_DISCONNECTED)) {
       /* trpc connection died */
       TRPC_INSTANCE *trpc=talloc_get_type_abort(tr_mq_msg_get_payload(msg), TRPC_INSTANCE);
-      TR_NAME *gssname=trpc_get_gssname(trpc);
-      TRP_PEER *peer=trps_get_peer_by_servicename(trps, gssname);
+      TR_NAME *svcname=trpc_get_gssname(trpc);
+      peer=trps_get_peer_by_servicename(trps, svcname);
+      tmp = tr_name_strdup(svcname);
       if (peer==NULL)
-        tr_err("tr_trps_process_mq: outgoing connection to unknown peer (%s) lost.", gssname->buf);
+        tr_err("tr_trps_process_mq: outgoing connection to unknown peer (%s) lost.", tmp);
       else {
         trp_peer_set_outgoing_status(peer, PEER_DISCONNECTED);
-        tr_err("tr_trps_process_mq: outgoing connection to %s lost.", gssname->buf);
+        tr_err("tr_trps_process_mq: outgoing connection to %s lost.", tmp);
         tr_trps_cleanup_trpc(trps, trpc);
       }
+      free(tmp);
     }
 
     else if (0==strcmp(s, TR_MQMSG_MSG_RECEIVED)) {
@@ -584,7 +604,8 @@ static void *tr_trpc_thread(void *arg)
   const char *msg_type=NULL;
   char *encoded_msg=NULL;
   TR_NAME *peer_gssname=NULL;
-  int n_sent=0;
+  int n_sent = 0;
+  int n_popped = 0;
   int exit_loop=0;
 
   struct trpc_notify_cb_data cb_data={0,
@@ -627,50 +648,77 @@ static void *tr_trpc_thread(void *arg)
     while(!exit_loop) {
       cb_data.msg_ready=0;
       pthread_cond_wait(&(cb_data.cond), &(cb_data.mutex));
-      /* verify the condition */
+      /* verify the condition - remember, we have the mutex! */
       if (cb_data.msg_ready) {
-        for (msg=trpc_mq_pop(trpc),n_sent=0; msg!=NULL; msg=trpc_mq_pop(trpc),n_sent++) {
-          msg_type=tr_mq_msg_get_message(msg);
-
-          if (0==strcmp(msg_type, TR_MQMSG_ABORT)) {
-            exit_loop=1;
+        n_popped = 0; /* have not popped any messages from the queue */
+        n_sent = 0; /* have not sent any messages yet */
+        for (msg = trpc_mq_pop(trpc);
+             msg != NULL;
+             msg = trpc_mq_pop(trpc)) {
+          n_popped++;
+          msg_type = tr_mq_msg_get_message(msg);
+          if (0 == strcmp(msg_type, TR_MQMSG_ABORT)) {
+            exit_loop = 1;
             break;
-          }
-          else if (0==strcmp(msg_type, TR_MQMSG_TRPC_SEND)) {
-            encoded_msg=tr_mq_msg_get_payload(msg);
-            if (encoded_msg==NULL)
+          } else if (0 == strcmp(msg_type, TR_MQMSG_TRPC_SEND)) {
+            encoded_msg = tr_mq_msg_get_payload(msg);
+            if (encoded_msg == NULL)
               tr_notice("tr_trpc_thread: null outgoing TRP message.");
             else {
               rc = trpc_send_msg(trpc, encoded_msg);
-              if (rc!=TRP_SUCCESS) {
+              if (rc == TRP_SUCCESS) {
+                n_sent++;
+              } else {
                 tr_notice("tr_trpc_thread: trpc_send_msg failed.");
-                exit_loop=1;
+                /* Assume this means we lost the connection. */
+                exit_loop = 1;
                 break;
               }
             }
-          }
-          else
+          } else
             tr_notice("tr_trpc_thread: unknown message '%s' received.", msg_type);
 
           tr_mq_msg_free(msg);
         }
-        if (n_sent==0)
-          tr_err("tr_trpc_thread: notified of msg, but queue empty");
+
+        /* if n_popped == 0, then n_sent must be zero (it's only set after at
+         * least one msg is popped) */
+        if (n_popped==0)
+          tr_err("tr_trpc_thread: notified of message, but queue empty");
         else 
           tr_debug("tr_trpc_thread: sent %d messages.", n_sent);
       }
     }
   }
 
-  tr_debug("tr_trpc_thread: exiting.");
+  tr_debug("tr_trpc_thread: Disconnected. Waiting to terminate thread.");
+  trpc->shutting_down = 1;
+
+  // trpc_mq_clear(trpc); /* clear any queued messages */
+
   msg=tr_mq_msg_new(tmp_ctx, TR_MQMSG_TRPC_DISCONNECTED, TR_MQ_PRIO_HIGH);
   tr_mq_msg_set_payload(msg, (void *)trpc, NULL); /* do not pass a free routine */
-  if (msg==NULL)
+  if (msg==NULL) {
+    /* can't notify main thread of exit - just do it and hope for the best */
     tr_err("tr_trpc_thread: error allocating TR_MQ_MSG");
-  else
+  } else {
     trps_mq_add(trps, msg);
-
-  trpc_mq_clear(trpc); /* clear any queued messages */
+    /* now wait for an acknowledgement */
+    exit_loop = 0;
+    while (!exit_loop) {
+      cb_data.msg_ready = 0;
+      pthread_cond_wait(&(cb_data.cond), &(cb_data.mutex));
+      /* verify the condition - remember, we have the mutex! */
+      if (cb_data.msg_ready) {
+        while (NULL != (msg = trpc_mq_pop(trpc))) {
+          msg_type = tr_mq_msg_get_message(msg);
+          /* ignore anything except an exit ack */
+          if (0 == strcmp(msg_type, TR_MQMSG_TRPC_EXIT_OK))
+            exit_loop = 1;
+        }
+      }
+    }
+  }
 
   talloc_free(tmp_ctx);
   return NULL;
index 64ddf3c..dfa666c 100644 (file)
@@ -338,8 +338,13 @@ int trp_connection_auth(TRP_CONNECTION *conn, TRP_AUTH_FUNC auth_callback, void
   return !auth;
 }
 
-/* Accept connection */
-TRP_CONNECTION *trp_connection_accept(TALLOC_CTX *mem_ctx, int listen, TR_NAME *gssname)
+/**
+ * Accept connection
+ *
+ * @param mem_ctx talloc context for return value
+ * @param listen socket fd for incoming connection
+ * @param gss_servicename our GSS service name to use for passive auth */
+TRP_CONNECTION *trp_connection_accept(TALLOC_CTX *mem_ctx, int listen, TR_NAME *gss_servicename)
 {
   int conn_fd=-1;
   TRP_CONNECTION *conn=NULL;
@@ -352,7 +357,7 @@ TRP_CONNECTION *trp_connection_accept(TALLOC_CTX *mem_ctx, int listen, TR_NAME *
   }
   conn=trp_connection_new(mem_ctx);
   trp_connection_set_fd(conn, conn_fd);
-  trp_connection_set_gssname(conn, gssname);
+  trp_connection_set_gssname(conn, gss_servicename);
   trp_connection_set_status(conn, TRP_CONNECTION_AUTHORIZING);
   return conn;
 }
index b6a2cea..6c1a4bc 100644 (file)
@@ -59,6 +59,7 @@ TRPC_INSTANCE *trpc_new (TALLOC_CTX *mem_ctx)
     trpc->server=NULL;
     trpc->port=0;
     trpc->conn=NULL;
+    trpc->shutting_down = 0;
     trpc->mq=tr_mq_new(trpc);
     if (trpc->mq==NULL) {
       talloc_free(trpc);
@@ -216,7 +217,7 @@ TRP_RC trpc_send_msg (TRPC_INSTANCE *trpc,
                                          *trp_connection_get_gssctx(trpc_get_conn(trpc)),
                                          msg_content, 
                                          strlen(msg_content))) {
-    tr_err( "trpc_send_msg: Error sending message over connection.\n");
+    tr_err( "trpc_send_msg: Error sending message over connection.");
     rc=TRP_ERROR;
   }
   return rc;
index 573f0a3..83461d9 100644 (file)
@@ -269,6 +269,9 @@ TRP_RC trps_send_msg(TRPS_INSTANCE *trps, TRP_PEER *peer, const char *msg)
    * connect fails */
   if (trpc==NULL) {
     tr_warning("trps_send_msg: skipping message queued for missing TRP client entry.");
+  } else if (trpc->shutting_down) {
+    tr_debug("trps_send_msg: skipping message because TRP client is shutting down.");
+    rc = TRP_SUCCESS; /* it's ok that this didn't get sent, the connection will be gone in a moment */
   } else {
     mq_msg=tr_mq_msg_new(tmp_ctx, TR_MQMSG_TRPC_SEND, TR_MQ_PRIO_NORMAL);
     msg_dup=talloc_strdup(mq_msg, msg); /* get local copy in mq_msg context */