diff --git a/storage/auth/previousNegotiations.go b/storage/auth/previousNegotiations.go index 6a538f51825098f8602a741ea2617825eab4245a..b8394f553e1137840ed139a1622bc786de225074 100644 --- a/storage/auth/previousNegotiations.go +++ b/storage/auth/previousNegotiations.go @@ -22,6 +22,7 @@ import ( "gitlab.com/elixxir/crypto/e2e/auth" "gitlab.com/xx_network/primitives/id" "gitlab.com/xx_network/primitives/netTime" + "strings" ) const ( @@ -157,11 +158,14 @@ func (s *Store) savePreviousNegotiations() error { return s.kv.Set(negotiationPartnersKey, negotiationPartnersVersion, obj) } -// loadPreviousNegotiations loads the list of previousNegotiations partners from -// storage. -func (s *Store) loadPreviousNegotiations() (map[id.ID]struct{}, error) { +// newOrLoadPreviousNegotiations loads the list of previousNegotiations partners +// from storage. +func (s *Store) newOrLoadPreviousNegotiations() (map[id.ID]struct{}, error) { obj, err := s.kv.Get(negotiationPartnersKey, negotiationPartnersVersion) if err != nil { + if strings.Contains(err.Error(), "object not found") { + return make(map[id.ID]struct{}), nil + } return nil, err } diff --git a/storage/auth/previousNegotiations_test.go b/storage/auth/previousNegotiations_test.go index 94351aac983f8b58439177a927c05bc9df6ee616..8d096fc21d054110d0fd048151a8873d90291b4d 100644 --- a/storage/auth/previousNegotiations_test.go +++ b/storage/auth/previousNegotiations_test.go @@ -229,9 +229,9 @@ func TestStore_deletePreviousNegotiationPartner(t *testing.T) { } // Check previousNegotiations in storage - previousNegotiations, err := s.loadPreviousNegotiations() + previousNegotiations, err := s.newOrLoadPreviousNegotiations() if err != nil { - t.Errorf("loadPreviousNegotiations returned an error (%d): %+v", + t.Errorf("newOrLoadPreviousNegotiations returned an error (%d): %+v", i, err) } _, exists = previousNegotiations[*v.partner] @@ -260,8 +260,8 @@ func TestStore_deletePreviousNegotiationPartner(t *testing.T) { } // Tests that Store.previousNegotiations can be saved and loaded from storage -// via Store.savePreviousNegotiations andStore.loadPreviousNegotiations. -func TestStore_savePreviousNegotiations_loadPreviousNegotiations(t *testing.T) { +// via Store.savePreviousNegotiations andStore.newOrLoadPreviousNegotiations. +func TestStore_savePreviousNegotiations_newOrLoadPreviousNegotiations(t *testing.T) { s := &Store{ kv: versioned.NewKV(make(ekv.Memstore)), previousNegotiations: make(map[id.ID]struct{}), @@ -280,9 +280,9 @@ func TestStore_savePreviousNegotiations_loadPreviousNegotiations(t *testing.T) { i, err) } - s.previousNegotiations, err = s.loadPreviousNegotiations() + s.previousNegotiations, err = s.newOrLoadPreviousNegotiations() if err != nil { - t.Errorf("loadPreviousNegotiations returned an error (%d): %+v", + t.Errorf("newOrLoadPreviousNegotiations returned an error (%d): %+v", i, err) } @@ -293,6 +293,26 @@ func TestStore_savePreviousNegotiations_loadPreviousNegotiations(t *testing.T) { } } +// Tests that Store.newOrLoadPreviousNegotiations returns blank negotiations if +// they do not exist. +func TestStore_newOrLoadPreviousNegotiations_noNegotiations(t *testing.T) { + s := &Store{ + kv: versioned.NewKV(make(ekv.Memstore)), + previousNegotiations: make(map[id.ID]struct{}), + } + expected := make(map[id.ID]struct{}) + + blankNegotations, err := s.newOrLoadPreviousNegotiations() + if err != nil { + t.Errorf("newOrLoadPreviousNegotiations returned an error: %+v", err) + } + + if !reflect.DeepEqual(expected, blankNegotations) { + t.Errorf("Loaded previousNegotiations does not match expected."+ + "\nexpected: %v\nreceived: %v", expected, blankNegotations) + } +} + // Tests that a list of partner IDs that is marshalled and unmarshalled via // marshalPreviousNegotiations and unmarshalPreviousNegotiations matches the // original list diff --git a/storage/auth/store.go b/storage/auth/store.go index 83a306d9bbb50f43cf15b4091c8b275feff414d1..069b828dbd663aa7e2fb658612041d70d5888b81 100644 --- a/storage/auth/store.go +++ b/storage/auth/store.go @@ -155,7 +155,7 @@ func LoadStore(kv *versioned.KV, grp *cyclic.Group, privKeys []*cyclic.Int) (*St } // Load previous negotiations from storage - s.previousNegotiations, err = s.loadPreviousNegotiations() + s.previousNegotiations, err = s.newOrLoadPreviousNegotiations() if err != nil { return nil, errors.Errorf("failed to load list of previouse "+ "negotation partner IDs: %+v", err)