diff --git a/connect/host.go b/connect/host.go
index 162ab99d5fafea325a83c822505ff0818fd05a7f..07a930227fab57c01dc14032129f4229c9b1a071 100644
--- a/connect/host.go
+++ b/connect/host.go
@@ -70,8 +70,6 @@ type Host struct {
// lock which ensures only a single thread is connecting at a time and
// that connections do not interrupt sends
connectionMux sync.RWMutex
- // lock which ensures transmissions are not interrupted by disconnections
- transmitMux sync.RWMutex
coolOffBucket *rateLimiting.Bucket
inCoolOff bool
@@ -164,6 +162,13 @@ func (h *Host) Connected() (bool, uint64) {
h.connectionMux.RLock()
defer h.connectionMux.RUnlock()
+ return h.connectedUnsafe()
+}
+
+// connectedUnsafe checks if the given Host's connection is alive without taking
+// a connection lock. Only use if already under a connection lock. The the uint is
+//the connection count, it increments every time a reconnect occurs
+func (h *Host) connectedUnsafe() (bool, uint64) {
return h.isAlive() && !h.authenticationRequired(), h.connectionCount
}
@@ -228,19 +233,17 @@ func (h *Host) SetMetricsTesting(m *Metric, face interface{}) {
}
// Disconnect closes a the Host connection under the write lock
+// Due to asynchorous connection handling, this may result in
+// killing a good connection and could result in an immediate
+// reconnection by a seperate thread
func (h *Host) Disconnect() {
- h.transmitMux.Lock()
- defer h.transmitMux.Unlock()
-
- h.disconnect()
+ h.connectionMux.Lock()
+ defer h.connectionMux.Unlock()
}
// ConditionalDisconnect closes a the Host connection under the write lock only
// if the connection count has not increased
func (h *Host) conditionalDisconnect(count uint64) {
- h.connectionMux.Lock()
- defer h.connectionMux.Unlock()
-
if count == h.connectionCount {
h.disconnect()
}
@@ -269,13 +272,10 @@ func (h *Host) IsOnline() bool {
}
// send checks that the host has a connection and sends if it does.
-// Operates under the host's read lock.
+// must be called under host's connection read lock.
func (h *Host) transmit(f func(conn *grpc.ClientConn) (interface{},
error)) (interface{}, error) {
- h.connectionMux.RLock()
- defer h.connectionMux.RUnlock()
-
// Check if connection is down
if h.connection == nil {
return nil, errors.New("Failed to transmit: host disconnected")
@@ -315,6 +315,7 @@ func (h *Host) authenticationRequired() bool {
}
// isAlive returns true if the connection is non-nil and alive
+// must already be under the connectionMux
func (h *Host) isAlive() bool {
if h.connection == nil {
return false
diff --git a/connect/transmit.go b/connect/transmit.go
index f1576ac5587486db0ef9e0d5b43d19805f4f058e..d2c42d243e4495274cc0d053345419b3b2313b4e 100644
--- a/connect/transmit.go
+++ b/connect/transmit.go
@@ -18,8 +18,13 @@ const MaxRetries = 3
const inCoolDownErr = "Host is in cool down. Cannot connect."
const lastTryErr = "Last try to connect to"
-// Sets up or recovers the Host's connection
+// transmit sets up or recovers the Host's connection
// Then runs the given Send function
+// This has a bug where if a disconnect happens after "host.transmit(f)"
+// and the comms was unsuccessful and is retried, this code will connect
+// and then do the operation, leaving the host as connected. In a system
+// like the host pool in client, this will cause untracked connections.
+// Given that connections have timeouts, this is a minor issue
func (c *ProtoComms) transmit(host *Host, f func(conn *grpc.ClientConn) (interface{},
error)) (result interface{}, err error) {
@@ -27,15 +32,16 @@ func (c *ProtoComms) transmit(host *Host, f func(conn *grpc.ClientConn) (interfa
return nil, errors.New("Host address is blank, host might be receive only.")
}
- host.transmitMux.RLock()
- defer host.transmitMux.RUnlock()
-
for numRetries := 0; numRetries < MaxRetries; numRetries++ {
err = nil
//reconnect if necessary
+ host.connectionMux.RLock()
connected, connectionCount := host.Connected()
if !connected {
+ host.connectionMux.RUnlock()
+ host.connectionMux.Lock()
connectionCount, err = c.connect(host, connectionCount)
+ host.connectionMux.Unlock()
if err != nil {
if strings.Contains(err.Error(), inCoolDownErr) ||
strings.Contains(err.Error(), lastTryErr) {
@@ -45,16 +51,20 @@ func (c *ProtoComms) transmit(host *Host, f func(conn *grpc.ClientConn) (interfa
"%v/%v : %s", numRetries+1, MaxRetries, err)
continue
}
+ host.connectionMux.RLock()
}
//transmit
result, err = host.transmit(f)
+ host.connectionMux.RUnlock()
// if the transmission goes well or it is a domain specific error, return
if err == nil || !(isConnError(err) || IsAuthError(err)) {
return result, err
}
+ host.connectionMux.Lock()
host.conditionalDisconnect(connectionCount)
+ host.connectionMux.Unlock()
jww.WARN.Printf("Failed to send to Host on attempt %v/%v: %+v",
numRetries+1, MaxRetries, err)
}
@@ -63,9 +73,6 @@ func (c *ProtoComms) transmit(host *Host, f func(conn *grpc.ClientConn) (interfa
}
func (c *ProtoComms) connect(host *Host, count uint64) (uint64, error) {
- host.connectionMux.Lock()
- defer host.connectionMux.Unlock()
-
if host.coolOffBucket != nil {
if host.inCoolOff {
if host.coolOffBucket.IsEmpty() {