Skip to content
Snippets Groups Projects
Commit 2d84c6e1 authored by Jonah Husson's avatar Jonah Husson
Browse files

add host params object for host creation

parent 05fdb4b1
Branches
Tags
No related merge requests found
......@@ -73,7 +73,7 @@ func TestProtoComms_AuthenticatedReceiver(t *testing.T) {
copy(tkn[:], expectedVal)
// Add host
_, err := pc.AddHost(testID, "", nil, false, true)
_, err := pc.AddHost(testID, "", nil, GetDefaultHostParams())
if err != nil {
t.Errorf("Failed to add host: %+v", err)
}
......@@ -123,7 +123,7 @@ func TestProtoComms_AuthenticatedReceiver_BadId(t *testing.T) {
copy(tkn[:], expectedVal)
// Add host
_, err := pc.AddHost(testID, "", nil, false, true)
_, err := pc.AddHost(testID, "", nil, GetDefaultHostParams())
if err != nil {
t.Errorf("Failed to add host: %+v", err)
}
......@@ -196,7 +196,7 @@ func TestProtoComms_PackAuthenticatedMessage(t *testing.T) {
testId := id.NewIdFromString("test", id.Node, t)
host, err := NewHost(testId, "test", nil, false, true)
host, err := NewHost(testId, "test", nil, GetDefaultHostParams())
if err != nil {
t.Errorf("Unable to create host: %+v", err)
}
......@@ -241,7 +241,7 @@ func TestProtoComms_ValidateToken(t *testing.T) {
copy(tkn[:], tokenBytes)
pub := testkeys.LoadFromPath(testkeys.GetNodeCertPath())
host, err := comm.AddHost(testId, "test", pub, false, true)
host, err := comm.AddHost(testId, "test", pub, GetDefaultHostParams())
if err != nil {
t.Errorf("Unable to create host: %+v", err)
}
......@@ -295,7 +295,7 @@ func TestProtoComms_ValidateToken_BadId(t *testing.T) {
copy(tkn[:], tokenBytes)
pub := testkeys.LoadFromPath(testkeys.GetNodeCertPath())
host, err := comm.AddHost(testId, "test", pub, false, true)
host, err := comm.AddHost(testId, "test", pub, GetDefaultHostParams())
if err != nil {
t.Errorf("Unable to create host: %+v", err)
}
......
......@@ -218,7 +218,7 @@ func TestCircuit_GetHostAtIndex(t *testing.T) {
cert, _ := utils.ReadFile(testkeys.GetNodeCertPath())
circuit := NewCircuit(nodeIdList)
testID := id.NewIdFromString("test", id.Generic, t)
testHost, _ := NewHost(testID, "test", cert, false, false)
testHost, _ := NewHost(testID, "test", cert, GetDefaultHostParams())
circuit.AddHost(testHost)
if !reflect.DeepEqual(circuit.hosts[0], testHost) {
......
......@@ -21,7 +21,6 @@ import (
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"
"math"
"net"
"sync"
"sync/atomic"
......@@ -60,7 +59,7 @@ type Host struct {
transmissionToken *token.Live
// Configure the maximum number of connection attempts
maxRetries int
maxRetries uint
// GRPC connection object
connection *grpc.ClientConn
......@@ -84,27 +83,20 @@ type Host struct {
}
// Creates a new Host object
func NewHost(id *id.ID, address string, cert []byte, disableTimeout,
enableAuth bool) (host *Host, err error) {
func NewHost(id *id.ID, address string, cert []byte, params HostParams) (host *Host, err error) {
// Initialize the Host object
host = &Host{
id: id,
certificate: cert,
enableAuth: enableAuth,
enableAuth: params.AuthEnabled,
transmissionToken: token.NewLive(),
receptionToken: token.NewLive(),
maxRetries: params.MaxRetries,
}
host.UpdateAddress(address)
// Set the max number of retries for establishing a connection
if disableTimeout {
host.maxRetries = math.MaxInt32
} else {
host.maxRetries = 100
}
// Configure the host credentials
err = host.setCredentials()
return
......@@ -289,7 +281,8 @@ func (h *Host) connectHelper() (err error) {
" credentials: %+v", h.GetAddress(), securityDial)
// Attempt to establish a new connection
for numRetries := 0; numRetries < h.maxRetries && !h.isAlive(); numRetries++ {
var numRetries uint
for numRetries = 0; numRetries < h.maxRetries && !h.isAlive(); numRetries++ {
h.disconnect()
jww.INFO.Printf("Connecting to %+v. Attempt number %+v of %+v",
......
package connect
// Params object for host creation
type HostParams struct {
MaxRetries uint
AuthEnabled bool
}
// Get default set of host params
func GetDefaultHostParams() HostParams {
return HostParams{
MaxRetries: 100,
AuthEnabled: true,
}
}
......@@ -18,7 +18,7 @@ func TestHost_address(t *testing.T) {
mgr := newManager()
testId := id.NewIdFromString("test", id.Node, t)
testAddress := "test"
host, err := mgr.AddHost(testId, testAddress, nil, false, false)
host, err := mgr.AddHost(testId, testAddress, nil, GetDefaultHostParams())
if err != nil {
t.Errorf("Unable to add host")
return
......@@ -121,8 +121,7 @@ func TestHost_IsOnline(t *testing.T) {
// Create the host
host, err := NewHost(id.NewIdFromString("test", id.Gateway, t), addr, nil,
false,
false)
GetDefaultHostParams())
if err != nil {
t.Errorf("Unable to create host: %+v", host)
return
......
......@@ -53,7 +53,7 @@ func (m *Manager) GetHost(hostId *id.ID) (*Host, bool) {
// Creates and adds a Host object to the Manager using the given id
func (m *Manager) AddHost(hid *id.ID, address string,
cert []byte, disableTimeout, enableAuth bool) (host *Host, err error) {
cert []byte, params HostParams) (host *Host, err error) {
m.mux.Lock()
defer m.mux.Unlock()
......@@ -64,7 +64,7 @@ func (m *Manager) AddHost(hid *id.ID, address string,
}
//create the new host
host, err = NewHost(hid, address, cert, disableTimeout, enableAuth)
host, err = NewHost(hid, address, cert, params)
if err != nil {
return nil, err
}
......
......@@ -59,7 +59,7 @@ func TestConnectionManager_Disconnect(t *testing.T) {
address := ServerAddress
manager := newManager()
testId := id.NewIdFromString("testId", id.Node, t)
host, err := manager.AddHost(testId, address, nil, false, false)
host, err := manager.AddHost(testId, address, nil, GetDefaultHostParams())
if err != nil {
t.Errorf("Unable to call connnect: %+v", err)
}
......@@ -99,7 +99,7 @@ func TestConnectionManager_DisconnectAll(t *testing.T) {
testId := id.NewIdFromString("testId", id.Generic, t)
testId2 := id.NewIdFromString("TestId2", id.Generic, t)
host, err := manager.AddHost(testId, address, nil, false, false)
host, err := manager.AddHost(testId, address, nil, GetDefaultHostParams())
if err != nil {
t.Errorf("Unable to call NewHost: %+v", err)
}
......@@ -112,7 +112,7 @@ func TestConnectionManager_DisconnectAll(t *testing.T) {
pass++
}
host2, err := manager.AddHost(testId2, address2, nil, false, false)
host2, err := manager.AddHost(testId2, address2, nil, GetDefaultHostParams())
if err != nil {
t.Errorf("Unable to call NewHost: %+v", err)
}
......@@ -157,7 +157,7 @@ func TestConnectionManager_String(t *testing.T) {
certPath := testkeys.GetNodeCertPath()
certData := testkeys.LoadFromPath(certPath)
testID := id.NewIdFromString("test", id.Node, t)
_, err := manager.AddHost(testID, "test", certData, false, false)
_, err := manager.AddHost(testID, "test", certData, GetDefaultHostParams())
if err != nil {
t.Errorf("Unable to call NewHost: %+v", err)
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment