diff --git a/connect/auth_test.go b/connect/auth_test.go index df170fba46ccfef25618fe44f25b8cb1410a351f..ca786f3c49f6a98aeabcd00cb6de94ae51b55234 100644 --- a/connect/auth_test.go +++ b/connect/auth_test.go @@ -73,7 +73,7 @@ func TestProtoComms_AuthenticatedReceiver(t *testing.T) { copy(tkn[:], expectedVal) // Add host - _, err := pc.AddHost(testID, "", nil, false, true) + _, err := pc.AddHost(testID, "", nil, GetDefaultHostParams()) if err != nil { t.Errorf("Failed to add host: %+v", err) } @@ -123,7 +123,7 @@ func TestProtoComms_AuthenticatedReceiver_BadId(t *testing.T) { copy(tkn[:], expectedVal) // Add host - _, err := pc.AddHost(testID, "", nil, false, true) + _, err := pc.AddHost(testID, "", nil, GetDefaultHostParams()) if err != nil { t.Errorf("Failed to add host: %+v", err) } @@ -196,7 +196,7 @@ func TestProtoComms_PackAuthenticatedMessage(t *testing.T) { testId := id.NewIdFromString("test", id.Node, t) - host, err := NewHost(testId, "test", nil, false, true) + host, err := NewHost(testId, "test", nil, GetDefaultHostParams()) if err != nil { t.Errorf("Unable to create host: %+v", err) } @@ -241,7 +241,7 @@ func TestProtoComms_ValidateToken(t *testing.T) { copy(tkn[:], tokenBytes) pub := testkeys.LoadFromPath(testkeys.GetNodeCertPath()) - host, err := comm.AddHost(testId, "test", pub, false, true) + host, err := comm.AddHost(testId, "test", pub, GetDefaultHostParams()) if err != nil { t.Errorf("Unable to create host: %+v", err) } @@ -295,7 +295,7 @@ func TestProtoComms_ValidateToken_BadId(t *testing.T) { copy(tkn[:], tokenBytes) pub := testkeys.LoadFromPath(testkeys.GetNodeCertPath()) - host, err := comm.AddHost(testId, "test", pub, false, true) + host, err := comm.AddHost(testId, "test", pub, GetDefaultHostParams()) if err != nil { t.Errorf("Unable to create host: %+v", err) } diff --git a/connect/circuit_test.go b/connect/circuit_test.go index 6f4da06a115d4d06ccfa2676b2c47aeadf7109ee..94cd0a82b023ae3fe6450a0494eb3de283e18a5d 100644 --- a/connect/circuit_test.go +++ b/connect/circuit_test.go @@ -218,7 +218,7 @@ func TestCircuit_GetHostAtIndex(t *testing.T) { cert, _ := utils.ReadFile(testkeys.GetNodeCertPath()) circuit := NewCircuit(nodeIdList) testID := id.NewIdFromString("test", id.Generic, t) - testHost, _ := NewHost(testID, "test", cert, false, false) + testHost, _ := NewHost(testID, "test", cert, GetDefaultHostParams()) circuit.AddHost(testHost) if !reflect.DeepEqual(circuit.hosts[0], testHost) { diff --git a/connect/host.go b/connect/host.go index 349287238841a94ec9cb71a1969a251e071286f3..c5ae8484869416a4e155c6ae2e2f176181a9e86f 100644 --- a/connect/host.go +++ b/connect/host.go @@ -21,7 +21,6 @@ import ( "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" "google.golang.org/grpc/keepalive" - "math" "net" "sync" "sync/atomic" @@ -60,7 +59,7 @@ type Host struct { transmissionToken *token.Live // Configure the maximum number of connection attempts - maxRetries int + maxRetries uint // GRPC connection object connection *grpc.ClientConn @@ -84,27 +83,20 @@ type Host struct { } // Creates a new Host object -func NewHost(id *id.ID, address string, cert []byte, disableTimeout, - enableAuth bool) (host *Host, err error) { +func NewHost(id *id.ID, address string, cert []byte, params HostParams) (host *Host, err error) { // Initialize the Host object host = &Host{ id: id, certificate: cert, - enableAuth: enableAuth, + enableAuth: params.AuthEnabled, transmissionToken: token.NewLive(), receptionToken: token.NewLive(), + maxRetries: params.MaxRetries, } host.UpdateAddress(address) - // Set the max number of retries for establishing a connection - if disableTimeout { - host.maxRetries = math.MaxInt32 - } else { - host.maxRetries = 100 - } - // Configure the host credentials err = host.setCredentials() return @@ -289,7 +281,8 @@ func (h *Host) connectHelper() (err error) { " credentials: %+v", h.GetAddress(), securityDial) // Attempt to establish a new connection - for numRetries := 0; numRetries < h.maxRetries && !h.isAlive(); numRetries++ { + var numRetries uint + for numRetries = 0; numRetries < h.maxRetries && !h.isAlive(); numRetries++ { h.disconnect() jww.INFO.Printf("Connecting to %+v. Attempt number %+v of %+v", diff --git a/connect/hostParams.go b/connect/hostParams.go new file mode 100644 index 0000000000000000000000000000000000000000..027dfb48aaf9a7cdc7b6b984f1b430488d93fc51 --- /dev/null +++ b/connect/hostParams.go @@ -0,0 +1,15 @@ +package connect + +// Params object for host creation +type HostParams struct { + MaxRetries uint + AuthEnabled bool +} + +// Get default set of host params +func GetDefaultHostParams() HostParams { + return HostParams{ + MaxRetries: 100, + AuthEnabled: true, + } +} diff --git a/connect/host_test.go b/connect/host_test.go index 9955d7df2f4d6807d7725223c491d469653b8e68..0fc05f00076cc7145c3bc32951c901b75a80b3fe 100644 --- a/connect/host_test.go +++ b/connect/host_test.go @@ -18,7 +18,7 @@ func TestHost_address(t *testing.T) { mgr := newManager() testId := id.NewIdFromString("test", id.Node, t) testAddress := "test" - host, err := mgr.AddHost(testId, testAddress, nil, false, false) + host, err := mgr.AddHost(testId, testAddress, nil, GetDefaultHostParams()) if err != nil { t.Errorf("Unable to add host") return @@ -121,8 +121,7 @@ func TestHost_IsOnline(t *testing.T) { // Create the host host, err := NewHost(id.NewIdFromString("test", id.Gateway, t), addr, nil, - false, - false) + GetDefaultHostParams()) if err != nil { t.Errorf("Unable to create host: %+v", host) return diff --git a/connect/manager.go b/connect/manager.go index 8678d61d9a4ae01d8599e7a8abd883ee0bac83ff..f766f447d3b51535a0f5245f8ba149ca81fb38d7 100644 --- a/connect/manager.go +++ b/connect/manager.go @@ -53,7 +53,7 @@ func (m *Manager) GetHost(hostId *id.ID) (*Host, bool) { // Creates and adds a Host object to the Manager using the given id func (m *Manager) AddHost(hid *id.ID, address string, - cert []byte, disableTimeout, enableAuth bool) (host *Host, err error) { + cert []byte, params HostParams) (host *Host, err error) { m.mux.Lock() defer m.mux.Unlock() @@ -64,7 +64,7 @@ func (m *Manager) AddHost(hid *id.ID, address string, } //create the new host - host, err = NewHost(hid, address, cert, disableTimeout, enableAuth) + host, err = NewHost(hid, address, cert, params) if err != nil { return nil, err } diff --git a/connect/manager_test.go b/connect/manager_test.go index a726a495c316a608e84d10f763ed4e580e2432a0..aacbdb9b7b6fcefbacb611b3d046e487af78a69f 100644 --- a/connect/manager_test.go +++ b/connect/manager_test.go @@ -59,7 +59,7 @@ func TestConnectionManager_Disconnect(t *testing.T) { address := ServerAddress manager := newManager() testId := id.NewIdFromString("testId", id.Node, t) - host, err := manager.AddHost(testId, address, nil, false, false) + host, err := manager.AddHost(testId, address, nil, GetDefaultHostParams()) if err != nil { t.Errorf("Unable to call connnect: %+v", err) } @@ -99,7 +99,7 @@ func TestConnectionManager_DisconnectAll(t *testing.T) { testId := id.NewIdFromString("testId", id.Generic, t) testId2 := id.NewIdFromString("TestId2", id.Generic, t) - host, err := manager.AddHost(testId, address, nil, false, false) + host, err := manager.AddHost(testId, address, nil, GetDefaultHostParams()) if err != nil { t.Errorf("Unable to call NewHost: %+v", err) } @@ -112,7 +112,7 @@ func TestConnectionManager_DisconnectAll(t *testing.T) { pass++ } - host2, err := manager.AddHost(testId2, address2, nil, false, false) + host2, err := manager.AddHost(testId2, address2, nil, GetDefaultHostParams()) if err != nil { t.Errorf("Unable to call NewHost: %+v", err) } @@ -157,7 +157,7 @@ func TestConnectionManager_String(t *testing.T) { certPath := testkeys.GetNodeCertPath() certData := testkeys.LoadFromPath(certPath) testID := id.NewIdFromString("test", id.Node, t) - _, err := manager.AddHost(testID, "test", certData, false, false) + _, err := manager.AddHost(testID, "test", certData, GetDefaultHostParams()) if err != nil { t.Errorf("Unable to call NewHost: %+v", err) }