From 1ebef8ab3b61ad462c3072612ce5c210900e4a45 Mon Sep 17 00:00:00 2001
From: Jono Wenger <jono@elixxir.io>
Date: Wed, 29 Dec 2021 12:56:59 -0800
Subject: [PATCH] Add dynamic timeout to sendToPreferredFunc

---
 network/gateway/sender.go       | 13 +++++++++----
 network/gateway/utils_test.go   | 12 +++++++-----
 network/message/sendCmix.go     | 13 +++++++++----
 network/message/sendManyCmix.go | 15 +++++++++++----
 network/rounds/retrieve.go      |  2 +-
 5 files changed, 37 insertions(+), 18 deletions(-)

diff --git a/network/gateway/sender.go b/network/gateway/sender.go
index 311bcf3dd..d09decce4 100644
--- a/network/gateway/sender.go
+++ b/network/gateway/sender.go
@@ -78,11 +78,14 @@ func (s *Sender) SendToAny(sendFunc func(host *connect.Host) (interface{}, error
 	return nil, errors.Errorf("Unable to send to any proxies")
 }
 
+// sendToPreferredFunc is the send function passed into Sender.SendToPreferred.
+type sendToPreferredFunc func(host *connect.Host, target *id.ID,
+	timeout time.Duration) (interface{}, error)
+
 // SendToPreferred Call given sendFunc to any Host in the HostPool, attempting
 // with up to numProxies destinations. Returns an error if the timeout is
 // reached.
-func (s *Sender) SendToPreferred(targets []*id.ID,
-	sendFunc func(host *connect.Host, target *id.ID) (interface{}, error),
+func (s *Sender) SendToPreferred(targets []*id.ID, sendFunc sendToPreferredFunc,
 	stop *stoppable.Single, timeout time.Duration) (interface{}, error) {
 
 	startTime := netTime.Now()
@@ -98,7 +101,8 @@ func (s *Sender) SendToPreferred(targets []*id.ID,
 				"sending to targets in HostPool timed out after %s", timeout)
 		}
 
-		result, err := sendFunc(targetHosts[i], targets[i])
+		remainingTimeout := timeout - netTime.Since(startTime)
+		result, err := sendFunc(targetHosts[i], targets[i], remainingTimeout)
 		if stop != nil && !stop.IsRunning() {
 			return nil, errors.Errorf(stoppable.ErrMsg, stop.Name(), "SendToPreferred")
 		} else if err == nil {
@@ -168,7 +172,8 @@ func (s *Sender) SendToPreferred(targets []*id.ID,
 				continue
 			}
 
-			result, err := sendFunc(proxy, target)
+			remainingTimeout := timeout - netTime.Since(startTime)
+			result, err := sendFunc(proxy, target, remainingTimeout)
 			if stop != nil && !stop.IsRunning() {
 				return nil, errors.Errorf(stoppable.ErrMsg, stop.Name(), "SendToPreferred")
 			} else if err == nil {
diff --git a/network/gateway/utils_test.go b/network/gateway/utils_test.go
index 09081da17..7bd89f123 100644
--- a/network/gateway/utils_test.go
+++ b/network/gateway/utils_test.go
@@ -9,9 +9,11 @@ package gateway
 
 import (
 	"fmt"
+	"github.com/pkg/errors"
 	"gitlab.com/xx_network/comms/connect"
 	"gitlab.com/xx_network/primitives/id"
 	"gitlab.com/xx_network/primitives/ndf"
+	"time"
 )
 
 // Mock structure adhering to HostManager to be used for happy path
@@ -141,16 +143,16 @@ func getTestNdf(face interface{}) *ndf.NetworkDefinition {
 
 const happyPathReturn = "happyPathReturn"
 
-func SendToPreferred_HappyPath(host *connect.Host, target *id.ID) (interface{}, error) {
+func SendToPreferred_HappyPath(*connect.Host, *id.ID, time.Duration) (interface{}, error) {
 	return happyPathReturn, nil
 }
 
-func SendToPreferred_KnownError(host *connect.Host, target *id.ID) (interface{}, error) {
-	return nil, fmt.Errorf(errorsList[0])
+func SendToPreferred_KnownError(*connect.Host, *id.ID, time.Duration) (interface{}, error) {
+	return nil, errors.Errorf(errorsList[0])
 }
 
-func SendToPreferred_UnknownError(host *connect.Host, target *id.ID) (interface{}, error) {
-	return nil, fmt.Errorf("Unexpected error: Oopsie")
+func SendToPreferred_UnknownError(*connect.Host, *id.ID, time.Duration) (interface{}, error) {
+	return nil, errors.Errorf("Unexpected error: Oopsie")
 }
 
 func SendToAny_HappyPath(host *connect.Host) (interface{}, error) {
diff --git a/network/message/sendCmix.go b/network/message/sendCmix.go
index 0c8aa4533..2126e1ac2 100644
--- a/network/message/sendCmix.go
+++ b/network/message/sendCmix.go
@@ -160,12 +160,17 @@ func sendCmixHelper(sender *gateway.Sender, msg format.Message,
 			encMsg.Digest(), firstGateway.String())
 
 		// Send the payload
-		sendFunc := func(host *connect.Host, target *id.ID) (interface{}, error) {
+		sendFunc := func(host *connect.Host, target *id.ID,
+			timeout time.Duration) (interface{}, error) {
 			wrappedMsg.Target = target.Marshal()
 
-			timeout := calculateSendTimeout(bestRound, maxTimeout)
-			result, err := comms.SendPutMessage(host, wrappedMsg,
-				timeout)
+			// Use the smaller of the two timeout durations
+			calculatedTimeout := calculateSendTimeout(bestRound, maxTimeout)
+			if calculatedTimeout < timeout {
+				timeout = calculatedTimeout
+			}
+
+			result, err := comms.SendPutMessage(host, wrappedMsg, timeout)
 			if err != nil {
 				// fixme: should we provide as a slice the whole topology?
 				err := handlePutMessageError(firstGateway, instance, session, nodeRegistration, recipient.String(), bestRound, err)
diff --git a/network/message/sendManyCmix.go b/network/message/sendManyCmix.go
index 4a5df6d10..58edff807 100644
--- a/network/message/sendManyCmix.go
+++ b/network/message/sendManyCmix.go
@@ -27,6 +27,7 @@ import (
 	"gitlab.com/xx_network/primitives/id/ephemeral"
 	"gitlab.com/xx_network/primitives/netTime"
 	"strings"
+	"time"
 )
 
 // SendManyCMIX sends many "raw" cMix message payloads to each of the provided
@@ -165,11 +166,17 @@ func sendManyCmixHelper(sender *gateway.Sender,
 		}
 
 		// Send the payload
-		sendFunc := func(host *connect.Host, target *id.ID) (interface{}, error) {
+		sendFunc := func(host *connect.Host, target *id.ID,
+			timeout time.Duration) (interface{}, error) {
+			// Use the smaller of the two timeout durations
+			calculatedTimeout := calculateSendTimeout(bestRound, maxTimeout)
+			if calculatedTimeout < timeout {
+				timeout = calculatedTimeout
+			}
+
 			wrappedMessage.Target = target.Marshal()
-			timeout := calculateSendTimeout(bestRound, maxTimeout)
-			result, err := comms.SendPutManyMessages(host,
-				wrappedMessage, timeout)
+			result, err := comms.SendPutManyMessages(
+				host, wrappedMessage, timeout)
 			if err != nil {
 				err := handlePutMessageError(firstGateway, instance,
 					session, nodeRegistration, recipientString, bestRound, err)
diff --git a/network/rounds/retrieve.go b/network/rounds/retrieve.go
index bcebcf5a4..1db833b43 100644
--- a/network/rounds/retrieve.go
+++ b/network/rounds/retrieve.go
@@ -148,7 +148,7 @@ func (m *Manager) getMessagesFromGateway(roundID id.Round,
 	stop *stoppable.Single) (message.Bundle, error) {
 	start := time.Now()
 	// Send to the gateways using backup proxies
-	result, err := m.sender.SendToPreferred(gwIds, func(host *connect.Host, target *id.ID) (interface{}, error) {
+	result, err := m.sender.SendToPreferred(gwIds, func(host *connect.Host, target *id.ID, _ time.Duration) (interface{}, error) {
 		jww.DEBUG.Printf("Trying to get messages for round %v for ephemeralID %d (%v)  "+
 			"via Gateway: %s", roundID, identity.EphId.Int64(), identity.Source.String(), host.GetId())
 
-- 
GitLab