From 74318f483a0634cd163405034506c4b38b0c87fb Mon Sep 17 00:00:00 2001
From: Benjamin Wenger <ben@elixxir.ioo>
Date: Tue, 5 Oct 2021 08:05:44 -0700
Subject: [PATCH] fixed deadlocks

---
 connect/host.go     | 27 ++++++++++++++-------------
 connect/transmit.go | 21 ++++++++++++++-------
 2 files changed, 28 insertions(+), 20 deletions(-)

diff --git a/connect/host.go b/connect/host.go
index 162ab99..07a9302 100644
--- a/connect/host.go
+++ b/connect/host.go
@@ -70,8 +70,6 @@ type Host struct {
 	// lock which ensures only a single thread is connecting at a time and
 	// that connections do not interrupt sends
 	connectionMux sync.RWMutex
-	// lock which ensures transmissions are not interrupted by disconnections
-	transmitMux sync.RWMutex
 
 	coolOffBucket *rateLimiting.Bucket
 	inCoolOff     bool
@@ -164,6 +162,13 @@ func (h *Host) Connected() (bool, uint64) {
 	h.connectionMux.RLock()
 	defer h.connectionMux.RUnlock()
 
+	return h.connectedUnsafe()
+}
+
+// connectedUnsafe checks if the given Host's connection is alive without taking
+// a connection lock. Only use if already under a connection lock. The the uint is
+//the connection count, it increments every time a reconnect occurs
+func (h *Host) connectedUnsafe() (bool, uint64) {
 	return h.isAlive() && !h.authenticationRequired(), h.connectionCount
 }
 
@@ -228,19 +233,17 @@ func (h *Host) SetMetricsTesting(m *Metric, face interface{}) {
 }
 
 // Disconnect closes a the Host connection under the write lock
+// Due to asynchorous connection handling, this may result in
+// killing a good connection and could result in an immediate
+// reconnection by a seperate thread
 func (h *Host) Disconnect() {
-	h.transmitMux.Lock()
-	defer h.transmitMux.Unlock()
-
-	h.disconnect()
+	h.connectionMux.Lock()
+	defer h.connectionMux.Unlock()
 }
 
 // ConditionalDisconnect closes a the Host connection under the write lock only
 // if the connection count has not increased
 func (h *Host) conditionalDisconnect(count uint64) {
-	h.connectionMux.Lock()
-	defer h.connectionMux.Unlock()
-
 	if count == h.connectionCount {
 		h.disconnect()
 	}
@@ -269,13 +272,10 @@ func (h *Host) IsOnline() bool {
 }
 
 // send checks that the host has a connection and sends if it does.
-// Operates under the host's read lock.
+// must be called under host's connection read lock.
 func (h *Host) transmit(f func(conn *grpc.ClientConn) (interface{},
 	error)) (interface{}, error) {
 
-	h.connectionMux.RLock()
-	defer h.connectionMux.RUnlock()
-
 	// Check if connection is down
 	if h.connection == nil {
 		return nil, errors.New("Failed to transmit: host disconnected")
@@ -315,6 +315,7 @@ func (h *Host) authenticationRequired() bool {
 }
 
 // isAlive returns true if the connection is non-nil and alive
+// must already be under the connectionMux
 func (h *Host) isAlive() bool {
 	if h.connection == nil {
 		return false
diff --git a/connect/transmit.go b/connect/transmit.go
index f1576ac..d2c42d2 100644
--- a/connect/transmit.go
+++ b/connect/transmit.go
@@ -18,8 +18,13 @@ const MaxRetries = 3
 const inCoolDownErr = "Host is in cool down. Cannot connect."
 const lastTryErr = "Last try to connect to"
 
-// Sets up or recovers the Host's connection
+// transmit sets up or recovers the Host's connection
 // Then runs the given Send function
+// This has a bug where if a disconnect happens after "host.transmit(f)"
+// and the comms was unsuccessful and is retried, this code will connect
+// and then do the operation, leaving the host as connected. In a system
+// like the host pool in client, this will cause untracked connections.
+// Given that connections have timeouts, this is a minor issue
 func (c *ProtoComms) transmit(host *Host, f func(conn *grpc.ClientConn) (interface{},
 	error)) (result interface{}, err error) {
 
@@ -27,15 +32,16 @@ func (c *ProtoComms) transmit(host *Host, f func(conn *grpc.ClientConn) (interfa
 		return nil, errors.New("Host address is blank, host might be receive only.")
 	}
 
-	host.transmitMux.RLock()
-	defer host.transmitMux.RUnlock()
-
 	for numRetries := 0; numRetries < MaxRetries; numRetries++ {
 		err = nil
 		//reconnect if necessary
+		host.connectionMux.RLock()
 		connected, connectionCount := host.Connected()
 		if !connected {
+			host.connectionMux.RUnlock()
+			host.connectionMux.Lock()
 			connectionCount, err = c.connect(host, connectionCount)
+			host.connectionMux.Unlock()
 			if err != nil {
 				if strings.Contains(err.Error(), inCoolDownErr) ||
 					strings.Contains(err.Error(), lastTryErr) {
@@ -45,16 +51,20 @@ func (c *ProtoComms) transmit(host *Host, f func(conn *grpc.ClientConn) (interfa
 					"%v/%v : %s", numRetries+1, MaxRetries, err)
 				continue
 			}
+			host.connectionMux.RLock()
 		}
 
 		//transmit
 		result, err = host.transmit(f)
+		host.connectionMux.RUnlock()
 
 		// if the transmission goes well or it is a domain specific error, return
 		if err == nil || !(isConnError(err) || IsAuthError(err)) {
 			return result, err
 		}
+		host.connectionMux.Lock()
 		host.conditionalDisconnect(connectionCount)
+		host.connectionMux.Unlock()
 		jww.WARN.Printf("Failed to send to Host on attempt %v/%v: %+v",
 			numRetries+1, MaxRetries, err)
 	}
@@ -63,9 +73,6 @@ func (c *ProtoComms) transmit(host *Host, f func(conn *grpc.ClientConn) (interfa
 }
 
 func (c *ProtoComms) connect(host *Host, count uint64) (uint64, error) {
-	host.connectionMux.Lock()
-	defer host.connectionMux.Unlock()
-
 	if host.coolOffBucket != nil {
 		if host.inCoolOff {
 			if host.coolOffBucket.IsEmpty() {
-- 
GitLab