From 17962663722cc93f76415f4845c975eeb560b496 Mon Sep 17 00:00:00 2001
From: Jake Taylor <jake@elixxir.io>
Date: Wed, 5 May 2021 15:37:12 -0500
Subject: [PATCH] Made contexts require timeout, added default timeout to
 hostparams

---
 connect/auth.go                 |  4 ++--
 connect/context.go              | 16 ++++------------
 connect/host.go                 | 11 ++++++++---
 connect/hostParams.go           | 12 ++++++++----
 interconnect/consensusClient.go |  8 ++++----
 5 files changed, 26 insertions(+), 25 deletions(-)

diff --git a/connect/auth.go b/connect/auth.go
index d54ad74..c442549 100644
--- a/connect/auth.go
+++ b/connect/auth.go
@@ -44,7 +44,7 @@ func (c *ProtoComms) clientHandshake(host *Host) (err error) {
 
 	// Set up the context
 	client := pb.NewGenericClient(host.connection)
-	ctx, cancel := MessagingContext()
+	ctx, cancel := MessagingContext(host.GetSendTimeout())
 	defer cancel()
 
 	// Send the token request message
@@ -75,7 +75,7 @@ func (c *ProtoComms) clientHandshake(host *Host) (err error) {
 	}
 
 	// Set up the context
-	ctx, cancel = MessagingContext()
+	ctx, cancel = MessagingContext(host.GetSendTimeout())
 	defer cancel()
 
 	// Send the authenticate token message
diff --git a/connect/context.go b/connect/context.go
index 3c95ea9..362a862 100644
--- a/connect/context.go
+++ b/connect/context.go
@@ -17,23 +17,15 @@ import (
 	"time"
 )
 
-// Used for creating connections with the default timeout
-func ConnectionContext(waitingPeriod time.Duration) (context.Context, context.CancelFunc) {
+// MessagingContext builds a context for connections and message sending with the given timeout
+func MessagingContext(waitingPeriod time.Duration) (context.Context, context.CancelFunc) {
 	jww.DEBUG.Printf("Timing out in: %s", waitingPeriod)
 	ctx, cancel := context.WithTimeout(context.Background(),
 		waitingPeriod)
 	return ctx, cancel
-
-}
-
-// Used for sending messages with the default timeout
-func MessagingContext() (context.Context, context.CancelFunc) {
-	ctx, cancel := context.WithTimeout(context.Background(),
-		2*time.Minute)
-	return ctx, cancel
 }
 
-// Creates a context object with the default context
+// StreamingContext Creates a context object with the default context
 // for all client streaming messages. This is primarily used to
 // allow a cancel option for clients and is suitable for unary streaming.
 func StreamingContext() (context.Context, context.CancelFunc) {
@@ -41,7 +33,7 @@ func StreamingContext() (context.Context, context.CancelFunc) {
 	return ctx, cancel
 }
 
-// Obtain address:port from the context of an incoming communication
+// GetAddressFromContext obtains address:port from the context of an incoming communication
 func GetAddressFromContext(ctx context.Context) (address string, port string, err error) {
 	info, _ := peer.FromContext(ctx)
 	address, port, err = net.SplitHostPort(info.Addr.String())
diff --git a/connect/host.go b/connect/host.go
index 5a6d74a..449fd9e 100644
--- a/connect/host.go
+++ b/connect/host.go
@@ -40,7 +40,7 @@ var KaClientOpts = keepalive.ClientParameters{
 	Time: infinityTime,
 	// 60s after ping before closing
 	Timeout: 60 * time.Second,
-	// For all connections, streaming and nonstreaming
+	// For all connections, with and without streaming
 	PermitWithoutStream: true,
 }
 
@@ -84,7 +84,7 @@ type Host struct {
 	// that connections do not interrupt sends
 	connectionMux sync.RWMutex
 	// lock which ensures transmissions are not interrupted by disconnections
-	transmitMux   sync.RWMutex
+	transmitMux sync.RWMutex
 
 	coolOffBucket *rateLimiting.Bucket
 	inCoolOff     bool
@@ -167,6 +167,11 @@ func (h *Host) Connected() (bool, uint64) {
 	return h.isAlive() && !h.authenticationRequired(), h.connectionCount
 }
 
+// GetSendTimeout returns the timeout for message sending
+func (h *Host) GetSendTimeout() time.Duration {
+	return h.params.SendTimeout
+}
+
 // GetId returns the id of the host
 func (h *Host) GetId() *id.ID {
 	if h == nil {
@@ -370,7 +375,7 @@ func (h *Host) connectHelper() (err error) {
 		if backoffTime > 15000 {
 			backoffTime = 15000
 		}
-		ctx, cancel := ConnectionContext(time.Duration(backoffTime) * time.Millisecond)
+		ctx, cancel := MessagingContext(time.Duration(backoffTime) * time.Millisecond)
 
 		// Create the connection
 		h.connection, err = grpc.DialContext(ctx, h.GetAddress(),
diff --git a/connect/hostParams.go b/connect/hostParams.go
index 0e04fe2..f5fc5d2 100644
--- a/connect/hostParams.go
+++ b/connect/hostParams.go
@@ -11,12 +11,12 @@ import (
 	"time"
 )
 
-// Params object for host creation
+// HostParams is the configuration object for Host creation
 type HostParams struct {
 	MaxRetries  uint32
 	AuthEnabled bool
 
-	// Toggles cool off of connections
+	// Toggles connection cool off
 	EnableCoolOff bool
 
 	// Number of leaky bucket sends before it stops
@@ -25,15 +25,18 @@ type HostParams struct {
 	// Amount of time after a cool off is triggered before allowed to send again
 	CoolOffTimeout time.Duration
 
+	// Message sending timeout
+	SendTimeout time.Duration
+
 	// If set, metric handling will be enabled on this host
 	EnableMetrics bool
 
 	// List of sending errors that are deemed unimportant
-	// Reception of these errors will not update the Metric's state
+	// Reception of these errors will not update the Metric state
 	ExcludeMetricErrors []string
 }
 
-// Get default set of host params
+// GetDefaultHostParams Get default set of host params
 func GetDefaultHostParams() HostParams {
 	return HostParams{
 		MaxRetries:            100,
@@ -41,6 +44,7 @@ func GetDefaultHostParams() HostParams {
 		EnableCoolOff:         false,
 		NumSendsBeforeCoolOff: 3,
 		CoolOffTimeout:        60 * time.Second,
+		SendTimeout:           2 * time.Minute,
 		EnableMetrics:         false,
 		ExcludeMetricErrors:   make([]string, 0),
 	}
diff --git a/interconnect/consensusClient.go b/interconnect/consensusClient.go
index d10a366..4065074 100644
--- a/interconnect/consensusClient.go
+++ b/interconnect/consensusClient.go
@@ -19,14 +19,14 @@ import (
 	"google.golang.org/grpc"
 )
 
-//  CMixServer -> consensus node Send Function
-func (s *CMixServer) GetNdf(host *connect.Host,
+// CMixServer -> consensus node Send Function
+func (c *CMixServer) GetNdf(host *connect.Host,
 	message *messages.Ping) (*NDF, error) {
 
 	// Create the Send Function
 	f := func(conn *grpc.ClientConn) (*any.Any, error) {
 		// Set up the context
-		ctx, cancel := connect.MessagingContext()
+		ctx, cancel := connect.MessagingContext(host.GetSendTimeout())
 		defer cancel()
 		//Format to authenticated message type
 		// Send the message
@@ -40,7 +40,7 @@ func (s *CMixServer) GetNdf(host *connect.Host,
 
 	// Execute the Send function
 	jww.DEBUG.Printf("Sending Post Phase message: %+v", message)
-	resultMsg, err := s.Send(host, f)
+	resultMsg, err := c.Send(host, f)
 	if err != nil {
 		return nil, err
 	}
-- 
GitLab