diff --git a/network/rounds/utils_test.go b/network/rounds/utils_test.go index f15e4887ea5351d8385c52f1da65b7dc0a88a6c9..7e538f014c0fe59d800e1f4380cec57660ccd82b 100644 --- a/network/rounds/utils_test.go +++ b/network/rounds/utils_test.go @@ -42,10 +42,18 @@ const ErrorGateway = "Error" type mockMessageRetrievalComms struct { testingSignature *testing.T + hosts map[string]*connect.Host } func (mmrc *mockMessageRetrievalComms) AddHost(hid *id.ID, address string, cert []byte, params connect.HostParams) (host *connect.Host, err error) { - return nil, err + host, err = connect.NewHost(hid, address, cert, params) + if err != nil { + return nil, err + } + + mmrc.hosts[hid.String()] = host + + return host, err } func (mmrc *mockMessageRetrievalComms) RemoveHost(hid *id.ID) { @@ -53,11 +61,9 @@ func (mmrc *mockMessageRetrievalComms) RemoveHost(hid *id.ID) { } func (mmrc *mockMessageRetrievalComms) GetHost(hostId *id.ID) (*connect.Host, bool) { - h, _ := connect.NewHost(hostId, "0.0.0.0", []byte(""), connect.HostParams{ - MaxRetries: 0, - AuthEnabled: false, - }) - return h, true + h, ok := mmrc.hosts[hostId.String()] + return h, ok + } // Mock comm which returns differently based on the host ID