diff --git a/network/gateway/hostPool.go b/network/gateway/hostPool.go index 93e033d8695037da6ae05c1a47ee9da56d9f5023..6624066ff872a71908c4a72235bc0078593bb0d2 100644 --- a/network/gateway/hostPool.go +++ b/network/gateway/hostPool.go @@ -65,7 +65,6 @@ type HostPool struct { poolParams PoolParams rng *fastRNG.StreamGenerator - storage *storage.Session manager HostManager addGatewayChan chan<- network.NodeGateway @@ -128,7 +127,6 @@ func newHostPool(poolParams PoolParams, rng *fastRNG.StreamGenerator, poolParams: poolParams, ndf: netDef.DeepCopy(), rng: rng, - storage: storage, addGatewayChan: addGateway, kv: storage.GetKV().Prefix(hostListPrefix), diff --git a/network/gateway/hostpool_test.go b/network/gateway/hostpool_test.go index bda350f1e655adbd8cc427ab2145361ccc02f942..91885a6cb7b925df935375cd88d852970fdab406 100644 --- a/network/gateway/hostpool_test.go +++ b/network/gateway/hostpool_test.go @@ -78,7 +78,7 @@ func TestNewHostPool_HostListStore(t *testing.T) { id.NewIdFromString("testID2", id.Gateway, t), id.NewIdFromString("testID3", id.Gateway, t), } - err := testStorage.HostList().Store(addedIDs) + err := saveHostList(testStorage.GetKV().Prefix(hostListPrefix), addedIDs) if err != nil { t.Fatalf("Failed to store host list: %+v", err) } @@ -94,7 +94,7 @@ func TestNewHostPool_HostListStore(t *testing.T) { } // Check that the host list was saved to storage - hostList, err := hp.storage.HostList().Get() + hostList, err := getHostList(hp.kv) if err != nil { t.Errorf("Failed to get host list: %+v", err) } @@ -180,13 +180,15 @@ func TestHostPool_ReplaceHost(t *testing.T) { testNdf := getTestNdf(t) newIndex := uint32(20) + testStorage := storage.InitTestingSession(t) + // Construct a manager (bypass business logic in constructor) hostPool := &HostPool{ manager: manager, hostList: make([]*connect.Host, newIndex+1), hostMap: make(map[id.ID]uint32), ndf: testNdf, - storage: storage.InitTestingSession(t), + kv: testStorage.GetKV().Prefix(hostListPrefix), } /* "Replace" a host with no entry */ @@ -281,7 +283,7 @@ func TestHostPool_ReplaceHost(t *testing.T) { } // Check that the host list was saved to storage - hostList, err := hostPool.storage.HostList().Get() + hostList, err := getHostList(hostPool.kv) if err != nil { t.Errorf("Failed to get host list: %+v", err) } @@ -483,13 +485,15 @@ func TestHostPool_UpdateNdf(t *testing.T) { testNdf := getTestNdf(t) newIndex := uint32(20) + testStorage := storage.InitTestingSession(t) + // Construct a manager (bypass business logic in constructor) hostPool := &HostPool{ manager: manager, hostList: make([]*connect.Host, newIndex+1), hostMap: make(map[id.ID]uint32), ndf: testNdf, - storage: storage.InitTestingSession(t), + kv: testStorage.GetKV().Prefix(hostListPrefix), poolParams: DefaultPoolParams(), filter: func(m map[id.ID]int, _ *ndf.NetworkDefinition) map[id.ID]int { return m @@ -850,6 +854,8 @@ func TestHostPool_AddGateway(t *testing.T) { params := DefaultPoolParams() params.MaxPoolSize = uint32(len(testNdf.Gateways)) + testStorage := storage.InitTestingSession(t) + // Construct a manager (bypass business logic in constructor) hostPool := &HostPool{ manager: manager, @@ -858,7 +864,7 @@ func TestHostPool_AddGateway(t *testing.T) { ndf: testNdf, poolParams: params, addGatewayChan: make(chan network.NodeGateway), - storage: storage.InitTestingSession(t), + kv: testStorage.GetKV().Prefix(hostListPrefix), } ndfIndex := 0 @@ -892,7 +898,6 @@ func TestHostPool_RemoveGateway(t *testing.T) { ndf: testNdf, poolParams: params, addGatewayChan: make(chan network.NodeGateway), - storage: storage.InitTestingSession(t), rng: fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG), } diff --git a/network/gateway/storeHostList_test.go b/network/gateway/storeHostList_test.go index d56f60bbe5d2c204857d38e8900466e5b8a4f935..e8ac8088d81b83afe9caf8a75ee3971e7fa36547 100644 --- a/network/gateway/storeHostList_test.go +++ b/network/gateway/storeHostList_test.go @@ -16,31 +16,17 @@ package gateway import ( "fmt" - "gitlab.com/elixxir/client/storage/versioned" - "gitlab.com/elixxir/ekv" + "gitlab.com/elixxir/client/storage" "gitlab.com/xx_network/primitives/id" "reflect" "strings" "testing" ) -// Unit test of NewStore. -func TestNewStore(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - expected := &Store{kv: kv.Prefix(hostListPrefix)} - - s := NewStore(kv) - - if !reflect.DeepEqual(expected, s) { - t.Errorf("NewStore did not return the expected object."+ - "\nexpected: %+v\nreceived: %+v", expected, s) - } -} - // Tests that a host list saved by Store.Store matches the host list returned // by Store.Get. func TestStore_Store_Get(t *testing.T) { - s := NewStore(versioned.NewKV(make(ekv.Memstore))) + // Init list to store list := []*id.ID{ id.NewIdFromString("histID_1", id.Node, t), nil, @@ -48,29 +34,45 @@ func TestStore_Store_Get(t *testing.T) { id.NewIdFromString("histID_3", id.Node, t), } - err := s.Store(list) + // Init storage + testStorage := storage.InitTestingSession(t) + storeKv := testStorage.GetKV().Prefix(hostListPrefix) + + // Save into storage + err := saveHostList(storeKv, list) if err != nil { t.Errorf("Store returned an error: %+v", err) } - newList, err := s.Get() + // Retrieve stored data from storage + newList, err := getHostList(storeKv) if err != nil { t.Errorf("get returned an error: %+v", err) } + // Ensure retrieved data from storage matches + // what was stored if !reflect.DeepEqual(list, newList) { t.Errorf("Failed to save and load host list."+ "\nexpected: %+v\nreceived: %+v", list, newList) } } -// Error path: tests that Store.Get returns an error if not host list is +// Error path: tests that Store.Get returns an error if no host list is // saved in storage. func TestStore_Get_StorageError(t *testing.T) { - s := NewStore(versioned.NewKV(make(ekv.Memstore))) + + // Init storage + testStorage := storage.InitTestingSession(t) + storeKv := testStorage.GetKV().Prefix(hostListPrefix) + + // Construct expected error expectedErr := strings.SplitN(getStorageErr, "%", 2)[0] - _, err := s.Get() + // Attempt to pull from an empty store + _, err := getHostList(storeKv) + + // Check that the expected error is received if err == nil || !strings.Contains(err.Error(), expectedErr) { t.Errorf("get failed to return the expected error."+ "\nexpected: %s\nreceived: %+v", expectedErr, err) @@ -80,6 +82,7 @@ func TestStore_Get_StorageError(t *testing.T) { // Tests that a list of IDs that is marshalled using marshalHostList and // unmarshalled using unmarshalHostList matches the original. func Test_marshalHostList_unmarshalHostList(t *testing.T) { + // Construct example list list := []*id.ID{ id.NewIdFromString("histID_1", id.Node, t), nil, @@ -87,13 +90,16 @@ func Test_marshalHostList_unmarshalHostList(t *testing.T) { id.NewIdFromString("histID_3", id.Node, t), } + // Marshal list data := marshalHostList(list) + // Unmarshal marsalled data into new object newList, err := unmarshalHostList(data) if err != nil { t.Errorf("unmarshalHostList produced an error: %+v", err) } + // Ensure original data and unmarshalled data is consistent if !reflect.DeepEqual(list, newList) { t.Errorf("Failed to marshal and unmarshal ID list."+ "\nexpected: %+v\nreceived: %+v", list, newList)