Skip to content
Snippets Groups Projects
Commit 72b66ed2 authored by Benjamin Wenger's avatar Benjamin Wenger
Browse files

fixed improper access of the address field of host

parent 7cf66475
Branches
Tags
No related merge requests found
......@@ -24,6 +24,7 @@ import (
"math"
"net"
"sync"
"sync/atomic"
"testing"
"time"
)
......@@ -45,7 +46,7 @@ type Host struct {
id *id.ID
// address:Port being connected to
address string
addressAtomic atomic.Value
// PEM-format TLS Certificate
certificate []byte
......@@ -89,13 +90,14 @@ func NewHost(id *id.ID, address string, cert []byte, disableTimeout,
// Initialize the Host object
host = &Host{
id: id,
address: address,
certificate: cert,
enableAuth: enableAuth,
transmissionToken: token.NewLive(),
receptionToken: token.NewLive(),
}
host.UpdateAddress(address)
// Set the max number of retries for establishing a connection
if disableTimeout {
host.maxRetries = math.MaxInt32
......@@ -155,12 +157,16 @@ func (h *Host) GetId() *id.ID {
// GetAddress returns the address of the host.
func (h *Host) GetAddress() string {
return h.address
a := h.addressAtomic.Load()
if a == nil {
return ""
}
return a.(string)
}
// UpdateAddress updates the address of the host
func (h *Host) UpdateAddress(address string) {
h.address = address
h.addressAtomic.Store(address)
}
// Disconnect closes a the Host connection under the write lock
......@@ -256,7 +262,7 @@ func (h *Host) disconnect() {
err := h.connection.Close()
if err != nil {
jww.ERROR.Printf("Unable to close connection to %s: %+v",
h.address, errors.New(err.Error()))
h.GetAddress(), errors.New(err.Error()))
} else {
h.connection = nil
}
......@@ -275,19 +281,19 @@ func (h *Host) connectHelper() (err error) {
securityDial = grpc.WithTransportCredentials(h.credentials)
} else {
// Create the gRPC client without TLS
jww.WARN.Printf("Connecting to %v without TLS!", h.address)
jww.WARN.Printf("Connecting to %v without TLS!", h.GetAddress())
securityDial = grpc.WithInsecure()
}
jww.DEBUG.Printf("Attempting to establish connection to %s using"+
" credentials: %+v", h.address, securityDial)
" credentials: %+v", h.GetAddress(), securityDial)
// Attempt to establish a new connection
for numRetries := 0; numRetries < h.maxRetries && !h.isAlive(); numRetries++ {
h.disconnect()
jww.INFO.Printf("Connecting to %+v. Attempt number %+v of %+v",
h.address, numRetries, h.maxRetries)
h.GetAddress(), numRetries, h.maxRetries)
// If timeout is enabled, the max wait time becomes
// ~14 seconds (with maxRetries=100)
......@@ -298,13 +304,13 @@ func (h *Host) connectHelper() (err error) {
ctx, cancel := ConnectionContext(time.Duration(backoffTime))
// Create the connection
h.connection, err = grpc.DialContext(ctx, h.address,
h.connection, err = grpc.DialContext(ctx, h.GetAddress(),
securityDial,
grpc.WithBlock(),
grpc.WithKeepaliveParams(KaClientOpts))
if err != nil {
jww.ERROR.Printf("Attempt number %+v to connect to %s failed: %+v\n",
numRetries, h.address, errors.New(err.Error()))
numRetries, h.GetAddress(), errors.New(err.Error()))
}
cancel()
}
......@@ -313,11 +319,12 @@ func (h *Host) connectHelper() (err error) {
if !h.isAlive() {
h.disconnect()
return errors.New(fmt.Sprintf(
"Last try to connect to %s failed. Giving up", h.address))
"Last try to connect to %s failed. Giving up",
h.GetAddress()))
}
// Add the successful connection to the Manager
jww.INFO.Printf("Successfully connected to %v", h.address)
jww.INFO.Printf("Successfully connected to %v", h.GetAddress())
return
}
......@@ -361,7 +368,7 @@ func (h *Host) setCredentials() error {
func (h *Host) String() string {
h.sendMux.RLock()
defer h.sendMux.RUnlock()
addr := h.address
addr := h.GetAddress()
actualConnection := h.connection
creds := h.credentials
......
......@@ -24,7 +24,7 @@ func TestHost_address(t *testing.T) {
return
}
if host.address != testAddress {
if host.GetAddress() != testAddress {
t.Errorf("Expected addresses to match")
}
}
......@@ -33,7 +33,6 @@ func TestHost_GetCertificate(t *testing.T) {
testCert := []byte("TEST")
host := Host{
address: "",
certificate: testCert,
maxRetries: 0,
connection: nil,
......@@ -65,7 +64,8 @@ func TestHost_GetId(t *testing.T) {
func TestHost_GetAddress(t *testing.T) {
// Test values
testAddress := "192.167.1.1:8080"
testHost := Host{address: testAddress}
testHost := Host{}
testHost.UpdateAddress(testAddress)
// Test function
if testHost.GetAddress() != testAddress {
......@@ -78,7 +78,8 @@ func TestHost_GetAddress(t *testing.T) {
func TestHost_UpdateAddress(t *testing.T) {
testAddress := "192.167.1.1:8080"
testUpdatedAddress := "192.167.1.1:8080"
testHost := Host{address: testAddress}
testHost := Host{}
testHost.UpdateAddress(testAddress)
// Test function
if testHost.GetAddress() != testAddress {
......
......@@ -81,6 +81,8 @@ func (m *Manager) RemoveHost(hid *id.ID) {
// Closes all client connections and removes them from Manager
func (m *Manager) DisconnectAll() {
m.mux.RLock()
defer m.mux.RUnlock()
for _, host := range m.connections {
host.Disconnect()
}
......
......@@ -42,7 +42,6 @@ func TestMain(m *testing.M) {
func TestSetCredentials_InvalidCert(t *testing.T) {
host := &Host{
address: "",
certificate: []byte("test"),
}
err := host.setCredentials()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment