Skip to content
Snippets Groups Projects
Commit a93ded09 authored by Jake Taylor's avatar Jake Taylor
Browse files

protected context building

parent 17962663
No related branches found
No related tags found
No related merge requests found
...@@ -44,7 +44,7 @@ func (c *ProtoComms) clientHandshake(host *Host) (err error) { ...@@ -44,7 +44,7 @@ func (c *ProtoComms) clientHandshake(host *Host) (err error) {
// Set up the context // Set up the context
client := pb.NewGenericClient(host.connection) client := pb.NewGenericClient(host.connection)
ctx, cancel := MessagingContext(host.GetSendTimeout()) ctx, cancel := host.GetMessagingContext()
defer cancel() defer cancel()
// Send the token request message // Send the token request message
...@@ -75,7 +75,7 @@ func (c *ProtoComms) clientHandshake(host *Host) (err error) { ...@@ -75,7 +75,7 @@ func (c *ProtoComms) clientHandshake(host *Host) (err error) {
} }
// Set up the context // Set up the context
ctx, cancel = MessagingContext(host.GetSendTimeout()) ctx, cancel = host.GetMessagingContext()
defer cancel() defer cancel()
// Send the authenticate token message // Send the authenticate token message
......
...@@ -17,8 +17,8 @@ import ( ...@@ -17,8 +17,8 @@ import (
"time" "time"
) )
// MessagingContext builds a context for connections and message sending with the given timeout // newContext builds a context for connections and message sending with the given timeout
func MessagingContext(waitingPeriod time.Duration) (context.Context, context.CancelFunc) { func newContext(waitingPeriod time.Duration) (context.Context, context.CancelFunc) {
jww.DEBUG.Printf("Timing out in: %s", waitingPeriod) jww.DEBUG.Printf("Timing out in: %s", waitingPeriod)
ctx, cancel := context.WithTimeout(context.Background(), ctx, cancel := context.WithTimeout(context.Background(),
waitingPeriod) waitingPeriod)
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
package connect package connect
import ( import (
"context"
"fmt" "fmt"
"github.com/pkg/errors" "github.com/pkg/errors"
jww "github.com/spf13/jwalterweatherman" jww "github.com/spf13/jwalterweatherman"
...@@ -167,9 +168,9 @@ func (h *Host) Connected() (bool, uint64) { ...@@ -167,9 +168,9 @@ func (h *Host) Connected() (bool, uint64) {
return h.isAlive() && !h.authenticationRequired(), h.connectionCount return h.isAlive() && !h.authenticationRequired(), h.connectionCount
} }
// GetSendTimeout returns the timeout for message sending // GetMessagingContext returns a context object for message sending configured according to HostParams
func (h *Host) GetSendTimeout() time.Duration { func (h *Host) GetMessagingContext() (context.Context, context.CancelFunc) {
return h.params.SendTimeout return newContext(h.params.SendTimeout)
} }
// GetId returns the id of the host // GetId returns the id of the host
...@@ -375,7 +376,7 @@ func (h *Host) connectHelper() (err error) { ...@@ -375,7 +376,7 @@ func (h *Host) connectHelper() (err error) {
if backoffTime > 15000 { if backoffTime > 15000 {
backoffTime = 15000 backoffTime = 15000
} }
ctx, cancel := MessagingContext(time.Duration(backoffTime) * time.Millisecond) ctx, cancel := newContext(time.Duration(backoffTime) * time.Millisecond)
// Create the connection // Create the connection
h.connection, err = grpc.DialContext(ctx, h.GetAddress(), h.connection, err = grpc.DialContext(ctx, h.GetAddress(),
......
...@@ -26,7 +26,7 @@ func (c *CMixServer) GetNdf(host *connect.Host, ...@@ -26,7 +26,7 @@ func (c *CMixServer) GetNdf(host *connect.Host,
// Create the Send Function // Create the Send Function
f := func(conn *grpc.ClientConn) (*any.Any, error) { f := func(conn *grpc.ClientConn) (*any.Any, error) {
// Set up the context // Set up the context
ctx, cancel := connect.MessagingContext(host.GetSendTimeout()) ctx, cancel := host.GetMessagingContext()
defer cancel() defer cancel()
//Format to authenticated message type //Format to authenticated message type
// Send the message // Send the message
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment