diff --git a/Makefile b/Makefile index 2e56e3530bb3ef5d03352961f8eef095b4a47c01..bd6383ff69c75244eaf1b62ef3cc9c0fb91c5154 100644 --- a/Makefile +++ b/Makefile @@ -22,8 +22,8 @@ build: update_release: GOFLAGS="" go get -u gitlab.com/xx_network/primitives@release GOFLAGS="" go get -u gitlab.com/elixxir/primitives@release - GOFLAGS="" go get -u gitlab.com/elixxir/crypto@release GOFLAGS="" go get -u gitlab.com/xx_network/crypto@release + GOFLAGS="" go get -u gitlab.com/elixxir/crypto@release GOFLAGS="" go get -u gitlab.com/xx_network/comms@release GOFLAGS="" go get -u gitlab.com/elixxir/comms@release diff --git a/README.md b/README.md index 98a8c6b0b888063c9ffd447641f55a0d5339a191..0cd7989af96e9851126932c253e8385c3f9ea8d7 100644 --- a/README.md +++ b/README.md @@ -124,7 +124,7 @@ client --password user1-password --ndf ndf.json -l client1.log -s user1session - client --password user2-password --ndf ndf.json -l client2.log -s user2session --writeContact user2-contact.json --unsafe -m "Hi" # Send E2E Messages -client --password user1-password --ndf ndf.json -l client1.log -s user1session --destfile user2-contact.json --unsafe-channel-creation -m "Hi User 2, from User 1 with E2E Encryption" & +client --password user1-password --ndf ndf.json -l client1.log -s user1session --destfile user1-contact.json --unsafe-channel-creation -m "Hi User 2, from User 1 with E2E Encryption" & client --password user2-password --ndf ndf.json -l client2.log -s user2session --destfile user1-contact.json --unsafe-channel-creation -m "Hi User 1, from User 2 with E2E Encryption" & ``` diff --git a/api/client.go b/api/client.go index ae87fdc6dc36a974875f5a73ddea87ed52fbded1..118615bdcb7d4dc23b6a7fde682a3fdc224b4a6a 100644 --- a/api/client.go +++ b/api/client.go @@ -8,6 +8,7 @@ package api import ( + "gitlab.com/xx_network/comms/connect" "gitlab.com/xx_network/primitives/id" "time" @@ -188,7 +189,7 @@ func OpenClient(storageDir string, password []byte, parameters params.Network) ( return c, nil } -// Login initalizes a client object from existing storage. +// Login initializes a client object from existing storage. func Login(storageDir string, password []byte, parameters params.Network) (*Client, error) { jww.INFO.Printf("Login()") @@ -199,13 +200,13 @@ func Login(storageDir string, password []byte, parameters params.Network) (*Clie } u := c.storage.GetUser() - jww.INFO.Printf("Client Logged in: \n\tTransmisstionID: %s " + + jww.INFO.Printf("Client Logged in: \n\tTransmisstionID: %s "+ "\n\tReceptionID: %s", u.TransmissionID, u.ReceptionID) //Attach the services interface c.services = newServiceProcessiesList(c.runner) - //initilize comms + // initialize comms err = c.initComms() if err != nil { return nil, err @@ -222,10 +223,20 @@ func Login(storageDir string, password []byte, parameters params.Network) (*Clie } } else { jww.WARN.Printf("Registration with permissioning skipped due to " + - "blank permissionign address. Client will not be able to register " + + "blank permissioning address. Client will not be able to register " + "or track network.") } + if def.Notification.Address != "" { + hp := connect.GetDefaultHostParams() + hp.AuthEnabled = false + hp.MaxRetries = 5 + _, err = c.comms.AddHost(&id.NotificationBot, def.Notification.Address, []byte(def.Notification.TlsCertificate), hp) + if err != nil { + jww.WARN.Printf("Failed adding host for notifications: %+v", err) + } + } + // Initialize network and link it to context c.network, err = network.NewManager(c.storage, c.switchboard, c.rng, c.comms, parameters, def) @@ -233,13 +244,7 @@ func Login(storageDir string, password []byte, parameters params.Network) (*Clie return nil, err } - //update gateway connections - err = c.network.GetInstance().UpdateGatewayConnections() - if err != nil { - return nil, err - } - - //initilize the auth tracker + // initialize the auth tracker c.auth = auth.NewManager(c.switchboard, c.storage, c.network) return c, nil @@ -296,13 +301,7 @@ func LoginWithNewBaseNDF_UNSAFE(storageDir string, password []byte, return nil, err } - //update gateway connections - err = c.network.GetInstance().UpdateGatewayConnections() - if err != nil { - return nil, err - } - - //initilize the auth tracker + // initialize the auth tracker c.auth = auth.NewManager(c.switchboard, c.storage, c.network) return c, nil @@ -343,7 +342,7 @@ func (c *Client) initPermissioning(def *ndf.NetworkDefinition) error { jww.ERROR.Printf("Client has failed registration: %s", err) return errors.WithMessage(err, "failed to load client") } - jww.INFO.Printf("Client sucsecfully registered with the network") + jww.INFO.Printf("Client successfully registered with the network") } return nil } @@ -381,7 +380,7 @@ func (c *Client) initPermissioning(def *ndf.NetworkDefinition) error { // Handles both auth confirm and requests func (c *Client) StartNetworkFollower() (<-chan interfaces.ClientError, error) { u := c.GetUser() - jww.INFO.Printf("StartNetworkFollower() \n\tTransmisstionID: %s " + + jww.INFO.Printf("StartNetworkFollower() \n\tTransmisstionID: %s "+ "\n\tReceptionID: %s", u.TransmissionID, u.ReceptionID) c.clientErrorChannel = make(chan interfaces.ClientError, 1000) @@ -436,18 +435,19 @@ func (c *Client) StopNetworkFollower(timeout time.Duration) error { return errors.WithMessage(err, "Failed to Stop the Network Follower") } err = c.runner.Close(timeout) - if err != nil { - return errors.WithMessage(err, "Failed to Stop the Network Follower") - } c.runner = stoppable.NewMulti("client") - err = c.status.toStopped() - if err != nil { - return errors.WithMessage(err, "Failed to Stop the Network Follower") + err2 := c.status.toStopped() + if err2 != nil { + if err ==nil{ + err = err2 + }else{ + err = errors.WithMessage(err,err2.Error()) + } } - return nil + return err } -// Gets the state of the network follower. Returns: +// NetworkFollowerStatus Gets the state of the network follower. Returns: // Stopped - 0 // Starting - 1000 // Running - 2000 @@ -526,13 +526,13 @@ func (c *Client) GetNodeRegistrationStatus() (int, int, error) { cmixStore := c.storage.Cmix() var numRegistered int - for i, n := range nodes{ + for i, n := range nodes { nid, err := id.Unmarshal(n.ID) - if err!=nil{ - return 0,0, errors.Errorf("Failed to unmarshal node ID %v " + + if err != nil { + return 0, 0, errors.Errorf("Failed to unmarshal node ID %v "+ "(#%d): %s", n.ID, i, err.Error()) } - if cmixStore.Has(nid){ + if cmixStore.Has(nid) { numRegistered++ } } diff --git a/api/notifications.go b/api/notifications.go index dff1977765b31b98297feaa75827e1b2edc71fd6..99ab1d89c79faa1c955d4c2e8f587f20e2f48d60 100644 --- a/api/notifications.go +++ b/api/notifications.go @@ -7,7 +7,16 @@ package api -import jww "github.com/spf13/jwalterweatherman" +import ( + "fmt" + "github.com/pkg/errors" + jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/comms/mixmessages" + "gitlab.com/elixxir/crypto/hash" + "gitlab.com/xx_network/crypto/signature/rsa" + "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/id/ephemeral" +) // RegisterForNotifications allows a client to register for push // notifications. @@ -15,24 +24,34 @@ import jww "github.com/spf13/jwalterweatherman" // especially as these rely on third parties (i.e., Firebase *cough* // *cough* google's palantir *cough*) that may represent a security // risk to the user. -func (c *Client) RegisterForNotifications(token []byte) error { +func (c *Client) RegisterForNotifications(token string) error { jww.INFO.Printf("RegisterForNotifications(%s)", token) - // // Pull the host from the manage - // notificationBotHost, ok := cl.receptionManager.Comms.GetHost(&id.NotificationBot) - // if !ok { - // return errors.New("Failed to retrieve host for notification bot") - // } - - // // Send the register message - // _, err := cl.receptionManager.Comms.RegisterForNotifications(notificationBotHost, - // &mixmessages.NotificationToken{ - // Token: notificationToken, - // }) - // if err != nil { - // err := errors.Errorf( - // "RegisterForNotifications: Unable to register for notifications! %s", err) - // return err - // } + fmt.Println("RegisterforNotifications") + // Pull the host from the manage + notificationBotHost, ok := c.comms.GetHost(&id.NotificationBot) + if !ok { + return errors.New("RegisterForNotifications: Failed to retrieve host for notification bot") + } + intermediaryReceptionID, sig, err := c.getIidAndSig() + if err != nil { + return err + } + fmt.Println("Sending message") + // Send the register message + _, err = c.comms.RegisterForNotifications(notificationBotHost, + &mixmessages.NotificationRegisterRequest{ + Token: token, + IntermediaryId: intermediaryReceptionID, + TransmissionRsa: rsa.CreatePublicKeyPem(c.GetUser().TransmissionRSA.GetPublic()), + TransmissionRsaSig: sig, + TransmissionSalt: c.GetUser().TransmissionSalt, + IIDTransmissionRsaSig: []byte("temp"), + }) + if err != nil { + err := errors.Errorf( + "RegisterForNotifications: Unable to register for notifications! %s", err) + return err + } return nil } @@ -40,19 +59,46 @@ func (c *Client) RegisterForNotifications(token []byte) error { // UnregisterForNotifications turns of notifications for this client func (c *Client) UnregisterForNotifications() error { jww.INFO.Printf("UnregisterForNotifications()") - // // Pull the host from the manage - // notificationBotHost, ok := cl.receptionManager.Comms.GetHost(&id.NotificationBot) - // if !ok { - // return errors.New("Failed to retrieve host for notification bot") - // } - - // // Send the unregister message - // _, err := cl.receptionManager.Comms.UnregisterForNotifications(notificationBotHost) - // if err != nil { - // err := errors.Errorf( - // "RegisterForNotifications: Unable to register for notifications! %s", err) - // return err - // } + // Pull the host from the manage + notificationBotHost, ok := c.comms.GetHost(&id.NotificationBot) + if !ok { + return errors.New("Failed to retrieve host for notification bot") + } + intermediaryReceptionID, sig, err := c.getIidAndSig() + if err != nil { + return err + } + // Send the unregister message + _, err = c.comms.UnregisterForNotifications(notificationBotHost, &mixmessages.NotificationUnregisterRequest{ + IntermediaryId: intermediaryReceptionID, + IIDTransmissionRsaSig: sig, + }) + if err != nil { + err := errors.Errorf( + "RegisterForNotifications: Unable to register for notifications! %s", err) + return err + } return nil } + +func (c *Client) getIidAndSig() ([]byte, []byte, error) { + intermediaryReceptionID, err := ephemeral.GetIntermediaryId(c.GetUser().ReceptionID) + if err != nil { + return nil, nil, errors.WithMessage(err, "RegisterForNotifications: Failed to form intermediary ID") + } + h, err := hash.NewCMixHash() + if err != nil { + return nil, nil, errors.WithMessage(err, "RegisterForNotifications: Failed to create cmix hash") + } + _, err = h.Write(intermediaryReceptionID) + if err != nil { + return nil, nil, errors.WithMessage(err, "RegisterForNotifications: Failed to write intermediary ID to hash") + } + + sig, err := rsa.Sign(c.rng.GetStream(), c.GetUser().TransmissionRSA, hash.CMixHash, h.Sum(nil), nil) + if err != nil { + return nil, nil, errors.WithMessage(err, "RegisterForNotifications: Failed to sign intermediary ID") + } + return intermediaryReceptionID, sig, nil +} diff --git a/api/results.go b/api/results.go index cf30e253356567ae2e777f99a061c56b02299940..590634b9646f2b92f3b56e29d422145806807d66 100644 --- a/api/results.go +++ b/api/results.go @@ -11,7 +11,6 @@ import ( "time" jww "github.com/spf13/jwalterweatherman" - "gitlab.com/elixxir/client/network/gateway" pb "gitlab.com/elixxir/comms/mixmessages" "gitlab.com/elixxir/comms/network" ds "gitlab.com/elixxir/comms/network/dataStructures" @@ -158,7 +157,7 @@ func (c *Client) getRoundResults(roundList []id.Round, timeout time.Duration, // Skip if the round is nil (unknown from historical rounds) // they default to timed out, so correct behavior is preserved - if roundReport.RoundInfo == nil || roundReport.TimedOut { + if roundReport.RoundInfo == nil || roundReport.TimedOut { allRoundsSucceeded = false } else { // If available, denote the result @@ -181,42 +180,36 @@ func (c *Client) getRoundResults(roundList []id.Round, timeout time.Duration, // Helper function which asynchronously pings a random gateway until // it gets information on it's requested historical rounds func (c *Client) getHistoricalRounds(msg *pb.HistoricalRounds, - instance *network.Instance, sendResults chan ds.EventReturn, - comms historicalRoundsComm) { + instance *network.Instance, sendResults chan ds.EventReturn, comms historicalRoundsComm) { var resp *pb.HistoricalRoundsResponse //retry 5 times - for i:=0;i<5;i++{ + for i := 0; i < 5; i++ { // Find a gateway to request about the roundRequests - gwHost, err := gateway.Get(instance.GetPartialNdf().Get(), comms, c.rng.GetStream()) - if err != nil { - jww.ERROR.Printf("Failed to track network, NDF has corrupt "+ - "data: %s", err) - continue - } + result, err := c.GetNetworkInterface().GetSender().SendToAny(func(host *connect.Host) (interface{}, error) { + return comms.RequestHistoricalRounds(host, msg) + }) // If an error, retry with (potentially) a different gw host. // If no error from received gateway request, exit loop // and process rounds - resp, err = comms.RequestHistoricalRounds(gwHost, msg) if err == nil { - jww.ERROR.Printf("Failed to lookup historical rounds: %s", - err) - }else{ + resp = result.(*pb.HistoricalRoundsResponse) break + } else { + jww.ERROR.Printf("Failed to lookup historical rounds: %s", err) } } - if resp == nil{ + if resp == nil { return } // Process historical rounds, sending back to the caller thread for _, ri := range resp.Rounds { sendResults <- ds.EventReturn{ - ri, - false, + RoundInfo: ri, } } } diff --git a/api/utilsInterfaces_test.go b/api/utilsInterfaces_test.go index ffa5154795045f361afd70d8cb67d0c45eb6e4c4..62ab16f5231d0bdc81ab30f8e0a1b890fb0931c7 100644 --- a/api/utilsInterfaces_test.go +++ b/api/utilsInterfaces_test.go @@ -10,6 +10,7 @@ import ( "gitlab.com/elixxir/client/interfaces" "gitlab.com/elixxir/client/interfaces/message" "gitlab.com/elixxir/client/interfaces/params" + "gitlab.com/elixxir/client/network/gateway" "gitlab.com/elixxir/client/stoppable" pb "gitlab.com/elixxir/comms/mixmessages" "gitlab.com/elixxir/comms/network" @@ -79,6 +80,7 @@ func (ht *historicalRounds) GetHost(hostId *id.ID) (*connect.Host, bool) { // Contains a test implementation of the networkManager interface. type testNetworkManagerGeneric struct { instance *network.Instance + sender *gateway.Sender } /* Below methods built for interface adherence */ @@ -119,3 +121,7 @@ func (t *testNetworkManagerGeneric) GetStoppable() stoppable.Stoppable { func (t *testNetworkManagerGeneric) InProgressRegistrations() int { return 0 } + +func (t *testNetworkManagerGeneric) GetSender() *gateway.Sender { + return t.sender +} diff --git a/api/utils_test.go b/api/utils_test.go index 31e09e2d4f5b87a16e434ef1c6eef58423a903b3..9b9fccac0225438387907c405efbcc023fa0e5a6 100644 --- a/api/utils_test.go +++ b/api/utils_test.go @@ -8,6 +8,7 @@ package api import ( + "gitlab.com/elixxir/client/network/gateway" "testing" "github.com/pkg/errors" @@ -64,7 +65,10 @@ func newTestingClient(face interface{}) (*Client, error) { return nil, nil } - c.network = &testNetworkManagerGeneric{instance: thisInstance} + p := gateway.DefaultPoolParams() + p.MaxPoolSize = 1 + sender, _ := gateway.NewSender(p, c.rng, def, commsManager, c.storage, nil) + c.network = &testNetworkManagerGeneric{instance: thisInstance, sender: sender} return c, nil } diff --git a/api/version_vars.go b/api/version_vars.go index 1a4ed3b5336c44eeed1cf474d589c1175808c16a..75d1adeef89b0d9ed44699e42d0bb7e6306fbbb6 100644 --- a/api/version_vars.go +++ b/api/version_vars.go @@ -1,10 +1,10 @@ // Code generated by go generate; DO NOT EDIT. // This file was generated by robots at -// 2021-04-02 08:56:27.8657223 -0700 PDT m=+0.054483401 +// 2021-05-07 09:33:37.4750421 -0700 PDT m=+0.043142701 package api -const GITVERSION = `6020ab79 removed an extranious print` -const SEMVER = "2.3.0" +const GITVERSION = `51fdae45 made stop network follower always allow restart` +const SEMVER = "2.5.0" const DEPENDENCIES = `module gitlab.com/elixxir/client go 1.13 @@ -13,12 +13,10 @@ require ( github.com/golang-collections/collections v0.0.0-20130729185459-604e922904d3 github.com/golang/protobuf v1.4.3 github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00 // indirect - github.com/liyue201/goqr v0.0.0-20200803022322-df443203d4ea github.com/magiconair/properties v1.8.4 // indirect github.com/mitchellh/mapstructure v1.4.0 // indirect github.com/pelletier/go-toml v1.8.1 // indirect github.com/pkg/errors v0.9.1 - github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/smartystreets/assertions v1.0.1 // indirect github.com/spf13/afero v1.5.1 // indirect github.com/spf13/cast v1.3.1 // indirect @@ -26,15 +24,16 @@ require ( github.com/spf13/jwalterweatherman v1.1.0 github.com/spf13/viper v1.7.1 gitlab.com/elixxir/bloomfilter v0.0.0-20200930191214-10e9ac31b228 - gitlab.com/elixxir/comms v0.0.4-0.20210401210158-6053ad2e224c - gitlab.com/elixxir/crypto v0.0.7-0.20210401210040-b7f1da24ef13 - gitlab.com/elixxir/ekv v0.1.4 - gitlab.com/elixxir/primitives v0.0.3-0.20210401175645-9b7b92f74ec4 - gitlab.com/xx_network/comms v0.0.4-0.20210401160731-7b8890cdd8ad - gitlab.com/xx_network/crypto v0.0.5-0.20210401160648-4f06cace9123 - gitlab.com/xx_network/primitives v0.0.4-0.20210331161816-ed23858bdb93 - golang.org/x/crypto v0.0.0-20201221181555-eec23a3978ad - golang.org/x/net v0.0.0-20201224014010-6772e930b67b // indirect + gitlab.com/elixxir/comms v0.0.4-0.20210506225017-37485f5ba063 + gitlab.com/elixxir/crypto v0.0.7-0.20210506223047-3196e4301110 + gitlab.com/elixxir/ekv v0.1.5 + gitlab.com/elixxir/primitives v0.0.3-0.20210504210415-34cf31c2816e + gitlab.com/xx_network/comms v0.0.4-0.20210505205155-48daa8448ad7 + gitlab.com/xx_network/crypto v0.0.5-0.20210504210244-9ddabbad25fd + gitlab.com/xx_network/primitives v0.0.4-0.20210504205835-db68f11de78a + golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 + golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 + golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57 // indirect google.golang.org/genproto v0.0.0-20210105202744-fe13368bc0e1 // indirect google.golang.org/grpc v1.34.0 // indirect google.golang.org/protobuf v1.26.0-rc.1 diff --git a/auth/confirm.go b/auth/confirm.go index 40c983c6feb73446495bb9307a2c6793fca7b259..30a51181c97f3e8bea4993eac82555152655dc7f 100644 --- a/auth/confirm.go +++ b/auth/confirm.go @@ -13,11 +13,11 @@ import ( "gitlab.com/elixxir/client/interfaces" "gitlab.com/elixxir/client/interfaces/params" "gitlab.com/elixxir/client/storage" - "gitlab.com/xx_network/primitives/id" "gitlab.com/elixxir/crypto/contact" "gitlab.com/elixxir/crypto/diffieHellman" cAuth "gitlab.com/elixxir/crypto/e2e/auth" "gitlab.com/elixxir/primitives/format" + "gitlab.com/xx_network/primitives/id" "io" ) diff --git a/auth/request.go b/auth/request.go index c43707904f835b19d1e64ba53cb4ec9fe550355f..76af994787f8e1658e13345bf14883e091e7790d 100644 --- a/auth/request.go +++ b/auth/request.go @@ -55,17 +55,20 @@ func RequestAuth(partner, me contact.Contact, message string, rng io.Reader, //lookup if an ongoing request is occurring rqType, sr, _, err := storage.Auth().GetRequest(partner.ID) - if err != nil && !strings.Contains(err.Error(), auth.NoRequest){ + if err == nil { if rqType == auth.Receive { - return 0, errors.WithMessage(err, - "Cannot send a request after receiving a request") + return 0, errors.Errorf("Cannot send a request after " + + "receiving a request") } else if rqType == auth.Sent { resend = true }else{ - return 0, errors.WithMessage(err, - "Cannot send a request after receiving unknown error " + - "on requesting contact status") + return 0, errors.Errorf("Cannot send a request after " + + " a stored request with unknown rqType: %d", rqType) } + }else if !strings.Contains(err.Error(), auth.NoRequest){ + return 0, errors.WithMessage(err, + "Cannot send a request after receiving unknown error " + + "on requesting contact status") } grp := storage.E2e().GetGroup() diff --git a/bindings/callback.go b/bindings/callback.go index a14cb5151358933f2f0bd6af6cdb04156ff4ac37..a6526d24d3ba961db2b79bdc2fbd6a1b950821dc 100644 --- a/bindings/callback.go +++ b/bindings/callback.go @@ -104,16 +104,15 @@ type ClientError interface { Report(source, message, trace string) } - -type LogWriter interface{ +type LogWriter interface { Log(string) } -type writerAdapter struct{ +type writerAdapter struct { lw LogWriter } -func (wa *writerAdapter)Write(p []byte) (n int, err error){ +func (wa *writerAdapter) Write(p []byte) (n int, err error) { wa.lw.Log(string(p)) return len(p), nil } diff --git a/bindings/client.go b/bindings/client.go index d5c40e456de1b263a48184250cbb67fecc9e0c88..33ef281983799bab25a50663a71b151749ba3535 100644 --- a/bindings/client.go +++ b/bindings/client.go @@ -18,6 +18,7 @@ import ( "gitlab.com/elixxir/crypto/contact" "gitlab.com/elixxir/primitives/states" "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/netTime" "sync" "time" ) @@ -25,6 +26,8 @@ import ( var extantClient bool var loginMux sync.Mutex +var clientSingleton *Client + // sets the log level func init() { jww.SetLogThreshold(jww.LevelInfo) @@ -96,7 +99,15 @@ func Login(storageDir string, password []byte, parameters string) (*Client, erro return nil, errors.New(fmt.Sprintf("Failed to login: %+v", err)) } extantClient = true - return &Client{api: *client}, nil + clientSingleton := &Client{api: *client} + return clientSingleton, nil +} + +// returns a previously created client. IF be used if the garbage collector +// removes the client instance on the app side. Is NOT thread safe relative to +// login, newClient, or newPrecannedClient +func GetClientSingleton() *Client { + return clientSingleton } // sets level of logging. All logs the set level and above will be displayed @@ -219,9 +230,9 @@ func (c *Client) StopNetworkFollower(timeoutMS int) error { // WaitForNewtwork will block until either the network is healthy or the // passed timeout. It will return true if the network is healthy func (c *Client) WaitForNetwork(timeoutMS int) bool { - start := time.Now() + start := netTime.Now() timeout := time.Duration(timeoutMS) * time.Millisecond - for time.Now().Sub(start) < timeout { + for netTime.Now().Sub(start) < timeout { if c.api.GetHealth().IsHealthy() { return true } @@ -366,15 +377,15 @@ func (c *Client) WaitForMessageDelivery(marshaledSendReport []byte, "WaitForMessageDelivery callback due to bad Send Report: %+v", err)) } - if sr==nil || sr.rl == nil || len(sr.rl.list) == 0{ + if sr == nil || sr.rl == nil || len(sr.rl.list) == 0 { return errors.New(fmt.Sprintf("Failed to "+ - "WaitForMessageDelivery callback due to invalid Send Report " + + "WaitForMessageDelivery callback due to invalid Send Report "+ "unmarshal: %s", string(marshaledSendReport))) } f := func(allRoundsSucceeded, timedOut bool, rounds map[id.Round]api.RoundResult) { results := make([]byte, len(sr.rl.list)) - jww.INFO.Printf("Processing WaitForMessageDelivery report " + + jww.INFO.Printf("Processing WaitForMessageDelivery report "+ "for %v, success: %v, timedout: %v", sr.mid, allRoundsSucceeded, timedOut) for i, r := range sr.rl.list { diff --git a/bindings/message.go b/bindings/message.go index 6cead4ad09a8be5fc49433bf9e68e5cf0ab87a9f..32947d5fc3b0fc7e87579bdacd39c1bc69edb404 100644 --- a/bindings/message.go +++ b/bindings/message.go @@ -40,7 +40,7 @@ func (m *Message) GetMessageType() int { // Returns the message's timestamp in ms func (m *Message) GetTimestampMS() int64 { ts := m.r.Timestamp.UnixNano() - ts = (ts+999999)/1000000 + ts = (ts + 999999) / 1000000 return ts } diff --git a/bindings/notifications.go b/bindings/notifications.go new file mode 100644 index 0000000000000000000000000000000000000000..461f15c324d55c144372161157e2d76cb0108869 --- /dev/null +++ b/bindings/notifications.go @@ -0,0 +1,42 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2021 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package bindings + +import ( + "encoding/base64" + "github.com/pkg/errors" + "gitlab.com/elixxir/crypto/fingerprint" + "gitlab.com/xx_network/primitives/id" +) + +// NotificationForMe Check if a notification received is for me +func NotificationForMe(messageHash, idFP string, receptionId []byte) (bool, error) { + messageHashBytes, err := base64.StdEncoding.DecodeString(messageHash) + if err != nil { + return false, errors.WithMessage(err, "Failed to decode message ID") + } + idFpBytes, err := base64.StdEncoding.DecodeString(idFP) + if err != nil { + return false, errors.WithMessage(err, "Failed to decode identity fingerprint") + } + rid, err := id.Unmarshal(receptionId) + if err != nil { + return false, errors.WithMessage(err, "Failed to unmartial reception ID") + } + return fingerprint.CheckIdentityFpFromMessageHash(idFpBytes, messageHashBytes, rid), nil +} + +// RegisterForNotifications accepts firebase messaging token +func (c *Client) RegisterForNotifications(token string) error { + return c.api.RegisterForNotifications(token) +} + +// UnregisterForNotifications unregister user for notifications +func (c *Client) UnregisterForNotifications() error { + return c.api.UnregisterForNotifications() +} diff --git a/bindings/notifications_test.go b/bindings/notifications_test.go new file mode 100644 index 0000000000000000000000000000000000000000..c30e6e175566905a59b60ef4a570adcca77af3f9 --- /dev/null +++ b/bindings/notifications_test.go @@ -0,0 +1,23 @@ +package bindings + +import ( + "encoding/base64" + "gitlab.com/elixxir/crypto/fingerprint" + "gitlab.com/xx_network/primitives/id" + "testing" +) + +func TestNotificationForMe(t *testing.T) { + payload := []byte("I'm a payload") + hash := fingerprint.GetMessageHash(payload) + rid := id.NewIdFromString("zezima", id.User, t) + fp := fingerprint.IdentityFP(payload, rid) + + ok, err := NotificationForMe(base64.StdEncoding.EncodeToString(hash), base64.StdEncoding.EncodeToString(fp), rid.Bytes()) + if err != nil { + t.Errorf("Failed to check notification: %+v", err) + } + if !ok { + t.Error("Should have gotten ok response") + } +} diff --git a/bindings/send.go b/bindings/send.go index 673fad0003ac009b707de67be8f3edc9804f639f..886bf0ea819557c50ddef4b41846d0ae3d5a5d00 100644 --- a/bindings/send.go +++ b/bindings/send.go @@ -135,9 +135,9 @@ type SendReport struct { mid e2e.MessageID } -type SendReportDisk struct{ +type SendReportDisk struct { List []id.Round - Mid []byte + Mid []byte } func (sr *SendReport) GetRoundList() *RoundList { @@ -157,14 +157,13 @@ func (sr *SendReport) Marshal() ([]byte, error) { } func (sr *SendReport) Unmarshal(b []byte) error { - srd := SendReportDisk{ - } - if err := json.Unmarshal(b, &srd); err!=nil{ - return errors.New(fmt.Sprintf("Failed to unmarshal send " + + srd := SendReportDisk{} + if err := json.Unmarshal(b, &srd); err != nil { + return errors.New(fmt.Sprintf("Failed to unmarshal send "+ "report: %s", err.Error())) } - copy(sr.mid[:],srd.Mid) - sr.rl = &RoundList{list:srd.List} + copy(sr.mid[:], srd.Mid) + sr.rl = &RoundList{list: srd.List} return nil } diff --git a/bindings/timeNow.go b/bindings/timeNow.go index dbeda1696669f98b9a489990ced0f66f51c9757b..82cc2d798d8da4c16936b53b2ed3621b175e0e02 100644 --- a/bindings/timeNow.go +++ b/bindings/timeNow.go @@ -13,12 +13,12 @@ import ( ) type TimeSource interface { - NowMs() int + NowMs() int64 } // SetTimeSource sets the network time to a custom source. func SetTimeSource(timeNow TimeSource) { netTime.Now = func() time.Time { - return time.Unix(0, int64(timeNow.NowMs()*int(time.Millisecond))) + return time.Unix(0,timeNow.NowMs()*int64(time.Millisecond)) } } diff --git a/cmd/root.go b/cmd/root.go index 9ff887dad0dbb4d3d6e22a649f4ca43e5039a32e..5379cc08411e557fcc39e2d05996632eddfa259e 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -78,7 +78,7 @@ var rootCmd = &cobra.Command{ jww.INFO.Printf("Message ListenerID: %v", listenerID) // Set up auth request handler, which simply prints the - // user id of the requestor. + // user id of the requester. authMgr := client.GetAuthRegistrar() authMgr.AddGeneralRequestCallback(printChanRequest) @@ -115,6 +115,11 @@ var rootCmd = &cobra.Command{ client.GetHealth().AddChannel(connected) waitUntilConnected(connected) + //err = client.RegisterForNotifications([]byte("dJwuGGX3KUyKldWK5PgQH8:APA91bFjuvimRc4LqOyMDiy124aLedifA8DhldtaB_b76ggphnFYQWJc_fq0hzQ-Jk4iYp2wPpkwlpE1fsOjs7XWBexWcNZoU-zgMiM0Mso9vTN53RhbXUferCbAiEylucEOacy9pniN")) + //if err != nil { + // jww.FATAL.Panicf("Failed to register for notifications: %+v", err) + //} + // After connection, make sure we have registered with at least // 85% of the nodes numReg := 1 @@ -163,7 +168,7 @@ var rootCmd = &cobra.Command{ } if !unsafe && !authConfirmed { - jww.INFO.Printf("Waiting for authentication channel "+ + jww.INFO.Printf("Waiting for authentication channel"+ " confirmation with partner %s", recipientID) scnt := uint(0) waitSecs := viper.GetUint("auth-timeout") @@ -249,6 +254,7 @@ var rootCmd = &cobra.Command{ } } fmt.Printf("Received %d\n", receiveCnt) + err = client.StopNetworkFollower(5 * time.Second) if err != nil { jww.WARN.Printf( @@ -322,12 +328,12 @@ func createClient() *api.Client { } else { if userIDprefix != "" { err = api.NewVanityClient(string(ndfJSON), storeDir, - []byte(pass), regCode, userIDprefix) + []byte(pass), regCode, userIDprefix) } else { err = api.NewClient(string(ndfJSON), storeDir, - []byte(pass), regCode) + []byte(pass), regCode) } - + } if err != nil { diff --git a/cmd/version.go b/cmd/version.go index 2a7eca16517f8e7f3d5c547ea35dbc472d4607c5..67b545b53c3ed7e36e85abccbeadf817bc52dc25 100644 --- a/cmd/version.go +++ b/cmd/version.go @@ -18,7 +18,7 @@ import ( ) // Change this value to set the version for this build -const currentVersion = "2.3.0" +const currentVersion = "2.5.0" func Version() string { out := fmt.Sprintf("Elixxir Client v%s -- %s\n\n", api.SEMVER, diff --git a/go.mod b/go.mod index c0ce181a7926d457116daa6b0a14bef1801bed98..7142c43f2d9fa81eacd6878befdc445e5d8f8a1f 100644 --- a/go.mod +++ b/go.mod @@ -17,14 +17,15 @@ require ( github.com/spf13/jwalterweatherman v1.1.0 github.com/spf13/viper v1.7.1 gitlab.com/elixxir/bloomfilter v0.0.0-20200930191214-10e9ac31b228 - gitlab.com/elixxir/comms v0.0.4-0.20210413160356-853e51fc18e5 - gitlab.com/elixxir/crypto v0.0.7-0.20210412231025-6f75c577f803 + gitlab.com/elixxir/comms v0.0.4-0.20210506225017-37485f5ba063 + gitlab.com/elixxir/crypto v0.0.7-0.20210506223047-3196e4301110 gitlab.com/elixxir/ekv v0.1.5 - gitlab.com/elixxir/primitives v0.0.3-0.20210409190923-7bf3cd8d97e7 - gitlab.com/xx_network/comms v0.0.4-0.20210409202820-eb3dca6571d3 - gitlab.com/xx_network/crypto v0.0.5-0.20210405224157-2b1f387b42c1 - gitlab.com/xx_network/primitives v0.0.4-0.20210402222416-37c1c4d3fac4 + gitlab.com/elixxir/primitives v0.0.3-0.20210504210415-34cf31c2816e + gitlab.com/xx_network/comms v0.0.4-0.20210505205155-48daa8448ad7 + gitlab.com/xx_network/crypto v0.0.5-0.20210504210244-9ddabbad25fd + gitlab.com/xx_network/primitives v0.0.4-0.20210504205835-db68f11de78a golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 + golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 golang.org/x/sys v0.0.0-20210403161142-5e06dd20ab57 // indirect google.golang.org/genproto v0.0.0-20210105202744-fe13368bc0e1 // indirect google.golang.org/grpc v1.34.0 // indirect diff --git a/go.sum b/go.sum index 2c9ea3fd15a1237eb861a5d9eda5b25d0d4a42db..c02d426f81e92c2cecddcaabb483586a412caa9e 100644 --- a/go.sum +++ b/go.sum @@ -253,22 +253,20 @@ github.com/zeebo/pcg v1.0.0 h1:dt+dx+HvX8g7Un32rY9XWoYnd0NmKmrIzpHF7qiTDj0= github.com/zeebo/pcg v1.0.0/go.mod h1:09F0S9iiKrwn9rlI5yjLkmrug154/YRW6KnnXVDM/l4= gitlab.com/elixxir/bloomfilter v0.0.0-20200930191214-10e9ac31b228 h1:Gi6rj4mAlK0BJIk1HIzBVMjWNjIUfstrsXC2VqLYPcA= gitlab.com/elixxir/bloomfilter v0.0.0-20200930191214-10e9ac31b228/go.mod h1:H6jztdm0k+wEV2QGK/KYA+MY9nj9Zzatux/qIvDDv3k= -gitlab.com/elixxir/comms v0.0.4-0.20210409192302-249b5af3dbc8 h1:k9BLWNw7CHwH4H3gNWA0Q/BXNg7923AFflWJtYZr5z4= -gitlab.com/elixxir/comms v0.0.4-0.20210409192302-249b5af3dbc8/go.mod h1:/y5QIivolXMa6TO+ZqFWAV49wxlXXxUCqZH9Zi82kXU= -gitlab.com/elixxir/comms v0.0.4-0.20210413160356-853e51fc18e5 h1:Q/+lhZpIDQdIKy9aXNCLkCk8AavFE7HAuaila0sv5mw= -gitlab.com/elixxir/comms v0.0.4-0.20210413160356-853e51fc18e5/go.mod h1:0XsJ63n7knUeSX9BDKQG7xGtX6w0l5WsfplSsMbP9iM= +gitlab.com/elixxir/comms v0.0.4-0.20210505205202-1d4c18a7fcb2 h1:8aL4V7FaKkDb5iPdJ1rlFFhrHrLWUtbmBjw4BysXzEA= +gitlab.com/elixxir/comms v0.0.4-0.20210505205202-1d4c18a7fcb2/go.mod h1:VN0fNE7GFMrkZwRGnqA7fNNRAXDA4CCP6su/FQQ68RI= +gitlab.com/elixxir/comms v0.0.4-0.20210506161214-6371db79ce6f h1:0hvU+6Y+JGFnBu8ZSMk0ukNuYg+GAnVKD8Yo4VwSdao= +gitlab.com/elixxir/comms v0.0.4-0.20210506161214-6371db79ce6f/go.mod h1:7ff+A4Nom55mKiRW7qWsN7LDjGay4OZwiaaIVXZ4hdk= +gitlab.com/elixxir/comms v0.0.4-0.20210506225017-37485f5ba063 h1:9A2FT1IDzb9E0HaEEcRMAZEVRM4SMXpklYvS6owSyIk= +gitlab.com/elixxir/comms v0.0.4-0.20210506225017-37485f5ba063/go.mod h1:KHV+lNKhcsXoor1KQizUHhCuHugnquldrAR8UU5PNKU= gitlab.com/elixxir/crypto v0.0.0-20200804182833-984246dea2c4 h1:28ftZDeYEko7xptCZzeFWS1Iam95dj46TWFVVlKmw6A= gitlab.com/elixxir/crypto v0.0.0-20200804182833-984246dea2c4/go.mod h1:ucm9SFKJo+K0N2GwRRpaNr+tKXMIOVWzmyUD0SbOu2c= gitlab.com/elixxir/crypto v0.0.3 h1:znCt/x2bL4y8czTPaaFkwzdgSgW3BJc/1+dxyf1jqVw= gitlab.com/elixxir/crypto v0.0.3/go.mod h1:ZNgBOblhYToR4m8tj4cMvJ9UsJAUKq+p0gCp07WQmhA= -gitlab.com/elixxir/crypto v0.0.7-0.20210409192145-eab67f2f8931 h1:kY/qBfjrZTFHJnvM1IcxB03+ZQL4+ESUjV4I4kCxoE8= -gitlab.com/elixxir/crypto v0.0.7-0.20210409192145-eab67f2f8931/go.mod h1:ZktO3MT3oNo+g2Nq0GuC3ebJWJphh7t5KwwDDGBegnY= -gitlab.com/elixxir/crypto v0.0.7-0.20210412193049-f3718fa4facb h1:9CT5f+nV4sisutLx8Z3BAEiqjktcCL2ZpEqcpmgzyqA= -gitlab.com/elixxir/crypto v0.0.7-0.20210412193049-f3718fa4facb/go.mod h1:ZktO3MT3oNo+g2Nq0GuC3ebJWJphh7t5KwwDDGBegnY= -gitlab.com/elixxir/crypto v0.0.7-0.20210412195114-be927031747a h1:DSYIXSCWrwkyHUs2fJMliI4+Bd9h+WA5PXI78uvhCj4= -gitlab.com/elixxir/crypto v0.0.7-0.20210412195114-be927031747a/go.mod h1:HMMRBuv/yMqB5c31G9OPlOAifOOqGypCyD5v6py+4vo= -gitlab.com/elixxir/crypto v0.0.7-0.20210412231025-6f75c577f803 h1:8sLODlAYRT0Y9NA+uoMoF1qBrBRrW5TikyKAOvyCd+E= -gitlab.com/elixxir/crypto v0.0.7-0.20210412231025-6f75c577f803/go.mod h1:HMMRBuv/yMqB5c31G9OPlOAifOOqGypCyD5v6py+4vo= +gitlab.com/elixxir/crypto v0.0.7-0.20210504210535-3077ddf9984d h1:E16E+gM2jJosFc8YmT2ISGxcfBThG2KAgsAcQXtxSIc= +gitlab.com/elixxir/crypto v0.0.7-0.20210504210535-3077ddf9984d/go.mod h1:pbq80k+R7XXvjyWDqanD2eCJi1ClfESdKS0K8NndoLs= +gitlab.com/elixxir/crypto v0.0.7-0.20210506223047-3196e4301110 h1:4KWUbx1RI5TABBM2omWl5MLW16dwySglz895X2rhSFQ= +gitlab.com/elixxir/crypto v0.0.7-0.20210506223047-3196e4301110/go.mod h1:pbq80k+R7XXvjyWDqanD2eCJi1ClfESdKS0K8NndoLs= gitlab.com/elixxir/ekv v0.1.5 h1:R8M1PA5zRU1HVnTyrtwybdABh7gUJSCvt1JZwUSeTzk= gitlab.com/elixxir/ekv v0.1.5/go.mod h1:e6WPUt97taFZe5PFLPb1Dupk7tqmDCTQu1kkstqJvw4= gitlab.com/elixxir/primitives v0.0.0-20200731184040-494269b53b4d/go.mod h1:OQgUZq7SjnE0b+8+iIAT2eqQF+2IFHn73tOo+aV11mg= @@ -276,25 +274,24 @@ gitlab.com/elixxir/primitives v0.0.0-20200804170709-a1896d262cd9/go.mod h1:p0Vel gitlab.com/elixxir/primitives v0.0.0-20200804182913-788f47bded40/go.mod h1:tzdFFvb1ESmuTCOl1z6+yf6oAICDxH2NPUemVgoNLxc= gitlab.com/elixxir/primitives v0.0.1 h1:q61anawANlNAExfkeQEE1NCsNih6vNV1FFLoUQX6txQ= gitlab.com/elixxir/primitives v0.0.1/go.mod h1:kNp47yPqja2lHSiS4DddTvFpB/4D9dB2YKnw5c+LJCE= -gitlab.com/elixxir/primitives v0.0.3-0.20210409190923-7bf3cd8d97e7 h1:q3cw7WVtD6hDqTi8ydky+yiqJ4RkWp/hkTSNirr9Z6Y= -gitlab.com/elixxir/primitives v0.0.3-0.20210409190923-7bf3cd8d97e7/go.mod h1:h0QHrjrixLNaP24ZXAgDOZXP4eegrQ24BCZPGitg8Jg= +gitlab.com/elixxir/primitives v0.0.3-0.20210504210415-34cf31c2816e h1:6Z5qAqI/xoWYPMVcItUDYEkOe84YWS1FJa+qjWGcJ2c= +gitlab.com/elixxir/primitives v0.0.3-0.20210504210415-34cf31c2816e/go.mod h1:4pNgiFEQQ11hHCXBRQJN1w9AIuKa1HTlPTxs9UYOXFA= gitlab.com/xx_network/comms v0.0.0-20200805174823-841427dd5023/go.mod h1:owEcxTRl7gsoM8c3RQ5KAm5GstxrJp5tn+6JfQ4z5Hw= -gitlab.com/xx_network/comms v0.0.4-0.20210406210737-45d1e87d294a h1:r0mvBjHPBCYEVmhEe6JhLQDc0+dCORf1ejtuZ8IbyKY= -gitlab.com/xx_network/comms v0.0.4-0.20210406210737-45d1e87d294a/go.mod h1:7ciuA+LTE0GC7upviGbyyb2hrpJG9Pnq2cc5oz2N5Ss= -gitlab.com/xx_network/comms v0.0.4-0.20210409202820-eb3dca6571d3 h1:0o9kveRSEQ9ykRh/hd+z9Iq53YNvFArW1RQ6ICdAG5g= -gitlab.com/xx_network/comms v0.0.4-0.20210409202820-eb3dca6571d3/go.mod h1:7ciuA+LTE0GC7upviGbyyb2hrpJG9Pnq2cc5oz2N5Ss= +gitlab.com/xx_network/comms v0.0.4-0.20210505204621-a93ded09b1ff/go.mod h1:RkNZ0CjeXKRhEFdUeAdCAF6QuK8sO1j2bUg9oqK0OEA= +gitlab.com/xx_network/comms v0.0.4-0.20210505205155-48daa8448ad7 h1:0oQfe8YZ51kYKEj1w9UN2ls0Kp2AHRO6CUbkF/T/UH4= +gitlab.com/xx_network/comms v0.0.4-0.20210505205155-48daa8448ad7/go.mod h1:RkNZ0CjeXKRhEFdUeAdCAF6QuK8sO1j2bUg9oqK0OEA= gitlab.com/xx_network/crypto v0.0.3/go.mod h1:DF2HYvvCw9wkBybXcXAgQMzX+MiGbFPjwt3t17VRqRE= gitlab.com/xx_network/crypto v0.0.4 h1:lpKOL5mTJ2awWMfgBy30oD/UvJVrWZzUimSHlOdZZxo= gitlab.com/xx_network/crypto v0.0.4/go.mod h1:+lcQEy+Th4eswFgQDwT0EXKp4AXrlubxalwQFH5O0Mk= -gitlab.com/xx_network/crypto v0.0.5-0.20210405224157-2b1f387b42c1 h1:4Hrphjtqn3vO8LI872YwVKy5dCFJdD5u0dE4O2QCZqU= -gitlab.com/xx_network/crypto v0.0.5-0.20210405224157-2b1f387b42c1/go.mod h1:CUhRpioyLaKIylg+LIyZX1rhOmFaEXQQ6esNycx9dcA= +gitlab.com/xx_network/crypto v0.0.5-0.20210504210244-9ddabbad25fd h1:jSY1ogxa2/MXthD8jadGr7IYBL4vXQID3VZp1g0GWec= +gitlab.com/xx_network/crypto v0.0.5-0.20210504210244-9ddabbad25fd/go.mod h1:bAqc5+q2K9OXWceHkZX+VnneSKlsSeg+G98O5S4Y2cA= gitlab.com/xx_network/primitives v0.0.0-20200803231956-9b192c57ea7c/go.mod h1:wtdCMr7DPePz9qwctNoAUzZtbOSHSedcK++3Df3psjA= gitlab.com/xx_network/primitives v0.0.0-20200804183002-f99f7a7284da h1:CCVslUwNC7Ul7NG5nu3ThGTSVUt1TxNRX+47f5TUwnk= gitlab.com/xx_network/primitives v0.0.0-20200804183002-f99f7a7284da/go.mod h1:OK9xevzWCaPO7b1wiluVJGk7R5ZsuC7pHY5hteZFQug= gitlab.com/xx_network/primitives v0.0.2 h1:r45yKenJ9e7PylI1ZXJ1Es09oYNaYXjxVy9+uYlwo7Y= gitlab.com/xx_network/primitives v0.0.2/go.mod h1:cs0QlFpdMDI6lAo61lDRH2JZz+3aVkHy+QogOB6F/qc= -gitlab.com/xx_network/primitives v0.0.4-0.20210402222416-37c1c4d3fac4 h1:YPYTKF0zQf08y0eQrjQP01C/EWQTypdqawjZPr5c6rc= -gitlab.com/xx_network/primitives v0.0.4-0.20210402222416-37c1c4d3fac4/go.mod h1:9imZHvYwNFobxueSvVtHneZLk9wTK7HQTzxPm+zhFhE= +gitlab.com/xx_network/primitives v0.0.4-0.20210504205835-db68f11de78a h1:Op+Dfm/Swtrs6Lgo/ro28SrCrftrfQtK9a3/EOoXYAo= +gitlab.com/xx_network/primitives v0.0.4-0.20210504205835-db68f11de78a/go.mod h1:9imZHvYwNFobxueSvVtHneZLk9wTK7HQTzxPm+zhFhE= gitlab.com/xx_network/ring v0.0.2 h1:TlPjlbFdhtJrwvRgIg4ScdngMTaynx/ByHBRZiXCoL0= gitlab.com/xx_network/ring v0.0.2/go.mod h1:aLzpP2TiZTQut/PVHR40EJAomzugDdHXetbieRClXIM= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= diff --git a/interfaces/IsRunning.go b/interfaces/IsRunning.go new file mode 100644 index 0000000000000000000000000000000000000000..950b842864712858b3d1dcf769f4765e5af86159 --- /dev/null +++ b/interfaces/IsRunning.go @@ -0,0 +1,8 @@ +package interfaces + +// this interface is used to allow the follower to to be stopped later if it +// fails + +type Running interface{ + IsRunning()bool +} diff --git a/interfaces/networkManager.go b/interfaces/networkManager.go index da548c855e48a30a5fff0895842f04975e00c88b..1daa7d38104731870459b1323c8d918401a398fd 100644 --- a/interfaces/networkManager.go +++ b/interfaces/networkManager.go @@ -10,6 +10,7 @@ package interfaces import ( "gitlab.com/elixxir/client/interfaces/message" "gitlab.com/elixxir/client/interfaces/params" + "gitlab.com/elixxir/client/network/gateway" "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/comms/network" "gitlab.com/elixxir/crypto/e2e" @@ -24,6 +25,7 @@ type NetworkManager interface { SendCMIX(message format.Message, recipient *id.ID, p params.CMIX) (id.Round, ephemeral.Id, error) GetInstance() *network.Instance GetHealthTracker() HealthTracker + GetSender() *gateway.Sender Follow(report ClientErrorReport) (stoppable.Stoppable, error) CheckGarbledMessages() InProgressRegistrations() int diff --git a/interfaces/params/network.go b/interfaces/params/network.go index b84c38a09310f13d13a5cd74d4ddfb83c62fc717..24af3d540f2e7824328641c07ddf325c37e86dfc 100644 --- a/interfaces/params/network.go +++ b/interfaces/params/network.go @@ -35,13 +35,13 @@ type Network struct { func GetDefaultNetwork() Network { n := Network{ - TrackNetworkPeriod: 100 * time.Millisecond, - MaxCheckedRounds: 500, - RegNodesBufferLen: 500, - NetworkHealthTimeout: 30 * time.Second, - E2EParams: GetDefaultE2ESessionParams(), + TrackNetworkPeriod: 100 * time.Millisecond, + MaxCheckedRounds: 500, + RegNodesBufferLen: 500, + NetworkHealthTimeout: 30 * time.Second, + E2EParams: GetDefaultE2ESessionParams(), ParallelNodeRegistrations: 8, - KnownRoundsThreshold: 1500, //5 rounds/sec * 60 sec/min * 5 min + KnownRoundsThreshold: 1500, //5 rounds/sec * 60 sec/min * 5 min } n.Rounds = GetDefaultRounds() n.Messages = GetDefaultMessage() diff --git a/interfaces/params/rounds.go b/interfaces/params/rounds.go index 87fa53fc04517b8262ffbc8be624b3989dddeb91..07c4c3c25d83f3c115263bef0269ab1c9226c7ee 100644 --- a/interfaces/params/rounds.go +++ b/interfaces/params/rounds.go @@ -39,9 +39,9 @@ func GetDefaultRounds() Rounds { HistoricalRoundsPeriod: 100 * time.Millisecond, NumMessageRetrievalWorkers: 8, - HistoricalRoundsBufferLen: 1000, - LookupRoundsBufferLen: 2000, - ForceHistoricalRounds: false, + HistoricalRoundsBufferLen: 1000, + LookupRoundsBufferLen: 2000, + ForceHistoricalRounds: false, MaxHistoricalRoundsRetries: 3, } } diff --git a/keyExchange/exchange.go b/keyExchange/exchange.go index d94e8800f60d440c3c0c8b41564c04f4ada32ba7..42b7a177e10595bf45dd1b11fe7715e3a0b0c4f2 100644 --- a/keyExchange/exchange.go +++ b/keyExchange/exchange.go @@ -32,7 +32,7 @@ func Start(switchboard *switchboard.Switchboard, sess *storage.Session, net inte // create the trigger stoppable triggerStop := stoppable.NewSingle(keyExchangeTriggerName) - cleanupTrigger := func(){ + cleanupTrigger := func() { switchboard.Unregister(triggerID) } @@ -46,7 +46,7 @@ func Start(switchboard *switchboard.Switchboard, sess *storage.Session, net inte // register the confirm stoppable confirmStop := stoppable.NewSingle(keyExchangeConfirmName) - cleanupConfirm := func(){ + cleanupConfirm := func() { switchboard.Unregister(confirmID) } diff --git a/keyExchange/utils_test.go b/keyExchange/utils_test.go index a1447b990744f92b063420888ea1c96e1d9fb5dd..a23e673f3a258164e4e2c407defbb5dabd5394d9 100644 --- a/keyExchange/utils_test.go +++ b/keyExchange/utils_test.go @@ -13,6 +13,7 @@ import ( "gitlab.com/elixxir/client/interfaces" "gitlab.com/elixxir/client/interfaces/message" "gitlab.com/elixxir/client/interfaces/params" + "gitlab.com/elixxir/client/network/gateway" "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/client/storage" "gitlab.com/elixxir/client/storage/e2e" @@ -103,6 +104,10 @@ func (t *testNetworkManagerGeneric) InProgressRegistrations() int { return 0 } +func (t *testNetworkManagerGeneric) GetSender() *gateway.Sender { + return nil +} + func InitTestingContextGeneric(i interface{}) (*storage.Session, interfaces.NetworkManager) { switch i.(type) { case *testing.T, *testing.M, *testing.B, *testing.PB: @@ -210,6 +215,10 @@ func (t *testNetworkManagerFullExchange) InProgressRegistrations() int { return 0 } +func (t *testNetworkManagerFullExchange) GetSender() *gateway.Sender { + return nil +} + func InitTestingContextFullExchange(i interface{}) (*storage.Session, *switchboard.Switchboard, interfaces.NetworkManager) { switch i.(type) { case *testing.T, *testing.M, *testing.B, *testing.PB: diff --git a/network/checkedRounds.go b/network/checkedRounds.go deleted file mode 100644 index 8ffa9ba39804fd70017123859cf60137f12f8e78..0000000000000000000000000000000000000000 --- a/network/checkedRounds.go +++ /dev/null @@ -1,75 +0,0 @@ -package network - -import ( - "container/list" - "crypto/md5" - "gitlab.com/elixxir/client/storage/reception" - "gitlab.com/xx_network/primitives/id" -) - - -type idFingerprint [16]byte - -type checkedRounds struct{ - lookup map[idFingerprint]*checklist -} - -type checklist struct{ - m map[id.Round]interface{} - l *list.List -} - -func newCheckedRounds()*checkedRounds{ - return &checkedRounds{ - lookup: make(map[idFingerprint]*checklist), - } -} - -func (cr *checkedRounds)Check(identity reception.IdentityUse, rid id.Round)bool{ - idFp := getIdFingerprint(identity) - cl, exists := cr.lookup[idFp] - if !exists{ - cl = &checklist{ - m: make(map[id.Round]interface{}), - l: list.New().Init(), - } - cr.lookup[idFp]=cl - } - - if _, exists := cl.m[rid]; !exists{ - cl.m[rid] = nil - cl.l.PushBack(rid) - return true - } - return false -} - -func (cr *checkedRounds)Prune(identity reception.IdentityUse, earliestAllowed id.Round){ - idFp := getIdFingerprint(identity) - cl, exists := cr.lookup[idFp] - if !exists { - return - } - - e := cl.l.Front() - for e!=nil { - if e.Value.(id.Round)<earliestAllowed{ - delete(cl.m,e.Value.(id.Round)) - lastE := e - e = e.Next() - cl.l.Remove(lastE) - }else{ - break - } - } -} - -func getIdFingerprint(identity reception.IdentityUse)idFingerprint{ - h := md5.New() - h.Write(identity.EphId[:]) - h.Write(identity.Source[:]) - - fp := idFingerprint{} - copy(fp[:], h.Sum(nil)) - return fp -} \ No newline at end of file diff --git a/network/ephemeral/testutil.go b/network/ephemeral/testutil.go index fbcc4e3e8d031097f4d9447c0904f537451b3d40..d2708b0c71c7d112151e4494d3ef631a78adaf62 100644 --- a/network/ephemeral/testutil.go +++ b/network/ephemeral/testutil.go @@ -8,6 +8,7 @@ package ephemeral import ( + "gitlab.com/elixxir/client/network/gateway" "testing" jww "github.com/spf13/jwalterweatherman" @@ -79,6 +80,10 @@ func (t *testNetworkManager) InProgressRegistrations() int { return 0 } +func (t *testNetworkManager) GetSender() *gateway.Sender { + return nil +} + func NewTestNetworkManager(i interface{}) interfaces.NetworkManager { switch i.(type) { case *testing.T, *testing.M, *testing.B: diff --git a/network/ephemeral/tracker.go b/network/ephemeral/tracker.go index 2f19604c49faea57cb58b5cf7caf3476af4d9eab..7292f73d8cfa8741c4bac631dd9c43ab3859b53d 100644 --- a/network/ephemeral/tracker.go +++ b/network/ephemeral/tracker.go @@ -61,6 +61,12 @@ func track(session *storage.Session, ourId *id.ID, stop *stoppable.Single) { for true { now := netTime.Now() + + //hack for inconsistent time on android + if now.Sub(lastCheck) <=0{ + now = lastCheck.Add(time.Nanosecond) + } + // Generates the IDs since the last track protoIds, err := ephemeral.GetIdsByRange(ourId, receptionStore.GetIDSize(), now, now.Sub(lastCheck)) diff --git a/network/follow.go b/network/follow.go index b3edc5a686bd751b7169f9208bdb219679f7b47a..71eee53edaf362fdad6cfdeb700b74cdf96fbf94 100644 --- a/network/follow.go +++ b/network/follow.go @@ -23,22 +23,17 @@ package network // instance import ( - "bytes" "fmt" - "math" - "sync/atomic" - "time" - jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/interfaces" - "gitlab.com/elixxir/client/network/gateway" "gitlab.com/elixxir/client/network/rounds" pb "gitlab.com/elixxir/comms/mixmessages" "gitlab.com/elixxir/primitives/knownRounds" - "gitlab.com/elixxir/primitives/states" "gitlab.com/xx_network/comms/connect" "gitlab.com/xx_network/crypto/csprng" "gitlab.com/xx_network/primitives/id" + "sync/atomic" + "time" ) const debugTrackPeriod = 1 * time.Minute @@ -51,7 +46,7 @@ type followNetworkComms interface { // followNetwork polls the network to get updated on the state of nodes, the // round status, and informs the client when messages can be retrieved. -func (m *manager) followNetwork(report interfaces.ClientErrorReport, quitCh <-chan struct{}) { +func (m *manager) followNetwork(report interfaces.ClientErrorReport, quitCh <-chan struct{}, isRunning interfaces.Running) { ticker := time.NewTicker(m.param.TrackNetworkPeriod) TrackTicker := time.NewTicker(debugTrackPeriod) rng := m.Rng.GetStream() @@ -63,17 +58,23 @@ func (m *manager) followNetwork(report interfaces.ClientErrorReport, quitCh <-ch rng.Close() done = true case <-ticker.C: - m.follow(report, rng, m.Comms) + m.follow(report, rng, m.Comms, isRunning) case <-TrackTicker.C: numPolls := atomic.SwapUint64(m.tracker, 0) - jww.INFO.Printf("Polled the network %d times in the " + + jww.INFO.Printf("Polled the network %d times in the "+ "last %s", numPolls, debugTrackPeriod) } + if !isRunning.IsRunning(){ + jww.ERROR.Printf("Killing network follower " + + "due to failed exit") + return + } } } // executes each iteration of the follower -func (m *manager) follow(report interfaces.ClientErrorReport, rng csprng.Source, comms followNetworkComms) { +func (m *manager) follow(report interfaces.ClientErrorReport, rng csprng.Source, + comms followNetworkComms, isRunning interfaces.Running) { //get the identity we will poll for identity, err := m.Session.Reception().GetIdentity(rng) @@ -84,14 +85,6 @@ func (m *manager) follow(report interfaces.ClientErrorReport, rng csprng.Source, atomic.AddUint64(m.tracker, 1) - //randomly select a gateway to poll - //TODO: make this more intelligent - gwHost, err := gateway.Get(m.Instance.GetPartialNdf().Get(), comms, rng) - if err != nil { - jww.FATAL.Panicf("Failed to follow network, NDF has corrupt "+ - "data: %s", err) - } - // Get client version for poll version := m.Session.GetClientVersion() @@ -106,23 +99,32 @@ func (m *manager) follow(report interfaces.ClientErrorReport, rng csprng.Source, EndTimestamp: identity.EndRequest.UnixNano(), ClientVersion: []byte(version.String()), } - jww.TRACE.Printf("Executing poll for %v(%s) range: %s-%s(%s) from %s", - identity.EphId.Int64(), identity.Source, identity.StartRequest, - identity.EndRequest, identity.EndRequest.Sub(identity.StartRequest), gwHost.GetId()) - pollResp, err := comms.SendPoll(gwHost, &pollReq) + result, err := m.GetSender().SendToAny(func(host *connect.Host) (interface{}, error) { + jww.DEBUG.Printf("Executing poll for %v(%s) range: %s-%s(%s) from %s", + identity.EphId.Int64(), identity.Source, identity.StartRequest, + identity.EndRequest, identity.EndRequest.Sub(identity.StartRequest), host.GetId()) + return comms.SendPoll(host, &pollReq) + }) + if !isRunning.IsRunning(){ + jww.ERROR.Printf("Killing network follower " + + "due to failed exit") + return + } if err != nil { - if report!=nil{ + if report != nil { report( "NetworkFollower", - fmt.Sprintf("Failed to poll network, \"%s\", Gateway: %s", err.Error(), gwHost.String()), + fmt.Sprintf("Failed to poll network, \"%s\":", err.Error()), fmt.Sprintf("%+v", err), ) } - jww.ERROR.Printf("Unable to poll %s for NDF: %+v", gwHost, err) + jww.ERROR.Printf("Unable to poll gateways: %+v", err) return } + pollResp := result.(*pb.GatewayPollResponse) + // ---- Process Network State Update Data ---- gwRoundsState := &knownRounds.KnownRounds{} err = gwRoundsState.Unmarshal(pollResp.KnownRounds) @@ -141,11 +143,8 @@ func (m *manager) follow(report interfaces.ClientErrorReport, rng csprng.Source, return } - err = m.Instance.UpdateGatewayConnections() - if err != nil { - jww.ERROR.Printf("Unable to update gateway connections: %+v", err) - return - } + // update gateway connections + m.GetSender().UpdateNdf(m.GetInstance().GetPartialNdf().Get()) } //check that the stored address space is correct @@ -161,51 +160,52 @@ func (m *manager) follow(report interfaces.ClientErrorReport, rng csprng.Source, return } + // TODO: ClientErr needs to know the source of the error and it doesn't yet // Iterate over ClientErrors for each RoundUpdate - for _, update := range pollResp.Updates { - - // Ignore irrelevant updates - if update.State != uint32(states.COMPLETED) && update.State != uint32(states.FAILED) { - continue - } - - for _, clientErr := range update.ClientErrors { - - // If this Client appears in the ClientError - if bytes.Equal(clientErr.ClientId, m.Session.GetUser().TransmissionID.Marshal()) { - - // Obtain relevant NodeGateway information - nGw, err := m.Instance.GetNodeAndGateway(gwHost.GetId()) - if err != nil { - jww.ERROR.Printf("Unable to get NodeGateway: %+v", err) - return - } - nid, err := nGw.Node.GetNodeId() - if err != nil { - jww.ERROR.Printf("Unable to get NodeID: %+v", err) - return - } - - // FIXME: Should be able to trigger proper type of round event - // FIXME: without mutating the RoundInfo. Signature also needs verified - // FIXME: before keys are deleted - update.State = uint32(states.FAILED) - rnd, err := m.Instance.GetWrappedRound(id.Round(update.ID)) - if err != nil { - jww.ERROR.Printf("Failed to report client error: " + - "Could not get round for event triggering: " + - "Unable to get round %d from instance: %+v", - id.Round(update.ID), err) - break - } - m.Instance.GetRoundEvents().TriggerRoundEvent(rnd) - - // delete all existing keys and trigger a re-registration with the relevant Node - m.Session.Cmix().Remove(nid) - m.Instance.GetAddGatewayChan() <- nGw - } - } - } + //for _, update := range pollResp.Updates { + // + // // Ignore irrelevant updates + // if update.State != uint32(states.COMPLETED) && update.State != uint32(states.FAILED) { + // continue + // } + // + // for _, clientErr := range update.ClientErrors { + // // If this Client appears in the ClientError + // if bytes.Equal(clientErr.ClientId, m.Session.GetUser().TransmissionID.Marshal()) { + // + // // Obtain relevant NodeGateway information + // // TODO ??? + // nGw, err := m.Instance.GetNodeAndGateway(gwHost.GetId()) + // if err != nil { + // jww.ERROR.Printf("Unable to get NodeGateway: %+v", err) + // return + // } + // nid, err := nGw.Node.GetNodeId() + // if err != nil { + // jww.ERROR.Printf("Unable to get NodeID: %+v", err) + // return + // } + // + // // FIXME: Should be able to trigger proper type of round event + // // FIXME: without mutating the RoundInfo. Signature also needs verified + // // FIXME: before keys are deleted + // update.State = uint32(states.FAILED) + // rnd, err := m.Instance.GetWrappedRound(id.Round(update.ID)) + // if err != nil { + // jww.ERROR.Printf("Failed to report client error: "+ + // "Could not get round for event triggering: "+ + // "Unable to get round %d from instance: %+v", + // id.Round(update.ID), err) + // break + // } + // m.Instance.GetRoundEvents().TriggerRoundEvent(rnd) + // + // // delete all existing keys and trigger a re-registration with the relevant Node + // m.Session.Cmix().Remove(nid) + // m.Instance.GetAddGatewayChan() <- nGw + // } + // } + //} } // ---- Identity Specific Round Processing ----- @@ -228,20 +228,11 @@ func (m *manager) follow(report interfaces.ClientErrorReport, rng csprng.Source, return } - firstRound := id.Round(math.MaxUint64) - lastRound := id.Round(0) - //prepare the filter objects for processing - filterList := make([]*rounds.RemoteFilter, filtersEnd-filtersStart) + filterList := make([]*rounds.RemoteFilter, 0, filtersEnd-filtersStart) for i := filtersStart; i < filtersEnd; i++ { if len(pollResp.Filters.Filters[i].Filter) != 0 { - filterList[i-filtersStart] = rounds.NewRemoteFilter(pollResp.Filters.Filters[i]) - if filterList[i-filtersStart].FirstRound() < firstRound { - firstRound = filterList[i-filtersStart].FirstRound() - } - if filterList[i-filtersStart].LastRound() > lastRound { - lastRound = filterList[i-filtersStart].LastRound() - } + filterList = append(filterList, rounds.NewRemoteFilter(pollResp.Filters.Filters[i])) } } @@ -249,7 +240,7 @@ func (m *manager) follow(report interfaces.ClientErrorReport, rng csprng.Source, // are messages waiting in rounds and then sends signals to the appropriate // handling threads roundChecker := func(rid id.Round) bool { - return rounds.Checker(rid, filterList) + return rounds.Checker(rid, filterList, identity.CR) } // move the earliest unknown round tracker forward to the earliest @@ -264,38 +255,32 @@ func (m *manager) follow(report interfaces.ClientErrorReport, rng csprng.Source, earliestRemaining, roundsWithMessages, roundsUnknown := gwRoundsState.RangeUnchecked(updated, m.param.KnownRoundsThreshold, roundChecker) _, changed := identity.ER.Set(earliestRemaining) - if changed{ + if changed { jww.TRACE.Printf("External returns of RangeUnchecked: %d, %v, %v", earliestRemaining, roundsWithMessages, roundsUnknown) jww.DEBUG.Printf("New Earliest Remaining: %d", earliestRemaining) } - roundsWithMessages2 := identity.UR.Iterate(func(rid id.Round)bool{ - if gwRoundsState.Checked(rid){ - return rounds.Checker(rid, filterList) + roundsWithMessages2 := identity.UR.Iterate(func(rid id.Round) bool { + if gwRoundsState.Checked(rid) { + return rounds.Checker(rid, filterList, identity.CR) } return false }, roundsUnknown) - for _, rid := range roundsWithMessages{ - if m.checked.Check(identity, rid){ + for _, rid := range roundsWithMessages { + if identity.CR.Check(rid) { m.round.GetMessagesFromRound(rid, identity) } } - for _, rid := range roundsWithMessages2{ - m.round.GetMessagesFromRound(rid, identity) - } - - earliestToKeep := getEarliestToKeep(m.param.KnownRoundsThreshold, - gwRoundsState.GetLastChecked()) - - m.checked.Prune(identity, earliestToKeep) - -} + identity.CR.Prune() + err = identity.CR.SaveCheckedRounds() + if err != nil { + jww.ERROR.Printf("Could not save rounds for identity %d (%s): %+v", + identity.EphId.Int64(), identity.Source, err) + } -func getEarliestToKeep(delta uint, lastchecked id.Round)id.Round{ - if uint(lastchecked)<delta{ - return 0 + for _, rid := range roundsWithMessages2 { + m.round.GetMessagesFromRound(rid, identity) } - return lastchecked - id.Round(delta) } diff --git a/network/gateway/gateway.go b/network/gateway/gateway.go deleted file mode 100644 index f4a16cf2e78a5be6fe9cadbcd496b2cd7bb95b79..0000000000000000000000000000000000000000 --- a/network/gateway/gateway.go +++ /dev/null @@ -1,112 +0,0 @@ -/////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -/////////////////////////////////////////////////////////////////////////////// - -package gateway - -import ( - "encoding/binary" - "fmt" - "github.com/pkg/errors" - "gitlab.com/elixxir/comms/mixmessages" - "gitlab.com/elixxir/crypto/shuffle" - "gitlab.com/xx_network/comms/connect" - "gitlab.com/xx_network/primitives/id" - "gitlab.com/xx_network/primitives/ndf" - "io" - "math" -) - -type HostGetter interface { - GetHost(hostId *id.ID) (*connect.Host, bool) -} - -// Get the Host of a random gateway in the NDF -func Get(ndf *ndf.NetworkDefinition, hg HostGetter, rng io.Reader) (*connect.Host, error) { - gwLen := uint32(len(ndf.Gateways)) - if gwLen == 0 { - return nil, errors.Errorf("no gateways available") - } - - gwIdx := ReadRangeUint32(0, gwLen, rng) - gwID, err := id.Unmarshal(ndf.Nodes[gwIdx].ID) - if err != nil { - return nil, errors.WithMessage(err, "failed to get Gateway") - } - gwID.SetType(id.Gateway) - gwHost, ok := hg.GetHost(gwID) - if !ok { - return nil, errors.Errorf("host for gateway %s could not be "+ - "retrieved", gwID) - } - return gwHost, nil -} - -// GetAllShuffled returns a shufled list of gateway hosts from the specified round -func GetAllShuffled(hg HostGetter, ri *mixmessages.RoundInfo) ([]*connect.Host, error) { - roundTop := ri.GetTopology() - hosts := make([]*connect.Host, 0) - shuffledList := make([]uint64, 0) - - // Collect all host information from the round - for index, _ := range roundTop { - selectedId, err := id.Unmarshal(roundTop[index]) - if err != nil { - return nil, err - } - - selectedId.SetType(id.Gateway) - - gwHost, ok := hg.GetHost(selectedId) - if !ok { - return nil, errors.Errorf("Could not find host for gateway %s", selectedId) - } - hosts = append(hosts, gwHost) - shuffledList = append(shuffledList, uint64(index)) - } - - returnHosts := make([]*connect.Host, len(hosts)) - - // Shuffle a list corresponding to the valid gateway hosts - shuffle.Shuffle(&shuffledList) - - // Index through the shuffled list, building a list - // of shuffled gateways from the round - for index, shuffledIndex := range shuffledList { - returnHosts[index] = hosts[shuffledIndex] - } - - return returnHosts, nil - -} - -// ReadUint32 reads an integer from an io.Reader (which should be a CSPRNG) -func ReadUint32(rng io.Reader) uint32 { - var rndBytes [4]byte - i, err := rng.Read(rndBytes[:]) - if i != 4 || err != nil { - panic(fmt.Sprintf("cannot read from rng: %+v", err)) - } - return binary.BigEndian.Uint32(rndBytes[:]) -} - -// ReadRangeUint32 reduces an integer from 0, MaxUint32 to the range start, end -func ReadRangeUint32(start, end uint32, rng io.Reader) uint32 { - size := end - start - // note we could just do the part inside the () here, but then extra - // can == size which means a little bit of range is wastes, either - // choice seems negligible so we went with the "more correct" - extra := (math.MaxUint32%size + 1) % size - limit := math.MaxUint32 - extra - // Loop until we read something inside the limit - for { - res := ReadUint32(rng) - if res > limit { - continue - } - return (res % size) + start - } -} diff --git a/network/gateway/hostPool.go b/network/gateway/hostPool.go new file mode 100644 index 0000000000000000000000000000000000000000..34b381c67a928f0bce6c9e6d14e3fa9c87458644 --- /dev/null +++ b/network/gateway/hostPool.go @@ -0,0 +1,467 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +// Package gateway Handles functionality related to providing Gateway connect.Host objects +// for message sending to the rest of the client repo +// Used to minimize # of open connections on mobile clients + +package gateway + +import ( + "encoding/binary" + "fmt" + "github.com/pkg/errors" + jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/client/storage" + "gitlab.com/elixxir/comms/network" + "gitlab.com/elixxir/crypto/fastRNG" + "gitlab.com/xx_network/comms/connect" + "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/ndf" + "golang.org/x/net/context" + "io" + "math" + "strings" + "sync" + "time" +) + +// List of errors that initiate a Host replacement +var errorsList = []string{context.DeadlineExceeded.Error(), "connection refused", "host disconnected", + "transport is closing", "all SubConns are in TransientFailure", "Last try to connect", + ndf.NO_NDF, "Host is in cool down"} + +// HostManager Interface allowing storage and retrieval of Host objects +type HostManager interface { + GetHost(hostId *id.ID) (*connect.Host, bool) + AddHost(hid *id.ID, address string, cert []byte, params connect.HostParams) (host *connect.Host, err error) + RemoveHost(hid *id.ID) +} + +// HostPool Handles providing hosts to the Client +type HostPool struct { + hostMap map[id.ID]uint32 // map key to its index in the slice + hostList []*connect.Host // each index in the slice contains the value + hostMux sync.RWMutex // Mutex for the above map/list combination + + ndfMap map[id.ID]int // map gateway ID to its index in the ndf + ndf *ndf.NetworkDefinition + ndfMux sync.RWMutex + + poolParams PoolParams + rng *fastRNG.StreamGenerator + storage *storage.Session + manager HostManager + addGatewayChan chan network.NodeGateway +} + +// PoolParams Allows configuration of HostPool parameters +type PoolParams struct { + MaxPoolSize uint32 // Maximum number of Hosts in the HostPool + PoolSize uint32 // Allows override of HostPool size. Set to zero for dynamic size calculation + // TODO: Move up a layer + ProxyAttempts uint32 // How many proxies will be used in event of send failure + HostParams connect.HostParams // Parameters for the creation of new Host objects +} + +// DefaultPoolParams Returns a default set of PoolParams +func DefaultPoolParams() PoolParams { + p := PoolParams{ + MaxPoolSize: 30, + ProxyAttempts: 5, + PoolSize: 0, + HostParams: connect.GetDefaultHostParams(), + } + p.HostParams.MaxRetries = 1 + p.HostParams.AuthEnabled = false + p.HostParams.EnableCoolOff = true + p.HostParams.NumSendsBeforeCoolOff = 1 + p.HostParams.CoolOffTimeout = 5 * time.Minute + p.HostParams.SendTimeout = 3500 * time.Millisecond + return p +} + +// Build and return new HostPool object +func newHostPool(poolParams PoolParams, rng *fastRNG.StreamGenerator, ndf *ndf.NetworkDefinition, getter HostManager, + storage *storage.Session, addGateway chan network.NodeGateway) (*HostPool, error) { + var err error + + // Determine size of HostPool + if poolParams.PoolSize == 0 { + poolParams.PoolSize, err = getPoolSize(uint32(len(ndf.Gateways)), poolParams.MaxPoolSize) + if err != nil { + return nil, err + } + } + + result := &HostPool{ + manager: getter, + hostMap: make(map[id.ID]uint32), + hostList: make([]*connect.Host, poolParams.PoolSize), + poolParams: poolParams, + ndf: ndf, + rng: rng, + storage: storage, + addGatewayChan: addGateway, + } + + // Propagate the NDF + err = result.updateConns() + if err != nil { + return nil, err + } + + // Build the initial HostPool and return + for i := 0; i < len(result.hostList); i++ { + err := result.forceReplace(uint32(i)) + if err != nil { + return nil, err + } + } + + jww.INFO.Printf("Initialized HostPool with size: %d/%d", poolParams.PoolSize, len(ndf.Gateways)) + return result, nil +} + +// UpdateNdf Mutates internal ndf to the given ndf +func (h *HostPool) UpdateNdf(ndf *ndf.NetworkDefinition) { + if len(ndf.Gateways) == 0 { + jww.WARN.Printf("Unable to UpdateNdf: no gateways available") + return + } + + h.ndfMux.Lock() + h.ndf = ndf + + h.hostMux.Lock() + err := h.updateConns() + h.hostMux.Unlock() + if err != nil { + jww.ERROR.Printf("Unable to updateConns: %+v", err) + } + h.ndfMux.Unlock() +} + +// Obtain a random, unique list of Hosts of the given length from the HostPool +func (h *HostPool) getAny(length uint32, excluded []*id.ID) []*connect.Host { + if length > h.poolParams.PoolSize { + length = h.poolParams.PoolSize + } + + checked := make(map[uint32]interface{}) // Keep track of Hosts already selected to avoid duplicates + if excluded != nil { + // Add excluded Hosts to already-checked list + for i := range excluded { + gwId := excluded[i] + if idx, ok := h.hostMap[*gwId]; ok { + checked[idx] = nil + } + } + } + + result := make([]*connect.Host, 0, length) + rng := h.rng.GetStream() + h.hostMux.RLock() + for i := uint32(0); i < length; { + // If we've checked the entire HostPool, bail + if uint32(len(checked)) >= h.poolParams.PoolSize { + break + } + + // Check the next HostPool index + gwIdx := readRangeUint32(0, h.poolParams.PoolSize, rng) + if _, ok := checked[gwIdx]; !ok { + result = append(result, h.hostList[gwIdx]) + checked[gwIdx] = nil + i++ + } + } + h.hostMux.RUnlock() + rng.Close() + + return result +} + +// Obtain a specific connect.Host from the manager, irrespective of the HostPool +func (h *HostPool) getSpecific(target *id.ID) (*connect.Host, bool) { + return h.manager.GetHost(target) +} + +// Try to obtain the given targets from the HostPool +// If each is not present, obtain a random replacement from the HostPool +func (h *HostPool) getPreferred(targets []*id.ID) []*connect.Host { + checked := make(map[uint32]interface{}) // Keep track of Hosts already selected to avoid duplicates + length := len(targets) + if length > int(h.poolParams.PoolSize) { + length = int(h.poolParams.PoolSize) + } + result := make([]*connect.Host, length) + + rng := h.rng.GetStream() + h.hostMux.RLock() + for i := 0; i < length; { + if hostIdx, ok := h.hostMap[*targets[i]]; ok { + result[i] = h.hostList[hostIdx] + checked[hostIdx] = nil + i++ + continue + } + + gwIdx := readRangeUint32(0, h.poolParams.PoolSize, rng) + if _, ok := checked[gwIdx]; !ok { + result[i] = h.hostList[gwIdx] + checked[gwIdx] = nil + i++ + } + } + h.hostMux.RUnlock() + rng.Close() + + return result +} + +// Replaces the given hostId in the HostPool if the given hostErr is in errorList +func (h *HostPool) checkReplace(hostId *id.ID, hostErr error) error { + // Check if Host should be replaced + doReplace := false + if hostErr != nil { + for _, errString := range errorsList { + if strings.Contains(hostErr.Error(), errString) { + // Host needs replaced, flag and continue + doReplace = true + break + } + } + } + + if doReplace { + h.hostMux.Lock() + defer h.hostMux.Unlock() + + // If the Host is still in the pool + if oldPoolIndex, ok := h.hostMap[*hostId]; ok { + // Replace it + h.ndfMux.RLock() + err := h.forceReplace(oldPoolIndex) + h.ndfMux.RUnlock() + return err + } + } + return nil +} + +// Replace given Host index with a new, randomly-selected Host from the NDF +func (h *HostPool) forceReplace(oldPoolIndex uint32) error { + rng := h.rng.GetStream() + defer rng.Close() + + // Loop until a replacement Host is found + for { + // Randomly select a new Gw by index in the NDF + ndfIdx := readRangeUint32(0, uint32(len(h.ndf.Gateways)), rng) + jww.DEBUG.Printf("Attempting to replace Host at HostPool %d with Host at NDF %d...", oldPoolIndex, ndfIdx) + + // Use the random ndfIdx to obtain a GwId from the NDF + gwId, err := id.Unmarshal(h.ndf.Gateways[ndfIdx].ID) + if err != nil { + return errors.WithMessage(err, "failed to get Gateway for pruning") + } + + // Verify the new GwId is not already in the hostMap + if _, ok := h.hostMap[*gwId]; !ok { + // If it is a new GwId, replace the old Host with the new Host + return h.replaceHost(gwId, oldPoolIndex) + } + } +} + +// Replace the given slot in the HostPool with a new Gateway with the specified ID +func (h *HostPool) replaceHost(newId *id.ID, oldPoolIndex uint32) error { + // Obtain that GwId's Host object + newHost, ok := h.manager.GetHost(newId) + if !ok { + return errors.Errorf("host for gateway %s could not be "+ + "retrieved", newId) + } + + // Keep track of oldHost for cleanup + oldHost := h.hostList[oldPoolIndex] + + // Use the poolIdx to overwrite the random Host in the corresponding index in the hostList + h.hostList[oldPoolIndex] = newHost + // Use the GwId to keep track of the new random Host's index in the hostList + h.hostMap[*newId] = oldPoolIndex + + // Clean up and move onto next Host + if oldHost != nil { + delete(h.hostMap, *oldHost.GetId()) + go oldHost.Disconnect() + } + jww.DEBUG.Printf("Replaced Host at %d with new Host %s", oldPoolIndex, newId.String()) + return nil +} + +// Force-add the Gateways to the HostPool, each replacing a random Gateway +func (h *HostPool) forceAdd(gwId *id.ID) error { + rng := h.rng.GetStream() + h.hostMux.Lock() + defer h.hostMux.Unlock() + defer rng.Close() + + // Verify the GwId is not already in the hostMap + if _, ok := h.hostMap[*gwId]; ok { + // If it is, skip + return nil + } + + // Randomly select another Gateway in the HostPool for replacement + poolIdx := readRangeUint32(0, h.poolParams.PoolSize, rng) + return h.replaceHost(gwId, poolIdx) +} + +// Updates the internal HostPool with any changes to the NDF +func (h *HostPool) updateConns() error { + // Prepare NDFs for comparison + newMap, err := convertNdfToMap(h.ndf) + if err != nil { + return errors.Errorf("Unable to convert new NDF to set: %+v", err) + } + + // Handle adding Gateways + for gwId, ndfIdx := range newMap { + if _, ok := h.ndfMap[gwId]; !ok { + // If GwId in newMap is not in ndfMap, add the Gateway + h.addGateway(gwId.DeepCopy(), ndfIdx) + } + } + + // Handle removing Gateways + for gwId := range h.ndfMap { + if _, ok := newMap[gwId]; !ok { + // If GwId in ndfMap is not in newMap, remove the Gateway + h.removeGateway(gwId.DeepCopy()) + } + } + + // Update the internal NDF set + h.ndfMap = newMap + return nil +} + +// Takes ndf.Gateways and puts their IDs into a map object +func convertNdfToMap(ndf *ndf.NetworkDefinition) (map[id.ID]int, error) { + result := make(map[id.ID]int) + if ndf == nil { + return result, nil + } + + // Process gateway Id's into set + for i := range ndf.Gateways { + gw := ndf.Gateways[i] + gwId, err := id.Unmarshal(gw.ID) + if err != nil { + return nil, err + } + result[*gwId] = i + } + + return result, nil +} + +// updateConns helper for removing old Gateways +func (h *HostPool) removeGateway(gwId *id.ID) { + h.manager.RemoveHost(gwId) + // If needed, replace the removed Gateway in the HostPool with a new one + if poolIndex, ok := h.hostMap[*gwId]; ok { + err := h.forceReplace(poolIndex) + if err != nil { + jww.ERROR.Printf("Unable to removeGateway: %+v", err) + } + } +} + +// updateConns helper for adding new Gateways +func (h *HostPool) addGateway(gwId *id.ID, ndfIndex int) { + gw := h.ndf.Gateways[ndfIndex] + + //check if the host exists + host, ok := h.manager.GetHost(gwId) + if !ok { + + // Check if gateway ID collides with an existing hard coded ID + if id.CollidesWithHardCodedID(gwId) { + jww.ERROR.Printf("Gateway ID invalid, collides with a "+ + "hard coded ID. Invalid ID: %v", gwId.Marshal()) + } + + // Add the new gateway host + _, err := h.manager.AddHost(gwId, gw.Address, []byte(gw.TlsCertificate), h.poolParams.HostParams) + if err != nil { + jww.ERROR.Printf("Could not add gateway host %s: %+v", gwId, err) + } + + // Send AddGateway event if we do not already possess keys for the GW + if !h.storage.Cmix().Has(gwId) { + ng := network.NodeGateway{ + Node: h.ndf.Nodes[ndfIndex], + Gateway: gw, + } + + select { + case h.addGatewayChan <- ng: + default: + jww.WARN.Printf("Unable to send AddGateway event for id %s", gwId.String()) + } + } + + } else if host.GetAddress() != gw.Address { + host.UpdateAddress(gw.Address) + } +} + +// getPoolSize determines the size of the HostPool based on the size of the NDF +func getPoolSize(ndfLen, maxSize uint32) (uint32, error) { + // Verify the NDF has at least one Gateway for the HostPool + if ndfLen == 0 { + return 0, errors.Errorf("Unable to create HostPool: no gateways available") + } + + // PoolSize = ceil(sqrt(len(ndf,Gateways))) + poolSize := uint32(math.Ceil(math.Sqrt(float64(ndfLen)))) + if poolSize > maxSize { + return maxSize, nil + } + return poolSize, nil +} + +// readUint32 reads an integer from an io.Reader (which should be a CSPRNG) +func readUint32(rng io.Reader) uint32 { + var rndBytes [4]byte + i, err := rng.Read(rndBytes[:]) + if i != 4 || err != nil { + panic(fmt.Sprintf("cannot read from rng: %+v", err)) + } + return binary.BigEndian.Uint32(rndBytes[:]) +} + +// readRangeUint32 reduces an integer from 0, MaxUint32 to the range start, end +func readRangeUint32(start, end uint32, rng io.Reader) uint32 { + size := end - start + // note we could just do the part inside the () here, but then extra + // can == size which means a little bit of range is wastes, either + // choice seems negligible so we went with the "more correct" + extra := (math.MaxUint32%size + 1) % size + limit := math.MaxUint32 - extra + // Loop until we read something inside the limit + for { + res := readUint32(rng) + if res > limit { + continue + } + return (res % size) + start + } +} diff --git a/network/gateway/hostpool_test.go b/network/gateway/hostpool_test.go new file mode 100644 index 0000000000000000000000000000000000000000..8181e6d49f0ea5522e8ae33eb6ac143516c45a26 --- /dev/null +++ b/network/gateway/hostpool_test.go @@ -0,0 +1,832 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package gateway + +import ( + "fmt" + "gitlab.com/elixxir/client/storage" + "gitlab.com/elixxir/comms/network" + "gitlab.com/elixxir/crypto/fastRNG" + "gitlab.com/xx_network/comms/connect" + "gitlab.com/xx_network/crypto/csprng" + "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/ndf" + "reflect" + "testing" +) + +// Unit test +func TestNewHostPool(t *testing.T) { + manager := newMockManager() + rng := fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG) + testNdf := getTestNdf(t) + testStorage := storage.InitTestingSession(t) + addGwChan := make(chan network.NodeGateway) + params := DefaultPoolParams() + params.MaxPoolSize = uint32(len(testNdf.Gateways)) + + // Pull all gateways from ndf into host manager + for _, gw := range testNdf.Gateways { + + gwId, err := id.Unmarshal(gw.ID) + if err != nil { + t.Errorf("Failed to unmarshal ID in mock ndf: %v", err) + } + // Add mock gateway to manager + _, err = manager.AddHost(gwId, "", nil, connect.GetDefaultHostParams()) + if err != nil { + t.Errorf("Could not add mock host to manager: %v", err) + t.FailNow() + } + + } + + // Call the constructor + _, err := newHostPool(params, rng, testNdf, manager, + testStorage, addGwChan) + if err != nil { + t.Fatalf("Failed to create mock host pool: %v", err) + } +} + +// Unit test +func TestHostPool_ManageHostPool(t *testing.T) { + manager := newMockManager() + rng := fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG) + testNdf := getTestNdf(t) + testStorage := storage.InitTestingSession(t) + addGwChan := make(chan network.NodeGateway) + + // Construct custom params + params := DefaultPoolParams() + params.MaxPoolSize = uint32(len(testNdf.Gateways)) + + // Pull all gateways from ndf into host manager + for _, gw := range testNdf.Gateways { + + gwId, err := id.Unmarshal(gw.ID) + if err != nil { + t.Errorf("Failed to unmarshal ID in mock ndf: %v", err) + } + // Add mock gateway to manager + _, err = manager.AddHost(gwId, gw.Address, nil, connect.GetDefaultHostParams()) + if err != nil { + t.Errorf("Could not add mock host to manager: %v", err) + t.FailNow() + } + + } + + // Call the constructor + testPool, err := newHostPool(params, rng, testNdf, manager, + testStorage, addGwChan) + if err != nil { + t.Fatalf("Failed to create mock host pool: %v", err) + } + + // Construct a list of new gateways/nodes to add to ndf + newGatewayLen := len(testNdf.Gateways) + newGateways := make([]ndf.Gateway, newGatewayLen) + newNodes := make([]ndf.Node, newGatewayLen) + for i := 0; i < newGatewayLen; i++ { + // Construct gateways + gwId := id.NewIdFromUInt(uint64(100+i), id.Gateway, t) + newGateways[i] = ndf.Gateway{ID: gwId.Bytes()} + // Construct nodes + nodeId := gwId.DeepCopy() + nodeId.SetType(id.Node) + newNodes[i] = ndf.Node{ID: nodeId.Bytes()} + + } + + newNdf := getTestNdf(t) + // Update the ndf, removing some gateways at a cutoff + newNdf.Gateways = newGateways + newNdf.Nodes = newNodes + + testPool.UpdateNdf(newNdf) + + // Check that old gateways are not in pool + for _, ndfGw := range testNdf.Gateways { + gwId, err := id.Unmarshal(ndfGw.ID) + if err != nil { + t.Errorf("Failed to marshal gateway id for %v", ndfGw) + } + if _, ok := testPool.hostMap[*gwId]; ok { + t.Errorf("Expected gateway %v to be removed from pool", gwId) + } + } +} + +// Full happy path test +func TestHostPool_ReplaceHost(t *testing.T) { + manager := newMockManager() + testNdf := getTestNdf(t) + newIndex := uint32(20) + + // Construct a manager (bypass business logic in constructor) + hostPool := &HostPool{ + manager: manager, + hostList: make([]*connect.Host, newIndex+1), + hostMap: make(map[id.ID]uint32), + ndf: testNdf, + } + + /* "Replace" a host with no entry */ + + // Pull a gateway ID from the ndf + gwIdOne, err := id.Unmarshal(testNdf.Gateways[0].ID) + if err != nil { + t.Errorf("Failed to unmarshal ID in mock ndf: %v", err) + } + + // Add mock gateway to manager + _, err = manager.AddHost(gwIdOne, "", nil, connect.GetDefaultHostParams()) + if err != nil { + t.Errorf("Could not add mock host to manager: %v", err) + } + + // "Replace" (insert) the host + err = hostPool.replaceHost(gwIdOne, newIndex) + if err != nil { + t.Errorf("Could not replace host: %v", err) + } + + // Check the state of the map has been correctly updated + retrievedIndex, ok := hostPool.hostMap[*gwIdOne] + if !ok { + t.Errorf("Expected insertion of gateway ID into map") + } + if retrievedIndex != newIndex { + t.Errorf("Index pulled from map not expected value."+ + "\n\tExpected: %d"+ + "\n\tReceived: %d", newIndex, retrievedIndex) + } + + // Check that the state of the list list been correctly updated + retrievedHost := hostPool.hostList[newIndex] + if !gwIdOne.Cmp(retrievedHost.GetId()) { + t.Errorf("Id pulled from list is not expected."+ + "\n\tExpected: %s"+ + "\n\tReceived: %s", gwIdOne, retrievedHost.GetId()) + } + + /* Replace the initial host with a new host */ + + // Pull a different gateway ID from the ndf + gwIdTwo, err := id.Unmarshal(testNdf.Gateways[1].ID) + if err != nil { + t.Errorf("Failed to unmarshal ID in mock ndf: %v", err) + } + + // Add second mock gateway to manager + _, err = manager.AddHost(gwIdTwo, "", nil, connect.GetDefaultHostParams()) + if err != nil { + t.Errorf("Could not add mock host to manager: %v", err) + } + + // Replace the old host + err = hostPool.replaceHost(gwIdTwo, newIndex) + if err != nil { + t.Errorf("Could not replace host: %v", err) + } + + // Check that the state of the list been correctly updated for new host + retrievedHost = hostPool.hostList[newIndex] + if !gwIdTwo.Cmp(retrievedHost.GetId()) { + t.Errorf("Id pulled from list is not expected."+ + "\n\tExpected: %s"+ + "\n\tReceived: %s", gwIdTwo, retrievedHost.GetId()) + } + + // Check the state of the map has been correctly removed for the old gateway + retrievedOldIndex, ok := hostPool.hostMap[*gwIdOne] + if ok { + t.Errorf("Exoected old gateway to be cleared from map") + } + if retrievedOldIndex != 0 { + t.Errorf("Index pulled from map with old gateway as the key "+ + "was not cleared."+ + "\n\tExpected: %d"+ + "\n\tReceived: %d", 0, retrievedOldIndex) + } + + // Check the state of the map has been correctly updated for the old gateway + retrievedIndex, ok = hostPool.hostMap[*gwIdTwo] + if !ok { + t.Errorf("Expected insertion of gateway ID into map") + } + if retrievedIndex != newIndex { + t.Errorf("Index pulled from map using new gateway as the key "+ + "was not updated."+ + "\n\tExpected: %d"+ + "\n\tReceived: %d", newIndex, retrievedIndex) + } + +} + +// Error path, could not get host +func TestHostPool_ReplaceHost_Error(t *testing.T) { + manager := newMockManager() + + // Construct a manager (bypass business logic in constructor) + hostPool := &HostPool{ + manager: manager, + hostList: make([]*connect.Host, 1), + hostMap: make(map[id.ID]uint32), + } + + // Construct an unknown gateway ID to the manager + gatewayId := id.NewIdFromString("BadGateway", id.Gateway, t) + + err := hostPool.replaceHost(gatewayId, 0) + if err == nil { + t.Errorf("Expected error in happy path: Should not be able to find a host") + } + +} + +// Unit test +func TestHostPool_ForceReplace(t *testing.T) { + manager := newMockManager() + rng := fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG) + testNdf := getTestNdf(t) + testStorage := storage.InitTestingSession(t) + addGwChan := make(chan network.NodeGateway) + + // Construct custom params + params := DefaultPoolParams() + params.PoolSize = uint32(len(testNdf.Gateways)) + + // Pull all gateways from ndf into host manager + for _, gw := range testNdf.Gateways { + + gwId, err := id.Unmarshal(gw.ID) + if err != nil { + t.Errorf("Failed to unmarshal ID in mock ndf: %v", err) + } + // Add mock gateway to manager + _, err = manager.AddHost(gwId, gw.Address, nil, connect.GetDefaultHostParams()) + if err != nil { + t.Errorf("Could not add mock host to manager: %v", err) + t.FailNow() + } + + } + + // Call the constructor + testPool, err := newHostPool(params, rng, testNdf, manager, + testStorage, addGwChan) + if err != nil { + t.Fatalf("Failed to create mock host pool: %v", err) + } + + // Add all gateways to hostPool's map + for index, gw := range testNdf.Gateways { + gwId, err := id.Unmarshal(gw.ID) + if err != nil { + t.Fatalf("Failed to unmarshal ID in mock ndf: %v", err) + } + + err = testPool.replaceHost(gwId, uint32(index)) + if err != nil { + t.Fatalf("Failed to replace host in set-up: %v", err) + } + } + + oldGatewayIndex := 0 + oldHost := testPool.hostList[oldGatewayIndex] + + // Force replace the gateway at a given index + err = testPool.forceReplace(uint32(oldGatewayIndex)) + if err != nil { + t.Errorf("Failed to force replace: %v", err) + } + + // Ensure that old gateway has been removed from the map + if _, ok := testPool.hostMap[*oldHost.GetId()]; ok { + t.Errorf("Expected old host to be removed from map") + } + + // Ensure we are disconnected from the old host + if isConnected, _ := oldHost.Connected(); isConnected { + t.Errorf("Failed to disconnect from old host %s", oldHost) + } + +} + +// Unit test +func TestHostPool_CheckReplace(t *testing.T) { + manager := newMockManager() + rng := fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG) + testNdf := getTestNdf(t) + testStorage := storage.InitTestingSession(t) + addGwChan := make(chan network.NodeGateway) + + // Construct custom params + params := DefaultPoolParams() + params.MaxPoolSize = uint32(len(testNdf.Gateways)) - 5 + + // Pull all gateways from ndf into host manager + for _, gw := range testNdf.Gateways { + + gwId, err := id.Unmarshal(gw.ID) + if err != nil { + t.Errorf("Failed to unmarshal ID in mock ndf: %v", err) + } + // Add mock gateway to manager + _, err = manager.AddHost(gwId, gw.Address, nil, connect.GetDefaultHostParams()) + if err != nil { + t.Errorf("Could not add mock host to manager: %v", err) + t.FailNow() + } + + } + + // Call the constructor + testPool, err := newHostPool(params, rng, testNdf, manager, + testStorage, addGwChan) + if err != nil { + t.Fatalf("Failed to create mock host pool: %v", err) + } + + // Call check replace + oldGatewayIndex := 0 + oldHost := testPool.hostList[oldGatewayIndex] + expectedError := fmt.Errorf(errorsList[0]) + err = testPool.checkReplace(oldHost.GetId(), expectedError) + if err != nil { + t.Errorf("Failed to check replace: %v", err) + } + + // Ensure that old gateway has been removed from the map + if _, ok := testPool.hostMap[*oldHost.GetId()]; ok { + t.Errorf("Expected old host to be removed from map") + } + + // Ensure we are disconnected from the old host + if isConnected, _ := oldHost.Connected(); isConnected { + t.Errorf("Failed to disconnect from old host %s", oldHost) + } + + // Check that an error not in the global list results in a no-op + goodGatewayIndex := 0 + goodGateway := testPool.hostList[goodGatewayIndex] + unexpectedErr := fmt.Errorf("not in global error list") + err = testPool.checkReplace(oldHost.GetId(), unexpectedErr) + if err != nil { + t.Errorf("Failed to check replace: %v", err) + } + + // Ensure that gateway with an unexpected error was not modified + if _, ok := testPool.hostMap[*goodGateway.GetId()]; !ok { + t.Errorf("Expected gateway with non-expected error to not be modified") + } + + // Ensure gateway host has not been disconnected + if isConnected, _ := oldHost.Connected(); isConnected { + t.Errorf("Should not disconnect from %s", oldHost) + } + +} + +// Unit test +func TestHostPool_UpdateNdf(t *testing.T) { + manager := newMockManager() + testNdf := getTestNdf(t) + newIndex := uint32(20) + + // Construct a manager (bypass business logic in constructor) + hostPool := &HostPool{ + manager: manager, + hostList: make([]*connect.Host, newIndex+1), + hostMap: make(map[id.ID]uint32), + ndf: testNdf, + storage: storage.InitTestingSession(t), + } + + // Construct a new Ndf different from original one above + newNdf := getTestNdf(t) + newGateway := ndf.Gateway{ + ID: id.NewIdFromUInt(27, id.Gateway, t).Bytes(), + } + newNode := ndf.Node{ + ID: id.NewIdFromUInt(27, id.Node, t).Bytes(), + } + newNdf.Gateways = append(newNdf.Gateways, newGateway) + newNdf.Nodes = append(newNdf.Nodes, newNode) + + // Update pool with the new Ndf + hostPool.UpdateNdf(newNdf) + + // Check that the host pool's ndf has been modified properly + if !reflect.DeepEqual(newNdf, hostPool.ndf) { + t.Errorf("Host pool ndf not updated to new ndf.") + } +} + +// Full test +func TestHostPool_GetPreferred(t *testing.T) { + manager := newMockManager() + rng := fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG) + testNdf := getTestNdf(t) + testStorage := storage.InitTestingSession(t) + addGwChan := make(chan network.NodeGateway) + params := DefaultPoolParams() + params.PoolSize = uint32(len(testNdf.Gateways)) + + // Pull all gateways from ndf into host manager + hostMap := make(map[id.ID]bool, 0) + targets := make([]*id.ID, 0) + for _, gw := range testNdf.Gateways { + + gwId, err := id.Unmarshal(gw.ID) + if err != nil { + t.Fatalf("Failed to unmarshal ID in mock ndf: %v", err) + } + // Add mock gateway to manager + _, err = manager.AddHost(gwId, gw.Address, nil, connect.GetDefaultHostParams()) + if err != nil { + t.Fatalf("Could not add mock host to manager: %v", err) + } + + hostMap[*gwId] = true + targets = append(targets, gwId) + + } + + // Call the constructor + testPool, err := newHostPool(params, rng, testNdf, manager, + testStorage, addGwChan) + if err != nil { + t.Fatalf("Failed to create mock host pool: %v", err) + } + + retrievedList := testPool.getPreferred(targets) + if len(retrievedList) != len(targets) { + t.Errorf("Requested list did not output requested length."+ + "\n\tExpected: %d"+ + "\n\tReceived: %v", len(targets), len(retrievedList)) + } + + // In case where all requested gateways are present + // ensure requested hosts were returned + for _, h := range retrievedList { + if !hostMap[*h.GetId()] { + t.Errorf("A target gateways which should have been returned was not."+ + "\n\tExpected: %v", h.GetId()) + } + } + + // Replace a request with a gateway not in pool + targets[3] = id.NewIdFromUInt(74, id.Gateway, t) + retrievedList = testPool.getPreferred(targets) + if len(retrievedList) != len(targets) { + t.Errorf("Requested list did not output requested length."+ + "\n\tExpected: %d"+ + "\n\tReceived: %v", len(targets), len(retrievedList)) + } + + // In case where a requested gateway is not present + for _, h := range retrievedList { + if h.GetId().Cmp(targets[3]) { + t.Errorf("Should not have returned ID not in pool") + } + } + +} + +// Unit test +func TestHostPool_GetAny(t *testing.T) { + manager := newMockManager() + rng := fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG) + testNdf := getTestNdf(t) + testStorage := storage.InitTestingSession(t) + addGwChan := make(chan network.NodeGateway) + params := DefaultPoolParams() + params.MaxPoolSize = uint32(len(testNdf.Gateways)) + + // Pull all gateways from ndf into host manager + for _, gw := range testNdf.Gateways { + + gwId, err := id.Unmarshal(gw.ID) + if err != nil { + t.Fatalf("Failed to unmarshal ID in mock ndf: %v", err) + } + // Add mock gateway to manager + _, err = manager.AddHost(gwId, gw.Address, nil, connect.GetDefaultHostParams()) + if err != nil { + t.Fatalf("Could not add mock host to manager: %v", err) + } + + } + + // Call the constructor + testPool, err := newHostPool(params, rng, testNdf, manager, + testStorage, addGwChan) + if err != nil { + t.Fatalf("Failed to create mock host pool: %v", err) + } + + requested := 3 + anyList := testPool.getAny(uint32(requested), nil) + if len(anyList) != requested { + t.Errorf("GetAnyList did not get requested length."+ + "\n\tExpected: %v"+ + "\n\tReceived: %v", requested, len(anyList)) + } + + for _, h := range anyList { + _, ok := manager.GetHost(h.GetId()) + if !ok { + t.Errorf("Host %s in retrieved list not in manager", h) + } + } + + // Request more than are in host list + largeRequest := uint32(requested * 1000) + largeRetrieved := testPool.getAny(largeRequest, nil) + if len(largeRetrieved) != len(testPool.hostList) { + t.Errorf("Large request should result in a list of all in host list") + } + +} + +// Unit test +func TestHostPool_ForceAdd(t *testing.T) { + manager := newMockManager() + rng := fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG) + testNdf := getTestNdf(t) + testStorage := storage.InitTestingSession(t) + addGwChan := make(chan network.NodeGateway) + params := DefaultPoolParams() + params.PoolSize = uint32(len(testNdf.Gateways)) + + // Pull all gateways from ndf into host manager + for _, gw := range testNdf.Gateways { + + gwId, err := id.Unmarshal(gw.ID) + if err != nil { + t.Fatalf("Failed to unmarshal ID in mock ndf: %v", err) + } + // Add mock gateway to manager + _, err = manager.AddHost(gwId, gw.Address, nil, connect.GetDefaultHostParams()) + if err != nil { + t.Fatalf("Could not add mock host to manager: %v", err) + } + + } + + // Call the constructor + testPool, err := newHostPool(params, rng, testNdf, manager, + testStorage, addGwChan) + if err != nil { + t.Fatalf("Failed to create mock host pool: %v", err) + } + + // Construct a new gateway to add + gwId := id.NewIdFromUInt(uint64(100), id.Gateway, t) + // Add mock gateway to manager + _, err = manager.AddHost(gwId, "", nil, connect.GetDefaultHostParams()) + if err != nil { + t.Fatalf("Could not add mock host to manager: %v", err) + } + + // forceAdd gateway + err = testPool.forceAdd(gwId) + if err != nil { + t.Errorf("Could not add gateways: %v", err) + } + + // check that gateways have been added to the map + if _, ok := testPool.hostMap[*gwId]; !ok { + t.Errorf("Failed to forcefully add new gateway ID: %v", gwId) + } +} + +// Unit test which only adds information to ndf +func TestHostPool_UpdateConns_AddGateways(t *testing.T) { + manager := newMockManager() + rng := fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG) + testNdf := getTestNdf(t) + testStorage := storage.InitTestingSession(t) + addGwChan := make(chan network.NodeGateway) + params := DefaultPoolParams() + params.MaxPoolSize = uint32(len(testNdf.Gateways)) + + // Pull all gateways from ndf into host manager + for _, gw := range testNdf.Gateways { + + gwId, err := id.Unmarshal(gw.ID) + if err != nil { + t.Fatalf("Failed to unmarshal ID in mock ndf: %v", err) + } + // Add mock gateway to manager + _, err = manager.AddHost(gwId, gw.Address, nil, connect.GetDefaultHostParams()) + if err != nil { + t.Fatalf("Could not add mock host to manager: %v", err) + } + + } + + // Call the constructor + testPool, err := newHostPool(params, rng, testNdf, manager, + testStorage, addGwChan) + if err != nil { + t.Fatalf("Failed to create mock host pool: %v", err) + } + + // Construct a list of new gateways/nodes to add to ndf + newGatewayLen := 10 + newGateways := make([]ndf.Gateway, newGatewayLen) + newNodes := make([]ndf.Node, newGatewayLen) + for i := 0; i < newGatewayLen; i++ { + // Construct gateways + gwId := id.NewIdFromUInt(uint64(100+i), id.Gateway, t) + newGateways[i] = ndf.Gateway{ID: gwId.Bytes()} + // Construct nodes + nodeId := gwId.DeepCopy() + nodeId.SetType(id.Node) + newNodes[i] = ndf.Node{ID: nodeId.Bytes()} + + } + + // Update the ndf + newNdf := getTestNdf(t) + newNdf.Gateways = append(newNdf.Gateways, newGateways...) + newNdf.Nodes = append(newNdf.Nodes, newNodes...) + + testPool.UpdateNdf(newNdf) + + // Update the connections + err = testPool.updateConns() + if err != nil { + t.Errorf("Failed to update connections: %v", err) + } + + // Check that new gateways are in manager + for _, ndfGw := range newGateways { + gwId, err := id.Unmarshal(ndfGw.ID) + if err != nil { + t.Errorf("Failed to marshal gateway id for %v", ndfGw) + } + _, ok := testPool.getSpecific(gwId) + if !ok { + t.Errorf("Failed to find gateway %v in manager", gwId) + } + } + +} + +// Unit test which only adds information to ndf +func TestHostPool_UpdateConns_RemoveGateways(t *testing.T) { + manager := newMockManager() + rng := fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG) + testNdf := getTestNdf(t) + testStorage := storage.InitTestingSession(t) + addGwChan := make(chan network.NodeGateway) + params := DefaultPoolParams() + params.MaxPoolSize = uint32(len(testNdf.Gateways)) + + // Pull all gateways from ndf into host manager + for _, gw := range testNdf.Gateways { + + gwId, err := id.Unmarshal(gw.ID) + if err != nil { + t.Errorf("Failed to unmarshal ID in mock ndf: %v", err) + } + // Add mock gateway to manager + _, err = manager.AddHost(gwId, gw.Address, nil, connect.GetDefaultHostParams()) + if err != nil { + t.Errorf("Could not add mock host to manager: %v", err) + t.FailNow() + } + + } + + // Call the constructor + testPool, err := newHostPool(params, rng, testNdf, manager, + testStorage, addGwChan) + if err != nil { + t.Fatalf("Failed to create mock host pool: %v", err) + } + + // Construct a list of new gateways/nodes to add to ndf + newGatewayLen := len(testNdf.Gateways) + newGateways := make([]ndf.Gateway, newGatewayLen) + newNodes := make([]ndf.Node, newGatewayLen) + for i := 0; i < newGatewayLen; i++ { + // Construct gateways + gwId := id.NewIdFromUInt(uint64(100+i), id.Gateway, t) + newGateways[i] = ndf.Gateway{ID: gwId.Bytes()} + // Construct nodes + nodeId := gwId.DeepCopy() + nodeId.SetType(id.Node) + newNodes[i] = ndf.Node{ID: nodeId.Bytes()} + + } + + // Update the ndf, replacing old data entirely + newNdf := getTestNdf(t) + newNdf.Gateways = newGateways + newNdf.Nodes = newNodes + + testPool.UpdateNdf(newNdf) + + // Update the connections + err = testPool.updateConns() + if err != nil { + t.Errorf("Failed to update connections: %v", err) + } + + // Check that old gateways are not in pool + for _, ndfGw := range testNdf.Gateways { + gwId, err := id.Unmarshal(ndfGw.ID) + if err != nil { + t.Errorf("Failed to marshal gateway id for %v", ndfGw) + } + if _, ok := testPool.hostMap[*gwId]; ok { + t.Errorf("Expected gateway %v to be removed from pool", gwId) + } + } +} + +// Unit test +func TestHostPool_AddGateway(t *testing.T) { + manager := newMockManager() + testNdf := getTestNdf(t) + newIndex := uint32(20) + params := DefaultPoolParams() + params.MaxPoolSize = uint32(len(testNdf.Gateways)) + + // Construct a manager (bypass business logic in constructor) + hostPool := &HostPool{ + manager: manager, + hostList: make([]*connect.Host, newIndex+1), + hostMap: make(map[id.ID]uint32), + ndf: testNdf, + addGatewayChan: make(chan network.NodeGateway), + storage: storage.InitTestingSession(t), + } + + ndfIndex := 0 + + gwId, err := id.Unmarshal(testNdf.Gateways[ndfIndex].ID) + if err != nil { + t.Errorf("Failed to unmarshal ID in mock ndf: %v", err) + } + + hostPool.addGateway(gwId, ndfIndex) + + _, ok := manager.GetHost(gwId) + if !ok { + t.Errorf("Unsuccessfully added host to manager") + } +} + +// Unit test +func TestHostPool_RemoveGateway(t *testing.T) { + manager := newMockManager() + testNdf := getTestNdf(t) + newIndex := uint32(20) + params := DefaultPoolParams() + params.MaxPoolSize = uint32(len(testNdf.Gateways)) + + // Construct a manager (bypass business logic in constructor) + hostPool := &HostPool{ + manager: manager, + hostList: make([]*connect.Host, newIndex+1), + hostMap: make(map[id.ID]uint32), + ndf: testNdf, + addGatewayChan: make(chan network.NodeGateway), + storage: storage.InitTestingSession(t), + rng: fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG), + } + + ndfIndex := 0 + + gwId, err := id.Unmarshal(testNdf.Gateways[ndfIndex].ID) + if err != nil { + t.Errorf("Failed to unmarshal ID in mock ndf: %v", err) + } + + // Manually add host information + hostPool.addGateway(gwId, ndfIndex) + + // Call the removal + hostPool.removeGateway(gwId) + + // Check that the map and list have been updated + if hostPool.hostList[ndfIndex] != nil { + t.Errorf("Host list index was not set to nil after removal") + } + + if _, ok := hostPool.hostMap[*gwId]; ok { + t.Errorf("Host map did not delete host entry") + } +} diff --git a/network/gateway/sender.go b/network/gateway/sender.go new file mode 100644 index 0000000000000000000000000000000000000000..dcfcdbb5b88a0be02b67f4f61dfcaa01c1169489 --- /dev/null +++ b/network/gateway/sender.go @@ -0,0 +1,132 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2021 Privategrity Corporation / +// / +// All rights reserved. / +//////////////////////////////////////////////////////////////////////////////// + +// Contains gateway message sending wrappers + +package gateway + +import ( + "github.com/pkg/errors" + jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/client/storage" + "gitlab.com/elixxir/comms/network" + "gitlab.com/elixxir/crypto/fastRNG" + "gitlab.com/xx_network/comms/connect" + "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/ndf" +) + +// Sender Object used for sending that wraps the HostPool for providing destinations +type Sender struct { + *HostPool +} + +// NewSender Create a new Sender object wrapping a HostPool object +func NewSender(poolParams PoolParams, rng *fastRNG.StreamGenerator, ndf *ndf.NetworkDefinition, getter HostManager, + storage *storage.Session, addGateway chan network.NodeGateway) (*Sender, error) { + + hostPool, err := newHostPool(poolParams, rng, ndf, getter, storage, addGateway) + if err != nil { + return nil, err + } + return &Sender{hostPool}, nil +} + +// SendToSpecific Call given sendFunc to a specific Host in the HostPool, +// attempting with up to numProxies destinations in case of failure +func (s *Sender) SendToSpecific(target *id.ID, + sendFunc func(host *connect.Host, target *id.ID) (interface{}, bool, error)) (interface{}, error) { + host, ok := s.getSpecific(target) + if ok { + result, didAbort, err := sendFunc(host, target) + if err == nil { + return result, s.forceAdd(target) + } else { + if didAbort { + return nil, errors.WithMessagef(err, "Aborted SendToSpecific gateway %s", host.GetId().String()) + } + jww.WARN.Printf("Unable to SendToSpecific %s: %s", host.GetId().String(), err) + } + } + + proxies := s.getAny(s.poolParams.ProxyAttempts, []*id.ID{target}) + for i := range proxies { + result, didAbort, err := sendFunc(proxies[i], target) + if err == nil { + return result, nil + } else { + if didAbort { + return nil, errors.WithMessagef(err, "Aborted SendToSpecific gateway proxy %s", + host.GetId().String()) + } + jww.WARN.Printf("Unable to SendToSpecific proxy %s: %s", proxies[i].GetId().String(), err) + err = s.checkReplace(proxies[i].GetId(), err) + if err != nil { + jww.ERROR.Printf("Unable to checkReplace: %+v", err) + } + } + } + + return nil, errors.Errorf("Unable to send to specific with proxies") +} + +// SendToAny Call given sendFunc to any Host in the HostPool, attempting with up to numProxies destinations +func (s *Sender) SendToAny(sendFunc func(host *connect.Host) (interface{}, error)) (interface{}, error) { + + proxies := s.getAny(s.poolParams.ProxyAttempts, nil) + for i := range proxies { + result, err := sendFunc(proxies[i]) + if err == nil { + return result, nil + } else { + jww.WARN.Printf("Unable to SendToAny %s: %s", proxies[i].GetId().String(), err) + err = s.checkReplace(proxies[i].GetId(), err) + if err != nil { + jww.ERROR.Printf("Unable to checkReplace: %+v", err) + } + } + } + + return nil, errors.Errorf("Unable to send to any proxies") +} + +// SendToPreferred Call given sendFunc to any Host in the HostPool, attempting with up to numProxies destinations +func (s *Sender) SendToPreferred(targets []*id.ID, + sendFunc func(host *connect.Host, target *id.ID) (interface{}, error)) (interface{}, error) { + + targetHosts := s.getPreferred(targets) + for i := range targetHosts { + result, err := sendFunc(targetHosts[i], targets[i]) + if err == nil { + return result, nil + } else { + jww.WARN.Printf("Unable to SendToPreferred %s via %s: %s", + targets[i], targetHosts[i].GetId(), err) + err = s.checkReplace(targetHosts[i].GetId(), err) + if err != nil { + jww.ERROR.Printf("Unable to checkReplace: %+v", err) + } + } + } + + proxies := s.getAny(s.poolParams.ProxyAttempts, targets) + for i := range proxies { + target := targets[i%len(targets)].DeepCopy() + result, err := sendFunc(proxies[i], target) + if err == nil { + return result, nil + } else { + jww.WARN.Printf("Unable to SendToPreferred %s via proxy "+ + "%s: %s", target, proxies[i].GetId(), err) + err = s.checkReplace(proxies[i].GetId(), err) + if err != nil { + jww.ERROR.Printf("Unable to checkReplace: %+v", err) + } + } + } + + return nil, errors.Errorf("Unable to send to any preferred") +} diff --git a/network/gateway/sender_test.go b/network/gateway/sender_test.go new file mode 100644 index 0000000000000000000000000000000000000000..4dd5d49c02ea525e547164f8c2c6392b526b30de --- /dev/null +++ b/network/gateway/sender_test.go @@ -0,0 +1,249 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package gateway + +import ( + "gitlab.com/elixxir/client/storage" + "gitlab.com/elixxir/comms/network" + "gitlab.com/elixxir/crypto/fastRNG" + "gitlab.com/xx_network/comms/connect" + "gitlab.com/xx_network/crypto/csprng" + "gitlab.com/xx_network/primitives/id" + "reflect" + "testing" +) + +// Unit test +func TestNewSender(t *testing.T) { + manager := newMockManager() + rng := fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG) + testNdf := getTestNdf(t) + testStorage := storage.InitTestingSession(t) + addGwChan := make(chan network.NodeGateway) + params := DefaultPoolParams() + params.MaxPoolSize = uint32(len(testNdf.Gateways)) + + _, err := NewSender(params, rng, testNdf, manager, testStorage, addGwChan) + if err != nil { + t.Fatalf("Failed to create mock sender: %v", err) + } +} + +// Unit test +func TestSender_SendToAny(t *testing.T) { + manager := newMockManager() + rng := fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG) + testNdf := getTestNdf(t) + testStorage := storage.InitTestingSession(t) + addGwChan := make(chan network.NodeGateway) + params := DefaultPoolParams() + params.PoolSize = uint32(len(testNdf.Gateways)) + + // Pull all gateways from ndf into host manager + for _, gw := range testNdf.Gateways { + + gwId, err := id.Unmarshal(gw.ID) + if err != nil { + t.Fatalf("Failed to unmarshal ID in mock ndf: %v", err) + } + // Add mock gateway to manager + _, err = manager.AddHost(gwId, gw.Address, nil, connect.GetDefaultHostParams()) + if err != nil { + t.Fatalf("Could not add mock host to manager: %v", err) + } + + } + + sender, err := NewSender(params, rng, testNdf, manager, testStorage, addGwChan) + if err != nil { + t.Fatalf("Failed to create mock sender: %v", err) + } + + // Add all gateways to hostPool's map + for index, gw := range testNdf.Gateways { + gwId, err := id.Unmarshal(gw.ID) + if err != nil { + t.Fatalf("Failed to unmarshal ID in mock ndf: %v", err) + } + + err = sender.replaceHost(gwId, uint32(index)) + if err != nil { + t.Fatalf("Failed to replace host in set-up: %v", err) + } + } + + // Test sendToAny with test interfaces + result, err := sender.SendToAny(SendToAny_HappyPath) + if err != nil { + t.Errorf("Should not error in SendToAny happy path: %v", err) + } + + if !reflect.DeepEqual(result, happyPathReturn) { + t.Errorf("Expected result not returnev via SendToAny interface."+ + "\n\tExpected: %v"+ + "\n\tReceived: %v", happyPathReturn, result) + } + + _, err = sender.SendToAny(SendToAny_KnownError) + if err == nil { + t.Fatalf("Expected error path did not receive error") + } + + _, err = sender.SendToAny(SendToAny_UnknownError) + if err == nil { + t.Fatalf("Expected error path did not receive error") + } + +} + +// Unit test +func TestSender_SendToPreferred(t *testing.T) { + manager := newMockManager() + rng := fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG) + testNdf := getTestNdf(t) + testStorage := storage.InitTestingSession(t) + addGwChan := make(chan network.NodeGateway) + params := DefaultPoolParams() + params.PoolSize = uint32(len(testNdf.Gateways)) - 5 + + // Do not test proxy attempts code in this test + // (self contain to code specific in sendPreferred) + params.ProxyAttempts = 0 + + // Pull all gateways from ndf into host manager + for _, gw := range testNdf.Gateways { + + gwId, err := id.Unmarshal(gw.ID) + if err != nil { + t.Fatalf("Failed to unmarshal ID in mock ndf: %v", err) + } + // Add mock gateway to manager + _, err = manager.AddHost(gwId, gw.Address, nil, connect.GetDefaultHostParams()) + if err != nil { + t.Fatalf("Could not add mock host to manager: %v", err) + } + + } + + sender, err := NewSender(params, rng, testNdf, manager, testStorage, addGwChan) + if err != nil { + t.Fatalf("Failed to create mock sender: %v", err) + } + + preferredIndex := 0 + preferredHost := sender.hostList[preferredIndex] + + // Happy path + result, err := sender.SendToPreferred([]*id.ID{preferredHost.GetId()}, SendToPreferred_HappyPath) + if err != nil { + t.Errorf("Should not error in SendToPreferred happy path: %v", err) + } + + if !reflect.DeepEqual(result, happyPathReturn) { + t.Errorf("Expected result not returnev via SendToPreferred interface."+ + "\n\tExpected: %v"+ + "\n\tReceived: %v", happyPathReturn, result) + } + + // Call a send which returns an error which triggers replacement + _, err = sender.SendToPreferred([]*id.ID{preferredHost.GetId()}, SendToPreferred_KnownError) + if err == nil { + t.Fatalf("Expected error path did not receive error") + } + + // Check the host has been replaced + if _, ok := sender.hostMap[*preferredHost.GetId()]; ok { + t.Errorf("Expected host %s to be removed due to error", preferredHost) + } + + // Ensure we are disconnected from the old host + if isConnected, _ := preferredHost.Connected(); isConnected { + t.Errorf("ForceReplace error: Failed to disconnect from old host %s", preferredHost) + } + + // Get a new host to test on + preferredIndex = 4 + preferredHost = sender.hostList[preferredIndex] + + // Unknown error return will not trigger replacement + _, err = sender.SendToPreferred([]*id.ID{preferredHost.GetId()}, SendToPreferred_UnknownError) + if err == nil { + t.Fatalf("Expected error path did not receive error") + } + + // Check the host has not been replaced + if _, ok := sender.hostMap[*preferredHost.GetId()]; !ok { + t.Errorf("Host %s should not have been removed due on an unknown error", preferredHost) + } + + // Ensure we are disconnected from the old host + if isConnected, _ := preferredHost.Connected(); isConnected { + t.Errorf("Should not disconnect from %s", preferredHost) + } + +} + +func TestSender_SendToSpecific(t *testing.T) { + manager := newMockManager() + rng := fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG) + testNdf := getTestNdf(t) + testStorage := storage.InitTestingSession(t) + addGwChan := make(chan network.NodeGateway) + params := DefaultPoolParams() + params.MaxPoolSize = uint32(len(testNdf.Gateways)) - 5 + + // Do not test proxy attempts code in this test + // (self contain to code specific in sendPreferred) + params.ProxyAttempts = 0 + + // Pull all gateways from ndf into host manager + for _, gw := range testNdf.Gateways { + + gwId, err := id.Unmarshal(gw.ID) + if err != nil { + t.Fatalf("Failed to unmarshal ID in mock ndf: %v", err) + } + // Add mock gateway to manager + _, err = manager.AddHost(gwId, gw.Address, nil, connect.GetDefaultHostParams()) + if err != nil { + t.Fatalf("Could not add mock host to manager: %v", err) + } + + } + + sender, err := NewSender(params, rng, testNdf, manager, testStorage, addGwChan) + if err != nil { + t.Fatalf("Failed to create mock sender: %v", err) + } + + preferredIndex := 0 + preferredHost := sender.hostList[preferredIndex] + + // Happy path + result, err := sender.SendToSpecific(preferredHost.GetId(), SendToSpecific_HappyPath) + if err != nil { + t.Errorf("Should not error in SendToSpecific happy path: %v", err) + } + + if !reflect.DeepEqual(result, happyPathReturn) { + t.Errorf("Expected result not returnev via SendToSpecific interface."+ + "\n\tExpected: %v"+ + "\n\tReceived: %v", happyPathReturn, result) + } + + // Ensure host is now in map + if _, ok := sender.hostMap[*preferredHost.GetId()]; !ok { + t.Errorf("Failed to forcefully add new gateway ID: %v", preferredHost.GetId()) + } + + _, err = sender.SendToSpecific(preferredHost.GetId(), SendToSpecific_Abort) + if err == nil { + t.Errorf("Expected sendSpecific to return an abort") + } + +} diff --git a/network/gateway/utils_test.go b/network/gateway/utils_test.go new file mode 100644 index 0000000000000000000000000000000000000000..0ec7dc11f8edde72a82affd3b718bcec121bf60c --- /dev/null +++ b/network/gateway/utils_test.go @@ -0,0 +1,162 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package gateway + +import ( + "fmt" + "gitlab.com/xx_network/comms/connect" + "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/ndf" +) + +// Mock structure adhering to HostManager to be used for happy path +type mockManager struct { + hosts map[string]*connect.Host +} + +// Constructor for mockManager +func newMockManager() *mockManager { + return &mockManager{ + hosts: make(map[string]*connect.Host), + } +} + +func (mhp *mockManager) GetHost(hostId *id.ID) (*connect.Host, bool) { + h, ok := mhp.hosts[hostId.String()] + return h, ok +} + +func (mhp *mockManager) AddHost(hid *id.ID, address string, + cert []byte, params connect.HostParams) (host *connect.Host, err error) { + host, err = connect.NewHost(hid, address, cert, params) + if err != nil { + return nil, err + } + + mhp.hosts[hid.String()] = host + + return +} + +func (mhp *mockManager) RemoveHost(hid *id.ID) { + delete(mhp.hosts, hid.String()) +} + +// Returns a mock +func getTestNdf(face interface{}) *ndf.NetworkDefinition { + return &ndf.NetworkDefinition{ + Gateways: []ndf.Gateway{{ + ID: id.NewIdFromUInt(0, id.Gateway, face)[:], + Address: "0.0.0.1", + }, { + ID: id.NewIdFromUInt(1, id.Gateway, face)[:], + Address: "0.0.0.2", + }, { + ID: id.NewIdFromUInt(2, id.Gateway, face)[:], + Address: "0.0.0.3", + }, { + ID: id.NewIdFromUInt(3, id.Gateway, face)[:], + Address: "0.0.0.1", + }, { + ID: id.NewIdFromUInt(4, id.Gateway, face)[:], + Address: "0.0.0.2", + }, { + ID: id.NewIdFromUInt(5, id.Gateway, face)[:], + Address: "0.0.0.3", + }, { + ID: id.NewIdFromUInt(6, id.Gateway, face)[:], + Address: "0.0.0.1", + }, { + ID: id.NewIdFromUInt(7, id.Gateway, face)[:], + Address: "0.0.0.2", + }, { + ID: id.NewIdFromUInt(8, id.Gateway, face)[:], + Address: "0.0.0.3", + }, { + ID: id.NewIdFromUInt(9, id.Gateway, face)[:], + Address: "0.0.0.1", + }, { + ID: id.NewIdFromUInt(10, id.Gateway, face)[:], + Address: "0.0.0.2", + }, { + ID: id.NewIdFromUInt(11, id.Gateway, face)[:], + Address: "0.0.0.3", + }}, + Nodes: []ndf.Node{{ + ID: id.NewIdFromUInt(0, id.Node, face)[:], + Address: "0.0.0.1", + }, { + ID: id.NewIdFromUInt(1, id.Node, face)[:], + Address: "0.0.0.2", + }, { + ID: id.NewIdFromUInt(2, id.Node, face)[:], + Address: "0.0.0.3", + }, { + ID: id.NewIdFromUInt(3, id.Node, face)[:], + Address: "0.0.0.1", + }, { + ID: id.NewIdFromUInt(4, id.Node, face)[:], + Address: "0.0.0.2", + }, { + ID: id.NewIdFromUInt(5, id.Node, face)[:], + Address: "0.0.0.3", + }, { + ID: id.NewIdFromUInt(6, id.Node, face)[:], + Address: "0.0.0.1", + }, { + ID: id.NewIdFromUInt(7, id.Node, face)[:], + Address: "0.0.0.2", + }, { + ID: id.NewIdFromUInt(8, id.Node, face)[:], + Address: "0.0.0.3", + }, { + ID: id.NewIdFromUInt(9, id.Node, face)[:], + Address: "0.0.0.1", + }, { + ID: id.NewIdFromUInt(10, id.Node, face)[:], + Address: "0.0.0.2", + }, { + ID: id.NewIdFromUInt(11, id.Node, face)[:], + Address: "0.0.0.3", + }}, + } +} + +const happyPathReturn = "happyPathReturn" + +func SendToPreferred_HappyPath(host *connect.Host, target *id.ID) (interface{}, error) { + return happyPathReturn, nil +} + +func SendToPreferred_KnownError(host *connect.Host, target *id.ID) (interface{}, error) { + return nil, fmt.Errorf(errorsList[0]) +} + +func SendToPreferred_UnknownError(host *connect.Host, target *id.ID) (interface{}, error) { + return nil, fmt.Errorf("Unexpected error: Oopsie") +} + +func SendToAny_HappyPath(host *connect.Host) (interface{}, error) { + return happyPathReturn, nil +} + +func SendToAny_KnownError(host *connect.Host) (interface{}, error) { + return nil, fmt.Errorf(errorsList[0]) +} + +func SendToAny_UnknownError(host *connect.Host) (interface{}, error) { + return nil, fmt.Errorf("Unexpected error: Oopsie") +} + +func SendToSpecific_HappyPath(host *connect.Host, target *id.ID) (interface{}, bool, error) { + return happyPathReturn, false, nil +} + +func SendToSpecific_Abort(host *connect.Host, target *id.ID) (interface{}, bool, error) { + return nil, true, fmt.Errorf(errorsList[0]) +} diff --git a/network/manager.go b/network/manager.go index 5d726cb088e53c4c12db98b8ac6fc9e13bd3de78..1cf065d691b44ca061404b426af7c26df965ca40 100644 --- a/network/manager.go +++ b/network/manager.go @@ -15,6 +15,7 @@ import ( "gitlab.com/elixxir/client/interfaces" "gitlab.com/elixxir/client/interfaces/params" "gitlab.com/elixxir/client/network/ephemeral" + "gitlab.com/elixxir/client/network/gateway" "gitlab.com/elixxir/client/network/health" "gitlab.com/elixxir/client/network/internal" "gitlab.com/elixxir/client/network/message" @@ -35,6 +36,8 @@ import ( type manager struct { // parameters of the network param params.Network + // handles message sending + sender *gateway.Sender //Shared data with all sub managers internal.Internal @@ -45,9 +48,6 @@ type manager struct { //number of polls done in a period of time tracker *uint64 - - //tracks already checked rounds - checked *checkedRounds } // NewManager builds a new reception manager object using inputted key fields @@ -73,7 +73,6 @@ func NewManager(session *storage.Session, switchboard *switchboard.Switchboard, m := manager{ param: params, tracker: &tracker, - checked: newCheckedRounds(), } m.Internal = internal.Internal{ @@ -88,18 +87,22 @@ func NewManager(session *storage.Session, switchboard *switchboard.Switchboard, ReceptionID: session.User().GetCryptographicIdentity().GetReceptionID(), } - // register the node registration channel early so login connection updates - // get triggered for registration if necessary - instance.SetAddGatewayChan(m.NodeRegistration) + // Set up gateway.Sender + poolParams := gateway.DefaultPoolParams() + m.sender, err = gateway.NewSender(poolParams, rng, + ndf, comms, session, m.NodeRegistration) + if err != nil { + return nil, err + } //create sub managers - m.message = message.NewManager(m.Internal, m.param.Messages, m.NodeRegistration) - m.round = rounds.NewManager(m.Internal, m.param.Rounds, m.message.GetMessageReceptionChannel()) + m.message = message.NewManager(m.Internal, m.param.Messages, m.NodeRegistration, m.sender) + m.round = rounds.NewManager(m.Internal, m.param.Rounds, m.message.GetMessageReceptionChannel(), m.sender) return &m, nil } -// StartRunners kicks off all network reception goroutines ("threads"). +// Follow StartRunners kicks off all network reception goroutines ("threads"). // Started Threads are: // - Network Follower (/network/follow.go) // - Historical Round Retrieval (/network/rounds/historical.go) @@ -120,14 +123,14 @@ func (m *manager) Follow(report interfaces.ClientErrorReport) (stoppable.Stoppab multi.Add(healthStop) // Node Updates - multi.Add(node.StartRegistration(m.Instance, m.Session, m.Rng, + multi.Add(node.StartRegistration(m.GetSender(), m.Session, m.Rng, m.Comms, m.NodeRegistration, m.param.ParallelNodeRegistrations)) // Adding/Keys //TODO-remover //m.runners.Add(StartNodeRemover(m.Context)) // Removing // Start the Network Tracker trackNetworkStopper := stoppable.NewSingle("TrackNetwork") - go m.followNetwork(report, trackNetworkStopper.Quit()) + go m.followNetwork(report, trackNetworkStopper.Quit(), trackNetworkStopper) multi.Add(trackNetworkStopper) // Message reception @@ -151,7 +154,12 @@ func (m *manager) GetInstance() *network.Instance { return m.Instance } -// triggers a check on garbled messages to see if they can be decrypted +// GetSender returns the gateway.Sender object +func (m *manager) GetSender() *gateway.Sender { + return m.sender +} + +// CheckGarbledMessages triggers a check on garbled messages to see if they can be decrypted // this should be done when a new e2e client is added in case messages were // received early or arrived out of order func (m *manager) CheckGarbledMessages() { diff --git a/network/message/critical.go b/network/message/critical.go index 0ebaed5d32390c35f92a6a96557782be6336ff1e..1dad92c821e756e3c89c6d013c5d78c9e09a1015 100644 --- a/network/message/critical.go +++ b/network/message/critical.go @@ -95,7 +95,7 @@ func (m *Manager) criticalMessages() { jww.INFO.Printf("Resending critical raw message to %s "+ "(msgDigest: %s)", rid, msg.Digest()) //send the message - round, _, err := m.SendCMIX(msg, rid, param) + round, _, err := m.SendCMIX(m.sender, msg, rid, param) //if the message fail to send, notify the buffer so it can be handled //in the future and exit if err != nil { diff --git a/network/message/garbled_test.go b/network/message/garbled_test.go index 0f5c3a074641af02ef3e156372bd7957eedba5aa..9d021ae488ed26717aade2d76d87ba688c7f9001 100644 --- a/network/message/garbled_test.go +++ b/network/message/garbled_test.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "gitlab.com/elixxir/client/interfaces/message" "gitlab.com/elixxir/client/interfaces/params" + "gitlab.com/elixxir/client/network/gateway" "gitlab.com/elixxir/client/network/internal" "gitlab.com/elixxir/client/network/message/parse" "gitlab.com/elixxir/client/storage" @@ -57,12 +58,18 @@ func TestManager_CheckGarbledMessages(t *testing.T) { Instance: nil, NodeRegistration: nil, } + p := gateway.DefaultPoolParams() + p.MaxPoolSize = 1 + sender, err := gateway.NewSender(p, i.Rng, getNDF(), &MockSendCMIXComms{t: t}, i.Session, nil) + if err != nil { + t.Errorf(err.Error()) + } m := NewManager(i, params.Messages{ MessageReceptionBuffLen: 20, MessageReceptionWorkerPoolSize: 20, MaxChecksGarbledMessage: 20, GarbledMessageWait: time.Hour, - }, nil) + }, nil, sender) e2ekv := i.Session.E2e() err = e2ekv.AddPartner(sess2.GetUser().TransmissionID, sess2.E2e().GetDHPublicKey(), e2ekv.GetDHPrivateKey(), diff --git a/network/message/handler.go b/network/message/handler.go index 88484e6fb98672520bf3e39af0cd3175597b53b4..b8892cff56f51009558998fcd74b1fa027c4a712 100644 --- a/network/message/handler.go +++ b/network/message/handler.go @@ -48,14 +48,11 @@ func (m *Manager) handleMessage(ecrMsg format.Message, identity reception.Identi var relationshipFingerprint []byte //check if the identity fingerprint matches - forMe, err := fingerprint2.CheckIdentityFP(ecrMsg.GetIdentityFP(), + forMe := fingerprint2.CheckIdentityFP(ecrMsg.GetIdentityFP(), ecrMsg.GetContents(), identity.Source) - if err != nil { - jww.FATAL.Panicf("Could not check IdentityFingerprint: %+v", err) - } if !forMe { if jww.GetLogThreshold() == jww.LevelTrace { - expectedFP, _ := fingerprint2.IdentityFP(ecrMsg.GetContents(), + expectedFP := fingerprint2.IdentityFP(ecrMsg.GetContents(), identity.Source) jww.TRACE.Printf("Message for %d (%s) failed identity "+ "check: %v (expected) vs %v (received)", identity.EphId, diff --git a/network/message/manager.go b/network/message/manager.go index 807ccc07066cf2281e557b6e07d09dd4180d3c15..7728910aa7e62186c682dcb97b297cb470dcac58 100644 --- a/network/message/manager.go +++ b/network/message/manager.go @@ -10,6 +10,7 @@ package message import ( "fmt" "gitlab.com/elixxir/client/interfaces/params" + "gitlab.com/elixxir/client/network/gateway" "gitlab.com/elixxir/client/network/internal" "gitlab.com/elixxir/client/network/message/parse" "gitlab.com/elixxir/client/stoppable" @@ -21,6 +22,7 @@ type Manager struct { param params.Messages partitioner parse.Partitioner internal.Internal + sender *gateway.Sender messageReception chan Bundle nodeRegistration chan network.NodeGateway @@ -29,7 +31,7 @@ type Manager struct { } func NewManager(internal internal.Internal, param params.Messages, - nodeRegistration chan network.NodeGateway) *Manager { + nodeRegistration chan network.NodeGateway, sender *gateway.Sender) *Manager { dummyMessage := format.NewMessage(internal.Session.Cmix().GetGroup().GetP().ByteLen()) m := Manager{ param: param, @@ -38,6 +40,7 @@ func NewManager(internal internal.Internal, param params.Messages, networkIsHealthy: make(chan bool, 1), triggerGarbled: make(chan struct{}, 100), nodeRegistration: nodeRegistration, + sender: sender, } m.Internal = internal return &m diff --git a/network/message/sendCmix.go b/network/message/sendCmix.go index 1aced6efe44eaef52f17a08392d4c52a8d6eae7c..2a7095843eb57c6f9af2af16ea9725fab0aac627 100644 --- a/network/message/sendCmix.go +++ b/network/message/sendCmix.go @@ -12,6 +12,7 @@ import ( "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/interfaces/params" + "gitlab.com/elixxir/client/network/gateway" "gitlab.com/elixxir/client/storage" pb "gitlab.com/elixxir/comms/mixmessages" "gitlab.com/elixxir/comms/network" @@ -29,7 +30,6 @@ import ( // interface for SendCMIX comms; allows mocking this in testing type sendCmixCommsInterface interface { - GetHost(hostId *id.ID) (*connect.Host, bool) SendPutMessage(host *connect.Host, message *pb.GatewaySlot) (*pb.GatewaySlotResponse, error) } @@ -38,9 +38,9 @@ const sendTimeBuffer = 2500 * time.Millisecond // WARNING: Potentially Unsafe // Public manager function to send a message over CMIX -func (m *Manager) SendCMIX(msg format.Message, recipient *id.ID, param params.CMIX) (id.Round, ephemeral.Id, error) { +func (m *Manager) SendCMIX(sender *gateway.Sender, msg format.Message, recipient *id.ID, param params.CMIX) (id.Round, ephemeral.Id, error) { msgCopy := msg.Copy() - return sendCmixHelper(msgCopy, recipient, param, m.Instance, m.Session, m.nodeRegistration, m.Rng, m.TransmissionID, m.Comms) + return sendCmixHelper(sender, msgCopy, recipient, param, m.Instance, m.Session, m.nodeRegistration, m.Rng, m.TransmissionID, m.Comms) } // Payloads send are not End to End encrypted, MetaData is NOT protected with @@ -51,7 +51,7 @@ func (m *Manager) SendCMIX(msg format.Message, recipient *id.ID, param params.CM // If the message is successfully sent, the id of the round sent it is returned, // which can be registered with the network instance to get a callback on // its status -func sendCmixHelper(msg format.Message, recipient *id.ID, param params.CMIX, instance *network.Instance, +func sendCmixHelper(sender *gateway.Sender, msg format.Message, recipient *id.ID, param params.CMIX, instance *network.Instance, session *storage.Session, nodeRegistration chan network.NodeGateway, rng *fastRNG.StreamGenerator, senderId *id.ID, comms sendCmixCommsInterface) (id.Round, ephemeral.Id, error) { @@ -108,13 +108,7 @@ func sendCmixHelper(msg format.Message, recipient *id.ID, param params.CMIX, ins msg.SetEphemeralRID(ephIdFilled[:]) //set the identity fingerprint - ifp, err := fingerprint.IdentityFP(msg.GetContents(), recipient) - if err != nil { - jww.FATAL.Panicf("failed to generate the Identity "+ - "fingerprint due to unrecoverable error when sending to %s "+ - "(msgDigest: %s): %+v", recipient, msg.Digest(), err) - } - + ifp := fingerprint.IdentityFP(msg.GetContents(), recipient) msg.SetIdentityFP(ifp) //build the topology @@ -142,15 +136,6 @@ func sendCmixHelper(msg format.Message, recipient *id.ID, param params.CMIX, ins firstGateway := topology.GetNodeAtIndex(0).DeepCopy() firstGateway.SetType(id.Gateway) - transmitGateway, ok := comms.GetHost(firstGateway) - if !ok { - jww.ERROR.Printf("Failed to get host for gateway %s when "+ - "sending to %s (msgDigest: %s)", transmitGateway, recipient, - msg.Digest()) - time.Sleep(param.RetryDelay) - continue - } - //encrypt the message stream = rng.GetStream() salt := make([]byte, 32) @@ -187,46 +172,57 @@ func sendCmixHelper(msg format.Message, recipient *id.ID, param params.CMIX, ins jww.INFO.Printf("Sending to EphID %d (%s) on round %d, "+ "(msgDigest: %s, ecrMsgDigest: %s) via gateway %s", ephID.Int64(), recipient, bestRound.ID, msg.Digest(), - encMsg.Digest(), transmitGateway.GetId()) - // //Send the payload - gwSlotResp, err := comms.SendPutMessage(transmitGateway, wrappedMsg) + encMsg.Digest(), firstGateway.String()) + + // Send the payload + result, err := sender.SendToSpecific(firstGateway, func(host *connect.Host, target *id.ID) (interface{}, bool, error) { + wrappedMsg.Target = target.Marshal() + result, err := comms.SendPutMessage(host, wrappedMsg) + if err != nil { + if strings.Contains(err.Error(), + "try a different round.") { + jww.WARN.Printf("Failed to send to %s (msgDigest: %s) "+ + "due to round error with round %d, retrying: %+v", + recipient, msg.Digest(), bestRound.ID, err) + return nil, true, err + } else if strings.Contains(err.Error(), + "Could not authenticate client. Is the client registered "+ + "with this node?") { + jww.WARN.Printf("Failed to send to %s (msgDigest: %s) "+ + "via %s due to failed authentication: %s", + recipient, msg.Digest(), firstGateway.String(), err) + //if we failed to send due to the gateway not recognizing our + // authorization, renegotiate with the node to refresh it + nodeID := firstGateway.DeepCopy() + nodeID.SetType(id.Node) + //delete the keys + session.Cmix().Remove(nodeID) + //trigger + go handleMissingNodeKeys(instance, nodeRegistration, []*id.ID{nodeID}) + return nil, true, err + } + } + return result, false, err + }) + //if the comm errors or the message fails to send, continue retrying. //return if it sends properly if err != nil { - if strings.Contains(err.Error(), - "try a different round.") { - jww.WARN.Printf("Failed to send to %s (msgDigest: %s) "+ - "due to round error with round %d, retrying: %+v", - recipient, msg.Digest(), bestRound.ID, err) - continue - } else if strings.Contains(err.Error(), - "Could not authenticate client. Is the client registered "+ - "with this node?") { - jww.WARN.Printf("Failed to send to %s (msgDigest: %s) "+ - "via %s due to failed authentication: %s", - recipient, msg.Digest(), transmitGateway.GetId(), err) - //if we failed to send due to the gateway not recognizing our - // authorization, renegotiate with the node to refresh it - nodeID := transmitGateway.GetId().DeepCopy() - nodeID.SetType(id.Node) - //delete the keys - session.Cmix().Remove(nodeID) - //trigger - go handleMissingNodeKeys(instance, nodeRegistration, []*id.ID{nodeID}) - continue - } jww.ERROR.Printf("Failed to send to EphID %d (%s) on "+ - "round %d, bailing: %+v", ephID.Int64(), recipient, + "round %d, trying a new round: %+v", ephID.Int64(), recipient, bestRound.ID, err) - return 0, ephemeral.Id{}, errors.WithMessage(err, "Failed to put cmix message") - } else if gwSlotResp.Accepted { + continue + } + + gwSlotResp := result.(*pb.GatewaySlotResponse) + if gwSlotResp.Accepted { jww.INFO.Printf("Successfully sent to EphID %v (source: %s) "+ "in round %d", ephID.Int64(), recipient, bestRound.ID) return id.Round(bestRound.ID), ephID, nil } else { jww.FATAL.Panicf("Gateway %s returned no error, but failed "+ "to accept message when sending to EphID %d (%s) on round %d", - transmitGateway.GetId(), ephID.Int64(), recipient, bestRound.ID) + firstGateway.String(), ephID.Int64(), recipient, bestRound.ID) } } return 0, ephemeral.Id{}, errors.New("failed to send the message, " + diff --git a/network/message/sendCmix_test.go b/network/message/sendCmix_test.go index 82b3ebcb2c3cfa87a72f62519f68541ae78bd06d..04da108426dda151bec4777338fe4a0036366832 100644 --- a/network/message/sendCmix_test.go +++ b/network/message/sendCmix_test.go @@ -1,8 +1,10 @@ package message import ( + "github.com/pkg/errors" "gitlab.com/elixxir/client/interfaces/message" "gitlab.com/elixxir/client/interfaces/params" + "gitlab.com/elixxir/client/network/gateway" "gitlab.com/elixxir/client/network/internal" "gitlab.com/elixxir/client/storage" "gitlab.com/elixxir/client/switchboard" @@ -40,6 +42,16 @@ func (mc *MockSendCMIXComms) GetHost(hostId *id.ID) (*connect.Host, bool) { }) return h, true } + +func (mc *MockSendCMIXComms) AddHost(hid *id.ID, address string, cert []byte, params connect.HostParams) (host *connect.Host, err error) { + host, _ = mc.GetHost(nil) + return host, nil +} + +func (mc *MockSendCMIXComms) RemoveHost(hid *id.ID) { + +} + func (mc *MockSendCMIXComms) SendPutMessage(host *connect.Host, message *mixmessages.GatewaySlot) (*mixmessages.GatewaySlotResponse, error) { return &mixmessages.GatewaySlotResponse{ Accepted: true, @@ -115,24 +127,36 @@ func Test_attemptSendCmix(t *testing.T) { Instance: inst, NodeRegistration: nil, } + p := gateway.DefaultPoolParams() + p.MaxPoolSize = 1 + sender, err := gateway.NewSender(p, i.Rng, getNDF(), &MockSendCMIXComms{t: t}, i.Session, nil) + if err != nil { + t.Errorf("%+v", errors.New(err.Error())) + return + } m := NewManager(i, params.Messages{ MessageReceptionBuffLen: 20, MessageReceptionWorkerPoolSize: 20, MaxChecksGarbledMessage: 20, GarbledMessageWait: time.Hour, - }, nil) + }, nil, sender) msgCmix := format.NewMessage(m.Session.Cmix().GetGroup().GetP().ByteLen()) msgCmix.SetContents([]byte("test")) e2e.SetUnencrypted(msgCmix, m.Session.User().GetCryptographicIdentity().GetTransmissionID()) - _, _, err = sendCmixHelper(msgCmix, sess2.GetUser().ReceptionID, params.GetDefaultCMIX(), + _, _, err = sendCmixHelper(sender, msgCmix, sess2.GetUser().ReceptionID, params.GetDefaultCMIX(), m.Instance, m.Session, m.nodeRegistration, m.Rng, m.TransmissionID, &MockSendCMIXComms{t: t}) if err != nil { t.Errorf("Failed to sendcmix: %+v", err) + panic("t") + return } } func getNDF() *ndf.NetworkDefinition { + nodeId := id.NewIdFromString("zezima", id.Node, &testing.T{}) + gwId := nodeId.DeepCopy() + gwId.SetType(id.Gateway) return &ndf.NetworkDefinition{ E2E: ndf.Group{ Prime: "E2EE983D031DC1DB6F1A7A67DF0E9A8E5561DB8E8D49413394C049B" + @@ -168,5 +192,19 @@ func getNDF() *ndf.NetworkDefinition { "BA9AE3F1DD2487199874393CD4D832186800654760E1E34C09E4D155179F9EC0" + "DC4473F996BDCE6EED1CABED8B6F116F7AD9CF505DF0F998E34AB27514B0FFE7", }, + Gateways: []ndf.Gateway{ + { + ID: gwId.Marshal(), + Address: "0.0.0.0", + TlsCertificate: "", + }, + }, + Nodes: []ndf.Node{ + { + ID: nodeId.Marshal(), + Address: "0.0.0.0", + TlsCertificate: "", + }, + }, } } diff --git a/network/message/sendE2E.go b/network/message/sendE2E.go index 78061232830f167adf91d21f5805fa2e00b58916..16c14701d47ead025dae9758fa4697ce6c5ea007 100644 --- a/network/message/sendE2E.go +++ b/network/message/sendE2E.go @@ -95,7 +95,7 @@ func (m *Manager) SendE2E(msg message.Send, param params.E2E) ([]id.Round, e2e.M wg.Add(1) go func(i int) { var err error - roundIds[i], _, err = m.SendCMIX(msgEnc, msg.Recipient, + roundIds[i], _, err = m.SendCMIX(m.sender, msgEnc, msg.Recipient, param.CMIX) if err != nil { errCh <- err diff --git a/network/message/sendUnsafe.go b/network/message/sendUnsafe.go index ef28cf2d789d1f5f0a097413046b5a0c50b0a230..19f7d5d5ce0bfe6dd14df77b76f0a21c824b2404 100644 --- a/network/message/sendUnsafe.go +++ b/network/message/sendUnsafe.go @@ -64,7 +64,7 @@ func (m *Manager) SendUnsafe(msg message.Send, param params.Unsafe) ([]id.Round, wg.Add(1) go func(i int) { var err error - roundIds[i], _, err = m.SendCMIX(msgCmix, msg.Recipient, param.CMIX) + roundIds[i], _, err = m.SendCMIX(m.sender, msgCmix, msg.Recipient, param.CMIX) if err != nil { errCh <- err } @@ -82,7 +82,7 @@ func (m *Manager) SendUnsafe(msg message.Send, param params.Unsafe) ([]id.Round, return nil, errors.Errorf("Failed to send %v/%v sub payloads:"+ " %s", numFail, len(partitions), errRtn) } else { - jww.INFO.Printf("Sucesfully Unsafe sent %d/%d to %s", + jww.INFO.Printf("Successfully Unsafe sent %d/%d to %s", len(partitions)-numFail, len(partitions), msg.Recipient) } diff --git a/network/node/register.go b/network/node/register.go index 2b59d1e7ffd23853705541c46eb14242e8c3ca76..c89d2dd55371bef074a3d417f5ca05ae66d4e550 100644 --- a/network/node/register.go +++ b/network/node/register.go @@ -13,6 +13,7 @@ import ( "fmt" "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/client/network/gateway" "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/client/storage" "gitlab.com/elixxir/client/storage/cmix" @@ -33,29 +34,28 @@ import ( ) type RegisterNodeCommsInterface interface { - GetHost(hostId *id.ID) (*connect.Host, bool) SendRequestNonceMessage(host *connect.Host, message *pb.NonceRequest) (*pb.Nonce, error) SendConfirmNonceMessage(host *connect.Host, message *pb.RequestRegistrationConfirmation) (*pb.RegistrationConfirmation, error) } -func StartRegistration(instance *network.Instance, session *storage.Session, rngGen *fastRNG.StreamGenerator, comms RegisterNodeCommsInterface, +func StartRegistration(sender *gateway.Sender, session *storage.Session, rngGen *fastRNG.StreamGenerator, comms RegisterNodeCommsInterface, c chan network.NodeGateway, numParallel uint) stoppable.Stoppable { multi := stoppable.NewMulti("NodeRegistrations") - for i:=uint(0);i<numParallel;i++{ + for i := uint(0); i < numParallel; i++ { stop := stoppable.NewSingle(fmt.Sprintf("NodeRegistration %d", i)) - go registerNodes(session, rngGen, comms, stop, c) + go registerNodes(sender, session, rngGen, comms, stop, c) multi.Add(stop) } return multi } -func registerNodes(session *storage.Session, rngGen *fastRNG.StreamGenerator, comms RegisterNodeCommsInterface, +func registerNodes(sender *gateway.Sender, session *storage.Session, rngGen *fastRNG.StreamGenerator, comms RegisterNodeCommsInterface, stop *stoppable.Single, c chan network.NodeGateway) { u := session.User() regSignature := u.GetTransmissionRegistrationValidationSignature() @@ -71,7 +71,7 @@ func registerNodes(session *storage.Session, rngGen *fastRNG.StreamGenerator, co t.Stop() return case gw := <-c: - err := registerWithNode(comms, gw, regSignature, uci, cmix, rng) + err := registerWithNode(sender, comms, gw, regSignature, uci, cmix, rng) if err != nil { jww.ERROR.Printf("Failed to register node: %+v", err) } @@ -82,7 +82,7 @@ func registerNodes(session *storage.Session, rngGen *fastRNG.StreamGenerator, co //registerWithNode serves as a helper for RegisterWithNodes // It registers a user with a specific in the client's ndf. -func registerWithNode(comms RegisterNodeCommsInterface, ngw network.NodeGateway, regSig []byte, +func registerWithNode(sender *gateway.Sender, comms RegisterNodeCommsInterface, ngw network.NodeGateway, regSig []byte, uci *user.CryptographicIdentity, store *cmix.Store, rng csprng.Source) error { nodeID, err := ngw.Node.GetNodeId() if err != nil { @@ -109,7 +109,7 @@ func registerWithNode(comms RegisterNodeCommsInterface, ngw network.NodeGateway, userNum := int(uci.GetTransmissionID().Bytes()[7]) h := sha256.New() h.Reset() - h.Write([]byte(strconv.Itoa(int(4000 + userNum)))) + h.Write([]byte(strconv.Itoa(4000 + userNum))) transmissionKey = store.GetGroup().NewIntFromBytes(h.Sum(nil)) jww.INFO.Printf("transmissionKey: %v", transmissionKey.Bytes()) @@ -118,7 +118,7 @@ func registerWithNode(comms RegisterNodeCommsInterface, ngw network.NodeGateway, // keys transmissionHash, _ := hash.NewCMixHash() - nonce, dhPub, err := requestNonce(comms, gatewayID, regSig, uci, store, rng) + nonce, dhPub, err := requestNonce(sender, comms, gatewayID, regSig, uci, store, rng) if err != nil { return errors.Errorf("Failed to request nonce: %+v", err) } @@ -127,8 +127,8 @@ func registerWithNode(comms RegisterNodeCommsInterface, ngw network.NodeGateway, serverPubDH := store.GetGroup().NewIntFromBytes(dhPub) // Confirm received nonce - jww.INFO.Println("Register: Confirming received nonce") - err = confirmNonce(comms, uci.GetTransmissionID().Bytes(), + jww.INFO.Printf("Register: Confirming received nonce from node %s", nodeID.String()) + err = confirmNonce(sender, comms, uci.GetTransmissionID().Bytes(), nonce, uci.GetTransmissionRSA(), gatewayID) if err != nil { errMsg := fmt.Sprintf("Register: Unable to confirm nonce: %v", err) @@ -145,7 +145,7 @@ func registerWithNode(comms RegisterNodeCommsInterface, ngw network.NodeGateway, return nil } -func requestNonce(comms RegisterNodeCommsInterface, gwId *id.ID, regHash []byte, +func requestNonce(sender *gateway.Sender, comms RegisterNodeCommsInterface, gwId *id.ID, regHash []byte, uci *user.CryptographicIdentity, store *cmix.Store, rng csprng.Source) ([]byte, []byte, error) { dhPub := store.GetDHPublicKey().Bytes() opts := rsa.NewDefaultOptions() @@ -162,34 +162,37 @@ func requestNonce(comms RegisterNodeCommsInterface, gwId *id.ID, regHash []byte, } // Request nonce message from gateway - jww.INFO.Printf("Register: Requesting nonce from gateway %v", - gwId.Bytes()) - - host, ok := comms.GetHost(gwId) - if !ok { - return nil, nil, errors.Errorf("Failed to find host with ID %s", gwId.String()) - } - nonceResponse, err := comms.SendRequestNonceMessage(host, - &pb.NonceRequest{ - Salt: uci.GetTransmissionSalt(), - ClientRSAPubKey: string(rsa.CreatePublicKeyPem(uci.GetTransmissionRSA().GetPublic())), - ClientSignedByServer: &messages.RSASignature{ - Signature: regHash, - }, - ClientDHPubKey: dhPub, - RequestSignature: &messages.RSASignature{ - Signature: clientSig, - }, - }) - + jww.INFO.Printf("Register: Requesting nonce from gateway %v", gwId.String()) + + result, err := sender.SendToAny(func(host *connect.Host) (interface{}, error) { + nonceResponse, err := comms.SendRequestNonceMessage(host, + &pb.NonceRequest{ + Salt: uci.GetTransmissionSalt(), + ClientRSAPubKey: string(rsa.CreatePublicKeyPem(uci.GetTransmissionRSA().GetPublic())), + ClientSignedByServer: &messages.RSASignature{ + Signature: regHash, + }, + ClientDHPubKey: dhPub, + RequestSignature: &messages.RSASignature{ + Signature: clientSig, + }, + Target: gwId.Marshal(), + }) + if err != nil { + errMsg := fmt.Sprintf("Register: Failed requesting nonce from gateway: %+v", err) + return nil, errors.New(errMsg) + } + if nonceResponse.Error != "" { + err := errors.New(fmt.Sprintf("requestNonce: nonceResponse error: %s", nonceResponse.Error)) + return nil, err + } + return nonceResponse, nil + }) if err != nil { - errMsg := fmt.Sprintf("Register: Failed requesting nonce from gateway: %+v", err) - return nil, nil, errors.New(errMsg) - } - if nonceResponse.Error != "" { - err := errors.New(fmt.Sprintf("requestNonce: nonceResponse error: %s", nonceResponse.Error)) return nil, nil, err } + nonceResponse := result.(*pb.Nonce) + // Use Client keypair to sign Server nonce return nonceResponse.Nonce, nonceResponse.DHPubKey, nil } @@ -197,7 +200,7 @@ func requestNonce(comms RegisterNodeCommsInterface, gwId *id.ID, regHash []byte, // confirmNonce is a helper for the Register function // It signs a nonce and sends it for confirmation // Returns nil if successful, error otherwise -func confirmNonce(comms RegisterNodeCommsInterface, UID, nonce []byte, +func confirmNonce(sender *gateway.Sender, comms RegisterNodeCommsInterface, UID, nonce []byte, privateKeyRSA *rsa.PrivateKey, gwID *id.ID) error { opts := rsa.NewDefaultOptions() opts.Hash = hash.CMixHash @@ -224,22 +227,19 @@ func confirmNonce(comms RegisterNodeCommsInterface, UID, nonce []byte, NonceSignedByClient: &messages.RSASignature{ Signature: sig, }, + Target: gwID.Marshal(), } - host, ok := comms.GetHost(gwID) - if !ok { - return errors.Errorf("Failed to find host with ID %s", gwID.String()) - } - confirmResponse, err := comms.SendConfirmNonceMessage(host, msg) - if err != nil { - err := errors.New(fmt.Sprintf( - "confirmNonce: Unable to send signed nonce! %s", err)) - return err - } - if confirmResponse.Error != "" { - err := errors.New(fmt.Sprintf( - "confirmNonce: Error confirming nonce: %s", confirmResponse.Error)) - return err - } - return nil + _, err = sender.SendToAny(func(host *connect.Host) (interface{}, error) { + confirmResponse, err := comms.SendConfirmNonceMessage(host, msg) + if err != nil { + return nil, err + } else if confirmResponse.Error != "" { + err := errors.New(fmt.Sprintf( + "confirmNonce: Error confirming nonce: %s", confirmResponse.Error)) + return nil, err + } + return confirmResponse, nil + }) + return err } diff --git a/network/polltracker.go b/network/polltracker.go index 48574c7d8e19eb4e1d73090af0c017e3c9b8187b..63739e19afbb1a959face8c7859f23c704a7cb66 100644 --- a/network/polltracker.go +++ b/network/polltracker.go @@ -13,7 +13,7 @@ func newPollTracker() *pollTracker { return &pt } -//tracks a single poll +// Track tracks a single poll func (pt *pollTracker) Track(ephID ephemeral.Id, source *id.ID) { if _, exists := (*pt)[*source]; !exists { (*pt)[*source] = make(map[int64]uint) @@ -25,7 +25,7 @@ func (pt *pollTracker) Track(ephID ephemeral.Id, source *id.ID) { } } -//reports all resent polls +// Report reports all recent polls func (pt *pollTracker) Report() string { report := "" numReports := uint(0) diff --git a/network/rounds/check.go b/network/rounds/check.go index 9deb3ec6c16fb227a6406c9fd6181ffcf6e4d8b1..03cd83718495b71ef5455e9112dfe95fdda89fe3 100644 --- a/network/rounds/check.go +++ b/network/rounds/check.go @@ -11,6 +11,7 @@ import ( "encoding/binary" jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/storage/reception" + "gitlab.com/elixxir/client/storage/rounds" "gitlab.com/xx_network/primitives/id" ) @@ -27,7 +28,12 @@ import ( // Retrieval // false: no message // true: message -func Checker(roundID id.Round, filters []*RemoteFilter) bool { +func Checker(roundID id.Round, filters []*RemoteFilter, cr *rounds.CheckedRounds) bool { + // Skip checking if the round is already checked + if cr.IsChecked(roundID) { + return true + } + //find filters that could have the round and check them serialRid := serializeRound(roundID) for _, filter := range filters { @@ -47,9 +53,9 @@ func serializeRound(roundId id.Round) []byte { return b } -func (m *Manager) GetMessagesFromRound(roundID id.Round, identity reception.IdentityUse){ +func (m *Manager) GetMessagesFromRound(roundID id.Round, identity reception.IdentityUse) { ri, err := m.Instance.GetRound(roundID) - if err !=nil || m.params.ForceHistoricalRounds { + if err != nil || m.params.ForceHistoricalRounds { if m.params.ForceHistoricalRounds { jww.WARN.Printf("Forcing use of historical rounds for round ID %d.", roundID) @@ -59,8 +65,8 @@ func (m *Manager) GetMessagesFromRound(roundID id.Round, identity reception.Iden identity.Source) // If we didn't find it, send to Historical Rounds Retrieval m.historicalRounds <- historicalRoundRequest{ - rid: roundID, - identity: identity, + rid: roundID, + identity: identity, numAttempts: 0, } } else { @@ -74,4 +80,4 @@ func (m *Manager) GetMessagesFromRound(roundID id.Round, identity reception.Iden } } -} \ No newline at end of file +} diff --git a/network/rounds/historical.go b/network/rounds/historical.go index d5451989c4f3092652943d308d37c4a21bbf0524..45aed7fdfaa7b296a46de79645e19fd010f88634 100644 --- a/network/rounds/historical.go +++ b/network/rounds/historical.go @@ -9,7 +9,6 @@ package rounds import ( jww "github.com/spf13/jwalterweatherman" - "gitlab.com/elixxir/client/network/gateway" "gitlab.com/elixxir/client/storage/reception" pb "gitlab.com/elixxir/comms/mixmessages" "gitlab.com/xx_network/comms/connect" @@ -34,8 +33,8 @@ type historicalRoundsComms interface { //structure which contains a historical round lookup type historicalRoundRequest struct { - rid id.Round - identity reception.IdentityUse + rid id.Round + identity reception.IdentityUse numAttempts uint } @@ -87,13 +86,6 @@ func (m *Manager) processHistoricalRounds(comm historicalRoundsComms, quitCh <-c continue } - //find a gateway to request about the roundRequests - gwHost, err := gateway.Get(m.Instance.GetPartialNdf().Get(), comm, rng) - if err != nil { - jww.FATAL.Panicf("Failed to track network, NDF has corrupt "+ - "data: %s", err) - } - rounds := make([]uint64, len(roundRequests)) for i, rr := range roundRequests { rounds[i] = uint64(rr.rid) @@ -104,18 +96,21 @@ func (m *Manager) processHistoricalRounds(comm historicalRoundsComms, quitCh <-c Rounds: rounds, } - jww.DEBUG.Printf("Requesting Historical rounds %v from "+ - "gateway %s", rounds, gwHost.GetId()) + result, err := m.sender.SendToAny(func(host *connect.Host) (interface{}, error) { + jww.DEBUG.Printf("Requesting Historical rounds %v from "+ + "gateway %s", rounds, host.GetId()) + return comm.RequestHistoricalRounds(host, hr) + }) - response, err := comm.RequestHistoricalRounds(gwHost, hr) if err != nil { jww.ERROR.Printf("Failed to request historical roundRequests "+ - "data for rounds %v: %s", rounds, response) + "data for rounds %v: %s", rounds, err) // if the check fails to resolve, break the loop and so they will be // checked again timerCh = time.NewTimer(m.params.HistoricalRoundsPeriod).C continue } + response := result.(*pb.HistoricalRoundsResponse) // process the returned historical roundRequests. for i, roundInfo := range response.Rounds { @@ -124,19 +119,19 @@ func (m *Manager) processHistoricalRounds(comm historicalRoundsComms, quitCh <-c // pick them up in the future. if roundInfo == nil { roundRequests[i].numAttempts++ - if roundRequests[i].numAttempts==m.params.MaxHistoricalRoundsRetries{ - jww.ERROR.Printf("Failed to retreive historical " + + if roundRequests[i].numAttempts == m.params.MaxHistoricalRoundsRetries { + jww.ERROR.Printf("Failed to retreive historical "+ "round %d on last attempt, will not try again", roundRequests[i].rid) - }else{ + } else { select { - case m.historicalRounds <-roundRequests[i]: - jww.WARN.Printf("Failed to retreive historical " + + case m.historicalRounds <- roundRequests[i]: + jww.WARN.Printf("Failed to retreive historical "+ "round %d, will try up to %d more times", roundRequests[i].rid, m.params.MaxHistoricalRoundsRetries-roundRequests[i].numAttempts) default: - jww.WARN.Printf("Failed to retreive historical " + - "round %d, failed to try again, round will not be " + + jww.WARN.Printf("Failed to retreive historical "+ + "round %d, failed to try again, round will not be "+ "retreived", roundRequests[i].rid) } } diff --git a/network/rounds/manager.go b/network/rounds/manager.go index 177c1c318b4807be0f17ccbb36bbe8720104deb4..942e86319efe8ab05711b901360edbcd37865978 100644 --- a/network/rounds/manager.go +++ b/network/rounds/manager.go @@ -10,6 +10,7 @@ package rounds import ( "fmt" "gitlab.com/elixxir/client/interfaces/params" + "gitlab.com/elixxir/client/network/gateway" "gitlab.com/elixxir/client/network/internal" "gitlab.com/elixxir/client/network/message" "gitlab.com/elixxir/client/stoppable" @@ -19,6 +20,7 @@ type Manager struct { params params.Rounds internal.Internal + sender *gateway.Sender historicalRounds chan historicalRoundRequest lookupRoundMessages chan roundLookup @@ -26,13 +28,14 @@ type Manager struct { } func NewManager(internal internal.Internal, params params.Rounds, - bundles chan<- message.Bundle) *Manager { + bundles chan<- message.Bundle, sender *gateway.Sender) *Manager { m := &Manager{ params: params, historicalRounds: make(chan historicalRoundRequest, params.HistoricalRoundsBufferLen), lookupRoundMessages: make(chan roundLookup, params.LookupRoundsBufferLen), messageBundles: bundles, + sender: sender, } m.Internal = internal @@ -55,4 +58,4 @@ func (m *Manager) StartProcessors() stoppable.Stoppable { multi.Add(stopper) } return multi -} \ No newline at end of file +} diff --git a/network/rounds/retrieve.go b/network/rounds/retrieve.go index effd66b7fcac84db95255cc458c997a36d134f2b..0d90355c64bb2adaff003442737e09468a5b05f8 100644 --- a/network/rounds/retrieve.go +++ b/network/rounds/retrieve.go @@ -10,7 +10,6 @@ package rounds import ( "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" - "gitlab.com/elixxir/client/network/gateway" "gitlab.com/elixxir/client/network/message" "gitlab.com/elixxir/client/storage/reception" pb "gitlab.com/elixxir/comms/mixmessages" @@ -44,52 +43,35 @@ func (m *Manager) processMessageRetrieval(comms messageRetrievalComms, done = true case rl := <-m.lookupRoundMessages: ri := rl.roundInfo - var bundle message.Bundle - // Get a shuffled list of gateways in the round - gwHosts, err := gateway.GetAllShuffled(comms, ri) - if err != nil { - jww.WARN.Printf("Failed to get gateway hosts from "+ - "round %v, not requesting from them", - ri.ID) - break - } - - // Attempt to request messages for every gateway in the list. - // If we retrieve without error, then we exit. If we error, then - // we retry with the next gateway in the list until we exhaust the list - for i, gwHost := range gwHosts { - // Attempt to request for this gateway - bundle, err = m.getMessagesFromGateway(id.Round(ri.ID), rl.identity, comms, gwHost) + // Convert gateways in round to proper ID format + gwIds := make([]*id.ID, len(ri.Topology)) + for i, idBytes := range ri.Topology { + gwId, err := id.Unmarshal(idBytes) if err != nil { - - jww.WARN.Printf("Failed on gateway [%d/%d] to get messages for round %v", - i, len(gwHosts), ri.ID) - - // Retry for the next gateway in the list - continue + jww.FATAL.Panicf("processMessageRetrieval: Unable to unmarshal: %+v", err) } - - // If a non-error request, no longer retry - break - - } - gwIDs := make([]*id.ID, 0) - for _, gwHost := range gwHosts { - gwIDs = append(gwIDs, gwHost.GetId()) + gwId.SetType(id.Gateway) + gwIds[i] = gwId } + // Attempt to request for this gateway + bundle, err := m.getMessagesFromGateway(id.Round(ri.ID), rl.identity, comms, gwIds) + // After trying all gateways, if none returned we mark the round as a // failure and print out the last error if err != nil { jww.ERROR.Printf("Failed to get pickup round %d "+ - "from all gateways (%v): final gateway %s returned : %s", - id.Round(ri.ID), gwIDs, gwHosts[len(gwHosts)-1].GetId(), err) - } else if len(bundle.Messages) != 0 { + "from all gateways (%v): %s", + id.Round(ri.ID), gwIds, err) + } + + if len(bundle.Messages) != 0 { // If successful and there are messages, we send them to another thread bundle.Identity = rl.identity m.messageBundles <- bundle } + } } } @@ -97,40 +79,49 @@ func (m *Manager) processMessageRetrieval(comms messageRetrievalComms, // getMessagesFromGateway attempts to get messages from their assigned // gateway host in the round specified. If successful func (m *Manager) getMessagesFromGateway(roundID id.Round, identity reception.IdentityUse, - comms messageRetrievalComms, gwHost *connect.Host) (message.Bundle, error) { + comms messageRetrievalComms, gwIds []*id.ID) (message.Bundle, error) { + + // Send to the gateways using backup proxies + result, err := m.sender.SendToPreferred(gwIds, func(host *connect.Host, target *id.ID) (interface{}, error) { + jww.DEBUG.Printf("Trying to get messages for round %v for ephemeralID %d (%v) "+ + "via Gateway: %s", roundID, identity.EphId.Int64(), identity.Source.String(), host.GetId()) + + // send the request + msgReq := &pb.GetMessages{ + ClientID: identity.EphId[:], + RoundID: uint64(roundID), + Target: target.Marshal(), + } - jww.DEBUG.Printf("Trying to get messages for round %v for ephmeralID %d (%v) "+ - "via Gateway: %s", roundID, identity.EphId.Int64(), identity.Source.String(), gwHost.GetId()) + // If the gateway doesnt have the round, return an error + msgResp, err := comms.RequestMessages(host, msgReq) + if err == nil && !msgResp.GetHasRound() { + return message.Bundle{}, errors.Errorf(noRoundError) + } + + return msgResp, err + }) - // send the request - msgReq := &pb.GetMessages{ - ClientID: identity.EphId[:], - RoundID: uint64(roundID), - } - msgResp, err := comms.RequestMessages(gwHost, msgReq) // Fail the round if an error occurs so it can be tried again later if err != nil { return message.Bundle{}, errors.WithMessagef(err, "Failed to "+ - "request messages from %s for round %d", gwHost.GetId(), roundID) - } - // if the gateway doesnt have the round, return an error - if !msgResp.GetHasRound() { - return message.Bundle{}, errors.Errorf(noRoundError) + "request messages for round %d", roundID) } + msgResp := result.(*pb.GetMessagesResponse) // If there are no messages print a warning. Due to the probabilistic nature // of the bloom filters, false positives will happen some times msgs := msgResp.GetMessages() if msgs == nil || len(msgs) == 0 { - jww.WARN.Printf("host %s has no messages for client %s "+ + jww.WARN.Printf("no messages for client %s "+ " in round %d. This happening every once in a while is normal,"+ - " but can be indicitive of a problem if it is consistant", gwHost, + " but can be indicative of a problem if it is consistent", m.TransmissionID, roundID) return message.Bundle{}, nil } - jww.INFO.Printf("Received %d messages in Round %v via Gateway %s for %d (%s)", - len(msgs), roundID, gwHost.GetId(), identity.EphId.Int64(), identity.Source) + jww.INFO.Printf("Received %d messages in Round %v for %d (%s)", + len(msgs), roundID, identity.EphId.Int64(), identity.Source) //build the bundle of messages to send to the message processor bundle := message.Bundle{ diff --git a/network/rounds/retrieve_test.go b/network/rounds/retrieve_test.go index 2efd2c89c6865ba335d2e167cab343bb786370db..abd79533dd9189cdef93f8f37aba4b8ad810b8d2 100644 --- a/network/rounds/retrieve_test.go +++ b/network/rounds/retrieve_test.go @@ -8,11 +8,15 @@ package rounds import ( "bytes" + "gitlab.com/elixxir/client/network/gateway" "gitlab.com/elixxir/client/network/message" "gitlab.com/elixxir/client/storage/reception" pb "gitlab.com/elixxir/comms/mixmessages" + "gitlab.com/elixxir/crypto/fastRNG" + "gitlab.com/xx_network/crypto/csprng" "gitlab.com/xx_network/primitives/id" "gitlab.com/xx_network/primitives/id/ephemeral" + "gitlab.com/xx_network/primitives/ndf" "reflect" "testing" "time" @@ -25,6 +29,17 @@ func TestManager_ProcessMessageRetrieval(t *testing.T) { roundId := id.Round(5) mockComms := &mockMessageRetrievalComms{testingSignature: t} quitChan := make(chan struct{}) + testNdf := getNDF() + nodeId := id.NewIdFromString(ReturningGateway, id.Node, &testing.T{}) + gwId := nodeId.DeepCopy() + gwId.SetType(id.Gateway) + testNdf.Gateways = []ndf.Gateway{{ID: gwId.Marshal()}} + + p := gateway.DefaultPoolParams() + p.MaxPoolSize = 1 + testManager.sender, _ = gateway.NewSender(p, + fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG), + testNdf, mockComms, testManager.Session, nil) // Create a local channel so reception is possible (testManager.messageBundles is // send only via newManager call above) @@ -103,8 +118,19 @@ func TestManager_ProcessMessageRetrieval(t *testing.T) { func TestManager_ProcessMessageRetrieval_NoRound(t *testing.T) { // General initializations testManager := newManager(t) + p := gateway.DefaultPoolParams() + p.MaxPoolSize = 1 roundId := id.Round(5) mockComms := &mockMessageRetrievalComms{testingSignature: t} + testNdf := getNDF() + nodeId := id.NewIdFromString(FalsePositive, id.Node, &testing.T{}) + gwId := nodeId.DeepCopy() + gwId.SetType(id.Gateway) + testNdf.Gateways = []ndf.Gateway{{ID: gwId.Marshal()}} + + testManager.sender, _ = gateway.NewSender(p, + fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG), + testNdf, mockComms, testManager.Session, nil) quitChan := make(chan struct{}) // Create a local channel so reception is possible (testManager.messageBundles is @@ -130,7 +156,7 @@ func TestManager_ProcessMessageRetrieval_NoRound(t *testing.T) { }, } - idList := [][]byte{dummyGateway.Bytes()} + idList := [][]byte{dummyGateway.Marshal()} roundInfo := &pb.RoundInfo{ ID: uint64(roundId), @@ -172,6 +198,17 @@ func TestManager_ProcessMessageRetrieval_FalsePositive(t *testing.T) { roundId := id.Round(5) mockComms := &mockMessageRetrievalComms{testingSignature: t} quitChan := make(chan struct{}) + testNdf := getNDF() + nodeId := id.NewIdFromString(FalsePositive, id.Node, &testing.T{}) + gwId := nodeId.DeepCopy() + gwId.SetType(id.Gateway) + testNdf.Gateways = []ndf.Gateway{{ID: gwId.Marshal()}} + + p := gateway.DefaultPoolParams() + p.MaxPoolSize = 1 + testManager.sender, _ = gateway.NewSender(p, + fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG), + testNdf, mockComms, testManager.Session, nil) // Create a local channel so reception is possible (testManager.messageBundles is // send only via newManager call above) @@ -306,6 +343,17 @@ func TestManager_ProcessMessageRetrieval_MultipleGateways(t *testing.T) { roundId := id.Round(5) mockComms := &mockMessageRetrievalComms{testingSignature: t} quitChan := make(chan struct{}) + testNdf := getNDF() + nodeId := id.NewIdFromString(ReturningGateway, id.Node, &testing.T{}) + gwId := nodeId.DeepCopy() + gwId.SetType(id.Gateway) + testNdf.Gateways = []ndf.Gateway{{ID: gwId.Marshal()}} + + p := gateway.DefaultPoolParams() + p.MaxPoolSize = 1 + testManager.sender, _ = gateway.NewSender(p, + fastRNG.NewStreamGenerator(1, 1, csprng.NewSystemRNG), + testNdf, mockComms, testManager.Session, nil) // Create a local channel so reception is possible (testManager.messageBundles is // send only via newManager call above) diff --git a/network/rounds/utils_test.go b/network/rounds/utils_test.go index 69c5925ed5e5b787f2440169d5da9a24d373b219..d352078dcd1219b188c8c7fde0b807748d8c3521 100644 --- a/network/rounds/utils_test.go +++ b/network/rounds/utils_test.go @@ -14,6 +14,7 @@ import ( pb "gitlab.com/elixxir/comms/mixmessages" "gitlab.com/xx_network/comms/connect" "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/ndf" "testing" ) @@ -28,7 +29,6 @@ func newManager(face interface{}) *Manager { TransmissionID: sess1.GetUser().TransmissionID, }, } - return testManager } @@ -38,10 +38,20 @@ const ReturningGateway = "GetMessageRequest" const FalsePositive = "FalsePositive" const PayloadMessage = "Payload" const ErrorGateway = "Error" + type mockMessageRetrievalComms struct { testingSignature *testing.T } +func (mmrc *mockMessageRetrievalComms) AddHost(hid *id.ID, address string, cert []byte, params connect.HostParams) (host *connect.Host, err error) { + host, _ = mmrc.GetHost(nil) + return host, nil +} + +func (mmrc *mockMessageRetrievalComms) RemoveHost(hid *id.ID) { + +} + func (mmrc *mockMessageRetrievalComms) GetHost(hostId *id.ID) (*connect.Host, bool) { h, _ := connect.NewHost(hostId, "0.0.0.0", []byte(""), connect.HostParams{ MaxRetries: 0, @@ -91,3 +101,42 @@ func (mmrc *mockMessageRetrievalComms) RequestMessages(host *connect.Host, return nil, nil } + +func getNDF() *ndf.NetworkDefinition { + return &ndf.NetworkDefinition{ + E2E: ndf.Group{ + Prime: "E2EE983D031DC1DB6F1A7A67DF0E9A8E5561DB8E8D49413394C049B" + + "7A8ACCEDC298708F121951D9CF920EC5D146727AA4AE535B0922C688B55B3DD2AE" + + "DF6C01C94764DAB937935AA83BE36E67760713AB44A6337C20E7861575E745D31F" + + "8B9E9AD8412118C62A3E2E29DF46B0864D0C951C394A5CBBDC6ADC718DD2A3E041" + + "023DBB5AB23EBB4742DE9C1687B5B34FA48C3521632C4A530E8FFB1BC51DADDF45" + + "3B0B2717C2BC6669ED76B4BDD5C9FF558E88F26E5785302BEDBCA23EAC5ACE9209" + + "6EE8A60642FB61E8F3D24990B8CB12EE448EEF78E184C7242DD161C7738F32BF29" + + "A841698978825B4111B4BC3E1E198455095958333D776D8B2BEEED3A1A1A221A6E" + + "37E664A64B83981C46FFDDC1A45E3D5211AAF8BFBC072768C4F50D7D7803D2D4F2" + + "78DE8014A47323631D7E064DE81C0C6BFA43EF0E6998860F1390B5D3FEACAF1696" + + "015CB79C3F9C2D93D961120CD0E5F12CBB687EAB045241F96789C38E89D796138E" + + "6319BE62E35D87B1048CA28BE389B575E994DCA755471584A09EC723742DC35873" + + "847AEF49F66E43873", + Generator: "2", + }, + CMIX: ndf.Group{ + Prime: "9DB6FB5951B66BB6FE1E140F1D2CE5502374161FD6538DF1648218642F0B5C48" + + "C8F7A41AADFA187324B87674FA1822B00F1ECF8136943D7C55757264E5A1A44F" + + "FE012E9936E00C1D3E9310B01C7D179805D3058B2A9F4BB6F9716BFE6117C6B5" + + "B3CC4D9BE341104AD4A80AD6C94E005F4B993E14F091EB51743BF33050C38DE2" + + "35567E1B34C3D6A5C0CEAA1A0F368213C3D19843D0B4B09DCB9FC72D39C8DE41" + + "F1BF14D4BB4563CA28371621CAD3324B6A2D392145BEBFAC748805236F5CA2FE" + + "92B871CD8F9C36D3292B5509CA8CAA77A2ADFC7BFD77DDA6F71125A7456FEA15" + + "3E433256A2261C6A06ED3693797E7995FAD5AABBCFBE3EDA2741E375404AE25B", + Generator: "5C7FF6B06F8F143FE8288433493E4769C4D988ACE5BE25A0E24809670716C613" + + "D7B0CEE6932F8FAA7C44D2CB24523DA53FBE4F6EC3595892D1AA58C4328A06C4" + + "6A15662E7EAA703A1DECF8BBB2D05DBE2EB956C142A338661D10461C0D135472" + + "085057F3494309FFA73C611F78B32ADBB5740C361C9F35BE90997DB2014E2EF5" + + "AA61782F52ABEB8BD6432C4DD097BC5423B285DAFB60DC364E8161F4A2A35ACA" + + "3A10B1C4D203CC76A470A33AFDCBDD92959859ABD8B56E1725252D78EAC66E71" + + "BA9AE3F1DD2487199874393CD4D832186800654760E1E34C09E4D155179F9EC0" + + "DC4473F996BDCE6EED1CABED8B6F116F7AD9CF505DF0F998E34AB27514B0FFE7", + }, + } +} diff --git a/network/send.go b/network/send.go index 0e3e9020aa187716e3a7dd7da9f538bff47ba3e9..d70a0c6661c2ee6f4c40f313219baa7cbb86fd52 100644 --- a/network/send.go +++ b/network/send.go @@ -28,7 +28,7 @@ func (m *manager) SendCMIX(msg format.Message, recipient *id.ID, param params.CM "network is not healthy") } - return m.message.SendCMIX(msg, recipient, param) + return m.message.SendCMIX(m.GetSender(), msg, recipient, param) } // SendUnsafe sends an unencrypted payload to the provided recipient diff --git a/single/manager_test.go b/single/manager_test.go index 5ee50ed606ff203253729aae941e248c224df533..2ce788f7b86ba44fd61a2aaa68814d1d62a49a02 100644 --- a/single/manager_test.go +++ b/single/manager_test.go @@ -14,6 +14,7 @@ import ( "gitlab.com/elixxir/client/interfaces" "gitlab.com/elixxir/client/interfaces/message" "gitlab.com/elixxir/client/interfaces/params" + "gitlab.com/elixxir/client/network/gateway" "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/client/storage" "gitlab.com/elixxir/client/storage/reception" @@ -323,6 +324,10 @@ func (tnm *testNetworkManager) InProgressRegistrations() int { return 0 } +func (t *testNetworkManager) GetSender() *gateway.Sender { + return nil +} + func getNDF() *ndf.NetworkDefinition { return &ndf.NetworkDefinition{ E2E: ndf.Group{ diff --git a/stoppable/cleanup.go b/stoppable/cleanup.go index 8111abe206492d4e22cfaffe31195acab58b77c5..b1e2c4561dc87c2af1ec654d48b787d47ca8182e 100644 --- a/stoppable/cleanup.go +++ b/stoppable/cleanup.go @@ -87,7 +87,7 @@ func (c *Cleanup) Close(timeout time.Duration) error { } }) - if err!=nil{ + if err != nil { jww.ERROR.Printf(err.Error()) } diff --git a/stoppable/single.go b/stoppable/single.go index 8821db88eb4934a06db4f96484ef700b2c6beeff..2e8fa78a2a1c3f4f12752f045cb93c541da8412b 100644 --- a/stoppable/single.go +++ b/stoppable/single.go @@ -52,7 +52,6 @@ func (s *Single) Name() string { func (s *Single) Close(timeout time.Duration) error { var err error s.once.Do(func() { - atomic.StoreUint32(&s.running, 0) timer := time.NewTimer(timeout) select { case <-timer.C: @@ -61,6 +60,7 @@ func (s *Single) Close(timeout time.Duration) error { err = errors.Errorf("%s failed to close", s.name) case s.quit <- struct{}{}: } + atomic.StoreUint32(&s.running, 0) }) return err } diff --git a/storage/auth/store.go b/storage/auth/store.go index d6730d03d8fcde7db0ee9edb76cb056942be8b7c..9ce2e6154e3b5bce28f4df6c36cbffca43255fc8 100644 --- a/storage/auth/store.go +++ b/storage/auth/store.go @@ -365,7 +365,7 @@ func (s *Store) Done(partner *id.ID) { s.mux.RUnlock() if !ok { - jww.ERROR.Panicf("Request cannot be finished, not " + + jww.ERROR.Panicf("Request cannot be finished, not "+ "found: %s", partner) return } diff --git a/storage/cmix/store_test.go b/storage/cmix/store_test.go index f4dfc3a7af59ee29511a2172779c6b919ef6ec6b..63329b769b1706dc4a341988a03e28fe84a37642 100644 --- a/storage/cmix/store_test.go +++ b/storage/cmix/store_test.go @@ -70,7 +70,6 @@ func TestStore_AddRemove(t *testing.T) { } } - // Happy path Add/Has test func TestStore_AddHas(t *testing.T) { // Uncomment to print keys that Set and Get are called on diff --git a/storage/e2e/relationship.go b/storage/e2e/relationship.go index 6efd7081b133fae3845c83e1fbc5ecf3997da1ae..a492a72ed51d0f48f08657d786bc8dc6a3c861ed 100644 --- a/storage/e2e/relationship.go +++ b/storage/e2e/relationship.go @@ -303,9 +303,9 @@ func (r *relationship) getNewestRekeyableSession() *Session { // always valid. It isn't clear it can fail though because we are // accessing the data in the same order it would be written (i think) if s.Status() != RekeyEmpty { - if s.IsConfirmed(){ + if s.IsConfirmed() { return s - }else if unconfirmed == nil{ + } else if unconfirmed == nil { unconfirmed = s } } diff --git a/storage/e2e/relationship_test.go b/storage/e2e/relationship_test.go index 49fa51de4abe7df24421ae8c91b24c09f8c10b45..457e4f6dcd508abd2014248db7738f46badf947a 100644 --- a/storage/e2e/relationship_test.go +++ b/storage/e2e/relationship_test.go @@ -571,4 +571,4 @@ func Test_relationship_getNewestRekeyableSession(t *testing.T) { } }) } -} \ No newline at end of file +} diff --git a/storage/e2e/session.go b/storage/e2e/session.go index 5f0344378ae7f106edd4c447904dd89cb2b6ca94..ad684725a04918389f706566e44642c6c1fbaf16 100644 --- a/storage/e2e/session.go +++ b/storage/e2e/session.go @@ -119,7 +119,7 @@ func newSession(ship *relationship, t RelationshipType, myPrivKey, partnerPubKey relationshipFingerprint: relationshipFingerprint, negotiationStatus: negotiationStatus, partnerSource: trigger, - partner: ship.manager.partner.DeepCopy(), + partner: ship.manager.partner.DeepCopy(), } session.kv = session.generate(ship.kv) @@ -173,8 +173,8 @@ func loadSession(ship *relationship, kv *versioned.KV, } session.relationshipFingerprint = relationshipFingerprint - if !session.partner.Cmp(ship.manager.partner){ - return nil, errors.Errorf("Stored partner (%s) did not match " + + if !session.partner.Cmp(ship.manager.partner) { + return nil, errors.Errorf("Stored partner (%s) did not match "+ "relationship partner (%s)", session.partner, ship.manager.partner) } diff --git a/storage/e2e/session_test.go b/storage/e2e/session_test.go index 800cac5410e8b5236d9c89aa5ed90c36a49bc654..258cea9edab79a1e0fe6f214b4b6a78e4814e9d3 100644 --- a/storage/e2e/session_test.go +++ b/storage/e2e/session_test.go @@ -628,8 +628,8 @@ func makeTestSession() (*Session, *context) { e2eParams: params.GetDefaultE2ESessionParams(), relationship: &relationship{ manager: &Manager{ - ctx: ctx, - kv: kv, + ctx: ctx, + kv: kv, partner: &id.ID{}, }, kv: kv, @@ -638,7 +638,7 @@ func makeTestSession() (*Session, *context) { t: Receive, negotiationStatus: Confirmed, rekeyThreshold: 5, - partner: &id.ID{}, + partner: &id.ID{}, } var err error s.keyState, err = newStateVector(s.kv, diff --git a/storage/e2e/store.go b/storage/e2e/store.go index a0a427e0abbff557514b6d4197335cc75c427a69..e647cacab042388b667d383f670c8082c450b591 100644 --- a/storage/e2e/store.go +++ b/storage/e2e/store.go @@ -260,8 +260,8 @@ func (s *Store) unmarshal(b []byte) error { partnerID, err.Error()) } - if !manager.GetPartnerID().Cmp(partnerID){ - jww.FATAL.Panicf("Loaded a manager with the wrong partner " + + if !manager.GetPartnerID().Cmp(partnerID) { + jww.FATAL.Panicf("Loaded a manager with the wrong partner "+ "ID: \n\t loaded: %s \n\t present: %s", partnerID, manager.GetPartnerID()) } diff --git a/storage/reception/IdentityUse.go b/storage/reception/IdentityUse.go index 93fc6b8d6c374cbf166d1a55e3c302a46192704f..2c8b9df3b64601651d7489e96935ff1a3d360db9 100644 --- a/storage/reception/IdentityUse.go +++ b/storage/reception/IdentityUse.go @@ -22,6 +22,7 @@ type IdentityUse struct { UR *rounds.UnknownRounds ER *rounds.EarliestRound + CR *rounds.CheckedRounds } // setSamplingPeriod add the Request mask as a random buffer around the sampling diff --git a/storage/reception/fake_test.go b/storage/reception/fake_test.go index 35ff0e1727a76645ef350b8d1e4a7cd5a76691e9..2c748ac2258e0950e8901a2748c0057c498770ef 100644 --- a/storage/reception/fake_test.go +++ b/storage/reception/fake_test.go @@ -24,7 +24,7 @@ func Test_generateFakeIdentity(t *testing.T) { "\"EndValid\":" + string(endValid) + "," + "\"RequestMask\":86400000000000,\"Ephemeral\":true," + "\"StartRequest\":\"0001-01-01T00:00:00Z\"," + - "\"EndRequest\":\"0001-01-01T00:00:00Z\",\"Fake\":true,\"UR\":null,\"ER\":null}" + "\"EndRequest\":\"0001-01-01T00:00:00Z\",\"Fake\":true,\"UR\":null,\"ER\":null,\"CR\":null}" timestamp := time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC) @@ -36,8 +36,8 @@ func Test_generateFakeIdentity(t *testing.T) { receivedJson, _ := json.Marshal(received) if expected != string(receivedJson) { - t.Errorf("The fake identity was generated incorrectly.\n "+ - "expected: %s\nreceived: %s", expected, receivedJson) + t.Errorf("The fake identity was generated incorrectly."+ + "\nexpected: %s\nreceived: %s", expected, receivedJson) } } diff --git a/storage/reception/registration.go b/storage/reception/registration.go index 70647129c7c621b9a757b9d42ec82ef66f3c1bd3..0535f092366ce03f556f21ad61b39b8b65a97b63 100644 --- a/storage/reception/registration.go +++ b/storage/reception/registration.go @@ -2,6 +2,8 @@ package reception import ( "github.com/pkg/errors" + jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/client/interfaces/params" "gitlab.com/elixxir/client/storage/rounds" "gitlab.com/elixxir/client/storage/versioned" "gitlab.com/xx_network/primitives/id" @@ -17,6 +19,7 @@ type registration struct { Identity UR *rounds.UnknownRounds ER *rounds.EarliestRound + CR *rounds.CheckedRounds kv *versioned.KV } @@ -46,6 +49,11 @@ func newRegistration(reg Identity, kv *versioned.KV) (*registration, error) { urParams.Stored = !reg.Ephemeral r.UR = rounds.NewUnknownRounds(kv, urParams) r.ER = rounds.NewEarliestRound(!reg.Ephemeral, kv) + cr, err := rounds.NewCheckedRounds(int(params.GetDefaultNetwork().KnownRoundsThreshold), kv) + if err != nil { + jww.FATAL.Printf("Failed to create new CheckedRounds for registration: %+v", err) + } + r.CR = cr // If this is not ephemeral, then store everything if !reg.Ephemeral { @@ -71,11 +79,24 @@ func loadRegistration(EphId ephemeral.Id, Source *id.ID, startValid time.Time, "for %s", regPrefix(EphId, Source, startValid)) } + cr, err := rounds.LoadCheckedRounds(int(params.GetDefaultNetwork().KnownRoundsThreshold), kv) + if err != nil { + jww.ERROR.Printf("Making new CheckedRounds, loading of CheckedRounds "+ + "failed: %+v", err) + + cr, err = rounds.NewCheckedRounds(int(params.GetDefaultNetwork().KnownRoundsThreshold), kv) + if err != nil { + jww.FATAL.Printf("Failed to create new CheckedRounds for "+ + "registration after CheckedRounds load failure: %+v", err) + } + } + r := ®istration{ Identity: reg, kv: kv, UR: rounds.LoadUnknownRounds(kv, rounds.DefaultUnknownRoundsParams()), ER: rounds.LoadEarliestRound(kv), + CR: cr, } return r, nil diff --git a/storage/reception/store.go b/storage/reception/store.go index f87ee416abaecd01e3955549080e462b3ae1e3ad..bd78f0b7682d3c06c2b47211f518729cb63cbfc1 100644 --- a/storage/reception/store.go +++ b/storage/reception/store.go @@ -382,5 +382,6 @@ func (s *Store) selectIdentity(rng io.Reader, now time.Time) (IdentityUse, error Fake: false, UR: selected.UR, ER: selected.ER, + CR: selected.CR, }, nil } diff --git a/storage/rounds/checkedRounds.go b/storage/rounds/checkedRounds.go new file mode 100644 index 0000000000000000000000000000000000000000..a7e3ebdecd425a3ba97902a2197fd3189b7b6f59 --- /dev/null +++ b/storage/rounds/checkedRounds.go @@ -0,0 +1,164 @@ +package rounds + +import ( + "container/list" + "encoding/binary" + "github.com/pkg/errors" + "gitlab.com/elixxir/client/storage/utility" + "gitlab.com/elixxir/client/storage/versioned" + "gitlab.com/xx_network/primitives/id" +) + +const itemsPerBlock = 50 + +type CheckedRounds struct { + // Map of round IDs for quick lookup if an ID is stored + m map[id.Round]interface{} + + // List of round IDs in order of age; oldest in front and newest in back + l *list.List + + // List of recently added round IDs that need to be stored + recent []id.Round + + // Saves round IDs in blocks to storage + store *utility.BlockStore + + // The maximum number of round IDs to store before pruning the oldest + maxRounds int +} + +// NewCheckedRounds returns a new CheckedRounds with an initialized map. +func NewCheckedRounds(maxRounds int, kv *versioned.KV) (*CheckedRounds, error) { + // Calculate the number of blocks of size itemsPerBlock are needed to store + // numRoundsToKeep number of round IDs + numBlocks := maxRounds / itemsPerBlock + if maxRounds%itemsPerBlock != 0 { + numBlocks++ + } + + // Create a new BlockStore for storing the round IDs to storage + store, err := utility.NewBlockStore(itemsPerBlock, numBlocks, kv) + if err != nil { + return nil, errors.Errorf("failed to save new checked rounds to storage: %+v", err) + } + + // Create new CheckedRounds + return newCheckedRounds(maxRounds, store), nil +} + +// newCheckedRounds initialises the lists in CheckedRounds. +func newCheckedRounds(maxRounds int, store *utility.BlockStore) *CheckedRounds { + return &CheckedRounds{ + m: make(map[id.Round]interface{}), + l: list.New(), + recent: []id.Round{}, + store: store, + maxRounds: maxRounds, + } +} + +// LoadCheckedRounds restores the list from storage. +func LoadCheckedRounds(maxRounds int, kv *versioned.KV) (*CheckedRounds, error) { + // Get rounds from storage + store, rounds, err := utility.LoadBlockStore(kv) + if err != nil { + return nil, errors.Errorf("failed to load CheckedRounds from storage: %+v", err) + } + + // Create new CheckedRounds + cr := newCheckedRounds(maxRounds, store) + + // Unmarshal round ID byte list into the new CheckedRounds + cr.unmarshal(rounds) + + return cr, nil +} + +// SaveCheckedRounds stores the list to storage. +func (cr *CheckedRounds) SaveCheckedRounds() error { + + // Save to disk + err := cr.store.Store(cr) + if err != nil { + return errors.Errorf("failed to store recent CheckedRounds: %+v", err) + } + + // Save to disk + return nil +} + +// Next pops the oldest recent round ID from the list and returns it as bytes. +// Returns false if the list is empty +func (cr *CheckedRounds) Next() ([]byte, bool) { + if len(cr.recent) > 0 { + b := make([]byte, 8) + binary.LittleEndian.PutUint64(b, uint64(cr.recent[0])) + cr.recent = cr.recent[1:] + return b, true + } + + return nil, false +} + +// Check determines if the round ID has been added to the checklist. If it has +// not, then it is added and the function returns true. Otherwise, if it already +// exists, then the function returns false. +func (cr *CheckedRounds) Check(rid id.Round) bool { + // Add the round ID to the checklist if it does not exist and return true + if _, exists := cr.m[rid]; !exists { + cr.m[rid] = nil // Add ID to the map + cr.l.PushBack(rid) // Add ID to the end of the list + cr.recent = append(cr.recent, rid) + + // The commented out code below works the same as the Prune function but + // occurs when adding a round ID to the list. It was decided to use + // Prune instead so that it does not block even though the savings are + // probably negligible. + // // Remove the oldest round ID the list is full + // if cr.l.Len() > cr.maxRounds { + // oldestID := cr.l.Remove(cr.l.Front()) // Remove oldest from list + // delete(cr.m, oldestID.(id.Round)) // Remove oldest from map + // } + + return true + } + + return false +} + +// IsChecked determines if the round has been added to the checklist. +func (cr *CheckedRounds) IsChecked(rid id.Round) bool { + _, exists := cr.m[rid] + return exists +} + +// Prune any rounds that are earlier than the earliestAllowed. +func (cr *CheckedRounds) Prune() { + if len(cr.m) < cr.maxRounds { + return + } + earliestAllowed := cr.l.Back().Value.(id.Round) - id.Round(cr.maxRounds) + 1 + + // Iterate over all the round IDs and remove any that are too old + for e := cr.l.Front(); e != nil; { + if e.Value.(id.Round) < earliestAllowed { + delete(cr.m, e.Value.(id.Round)) + lastE := e + e = e.Next() + cr.l.Remove(lastE) + } else { + break + } + } +} + +// unmarshal unmarshalls the list of byte slices into the CheckedRounds map and +// list. +func (cr *CheckedRounds) unmarshal(rounds [][]byte) { + for _, round := range rounds { + rid := id.Round(binary.LittleEndian.Uint64(round)) + cr.m[rid] = nil + cr.l.PushBack(rid) + } +} diff --git a/storage/rounds/checkedRounds_test.go b/storage/rounds/checkedRounds_test.go new file mode 100644 index 0000000000000000000000000000000000000000..816c6fce7c311b71c5dd2d33f4b5fa07b2ec7ec9 --- /dev/null +++ b/storage/rounds/checkedRounds_test.go @@ -0,0 +1,217 @@ +package rounds + +import ( + "container/list" + "encoding/binary" + "gitlab.com/elixxir/client/storage/utility" + "gitlab.com/elixxir/client/storage/versioned" + "gitlab.com/elixxir/ekv" + "gitlab.com/xx_network/primitives/id" + "reflect" + "testing" +) + +// Happy path. +func Test_newCheckedRounds(t *testing.T) { + maxRounds := 230 + kv := versioned.NewKV(make(ekv.Memstore)) + + // Create a new BlockStore for storing the round IDs to storage + store, err := utility.NewBlockStore(itemsPerBlock, maxRounds/itemsPerBlock+1, kv) + if err != nil { + t.Errorf("Failed to create new BlockStore: %+v", err) + } + + expected := &CheckedRounds{ + m: make(map[id.Round]interface{}), + l: list.New(), + recent: []id.Round{}, + store: store, + maxRounds: maxRounds, + } + + received, err := NewCheckedRounds(maxRounds, kv) + if err != nil { + t.Errorf("NewCheckedRounds() returned an error: %+v", err) + } + + if !reflect.DeepEqual(expected, received) { + t.Errorf("NewCheckedRounds() did not return the exepcted CheckedRounds."+ + "\nexpected: %+v\nreceived: %+v", expected, received) + } +} + +// Tests that a CheckedRounds that has been saved and loaded from storage +// matches the original. +func TestCheckedRounds_SaveCheckedRounds_TestLoadCheckedRounds(t *testing.T) { + // Create new CheckedRounds and add rounds to it + kv := versioned.NewKV(make(ekv.Memstore)) + cr, err := NewCheckedRounds(50, kv) + if err != nil { + t.Errorf("failed to make new CheckedRounds: %+v", err) + } + for i := id.Round(0); i < 100; i++ { + cr.Check(i) + } + + err = cr.SaveCheckedRounds() + if err != nil { + t.Errorf("SaveCheckedRounds() returned an error: %+v", err) + } + + cr.Prune() + + newCR, err := LoadCheckedRounds(50, kv) + if err != nil { + t.Errorf("LoadCheckedRounds() returned an error: %+v", err) + } + + if !reflect.DeepEqual(cr, newCR) { + t.Errorf("Failed to store and load CheckedRounds."+ + "\nexpected: %+v\nreceived: %+v", cr, newCR) + } +} + +// Happy path. +func TestCheckedRounds_Next(t *testing.T) { + cr := newCheckedRounds(100, nil) + rounds := make([][]byte, 10) + for i := id.Round(0); i < 10; i++ { + cr.Check(i) + } + + for i := id.Round(0); i < 10; i++ { + round, exists := cr.Next() + if !exists { + t.Error("Next() returned false when there should be more IDs.") + } + + rounds[i] = round + } + round, exists := cr.Next() + if exists { + t.Errorf("Next() returned true when the list should be empty: %d", round) + } + + testCR := newCheckedRounds(100, nil) + testCR.unmarshal(rounds) + + if !reflect.DeepEqual(cr, testCR) { + t.Errorf("unmarshal() did not return the expected CheckedRounds."+ + "\nexpected: %+v\nreceived: %+v", cr, testCR) + } +} + +// Happy path. +func Test_checkedRounds_Check(t *testing.T) { + cr := newCheckedRounds(100, nil) + var expected []id.Round + for i := id.Round(1); i < 11; i++ { + if i%2 == 0 { + if !cr.Check(i) { + t.Errorf("Check() returned false when the round ID should have been added (%d).", i) + } + + val := cr.l.Back().Value.(id.Round) + if val != i { + t.Errorf("Check() did not add the round ID to the back of the list."+ + "\nexpected: %d\nreceived: %d", i, val) + } + expected = append(expected, i) + } + } + + if !reflect.DeepEqual(cr.recent, expected) { + t.Errorf("Unexpected list of recent rounds."+ + "\nexpected: %+v\nreceived: %+v", expected, cr.recent) + } + + for i := id.Round(1); i < 11; i++ { + result := cr.Check(i) + if i%2 == 0 { + if result { + t.Errorf("Check() returned true when the round ID should not have been added (%d).", i) + } + } else if !result { + t.Errorf("Check() returned false when the round ID should have been added (%d).", i) + } else { + expected = append(expected, i) + } + } + + if !reflect.DeepEqual(cr.recent, expected) { + t.Errorf("Unexpected list of recent rounds."+ + "\nexpected: %+v\nreceived: %+v", expected, cr.recent) + } +} + +// Happy path. +func TestCheckedRounds_IsChecked(t *testing.T) { + cr := newCheckedRounds(100, nil) + + for i := id.Round(0); i < 100; i += 2 { + cr.Check(i) + } + + for i := id.Round(0); i < 100; i++ { + if i%2 == 0 { + if !cr.IsChecked(i) { + t.Errorf("IsChecked() falsly reported round ID %d as not checked.", i) + } + } else if cr.IsChecked(i) { + t.Errorf("IsChecked() falsly reported round ID %d as checked.", i) + } + } +} + +// Happy path. +func Test_checkedRounds_Prune(t *testing.T) { + cr := newCheckedRounds(5, nil) + for i := id.Round(0); i < 10; i++ { + cr.Check(i) + } + + cr.Prune() + + if len(cr.m) != 5 || cr.l.Len() != 5 { + t.Errorf("Prune() did not remove the correct number of round IDs."+ + "\nexpected: %d\nmap: %d\nlist: %d", 5, + len(cr.m), cr.l.Len()) + } +} + +// Happy path: length of the list is not too long and does not need to be pruned. +func Test_checkedRounds_Prune_NoChange(t *testing.T) { + cr := newCheckedRounds(100, nil) + for i := id.Round(0); i < 10; i++ { + cr.Check(i) + } + + cr.Prune() + + if len(cr.m) != 10 || cr.l.Len() != 10 { + t.Errorf("Prune() did not remove the correct number of round IDs."+ + "\nexpected: %d\nmap: %d\nlist: %d", 5, + len(cr.m), cr.l.Len()) + } +} + +// Happy path. +func TestCheckedRounds_unmarshal(t *testing.T) { + expected := newCheckedRounds(100, nil) + rounds := make([][]byte, 10) + for i := id.Round(0); i < 10; i++ { + expected.Check(i) + rounds[i] = make([]byte, 8) + binary.LittleEndian.PutUint64(rounds[i], uint64(i)) + } + expected.recent = []id.Round{} + + cr := newCheckedRounds(100, nil) + cr.unmarshal(rounds) + + if !reflect.DeepEqual(expected, cr) { + t.Errorf("unmarshal() did not return the expected CheckedRounds."+ + "\nexpected: %+v\nreceived: %+v", expected, cr) + } +} diff --git a/storage/rounds/unknownRounds_test.go b/storage/rounds/unknownRounds_test.go index ffc9bdd46a74060948a7a06d724cf2590d203da4..f437e0839f7d37208f2d7b5c826e6a6cfeb7d5b2 100644 --- a/storage/rounds/unknownRounds_test.go +++ b/storage/rounds/unknownRounds_test.go @@ -27,7 +27,7 @@ func TestNewUnknownRoundsStore(t *testing.T) { params: DefaultUnknownRoundsParams(), } - store := NewUnknownRounds(kv, DefaultUnknownRoundsParams()) + store := NewUnknownRounds(kv, DefaultUnknownRoundsParams()) // Compare manually created object with NewUnknownRoundsStore if !reflect.DeepEqual(expectedStore, store) { @@ -60,7 +60,7 @@ func TestNewUnknownRoundsStore(t *testing.T) { // Full test func TestUnknownRoundsStore_Iterate(t *testing.T) { kv := versioned.NewKV(make(ekv.Memstore)) - store := NewUnknownRounds(kv, DefaultUnknownRoundsParams()) + store := NewUnknownRounds(kv, DefaultUnknownRoundsParams()) // Return true only for rounds that are even mockChecker := func(rid id.Round) bool { @@ -116,7 +116,7 @@ func TestUnknownRoundsStore_Iterate(t *testing.T) { // Iterate over map until all rounds have checks incremented over // maxCheck for i := 0; i < defaultMaxCheck+1; i++ { - _ = store.Iterate(mockChecker,[]id.Round{}) + _ = store.Iterate(mockChecker, []id.Round{}) } diff --git a/storage/utility/blockStore.go b/storage/utility/blockStore.go new file mode 100644 index 0000000000000000000000000000000000000000..b29dcd12d56037d6e7d0622b1fbc47f8faea62b9 --- /dev/null +++ b/storage/utility/blockStore.go @@ -0,0 +1,281 @@ +package utility + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "github.com/pkg/errors" + jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/client/storage/versioned" + "gitlab.com/xx_network/primitives/netTime" + "strconv" +) + +// Sizes in bytes +const ( + int64Size = 8 + marshalledSize = 4 * int64Size +) + +// Error messages +const ( + bsBuffLengthErr = "length of buffer %d != %d expected" + bsKvSaveErr = "failed to save blockStore to KV: %+v" + bsKvInitSaveErr = "failed to save initial block: %+v" + bsKvLoadErr = "failed to get BlockStore from storage: %+v" + bsKvUnmarshalErr = "failed to unmarshal BlockStore loaded from storage: %+v" + bJsonMarshalErr = "failed to JSON marshal block %d: %+v" + bKvSaveErr = "failed to save block %d to KV: %+v" + bKvDeleteErr = "failed to delete block %d from KV: %+v" + bKvLoadErr = "failed to get block %d from KV: %+v" + bJsonUnmarshalErr = "failed to JSON marshal block %d: %+v" +) + +// Storage keys and parts +const ( + delimiter = "/" + blockStoreKey = "blockStore" + blockStoreVersion = 0 + blockKey = "block" + blockVersion = 0 +) + +type Iterator interface { + Next() ([]byte, bool) +} + +type BlockStore struct { + block [][]byte + numBlocks int // The maximum number of blocks saved to the kv + blockSize int // The maximum number of items allowed in a block + firstSaved int // The index of the oldest block in the list + lastSaved int // The index of the newest block in the list + kv *versioned.KV +} + +// NewBlockStore returns a new BlockStore and saves it to storage. +func NewBlockStore(numBlocks, blockSize int, kv *versioned.KV) (*BlockStore, error) { + bs := &BlockStore{ + block: make([][]byte, 0, blockSize), + numBlocks: numBlocks, + blockSize: blockSize, + firstSaved: 0, + lastSaved: 0, + kv: kv, + } + + return bs, bs.save() +} + +// LoadBlockStore returns the BlockStore from storage and a concatenation of all +// blocks in storage. +func LoadBlockStore(kv *versioned.KV) (*BlockStore, [][]byte, error) { + bs := &BlockStore{kv: kv} + + // Get BlockStore parameters from storage + err := bs.load() + if err != nil { + return nil, nil, err + } + + // LoadBlockStore each block from storage and join together into single slice + var data, block [][]byte + for i := bs.firstSaved; i <= bs.lastSaved; i++ { + // Get the block from storage + block, err = bs.loadBlock(i) + if err != nil { + return nil, nil, err + } + + // Append block to the rest of the data + data = append(data, block...) + } + + // Save the last block into memory + bs.block = block + + return bs, data, nil +} + +// Store stores all items in the Iterator to storage in blocks. +func (bs *BlockStore) Store(iter Iterator) error { + // Iterate through all items in the Iterator and add each to the block in + // memory. When the block is full, it is saved to storage and a new block is + // added to until the iterator returns false. + for item, exists := iter.Next(); exists; item, exists = iter.Next() { + // If the block is full, save it to storage and start a new block + if len(bs.block) >= bs.blockSize { + if err := bs.saveBlock(); err != nil { + return err + } + + bs.lastSaved++ + bs.block = make([][]byte, 0, bs.blockSize) + } + + // Append the item to the block in memory + bs.block = append(bs.block, item) + } + + // Save the current partially filled block to storage + if err := bs.saveBlock(); err != nil { + return err + } + + // Calculate the new first saved index + oldFirstSaved := bs.firstSaved + if (bs.lastSaved+1)-bs.firstSaved > bs.numBlocks { + bs.firstSaved = (bs.lastSaved + 1) - bs.numBlocks + } + + // Save the BlockStorage parameters to storage + if err := bs.save(); err != nil { + return err + } + + // Delete all old blocks + bs.pruneBlocks(oldFirstSaved) + + return nil +} + +// saveBlock saves the block to an indexed storage. +func (bs *BlockStore) saveBlock() error { + // JSON marshal block + data, err := json.Marshal(bs.block) + if err != nil { + return errors.Errorf(bJsonMarshalErr, bs.lastSaved, err) + } + + // Construct versioning object + obj := versioned.Object{ + Version: blockVersion, + Timestamp: netTime.Now(), + Data: data, + } + + // Save to storage + err = bs.kv.Set(bs.getKey(bs.lastSaved), blockVersion, &obj) + if err != nil { + return errors.Errorf(bKvSaveErr, bs.lastSaved, err) + } + + return nil +} + +// loadBlock loads the block with the index from storage. +func (bs *BlockStore) loadBlock(i int) ([][]byte, error) { + // Get the data from the kv + obj, err := bs.kv.Get(bs.getKey(i), blockVersion) + if err != nil { + return nil, errors.Errorf(bKvLoadErr, i, err) + } + + // Unmarshal the block + var block [][]byte + err = json.Unmarshal(obj.Data, &block) + if err != nil { + return nil, errors.Errorf(bJsonUnmarshalErr, i, err) + } + + return block, nil +} + +// pruneBlocks reduces the number of saved blocks to numBlocks by removing the +// oldest blocks. +func (bs *BlockStore) pruneBlocks(firstSaved int) { + // Exit if no blocks need to be pruned + if (bs.lastSaved+1)-firstSaved < bs.numBlocks { + return + } + + // Delete all blocks before the firstSaved index + for ; firstSaved < bs.firstSaved; firstSaved++ { + err := bs.kv.Delete(bs.getKey(firstSaved), blockVersion) + if err != nil { + jww.WARN.Printf(bKvDeleteErr, bs.firstSaved, err) + } + } +} + +// getKey produces a block storage key for the given index. +func (bs *BlockStore) getKey(i int) string { + return blockKey + delimiter + strconv.Itoa(i) +} + +// save saves the parameters in BlockStore to storage. It does not save any +// block data. +func (bs *BlockStore) save() error { + // Construct versioning object + obj := versioned.Object{ + Version: blockStoreVersion, + Timestamp: netTime.Now(), + Data: bs.marshal(), + } + + // Save to storage + err := bs.kv.Set(blockStoreKey, blockStoreVersion, &obj) + if err != nil { + return errors.Errorf(bsKvSaveErr, err) + } + + // Save initial block + err = bs.saveBlock() + if err != nil { + return errors.Errorf(bsKvInitSaveErr, err) + } + + return nil +} + +// load loads BlockStore parameters from storage. +func (bs *BlockStore) load() error { + // Get the data from the kv + obj, err := bs.kv.Get(blockStoreKey, blockStoreVersion) + if err != nil { + return errors.Errorf(bsKvLoadErr, err) + } + + // Unmarshal the data into a BlockStore + err = bs.unmarshal(obj.Data) + if err != nil { + return errors.Errorf(bsKvUnmarshalErr, err) + } + + return nil +} + +// marshal marshals the BlockStore integer values to a byte slice. +func (bs *BlockStore) marshal() []byte { + // Build list of values to store + values := []int{bs.numBlocks, bs.blockSize, bs.firstSaved, bs.lastSaved} + + // Convert each value to a byte slice and store + var buff bytes.Buffer + for _, val := range values { + b := make([]byte, int64Size) + binary.LittleEndian.PutUint64(b, uint64(val)) + buff.Write(b) + } + + // Return the bytes + return buff.Bytes() +} + +// unmarshal unmarshalls the BlockStore int values from the buffer. An error is +// returned if the length of the bytes is incorrect. +func (bs *BlockStore) unmarshal(b []byte) error { + // Return an error if the buffer is not the expected length + if len(b) != marshalledSize { + return errors.Errorf(bsBuffLengthErr, len(b), marshalledSize) + } + + // Convert the byte slices to ints and store + buff := bytes.NewBuffer(b) + bs.numBlocks = int(binary.LittleEndian.Uint64(buff.Next(int64Size))) + bs.blockSize = int(binary.LittleEndian.Uint64(buff.Next(int64Size))) + bs.firstSaved = int(binary.LittleEndian.Uint64(buff.Next(int64Size))) + bs.lastSaved = int(binary.LittleEndian.Uint64(buff.Next(int64Size))) + + return nil +} diff --git a/storage/utility/blockStore_test.go b/storage/utility/blockStore_test.go new file mode 100644 index 0000000000000000000000000000000000000000..c9e2fb56daef0cff51eb7e283b463f9782ac4cf6 --- /dev/null +++ b/storage/utility/blockStore_test.go @@ -0,0 +1,414 @@ +package utility + +import ( + "bytes" + "encoding/binary" + "fmt" + "gitlab.com/elixxir/client/storage/versioned" + "gitlab.com/elixxir/ekv" + "gitlab.com/xx_network/primitives/netTime" + "math/rand" + "reflect" + "strings" + "testing" +) + +type iter [][]byte + +func (it *iter) Next() ([]byte, bool) { + if len(*it) > 0 { + item := (*it)[0] + *it = (*it)[1:] + return item, true + } + + return nil, false +} + +// Happy path. +func TestNewBlockStore(t *testing.T) { + expected := &BlockStore{ + block: make([][]byte, 0, 20), + numBlocks: 50, + blockSize: 20, + firstSaved: 0, + lastSaved: 0, + kv: versioned.NewKV(make(ekv.Memstore)), + } + + bs, err := NewBlockStore(expected.numBlocks, expected.blockSize, expected.kv) + if err != nil { + t.Errorf("NewBlockStore() returned an error: %+v", err) + } + + if !reflect.DeepEqual(expected, bs) { + t.Errorf("NewBlockStore() did not return the expected BlockStore."+ + "\nexpected: %+v\nreceived: %+v", expected, bs) + } +} + +// Tests BlockStore storing and loading data in multiple situations. +func TestBlockStore_Store_LoadBlockStore(t *testing.T) { + values := []struct { + blockSize, numBlocks int + expectedFirstSaved, expectedLastSaved int + dataCutIndex int + }{ + {3, 5, 0, 3, 0}, // Multiple blocks, last block partial, no pruning + {10, 5, 0, 0, 0}, // Single block, last block full, no pruning + {15, 5, 0, 0, 0}, // Single block, last block partial, no pruning + + {2, 3, 2, 4, 4}, // Multiple blocks, last block partial, pruned + {5, 1, 1, 1, 5}, // Single block, last block full, pruned + {4, 1, 2, 2, 8}, // Single block, last block partial, pruned + } + + for i, v := range values { + // Create the initial data to store + iter := make(iter, 10) + for i := uint64(0); i < 10; i++ { + iter[i] = make([]byte, 8) + binary.LittleEndian.PutUint64(iter[i], i) + } + + // Calculate the expected data + expected := make([][]byte, len(iter[v.dataCutIndex:])) + copy(expected, iter[v.dataCutIndex:]) + + bs, err := NewBlockStore(v.numBlocks, v.blockSize, versioned.NewKV(make(ekv.Memstore))) + if err != nil { + t.Errorf("Failed to create new BlockStore (%d): %+v", i, err) + } + + // Attempt to store the data + err = bs.Store(&iter) + if err != nil { + t.Errorf("Store() returned an error (%d): %+v", i, err) + } + + if bs.firstSaved != v.expectedFirstSaved { + t.Errorf("Store() did not return the expected firstSaved (%d)."+ + "\nexpected: %d\nreceived: %d", i, v.expectedFirstSaved, bs.firstSaved) + } + + if bs.lastSaved != v.expectedLastSaved { + t.Errorf("Store() did not return the expected lastSaved (%d)."+ + "\nexpected: %d\nreceived: %d", i, v.expectedLastSaved, bs.lastSaved) + } + + // Attempt to load the data + loadBS, data, err := LoadBlockStore(bs.kv) + if err != nil { + t.Errorf("LoadBlockStore() returned an error (%d): %+v", i, err) + } + + // Check if the loaded BlockStore is correct + if !reflect.DeepEqual(bs, loadBS) { + t.Errorf("Loading wrong BlockStore from storage (%d)."+ + "\nexpected: %+v\nreceived: %+v", i, bs, loadBS) + } + + // Check if the loaded data is correct + if !reflect.DeepEqual(expected, data) { + t.Errorf("Loading wrong data from storage (%d)."+ + "\nexpected: %+v\nreceived: %+v", i, expected, data) + } + } +} + +// Tests that a block is successfully saved and loaded from storage. +func TestBlockStore_saveBlock_loadBlock(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + bs := &BlockStore{ + block: make([][]byte, 0, 20), + numBlocks: 50, + blockSize: 20, + firstSaved: 0, + lastSaved: 0, + kv: versioned.NewKV(make(ekv.Memstore)), + } + + for i := range bs.block { + bs.block[i] = make([]byte, 32) + prng.Read(bs.block[i]) + } + + err := bs.saveBlock() + if err != nil { + t.Errorf("saveBlock() returned an error: %+v", err) + } + + newBS := &BlockStore{kv: bs.kv} + block, err := newBS.loadBlock(0) + if err != nil { + t.Errorf("loadBlock() returned an error: %+v", err) + } + + if !reflect.DeepEqual(bs.block, block) { + t.Errorf("Failed to save and load block to storage."+ + "\nexpected: %+v\nreceived: %+v", bs.block, block) + } +} + +// Error path: failed to save to KV. +func TestBlockStore_saveBlock_SaveError(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + bs := &BlockStore{ + block: make([][]byte, 0, 20), + numBlocks: 50, + blockSize: 20, + firstSaved: 0, + lastSaved: 0, + kv: versioned.NewKV(make(ekv.Memstore)), + } + + for i := range bs.block { + bs.block[i] = make([]byte, 32) + prng.Read(bs.block[i]) + } + + err := bs.saveBlock() + if err != nil { + t.Errorf("saveBlock() returned an error: %+v", err) + } + + newBS := &BlockStore{kv: bs.kv} + block, err := newBS.loadBlock(0) + if err != nil { + t.Errorf("loadBlock() returned an error: %+v", err) + } + + if !reflect.DeepEqual(bs.block, block) { + t.Errorf("Failed to save and load block to storage."+ + "\nexpected: %+v\nreceived: %+v", bs.block, block) + } +} + +// Error path: loading of nonexistent key returns an error. +func TestBlockStore_loadBlock_LoadStorageError(t *testing.T) { + expectedErr := strings.SplitN(bKvLoadErr, "%", 2)[0] + bs := &BlockStore{kv: versioned.NewKV(make(ekv.Memstore))} + _, err := bs.loadBlock(0) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("loadBlock() did not return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Error path: unmarshalling of invalid data fails. +func TestBlockStore_loadBlock_UnmarshalError(t *testing.T) { + bs := &BlockStore{kv: versioned.NewKV(make(ekv.Memstore))} + expectedErr := strings.SplitN(bJsonUnmarshalErr, "%", 2)[0] + + // Construct object with invalid data + obj := versioned.Object{ + Version: blockVersion, + Timestamp: netTime.Now(), + Data: []byte("invalid JSON"), + } + + // Save to storage + err := bs.kv.Set(bs.getKey(bs.lastSaved), blockVersion, &obj) + if err != nil { + t.Errorf("Failed to save data to KV: %+v", err) + } + + _, err = bs.loadBlock(0) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("loadBlock() did not return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Happy path. +func TestBlockStore_pruneBlocks(t *testing.T) { + bs := &BlockStore{ + block: make([][]byte, 0, 32), + numBlocks: 5, + blockSize: 32, + firstSaved: 0, + lastSaved: 0, + kv: versioned.NewKV(make(ekv.Memstore)), + } + + // Save blocks to storage + for ; bs.lastSaved < 15; bs.lastSaved++ { + if err := bs.saveBlock(); err != nil { + t.Errorf("Failed to save block %d: %+v", bs.lastSaved, err) + } + } + + // Calculate the new first saved index + oldFirstSaved := bs.firstSaved + bs.firstSaved = bs.lastSaved - bs.numBlocks + + // Prune blocks + bs.pruneBlocks(oldFirstSaved) + + // Check that the old blocks were deleted + for i := 0; i < bs.lastSaved-bs.numBlocks; i++ { + if _, err := bs.kv.Get(bs.getKey(i), blockVersion); err == nil { + t.Errorf("pruneBlocks() failed to delete old block %d: %+v", i, err) + } + } + + // Check that the new blocks were not deleted + for i := bs.firstSaved; i < bs.lastSaved; i++ { + if _, err := bs.kv.Get(bs.getKey(i), blockVersion); err != nil { + t.Errorf("pruneBlocks() deleted block %d: %+v", i, err) + } + } + + // Call pruneBlocks when there are no blocks to prune + oldFirstSaved = bs.firstSaved + bs.firstSaved = bs.lastSaved - bs.numBlocks + bs.pruneBlocks(oldFirstSaved) + + // Check that the new blocks were not deleted + for i := bs.firstSaved; i < bs.lastSaved; i++ { + if _, err := bs.kv.Get(bs.getKey(i), blockVersion); err != nil { + t.Errorf("pruneBlocks() deleted block %d: %+v", i, err) + } + } +} + +// Consistency test. +func TestBlockStore_getKey_Consistency(t *testing.T) { + expectedKeys := []string{ + "block/0", "block/1", "block/2", "block/3", "block/4", + "block/5", "block/6", "block/7", "block/8", "block/9", + } + var bs BlockStore + + for i, expected := range expectedKeys { + key := bs.getKey(i) + if key != expected { + t.Errorf("getKey did not return the correct key for the index %d."+ + "\nexpected: %s\nreceived: %s", i, expected, key) + } + } +} + +// Tests that a BlockStore can be saved and loaded from the KV correctly. +func TestBlockStore_save_load(t *testing.T) { + bs := &BlockStore{ + numBlocks: 5, blockSize: 6, firstSaved: 7, lastSaved: 8, + kv: versioned.NewKV(make(ekv.Memstore)), + } + + err := bs.save() + if err != nil { + t.Errorf("save() returned an error: %+v", err) + } + + testBS := &BlockStore{kv: bs.kv} + err = testBS.load() + if err != nil { + t.Errorf("load() returned an error: %+v", err) + } + + if !reflect.DeepEqual(bs, testBS) { + t.Errorf("Failed to save and load BlockStore to KV."+ + "\nexpected: %+v\nreceived: %+v", bs, testBS) + } +} + +// Error path: loading of unsaved BlockStore fails. +func TestBlockStore_load_KvGetError(t *testing.T) { + expectedErr := strings.SplitN(bsKvLoadErr, "%", 2)[0] + + testBS := &BlockStore{kv: versioned.NewKV(make(ekv.Memstore))} + err := testBS.load() + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("load() did not return an error for a nonexistent item in storage."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Error path: unmarshalling of invalid data fails. +func TestBlockStore_load_UnmarshalError(t *testing.T) { + expectedErr := strings.SplitN(bsKvUnmarshalErr, "%", 2)[0] + kv := versioned.NewKV(make(ekv.Memstore)) + + // Construct invalid versioning object + obj := versioned.Object{ + Version: blockStoreVersion, + Timestamp: netTime.Now(), + Data: []byte("invalid data"), + } + + // Save to storage + err := kv.Set(blockStoreKey, blockStoreVersion, &obj) + if err != nil { + t.Fatalf("failed to save object to storage: %+v", err) + } + + // Try to retrieve invalid object + testBS := &BlockStore{kv: kv} + err = testBS.load() + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("load() did not return an error for a nonexistent item in storage."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Consistency test. +func TestBlockStore_unmarshal(t *testing.T) { + buff := []byte{5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, + 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0} + expected := &BlockStore{numBlocks: 5, blockSize: 6, firstSaved: 7, lastSaved: 8} + + bs := &BlockStore{} + err := bs.unmarshal(buff) + if err != nil { + t.Errorf("unmarshal() returned an error: %+v", err) + } + + if !reflect.DeepEqual(expected, bs) { + t.Errorf("unmarshal() did not return the expected BlockStore."+ + "\nexpected: %+v\nreceived: %+v", expected, bs) + } +} + +// Error path: length of buffer incorrect. +func TestBlockStore_unmarshal_BuffLengthError(t *testing.T) { + expectedErr := fmt.Sprintf(bsBuffLengthErr, 0, marshalledSize) + bs := BlockStore{} + err := bs.unmarshal([]byte{}) + if err == nil || err.Error() != expectedErr { + t.Errorf("unmarshal() did not return the expected error for a buffer "+ + "of the wrong size.\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Consistency test. +func TestBlockStore_marshal(t *testing.T) { + expected := []byte{5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, + 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0} + bs := &BlockStore{numBlocks: 5, blockSize: 6, firstSaved: 7, lastSaved: 8} + + buff := bs.marshal() + + if !bytes.Equal(expected, buff) { + t.Errorf("marshal() did not return the expected bytes."+ + "\nexpected: %+v\nreceived: %+v", expected, buff) + } +} + +// Tests that marshal and unmarshal work together. +func TestBlockStore_marshal_unmarshal(t *testing.T) { + bs := &BlockStore{numBlocks: 5, blockSize: 6, firstSaved: 7, lastSaved: 8} + + buff := bs.marshal() + + testBS := &BlockStore{} + err := testBS.unmarshal(buff) + if err != nil { + t.Errorf("unmarshal() returned an error: %+v", err) + } + + if !reflect.DeepEqual(bs, testBS) { + t.Errorf("failed to marshal and unmarshal BlockStore."+ + "\nexpected: %+v\nreceived: %+v", bs, testBS) + } +}