From 0e08489b8f9645686b096298e10b14f506949a61 Mon Sep 17 00:00:00 2001 From: Benjamin Wenger <ben@elixxir.ioo> Date: Fri, 28 Jan 2022 09:39:32 -0800 Subject: [PATCH] added mechanisim to replay all requests. they will replay on recept if the flag is set and if ReplayRequests() is called --- api/client.go | 6 +++--- auth/callback.go | 13 +++++++++++-- auth/manager.go | 17 ++++++++++++++++- bindings/authenticatedChannels.go | 7 ++++++- cmd/root.go | 6 ++---- cmd/single.go | 2 +- cmd/ud.go | 2 +- interfaces/auth.go | 4 +++- interfaces/params/network.go | 3 +++ storage/auth/store.go | 13 +++++++++++++ 10 files changed, 59 insertions(+), 14 deletions(-) diff --git a/api/client.go b/api/client.go index 2c28d84d3..0aee60d57 100644 --- a/api/client.go +++ b/api/client.go @@ -304,7 +304,7 @@ func Login(storageDir string, password []byte, parameters params.Network) (*Clie } // initialize the auth tracker - c.auth = auth.NewManager(c.switchboard, c.storage, c.network) + c.auth = auth.NewManager(c.switchboard, c.storage, c.network, parameters.ReplayRequests) // Add all processes to the followerServices err = c.registerFollower() @@ -363,7 +363,7 @@ func LoginWithNewBaseNDF_UNSAFE(storageDir string, password []byte, } // initialize the auth tracker - c.auth = auth.NewManager(c.switchboard, c.storage, c.network) + c.auth = auth.NewManager(c.switchboard, c.storage, c.network, parameters.ReplayRequests) err = c.registerFollower() if err != nil { @@ -420,7 +420,7 @@ func LoginWithProtoClient(storageDir string, password []byte, protoClientJSON [] } // initialize the auth tracker - c.auth = auth.NewManager(c.switchboard, c.storage, c.network) + c.auth = auth.NewManager(c.switchboard, c.storage, c.network, parameters.ReplayRequests) err = c.registerFollower() if err != nil { diff --git a/auth/callback.go b/auth/callback.go index c0acff7b6..7d3f26216 100644 --- a/auth/callback.go +++ b/auth/callback.go @@ -165,7 +165,7 @@ func (m *Manager) handleRequest(cmixMsg format.Message, return } else { //check if the relationship already exists, - rType, _, _, err := m.storage.Auth().GetRequest(partnerID) + rType, _, c, err := m.storage.Auth().GetRequest(partnerID) if err != nil && !strings.Contains(err.Error(), auth.NoRequest) { // if another error is received, print it and exit em := fmt.Sprintf("Received new Auth request for %s, "+ @@ -183,6 +183,15 @@ func (m *Manager) handleRequest(cmixMsg format.Message, "is a duplicate", partnerID) jww.WARN.Print(em) events.Report(5, "Auth", "DuplicateRequest", em) + // if the caller of the API wants requests replayed, + // replay the duplicate request + if m.replayRequests{ + cbList := m.requestCallbacks.Get(c.ID) + for _, cb := range cbList { + rcb := cb.(interfaces.RequestCallback) + go rcb(c) + } + } return // if we sent a request, then automatically confirm // then exit, nothing else needed @@ -296,7 +305,7 @@ func (m *Manager) handleRequest(cmixMsg format.Message, cbList := m.requestCallbacks.Get(c.ID) for _, cb := range cbList { rcb := cb.(interfaces.RequestCallback) - go rcb(c, "") + go rcb(c) } return } diff --git a/auth/manager.go b/auth/manager.go index b5803d69f..4562ed804 100644 --- a/auth/manager.go +++ b/auth/manager.go @@ -23,16 +23,19 @@ type Manager struct { storage *storage.Session net interfaces.NetworkManager + + replayRequests bool } func NewManager(sw interfaces.Switchboard, storage *storage.Session, - net interfaces.NetworkManager) *Manager { + net interfaces.NetworkManager, replayRequests bool) *Manager { m := &Manager{ requestCallbacks: newCallbackMap(), confirmCallbacks: newCallbackMap(), rawMessages: make(chan message.Receive, 1000), storage: storage, net: net, + replayRequests: replayRequests, } sw.RegisterChannel("Auth", switchboard.AnyUser(), message.Raw, m.rawMessages) @@ -89,3 +92,15 @@ func (m *Manager) AddSpecificConfirmCallback(id *id.ID, cb interfaces.ConfirmCal func (m *Manager) RemoveSpecificConfirmCallback(id *id.ID) { m.confirmCallbacks.RemoveSpecific(id) } + +func (m *Manager)ReplayRequests(){ + cList := m.storage.Auth().GetAllReceived() + for i := range cList{ + c := cList[i] + cbList := m.requestCallbacks.Get(c.ID) + for _, cb := range cbList { + rcb := cb.(interfaces.RequestCallback) + go rcb(c) + } + } +} diff --git a/bindings/authenticatedChannels.go b/bindings/authenticatedChannels.go index 67d2a8850..36888a604 100644 --- a/bindings/authenticatedChannels.go +++ b/bindings/authenticatedChannels.go @@ -63,7 +63,7 @@ func (c *Client) RequestAuthenticatedChannel(recipientMarshaled, func (c *Client) RegisterAuthCallbacks(request AuthRequestCallback, confirm AuthConfirmCallback) { - requestFunc := func(requestor contact.Contact, message string) { + requestFunc := func(requestor contact.Contact) { requestorBind := &Contact{c: &requestor} request.Callback(requestorBind) } @@ -136,3 +136,8 @@ func (c *Client) GetRelationshipFingerprint(partnerID []byte) (string, error) { return c.api.GetRelationshipFingerprint(partner) } + +// ReplayRequests Resends all pending requests over the normal callbacks +func (c *Client) ReplayRequests() () { + c.api.GetAuthRegistrar().ReplayRequests() +} \ No newline at end of file diff --git a/cmd/root.go b/cmd/root.go index 328f814cd..7f38db02a 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -473,7 +473,7 @@ func initClientCallbacks(client *api.Client) (chan *id.ID, }) if viper.GetBool("unsafe-channel-creation") { authMgr.AddGeneralRequestCallback(func( - requestor contact.Contact, message string) { + requestor contact.Contact) { jww.INFO.Printf("Channel Request: %s", requestor.ID) _, err := client.ConfirmAuthenticatedChannel( @@ -681,13 +681,11 @@ func deleteChannel(client *api.Client, partnerId *id.ID) { } } -func printChanRequest(requestor contact.Contact, message string) { +func printChanRequest(requestor contact.Contact) { msg := fmt.Sprintf("Authentication channel request from: %s\n", requestor.ID) jww.INFO.Printf(msg) fmt.Printf(msg) - msg = fmt.Sprintf("Authentication channel request message: %s\n", message) - jww.INFO.Printf(msg) // fmt.Printf(msg) } diff --git a/cmd/single.go b/cmd/single.go index cc0cadadd..cd6cc0a88 100644 --- a/cmd/single.go +++ b/cmd/single.go @@ -53,7 +53,7 @@ var singleCmd = &cobra.Command{ // If unsafe channels, then add auto-acceptor if viper.GetBool("unsafe-channel-creation") { authMgr.AddGeneralRequestCallback(func( - requester contact.Contact, message string) { + requester contact.Contact) { jww.INFO.Printf("Got request: %s", requester.ID) _, err := client.ConfirmAuthenticatedChannel(requester) if err != nil { diff --git a/cmd/ud.go b/cmd/ud.go index 767bc12ea..34fff1c09 100644 --- a/cmd/ud.go +++ b/cmd/ud.go @@ -53,7 +53,7 @@ var udCmd = &cobra.Command{ // If unsafe channels, add auto-acceptor if viper.GetBool("unsafe-channel-creation") { authMgr.AddGeneralRequestCallback(func( - requester contact.Contact, message string) { + requester contact.Contact) { jww.INFO.Printf("Got Request: %s", requester.ID) _, err := client.ConfirmAuthenticatedChannel(requester) if err != nil { diff --git a/interfaces/auth.go b/interfaces/auth.go index bb3cae74d..4ce22fba7 100644 --- a/interfaces/auth.go +++ b/interfaces/auth.go @@ -12,7 +12,7 @@ import ( "gitlab.com/xx_network/primitives/id" ) -type RequestCallback func(requestor contact.Contact, message string) +type RequestCallback func(requestor contact.Contact) type ConfirmCallback func(partner contact.Contact) type Auth interface { @@ -42,4 +42,6 @@ type Auth interface { AddSpecificConfirmCallback(id *id.ID, cb ConfirmCallback) // Removes a specific callback to be used on auth confirm. RemoveSpecificConfirmCallback(id *id.ID) + //Replays all pending received requests over tha callbacks + ReplayRequests() } diff --git a/interfaces/params/network.go b/interfaces/params/network.go index 67d083802..259cbef58 100644 --- a/interfaces/params/network.go +++ b/interfaces/params/network.go @@ -34,6 +34,8 @@ type Network struct { // Determines if the state of every round processed is tracked in ram. // This is very memory intensive and is primarily used for debugging VerboseRoundTracking bool + // Resends auth requests up the stack if received multiple times + ReplayRequests bool Rounds Messages @@ -54,6 +56,7 @@ func GetDefaultNetwork() Network { FastPolling: true, BlacklistedNodes: make([]string, 0), VerboseRoundTracking: false, + ReplayRequests: true, } n.Rounds = GetDefaultRounds() n.Messages = GetDefaultMessage() diff --git a/storage/auth/store.go b/storage/auth/store.go index addd3d138..acf15c3d2 100644 --- a/storage/auth/store.go +++ b/storage/auth/store.go @@ -261,6 +261,19 @@ func (s *Store) AddReceived(c contact.Contact, key *sidh.PublicKey) error { return nil } +func (s *Store)GetAllReceived()[]contact.Contact{ + s.mux.RLock() + defer s.mux.RUnlock() + cList := make([]contact.Contact, 0, len(s.requests)) + for key := range s.requests{ + r := s.requests[key] + if r.rt == Receive{ + cList = append(cList,*r.receive) + } + } + return cList +} + // GetFingerprint can return either a private key or a sentRequest if the // fingerprint is found. If it returns a sentRequest, then it takes the lock to // ensure there is only one operator at a time. The user of the API must release -- GitLab