diff --git a/api/client.go b/api/client.go index 2c28d84d33b193a3236f2332a9c69a16f4e3e5d9..0aee60d576cae5301f0e7d5e48f0d91e68f50a21 100644 --- a/api/client.go +++ b/api/client.go @@ -304,7 +304,7 @@ func Login(storageDir string, password []byte, parameters params.Network) (*Clie } // initialize the auth tracker - c.auth = auth.NewManager(c.switchboard, c.storage, c.network) + c.auth = auth.NewManager(c.switchboard, c.storage, c.network, parameters.ReplayRequests) // Add all processes to the followerServices err = c.registerFollower() @@ -363,7 +363,7 @@ func LoginWithNewBaseNDF_UNSAFE(storageDir string, password []byte, } // initialize the auth tracker - c.auth = auth.NewManager(c.switchboard, c.storage, c.network) + c.auth = auth.NewManager(c.switchboard, c.storage, c.network, parameters.ReplayRequests) err = c.registerFollower() if err != nil { @@ -420,7 +420,7 @@ func LoginWithProtoClient(storageDir string, password []byte, protoClientJSON [] } // initialize the auth tracker - c.auth = auth.NewManager(c.switchboard, c.storage, c.network) + c.auth = auth.NewManager(c.switchboard, c.storage, c.network, parameters.ReplayRequests) err = c.registerFollower() if err != nil { diff --git a/auth/callback.go b/auth/callback.go index c0acff7b66784d4834cdfddc78c297b5496ff84b..7d3f262165df1f8cb5eb1f30a5c3283a99da9150 100644 --- a/auth/callback.go +++ b/auth/callback.go @@ -165,7 +165,7 @@ func (m *Manager) handleRequest(cmixMsg format.Message, return } else { //check if the relationship already exists, - rType, _, _, err := m.storage.Auth().GetRequest(partnerID) + rType, _, c, err := m.storage.Auth().GetRequest(partnerID) if err != nil && !strings.Contains(err.Error(), auth.NoRequest) { // if another error is received, print it and exit em := fmt.Sprintf("Received new Auth request for %s, "+ @@ -183,6 +183,15 @@ func (m *Manager) handleRequest(cmixMsg format.Message, "is a duplicate", partnerID) jww.WARN.Print(em) events.Report(5, "Auth", "DuplicateRequest", em) + // if the caller of the API wants requests replayed, + // replay the duplicate request + if m.replayRequests{ + cbList := m.requestCallbacks.Get(c.ID) + for _, cb := range cbList { + rcb := cb.(interfaces.RequestCallback) + go rcb(c) + } + } return // if we sent a request, then automatically confirm // then exit, nothing else needed @@ -296,7 +305,7 @@ func (m *Manager) handleRequest(cmixMsg format.Message, cbList := m.requestCallbacks.Get(c.ID) for _, cb := range cbList { rcb := cb.(interfaces.RequestCallback) - go rcb(c, "") + go rcb(c) } return } diff --git a/auth/manager.go b/auth/manager.go index b5803d69feb7ae67073e351b1c7fd24c3ac5b597..20b341c3891212edf308ec9c7afcf43a5f1588f5 100644 --- a/auth/manager.go +++ b/auth/manager.go @@ -23,16 +23,19 @@ type Manager struct { storage *storage.Session net interfaces.NetworkManager + + replayRequests bool } func NewManager(sw interfaces.Switchboard, storage *storage.Session, - net interfaces.NetworkManager) *Manager { + net interfaces.NetworkManager, replayRequests bool) *Manager { m := &Manager{ requestCallbacks: newCallbackMap(), confirmCallbacks: newCallbackMap(), rawMessages: make(chan message.Receive, 1000), storage: storage, net: net, + replayRequests: replayRequests, } sw.RegisterChannel("Auth", switchboard.AnyUser(), message.Raw, m.rawMessages) @@ -89,3 +92,17 @@ func (m *Manager) AddSpecificConfirmCallback(id *id.ID, cb interfaces.ConfirmCal func (m *Manager) RemoveSpecificConfirmCallback(id *id.ID) { m.confirmCallbacks.RemoveSpecific(id) } + +// ReplayRequests will iterate through all pending contact requests and resend them +// to the desired contact. +func (m *Manager) ReplayRequests() { + cList := m.storage.Auth().GetAllReceived() + for i := range cList { + c := cList[i] + cbList := m.requestCallbacks.Get(c.ID) + for _, cb := range cbList { + rcb := cb.(interfaces.RequestCallback) + go rcb(c) + } + } +} diff --git a/auth/manager_test.go b/auth/manager_test.go new file mode 100644 index 0000000000000000000000000000000000000000..7332a2db6d45ba7e2c0d9ddbfb18bc02d06ab286 --- /dev/null +++ b/auth/manager_test.go @@ -0,0 +1,105 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package auth + +import ( + "github.com/cloudflare/circl/dh/sidh" + "gitlab.com/elixxir/client/interfaces" + "gitlab.com/elixxir/client/storage" + "gitlab.com/elixxir/client/storage/auth" + util "gitlab.com/elixxir/client/storage/utility" + "gitlab.com/elixxir/client/storage/versioned" + "gitlab.com/elixxir/crypto/contact" + "gitlab.com/elixxir/crypto/cyclic" + "gitlab.com/elixxir/ekv" + "gitlab.com/xx_network/crypto/csprng" + "gitlab.com/xx_network/crypto/large" + "gitlab.com/xx_network/primitives/id" + "io" + "math/rand" + "testing" + "time" +) + +func TestManager_ReplayRequests(t *testing.T) { + s := storage.InitTestingSession(t) + numReceived := 10 + + // Construct barebones manager + m := Manager{ + requestCallbacks: newCallbackMap(), + storage: s, + replayRequests: true, + } + + ch := make(chan struct{}, numReceived) + + // Add multiple received contact requests + for i := 0; i < numReceived; i++ { + c := contact.Contact{ID: id.NewIdFromUInt(rand.Uint64(), id.User, t)} + rng := csprng.NewSystemRNG() + _, sidhPubKey := genSidhAKeys(rng) + + if err := m.storage.Auth().AddReceived(c, sidhPubKey); err != nil { + t.Fatalf("AddReceived() returned an error: %+v", err) + } + + m.requestCallbacks.AddSpecific(c.ID, interfaces.RequestCallback(func(c contact.Contact) { + ch <- struct{}{} + })) + + } + + m.ReplayRequests() + + timeout := time.NewTimer(1 * time.Second) + numChannelReceived := 0 +loop: + for { + select { + case <-ch: + numChannelReceived++ + case <-timeout.C: + break loop + } + } + + if numReceived != numChannelReceived { + t.Errorf("Unexpected number of callbacks called"+ + "\nExpected: %d"+ + "\nReceived: %d", numChannelReceived, numReceived) + } +} + +func makeTestStore(t *testing.T) (*auth.Store, *versioned.KV, []*cyclic.Int) { + kv := versioned.NewKV(make(ekv.Memstore)) + grp := cyclic.NewGroup(large.NewInt(173), large.NewInt(0)) + privKeys := make([]*cyclic.Int, 10) + for i := range privKeys { + privKeys[i] = grp.NewInt(rand.Int63n(170) + 1) + } + + store, err := auth.NewStore(kv, grp, privKeys) + if err != nil { + t.Fatalf("Failed to create new Store: %+v", err) + } + + return store, kv, privKeys +} + +func genSidhAKeys(rng io.Reader) (*sidh.PrivateKey, *sidh.PublicKey) { + sidHPrivKeyA := util.NewSIDHPrivateKey(sidh.KeyVariantSidhA) + sidHPubKeyA := util.NewSIDHPublicKey(sidh.KeyVariantSidhA) + + if err := sidHPrivKeyA.Generate(rng); err != nil { + panic("failure to generate SidH A private key") + } + sidHPrivKeyA.GeneratePublicKey(sidHPubKeyA) + + return sidHPrivKeyA, sidHPubKeyA +} diff --git a/bindings/authenticatedChannels.go b/bindings/authenticatedChannels.go index 67d2a885098b34671d75014189b682ed75a6a32c..36888a604177ae36b37d80f97ad61b01700cb51a 100644 --- a/bindings/authenticatedChannels.go +++ b/bindings/authenticatedChannels.go @@ -63,7 +63,7 @@ func (c *Client) RequestAuthenticatedChannel(recipientMarshaled, func (c *Client) RegisterAuthCallbacks(request AuthRequestCallback, confirm AuthConfirmCallback) { - requestFunc := func(requestor contact.Contact, message string) { + requestFunc := func(requestor contact.Contact) { requestorBind := &Contact{c: &requestor} request.Callback(requestorBind) } @@ -136,3 +136,8 @@ func (c *Client) GetRelationshipFingerprint(partnerID []byte) (string, error) { return c.api.GetRelationshipFingerprint(partner) } + +// ReplayRequests Resends all pending requests over the normal callbacks +func (c *Client) ReplayRequests() () { + c.api.GetAuthRegistrar().ReplayRequests() +} \ No newline at end of file diff --git a/cmd/root.go b/cmd/root.go index 328f814cd5f67dd93b13075e93d62795b3684861..7f38db02a83e6cbcd869637e9c5a6fbec920cd14 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -473,7 +473,7 @@ func initClientCallbacks(client *api.Client) (chan *id.ID, }) if viper.GetBool("unsafe-channel-creation") { authMgr.AddGeneralRequestCallback(func( - requestor contact.Contact, message string) { + requestor contact.Contact) { jww.INFO.Printf("Channel Request: %s", requestor.ID) _, err := client.ConfirmAuthenticatedChannel( @@ -681,13 +681,11 @@ func deleteChannel(client *api.Client, partnerId *id.ID) { } } -func printChanRequest(requestor contact.Contact, message string) { +func printChanRequest(requestor contact.Contact) { msg := fmt.Sprintf("Authentication channel request from: %s\n", requestor.ID) jww.INFO.Printf(msg) fmt.Printf(msg) - msg = fmt.Sprintf("Authentication channel request message: %s\n", message) - jww.INFO.Printf(msg) // fmt.Printf(msg) } diff --git a/cmd/single.go b/cmd/single.go index cc0cadadd56ea05f1656a7ff7ed3858213ee3169..cd6cc0a8815caa9969f02123274a0023512988e9 100644 --- a/cmd/single.go +++ b/cmd/single.go @@ -53,7 +53,7 @@ var singleCmd = &cobra.Command{ // If unsafe channels, then add auto-acceptor if viper.GetBool("unsafe-channel-creation") { authMgr.AddGeneralRequestCallback(func( - requester contact.Contact, message string) { + requester contact.Contact) { jww.INFO.Printf("Got request: %s", requester.ID) _, err := client.ConfirmAuthenticatedChannel(requester) if err != nil { diff --git a/cmd/ud.go b/cmd/ud.go index 767bc12eab4e84c82c93efc2a80e6b4613bc0a7a..34fff1c098f490195a952ac302cc7256ccf7b577 100644 --- a/cmd/ud.go +++ b/cmd/ud.go @@ -53,7 +53,7 @@ var udCmd = &cobra.Command{ // If unsafe channels, add auto-acceptor if viper.GetBool("unsafe-channel-creation") { authMgr.AddGeneralRequestCallback(func( - requester contact.Contact, message string) { + requester contact.Contact) { jww.INFO.Printf("Got Request: %s", requester.ID) _, err := client.ConfirmAuthenticatedChannel(requester) if err != nil { diff --git a/interfaces/auth.go b/interfaces/auth.go index bb3cae74d0a37f0fda2c34e8ce0623af6f8a789f..4ce22fba789b45a5616a640f0409a054f6a05ecb 100644 --- a/interfaces/auth.go +++ b/interfaces/auth.go @@ -12,7 +12,7 @@ import ( "gitlab.com/xx_network/primitives/id" ) -type RequestCallback func(requestor contact.Contact, message string) +type RequestCallback func(requestor contact.Contact) type ConfirmCallback func(partner contact.Contact) type Auth interface { @@ -42,4 +42,6 @@ type Auth interface { AddSpecificConfirmCallback(id *id.ID, cb ConfirmCallback) // Removes a specific callback to be used on auth confirm. RemoveSpecificConfirmCallback(id *id.ID) + //Replays all pending received requests over tha callbacks + ReplayRequests() } diff --git a/interfaces/params/network.go b/interfaces/params/network.go index 67d083802e516de5d0feca95a299416489561a8a..259cbef58fdc572195b0e1ecea284e427afad675 100644 --- a/interfaces/params/network.go +++ b/interfaces/params/network.go @@ -34,6 +34,8 @@ type Network struct { // Determines if the state of every round processed is tracked in ram. // This is very memory intensive and is primarily used for debugging VerboseRoundTracking bool + // Resends auth requests up the stack if received multiple times + ReplayRequests bool Rounds Messages @@ -54,6 +56,7 @@ func GetDefaultNetwork() Network { FastPolling: true, BlacklistedNodes: make([]string, 0), VerboseRoundTracking: false, + ReplayRequests: true, } n.Rounds = GetDefaultRounds() n.Messages = GetDefaultMessage() diff --git a/storage/auth/store.go b/storage/auth/store.go index addd3d138499ae09cfc5e9dc5d221d79d95d060a..2bb90faf774bb777b28b33799934d89f9a966fe9 100644 --- a/storage/auth/store.go +++ b/storage/auth/store.go @@ -261,6 +261,20 @@ func (s *Store) AddReceived(c contact.Contact, key *sidh.PublicKey) error { return nil } +// GetAllReceived returns all pending received contact requests from storage. +func (s *Store) GetAllReceived() []contact.Contact { + s.mux.RLock() + defer s.mux.RUnlock() + cList := make([]contact.Contact, 0, len(s.requests)) + for key := range s.requests { + r := s.requests[key] + if r.rt == Receive { + cList = append(cList, *r.receive) + } + } + return cList +} + // GetFingerprint can return either a private key or a sentRequest if the // fingerprint is found. If it returns a sentRequest, then it takes the lock to // ensure there is only one operator at a time. The user of the API must release diff --git a/storage/auth/store_test.go b/storage/auth/store_test.go index b1c4bb92b95c98eadc23db6e66d1ab5e353efc51..6568cd75540f7633b711d64380e106435b50d359 100644 --- a/storage/auth/store_test.go +++ b/storage/auth/store_test.go @@ -8,6 +8,7 @@ package auth import ( + "bytes" "github.com/cloudflare/circl/dh/sidh" sidhinterface "gitlab.com/elixxir/client/interfaces/sidh" util "gitlab.com/elixxir/client/storage/utility" @@ -23,6 +24,7 @@ import ( "io" "math/rand" "reflect" + "sort" "sync" "testing" ) @@ -764,6 +766,145 @@ func TestStore_Delete_RequestNotInMap(t *testing.T) { } } +// Unit test of Store.GetAllReceived. +func TestStore_GetAllReceived(t *testing.T) { + s, _, _ := makeTestStore(t) + numReceived := 10 + + expectContactList := make([]contact.Contact, 0, numReceived) + // Add multiple received contact requests + for i := 0; i < numReceived; i++ { + c := contact.Contact{ID: id.NewIdFromUInt(rand.Uint64(), id.User, t)} + rng := csprng.NewSystemRNG() + _, sidhPubKey := genSidhAKeys(rng) + + if err := s.AddReceived(c, sidhPubKey); err != nil { + t.Fatalf("AddReceived() returned an error: %+v", err) + } + + expectContactList = append(expectContactList, c) + } + + // Check that GetAllReceived returns all contacts + receivedContactList := s.GetAllReceived() + if len(receivedContactList) != numReceived { + t.Errorf("GetAllReceived did not return expected amount of contacts."+ + "\nExpected: %d"+ + "\nReceived: %d", numReceived, len(receivedContactList)) + } + + // Sort expected and received lists so that they are in the same order + // since extraction from a map does not maintain order + sort.Slice(expectContactList, func(i, j int) bool { + return bytes.Compare(expectContactList[i].ID.Bytes(), expectContactList[j].ID.Bytes()) == -1 + }) + sort.Slice(receivedContactList, func(i, j int) bool { + return bytes.Compare(receivedContactList[i].ID.Bytes(), receivedContactList[j].ID.Bytes()) == -1 + }) + + // Check validity of contacts + if !reflect.DeepEqual(expectContactList, receivedContactList) { + t.Errorf("GetAllReceived did not return expected contact list."+ + "\nExpected: %+v"+ + "\nReceived: %+v", expectContactList, receivedContactList) + } + +} + +// Tests that Store.GetAllReceived returns an empty list when there are no +// received requests. +func TestStore_GetAllReceived_EmptyList(t *testing.T) { + s, _, _ := makeTestStore(t) + + // Check that GetAllReceived returns all contacts + receivedContactList := s.GetAllReceived() + if len(receivedContactList) != 0 { + t.Errorf("GetAllReceived did not return expected amount of contacts."+ + "\nExpected: %d"+ + "\nReceived: %d", 0, len(receivedContactList)) + } + + // Add Sent and Receive requests + for i := 0; i < 10; i++ { + partnerID := id.NewIdFromUInt(rand.Uint64(), id.User, t) + rng := csprng.NewSystemRNG() + sidhPrivKey, sidhPubKey := genSidhAKeys(rng) + sr := &SentRequest{ + kv: s.kv, + partner: partnerID, + partnerHistoricalPubKey: s.grp.NewInt(1), + myPrivKey: s.grp.NewInt(2), + myPubKey: s.grp.NewInt(3), + mySidHPrivKeyA: sidhPrivKey, + mySidHPubKeyA: sidhPubKey, + fingerprint: format.Fingerprint{5}, + } + if err := s.AddSent(sr.partner, sr.partnerHistoricalPubKey, + sr.myPrivKey, sr.myPubKey, sr.mySidHPrivKeyA, + sr.mySidHPubKeyA, sr.fingerprint); err != nil { + t.Fatalf("AddSent() returned an error: %+v", err) + } + } + + // Check that GetAllReceived returns all contacts + receivedContactList = s.GetAllReceived() + if len(receivedContactList) != 0 { + t.Errorf("GetAllReceived did not return expected amount of contacts. "+ + "It may be pulling from Sent Requests."+ + "\nExpected: %d"+ + "\nReceived: %d", 0, len(receivedContactList)) + } + +} + +// Tests that Store.GetAllReceived returns only Sent requests when there +// are both Sent and Receive requests in Store. +func TestStore_GetAllReceived_MixSentReceived(t *testing.T) { + s, _, _ := makeTestStore(t) + numReceived := 10 + + // Add multiple received contact requests + for i := 0; i < numReceived; i++ { + // Add received request + c := contact.Contact{ID: id.NewIdFromUInt(rand.Uint64(), id.User, t)} + rng := csprng.NewSystemRNG() + _, sidhPubKey := genSidhAKeys(rng) + + if err := s.AddReceived(c, sidhPubKey); err != nil { + t.Fatalf("AddReceived() returned an error: %+v", err) + } + + // Add sent request + partnerID := id.NewIdFromUInt(rand.Uint64(), id.User, t) + sidhPrivKey, sidhPubKey := genSidhAKeys(rng) + sr := &SentRequest{ + kv: s.kv, + partner: partnerID, + partnerHistoricalPubKey: s.grp.NewInt(1), + myPrivKey: s.grp.NewInt(2), + myPubKey: s.grp.NewInt(3), + mySidHPrivKeyA: sidhPrivKey, + mySidHPubKeyA: sidhPubKey, + fingerprint: format.Fingerprint{5}, + } + if err := s.AddSent(sr.partner, sr.partnerHistoricalPubKey, + sr.myPrivKey, sr.myPubKey, sr.mySidHPrivKeyA, + sr.mySidHPubKeyA, sr.fingerprint); err != nil { + t.Fatalf("AddSent() returned an error: %+v", err) + } + } + + // Check that GetAllReceived returns all contacts + receivedContactList := s.GetAllReceived() + if len(receivedContactList) != numReceived { + t.Errorf("GetAllReceived did not return expected amount of contacts. "+ + "It may be pulling from Sent Requests."+ + "\nExpected: %d"+ + "\nReceived: %d", numReceived, len(receivedContactList)) + } + +} + func makeTestStore(t *testing.T) (*Store, *versioned.KV, []*cyclic.Int) { kv := versioned.NewKV(make(ekv.Memstore)) grp := cyclic.NewGroup(large.NewInt(173), large.NewInt(0)) diff --git a/storage/session.go b/storage/session.go index e54ab01fbc4a2002bac5ffa445d906ffd8710186..d11e11fc6697edd4e52d3550c6de2d6677a37bd1 100644 --- a/storage/session.go +++ b/storage/session.go @@ -486,6 +486,18 @@ func InitTestingSession(i interface{}) *Session { s.hostList = hostList.NewStore(s.kv) + privKeys := make([]*cyclic.Int, 10) + pubKeys := make([]*cyclic.Int, 10) + for i := range privKeys { + privKeys[i] = cmixGrp.NewInt(5) + pubKeys[i] = cmixGrp.ExpG(privKeys[i], cmixGrp.NewInt(1)) + } + + s.auth, err = auth.NewStore(s.kv, cmixGrp, privKeys) + if err != nil { + jww.FATAL.Panicf("Failed to create auth store: %v", err) + } + s.edgeCheck, err = edge.NewStore(s.kv, uid) if err != nil { jww.FATAL.Panicf("Failed to create new edge Store: %+v", err)