diff --git a/api/authenticatedChannel.go b/api/auth.go similarity index 71% rename from api/authenticatedChannel.go rename to api/auth.go index 4b9ef33192279c0d5a6d36c520e17ddddf77c35d..d7a5cc362ea49dc9108e220d05c5fd45eeb5586f 100644 --- a/api/authenticatedChannel.go +++ b/api/auth.go @@ -9,15 +9,13 @@ package api import ( "encoding/binary" - "gitlab.com/elixxir/client/catalog" "math/rand" "github.com/cloudflare/circl/dh/sidh" "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/auth" - "gitlab.com/elixxir/client/interfaces" - "gitlab.com/elixxir/client/storage/edge" + "gitlab.com/elixxir/client/e2e/ratchet/partner/session" util "gitlab.com/elixxir/client/storage/utility" "gitlab.com/elixxir/crypto/contact" "gitlab.com/elixxir/primitives/fact" @@ -36,13 +34,12 @@ func (c *Client) RequestAuthenticatedChannel(recipient, me contact.Contact, message string) (id.Round, error) { jww.INFO.Printf("RequestAuthenticatedChannel(%s)", recipient.ID) - if !c.network.GetHealthTracker().IsHealthy() { + if !c.network.IsHealthy() { return 0, errors.New("Cannot request authenticated channel " + "creation when the network is not healthy") } - return auth.RequestAuth(recipient, me, c.rng.GetStream(), - c.storage, c.network) + return c.auth.Request(recipient, c.GetUser().GetContact().Facts) } // ResetSession resets an authenticate channel that already exists @@ -50,18 +47,17 @@ func (c *Client) ResetSession(recipient, me contact.Contact, message string) (id.Round, error) { jww.INFO.Printf("ResetSession(%s)", recipient.ID) - if !c.network.GetHealthTracker().IsHealthy() { + if !c.network.IsHealthy() { return 0, errors.New("Cannot request authenticated channel " + "creation when the network is not healthy") } - return auth.ResetSession(recipient, me, c.rng.GetStream(), - c.storage, c.network) + return c.auth.Reset(recipient) } // GetAuthRegistrar gets the object which allows the registration of auth // callbacks -func (c *Client) GetAuthRegistrar() interfaces.Auth { +func (c *Client) GetAuthRegistrar() auth.State { jww.INFO.Printf("GetAuthRegistrar(...)") return c.auth @@ -72,7 +68,7 @@ func (c *Client) GetAuthRegistrar() interfaces.Auth { func (c *Client) GetAuthenticatedChannelRequest(partner *id.ID) (contact.Contact, error) { jww.INFO.Printf("GetAuthenticatedChannelRequest(%s)", partner) - return c.storage.Auth().GetReceivedRequestData(partner) + return c.auth.GetReceivedRequest(partner) } // ConfirmAuthenticatedChannel creates an authenticated channel out of a valid @@ -86,7 +82,7 @@ func (c *Client) GetAuthenticatedChannelRequest(partner *id.ID) (contact.Contact func (c *Client) ConfirmAuthenticatedChannel(recipient contact.Contact) (id.Round, error) { jww.INFO.Printf("ConfirmAuthenticatedChannel(%s)", recipient.ID) - if !c.network.GetHealthTracker().IsHealthy() { + if !c.network.IsHealthy() { return 0, errors.New("Cannot request authenticated channel " + "creation when the network is not healthy") } @@ -99,18 +95,19 @@ func (c *Client) ConfirmAuthenticatedChannel(recipient contact.Contact) (id.Roun func (c *Client) VerifyOwnership(received, verified contact.Contact) bool { jww.INFO.Printf("VerifyOwnership(%s)", received.ID) - return auth.VerifyOwnership(received, verified, c.storage) + return c.auth.VerifyOwnership(received, verified, c.e2e) } // HasAuthenticatedChannel returns true if an authenticated channel exists for // the partner func (c *Client) HasAuthenticatedChannel(partner *id.ID) bool { - m, err := c.storage.E2e().GetPartner(partner) + m, err := c.e2e.GetPartner(partner) return m != nil && err == nil } // Create an insecure e2e relationship with a precanned user -func (c *Client) MakePrecannedAuthenticatedChannel(precannedID uint) (contact.Contact, error) { +func (c *Client) MakePrecannedAuthenticatedChannel(precannedID uint) ( + contact.Contact, error) { precan := c.MakePrecannedContact(precannedID) @@ -135,49 +132,15 @@ func (c *Client) MakePrecannedAuthenticatedChannel(precannedID uint) (contact.Co mySIDHPrivKey.GeneratePublicKey(mySIDHPubKey) // add the precanned user as a e2e contact - sesParam := c.parameters.E2EParams - err := c.storage.E2e().AddPartner(precan.ID, precan.DhPubKey, - c.storage.E2e().GetDHPrivateKey(), theirSIDHPubKey, + // FIXME: these params need to be threaded through... + sesParam := session.GetDefaultParams() + _, err := c.e2e.AddPartner(precan.ID, precan.DhPubKey, + c.e2e.GetHistoricalDHPrivkey(), theirSIDHPubKey, mySIDHPrivKey, sesParam, sesParam) // check garbled messages in case any messages arrived before creating // the channel - c.network.CheckGarbledMessages() - - //add the e2e and rekey firngeprints - //e2e - sessionPartner, err := c.storage.E2e().GetPartner(precan.ID) - if err != nil { - jww.FATAL.Panicf("Cannot find %s right after creating: %+v", precan.ID, err) - } - me := c.storage.GetUser().ReceptionID - - c.storage.GetEdge().Add(edge.Preimage{ - Data: sessionPartner.GetE2EPreimage(), - Type: catalog.E2e, - Source: precan.ID[:], - }, me) - - // slient (rekey) - c.storage.GetEdge().Add(edge.Preimage{ - Data: sessionPartner.GetSilentPreimage(), - Type: catalog.Silent, - Source: precan.ID[:], - }, me) - - // File transfer end - c.storage.GetEdge().Add(edge.Preimage{ - Data: sessionPartner.GetFileTransferPreimage(), - Type: catalog.EndFT, - Source: precan.ID[:], - }, me) - - // group request - c.storage.GetEdge().Add(edge.Preimage{ - Data: sessionPartner.GetGroupRequestPreimage(), - Type: catalog.GroupRq, - Source: precan.ID[:], - }, me) + c.network.CheckInProgressMessages() return precan, err } @@ -185,14 +148,14 @@ func (c *Client) MakePrecannedAuthenticatedChannel(precannedID uint) (contact.Co // Create an insecure e2e contact object for a precanned user func (c *Client) MakePrecannedContact(precannedID uint) contact.Contact { - e2eGrp := c.storage.E2e().GetGroup() + e2eGrp := c.storage.GetE2EGroup() - // get the user definition precanned := createPrecannedUser(precannedID, c.rng.GetStream(), - c.storage.Cmix().GetGroup(), e2eGrp) + c.storage.GetCmixGroup(), e2eGrp) // compute their public e2e key - partnerPubKey := e2eGrp.ExpG(precanned.E2eDhPrivateKey, e2eGrp.NewInt(1)) + partnerPubKey := e2eGrp.ExpG(precanned.E2eDhPrivateKey, + e2eGrp.NewInt(1)) return contact.Contact{ ID: precanned.ReceptionID, @@ -206,12 +169,14 @@ func (c *Client) MakePrecannedContact(precannedID uint) contact.Contact { // E2E relationship. An error is returned if no relationship with the partner // is found. func (c *Client) GetRelationshipFingerprint(partner *id.ID) (string, error) { - m, err := c.storage.E2e().GetPartner(partner) + m, err := c.e2e.GetPartner(partner) if err != nil { - return "", errors.Errorf("could not get partner %s: %+v", partner, err) + return "", errors.Errorf("could not get partner %s: %+v", + partner, err) } else if m == nil { - return "", errors.Errorf("manager for partner %s is nil.", partner) + return "", errors.Errorf("manager for partner %s is nil.", + partner) } - return m.GetRelationshipFingerprint(), nil + return m.GetConnectionFingerprint(), nil } diff --git a/api/client.go b/api/client.go index 85894ca42f590f64843d820d50339a6d22d3ec32..458d4ed47196febd1f846ca8eaed29611c07f17f 100644 --- a/api/client.go +++ b/api/client.go @@ -9,20 +9,23 @@ package api import ( "encoding/json" + "math" + "time" + "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/auth" - "gitlab.com/elixxir/client/catalog" + "gitlab.com/elixxir/client/backup" "gitlab.com/elixxir/client/cmix" - keyExchange2 "gitlab.com/elixxir/client/e2e/rekey" + "gitlab.com/elixxir/client/e2e" "gitlab.com/elixxir/client/event" "gitlab.com/elixxir/client/interfaces" - "gitlab.com/elixxir/client/interfaces/user" "gitlab.com/elixxir/client/registration" "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/client/storage" + "gitlab.com/elixxir/client/storage/user" "gitlab.com/elixxir/comms/client" - "gitlab.com/elixxir/crypto/backup" + cryptoBackup "gitlab.com/elixxir/crypto/backup" "gitlab.com/elixxir/crypto/cyclic" "gitlab.com/elixxir/crypto/fastRNG" "gitlab.com/elixxir/primitives/version" @@ -33,8 +36,6 @@ import ( "gitlab.com/xx_network/primitives/id" "gitlab.com/xx_network/primitives/ndf" "gitlab.com/xx_network/primitives/region" - "math" - "time" ) const followerStoppableName = "client" @@ -44,20 +45,23 @@ type Client struct { rng *fastRNG.StreamGenerator // the storage session securely stores data to disk and memoizes as is // appropriate - storage *storage.Session + storage storage.Session + + // user state object + userState *user.User + //object used for communications comms *client.Comms - // Network parameters - parameters cmix.Params + // Network parameters, note e2e params wrap CMIXParams + parameters e2e.Params - // note that the manager has a pointer to the context in many cases, but - // this interface allows it to be mocked for easy testing without the - // loop - network interfaces.NetworkManager + network cmix.Client //object used to register and communicate with permissioning permissioning *registration.Registration //object containing auth interactions - auth *auth.State + auth auth.State + + e2e e2e.Handler //services system to track running threads followerServices *services @@ -68,7 +72,7 @@ type Client struct { events *event.Manager // Handles the triggering and delivery of backups - backup *interfaces.BackupContainer + backup *backup.Backup } // NewClient creates client storage, generates keys, connects, and registers @@ -78,10 +82,9 @@ type Client struct { func NewClient(ndfJSON, storageDir string, password []byte, registrationCode string) error { jww.INFO.Printf("NewClient(dir: %s)", storageDir) - // Use fastRNG for RNG ops (AES fortuna based RNG using system RNG) - rngStreamGen := fastRNG.NewStreamGenerator(12, 1024, csprng.NewSystemRNG) + rngStreamGen := fastRNG.NewStreamGenerator(12, 1024, + csprng.NewSystemRNG) - // Parse the NDF def, err := parseNDF(ndfJSON) if err != nil { return err @@ -90,10 +93,12 @@ func NewClient(ndfJSON, storageDir string, password []byte, cmixGrp, e2eGrp := decodeGroups(def) start := time.Now() protoUser := createNewUser(rngStreamGen, cmixGrp, e2eGrp) - jww.DEBUG.Printf("PortableUserInfo generation took: %s", time.Now().Sub(start)) + jww.DEBUG.Printf("PortableUserInfo generation took: %s", + time.Now().Sub(start)) - _, err = checkVersionAndSetupStorage(def, storageDir, password, protoUser, - cmixGrp, e2eGrp, rngStreamGen, false, registrationCode) + _, err = checkVersionAndSetupStorage(def, storageDir, password, + protoUser, cmixGrp, e2eGrp, rngStreamGen, false, + registrationCode) if err != nil { return err } @@ -102,29 +107,29 @@ func NewClient(ndfJSON, storageDir string, password []byte, return nil } -// NewPrecannedClient creates an insecure user with predetermined keys with nodes -// It creates client storage, generates keys, connects, and registers -// with the network. Note that this does not register a username/identity, but -// merely creates a new cryptographic identity for adding such information -// at a later date. +// NewPrecannedClient creates an insecure user with predetermined keys +// with nodes It creates client storage, generates keys, connects, and +// registers with the network. Note that this does not register a +// username/identity, but merely creates a new cryptographic identity +// for adding such information at a later date. func NewPrecannedClient(precannedID uint, defJSON, storageDir string, password []byte) error { jww.INFO.Printf("NewPrecannedClient()") - // Use fastRNG for RNG ops (AES fortuna based RNG using system RNG) - rngStreamGen := fastRNG.NewStreamGenerator(12, 1024, csprng.NewSystemRNG) + rngStreamGen := fastRNG.NewStreamGenerator(12, 1024, + csprng.NewSystemRNG) rngStream := rngStreamGen.GetStream() - // Parse the NDF def, err := parseNDF(defJSON) if err != nil { return err } cmixGrp, e2eGrp := decodeGroups(def) - protoUser := createPrecannedUser(precannedID, rngStream, cmixGrp, e2eGrp) + protoUser := createPrecannedUser(precannedID, rngStream, + cmixGrp, e2eGrp) - _, err = checkVersionAndSetupStorage(def, storageDir, password, protoUser, - cmixGrp, e2eGrp, rngStreamGen, true, "") + _, err = checkVersionAndSetupStorage(def, storageDir, password, + protoUser, cmixGrp, e2eGrp, rngStreamGen, true, "") if err != nil { return err } @@ -132,29 +137,31 @@ func NewPrecannedClient(precannedID uint, defJSON, storageDir string, return nil } -// NewVanityClient creates a user with a receptionID that starts with the supplied prefix -// It creates client storage, generates keys, connects, and registers -// with the network. Note that this does not register a username/identity, but -// merely creates a new cryptographic identity for adding such information -// at a later date. +// NewVanityClient creates a user with a receptionID that starts with +// the supplied prefix It creates client storage, generates keys, +// connects, and registers with the network. Note that this does not +// register a username/identity, but merely creates a new +// cryptographic identity for adding such information at a later date. func NewVanityClient(ndfJSON, storageDir string, password []byte, registrationCode string, userIdPrefix string) error { jww.INFO.Printf("NewVanityClient()") - // Use fastRNG for RNG ops (AES fortuna based RNG using system RNG) - rngStreamGen := fastRNG.NewStreamGenerator(12, 1024, csprng.NewSystemRNG) + + rngStreamGen := fastRNG.NewStreamGenerator(12, 1024, + csprng.NewSystemRNG) rngStream := rngStreamGen.GetStream() - // Parse the NDF def, err := parseNDF(ndfJSON) if err != nil { return err } cmixGrp, e2eGrp := decodeGroups(def) - protoUser := createNewVanityUser(rngStream, cmixGrp, e2eGrp, userIdPrefix) + protoUser := createNewVanityUser(rngStream, cmixGrp, e2eGrp, + userIdPrefix) - _, err = checkVersionAndSetupStorage(def, storageDir, password, protoUser, - cmixGrp, e2eGrp, rngStreamGen, false, registrationCode) + _, err = checkVersionAndSetupStorage(def, storageDir, password, + protoUser, cmixGrp, e2eGrp, rngStreamGen, false, + registrationCode) if err != nil { return err } @@ -163,22 +170,24 @@ func NewVanityClient(ndfJSON, storageDir string, password []byte, return nil } -// NewClientFromBackup constructs a new Client from an encrypted backup. The backup -// is decrypted using the backupPassphrase. On success a successful client creation, -// the function will return a JSON encoded list of the E2E partners -// contained in the backup and a json-encoded string containing parameters stored in the backup +// NewClientFromBackup constructs a new Client from an encrypted +// backup. The backup is decrypted using the backupPassphrase. On +// success a successful client creation, the function will return a +// JSON encoded list of the E2E partners contained in the backup and a +// json-encoded string containing parameters stored in the backup func NewClientFromBackup(ndfJSON, storageDir string, sessionPassword, - backupPassphrase []byte, backupFileContents []byte) ([]*id.ID, string, error) { + backupPassphrase []byte, backupFileContents []byte) ([]*id.ID, + string, error) { - backUp := &backup.Backup{} + backUp := &cryptoBackup.Backup{} err := backUp.Decrypt(string(backupPassphrase), backupFileContents) if err != nil { - return nil, "", errors.WithMessage(err, "Failed to unmarshal decrypted client contents.") + return nil, "", errors.WithMessage(err, + "Failed to unmarshal decrypted client contents.") } usr := user.NewUserFromBackup(backUp) - // Parse the NDF def, err := parseNDF(ndfJSON) if err != nil { return nil, "", err @@ -186,21 +195,23 @@ func NewClientFromBackup(ndfJSON, storageDir string, sessionPassword, cmixGrp, e2eGrp := decodeGroups(def) - // Use fastRNG for RNG ops (AES fortuna based RNG using system RNG) rngStreamGen := fastRNG.NewStreamGenerator(12, 3, csprng.NewSystemRNG) - // Create storage object. - // Note we do not need registration - storageSess, err := checkVersionAndSetupStorage(def, storageDir, []byte(sessionPassword), usr, - cmixGrp, e2eGrp, rngStreamGen, false, backUp.RegistrationCode) + // Note we do not need registration here + storageSess, err := checkVersionAndSetupStorage(def, storageDir, + []byte(sessionPassword), usr, cmixGrp, e2eGrp, rngStreamGen, + false, backUp.RegistrationCode) - // Set registration values in storage - storageSess.User().SetReceptionRegistrationValidationSignature(backUp.ReceptionIdentity.RegistrarSignature) - storageSess.User().SetTransmissionRegistrationValidationSignature(backUp.TransmissionIdentity.RegistrarSignature) - storageSess.User().SetRegistrationTimestamp(backUp.RegistrationTimestamp) + storageSess.SetReceptionRegistrationValidationSignature( + backUp.ReceptionIdentity.RegistrarSignature) + storageSess.SetTransmissionRegistrationValidationSignature( + backUp.TransmissionIdentity.RegistrarSignature) + storageSess.SetRegistrationTimestamp(backUp.RegistrationTimestamp) - //move the registration state to indicate registered with registration on proto client - err = storageSess.ForwardRegistrationStatus(storage.PermissioningComplete) + //move the registration state to indicate registered with + // registration on proto client + err = storageSess.ForwardRegistrationStatus( + storage.PermissioningComplete) if err != nil { return nil, "", err } @@ -209,37 +220,36 @@ func NewClientFromBackup(ndfJSON, storageDir string, sessionPassword, } // OpenClient session, but don't connect to the network or log in -func OpenClient(storageDir string, password []byte, parameters params.Network) (*Client, error) { +func OpenClient(storageDir string, password []byte, + parameters e2e.Params) (*Client, error) { jww.INFO.Printf("OpenClient()") - // Use fastRNG for RNG ops (AES fortuna based RNG using system RNG) - rngStreamGen := fastRNG.NewStreamGenerator(12, 1024, csprng.NewSystemRNG) - // get current client version + rngStreamGen := fastRNG.NewStreamGenerator(12, 1024, + csprng.NewSystemRNG) + currentVersion, err := version.ParseVersion(SEMVER) if err != nil { - return nil, errors.WithMessage(err, "Could not parse version string.") + return nil, errors.WithMessage(err, + "Could not parse version string.") } - // Load Storage passwordStr := string(password) - storageSess, err := storage.Load(storageDir, passwordStr, currentVersion, - rngStreamGen) + storageSess, err := storage.Load(storageDir, passwordStr, + currentVersion) if err != nil { return nil, err } - // Set up a new context c := &Client{ storage: storageSess, - switchboard: switchboard.New(), rng: rngStreamGen, comms: nil, network: nil, followerServices: newServices(), parameters: parameters, clientErrorChannel: make(chan interfaces.ClientError, 1000), - events: event.newEventManager(), - backup: &interfaces.BackupContainer{}, + events: event.NewEventManager(), + backup: &backup.Backup{}, } return c, nil @@ -252,10 +262,8 @@ func NewProtoClient_Unsafe(ndfJSON, storageDir string, password, protoClientJSON []byte) error { jww.INFO.Printf("NewProtoClient_Unsafe") - // Use fastRNG for RNG ops (AES fortuna based RNG using system RNG) rngStreamGen := fastRNG.NewStreamGenerator(12, 3, csprng.NewSystemRNG) - // Parse the NDF def, err := parseNDF(ndfJSON) if err != nil { return err @@ -263,30 +271,31 @@ func NewProtoClient_Unsafe(ndfJSON, storageDir string, password, cmixGrp, e2eGrp := decodeGroups(def) - // Pull the proto user from the JSON protoUser := &user.Proto{} err = json.Unmarshal(protoClientJSON, protoUser) if err != nil { return err } - // Initialize a user object for storage set up usr := user.NewUserFromProto(protoUser) - // Set up storage - storageSess, err := checkVersionAndSetupStorage(def, storageDir, password, usr, - cmixGrp, e2eGrp, rngStreamGen, false, protoUser.RegCode) + storageSess, err := checkVersionAndSetupStorage(def, storageDir, + password, usr, cmixGrp, e2eGrp, rngStreamGen, false, + protoUser.RegCode) if err != nil { return err } - // Set registration values in storage - storageSess.User().SetReceptionRegistrationValidationSignature(protoUser.ReceptionRegValidationSig) - storageSess.User().SetTransmissionRegistrationValidationSignature(protoUser.TransmissionRegValidationSig) - storageSess.User().SetRegistrationTimestamp(protoUser.RegistrationTimestamp) + storageSess.SetReceptionRegistrationValidationSignature( + protoUser.ReceptionRegValidationSig) + storageSess.SetTransmissionRegistrationValidationSignature( + protoUser.TransmissionRegValidationSig) + storageSess.SetRegistrationTimestamp(protoUser.RegistrationTimestamp) - //move the registration state to indicate registered with registration on proto client - err = storageSess.ForwardRegistrationStatus(storage.PermissioningComplete) + // move the registration state to indicate registered with + // registration on proto client + err = storageSess.ForwardRegistrationStatus( + storage.PermissioningComplete) if err != nil { return err } @@ -295,26 +304,24 @@ func NewProtoClient_Unsafe(ndfJSON, storageDir string, password, } // Login initializes a client object from existing storage. -func Login(storageDir string, password []byte, parameters params.Network) (*Client, error) { +func Login(storageDir string, password []byte, + authCallbacks auth.Callbacks, parameters e2e.Params) (*Client, error) { jww.INFO.Printf("Login()") - //Open the client c, err := OpenClient(storageDir, password, parameters) if err != nil { return nil, err } - u := c.storage.GetUser() + u := c.GetUser() jww.INFO.Printf("Client Logged in: \n\tTransmisstionID: %s "+ "\n\tReceptionID: %s", u.TransmissionID, u.ReceptionID) - // initialize comms err = c.initComms() if err != nil { return nil, err } - //get the NDF to pass into registration and the network manager def := c.storage.GetNDF() //initialize registration @@ -324,9 +331,9 @@ func Login(storageDir string, password []byte, parameters params.Network) (*Clie return nil, err } } else { - jww.WARN.Printf("Registration with permissioning skipped due to " + - "blank permissioning address. Client will not be able to register " + - "or track network.") + jww.WARN.Printf("Registration with permissioning skipped due " + + "to blank permissioning address. Client will not be " + + "able to register or track network.") } if def.Notification.Address != "" { @@ -335,25 +342,38 @@ func Login(storageDir string, password []byte, parameters params.Network) (*Clie hp.KaClientOpts.Time = time.Duration(math.MaxInt64) hp.AuthEnabled = false hp.MaxRetries = 5 - _, err = c.comms.AddHost(&id.NotificationBot, def.Notification.Address, + _, err = c.comms.AddHost(&id.NotificationBot, + def.Notification.Address, []byte(def.Notification.TlsCertificate), hp) if err != nil { - jww.WARN.Printf("Failed adding host for notifications: %+v", err) + jww.WARN.Printf("Failed adding host for "+ + "notifications: %+v", err) } } - // Initialize network and link it to context - c.network, err = cmix.NewClient(c.storage, c.switchboard, c.rng, - c.events, c.comms, parameters, def) + c.network, err = cmix.NewClient(parameters.Network, c.comms, c.storage, + c.storage.GetNDF(), c.rng, c.events) + if err != nil { + return nil, err + } + + c.e2e, err = e2e.Load(c.storage.GetKV(), c.network, + c.GetUser().ReceptionID, c.storage.GetE2EGroup(), + c.rng, c.events) if err != nil { return nil, err } - // initialize the auth tracker - c.auth = auth.NewManager(c.switchboard, c.storage, c.network, c.rng, - c.backup.TriggerBackup, parameters.ReplayRequests) + // FIXME: The callbacks need to be set, so I suppose we would need to + // either set them via a special type or add them + // to the login call? + authParams := auth.GetDefaultParams() + c.auth, err = auth.NewState(c.storage.GetKV(), c.network, c.e2e, c.rng, + c.events, authParams, authCallbacks, c.backup.TriggerBackup) + if err != nil { + return nil, err + } - // Add all processes to the followerServices err = c.registerFollower() if err != nil { return nil, err @@ -366,22 +386,20 @@ func Login(storageDir string, password []byte, parameters params.Network) (*Clie // while replacing the base NDF. This is designed for some specific deployment // procedures and is generally unsafe. func LoginWithNewBaseNDF_UNSAFE(storageDir string, password []byte, - newBaseNdf string, parameters params.Network) (*Client, error) { + newBaseNdf string, authCallbacks auth.Callbacks, + parameters e2e.Params) (*Client, error) { jww.INFO.Printf("LoginWithNewBaseNDF_UNSAFE()") - // Parse the NDF def, err := parseNDF(newBaseNdf) if err != nil { return nil, err } - //Open the client c, err := OpenClient(storageDir, password, parameters) if err != nil { return nil, err } - //initialize comms err = c.initComms() if err != nil { return nil, err @@ -390,28 +408,39 @@ func LoginWithNewBaseNDF_UNSAFE(storageDir string, password []byte, //store the updated base NDF c.storage.SetNDF(def) - //initialize registration if def.Registration.Address != "" { err = c.initPermissioning(def) if err != nil { return nil, err } } else { - jww.WARN.Printf("Registration with permissioning skipped due to " + - "blank permissionign address. Client will not be able to register " + - "or track network.") + jww.WARN.Printf("Registration with permissioning skipped due " + + "to blank permissionign address. Client will not be " + + "able to register or track network.") } - // Initialize network and link it to context - c.network, err = cmix.NewClient(c.storage, c.switchboard, c.rng, - c.events, c.comms, parameters, def) + c.network, err = cmix.NewClient(parameters.Network, c.comms, c.storage, + c.storage.GetNDF(), c.rng, c.events) if err != nil { return nil, err } - // initialize the auth tracker - c.auth = auth.NewManager(c.switchboard, c.storage, c.network, c.rng, - c.backup.TriggerBackup, parameters.ReplayRequests) + c.e2e, err = e2e.Load(c.storage.GetKV(), c.network, + c.GetUser().ReceptionID, c.storage.GetE2EGroup(), + c.rng, c.events) + if err != nil { + return nil, err + } + + // FIXME: The callbacks need to be set, so I suppose we would need to + // either set them via a special type or add them + // to the login call? + authParams := auth.GetDefaultParams() + c.auth, err = auth.NewState(c.storage.GetKV(), c.network, c.e2e, c.rng, + c.events, authParams, authCallbacks, c.backup.TriggerBackup) + if err != nil { + return nil, err + } err = c.registerFollower() if err != nil { @@ -421,38 +450,35 @@ func LoginWithNewBaseNDF_UNSAFE(storageDir string, password []byte, return c, nil } -// LoginWithProtoClient creates a client object with a protoclient JSON containing the -// cryptographic primitives. This is designed for some specific deployment -//// procedures and is generally unsafe. -func LoginWithProtoClient(storageDir string, password []byte, protoClientJSON []byte, - newBaseNdf string, parameters params.Network) (*Client, error) { +// LoginWithProtoClient creates a client object with a protoclient +// JSON containing the cryptographic primitives. This is designed for +// some specific deployment procedures and is generally unsafe. +func LoginWithProtoClient(storageDir string, password []byte, + protoClientJSON []byte, newBaseNdf string, authCallbacks auth.Callbacks, + parameters e2e.Params) (*Client, error) { jww.INFO.Printf("LoginWithProtoClient()") - // Parse the NDF def, err := parseNDF(newBaseNdf) if err != nil { return nil, err } - //Open the client - err = NewProtoClient_Unsafe(newBaseNdf, storageDir, password, protoClientJSON) + err = NewProtoClient_Unsafe(newBaseNdf, storageDir, password, + protoClientJSON) if err != nil { return nil, err } - //Open the client c, err := OpenClient(storageDir, password, parameters) if err != nil { return nil, err } - //initialize comms err = c.initComms() if err != nil { return nil, err } - //store the updated base NDF c.storage.SetNDF(def) err = c.initPermissioning(def) @@ -460,17 +486,28 @@ func LoginWithProtoClient(storageDir string, password []byte, protoClientJSON [] return nil, err } - // Initialize network and link it to context - c.network, err = cmix.NewClient(c.storage, c.switchboard, c.rng, - c.events, c.comms, parameters, def) + c.network, err = cmix.NewClient(parameters.Network, c.comms, c.storage, + c.storage.GetNDF(), c.rng, c.events) if err != nil { return nil, err } - // initialize the auth tracker - c.auth = auth.NewManager(c.switchboard, c.storage, c.network, c.rng, - c.backup.TriggerBackup, parameters.ReplayRequests) + c.e2e, err = e2e.Load(c.storage.GetKV(), c.network, + c.GetUser().ReceptionID, c.storage.GetE2EGroup(), + c.rng, c.events) + if err != nil { + return nil, err + } + // FIXME: The callbacks need to be set, so I suppose we would need to + // either set them via a special type or add them + // to the login call? + authParams := auth.GetDefaultParams() + c.auth, err = auth.NewState(c.storage.GetKV(), c.network, c.e2e, c.rng, + c.events, authParams, authCallbacks, c.backup.TriggerBackup) + if err != nil { + return nil, err + } err = c.registerFollower() if err != nil { return nil, err @@ -483,14 +520,16 @@ func (c *Client) initComms() error { var err error //get the user from session - u := c.storage.User() - cryptoUser := u.GetCryptographicIdentity() + u := c.userState + cryptoUser := u.CryptographicIdentity + + privKey := cryptoUser.GetTransmissionRSA() + pubPEM := rsa.CreatePublicKeyPem(privKey.GetPublic()) + privPEM := rsa.CreatePrivateKeyPem(privKey) //start comms c.comms, err = client.NewClientComms(cryptoUser.GetTransmissionID(), - rsa.CreatePublicKeyPem(cryptoUser.GetTransmissionRSA().GetPublic()), - rsa.CreatePrivateKeyPem(cryptoUser.GetTransmissionRSA()), - cryptoUser.GetTransmissionSalt()) + pubPEM, privPEM, cryptoUser.GetTransmissionSalt()) if err != nil { return errors.WithMessage(err, "failed to load client") } @@ -508,18 +547,22 @@ func (c *Client) initPermissioning(def *ndf.NetworkDefinition) error { //register with registration if necessary if c.storage.GetRegistrationStatus() == storage.KeyGenComplete { - jww.INFO.Printf("Client has not registered yet, attempting registration") + jww.INFO.Printf("Client has not registered yet, " + + "attempting registration") err = c.registerWithPermissioning() if err != nil { - jww.ERROR.Printf("Client has failed registration: %s", err) + jww.ERROR.Printf("Client has failed registration: %s", + err) return errors.WithMessage(err, "failed to load client") } - jww.INFO.Printf("Client successfully registered with the network") + jww.INFO.Printf("Client successfully registered " + + "with the network") } return nil } -// registerFollower adds the follower processes to the client's follower service list. +// registerFollower adds the follower processes to the client's +// follower service list. // This should only ever be called once func (c *Client) registerFollower() error { //build the error callback @@ -531,11 +574,12 @@ func (c *Client) registerFollower() error { Trace: trace, }: default: - jww.WARN.Printf("Failed to notify about ClientError from %s: %s", source, message) + jww.WARN.Printf("Failed to notify about ClientError "+ + "from %s: %s", source, message) } } - err := c.followerServices.add(c.events.eventService) + err := c.followerServices.add(c.events.EventService) if err != nil { return errors.WithMessage(err, "Couldn't start event reporting") } @@ -549,26 +593,14 @@ func (c *Client) registerFollower() error { "the network") } - //register the incremental key upgrade service - err = c.followerServices.add(c.auth.StartProcesses) - if err != nil { - return errors.WithMessage(err, "Failed to start following "+ - "the network") - } - - //register the key exchange service - keyXchange := func() (stoppable.Stoppable, error) { - return keyExchange2.Start(c.switchboard, c.storage, c.network, c.parameters.Rekey) - } - err = c.followerServices.add(keyXchange) - return nil } // ----- Client Functions ----- // GetErrorsChannel returns a channel which passess errors from the -// long running threads controlled by StartNetworkFollower and StopNetworkFollower +// long running threads controlled by StartNetworkFollower and +// StopNetworkFollower func (c *Client) GetErrorsChannel() <-chan interfaces.ClientError { return c.clientErrorChannel } @@ -584,16 +616,20 @@ func (c *Client) GetErrorsChannel() <-chan interfaces.ClientError { // - Network Follower (/network/follow.go) // tracks the network events and hands them off to workers for handling // - Historical Round Retrieval (/network/rounds/historical.go) -// Retrieves data about rounds which are too old to be stored by the client +// Retrieves data about rounds which are too old to be +// stored by the client // - Message Retrieval Worker Group (/network/rounds/retrieve.go) -// Requests all messages in a given round from the gateway of the last nodes +// Requests all messages in a given round from the +// gateway of the last nodes // - Message Handling Worker Group (/network/message/handle.go) -// Decrypts and partitions messages when signals via the Switchboard +// Decrypts and partitions messages when signals via the +// Switchboard // - health Tracker (/network/health) // Via the network instance tracks the state of the network // - Garbled Messages (/network/message/garbled.go) -// Can be signaled to check all recent messages which could be be decoded -// Uses a message store on disk for persistence +// Can be signaled to check all recent messages which +// could be be decoded Uses a message store on disk for +// persistence // - Critical Messages (/network/message/critical.go) // Ensures all protocol layer mandatory messages are sent // Uses a message store on disk for persistence @@ -640,20 +676,15 @@ func (c *Client) HasRunningProcessies() bool { // Returns the health tracker for registration and polling func (c *Client) GetHealth() interfaces.HealthTracker { jww.INFO.Printf("GetHealth()") - return c.network.GetHealthTracker() -} - -// Returns the switchboard for Registration -func (c *Client) GetSwitchboard() interfaces.Switchboard { - jww.INFO.Printf("GetSwitchboard()") - return c.switchboard + return c.GetHealth() } // RegisterRoundEventsCb registers a callback for round // events. func (c *Client) GetRoundEvents() interfaces.RoundEvents { jww.INFO.Printf("GetRoundEvents()") - jww.WARN.Printf("GetRoundEvents does not handle Client Errors edge case!") + jww.WARN.Printf("GetRoundEvents does not handle Client Errors " + + "edge case!") return c.network.GetInstance().GetRoundEvents() } @@ -667,7 +698,7 @@ func (c *Client) AddService(sp Service) error { // can be serialized into a byte stream for out-of-band sharing. func (c *Client) GetUser() user.Info { jww.INFO.Printf("GetUser()") - return c.storage.GetUser() + return c.GetUser() } // GetComms returns the client comms object @@ -681,28 +712,21 @@ func (c *Client) GetRng() *fastRNG.StreamGenerator { } // GetStorage returns the client storage object -func (c *Client) GetStorage() *storage.Session { +func (c *Client) GetStorage() storage.Session { return c.storage } // GetNetworkInterface returns the client Network Interface -func (c *Client) GetNetworkInterface() interfaces.NetworkManager { +func (c *Client) GetNetworkInterface() cmix.Client { return c.network } // GetBackup returns a pointer to the backup container so that the backup can be // set and triggered. -func (c *Client) GetBackup() *interfaces.BackupContainer { +func (c *Client) GetBackup() *backup.Backup { return c.backup } -// GetRateLimitParams retrieves the rate limiting parameters. -func (c *Client) GetRateLimitParams() (uint32, uint32, int64) { - rateLimitParams := c.storage.GetBucketParams().Get() - return rateLimitParams.Capacity, rateLimitParams.LeakedTokens, - rateLimitParams.LeakDuration.Nanoseconds() -} - // GetNodeRegistrationStatus gets the current state of nodes registration. It // returns the total number of nodes in the NDF and the number of those which // are currently registers with. An error is returned if the network is not @@ -710,27 +734,25 @@ func (c *Client) GetRateLimitParams() (uint32, uint32, int64) { func (c *Client) GetNodeRegistrationStatus() (int, int, error) { // Return an error if the network is not healthy if !c.GetHealth().IsHealthy() { - return 0, 0, errors.New("Cannot get number of nodes registrations when " + - "network is not healthy") + return 0, 0, errors.New("Cannot get number of nodes " + + "registrations when network is not healthy") } - nodes := c.GetNetworkInterface().GetInstance().GetPartialNdf().Get().Nodes - - cmixStore := c.storage.Cmix() + nodes := c.network.GetInstance().GetFullNdf().Get().Nodes var numRegistered int var numStale = 0 for i, n := range nodes { nid, err := id.Unmarshal(n.ID) if err != nil { - return 0, 0, errors.Errorf("Failed to unmarshal nodes ID %v "+ - "(#%d): %s", n.ID, i, err.Error()) + return 0, 0, errors.Errorf("Failed to unmarshal nodes "+ + "ID %v (#%d): %s", n.ID, i, err.Error()) } if n.Status == ndf.Stale { numStale += 1 continue } - if cmixStore.Has(nid) { + if c.network.HasNode(nid) { numRegistered++ } } @@ -744,124 +766,50 @@ func (c *Client) GetNodeRegistrationStatus() (int, int, error) { // partner ID an error will be returned. func (c *Client) DeleteRequest(partnerId *id.ID) error { jww.DEBUG.Printf("Deleting request for partner ID: %s", partnerId) - return c.GetStorage().Auth().DeleteRequest(partnerId) + return c.auth.DeleteRequest(partnerId) } // DeleteAllRequests clears all requests from client's auth storage. func (c *Client) DeleteAllRequests() error { jww.DEBUG.Printf("Deleting all requests") - return c.GetStorage().Auth().DeleteAllRequests() + return c.auth.DeleteAllRequests() } // DeleteSentRequests clears sent requests from client's auth storage. func (c *Client) DeleteSentRequests() error { jww.DEBUG.Printf("Deleting all sent requests") - return c.GetStorage().Auth().DeleteSentRequests() + return c.auth.DeleteSentRequests() } // DeleteReceiveRequests clears receive requests from client's auth storage. func (c *Client) DeleteReceiveRequests() error { jww.DEBUG.Printf("Deleting all received requests") - return c.GetStorage().Auth().DeleteReceiveRequests() + return c.auth.DeleteReceiveRequests() } // DeleteContact is a function which removes a partner from Client's storage func (c *Client) DeleteContact(partnerId *id.ID) error { jww.DEBUG.Printf("Deleting contact with ID %s", partnerId) - // get the partner so that they can be removed from preimage store - partner, err := c.storage.E2e().GetPartner(partnerId) + + _, err := c.e2e.GetPartner(partnerId) if err != nil { return errors.WithMessagef(err, "Could not delete %s because "+ "they could not be found", partnerId) } - e2ePreimage := partner.GetE2EPreimage() - rekeyPreimage := partner.GetSilentPreimage() - fileTransferPreimage := partner.GetFileTransferPreimage() - groupRequestPreimage := partner.GetGroupRequestPreimage() - //delete the partner - if err = c.storage.E2e().DeletePartner(partnerId); err != nil { + if err = c.e2e.DeletePartner(partnerId); err != nil { return err } - // Trigger backup c.backup.TriggerBackup("contact deleted") - //delete the preimages - if err = c.storage.GetEdge().Remove(edge.Preimage{ - Data: e2ePreimage, - Type: catalog.E2e, - Source: partnerId[:], - }, c.storage.GetUser().ReceptionID); err != nil { - jww.WARN.Printf("Failed delete the preimage for e2e "+ - "from %s on contact deletion: %+v", partnerId, err) - } - - if err = c.storage.GetEdge().Remove(edge.Preimage{ - Data: rekeyPreimage, - Type: catalog.Silent, - Source: partnerId[:], - }, c.storage.GetUser().ReceptionID); err != nil { - jww.WARN.Printf("Failed delete the preimage for rekey "+ - "from %s on contact deletion: %+v", partnerId, err) - } - - if err = c.storage.GetEdge().Remove(edge.Preimage{ - Data: fileTransferPreimage, - Type: catalog.EndFT, - Source: partnerId[:], - }, c.storage.GetUser().ReceptionID); err != nil { - jww.WARN.Printf("Failed delete the preimage for file transfer "+ - "from %s on contact deletion: %+v", partnerId, err) - } - - if err = c.storage.GetEdge().Remove(edge.Preimage{ - Data: groupRequestPreimage, - Type: catalog.GroupRq, - Source: partnerId[:], - }, c.storage.GetUser().ReceptionID); err != nil { - jww.WARN.Printf("Failed delete the preimage for group request "+ - "from %s on contact deletion: %+v", partnerId, err) - } - - //delete conversations - c.storage.Conversations().Delete(partnerId) + // FIXME: Do we need this? + // c.e2e.Conversations().Delete(partnerId) // call delete requests to make sure nothing is lingering. // this is for saftey to ensure the contact can be readded // in the future - _ = c.storage.Auth().Delete(partnerId) - - return nil -} - -// SetProxiedBins updates the host pool filter that filters out gateways that -// are not in one of the specified bins. -func (c *Client) SetProxiedBins(binStrings []string) error { - // Convert each region string into a region.GeoBin and place in a map for - // easy lookup - bins := make(map[region.GeoBin]bool, len(binStrings)) - for i, binStr := range binStrings { - bin, err := region.GetRegion(binStr) - if err != nil { - return errors.Errorf("failed to parse geographic bin #%d: %+v", i, err) - } - - bins[bin] = true - } - - // Create filter func - f := func(m map[id.ID]int, netDef *ndf.NetworkDefinition) map[id.ID]int { - prunedList := make(map[id.ID]int, len(m)) - for gwID, i := range m { - if bins[netDef.Gateways[i].Bin] { - prunedList[gwID] = i - } - } - return prunedList - } - - c.network.SetPoolFilter(f) + _ = c.auth.DeleteRequest(partnerId) return nil } @@ -872,8 +820,8 @@ func (c *Client) GetPreferredBins(countryCode string) ([]string, error) { // get the bin that the country is in bin, exists := region.GetCountryBin(countryCode) if !exists { - return nil, errors.Errorf("failed to find geographic bin for country %q", - countryCode) + return nil, errors.Errorf("failed to find geographic bin "+ + "for country %q", countryCode) } // Add bin to list of geographic bins @@ -884,17 +832,24 @@ func (c *Client) GetPreferredBins(countryCode string) ([]string, error) { case region.SouthAndCentralAmerica: bins = append(bins, region.NorthAmerica.String()) case region.MiddleEast: - bins = append(bins, region.EasternEurope.String(), region.CentralEurope.String(), region.WesternAsia.String()) + bins = append(bins, region.EasternEurope.String(), + region.CentralEurope.String(), + region.WesternAsia.String()) case region.NorthernAfrica: - bins = append(bins, region.WesternEurope.String(), region.CentralEurope.String()) + bins = append(bins, region.WesternEurope.String(), + region.CentralEurope.String()) case region.SouthernAfrica: - bins = append(bins, region.WesternEurope.String(), region.CentralEurope.String()) + bins = append(bins, region.WesternEurope.String(), + region.CentralEurope.String()) case region.EasternAsia: - bins = append(bins, region.WesternAsia.String(), region.Oceania.String(), region.NorthAmerica.String()) + bins = append(bins, region.WesternAsia.String(), + region.Oceania.String(), region.NorthAmerica.String()) case region.WesternAsia: - bins = append(bins, region.EasternAsia.String(), region.Russia.String(), region.MiddleEast.String()) + bins = append(bins, region.EasternAsia.String(), + region.Russia.String(), region.MiddleEast.String()) case region.Oceania: - bins = append(bins, region.EasternAsia.String(), region.NorthAmerica.String()) + bins = append(bins, region.EasternAsia.String(), + region.NorthAmerica.String()) } return bins, nil @@ -932,23 +887,25 @@ func decodeGroups(ndf *ndf.NetworkDefinition) (cmixGrp, e2eGrp *cyclic.Group) { return cmixGrp, e2eGrp } -// checkVersionAndSetupStorage is common code shared by NewClient, NewPrecannedClient and NewVanityClient -// it checks client version and creates a new storage for user data +// checkVersionAndSetupStorage is common code shared by NewClient, +// NewPrecannedClient and NewVanityClient it checks client version and +// creates a new storage for user data func checkVersionAndSetupStorage(def *ndf.NetworkDefinition, storageDir string, password []byte, protoUser user.Info, cmixGrp, e2eGrp *cyclic.Group, rngStreamGen *fastRNG.StreamGenerator, - isPrecanned bool, registrationCode string) (*storage.Session, error) { + isPrecanned bool, registrationCode string) (storage.Session, error) { // get current client version currentVersion, err := version.ParseVersion(SEMVER) if err != nil { - return nil, errors.WithMessage(err, "Could not parse version string.") + return nil, errors.WithMessage(err, + "Could not parse version string.") } // Create Storage passwordStr := string(password) storageSess, err := storage.New(storageDir, passwordStr, protoUser, - currentVersion, cmixGrp, e2eGrp, rngStreamGen, def.RateLimits) + currentVersion, cmixGrp, e2eGrp) if err != nil { return nil, err } @@ -960,19 +917,15 @@ func checkVersionAndSetupStorage(def *ndf.NetworkDefinition, //store the registration code for later use storageSess.SetRegCode(registrationCode) //move the registration state to keys generated - err = storageSess.ForwardRegistrationStatus(storage.KeyGenComplete) + err = storageSess.ForwardRegistrationStatus( + storage.KeyGenComplete) } else { - //move the registration state to indicate registered with registration - err = storageSess.ForwardRegistrationStatus(storage.PermissioningComplete) + //move the registration state to indicate registered + // with registration + err = storageSess.ForwardRegistrationStatus( + storage.PermissioningComplete) } - //add the request preiamge - storageSess.GetEdge().Add(edge.Preimage{ - Data: preimage.GenerateRequest(protoUser.ReceptionID), - Type: catalog.Request, - Source: protoUser.ReceptionID[:], - }, protoUser.ReceptionID) - if err != nil { return nil, errors.WithMessage(err, "Failed to denote state "+ "change in session") diff --git a/api/notifications.go b/api/notifications.go index 58e97190e469a7d21d9175b9b84ddaff89b75e2c..60355a48034572902905c876591004e039b2b0b0 100644 --- a/api/notifications.go +++ b/api/notifications.go @@ -34,20 +34,26 @@ func (c *Client) RegisterForNotifications(token string) error { if err != nil { return err } + + privKey := c.userState.CryptographicIdentity.GetTransmissionRSA() + pubPEM := rsa.CreatePublicKeyPem(privKey.GetPublic()) + regSig := c.userState.GetTransmissionRegistrationValidationSignature() + regTS := c.GetUser().RegistrationTimestamp + // Send the register message _, err = c.comms.RegisterForNotifications(notificationBotHost, &mixmessages.NotificationRegisterRequest{ Token: token, IntermediaryId: intermediaryReceptionID, - TransmissionRsa: rsa.CreatePublicKeyPem(c.GetStorage().User().GetCryptographicIdentity().GetTransmissionRSA().GetPublic()), + TransmissionRsa: pubPEM, TransmissionSalt: c.GetUser().TransmissionSalt, - TransmissionRsaSig: c.GetStorage().User().GetTransmissionRegistrationValidationSignature(), + TransmissionRsaSig: regSig, IIDTransmissionRsaSig: sig, - RegistrationTimestamp: c.GetUser().RegistrationTimestamp, + RegistrationTimestamp: regTS, }) if err != nil { - err := errors.Errorf( - "RegisterForNotifications: Unable to register for notifications! %s", err) + err := errors.Errorf("RegisterForNotifications: Unable to "+ + "register for notifications! %s", err) return err } @@ -96,7 +102,9 @@ func (c *Client) getIidAndSig() ([]byte, []byte, error) { stream := c.rng.GetStream() c.GetUser() - sig, err := rsa.Sign(stream, c.storage.User().GetCryptographicIdentity().GetTransmissionRSA(), hash.CMixHash, h.Sum(nil), nil) + sig, err := rsa.Sign(stream, + c.userState.GetTransmissionRSA(), + hash.CMixHash, h.Sum(nil), nil) if err != nil { return nil, nil, errors.WithMessage(err, "RegisterForNotifications: Failed to sign intermediary ID") } diff --git a/api/permissioning.go b/api/permissioning.go index 2dd62ec87dd8a95f43bce9f93df31db105edc4ca..b4060d994c84622f3395193c145a1e257559c815 100644 --- a/api/permissioning.go +++ b/api/permissioning.go @@ -9,17 +9,18 @@ package api import ( "encoding/json" + "github.com/pkg/errors" - "gitlab.com/elixxir/client/interfaces/user" "gitlab.com/elixxir/client/storage" + "gitlab.com/elixxir/client/storage/user" ) // Returns an error if registration fails. func (c *Client) registerWithPermissioning() error { - userData := c.storage.User() + userData := c.userState.CryptographicIdentity //get the users public key - transmissionPubKey := userData.GetCryptographicIdentity().GetTransmissionRSA().GetPublic() - receptionPubKey := userData.GetCryptographicIdentity().GetReceptionRSA().GetPublic() + transmissionPubKey := userData.GetTransmissionRSA().GetPublic() + receptionPubKey := userData.GetReceptionRSA().GetPublic() //load the registration code regCode, err := c.storage.GetRegCode() @@ -30,16 +31,19 @@ func (c *Client) registerWithPermissioning() error { //register with registration transmissionRegValidationSignature, receptionRegValidationSignature, - registrationTimestamp, err := c.permissioning.Register(transmissionPubKey, receptionPubKey, regCode) + registrationTimestamp, err := c.permissioning.Register( + transmissionPubKey, receptionPubKey, regCode) if err != nil { return errors.WithMessage(err, "failed to register with "+ "permissioning") } //store the signature - userData.SetTransmissionRegistrationValidationSignature(transmissionRegValidationSignature) - userData.SetReceptionRegistrationValidationSignature(receptionRegValidationSignature) - userData.SetRegistrationTimestamp(registrationTimestamp) + c.storage.SetTransmissionRegistrationValidationSignature( + transmissionRegValidationSignature) + c.storage.SetReceptionRegistrationValidationSignature( + receptionRegValidationSignature) + c.storage.SetRegistrationTimestamp(registrationTimestamp) //update the registration state err = c.storage.ForwardRegistrationStatus(storage.PermissioningComplete) @@ -50,9 +54,9 @@ func (c *Client) registerWithPermissioning() error { return nil } -// ConstructProtoUerFile is a helper function which is used for proto client testing. -// This is used for development testing. -func (c *Client) ConstructProtoUerFile() ([]byte, error) { +// ConstructProtoUserFile is a helper function which is used for proto +// client testing. This is used for development testing. +func (c *Client) ConstructProtoUserFile() ([]byte, error) { //load the registration code regCode, err := c.storage.GetRegCode() @@ -71,10 +75,10 @@ func (c *Client) ConstructProtoUerFile() ([]byte, error) { Precanned: c.GetUser().Precanned, RegistrationTimestamp: c.GetUser().RegistrationTimestamp, RegCode: regCode, - TransmissionRegValidationSig: c.storage.User().GetTransmissionRegistrationValidationSignature(), - ReceptionRegValidationSig: c.storage.User().GetReceptionRegistrationValidationSignature(), - E2eDhPrivateKey: c.GetStorage().E2e().GetDHPrivateKey(), - E2eDhPublicKey: c.GetStorage().E2e().GetDHPublicKey(), + TransmissionRegValidationSig: c.storage.GetTransmissionRegistrationValidationSignature(), + ReceptionRegValidationSig: c.storage.GetReceptionRegistrationValidationSignature(), + E2eDhPrivateKey: c.e2e.GetHistoricalDHPrivkey(), + E2eDhPublicKey: c.e2e.GetHistoricalDHPubkey(), } jsonBytes, err := json.Marshal(Usr) diff --git a/api/send.go b/api/send.go index 5af5704d82b3696693997a7c6911fd08955440dc..5959d81946ada40f94d56f40b06f0ead130ae3d5 100644 --- a/api/send.go +++ b/api/send.go @@ -8,15 +8,18 @@ package api import ( + "time" + "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" - "gitlab.com/elixxir/client/interfaces/message" - "gitlab.com/elixxir/client/interfaces/params" - "gitlab.com/elixxir/crypto/e2e" + "gitlab.com/elixxir/client/catalog" + "gitlab.com/elixxir/client/cmix" + "gitlab.com/elixxir/client/cmix/message" + "gitlab.com/elixxir/client/e2e" + e2eCrypto "gitlab.com/elixxir/crypto/e2e" "gitlab.com/elixxir/primitives/format" "gitlab.com/xx_network/primitives/id" "gitlab.com/xx_network/primitives/id/ephemeral" - "time" ) //This holds all functions to send messages over the network @@ -24,11 +27,12 @@ import ( // SendE2E sends an end-to-end payload to the provided recipient with // the provided msgType. Returns the list of rounds in which parts of // the message were sent or an error if it fails. -func (c *Client) SendE2E(m message.Send, param params.E2E) ([]id.Round, - e2e.MessageID, time.Time, error) { - jww.INFO.Printf("SendE2E(%s, %d. %v)", m.Recipient, - m.MessageType, m.Payload) - return c.network.SendE2E(m, param, nil) +func (c *Client) SendE2E(mt catalog.MessageType, recipient *id.ID, + payload []byte, param e2e.Params) ([]id.Round, + e2eCrypto.MessageID, time.Time, error) { + jww.INFO.Printf("SendE2E(%s, %d. %v)", recipient, + mt, payload) + return c.e2e.SendE2E(mt, recipient, payload, param) } // SendUnsafe sends an unencrypted payload to the provided recipient @@ -36,11 +40,12 @@ func (c *Client) SendE2E(m message.Send, param params.E2E) ([]id.Round, // of the message were sent or an error if it fails. // NOTE: Do not use this function unless you know what you are doing. // This function always produces an error message in client logging. -func (c *Client) SendUnsafe(m message.Send, param params.Unsafe) ([]id.Round, +func (c *Client) SendUnsafe(mt catalog.MessageType, recipient *id.ID, + payload []byte, param e2e.Params) ([]id.Round, time.Time, error) { - jww.INFO.Printf("SendUnsafe(%s, %d. %v)", m.Recipient, - m.MessageType, m.Payload) - return c.network.SendUnsafe(m, param) + jww.INFO.Printf("SendUnsafe(%s, %d. %v)", recipient, + mt, payload) + return c.e2e.SendUnsafe(mt, recipient, payload, param) } // SendCMIX sends a "raw" CMIX message payload to the provided @@ -48,24 +53,26 @@ func (c *Client) SendUnsafe(m message.Send, param params.Unsafe) ([]id.Round, // Returns the round ID of the round the payload was sent or an error // if it fails. func (c *Client) SendCMIX(msg format.Message, recipientID *id.ID, - param params.CMIX) (id.Round, ephemeral.Id, error) { + param cmix.CMIXParams) (id.Round, ephemeral.Id, error) { jww.INFO.Printf("Send(%s)", string(msg.GetContents())) - return c.network.SendCMIX(msg, recipientID, param) + return c.network.Send(recipientID, msg.GetKeyFP(), + message.GetDefaultService(recipientID), + msg.GetContents(), msg.GetMac(), param) } // SendManyCMIX sends many "raw" CMIX message payloads to each of the // provided recipients. Used for group chat functionality. Returns the // round ID of the round the payload was sent or an error if it fails. -func (c *Client) SendManyCMIX(messages []message.TargetedCmixMessage, - params params.CMIX) (id.Round, []ephemeral.Id, error) { - return c.network.SendManyCMIX(messages, params) +func (c *Client) SendManyCMIX(messages []cmix.TargetedCmixMessage, + params cmix.CMIXParams) (id.Round, []ephemeral.Id, error) { + return c.network.SendMany(messages, params) } // NewCMIXMessage Creates a new cMix message with the right properties // for the current cMix network. // FIXME: this is weird and shouldn't be necessary, but it is. func (c *Client) NewCMIXMessage(contents []byte) (format.Message, error) { - primeSize := len(c.storage.Cmix().GetGroup().GetPBytes()) + primeSize := len(c.storage.GetCmixGroup().GetPBytes()) msg := format.NewMessage(primeSize) if len(contents) > msg.ContentsSize() { return format.Message{}, errors.New("Contents to long for cmix") diff --git a/api/user.go b/api/user.go index 5b64f0b0a2df63408cf0b6f3baa0c2a391132caf..7ed520a0ffd2d19b12c1af772713273f7fbd7210 100644 --- a/api/user.go +++ b/api/user.go @@ -9,19 +9,20 @@ package api import ( "encoding/binary" + "math/rand" + "regexp" + "runtime" + "strings" + "sync" + jww "github.com/spf13/jwalterweatherman" - "gitlab.com/elixxir/client/interfaces/user" + "gitlab.com/elixxir/client/storage/user" "gitlab.com/elixxir/crypto/cyclic" "gitlab.com/elixxir/crypto/fastRNG" "gitlab.com/xx_network/crypto/csprng" "gitlab.com/xx_network/crypto/signature/rsa" "gitlab.com/xx_network/crypto/xx" "gitlab.com/xx_network/primitives/id" - "math/rand" - "regexp" - "runtime" - "strings" - "sync" ) const ( @@ -30,7 +31,8 @@ const ( ) // createNewUser generates an identity for cMix -func createNewUser(rng *fastRNG.StreamGenerator, cmix, e2e *cyclic.Group) user.Info { +func createNewUser(rng *fastRNG.StreamGenerator, cmix, + e2e *cyclic.Group) user.Info { // CMIX Keygen var transmissionRsaKey, receptionRsaKey *rsa.PrivateKey @@ -65,12 +67,14 @@ func createNewUser(rng *fastRNG.StreamGenerator, cmix, e2e *cyclic.Group) user.I stream.Close() - transmissionID, err := xx.NewID(transmissionRsaKey.GetPublic(), transmissionSalt, id.User) + transmissionID, err := xx.NewID(transmissionRsaKey.GetPublic(), + transmissionSalt, id.User) if err != nil { jww.FATAL.Panicf(err.Error()) } - receptionID, err := xx.NewID(receptionRsaKey.GetPublic(), receptionSalt, id.User) + receptionID, err := xx.NewID(receptionRsaKey.GetPublic(), + receptionSalt, id.User) if err != nil { jww.FATAL.Panicf(err.Error()) } @@ -98,13 +102,12 @@ func createDhKeys(rng *fastRNG.StreamGenerator, go func() { defer wg.Done() var err error - // DH Keygen - // FIXME: Why 256 bits? -- this is spec but not explained, it has - // to do with optimizing operations on one side and still preserves - // decent security -- cite this. Why valid for BOTH e2e and cmix? - stream := rng.GetStream() - e2eKeyBytes, err = csprng.GenerateInGroup(e2e.GetPBytes(), 256, stream) - stream.Close() + rngStream := rng.GetStream() + prime := e2e.GetPBytes() + keyLen := len(prime) + e2eKeyBytes, err = csprng.GenerateInGroup(prime, keyLen, + rngStream) + rngStream.Close() if err != nil { jww.FATAL.Panicf(err.Error()) } @@ -115,7 +118,8 @@ func createDhKeys(rng *fastRNG.StreamGenerator, defer wg.Done() var err error stream := rng.GetStream() - transmissionRsaKey, err = rsa.GenerateKey(stream, rsa.DefaultRSABitLen) + transmissionRsaKey, err = rsa.GenerateKey(stream, + rsa.DefaultRSABitLen) stream.Close() if err != nil { jww.FATAL.Panicf(err.Error()) @@ -126,7 +130,8 @@ func createDhKeys(rng *fastRNG.StreamGenerator, defer wg.Done() var err error stream := rng.GetStream() - receptionRsaKey, err = rsa.GenerateKey(stream, rsa.DefaultRSABitLen) + receptionRsaKey, err = rsa.GenerateKey(stream, + rsa.DefaultRSABitLen) stream.Close() if err != nil { jww.FATAL.Panicf(err.Error()) @@ -140,13 +145,13 @@ func createDhKeys(rng *fastRNG.StreamGenerator, // TODO: Add precanned user code structures here. // creates a precanned user -func createPrecannedUser(precannedID uint, rng csprng.Source, cmix, e2e *cyclic.Group) user.Info { +func createPrecannedUser(precannedID uint, rng csprng.Source, cmix, + e2e *cyclic.Group) user.Info { // DH Keygen - // FIXME: Why 256 bits? -- this is spec but not explained, it has - // to do with optimizing operations on one side and still preserves - // decent security -- cite this. Why valid for BOTH e2e and cmix? prng := rand.New(rand.NewSource(int64(precannedID))) - e2eKeyBytes, err := csprng.GenerateInGroup(e2e.GetPBytes(), 256, prng) + prime := e2e.GetPBytes() + keyLen := len(prime) + e2eKeyBytes, err := csprng.GenerateInGroup(prime, keyLen, prng) if err != nil { jww.FATAL.Panicf(err.Error()) } @@ -178,12 +183,12 @@ func createPrecannedUser(precannedID uint, rng csprng.Source, cmix, e2e *cyclic. // createNewVanityUser generates an identity for cMix // The identity's ReceptionID is not random but starts with the supplied prefix -func createNewVanityUser(rng csprng.Source, cmix, e2e *cyclic.Group, prefix string) user.Info { +func createNewVanityUser(rng csprng.Source, cmix, + e2e *cyclic.Group, prefix string) user.Info { // DH Keygen - // FIXME: Why 256 bits? -- this is spec but not explained, it has - // to do with optimizing operations on one side and still preserves - // decent security -- cite this. Why valid for BOTH e2e and cmix? - e2eKeyBytes, err := csprng.GenerateInGroup(e2e.GetPBytes(), 256, rng) + prime := e2e.GetPBytes() + keyLen := len(prime) + e2eKeyBytes, err := csprng.GenerateInGroup(prime, keyLen, rng) if err != nil { jww.FATAL.Panicf(err.Error()) } @@ -203,7 +208,8 @@ func createNewVanityUser(rng csprng.Source, cmix, e2e *cyclic.Group, prefix stri if n != SaltSize { jww.FATAL.Panicf("transmissionSalt size too small: %d", n) } - transmissionID, err := xx.NewID(transmissionRsaKey.GetPublic(), transmissionSalt, id.User) + transmissionID, err := xx.NewID(transmissionRsaKey.GetPublic(), + transmissionSalt, id.User) if err != nil { jww.FATAL.Panicf(err.Error()) } @@ -213,7 +219,9 @@ func createNewVanityUser(rng csprng.Source, cmix, e2e *cyclic.Group, prefix stri jww.FATAL.Panicf(err.Error()) } - var mu sync.Mutex // just in case more than one go routine tries to access receptionSalt and receptionID + // just in case more than one go routine tries to access + // receptionSalt and receptionID + var mu sync.Mutex done := make(chan struct{}) found := make(chan bool) wg := &sync.WaitGroup{} @@ -234,7 +242,8 @@ func createNewVanityUser(rng csprng.Source, cmix, e2e *cyclic.Group, prefix stri if match == false { jww.FATAL.Panicf("Prefix contains non-Base64 characters") } - jww.INFO.Printf("Vanity userID generation started. Prefix: %s Ignore-Case: %v NumCPU: %d", pref, ignoreCase, cores) + jww.INFO.Printf("Vanity userID generation started. Prefix: %s "+ + "Ignore-Case: %v NumCPU: %d", pref, ignoreCase, cores) for w := 0; w < cores; w++ { wg.Add(1) go func() { @@ -245,14 +254,20 @@ func createNewVanityUser(rng csprng.Source, cmix, e2e *cyclic.Group, prefix stri defer wg.Done() return default: - n, err = csprng.NewSystemRNG().Read(rSalt) + n, err = csprng.NewSystemRNG().Read( + rSalt) if err != nil { jww.FATAL.Panicf(err.Error()) } if n != SaltSize { - jww.FATAL.Panicf("receptionSalt size too small: %d", n) + jww.FATAL.Panicf( + "receptionSalt size "+ + "too small: %d", + n) } - rID, err := xx.NewID(receptionRsaKey.GetPublic(), rSalt, id.User) + rID, err := xx.NewID( + receptionRsaKey.GetPublic(), + rSalt, id.User) if err != nil { jww.FATAL.Panicf(err.Error()) } @@ -273,7 +288,8 @@ func createNewVanityUser(rng csprng.Source, cmix, e2e *cyclic.Group, prefix stri } }() } - // wait for a solution then close the done channel to signal the workers to exit + // wait for a solution then close the done channel to signal + // the workers to exit <-found close(done) wg.Wait() diff --git a/api/utilsInterfaces_test.go b/api/utilsInterfaces_test.go index 29b5f8eb9abbc2217c1713cc25764472087701fa..617bce08b96810b920fb6ceb9aaf9ef3990f8757 100644 --- a/api/utilsInterfaces_test.go +++ b/api/utilsInterfaces_test.go @@ -7,21 +7,21 @@ package api import ( + "time" + + "gitlab.com/elixxir/client/cmix" "gitlab.com/elixxir/client/cmix/gateway" - "gitlab.com/elixxir/client/event" - "gitlab.com/elixxir/client/interfaces" - "gitlab.com/elixxir/client/interfaces/message" - "gitlab.com/elixxir/client/interfaces/params" + "gitlab.com/elixxir/client/cmix/identity" + "gitlab.com/elixxir/client/cmix/message" + "gitlab.com/elixxir/client/cmix/rounds" "gitlab.com/elixxir/client/stoppable" pb "gitlab.com/elixxir/comms/mixmessages" "gitlab.com/elixxir/comms/network" - cE2e "gitlab.com/elixxir/crypto/e2e" "gitlab.com/elixxir/primitives/format" "gitlab.com/elixxir/primitives/states" "gitlab.com/xx_network/comms/connect" "gitlab.com/xx_network/primitives/id" "gitlab.com/xx_network/primitives/id/ephemeral" - "time" ) // Mock comm struct which returns no historical round data @@ -82,46 +82,45 @@ func (ht *historicalRounds) GetHost(hostId *id.ID) (*connect.Host, bool) { // Contains a test implementation of the networkManager interface. type testNetworkManagerGeneric struct { instance *network.Instance - sender *gateway.Sender + sender gateway.Sender } type dummyEventMgr struct{} func (d *dummyEventMgr) Report(p int, a, b, c string) {} -func (t *testNetworkManagerGeneric) GetEventManager() event.Manager { - return &dummyEventMgr{} +func (d *dummyEventMgr) EventService() (stoppable.Stoppable, error) { + return nil, nil } /* Below methods built for interface adherence */ -func (t *testNetworkManagerGeneric) GetHealthTracker() interfaces.HealthTracker { - return nil -} -func (t *testNetworkManagerGeneric) Follow(report interfaces.ClientErrorReport) (stoppable.Stoppable, error) { +func (t *testNetworkManagerGeneric) Follow(report cmix.ClientErrorReport) (stoppable.Stoppable, error) { return nil, nil } -func (t *testNetworkManagerGeneric) CheckGarbledMessages() { +func (t *testNetworkManagerGeneric) GetMaxMessageLength() int { return 0 } + +func (t *testNetworkManagerGeneric) CheckInProgressMessages() { return } func (t *testNetworkManagerGeneric) GetVerboseRounds() string { return "" } -func (t *testNetworkManagerGeneric) SendE2E(message.Send, params.E2E, *stoppable.Single) ( - []id.Round, cE2e.MessageID, time.Time, error) { - rounds := []id.Round{id.Round(0), id.Round(1), id.Round(2)} - return rounds, cE2e.MessageID{}, time.Time{}, nil -} -func (t *testNetworkManagerGeneric) SendUnsafe(m message.Send, p params.Unsafe) ([]id.Round, error) { - return nil, nil +func (t *testNetworkManagerGeneric) AddFingerprint(identity *id.ID, fingerprint format.Fingerprint, mp message.Processor) error { + return nil } -func (t *testNetworkManagerGeneric) SendCMIX(message format.Message, rid *id.ID, p params.CMIX) (id.Round, ephemeral.Id, error) { + +func (t *testNetworkManagerGeneric) Send(*id.ID, format.Fingerprint, + message.Service, []byte, []byte, cmix.CMIXParams) (id.Round, + ephemeral.Id, error) { return id.Round(0), ephemeral.Id{}, nil } -func (t *testNetworkManagerGeneric) SendManyCMIX(messages []message.TargetedCmixMessage, p params.CMIX) (id.Round, []ephemeral.Id, error) { +func (t *testNetworkManagerGeneric) SendMany(messages []cmix.TargetedCmixMessage, + p cmix.CMIXParams) (id.Round, []ephemeral.Id, error) { return 0, []ephemeral.Id{}, nil } func (t *testNetworkManagerGeneric) GetInstance() *network.Instance { return t.instance } -func (t *testNetworkManagerGeneric) RegisterWithPermissioning(string) ([]byte, error) { +func (t *testNetworkManagerGeneric) RegisterWithPermissioning(string) ( + []byte, error) { return nil, nil } func (t *testNetworkManagerGeneric) GetRemoteVersion() (string, error) { @@ -135,7 +134,7 @@ func (t *testNetworkManagerGeneric) InProgressRegistrations() int { return 0 } -func (t *testNetworkManagerGeneric) GetSender() *gateway.Sender { +func (t *testNetworkManagerGeneric) GetSender() gateway.Sender { return t.sender } @@ -147,3 +146,66 @@ func (t *testNetworkManagerGeneric) RegisterAddressSizeNotification(string) (cha func (t *testNetworkManagerGeneric) UnregisterAddressSizeNotification(string) {} func (t *testNetworkManagerGeneric) SetPoolFilter(gateway.Filter) {} +func (t *testNetworkManagerGeneric) AddHealthCallback(f func(bool)) uint64 { + return 0 +} +func (t *testNetworkManagerGeneric) AddIdentity(id *id.ID, + validUntil time.Time, persistent bool) { +} +func (t *testNetworkManagerGeneric) RemoveIdentity(id *id.ID) {} +func (t *testNetworkManagerGeneric) AddService(clientID *id.ID, + newService message.Service, response message.Processor) { +} +func (t *testNetworkManagerGeneric) DeleteService(clientID *id.ID, + toDelete message.Service, processor message.Processor) { +} +func (t *testNetworkManagerGeneric) DeleteClientService(clientID *id.ID) { +} +func (t *testNetworkManagerGeneric) DeleteFingerprint(identity *id.ID, + fingerprint format.Fingerprint) { +} +func (t *testNetworkManagerGeneric) DeleteClientFingerprints(identity *id.ID) { +} +func (t *testNetworkManagerGeneric) GetAddressSpace() uint8 { return 0 } +func (t *testNetworkManagerGeneric) GetHostParams() connect.HostParams { + return connect.GetDefaultHostParams() +} +func (t *testNetworkManagerGeneric) GetIdentity(get *id.ID) ( + identity.TrackedID, error) { + return identity.TrackedID{}, nil +} +func (t *testNetworkManagerGeneric) GetRoundResults(timeout time.Duration, + roundCallback cmix.RoundEventCallback, roundList ...id.Round) error { + return nil +} +func (t *testNetworkManagerGeneric) HasNode(nid *id.ID) bool { return false } +func (t *testNetworkManagerGeneric) IsHealthy() bool { return true } +func (t *testNetworkManagerGeneric) WasHealthy() bool { return true } +func (t *testNetworkManagerGeneric) LookupHistoricalRound(rid id.Round, + callback rounds.RoundResultCallback) error { + return nil +} +func (t *testNetworkManagerGeneric) NumRegisteredNodes() int { return 0 } +func (t *testNetworkManagerGeneric) RegisterAddressSpaceNotification( + tag string) (chan uint8, error) { + return nil, nil +} +func (t *testNetworkManagerGeneric) RemoveHealthCallback(uint64) {} +func (t *testNetworkManagerGeneric) SendToAny( + sendFunc func(host *connect.Host) (interface{}, error), + stop *stoppable.Single) (interface{}, error) { + return nil, nil +} +func (t *testNetworkManagerGeneric) SendToPreferred(targets []*id.ID, + sendFunc gateway.SendToPreferredFunc, stop *stoppable.Single, + timeout time.Duration) (interface{}, error) { + return nil, nil +} +func (t *testNetworkManagerGeneric) SetGatewayFilter(f gateway.Filter) {} +func (t *testNetworkManagerGeneric) TrackServices( + tracker message.ServicesTracker) { +} +func (t *testNetworkManagerGeneric) TriggerNodeRegistration(nid *id.ID) {} +func (t *testNetworkManagerGeneric) UnregisterAddressSpaceNotification( + tag string) { +} diff --git a/api/utils_test.go b/api/utils_test.go index 2171ee98fb7645daef91c2fa2cb27c378c6e98a6..4496905d51eab7344b8d2fcde840a2cd3c862d11 100644 --- a/api/utils_test.go +++ b/api/utils_test.go @@ -9,12 +9,13 @@ package api import ( "bytes" - "gitlab.com/elixxir/client/cmix/gateway" "testing" + "gitlab.com/elixxir/client/cmix/gateway" + "gitlab.com/elixxir/client/e2e" + "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" - "gitlab.com/elixxir/client/interfaces/params" pb "gitlab.com/elixxir/comms/mixmessages" "gitlab.com/elixxir/comms/network" "gitlab.com/elixxir/comms/testkeys" @@ -32,7 +33,8 @@ func newTestingClient(face interface{}) (*Client, error) { case *testing.T, *testing.M, *testing.B, *testing.PB: break default: - jww.FATAL.Panicf("InitTestingSession is restricted to testing only. Got %T", face) + jww.FATAL.Panicf("InitTestingSession is restricted to testing "+ + "only. Got %T", face) } def := getNDF(face) @@ -41,12 +43,14 @@ func newTestingClient(face interface{}) (*Client, error) { password := []byte("hunter2") err := NewClient(string(marshalledDef), storageDir, password, "AAAA") if err != nil { - return nil, errors.Errorf("Could not construct a mock client: %v", err) + return nil, errors.Errorf( + "Could not construct a mock client: %v", err) } - c, err := OpenClient(storageDir, password, params.GetDefaultNetwork()) + c, err := OpenClient(storageDir, password, e2e.GetDefaultParams()) if err != nil { - return nil, errors.Errorf("Could not open a mock client: %v", err) + return nil, errors.Errorf("Could not open a mock client: %v", + err) } commsManager := connect.NewManagerTesting(face) @@ -56,19 +60,22 @@ func newTestingClient(face interface{}) (*Client, error) { jww.FATAL.Panicf("Failed to create new test instance: %v", err) } - commsManager.AddHost(&id.Permissioning, "", cert, connect.GetDefaultHostParams()) + commsManager.AddHost(&id.Permissioning, "", cert, + connect.GetDefaultHostParams()) instanceComms := &connect.ProtoComms{ Manager: commsManager, } - thisInstance, err := network.NewInstanceTesting(instanceComms, def, def, nil, nil, face) + thisInstance, err := network.NewInstanceTesting(instanceComms, def, + def, nil, nil, face) if err != nil { return nil, err } p := gateway.DefaultPoolParams() p.MaxPoolSize = 1 - sender, err := gateway.NewSender(p, c.rng, def, commsManager, c.storage, nil) + sender, err := gateway.NewSender(p, c.rng, def, commsManager, + c.storage, nil) if err != nil { return nil, err } diff --git a/auth/confirm.go b/auth/confirm.go index e523715c74a421773354216c9e335a5d9157b2fa..fc6c5b69c90fa6c5533fd26e570ad31c10e8e4a6 100644 --- a/auth/confirm.go +++ b/auth/confirm.go @@ -9,6 +9,7 @@ package auth import ( "fmt" + "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/auth/store" @@ -126,7 +127,7 @@ func (s *state) confirm(partner contact.Contact, serviceTag string) ( s.event.Report(10, "Auth", "SendConfirmError", em) } - //todo: s.backupTrigger("confirmed authenticated channel") + s.backupTrigger("confirmed authenticated channel") jww.INFO.Printf("Confirming Auth from %s to %s, msgDigest: %s", partner.ID, s.e2e.GetReceptionID(), @@ -152,7 +153,7 @@ func (s *state) confirm(partner contact.Contact, serviceTag string) ( } func sendAuthConfirm(net cmixClient, partner *id.ID, - fp format.Fingerprint, payload, mac []byte, event event.Manager, + fp format.Fingerprint, payload, mac []byte, event event.Reporter, serviceTag string) ( id.Round, error) { svc := message.Service{ diff --git a/auth/interface.go b/auth/interface.go index f44e221acf77e303b0e19559d5a3410335fa7a56..b8ea8865c78f89ba7ffacce3476e46aeb3ff0f7f 100644 --- a/auth/interface.go +++ b/auth/interface.go @@ -1,6 +1,7 @@ package auth import ( + "gitlab.com/elixxir/client/e2e" "gitlab.com/elixxir/crypto/contact" "gitlab.com/elixxir/primitives/fact" "gitlab.com/xx_network/primitives/id" @@ -60,4 +61,26 @@ type State interface { // CallAllReceivedRequests will iterate through all pending contact requests // and replay them on the callbacks. CallAllReceivedRequests() + + // DeleteRequest deletes sent or received requests for a + // specific partner ID. + DeleteRequest(partnerID *id.ID) error + + // DeleteAllRequests clears all requests from client's auth storage. + DeleteAllRequests() error + + // DeleteSentRequests clears all sent requests from client's auth + // storage. + DeleteSentRequests() error + + // DeleteReceiveRequests clears all received requests from client's auth + // storage. + DeleteReceiveRequests() error + + // GetReceivedRequest returns a contact if there's a received + // request for it. + GetReceivedRequest(partner *id.ID) (contact.Contact, error) + + // VerifyOwnership checks if the received ownership proof is valid + VerifyOwnership(received, verified contact.Contact, e2e e2e.Handler) bool } diff --git a/auth/receivedConfirm.go b/auth/receivedConfirm.go index 1990fd5bdca5eeb826573b6c06c9151ce96ce25c..07e12df9beae5cb55889d9eadd29a3361b68b2b3 100644 --- a/auth/receivedConfirm.go +++ b/auth/receivedConfirm.go @@ -5,9 +5,9 @@ import ( "fmt" jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/auth/store" - "gitlab.com/elixxir/client/cmix/rounds" "gitlab.com/elixxir/client/cmix/identity/receptionID" "gitlab.com/elixxir/client/cmix/message" + "gitlab.com/elixxir/client/cmix/rounds" "gitlab.com/elixxir/client/e2e/ratchet/partner/session" "gitlab.com/elixxir/crypto/contact" cAuth "gitlab.com/elixxir/crypto/e2e/auth" @@ -95,7 +95,7 @@ func (rcs *receivedConfirmService) Process(msg format.Message, "%s : %+v", rcs.GetPartner(), receptionID.Source, err) } - //todo: trigger backup + rcs.s.backupTrigger("received confirmation from request") // remove the service used for notifications of the confirm state.net.DeleteService(receptionID.Source, rcs.notificationsService, nil) diff --git a/auth/request.go b/auth/request.go index 3e6b8c874b8243c313182d39437aca0d1d412153..a37839fd8eb065258a61867bc45a671bc3b55dc7 100644 --- a/auth/request.go +++ b/auth/request.go @@ -9,11 +9,15 @@ package auth import ( "fmt" + "io" + "strings" + "github.com/cloudflare/circl/dh/sidh" "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/cmix" "gitlab.com/elixxir/client/cmix/message" + "gitlab.com/elixxir/client/e2e" "gitlab.com/elixxir/client/e2e/ratchet" util "gitlab.com/elixxir/client/storage/utility" "gitlab.com/elixxir/crypto/contact" @@ -23,8 +27,6 @@ import ( "gitlab.com/elixxir/primitives/fact" "gitlab.com/elixxir/primitives/format" "gitlab.com/xx_network/primitives/id" - "io" - "strings" ) const terminator = ";" @@ -199,3 +201,28 @@ func createRequestAuth(sender *id.ID, payload, ownership []byte, myDHPriv, return &baseFmt, mac, nil } + +func (s *state) GetReceivedRequest(partner *id.ID) (contact.Contact, error) { + return s.store.GetReceivedRequest(partner) +} + +func (s *state) VerifyOwnership(received, verified contact.Contact, + e2e e2e.Handler) bool { + return VerifyOwnership(received, verified, e2e) +} + +func (s *state) DeleteRequest(partnerID *id.ID) error { + return s.store.DeleteRequest(partnerID) +} + +func (s *state) DeleteAllRequests() error { + return s.store.DeleteAllRequests() +} + +func (s *state) DeleteSentRequests() error { + return s.store.DeleteSentRequests() +} + +func (s *state) DeleteReceiveRequests() error { + return s.store.DeleteReceiveRequests() +} diff --git a/auth/state.go b/auth/state.go index 7ced7eb8d69ca96ef321fd5b3ac5faa625c6dab7..a3e3d59df1ff43bf8eb6ace2582a009aaedab191 100644 --- a/auth/state.go +++ b/auth/state.go @@ -9,6 +9,7 @@ package auth import ( "encoding/base64" + "github.com/cloudflare/circl/dh/sidh" "github.com/pkg/errors" "gitlab.com/elixxir/client/auth/store" @@ -33,16 +34,16 @@ import ( type state struct { callbacks Callbacks - // net cmix.Client net cmixClient - // e2e e2e.Handler e2e e2eHandler rng *fastRNG.StreamGenerator store *store.Store - event event.Manager + event event.Reporter params Param + + backupTrigger func(reason string) } type cmixClient interface { @@ -100,10 +101,11 @@ type Callbacks interface { // with a memory only versioned.KV) as well as a memory only versioned.KV for // NewState and use GetDefaultTemporaryParams() for the parameters func NewState(kv *versioned.KV, net cmix.Client, e2e e2e.Handler, - rng *fastRNG.StreamGenerator, event event.Manager, params Param, - callbacks Callbacks) (State, error) { + rng *fastRNG.StreamGenerator, event event.Reporter, params Param, + callbacks Callbacks, backupTrigger func(reason string)) (State, error) { kv = kv.Prefix(makeStorePrefix(e2e.GetReceptionID())) - return NewStateLegacy(kv, net, e2e, rng, event, params, callbacks) + return NewStateLegacy( + kv, net, e2e, rng, event, params, callbacks, backupTrigger) } // NewStateLegacy loads the auth state or creates new auth state if one cannot be @@ -112,18 +114,17 @@ func NewState(kv *versioned.KV, net cmix.Client, e2e e2e.Handler, // Does not modify the kv prefix for backwards compatibility // Otherwise, acts the same as NewState func NewStateLegacy(kv *versioned.KV, net cmix.Client, e2e e2e.Handler, - rng *fastRNG.StreamGenerator, event event.Manager, params Param, - callbacks Callbacks) (State, error) { + rng *fastRNG.StreamGenerator, event event.Reporter, params Param, + callbacks Callbacks, backupTrigger func(reason string)) (State, error) { s := &state{ - callbacks: callbacks, - - net: net, - e2e: e2e, - rng: rng, - - params: params, - event: event, + callbacks: callbacks, + net: net, + e2e: e2e, + rng: rng, + event: event, + params: params, + backupTrigger: backupTrigger, } // create the store @@ -131,7 +132,7 @@ func NewStateLegacy(kv *versioned.KV, net cmix.Client, e2e e2e.Handler, s.store, err = store.NewOrLoadStore(kv, e2e.GetGroup(), &sentRequestHandler{s: s}) - //register services + // register services net.AddService(e2e.GetReceptionID(), message.Service{ Identifier: e2e.GetReceptionID()[:], Tag: params.RequestTag, diff --git a/backup/backup.go b/backup/backup.go index ea9667c0391b495b6c1c814cf5cf27f19341002e..291348fa6aa3bbbee8b49e6937013c4ea77b51f6 100644 --- a/backup/backup.go +++ b/backup/backup.go @@ -5,34 +5,30 @@ // LICENSE file // //////////////////////////////////////////////////////////////////////////////// -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - package backup import ( + "gitlab.com/elixxir/client/storage/versioned" + "gitlab.com/elixxir/crypto/cyclic" + "gitlab.com/elixxir/primitives/fact" + "gitlab.com/xx_network/primitives/id" "sync" + "time" "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" - "gitlab.com/elixxir/client/api" - "gitlab.com/elixxir/client/interfaces" - "gitlab.com/elixxir/client/storage" "gitlab.com/elixxir/crypto/backup" "gitlab.com/elixxir/crypto/fastRNG" + "gitlab.com/xx_network/crypto/signature/rsa" ) // Error messages. const ( - // initializeBackup + // InitializeBackup errSavePassword = "failed to save password: %+v" errSaveKeySaltParams = "failed to save key, salt, and params: %+v" - // resumeBackup + // ResumeBackup errLoadPassword = "backup not initialized: load user password failed: %+v" // Backup.StopBackup @@ -46,15 +42,43 @@ type Backup struct { // Callback that is called with the encrypted backup when triggered updateBackupCb UpdateBackupFn - mux sync.RWMutex + container *Container + + jsonParams string // Client structures - client *api.Client - store *storage.Session - backupContainer *interfaces.BackupContainer - rng *fastRNG.StreamGenerator + e2e E2e + session Session + ud UserDiscovery + kv *versioned.KV + rng *fastRNG.StreamGenerator - jsonParams string + mux sync.RWMutex +} + +// E2e is a subset of functions from the interface e2e.Handler. +type E2e interface { + GetAllPartnerIDs() []*id.ID + GetHistoricalDHPubkey() *cyclic.Int + GetHistoricalDHPrivkey() *cyclic.Int +} + +// Session is a subset of functions from the interface storage.Session. +type Session interface { + GetRegCode() (string, error) + GetTransmissionID() *id.ID + GetTransmissionSalt() []byte + GetReceptionID() *id.ID + GetReceptionSalt() []byte + GetReceptionRSA() *rsa.PrivateKey + GetTransmissionRSA() *rsa.PrivateKey + GetTransmissionRegistrationValidationSignature() []byte + GetReceptionRegistrationValidationSignature() []byte + GetRegistrationTimestamp() time.Time +} + +type UserDiscovery interface { + GetFacts() fact.FactList } // UpdateBackupFn is the callback that encrypted backup data is returned on @@ -68,27 +92,20 @@ type UpdateBackupFn func(encryptedBackup []byte) // Call this to turn on backups for the first time or to replace the user's // password. func InitializeBackup(password string, updateBackupCb UpdateBackupFn, - c *api.Client) (*Backup, error) { - return initializeBackup( - password, updateBackupCb, c, c.GetStorage(), c.GetBackup(), c.GetRng()) -} - -// initializeBackup is a helper function that takes in all the fields for Backup -// as parameters for easier testing. -func initializeBackup(password string, updateBackupCb UpdateBackupFn, - c *api.Client, store *storage.Session, - backupContainer *interfaces.BackupContainer, rng *fastRNG.StreamGenerator) ( - *Backup, error) { + container *Container, e2e E2e, session Session, ud UserDiscovery, + kv *versioned.KV, rng *fastRNG.StreamGenerator) (*Backup, error) { b := &Backup{ - updateBackupCb: updateBackupCb, - client: c, - store: store, - backupContainer: backupContainer, - rng: rng, + updateBackupCb: updateBackupCb, + container: container, + e2e: e2e, + session: session, + ud: ud, + kv: kv, + rng: rng, } // Save password to storage - err := savePassword(password, b.store.GetKV()) + err := savePassword(password, b.kv) if err != nil { return nil, errors.Errorf(errSavePassword, err) } @@ -105,15 +122,15 @@ func initializeBackup(password string, updateBackupCb UpdateBackupFn, key := backup.DeriveKey(password, salt, params) // Save key, salt, and parameters to storage - err = saveBackup(key, salt, params, b.store.GetKV()) + err = saveBackup(key, salt, params, b.kv) if err != nil { return nil, errors.Errorf(errSaveKeySaltParams, err) } // Setting backup trigger in client - b.backupContainer.SetBackup(b.TriggerBackup) + b.container.SetBackup(b.TriggerBackup) - b.TriggerBackup("initializeBackup") + b.TriggerBackup("InitializeBackup") jww.INFO.Print("Initialized backup with new user key.") return b, nil @@ -122,32 +139,27 @@ func initializeBackup(password string, updateBackupCb UpdateBackupFn, // ResumeBackup resumes a backup by restoring the Backup object and registering // a new callback. Call this to resume backups that have already been // initialized. Returns an error if backups have not already been initialized. -func ResumeBackup(updateBackupCb UpdateBackupFn, c *api.Client) (*Backup, error) { - return resumeBackup( - updateBackupCb, c, c.GetStorage(), c.GetBackup(), c.GetRng()) -} - -// resumeBackup is a helper function that takes in all the fields for Backup as -// parameters for easier testing. -func resumeBackup(updateBackupCb UpdateBackupFn, c *api.Client, - store *storage.Session, backupContainer *interfaces.BackupContainer, +func ResumeBackup(updateBackupCb UpdateBackupFn, container *Container, + e2e E2e, session Session, ud UserDiscovery, kv *versioned.KV, rng *fastRNG.StreamGenerator) (*Backup, error) { - _, err := loadPassword(store.GetKV()) + _, err := loadPassword(kv) if err != nil { return nil, errors.Errorf(errLoadPassword, err) } b := &Backup{ - updateBackupCb: updateBackupCb, - client: c, - store: store, - backupContainer: backupContainer, - rng: rng, - jsonParams: loadJson(store.GetKV()), + updateBackupCb: updateBackupCb, + container: container, + jsonParams: loadJson(kv), + e2e: e2e, + session: session, + ud: ud, + kv: kv, + rng: rng, } // Setting backup trigger in client - b.backupContainer.SetBackup(b.TriggerBackup) + b.container.SetBackup(b.TriggerBackup) jww.INFO.Print("Resumed backup with password loaded from storage.") @@ -181,7 +193,7 @@ func (b *Backup) TriggerBackup(reason string) { b.mux.RLock() defer b.mux.RUnlock() - key, salt, params, err := loadBackup(b.store.GetKV()) + key, salt, params, err := loadBackup(b.kv) if err != nil { jww.ERROR.Printf("Backup Failed: could not load key, salt, and "+ "parameters for encrypting backup from storage: %+v", err) @@ -217,7 +229,7 @@ func (b *Backup) AddJson(newJson string) { if newJson != b.jsonParams { b.jsonParams = newJson - if err := storeJson(newJson, b.store.GetKV()); err != nil { + if err := storeJson(newJson, b.kv); err != nil { jww.FATAL.Panicf("Failed to store json: %+v", err) } go b.TriggerBackup("New Json") @@ -231,12 +243,12 @@ func (b *Backup) StopBackup() error { defer b.mux.Unlock() b.updateBackupCb = nil - err := deletePassword(b.store.GetKV()) + err := deletePassword(b.kv) if err != nil { return errors.Errorf(errDeletePassword, err) } - err = deleteBackup(b.store.GetKV()) + err = deleteBackup(b.kv) if err != nil { return errors.Errorf(errDeleteCrypto, err) } @@ -268,42 +280,38 @@ func (b *Backup) assembleBackup() backup.Backup { Contacts: backup.Contacts{}, } - // get user and storage user - u := b.store.GetUser() - su := b.store.User() - - // get registration timestamp - bu.RegistrationTimestamp = u.RegistrationTimestamp + // Get registration timestamp + bu.RegistrationTimestamp = b.session.GetRegistrationTimestamp().UnixNano() - // get registration code; ignore the error because if there is no + // Get registration code; ignore the error because if there is no // registration, then an empty string is returned - bu.RegistrationCode, _ = b.store.GetRegCode() + bu.RegistrationCode, _ = b.session.GetRegCode() - // get transmission identity + // Get transmission identity bu.TransmissionIdentity = backup.TransmissionIdentity{ - RSASigningPrivateKey: u.TransmissionRSA, - RegistrarSignature: su.GetTransmissionRegistrationValidationSignature(), - Salt: u.TransmissionSalt, - ComputedID: u.TransmissionID, + RSASigningPrivateKey: b.session.GetTransmissionRSA(), + RegistrarSignature: b.session.GetTransmissionRegistrationValidationSignature(), + Salt: b.session.GetTransmissionSalt(), + ComputedID: b.session.GetTransmissionID(), } - // get reception identity + // Get reception identity bu.ReceptionIdentity = backup.ReceptionIdentity{ - RSASigningPrivateKey: u.ReceptionRSA, - RegistrarSignature: su.GetReceptionRegistrationValidationSignature(), - Salt: u.ReceptionSalt, - ComputedID: u.ReceptionID, - DHPrivateKey: u.E2eDhPrivateKey, - DHPublicKey: u.E2eDhPublicKey, + RSASigningPrivateKey: b.session.GetReceptionRSA(), + RegistrarSignature: b.session.GetReceptionRegistrationValidationSignature(), + Salt: b.session.GetReceptionSalt(), + ComputedID: b.session.GetReceptionID(), + DHPrivateKey: b.e2e.GetHistoricalDHPrivkey(), + DHPublicKey: b.e2e.GetHistoricalDHPubkey(), } - // get facts - bu.UserDiscoveryRegistration.FactList = b.store.GetUd().GetFacts() + // Get facts + bu.UserDiscoveryRegistration.FactList = b.ud.GetFacts() - // get contacts - bu.Contacts.Identities = b.store.E2e().GetPartners() + // Get contacts + bu.Contacts.Identities = b.e2e.GetAllPartnerIDs() - //add the memoized json params + // Add the memoized json params bu.JSONParams = b.jsonParams return bu diff --git a/backup/backup_test.go b/backup/backup_test.go index 434c76e78665692524849952b599652be428600a..076fc2d8c7451dbe039d60fc1c8cd9edc85204e4 100644 --- a/backup/backup_test.go +++ b/backup/backup_test.go @@ -9,29 +9,30 @@ package backup import ( "bytes" + "gitlab.com/elixxir/client/storage/versioned" + "gitlab.com/elixxir/ekv" "reflect" "strings" "testing" "time" - "gitlab.com/elixxir/client/interfaces" - "gitlab.com/elixxir/client/storage" "gitlab.com/elixxir/crypto/backup" "gitlab.com/elixxir/crypto/fastRNG" "gitlab.com/xx_network/crypto/csprng" ) -// Tests that Backup.initializeBackup returns a new Backup with a copy of the +// Tests that Backup.InitializeBackup returns a new Backup with a copy of the // key and the callback. -func Test_initializeBackup(t *testing.T) { +func Test_InitializeBackup(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + rngGen := fastRNG.NewStreamGenerator(1000, 10, csprng.NewSystemRNG) cbChan := make(chan []byte, 2) cb := func(encryptedBackup []byte) { cbChan <- encryptedBackup } expectedPassword := "MySuperSecurePassword" - b, err := initializeBackup(expectedPassword, cb, nil, - storage.InitTestingSession(t), &interfaces.BackupContainer{}, - fastRNG.NewStreamGenerator(1000, 10, csprng.NewSystemRNG)) + b, err := InitializeBackup(expectedPassword, cb, &Container{}, newMockE2e(t), + newMockSession(t), newMockUserDiscovery(), kv, rngGen) if err != nil { - t.Errorf("initializeBackup returned an error: %+v", err) + t.Errorf("InitializeBackup returned an error: %+v", err) } select { @@ -41,7 +42,7 @@ func Test_initializeBackup(t *testing.T) { } // Check that the correct password is in storage - loadedPassword, err := loadPassword(b.store.GetKV()) + loadedPassword, err := loadPassword(b.kv) if err != nil { t.Errorf("Failed to load password: %+v", err) } @@ -51,7 +52,7 @@ func Test_initializeBackup(t *testing.T) { } // Check that the key, salt, and params were saved to storage - key, salt, p, err := loadBackup(b.store.GetKV()) + key, salt, p, err := loadBackup(b.kv) if err != nil { t.Errorf("Failed to load key, salt, and params: %+v", err) } @@ -80,17 +81,17 @@ func Test_initializeBackup(t *testing.T) { } } -// Initialises a new backup and then tests that Backup.resumeBackup overwrites -// the callback but keeps the password. -func Test_resumeBackup(t *testing.T) { +// Initialises a new backup and then tests that ResumeBackup overwrites the +// callback but keeps the password. +func Test_ResumeBackup(t *testing.T) { // Start the first backup + kv := versioned.NewKV(make(ekv.Memstore)) + rngGen := fastRNG.NewStreamGenerator(1000, 10, csprng.NewSystemRNG) cbChan1 := make(chan []byte) cb1 := func(encryptedBackup []byte) { cbChan1 <- encryptedBackup } - s := storage.InitTestingSession(t) expectedPassword := "MySuperSecurePassword" - b, err := initializeBackup(expectedPassword, cb1, nil, s, - &interfaces.BackupContainer{}, - fastRNG.NewStreamGenerator(1000, 10, csprng.NewSystemRNG)) + b, err := InitializeBackup(expectedPassword, cb1, &Container{}, + newMockE2e(t), newMockSession(t), newMockUserDiscovery(), kv, rngGen) if err != nil { t.Errorf("Failed to initialize new Backup: %+v", err) } @@ -102,7 +103,7 @@ func Test_resumeBackup(t *testing.T) { } // get key and salt to compare to later - key1, salt1, _, err := loadBackup(b.store.GetKV()) + key1, salt1, _, err := loadBackup(b.kv) if err != nil { t.Errorf("Failed to load key, salt, and params from newly "+ "initialized backup: %+v", err) @@ -111,14 +112,14 @@ func Test_resumeBackup(t *testing.T) { // Resume the backup with a new callback cbChan2 := make(chan []byte) cb2 := func(encryptedBackup []byte) { cbChan2 <- encryptedBackup } - b2, err := resumeBackup(cb2, nil, s, &interfaces.BackupContainer{}, - fastRNG.NewStreamGenerator(1000, 10, csprng.NewSystemRNG)) + b2, err := ResumeBackup(cb2, &Container{}, newMockE2e(t), newMockSession(t), + newMockUserDiscovery(), kv, rngGen) if err != nil { - t.Errorf("resumeBackup returned an error: %+v", err) + t.Errorf("ResumeBackup returned an error: %+v", err) } // Check that the correct password is in storage - loadedPassword, err := loadPassword(b.store.GetKV()) + loadedPassword, err := loadPassword(b.kv) if err != nil { t.Errorf("Failed to load password: %+v", err) } @@ -128,7 +129,7 @@ func Test_resumeBackup(t *testing.T) { } // get key, salt, and parameters of resumed backup - key2, salt2, _, err := loadBackup(b.store.GetKV()) + key2, salt2, _, err := loadBackup(b.kv) if err != nil { t.Errorf("Failed to load key, salt, and params from resumed "+ "backup: %+v", err) @@ -158,14 +159,16 @@ func Test_resumeBackup(t *testing.T) { } } -// Error path: Tests that Backup.resumeBackup returns an error if no password is +// Error path: Tests that ResumeBackup returns an error if no password is // present in storage. -func Test_resumeBackup_NoKeyError(t *testing.T) { +func Test_ResumeBackup_NoKeyError(t *testing.T) { expectedErr := strings.Split(errLoadPassword, "%")[0] - s := storage.InitTestingSession(t) - _, err := resumeBackup(nil, nil, s, &interfaces.BackupContainer{}, nil) + kv := versioned.NewKV(make(ekv.Memstore)) + rngGen := fastRNG.NewStreamGenerator(1000, 10, csprng.NewSystemRNG) + _, err := ResumeBackup(nil, &Container{}, newMockE2e(t), newMockSession(t), + newMockUserDiscovery(), kv, rngGen) if err == nil || !strings.Contains(err.Error(), expectedErr) { - t.Errorf("resumeBackup did not return the expected error when no "+ + t.Errorf("ResumeBackup did not return the expected error when no "+ "password is present.\nexpected: %s\nreceived: %+v", expectedErr, err) } } @@ -178,7 +181,7 @@ func TestBackup_TriggerBackup(t *testing.T) { b := newTestBackup("MySuperSecurePassword", cb, t) // get password - password, err := loadPassword(b.store.GetKV()) + password, err := loadPassword(b.kv) if err != nil { t.Errorf("Failed to load password from storage: %+v", err) } @@ -215,7 +218,7 @@ func TestBackup_TriggerBackup_NoKey(t *testing.T) { t.Errorf("backup not called") } - err := deleteBackup(b.store.GetKV()) + err := deleteBackup(b.kv) if err != nil { t.Errorf("Failed to delete key, salt, and params: %+v", err) } @@ -260,13 +263,13 @@ func TestBackup_StopBackup(t *testing.T) { } // Make sure password is deleted - password, err := loadPassword(b.store.GetKV()) + password, err := loadPassword(b.kv) if err == nil || len(password) != 0 { t.Errorf("Loaded password that should be deleted: %q", password) } // Make sure key, salt, and params are deleted - key, salt, p, err := loadBackup(b.store.GetKV()) + key, salt, p, err := loadBackup(b.kv) if err == nil || len(key) != 0 || len(salt) != 0 || p != (backup.Params{}) { t.Errorf("Loaded key, salt, and params that should be deleted.") } @@ -296,77 +299,79 @@ func TestBackup_IsBackupRunning(t *testing.T) { func TestBackup_AddJson(t *testing.T) { b := newTestBackup("MySuperSecurePassword", nil, t) - s := b.store + s := b.session.(*mockSession) + e2e := b.e2e.(*mockE2e) json := "{'data': {'one': 1}}" - expectedCollatedBackup := backup.Backup{ - RegistrationTimestamp: s.GetUser().RegistrationTimestamp, + expected := backup.Backup{ + RegistrationCode: s.regCode, + RegistrationTimestamp: s.registrationTimestamp.UnixNano(), TransmissionIdentity: backup.TransmissionIdentity{ - RSASigningPrivateKey: s.GetUser().TransmissionRSA, - RegistrarSignature: s.User().GetTransmissionRegistrationValidationSignature(), - Salt: s.GetUser().TransmissionSalt, - ComputedID: s.GetUser().TransmissionID, + RSASigningPrivateKey: s.transmissionRSA, + RegistrarSignature: s.transmissionRegistrationValidationSignature, + Salt: s.transmissionSalt, + ComputedID: s.transmissionID, }, ReceptionIdentity: backup.ReceptionIdentity{ - RSASigningPrivateKey: s.GetUser().ReceptionRSA, - RegistrarSignature: s.User().GetReceptionRegistrationValidationSignature(), - Salt: s.GetUser().ReceptionSalt, - ComputedID: s.GetUser().ReceptionID, - DHPrivateKey: s.GetUser().E2eDhPrivateKey, - DHPublicKey: s.GetUser().E2eDhPublicKey, + RSASigningPrivateKey: s.receptionRSA, + RegistrarSignature: s.receptionRegistrationValidationSignature, + Salt: s.receptionSalt, + ComputedID: s.receptionID, + DHPrivateKey: e2e.historicalDHPrivkey, + DHPublicKey: e2e.historicalDHPubkey, }, UserDiscoveryRegistration: backup.UserDiscoveryRegistration{ - FactList: s.GetUd().GetFacts(), + FactList: b.ud.(*mockUserDiscovery).facts, }, - Contacts: backup.Contacts{Identities: s.E2e().GetPartners()}, + Contacts: backup.Contacts{Identities: e2e.partnerIDs}, JSONParams: json, } b.AddJson(json) collatedBackup := b.assembleBackup() - if !reflect.DeepEqual(expectedCollatedBackup, collatedBackup) { + if !reflect.DeepEqual(expected, collatedBackup) { t.Errorf("Collated backup does not match expected."+ - "\nexpected: %+v\nreceived: %+v", - expectedCollatedBackup, collatedBackup) + "\nexpected: %+v\nreceived: %+v", expected, collatedBackup) } } func TestBackup_AddJson_badJson(t *testing.T) { b := newTestBackup("MySuperSecurePassword", nil, t) - s := b.store + s := b.session.(*mockSession) + e2e := b.e2e.(*mockE2e) json := "abc{'i'm a bad json: 'one': 1'''}}" - expectedCollatedBackup := backup.Backup{ - RegistrationTimestamp: s.GetUser().RegistrationTimestamp, + expected := backup.Backup{ + RegistrationCode: s.regCode, + RegistrationTimestamp: s.registrationTimestamp.UnixNano(), TransmissionIdentity: backup.TransmissionIdentity{ - RSASigningPrivateKey: s.GetUser().TransmissionRSA, - RegistrarSignature: s.User().GetTransmissionRegistrationValidationSignature(), - Salt: s.GetUser().TransmissionSalt, - ComputedID: s.GetUser().TransmissionID, + RSASigningPrivateKey: s.transmissionRSA, + RegistrarSignature: s.transmissionRegistrationValidationSignature, + Salt: s.transmissionSalt, + ComputedID: s.transmissionID, }, ReceptionIdentity: backup.ReceptionIdentity{ - RSASigningPrivateKey: s.GetUser().ReceptionRSA, - RegistrarSignature: s.User().GetReceptionRegistrationValidationSignature(), - Salt: s.GetUser().ReceptionSalt, - ComputedID: s.GetUser().ReceptionID, - DHPrivateKey: s.GetUser().E2eDhPrivateKey, - DHPublicKey: s.GetUser().E2eDhPublicKey, + RSASigningPrivateKey: s.receptionRSA, + RegistrarSignature: s.receptionRegistrationValidationSignature, + Salt: s.receptionSalt, + ComputedID: s.receptionID, + DHPrivateKey: e2e.historicalDHPrivkey, + DHPublicKey: e2e.historicalDHPubkey, }, UserDiscoveryRegistration: backup.UserDiscoveryRegistration{ - FactList: s.GetUd().GetFacts(), + FactList: b.ud.(*mockUserDiscovery).facts, }, - Contacts: backup.Contacts{Identities: s.E2e().GetPartners()}, + Contacts: backup.Contacts{Identities: e2e.partnerIDs}, JSONParams: json, } b.AddJson(json) collatedBackup := b.assembleBackup() - if !reflect.DeepEqual(expectedCollatedBackup, collatedBackup) { + if !reflect.DeepEqual(expected, collatedBackup) { t.Errorf("Collated backup does not match expected."+ - "\nexpected: %+v\nreceived: %+v", - expectedCollatedBackup, collatedBackup) + "\nexpected: %+v\nreceived: %+v", expected, collatedBackup) } } @@ -374,47 +379,51 @@ func TestBackup_AddJson_badJson(t *testing.T) { // results. func TestBackup_assembleBackup(t *testing.T) { b := newTestBackup("MySuperSecurePassword", nil, t) - s := b.store + s := b.session.(*mockSession) + e2e := b.e2e.(*mockE2e) - expectedCollatedBackup := backup.Backup{ - RegistrationTimestamp: s.GetUser().RegistrationTimestamp, + expected := backup.Backup{ + RegistrationCode: s.regCode, + RegistrationTimestamp: s.registrationTimestamp.UnixNano(), TransmissionIdentity: backup.TransmissionIdentity{ - RSASigningPrivateKey: s.GetUser().TransmissionRSA, - RegistrarSignature: s.User().GetTransmissionRegistrationValidationSignature(), - Salt: s.GetUser().TransmissionSalt, - ComputedID: s.GetUser().TransmissionID, + RSASigningPrivateKey: s.transmissionRSA, + RegistrarSignature: s.transmissionRegistrationValidationSignature, + Salt: s.transmissionSalt, + ComputedID: s.transmissionID, }, ReceptionIdentity: backup.ReceptionIdentity{ - RSASigningPrivateKey: s.GetUser().ReceptionRSA, - RegistrarSignature: s.User().GetReceptionRegistrationValidationSignature(), - Salt: s.GetUser().ReceptionSalt, - ComputedID: s.GetUser().ReceptionID, - DHPrivateKey: s.GetUser().E2eDhPrivateKey, - DHPublicKey: s.GetUser().E2eDhPublicKey, + RSASigningPrivateKey: s.receptionRSA, + RegistrarSignature: s.receptionRegistrationValidationSignature, + Salt: s.receptionSalt, + ComputedID: s.receptionID, + DHPrivateKey: e2e.historicalDHPrivkey, + DHPublicKey: e2e.historicalDHPubkey, }, UserDiscoveryRegistration: backup.UserDiscoveryRegistration{ - FactList: s.GetUd().GetFacts(), + FactList: b.ud.(*mockUserDiscovery).facts, }, - Contacts: backup.Contacts{Identities: s.E2e().GetPartners()}, + Contacts: backup.Contacts{Identities: e2e.partnerIDs}, } collatedBackup := b.assembleBackup() - if !reflect.DeepEqual(expectedCollatedBackup, collatedBackup) { + if !reflect.DeepEqual(expected, collatedBackup) { t.Errorf("Collated backup does not match expected."+ "\nexpected: %+v\nreceived: %+v", - expectedCollatedBackup, collatedBackup) + expected, collatedBackup) } } // newTestBackup creates a new Backup for testing. func newTestBackup(password string, cb UpdateBackupFn, t *testing.T) *Backup { - b, err := initializeBackup( + b, err := InitializeBackup( password, cb, - nil, - storage.InitTestingSession(t), - &interfaces.BackupContainer{}, + &Container{}, + newMockE2e(t), + newMockSession(t), + newMockUserDiscovery(), + versioned.NewKV(make(ekv.Memstore)), fastRNG.NewStreamGenerator(1000, 10, csprng.NewSystemRNG), ) if err != nil { diff --git a/interfaces/backup.go b/backup/container.go similarity index 83% rename from interfaces/backup.go rename to backup/container.go index 559b4b0f8756ee772aba9064757601d2ade417d6..fd456ea96c11337c4709991c97f6b84458adfe11 100644 --- a/interfaces/backup.go +++ b/backup/container.go @@ -5,14 +5,14 @@ // LICENSE file // //////////////////////////////////////////////////////////////////////////////// -package interfaces +package backup import "sync" type TriggerBackup func(reason string) -// BackupContainer contains the trigger to call to initiate a backup. -type BackupContainer struct { +// Container contains the trigger to call to initiate a backup. +type Container struct { triggerBackup TriggerBackup mux sync.RWMutex } @@ -22,7 +22,7 @@ type BackupContainer struct { // should be in the paste tense. For example, if a contact is deleted, the // reason can be "contact deleted" and the log will show: // Triggering backup: contact deleted -func (bc *BackupContainer) TriggerBackup(reason string) { +func (bc *Container) TriggerBackup(reason string) { bc.mux.RLock() defer bc.mux.RUnlock() if bc.triggerBackup != nil { @@ -32,7 +32,7 @@ func (bc *BackupContainer) TriggerBackup(reason string) { // SetBackup sets the backup trigger function which will cause a backup to start // on the next event that triggers is. -func (bc *BackupContainer) SetBackup(triggerBackup TriggerBackup) { +func (bc *Container) SetBackup(triggerBackup TriggerBackup) { bc.mux.Lock() defer bc.mux.Unlock() diff --git a/backup/utils_test.go b/backup/utils_test.go new file mode 100644 index 0000000000000000000000000000000000000000..c16969ff89da1f08b29574dba9f20e9da2e4fb44 --- /dev/null +++ b/backup/utils_test.go @@ -0,0 +1,156 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package backup + +import ( + "gitlab.com/elixxir/crypto/cyclic" + "gitlab.com/elixxir/primitives/fact" + "gitlab.com/xx_network/crypto/large" + "gitlab.com/xx_network/crypto/signature/rsa" + "gitlab.com/xx_network/primitives/id" + "testing" + "time" +) + +// Adheres to the E2e interface. +type mockE2e struct { + partnerIDs []*id.ID + historicalDHPubkey *cyclic.Int + historicalDHPrivkey *cyclic.Int +} + +func newMockE2e(t *testing.T) *mockE2e { + grp := cyclic.NewGroup(large.NewInt(173), large.NewInt(0)) + return &mockE2e{ + partnerIDs: []*id.ID{ + id.NewIdFromString("partner1", id.User, t), + id.NewIdFromString("partner2", id.User, t), + id.NewIdFromString("partner3", id.User, t), + }, + historicalDHPubkey: grp.NewInt(45), + historicalDHPrivkey: grp.NewInt(46), + } +} +func (m *mockE2e) GetAllPartnerIDs() []*id.ID { return m.partnerIDs } +func (m *mockE2e) GetHistoricalDHPubkey() *cyclic.Int { return m.historicalDHPubkey } +func (m *mockE2e) GetHistoricalDHPrivkey() *cyclic.Int { return m.historicalDHPrivkey } + +// Adheres to the Session interface. +type mockSession struct { + regCode string + transmissionID *id.ID + transmissionSalt []byte + receptionID *id.ID + receptionSalt []byte + receptionRSA *rsa.PrivateKey + transmissionRSA *rsa.PrivateKey + transmissionRegistrationValidationSignature []byte + receptionRegistrationValidationSignature []byte + registrationTimestamp time.Time +} + +func newMockSession(t *testing.T) *mockSession { + receptionRSA, _ := rsa.LoadPrivateKeyFromPem([]byte(privKey)) + transmissionRSA, _ := rsa.LoadPrivateKeyFromPem([]byte(privKey)) + + return &mockSession{ + regCode: "regCode", + transmissionID: id.NewIdFromString("transmission", id.User, t), + transmissionSalt: []byte("transmissionSalt"), + receptionID: id.NewIdFromString("reception", id.User, t), + receptionSalt: []byte("receptionSalt"), + receptionRSA: receptionRSA, + transmissionRSA: transmissionRSA, + transmissionRegistrationValidationSignature: []byte("transmissionSig"), + receptionRegistrationValidationSignature: []byte("receptionSig"), + registrationTimestamp: time.Date(2012, 12, 21, 22, 8, 41, 0, time.UTC), + } + +} +func (m mockSession) GetRegCode() (string, error) { return m.regCode, nil } +func (m mockSession) GetTransmissionID() *id.ID { return m.transmissionID } +func (m mockSession) GetTransmissionSalt() []byte { return m.transmissionSalt } +func (m mockSession) GetReceptionID() *id.ID { return m.receptionID } +func (m mockSession) GetReceptionSalt() []byte { return m.receptionSalt } +func (m mockSession) GetReceptionRSA() *rsa.PrivateKey { return m.receptionRSA } +func (m mockSession) GetTransmissionRSA() *rsa.PrivateKey { return m.transmissionRSA } +func (m mockSession) GetTransmissionRegistrationValidationSignature() []byte { + return m.transmissionRegistrationValidationSignature +} +func (m mockSession) GetReceptionRegistrationValidationSignature() []byte { + return m.receptionRegistrationValidationSignature +} +func (m mockSession) GetRegistrationTimestamp() time.Time { return m.registrationTimestamp } + +// Adheres to the UserDiscovery interface. +type mockUserDiscovery struct { + facts fact.FactList +} + +func newMockUserDiscovery() *mockUserDiscovery { + return &mockUserDiscovery{facts: fact.FactList{ + {"myUserName", fact.Username}, + {"hello@example.com", fact.Email}, + {"6175555212", fact.Phone}, + {"name", fact.Nickname}, + }} +} +func (m mockUserDiscovery) GetFacts() fact.FactList { return m.facts } + +const privKey = `-----BEGIN PRIVATE KEY----- +MIIJQQIBADANBgkqhkiG9w0BAQEFAASCCSswggknAgEAAoICAQC7Dkb6VXFn4cdp +U0xh6ji0nTDQUyT9DSNW9I3jVwBrWfqMc4ymJuonMZbuqK+cY2l+suS2eugevWZr +tzujFPBRFp9O14Jl3fFLfvtjZvkrKbUMHDHFehascwzrp3tXNryiRMmCNQV55TfI +TVCv8CLE0t1ibiyOGM9ZWYB2OjXt59j76lPARYww5qwC46vS6+3Cn2Yt9zkcrGes +kWEFa2VttHqF910TP+DZk2R5C7koAh6wZYK6NQ4S83YQurdHAT51LKGrbGehFKXq +6/OAXCU1JLi3kW2PovTb6MZuvxEiRmVAONsOcXKu7zWCmFjuZZwfRt2RhnpcSgzf +rarmsGM0LZh6JY3MGJ9YdPcVGSz+Vs2E4zWbNW+ZQoqlcGeMKgsIiQ670g0xSjYI +Cqldpt79gaET9PZsoXKEmKUaj6pq1d4qXDk7s63HRQazwVLGBdJQK8qX41eCdR8V +MKbrCaOkzD5zgnEu0jBBAwdMtcigkMIk1GRv91j7HmqwryOBHryLi6NWBY3tjb4S +o9AppDQB41SH3SwNenAbNO1CXeUqN0hHX6I1bE7OlbjqI7tXdrTllHAJTyVVjenP +el2ApMXp+LVRdDbKtwBiuM6+n+z0I7YYerxN1gfvpYgcXm4uye8dfwotZj6H2J/u +SALsU2v9UHBzprdrLSZk2YpozJb+CQIDAQABAoICAARjDFUYpeU6zVNyCauOM7BA +s4FfQdHReg+zApTfWHosDQ04NIc9CGbM6e5E9IFlb3byORzyevkllf5WuMZVWmF8 +d1YBBeTftKYBn2Gwa42Ql9dl3eD0wQ1gUWBBeEoOVZQ0qskr9ynpr0o6TfciWZ5m +F50UWmUmvc4ppDKhoNwogNU/pKEwwF3xOv2CW2hB8jyLQnk3gBZlELViX3UiFKni +/rCfoYYvDFXt+ABCvx/qFNAsQUmerurQ3Ob9igjXRaC34D7F9xQ3CMEesYJEJvc9 +Gjvr5DbnKnjx152HS56TKhK8gp6vGHJz17xtWECXD3dIUS/1iG8bqXuhdg2c+2aW +m3MFpa5jgpAawUWc7c32UnqbKKf+HI7/x8J1yqJyNeU5SySyYSB5qtwTShYzlBW/ +yCYD41edeJcmIp693nUcXzU+UAdtpt0hkXS59WSWlTrB/huWXy6kYXLNocNk9L7g +iyx0cOmkuxREMHAvK0fovXdVyflQtJYC7OjJxkzj2rWO+QtHaOySXUyinkuTb5ev +xNhs+ROWI/HAIE9buMqXQIpHx6MSgdKOL6P6AEbBan4RAktkYA6y5EtH/7x+9V5E +QTIz4LrtI6abaKb4GUlZkEsc8pxrkNwCqOAE/aqEMNh91Na1TOj3f0/a6ckGYxYH +pyrvwfP2Ouu6e5FhDcCBAoIBAQDcN8mK99jtrH3q3Q8vZAWFXHsOrVvnJXyHLz9V +1Rx/7TnMUxvDX1PIVxhuJ/tmHtxrNIXOlps80FCZXGgxfET/YFrbf4H/BaMNJZNP +ag1wBV5VQSnTPdTR+Ijice+/ak37S2NKHt8+ut6yoZjD7sf28qiO8bzNua/OYHkk +V+RkRkk68Uk2tFMluQOSyEjdsrDNGbESvT+R1Eotupr0Vy/9JRY/TFMc4MwJwOoy +s7wYr9SUCq/cYn7FIOBTI+PRaTx1WtpfkaErDc5O+nLLEp1yOrfktl4LhU/r61i7 +fdtafUACTKrXG2qxTd3w++mHwTwVl2MwhiMZfxvKDkx0L2gxAoIBAQDZcxKwyZOy +s6Aw7igw1ftLny/dpjPaG0p6myaNpeJISjTOU7HKwLXmlTGLKAbeRFJpOHTTs63y +gcmcuE+vGCpdBHQkaCev8cve1urpJRcxurura6+bYaENO6ua5VzF9BQlDYve0YwY +lbJiRKmEWEAyULjbIebZW41Z4UqVG3MQI750PRWPW4WJ2kDhksFXN1gwSnaM46KR +PmVA0SL+RCPcAp/VkImCv0eqv9exsglY0K/QiJfLy3zZ8QvAn0wYgZ3AvH3lr9rJ +T7pg9WDb+OkfeEQ7INubqSthhaqCLd4zwbMRlpyvg1cMSq0zRvrFpwVlSY85lW4F +g/tgjJ99W9VZAoIBAH3OYRVDAmrFYCoMn+AzA/RsIOEBqL8kaz/Pfh9K4D01CQ/x +aqryiqqpFwvXS4fLmaClIMwkvgq/90ulvuCGXeSG52D+NwW58qxQCxgTPhoA9yM9 +VueXKz3I/mpfLNftox8sskxl1qO/nfnu15cXkqVBe4ouD+53ZjhAZPSeQZwHi05h +CbJ20gl66M+yG+6LZvXE96P8+ZQV80qskFmGdaPozAzdTZ3xzp7D1wegJpTz3j20 +3ULKAiIb5guZNU0tEZz5ikeOqsQt3u6/pVTeDZR0dxnyFUf/oOjmSorSG75WT3sA +0ZiR0SH5mhFR2Nf1TJ4JHmFaQDMQqo+EG6lEbAECggEAA7kGnuQ0lSCiI3RQV9Wy +Aa9uAFtyE8/XzJWPaWlnoFk04jtoldIKyzHOsVU0GOYOiyKeTWmMFtTGANre8l51 +izYiTuVBmK+JD/2Z8/fgl8dcoyiqzvwy56kX3QUEO5dcKO48cMohneIiNbB7PnrM +TpA3OfkwnJQGrX0/66GWrLYP8qmBDv1AIgYMilAa40VdSyZbNTpIdDgfP6bU9Ily +G7gnyF47HHPt5Cx4ouArbMvV1rof7ytCrfCEhP21Lc46Ryxy81W5ZyzoQfSxfdKb +GyDR+jkryVRyG69QJf5nCXfNewWbFR4ohVtZ78DNVkjvvLYvr4qxYYLK8PI3YMwL +sQKCAQB9lo7JadzKVio+C18EfNikOzoriQOaIYowNaaGDw3/9KwIhRsKgoTs+K5O +gt/gUoPRGd3M2z4hn5j4wgeuFi7HC1MdMWwvgat93h7R1YxiyaOoCTxH1klbB/3K +4fskdQRxuM8McUebebrp0qT5E0xs2l+ABmt30Dtd3iRrQ5BBjnRc4V//sQiwS1aC +Yi5eNYCQ96BSAEo1dxJh5RI/QxF2HEPUuoPM8iXrIJhyg9TEEpbrEJcxeagWk02y +OMEoUbWbX07OzFVvu+aJaN/GlgiogMQhb6IiNTyMlryFUleF+9OBA8xGHqGWA6nR +OaRA5ZbdE7g7vxKRV36jT3wvD7W+ +-----END PRIVATE KEY-----` diff --git a/bindings/fileTransfer.go b/bindings/fileTransfer.go index 74d911483de71b990ae05ee5b3b0a281ecc528f1..9e2d76169cd393422039b06de483905bd3a0de76 100644 --- a/bindings/fileTransfer.go +++ b/bindings/fileTransfer.go @@ -10,7 +10,6 @@ package bindings import ( "encoding/json" ft "gitlab.com/elixxir/client/fileTransfer" - "gitlab.com/elixxir/client/interfaces" ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" "gitlab.com/xx_network/primitives/id" "time" @@ -18,23 +17,23 @@ import ( // FileTransfer contains the file transfer manager. type FileTransfer struct { - m *ft.Manager + m ft.FileTransfer } // FileTransferSentProgressFunc contains a function callback that tracks the -// progress of sending a file. It is called when a file part is sent, a file -// part arrives, the transfer completes, or on error. +// progress of sending a file. It is called when a file part arrives, the +// transfer completes, or on error. type FileTransferSentProgressFunc interface { - SentProgressCallback(completed bool, sent, arrived, total int, - t *FilePartTracker, err error) + SentProgressCallback( + completed bool, arrived, total int, t *FilePartTracker, err error) } // FileTransferReceivedProgressFunc contains a function callback that tracks the // progress of receiving a file. It is called when a file part is received, the // transfer completes, or on error. type FileTransferReceivedProgressFunc interface { - ReceivedProgressCallback(completed bool, received, total int, - t *FilePartTracker, err error) + ReceivedProgressCallback( + completed bool, received, total int, t *FilePartTracker, err error) } // FileTransferReceiveFunc contains a function callback that notifies the @@ -49,7 +48,7 @@ type FileTransferReceiveFunc interface { // sending and receiving threads. The receiveFunc is called everytime a new file // transfer is received. // The parameters string contains file transfer network configuration options -// and is a JSON formatted string of the fileTransfer.Params object. If it is +// and is a JSON formatted string of the fileTransfer2.Params object. If it is // left empty, then defaults are used. It is highly recommended that defaults // are used. If it is set, it must match the following format: // {"MaxThroughput":150000,"SendTimeout":500000000} @@ -57,7 +56,7 @@ type FileTransferReceiveFunc interface { func NewFileTransferManager(client *Client, receiveFunc FileTransferReceiveFunc, parameters string) (*FileTransfer, error) { - receiveCB := func(tid ftCrypto.TransferID, fileName, fileType string, + receiveCB := func(tid *ftCrypto.TransferID, fileName, fileType string, sender *id.ID, size uint32, preview []byte) { receiveFunc.ReceiveCallback( tid.Bytes(), fileName, fileType, sender.Bytes(), int(size), preview) @@ -73,7 +72,8 @@ func NewFileTransferManager(client *Client, receiveFunc FileTransferReceiveFunc, } // Create new file transfer manager - m, err := ft.NewManager(&client.api, receiveCB, p) + // TODO: Fix NewManager parameters + m, err := ft.NewManager(receiveCB, p, nil, nil, nil, nil, nil) if err != nil { return nil, err } @@ -107,10 +107,10 @@ func (f *FileTransfer) Send(fileName, fileType string, fileData []byte, progressFunc FileTransferSentProgressFunc, periodMS int) ([]byte, error) { // Create SentProgressCallback - progressCB := func(completed bool, sent, arrived, total uint16, - t interfaces.FilePartTracker, err error) { + progressCB := func(completed bool, arrived, total uint16, + t ft.FilePartTracker, err error) { progressFunc.SentProgressCallback( - completed, int(sent), int(arrived), int(total), &FilePartTracker{t}, err) + completed, int(arrived), int(total), &FilePartTracker{t}, err) } // Convert recipient ID bytes to id.ID @@ -149,16 +149,16 @@ func (f *FileTransfer) RegisterSendProgressCallback(transferID []byte, tid := ftCrypto.UnmarshalTransferID(transferID) // Create SentProgressCallback - progressCB := func(completed bool, sent, arrived, total uint16, - t interfaces.FilePartTracker, err error) { + progressCB := func(completed bool, arrived, total uint16, + t ft.FilePartTracker, err error) { progressFunc.SentProgressCallback( - completed, int(sent), int(arrived), int(total), &FilePartTracker{t}, err) + completed, int(arrived), int(total), &FilePartTracker{t}, err) } // Convert period to time.Duration period := time.Duration(periodMS) * time.Millisecond - return f.m.RegisterSentProgressCallback(tid, progressCB, period) + return f.m.RegisterSentProgressCallback(&tid, progressCB, period) } // Resend resends a file if sending fails. This function should only be called @@ -178,7 +178,7 @@ func (f *FileTransfer) CloseSend(transferID []byte) error { // Unmarshal transfer ID tid := ftCrypto.UnmarshalTransferID(transferID) - return f.m.CloseSend(tid) + return f.m.CloseSend(&tid) } // Receive returns the fully assembled file on the completion of the transfer. @@ -189,7 +189,7 @@ func (f *FileTransfer) Receive(transferID []byte) ([]byte, error) { // Unmarshal transfer ID tid := ftCrypto.UnmarshalTransferID(transferID) - return f.m.Receive(tid) + return f.m.Receive(&tid) } // RegisterReceiveProgressCallback allows for the registration of a callback to @@ -209,7 +209,7 @@ func (f *FileTransfer) RegisterReceiveProgressCallback(transferID []byte, // Create ReceivedProgressCallback progressCB := func(completed bool, received, total uint16, - t interfaces.FilePartTracker, err error) { + t ft.FilePartTracker, err error) { progressFunc.ReceivedProgressCallback( completed, int(received), int(total), &FilePartTracker{t}, err) } @@ -217,7 +217,7 @@ func (f *FileTransfer) RegisterReceiveProgressCallback(transferID []byte, // Convert period to time.Duration period := time.Duration(periodMS) * time.Millisecond - return f.m.RegisterReceivedProgressCallback(tid, progressCB, period) + return f.m.RegisterReceivedProgressCallback(&tid, progressCB, period) } //////////////////////////////////////////////////////////////////////////////// @@ -253,7 +253,7 @@ func (f *FileTransfer) GetMaxFileSize() int { // FilePartTracker contains the interfaces.FilePartTracker. type FilePartTracker struct { - m interfaces.FilePartTracker + m ft.FilePartTracker } // GetPartStatus returns the status of the file part with the given part number. diff --git a/cmd/fileTransfer.go b/cmd/fileTransfer.go index 6892a9cb9538673ee7351bcd0c74ed439d11bbf1..3241d67cad45472c324226ca2f58678bbe8d8dc0 100644 --- a/cmd/fileTransfer.go +++ b/cmd/fileTransfer.go @@ -14,7 +14,6 @@ import ( "github.com/spf13/viper" "gitlab.com/elixxir/client/api" ft "gitlab.com/elixxir/client/fileTransfer" - "gitlab.com/elixxir/client/interfaces" "gitlab.com/elixxir/crypto/contact" ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" "gitlab.com/xx_network/primitives/id" @@ -117,7 +116,7 @@ var ftCmd = &cobra.Command{ // receivedFtResults is used to return received new file transfer results on a // channel from a callback. type receivedFtResults struct { - tid ftCrypto.TransferID + tid *ftCrypto.TransferID fileName string fileType string sender *id.ID @@ -129,11 +128,11 @@ type receivedFtResults struct { // reception callback. Returns the file transfer manager and the channel that // will be triggered when the callback is called. func initFileTransferManager(client *api.Client, maxThroughput int) ( - *ft.Manager, chan receivedFtResults) { + ft.FileTransfer, chan receivedFtResults) { // Create interfaces.ReceiveCallback that returns the results on a channel receiveChan := make(chan receivedFtResults, 100) - receiveCB := func(tid ftCrypto.TransferID, fileName, fileType string, + receiveCB := func(tid *ftCrypto.TransferID, fileName, fileType string, sender *id.ID, size uint32, preview []byte) { receiveChan <- receivedFtResults{ tid, fileName, fileType, sender, size, preview} @@ -147,7 +146,8 @@ func initFileTransferManager(client *api.Client, maxThroughput int) ( } // Create new manager - manager, err := ft.NewManager(client, receiveCB, p) + // TODO: Fix NewManager parameters + manager, err := ft.NewManager(receiveCB, p, nil, nil, nil, nil, nil) if err != nil { jww.FATAL.Panicf( "[FT] Failed to create new file transfer manager: %+v", err) @@ -164,7 +164,7 @@ func initFileTransferManager(client *api.Client, maxThroughput int) ( // sendFile sends the file to the recipient and prints the progress. func sendFile(filePath, fileType, filePreviewPath, filePreviewString, - recipientContactPath string, retry float32, m *ft.Manager, + recipientContactPath string, retry float32, m ft.FileTransfer, done chan struct{}) { // Get file from path @@ -200,16 +200,16 @@ func sendFile(filePath, fileType, filePreviewPath, filePreviewString, var sendStart time.Time // Create sent progress callback that prints the results - progressCB := func(completed bool, sent, arrived, total uint16, - t interfaces.FilePartTracker, err error) { + progressCB := func(completed bool, arrived, total uint16, + t ft.FilePartTracker, err error) { jww.INFO.Printf("[FT] Sent progress callback for %q "+ - "{completed: %t, sent: %d, arrived: %d, total: %d, err: %v}", - fileName, completed, sent, arrived, total, err) - if (sent == 0 && arrived == 0) || (arrived == total) || completed || + "{completed: %t, arrived: %d, total: %d, err: %v}", + fileName, completed, arrived, total, err) + if arrived == 0 || (arrived == total) || completed || err != nil { fmt.Printf("Sent progress callback for %q "+ - "{completed: %t, sent: %d, arrived: %d, total: %d, err: %v}\n", - fileName, completed, sent, arrived, total, err) + "{completed: %t, arrived: %d, total: %d, err: %v}\n", + fileName, completed, arrived, total, err) } if completed { @@ -247,7 +247,7 @@ func sendFile(filePath, fileType, filePreviewPath, filePreviewString, // receiveNewFileTransfers waits to receive new file transfers and prints its // information to the log. func receiveNewFileTransfers(receive chan receivedFtResults, done, - quit chan struct{}, m *ft.Manager) { + quit chan struct{}, m ft.FileTransfer) { jww.INFO.Print("[FT] Starting thread waiting to receive NewFileTransfer " + "E2E message.") for { @@ -276,11 +276,11 @@ func receiveNewFileTransfers(receive chan receivedFtResults, done, // newReceiveProgressCB creates a new reception progress callback that prints // the results to the log. -func newReceiveProgressCB(tid ftCrypto.TransferID, fileName string, +func newReceiveProgressCB(tid *ftCrypto.TransferID, fileName string, done chan struct{}, receiveStart time.Time, - m *ft.Manager) interfaces.ReceivedProgressCallback { + m ft.FileTransfer) ft.ReceivedProgressCallback { return func(completed bool, received, total uint16, - t interfaces.FilePartTracker, err error) { + t ft.FilePartTracker, err error) { jww.INFO.Printf("[FT] Receive progress callback for transfer %s "+ "{completed: %t, received: %d, total: %d, err: %v}", tid, completed, received, total, err) @@ -333,7 +333,7 @@ func getContactFromFile(path string) contact.Contact { // init initializes commands and flags for Cobra. func init() { ftCmd.Flags().String("sendFile", "", - "Sends a file to a recipient with with the contact file at this path.") + "Sends a file to a recipient with the contact file at this path.") bindPFlagCheckErr("sendFile") ftCmd.Flags().String("filePath", "", diff --git a/cmix/client.go b/cmix/client.go index 76fd6377fdee71f9cbe2ffc6fcdb6fdcd8197640..1dd7cff35662ce48e696bd66e06c4ce4b8db4185 100644 --- a/cmix/client.go +++ b/cmix/client.go @@ -11,6 +11,11 @@ package cmix // and intra-client state are accessible through the context object. import ( + "math" + "strconv" + "sync/atomic" + "time" + "github.com/pkg/errors" "gitlab.com/elixxir/client/cmix/address" "gitlab.com/elixxir/client/cmix/gateway" @@ -30,10 +35,6 @@ import ( "gitlab.com/xx_network/primitives/id" "gitlab.com/xx_network/primitives/id/ephemeral" "gitlab.com/xx_network/primitives/ndf" - "math" - "strconv" - "sync/atomic" - "time" ) // fakeIdentityRange indicates the range generated between 0 (most current) and @@ -81,7 +82,7 @@ type client struct { verboseRounds *RoundTracker // Event reporting API - events event.Manager + events event.Reporter // Storage of the max message length maxMsgLen int @@ -90,11 +91,12 @@ type client struct { // NewClient builds a new reception client object using inputted key fields. func NewClient(params Params, comms *commClient.Comms, session storage.Session, ndf *ndf.NetworkDefinition, rng *fastRNG.StreamGenerator, - events event.Manager) (Client, error) { + events event.Reporter) (Client, error) { // Start network instance instance, err := commNetwork.NewInstance( - comms.ProtoComms, ndf, nil, nil, commNetwork.None, params.FastPolling) + comms.ProtoComms, ndf, nil, nil, commNetwork.None, + params.FastPolling) if err != nil { return nil, errors.WithMessage( err, "failed to create network client") diff --git a/cmix/interface.go b/cmix/interface.go index 89320029719c5f9be6465c26254c0d44869a5f8a..f6788924f5a1064b235ec3c6b99da5ff7f6ead5b 100644 --- a/cmix/interface.go +++ b/cmix/interface.go @@ -172,7 +172,7 @@ type Client interface { // TrackServices registers a callback that will get called every time a // service is added or removed. It will receive the triggers list every time // it is modified. It will only get callbacks while the network follower is - // running. Multiple trackTriggers can be registered + // running. Multiple trackTriggers can be registered. TrackServices(tracker message.ServicesTracker) /* === In inProcess ===================================================== */ diff --git a/cmix/message/pickup.go b/cmix/message/pickup.go index acba1c0382c1190c8f7524dbdba40852dfa7e415..29fbb95951aea7f4cffbd1c96ba9919ab0d27ee0 100644 --- a/cmix/message/pickup.go +++ b/cmix/message/pickup.go @@ -8,11 +8,12 @@ package message import ( + "strconv" + "gitlab.com/elixxir/client/event" "gitlab.com/elixxir/client/storage/versioned" "gitlab.com/elixxir/primitives/format" "gitlab.com/xx_network/primitives/id" - "strconv" jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/stoppable" @@ -47,13 +48,13 @@ type handler struct { inProcess *MeteredCmixMessageBuffer - events event.Manager + events event.Reporter FingerprintsManager ServicesManager } -func NewHandler(param Params, kv *versioned.KV, events event.Manager, +func NewHandler(param Params, kv *versioned.KV, events event.Reporter, standardID *id.ID) Handler { garbled, err := NewOrLoadMeteredCmixMessageBuffer(kv, inProcessKey) diff --git a/cmix/nodes/register.go b/cmix/nodes/register.go index 9cd4502f04a1e9fd16269df09c07f0eb31a1d727..af18976e4402845ea513aae0e8ae400f6683bbb5 100644 --- a/cmix/nodes/register.go +++ b/cmix/nodes/register.go @@ -10,6 +10,10 @@ package nodes import ( "crypto/sha256" "encoding/hex" + "strconv" + "sync" + "time" + "github.com/golang/protobuf/proto" "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" @@ -31,9 +35,6 @@ import ( "gitlab.com/xx_network/primitives/id" "gitlab.com/xx_network/primitives/ndf" "gitlab.com/xx_network/primitives/netTime" - "strconv" - "sync" - "time" ) func registerNodes(r *registrar, s storage.Session, stop *stoppable.Single, @@ -155,10 +156,9 @@ func requestKey(sender gateway.Sender, comms RegisterNodeCommsInterface, grp := r.session.GetCmixGroup() - // FIXME: Why 256 bits? -- this is spec but not explained, it has to do with - // optimizing operations on one side and still preserves decent security -- - // cite this. - dhPrivBytes, err := csprng.GenerateInGroup(grp.GetPBytes(), 256, rng) + prime := grp.GetPBytes() + keyLen := len(prime) + dhPrivBytes, err := csprng.GenerateInGroup(prime, keyLen, rng) if err != nil { return nil, nil, 0, err } diff --git a/cmix/params.go b/cmix/params.go index bfd0cd1a5011fabd45c06d971da6ddd7112f7527..d823632e199e021606a3414bc2403babcc3d17b9 100644 --- a/cmix/params.go +++ b/cmix/params.go @@ -3,13 +3,14 @@ package cmix import ( "encoding/base64" "encoding/json" + "time" + "gitlab.com/elixxir/client/cmix/message" "gitlab.com/elixxir/client/cmix/pickup" "gitlab.com/elixxir/client/cmix/rounds" "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/primitives/excludedRounds" "gitlab.com/xx_network/primitives/id" - "time" ) type Params struct { diff --git a/cmix/rounds/historical.go b/cmix/rounds/historical.go index 98522670faf6a2635e37d300f5df25db4c451b75..6b4ffbcd727c9b2e8c70b53769fb44c46761dca7 100644 --- a/cmix/rounds/historical.go +++ b/cmix/rounds/historical.go @@ -40,7 +40,7 @@ type manager struct { comms Comms sender gateway.Sender - events event.Manager + events event.Reporter c chan roundRequest } @@ -63,7 +63,7 @@ type roundRequest struct { } func NewRetriever(param Params, comms Comms, sender gateway.Sender, - events event.Manager) Retriever { + events event.Reporter) Retriever { return &manager{ params: param, comms: comms, @@ -200,7 +200,7 @@ func (m *manager) processHistoricalRounds(comm Comms, stop *stoppable.Single) { } func processHistoricalRoundsResponse(response *pb.HistoricalRoundsResponse, - roundRequests []roundRequest, maxRetries uint, events event.Manager) ( + roundRequests []roundRequest, maxRetries uint, events event.Reporter) ( []uint64, []roundRequest) { retries := make([]roundRequest, 0) rids := make([]uint64, 0) diff --git a/cmix/sendCmix.go b/cmix/sendCmix.go index 8b1737e0f34b1404220c8a4c4915193536956121..e7267460c916203c529387cb0d7ceb5eb59c4095 100644 --- a/cmix/sendCmix.go +++ b/cmix/sendCmix.go @@ -9,6 +9,9 @@ package cmix import ( "fmt" + "strings" + "time" + "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/cmix/gateway" @@ -28,8 +31,6 @@ import ( "gitlab.com/xx_network/primitives/id/ephemeral" "gitlab.com/xx_network/primitives/netTime" "gitlab.com/xx_network/primitives/rateLimiting" - "strings" - "time" ) // Send sends a "raw" cMix message payload to the provided recipient. @@ -95,7 +96,7 @@ func (c *client) Send(recipient *id.ID, fingerprint format.Fingerprint, // status. func sendCmixHelper(sender gateway.Sender, msg format.Message, recipient *id.ID, cmixParams CMIXParams, instance *network.Instance, grp *cyclic.Group, - nodes nodes.Registrar, rng *fastRNG.StreamGenerator, events event.Manager, + nodes nodes.Registrar, rng *fastRNG.StreamGenerator, events event.Reporter, senderId *id.ID, comms SendCmixCommsInterface) (id.Round, ephemeral.Id, error) { timeStart := netTime.Now() diff --git a/cmix/sendManyCmix.go b/cmix/sendManyCmix.go index 4d634b8ca10f3101a85bf69a7680e5427d729511..01e39d34e881167e224ae78fe65a8bbc18aacee0 100644 --- a/cmix/sendManyCmix.go +++ b/cmix/sendManyCmix.go @@ -9,6 +9,9 @@ package cmix import ( "fmt" + "strings" + "time" + "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/cmix/gateway" @@ -27,8 +30,6 @@ import ( "gitlab.com/xx_network/primitives/id" "gitlab.com/xx_network/primitives/id/ephemeral" "gitlab.com/xx_network/primitives/netTime" - "strings" - "time" ) // TargetedCmixMessage defines a recipient target pair in a sendMany cMix @@ -110,7 +111,7 @@ type assembledCmixMessage struct { func sendManyCmixHelper(sender gateway.Sender, msgs []assembledCmixMessage, param CMIXParams, instance *network.Instance, grp *cyclic.Group, registrar nodes.Registrar, - rng *fastRNG.StreamGenerator, events event.Manager, + rng *fastRNG.StreamGenerator, events event.Reporter, senderId *id.ID, comms SendCmixCommsInterface) ( id.Round, []ephemeral.Id, error) { diff --git a/dummy/utils_test.go b/dummy/utils_test.go index c6ecaf36b9906501bf1b1c8276c7dd464e2c85ef..73d3c5d0cc28b047ecea26d55e66f44c907069cb 100644 --- a/dummy/utils_test.go +++ b/dummy/utils_test.go @@ -8,6 +8,12 @@ package dummy import ( + "io" + "math/rand" + "sync" + "testing" + "time" + "github.com/pkg/errors" "gitlab.com/elixxir/client/cmix/gateway" "gitlab.com/elixxir/client/event" @@ -25,11 +31,6 @@ import ( "gitlab.com/xx_network/primitives/id" "gitlab.com/xx_network/primitives/id/ephemeral" "gitlab.com/xx_network/primitives/ndf" - "io" - "math/rand" - "sync" - "testing" - "time" ) //////////////////////////////////////////////////////////////////////////////// @@ -143,7 +144,7 @@ func (tnm *testNetworkManager) SendManyCMIX([]message.TargetedCmixMessage, param type dummyEventMgr struct{} func (d *dummyEventMgr) Report(int, string, string, string) {} -func (tnm *testNetworkManager) GetEventManager() event.Manager { +func (tnm *testNetworkManager) GetEventManager() event.Reporter { return &dummyEventMgr{} } diff --git a/e2e/manager.go b/e2e/manager.go index 4fa9e46b21dec8359c5603d72b270410050ba1fc..853303bbe20633da76ff329ee6f6f4a99ec769e1 100644 --- a/e2e/manager.go +++ b/e2e/manager.go @@ -3,6 +3,7 @@ package e2e import ( "encoding/json" "fmt" + "strings" "time" "github.com/pkg/errors" @@ -29,7 +30,7 @@ type manager struct { net cmix.Client myID *id.ID rng *fastRNG.StreamGenerator - events event.Manager + events event.Reporter grp *cyclic.Group crit *critical rekeyParams rekey.Params @@ -69,7 +70,7 @@ func initE2E(kv *versioned.KV, myID *id.ID, privKey *cyclic.Int, // You can use a memkv for an ephemeral e2e id func Load(kv *versioned.KV, net cmix.Client, myID *id.ID, grp *cyclic.Group, rng *fastRNG.StreamGenerator, - events event.Manager) (Handler, error) { + events event.Reporter) (Handler, error) { kv = kv.Prefix(makeE2ePrefix(myID)) return loadE2E(kv, net, myID, grp, rng, events) } @@ -82,7 +83,7 @@ func Load(kv *versioned.KV, net cmix.Client, myID *id.ID, // You can use a memkv for an ephemeral e2e id func LoadLegacy(kv *versioned.KV, net cmix.Client, myID *id.ID, grp *cyclic.Group, rng *fastRNG.StreamGenerator, - events event.Manager, params rekey.Params) (Handler, error) { + events event.Reporter, params rekey.Params) (Handler, error) { // Marshal the passed params data rekeyParamsData, err := json.Marshal(params) @@ -93,7 +94,7 @@ func LoadLegacy(kv *versioned.KV, net cmix.Client, myID *id.ID, // Check if values are already written. If they exist on disk/memory already, // this would be a case where LoadLegacy is most likely not the correct // code-path the caller should be following. - if _, err := kv.Get(e2eRekeyParamsKey, e2eRekeyParamsVer); err != nil { + if _, err := kv.Get(e2eRekeyParamsKey, e2eRekeyParamsVer); err != nil && !strings.Contains(err.Error(), "object not found") { fmt.Printf("err: %v", err) return nil, errors.New("E2E rekey params are already on disk, " + "LoadLegacy should not be called") @@ -113,7 +114,7 @@ func LoadLegacy(kv *versioned.KV, net cmix.Client, myID *id.ID, func loadE2E(kv *versioned.KV, net cmix.Client, myDefaultID *id.ID, grp *cyclic.Group, rng *fastRNG.StreamGenerator, - events event.Manager) (Handler, error) { + events event.Reporter) (Handler, error) { m := &manager{ Switchboard: receive.New(), diff --git a/e2e/manager_test.go b/e2e/manager_test.go index e040ddd8fab4e6bf08195fa168487a5e3bc48999..5367f9b3fd89dd7f562ada8fa2be9dbeb04b0999 100644 --- a/e2e/manager_test.go +++ b/e2e/manager_test.go @@ -3,6 +3,7 @@ package e2e import ( "bytes" "github.com/cloudflare/circl/dh/sidh" + "gitlab.com/elixxir/client/e2e/ratchet" "gitlab.com/elixxir/client/e2e/rekey" util "gitlab.com/elixxir/client/storage/utility" "gitlab.com/elixxir/client/storage/versioned" @@ -95,12 +96,17 @@ func TestLoadLegacy(t *testing.T) { } // Construct kv with legacy data - fs, err := ekv.NewFilestore("/home/josh/src/client/e2e/legacyEkv", "hello") + //fs, err := ekv.NewFilestore("/home/josh/src/client/e2e/legacyEkv", "hello") + //if err != nil { + // t.Fatalf( + // "Failed to create storage session: %+v", err) + //} + kv := versioned.NewKV(ekv.Memstore{}) + + err := ratchet.New(kv, myId, myPrivKey, grp) if err != nil { - t.Fatalf( - "Failed to create storage session: %+v", err) + t.Errorf("Failed to init ratchet: %+v", err) } - kv := versioned.NewKV(fs) rng := fastRNG.NewStreamGenerator(12, 3, csprng.NewSystemRNG) diff --git a/e2e/params.go b/e2e/params.go index ef54d50ee7cd1148a553055f04c53c4c2cbc9f92..52dd2d7b59f7ad6dffe999f56af84b75ed252bba 100644 --- a/e2e/params.go +++ b/e2e/params.go @@ -27,6 +27,9 @@ type Params struct { // system within e2e will be used which preserves privacy CMIX cmix.CMIXParams + // cMix network params + Network cmix.Params + //Authorizes the message to use a key reserved for rekeying. Do not use //unless sending a rekey Rekey bool diff --git a/e2e/ratchet/partner/relationship_test.go b/e2e/ratchet/partner/relationship_test.go index 911a5f711d43eb1af403872d65d6e4ad6a1d2ad7..9591ecea3193a0696b7910a64c130e537550a18c 100644 --- a/e2e/ratchet/partner/relationship_test.go +++ b/e2e/ratchet/partner/relationship_test.go @@ -75,7 +75,7 @@ func TestLoadRelationship(t *testing.T) { } // Shows that a deleted Relationship can no longer be pulled from store -func TestDeleteRelationship(t *testing.T) { +func Test_deleteRelationship(t *testing.T) { mgr, kv := makeTestRelationshipManager(t) // Generate send relationship @@ -90,7 +90,7 @@ func TestDeleteRelationship(t *testing.T) { t.Fatal(err) } - err := mgr.DeleteRelationship() + err := mgr.deleteRelationships() if err != nil { t.Fatalf("DeleteRelationship error: Could not delete manager: %v", err) } diff --git a/e2e/ratchet/partner/utils.go b/e2e/ratchet/partner/utils.go new file mode 100644 index 0000000000000000000000000000000000000000..55a25ab9b2f9da82f655eed4dcf85b2331fffa9a --- /dev/null +++ b/e2e/ratchet/partner/utils.go @@ -0,0 +1,113 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2022 Privategrity Corporation / +// / +// All rights reserved. / +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2022 Privategrity Corporation / +// / +// All rights reserved. / +//////////////////////////////////////////////////////////////////////////////// + +package partner + +import ( + "github.com/cloudflare/circl/dh/sidh" + "gitlab.com/elixxir/client/cmix/message" + "gitlab.com/elixxir/client/e2e/ratchet/partner/session" + "gitlab.com/elixxir/crypto/contact" + "gitlab.com/elixxir/crypto/cyclic" + "gitlab.com/xx_network/primitives/id" + "testing" +) + +// Test implementation of the Manager interface +type testManager struct { + partnerId *id.ID + grp *cyclic.Group + partnerPubKey, myPrivKey *cyclic.Int +} + +// NewTestManager allows creation of a Manager interface object for testing purposes +// Backwards compatibility must be maintained if you make changes here +// Currently used for: Group chat testing +func NewTestManager(partnerId *id.ID, partnerPubKey, myPrivKey *cyclic.Int, t *testing.T) Manager { + return &testManager{partnerId: partnerId, partnerPubKey: partnerPubKey, myPrivKey: myPrivKey} +} + +func (p *testManager) GetPartnerID() *id.ID { + return p.partnerId +} + +func (p *testManager) GetMyID() *id.ID { + panic("implement me") +} + +func (p *testManager) GetMyOriginPrivateKey() *cyclic.Int { + return p.myPrivKey +} + +func (p *testManager) GetPartnerOriginPublicKey() *cyclic.Int { + return p.partnerPubKey +} + +func (p *testManager) GetSendRelationshipFingerprint() []byte { + panic("implement me") +} + +func (p *testManager) GetReceiveRelationshipFingerprint() []byte { + panic("implement me") +} + +func (p *testManager) GetConnectionFingerprintBytes() []byte { + panic("implement me") +} + +func (p *testManager) GetConnectionFingerprint() string { + panic("implement me") +} + +func (p *testManager) GetContact() contact.Contact { + panic("implement me") +} + +func (p *testManager) PopSendCypher() (*session.Cypher, error) { + panic("implement me") +} + +func (p *testManager) PopRekeyCypher() (*session.Cypher, error) { + panic("implement me") +} + +func (p *testManager) NewReceiveSession(partnerPubKey *cyclic.Int, partnerSIDHPubKey *sidh.PublicKey, e2eParams session.Params, source *session.Session) (*session.Session, bool) { + panic("implement me") +} + +func (p *testManager) NewSendSession(myDHPrivKey *cyclic.Int, mySIDHPrivateKey *sidh.PrivateKey, e2eParams session.Params, source *session.Session) *session.Session { + panic("implement me") +} + +func (p *testManager) GetSendSession(sid session.SessionID) *session.Session { + panic("implement me") +} + +func (p *testManager) GetReceiveSession(sid session.SessionID) *session.Session { + panic("implement me") +} + +func (p *testManager) Confirm(sid session.SessionID) error { + panic("implement me") +} + +func (p *testManager) TriggerNegotiations() []*session.Session { + panic("implement me") +} + +func (p *testManager) MakeService(tag string) message.Service { + panic("implement me") +} + +func (p *testManager) Delete() error { + panic("implement me") +} diff --git a/e2e/rekey/rekey.go b/e2e/rekey/rekey.go index 91209bd741f62d039fc9a46276d13ab7727a2a92..8de7ef2daff692cc41914f67c3d8ea7974bb84c9 100644 --- a/e2e/rekey/rekey.go +++ b/e2e/rekey/rekey.go @@ -27,7 +27,7 @@ import ( ) func CheckKeyExchanges(instance *commsNetwork.Instance, grp *cyclic.Group, - sendE2E E2eSender, events event.Manager, manager partner.Manager, + sendE2E E2eSender, events event.Reporter, manager partner.Manager, param Params, sendTimeout time.Duration) { //get all sessions that may need a key exchange @@ -45,7 +45,7 @@ func CheckKeyExchanges(instance *commsNetwork.Instance, grp *cyclic.Group, // session. They run the same negotiation, the former does it on a newly created // session while the latter on an extant session func trigger(instance *commsNetwork.Instance, grp *cyclic.Group, sendE2E E2eSender, - events event.Manager, manager partner.Manager, inputSession *session.Session, + events event.Reporter, manager partner.Manager, inputSession *session.Session, sendTimeout time.Duration, params Params) { var negotiatingSession *session.Session diff --git a/event/event.go b/event/event.go index 4a15a56fb8aea3968f2d39ff1592d749f8427ee5..1b017009e991c089bd9b7be930759bd2a1d3ffec 100644 --- a/event/event.go +++ b/event/event.go @@ -9,10 +9,11 @@ package event import ( "fmt" + "sync" + "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/stoppable" - "sync" ) // ReportableEvent is used to surface events to client users. @@ -30,20 +31,20 @@ func (e reportableEvent) String() string { } // Holds state for the event reporting system -type eventManager struct { +type Manager struct { eventCh chan reportableEvent eventCbs sync.Map } -func newEventManager() *eventManager { - return &eventManager{ +func NewEventManager() *Manager { + return &Manager{ eventCh: make(chan reportableEvent, 1000), } } // Report reports an event from the client to api users, providing a // priority, category, eventType, and details -func (e *eventManager) Report(priority int, category, evtType, details string) { +func (e *Manager) Report(priority int, category, evtType, details string) { re := reportableEvent{ Priority: priority, Category: category, @@ -61,7 +62,7 @@ func (e *eventManager) Report(priority int, category, evtType, details string) { // RegisterEventCallback records the given function to receive // ReportableEvent objects. It returns the internal index // of the callback so that it can be deleted later. -func (e *eventManager) RegisterEventCallback(name string, +func (e *Manager) RegisterEventCallback(name string, myFunc Callback) error { _, existsAlready := e.eventCbs.LoadOrStore(name, myFunc) if existsAlready { @@ -73,18 +74,18 @@ func (e *eventManager) RegisterEventCallback(name string, // UnregisterEventCallback deletes the callback identified by the // index. It returns an error if it fails. -func (e *eventManager) UnregisterEventCallback(name string) { +func (e *Manager) UnregisterEventCallback(name string) { e.eventCbs.Delete(name) } -func (e *eventManager) eventService() (stoppable.Stoppable, error) { +func (e *Manager) EventService() (stoppable.Stoppable, error) { stop := stoppable.NewSingle("EventReporting") go e.reportEventsHandler(stop) return stop, nil } // reportEventsHandler reports events to every registered event callback -func (e *eventManager) reportEventsHandler(stop *stoppable.Single) { +func (e *Manager) reportEventsHandler(stop *stoppable.Single) { jww.DEBUG.Print("reportEventsHandler routine started") for { select { diff --git a/event/event_test.go b/event/event_test.go index c960f1374b134f5635984a3571e154fb4e901291..d657206bac2cbf2ed761ab454e6728484931a6f6 100644 --- a/event/event_test.go +++ b/event/event_test.go @@ -25,8 +25,8 @@ func TestEventReporting(t *testing.T) { evts = append(evts, evt) } - evtMgr := newEventManager() - stop, _ := evtMgr.eventService() + evtMgr := NewEventManager() + stop, _ := evtMgr.EventService() // Register a callback err := evtMgr.RegisterEventCallback("test", myCb) if err != nil { diff --git a/event/interface.go b/event/interface.go index 47e8d9795d7cc63d30b04715b92f0ec842479cf5..867e4dd9c11760474240c27a3661b910b8fe061d 100644 --- a/event/interface.go +++ b/event/interface.go @@ -10,7 +10,7 @@ package event // Callback defines the callback functions for client event reports type Callback func(priority int, category, evtType, details string) -// Manager reporting api (used internally) -type Manager interface { +// Reporter reporting api (used internally) +type Reporter interface { Report(priority int, category, evtType, details string) } diff --git a/fileTransfer/batchBuilder.go b/fileTransfer/batchBuilder.go new file mode 100644 index 0000000000000000000000000000000000000000..005c7f93827ecbc8d4178f7c0b356d184c621ccf --- /dev/null +++ b/fileTransfer/batchBuilder.go @@ -0,0 +1,82 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "encoding/binary" + jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/client/fileTransfer/store" + "gitlab.com/elixxir/client/stoppable" + "gitlab.com/elixxir/crypto/fastRNG" + "gitlab.com/xx_network/crypto/csprng" + "go.uber.org/ratelimit" + "time" +) + +const ( + // Duration to wait before adding a partially filled part packet to the send + // channel. + unfilledPacketTimeout = 100 * time.Millisecond +) + +// batchBuilderThread creates batches of file parts as they become available and +// buffer them to send. Also rate limits adding to the buffer. +func (m *manager) batchBuilderThread(stop *stoppable.Single) { + // Calculate the average amount of data sent via SendManyCMIX + avgNumMessages := (minPartsSendPerRound + maxPartsSendPerRound) / 2 + avgSendSize := avgNumMessages * 8192 + + // Calculate rate and make rate limiter + rate := m.params.MaxThroughput / avgSendSize + rl := ratelimit.New(rate, ratelimit.WithoutSlack) + + for { + numParts := generateRandomPacketSize(m.rng) + packet := make([]store.Part, 0, numParts) + delayedTimer := NewDelayedTimer(unfilledPacketTimeout) + loop: + for cap(packet) > len(packet) { + select { + case <-stop.Quit(): + delayedTimer.Stop() + jww.DEBUG.Printf("[FT] Stopping file part packing thread " + + "while packing: stoppable triggered.") + stop.ToStopped() + return + case <-*delayedTimer.C: + break loop + case p := <-m.batchQueue: + packet = append(packet, p) + delayedTimer.Start() + } + } + + // Rate limiter + rl.Take() + m.sendQueue <- packet + } +} + +// generateRandomPacketSize returns a random number between minPartsSendPerRound +// and maxPartsSendPerRound, inclusive. +func generateRandomPacketSize(rngGen *fastRNG.StreamGenerator) int { + rng := rngGen.GetStream() + defer rng.Close() + + // Generate random bytes + b, err := csprng.Generate(8, rng) + if err != nil { + jww.FATAL.Panicf(getRandomNumPartsRandPanic, err) + } + + // Convert bytes to integer + num := binary.LittleEndian.Uint64(b) + + // Return random number that is minPartsSendPerRound <= num <= max + return int((num % (maxPartsSendPerRound)) + minPartsSendPerRound) +} diff --git a/fileTransfer/callbackTracker/callbackTracker.go b/fileTransfer/callbackTracker/callbackTracker.go new file mode 100644 index 0000000000000000000000000000000000000000..6cf7a641c10b2fccf96b46097a5c18dbabea3655 --- /dev/null +++ b/fileTransfer/callbackTracker/callbackTracker.go @@ -0,0 +1,99 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package callbackTracker + +import ( + "gitlab.com/elixxir/client/stoppable" + "gitlab.com/xx_network/primitives/netTime" + "sync" + "time" +) + +type callback func(err error) + +// callbackTracker tracks the fileTransfer.SentProgressCallback and +// information on when to call it. The callback will be called on each send, +// unless the time since the lastCall is smaller than the period. In that case, +// a callback is marked as scheduled and waits to be called at the end of the +// period. A callback is called once every period, regardless of the number of +// sends that occur. +type callbackTracker struct { + period time.Duration // How often to call the callback + lastCall time.Time // Timestamp of the last call + scheduled bool // Denotes if callback call is scheduled + complete bool // Denotes if the callback should not be called + stop *stoppable.Single // Stops the scheduled callback from triggering + cb callback + mux sync.RWMutex +} + +// newCallbackTracker creates a new and unused sentCallbackTracker. +func newCallbackTracker( + cb callback, period time.Duration, stop *stoppable.Single) *callbackTracker { + return &callbackTracker{ + period: period, + lastCall: time.Time{}, + scheduled: false, + complete: false, + stop: stop, + cb: cb, + } +} + +// call triggers the progress callback with the most recent progress from the +// sentProgressTracker. If a callback has been called within the last period, +// then a new call is scheduled to occur at the beginning of the next period. If +// a call is already scheduled, then nothing happens; when the callback is +// finally called, it will do so with the most recent changes. +func (ct *callbackTracker) call(err error) { + ct.mux.RLock() + // Exit if a callback is already scheduled + if (ct.scheduled || ct.complete) && err == nil { + ct.mux.RUnlock() + return + } + + ct.mux.RUnlock() + ct.mux.Lock() + defer ct.mux.Unlock() + + if (ct.scheduled || ct.complete) && err == nil { + return + } + + // Mark callback complete if an error is passed + ct.complete = err != nil + + // Check if a callback has occurred within the last period + timeSinceLastCall := netTime.Since(ct.lastCall) + if timeSinceLastCall > ct.period { + + // If no callback occurred, then trigger the callback now + go ct.cb(err) + ct.lastCall = netTime.Now() + } else { + // If a callback did occur, then schedule a new callback to occur at the + // start of the next period + ct.scheduled = true + go func() { + timer := time.NewTimer(ct.period - timeSinceLastCall) + select { + case <-ct.stop.Quit(): + timer.Stop() + ct.stop.ToStopped() + return + case <-timer.C: + ct.mux.Lock() + go ct.cb(err) + ct.lastCall = netTime.Now() + ct.scheduled = false + ct.mux.Unlock() + } + }() + } +} diff --git a/fileTransfer/callbackTracker/callbackTracker_test.go b/fileTransfer/callbackTracker/callbackTracker_test.go new file mode 100644 index 0000000000000000000000000000000000000000..f14d368ee0fb6fbaa0ef9bdd1c2ac4b17c511755 --- /dev/null +++ b/fileTransfer/callbackTracker/callbackTracker_test.go @@ -0,0 +1,148 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package callbackTracker + +import ( + "github.com/pkg/errors" + "gitlab.com/elixxir/client/stoppable" + "reflect" + "testing" + "time" +) + +// Tests that newCallbackTracker returns a new callbackTracker with all the +// expected values. +func Test_newCallbackTracker(t *testing.T) { + expected := &callbackTracker{ + period: time.Millisecond, + lastCall: time.Time{}, + scheduled: false, + complete: false, + stop: stoppable.NewSingle("Test_newCallbackTracker"), + } + + newCT := newCallbackTracker(nil, expected.period, expected.stop) + newCT.cb = nil + + if !reflect.DeepEqual(expected, newCT) { + t.Errorf("New callbackTracker does not match expected."+ + "\nexpected: %+v\nreceived: %+v", expected, newCT) + } +} + +// Tests four test cases of callbackTracker.call: +// 1. An initial call is not scheduled. +// 2. A second call within the periods is only called after the period. +// 3. An error sets the callback to complete. +// 4. No more callbacks will be called after set to complete. +func Test_callbackTracker_call(t *testing.T) { + cbChan := make(chan error, 10) + cb := func(err error) { cbChan <- err } + stop := stoppable.NewSingle("Test_callbackTracker_call") + ct := newCallbackTracker(cb, 250*time.Millisecond, stop) + + // Test that the initial call is unscheduled and is called before the period + go ct.call(nil) + + select { + case r := <-cbChan: + if r != nil { + t.Errorf("Received error: %+v", r) + } + case <-time.After(25 * time.Millisecond): + t.Error("Timed out waiting for callback.") + } + + // Test that another call within the period is called only after the period + // is reached + go ct.call(nil) + + select { + case <-cbChan: + t.Error("Callback called too soon.") + + case <-time.After(25 * time.Millisecond): + ct.mux.RLock() + if !ct.scheduled { + t.Error("Callback is not scheduled when it should be.") + } + ct.mux.RUnlock() + select { + case r := <-cbChan: + if r != nil { + t.Errorf("Received error: %+v", r) + } + case <-time.After(ct.period): + t.Errorf("Callback not called after period %s.", ct.period) + + if ct.scheduled { + t.Error("Callback is scheduled when it should not be.") + } + } + } + + // Test that calling with an error sets the callback to complete + expectedErr := errors.New("test error") + go ct.call(expectedErr) + + select { + case r := <-cbChan: + if r != expectedErr { + t.Errorf("Received incorrect error.\nexpected: %v\nreceived: %v", + expectedErr, r) + } + if !ct.complete { + t.Error("Callback is not marked complete when it should be.") + } + case <-time.After(ct.period + 15*time.Millisecond): + t.Errorf("Callback not called after period %s.", + ct.period+15*time.Millisecond) + } + + // Tests that all callback calls after an error are blocked + go ct.call(nil) + + select { + case r := <-cbChan: + t.Errorf("Received callback when it should have been completed: %+v", r) + case <-time.After(ct.period): + } +} + +// Tests that callbackTracker.call does not call on the callback when the +// stoppable is triggered. +func Test_callbackTracker_call_stop(t *testing.T) { + cbChan := make(chan error, 10) + cb := func(err error) { cbChan <- err } + stop := stoppable.NewSingle("Test_callbackTracker_call") + ct := newCallbackTracker(cb, 250*time.Millisecond, stop) + + go ct.call(nil) + + select { + case r := <-cbChan: + if r != nil { + t.Errorf("Received error: %+v", r) + } + case <-time.After(25 * time.Millisecond): + t.Error("Timed out waiting for callback.") + } + + go ct.call(nil) + + err := stop.Close() + if err != nil { + t.Errorf("Failed closing stoppable: %+v", err) + } + + select { + case <-cbChan: + t.Error("Callback called.") + case <-time.After(ct.period * 2): + } +} diff --git a/fileTransfer/callbackTracker/manager.go b/fileTransfer/callbackTracker/manager.go new file mode 100644 index 0000000000000000000000000000000000000000..e6c4cf437ad0d625454667950d2a136baa8f20d2 --- /dev/null +++ b/fileTransfer/callbackTracker/manager.go @@ -0,0 +1,99 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package callbackTracker + +import ( + jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/client/stoppable" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "strconv" + "sync" + "time" +) + +// Manager tracks the callbacks for each transfer. +type Manager struct { + // Map of transfers and their list of callbacks + callbacks map[ftCrypto.TransferID][]*callbackTracker + + // List of multi stoppables used to stop callback trackers; each multi + // stoppable contains a single stoppable for each callback. + stops map[ftCrypto.TransferID]*stoppable.Multi + + mux sync.RWMutex +} + +// NewManager initializes a new callback tracker Manager. +func NewManager() *Manager { + m := &Manager{ + callbacks: make(map[ftCrypto.TransferID][]*callbackTracker), + stops: make(map[ftCrypto.TransferID]*stoppable.Multi), + } + + return m +} + +// AddCallback adds a callback to the list of callbacks for the given transfer +// ID and calls it regardless of the callback tracker status. +func (m *Manager) AddCallback( + tid *ftCrypto.TransferID, cb callback, period time.Duration) { + m.mux.Lock() + defer m.mux.Unlock() + + // Create new entries for this transfer ID if none exist + if _, exists := m.callbacks[*tid]; !exists { + m.callbacks[*tid] = []*callbackTracker{} + m.stops[*tid] = stoppable.NewMulti("FileTransfer/" + tid.String()) + } + + // Generate the stoppable and add it to the transfer's multi stoppable + stop := stoppable.NewSingle(makeStoppableName(tid, len(m.callbacks[*tid]))) + m.stops[*tid].Add(stop) + + // Create new callback tracker and add to the map + ct := newCallbackTracker(cb, period, stop) + m.callbacks[*tid] = append(m.callbacks[*tid], ct) + + // Call the callback + go cb(nil) +} + +// Call triggers each callback for the given transfer ID and passes along the +// given error. +func (m *Manager) Call(tid *ftCrypto.TransferID, err error) { + m.mux.Lock() + defer m.mux.Unlock() + + for _, cb := range m.callbacks[*tid] { + go cb.call(err) + } +} + +// Delete stops all scheduled stoppables for the given transfer and deletes the +// callbacks from the map. +func (m *Manager) Delete(tid *ftCrypto.TransferID) { + m.mux.Lock() + defer m.mux.Unlock() + + // Stop the stoppable if the stoppable still exists + stop, exists := m.stops[*tid] + if exists { + if err := stop.Close(); err != nil { + jww.ERROR.Printf("[FT] Failed to stop progress callbacks: %+v", err) + } + } + + // Delete callbacks and stoppables + delete(m.callbacks, *tid) + delete(m.stops, *tid) +} + +// makeStoppableName generates a unique name for the callback stoppable. +func makeStoppableName(tid *ftCrypto.TransferID, callbackNum int) string { + return tid.String() + "/" + strconv.Itoa(callbackNum) +} diff --git a/fileTransfer/callbackTracker/manager_test.go b/fileTransfer/callbackTracker/manager_test.go new file mode 100644 index 0000000000000000000000000000000000000000..4226807f18a728345fc4907758bd27ae21b6dcc1 --- /dev/null +++ b/fileTransfer/callbackTracker/manager_test.go @@ -0,0 +1,176 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package callbackTracker + +import ( + "errors" + "gitlab.com/elixxir/client/stoppable" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/xx_network/crypto/csprng" + "io" + "math/rand" + "reflect" + "sync" + "testing" + "time" +) + +// Tests that NewManager returns the expected Manager. +func TestNewManager(t *testing.T) { + expected := &Manager{ + callbacks: make(map[ftCrypto.TransferID][]*callbackTracker), + stops: make(map[ftCrypto.TransferID]*stoppable.Multi), + } + + newManager := NewManager() + + if !reflect.DeepEqual(expected, newManager) { + t.Errorf("New Manager does not match expected."+ + "\nexpected: %+v\nreceived: %+v", expected, newManager) + } +} + +// Tests that Manager.AddCallback adds the callback to the list, creates and +// adds a stoppable to the list, and that the callback is called. +func TestManager_AddCallback(t *testing.T) { + m := NewManager() + + cbChan := make(chan error, 10) + cb := func(err error) { cbChan <- err } + tid := &ftCrypto.TransferID{5} + m.AddCallback(tid, cb, 0) + + // Check that the callback was called + select { + case <-cbChan: + case <-time.After(25 * time.Millisecond): + t.Error("Timed out waiting for callback to be called.") + } + + // Check that the callback was added + if _, exists := m.callbacks[*tid]; !exists { + t.Errorf("No callback list found for transfer ID %s.", tid) + } else if len(m.callbacks[*tid]) != 1 { + t.Errorf("Incorrect number of callbacks.\nexpected: %d\nreceived: %d", + 1, len(m.callbacks[*tid])) + } + + // Check that the stoppable was added + if _, exists := m.stops[*tid]; !exists { + t.Errorf("No stoppable list found for transfer ID %s.", tid) + } +} + +// Tests that Manager.Call calls al the callbacks associated with the transfer +// ID. +func TestManager_Call(t *testing.T) { + m := NewManager() + tid := &ftCrypto.TransferID{5} + n := 10 + cbChans := make([]chan error, n) + cbs := make([]func(err error), n) + for i := range cbChans { + cbChan := make(chan error, 10) + cbs[i] = func(err error) { cbChan <- err } + cbChans[i] = cbChan + } + + // Add callbacks + for i := range cbs { + m.AddCallback(tid, cbs[i], 0) + + // Receive channel from first call + select { + case <-cbChans[i]: + case <-time.After(25 * time.Millisecond): + t.Errorf("Callback #%d never called.", i) + } + } + + // Call callbacks + m.Call(tid, errors.New("test")) + + // Check to make sure callbacks were called + var wg sync.WaitGroup + for i := range cbs { + wg.Add(1) + go func(i int) { + select { + case r := <-cbChans[i]: + if r == nil { + t.Errorf("Callback #%d did not receive an error.", i) + } + case <-time.After(25 * time.Millisecond): + t.Errorf("Callback #%d never called.", i) + } + + wg.Done() + }(i) + } + + wg.Wait() +} + +// Tests that Manager.Delete removes all callbacks and stoppables from the list. +func TestManager_Delete(t *testing.T) { + m := NewManager() + + cbChan := make(chan error, 10) + cb := func(err error) { cbChan <- err } + tid := &ftCrypto.TransferID{5} + m.AddCallback(tid, cb, 0) + + m.Delete(tid) + + // Check that the callback was deleted + if _, exists := m.callbacks[*tid]; exists { + t.Errorf("Callback list found for transfer ID %s.", tid) + } + + // Check that the stoppable was deleted + if _, exists := m.stops[*tid]; exists { + t.Errorf("Stoppable list found for transfer ID %s.", tid) + } +} + +// Consistency test of makeStoppableName. +func Test_makeStoppableName_Consistency(t *testing.T) { + rng := NewPrng(42) + expectedValues := []string{ + "U4x/lrFkvxuXu59LtHLon1sUhPJSCcnZND6SugndnVI=/0", + "39ebTXZCm2F6DJ+fDTulWwzA1hRMiIU1hBrL4HCbB1g=/1", + "CD9h03W8ArQd9PkZKeGP2p5vguVOdI6B555LvW/jTNw=/2", + "uoQ+6NY+jE/+HOvqVG2PrBPdGqwEzi6ih3xVec+ix44=/3", + "GwuvrogbgqdREIpC7TyQPKpDRlp4YgYWl4rtDOPGxPM=/4", + "rnvD4ElbVxL+/b4MECiH4QDazS2IX2kstgfaAKEcHHA=/5", + "ceeWotwtwlpbdLLhKXBeJz8FySMmgo4rBW44F2WOEGE=/6", + "SYlH/fNEQQ7UwRYCP6jjV2tv7Sf/iXS6wMr9mtBWkrE=/7", + "NhnnOJZN/ceejVNDc2Yc/WbXT+weG4lJGrcjbkt1IWI=/8", + "kM8r60LDyicyhWDxqsBnzqbov0bUqytGgEAsX7KCDog=/9", + } + + for i, expected := range expectedValues { + tid, err := ftCrypto.NewTransferID(rng) + if err != nil { + t.Errorf("Failed to generated transfer ID #%d: %+v", i, err) + } + + name := makeStoppableName(&tid, i) + if expected != name { + t.Errorf("Stoppable name does not match expected."+ + "\nexpected: %q\nreceived: %q", expected, name) + } + } +} + +// Prng is a PRNG that satisfies the csprng.Source interface. +type Prng struct{ prng io.Reader } + +func NewPrng(seed int64) csprng.Source { return &Prng{rand.New(rand.NewSource(seed))} } +func (s *Prng) Read(b []byte) (int, error) { return s.prng.Read(b) } +func (s *Prng) SetSeed([]byte) error { return nil } diff --git a/fileTransfer/delayedTimer.go b/fileTransfer/delayedTimer.go new file mode 100644 index 0000000000000000000000000000000000000000..e07af67d67bf92fcb8807ca0cee48bcdd631beb1 --- /dev/null +++ b/fileTransfer/delayedTimer.go @@ -0,0 +1,49 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import "time" + +// The DelayedTimer type represents a single event manually started. +// When the DelayedTimer expires, the current time will be sent on C. +// A DelayedTimer must be created with NewDelayedTimer. +type DelayedTimer struct { + d time.Duration + t *time.Timer + C *<-chan time.Time +} + +// NewDelayedTimer creates a new DelayedTimer that will send the current time on +// its channel after at least duration d once it is started. +func NewDelayedTimer(d time.Duration) *DelayedTimer { + c := make(<-chan time.Time) + return &DelayedTimer{ + d: d, + C: &c, + } +} + +// Start starts the timer that will send the current time on its channel after +// at least duration d. If it is already running or stopped, it does nothing. +func (dt *DelayedTimer) Start() { + if dt.t == nil { + dt.t = time.NewTimer(dt.d) + dt.C = &dt.t.C + } +} + +// Stop prevents the Timer from firing. +// It returns true if the call stops the timer, false if the timer has already +// expired, been stopped, or was never started. +func (dt *DelayedTimer) Stop() bool { + if dt.t == nil { + return false + } + + return dt.t.Stop() +} diff --git a/fileTransfer/ftMessages.proto b/fileTransfer/ftMessages.proto index dae0d2f2a4d72784965924a547f7e5d599386172..22b2235a451b54a56b994018cc5f97f8c738299a 100644 --- a/fileTransfer/ftMessages.proto +++ b/fileTransfer/ftMessages.proto @@ -13,12 +13,12 @@ option go_package = "fileTransfer/"; // NewFileTransfer is transmitted first on the initialization of a file transfer // to inform the receiver about the incoming file. message NewFileTransfer { - string fileName = 1; // Name of the file; max 48 characters - string fileType = 2; // Type of file; max 8 characters - bytes transferKey = 3; // 256 bit encryption key to identify the transfer - bytes transferMac = 4; // 256 bit MAC of the entire file - uint32 numParts = 5; // Number of file parts - uint32 size = 6; // The size of the file; max of 250 kB - float retry = 7; // Used to determine how many times to retry sending - bytes preview = 8; // A preview of the file; max of 4 kB + string fileName = 1; // Name of the file; max 48 characters + string fileType = 2; // Type of file; max 8 characters + bytes transferKey = 3; // 256 bit encryption key to identify the transfer + bytes transferMac = 4; // 256 bit MAC of the entire file + uint32 numParts = 5; // Number of file parts + uint32 size = 6; // The size of the file; max of 250 kB + float retry = 7; // Used to determine how many times to retry sending + bytes preview = 8; // A preview of the file; max of 4 kB } \ No newline at end of file diff --git a/fileTransfer/interface.go b/fileTransfer/interface.go new file mode 100644 index 0000000000000000000000000000000000000000..d31dbb99be99f532fab1a7bc9b5edaf8cc0817e5 --- /dev/null +++ b/fileTransfer/interface.go @@ -0,0 +1,217 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "gitlab.com/elixxir/client/stoppable" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/xx_network/primitives/id" + "strconv" + "time" +) + +// SentProgressCallback is a callback function that tracks the progress of +// sending a file. +type SentProgressCallback func(completed bool, arrived, total uint16, + t FilePartTracker, err error) + +// ReceivedProgressCallback is a callback function that tracks the progress of +// receiving a file. +type ReceivedProgressCallback func(completed bool, received, total uint16, + t FilePartTracker, err error) + +// ReceiveCallback is a callback function that notifies the receiver of an +// incoming file transfer. +type ReceiveCallback func(tid *ftCrypto.TransferID, fileName, fileType string, + sender *id.ID, size uint32, preview []byte) + +// FileTransfer facilities the sending and receiving of large file transfers. +// It allows for progress tracking of both inbound and outbound transfers. +type FileTransfer interface { + + // StartProcesses starts the listening for new file transfer messages and + // starts the sending threads that wait for transfers to send. + StartProcesses() (stoppable.Stoppable, error) + + // MaxFileNameLen returns the max number of bytes allowed for a file name. + MaxFileNameLen() int + + // MaxFileTypeLen returns the max number of bytes allowed for a file type. + MaxFileTypeLen() int + + // MaxFileSize returns the max number of bytes allowed for a file. + MaxFileSize() int + + // MaxPreviewSize returns the max number of bytes allowed for a file + // preview. + MaxPreviewSize() int + + /* === Sending ========================================================== */ + /* The processes of sending a file involves three main steps: + 1. Sending the file using Send + 2. Receiving transfer progress + 3. Closing a finished send using CloseSend + + Once the file is sent, it is broken into individual, equal-length parts + and sent to the recipient. Every time one of these parts arrives, it is + reported on all registered SentProgressCallbacks for that transfer. + + A SentProgressCallback is registered on the initial send. However, if the + client is closed and reopened, the callback must be registered again + using RegisterSentProgressCallback, otherwise the continued progress of + the transfer will not be reported. + + Once the SentProgressCallback returns that the file has completed + sending, the file can be closed using CloseSend. If the callback reports + an error, then the file should also be closed using CloseSend. + */ + + // Send initiates the sending of a file to the recipient and returns a + // transfer ID that uniquely identifies this file transfer. + // In-progress transfers are restored when closing and reopening; however, a + // SentProgressCallback must be registered again. + // fileName - Human-readable file name. Max length defined by + // MaxFileNameLen. + // fileType - Shorthand that identifies the type of file. Max length + // defined by MaxFileTypeLen. + // fileData - File contents. Max size defined by MaxFileSize. + // recipient - ID of the receiver of the file transfer. The sender must + // have an E2E relationship with the recipient. + // retry - The number of sending retries allowed on send failure (e.g. + // a retry of 2.0 with 6 parts means 12 total possible sends). + // preview - A preview of the file data (e.g. a thumbnail). Max size + // defined by MaxPreviewSize. + // progressCB - A callback that reports the progress of the file transfer. + // The callback is called once on initialization, on every progress + // update (or less if restricted by the period), or on fatal error. + // period - A progress callback will be limited from triggering only once + // per period. + Send(fileName, fileType string, fileData []byte, recipient *id.ID, + retry float32, preview []byte, progressCB SentProgressCallback, + period time.Duration) (*ftCrypto.TransferID, error) + + // RegisterSentProgressCallback allows for the registration of a callback to + // track the progress of an individual sent file transfer. + // SentProgressCallback is auto registered on Send; this function should be + // called when resuming clients or registering extra callbacks. + // + // The callback will be called immediately when added to report the current + // progress of the transfer. It will then call every time a file part + // arrives, the transfer completes, or a fatal error occurs. It is called at + // most once every period regardless of the number of progress updates. + // + // In the event that the client is closed and resumed, this function must be + // used to re-register any callbacks previously registered with this + // function or Send. + RegisterSentProgressCallback(tid *ftCrypto.TransferID, + progressCB SentProgressCallback, period time.Duration) error + + // CloseSend deletes a file from the internal storage once a transfer has + // completed or reached the retry limit. Returns an error if the transfer + // has not run out of retries. + // + // This function should be called once a transfer completes or errors out + // (as reported by the progress callback). + CloseSend(tid *ftCrypto.TransferID) error + + /* === Receiving ======================================================== */ + /* The processes of receiving a file involves three main steps: + 1. Receiving a new file transfer on ReceiveCallback + 2. Receiving transfer progress + 3. Receiving the complete file using Receive + + Once the file transfer manager has started, it will call the + ReceiveCallback for every new file transfer that is received. Once that + happens, a ReceivedProgressCallback must be registered using + RegisterReceivedProgressCallback to get progress updates on the transfer. + + When the progress callback reports that the transfer is complete, the + full file can be retrieved using Receive. + */ + + // RegisterReceivedProgressCallback allows for the registration of a + // callback to track the progress of an individual received file transfer. + // This should be done when a new transfer is received on the + // ReceiveCallback. + // + // The callback will be called immediately when added to report the current + // progress of the transfer. It will then call every time a file part is + // received, the transfer completes, or a fatal error occurs. It is called + // at most once every period regardless of the number of progress updates. + // + // In the event that the client is closed and resumed, this function must be + // used to re-register any callbacks previously registered. + // + // Once the callback reports that the transfer has completed, the recipient + // can get the full file by calling Receive. + RegisterReceivedProgressCallback(tid *ftCrypto.TransferID, + progressCB ReceivedProgressCallback, period time.Duration) error + + // Receive returns the full file on the completion of the transfer. + // It deletes internal references to the data and unregisters any attached + // progress callback. Returns an error if the transfer is not complete, the + // full file cannot be verified, or if the transfer cannot be found. + // + // Receive can only be called once the progress callback returns that the + // file transfer is complete. + Receive(tid *ftCrypto.TransferID) ([]byte, error) +} + +// FilePartTracker tracks the status of each file part in a sent or received +// file transfer. +type FilePartTracker interface { + // GetPartStatus returns the status of the file part with the given part + // number. The possible values for the status are: + // 0 < Part does not exist + // 0 = unsent + // 1 = arrived (sender has sent a part, and it has arrived) + // 2 = received (receiver has received a part) + GetPartStatus(partNum uint16) FpStatus + + // GetNumParts returns the total number of file parts in the transfer. + GetNumParts() uint16 +} + +// FpStatus is the file part status and indicates the status of individual file +// parts in a file transfer. +type FpStatus int + +// Possible values for FpStatus. +const ( + // FpUnsent indicates that the file part has not been sent + FpUnsent FpStatus = iota + + // FpSent indicates that the file part has been sent (sender has sent a + // part, but it has not arrived) + FpSent + + // FpArrived indicates that the file part has arrived (sender has sent a + // part, and it has arrived) + FpArrived + + // FpReceived indicates that the file part has been received (receiver has + // received a part) + FpReceived +) + +// String returns the string representing of the FpStatus. This functions +// satisfies the fmt.Stringer interface. +func (fps FpStatus) String() string { + switch fps { + case FpUnsent: + return "unsent" + case FpSent: + return "sent" + case FpArrived: + return "arrived" + case FpReceived: + return "received" + default: + return "INVALID FpStatus: " + strconv.Itoa(int(fps)) + } +} diff --git a/fileTransfer/listener.go b/fileTransfer/listener.go new file mode 100644 index 0000000000000000000000000000000000000000..23f4202489ab262516203a53c42514d411db6c66 --- /dev/null +++ b/fileTransfer/listener.go @@ -0,0 +1,93 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "github.com/golang/protobuf/proto" + jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/client/e2e/receive" + "gitlab.com/elixxir/client/fileTransfer/store" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" +) + +// Error messages. +const ( + errProtoUnmarshal = "[FT] Failed to proto unmarshal new file transfer request: %+v" + errNewRtTransferID = "[FT] Failed to generate transfer ID for new received file transfer: %+v" + errAddNewRt = "[FT] Failed to add new file transfer %s (%q): %+v" +) + +// Name of listener (used for debugging) +const listenerName = "NewFileTransferListener" + +// fileTransferListener waits for new file transfer messages to get ready to +// receive the file transfer parts. Adheres to the receive.Listener interface. +type fileTransferListener struct { + *manager +} + +// Hear is called when a new file transfer is received. It creates a new +// internal received file transfer and starts waiting to receive file part +// messages. +func (ftl *fileTransferListener) Hear(msg receive.Message) { + // Unmarshal the request message + newFT := &NewFileTransfer{} + err := proto.Unmarshal(msg.Payload, newFT) + if err != nil { + jww.ERROR.Printf(errProtoUnmarshal, err) + return + } + // Generate new transfer ID + rng := ftl.rng.GetStream() + tid, err := ftCrypto.NewTransferID(rng) + if err != nil { + jww.ERROR.Printf(errNewRtTransferID, err) + return + } + rng.Close() + + key := ftCrypto.UnmarshalTransferKey(newFT.TransferKey) + numParts := uint16(newFT.GetNumParts()) + numFps := calcNumberOfFingerprints(int(numParts), newFT.GetRetry()) + + rt, err := ftl.received.AddTransfer(&key, &tid, newFT.GetFileName(), + newFT.GetTransferMac(), numParts, numFps, newFT.GetSize()) + if err != nil { + jww.ERROR.Printf(errAddNewRt, tid, newFT.GetFileName(), err) + } + + ftl.addFingerprints(rt) + + // Call the reception callback + go ftl.receiveCB(&tid, newFT.GetFileName(), newFT.GetFileType(), msg.Sender, + newFT.GetSize(), newFT.GetPreview()) +} + +// addFingerprints adds all fingerprints for unreceived parts in the received +// transfer. +func (m *manager) addFingerprints(rt *store.ReceivedTransfer) { + // Build processor for each file part and add its fingerprint to receive on + for _, c := range rt.GetUnusedCyphers() { + p := &processor{ + Cypher: c, + ReceivedTransfer: rt, + manager: m, + } + + err := m.cmix.AddFingerprint(m.myID, c.GetFingerprint(), p) + if err != nil { + jww.ERROR.Printf("[FT] Failed to add fingerprint for transfer "+ + "%s: %+v", rt.TransferID(), err) + } + } +} + +// Name returns a name used for debugging. +func (ftl *fileTransferListener) Name() string { + return listenerName +} diff --git a/fileTransfer/manager.go b/fileTransfer/manager.go index 2235b98795bb4a63f3f5d908cf5f4bfcbd1e1017..c22b8bbaaa28a93bfa6ef3d4c9b65eeed5711103 100644 --- a/fileTransfer/manager.go +++ b/fileTransfer/manager.go @@ -8,27 +8,28 @@ package fileTransfer import ( + "bytes" "github.com/pkg/errors" - jww "github.com/spf13/jwalterweatherman" - "gitlab.com/elixxir/client/api" + "gitlab.com/elixxir/client/catalog" "gitlab.com/elixxir/client/cmix" - "gitlab.com/elixxir/client/interfaces" - "gitlab.com/elixxir/client/interfaces/message" + "gitlab.com/elixxir/client/cmix/message" + "gitlab.com/elixxir/client/e2e" + "gitlab.com/elixxir/client/e2e/receive" + "gitlab.com/elixxir/client/fileTransfer/callbackTracker" + "gitlab.com/elixxir/client/fileTransfer/store" + "gitlab.com/elixxir/client/fileTransfer/store/fileMessage" "gitlab.com/elixxir/client/stoppable" - "gitlab.com/elixxir/client/storage" - ftStorage "gitlab.com/elixxir/client/storage/fileTransfer" "gitlab.com/elixxir/client/storage/versioned" + e2eCrypto "gitlab.com/elixxir/crypto/e2e" "gitlab.com/elixxir/crypto/fastRNG" ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/elixxir/primitives/format" "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/id/ephemeral" "time" ) const ( - // PreviewMaxSize is the maximum size, in bytes, for a file preview. - // Currently, it is set to 4 kB. - PreviewMaxSize = 4_000 - // FileNameMaxLen is the maximum size, in bytes, for a file name. Currently, // it is set to 48 bytes. FileNameMaxLen = 48 @@ -41,402 +42,463 @@ const ( // it is set to 250 kB. FileMaxSize = 250_000 + // PreviewMaxSize is the maximum size, in bytes, for a file preview. + // Currently, it is set to 4 kB. + PreviewMaxSize = 4_000 + // minPartsSendPerRound is the minimum number of file parts sent each round. minPartsSendPerRound = 1 // maxPartsSendPerRound is the maximum number of file parts sent each round. maxPartsSendPerRound = 11 - // Size of the buffered channel that queues file parts to send - sendQueueBuffLen = 10_000 + // Size of the buffered channel that queues file parts to package + batchQueueBuffLen = 10_000 - // Size of the buffered channel that reports if the network is healthy - networkHealthBuffLen = 10 + // Size of the buffered channel that queues file packets to send + sendQueueBuffLen = 10_000 ) -// Error messages. +// Stoppable and listener values. const ( - // newManager - newManagerSentErr = "failed to load or create new list of sent file transfers: %+v" - newManagerReceivedErr = "failed to load or create new list of received file transfers: %+v" - - // Manager.Send - sendNetworkHealthErr = "cannot initiate file transfer of %q when network is not healthy." - fileNameSizeErr = "length of filename (%d) greater than max allowed length (%d)" - fileTypeSizeErr = "length of file type (%d) greater than max allowed length (%d)" - fileSizeErr = "size of file (%d bytes) greater than max allowed size (%d bytes)" - previewSizeErr = "size of preview (%d bytes) greater than max allowed size (%d bytes)" - getPartSizeErr = "failed to get file part size: %+v" - sendInitMsgErr = "failed to send initial file transfer message: %+v" - - // Manager.Resend - transferNotFailedErr = "transfer %s has not failed" - - // Manager.CloseSend - transferInProgressErr = "transfer %s has not completed or failed" + fileTransferStoppable = "FileTransfer" + workerPoolStoppable = "FilePartSendingWorkerPool" + batchBuilderThreadStoppable = "BatchBuilderThread" ) -// Stoppable and listener values. +// Error messages. const ( - rawMessageBuffSize = 10_000 - sendStoppableName = "FileTransferSend" - newFtStoppableName = "FileTransferNew" - newFtListenerName = "FileTransferNewListener" - filePartStoppableName = "FilePart" - filePartListenerName = "FilePartListener" - fileTransferStoppableName = "FileTransfer" + errNoSentTransfer = "could not find sent transfer with ID %s" + errNoReceivedTransfer = "could not find received transfer with ID %s" + + // NewManager + errNewOrLoadSent = "failed to load or create new list of sent file transfers: %+v" + errNewOrLoadReceived = "failed to load or create new list of received file transfers: %+v" + + // manager.Send + errFileNameSize = "length of filename (%d) greater than max allowed length (%d)" + errFileTypeSize = "length of file type (%d) greater than max allowed length (%d)" + errFileSize = "size of file (%d bytes) greater than max allowed size (%d bytes)" + errPreviewSize = "size of preview (%d bytes) greater than max allowed size (%d bytes)" + errSendNetworkHealth = "cannot initiate file transfer of %q when network is not healthy." + errNewKey = "could not generate new transfer key: %+v" + errNewID = "could not generate new transfer ID: %+v" + errSendNewMsg = "failed to send initial file transfer message: %+v" + errAddSentTransfer = "failed to add transfer: %+v" + + // manager.CloseSend + errDeleteIncompleteTransfer = "cannot delete transfer %s that has not completed or failed" + errDeleteSentTransfer = "could not delete sent transfer %s: %+v" + errRemoveSentTransfer = "could not remove transfer %s from list: %+v" + + // manager.Receive + errIncompleteFile = "cannot get incomplete file: missing %d of %d parts" + errDeleteReceivedTransfer = "could not delete received transfer %s: %+v" + errRemoveReceivedTransfer = "could not remove transfer %s from list: %+v" ) -// Manager is used to manage the sending and receiving of all file transfers. -type Manager struct { - // Callback that is called every time a new file transfer is received - receiveCB interfaces.ReceiveCallback - +// manager handles the sending and receiving of file, their storage, and their +// callbacks. +type manager struct { // Storage-backed structure for tracking sent file transfers - sent *ftStorage.SentFileTransfersStore + sent *store.Sent // Storage-backed structure for tracking received file transfers - received *ftStorage.ReceivedFileTransfersStore + received *store.Received + + // Callback that is called every time a new file transfer is received + receiveCB ReceiveCallback - // Queue of parts to send - sendQueue chan queuedPart + // Progress callback tracker + callbacks *callbackTracker.Manager - // Indicates if old transfers saved to storage have been recovered after - // file transfer is closed and reopened; this is an atomic - oldTransfersRecovered *uint32 + // Queue of parts to batch and send + batchQueue chan store.Part + + // Queue of batches of parts to send + sendQueue chan []store.Part // File transfer parameters - p Params - - // Client interfaces - client *api.Client - store *storage.Session - swb interfaces.Switchboard - net interfaces.NetworkManager - rng *fastRNG.StreamGenerator - getRoundResults getRoundResultsFunc + params Params + + myID *id.ID + cmix Cmix + e2e E2e + kv *versioned.KV + rng *fastRNG.StreamGenerator } -// getRoundResultsFunc is a function that matches client.GetRoundResults. It is -// used to pass in an alternative function for testing. -type getRoundResultsFunc func(roundList []id.Round, timeout time.Duration, - roundCallback cmix.RoundEventCallback) error +// MaxFileNameLen returns the max number of bytes allowed for a file name. +func (m *manager) MaxFileNameLen() int { + return FileNameMaxLen +} + +// MaxFileTypeLen returns the max number of bytes allowed for a file type. +func (m *manager) MaxFileTypeLen() int { + return FileTypeMaxLen +} + +// MaxFileSize returns the max number of bytes allowed for a file. +func (m *manager) MaxFileSize() int { + return FileMaxSize +} -// queuedPart contains the unique information identifying a file part. -type queuedPart struct { - tid ftCrypto.TransferID - partNum uint16 +// MaxPreviewSize returns the max number of bytes allowed for a file preview. +func (m *manager) MaxPreviewSize() int { + return PreviewMaxSize } -// NewManager produces a new empty file transfer Manager. Does not start sending -// and receiving services. -func NewManager(client *api.Client, receiveCB interfaces.ReceiveCallback, - p Params) (*Manager, error) { - return newManager(client, client.GetStorage(), client.GetSwitchboard(), - client.GetNetworkInterface(), client.GetRng(), client.GetRoundResults, - client.GetStorage().GetKV(), receiveCB, p) +type Cmix interface { + GetMaxMessageLength() int + SendMany(messages []cmix.TargetedCmixMessage, p cmix.CMIXParams) (id.Round, + []ephemeral.Id, error) + AddFingerprint(identity *id.ID, fingerprint format.Fingerprint, + mp message.Processor) error + DeleteFingerprint(identity *id.ID, fingerprint format.Fingerprint) + IsHealthy() bool + AddHealthCallback(f func(bool)) uint64 + RemoveHealthCallback(uint64) + GetRoundResults(timeout time.Duration, roundCallback cmix.RoundEventCallback, + roundList ...id.Round) error } -// newManager builds the manager from fields explicitly passed in. This function -// is a helper function for NewManager to make it easier to test. -func newManager(client *api.Client, store *storage.Session, - swb interfaces.Switchboard, net interfaces.NetworkManager, - rng *fastRNG.StreamGenerator, getRoundResults getRoundResultsFunc, - kv *versioned.KV, receiveCB interfaces.ReceiveCallback, p Params) ( - *Manager, error) { - - // Create a new list of sent file transfers or load one if it exists in - // storage - sent, err := ftStorage.NewOrLoadSentFileTransfersStore(kv) +type E2e interface { + SendE2E(mt catalog.MessageType, recipient *id.ID, payload []byte, + params e2e.Params) ([]id.Round, e2eCrypto.MessageID, time.Time, error) + RegisterListener(senderID *id.ID, messageType catalog.MessageType, + newListener receive.Listener) receive.ListenerID +} + +// NewManager creates a new file transfer manager object. If sent or received +// transfers already existed, they are loaded from storage and queued to resume +// once manager.startProcesses is called. +func NewManager(receiveCB ReceiveCallback, params Params, myID *id.ID, + cmix Cmix, e2e E2e, kv *versioned.KV, + rng *fastRNG.StreamGenerator) (FileTransfer, error) { + + // Create a new list of sent file transfers or load one if it exists + sent, unsentParts, err := store.NewOrLoadSent(kv) if err != nil { - return nil, errors.Errorf(newManagerSentErr, err) + return nil, errors.Errorf(errNewOrLoadSent, err) } - // Create a new list of received file transfers or load one if it exists in - // storage - received, err := ftStorage.NewOrLoadReceivedFileTransfersStore(kv) + // Create a new list of received file transfers or load one if it exists + received, incompleteTransfers, err := store.NewOrLoadReceived(kv) if err != nil { - return nil, errors.Errorf(newManagerReceivedErr, err) + return nil, errors.Errorf(errNewOrLoadReceived, err) } - jww.DEBUG.Printf(""+ - "[FT] Created new file transfer manager with params: %+v", p) - - oldTransfersRecovered := uint32(0) - - return &Manager{ - receiveCB: receiveCB, - sent: sent, - received: received, - sendQueue: make(chan queuedPart, sendQueueBuffLen), - oldTransfersRecovered: &oldTransfersRecovered, - p: p, - client: client, - store: store, - swb: swb, - net: net, - rng: rng, - getRoundResults: getRoundResults, - }, nil -} + // Construct manager + m := &manager{ + sent: sent, + received: received, + receiveCB: receiveCB, + callbacks: callbackTracker.NewManager(), + batchQueue: make(chan store.Part, batchQueueBuffLen), + sendQueue: make(chan []store.Part, sendQueueBuffLen), + params: params, + myID: myID, + cmix: cmix, + e2e: e2e, + kv: kv, + rng: rng, + } -// StartProcesses starts the processes needed to send and receive file parts. It -// starts three threads that (1) receives the initial NewFileTransfer E2E -// message; (2) receives each file part; and (3) sends file parts. It also -// registers the network health channel. -func (m *Manager) StartProcesses() (stoppable.Stoppable, error) { - // Create the two reception channels - newFtChan := make(chan message.Receive, rawMessageBuffSize) - filePartChan := make(chan message.Receive, rawMessageBuffSize) + // Add all unsent file parts to queue + for _, p := range unsentParts { + m.batchQueue <- p + } + + // Add all fingerprints for unreceived parts + for _, rt := range incompleteTransfers { + m.addFingerprints(rt) + } - return m.startProcesses(newFtChan, filePartChan) + return m, nil } -// startProcesses starts the sending and receiving processes with the provided -// channels. -func (m *Manager) startProcesses(newFtChan, filePartChan chan message.Receive) ( - stoppable.Stoppable, error) { - - // Register network health channel that is used by the sending thread to - // ensure the network is healthy before sending - healthyRecover := make(chan bool, networkHealthBuffLen) - healthyRecoverID := m.net.GetHealthTracker().AddChannel(healthyRecover) - healthySend := make(chan bool, networkHealthBuffLen) - healthySendID := m.net.GetHealthTracker().AddChannel(healthySend) - - // Recover unsent parts from storage - m.oldTransferRecovery(healthyRecover, healthyRecoverID) - - // Start the new file transfer message reception thread - newFtStop := stoppable.NewSingle(newFtStoppableName) - m.swb.RegisterChannel(newFtListenerName, &id.ID{}, - message.NewFileTransfer, newFtChan) - go m.receiveNewFileTransfer(newFtChan, newFtStop) - - // Start the file part message reception thread - filePartStop := stoppable.NewSingle(filePartStoppableName) - m.swb.RegisterChannel(filePartListenerName, &id.ID{}, message.Raw, - filePartChan) - go m.receive(filePartChan, filePartStop) - - // Start the file part sending thread - sendStop := stoppable.NewSingle(sendStoppableName) - go m.sendThread(sendStop, healthySend, healthySendID, getRandomNumParts) +// StartProcesses starts the sending threads. Adheres to the api.Service type. +func (m *manager) StartProcesses() (stoppable.Stoppable, error) { + // Register listener to receive new file transfers + m.e2e.RegisterListener( + m.myID, catalog.NewFileTransfer, &fileTransferListener{m}) + + // Construct stoppables + multiStop := stoppable.NewMulti(workerPoolStoppable) + batchBuilderStop := stoppable.NewSingle(batchBuilderThreadStoppable) + + // Start sending threads + go m.startSendingWorkerPool(multiStop) + go m.batchBuilderThread(batchBuilderStop) // Create a multi stoppable - multiStoppable := stoppable.NewMulti(fileTransferStoppableName) - multiStoppable.Add(newFtStop) - multiStoppable.Add(filePartStop) - multiStoppable.Add(sendStop) + multiStoppable := stoppable.NewMulti(fileTransferStoppable) + multiStoppable.Add(multiStop) + multiStoppable.Add(batchBuilderStop) return multiStoppable, nil } -// Send starts the sending of a file transfer to the recipient. It sends the -// initial NewFileTransfer E2E message to the recipient to inform them of the -// incoming file parts. It partitions the file, puts it into storage, and queues -// each file for sending. Returns a unique ID identifying the file transfer. -// Returns an error if the network is not healthy. -func (m Manager) Send(fileName, fileType string, fileData []byte, +// Send partitions the given file into cMix message sized chunks and sends them +// via cmix.SendMany. +func (m *manager) Send(fileName, fileType string, fileData []byte, recipient *id.ID, retry float32, preview []byte, - progressCB interfaces.SentProgressCallback, period time.Duration) ( - ftCrypto.TransferID, error) { - - // Return an error if the network is not healthy - if !m.net.GetHealthTracker().IsHealthy() { - return ftCrypto.TransferID{}, - errors.Errorf(sendNetworkHealthErr, fileName) - } + progressCB SentProgressCallback, period time.Duration) ( + *ftCrypto.TransferID, error) { // Return an error if the file name is too long if len(fileName) > FileNameMaxLen { - return ftCrypto.TransferID{}, errors.Errorf( - fileNameSizeErr, len(fileName), FileNameMaxLen) + return nil, errors.Errorf(errFileNameSize, len(fileName), FileNameMaxLen) } // Return an error if the file type is too long if len(fileType) > FileTypeMaxLen { - return ftCrypto.TransferID{}, errors.Errorf( - fileTypeSizeErr, len(fileType), FileTypeMaxLen) + return nil, errors.Errorf(errFileTypeSize, len(fileType), FileTypeMaxLen) } // Return an error if the file is too large if len(fileData) > FileMaxSize { - return ftCrypto.TransferID{}, errors.Errorf( - fileSizeErr, len(fileData), FileMaxSize) + return nil, errors.Errorf(errFileSize, len(fileData), FileMaxSize) } // Return an error if the preview is too large if len(preview) > PreviewMaxSize { - return ftCrypto.TransferID{}, errors.Errorf( - previewSizeErr, len(preview), PreviewMaxSize) + return nil, errors.Errorf(errPreviewSize, len(preview), PreviewMaxSize) } - // Generate new transfer key + // Return an error if the network is not healthy + if !m.cmix.IsHealthy() { + return nil, errors.Errorf(errSendNetworkHealth, fileName) + } + + // Generate new transfer key and transfer ID rng := m.rng.GetStream() - transferKey, err := ftCrypto.NewTransferKey(rng) + key, err := ftCrypto.NewTransferKey(rng) if err != nil { rng.Close() - return ftCrypto.TransferID{}, err + return nil, errors.Errorf(errNewKey, err) } - rng.Close() - - // Get the size of each file part - partSize, err := m.getPartSize() + tid, err := ftCrypto.NewTransferID(rng) if err != nil { - return ftCrypto.TransferID{}, errors.Errorf(getPartSizeErr, err) + rng.Close() + return nil, errors.Errorf(errNewID, err) } + rng.Close() // Generate transfer MAC - mac := ftCrypto.CreateTransferMAC(fileData, transferKey) + mac := ftCrypto.CreateTransferMAC(fileData, key) - // Partition the file into parts - parts := partitionFile(fileData, partSize) + // Get size of each part and partition file into equal length parts + partMessage := fileMessage.NewPartMessage(m.cmix.GetMaxMessageLength()) + parts := partitionFile(fileData, partMessage.GetPartSize()) numParts := uint16(len(parts)) fileSize := uint32(len(fileData)) // Send the initial file transfer message over E2E - err = m.sendNewFileTransfer(recipient, fileName, fileType, transferKey, mac, + err = m.sendNewFileTransferMessage(recipient, fileName, fileType, &key, mac, numParts, fileSize, retry, preview) if err != nil { - return ftCrypto.TransferID{}, errors.Errorf(sendInitMsgErr, err) + return nil, errors.Errorf(errSendNewMsg, err) } // Calculate the number of fingerprints to generate - numFps := calcNumberOfFingerprints(numParts, retry) + numFps := calcNumberOfFingerprints(len(parts), retry) - // Add the transfer to storage - rng = m.rng.GetStream() - tid, err := m.sent.AddTransfer( - recipient, transferKey, parts, numFps, progressCB, period, rng) + // Create new sent transfer + st, err := m.sent.AddTransfer(recipient, &key, &tid, fileName, parts, numFps) if err != nil { - return ftCrypto.TransferID{}, err + return nil, errors.Errorf(errAddSentTransfer, err) } - rng.Close() - jww.DEBUG.Printf("[FT] Sending new file transfer %s to %s {name: %s, "+ - "type: %q, size: %d, parts: %d, numFps: %d, retry: %f}", - tid, recipient, fileName, fileType, fileSize, numParts, numFps, retry) + // Add all parts to the send queue + for _, p := range st.GetUnsentParts() { + m.batchQueue <- p + } - // Add all parts to queue - m.queueParts(tid, makeListOfPartNums(numParts)) + // Register the progress callback + m.registerSentProgressCallback(st, progressCB, period) - return tid, nil + return &tid, nil } -// RegisterSentProgressCallback adds the sent progress callback to the sent -// transfer so that it will be called when updates for the transfer occur. The -// progress callback is called when initially added and on transfer updates, at -// most once per period. -func (m Manager) RegisterSentProgressCallback(tid ftCrypto.TransferID, - progressCB interfaces.SentProgressCallback, period time.Duration) error { - // Get the transfer for the given ID - transfer, err := m.sent.GetTransfer(tid) - if err != nil { - return err +// RegisterSentProgressCallback adds the given callback to the callback manager +// for the given transfer ID. Returns an error if the transfer cannot be found. +func (m *manager) RegisterSentProgressCallback(tid *ftCrypto.TransferID, + progressCB SentProgressCallback, period time.Duration) error { + st, exists := m.sent.GetTransfer(tid) + if !exists { + return errors.Errorf(errNoSentTransfer, tid) } - // Add the progress callback - transfer.AddProgressCB(progressCB, period) + m.registerSentProgressCallback(st, progressCB, period) return nil } -// Resend resends a file if sending fails. Returns an error if CloseSend -// was already called or if the transfer did not run out of retries. This -// function should only be called if the interfaces.SentProgressCallback returns -// an error. -// TODO: Need to implement Resend but there are some unanswered questions. -// - Can you resend? -// - Can you reuse fingerprints? -// - What to do if sendE2E fails? -func (m Manager) Resend(tid ftCrypto.TransferID) error { - // Get the transfer for the given ID - transfer, err := m.sent.GetTransfer(tid) - if err != nil { - return err +// registerSentProgressCallback creates a callback for the sent transfer that +// will get the most recent progress and send it on the progress callback. +func (m *manager) registerSentProgressCallback(st *store.SentTransfer, + progressCB SentProgressCallback, period time.Duration) { + if progressCB == nil { + return } - // Check if the transfer has run out of fingerprints, which occurs when the - // retry limit is reached - if transfer.GetNumAvailableFps() > 0 { - return errors.Errorf(transferNotFailedErr, tid) + // Build callback + cb := func(err error) { + // Get transfer progress + arrived, total := st.NumArrived(), st.NumParts() + completed := arrived == total + + // If the transfer is completed, send last message informing recipient + if completed { + m.sendEndFileTransferMessage(st.Recipient()) + } + + // Build part tracker from copy of part statuses vector + tracker := &sentFilePartTracker{st.CopyPartStatusVector()} + + // Call the progress callback + progressCB(completed, arrived, total, tracker, err) } - return nil + // Add the callback to the callback tracker + m.callbacks.AddCallback(st.TransferID(), cb, period) } -// CloseSend deletes a sent file transfer from the sent transfer map and from -// storage once a transfer has completed or reached the retry limit. Returns an -// error if the transfer has not run out of retries. -func (m Manager) CloseSend(tid ftCrypto.TransferID) error { - // Get the transfer for the given ID - st, err := m.sent.GetTransfer(tid) +// CloseSend deletes the sent transfer from storage and the sent transfer list. +// Also stops any scheduled progress callbacks and deletes them from the manager +// to prevent any further calls. Deletion only occurs if the transfer has either +// completed or failed. +func (m *manager) CloseSend(tid *ftCrypto.TransferID) error { + st, exists := m.sent.GetTransfer(tid) + if !exists { + return errors.Errorf(errNoSentTransfer, tid) + } + + // Check that the transfer is either completed or failed + if st.Status() != store.Completed && st.Status() != store.Failed { + return errors.Errorf(errDeleteIncompleteTransfer, tid) + } + + // Delete from storage + err := st.Delete() if err != nil { - return err + return errors.Errorf(errDeleteSentTransfer, tid, err) } - // Check if the transfer has completed or run out of fingerprints, which - // occurs when the retry limit is reached - completed, _, _, _, _ := st.GetProgress() - if st.GetNumAvailableFps() > 0 && !completed { - return errors.Errorf(transferInProgressErr, tid) + // Delete from transfers list + err = m.sent.RemoveTransfer(tid) + if err != nil { + return errors.Errorf(errRemoveSentTransfer, tid, err) } - jww.DEBUG.Printf("[FT] Closing file transfer %s to %s {completed: %t, "+ - "parts: %d, numFps: %d/%d,}", tid, st.GetRecipient(), completed, - st.GetNumParts(), st.GetNumFps()-st.GetNumAvailableFps(), st.GetNumFps()) + // Stop and delete all progress callbacks + m.callbacks.Delete(tid) - // Delete the transfer from storage - return m.sent.DeleteTransfer(tid) + return nil } -// Receive returns the fully assembled file on the completion of the transfer. -// It deletes the transfer from the received transfer map and from storage. -// Returns an error if the transfer is not complete, the full file cannot be -// verified, or if the transfer cannot be found. -func (m Manager) Receive(tid ftCrypto.TransferID) ([]byte, error) { - // Get the transfer for the given ID - rt, err := m.received.GetTransfer(tid) +// Receive concatenates the received file and returns it. Only returns the file +// if all file parts have been received and returns an error otherwise. Also +// deletes the transfer from storage. Once Receive has been called on a file, it +// cannot be received again. +func (m *manager) Receive(tid *ftCrypto.TransferID) ([]byte, error) { + rt, exists := m.received.GetTransfer(tid) + if !exists { + return nil, errors.Errorf(errNoReceivedTransfer, tid) + } + + // Return an error if the transfer is not complete + if rt.NumReceived() != rt.NumParts() { + return nil, errors.Errorf( + errIncompleteFile, rt.NumParts()-rt.NumReceived(), rt.NumParts()) + } + + // Get the file + file := rt.GetFile() + + // Delete all unused fingerprints + for _, c := range rt.GetUnusedCyphers() { + m.cmix.DeleteFingerprint(m.myID, c.GetFingerprint()) + } + + // Delete from storage + err := rt.Delete() if err != nil { - return nil, err + return nil, errors.Errorf(errDeleteReceivedTransfer, tid, err) } - // Get the file from the transfer - file, err := rt.GetFile() + // Delete from transfers list + err = m.received.RemoveTransfer(tid) if err != nil { - return nil, err + return nil, errors.Errorf(errRemoveReceivedTransfer, tid, err) } - jww.DEBUG.Printf("[FT] Receiver completed transfer %s {size: %d, "+ - "parts: %d, numFps: %d/%d}", tid, rt.GetFileSize(), rt.GetNumParts(), - rt.GetNumFps()-rt.GetNumAvailableFps(), rt.GetNumFps()) + // Stop and delete all progress callbacks + m.callbacks.Delete(tid) - // Return the file and delete the transfer from storage - return file, m.received.DeleteTransfer(tid) + return file, nil } -// RegisterReceivedProgressCallback adds the reception progress callback to the -// received transfer so that it will be called when updates for the transfer -// occur. The progress callback is called when initially added and on transfer -// updates, at most once per period. -func (m Manager) RegisterReceivedProgressCallback(tid ftCrypto.TransferID, - progressCB interfaces.ReceivedProgressCallback, period time.Duration) error { - // Get the transfer for the given ID - transfer, err := m.received.GetTransfer(tid) - if err != nil { - return err +// RegisterReceivedProgressCallback adds the given callback to the callback +// manager for the given transfer ID. Returns an error if the transfer cannot be +// found. +func (m *manager) RegisterReceivedProgressCallback(tid *ftCrypto.TransferID, + progressCB ReceivedProgressCallback, period time.Duration) error { + rt, exists := m.received.GetTransfer(tid) + if !exists { + return errors.Errorf(errNoReceivedTransfer, tid) } - // Add the progress callback - transfer.AddProgressCB(progressCB, period) + m.registerReceivedProgressCallback(rt, progressCB, period) return nil } +// registerReceivedProgressCallback creates a callback for the received transfer +// that will get the most recent progress and send it on the progress callback. +func (m *manager) registerReceivedProgressCallback(rt *store.ReceivedTransfer, + progressCB ReceivedProgressCallback, period time.Duration) { + if progressCB == nil { + return + } + + // Build callback + cb := func(err error) { + // Get transfer progress + received, total := rt.NumReceived(), rt.NumParts() + completed := received == total + + // Build part tracker from copy of part statuses vector + tracker := &receivedFilePartTracker{rt.CopyPartStatusVector()} + + // Call the progress callback + progressCB(completed, received, total, tracker, err) + } + + // Add the callback to the callback tracker + m.callbacks.AddCallback(rt.TransferID(), cb, period) +} + +// partitionFile splits the file into parts of the specified part size. +func partitionFile(file []byte, partSize int) [][]byte { + // Initialize part list to the correct size + numParts := (len(file) + partSize - 1) / partSize + parts := make([][]byte, 0, numParts) + buff := bytes.NewBuffer(file) + + for n := buff.Next(partSize); len(n) > 0; n = buff.Next(partSize) { + newPart := make([]byte, partSize) + copy(newPart, n) + parts = append(parts, newPart) + } + + return parts +} + // calcNumberOfFingerprints is the formula used to calculate the number of // fingerprints to generate, which is based off the number of file parts and the // retry float. -func calcNumberOfFingerprints(numParts uint16, retry float32) uint16 { +func calcNumberOfFingerprints(numParts int, retry float32) uint16 { return uint16(float32(numParts) * (1 + retry)) } diff --git a/fileTransfer/manager_test.go b/fileTransfer/manager_test.go index 63952f154487962fc6e504c56d58f6d0d00cda02..c92e79481516b1966344ff9c3a29b3a7316cd145 100644 --- a/fileTransfer/manager_test.go +++ b/fileTransfer/manager_test.go @@ -9,574 +9,51 @@ package fileTransfer import ( "bytes" - "errors" - "fmt" - "github.com/cloudflare/circl/dh/sidh" - "github.com/golang/protobuf/proto" - "gitlab.com/elixxir/client/interfaces" - "gitlab.com/elixxir/client/interfaces/message" - "gitlab.com/elixxir/client/interfaces/params" - ftStorage "gitlab.com/elixxir/client/storage/fileTransfer" - util "gitlab.com/elixxir/client/storage/utility" + "gitlab.com/elixxir/client/catalog" + "gitlab.com/elixxir/client/e2e/receive" "gitlab.com/elixxir/client/storage/versioned" - "gitlab.com/elixxir/crypto/diffieHellman" + "gitlab.com/elixxir/crypto/fastRNG" ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" "gitlab.com/elixxir/ekv" "gitlab.com/xx_network/crypto/csprng" "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/netTime" + "math/rand" "reflect" - "strings" "sync" "testing" "time" ) -// Tests that newManager does not return errors, that the sent and received -// transfer lists are new, and that the callback works. -func Test_newManager(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - - cbChan := make(chan bool) - cb := func(ftCrypto.TransferID, string, string, *id.ID, uint32, []byte) { - cbChan <- true - } - - m, err := newManager(nil, nil, nil, nil, nil, nil, kv, cb, DefaultParams()) - if err != nil { - t.Errorf("newManager returned an error: %+v", err) - } - - // Check that the SentFileTransfersStore is new and correct - expectedSent, _ := ftStorage.NewSentFileTransfersStore(kv) - if !reflect.DeepEqual(expectedSent, m.sent) { - t.Errorf("SentFileTransfersStore in manager incorrect."+ - "\nexpected: %+v\nreceived: %+v", expectedSent, m.sent) - } - - // Check that the ReceivedFileTransfersStore is new and correct - expectedReceived, _ := ftStorage.NewReceivedFileTransfersStore(kv) - if !reflect.DeepEqual(expectedReceived, m.received) { - t.Errorf("ReceivedFileTransfersStore in manager incorrect."+ - "\nexpected: %+v\nreceived: %+v", expectedReceived, m.received) - } - - // Check that the callback is called - go m.receiveCB(ftCrypto.TransferID{}, "", "", nil, 0, nil) - select { - case <-cbChan: - case <-time.NewTimer(time.Millisecond).C: - t.Error("Timed out waiting for callback to be called") - } +// Tests that manager adheres to the FileTransfer interface. +func Test_manager_FileTransferInterface(t *testing.T) { + var _ FileTransfer = (*manager)(nil) } -// Tests that Manager.Send adds a new sent transfer, sends the NewFileTransfer -// E2E message, and adds all the file parts to the queue. -func TestManager_Send(t *testing.T) { - m := newTestManager(false, nil, nil, nil, nil, t) - prng := NewPrng(42) - recipient := id.NewIdFromString("recipient", id.User, t) - fileName := "testFile" - fileType := "txt" - numParts := uint16(16) - partSize, _ := m.getPartSize() - fileData, _ := newFile(numParts, partSize, prng, t) - preview := []byte("filePreview") - retry := float32(1.5) - numFps := calcNumberOfFingerprints(numParts, retry) - - rng := csprng.NewSystemRNG() - dhKey := m.store.E2e().GetGroup().NewInt(42) - pubKey := diffieHellman.GeneratePublicKey(dhKey, m.store.E2e().GetGroup()) - _, mySidhPriv := util.GenerateSIDHKeyPair(sidh.KeyVariantSidhA, rng) - theirSidhPub, _ := util.GenerateSIDHKeyPair(sidh.KeyVariantSidhB, rng) - p := params.GetDefaultE2ESessionParams() - - err := m.store.E2e().AddPartner(recipient, pubKey, dhKey, - mySidhPriv, theirSidhPub, p, p) - if err != nil { - t.Errorf("Failed to add partner %s: %+v", recipient, err) - } +// Tests that partitionFile partitions the given file into the expected parts. +func Test_partitionFile(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + partSize := 96 + fileData, expectedParts := newFile(24, partSize, prng, t) - tid, err := m.Send( - fileName, fileType, fileData, recipient, retry, preview, nil, 0) - if err != nil { - t.Errorf("Send returned an error: %+v", err) - } - - //// - // Check if the transfer exists - //// - transfer, err := m.sent.GetTransfer(tid) - if err != nil { - t.Errorf("Failed to get transfer %s: %+v", tid, err) - } + receivedParts := partitionFile(fileData, partSize) - if !recipient.Cmp(transfer.GetRecipient()) { - t.Errorf("New transfer has incorrect recipient."+ - "\nexpected: %s\nreceoved: %s", recipient, transfer.GetRecipient()) - } - if transfer.GetNumParts() != numParts { - t.Errorf("New transfer has incorrect number of parts."+ - "\nexpected: %d\nreceived: %d", numParts, transfer.GetNumParts()) - } - if transfer.GetNumFps() != numFps { - t.Errorf("New transfer has incorrect number of fingerprints."+ - "\nexpected: %d\nreceived: %d", numFps, transfer.GetNumFps()) + if !reflect.DeepEqual(expectedParts, receivedParts) { + t.Errorf("File parts do not match expected."+ + "\nexpected: %q\nreceived: %q", expectedParts, receivedParts) } - //// - // Get NewFileTransfer E2E message - //// - sendMsg := m.net.(*testNetworkManager).GetE2eMsg(0) - if sendMsg.MessageType != message.NewFileTransfer { - t.Errorf("E2E message has wrong MessageType.\nexpected: %d\nreceived: %d", - message.NewFileTransfer, sendMsg.MessageType) - } - if !sendMsg.Recipient.Cmp(recipient) { - t.Errorf("E2E message has wrong Recipient.\nexpected: %s\nreceived: %s", - recipient, sendMsg.Recipient) - } - receivedNFT := &NewFileTransfer{} - err = proto.Unmarshal(sendMsg.Payload, receivedNFT) - if err != nil { - t.Errorf("Failed to unmarshal received NewFileTransfer: %+v", err) - } - expectedNFT := &NewFileTransfer{ - FileName: fileName, - FileType: fileType, - TransferKey: transfer.GetTransferKey().Bytes(), - TransferMac: ftCrypto.CreateTransferMAC(fileData, transfer.GetTransferKey()), - NumParts: uint32(numParts), - Size: uint32(len(fileData)), - Retry: retry, - Preview: preview, - } - if !proto.Equal(expectedNFT, receivedNFT) { - t.Errorf("Received NewFileTransfer message does not match expected."+ - "\nexpected: %+v\nreceived: %+v", expectedNFT, receivedNFT) - } - - //// - // Check queued parts - //// - if len(m.sendQueue) != int(numParts) { - t.Errorf("Failed to add all file parts to queue."+ - "\nexpected: %d\nreceived: %d", numParts, len(m.sendQueue)) - } -} - -// Error path: tests that Manager.Send returns the expected error when the -// network is not healthy. -func TestManager_Send_NetworkHealthError(t *testing.T) { - m := newTestManager(false, nil, nil, nil, nil, t) - - fileName := "MySentFile" - expectedErr := fmt.Sprintf(sendNetworkHealthErr, fileName) - - m.net.(*testNetworkManager).health.healthy = false - - recipient := id.NewIdFromString("recipient", id.User, t) - _, err := m.Send(fileName, "", nil, recipient, 0, nil, nil, 0) - if err == nil || err.Error() != expectedErr { - t.Errorf("Send did not return the expected error when the network is "+ - "not healthy.\nexpected: %s\nreceived: %+v", expectedErr, err) - } -} - -// Error path: tests that Manager.Send returns the expected error when the -// provided file name is longer than FileNameMaxLen. -func TestManager_Send_FileNameLengthError(t *testing.T) { - m := newTestManager(false, nil, nil, nil, nil, t) - - fileName := strings.Repeat("A", FileNameMaxLen+1) - expectedErr := fmt.Sprintf(fileNameSizeErr, len(fileName), FileNameMaxLen) - - _, err := m.Send(fileName, "", nil, nil, 0, nil, nil, 0) - if err == nil || err.Error() != expectedErr { - t.Errorf("Send did not return the expected error when the file name "+ - "is too long.\nexpected: %s\nreceived: %+v", expectedErr, err) - } -} - -// Error path: tests that Manager.Send returns the expected error when the -// provided file type is longer than FileTypeMaxLen. -func TestManager_Send_FileTypeLengthError(t *testing.T) { - m := newTestManager(false, nil, nil, nil, nil, t) - - fileType := strings.Repeat("A", FileTypeMaxLen+1) - expectedErr := fmt.Sprintf(fileTypeSizeErr, len(fileType), FileTypeMaxLen) - - _, err := m.Send("", fileType, nil, nil, 0, nil, nil, 0) - if err == nil || err.Error() != expectedErr { - t.Errorf("Send did not return the expected error when the file type "+ - "is too long.\nexpected: %s\nreceived: %+v", expectedErr, err) + fullFile := bytes.Join(receivedParts, nil) + if !bytes.Equal(fileData, fullFile) { + t.Errorf("Full file does not match expected."+ + "\nexpected: %q\nreceived: %q", fileData, fullFile) } } -// Error path: tests that Manager.Send returns the expected error when the -// provided file is larger than FileMaxSize. -func TestManager_Send_FileSizeError(t *testing.T) { - m := newTestManager(false, nil, nil, nil, nil, t) - - fileData := make([]byte, FileMaxSize+1) - expectedErr := fmt.Sprintf(fileSizeErr, len(fileData), FileMaxSize) - - _, err := m.Send("", "", fileData, nil, 0, nil, nil, 0) - if err == nil || err.Error() != expectedErr { - t.Errorf("Send did not return the expected error when the file data "+ - "is too large.\nexpected: %s\nreceived: %+v", expectedErr, err) - } -} - -// Error path: tests that Manager.Send returns the expected error when the -// provided preview is larger than PreviewMaxSize. -func TestManager_Send_PreviewSizeError(t *testing.T) { - m := newTestManager(false, nil, nil, nil, nil, t) - - previewData := make([]byte, PreviewMaxSize+1) - expectedErr := fmt.Sprintf(previewSizeErr, len(previewData), PreviewMaxSize) - - _, err := m.Send("", "", nil, nil, 0, previewData, nil, 0) - if err == nil || err.Error() != expectedErr { - t.Errorf("Send did not return the expected error when the preview "+ - "data is too large.\nexpected: %s\nreceived: %+v", expectedErr, err) - } -} - -// Error path: tests that Manager.Send returns the expected error when the E2E -// message fails to send. -func TestManager_Send_SendE2eError(t *testing.T) { - m := newTestManager(true, nil, nil, nil, nil, t) - prng := NewPrng(42) - recipient := id.NewIdFromString("recipient", id.User, t) - fileName := "testFile" - fileType := "bytes" - numParts := uint16(16) - partSize, _ := m.getPartSize() - fileData, _ := newFile(numParts, partSize, prng, t) - preview := []byte("filePreview") - retry := float32(1.5) - - rng := csprng.NewSystemRNG() - dhKey := m.store.E2e().GetGroup().NewInt(42) - pubKey := diffieHellman.GeneratePublicKey(dhKey, m.store.E2e().GetGroup()) - _, mySidhPriv := util.GenerateSIDHKeyPair(sidh.KeyVariantSidhA, rng) - theirSidhPub, _ := util.GenerateSIDHKeyPair(sidh.KeyVariantSidhB, rng) - p := params.GetDefaultE2ESessionParams() - - err := m.store.E2e().AddPartner(recipient, pubKey, dhKey, - mySidhPriv, theirSidhPub, p, p) - if err != nil { - t.Errorf("Failed to add partner %s: %+v", recipient, err) - } - - expectedErr := fmt.Sprintf(newFtSendE2eErr, recipient, "") - - _, err = m.Send( - fileName, fileType, fileData, recipient, retry, preview, nil, 0) - if err == nil || !strings.Contains(err.Error(), expectedErr) { - t.Errorf("Send did not return the expected error when the E2E message "+ - "failed to send.\nexpected: %s\nreceived: %+v", expectedErr, err) - } -} - -// Tests that Manager.RegisterSentProgressCallback calls the callback when it is -// added to the transfer and that the callback is associated with the expected -// transfer and is called when calling from the transfer. -func TestManager_RegisterSentProgressCallback(t *testing.T) { - m, sti, _ := newTestManagerWithTransfers( - []uint16{12, 4, 1}, false, false, nil, nil, nil, t) - expectedErr := errors.New("CallbackError") - - // Create new callback and channel for the callback to trigger - cbChan := make(chan sentProgressResults, 6) - cb := func(completed bool, sent, arrived, total uint16, - tr interfaces.FilePartTracker, err error) { - cbChan <- sentProgressResults{completed, sent, arrived, total, tr, err} - } - - // Start thread waiting for callback to be called - done0, done1 := make(chan bool), make(chan bool) - go func() { - for i := 0; i < 2; i++ { - select { - case <-time.NewTimer(20 * time.Millisecond).C: - t.Errorf("Timed out waiting for callback call #%d.", i) - case r := <-cbChan: - switch i { - case 0: - err := checkSentProgress(r.completed, r.sent, r.arrived, - r.total, false, 0, 0, sti[0].numParts) - if err != nil { - t.Errorf("%d: %+v", i, err) - } - if r.err != nil { - t.Errorf("Callback returned an error (%d): %+v", i, r.err) - } - done0 <- true - case 1: - if r.err == nil || r.err != expectedErr { - t.Errorf("Callback did not return the expected error (%d)."+ - "\nexpected: %v\nreceived: %+v", i, expectedErr, r.err) - } - done1 <- true - } - } - } - }() - - err := m.RegisterSentProgressCallback(sti[0].tid, cb, 1*time.Millisecond) - if err != nil { - t.Errorf("RegisterSentProgressCallback returned an error: %+v", err) - } - <-done0 - - transfer, _ := m.sent.GetTransfer(sti[0].tid) - - transfer.CallProgressCB(expectedErr) - - <-done1 -} - -// Error path: tests that Manager.RegisterSentProgressCallback returns an error -// when no transfer with the ID exists. -func TestManager_RegisterSentProgressCallback_NoTransferError(t *testing.T) { - m := newTestManager(false, nil, nil, nil, nil, t) - tid := ftCrypto.UnmarshalTransferID([]byte("invalidID")) - - err := m.RegisterSentProgressCallback(tid, nil, 0) - if err == nil { - t.Error("RegisterSentProgressCallback did not return an error when " + - "no transfer with the ID exists.") - } -} - -func TestManager_Resend(t *testing.T) { - -} - -// Error path: tests that Manager.Resend returns an error when no transfer with -// the ID exists. -func TestManager_Resend_NoTransferError(t *testing.T) { - m := newTestManager(false, nil, nil, nil, nil, t) - tid := ftCrypto.UnmarshalTransferID([]byte("invalidID")) - - err := m.Resend(tid) - if err == nil { - t.Error("Resend did not return an error when no transfer with the " + - "ID exists.") - } -} - -// Error path: tests that Manager.Resend returns the error when the transfer has -// not run out of fingerprints. -func TestManager_Resend_NoFingerprints(t *testing.T) { - m, sti, _ := newTestManagerWithTransfers( - []uint16{16}, false, false, nil, nil, nil, t) - expectedErr := fmt.Sprintf(transferNotFailedErr, sti[0].tid) - // Delete the transfer - err := m.Resend(sti[0].tid) - if err == nil || err.Error() != expectedErr { - t.Errorf("Resend did not return the expected error when the transfer "+ - "has not run out of fingerprints.\nexpected: %s\nreceived: %v", - expectedErr, err) - } -} - -// Tests that Manager.CloseSend deletes the transfer when it has run out of -// fingerprints but is not complete. -func TestManager_CloseSend_NoFingerprints(t *testing.T) { - m, sti, _ := newTestManagerWithTransfers( - []uint16{16}, false, false, nil, nil, nil, t) - partSize, _ := m.getPartSize() - - // Use up all the fingerprints in the transfer - transfer, _ := m.sent.GetTransfer(sti[0].tid) - for fpNum := uint16(0); fpNum < sti[0].numFps; fpNum++ { - partNum := fpNum % sti[0].numParts - _, _, _, err := transfer.GetEncryptedPart(partNum, partSize+2) - if err != nil { - t.Errorf("Failed to encrypt part %d (%d): %+v", partNum, fpNum, err) - } - } - - // Delete the transfer - err := m.CloseSend(sti[0].tid) - if err != nil { - t.Errorf("CloseSend returned an error: %+v", err) - } - - // Check that the transfer was deleted - _, err = m.sent.GetTransfer(sti[0].tid) - if err == nil { - t.Errorf("Failed to delete transfer %s.", sti[0].tid) - } -} - -// Tests that Manager.CloseSend deletes the transfer when it completed but has -// fingerprints. -func TestManager_CloseSend_Complete(t *testing.T) { - m, sti, _ := newTestManagerWithTransfers( - []uint16{3}, false, false, nil, nil, nil, t) - - // Set all parts to finished - transfer, _ := m.sent.GetTransfer(sti[0].tid) - _, err := transfer.SetInProgress(0, 0, 1, 2) - if err != nil { - t.Errorf("Failed to set parts to in-progress: %+v", err) - } - complete, err := transfer.FinishTransfer(0) - if err != nil { - t.Errorf("Failed to set parts to finished: %+v", err) - } - - // Ensure that FinishTransfer reported the transfer as complete - if !complete { - t.Error("FinishTransfer did not report the transfer as complete.") - } - - // Delete the transfer - err = m.CloseSend(sti[0].tid) - if err != nil { - t.Errorf("CloseSend returned an error: %+v", err) - } - - // Check that the transfer was deleted - _, err = m.sent.GetTransfer(sti[0].tid) - if err == nil { - t.Errorf("Failed to delete transfer %s.", sti[0].tid) - } -} - -// Error path: tests that Manager.CloseSend returns an error when no transfer -// with the ID exists. -func TestManager_CloseSend_NoTransferError(t *testing.T) { - m := newTestManager(false, nil, nil, nil, nil, t) - tid := ftCrypto.UnmarshalTransferID([]byte("invalidID")) - - err := m.CloseSend(tid) - if err == nil { - t.Error("CloseSend did not return an error when no transfer with the " + - "ID exists.") - } -} - -// Error path: tests that Manager.CloseSend returns an error when the transfer -// has not run out of fingerprints and is not complete -func TestManager_CloseSend_NotCompleteErr(t *testing.T) { - m, sti, _ := newTestManagerWithTransfers( - []uint16{16}, false, false, nil, nil, nil, t) - expectedErr := fmt.Sprintf(transferInProgressErr, sti[0].tid) - - err := m.CloseSend(sti[0].tid) - if err == nil || err.Error() != expectedErr { - t.Errorf("CloseSend did not return the expected error when the transfer"+ - "is not complete.\nexpected: %s\nreceived: %+v", expectedErr, err) - } -} - -// Error path: tests that Manager.Receive returns an error when no transfer with -// the ID exists. -func TestManager_Receive_NoTransferError(t *testing.T) { - m := newTestManager(false, nil, nil, nil, nil, t) - tid := ftCrypto.UnmarshalTransferID([]byte("invalidID")) - - _, err := m.Receive(tid) - if err == nil { - t.Error("Receive did not return an error when no transfer with the ID " + - "exists.") - } -} - -// Error path: tests that Manager.Receive returns an error when the file is -// incomplete. -func TestManager_Receive_GetFileError(t *testing.T) { - m, _, rti := newTestManagerWithTransfers( - []uint16{12, 4, 1}, false, false, nil, nil, nil, t) - - _, err := m.Receive(rti[0].tid) - if err == nil || !strings.Contains(err.Error(), "missing") { - t.Error("Receive did not return the expected error when no transfer " + - "with the ID exists.") - } -} - -// Tests that Manager.RegisterReceivedProgressCallback calls the callback when -// it is added to the transfer and that the callback is associated with the -// expected transfer and is called when calling from the transfer. -func TestManager_RegisterReceivedProgressCallback(t *testing.T) { - m, _, rti := newTestManagerWithTransfers( - []uint16{12, 4, 1}, false, false, nil, nil, nil, t) - expectedErr := errors.New("CallbackError") - - // Create new callback and channel for the callback to trigger - cbChan := make(chan receivedProgressResults, 6) - cb := func(completed bool, received, total uint16, - tr interfaces.FilePartTracker, err error) { - cbChan <- receivedProgressResults{completed, received, total, tr, err} - } - - // Start thread waiting for callback to be called - done0, done1 := make(chan bool), make(chan bool) - go func() { - for i := 0; i < 2; i++ { - select { - case <-time.NewTimer(20 * time.Millisecond).C: - t.Errorf("Timed out waiting for callback call #%d.", i) - case r := <-cbChan: - switch i { - case 0: - err := checkReceivedProgress(r.completed, r.received, - r.total, false, 0, rti[0].numParts) - if err != nil { - t.Errorf("%d: %+v", i, err) - } - if r.err != nil { - t.Errorf("Callback returned an error (%d): %+v", i, r.err) - } - done0 <- true - case 1: - if r.err == nil || r.err != expectedErr { - t.Errorf("Callback did not return the expected error (%d)."+ - "\nexpected: %v\nreceived: %+v", i, expectedErr, r.err) - } - done1 <- true - } - } - } - }() - - err := m.RegisterReceivedProgressCallback(rti[0].tid, cb, time.Millisecond) - if err != nil { - t.Errorf("RegisterReceivedProgressCallback returned an error: %+v", err) - } - <-done0 - - transfer, _ := m.received.GetTransfer(rti[0].tid) - - transfer.CallProgressCB(expectedErr) - - <-done1 -} - -// Error path: tests that Manager.RegisterReceivedProgressCallback returns an -// error when no transfer with the ID exists. -func TestManager_RegisterReceivedProgressCallback_NoTransferError(t *testing.T) { - m := newTestManager(false, nil, nil, nil, nil, t) - tid := ftCrypto.UnmarshalTransferID([]byte("invalidID")) - - err := m.RegisterReceivedProgressCallback(tid, nil, 0) - if err == nil { - t.Error("RegisterReceivedProgressCallback did not return an error " + - "when no transfer with the ID exists.") - } -} - -// Tests that calcNumberOfFingerprints matches some manually calculated -// results. +// Tests that calcNumberOfFingerprints matches some manually calculated results. func Test_calcNumberOfFingerprints(t *testing.T) { testValues := []struct { - numParts uint16 + numParts int retry float32 result uint16 }{ @@ -598,204 +75,175 @@ func Test_calcNumberOfFingerprints(t *testing.T) { } } -// Tests that Manager satisfies the interfaces.FileTransfer interface. -func TestManager_FileTransferInterface(t *testing.T) { - var _ interfaces.FileTransfer = Manager{} -} +// Smoke test of the entire file transfer system. +func Test_FileTransfer_Smoke(t *testing.T) { + // Set up cMix and E2E message handlers + cMixHandler := newMockCmixHandler() + e2eHandler := newMockE2eHandler() + rngGen := fastRNG.NewStreamGenerator(1000, 10, csprng.NewSystemRNG) + params := DefaultParams() -// Sets up a mock file transfer with two managers that sends one file from one -// to another. -func Test_FileTransfer(t *testing.T) { - var wg sync.WaitGroup + type receiveCbValues struct { + tid *ftCrypto.TransferID + fileName string + fileType string + sender *id.ID + size uint32 + preview []byte + } - // Create callback with channel for receiving new file transfer - receiveNewCbChan := make(chan receivedFtResults, 100) - receiveNewCB := func(tid ftCrypto.TransferID, fileName, fileType string, + // Set up the first client + receiveCbChan1 := make(chan receiveCbValues, 10) + receiveCB1 := func(tid *ftCrypto.TransferID, fileName, fileType string, sender *id.ID, size uint32, preview []byte) { - receiveNewCbChan <- receivedFtResults{ + receiveCbChan1 <- receiveCbValues{ tid, fileName, fileType, sender, size, preview} } - - // Create reception channels for both managers - newFtChan1 := make(chan message.Receive, rawMessageBuffSize) - filePartChan1 := make(chan message.Receive, rawMessageBuffSize) - newFtChan2 := make(chan message.Receive, rawMessageBuffSize) - filePartChan2 := make(chan message.Receive, rawMessageBuffSize) - - // Generate sending and receiving managers - m1 := newTestManager(false, filePartChan2, newFtChan2, nil, nil, t) - m2 := newTestManager(false, filePartChan1, newFtChan1, receiveNewCB, nil, t) - - // Add partner - dhKey := m1.store.E2e().GetGroup().NewInt(42) - pubKey := diffieHellman.GeneratePublicKey(dhKey, m1.store.E2e().GetGroup()) - p := params.GetDefaultE2ESessionParams() - recipient := id.NewIdFromString("recipient", id.User, t) - - rng := csprng.NewSystemRNG() - _, mySidhPriv := util.GenerateSIDHKeyPair(sidh.KeyVariantSidhA, - rng) - theirSidhPub, _ := util.GenerateSIDHKeyPair( - sidh.KeyVariantSidhB, rng) - - err := m1.store.E2e().AddPartner(recipient, pubKey, dhKey, - mySidhPriv, theirSidhPub, p, p) + myID1 := id.NewIdFromString("myID1", id.User, t) + kv1 := versioned.NewKV(make(ekv.Memstore)) + endE2eChan1 := make(chan receive.Message, 3) + e2e1 := newMockE2e(myID1, e2eHandler) + e2e1.RegisterListener( + myID1, catalog.EndFileTransfer, newMockListener(endE2eChan1)) + ftm1, err := NewManager(receiveCB1, params, myID1, + newMockCmix(myID1, cMixHandler), e2e1, kv1, rngGen) if err != nil { - t.Errorf("Failed to add partner %s: %+v", recipient, err) + t.Errorf("Failed to create new file transfer manager 1: %+v", err) } + m1 := ftm1.(*manager) - stop1, err := m1.startProcesses(newFtChan1, filePartChan1) + stop1, err := m1.StartProcesses() if err != nil { - t.Errorf("Failed to start processes for sending manager: %+v", err) + t.Errorf("Failed to start processes for manager 1: %+v", err) } - stop2, err := m2.startProcesses(newFtChan2, filePartChan2) - if err != nil { - t.Errorf("Failed to start processes for receving manager: %+v", err) + // Set up the second client + receiveCbChan2 := make(chan receiveCbValues, 10) + receiveCB2 := func(tid *ftCrypto.TransferID, fileName, fileType string, + sender *id.ID, size uint32, preview []byte) { + receiveCbChan2 <- receiveCbValues{ + tid, fileName, fileType, sender, size, preview} } - - // Create progress tracker for sending - sentCbChan := make(chan sentProgressResults, 20) - sentCb := func(completed bool, sent, arrived, total uint16, - tr interfaces.FilePartTracker, err error) { - sentCbChan <- sentProgressResults{completed, sent, arrived, total, tr, err} + myID2 := id.NewIdFromString("myID2", id.User, t) + kv2 := versioned.NewKV(make(ekv.Memstore)) + endE2eChan2 := make(chan receive.Message, 3) + e2e2 := newMockE2e(myID1, e2eHandler) + e2e2.RegisterListener( + myID2, catalog.EndFileTransfer, newMockListener(endE2eChan2)) + ftm2, err := NewManager(receiveCB2, params, myID2, + newMockCmix(myID2, cMixHandler), e2e2, kv2, rngGen) + if err != nil { + t.Errorf("Failed to create new file transfer manager 2: %+v", err) } + m2 := ftm2.(*manager) - // Start threads that tracks sent progress until complete - wg.Add(1) - go func() { - defer wg.Done() - for i := 0; i < 20; i++ { - select { - case <-time.NewTimer(250 * time.Millisecond).C: - t.Errorf("Timed out waiting for sent progress callback %d.", i) - case r := <-sentCbChan: - if r.completed { - return - } - } - } - t.Error("Sent progress callback never reported file finishing to send.") - }() - - // Create file and parameters - prng := NewPrng(42) - partSize, _ := m1.getPartSize() - fileName := "testFile" - fileType := "file" - file, parts := newFile(32, partSize, prng, t) - preview := parts[0] - - // Send file - sendTid, err := m1.Send(fileName, fileType, file, recipient, 0.5, preview, - sentCb, time.Millisecond) + stop2, err := m2.StartProcesses() if err != nil { - t.Errorf("Send returned an error: %+v", err) + t.Errorf("Failed to start processes for manager 2: %+v", err) } - // Wait for the receiving manager to get E2E message and call callback to - // get transfer ID of received transfer - var receiveTid ftCrypto.TransferID - select { - case <-time.NewTimer(10 * time.Millisecond).C: - t.Error("Timed out waiting for receive callback: ") - case r := <-receiveNewCbChan: - if !bytes.Equal(preview, r.preview) { - t.Errorf("File preview received from callback incorrect."+ - "\nexpected: %q\nreceived: %q", preview, r.preview) - } - if len(file) != int(r.size) { - t.Errorf("File size received from callback incorrect."+ - "\nexpected: %d\nreceived: %d", len(file), r.size) - } - receiveTid = r.tid - } + // Wait group prevents the test from quiting before the file has completed + // sending and receiving + var wg sync.WaitGroup - // Register progress callback with receiving manager - receiveCbChan := make(chan receivedProgressResults, 100) - receiveCb := func(completed bool, received, total uint16, - tr interfaces.FilePartTracker, err error) { - receiveCbChan <- receivedProgressResults{completed, received, total, tr, err} - } + // Define details of file to send + fileName, fileType := "myFile", "txt" + fileData := []byte(loreumIpsum) + preview := []byte("Lorem ipsum dolor sit amet") + retry := float32(2.0) - // Start threads that tracks received progress until complete + // Create go func that waits for file transfer to be received to register + // a progress callback that then checks that the file received is correct + // when done wg.Add(1) + var called bool + timeReceived := make(chan time.Time) go func() { - defer wg.Done() - for i := 0; i < 20; i++ { - select { - case <-time.NewTimer(450 * time.Millisecond).C: - t.Errorf("Timed out waiting for receive progress callback %d.", i) - case r := <-receiveCbChan: - if r.completed { - // Count the number of parts marked as received - count := 0 - for j := uint16(0); j < r.total; j++ { - if r.tracker.GetPartStatus(j) == interfaces.FpReceived { - count++ - } + select { + case r := <-receiveCbChan2: + receiveProgressCB := func(completed bool, received, total uint16, + fpt FilePartTracker, err error) { + if completed && !called { + timeReceived <- netTime.Now() + receivedFile, err2 := m2.Receive(r.tid) + if err2 != nil { + t.Errorf("Failed to receive file: %+v", err2) } - // Ensure that the number of parts received reported by the - // callback matches the number marked received - if count != int(r.received) { - t.Errorf("Number of parts marked received does not "+ - "match number reported by callback."+ - "\nmarked: %d\ncallback: %d", count, r.received) + if !bytes.Equal(fileData, receivedFile) { + t.Errorf("Received file does not match sent."+ + "\nsent: %q\nreceived: %q", + fileData, receivedFile) } - - return + wg.Done() } } + err3 := m2.RegisterReceivedProgressCallback( + r.tid, receiveProgressCB, 0) + if err3 != nil { + t.Errorf( + "Failed to Rregister received progress callback: %+v", err3) + } + case <-time.After(2100 * time.Millisecond): + t.Errorf("Timed out waiting to receive new file transfer.") + wg.Done() } - t.Error("Receive progress callback never reported file finishing to receive.") }() - err = m2.RegisterReceivedProgressCallback( - receiveTid, receiveCb, time.Millisecond) + // Define sent progress callback + wg.Add(1) + sentProgressCb1 := func(completed bool, arrived, total uint16, + fpt FilePartTracker, err error) { + if completed { + wg.Done() + } + } + + // Send file. + sendStart := netTime.Now() + tid1, err := m1.Send( + fileName, fileType, fileData, myID2, retry, preview, sentProgressCb1, 0) if err != nil { - t.Errorf("Failed to register receive progress callback: %+v", err) + t.Errorf("Failed to send file: %+v", err) } + go func() { + select { + case tr := <-timeReceived: + fileSize := len(fileData) + sendTime := tr.Sub(sendStart) + fileSizeKb := float32(fileSize) * .001 + speed := fileSizeKb * float32(time.Second) / (float32(sendTime)) + t.Logf("Completed sending file %q in %s (%.2f kb @ %.2f kb/s).", + fileName, sendTime, fileSizeKb, speed) + } + }() + + // Wait for file to be sent and received wg.Wait() - // Check that the file can be received - receivedFile, err := m2.Receive(receiveTid) + err = m1.CloseSend(tid1) if err != nil { - t.Errorf("Failed to receive file: %+v", err) - } - - // Check that the received file matches the sent file - if !bytes.Equal(file, receivedFile) { - t.Errorf("Received file does not match sent."+ - "\nexpected: %q\nrecevied: %q", file, receivedFile) + t.Errorf("Failed to close transfer: %+v", err) } - // Check that the received transfer was deleted - _, err = m2.received.GetTransfer(receiveTid) - if err == nil { - t.Error("Failed to delete received file transfer once file has been " + - "received.") + err = stop1.Close() + if err != nil { + t.Errorf("Failed to close processes for manager 1: %+v", err) } - // Close the transfer on the sending manager - err = m1.CloseSend(sendTid) + err = stop2.Close() if err != nil { - t.Errorf("Failed to close the send: %+v", err) + t.Errorf("Failed to close processes for manager 2: %+v", err) } +} - // Check that the sent transfer was deleted - _, err = m1.sent.GetTransfer(sendTid) - if err == nil { - t.Error("Failed to delete sent file transfer once file has been sent " + - "and closed.") - } +const loreumIpsum = `Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed sit amet urna venenatis, rutrum magna maximus, tempor orci. Cras sit amet nulla id dolor blandit commodo. Suspendisse potenti. Praesent gravida porttitor metus vel aliquam. Maecenas rutrum velit at lobortis auctor. Mauris porta blandit tempor. Class aptent taciti sociosqu ad litora torquent per conubia nostra, per inceptos himenaeos. Morbi volutpat posuere maximus. Nunc in augue molestie ante mattis tempor. - if err = stop1.Close(); err != nil { - t.Errorf("Failed to close sending manager threads: %+v", err) - } +Phasellus placerat elit eu fringilla pharetra. Vestibulum consectetur pulvinar nunc, vestibulum tincidunt felis rhoncus sit amet. Duis non dolor eleifend nibh luctus eleifend. Nunc urna odio, euismod sit amet feugiat ut, dapibus vel elit. Nulla est mauris, posuere eget enim cursus, vehicula viverra est. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Quisque mattis, nisi quis consectetur semper, neque enim rhoncus dolor, ut aliquam leo orci sed dolor. Integer ullamcorper pulvinar turpis, a sollicitudin nunc posuere et. Nullam orci nibh, facilisis ac massa eu, bibendum bibendum sapien. Sed tincidunt nunc mauris, nec ullamcorper enim lacinia nec. Nulla dapibus sapien ut odio bibendum, tempus ornare sapien lacinia. - if err = stop2.Close(); err != nil { - t.Errorf("Failed to close receiving manager threads: %+v", err) - } -} +Duis ac hendrerit augue. Nullam porttitor feugiat finibus. Nam enim urna, maximus et ligula eu, aliquet convallis turpis. Vestibulum luctus quam in dictum efficitur. Vestibulum ac pulvinar ipsum. Vivamus consectetur augue nec tellus mollis, at iaculis magna efficitur. Nunc dictum convallis sem, at vehicula nulla accumsan non. Nullam blandit orci vel turpis convallis, mollis porttitor felis accumsan. Sed non posuere leo. Proin ultricies varius nulla at ultricies. Phasellus et pharetra justo. Quisque eu orci odio. Pellentesque pharetra tempor tempor. Aliquam ac nulla lorem. Sed dignissim ligula sit amet nibh fermentum facilisis. + +Donec facilisis rhoncus ante. Duis nec nisi et dolor congue semper vel id ligula. Mauris non eleifend libero, et sodales urna. Nullam pharetra gravida velit non mollis. Integer vel ultrices libero, at ultrices magna. Duis semper risus a leo vulputate consectetur. Cras sit amet convallis sapien. Sed blandit, felis et porttitor fringilla, urna tellus commodo metus, at pharetra nibh urna sed sem. Nam ex dui, posuere id mi et, egestas tincidunt est. Nullam elementum pulvinar diam in maximus. Maecenas vel augue vitae nunc consectetur vestibulum in aliquet lacus. Nullam nec lectus dapibus, dictum nisi nec, congue quam. Suspendisse mollis vel diam nec dapibus. Mauris neque justo, scelerisque et suscipit non, imperdiet eget leo. Vestibulum leo turpis, dapibus ac lorem a, mollis pulvinar quam. + +Sed sed mauris a neque dignissim aliquet. Aliquam congue gravida velit in efficitur. Integer elementum feugiat est, ac lacinia libero bibendum sed. Sed vestibulum suscipit dignissim. Nunc scelerisque, turpis quis varius tristique, enim lacus vehicula lacus, id vestibulum velit erat eu odio. Donec tincidunt nunc sit amet sapien varius ornare. Phasellus semper venenatis ligula eget euismod. Mauris sodales massa tempor, cursus velit a, feugiat neque. Sed odio justo, rhoncus eu fermentum non, tristique a quam. In vehicula in tortor nec iaculis. Cras ligula sem, sollicitudin at nulla eget, placerat lacinia massa. Mauris tempus quam sit amet leo efficitur egestas. Proin iaculis, velit in blandit egestas, felis odio sollicitudin ipsum, eget interdum leo odio tempor nisi. Curabitur sed mauris id turpis tempor finibus ut mollis lectus. Curabitur neque libero, aliquam facilisis lobortis eget, posuere in augue. In sodales urna sit amet elit euismod rhoncus.` diff --git a/fileTransfer/oldTransferRecovery.go b/fileTransfer/oldTransferRecovery.go deleted file mode 100644 index f47e9ba0d004272b8605cf6b5a23332df9e3c67d..0000000000000000000000000000000000000000 --- a/fileTransfer/oldTransferRecovery.go +++ /dev/null @@ -1,135 +0,0 @@ -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - -package fileTransfer - -import ( - "github.com/pkg/errors" - jww "github.com/spf13/jwalterweatherman" - ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" - "gitlab.com/xx_network/primitives/id" - "sync/atomic" -) - -// Error messages. -const ( - oldTransfersRoundResultsErr = "[FT] failed to recover round information " + - "for %d rounds for old file transfers after %d attempts" -) - -// roundResultsMaxAttempts is the maximum number of attempts to get round -// results via api.RoundEventCallback before stopping to try -const roundResultsMaxAttempts = 5 - -// oldTransferRecovery adds all unsent file parts back into the queue and -// updates the in-progress file parts by getting round updates. -func (m Manager) oldTransferRecovery(healthyChan chan bool, chanID uint64) { - - // Exit if old transfers have already been recovered - // TODO: move GetUnsentPartsAndSentRounds to manager creation and remove the - // atomic - if !atomic.CompareAndSwapUint32(m.oldTransfersRecovered, 0, 1) { - jww.DEBUG.Printf("[FT] Old file transfer recovery thread not " + - "starting: none to recover (app was not closed)") - return - } - - // Get list of unsent parts and rounds that parts were sent on - unsentParts, sentRounds, err := m.sent.GetUnsentPartsAndSentRounds() - - jww.DEBUG.Printf("[FT] Adding unsent parts from %d recovered transfers: %v", - len(unsentParts), unsentParts) - - // Add all unsent parts to the queue - for tid, partNums := range unsentParts { - m.queueParts(tid, partNums) - } - - if err != nil { - jww.ERROR.Printf("[FT] Failed to get sent rounds: %+v", err) - m.net.GetHealthTracker().RemoveChannel(chanID) - return - } - - // Return if there are no parts to recover - if len(sentRounds) == 0 { - jww.DEBUG.Print( - "[FT] No in-progress rounds from old transfers to recover.") - return - } - - // Update parts that were sent by looking up the status of the rounds they - // were sent on - go func(healthyChan chan bool, chanID uint64, - sentRounds map[id.Round][]ftCrypto.TransferID) { - err := m.updateSentRounds(healthyChan, sentRounds) - if err != nil { - jww.ERROR.Print(err) - } - - // Remove channel from tacker once done with it - m.net.GetHealthTracker().RemoveChannel(chanID) - }(healthyChan, chanID, sentRounds) -} - -// updateSentRounds looks up the status of each round that parts were sent on -// but never arrived. It updates the status of each part depending on if the -// round failed or succeeded. -func (m Manager) updateSentRounds(healthyChan chan bool, - sentRounds map[id.Round][]ftCrypto.TransferID) error { - // Tracks the number of attempts to get round results - var getRoundResultsAttempts int - - jww.DEBUG.Print("[FT] Starting old file transfer recovery thread.") - - // Wait for network to be healthy to attempt to get round states - for getRoundResultsAttempts < roundResultsMaxAttempts { - select { - case healthy := <-healthyChan: - // If the network is unhealthy, wait until it becomes healthy - if !healthy { - jww.DEBUG.Print("[FT] Suspending old file transfer recovery " + - "thread: network is unhealthy.") - } - for !healthy { - healthy = <-healthyChan - } - jww.DEBUG.Print("[FT] Old file transfer recovery thread: " + - "network is healthy.") - - // Register callback to get Round results and retry on error - roundList := roundIdMapToList(sentRounds) - err := m.getRoundResults(roundList, roundResultsTimeout, - m.makeRoundEventCallback(sentRounds)) - if err != nil { - jww.WARN.Printf("[FT] Failed to get round results for old "+ - "transfers for rounds %d (attempt %d/%d): %+v", - getRoundResultsAttempts, roundResultsMaxAttempts, - roundList, err) - } else { - jww.INFO.Printf( - "[FT] Successfully recovered old file transfers: %v", - sentRounds) - - return nil - } - getRoundResultsAttempts++ - } - } - - return errors.Errorf( - oldTransfersRoundResultsErr, len(sentRounds), getRoundResultsAttempts) -} - -// roundIdMapToList returns a list of all round IDs in the map. -func roundIdMapToList(roundMap map[id.Round][]ftCrypto.TransferID) []id.Round { - roundSlice := make([]id.Round, 0, len(roundMap)) - for rid := range roundMap { - roundSlice = append(roundSlice, rid) - } - return roundSlice -} diff --git a/fileTransfer/oldTransferRecovery_test.go b/fileTransfer/oldTransferRecovery_test.go deleted file mode 100644 index dcc51eb6c0b6b6edf5964005337582e3bfe9ca92..0000000000000000000000000000000000000000 --- a/fileTransfer/oldTransferRecovery_test.go +++ /dev/null @@ -1,382 +0,0 @@ -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - -package fileTransfer - -import ( - "fmt" - "github.com/pkg/errors" - "gitlab.com/elixxir/client/cmix" - "gitlab.com/elixxir/client/interfaces" - "gitlab.com/elixxir/client/storage/versioned" - ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" - "gitlab.com/elixxir/ekv" - "gitlab.com/xx_network/primitives/id" - "math/rand" - "reflect" - "sort" - "sync" - "sync/atomic" - "testing" - "time" -) - -// Tests that Manager.oldTransferRecovery adds all unsent parts to the queue. -func TestManager_oldTransferRecovery(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - m, sti, _ := newTestManagerWithTransfers( - []uint16{6, 12, 18}, false, true, nil, nil, kv, t) - - finishedRounds := make(map[id.Round][]ftCrypto.TransferID) - expectedStatus := make( - map[ftCrypto.TransferID]map[uint16]interfaces.FpStatus, len(sti)) - numCbCalls := make(map[ftCrypto.TransferID]int, len(sti)) - var numUnsent int - - for i, st := range sti { - transfer, err := m.sent.GetTransfer(st.tid) - if err != nil { - t.Fatalf("Failed to get transfer #%d %s: %+v", i, st.tid, err) - } - - expectedStatus[st.tid] = make(map[uint16]interfaces.FpStatus, st.numParts) - - // Loop through each part and set it individually - for j, k := uint16(0), 0; j < transfer.GetNumParts(); j++ { - rid := id.Round(j) - switch j % 3 { - case 0: - // Part is sent (in-progress) - _, _ = transfer.SetInProgress(rid, j) - if k%2 == 0 { - finishedRounds[rid] = append(finishedRounds[rid], st.tid) - expectedStatus[st.tid][j] = 2 - } else { - expectedStatus[st.tid][j] = 0 - numUnsent++ - } - numCbCalls[st.tid]++ - k++ - case 1: - // Part is sent and arrived (finished) - _, _ = transfer.SetInProgress(rid, j) - _, _ = transfer.FinishTransfer(rid) - finishedRounds[rid] = append(finishedRounds[rid], st.tid) - expectedStatus[st.tid][j] = 2 - case 2: - // Part is unsent (neither in-progress nor arrived) - expectedStatus[st.tid][j] = 0 - numUnsent++ - } - } - } - - // Returns an error on function and round failure on callback if sendErr is - // set; otherwise, it reports round successes and returns nil - rr := func(rIDs []id.Round, _ time.Duration, cb cmix.RoundEventCallback) error { - rounds := make(map[id.Round]cmix.RoundLookupStatus, len(rIDs)) - for _, rid := range rIDs { - if finishedRounds[rid] != nil { - rounds[rid] = cmix.Succeeded - } else { - rounds[rid] = cmix.Failed - } - } - cb(true, false, rounds) - - return nil - } - - // Load new manager from the original manager's storage - net := newTestNetworkManager(false, nil, nil, t) - loadedManager, err := newManager( - nil, nil, nil, net, nil, rr, kv, nil, DefaultParams()) - if err != nil { - t.Errorf("Failed to create new manager from KV: %+v", err) - } - - // Create new progress callbacks with channels - cbChans := make([]chan sentProgressResults, len(sti)) - numCbCalls2 := make(map[ftCrypto.TransferID]int, len(sti)) - for i, st := range sti { - // Create sent progress callback and channel - cbChan := make(chan sentProgressResults, 32) - numCbCalls2[st.tid] = 0 - tid := st.tid - numCalls, maxNumCalls := int64(0), int64(numCbCalls[tid]) - cb := func(completed bool, sent, arrived, total uint16, - tr interfaces.FilePartTracker, err error) { - if atomic.CompareAndSwapInt64(&numCalls, maxNumCalls, maxNumCalls) { - cbChan <- sentProgressResults{ - completed, sent, arrived, total, tr, err} - } - atomic.AddInt64(&numCalls, 1) - } - cbChans[i] = cbChan - - err = loadedManager.RegisterSentProgressCallback(st.tid, cb, 0) - if err != nil { - t.Errorf("Failed to register SentProgressCallback for transfer "+ - "%d %s: %+v", i, st.tid, err) - } - } - - // Wait until callbacks have been called to know the transfers have been - // recovered - var wg sync.WaitGroup - for i, st := range sti { - wg.Add(1) - go func(i, callNum int, cbChan chan sentProgressResults, st sentTransferInfo) { - defer wg.Done() - select { - case <-time.NewTimer(150 * time.Millisecond).C: - case <-cbChan: - } - }(i, numCbCalls[st.tid], cbChans[i], st) - } - - // Create health chan - healthyRecover := make(chan bool, networkHealthBuffLen) - chanID := net.GetHealthTracker().AddChannel(healthyRecover) - healthyRecover <- true - - loadedManager.oldTransferRecovery(healthyRecover, chanID) - - wg.Wait() - - // Check the status of each round in each transfer - for i, st := range sti { - transfer, err := loadedManager.sent.GetTransfer(st.tid) - if err != nil { - t.Fatalf("Failed to get transfer #%d %s: %+v", i, st.tid, err) - } - - _, _, _, _, track := transfer.GetProgress() - for j := uint16(0); j < track.GetNumParts(); j++ { - if track.GetPartStatus(j) != expectedStatus[st.tid][j] { - t.Errorf("Unexpected part #%d status for transfer #%d %s."+ - "\nexpected: %d\nreceived: %d", j, i, st.tid, - expectedStatus[st.tid][j], track.GetPartStatus(j)) - } - } - } - - // Check that each item in the queue is unsent - var queueCount int - for done := false; !done; { - select { - case <-time.NewTimer(5 * time.Millisecond).C: - done = true - case p := <-loadedManager.sendQueue: - queueCount++ - if expectedStatus[p.tid][p.partNum] != 0 { - t.Errorf("Part #%d for transfer %s not expected in qeueu."+ - "\nexpected: %d\nreceived: %d", p.partNum, p.tid, - expectedStatus[p.tid][p.partNum], 0) - } - } - } - - // Check that the number of items in the queue is correct - if queueCount != numUnsent { - t.Errorf("Number of items incorrect.\nexpected: %d\nreceived: %d", - numUnsent, queueCount) - } -} - -// Tests that Manager.updateSentRounds updates the status of each round -// correctly by using the part tracker and checks that all the correct parts -// were added to the queue. -func TestManager_updateSentRounds(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - m, sti, _ := newTestManagerWithTransfers( - []uint16{6, 12, 18}, false, true, nil, nil, kv, t) - - finishedRounds := make(map[id.Round][]ftCrypto.TransferID) - expectedStatus := make( - map[ftCrypto.TransferID]map[uint16]interfaces.FpStatus, len(sti)) - var numUnsent int - - for i, st := range sti { - transfer, err := m.sent.GetTransfer(st.tid) - if err != nil { - t.Fatalf("Failed to get transfer #%d %s: %+v", i, st.tid, err) - } - - expectedStatus[st.tid] = make( - map[uint16]interfaces.FpStatus, st.numParts) - - // Loop through each part and set it individually - for j, k := uint16(0), 0; j < transfer.GetNumParts(); j++ { - rid := id.Round(j) - switch j % 3 { - case 0: - // Part is sent (in-progress) - _, _ = transfer.SetInProgress(rid, j) - if k%2 == 0 { - finishedRounds[rid] = append(finishedRounds[rid], st.tid) - expectedStatus[st.tid][j] = 2 - } else { - expectedStatus[st.tid][j] = 0 - numUnsent++ - } - k++ - case 1: - // Part is sent and arrived (finished) - _, _ = transfer.SetInProgress(rid, j) - _, _ = transfer.FinishTransfer(rid) - finishedRounds[rid] = append(finishedRounds[rid], st.tid) - expectedStatus[st.tid][j] = 2 - case 2: - // Part is unsent (neither in-progress nor arrived) - expectedStatus[st.tid][j] = 0 - } - } - } - - // Returns an error on function and round failure on callback if sendErr is - // set; otherwise, it reports round successes and returns nil - rr := func(rIDs []id.Round, _ time.Duration, cb cmix.RoundEventCallback) error { - rounds := make(map[id.Round]cmix.RoundLookupStatus, len(rIDs)) - for _, rid := range rIDs { - if finishedRounds[rid] != nil { - rounds[rid] = cmix.Succeeded - } else { - rounds[rid] = cmix.Failed - } - } - cb(true, false, rounds) - - return nil - } - - loadedManager, err := newManager( - nil, nil, nil, nil, nil, rr, kv, nil, DefaultParams()) - if err != nil { - t.Errorf("Failed to create new manager from KV: %+v", err) - } - - // Create health chan - healthyRecover := make(chan bool, networkHealthBuffLen) - healthyRecover <- true - - // Get list of rounds that parts were sent on - _, loadedSentRounds, _ := m.sent.GetUnsentPartsAndSentRounds() - - err = loadedManager.updateSentRounds(healthyRecover, loadedSentRounds) - if err != nil { - t.Errorf("updateSentRounds returned an error: %+v", err) - } - - // Check the status of each round in each transfer - for i, st := range sti { - transfer, err := loadedManager.sent.GetTransfer(st.tid) - if err != nil { - t.Fatalf("Failed to get transfer #%d %s: %+v", i, st.tid, err) - } - - _, _, _, _, track := transfer.GetProgress() - for j := uint16(0); j < track.GetNumParts(); j++ { - if track.GetPartStatus(j) != expectedStatus[st.tid][j] { - t.Errorf("Unexpected part #%d status for transfer #%d %s."+ - "\nexpected: %d\nreceived: %d", j, i, st.tid, - expectedStatus[st.tid][j], track.GetPartStatus(j)) - } - } - } - - // Check that each item in the queue is unsent - var queueCount int - for done := false; !done; { - select { - case <-time.NewTimer(5 * time.Millisecond).C: - done = true - case p := <-loadedManager.sendQueue: - queueCount++ - if expectedStatus[p.tid][p.partNum] != 0 { - t.Errorf("Part #%d for transfer %s not expected in qeueu."+ - "\nexpected: %d\nreceived: %d", p.partNum, p.tid, - expectedStatus[p.tid][p.partNum], 0) - } - } - } - - // Check that the number of items in the queue is correct - if queueCount != numUnsent { - t.Errorf("Number of items incorrect.\nexpected: %d\nreceived: %d", - numUnsent, queueCount) - } -} - -// Error path: tests that Manager.updateSentRounds returns the expected error -// when getRoundResults returns only errors. -func TestManager_updateSentRounds_Error(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - m, _, _ := newTestManagerWithTransfers( - []uint16{6, 12, 18}, false, true, nil, nil, kv, t) - - // Returns an error on function and round failure on callback if sendErr is - // set; otherwise, it reports round successes and returns nil - m.getRoundResults = func( - []id.Round, time.Duration, cmix.RoundEventCallback) error { - return errors.Errorf("GetRoundResults error") - } - - // Create health chan - healthyRecover := make(chan bool, roundResultsMaxAttempts) - for i := 0; i < roundResultsMaxAttempts; i++ { - healthyRecover <- true - } - - sentRounds := map[id.Round][]ftCrypto.TransferID{ - 0: {{1}, {2}, {3}}, - 5: {{4}, {2}, {6}}, - 9: {{3}, {9}, {8}}, - } - - expectedErr := fmt.Sprintf( - oldTransfersRoundResultsErr, len(sentRounds), roundResultsMaxAttempts) - err := m.updateSentRounds(healthyRecover, sentRounds) - if err == nil || err.Error() != expectedErr { - t.Errorf("updateSentRounds did not return the expected error when "+ - "getRoundResults returns only errors.\nexpected: %s\nreceived: %+v", - expectedErr, err) - } - -} - -// Tests that roundIdMapToList returns all the round IDs in the map. -func Test_roundIdMapToList(t *testing.T) { - n := 10 - roundMap := make(map[id.Round][]ftCrypto.TransferID, n) - expectedRoundList := make([]id.Round, n) - - csPrng := NewPrng(42) - prng := rand.New(rand.NewSource(42)) - - for i := 0; i < n; i++ { - rid := id.Round(i) - roundMap[rid] = make([]ftCrypto.TransferID, prng.Intn(10)) - for j := range roundMap[rid] { - roundMap[rid][j], _ = ftCrypto.NewTransferID(csPrng) - } - - expectedRoundList[i] = rid - } - - receivedRoundList := roundIdMapToList(roundMap) - - sort.SliceStable(receivedRoundList, func(i, j int) bool { - return receivedRoundList[i] < receivedRoundList[j] - }) - - if !reflect.DeepEqual(expectedRoundList, receivedRoundList) { - t.Errorf("Round list does not match expected."+ - "\nexpected: %v\nreceived: %v", - expectedRoundList, receivedRoundList) - } -} diff --git a/fileTransfer/partTracker.go b/fileTransfer/partTracker.go new file mode 100644 index 0000000000000000000000000000000000000000..1288c7a16482fecec864a5cfa236b373c1020d3e --- /dev/null +++ b/fileTransfer/partTracker.go @@ -0,0 +1,68 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "gitlab.com/elixxir/client/storage/utility" +) + +// sentFilePartTracker contains utility.StateVector that tracks which parts have +// arrived. It adheres to the FilePartTracker interface. +type sentFilePartTracker struct { + *utility.StateVector +} + +// GetPartStatus returns the status of the sent file part with the given part +// number. +func (s *sentFilePartTracker) GetPartStatus(partNum uint16) FpStatus { + if uint32(partNum) >= s.GetNumKeys() { + return -1 + } + + switch s.Used(uint32(partNum)) { + case true: + return FpArrived + case false: + return FpUnsent + default: + return -1 + } +} + +// GetNumParts returns the total number of file parts in the transfer. +func (s *sentFilePartTracker) GetNumParts() uint16 { + return uint16(s.GetNumKeys()) +} + +// receivedFilePartTracker contains utility.StateVector that tracks which parts +// have been received. It adheres to the FilePartTracker interface. +type receivedFilePartTracker struct { + *utility.StateVector +} + +// GetPartStatus returns the status of the received file part with the given +// part number. +func (r *receivedFilePartTracker) GetPartStatus(partNum uint16) FpStatus { + if uint32(partNum) >= r.GetNumKeys() { + return -1 + } + + switch r.Used(uint32(partNum)) { + case true: + return FpReceived + case false: + return FpUnsent + default: + return -1 + } +} + +// GetNumParts returns the total number of file parts in the transfer. +func (r *receivedFilePartTracker) GetNumParts() uint16 { + return uint16(r.GetNumKeys()) +} diff --git a/fileTransfer/processor.go b/fileTransfer/processor.go new file mode 100644 index 0000000000000000000000000000000000000000..f4c2432c3277c61588c7699cf8544ca8346b1d54 --- /dev/null +++ b/fileTransfer/processor.go @@ -0,0 +1,64 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/client/cmix/identity/receptionID" + "gitlab.com/elixxir/client/cmix/rounds" + "gitlab.com/elixxir/client/fileTransfer/store" + "gitlab.com/elixxir/client/fileTransfer/store/cypher" + "gitlab.com/elixxir/client/fileTransfer/store/fileMessage" + "gitlab.com/elixxir/primitives/format" +) + +// Error messages. +const ( + // processor.Process + errDecryptPart = "[FT] Failed to decrypt file part for transfer %s (%q) on round %d: %+v" + errUnmarshalPart = "[FT] Failed to unmarshal decrypted file part for transfer %s (%q) on round %d: %+v" + errAddPart = "[FT] Failed to add part #%d to transfer transfer %s (%q): %+v" +) + +// processor manages the reception of file transfer messages. Adheres to the +// message.Processor interface. +type processor struct { + cypher.Cypher + *store.ReceivedTransfer + *manager +} + +// Process decrypts and hands off the file part message and adds it to the +// correct file transfer. +func (p *processor) Process(msg format.Message, + _ receptionID.EphemeralIdentity, round rounds.Round) { + + decryptedPart, err := p.Decrypt(msg) + if err != nil { + jww.ERROR.Printf( + errDecryptPart, p.TransferID(), p.FileName(), round.ID, err) + return + } + + partMsg, err := fileMessage.UnmarshalPartMessage(decryptedPart) + if err != nil { + jww.ERROR.Printf( + errUnmarshalPart, p.TransferID(), p.FileName(), round.ID, err) + return + } + + err = p.AddPart(partMsg.GetPart(), int(partMsg.GetPartNum())) + if err != nil { + jww.WARN.Printf( + errAddPart, partMsg.GetPartNum(), p.TransferID(), p.FileName(), err) + return + } + + // Call callback with updates + p.callbacks.Call(p.TransferID(), nil) +} diff --git a/fileTransfer/receive.go b/fileTransfer/receive.go deleted file mode 100644 index c122541ee778c2a7943af2ecb69268f1f94ca7ab..0000000000000000000000000000000000000000 --- a/fileTransfer/receive.go +++ /dev/null @@ -1,78 +0,0 @@ -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - -package fileTransfer - -import ( - jww "github.com/spf13/jwalterweatherman" - "gitlab.com/elixxir/client/interfaces/message" - "gitlab.com/elixxir/client/stoppable" - "gitlab.com/elixxir/primitives/format" - "strings" -) - -// receive runs a loop that receives file message parts and stores them in their -// appropriate transfer. -func (m *Manager) receive(rawMsgs chan message.Receive, stop *stoppable.Single) { - jww.DEBUG.Print("[FT] Starting file part reception thread.") - - for { - select { - case <-stop.Quit(): - jww.DEBUG.Print("[FT] Stopping file part reception thread: stoppable " + - "triggered.") - stop.ToStopped() - return - case receiveMsg := <-rawMsgs: - cMixMsg, err := m.readMessage(receiveMsg) - if err != nil { - // Print error as warning unless the fingerprint does not match, - // which means this message is not of the correct type and will - // be ignored - if strings.Contains(err.Error(), "fingerprint") { - jww.TRACE.Printf("[FT] %v", err) - } else { - jww.WARN.Printf("[FT] %v", err) - } - continue - } - - // Denote that the message is a file part - m.store.GetGarbledMessages().Remove(cMixMsg) - } - } -} - -// readMessage unmarshal the payload in the message.Receive and stores it with -// the appropriate received transfer. The cMix message is returned so that, on -// error, it can be either marked as used not used. -func (m *Manager) readMessage(msg message.Receive) (format.Message, error) { - // Unmarshal payload into cMix message - cMixMsg, err := format.Unmarshal(msg.Payload) - if err != nil { - return cMixMsg, err - } - - // Add part to received transfer - rt, tid, completed, err := m.received.AddPart(cMixMsg) - if err != nil { - return cMixMsg, err - } - - // Print debug message on completion - if completed { - jww.DEBUG.Printf("[FT] Received last part for file transfer %s from "+ - "%s {size: %d, parts: %d, numFps: %d/%d}", tid, msg.Sender, - rt.GetFileSize(), rt.GetNumParts(), - rt.GetNumFps()-rt.GetNumAvailableFps(), rt.GetNumFps()) - } - - // Call callback with updates - rt.CallProgressCB(nil) - - return cMixMsg, nil -} diff --git a/fileTransfer/receiveNew.go b/fileTransfer/receiveNew.go deleted file mode 100644 index 67f3ce3e44570f4f6647c69b7ba098f965b19c62..0000000000000000000000000000000000000000 --- a/fileTransfer/receiveNew.go +++ /dev/null @@ -1,106 +0,0 @@ -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - -package fileTransfer - -import ( - "github.com/golang/protobuf/proto" - "github.com/pkg/errors" - jww "github.com/spf13/jwalterweatherman" - "gitlab.com/elixxir/client/interfaces/message" - "gitlab.com/elixxir/client/stoppable" - ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" - "gitlab.com/xx_network/primitives/id" -) - -// Error messages. -const ( - receiveMessageTypeErr = "received message is not of type NewFileTransfer" - protoUnmarshalErr = "failed to unmarshal request: %+v" -) - -// receiveNewFileTransfer starts a thread that waits for new file transfer -// messages. -func (m *Manager) receiveNewFileTransfer(rawMsgs chan message.Receive, - stop *stoppable.Single) { - jww.DEBUG.Print("[FT] Starting new file transfer message reception thread.") - - for { - select { - case <-stop.Quit(): - jww.DEBUG.Print("[FT] Stopping new file transfer message " + - "reception thread: stoppable triggered") - stop.ToStopped() - return - case receivedMsg := <-rawMsgs: - jww.TRACE.Print( - "[FT] New file transfer message thread received message.") - - tid, fileName, fileType, sender, size, preview, err := - m.readNewFileTransferMessage(receivedMsg) - if err != nil { - if err.Error() == receiveMessageTypeErr { - jww.DEBUG.Printf("[FT] Failed to read message as new file "+ - "transfer message: %+v", err) - } else { - jww.WARN.Printf("[FT] Failed to read message as new file "+ - "transfer message: %+v", err) - } - continue - } - - // Call the reception callback - go m.receiveCB(tid, fileName, fileType, sender, size, preview) - - // Trigger a resend of all garbled messages - m.net.CheckGarbledMessages() - } - } -} - -// readNewFileTransferMessage reads the received message and adds it to the -// received transfer list. Returns the transfer ID, sender ID, file size, and -// file preview. -func (m *Manager) readNewFileTransferMessage(msg message.Receive) ( - tid ftCrypto.TransferID, fileName, fileType string, sender *id.ID, - fileSize uint32, preview []byte, err error) { - - // Return an error if the message is not a NewFileTransfer - if msg.MessageType != message.NewFileTransfer { - err = errors.New(receiveMessageTypeErr) - return - } - - // Unmarshal the request message - newFT := &NewFileTransfer{} - err = proto.Unmarshal(msg.Payload, newFT) - if err != nil { - err = errors.Errorf(protoUnmarshalErr, err) - return - } - - // Get RNG from stream - rng := m.rng.GetStream() - defer rng.Close() - - // Add the transfer to the list of receiving transfers - key := ftCrypto.UnmarshalTransferKey(newFT.TransferKey) - numParts := uint16(newFT.NumParts) - numFps := calcNumberOfFingerprints(numParts, newFT.Retry) - tid, err = m.received.AddTransfer( - key, newFT.TransferMac, newFT.Size, numParts, numFps, rng) - if err != nil { - return - } - - jww.DEBUG.Printf("[FT] Received new file transfer %s from %s {name: %q, "+ - "type: %q, size: %d, parts: %d, numFps: %d, retry: %f}", tid, msg.Sender, - newFT.FileName, newFT.FileType, newFT.Size, numParts, numFps, newFT.Retry) - - return tid, newFT.FileName, newFT.FileType, msg.Sender, newFT.Size, - newFT.Preview, nil -} diff --git a/fileTransfer/receiveNew_test.go b/fileTransfer/receiveNew_test.go deleted file mode 100644 index 4aea311faa6e0fedadf6f627e8f4b43eb544753e..0000000000000000000000000000000000000000 --- a/fileTransfer/receiveNew_test.go +++ /dev/null @@ -1,294 +0,0 @@ -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - -package fileTransfer - -import ( - "bytes" - "github.com/golang/protobuf/proto" - "gitlab.com/elixxir/client/interfaces/message" - "gitlab.com/elixxir/client/stoppable" - ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" - "gitlab.com/xx_network/primitives/id" - "strings" - "testing" - "time" -) - -// Tests that Manager.receiveNewFileTransfer receives the sent message and that -// it reports the correct data to the callback. -func TestManager_receiveNewFileTransfer(t *testing.T) { - // Create new ReceiveCallback that sends the results on a channel - receiveChan := make(chan receivedFtResults) - receiveCB := func(tid ftCrypto.TransferID, fileName, fileType string, - sender *id.ID, size uint32, preview []byte) { - receiveChan <- receivedFtResults{ - tid, fileName, fileType, sender, size, preview} - } - - // Create new manager, stoppable, and channel to receive messages - m := newTestManager(false, nil, nil, receiveCB, nil, t) - stop := stoppable.NewSingle(newFtStoppableName) - rawMsgs := make(chan message.Receive, rawMessageBuffSize) - - // Start receiving thread - go m.receiveNewFileTransfer(rawMsgs, stop) - - // Create new message.Receive with marshalled NewFileTransfer - key, _ := ftCrypto.NewTransferKey(NewPrng(42)) - protoMsg := &NewFileTransfer{ - FileName: "testFile", - TransferKey: key.Bytes(), - TransferMac: []byte("transferMac"), - NumParts: 16, - Size: 256, - Retry: 1.5, - Preview: []byte("filePreview"), - } - marshalledMsg, err := proto.Marshal(protoMsg) - if err != nil { - t.Errorf("Failed to Marshal proto message: %+v", err) - } - receiveMsg := message.Receive{ - Payload: marshalledMsg, - MessageType: message.NewFileTransfer, - Sender: id.NewIdFromString("sender", id.User, t), - } - - // Add message to channel - rawMsgs <- receiveMsg - - // Wait to receive message - select { - case <-time.NewTimer(100 * time.Millisecond).C: - t.Error("Timed out waiting to receive message.") - case r := <-receiveChan: - if !receiveMsg.Sender.Cmp(r.sender) { - t.Errorf("Received sender ID does not match expected."+ - "\nexpected: %s\nreceived: %s", receiveMsg.Sender, r.sender) - } - if protoMsg.Size != r.size { - t.Errorf("Received file size does not match expected."+ - "\nexpected: %d\nreceived: %d", protoMsg.Size, r.size) - } - if !bytes.Equal(protoMsg.Preview, r.preview) { - t.Errorf("Received preview does not match expected."+ - "\nexpected: %q\nreceived: %q", protoMsg.Preview, r.preview) - } - } - - // Stop thread - err = stop.Close() - if err != nil { - t.Errorf("Failed to stop stoppable: %+v", err) - } -} - -// Tests that Manager.receiveNewFileTransfer stops receiving messages when the -// stoppable is triggered. -func TestManager_receiveNewFileTransfer_Stop(t *testing.T) { - // Create new ReceiveCallback that sends the results on a channel - receiveChan := make(chan receivedFtResults) - receiveCB := func(tid ftCrypto.TransferID, fileName, fileType string, - sender *id.ID, size uint32, preview []byte) { - receiveChan <- receivedFtResults{ - tid, fileName, fileType, sender, size, preview} - } - - // Create new manager, stoppable, and channel to receive messages - m := newTestManager(false, nil, nil, receiveCB, nil, t) - stop := stoppable.NewSingle(newFtStoppableName) - rawMsgs := make(chan message.Receive, rawMessageBuffSize) - - // Start receiving thread - go m.receiveNewFileTransfer(rawMsgs, stop) - - // Create new message.Receive with marshalled NewFileTransfer - key, _ := ftCrypto.NewTransferKey(NewPrng(42)) - protoMsg := &NewFileTransfer{ - FileName: "testFile", - TransferKey: key.Bytes(), - TransferMac: []byte("transferMac"), - NumParts: 16, - Size: 256, - Retry: 1.5, - Preview: []byte("filePreview"), - } - marshalledMsg, err := proto.Marshal(protoMsg) - if err != nil { - t.Errorf("Failed to Marshal proto message: %+v", err) - } - receiveMsg := message.Receive{ - Payload: marshalledMsg, - MessageType: message.NewFileTransfer, - Sender: id.NewIdFromString("sender", id.User, t), - } - - // Stop thread - err = stop.Close() - if err != nil { - t.Errorf("Failed to stop stoppable: %+v", err) - } - - for !stop.IsStopped() { - } - - // Add message to channel - rawMsgs <- receiveMsg - - // Wait to receive message - select { - case <-time.NewTimer(time.Millisecond).C: - case r := <-receiveChan: - t.Errorf("Callback called when the thread should have quit."+ - "\nreceived: %+v", r) - } -} - -// Tests that Manager.receiveNewFileTransfer does not report on the callback -// when the received message is of the wrong type. -func TestManager_receiveNewFileTransfer_InvalidMessageError(t *testing.T) { - // Create new ReceiveCallback that sends the results on a channel - receiveChan := make(chan receivedFtResults) - receiveCB := func(tid ftCrypto.TransferID, fileName, fileType string, - sender *id.ID, size uint32, preview []byte) { - receiveChan <- receivedFtResults{ - tid, fileName, fileType, sender, size, preview} - } - - // Create new manager, stoppable, and channel to receive messages - m := newTestManager(false, nil, nil, receiveCB, nil, t) - stop := stoppable.NewSingle(newFtStoppableName) - rawMsgs := make(chan message.Receive, rawMessageBuffSize) - - // Start receiving thread - go m.receiveNewFileTransfer(rawMsgs, stop) - - // Create new message.Receive with wrong type - receiveMsg := message.Receive{ - MessageType: message.NoType, - } - - // Add message to channel - rawMsgs <- receiveMsg - - // Wait to receive message - select { - case <-time.NewTimer(time.Millisecond).C: - case r := <-receiveChan: - t.Errorf("Callback called when the message is of the wrong type."+ - "\nreceived: %+v", r) - } - - // Stop thread - err := stop.Close() - if err != nil { - t.Errorf("Failed to stop stoppable: %+v", err) - } -} - -// Tests that Manager.readNewFileTransferMessage returns the expected sender ID, -// file size, and preview. -func TestManager_readNewFileTransferMessage(t *testing.T) { - m := newTestManager(false, nil, nil, nil, nil, t) - - // Create new message.Send containing marshalled NewFileTransfer - recipient := id.NewIdFromString("recipient", id.User, t) - expectedFileName := "testFile" - expectedFileType := "txt" - key, _ := ftCrypto.NewTransferKey(NewPrng(42)) - mac := []byte("transferMac") - numParts, expectedFileSize, retry := uint16(16), uint32(256), float32(1.5) - expectedPreview := []byte("filePreview") - sendMsg, err := newNewFileTransferE2eMessage( - recipient, expectedFileName, expectedFileType, key, mac, numParts, - expectedFileSize, retry, expectedPreview) - if err != nil { - t.Errorf("Failed to create new Send message: %+v", err) - } - - // Create message.Receive with marshalled NewFileTransfer - receiveMsg := message.Receive{ - Payload: sendMsg.Payload, - MessageType: message.NewFileTransfer, - Sender: id.NewIdFromString("sender", id.User, t), - } - - // Read the message - _, fileName, fileType, sender, fileSize, preview, err := - m.readNewFileTransferMessage(receiveMsg) - if err != nil { - t.Errorf("readNewFileTransferMessage returned an error: %+v", err) - } - - if expectedFileName != fileName { - t.Errorf("Returned file name does not match expected."+ - "\nexpected: %q\nreceived: %q", expectedFileName, fileName) - } - - if expectedFileType != fileType { - t.Errorf("Returned file type does not match expected."+ - "\nexpected: %q\nreceived: %q", expectedFileType, fileType) - } - - if !receiveMsg.Sender.Cmp(sender) { - t.Errorf("Returned sender ID does not match expected."+ - "\nexpected: %s\nreceived: %s", receiveMsg.Sender, sender) - } - - if expectedFileSize != fileSize { - t.Errorf("Returned file size does not match expected."+ - "\nexpected: %d\nreceived: %d", expectedFileSize, fileSize) - } - - if !bytes.Equal(expectedPreview, preview) { - t.Errorf("Returned preview does not match expected."+ - "\nexpected: %q\nreceived: %q", expectedPreview, preview) - } -} - -// Error path: tests that Manager.readNewFileTransferMessage returns the -// expected error when the message.Receive has the wrong MessageType. -func TestManager_readNewFileTransferMessage_MessageTypeError(t *testing.T) { - m := newTestManager(false, nil, nil, nil, nil, t) - expectedErr := receiveMessageTypeErr - - // Create message.Receive with marshalled NewFileTransfer - receiveMsg := message.Receive{ - MessageType: message.NoType, - } - - // Read the message - _, _, _, _, _, _, err := m.readNewFileTransferMessage(receiveMsg) - if err == nil || err.Error() != expectedErr { - t.Errorf("readNewFileTransferMessage did not return the expected "+ - "error when the message.Receive has the wrong MessageType."+ - "\nexpected: %s\nreceived: %+v", expectedErr, err) - } -} - -// Error path: tests that Manager.readNewFileTransferMessage returns the -// expected error when the payload of the message.Receive cannot be -// unmarshalled. -func TestManager_readNewFileTransferMessage_ProtoUnmarshalError(t *testing.T) { - m := newTestManager(false, nil, nil, nil, nil, t) - expectedErr := strings.Split(protoUnmarshalErr, "%")[0] - - // Create message.Receive with marshalled NewFileTransfer - receiveMsg := message.Receive{ - Payload: []byte("invalidPayload"), - MessageType: message.NewFileTransfer, - } - - // Read the message - _, _, _, _, _, _, err := m.readNewFileTransferMessage(receiveMsg) - if err == nil || !strings.Contains(err.Error(), expectedErr) { - t.Errorf("readNewFileTransferMessage did not return the expected "+ - "error when the payload could not be unmarshalled."+ - "\nexpected: %s\nreceived: %+v", expectedErr, err) - } -} diff --git a/fileTransfer/receive_test.go b/fileTransfer/receive_test.go deleted file mode 100644 index cc28fc24f1b910128ceba15d1486b1ef9e120966..0000000000000000000000000000000000000000 --- a/fileTransfer/receive_test.go +++ /dev/null @@ -1,345 +0,0 @@ -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - -package fileTransfer - -import ( - "bytes" - "gitlab.com/elixxir/client/interfaces" - "gitlab.com/elixxir/client/interfaces/message" - "gitlab.com/elixxir/client/stoppable" - ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" - "gitlab.com/elixxir/primitives/format" - "gitlab.com/xx_network/primitives/id" - "reflect" - "testing" - "time" -) - -// Tests that Manager.receive returns the correct progress on the callback when -// receiving a single message. -func TestManager_receive(t *testing.T) { - // Build a manager for sending and a manger for receiving - m1 := newTestManager(false, nil, nil, nil, nil, t) - m2 := newTestManager(false, nil, nil, nil, nil, t) - - // Create transfer components - prng := NewPrng(42) - recipient := id.NewIdFromString("recipient", id.User, t) - key, _ := ftCrypto.NewTransferKey(prng) - numParts := uint16(16) - numFps := calcNumberOfFingerprints(numParts, 0.5) - partSize, _ := m1.getPartSize() - file, parts := newFile(numParts, partSize, prng, t) - fileSize := uint32(len(file)) - mac := ftCrypto.CreateTransferMAC(file, key) - - // Add transfer to sending manager - stID, err := m1.sent.AddTransfer( - recipient, key, parts, numFps, nil, 0, prng) - if err != nil { - t.Errorf("Failed to add new sent transfer: %+v", err) - } - - // Add transfer to receiving manager - rtID, err := m2.received.AddTransfer( - key, mac, fileSize, numParts, numFps, prng) - if err != nil { - t.Errorf("Failed to add new received transfer: %+v", err) - } - - // Generate receive callback that should be called when a message is read - cbChan := make(chan receivedProgressResults) - cb := func(completed bool, received, total uint16, - tr interfaces.FilePartTracker, err error) { - cbChan <- receivedProgressResults{completed, received, total, tr, err} - } - - done0, done1 := make(chan bool), make(chan bool) - go func() { - for i := 0; i < 2; i++ { - select { - case <-time.NewTimer(10 * time.Millisecond).C: - t.Error("Timed out waiting for callback to be called.") - case r := <-cbChan: - switch i { - case 0: - done0 <- true - case 1: - err = checkReceivedProgress(r.completed, r.received, r.total, false, - 1, numParts) - if r.err != nil { - t.Errorf("Callback returned an error: %+v", err) - } - if err != nil { - t.Error(err) - } - done1 <- true - } - } - } - }() - - rt, err := m2.received.GetTransfer(rtID) - if err != nil { - t.Errorf("Failed to get received transfer %s: %+v", rtID, err) - } - rt.AddProgressCB(cb, time.Millisecond) - - <-done0 - - // Build message.Receive with cMix message that has file part - st, err := m1.sent.GetTransfer(stID) - if err != nil { - t.Errorf("Failed to get sent transfer %s: %+v", stID, err) - } - cMixMsg, err := m1.newCmixMessage(st, 0) - if err != nil { - t.Errorf("Failed to create new cMix message: %+v", err) - } - receiveMsg := message.Receive{Payload: cMixMsg.Marshal()} - - // Start reception thread - rawMsgs := make(chan message.Receive, rawMessageBuffSize) - stop := stoppable.NewSingle(filePartStoppableName) - go m2.receive(rawMsgs, stop) - - // Send message on channel; - rawMsgs <- receiveMsg - - <-done1 - - // Create cMix message with wrong fingerprint - cMixMsg.SetKeyFP(format.NewFingerprint([]byte("invalidFP"))) - receiveMsg = message.Receive{Payload: cMixMsg.Marshal()} - - done := make(chan bool) - go func() { - select { - case <-time.NewTimer(10 * time.Millisecond).C: - case r := <-cbChan: - t.Errorf("Callback should not be called for invalid message: %+v", r) - } - done <- true - }() - - // Send message on channel; - rawMsgs <- receiveMsg - - <-done -} - -// Tests that Manager.receive the progress callback is not called when the -// stoppable is triggered. -func TestManager_receive_Stop(t *testing.T) { - // Build a manager for sending and a manger for receiving - m1 := newTestManager(false, nil, nil, nil, nil, t) - m2 := newTestManager(false, nil, nil, nil, nil, t) - - // Create transfer components - prng := NewPrng(42) - recipient := id.NewIdFromString("recipient", id.User, t) - key, _ := ftCrypto.NewTransferKey(prng) - numParts := uint16(16) - numFps := calcNumberOfFingerprints(numParts, 0.5) - partSize, _ := m1.getPartSize() - file, parts := newFile(numParts, partSize, prng, t) - fileSize := uint32(len(file)) - mac := ftCrypto.CreateTransferMAC(file, key) - - // Add transfer to sending manager - stID, err := m1.sent.AddTransfer( - recipient, key, parts, numFps, nil, 0, prng) - if err != nil { - t.Errorf("Failed to add new sent transfer: %+v", err) - } - - // Add transfer to receiving manager - rtID, err := m2.received.AddTransfer( - key, mac, fileSize, numParts, numFps, prng) - if err != nil { - t.Errorf("Failed to add new received transfer: %+v", err) - } - - // Generate receive callback that should be called when a message is read - cbChan := make(chan receivedProgressResults) - cb := func(completed bool, received, total uint16, - tr interfaces.FilePartTracker, err error) { - cbChan <- receivedProgressResults{completed, received, total, tr, err} - } - - done0, done1 := make(chan bool), make(chan bool) - go func() { - for i := 0; i < 2; i++ { - select { - case <-time.NewTimer(20 * time.Millisecond).C: - done1 <- true - case r := <-cbChan: - switch i { - case 0: - done0 <- true - case 1: - t.Errorf("Callback should not have been called: %+v", r) - done1 <- true - } - } - } - }() - - rt, err := m2.received.GetTransfer(rtID) - if err != nil { - t.Errorf("Failed to get received transfer %s: %+v", rtID, err) - } - rt.AddProgressCB(cb, time.Millisecond) - - <-done0 - - // Build message.Receive with cMix message that has file part - st, err := m1.sent.GetTransfer(stID) - if err != nil { - t.Errorf("Failed to get sent transfer %s: %+v", stID, err) - } - cMixMsg, err := m1.newCmixMessage(st, 0) - if err != nil { - t.Errorf("Failed to create new cMix message: %+v", err) - } - receiveMsg := message.Receive{ - Payload: cMixMsg.Marshal(), - } - - // Start reception thread - rawMsgs := make(chan message.Receive, rawMessageBuffSize) - stop := stoppable.NewSingle(filePartStoppableName) - go m2.receive(rawMsgs, stop) - - // Trigger stoppable - err = stop.Close() - if err != nil { - t.Errorf("Failed to close stoppable: %+v", err) - } - - for stop.IsStopping() { - - } - - // Send message on channel; - rawMsgs <- receiveMsg - - <-done1 -} - -// Tests that Manager.readMessage reads the message without errors and that it -// reports the correct progress on the callback. It also gets the file and -// checks that the part is where it should be. -func TestManager_readMessage(t *testing.T) { - - // Build a manager for sending and a manger for receiving - m1 := newTestManager(false, nil, nil, nil, nil, t) - m2 := newTestManager(false, nil, nil, nil, nil, t) - - // Create transfer components - prng := NewPrng(42) - recipient := id.NewIdFromString("recipient", id.User, t) - key, _ := ftCrypto.NewTransferKey(prng) - numParts := uint16(16) - numFps := calcNumberOfFingerprints(numParts, 0.5) - partSize, _ := m1.getPartSize() - file, parts := newFile(numParts, partSize, prng, t) - fileSize := uint32(len(file)) - mac := ftCrypto.CreateTransferMAC(file, key) - - // Add transfer to sending manager - stID, err := m1.sent.AddTransfer( - recipient, key, parts, numFps, nil, 0, prng) - if err != nil { - t.Errorf("Failed to add new sent transfer: %+v", err) - } - - // Add transfer to receiving manager - rtID, err := m2.received.AddTransfer( - key, mac, fileSize, numParts, numFps, prng) - if err != nil { - t.Errorf("Failed to add new received transfer: %+v", err) - } - - // Generate receive callback that should be called when a message is read - cbChan := make(chan receivedProgressResults, 2) - cb := func(completed bool, received, total uint16, - tr interfaces.FilePartTracker, err error) { - cbChan <- receivedProgressResults{completed, received, total, tr, err} - } - - done0, done1 := make(chan bool), make(chan bool) - go func() { - for i := 0; i < 2; i++ { - select { - case <-time.NewTimer(10 * time.Millisecond).C: - t.Error("Timed out waiting for callback to be called.") - case r := <-cbChan: - switch i { - case 0: - done0 <- true - case 1: - err := checkReceivedProgress(r.completed, r.received, r.total, false, - 1, numParts) - if r.err != nil { - t.Errorf("Callback returned an error: %+v", err) - } - if err != nil { - t.Error(err) - } - done1 <- true - } - } - } - }() - - rt, err := m2.received.GetTransfer(rtID) - if err != nil { - t.Errorf("Failed to get received transfer %s: %+v", rtID, err) - } - rt.AddProgressCB(cb, time.Millisecond) - - <-done0 - - // Build message.Receive with cMix message that has file part - st, err := m1.sent.GetTransfer(stID) - if err != nil { - t.Errorf("Failed to get sent transfer %s: %+v", stID, err) - } - cMixMsg, err := m1.newCmixMessage(st, 0) - if err != nil { - t.Errorf("Failed to create new cMix message: %+v", err) - } - receiveMsg := message.Receive{ - Payload: cMixMsg.Marshal(), - } - - // Read receive message - receivedCMixMsg, err := m2.readMessage(receiveMsg) - if err != nil { - t.Errorf("readMessage returned an error: %+v", err) - } - - if !reflect.DeepEqual(cMixMsg, receivedCMixMsg) { - t.Errorf("Received cMix message does not match sent."+ - "\nexpected: %+v\nreceived: %+v", cMixMsg, receivedCMixMsg) - } - - <-done1 - - // Get the file and check that the part was added to it - fileData, err := rt.GetFile() - if err == nil { - t.Error("GetFile did not return an error when parts are missing.") - } - - if !bytes.Equal(parts[0], fileData[:partSize]) { - t.Errorf("Part failed to be added to part store."+ - "\nexpected: %q\nreceived: %q", parts[0], fileData[:partSize]) - } -} diff --git a/fileTransfer/send.go b/fileTransfer/send.go index 4a795336f93d5723ee99bc0acbd9c74ccbda7d0b..ca66ae68118f6b11f6f433b5ee15658d9ce89432 100644 --- a/fileTransfer/send.go +++ b/fileTransfer/send.go @@ -8,638 +8,176 @@ package fileTransfer import ( - "bytes" - "encoding/binary" "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/cmix" - "gitlab.com/elixxir/client/interfaces/message" - "gitlab.com/elixxir/client/interfaces/params" + "gitlab.com/elixxir/client/cmix/message" + "gitlab.com/elixxir/client/fileTransfer/sentRoundTracker" + "gitlab.com/elixxir/client/fileTransfer/store" "gitlab.com/elixxir/client/stoppable" - ftStorage "gitlab.com/elixxir/client/storage/fileTransfer" - ds "gitlab.com/elixxir/comms/network/dataStructures" ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" - "gitlab.com/elixxir/crypto/shuffle" - "gitlab.com/elixxir/primitives/format" - "gitlab.com/elixxir/primitives/states" - "gitlab.com/xx_network/crypto/csprng" "gitlab.com/xx_network/primitives/id" - "gitlab.com/xx_network/primitives/netTime" + "strconv" "time" ) // Error messages. const ( - // Manager.sendParts - sendManyCmixWarn = "[FT] Failed to send %d file parts %v via SendManyCMIX: %+v" - setInProgressErr = "[FT] Failed to set parts %v to in-progress for transfer %s: %+v" - getRoundResultsErr = "[FT] Failed to get round results for round %d for file transfers %v: %+v" - - // Manager.buildMessages - noSentTransferWarn = "[FT] Could not get transfer %s for part %d: %+v" - maxRetriesErr = "Stopping message transfer: %+v" - newCmixMessageErr = "[FT] Failed to assemble cMix message for file part %d on transfer %s: %+v" - - // Manager.makeRoundEventCallback - finishPassNoTransferErr = "[FT] Failed to mark in-progress parts as finished on success of round %d for transfer %s: %+v" - finishTransferErr = "[FT] Failed to set part(s) to finished for transfer %s: %+v" - finishedEndE2eMsfErr = "[FT] Failed to send E2E message to %s on completion of file transfer %s: %+v" - roundFailureWarn = "[FT] Failed to send file parts for file transfers %v on round %d: round %s" - finishFailNoTransferErr = "[FT] Failed to requeue in-progress parts on failure of round %d for transfer %s: %+v" - unsetInProgressErr = "[FT] Failed to remove parts from in-progress list for transfer %s: round %s: %+v" - - // Manager.sendEndE2eMessage - endE2eGetPartnerErr = "failed to get file transfer partner %s: %+v" - endE2eHealthTimeout = "waiting for network to become healthy timed out after %s." - endE2eSendErr = "failed to send end file transfer message via E2E to recipient %s: %+v" - - // getRandomNumParts + // generateRandomPacketSize getRandomNumPartsRandPanic = "[FT] Failed to generate random number of file parts to send: %+v" + + // manager.sendCmix + errNoMoreRetries = "file transfer failed: ran our of retries." ) const ( // Duration to wait for round to finish before timing out. roundResultsTimeout = 15 * time.Second - // Duration to wait for send batch to fill before sending partial batch. - pollSleepDuration = 100 * time.Millisecond - // Age when rounds that files were sent from are deleted from the tracker. clearSentRoundsAge = 10 * time.Second - // Duration to wait for network to become healthy to send end E2E message - // before timing out. - sendEndE2eHealthTimeout = 5 * time.Second + // Number of concurrent sending threads + workerPoolThreads = 4 // Tag that prints with cMix sending logs. cMixDebugTag = "FT.Part" -) - -// sendThread waits on the sendQueue channel for parts to send. Once its -// receives a random number between 1 and 11 of file parts, they are encrypted, -// put into cMix messages, and sent to their recipients. Failed messages are -// added to the end of the queue. -func (m *Manager) sendThread(stop *stoppable.Single, healthChan chan bool, - healthChanID uint64, getNumParts getRngNum) { - jww.DEBUG.Print("[FT] Starting file part sending thread.") - - // Calculate the average amount of data sent via SendManyCMIX - avgNumMessages := (minPartsSendPerRound + maxPartsSendPerRound) / 2 - avgSendSize := avgNumMessages * (8192 / 8) - // Calculate the delay needed to reach max throughput - delay := time.Duration((int(time.Second) * avgSendSize) / m.p.MaxThroughput) - - // Batch of parts read from the queue to be sent - var partList []queuedPart - - // Create new sent round tracker that tracks which recent rounds file parts - // were sent on so that they can be avoided on subsequent sends - sentRounds := newSentRoundTracker(clearSentRoundsAge) - - // The size of each batch - var numParts int + // Prefix used for the name of a stoppable used for a sending thread + sendThreadStoppableName = "FilePartSendingThread#" +) - // Timer triggers sending of unfilled batch to prevent hanging when the - // file part queue has fewer items then the batch size - timer := time.NewTimer(pollSleepDuration) +// startSendingWorkerPool initialises a worker pool of file part sending +// threads. +func (m *manager) startSendingWorkerPool(multiStop *stoppable.Multi) { + // Set up cMix sending parameters + params := cmix.GetDefaultCMIXParams() + params.SendTimeout = m.params.SendTimeout + params.ExcludedRounds = sentRoundTracker.NewManager(clearSentRoundsAge) + params.DebugTag = cMixDebugTag - // Tracks time that the last send completed - var lastSend time.Time + for i := 0; i < workerPoolThreads; i++ { + stop := stoppable.NewSingle(sendThreadStoppableName + strconv.Itoa(i)) + multiStop.Add(stop) + go m.sendingThread(params, stop) + } +} - // Loop forever polling the sendQueue channel for new file parts to send. If - // the channel is empty, then polling is suspended for pollSleepDuration. If - // the network is not healthy, then polling is suspended until the network - // becomes healthy. +// sendingThread sends part packets that become available oin the send queue. +func (m *manager) sendingThread(cMixParams cmix.CMIXParams, stop *stoppable.Single) { + healthChan := make(chan bool, 10) + healthChanID := m.cmix.AddHealthCallback(func(b bool) { healthChan <- b }) for { - timer = time.NewTimer(pollSleepDuration) select { case <-stop.Quit(): - timer.Stop() - - // Close the thread when the stoppable is triggered - m.closeSendThread(partList, stop, healthChanID) - + jww.DEBUG.Printf("[FT] Stopping file part sending thread: " + + "stoppable triggered.") + m.cmix.RemoveHealthCallback(healthChanID) + stop.ToStopped() return case healthy := <-healthChan: - var wasNotHealthy bool - // If the network is unhealthy, wait until it becomes healthy - if !healthy { - jww.TRACE.Print("[FT] Suspending file part sending thread: " + - "network is unhealthy.") - wasNotHealthy = true - } for !healthy { healthy = <-healthChan } - if wasNotHealthy { - jww.TRACE.Print("[FT] File part sending thread: " + - "network is healthy.") - } - case part := <-m.sendQueue: - // When a part is received from the queue, add it to the list of - // parts to be sent - - // If the batch is empty (a send just occurred), start a new batch - if partList == nil { - rng := m.rng.GetStream() - numParts = getNumParts(rng) - rng.Close() - partList = make([]queuedPart, 0, numParts) - } - - partList = append(partList, part) - - // If the batch is full, then send the parts - if len(partList) == numParts { - quit := m.handleSend( - &partList, &lastSend, delay, stop, healthChanID, sentRounds) - if quit { - timer.Stop() - return - } - } - case <-timer.C: - // If the timeout is reached, send an incomplete batch - - // Skip if there are no parts to send - if len(partList) == 0 { - continue - } - - quit := m.handleSend( - &partList, &lastSend, delay, stop, healthChanID, sentRounds) - if quit { - return - } + case packet := <-m.sendQueue: + m.sendCmix(packet, cMixParams) } } } -// closeSendThread safely stops the sending thread by saving unsent parts to the -// queue and setting the stoppable to stopped. -func (m *Manager) closeSendThread(partList []queuedPart, stop *stoppable.Single, - healthChanID uint64) { - // Exit the thread if the stoppable is triggered - jww.DEBUG.Print("[FT] Stopping file part sending thread: stoppable " + - "triggered.") - - // Add all the unsent parts back in the queue - for _, part := range partList { - m.sendQueue <- part - } - - // Unregister network health channel - m.net.GetHealthTracker().RemoveChannel(healthChanID) - - // Mark stoppable as stopped - stop.ToStopped() -} - -// handleSend handles the sending of parts with bandwidth limitations. On a -// successful send, the partList is cleared and lastSend is updated to the send -// timestamp. When the stoppable is triggered, closing is automatically handled. -// Returns true if the stoppable has been triggered and the sending thread -// should quit. -func (m *Manager) handleSend(partList *[]queuedPart, lastSend *time.Time, - delay time.Duration, stop *stoppable.Single, healthChanID uint64, - sentRounds *sentRoundTracker) bool { - // Bandwidth limiter: wait to send until the delay has been reached so that - // the bandwidth is limited to the maximum throughput - if netTime.Since(*lastSend) < delay { - waitingTime := delay - netTime.Since(*lastSend) - jww.TRACE.Printf("[FT] Suspending file part sending (%d parts): "+ - "bandwidth limit reached; waiting %s to send.", - len(*partList), waitingTime) - - waitingTimer := time.NewTimer(waitingTime) - select { - case <-stop.Quit(): - waitingTimer.Stop() - - // Close the thread when the stoppable is triggered - m.closeSendThread(*partList, stop, healthChanID) - - return true - case <-waitingTimer.C: - jww.TRACE.Printf("[FT] Resuming file part sending (%d parts) "+ - "after waiting %s for bandwidth limiting.", - len(*partList), waitingTime) - } - } - - // Send all the messages - go func(partList []queuedPart, sentRounds *sentRoundTracker) { - err := m.sendParts(partList, sentRounds) +// sendCmix sends the parts in the packet via Client.SendMany. +func (m *manager) sendCmix(packet []store.Part, cMixParams cmix.CMIXParams) { + // validParts will contain all parts in the original packet excluding those + // that return an error from GetEncryptedPart + validParts := make([]store.Part, 0, len(packet)) + + // Encrypt each part and to a TargetedCmixMessage + messages := make([]cmix.TargetedCmixMessage, 0, len(packet)) + for _, p := range packet { + encryptedPart, mac, fp, err := + p.GetEncryptedPart(m.cmix.GetMaxMessageLength()) if err != nil { - jww.ERROR.Print(err) + jww.ERROR.Printf("[FT] File transfer %s (%q) failed: %+v", + p.TransferID(), p.FileName(), err) + m.callbacks.Call(p.TransferID(), errors.New(errNoMoreRetries)) + continue } - }(copyPartList(*partList), sentRounds) - - // Update the timestamp of the send - *lastSend = netTime.Now() - // Clear partList once done - *partList = nil + validParts = append(validParts, p) - return false -} - -// copyPartList makes a copy of the list of queuedPart. -func copyPartList(partList []queuedPart) []queuedPart { - newPartList := make([]queuedPart, len(partList)) - copy(newPartList, partList) - return newPartList -} - -// sendParts handles the composing and sending of a cMix message for each part -// in the list. All errors returned are fatal errors. -func (m *Manager) sendParts(partList []queuedPart, - sentRounds *sentRoundTracker) error { - - // Build cMix messages - messages, transfers, groupedParts, partsToResend, err := - m.buildMessages(partList) - if err != nil { - return err - } - - // Exit if there are no parts to send - if len(messages) == 0 { - return nil + messages = append(messages, cmix.TargetedCmixMessage{ + Recipient: p.Recipient(), + Payload: encryptedPart, + Fingerprint: fp, + Service: message.Service{}, + Mac: mac, + }) } // Clear all old rounds from the sent rounds list - sentRounds.removeOldRounds() - - // Create cMix parameters with round exclusion list - p := params.GetDefaultCMIX() - p.SendTimeout = m.p.SendTimeout - p.ExcludedRounds = sentRounds - p.DebugTag = cMixDebugTag + cMixParams.ExcludedRounds.(*sentRoundTracker.Manager).RemoveOldRounds() - jww.TRACE.Printf("[FT] Sending %d file parts via SendManyCMIX with "+ - "parameters %+v", len(messages), p) + jww.DEBUG.Printf("[FT] Sending %d file parts via SendManyCMIX", + len(messages)) - // Send parts - rid, _, err := m.net.SendManyCMIX(messages, p) + rid, _, err := m.cmix.SendMany(messages, cMixParams) if err != nil { - // If an error occurs, then print a warning and add the file parts back - // to the queue to try sending again - jww.WARN.Printf(sendManyCmixWarn, len(messages), groupedParts, err) + jww.WARN.Printf("[FT] Failed to send %d file parts via "+ + "SendManyCMIX: %+v", len(messages), err) - // Add parts back to queue - for _, partIndex := range partsToResend { - m.sendQueue <- partList[partIndex] - } - - return nil - } - - // Create list for transfer IDs to watch with the round results callback - tIDs := make([]ftCrypto.TransferID, 0, len(transfers)) - - // Set all parts to in-progress - for tid, transfer := range transfers { - exists, err := transfer.SetInProgress(rid, groupedParts[tid]...) - if err != nil { - return errors.Errorf(setInProgressErr, groupedParts[tid], tid, err) - } - - transfer.CallProgressCB(nil) - - // Add transfer ID to list to be tracked; skip if the tracker has - // already been launched for this transfer and round ID - if !exists { - tIDs = append(tIDs, tid) + for _, p := range validParts { + m.batchQueue <- p } } - // Set up tracker waiting for the round to end to update state and update - // progress - roundResultCB := m.makeRoundEventCallback( - map[id.Round][]ftCrypto.TransferID{rid: tIDs}) - err = m.getRoundResults( - []id.Round{rid}, roundResultsTimeout, roundResultCB) - if err != nil { - return errors.Errorf(getRoundResultsErr, rid, tIDs, err) - } - - return nil + err = m.cmix.GetRoundResults( + roundResultsTimeout, m.roundResultsCallback(validParts), rid) } -// buildMessages builds the list of cMix messages to send via SendManyCmix. Also -// returns three separate lists used for later progress tracking. The first, a -// map that contains each unique transfer for each part in the list. The second, -// a map of part numbers being sent grouped by their transfer. The last a list -// of partList index of parts that will be sent. Any part that encounters a -// non-fatal error will be skipped and will not be included in an of the lists. -// All errors returned are fatal errors. -func (m *Manager) buildMessages(partList []queuedPart) ( - []message.TargetedCmixMessage, map[ftCrypto.TransferID]*ftStorage.SentTransfer, - map[ftCrypto.TransferID][]uint16, []int, error) { - messages := make([]message.TargetedCmixMessage, 0, len(partList)) - transfers := map[ftCrypto.TransferID]*ftStorage.SentTransfer{} - groupedParts := map[ftCrypto.TransferID][]uint16{} - partsToResend := make([]int, 0, len(partList)) - - rng := m.rng.GetStream() - defer rng.Close() - - for i, part := range partList { - // Lookup the transfer by the ID; if the transfer does not exist, then - // print a warning and skip this message - st, err := m.sent.GetTransfer(part.tid) - if err != nil { - jww.WARN.Printf(noSentTransferWarn, part.tid, part.partNum, err) - continue +// roundResultsCallback generates a network.RoundEventCallback that handles +// all parts in the packet once the round succeeds or fails. +func (m *manager) roundResultsCallback(packet []store.Part) cmix.RoundEventCallback { + // Group file parts by transfer + grouped := map[ftCrypto.TransferID][]store.Part{} + for _, p := range packet { + if _, exists := grouped[*p.TransferID()]; exists { + grouped[*p.TransferID()] = append(grouped[*p.TransferID()], p) + } else { + grouped[*p.TransferID()] = []store.Part{p} } - - // Generate new cMix message with encrypted file part - cmixMsg, err := m.newCmixMessage(st, part.partNum) - if err == ftStorage.MaxRetriesErr { - jww.DEBUG.Printf("[FT] File transfer %s sent to %s ran out of "+ - "retries {parts: %d, numFps: %d/%d}", - part.tid, st.GetRecipient(), st.GetNumParts(), - st.GetNumFps()-st.GetNumAvailableFps(), st.GetNumFps()) - - // If the max number of retries has been reached, then report the - // error on the callback, delete the transfer, and skip to the next - // message - go st.CallProgressCB(errors.Errorf(maxRetriesErr, err)) - continue - } else if err != nil { - // For all other errors, return an error - return nil, nil, nil, nil, - errors.Errorf(newCmixMessageErr, part.partNum, part.tid, err) - } - - // Construct TargetedCmixMessage - msg := message.TargetedCmixMessage{ - Recipient: st.GetRecipient(), - Message: cmixMsg, - } - - // Add to list of messages to send - messages = append(messages, msg) - transfers[part.tid] = st - groupedParts[part.tid] = append(groupedParts[part.tid], part.partNum) - partsToResend = append(partsToResend, i) - } - - return messages, transfers, groupedParts, partsToResend, nil -} - -// newCmixMessage creates a new cMix message with an encrypted file part, its -// MAC, and fingerprint. -func (m *Manager) newCmixMessage(transfer *ftStorage.SentTransfer, - partNum uint16) (format.Message, error) { - // Create new empty cMix message - cmixMsg := format.NewMessage(m.store.Cmix().GetGroup().GetP().ByteLen()) - - // Get encrypted file part, file part MAC, nonce (nonce), and fingerprint - encPart, mac, fp, err := transfer.GetEncryptedPart(partNum, cmixMsg.ContentsSize()) - if err != nil { - return format.Message{}, err } - // Construct cMix message - cmixMsg.SetContents(encPart) - cmixMsg.SetKeyFP(fp) - cmixMsg.SetMac(mac) - - return cmixMsg, nil -} - -// makeRoundEventCallback returns an api.RoundEventCallback that is called once -// the round that file parts were sent on either succeeds or fails. If the round -// succeeds, then all file parts for each transfer are marked as finished and -// the progress callback is called with the current progress. If the round -// fails, then each part for each transfer is removed from the in-progress list, -// added to the end of the sending queue, and the callback called with an error. -func (m *Manager) makeRoundEventCallback( - sentRounds map[id.Round][]ftCrypto.TransferID) cmix.RoundEventCallback { - - return func(allSucceeded, timedOut bool, rounds map[id.Round]cmix.RoundLookupStatus) { - for rid, roundResult := range rounds { - if roundResult == cmix.Succeeded { - // If the round succeeded, then set all parts for each transfer - // for this round to finished and call the progress callback - for _, tid := range sentRounds[rid] { - st, err := m.sent.GetTransfer(tid) - if err != nil { - jww.ERROR.Printf(finishPassNoTransferErr, rid, tid, err) - continue - } - - // Mark as finished - completed, err := st.FinishTransfer(rid) - if err != nil { - jww.ERROR.Printf(finishTransferErr, tid, err) - continue - } - - // Call progress callback after change in progress - st.CallProgressCB(nil) + return func( + allRoundsSucceeded, _ bool, rounds map[id.Round]cmix.RoundResult) { + // Get round ID + var rid id.Round + for rid = range rounds { + break + } - // If the transfer is complete, send an E2E message to the - // recipient informing them - if completed { - jww.DEBUG.Printf("[FT] Finished sending file "+ - "transfer %s to %s {parts: %d, numFps: %d/%d}", - tid, st.GetRecipient(), st.GetNumParts(), - st.GetNumFps()-st.GetNumAvailableFps(), - st.GetNumFps()) + if allRoundsSucceeded { + jww.DEBUG.Printf("[FT] %d file parts delivered on round %d (%v)", + len(packet), rid, grouped) - go func(tid ftCrypto.TransferID, recipient *id.ID) { - err = m.sendEndE2eMessage(recipient) - if err != nil { - jww.ERROR.Printf(finishedEndE2eMsfErr, - recipient, tid, err) - } - }(tid, st.GetRecipient()) - } + // If the round succeeded, then mark all parts as arrived and report + // each transfer's progress on its progress callback + for tid, parts := range grouped { + for _, p := range parts { + p.MarkArrived() } - } else { - jww.WARN.Printf(roundFailureWarn, sentRounds[rid], rid, roundResult) - - // If the round failed, then remove all parts for each transfer - // for this round from the in-progress list, call the progress - // callback with an error, and add the parts back into the queue - for _, tid := range sentRounds[rid] { - st, err := m.sent.GetTransfer(tid) - if err != nil { - jww.ERROR.Printf(finishFailNoTransferErr, rid, tid, err) - continue - } - - // Remove parts from in-progress list - partsToResend, err := st.UnsetInProgress(rid) - if err != nil { - jww.ERROR.Printf( - unsetInProgressErr, tid, roundResult, err) - } - - // Call progress callback after change in progress - st.CallProgressCB(nil) - - // Add all the unsent parts back in the queue - m.queueParts(tid, partsToResend) - } + // Call the progress callback after all parts have been marked + // so that the progress reported included all parts in the batch + m.callbacks.Call(&tid, nil) } - } - } -} - -// sendEndE2eMessage sends an E2E message to the recipient once the transfer -// complete information them that all file parts have been sent. -func (m *Manager) sendEndE2eMessage(recipient *id.ID) error { - // Get the partner - partner, err := m.store.E2e().GetPartner(recipient) - if err != nil { - return errors.Errorf(endE2eGetPartnerErr, recipient, err) - } - - // Build the message - sendMsg := message.Send{ - Recipient: recipient, - MessageType: message.EndFileTransfer, - } - - // Send the message under file transfer preimage - e2eParams := params.GetDefaultE2E() - e2eParams.IdentityPreimage = partner.GetFileTransferPreimage() - e2eParams.DebugTag = "FT.End" + } else { + jww.DEBUG.Printf("[FT] %d file parts failed on round %d (%v)", + len(packet), rid, grouped) - // Store the message in the critical messages buffer first to ensure it is - // present if the send fails - m.store.GetCriticalMessages().AddProcessing(sendMsg, e2eParams) - - // Register health channel and wait for network to become healthy - healthChan := make(chan bool, networkHealthBuffLen) - healthChanID := m.net.GetHealthTracker().AddChannel(healthChan) - defer m.net.GetHealthTracker().RemoveChannel(healthChanID) - isHealthy := m.net.GetHealthTracker().IsHealthy() - healthCheckTimer := time.NewTimer(sendEndE2eHealthTimeout) - for !isHealthy { - select { - case isHealthy = <-healthChan: - case <-healthCheckTimer.C: - return errors.Errorf(endE2eHealthTimeout, sendEndE2eHealthTimeout) + // If the round failed, then add each part into the send queue + for _, p := range packet { + m.batchQueue <- p + } } } - healthCheckTimer.Stop() - - // Send E2E message - rounds, e2eMsgID, _, err := m.net.SendE2E(sendMsg, e2eParams, nil) - if err != nil { - return errors.Errorf(endE2eSendErr, recipient, err) - } - - // Register the event for all rounds - sendResults := make(chan ds.EventReturn, len(rounds)) - roundEvents := m.net.GetInstance().GetRoundEvents() - for _, r := range rounds { - roundEvents.AddRoundEventChan(r, sendResults, 10*time.Second, - states.COMPLETED, states.FAILED) - } - - // Wait until the result tracking responds - success, numTimeOut, numRoundFail := cmix.TrackResults( - sendResults, len(rounds)) - - // If a single partition of the end file transfer message does not transmit, - // then the partner will not be able to read the confirmation - if !success { - jww.ERROR.Printf("[FT] Sending E2E message %s to end file transfer "+ - "with %s failed to transmit %d/%d partitions: %d round failures, "+ - "%d timeouts", recipient, e2eMsgID, numRoundFail+numTimeOut, - len(rounds), numRoundFail, numTimeOut) - m.store.GetCriticalMessages().Failed(sendMsg, e2eParams) - return nil - } - - // Otherwise, the transmission is a success and this should be denoted in - // the session and the log - m.store.GetCriticalMessages().Succeeded(sendMsg, e2eParams) - jww.INFO.Printf("[FT] Sending of message %s informing %s that a transfer "+ - "completed successfully.", e2eMsgID, recipient) - - return nil -} - -// queueParts adds an entry for each file part in the list into the sendQueue -// channel in a random order. -func (m *Manager) queueParts(tid ftCrypto.TransferID, partNums []uint16) { - // Shuffle the list - shuffle.Shuffle16(&partNums) - - // Add each part to the queue - for _, partNum := range partNums { - m.sendQueue <- queuedPart{tid, partNum} - } -} - -// makeListOfPartNums returns a list of number of file part, from 0 to numParts. -func makeListOfPartNums(numParts uint16) []uint16 { - partNumList := make([]uint16, numParts) - for i := range partNumList { - partNumList[i] = uint16(i) - } - - return partNumList -} - -// getPartSize determines the maximum size for each file part in bytes. The size -// is calculated based on the content size of a cMix message. Returns an error -// if a file part message cannot fit into a cMix message payload. -func (m *Manager) getPartSize() (int, error) { - // Create new empty cMix message - cmixMsg := format.NewMessage(m.store.Cmix().GetGroup().GetP().ByteLen()) - - // Create new empty file part message of size equal to the available payload - // size in the cMix message - partMsg, err := ftStorage.NewPartMessage(cmixMsg.ContentsSize()) - if err != nil { - return 0, err - } - - return partMsg.GetPartSize(), nil -} - -// partitionFile splits the file into parts of the specified part size. -func partitionFile(file []byte, partSize int) [][]byte { - // Initialize part list to the correct size - numParts := (len(file) + partSize - 1) / partSize - parts := make([][]byte, 0, numParts) - buff := bytes.NewBuffer(file) - - for n := buff.Next(partSize); len(n) > 0; n = buff.Next(partSize) { - newPart := make([]byte, partSize) - copy(newPart, n) - parts = append(parts, newPart) - } - - return parts -} - -// getRngNum takes in a PRNG source and returns a random number. This type makes -// it easier to test by allowing custom functions that return expected values. -type getRngNum func(rng csprng.Source) int - -// getRandomNumParts returns a random number between minPartsSendPerRound and -// maxPartsSendPerRound, inclusive. -func getRandomNumParts(rng csprng.Source) int { - // Generate random bytes - b, err := csprng.Generate(8, rng) - if err != nil { - jww.FATAL.Panicf(getRandomNumPartsRandPanic, err) - } - - // Convert bytes to integer - num := binary.LittleEndian.Uint64(b) - - // Return random number that is minPartsSendPerRound <= num <= max - return int((num % (maxPartsSendPerRound)) + minPartsSendPerRound) } diff --git a/fileTransfer/sendE2e.go b/fileTransfer/sendE2e.go new file mode 100644 index 0000000000000000000000000000000000000000..647f3d905fdea3f172a94bfca8c2e7638e846cbd --- /dev/null +++ b/fileTransfer/sendE2e.go @@ -0,0 +1,101 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package fileTransfer + +import ( + "github.com/golang/protobuf/proto" + "github.com/pkg/errors" + jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/client/catalog" + "gitlab.com/elixxir/client/e2e" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/xx_network/primitives/id" +) + +// Error messages. +const ( + // manager.sendNewFileTransferMessage + errProtoMarshal = "failed to proto marshal NewFileTransfer: %+v" + errNewFtSendE2e = "failed to send initial file transfer message via E2E: %+v" + + // manager.sendEndFileTransferMessage + errEndFtSendE2e = "[FT] Failed to send ending file transfer message via E2E: %+v" +) + +const ( + // Tag that is used for log printing in SendE2E when sending the initial + // message + initialMessageDebugTag = "FT.New" + + // Tag that is used for log printing in SendE2E when sending the ending + // message + lastMessageDebugTag = "FT.End" +) + +// sendNewFileTransferMessage sends an E2E message to the recipient informing +// them of the incoming file transfer. +func (m *manager) sendNewFileTransferMessage(recipient *id.ID, fileName, + fileType string, key *ftCrypto.TransferKey, mac []byte, numParts uint16, + fileSize uint32, retry float32, preview []byte) error { + + // Construct NewFileTransfer message + protoMsg := &NewFileTransfer{ + FileName: fileName, + FileType: fileType, + TransferKey: key.Bytes(), + TransferMac: mac, + NumParts: uint32(numParts), + Size: fileSize, + Retry: retry, + Preview: preview, + } + + // Marshal the message + payload, err := proto.Marshal(protoMsg) + if err != nil { + return errors.Errorf(errProtoMarshal, err) + } + + // Get E2E parameters + params := e2e.GetDefaultParams() + params.ServiceTag = catalog.Silent + params.LastServiceTag = catalog.Silent + params.CMIX.DebugTag = initialMessageDebugTag + + _, _, _, err = m.e2e.SendE2E( + catalog.NewFileTransfer, recipient, payload, params) + if err != nil { + return errors.Errorf(errNewFtSendE2e, err) + } + + return nil +} + +// sendEndFileTransferMessage sends an E2E message to the recipient informing +// them that all file parts have arrived once the network is healthy. +func (m *manager) sendEndFileTransferMessage(recipient *id.ID) { + callbackID := make(chan uint64, 1) + callbackID <- m.cmix.AddHealthCallback( + func(healthy bool) { + if healthy { + params := e2e.GetDefaultParams() + params.LastServiceTag = catalog.EndFT + params.CMIX.DebugTag = lastMessageDebugTag + + _, _, _, err := m.e2e.SendE2E( + catalog.EndFileTransfer, recipient, nil, params) + if err != nil { + jww.ERROR.Printf(errEndFtSendE2e, err) + } + + cbID := <-callbackID + m.cmix.RemoveHealthCallback(cbID) + } + }, + ) +} diff --git a/fileTransfer/sendNew.go b/fileTransfer/sendNew.go deleted file mode 100644 index 91c979dc806324d5a4077a902605c239239c9926..0000000000000000000000000000000000000000 --- a/fileTransfer/sendNew.go +++ /dev/null @@ -1,89 +0,0 @@ -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - -package fileTransfer - -import ( - "github.com/golang/protobuf/proto" - "github.com/pkg/errors" - "gitlab.com/elixxir/client/interfaces/message" - "gitlab.com/elixxir/client/interfaces/params" - ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" - "gitlab.com/xx_network/primitives/id" -) - -// Error messages. -const ( - newFtProtoMarshalErr = "failed to form new file transfer message: %+v" - newFtSendE2eErr = "failed to send new file transfer message via E2E to recipient %s: %+v" -) - -// sendNewFileTransfer sends the initial file transfer message over E2E. -func (m *Manager) sendNewFileTransfer(recipient *id.ID, fileName, - fileType string, key ftCrypto.TransferKey, mac []byte, numParts uint16, - fileSize uint32, retry float32, preview []byte) error { - - // Create Send message with marshalled NewFileTransfer - sendMsg, err := newNewFileTransferE2eMessage(recipient, fileName, fileType, - key, mac, numParts, fileSize, retry, preview) - if err != nil { - return errors.Errorf(newFtProtoMarshalErr, err) - } - - // Get partner relationship so that the silent preimage can be generated - relationship, err := m.store.E2e().GetPartner(recipient) - if err != nil { - return err - } - - // Sends as a silent message to avoid a notification - p := params.GetDefaultE2E() - p.CMIX.IdentityPreimage = relationship.GetSilentPreimage() - p.DebugTag = "FT.New" - - // Send E2E message - _, _, _, err = m.net.SendE2E(sendMsg, p, nil) - if err != nil { - return errors.Errorf(newFtSendE2eErr, recipient, err) - } - - return nil -} - -// newNewFileTransferE2eMessage generates the message.Send for the given -// recipient containing the marshalled NewFileTransfer message. -func newNewFileTransferE2eMessage(recipient *id.ID, fileName, fileType string, - key ftCrypto.TransferKey, mac []byte, numParts uint16, fileSize uint32, - retry float32, preview []byte) (message.Send, error) { - - // Construct NewFileTransfer message - protoMsg := &NewFileTransfer{ - FileName: fileName, - FileType: fileType, - TransferKey: key.Bytes(), - TransferMac: mac, - NumParts: uint32(numParts), - Size: fileSize, - Retry: retry, - Preview: preview, - } - - // Marshal the message - marshalledMsg, err := proto.Marshal(protoMsg) - if err != nil { - return message.Send{}, err - } - - // Create message.Send of the type NewFileTransfer - sendMsg := message.Send{ - Recipient: recipient, - Payload: marshalledMsg, - MessageType: message.NewFileTransfer, - } - - return sendMsg, nil -} diff --git a/fileTransfer/sendNew_test.go b/fileTransfer/sendNew_test.go deleted file mode 100644 index 8a4da00a6d666001813c728311496ba39fcf9a1b..0000000000000000000000000000000000000000 --- a/fileTransfer/sendNew_test.go +++ /dev/null @@ -1,152 +0,0 @@ -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - -package fileTransfer - -import ( - "fmt" - "github.com/cloudflare/circl/dh/sidh" - "github.com/golang/protobuf/proto" - "gitlab.com/elixxir/client/interfaces/message" - "gitlab.com/elixxir/client/interfaces/params" - util "gitlab.com/elixxir/client/storage/utility" - "gitlab.com/elixxir/crypto/diffieHellman" - ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" - "gitlab.com/xx_network/crypto/csprng" - "gitlab.com/xx_network/primitives/id" - "reflect" - "strings" - "testing" -) - -// Tests that the E2E message sent via Manager.sendNewFileTransfer matches -// expected. -func TestManager_sendNewFileTransfer(t *testing.T) { - m := newTestManager(false, nil, nil, nil, nil, t) - - recipient := id.NewIdFromString("recipient", id.User, t) - fileName := "testFile" - fileType := "txt" - key, _ := ftCrypto.NewTransferKey(NewPrng(42)) - mac := []byte("transferMac") - numParts, fileSize, retry := uint16(16), uint32(256), float32(1.5) - preview := []byte("filePreview") - - rng := csprng.NewSystemRNG() - dhKey := m.store.E2e().GetGroup().NewInt(42) - pubKey := diffieHellman.GeneratePublicKey(dhKey, m.store.E2e().GetGroup()) - _, mySidhPriv := util.GenerateSIDHKeyPair(sidh.KeyVariantSidhA, rng) - theirSidhPub, _ := util.GenerateSIDHKeyPair(sidh.KeyVariantSidhB, rng) - p := params.GetDefaultE2ESessionParams() - - err := m.store.E2e().AddPartner(recipient, pubKey, dhKey, - mySidhPriv, theirSidhPub, p, p) - if err != nil { - t.Errorf("Failed to add partner %s: %+v", recipient, err) - } - - expected, err := newNewFileTransferE2eMessage(recipient, fileName, fileType, - key, mac, numParts, fileSize, retry, preview) - if err != nil { - t.Errorf("Failed to create new Send message: %+v", err) - } - - err = m.sendNewFileTransfer(recipient, fileName, fileType, key, mac, - numParts, fileSize, retry, preview) - if err != nil { - t.Errorf("sendNewFileTransfer returned an error: %+v", err) - } - - received := m.net.(*testNetworkManager).GetE2eMsg(0) - - if !reflect.DeepEqual(expected, received) { - t.Errorf("Received E2E message does not match expected."+ - "\nexpected: %+v\nreceived: %+v", expected, received) - } -} - -// Error path: tests that Manager.sendNewFileTransfer returns the expected error -// when SendE2E fails. -func TestManager_sendNewFileTransfer_E2eError(t *testing.T) { - // Create new test manager with a SendE2E error triggered - m := newTestManager(true, nil, nil, nil, nil, t) - - recipient := id.NewIdFromString("recipient", id.User, t) - key, _ := ftCrypto.NewTransferKey(NewPrng(42)) - - rng := csprng.NewSystemRNG() - dhKey := m.store.E2e().GetGroup().NewInt(42) - pubKey := diffieHellman.GeneratePublicKey(dhKey, m.store.E2e().GetGroup()) - _, mySidhPriv := util.GenerateSIDHKeyPair(sidh.KeyVariantSidhA, rng) - theirSidhPub, _ := util.GenerateSIDHKeyPair(sidh.KeyVariantSidhB, rng) - p := params.GetDefaultE2ESessionParams() - - err := m.store.E2e().AddPartner(recipient, pubKey, dhKey, - mySidhPriv, theirSidhPub, p, p) - if err != nil { - t.Errorf("Failed to add partner %s: %+v", recipient, err) - } - - expectedErr := fmt.Sprintf(newFtSendE2eErr, recipient, "") - err = m.sendNewFileTransfer(recipient, "", "", key, nil, 16, 256, 1.5, nil) - if err == nil || !strings.Contains(err.Error(), expectedErr) { - t.Errorf("sendNewFileTransfer di dnot return the expected error when "+ - "SendE2E failed.\nexpected: %s\nreceived: %+v", expectedErr, err) - } - - if len(m.net.(*testNetworkManager).e2eMessages) > 0 { - t.Errorf("%d E2E messeage(s) found when SendE2E should have failed.", - len(m.net.(*testNetworkManager).e2eMessages)) - } -} - -// Tests that newNewFileTransferE2eMessage returns a message.Send with the -// correct recipient and message type and that the payload can be unmarshalled -// into the correct NewFileTransfer. -func Test_newNewFileTransferE2eMessage(t *testing.T) { - recipient := id.NewIdFromString("recipient", id.User, t) - key, _ := ftCrypto.NewTransferKey(NewPrng(42)) - expected := &NewFileTransfer{ - FileName: "testFile", - FileType: "txt", - TransferKey: key.Bytes(), - TransferMac: []byte("transferMac"), - NumParts: 16, - Size: 256, - Retry: 1.5, - Preview: []byte("filePreview"), - } - - sendMsg, err := newNewFileTransferE2eMessage(recipient, expected.FileName, - expected.FileType, key, expected.TransferMac, uint16(expected.NumParts), - expected.Size, expected.Retry, expected.Preview) - if err != nil { - t.Errorf("newNewFileTransferE2eMessage returned an error: %+v", err) - } - - if sendMsg.MessageType != message.NewFileTransfer { - t.Errorf("Send message has wrong MessageType."+ - "\nexpected: %d\nreceived: %d", message.NewFileTransfer, - sendMsg.MessageType) - } - - if !sendMsg.Recipient.Cmp(recipient) { - t.Errorf("Send message has wrong Recipient."+ - "\nexpected: %s\nreceived: %s", recipient, sendMsg.Recipient) - } - - received := &NewFileTransfer{} - err = proto.Unmarshal(sendMsg.Payload, received) - if err != nil { - t.Errorf("Failed to unmarshal received NewFileTransfer: %+v", err) - } - - if !proto.Equal(expected, received) { - t.Errorf("Received NewFileTransfer does not match expected."+ - "\nexpected: %+v\nreceived: %+v", expected, received) - } -} diff --git a/fileTransfer/send_test.go b/fileTransfer/send_test.go deleted file mode 100644 index 3572ac7c9f9df16752df11dab9050d2b4b10e4c8..0000000000000000000000000000000000000000 --- a/fileTransfer/send_test.go +++ /dev/null @@ -1,1076 +0,0 @@ -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - -package fileTransfer - -import ( - "bytes" - "errors" - "fmt" - "github.com/cloudflare/circl/dh/sidh" - "gitlab.com/elixxir/client/cmix" - "gitlab.com/elixxir/client/interfaces" - "gitlab.com/elixxir/client/interfaces/message" - "gitlab.com/elixxir/client/interfaces/params" - "gitlab.com/elixxir/client/stoppable" - ftStorage "gitlab.com/elixxir/client/storage/fileTransfer" - util "gitlab.com/elixxir/client/storage/utility" - "gitlab.com/elixxir/client/storage/versioned" - "gitlab.com/elixxir/crypto/diffieHellman" - ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" - "gitlab.com/elixxir/ekv" - "gitlab.com/elixxir/primitives/format" - "gitlab.com/xx_network/crypto/csprng" - "gitlab.com/xx_network/primitives/id" - "math/rand" - "reflect" - "sort" - "strings" - "sync" - "testing" - "time" -) - -// Tests that Manager.sendThread successfully sends the parts and reports their -// progress on the callback. -func TestManager_sendThread(t *testing.T) { - m, sti, _ := newTestManagerWithTransfers( - []uint16{12, 4, 1}, false, true, nil, nil, nil, t) - - // Add three transfers - partsToSend := [][]uint16{ - {0, 1, 3, 5, 6, 7}, - {1, 2, 3, 0}, - {0}, - } - - var wg sync.WaitGroup - for i, st := range sti { - wg.Add(1) - go func(i int, st sentTransferInfo) { - for j := 0; j < 2; j++ { - select { - case <-time.NewTimer(160 * time.Millisecond).C: - t.Errorf("Timed out waiting for callback #%d", i) - case r := <-st.cbChan: - if j > 0 { - err := checkSentProgress(r.completed, r.sent, r.arrived, - r.total, false, uint16(len(partsToSend[i])), 0, - st.numParts) - if err != nil { - t.Errorf("%d: %+v", i, err) - } - if r.err != nil { - t.Errorf("Callback returned an error (%d): %+v", i, r.err) - } - } - } - } - wg.Done() - }(i, st) - } - - // Create queued part list, add parts from each transfer, and shuffle - queuedParts := make([]queuedPart, 0, 11) - for i, sendingParts := range partsToSend { - for _, part := range sendingParts { - queuedParts = append(queuedParts, queuedPart{sti[i].tid, part}) - } - } - rand.Shuffle(len(queuedParts), func(i, j int) { - queuedParts[i], queuedParts[j] = queuedParts[j], queuedParts[i] - }) - - // Create custom RNG function that always returns 11 - getNumParts := func(rng csprng.Source) int { - return len(queuedParts) - } - - // Start sending thread - stop := stoppable.NewSingle("testSendThreadStoppable") - healthyChan := make(chan bool, 8) - healthyChan <- true - go m.sendThread(stop, healthyChan, 0, getNumParts) - - // Add parts to queue - for _, part := range queuedParts { - m.sendQueue <- part - } - - wg.Wait() - - err := stop.Close() - if err != nil { - t.Errorf("Failed to stop stoppable: %+v", err) - } - - if len(m.net.(*testNetworkManager).GetMsgList(0)) != len(queuedParts) { - t.Errorf("Not all messages were received.\nexpected: %d\nreceived: %d", - len(queuedParts), len(m.net.(*testNetworkManager).GetMsgList(0))) - } -} - -// Tests that Manager.sendThread successfully sends the parts and reports their -// progress on the callback. -func TestManager_sendThread_NetworkNotHealthy(t *testing.T) { - m, _, _ := newTestManagerWithTransfers( - []uint16{12, 4, 1}, false, true, nil, nil, nil, t) - - sendingChan := make(chan bool, 4) - getNumParts := func(csprng.Source) int { - sendingChan <- true - return 0 - } - - // Start sending thread - stop := stoppable.NewSingle("testSendThreadStoppable") - healthyChan := make(chan bool, 8) - go m.sendThread(stop, healthyChan, 0, getNumParts) - - for i := 0; i < 15; i++ { - healthyChan <- false - } - m.sendQueue <- queuedPart{ftCrypto.TransferID{5}, 0} - - select { - case <-time.NewTimer(150 * time.Millisecond).C: - healthyChan <- true - case r := <-sendingChan: - t.Errorf("sendThread tried to send even though the network is "+ - "unhealthy. %t", r) - } - - select { - case <-time.NewTimer(150 * time.Millisecond).C: - t.Errorf("Timed out waiting for sending to start.") - case <-sendingChan: - } - - err := stop.Close() - if err != nil { - t.Errorf("Failed to stop stoppable: %+v", err) - } -} - -// Tests that Manager.sendThread successfully sends a partially filled batch -// of the correct length when its times out waiting for messages. -func TestManager_sendThread_Timeout(t *testing.T) { - m, sti, _ := newTestManagerWithTransfers( - []uint16{12, 4, 1}, false, false, nil, nil, nil, t) - - // Add three transfers - partsToSend := [][]uint16{ - {0, 1, 3, 5, 6, 7}, - {1, 2, 3, 0}, - {0}, - } - - var wg sync.WaitGroup - for i, st := range sti[:1] { - wg.Add(1) - go func(i int, st sentTransferInfo) { - for j := 0; j < 2; j++ { - select { - case <-time.NewTimer(80*time.Millisecond + pollSleepDuration).C: - t.Errorf("Timed out waiting for callback #%d", i) - case r := <-st.cbChan: - if j > 0 { - err := checkSentProgress(r.completed, r.sent, r.arrived, - r.total, false, 5, 0, - st.numParts) - if err != nil { - t.Errorf("%d: %+v", i, err) - } - if r.err != nil { - t.Errorf("Callback returned an error (%d): %+v", i, r.err) - } - } - } - } - wg.Done() - }(i, st) - } - - // Create queued part list, add parts from each transfer, and shuffle - queuedParts := make([]queuedPart, 0, 11) - for i, sendingParts := range partsToSend { - for _, part := range sendingParts { - queuedParts = append(queuedParts, queuedPart{sti[i].tid, part}) - } - } - - // Crate custom getRngNum function that always returns 11 - getNumParts := func(rng csprng.Source) int { - return len(queuedParts) - } - - // Start sending thread - stop := stoppable.NewSingle("testSendThreadStoppable") - go m.sendThread(stop, make(chan bool), 0, getNumParts) - - // Add parts to queue - for _, part := range queuedParts[:5] { - m.sendQueue <- part - } - - time.Sleep(pollSleepDuration) - - wg.Wait() - - err := stop.Close() - if err != nil { - t.Errorf("Failed to stop stoppable: %+v", err) - } - - if len(m.net.(*testNetworkManager).GetMsgList(0)) != len(queuedParts[:5]) { - t.Errorf("Not all messages were received.\nexpected: %d\nreceived: %d", - len(queuedParts[:5]), len(m.net.(*testNetworkManager).GetMsgList(0))) - } -} - -// Tests that Manager.sendParts sends all the correct cMix messages and calls -// the progress callbacks with the correct values. -func TestManager_sendParts(t *testing.T) { - m, sti, _ := newTestManagerWithTransfers( - []uint16{12, 4, 1}, false, true, nil, nil, nil, t) - - // Add three transfers - partsToSend := [][]uint16{ - {0, 1, 3, 5, 6, 7}, - {1, 2, 3, 0}, - {0}, - } - - var wg sync.WaitGroup - for i, st := range sti { - wg.Add(1) - go func(i int, st sentTransferInfo) { - for j := 0; j < 2; j++ { - select { - case <-time.NewTimer(20 * time.Millisecond).C: - t.Errorf("Timed out waiting for callback #%d", i) - case r := <-st.cbChan: - if j > 0 { - err := checkSentProgress(r.completed, r.sent, r.arrived, - r.total, false, uint16(len(partsToSend[i])), 0, - st.numParts) - if err != nil { - t.Errorf("%d: %+v", i, err) - } - if r.err != nil { - t.Errorf("Callback returned an error (%d): %+v", i, r.err) - } - } - } - } - wg.Done() - }(i, st) - } - - // Create queued part list, add parts from each transfer, and shuffle - queuedParts := make([]queuedPart, 0, 11) - for i, sendingParts := range partsToSend { - for _, part := range sendingParts { - queuedParts = append(queuedParts, queuedPart{sti[i].tid, part}) - } - } - rand.Shuffle(len(queuedParts), func(i, j int) { - queuedParts[i], queuedParts[j] = queuedParts[j], queuedParts[i] - }) - - err := m.sendParts(queuedParts, newSentRoundTracker(clearSentRoundsAge)) - if err != nil { - t.Errorf("sendParts returned an error: %+v", err) - } - - m.net.(*testNetworkManager).GetMsgList(0) - - // Check that each recipient is connected to the correct transfer and that - // the fingerprint is correct - for i, tcm := range m.net.(*testNetworkManager).GetMsgList(0) { - index := 0 - for ; !sti[index].recipient.Cmp(tcm.Recipient); index++ { - - } - transfer, err := m.sent.GetTransfer(sti[index].tid) - if err != nil { - t.Errorf("Failed to get transfer %s: %+v", sti[index].tid, err) - } - - var fpFound bool - for _, fp := range ftCrypto.GenerateFingerprints( - transfer.GetTransferKey(), 15) { - if fp == tcm.Message.GetKeyFP() { - fpFound = true - } - } - if !fpFound { - t.Errorf("Fingeprint %s not found (%d).", tcm.Message.GetKeyFP(), i) - } - } - - wg.Wait() -} - -// Error path: tests that, on SendManyCMIX failure, Manager.sendParts adds the -// parts back into the queue, does not call the callback, and does not update -// the progress. -func TestManager_sendParts_SendManyCmixError(t *testing.T) { - m, sti, _ := newTestManagerWithTransfers( - []uint16{12, 4, 1}, true, false, nil, nil, nil, t) - partsToSend := [][]uint16{ - {0, 1, 3, 5, 6, 7}, - {1, 2, 3, 0}, - {0}, - } - - var wg sync.WaitGroup - for i, st := range sti { - wg.Add(1) - go func(i int, st sentTransferInfo) { - for j := 0; j < 2; j++ { - select { - case <-time.NewTimer(10 * time.Millisecond).C: - if j < 1 { - t.Errorf("Timed out waiting for callback #%d (%d)", i, j) - } - case r := <-st.cbChan: - if j > 0 { - t.Errorf("Callback called on send failure: %+v", r) - } - } - } - wg.Done() - }(i, st) - } - - // Create queued part list, add parts from each transfer, and shuffle - queuedParts := make([]queuedPart, 0, 11) - for i, sendingParts := range partsToSend { - for _, part := range sendingParts { - queuedParts = append(queuedParts, queuedPart{sti[i].tid, part}) - } - } - - err := m.sendParts(queuedParts, newSentRoundTracker(clearSentRoundsAge)) - if err != nil { - t.Errorf("sendParts returned an error: %+v", err) - } - - if len(m.net.(*testNetworkManager).GetMsgList(0)) > 0 { - t.Errorf("Sent %d cMix message(s) when sending should have failed.", - len(m.net.(*testNetworkManager).GetMsgList(0))) - } - - if len(m.sendQueue) != len(queuedParts) { - t.Errorf("Failed to add all parts to queue after send failure."+ - "\nexpected: %d\nreceived: %d", len(queuedParts), len(m.sendQueue)) - } - - wg.Wait() -} - -// Error path: tests that Manager.sendParts returns the expected error whe -// getRoundResults returns an error. -func TestManager_sendParts_RoundResultsError(t *testing.T) { - m, sti, _ := newTestManagerWithTransfers( - []uint16{12}, false, true, nil, nil, nil, t) - - grrErr := errors.New("GetRoundResultsError") - m.getRoundResults = - func([]id.Round, time.Duration, cmix.RoundEventCallback) error { - return grrErr - } - - // Add three transfers - partsToSend := [][]uint16{ - {0, 1, 3, 5, 6, 7}, - } - - // Create queued part list, add parts from each transfer, and shuffle - queuedParts := make([]queuedPart, 0, 11) - tIDs := make([]ftCrypto.TransferID, 0, len(sti)) - for i, sendingParts := range partsToSend { - for _, part := range sendingParts { - queuedParts = append(queuedParts, queuedPart{sti[i].tid, part}) - } - tIDs = append(tIDs, sti[i].tid) - } - - expectedErr := fmt.Sprintf(getRoundResultsErr, 0, tIDs, grrErr) - err := m.sendParts(queuedParts, newSentRoundTracker(clearSentRoundsAge)) - if err == nil || err.Error() != expectedErr { - t.Errorf("sendParts did not return the expected error when "+ - "GetRoundResults should have returned an error."+ - "\nexpected: %s\nreceived: %+v", expectedErr, err) - } -} - -// Tests that Manager.buildMessages returns the expected values for a group -// of 11 file parts from three different transfers. -func TestManager_buildMessages(t *testing.T) { - m, sti, _ := newTestManagerWithTransfers( - []uint16{12, 4, 1}, false, false, nil, nil, nil, t) - partsToSend := [][]uint16{ - {0, 1, 3, 5, 6, 7}, - {1, 2, 3, 0}, - {0}, - } - idMap := map[ftCrypto.TransferID]int{ - sti[0].tid: 0, - sti[1].tid: 1, - sti[2].tid: 2, - } - - // Create queued part list, add parts from each transfer, and shuffle - queuedParts := make([]queuedPart, 0, 11) - for i, sendingParts := range partsToSend { - for _, part := range sendingParts { - queuedParts = append(queuedParts, queuedPart{sti[i].tid, part}) - } - } - rand.Shuffle(len(queuedParts), func(i, j int) { - queuedParts[i], queuedParts[j] = queuedParts[j], queuedParts[i] - }) - - // Build the messages - messages, transfers, groupedParts, partsToResend, err := - m.buildMessages(queuedParts) - if err != nil { - t.Errorf("buildMessages returned an error: %+v", err) - } - - // Check that the transfer map has all the transfers - for i, st := range sti { - _, exists := transfers[st.tid] - if !exists { - t.Errorf("Transfer %s (#%d) not in transfers map.", st.tid, i) - } - } - - // Check that partsToResend contains all parts because none should have - // errored - if len(partsToResend) != len(queuedParts) { - sort.SliceStable(partsToResend, func(i, j int) bool { - return partsToResend[i] < partsToResend[j] - }) - for i, j := range partsToResend { - if i != j { - t.Errorf("Part at index %d not found. Found %d.", i, j) - } - } - } - - // Check that all the parts are present and grouped correctly - for tid, partNums := range groupedParts { - sort.SliceStable(partNums, func(i, j int) bool { - return partNums[i] < partNums[j] - }) - sort.SliceStable(partsToSend[idMap[tid]], func(i, j int) bool { - return partsToSend[idMap[tid]][i] < partsToSend[idMap[tid]][j] - }) - if !reflect.DeepEqual(partsToSend[idMap[tid]], partNums) { - t.Errorf("Incorrect parts for transfer %s."+ - "\nexpected: %v\nreceived: %v", - tid, partsToSend[idMap[tid]], partNums) - } - } - - // Check that each recipient is connected to the correct transfer and that - // the fingerprint is correct - for i, tcm := range messages { - index := 0 - for ; !sti[index].recipient.Cmp(tcm.Recipient); index++ { - - } - transfer, err := m.sent.GetTransfer(sti[index].tid) - if err != nil { - t.Errorf("Failed to get transfer %s: %+v", sti[index].tid, err) - } - - var fpFound bool - for _, fp := range ftCrypto.GenerateFingerprints( - transfer.GetTransferKey(), 15) { - if fp == tcm.Message.GetKeyFP() { - fpFound = true - } - } - if !fpFound { - t.Errorf("Fingeprint %s not found (%d).", tcm.Message.GetKeyFP(), i) - } - } -} - -// Tests that Manager.buildMessages skips file parts with deleted transfers or -// transfers that have run out of fingerprints. -func TestManager_buildMessages_MessageBuildFailureError(t *testing.T) { - m := newTestManager(false, nil, nil, nil, nil, t) - - callbackChan := make(chan sentProgressResults, 10) - progressCB := func(completed bool, sent, arrived, total uint16, - tr interfaces.FilePartTracker, err error) { - callbackChan <- sentProgressResults{ - completed, sent, arrived, total, tr, err} - } - - done0, done1 := make(chan bool), make(chan bool) - go func() { - for i := 0; i < 2; i++ { - select { - case <-time.NewTimer(20 * time.Millisecond).C: - t.Error("Timed out waiting for callback.") - case r := <-callbackChan: - switch i { - case 0: - done0 <- true - case 1: - expectedErr := fmt.Sprintf( - maxRetriesErr, ftStorage.MaxRetriesErr) - if i > 0 && (r.err == nil || r.err.Error() != expectedErr) { - t.Errorf("Callback received unexpected error when max "+ - "retries should have been reached."+ - "\nexpected: %s\nreceived: %v", expectedErr, r.err) - } - done1 <- true - } - } - } - }() - - prng := NewPrng(42) - recipient := id.NewIdFromString("recipient", id.User, t) - key, _ := ftCrypto.NewTransferKey(prng) - _, parts := newFile(4, 64, prng, t) - tid, err := m.sent.AddTransfer( - recipient, key, parts, 3, progressCB, time.Millisecond, prng) - if err != nil { - t.Errorf("Failed to add new transfer: %+v", err) - } - - <-done0 - - // Create queued part list add parts - queuedParts := []queuedPart{ - {tid, 0}, - {tid, 1}, - {tid, 2}, - {tid, 3}, - {ftCrypto.UnmarshalTransferID([]byte("invalidID")), 3}, - } - - // Build the messages - prng = NewPrng(46) - messages, transfers, groupedParts, partsToResend, err := m.buildMessages( - queuedParts) - if err != nil { - t.Errorf("buildMessages returned an error: %+v", err) - } - - <-done1 - - // Check that partsToResend contains all parts because none should have - // errored - if len(partsToResend) != len(queuedParts)-2 { - sort.SliceStable(partsToResend, func(i, j int) bool { - return partsToResend[i] < partsToResend[j] - }) - for i, j := range partsToResend { - if i != j { - t.Errorf("Part at index %d not found. Found %d.", i, j) - } - } - } - - if len(messages) != 3 { - t.Errorf("Length of messages incorrect."+ - "\nexpected: %d\nreceived: %d", 3, len(messages)) - } - - if len(transfers) != 1 { - t.Errorf("Length of transfers incorrect."+ - "\nexpected: %d\nreceived: %d", 1, len(transfers)) - } - - if len(groupedParts) != 1 { - t.Errorf("Length of grouped parts incorrect."+ - "\nexpected: %d\nreceived: %d", 1, len(groupedParts)) - } -} - -// Tests that Manager.buildMessages returns the expected error when a queued -// part has an invalid part number. -func TestManager_buildMessages_NewCmixMessageError(t *testing.T) { - m := newTestManager(false, nil, nil, nil, nil, t) - - // Add transfer - prng := NewPrng(42) - recipient := id.NewIdFromString("recipient", id.User, t) - key, _ := ftCrypto.NewTransferKey(prng) - _, parts := newFile(12, 405, prng, t) - tid, err := m.sent.AddTransfer(recipient, key, parts, 15, nil, 0, prng) - if err != nil { - t.Errorf("Failed to add new transfer: %+v", err) - } - - // Create queued part list and add a single part with an invalid part number - queuedParts := []queuedPart{{tid, uint16(len(parts))}} - - // Build the messages - expectedErr := fmt.Sprintf( - newCmixMessageErr, queuedParts[0].partNum, tid, "") - _, _, _, _, err = m.buildMessages(queuedParts) - if err == nil || !strings.Contains(err.Error(), expectedErr) { - t.Errorf("buildMessages did not return the expected error when the "+ - "queuedPart part number is invalid.\nexpected: %s\nreceived: %+v", - expectedErr, err) - } - -} - -// Tests that Manager.newCmixMessage returns a format.Message with the correct -// MAC, fingerprint, and contents. -func TestManager_newCmixMessage(t *testing.T) { - m := newTestManager(false, nil, nil, nil, nil, t) - prng := NewPrng(42) - tid, _ := ftCrypto.NewTransferID(prng) - key, _ := ftCrypto.NewTransferKey(prng) - recipient := id.NewIdFromString("recipient", id.User, t) - partSize, _ := m.getPartSize() - _, parts := newFile(16, partSize, prng, t) - numFps := calcNumberOfFingerprints(uint16(len(parts)), 1.5) - kv := versioned.NewKV(make(ekv.Memstore)) - - transfer, err := ftStorage.NewSentTransfer(recipient, tid, key, parts, - numFps, nil, 0, kv) - if err != nil { - t.Errorf("Failed to create a new SentTransfer: %+v", err) - } - - cmixMsg, err := m.newCmixMessage(transfer, 0) - if err != nil { - t.Errorf("newCmixMessage returned an error: %+v", err) - } - - fp := ftCrypto.GenerateFingerprints(key, numFps)[0] - - if cmixMsg.GetKeyFP() != fp { - t.Errorf("cMix message has wrong fingerprint."+ - "\nexpected: %s\nrecieved: %s", fp, cmixMsg.GetKeyFP()) - } - - decrPart, err := ftCrypto.DecryptPart(key, cmixMsg.GetContents(), - cmixMsg.GetMac(), 0, cmixMsg.GetKeyFP()) - if err != nil { - t.Errorf("Failed to decrypt file part: %+v", err) - } - - partMsg, err := ftStorage.UnmarshalPartMessage(decrPart) - if err != nil { - t.Errorf("Failed to unmarshal part message: %+v", err) - } - - if !bytes.Equal(partMsg.GetPart(), parts[0]) { - t.Errorf("Decrypted part does not match expected."+ - "\nexpected: %q\nreceived: %q", parts[0], partMsg.GetPart()) - } -} - -// Tests that Manager.makeRoundEventCallback returns a callback that calls the -// progress callback when a round succeeds. -func TestManager_makeRoundEventCallback(t *testing.T) { - sendE2eChan := make(chan message.Receive, 100) - m := newTestManager(false, nil, sendE2eChan, nil, nil, t) - - callbackChan := make(chan sentProgressResults, 100) - progressCB := func(completed bool, sent, arrived, total uint16, - tr interfaces.FilePartTracker, err error) { - callbackChan <- sentProgressResults{ - completed, sent, arrived, total, tr, err} - } - - // Add recipient as partner - recipient := id.NewIdFromString("recipient", id.User, t) - grp := m.store.E2e().GetGroup() - dhKey := grp.NewInt(42) - pubKey := diffieHellman.GeneratePublicKey(dhKey, grp) - p := params.GetDefaultE2ESessionParams() - - rng := csprng.NewSystemRNG() - _, mySidhPriv := util.GenerateSIDHKeyPair(sidh.KeyVariantSidhA, - rng) - theirSidhPub, _ := util.GenerateSIDHKeyPair( - sidh.KeyVariantSidhB, rng) - - err := m.store.E2e().AddPartner(recipient, pubKey, dhKey, mySidhPriv, - theirSidhPub, p, p) - if err != nil { - t.Errorf("Failed to add partner %s: %+v", recipient, err) - } - - done0, done1 := make(chan bool), make(chan bool) - go func() { - for i := 0; i < 2; i++ { - select { - case <-time.NewTimer(10 * time.Millisecond).C: - t.Errorf("Timed out waiting for callback (%d).", i) - case r := <-callbackChan: - switch i { - case 0: - done0 <- true - case 1: - if r.err != nil { - t.Errorf("Callback received error: %v", r.err) - } - if !r.completed { - t.Error("File not marked as completed.") - } - done1 <- true - } - } - } - }() - - prng := NewPrng(42) - key, _ := ftCrypto.NewTransferKey(prng) - _, parts := newFile(4, 64, prng, t) - tid, err := m.sent.AddTransfer( - recipient, key, parts, 6, progressCB, time.Millisecond, prng) - if err != nil { - t.Errorf("Failed to add new transfer: %+v", err) - } - - <-done0 - - // Create queued part list add parts - queuedParts := []queuedPart{ - {tid, 0}, - {tid, 1}, - {tid, 2}, - {tid, 3}, - } - - rid := id.Round(42) - - _, transfers, groupedParts, _, err := m.buildMessages(queuedParts) - - // Set all parts to in-progress - for tid, transfer := range transfers { - _, _ = transfer.SetInProgress(rid, groupedParts[tid]...) - } - - roundEventCB := m.makeRoundEventCallback( - map[id.Round][]ftCrypto.TransferID{rid: {tid}}) - - roundEventCB(true, false, map[id.Round]cmix.RoundLookupStatus{rid: cmix.Succeeded}) - - <-done1 - - select { - case <-time.NewTimer(50 * time.Millisecond).C: - t.Errorf("Timed out waiting for end E2E message.") - case msg := <-sendE2eChan: - if msg.MessageType != message.EndFileTransfer { - t.Errorf("E2E message has wrong type.\nexpected: %d\nreceived: %d", - message.EndFileTransfer, msg.MessageType) - } else if !msg.RecipientID.Cmp(recipient) { - t.Errorf("E2E message has wrong recipient."+ - "\nexpected: %d\nreceived: %d", recipient, msg.RecipientID) - } - } -} - -// Tests that Manager.makeRoundEventCallback returns a callback that calls the -// progress callback with no parts sent on round failure. Also checks that the -// file parts were added back into the queue. -func TestManager_makeRoundEventCallback_RoundFailure(t *testing.T) { - m := newTestManager(false, nil, nil, nil, nil, t) - - rid := id.Round(42) - - callbackChan := make(chan sentProgressResults, 10) - progressCB := func(completed bool, sent, arrived, total uint16, - tr interfaces.FilePartTracker, err error) { - callbackChan <- sentProgressResults{ - completed, sent, arrived, total, tr, err} - } - - prng := NewPrng(42) - recipient := id.NewIdFromString("recipient", id.User, t) - key, _ := ftCrypto.NewTransferKey(prng) - _, parts := newFile(4, 64, prng, t) - tid, err := m.sent.AddTransfer( - recipient, key, parts, 6, progressCB, time.Millisecond, prng) - if err != nil { - t.Errorf("Failed to add new transfer: %+v", err) - } - - partsToSend := []uint16{0, 1, 2, 3} - - done0, done1 := make(chan bool), make(chan bool) - go func() { - for i := 0; i < 2; i++ { - select { - case <-time.NewTimer(10 * time.Millisecond).C: - t.Error("Timed out waiting for callback.") - case r := <-callbackChan: - switch i { - case 0: - done0 <- true - case 1: - expectedResult := sentProgressResults{ - false, 0, 0, uint16(len(partsToSend)), r.tracker, nil} - if !reflect.DeepEqual(expectedResult, r) { - t.Errorf("Callback returned unexpected values."+ - "\nexpected: %+v\nreceived: %+v", expectedResult, r) - } - done1 <- true - } - } - } - }() - - // Create queued part list add parts - partsMap := make(map[uint16]queuedPart, len(partsToSend)) - queuedParts := make([]queuedPart, len(partsToSend)) - for i := range queuedParts { - queuedParts[i] = queuedPart{tid, uint16(i)} - partsMap[uint16(i)] = queuedParts[i] - } - - _, transfers, groupedParts, _, err := m.buildMessages(queuedParts) - - <-done0 - - // Set all parts to in-progress - for tid, transfer := range transfers { - _, _ = transfer.SetInProgress(rid, groupedParts[tid]...) - } - - roundEventCB := m.makeRoundEventCallback( - map[id.Round][]ftCrypto.TransferID{rid: {tid}}) - - roundEventCB(false, false, map[id.Round]cmix.RoundLookupStatus{rid: cmix.Failed}) - - <-done1 - - // Check that the parts were added to the queue - for i := range partsToSend { - select { - case <-time.NewTimer(10 * time.Millisecond).C: - t.Errorf("Timed out waiting for part %d.", i) - case r := <-m.sendQueue: - if partsMap[r.partNum] != r { - t.Errorf("Incorrect part in queue (%d)."+ - "\nexpected: %+v\nreceived: %+v", i, partsMap[r.partNum], r) - } else { - delete(partsMap, r.partNum) - } - } - } -} - -// Tests that Manager.sendEndE2eMessage sends an E2E message with the expected -// recipient and message type. This does not test round tracking or critical -// messages. -func TestManager_sendEndE2eMessage(t *testing.T) { - sendE2eChan := make(chan message.Receive, 10) - m := newTestManager(false, nil, sendE2eChan, nil, nil, t) - - // Add recipient as partner - recipient := id.NewIdFromString("recipient", id.User, t) - grp := m.store.E2e().GetGroup() - dhKey := grp.NewInt(42) - pubKey := diffieHellman.GeneratePublicKey(dhKey, grp) - p := params.GetDefaultE2ESessionParams() - - rng := csprng.NewSystemRNG() - _, mySidhPriv := util.GenerateSIDHKeyPair(sidh.KeyVariantSidhA, rng) - theirSidhPub, _ := util.GenerateSIDHKeyPair(sidh.KeyVariantSidhB, rng) - - err := m.store.E2e().AddPartner( - recipient, pubKey, dhKey, mySidhPriv, theirSidhPub, p, p) - if err != nil { - t.Errorf("Failed to add partner %s: %+v", recipient, err) - } - - go func() { - err = m.sendEndE2eMessage(recipient) - if err != nil { - t.Errorf("sendEndE2eMessage returned an error: %+v", err) - } - }() - - select { - case <-time.NewTimer(50 * time.Millisecond).C: - t.Errorf("Timed out waiting for end E2E message.") - case msg := <-sendE2eChan: - if msg.MessageType != message.EndFileTransfer { - t.Errorf("E2E message has wrong type.\nexpected: %d\nreceived: %d", - message.EndFileTransfer, msg.MessageType) - } else if !msg.RecipientID.Cmp(recipient) { - t.Errorf("E2E message has wrong recipient."+ - "\nexpected: %d\nreceived: %d", recipient, msg.RecipientID) - } - } -} - -// Tests that Manager.queueParts adds all the expected parts to the sendQueue -// channel. -func TestManager_queueParts(t *testing.T) { - m := newTestManager(false, nil, nil, nil, nil, t) - csPrng := NewPrng(42) - prng := rand.New(rand.NewSource(42)) - - // Create map of expected parts - n := 26 - parts := make(map[ftCrypto.TransferID]map[uint16]bool, n) - for i := 0; i < n; i++ { - tid, _ := ftCrypto.NewTransferID(csPrng) - o := uint16(prng.Int31n(64)) - parts[tid] = make(map[uint16]bool, o) - for j := uint16(0); j < o; j++ { - parts[tid][j] = true - } - } - - // Queue all parts - for tid, partList := range parts { - partNums := makeListOfPartNums(uint16(len(partList))) - m.queueParts(tid, partNums) - } - - // Read through all parts in channel ensuring none are missing or duplicate - for len(parts) > 0 { - select { - case part := <-m.sendQueue: - partList, exists := parts[part.tid] - if !exists { - t.Errorf("Could not find transfer %s", part.tid) - } else { - _, exists := partList[part.partNum] - if !exists { - t.Errorf("Could not find part number %d for transfer %s", - part.partNum, part.tid) - } - - delete(parts[part.tid], part.partNum) - - if len(parts[part.tid]) == 0 { - delete(parts, part.tid) - } - } - case <-time.NewTimer(time.Millisecond).C: - if len(parts) != 0 { - t.Errorf("Timed out reading parts from channel. Failed to "+ - "read parts for %d/%d transfers.", len(parts), n) - } - return - default: - if len(parts) != 0 { - t.Errorf("Failed to read parts for %d/%d transfers.", - len(parts), n) - } - } - } -} - -// Tests that makeListOfPartNums returns a list with all the part numbers. -func Test_makeListOfPartNums(t *testing.T) { - n := uint16(100) - numList := makeListOfPartNums(n) - - if len(numList) != int(n) { - t.Errorf("Length of list incorrect.\nexpected: %d\nreceived: %d", - n, len(numList)) - } - - for i := uint16(0); i < n; i++ { - if numList[i] != i { - t.Errorf("Part number at index %d incorrect."+ - "\nexpected: %d\nreceived: %d", i, i, numList[i]) - } - } -} - -// Tests that the part size returned by Manager.GetPartSize matches the manually -// calculated part size. -func TestManager_getPartSize(t *testing.T) { - m := newTestManager(false, nil, nil, nil, nil, t) - - // Calculate the expected part size - primeByteLen := m.store.Cmix().GetGroup().GetP().ByteLen() - cmixMsgUsedLen := format.AssociatedDataSize - filePartMsgUsedLen := ftStorage.FmMinSize - expected := 2*primeByteLen - cmixMsgUsedLen - filePartMsgUsedLen - 1 - - // Get the part size - partSize, err := m.getPartSize() - if err != nil { - t.Errorf("GetPartSize returned an error: %+v", err) - } - - if expected != partSize { - t.Errorf("Returned part size does not match expected."+ - "\nexpected: %d\nreceived: %d", expected, partSize) - } -} - -// Tests that partitionFile partitions the given file into the expected parts. -func Test_partitionFile(t *testing.T) { - prng := rand.New(rand.NewSource(42)) - partSize := 96 - fileData, expectedParts := newFile(24, partSize, prng, t) - - receivedParts := partitionFile(fileData, partSize) - - if !reflect.DeepEqual(expectedParts, receivedParts) { - t.Errorf("File parts do not match expected."+ - "\nexpected: %q\nreceived: %q", expectedParts, receivedParts) - } - - fullFile := bytes.Join(receivedParts, nil) - if !bytes.Equal(fileData, fullFile) { - t.Errorf("Full file does not match expected."+ - "\nexpected: %q\nreceived: %q", fileData, fullFile) - } -} - -// Tests that getRandomNumParts returns values between minPartsSendPerRound and -// maxPartsSendPerRound. -func Test_getRandomNumParts(t *testing.T) { - prng := NewPrng(42) - - for i := 0; i < 100; i++ { - num := getRandomNumParts(prng) - if num < minPartsSendPerRound { - t.Errorf("Number %d is smaller than minimum %d (%d)", - num, minPartsSendPerRound, i) - } else if num > maxPartsSendPerRound { - t.Errorf("Number %d is greater than maximum %d (%d)", - num, maxPartsSendPerRound, i) - } - } -} - -// Tests that getRandomNumParts panics for a PRNG that errors. -func Test_getRandomNumParts_PrngPanic(t *testing.T) { - prng := NewPrngErr() - - defer func() { - // Error if a panic does not occur - if err := recover(); err == nil { - t.Error("Failed to panic for broken PRNG.") - } - }() - - _ = getRandomNumParts(prng) -} - -// Tests that getRandomNumParts satisfies the getRngNum type. -func Test_getRandomNumParts_GetRngNumType(t *testing.T) { - var _ getRngNum = getRandomNumParts -} diff --git a/fileTransfer/sentRoundTracker.go b/fileTransfer/sentRoundTracker/sentRoundTracker.go similarity index 54% rename from fileTransfer/sentRoundTracker.go rename to fileTransfer/sentRoundTracker/sentRoundTracker.go index 75a06eb13e593f205b470b30f2dadc76a1c2c869..68511497dad8fbb2d7f3a3417c4d3873997a7c58 100644 --- a/fileTransfer/sentRoundTracker.go +++ b/fileTransfer/sentRoundTracker/sentRoundTracker.go @@ -5,14 +5,7 @@ // LICENSE file // //////////////////////////////////////////////////////////////////////////////// -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - -package fileTransfer +package sentRoundTracker import ( "gitlab.com/xx_network/primitives/id" @@ -21,24 +14,25 @@ import ( "time" ) -// sentRoundTracker keeps track of rounds that file parts were sent on and when +// Manager keeps track of rounds that file parts were sent on and when // those rounds occurred. Rounds past the given age can be deleted manually. -type sentRoundTracker struct { +// Adheres to excludedRounds.ExcludedRounds. +type Manager struct { rounds map[id.Round]time.Time age time.Duration mux sync.RWMutex } -// newSentRoundTracker returns an empty sentRoundTracker. -func newSentRoundTracker(interval time.Duration) *sentRoundTracker { - return &sentRoundTracker{ +// NewManager creates a new sent round tracker Manager. +func NewManager(interval time.Duration) *Manager { + return &Manager{ rounds: make(map[id.Round]time.Time), age: interval, } } -// removeOldRounds removes any rounds that are older than the max round age. -func (srt *sentRoundTracker) removeOldRounds() { +// RemoveOldRounds removes any rounds that are older than the max round age. +func (srt *Manager) RemoveOldRounds() { srt.mux.Lock() defer srt.mux.Unlock() deleteBefore := netTime.Now().Add(-srt.age) @@ -51,7 +45,7 @@ func (srt *sentRoundTracker) removeOldRounds() { } // Has indicates if the round ID is in the tracker. -func (srt *sentRoundTracker) Has(rid id.Round) bool { +func (srt *Manager) Has(rid id.Round) bool { srt.mux.RLock() defer srt.mux.RUnlock() @@ -61,7 +55,7 @@ func (srt *sentRoundTracker) Has(rid id.Round) bool { // Insert adds the round to the tracker with the current time. Returns true if // the round was added. -func (srt *sentRoundTracker) Insert(rid id.Round) bool { +func (srt *Manager) Insert(rid id.Round) bool { timeNow := netTime.Now() srt.mux.Lock() defer srt.mux.Unlock() @@ -76,30 +70,16 @@ func (srt *sentRoundTracker) Insert(rid id.Round) bool { } // Remove deletes a round ID from the tracker. -func (srt *sentRoundTracker) Remove(rid id.Round) { +func (srt *Manager) Remove(rid id.Round) { srt.mux.Lock() defer srt.mux.Unlock() delete(srt.rounds, rid) } // Len returns the number of round IDs in the tracker. -func (srt *sentRoundTracker) Len() int { +func (srt *Manager) Len() int { srt.mux.RLock() defer srt.mux.RUnlock() return len(srt.rounds) } - -// GetRoundIDs returns a list of all round IDs in the tracker. -func (srt *sentRoundTracker) GetRoundIDs() []id.Round { - srt.mux.RLock() - defer srt.mux.RUnlock() - - roundIDs := make([]id.Round, 0, len(srt.rounds)) - - for rid := range srt.rounds { - roundIDs = append(roundIDs, rid) - } - - return roundIDs -} diff --git a/fileTransfer/sentRoundTracker_test.go b/fileTransfer/sentRoundTracker/sentRoundTracker_test.go similarity index 67% rename from fileTransfer/sentRoundTracker_test.go rename to fileTransfer/sentRoundTracker/sentRoundTracker_test.go index 451ec1eff967d387e3c904791c5f3f8bf82abffb..90876842d6408cee151bf9d8cdb38598e276135c 100644 --- a/fileTransfer/sentRoundTracker_test.go +++ b/fileTransfer/sentRoundTracker/sentRoundTracker_test.go @@ -5,14 +5,7 @@ // LICENSE file // //////////////////////////////////////////////////////////////////////////////// -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - -package fileTransfer +package sentRoundTracker import ( "gitlab.com/xx_network/primitives/id" @@ -21,26 +14,26 @@ import ( "time" ) -// Tests that newSentRoundTracker returns the expected new sentRoundTracker. -func Test_newSentRoundTracker(t *testing.T) { +// Tests that NewManager returns the expected new Manager. +func Test_NewSentRoundTracker(t *testing.T) { interval := 10 * time.Millisecond - expected := &sentRoundTracker{ + expected := &Manager{ rounds: make(map[id.Round]time.Time), age: interval, } - srt := newSentRoundTracker(interval) + srt := NewManager(interval) if !reflect.DeepEqual(expected, srt) { - t.Errorf("New sentRoundTracker does not match expected."+ + t.Errorf("New Manager does not match expected."+ "\nexpected: %+v\nreceived: %+v", expected, srt) } } -// Tests that sentRoundTracker.removeOldRounds removes only old rounds and not +// Tests that Manager.RemoveOldRounds removes only old rounds and not // newer rounds. -func Test_sentRoundTracker_removeOldRounds(t *testing.T) { - srt := newSentRoundTracker(50 * time.Millisecond) +func TestManager_RemoveOldRounds(t *testing.T) { + srt := NewManager(50 * time.Millisecond) // Add odd round to tracker for rid := id.Round(0); rid < 100; rid++ { @@ -59,7 +52,7 @@ func Test_sentRoundTracker_removeOldRounds(t *testing.T) { } // Remove all old rounds (should be all odd rounds) - srt.removeOldRounds() + srt.RemoveOldRounds() // Check that only even rounds exist for rid := id.Round(0); rid < 100; rid++ { @@ -73,10 +66,10 @@ func Test_sentRoundTracker_removeOldRounds(t *testing.T) { } } -// Tests that sentRoundTracker.Has returns true for all the even rounds and +// Tests that Manager.Has returns true for all the even rounds and // false for all odd rounds. -func Test_sentRoundTracker_Has(t *testing.T) { - srt := newSentRoundTracker(0) +func TestManager_Has(t *testing.T) { + srt := NewManager(0) // Insert even rounds into the tracker for rid := id.Round(0); rid < 100; rid++ { @@ -97,14 +90,17 @@ func Test_sentRoundTracker_Has(t *testing.T) { } } -// Tests that sentRoundTracker.Insert adds all the expected rounds. -func Test_sentRoundTracker_Insert(t *testing.T) { - srt := newSentRoundTracker(0) +// Tests that Manager.Insert adds all the expected rounds to the map and that it +// returns true when the round does not already exist and false otherwise. +func TestManager_Insert(t *testing.T) { + srt := NewManager(0) // Insert even rounds into the tracker for rid := id.Round(0); rid < 100; rid++ { if rid%2 == 0 { - srt.Insert(rid) + if !srt.Insert(rid) { + t.Errorf("Did not insert round %d.", rid) + } } } @@ -119,11 +115,16 @@ func Test_sentRoundTracker_Insert(t *testing.T) { t.Errorf("Round %d does not exist.", rid) } } + + // Check that adding a round that already exists returns false + if srt.Insert(0) { + t.Errorf("Inserted round %d.", 0) + } } -// Tests that sentRoundTracker.Remove removes all even rounds. -func Test_sentRoundTracker_Remove(t *testing.T) { - srt := newSentRoundTracker(0) +// Tests that Manager.Remove removes all even rounds. +func TestManager_Remove(t *testing.T) { + srt := NewManager(0) // Add all round to tracker for rid := id.Round(0); rid < 100; rid++ { @@ -150,10 +151,10 @@ func Test_sentRoundTracker_Remove(t *testing.T) { } } -// Tests that sentRoundTracker.Len returns the expected length when the tracker +// Tests that Manager.Len returns the expected length when the tracker // is empty, filled, and then modified. -func Test_sentRoundTracker_Len(t *testing.T) { - srt := newSentRoundTracker(0) +func TestManager_Len(t *testing.T) { + srt := NewManager(0) if srt.Len() != 0 { t.Errorf("Length of tracker incorrect.\nexpected: %d\nreceived: %d", diff --git a/fileTransfer/store/cypher/cypher.go b/fileTransfer/store/cypher/cypher.go new file mode 100644 index 0000000000000000000000000000000000000000..c1926c9be6f9c79b9af0b7b05d8f9010d53efdcb --- /dev/null +++ b/fileTransfer/store/cypher/cypher.go @@ -0,0 +1,52 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package cypher + +import ( + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/elixxir/primitives/format" +) + +// Cypher contains the cryptographic and identifying information needed to +// decrypt a file part and associate it with the correct file. +type Cypher struct { + *Manager + fpNum uint16 +} + +// Encrypt encrypts the file part contents and returns them along with a MAC and +// fingerprint. +func (c Cypher) Encrypt(contents []byte) ( + cipherText, mac []byte, fp format.Fingerprint) { + + // Generate fingerprint + fp = ftCrypto.GenerateFingerprint(*c.key, c.fpNum) + + // Encrypt part and get MAC + cipherText, mac = ftCrypto.EncryptPart(*c.key, contents, c.fpNum, fp) + + return cipherText, mac, fp +} + +// Decrypt decrypts the content of the message. +func (c Cypher) Decrypt(msg format.Message) ([]byte, error) { + filePart, err := ftCrypto.DecryptPart( + *c.key, msg.GetContents(), msg.GetMac(), c.fpNum, msg.GetKeyFP()) + if err != nil { + return nil, err + } + + c.fpVector.Use(uint32(c.fpNum)) + + return filePart, nil +} + +// GetFingerprint generates and returns the fingerprints. +func (c Cypher) GetFingerprint() format.Fingerprint { + return ftCrypto.GenerateFingerprint(*c.key, c.fpNum) +} diff --git a/fileTransfer/store/cypher/cypher_test.go b/fileTransfer/store/cypher/cypher_test.go new file mode 100644 index 0000000000000000000000000000000000000000..b3700090322106e817e6c30c1bcceb69eca1df4e --- /dev/null +++ b/fileTransfer/store/cypher/cypher_test.go @@ -0,0 +1,97 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package cypher + +import ( + "bytes" + "gitlab.com/elixxir/primitives/format" + "testing" +) + +// Tests that contents that are encrypted with Cypher.Encrypt match the +// decrypted contents of Cypher.Decrypt. +func TestCypher_Encrypt_Decrypt(t *testing.T) { + m, _ := newTestManager(16, t) + numPrimeBytes := 512 + + // Create contents of the right size + contents := make([]byte, format.NewMessage(numPrimeBytes).ContentsSize()) + copy(contents, "This is some message contents.") + + c, err := m.PopCypher() + if err != nil { + t.Errorf("Failed to pop cypher: %+v", err) + } + + // Encrypt contents + cipherText, mac, fp := c.Encrypt(contents) + + // Create message to decrypt + msg := format.NewMessage(numPrimeBytes) + msg.SetContents(cipherText) + msg.SetMac(mac) + msg.SetKeyFP(fp) + + // Decrypt message + decryptedContents, err := c.Decrypt(msg) + if err != nil { + t.Errorf("Decrypt returned an error: %+v", err) + } + + // Tests that the decrypted contents match the original + if !bytes.Equal(contents, decryptedContents) { + t.Errorf("Decrypted contents do not match original."+ + "\nexpected: %q\nreceived: %q", contents, decryptedContents) + } +} + +// Tests that Cypher.Decrypt returns an error when the contents are the wrong +// size. +func TestCypher_Decrypt_MacError(t *testing.T) { + m, _ := newTestManager(16, t) + + // Create contents of the wrong size + contents := []byte("This is some message contents.") + + c, err := m.PopCypher() + if err != nil { + t.Errorf("Failed to pop cypher: %+v", err) + } + + // Encrypt contents + cipherText, mac, fp := c.Encrypt(contents) + + // Create message to decrypt + msg := format.NewMessage(512) + msg.SetContents(cipherText) + msg.SetMac(mac) + msg.SetKeyFP(fp) + + // Decrypt message + _, err = c.Decrypt(msg) + if err == nil { + t.Error("Failed to receive an error when the contents are the wrong " + + "length.") + } +} + +// Tests that Cypher.GetFingerprint returns unique fingerprints. +func TestCypher_GetFingerprint(t *testing.T) { + m, _ := newTestManager(16, t) + fpMap := make(map[format.Fingerprint]bool, m.fpVector.GetNumKeys()) + + for c, err := m.PopCypher(); err == nil; c, err = m.PopCypher() { + fp := c.GetFingerprint() + + if fpMap[fp] { + t.Errorf("Fingerprint %s already exists.", fp) + } else { + fpMap[fp] = true + } + } +} diff --git a/fileTransfer/store/cypher/manager.go b/fileTransfer/store/cypher/manager.go new file mode 100644 index 0000000000000000000000000000000000000000..be1662d2da1dd4eb7062e990b31dc341fa3eae92 --- /dev/null +++ b/fileTransfer/store/cypher/manager.go @@ -0,0 +1,179 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package cypher + +import ( + "github.com/pkg/errors" + "gitlab.com/elixxir/client/storage/utility" + "gitlab.com/elixxir/client/storage/versioned" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/xx_network/primitives/netTime" +) + +// Storage keys and versions. +const ( + cypherManagerPrefix = "CypherManagerStore" + cypherManagerFpVectorKey = "CypherManagerFingerprintVector" + cypherManagerKeyStoreKey = "CypherManagerKey" + cypherManagerKeyStoreVersion = 0 +) + +// Error messages. +const ( + // NewManager + errNewFpVector = "failed to create new state vector for fingerprints: %+v" + errSaveKey = "failed to save transfer key: %+v" + + // LoadManager + errLoadKey = "failed to load transfer key: %+v" + errLoadFpVector = "failed to load state vector: %+v" + + // Manager.PopCypher + errGetNextFp = "used all %d fingerprints" + + // Manager.Delete + errDeleteKey = "failed to delete transfer key: %+v" + errDeleteFpVector = "failed to delete fingerprint state vector: %+v" +) + +// Manager the creation +type Manager struct { + // The transfer key is a randomly generated key created by the sender and + // used to generate MACs and fingerprints + key *ftCrypto.TransferKey + + // Stores the state of a fingerprint (used/unused) in a bitstream format + // (has its own storage backend) + fpVector *utility.StateVector + + kv *versioned.KV +} + +// NewManager returns a new cypher Manager initialised with the given number of +// fingerprints. +func NewManager(key *ftCrypto.TransferKey, numFps uint16, kv *versioned.KV) ( + *Manager, error) { + + kv = kv.Prefix(cypherManagerPrefix) + + fpVector, err := utility.NewStateVector( + kv, cypherManagerFpVectorKey, uint32(numFps)) + if err != nil { + return nil, errors.Errorf(errNewFpVector, err) + } + + err = saveKey(key, kv) + if err != nil { + return nil, errors.Errorf(errSaveKey, err) + } + + tfp := &Manager{ + key: key, + fpVector: fpVector, + kv: kv, + } + + return tfp, nil +} + +// PopCypher returns a new Cypher with next available fingerprint number. This +// marks the fingerprint as used. Returns false if no more fingerprints are +// available. +func (m *Manager) PopCypher() (Cypher, error) { + fpNum, err := m.fpVector.Next() + if err != nil { + return Cypher{}, errors.Errorf(errGetNextFp, m.fpVector.GetNumKeys()) + } + + c := Cypher{ + Manager: m, + fpNum: uint16(fpNum), + } + + return c, nil +} + +// GetUnusedCyphers returns a list of cyphers with unused fingerprints numbers. +func (m *Manager) GetUnusedCyphers() []Cypher { + fpNums := m.fpVector.GetUnusedKeyNums() + cypherList := make([]Cypher, len(fpNums)) + + for i, fpNum := range fpNums { + cypherList[i] = Cypher{ + Manager: m, + fpNum: uint16(fpNum), + } + } + + return cypherList +} + +//////////////////////////////////////////////////////////////////////////////// +// Storage Functions // +//////////////////////////////////////////////////////////////////////////////// + +// LoadManager loads the Manager from storage. +func LoadManager(kv *versioned.KV) (*Manager, error) { + kv = kv.Prefix(cypherManagerPrefix) + key, err := loadKey(kv) + if err != nil { + return nil, errors.Errorf(errLoadKey, err) + } + + fpVector, err := utility.LoadStateVector(kv, cypherManagerFpVectorKey) + if err != nil { + return nil, errors.Errorf(errLoadFpVector, err) + } + + tfp := &Manager{ + key: key, + fpVector: fpVector, + kv: kv, + } + + return tfp, nil +} + +// Delete removes all saved entries from storage. +func (m *Manager) Delete() error { + // Delete transfer key + err := m.kv.Delete(cypherManagerKeyStoreKey, cypherManagerKeyStoreVersion) + if err != nil { + return errors.Errorf(errDeleteKey, err) + } + + // Delete StateVector + err = m.fpVector.Delete() + if err != nil { + return errors.Errorf(errDeleteFpVector, err) + } + + return nil +} + +// saveKey saves the transfer key to storage. +func saveKey(key *ftCrypto.TransferKey, kv *versioned.KV) error { + obj := &versioned.Object{ + Version: cypherManagerKeyStoreVersion, + Timestamp: netTime.Now(), + Data: key.Bytes(), + } + + return kv.Set(cypherManagerKeyStoreKey, cypherManagerKeyStoreVersion, obj) +} + +// loadKey loads the transfer key from storage. +func loadKey(kv *versioned.KV) (*ftCrypto.TransferKey, error) { + obj, err := kv.Get(cypherManagerKeyStoreKey, cypherManagerKeyStoreVersion) + if err != nil { + return nil, err + } + + key := ftCrypto.UnmarshalTransferKey(obj.Data) + return &key, nil +} diff --git a/fileTransfer/store/cypher/manager_test.go b/fileTransfer/store/cypher/manager_test.go new file mode 100644 index 0000000000000000000000000000000000000000..ea86d1baba7eece93b3ed26116efa252e54b1e78 --- /dev/null +++ b/fileTransfer/store/cypher/manager_test.go @@ -0,0 +1,170 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package cypher + +import ( + "fmt" + "gitlab.com/elixxir/client/storage/utility" + "gitlab.com/elixxir/client/storage/versioned" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/elixxir/ekv" + "gitlab.com/xx_network/crypto/csprng" + "reflect" + "testing" +) + +// Tests that NewManager returns a new Manager that matches the expected +// manager. +func TestNewManager(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + numFps := uint16(64) + fpv, _ := utility.NewStateVector(kv.Prefix(cypherManagerPrefix), + cypherManagerFpVectorKey, uint32(numFps)) + expected := &Manager{ + key: &ftCrypto.TransferKey{1, 2, 3}, + fpVector: fpv, + kv: kv.Prefix(cypherManagerPrefix), + } + + manager, err := NewManager(expected.key, numFps, kv) + if err != nil { + t.Errorf("NewManager returned an error: %+v", err) + } + + if !reflect.DeepEqual(expected, manager) { + t.Errorf("New manager does not match expected."+ + "\nexpected: %+v\nreceived: %+v", expected, manager) + } +} + +// Tests that Manager.PopCypher returns cyphers with correct fingerprint numbers +// and that trying to pop after the last pop results in an error. +func TestManager_PopCypher(t *testing.T) { + m, _ := newTestManager(64, t) + + for i := uint16(0); i < uint16(m.fpVector.GetNumKeys()); i++ { + c, err := m.PopCypher() + if err != nil { + t.Errorf("Failed to pop cypher #%d: %+v", i, err) + } + + if c.fpNum != i { + t.Errorf("Fingerprint number does not match expected."+ + "\nexpected: %d\nreceived: %d", i, c.fpNum) + } + + if c.Manager != m { + t.Errorf("Cypher has wrong manager.\nexpected: %v\nreceived: %v", + m, c.Manager) + } + } + + // Test that an error is returned when popping a cypher after all + // fingerprints have been used + expectedErr := fmt.Sprintf(errGetNextFp, m.fpVector.GetNumKeys()) + _, err := m.PopCypher() + if err == nil || (err.Error() != expectedErr) { + t.Errorf("PopCypher did not return the expected error when all "+ + "fingerprints should be used.\nexpected: %s\nreceived: %+v", + expectedErr, err) + } +} + +// Tests Manager.GetUnusedCyphers +func TestManager_GetUnusedCyphers(t *testing.T) { + m, _ := newTestManager(64, t) + + // Use every other key + for i := uint32(0); i < m.fpVector.GetNumKeys(); i += 2 { + m.fpVector.Use(i) + } + + // Check that every other key is in the list + for i, c := range m.GetUnusedCyphers() { + if c.fpNum != uint16(2*i)+1 { + t.Errorf("Fingerprint number #%d incorrect."+ + "\nexpected: %d\nreceived: %d", i, 2*i+1, c.fpNum) + } + } +} + +//////////////////////////////////////////////////////////////////////////////// +// Storage Functions // +//////////////////////////////////////////////////////////////////////////////// + +// Tests that a Manager loaded via LoadManager matches the original. +func TestLoadManager(t *testing.T) { + m, kv := newTestManager(64, t) + + // Use every other key + for i := uint32(0); i < m.fpVector.GetNumKeys(); i += 2 { + m.fpVector.Use(i) + } + + newManager, err := LoadManager(kv) + if err != nil { + t.Errorf("Failed to load manager: %+v", err) + } + + if !reflect.DeepEqual(m, newManager) { + t.Errorf("Loaded manager does not match original."+ + "\nexpected: %+v\nreceived: %+v", m, newManager) + } +} + +// Tests that Manager.Delete deletes the storage by trying to load the manager. +func TestManager_Delete(t *testing.T) { + m, _ := newTestManager(64, t) + + err := m.Delete() + if err != nil { + t.Errorf("Failed to delete manager: %+v", err) + } + + _, err = LoadManager(m.kv) + if err == nil { + t.Error("Failed to receive error when loading manager that was deleted.") + } +} + +// Tests that a transfer key saved via saveKey can be loaded via loadKey. +func Test_saveKey_loadKey(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + key := &ftCrypto.TransferKey{42} + + err := saveKey(key, kv) + if err != nil { + t.Errorf("Error when saving key: %+v", err) + } + + loadedKey, err := loadKey(kv) + if err != nil { + t.Errorf("Error when loading key: %+v", err) + } + + if *key != *loadedKey { + t.Errorf("Loaded key does not match original."+ + "\nexpected: %s\nreceived: %s", key, loadedKey) + } +} + +// newTestManager creates a new Manager for testing. +func newTestManager(numFps uint16, t *testing.T) (*Manager, *versioned.KV) { + key, err := ftCrypto.NewTransferKey(csprng.NewSystemRNG()) + if err != nil { + t.Errorf("Failed to generate transfer key: %+v", err) + } + + kv := versioned.NewKV(make(ekv.Memstore)) + m, err := NewManager(&key, numFps, kv) + if err != nil { + t.Errorf("Failed to make new Manager: %+v", err) + } + + return m, kv +} diff --git a/storage/fileTransfer/fileMessage.go b/fileTransfer/store/fileMessage/fileMessage.go similarity index 67% rename from storage/fileTransfer/fileMessage.go rename to fileTransfer/store/fileMessage/fileMessage.go index 90fd6eb30345f2aa5eda31d313058aaadd437ee6..d1d0e107b33d6fe9c08e685461d1269e0bcc7753 100644 --- a/storage/fileTransfer/fileMessage.go +++ b/fileTransfer/store/fileMessage/fileMessage.go @@ -1,37 +1,38 @@ -/////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -/////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// -package fileTransfer +package fileMessage import ( "encoding/binary" "github.com/pkg/errors" + jww "github.com/spf13/jwalterweatherman" ) // Size constants. const ( partNumLen = 2 // The length of the part number in bytes - FmMinSize = partNumLen // Minimum size for the PartMessage + fmMinSize = partNumLen // Minimum size for the PartMessage ) // Error messages. const ( - newFmSizeErr = "size of external payload (%d) must be greater than %d" + errNewFmSize = "[FT] Could not create file part message: size of payload (%d) must be greater than %d" unmarshalFmSizeErr = "size of passed in bytes (%d) must be greater than %d" - setFileFmErr = "length of part bytes (%d) must be smaller than maximum payload size %d" + errSetFileFm = "[FT] Could not set file part message payload: length of part bytes (%d) must be smaller than maximum payload size %d" ) /* -+-----------------------------------------+ -| CMIX Message Contents | -+---------+-------------+-----------------+ -| Padding | Part Number | File Data | -| 8 bytes | 2 bytes | remaining space | -+---------+-------------+-----------------+ ++-------------------------------+ +| CMIX Message Contents | ++-------------+-----------------+ +| Part Number | File Data | +| 2 bytes | remaining space | ++-------------+-----------------+ */ // PartMessage contains part of the data being transferred and 256-bit nonce @@ -45,18 +46,17 @@ type PartMessage struct { // NewPartMessage generates a new part message that fits into the specified // external payload size. An error is returned if the external payload size is // too small to fit the part message. -func NewPartMessage(externalPayloadSize int) (PartMessage, error) { - if externalPayloadSize < FmMinSize { - return PartMessage{}, - errors.Errorf(newFmSizeErr, externalPayloadSize, FmMinSize) +func NewPartMessage(externalPayloadSize int) PartMessage { + if externalPayloadSize < fmMinSize { + jww.FATAL.Panicf(errNewFmSize, externalPayloadSize, fmMinSize) } - return MapPartMessage(make([]byte, externalPayloadSize)), nil + return mapPartMessage(make([]byte, externalPayloadSize)) } -// MapPartMessage maps the data to the components of a PartMessage. It is mapped +// mapPartMessage maps the data to the components of a PartMessage. It is mapped // by reference; a copy is not made. -func MapPartMessage(data []byte) PartMessage { +func mapPartMessage(data []byte) PartMessage { return PartMessage{ data: data, partNum: data[:partNumLen], @@ -67,12 +67,12 @@ func MapPartMessage(data []byte) PartMessage { // UnmarshalPartMessage converts the bytes into a PartMessage. An error is // returned if the size of the data is too small for a PartMessage. func UnmarshalPartMessage(b []byte) (PartMessage, error) { - if len(b) < FmMinSize { + if len(b) < fmMinSize { return PartMessage{}, - errors.Errorf(unmarshalFmSizeErr, len(b), FmMinSize) + errors.Errorf(unmarshalFmSizeErr, len(b), fmMinSize) } - return MapPartMessage(b), nil + return mapPartMessage(b), nil } // Marshal returns the byte representation of the PartMessage. @@ -103,14 +103,12 @@ func (m PartMessage) GetPart() []byte { // SetPart sets the PartMessage part to the given bytes. An error is returned if // the size of the provided part data is too large to store. -func (m PartMessage) SetPart(b []byte) error { +func (m PartMessage) SetPart(b []byte) { if len(b) > len(m.part) { - return errors.Errorf(setFileFmErr, len(b), len(m.part)) + jww.FATAL.Panicf(errSetFileFm, len(b), len(m.part)) } copy(m.part, b) - - return nil } // GetPartSize returns the number of bytes available to store part data. diff --git a/storage/fileTransfer/fileMessage_test.go b/fileTransfer/store/fileMessage/fileMessage_test.go similarity index 76% rename from storage/fileTransfer/fileMessage_test.go rename to fileTransfer/store/fileMessage/fileMessage_test.go index ad565931e6d44d60f887552e9cb4c1a954fe5cf0..b7ac5e0e60d50f4f5be80bf2522db1e010e2de4f 100644 --- a/storage/fileTransfer/fileMessage_test.go +++ b/fileTransfer/store/fileMessage/fileMessage_test.go @@ -1,11 +1,11 @@ -/////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -/////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// -package fileTransfer +package fileMessage import ( "bytes" @@ -19,10 +19,7 @@ import ( func Test_newPartMessage(t *testing.T) { externalPayloadSize := 256 - fm, err := NewPartMessage(externalPayloadSize) - if err != nil { - t.Errorf("NewPartMessage returned an error: %+v", err) - } + fm := NewPartMessage(externalPayloadSize) if len(fm.data) != externalPayloadSize { t.Errorf("Size of PartMessage data does not match payload size."+ @@ -33,25 +30,28 @@ func Test_newPartMessage(t *testing.T) { // Error path: tests that NewPartMessage returns the expected error when the // external payload size is too small. func Test_newPartMessage_SmallPayloadSizeError(t *testing.T) { - externalPayloadSize := FmMinSize - 1 - expectedErr := fmt.Sprintf(newFmSizeErr, externalPayloadSize, FmMinSize) - - _, err := NewPartMessage(externalPayloadSize) - if err == nil || err.Error() != expectedErr { - t.Errorf("NewPartMessage did not return the expected error when the "+ - "given external payload size is too small."+ - "\nexpected: %s\nreceived: %+v", expectedErr, err) - } + externalPayloadSize := fmMinSize - 1 + expectedErr := fmt.Sprintf(errNewFmSize, externalPayloadSize, fmMinSize) + + defer func() { + if r := recover(); r == nil || r != expectedErr { + t.Errorf("NewPartMessage did not return the expected error when "+ + "the given external payload size is too small."+ + "\nexpected: %s\nreceived: %+v", expectedErr, r) + } + }() + + NewPartMessage(externalPayloadSize) } -// Tests that MapPartMessage maps the data to the correct parts of the +// Tests that mapPartMessage maps the data to the correct parts of the // PartMessage. func Test_mapPartMessage(t *testing.T) { // Generate expected values _, expectedData, expectedPartNum, expectedFile := newRandomFileMessage() - fm := MapPartMessage(expectedData) + fm := mapPartMessage(expectedData) if !bytes.Equal(expectedData, fm.data) { t.Errorf("Incorrect data.\nexpected: %q\nreceived: %q", @@ -101,8 +101,8 @@ func Test_unmarshalPartMessage(t *testing.T) { // Error path: tests that UnmarshalPartMessage returns the expected error when // the provided data is too small to be unmarshalled into a PartMessage. func Test_unmarshalPartMessage_SizeError(t *testing.T) { - data := make([]byte, FmMinSize-1) - expectedErr := fmt.Sprintf(unmarshalFmSizeErr, len(data), FmMinSize) + data := make([]byte, fmMinSize-1) + expectedErr := fmt.Sprintf(unmarshalFmSizeErr, len(data), fmMinSize) _, err := UnmarshalPartMessage(data) if err == nil || err.Error() != expectedErr { @@ -139,10 +139,7 @@ func Test_fileMessage_getPartNum(t *testing.T) { // Tests that PartMessage.SetPartNum sets the correct part number. func Test_fileMessage_setPartNum(t *testing.T) { - fm, err := NewPartMessage(256) - if err != nil { - t.Errorf("Failed to create new PartMessage: %+v", err) - } + fm := NewPartMessage(256) expectedPartNum := make([]byte, partNumLen) rand.New(rand.NewSource(42)).Read(expectedPartNum) @@ -170,20 +167,14 @@ func Test_fileMessage_getFile(t *testing.T) { // Tests that PartMessage.SetPart sets the correct part data. func Test_fileMessage_setFile(t *testing.T) { - fm, err := NewPartMessage(256) - if err != nil { - t.Errorf("Failed to create new PartMessage: %+v", err) - } + fm := NewPartMessage(256) fileData := make([]byte, 64) rand.New(rand.NewSource(42)).Read(fileData) expectedFile := make([]byte, fm.GetPartSize()) copy(expectedFile, fileData) - err = fm.SetPart(expectedFile) - if err != nil { - t.Errorf("SetPart returned an error: %+v", err) - } + fm.SetPart(expectedFile) if !bytes.Equal(expectedFile, fm.GetPart()) { t.Errorf("Failed to set correct part data.\nexpected: %q\nreceived: %q", @@ -194,19 +185,19 @@ func Test_fileMessage_setFile(t *testing.T) { // Error path: tests that PartMessage.SetPart returns the expected error when // the provided part data is too large for the message. func Test_fileMessage_setFile_FileTooLargeError(t *testing.T) { - fm, err := NewPartMessage(FmMinSize + 1) - if err != nil { - t.Errorf("Failed to create new PartMessage: %+v", err) - } + fm := NewPartMessage(fmMinSize + 1) - expectedErr := fmt.Sprintf(setFileFmErr, fm.GetPartSize()+1, fm.GetPartSize()) + expectedErr := fmt.Sprintf(errSetFileFm, fm.GetPartSize()+1, fm.GetPartSize()) - err = fm.SetPart(make([]byte, fm.GetPartSize()+1)) - if err == nil || err.Error() != expectedErr { - t.Errorf("SetPart did not return the expected error when the given "+ - "part data is too large to fit in the PartMessage."+ - "\nexpected: %s\nreceived: %+v", expectedErr, err) - } + defer func() { + if r := recover(); r == nil || r != expectedErr { + t.Errorf("SetPart did not return the expected error when the "+ + "given part data is too large to fit in the PartMessage."+ + "\nexpected: %s\nreceived: %+v", expectedErr, r) + } + }() + + fm.SetPart(make([]byte, fm.GetPartSize()+1)) } // Tests that PartMessage.GetPartSize returns the expected available space for @@ -214,10 +205,7 @@ func Test_fileMessage_setFile_FileTooLargeError(t *testing.T) { func Test_fileMessage_getFileSize(t *testing.T) { expectedSize := 256 - fm, err := NewPartMessage(FmMinSize + expectedSize) - if err != nil { - t.Errorf("Failed to create new PartMessage: %+v", err) - } + fm := NewPartMessage(fmMinSize + expectedSize) if expectedSize != fm.GetPartSize() { t.Errorf("File size incorrect.\nexpected: %d\nreceived: %d", @@ -235,7 +223,7 @@ func newRandomFileMessage() (PartMessage, []byte, []byte, []byte) { prng.Read(part) data := append(partNum, part...) - fm := MapPartMessage(data) + fm := mapPartMessage(data) return fm, data, partNum, part } diff --git a/fileTransfer/store/part.go b/fileTransfer/store/part.go new file mode 100644 index 0000000000000000000000000000000000000000..c382699695cda68c24ae417721660baae162757d --- /dev/null +++ b/fileTransfer/store/part.go @@ -0,0 +1,70 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package store + +import ( + "gitlab.com/elixxir/client/fileTransfer/store/cypher" + "gitlab.com/elixxir/client/fileTransfer/store/fileMessage" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/elixxir/primitives/format" + "gitlab.com/xx_network/primitives/id" +) + +// Part contains information about a single file part and its parent transfer. +// Also contains cryptographic information needed to encrypt the part data. +type Part struct { + transfer *SentTransfer + cypherManager *cypher.Manager + partNum uint16 +} + +// GetEncryptedPart gets the specified part, encrypts it, and returns the +// encrypted part along with its MAC and fingerprint. An error is returned if no +// fingerprints are available. +func (p *Part) GetEncryptedPart(contentsSize int) ( + encryptedPart, mac []byte, fp format.Fingerprint, err error) { + // Create new empty file part message of the size provided + partMsg := fileMessage.NewPartMessage(contentsSize) + + // Add part number and part data to part message + partMsg.SetPartNum(p.partNum) + partMsg.SetPart(p.transfer.getPartData(p.partNum)) + + // Get next cypher + c, err := p.cypherManager.PopCypher() + if err != nil { + p.transfer.markTransferFailed() + return nil, nil, format.Fingerprint{}, err + } + + // Encrypt part and get MAC and fingerprint + encryptedPart, mac, fp = c.Encrypt(partMsg.Marshal()) + + return encryptedPart, mac, fp, nil +} + +// MarkArrived marks the part as arrived. This should be called after the round +// the part is sent on succeeds. +func (p *Part) MarkArrived() { + p.transfer.markArrived(p.partNum) +} + +// Recipient returns the recipient of the file transfer. +func (p *Part) Recipient() *id.ID { + return p.transfer.recipient +} + +// TransferID returns the ID of the file transfer. +func (p *Part) TransferID() *ftCrypto.TransferID { + return p.transfer.tid +} + +// FileName returns the name of the file. +func (p *Part) FileName() string { + return p.transfer.FileName() +} diff --git a/fileTransfer/store/part_test.go b/fileTransfer/store/part_test.go new file mode 100644 index 0000000000000000000000000000000000000000..ea71d299e44fbf66f0309e646dbba4f8e14a6596 --- /dev/null +++ b/fileTransfer/store/part_test.go @@ -0,0 +1,119 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package store + +import ( + "bytes" + "gitlab.com/elixxir/client/fileTransfer/store/fileMessage" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/elixxir/primitives/format" + "testing" +) + +// Tests that the encrypted part returned by Part.GetEncryptedPart can be +// decrypted, unmarshalled, and that it matches the original. +func TestPart_GetEncryptedPart(t *testing.T) { + st, parts, key, _, _ := newTestSentTransfer(25, t) + partNum := 0 + part := st.GetUnsentParts()[partNum] + + encryptedPart, mac, fp, err := part.GetEncryptedPart( + format.NewMessage(numPrimeBytes).ContentsSize()) + if err != nil { + t.Errorf("GetEncryptedPart returned an error: %+v", err) + } + + decryptedPart, err := ftCrypto.DecryptPart( + *key, encryptedPart, mac, uint16(partNum), fp) + if err != nil { + t.Errorf("Failed to decrypt part: %+v", err) + } + + partMsg, err := fileMessage.UnmarshalPartMessage(decryptedPart) + if err != nil { + t.Errorf("Failed to unmarshal part message: %+v", err) + } + + if !bytes.Equal(parts[partNum], partMsg.GetPart()) { + t.Errorf("Decrypted part does not match original."+ + "\nexpected: %q\nreceived: %q", parts[partNum], partMsg.GetPart()) + } + + if int(partMsg.GetPartNum()) != partNum { + t.Errorf("Decrypted part does not have correct part number."+ + "\nexpected: %d\nreceived: %d", partNum, partMsg.GetPartNum()) + } +} + +// Tests that Part.GetEncryptedPart returns an error when the underlying cypher +// manager runs out of fingerprints. +func TestPart_GetEncryptedPart_OutOfFingerprints(t *testing.T) { + numParts := uint16(25) + st, _, _, numFps, _ := newTestSentTransfer(numParts, t) + part := st.GetUnsentParts()[0] + for i := uint16(0); i < numFps; i++ { + _, _, _, err := part.GetEncryptedPart( + format.NewMessage(numPrimeBytes).ContentsSize()) + if err != nil { + t.Errorf("Getting encrtypted part %d failed: %+v", i, err) + } + } + + _, _, _, err := part.GetEncryptedPart( + format.NewMessage(numPrimeBytes).ContentsSize()) + if err == nil { + t.Errorf("Failed to get an error when run out of fingerprints.") + } +} + +// Tests that Part.MarkArrived correctly marks the part's status in the +// SentTransfer's partStatus vector. +func TestPart_MarkArrived(t *testing.T) { + st, _, _, _, _ := newTestSentTransfer(25, t) + partNum := 0 + part := st.GetUnsentParts()[partNum] + + part.MarkArrived() + + if !st.partStatus.Used(uint32(partNum)) { + t.Errorf("Part #%d not marked as arrived.", partNum) + } +} + +// Tests that Part.Recipient returns the correct recipient ID. +func TestPart_Recipient(t *testing.T) { + st, _, _, _, _ := newTestSentTransfer(25, t) + part := st.GetUnsentParts()[0] + + if !part.Recipient().Cmp(st.Recipient()) { + t.Errorf("Recipient ID does not match expected."+ + "\nexpected: %s\nreceived: %s", st.Recipient(), part.Recipient()) + } +} + +// Tests that Part.TransferID returns the correct transfer ID. +func TestPart_TransferID(t *testing.T) { + st, _, _, _, _ := newTestSentTransfer(25, t) + part := st.GetUnsentParts()[0] + + if part.TransferID() != st.TransferID() { + t.Errorf("Transfer ID does not match expected."+ + "\nexpected: %s\nreceived: %s", st.TransferID(), part.TransferID()) + } +} + +// Tests that Part.FileName returns the correct file name. +func TestPart_FileName(t *testing.T) { + st, _, _, _, _ := newTestSentTransfer(25, t) + part := st.GetUnsentParts()[0] + + if part.FileName() != st.FileName() { + t.Errorf("File name does not match expected."+ + "\nexpected: %q\nreceived: %q", st.FileName(), part.FileName()) + } +} diff --git a/fileTransfer/store/received.go b/fileTransfer/store/received.go new file mode 100644 index 0000000000000000000000000000000000000000..0c8c5f9c090474a57fed38dce24a6b8690a9f240 --- /dev/null +++ b/fileTransfer/store/received.go @@ -0,0 +1,174 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package store + +import ( + "encoding/json" + "github.com/pkg/errors" + jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/client/storage/versioned" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/elixxir/ekv" + "gitlab.com/xx_network/primitives/netTime" + "sync" +) + +// Storage keys and versions. +const ( + receivedTransfersStorePrefix = "ReceivedFileTransfersPrefix" + receivedTransfersStoreKey = "ReceivedFileTransfers" + receivedTransfersStoreVersion = 0 +) + +// Error messages. +const ( + // NewOrLoadReceived + errLoadReceived = "error loading received transfer list from storage: %+v" + errUnmarshalReceived = "could not unmarshal received transfer list: %+v" + warnLoadReceivedTransfer = "[FT] failed to load received transfer %d of %d with ID %s: %+v" + errLoadAllReceivedTransfer = "failed to load all %d transfers" + + // Received.AddTransfer + errAddExistingReceivedTransfer = "received transfer with ID %s already exists in map." +) + +// Received contains a list of all received transfers. +type Received struct { + transfers map[ftCrypto.TransferID]*ReceivedTransfer + + mux sync.RWMutex + kv *versioned.KV +} + +// NewOrLoadReceived attempts to load a Received from storage. Or if none exist, +// then a new Received is returned. Also returns a list of all transfers that +// have unreceived file parts so their fingerprints can be re-added. +func NewOrLoadReceived(kv *versioned.KV) (*Received, []*ReceivedTransfer, error) { + s := &Received{ + transfers: make(map[ftCrypto.TransferID]*ReceivedTransfer), + kv: kv.Prefix(receivedTransfersStorePrefix), + } + + obj, err := s.kv.Get(receivedTransfersStoreKey, receivedTransfersStoreVersion) + if err != nil { + if ekv.Exists(err) { + return nil, nil, errors.Errorf(errLoadReceived, err) + } else { + return s, nil, nil + } + } + + tidList, err := unmarshalTransferIdList(obj.Data) + if err != nil { + return nil, nil, errors.Errorf(errUnmarshalReceived, err) + } + + var errCount int + unfinishedTransfer := make([]*ReceivedTransfer, 0, len(tidList)) + for i := range tidList { + tid := tidList[i] + s.transfers[tid], err = loadReceivedTransfer(&tid, s.kv) + if err != nil { + jww.WARN.Print(warnLoadReceivedTransfer, i, len(tidList), tid, err) + errCount++ + } + + if s.transfers[tid].NumReceived() != s.transfers[tid].NumParts() { + unfinishedTransfer = append(unfinishedTransfer, s.transfers[tid]) + } + } + + // Return an error if all transfers failed to load + if errCount == len(tidList) { + return nil, nil, errors.Errorf(errLoadAllReceivedTransfer, len(tidList)) + } + + return s, unfinishedTransfer, nil +} + +// AddTransfer adds the ReceivedTransfer to the map keyed on its transfer ID. +func (r *Received) AddTransfer(key *ftCrypto.TransferKey, + tid *ftCrypto.TransferID, fileName string, transferMAC []byte, numParts, + numFps uint16, fileSize uint32) (*ReceivedTransfer, error) { + + r.mux.Lock() + defer r.mux.Unlock() + + _, exists := r.transfers[*tid] + if exists { + return nil, errors.Errorf(errAddExistingReceivedTransfer, tid) + } + + rt, err := newReceivedTransfer(key, tid, fileName, transferMAC, numParts, + numFps, fileSize, r.kv) + if err != nil { + return nil, err + } + + r.transfers[*tid] = rt + + return rt, r.save() +} + +// GetTransfer returns the ReceivedTransfer with the desiccated transfer ID or +// false if none exists. +func (r *Received) GetTransfer(tid *ftCrypto.TransferID) (*ReceivedTransfer, bool) { + r.mux.RLock() + defer r.mux.RUnlock() + + rt, exists := r.transfers[*tid] + return rt, exists +} + +// RemoveTransfer removes the transfer from the map. If no transfer exists, +// returns nil. Only errors due to saving to storage are returned. +func (r *Received) RemoveTransfer(tid *ftCrypto.TransferID) error { + r.mux.Lock() + defer r.mux.Unlock() + + _, exists := r.transfers[*tid] + if !exists { + return nil + } + + delete(r.transfers, *tid) + return r.save() +} + +//////////////////////////////////////////////////////////////////////////////// +// Storage Functions // +//////////////////////////////////////////////////////////////////////////////// + +// save stores a list of transfer IDs in the map to storage. +func (r *Received) save() error { + data, err := marshalReceivedTransfersMap(r.transfers) + if err != nil { + return err + } + + obj := &versioned.Object{ + Version: receivedTransfersStoreVersion, + Timestamp: netTime.Now(), + Data: data, + } + + return r.kv.Set(receivedTransfersStoreKey, receivedTransfersStoreVersion, obj) +} + +// marshalReceivedTransfersMap serialises the list of transfer IDs from a +// ReceivedTransfer map. +func marshalReceivedTransfersMap( + transfers map[ftCrypto.TransferID]*ReceivedTransfer) ([]byte, error) { + tidList := make([]ftCrypto.TransferID, 0, len(transfers)) + + for tid := range transfers { + tidList = append(tidList, tid) + } + + return json.Marshal(tidList) +} diff --git a/fileTransfer/store/receivedTransfer.go b/fileTransfer/store/receivedTransfer.go new file mode 100644 index 0000000000000000000000000000000000000000..434e32bf1d6bc66e1a6247e91f30db22bd6c2384 --- /dev/null +++ b/fileTransfer/store/receivedTransfer.go @@ -0,0 +1,370 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package store + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "github.com/pkg/errors" + jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/client/fileTransfer/store/cypher" + "gitlab.com/elixxir/client/storage/utility" + "gitlab.com/elixxir/client/storage/versioned" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/xx_network/primitives/netTime" + "strconv" + "sync" +) + +// Storage keys and versions. +const ( + receivedTransferStorePrefix = "ReceivedFileTransferStore/" + receivedTransferStoreKey = "ReceivedTransfer" + receivedTransferStoreVersion = 0 + receivedTransferStatusKey = "ReceivedPartStatusVector" + receivedPartStoreKey = "receivedPart#" + receivedPartStoreVersion = 0 +) + +// Error messages. +const ( + // newReceivedTransfer + errRtNewCypherManager = "failed to create new cypher manager: %+v" + errRtNewPartStatusVectorErr = "failed to create new state vector for part statuses: %+v" + + // ReceivedTransfer.AddPart + errPartOutOfRange = "part number %d out of range of max %d" + errReceivedPartSave = "failed to save part #%d to storage: %+v" + + // loadReceivedTransfer + errRtLoadCypherManager = "failed to load cypher manager from storage: %+v" + errRtLoadFields = "failed to load transfer MAC, number of parts, and file size: %+v" + errRtUnmarshalFields = "failed to unmarshal transfer MAC, number of parts, and file size: %+v" + errRtLoadPartStatusVector = "failed to load state vector for part statuses: %+v" + errRtLoadPart = "[FT] Failed to load part #%d from storage: %+v" + + // ReceivedTransfer.Delete + errRtDeleteCypherManager = "failed to delete cypher manager: %+v" + errRtDeleteSentTransfer = "failed to delete transfer MAC, number of parts, and file size: %+v" + errRtDeletePartStatus = "failed to delete part status state vector: %+v" + + // ReceivedTransfer.save + errMarshalReceivedTransfer = "failed to marshal: %+v" +) + +// ReceivedTransfer contains information and progress data for a receiving or +// received file transfer. +type ReceivedTransfer struct { + // Tracks file part cyphers + cypherManager *cypher.Manager + + // The ID of the transfer + tid *ftCrypto.TransferID + + // User given name to file + fileName string + + // The MAC for the entire file; used to verify the integrity of all parts + transferMAC []byte + + // The number of file parts in the file + numParts uint16 + + // Size of the entire file in bytes + fileSize uint32 + + // Saves each part in order (has its own storage backend) + parts [][]byte + + // Stores the received status for each file part in a bitstream format + partStatus *utility.StateVector + + mux sync.RWMutex + kv *versioned.KV +} + +// newReceivedTransfer generates a ReceivedTransfer with the specified transfer +// key, transfer ID, and a number of parts. +func newReceivedTransfer(key *ftCrypto.TransferKey, tid *ftCrypto.TransferID, + fileName string, transferMAC []byte, numParts, numFps uint16, + fileSize uint32, kv *versioned.KV) (*ReceivedTransfer, error) { + kv = kv.Prefix(makeReceivedTransferPrefix(tid)) + + // Create new cypher manager + cypherManager, err := cypher.NewManager(key, numFps, kv) + if err != nil { + return nil, errors.Errorf(errRtNewCypherManager, err) + } + + // Create new state vector for storing statuses of received parts + partStatus, err := utility.NewStateVector( + kv, receivedTransferStatusKey, uint32(numParts)) + if err != nil { + return nil, errors.Errorf(errRtNewPartStatusVectorErr, err) + } + + rt := &ReceivedTransfer{ + cypherManager: cypherManager, + tid: tid, + fileName: fileName, + transferMAC: transferMAC, + numParts: numParts, + fileSize: fileSize, + parts: make([][]byte, numParts), + partStatus: partStatus, + kv: kv, + } + + return rt, rt.save() +} + +// AddPart adds the file part to the list of file parts at the index of partNum. +func (rt *ReceivedTransfer) AddPart(part []byte, partNum int) error { + rt.mux.Lock() + defer rt.mux.Unlock() + + if partNum > len(rt.parts)-1 { + return errors.Errorf(errPartOutOfRange, partNum, len(rt.parts)-1) + } + + // Save part + rt.parts[partNum] = part + err := savePart(part, partNum, rt.kv) + if err != nil { + return errors.Errorf(errReceivedPartSave, partNum, err) + } + + // Mark part as received + rt.partStatus.Use(uint32(partNum)) + + return nil +} + +// GetFile concatenates all file parts and returns it as a single complete file. +// Note that this function does not care for the completeness of the file and +// returns all parts it has. +func (rt *ReceivedTransfer) GetFile() []byte { + rt.mux.RLock() + defer rt.mux.RUnlock() + + file := bytes.Join(rt.parts, nil) + + // Strip off trailing padding from last part + if len(file) > int(rt.fileSize) { + file = file[:rt.fileSize] + } + + return file +} + +// GetUnusedCyphers returns a list of cyphers with unused fingerprint numbers. +func (rt *ReceivedTransfer) GetUnusedCyphers() []cypher.Cypher { + return rt.cypherManager.GetUnusedCyphers() +} + +// NumParts returns the total number of file parts in the transfer. +func (rt *ReceivedTransfer) NumParts() uint16 { + return rt.numParts +} + +// TransferID returns the transfer's ID. +func (rt *ReceivedTransfer) TransferID() *ftCrypto.TransferID { + return rt.tid +} + +// FileName returns the transfer's file name. +func (rt *ReceivedTransfer) FileName() string { + return rt.fileName +} + +// NumReceived returns the number of parts that have been received. +func (rt *ReceivedTransfer) NumReceived() uint16 { + rt.mux.RLock() + defer rt.mux.RUnlock() + return uint16(rt.partStatus.GetNumUsed()) +} + +// CopyPartStatusVector returns a copy of the part status vector that can be +// used to look up the current status of parts. Note that the statuses are from +// when this function is called and not realtime. +func (rt *ReceivedTransfer) CopyPartStatusVector() *utility.StateVector { + return rt.partStatus.DeepCopy() +} + +//////////////////////////////////////////////////////////////////////////////// +// Storage Functions // +//////////////////////////////////////////////////////////////////////////////// + +// loadReceivedTransfer loads the ReceivedTransfer with the given transfer ID +// from storage. +func loadReceivedTransfer(tid *ftCrypto.TransferID, kv *versioned.KV) ( + *ReceivedTransfer, error) { + kv = kv.Prefix(makeReceivedTransferPrefix(tid)) + + // Load cypher manager + cypherManager, err := cypher.LoadManager(kv) + if err != nil { + return nil, errors.Errorf(errRtLoadCypherManager, err) + } + + // Load transfer MAC, number of parts, and file size + obj, err := kv.Get(receivedTransferStoreKey, receivedTransferStoreVersion) + if err != nil { + return nil, errors.Errorf(errRtLoadFields, err) + } + + fileName, transferMAC, numParts, fileSize, err := + unmarshalReceivedTransfer(obj.Data) + if err != nil { + return nil, errors.Errorf(errRtUnmarshalFields, err) + } + + // Load StateVector for storing statuses of received parts + partStatus, err := utility.LoadStateVector(kv, receivedTransferStatusKey) + if err != nil { + return nil, errors.Errorf(errRtLoadPartStatusVector, err) + } + + // Load parts from storage + parts := make([][]byte, numParts) + for i := range parts { + if partStatus.Used(uint32(i)) { + parts[i], err = loadPart(i, kv) + if err != nil { + jww.ERROR.Printf(errRtLoadPart, i, err) + } + } + } + + rt := &ReceivedTransfer{ + cypherManager: cypherManager, + tid: tid, + fileName: fileName, + transferMAC: transferMAC, + numParts: numParts, + fileSize: fileSize, + parts: parts, + partStatus: partStatus, + kv: kv, + } + + return rt, nil +} + +// Delete deletes all data in the ReceivedTransfer from storage. +func (rt *ReceivedTransfer) Delete() error { + rt.mux.Lock() + defer rt.mux.Unlock() + + // Delete cypher manager + err := rt.cypherManager.Delete() + if err != nil { + return errors.Errorf(errRtDeleteCypherManager, err) + } + + // Delete transfer MAC, number of parts, and file size + err = rt.kv.Delete(receivedTransferStoreKey, receivedTransferStoreVersion) + if err != nil { + return errors.Errorf(errRtDeleteSentTransfer, err) + } + + // Delete part status state vector + err = rt.partStatus.Delete() + if err != nil { + return errors.Errorf(errRtDeletePartStatus, err) + } + + return nil +} + +// save stores all fields in ReceivedTransfer that do not have their own storage +// (transfer MAC, file size, and number of file parts) to storage. +func (rt *ReceivedTransfer) save() error { + data, err := rt.marshal() + if err != nil { + return errors.Errorf(errMarshalReceivedTransfer, err) + } + + // Create new versioned object for the ReceivedTransfer + vo := &versioned.Object{ + Version: receivedTransferStoreVersion, + Timestamp: netTime.Now(), + Data: data, + } + + // Save versioned object + return rt.kv.Set(receivedTransferStoreKey, receivedTransferStoreVersion, vo) +} + +// receivedTransferDisk structure is used to marshal and unmarshal +// ReceivedTransfer fields to/from storage. +type receivedTransferDisk struct { + FileName string + TransferMAC []byte + NumParts uint16 + FileSize uint32 +} + +// marshal serialises the ReceivedTransfer's fileName, transferMAC, numParts, +// and fileSize. +func (rt *ReceivedTransfer) marshal() ([]byte, error) { + disk := receivedTransferDisk{ + FileName: rt.fileName, + TransferMAC: rt.transferMAC, + NumParts: rt.numParts, + FileSize: rt.fileSize, + } + + return json.Marshal(disk) +} + +// unmarshalReceivedTransfer deserializes the data into the fileName, +// transferMAC, numParts, and fileSize. +func unmarshalReceivedTransfer(data []byte) (fileName string, + transferMAC []byte, numParts uint16, fileSize uint32, err error) { + var disk receivedTransferDisk + err = json.Unmarshal(data, &disk) + if err != nil { + return "", nil, 0, 0, err + } + + return disk.FileName, disk.TransferMAC, disk.NumParts, disk.FileSize, nil +} + +// savePart saves the given part to storage keying on its part number. +func savePart(part []byte, partNum int, kv *versioned.KV) error { + obj := &versioned.Object{ + Version: receivedPartStoreVersion, + Timestamp: netTime.Now(), + Data: part, + } + + return kv.Set(makeReceivedPartKey(partNum), receivedPartStoreVersion, obj) +} + +// loadPart loads the part with the given part number from storage. +func loadPart(partNum int, kv *versioned.KV) ([]byte, error) { + obj, err := kv.Get(makeReceivedPartKey(partNum), receivedPartStoreVersion) + if err != nil { + return nil, err + } + return obj.Data, nil +} + +// makeReceivedTransferPrefix generates the unique prefix used on the key value +// store to store received transfers for the given transfer ID. +func makeReceivedTransferPrefix(tid *ftCrypto.TransferID) string { + return receivedTransferStorePrefix + + base64.StdEncoding.EncodeToString(tid.Bytes()) +} + +// makeReceivedPartKey generates a storage key for the given part number. +func makeReceivedPartKey(partNum int) string { + return receivedPartStoreKey + strconv.Itoa(partNum) +} diff --git a/fileTransfer/store/receivedTransfer_test.go b/fileTransfer/store/receivedTransfer_test.go new file mode 100644 index 0000000000000000000000000000000000000000..0f2c16cca43a883a2515e6c77363daded6c1bc3c --- /dev/null +++ b/fileTransfer/store/receivedTransfer_test.go @@ -0,0 +1,447 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package store + +import ( + "bytes" + "fmt" + "gitlab.com/elixxir/client/fileTransfer/store/cypher" + "gitlab.com/elixxir/client/storage/utility" + "gitlab.com/elixxir/client/storage/versioned" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/elixxir/ekv" + "gitlab.com/xx_network/crypto/csprng" + "reflect" + "testing" +) + +// Tests that newReceivedTransfer returns a new ReceivedTransfer with the +// expected values. +func Test_newReceivedTransfer(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + key, _ := ftCrypto.NewTransferKey(csprng.NewSystemRNG()) + tid, _ := ftCrypto.NewTransferID(csprng.NewSystemRNG()) + numFps := uint16(24) + parts := generateTestParts(16) + fileSize := uint32(len(parts) * len(parts[0])) + numParts := uint16(len(parts)) + rtKv := kv.Prefix(makeReceivedTransferPrefix(&tid)) + + cypherManager, err := cypher.NewManager(&key, numFps, rtKv) + if err != nil { + t.Errorf("Failed to make new cypher manager: %+v", err) + } + partStatus, err := utility.NewStateVector( + rtKv, receivedTransferStatusKey, uint32(numParts)) + if err != nil { + t.Errorf("Failed to make new state vector: %+v", err) + } + + expected := &ReceivedTransfer{ + cypherManager: cypherManager, + tid: &tid, + fileName: "fileName", + transferMAC: []byte("transferMAC"), + numParts: numParts, + fileSize: fileSize, + parts: make([][]byte, numParts), + partStatus: partStatus, + kv: rtKv, + } + + rt, err := newReceivedTransfer(&key, &tid, expected.fileName, + expected.transferMAC, numParts, numFps, fileSize, kv) + if err != nil { + t.Errorf("newReceivedTransfer returned an error: %+v", err) + } + + if !reflect.DeepEqual(expected, rt) { + t.Errorf("New ReceivedTransfer does not match expected."+ + "\nexpected: %+v\nreceived: %+v", expected, rt) + } +} + +// Tests that ReceivedTransfer.AddPart adds the part to the part list and marks +// it as received +func TestReceivedTransfer_AddPart(t *testing.T) { + rt, _, _, _ := newTestReceivedTransfer(16, t) + + part := []byte("Part") + partNum := 6 + + err := rt.AddPart(part, partNum) + if err != nil { + t.Errorf("Failed to add part: %+v", err) + } + + if !bytes.Equal(rt.parts[partNum], part) { + t.Errorf("Found incorrect part in list.\nexpected: %q\nreceived: %q", + part, rt.parts[partNum]) + } + + if !rt.partStatus.Used(uint32(partNum)) { + t.Errorf("Part #%d not marked as received.", partNum) + } +} + +// Tests that ReceivedTransfer.AddPart returns an error if the part number is +// not within the range of part numbers +func TestReceivedTransfer_AddPart_PartOutOfRangeError(t *testing.T) { + rt, _, _, _ := newTestReceivedTransfer(16, t) + + expectedErr := fmt.Sprintf(errPartOutOfRange, rt.partStatus.GetNumKeys(), + rt.partStatus.GetNumKeys()-1) + + err := rt.AddPart([]byte("Part"), int(rt.partStatus.GetNumKeys())) + if err == nil || err.Error() != expectedErr { + t.Errorf("Failed to get expected error when part number is out of range."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Tests that ReceivedTransfer.GetFile returns the expected file after all the +// parts are added to the transfer. +func TestReceivedTransfer_GetFile(t *testing.T) { + // Generate parts and make last file part smaller than the rest + parts := generateTestParts(16) + lastPartLen := 6 + rt, _, _, _ := newTestReceivedTransfer(uint16(len(parts)), t) + rt.fileSize = uint32((len(parts)-1)*len(parts[0]) + lastPartLen) + + for i, p := range parts { + err := rt.AddPart(p, i) + if err != nil { + t.Errorf("Failed to add part #%d: %+v", i, err) + } + } + + parts[len(parts)-1] = parts[len(parts)-1][:lastPartLen] + combinedParts := bytes.Join(parts, nil) + + file := rt.GetFile() + + if !bytes.Equal(file, combinedParts) { + t.Errorf("Received file does not match expected."+ + "\nexpected: %q\nreceived: %q", combinedParts, file) + } + +} + +// Tests that ReceivedTransfer.GetUnusedCyphers returns the correct number of +// unused cyphers. +func TestReceivedTransfer_GetUnusedCyphers(t *testing.T) { + numParts := uint16(10) + rt, _, numFps, _ := newTestReceivedTransfer(numParts, t) + + // Check that all cyphers are returned after initialisation + unsentCyphers := rt.GetUnusedCyphers() + if len(unsentCyphers) != int(numFps) { + t.Errorf("Number of unused cyphers does not match original number of "+ + "fingerprints when none have been used.\nexpected: %d\nreceived: %d", + numFps, len(unsentCyphers)) + } + + // Use every other part + for i := range unsentCyphers { + if i%2 == 0 { + _, _ = unsentCyphers[i].PopCypher() + } + } + + // Check that only have the number of parts is returned + unsentCyphers = rt.GetUnusedCyphers() + if len(unsentCyphers) != int(numFps)/2 { + t.Errorf("Number of unused cyphers is not half original number after "+ + "half have been marked as received.\nexpected: %d\nreceived: %d", + numFps/2, len(unsentCyphers)) + } + + // Use the rest of the parts + for i := range unsentCyphers { + _, _ = unsentCyphers[i].PopCypher() + } + + // Check that no sent parts are returned + unsentCyphers = rt.GetUnusedCyphers() + if len(unsentCyphers) != 0 { + t.Errorf("Number of unused cyphers is not zero after all have been "+ + "marked as received.\nexpected: %d\nreceived: %d", + 0, len(unsentCyphers)) + } +} + +// Tests that ReceivedTransfer.NumParts returns the correct number of parts. +func TestReceivedTransfer_NumParts(t *testing.T) { + numParts := uint16(16) + rt, _, _, _ := newTestReceivedTransfer(numParts, t) + + if rt.NumParts() != numParts { + t.Errorf("Incorrect number of parts.\nexpected: %d\nreceived: %d", + numParts, rt.NumParts()) + } +} + +// Tests that ReceivedTransfer.TransferID returns the correct transfer ID. +func TestReceivedTransfer_TransferID(t *testing.T) { + rt, _, _, _ := newTestReceivedTransfer(16, t) + + if rt.TransferID() != rt.tid { + t.Errorf("Incorrect transfer ID.\nexpected: %s\nreceived: %s", + rt.tid, rt.TransferID()) + } +} + +// Tests that ReceivedTransfer.FileName returns the correct file name. +func TestReceivedTransfer_FileName(t *testing.T) { + rt, _, _, _ := newTestReceivedTransfer(16, t) + + if rt.FileName() != rt.fileName { + t.Errorf("Incorrect transfer ID.\nexpected: %s\nreceived: %s", + rt.fileName, rt.FileName()) + } +} + +// Tests that ReceivedTransfer.NumReceived returns the correct number of +// received parts. +func TestReceivedTransfer_NumReceived(t *testing.T) { + rt, _, _, _ := newTestReceivedTransfer(16, t) + + if rt.NumReceived() != 0 { + t.Errorf("Incorrect number of received parts."+ + "\nexpected: %d\nreceived: %d", 0, rt.NumReceived()) + } + + // Add all parts as received + for i := 0; i < int(rt.numParts); i++ { + _ = rt.AddPart(nil, i) + } + + if uint32(rt.NumReceived()) != rt.partStatus.GetNumKeys() { + t.Errorf("Incorrect number of received parts."+ + "\nexpected: %d\nreceived: %d", + uint32(rt.NumReceived()), rt.partStatus.GetNumKeys()) + } +} + +// Tests that the state vector returned by ReceivedTransfer.CopyPartStatusVector +// has the same values as the original but is a copy. +func TestReceivedTransfer_CopyPartStatusVector(t *testing.T) { + rt, _, _, _ := newTestReceivedTransfer(64, t) + + // Check that the vectors have the same unused parts + partStatus := rt.CopyPartStatusVector() + if !reflect.DeepEqual( + partStatus.GetUnusedKeyNums(), rt.partStatus.GetUnusedKeyNums()) { + t.Errorf("Copied part status does not match original."+ + "\nexpected: %v\nreceived: %v", + rt.partStatus.GetUnusedKeyNums(), partStatus.GetUnusedKeyNums()) + } + + // Modify the state + _ = rt.AddPart([]byte("hello"), 5) + + // Check that the copied state is different + if reflect.DeepEqual( + partStatus.GetUnusedKeyNums(), rt.partStatus.GetUnusedKeyNums()) { + t.Errorf("Old copied part status matches new status."+ + "\nexpected: %v\nreceived: %v", + rt.partStatus.GetUnusedKeyNums(), partStatus.GetUnusedKeyNums()) + } +} + +//////////////////////////////////////////////////////////////////////////////// +// Storage Functions // +//////////////////////////////////////////////////////////////////////////////// + +// Tests that a ReceivedTransfer loaded via loadReceivedTransfer matches the +// original. +func Test_loadReceivedTransfer(t *testing.T) { + parts := generateTestParts(16) + rt, _, _, kv := newTestReceivedTransfer(uint16(len(parts)), t) + + for i, p := range parts { + if i%2 == 0 { + + err := rt.AddPart(p, i) + if err != nil { + t.Errorf("Failed to add part #%d: %+v", i, err) + } + } + } + + loadedRt, err := loadReceivedTransfer(rt.tid, kv) + if err != nil { + t.Errorf("Failed to load ReceivedTransfer: %+v", err) + } + + if !reflect.DeepEqual(rt, loadedRt) { + t.Errorf("Loaded ReceivedTransfer does not match original."+ + "\nexpected: %+v\nreceived: %+v", rt, loadedRt) + } +} + +// Tests that ReceivedTransfer.Delete deletes the storage backend of the +// ReceivedTransfer and that it cannot be loaded again. +func TestReceivedTransfer_Delete(t *testing.T) { + rt, _, _, kv := newTestReceivedTransfer(64, t) + + err := rt.Delete() + if err != nil { + t.Errorf("Delete returned an error: %+v", err) + } + + _, err = loadSentTransfer(rt.tid, kv) + if err == nil { + t.Errorf("Loaded received transfer that was deleted.") + } +} + +// Tests that the fields saved by ReceivedTransfer.save can be loaded from +// storage. +func TestReceivedTransfer_save(t *testing.T) { + rt, _, _, _ := newTestReceivedTransfer(64, t) + + err := rt.save() + if err != nil { + t.Errorf("save returned an error: %+v", err) + } + + _, err = rt.kv.Get(receivedTransferStoreKey, receivedTransferStoreVersion) + if err != nil { + t.Errorf("Failed to load saved ReceivedTransfer: %+v", err) + } +} + +// newTestReceivedTransfer creates a new ReceivedTransfer for testing. +func newTestReceivedTransfer(numParts uint16, t *testing.T) ( + *ReceivedTransfer, *ftCrypto.TransferKey, uint16, *versioned.KV) { + kv := versioned.NewKV(make(ekv.Memstore)) + key, _ := ftCrypto.NewTransferKey(csprng.NewSystemRNG()) + tid, _ := ftCrypto.NewTransferID(csprng.NewSystemRNG()) + transferMAC := []byte("I am a transfer MAC") + numFps := 2 * numParts + fileName := "helloFile" + parts := generateTestParts(numParts) + fileSize := uint32(len(parts) * len(parts[0])) + + st, err := newReceivedTransfer( + &key, &tid, fileName, transferMAC, numParts, numFps, fileSize, kv) + if err != nil { + t.Errorf("Failed to make new SentTransfer: %+v", err) + } + + return st, &key, numFps, kv +} + +// Tests that a ReceivedTransfer marshalled via ReceivedTransfer.marshal and +// unmarshalled via unmarshalReceivedTransfer matches the original. +func TestReceivedTransfer_marshal_unmarshalReceivedTransfer(t *testing.T) { + rt := &ReceivedTransfer{ + fileName: "transferName", + transferMAC: []byte("I am a transfer MAC"), + numParts: 153, + fileSize: 735, + } + + data, err := rt.marshal() + if err != nil { + t.Errorf("marshal returned an error: %+v", err) + } + + fileName, transferMac, numParts, fileSize, err := + unmarshalReceivedTransfer(data) + if err != nil { + t.Errorf("Failed to unmarshal SentTransfer: %+v", err) + } + + if rt.fileName != fileName { + t.Errorf("Incorrect file name.\nexpected: %q\nreceived: %q", + rt.fileName, fileName) + } + + if !bytes.Equal(rt.transferMAC, transferMac) { + t.Errorf("Incorrect transfer MAC.\nexpected: %s\nreceived: %s", + rt.transferMAC, transferMac) + } + + if rt.numParts != numParts { + t.Errorf("Incorrect number of parts.\nexpected: %d\nreceived: %d", + rt.numParts, numParts) + } + + if rt.fileSize != fileSize { + t.Errorf("Incorrect file size.\nexpected: %d\nreceived: %d", + rt.fileSize, fileSize) + } +} + +// Tests that the part saved to storage via savePart can be loaded. +func Test_savePart_loadPart(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + part := []byte("I am a part.") + partNum := 18 + + err := savePart(part, partNum, kv) + if err != nil { + t.Errorf("Failed to save part: %+v", err) + } + + loadedPart, err := loadPart(partNum, kv) + if err != nil { + t.Errorf("Failed to load part: %+v", err) + } + + if !bytes.Equal(part, loadedPart) { + t.Errorf("Loaded part does not match original."+ + "\nexpected: %q\nreceived: %q", part, loadedPart) + } +} + +// Consistency test of makeReceivedTransferPrefix. +func Test_makeReceivedTransferPrefix_Consistency(t *testing.T) { + expectedPrefixes := []string{ + "ReceivedFileTransferStore/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", + "ReceivedFileTransferStore/AQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", + "ReceivedFileTransferStore/AgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", + "ReceivedFileTransferStore/AwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", + "ReceivedFileTransferStore/BAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", + "ReceivedFileTransferStore/BQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", + "ReceivedFileTransferStore/BgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", + "ReceivedFileTransferStore/BwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", + "ReceivedFileTransferStore/CAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", + "ReceivedFileTransferStore/CQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", + } + + for i, expected := range expectedPrefixes { + tid := ftCrypto.TransferID{byte(i)} + prefix := makeReceivedTransferPrefix(&tid) + + if expected != prefix { + t.Errorf("Prefix #%d does not match expected."+ + "\nexpected: %q\nreceived: %q", i, expected, prefix) + } + } +} + +// Consistency test of makeReceivedPartKey. +func Test_makeReceivedPartKey_Consistency(t *testing.T) { + expectedKeys := []string{ + "receivedPart#0", "receivedPart#1", "receivedPart#2", "receivedPart#3", + "receivedPart#4", "receivedPart#5", "receivedPart#6", "receivedPart#7", + "receivedPart#8", "receivedPart#9", + } + + for i, expected := range expectedKeys { + key := makeReceivedPartKey(i) + + if expected != key { + t.Errorf("Key #%d does not match expected."+ + "\nexpected: %q\nreceived: %q", i, expected, key) + } + } +} diff --git a/fileTransfer/store/received_test.go b/fileTransfer/store/received_test.go new file mode 100644 index 0000000000000000000000000000000000000000..83134227235484bee889d6c4a1a1a433450caa3a --- /dev/null +++ b/fileTransfer/store/received_test.go @@ -0,0 +1,255 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package store + +import ( + "bytes" + "fmt" + "gitlab.com/elixxir/client/storage/versioned" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/elixxir/ekv" + "gitlab.com/xx_network/crypto/csprng" + "reflect" + "sort" + "strconv" + "testing" +) + +// Tests that NewOrLoadReceived returns a new Received when none exist in +// storage and that the list of incomplete transfers is nil. +func TestNewOrLoadReceived_New(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + expected := &Received{ + transfers: make(map[ftCrypto.TransferID]*ReceivedTransfer), + kv: kv.Prefix(receivedTransfersStorePrefix), + } + + r, incompleteTransfers, err := NewOrLoadReceived(kv) + if err != nil { + t.Errorf("NewOrLoadReceived returned an error: %+v", err) + } + + if !reflect.DeepEqual(expected, r) { + t.Errorf("New Received does not match expected."+ + "\nexpected: %+v\nreceived: %+v", expected, r) + } + + if incompleteTransfers != nil { + t.Errorf("List of incomplete transfers should be nil when not "+ + "loading: %+v", incompleteTransfers) + } +} + +// Tests that NewOrLoadReceived returns a loaded Received when one exist in +// storage and that the list of incomplete transfers is correct. +func TestNewOrLoadReceived_Load(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + r, _, err := NewOrLoadReceived(kv) + if err != nil { + t.Errorf("Failed to make new Received: %+v", err) + } + var expectedIncompleteTransfers []*ReceivedTransfer + + // Create and add transfers to map and save + for i := 0; i < 2; i++ { + key, _ := ftCrypto.NewTransferKey(csprng.NewSystemRNG()) + tid, _ := ftCrypto.NewTransferID(csprng.NewSystemRNG()) + rt, err2 := r.AddTransfer(&key, &tid, "file"+strconv.Itoa(i), + []byte("transferMAC"+strconv.Itoa(i)), 10, 20, 128) + if err2 != nil { + t.Errorf("Failed to add transfer #%d: %+v", i, err2) + } + expectedIncompleteTransfers = append(expectedIncompleteTransfers, rt) + } + if err = r.save(); err != nil { + t.Errorf("Failed to make save filled Receivced: %+v", err) + } + + // Load Received + loadedReceived, incompleteTransfers, err := NewOrLoadReceived(kv) + if err != nil { + t.Errorf("Failed to load Received: %+v", err) + } + + // Check that the loaded Received matches original + if !reflect.DeepEqual(r, loadedReceived) { + t.Errorf("Loaded Received does not match original."+ + "\nexpected: %#v\nreceived: %#v", r, loadedReceived) + } + + sort.Slice(incompleteTransfers, func(i, j int) bool { + return bytes.Compare(incompleteTransfers[i].TransferID()[:], + incompleteTransfers[j].TransferID()[:]) == -1 + }) + + sort.Slice(expectedIncompleteTransfers, func(i, j int) bool { + return bytes.Compare(expectedIncompleteTransfers[i].TransferID()[:], + expectedIncompleteTransfers[j].TransferID()[:]) == -1 + }) + + // Check that the incomplete transfers matches expected + if !reflect.DeepEqual(expectedIncompleteTransfers, incompleteTransfers) { + t.Errorf("Incorrect incomplete transfers.\nexpected: %v\nreceived: %v", + expectedIncompleteTransfers, incompleteTransfers) + } +} + +// Tests that Received.AddTransfer makes a new transfer and adds it to the list. +func TestReceived_AddTransfer(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + r, _, _ := NewOrLoadReceived(kv) + + key, _ := ftCrypto.NewTransferKey(csprng.NewSystemRNG()) + tid, _ := ftCrypto.NewTransferID(csprng.NewSystemRNG()) + + rt, err := r.AddTransfer( + &key, &tid, "file", []byte("transferMAC"), 10, 20, 128) + if err != nil { + t.Errorf("Failed to add new transfer: %+v", err) + } + + // Check that the transfer was added + if _, exists := r.transfers[*rt.tid]; !exists { + t.Errorf("No transfer with ID %s exists.", rt.tid) + } +} + +// Tests that Received.AddTransfer returns an error when adding a transfer ID +// that already exists. +func TestReceived_AddTransfer_TransferAlreadyExists(t *testing.T) { + tid := ftCrypto.TransferID{0} + r := &Received{ + transfers: map[ftCrypto.TransferID]*ReceivedTransfer{tid: nil}, + } + + expectedErr := fmt.Sprintf(errAddExistingReceivedTransfer, tid) + _, err := r.AddTransfer(nil, &tid, "", nil, 0, 0, 0) + if err == nil || err.Error() != expectedErr { + t.Errorf("Received unexpected error when adding transfer that already "+ + "exists.\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Tests that Received.GetTransfer returns the expected transfer. +func TestReceived_GetTransfer(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + r, _, _ := NewOrLoadReceived(kv) + + key, _ := ftCrypto.NewTransferKey(csprng.NewSystemRNG()) + tid, _ := ftCrypto.NewTransferID(csprng.NewSystemRNG()) + + rt, err := r.AddTransfer( + &key, &tid, "file", []byte("transferMAC"), 10, 20, 128) + if err != nil { + t.Errorf("Failed to add new transfer: %+v", err) + } + + // Check that the transfer was added + receivedRt, exists := r.GetTransfer(rt.tid) + if !exists { + t.Errorf("No transfer with ID %s exists.", rt.tid) + } + + if !reflect.DeepEqual(rt, receivedRt) { + t.Errorf("Received ReceivedTransfer does not match expected."+ + "\nexpected: %+v\nreceived: %+v", rt, receivedRt) + } +} + +// Tests that Sent.RemoveTransfer removes the transfer from the list. +func TestReceived_RemoveTransfer(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + r, _, _ := NewOrLoadReceived(kv) + + key, _ := ftCrypto.NewTransferKey(csprng.NewSystemRNG()) + tid, _ := ftCrypto.NewTransferID(csprng.NewSystemRNG()) + + rt, err := r.AddTransfer( + &key, &tid, "file", []byte("transferMAC"), 10, 20, 128) + if err != nil { + t.Errorf("Failed to add new transfer: %+v", err) + } + + // Delete the transfer + err = r.RemoveTransfer(rt.tid) + if err != nil { + t.Errorf("RemoveTransfer returned an error: %+v", err) + } + + // Check that the transfer was deleted + _, exists := r.GetTransfer(rt.tid) + if exists { + t.Errorf("Transfer %s exists.", rt.tid) + } + + // Remove transfer that was already removed + err = r.RemoveTransfer(rt.tid) + if err != nil { + t.Errorf("RemoveTransfer returned an error: %+v", err) + } +} + +//////////////////////////////////////////////////////////////////////////////// +// Storage Functions // +//////////////////////////////////////////////////////////////////////////////// + +// Tests that Received.save saves the transfer ID list to storage by trying to +// load it after a save. +func TestReceived_save(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + r, _, _ := NewOrLoadReceived(kv) + r.transfers = map[ftCrypto.TransferID]*ReceivedTransfer{ + ftCrypto.TransferID{0}: nil, ftCrypto.TransferID{1}: nil, + ftCrypto.TransferID{2}: nil, ftCrypto.TransferID{3}: nil, + } + + err := r.save() + if err != nil { + t.Errorf("Failed to save transfer ID list: %+v", err) + } + + _, err = r.kv.Get(receivedTransfersStoreKey, receivedTransfersStoreVersion) + if err != nil { + t.Errorf("Failed to load transfer ID list: %+v", err) + } +} + +// Tests that the transfer IDs keys in the map marshalled by +// marshalReceivedTransfersMap and unmarshalled by unmarshalTransferIdList match +// the original. +func Test_marshalReceivedTransfersMap_unmarshalTransferIdList(t *testing.T) { + // Build map of transfer IDs + transfers := make(map[ftCrypto.TransferID]*ReceivedTransfer, 10) + for i := 0; i < 10; i++ { + tid, _ := ftCrypto.NewTransferID(csprng.NewSystemRNG()) + transfers[tid] = nil + } + + data, err := marshalReceivedTransfersMap(transfers) + if err != nil { + t.Errorf("marshalReceivedTransfersMap returned an error: %+v", err) + } + + tidList, err := unmarshalTransferIdList(data) + if err != nil { + t.Errorf("unmarshalSentTransfer returned an error: %+v", err) + } + + for _, tid := range tidList { + if _, exists := transfers[tid]; exists { + delete(transfers, tid) + } else { + t.Errorf("Transfer %s does not exist in list.", tid) + } + } + + if len(transfers) != 0 { + t.Errorf("%d transfers not in unmarshalled list: %v", + len(transfers), transfers) + } +} diff --git a/fileTransfer/store/sent.go b/fileTransfer/store/sent.go new file mode 100644 index 0000000000000000000000000000000000000000..9ca7f482e53540adff504c932ad39219225e39fa --- /dev/null +++ b/fileTransfer/store/sent.go @@ -0,0 +1,192 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package store + +import ( + "encoding/json" + "github.com/pkg/errors" + jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/client/storage/versioned" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/elixxir/ekv" + "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/netTime" + "sync" +) + +// Storage keys and versions. +const ( + sentTransfersStorePrefix = "SentFileTransfersPrefix" + sentTransfersStoreKey = "SentFileTransfers" + sentTransfersStoreVersion = 0 +) + +// Error messages. +const ( + // NewOrLoadSent + errLoadSent = "error loading sent transfer list from storage: %+v" + errUnmarshalSent = "could not unmarshal sent transfer list: %+v" + warnLoadSentTransfer = "[FT] Failed to load sent transfer %d of %d with ID %s: %+v" + errLoadAllSentTransfer = "failed to load all %d transfers" + + // Sent.AddTransfer + errAddExistingSentTransfer = "sent transfer with ID %s already exists in map." + errNewSentTransfer = "failed to make new sent transfer: %+v" +) + +// Sent contains a list of all sent transfers. +type Sent struct { + transfers map[ftCrypto.TransferID]*SentTransfer + + mux sync.RWMutex + kv *versioned.KV +} + +// NewOrLoadSent attempts to load Sent from storage. Or if none exist, then a +// new Sent is returned. If running transfers were loaded from storage, a list +// of unsent parts is returned. +func NewOrLoadSent(kv *versioned.KV) (*Sent, []Part, error) { + s := &Sent{ + transfers: make(map[ftCrypto.TransferID]*SentTransfer), + kv: kv.Prefix(sentTransfersStorePrefix), + } + + obj, err := s.kv.Get(sentTransfersStoreKey, sentTransfersStoreVersion) + if err != nil { + if !ekv.Exists(err) { + // Return the new Sent if none exists in storage + return s, nil, nil + } else { + // Return other errors + return nil, nil, errors.Errorf(errLoadSent, err) + } + } + + // Load list of saved sent transfers from storage + tidList, err := unmarshalTransferIdList(obj.Data) + if err != nil { + return nil, nil, errors.Errorf(errUnmarshalSent, err) + } + + // Load sent transfers from storage + var errCount int + var unsentParts []Part + for i := range tidList { + tid := tidList[i] + s.transfers[tid], err = loadSentTransfer(&tid, s.kv) + if err != nil { + jww.WARN.Printf(warnLoadSentTransfer, i, len(tidList), tid, err) + errCount++ + continue + } + + if s.transfers[tid].Status() == Running { + unsentParts = + append(unsentParts, s.transfers[tid].GetUnsentParts()...) + } + } + + // Return an error if all transfers failed to load + if errCount == len(tidList) { + return nil, nil, errors.Errorf(errLoadAllSentTransfer, len(tidList)) + } + + return s, unsentParts, nil +} + +// AddTransfer creates a SentTransfer and adds it to the map keyed on its +// transfer ID. +func (s *Sent) AddTransfer(recipient *id.ID, key *ftCrypto.TransferKey, + tid *ftCrypto.TransferID, fileName string, parts [][]byte, numFps uint16) ( + *SentTransfer, error) { + s.mux.Lock() + defer s.mux.Unlock() + + _, exists := s.transfers[*tid] + if exists { + return nil, errors.Errorf(errAddExistingSentTransfer, tid) + } + + st, err := newSentTransfer(recipient, key, tid, fileName, parts, numFps, s.kv) + if err != nil { + return nil, errors.Errorf(errNewSentTransfer, tid) + } + + s.transfers[*tid] = st + + return st, s.save() +} + +// GetTransfer returns the SentTransfer with the desiccated transfer ID or false +// if none exists. +func (s *Sent) GetTransfer(tid *ftCrypto.TransferID) (*SentTransfer, bool) { + s.mux.RLock() + defer s.mux.RUnlock() + + st, exists := s.transfers[*tid] + return st, exists +} + +// RemoveTransfer removes the transfer from the map. If no transfer exists, +// returns nil. Only errors due to saving to storage are returned. +func (s *Sent) RemoveTransfer(tid *ftCrypto.TransferID) error { + s.mux.Lock() + defer s.mux.Unlock() + + _, exists := s.transfers[*tid] + if !exists { + return nil + } + + delete(s.transfers, *tid) + return s.save() +} + +//////////////////////////////////////////////////////////////////////////////// +// Storage Functions // +//////////////////////////////////////////////////////////////////////////////// + +// save stores a list of transfer IDs in the map to storage. +func (s *Sent) save() error { + data, err := marshalSentTransfersMap(s.transfers) + if err != nil { + return err + } + + obj := &versioned.Object{ + Version: sentTransfersStoreVersion, + Timestamp: netTime.Now(), + Data: data, + } + + return s.kv.Set(sentTransfersStoreKey, sentTransfersStoreVersion, obj) +} + +// marshalSentTransfersMap serialises the list of transfer IDs from a +// SentTransfer map. +func marshalSentTransfersMap(transfers map[ftCrypto.TransferID]*SentTransfer) ( + []byte, error) { + tidList := make([]ftCrypto.TransferID, 0, len(transfers)) + + for tid := range transfers { + tidList = append(tidList, tid) + } + + return json.Marshal(tidList) +} + +// unmarshalTransferIdList deserializes the data into a list of transfer IDs. +func unmarshalTransferIdList(data []byte) ([]ftCrypto.TransferID, error) { + var tidList []ftCrypto.TransferID + err := json.Unmarshal(data, &tidList) + if err != nil { + return nil, err + } + + return tidList, nil +} diff --git a/fileTransfer/store/sentTransfer.go b/fileTransfer/store/sentTransfer.go new file mode 100644 index 0000000000000000000000000000000000000000..dbd4033a85a0b6315e9d052d230557a822612a33 --- /dev/null +++ b/fileTransfer/store/sentTransfer.go @@ -0,0 +1,339 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package store + +import ( + "encoding/base64" + "encoding/json" + "github.com/pkg/errors" + jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/client/fileTransfer/store/cypher" + "gitlab.com/elixxir/client/storage/utility" + "gitlab.com/elixxir/client/storage/versioned" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/xx_network/primitives/id" + "gitlab.com/xx_network/primitives/netTime" + "sync" +) + +// Storage keys and versions. +const ( + sentTransferStorePrefix = "SentFileTransferStore/" + sentTransferStoreKey = "SentTransfer" + sentTransferStoreVersion = 0 + sentTransferStatusKey = "SentPartStatusVector" +) + +// Error messages. +const ( + // newSentTransfer + errStNewCypherManager = "failed to create new cypher manager: %+v" + errStNewPartStatusVector = "failed to create new state vector for part statuses: %+v" + + // SentTransfer.getPartData + errNoPartNum = "no part with part number %d exists in transfer %s (%q)" + + // loadSentTransfer + errStLoadCypherManager = "failed to load cypher manager from storage: %+v" + errStLoadFields = "failed to load recipient, status, and parts list: %+v" + errStUnmarshalFields = "failed to unmarshal recipient, status, and parts list: %+v" + errStLoadPartStatusVector = "failed to load state vector for part statuses: %+v" + + // SentTransfer.Delete + errStDeleteCypherManager = "failed to delete cypherManager: %+v" + errStDeleteSentTransfer = "failed to delete recipient ID, status, and file parts: %+v" + errStDeletePartStatus = "failed to delete part status multi state vector: %+v" + + // SentTransfer.save + errMarshalSentTransfer = "failed to marshal: %+v" +) + +// SentTransfer contains information and progress data for sending or sent file +// transfer. +type SentTransfer struct { + // Tracks cyphers for each part + cypherManager *cypher.Manager + + // The ID of the transfer + tid *ftCrypto.TransferID + + // User given name to file + fileName string + + // ID of the recipient of the file transfer + recipient *id.ID + + // The number of file parts in the file + numParts uint16 + + // Indicates the status of the transfer + status TransferStatus + + // List of all file parts in order to send + parts [][]byte + + // Stores the status of each part in a bitstream format + partStatus *utility.StateVector + + mux sync.RWMutex + kv *versioned.KV +} + +// newSentTransfer generates a new SentTransfer with the specified transfer key, +// transfer ID, and parts. +func newSentTransfer(recipient *id.ID, key *ftCrypto.TransferKey, + tid *ftCrypto.TransferID, fileName string, parts [][]byte, numFps uint16, + kv *versioned.KV) (*SentTransfer, error) { + kv = kv.Prefix(makeSentTransferPrefix(tid)) + + // Create new cypher manager + cypherManager, err := cypher.NewManager(key, numFps, kv) + if err != nil { + return nil, errors.Errorf(errStNewCypherManager, err) + } + + // Create new state vector for storing statuses of arrived parts + partStatus, err := utility.NewStateVector( + kv, sentTransferStatusKey, uint32(len(parts))) + if err != nil { + return nil, errors.Errorf(errStNewPartStatusVector, err) + } + + st := &SentTransfer{ + cypherManager: cypherManager, + tid: tid, + fileName: fileName, + recipient: recipient, + numParts: uint16(len(parts)), + status: Running, + parts: parts, + partStatus: partStatus, + kv: kv, + } + + return st, st.save() +} + +// GetUnsentParts builds a list of all unsent parts, each in a Part object. +func (st *SentTransfer) GetUnsentParts() []Part { + unusedPartNumbers := st.partStatus.GetUnusedKeyNums() + partList := make([]Part, len(unusedPartNumbers)) + + for i, partNum := range unusedPartNumbers { + partList[i] = Part{ + transfer: st, + cypherManager: st.cypherManager, + partNum: uint16(partNum), + } + } + + return partList +} + +// getPartData returns the part data from the given part number. +func (st *SentTransfer) getPartData(partNum uint16) []byte { + if int(partNum) > len(st.parts)-1 { + jww.FATAL.Panicf(errNoPartNum, partNum, st.tid, st.fileName) + } + + return st.parts[partNum] +} + +// markArrived marks the status of the given part numbers as arrived. When the +// last part is marked arrived, the transfer is marked as completed. +func (st *SentTransfer) markArrived(partNum uint16) { + st.mux.Lock() + defer st.mux.Unlock() + + st.partStatus.Use(uint32(partNum)) + + // Mark transfer completed if all parts arrived + if st.partStatus.GetNumUsed() == uint32(st.numParts) { + st.status = Completed + } +} + +// markTransferFailed sets the transfer as failed. Only call this if no more +// retries are available. +func (st *SentTransfer) markTransferFailed() { + st.mux.Lock() + defer st.mux.Unlock() + st.status = Failed +} + +// Status returns the status of the transfer. +func (st *SentTransfer) Status() TransferStatus { + st.mux.RLock() + defer st.mux.RUnlock() + return st.status +} + +// NumParts returns the total number of file parts in the transfer. +func (st *SentTransfer) NumParts() uint16 { + return st.numParts +} + +// TransferID returns the transfer's ID. +func (st *SentTransfer) TransferID() *ftCrypto.TransferID { + return st.tid +} + +// FileName returns the transfer's file name. +func (st *SentTransfer) FileName() string { + return st.fileName +} + +// Recipient returns the transfer's recipient ID. +func (st *SentTransfer) Recipient() *id.ID { + return st.recipient +} + +// NumArrived returns the number of parts that have arrived. +func (st *SentTransfer) NumArrived() uint16 { + return uint16(st.partStatus.GetNumUsed()) +} + +// CopyPartStatusVector returns a copy of the part status vector that can be +// used to look up the current status of parts. Note that the statuses are from +// when this function is called and not realtime. +func (st *SentTransfer) CopyPartStatusVector() *utility.StateVector { + return st.partStatus.DeepCopy() +} + +//////////////////////////////////////////////////////////////////////////////// +// Storage Functions // +//////////////////////////////////////////////////////////////////////////////// + +// loadSentTransfer loads the SentTransfer with the given transfer ID from +// storage. +func loadSentTransfer(tid *ftCrypto.TransferID, kv *versioned.KV) ( + *SentTransfer, error) { + kv = kv.Prefix(makeSentTransferPrefix(tid)) + + // Load cypher manager + cypherManager, err := cypher.LoadManager(kv) + if err != nil { + return nil, errors.Errorf(errStLoadCypherManager, err) + } + + // Load fileName, recipient ID, status, and file parts + obj, err := kv.Get(sentTransferStoreKey, sentTransferStoreVersion) + if err != nil { + return nil, errors.Errorf(errStLoadFields, err) + } + + fileName, recipient, status, parts, err := unmarshalSentTransfer(obj.Data) + if err != nil { + return nil, errors.Errorf(errStUnmarshalFields, err) + } + + // Load state vector for storing statuses of arrived parts + partStatus, err := utility.LoadStateVector(kv, sentTransferStatusKey) + if err != nil { + return nil, errors.Errorf(errStLoadPartStatusVector, err) + } + + st := &SentTransfer{ + cypherManager: cypherManager, + tid: tid, + fileName: fileName, + recipient: recipient, + numParts: uint16(len(parts)), + status: status, + parts: parts, + partStatus: partStatus, + kv: kv, + } + + return st, nil +} + +// Delete deletes all data in the SentTransfer from storage. +func (st *SentTransfer) Delete() error { + st.mux.Lock() + defer st.mux.Unlock() + + // Delete cypher manager + err := st.cypherManager.Delete() + if err != nil { + return errors.Errorf(errStDeleteCypherManager, err) + } + + // Delete recipient ID, status, and file parts + err = st.kv.Delete(sentTransferStoreKey, sentTransferStoreVersion) + if err != nil { + return errors.Errorf(errStDeleteSentTransfer, err) + } + + // Delete part status multi state vector + err = st.partStatus.Delete() + if err != nil { + return errors.Errorf(errStDeletePartStatus, err) + } + + return nil +} + +// save stores all fields in SentTransfer that do not have their own storage +// (recipient ID, status, and file parts) to storage. +func (st *SentTransfer) save() error { + data, err := st.marshal() + if err != nil { + return errors.Errorf(errMarshalSentTransfer, err) + } + + obj := &versioned.Object{ + Version: sentTransferStoreVersion, + Timestamp: netTime.Now(), + Data: data, + } + + return st.kv.Set(sentTransferStoreKey, sentTransferStoreVersion, obj) +} + +// sentTransferDisk structure is used to marshal and unmarshal SentTransfer +// fields to/from storage. +type sentTransferDisk struct { + FileName string + Recipient *id.ID + Status TransferStatus + Parts [][]byte +} + +// marshal serialises the SentTransfer's fileName, recipient, status, and parts +// list. +func (st *SentTransfer) marshal() ([]byte, error) { + disk := sentTransferDisk{ + FileName: st.fileName, + Recipient: st.recipient, + Status: st.status, + Parts: st.parts, + } + + return json.Marshal(disk) +} + +// unmarshalSentTransfer deserializes the data into a fileName, recipient, +// status, and parts list. +func unmarshalSentTransfer(data []byte) (fileName string, recipient *id.ID, + status TransferStatus, parts [][]byte, err error) { + var disk sentTransferDisk + err = json.Unmarshal(data, &disk) + if err != nil { + return "", nil, 0, nil, err + } + + return disk.FileName, disk.Recipient, disk.Status, disk.Parts, nil +} + +// makeSentTransferPrefix generates the unique prefix used on the key value +// store to store sent transfers for the given transfer ID. +func makeSentTransferPrefix(tid *ftCrypto.TransferID) string { + return sentTransferStorePrefix + + base64.StdEncoding.EncodeToString(tid.Bytes()) +} diff --git a/fileTransfer/store/sentTransfer_test.go b/fileTransfer/store/sentTransfer_test.go new file mode 100644 index 0000000000000000000000000000000000000000..e2999646b85cc334642da0a7289ed699c5a7eb5b --- /dev/null +++ b/fileTransfer/store/sentTransfer_test.go @@ -0,0 +1,473 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package store + +import ( + "bytes" + "fmt" + "gitlab.com/elixxir/client/fileTransfer/store/cypher" + "gitlab.com/elixxir/client/fileTransfer/store/fileMessage" + "gitlab.com/elixxir/client/storage/utility" + "gitlab.com/elixxir/client/storage/versioned" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/elixxir/ekv" + "gitlab.com/elixxir/primitives/format" + "gitlab.com/xx_network/crypto/csprng" + "gitlab.com/xx_network/primitives/id" + "reflect" + "strconv" + "testing" +) + +// Tests that newSentTransfer returns a new SentTransfer with the expected +// values. +func Test_newSentTransfer(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + key, _ := ftCrypto.NewTransferKey(csprng.NewSystemRNG()) + tid, _ := ftCrypto.NewTransferID(csprng.NewSystemRNG()) + numFps := uint16(24) + parts := [][]byte{[]byte("hello"), []byte("hello"), []byte("hello")} + stKv := kv.Prefix(makeSentTransferPrefix(&tid)) + + cypherManager, err := cypher.NewManager(&key, numFps, stKv) + if err != nil { + t.Errorf("Failed to make new cypher manager: %+v", err) + } + partStatus, err := utility.NewStateVector( + stKv, sentTransferStatusKey, uint32(len(parts))) + if err != nil { + t.Errorf("Failed to make new state vector: %+v", err) + } + + expected := &SentTransfer{ + cypherManager: cypherManager, + tid: &tid, + fileName: "file", + recipient: id.NewIdFromString("user", id.User, t), + numParts: uint16(len(parts)), + status: Running, + parts: parts, + partStatus: partStatus, + kv: stKv, + } + + st, err := newSentTransfer( + expected.recipient, &key, &tid, expected.fileName, parts, numFps, kv) + if err != nil { + t.Errorf("newSentTransfer returned an error: %+v", err) + } + + if !reflect.DeepEqual(expected, st) { + t.Errorf("New SentTransfer does not match expected."+ + "\nexpected: %+v\nreceived: %+v", expected, st) + } +} + +// Tests that SentTransfer.GetUnsentParts returns the correct number of unsent +// parts +func TestSentTransfer_GetUnsentParts(t *testing.T) { + numParts := uint16(10) + st, _, _, _, _ := newTestSentTransfer(numParts, t) + + // Check that all parts are returned after initialisation + unsentParts := st.GetUnsentParts() + if len(unsentParts) != int(numParts) { + t.Errorf("Number of unsent parts does not match original number of "+ + "parts when none have been sent.\nexpected: %d\nreceived: %d", + numParts, len(unsentParts)) + } + + // Ensure all parts have the proper part number + for i, p := range unsentParts { + if int(p.partNum) != i { + t.Errorf("Part has incorrect part number."+ + "\nexpected: %d\nreceived: %d", i, p.partNum) + } + } + + // Use every other part + for i := range unsentParts { + if i%2 == 0 { + unsentParts[i].MarkArrived() + } + } + + // Check that only have the number of parts is returned + unsentParts = st.GetUnsentParts() + if len(unsentParts) != int(numParts)/2 { + t.Errorf("Number of unsent parts is not half original number after "+ + "half have been marked as arrived.\nexpected: %d\nreceived: %d", + numParts/2, len(unsentParts)) + } + + // Ensure all parts have the proper part number + for i, p := range unsentParts { + if int(p.partNum) != i*2+1 { + t.Errorf("Part has incorrect part number."+ + "\nexpected: %d\nreceived: %d", i*2+1, p.partNum) + } + } + + // Use the rest of the parts + for i := range unsentParts { + unsentParts[i].MarkArrived() + } + + // Check that no sent parts are returned + unsentParts = st.GetUnsentParts() + if len(unsentParts) != 0 { + t.Errorf("Number of unsent parts is not zero after all have been "+ + "marked as arrived.\nexpected: %d\nreceived: %d", + 0, len(unsentParts)) + } +} + +// Tests that SentTransfer.getPartData returns all the correct parts at their +// expected indexes. +func TestSentTransfer_getPartData(t *testing.T) { + st, parts, _, _, _ := newTestSentTransfer(16, t) + + for i, part := range parts { + partData := st.getPartData(uint16(i)) + + if !bytes.Equal(part, partData) { + t.Errorf("Incorrect part #%d.\nexpected: %q\nreceived: %q", + i, part, partData) + } + } +} + +// Tests that SentTransfer.getPartData panics when the part number is not within +// the range of part numbers. +func TestSentTransfer_getPartData_OutOfRangePanic(t *testing.T) { + st, parts, _, _, _ := newTestSentTransfer(16, t) + + invalidPartNum := uint16(len(parts) + 1) + expectedErr := fmt.Sprintf(errNoPartNum, invalidPartNum, st.tid, st.fileName) + + defer func() { + r := recover() + if r == nil || r != expectedErr { + t.Errorf("getPartData did not return the expected error when the "+ + "part number %d is out of range.\nexpected: %s\nreceived: %+v", + invalidPartNum, expectedErr, r) + } + }() + + _ = st.getPartData(invalidPartNum) +} + +// Tests that after setting all parts as arrived via SentTransfer.markArrived, +// there are no unsent parts left and the transfer is marked as Completed. +func TestSentTransfer_markArrived(t *testing.T) { + st, parts, _, _, _ := newTestSentTransfer(16, t) + + // Mark all parts as arrived + for i := range parts { + st.markArrived(uint16(i)) + } + + // Check that all parts are marked as arrived + unsentParts := st.GetUnsentParts() + if len(unsentParts) != 0 { + t.Errorf("There are %d unsent parts.", len(unsentParts)) + } + + if st.status != Completed { + t.Errorf("Status not correctly marked.\nexpected: %s\nreceived: %s", + Completed, st.status) + } +} + +// Tests that SentTransfer.markTransferFailed changes the status of the transfer +// to Failed. +func TestSentTransfer_markTransferFailed(t *testing.T) { + st, _, _, _, _ := newTestSentTransfer(16, t) + + st.markTransferFailed() + + if st.status != Failed { + t.Errorf("Status not correctly marked.\nexpected: %s\nreceived: %s", + Failed, st.status) + } +} + +// Tests that SentTransfer.Status returns the correct status of the transfer. +func TestSentTransfer_Status(t *testing.T) { + st, parts, _, _, _ := newTestSentTransfer(16, t) + + // Check that it is Running + if st.Status() != Running { + t.Errorf("Status returned incorrect status.\nexpected: %s\nreceived: %s", + Running, st.Status()) + } + + // Mark all parts as arrived + for i := range parts { + st.markArrived(uint16(i)) + } + + // Check that it is Completed + if st.Status() != Completed { + t.Errorf("Status returned incorrect status.\nexpected: %s\nreceived: %s", + Completed, st.Status()) + } + + // Mark transfer failed + st.markTransferFailed() + + // Check that it is Failed + if st.Status() != Failed { + t.Errorf("Status returned incorrect status.\nexpected: %s\nreceived: %s", + Failed, st.Status()) + } +} + +// Tests that SentTransfer.NumParts returns the correct number of parts. +func TestSentTransfer_NumParts(t *testing.T) { + numParts := uint16(16) + st, _, _, _, _ := newTestSentTransfer(numParts, t) + + if st.NumParts() != numParts { + t.Errorf("Incorrect number of parts.\nexpected: %d\nreceived: %d", + numParts, st.NumParts()) + } +} + +// Tests that SentTransfer.TransferID returns the correct transfer ID. +func TestSentTransfer_TransferID(t *testing.T) { + st, _, _, _, _ := newTestSentTransfer(16, t) + + if st.TransferID() != st.tid { + t.Errorf("Incorrect transfer ID.\nexpected: %s\nreceived: %s", + st.tid, st.TransferID()) + } +} + +// Tests that SentTransfer.FileName returns the correct file name. +func TestSentTransfer_FileName(t *testing.T) { + st, _, _, _, _ := newTestSentTransfer(16, t) + + if st.FileName() != st.fileName { + t.Errorf("Incorrect transfer ID.\nexpected: %s\nreceived: %s", + st.fileName, st.FileName()) + } +} + +// Tests that SentTransfer.Recipient returns the correct recipient ID. +func TestSentTransfer_Recipient(t *testing.T) { + st, _, _, _, _ := newTestSentTransfer(16, t) + + if !st.Recipient().Cmp(st.recipient) { + t.Errorf("Incorrect recipient ID.\nexpected: %s\nreceived: %s", + st.recipient, st.Recipient()) + } +} + +// Tests that SentTransfer.NumArrived returns the correct number of arrived +// parts. +func TestSentTransfer_NumArrived(t *testing.T) { + st, parts, _, _, _ := newTestSentTransfer(16, t) + + if st.NumArrived() != 0 { + t.Errorf("Incorrect number of arrived parts."+ + "\nexpected: %d\nreceived: %d", 0, st.NumArrived()) + } + + // Mark all parts as arrived + for i := range parts { + st.markArrived(uint16(i)) + } + + if uint32(st.NumArrived()) != st.partStatus.GetNumKeys() { + t.Errorf("Incorrect number of arrived parts."+ + "\nexpected: %d\nreceived: %d", + uint32(st.NumArrived()), st.partStatus.GetNumKeys()) + } +} + +// Tests that the state vector returned by SentTransfer.CopyPartStatusVector +// has the same values as the original but is a copy. +func TestSentTransfer_CopyPartStatusVector(t *testing.T) { + st, _, _, _, _ := newTestSentTransfer(16, t) + + // Check that the vectors have the same unused parts + partStatus := st.CopyPartStatusVector() + if !reflect.DeepEqual( + partStatus.GetUnusedKeyNums(), st.partStatus.GetUnusedKeyNums()) { + t.Errorf("Copied part status does not match original."+ + "\nexpected: %v\nreceived: %v", + st.partStatus.GetUnusedKeyNums(), partStatus.GetUnusedKeyNums()) + } + + // Modify the state + st.markArrived(5) + + // Check that the copied state is different + if reflect.DeepEqual( + partStatus.GetUnusedKeyNums(), st.partStatus.GetUnusedKeyNums()) { + t.Errorf("Old copied part status matches new status."+ + "\nexpected: %v\nreceived: %v", + st.partStatus.GetUnusedKeyNums(), partStatus.GetUnusedKeyNums()) + } +} + +//////////////////////////////////////////////////////////////////////////////// +// Storage Functions // +//////////////////////////////////////////////////////////////////////////////// + +// Tests that a SentTransfer loaded via loadSentTransfer matches the original. +func Test_loadSentTransfer(t *testing.T) { + st, _, _, _, kv := newTestSentTransfer(64, t) + + loadedSt, err := loadSentTransfer(st.tid, kv) + if err != nil { + t.Errorf("Failed to load SentTransfer: %+v", err) + } + + if !reflect.DeepEqual(st, loadedSt) { + t.Errorf("Loaded SentTransfer does not match original."+ + "\nexpected: %+v\nreceived: %+v", st, loadedSt) + } +} + +// Tests that SentTransfer.Delete deletes the storage backend of the +// SentTransfer and that it cannot be loaded again. +func TestSentTransfer_Delete(t *testing.T) { + st, _, _, _, kv := newTestSentTransfer(64, t) + + err := st.Delete() + if err != nil { + t.Errorf("Delete returned an error: %+v", err) + } + + _, err = loadSentTransfer(st.tid, kv) + if err == nil { + t.Errorf("Loaded sent transfer that was deleted.") + } +} + +// Tests that the fields saved by SentTransfer.save can be loaded from storage. +func TestSentTransfer_save(t *testing.T) { + st, _, _, _, _ := newTestSentTransfer(64, t) + + err := st.save() + if err != nil { + t.Errorf("save returned an error: %+v", err) + } + + _, err = st.kv.Get(sentTransferStoreKey, sentTransferStoreVersion) + if err != nil { + t.Errorf("Failed to load saved SentTransfer: %+v", err) + } +} + +// Tests that a SentTransfer marshalled via SentTransfer.marshal and +// unmarshalled via unmarshalSentTransfer matches the original. +func TestSentTransfer_marshal_unmarshalSentTransfer(t *testing.T) { + st := &SentTransfer{ + fileName: "transferName", + recipient: id.NewIdFromString("user", id.User, t), + status: Failed, + parts: [][]byte{[]byte("Message"), []byte("Part")}, + } + + data, err := st.marshal() + if err != nil { + t.Errorf("marshal returned an error: %+v", err) + } + + fileName, recipient, status, parts, err := unmarshalSentTransfer(data) + if err != nil { + t.Errorf("Failed to unmarshal SentTransfer: %+v", err) + } + + if st.fileName != fileName { + t.Errorf("Incorrect file name.\nexpected: %q\nreceived: %q", + st.fileName, fileName) + } + + if !st.recipient.Cmp(recipient) { + t.Errorf("Incorrect recipient.\nexpected: %s\nreceived: %s", + st.recipient, recipient) + } + + if status != status { + t.Errorf("Incorrect status.\nexpected: %s\nreceived: %s", + status, status) + } + + if !reflect.DeepEqual(st.parts, parts) { + t.Errorf("Incorrect parts.\nexpected: %q\nreceived: %q", + st.parts, parts) + } +} + +// Consistency test of makeSentTransferPrefix. +func Test_makeSentTransferPrefix_Consistency(t *testing.T) { + expectedPrefixes := []string{ + "SentFileTransferStore/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", + "SentFileTransferStore/AQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", + "SentFileTransferStore/AgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", + "SentFileTransferStore/AwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", + "SentFileTransferStore/BAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", + "SentFileTransferStore/BQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", + "SentFileTransferStore/BgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", + "SentFileTransferStore/BwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", + "SentFileTransferStore/CAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", + "SentFileTransferStore/CQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", + } + + for i, expected := range expectedPrefixes { + tid := ftCrypto.TransferID{byte(i)} + prefix := makeSentTransferPrefix(&tid) + + if expected != prefix { + t.Errorf("Prefix #%d does not match expected."+ + "\nexpected: %q\nreceived: %q", i, expected, prefix) + } + } +} + +const numPrimeBytes = 512 + +// newTestSentTransfer creates a new SentTransfer for testing. +func newTestSentTransfer(numParts uint16, t *testing.T) ( + *SentTransfer, [][]byte, *ftCrypto.TransferKey, uint16, *versioned.KV) { + kv := versioned.NewKV(make(ekv.Memstore)) + recipient := id.NewIdFromString("recipient", id.User, t) + key, _ := ftCrypto.NewTransferKey(csprng.NewSystemRNG()) + tid, _ := ftCrypto.NewTransferID(csprng.NewSystemRNG()) + numFps := 2 * numParts + fileName := "helloFile" + parts := generateTestParts(numParts) + + st, err := newSentTransfer(recipient, &key, &tid, fileName, parts, numFps, kv) + if err != nil { + t.Errorf("Failed to make new SentTransfer: %+v", err) + } + + return st, parts, &key, numFps, kv +} + +// generateTestParts generates a list of file parts of the correct size to be +// encrypted/decrypted. +func generateTestParts(numParts uint16) [][]byte { + // Calculate part size + partSize := fileMessage.NewPartMessage( + format.NewMessage(numPrimeBytes).ContentsSize()).GetPartSize() + + // Create list of parts and fill + parts := make([][]byte, numParts) + for i := range parts { + parts[i] = make([]byte, partSize) + copy(parts[i], "Hello "+strconv.Itoa(i)) + } + + return parts +} diff --git a/fileTransfer/store/sent_test.go b/fileTransfer/store/sent_test.go new file mode 100644 index 0000000000000000000000000000000000000000..9a25c329c3398fa5c13b885bd8f9856e0883c0d2 --- /dev/null +++ b/fileTransfer/store/sent_test.go @@ -0,0 +1,272 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package store + +import ( + "bytes" + "fmt" + "gitlab.com/elixxir/client/storage/versioned" + ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" + "gitlab.com/elixxir/ekv" + "gitlab.com/xx_network/crypto/csprng" + "gitlab.com/xx_network/primitives/id" + "reflect" + "sort" + "strconv" + "testing" +) + +// Tests that NewOrLoadSent returns a new Sent when none exist in storage and +// that the list of unsent parts is nil. +func TestNewOrLoadSent_New(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + expected := &Sent{ + transfers: make(map[ftCrypto.TransferID]*SentTransfer), + kv: kv.Prefix(sentTransfersStorePrefix), + } + + s, unsentParts, err := NewOrLoadSent(kv) + if err != nil { + t.Errorf("NewOrLoadSent returned an error: %+v", err) + } + + if !reflect.DeepEqual(expected, s) { + t.Errorf("New Sent does not match expected."+ + "\nexpected: %+v\nreceived: %+v", expected, s) + } + + if unsentParts != nil { + t.Errorf("List of parts should be nil when not loading: %+v", + unsentParts) + } +} + +// Tests that NewOrLoadSent returns a loaded Sent when one exist in storage and +// that the list of unsent parts is correct. +func TestNewOrLoadSent_Load(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + s, _, err := NewOrLoadSent(kv) + if err != nil { + t.Errorf("Failed to make new Sent: %+v", err) + } + var expectedUnsentParts []Part + + // Create and add transfers to map and save + for i := 0; i < 10; i++ { + key, _ := ftCrypto.NewTransferKey(csprng.NewSystemRNG()) + tid, _ := ftCrypto.NewTransferID(csprng.NewSystemRNG()) + st, err2 := s.AddTransfer( + id.NewIdFromString("recipient"+strconv.Itoa(i), id.User, t), + &key, &tid, "file"+strconv.Itoa(i), + generateTestParts(uint16(10+i)), uint16(2*(10+i))) + if err2 != nil { + t.Errorf("Failed to add transfer #%d: %+v", i, err2) + } + expectedUnsentParts = append(expectedUnsentParts, st.GetUnsentParts()...) + } + if err = s.save(); err != nil { + t.Errorf("Failed to make save filled Sent: %+v", err) + } + + // Load Sent + loadedSent, unsentParts, err := NewOrLoadSent(kv) + if err != nil { + t.Errorf("Failed to load Sent: %+v", err) + } + + // Check that the loaded Sent matches original + if !reflect.DeepEqual(s, loadedSent) { + t.Errorf("Loaded Sent does not match original."+ + "\nexpected: %v\nreceived: %v", s, loadedSent) + } + + sort.Slice(unsentParts, func(i, j int) bool { + switch bytes.Compare(unsentParts[i].TransferID()[:], + unsentParts[j].TransferID()[:]) { + case -1: + return true + case 1: + return false + default: + return unsentParts[i].partNum < unsentParts[j].partNum + } + }) + + sort.Slice(expectedUnsentParts, func(i, j int) bool { + switch bytes.Compare(expectedUnsentParts[i].TransferID()[:], + expectedUnsentParts[j].TransferID()[:]) { + case -1: + return true + case 1: + return false + default: + return expectedUnsentParts[i].partNum < expectedUnsentParts[j].partNum + } + }) + + // Check that the unsent parts matches expected + if !reflect.DeepEqual(expectedUnsentParts, unsentParts) { + t.Errorf("Incorrect unsent parts.\nexpected: %v\nreceived: %v", + expectedUnsentParts, unsentParts) + } +} + +// Tests that Sent.AddTransfer makes a new transfer and adds it to the list. +func TestSent_AddTransfer(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + s, _, _ := NewOrLoadSent(kv) + + key, _ := ftCrypto.NewTransferKey(csprng.NewSystemRNG()) + tid, _ := ftCrypto.NewTransferID(csprng.NewSystemRNG()) + + st, err := s.AddTransfer(id.NewIdFromString("recipient", id.User, t), + &key, &tid, "file", generateTestParts(10), 20) + if err != nil { + t.Errorf("Failed to add new transfer: %+v", err) + } + + // Check that the transfer was added + if _, exists := s.transfers[*st.tid]; !exists { + t.Errorf("No transfer with ID %s exists.", st.tid) + } +} + +// Tests that Sent.AddTransfer returns an error when adding a transfer ID that +// already exists. +func TestSent_AddTransfer_TransferAlreadyExists(t *testing.T) { + tid := ftCrypto.TransferID{0} + s := &Sent{ + transfers: map[ftCrypto.TransferID]*SentTransfer{tid: nil}, + } + + expectedErr := fmt.Sprintf(errAddExistingSentTransfer, tid) + _, err := s.AddTransfer(nil, nil, &tid, "", nil, 0) + if err == nil || err.Error() != expectedErr { + t.Errorf("Received unexpected error when adding transfer that already "+ + "exists.\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Tests that Sent.GetTransfer returns the expected transfer. +func TestSent_GetTransfer(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + s, _, _ := NewOrLoadSent(kv) + + key, _ := ftCrypto.NewTransferKey(csprng.NewSystemRNG()) + tid, _ := ftCrypto.NewTransferID(csprng.NewSystemRNG()) + + st, err := s.AddTransfer(id.NewIdFromString("recipient", id.User, t), + &key, &tid, "file", generateTestParts(10), 20) + if err != nil { + t.Errorf("Failed to add new transfer: %+v", err) + } + + // Check that the transfer was added + receivedSt, exists := s.GetTransfer(st.tid) + if !exists { + t.Errorf("No transfer with ID %s exists.", st.tid) + } + + if !reflect.DeepEqual(st, receivedSt) { + t.Errorf("Received SentTransfer does not match expected."+ + "\nexpected: %+v\nreceived: %+v", st, receivedSt) + } +} + +// Tests that Sent.RemoveTransfer removes the transfer from the list. +func TestSent_RemoveTransfer(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + s, _, _ := NewOrLoadSent(kv) + + key, _ := ftCrypto.NewTransferKey(csprng.NewSystemRNG()) + tid, _ := ftCrypto.NewTransferID(csprng.NewSystemRNG()) + + st, err := s.AddTransfer(id.NewIdFromString("recipient", id.User, t), + &key, &tid, "file", generateTestParts(10), 20) + if err != nil { + t.Errorf("Failed to add new transfer: %+v", err) + } + + // Delete the transfer + err = s.RemoveTransfer(st.tid) + if err != nil { + t.Errorf("RemoveTransfer returned an error: %+v", err) + } + + // Check that the transfer was deleted + _, exists := s.GetTransfer(st.tid) + if exists { + t.Errorf("Transfer %s exists.", st.tid) + } + + // Remove transfer that was already removed + err = s.RemoveTransfer(st.tid) + if err != nil { + t.Errorf("RemoveTransfer returned an error: %+v", err) + } +} + +//////////////////////////////////////////////////////////////////////////////// +// Storage Functions // +//////////////////////////////////////////////////////////////////////////////// + +// Tests that Sent.save saves the transfer ID list to storage by trying to load +// it after a save. +func TestSent_save(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + s, _, _ := NewOrLoadSent(kv) + s.transfers = map[ftCrypto.TransferID]*SentTransfer{ + ftCrypto.TransferID{0}: nil, ftCrypto.TransferID{1}: nil, + ftCrypto.TransferID{2}: nil, ftCrypto.TransferID{3}: nil, + } + + err := s.save() + if err != nil { + t.Errorf("Failed to save transfer ID list: %+v", err) + } + + _, err = s.kv.Get(sentTransfersStoreKey, sentTransfersStoreVersion) + if err != nil { + t.Errorf("Failed to load transfer ID list: %+v", err) + } +} + +// Tests that the transfer IDs keys in the map marshalled by +// marshalSentTransfersMap and unmarshalled by unmarshalTransferIdList match the +// original. +func Test_marshalSentTransfersMap_unmarshalTransferIdList(t *testing.T) { + // Build map of transfer IDs + transfers := make(map[ftCrypto.TransferID]*SentTransfer, 10) + for i := 0; i < 10; i++ { + tid, _ := ftCrypto.NewTransferID(csprng.NewSystemRNG()) + transfers[tid] = nil + } + + data, err := marshalSentTransfersMap(transfers) + if err != nil { + t.Errorf("marshalSentTransfersMap returned an error: %+v", err) + } + + tidList, err := unmarshalTransferIdList(data) + if err != nil { + t.Errorf("unmarshalSentTransfer returned an error: %+v", err) + } + + for _, tid := range tidList { + if _, exists := transfers[tid]; exists { + delete(transfers, tid) + } else { + t.Errorf("Transfer %s does not exist in list.", tid) + } + } + + if len(transfers) != 0 { + t.Errorf("%d transfers not in unmarshalled list: %v", + len(transfers), transfers) + } +} diff --git a/storage/fileTransfer/transferStatus.go b/fileTransfer/store/transferStatus.go similarity index 80% rename from storage/fileTransfer/transferStatus.go rename to fileTransfer/store/transferStatus.go index ef4fe36dc5716c79be40fd21438101eb16307ae4..4055f622a58bb1daac519c103abf9315e439acc4 100644 --- a/storage/fileTransfer/transferStatus.go +++ b/fileTransfer/store/transferStatus.go @@ -5,7 +5,7 @@ // LICENSE file // //////////////////////////////////////////////////////////////////////////////// -package fileTransfer +package store import ( "encoding/binary" @@ -19,13 +19,11 @@ const ( // Running indicates that the transfer is in the processes of sending Running TransferStatus = iota - // Stopping indicates that the last part has been sent but the callback - // indicating a completed transfer has not been called - Stopping + // Completed indicates that all file parts have been sent and arrived + Completed - // Stopped indicates that the last part in the transfer has been sent and - // the last callback has been called - Stopped + // Failed indicates that the transfer has run out of sending retries + Failed ) const invalidTransferStatusStringErr = "INVALID TransferStatus: " @@ -36,10 +34,10 @@ func (ts TransferStatus) String() string { switch ts { case Running: return "running" - case Stopping: - return "stopping" - case Stopped: - return "stopped" + case Completed: + return "completed" + case Failed: + return "failed" default: return invalidTransferStatusStringErr + strconv.Itoa(int(ts)) } diff --git a/storage/fileTransfer/transferStatus_test.go b/fileTransfer/store/transferStatus_test.go similarity index 86% rename from storage/fileTransfer/transferStatus_test.go rename to fileTransfer/store/transferStatus_test.go index 4f8bcd0a583309698cc25ff969539c03b5b2365a..a54878295b992b7faefecd3f9e66f1c78561d976 100644 --- a/storage/fileTransfer/transferStatus_test.go +++ b/fileTransfer/store/transferStatus_test.go @@ -5,7 +5,7 @@ // LICENSE file // //////////////////////////////////////////////////////////////////////////////// -package fileTransfer +package store import ( "strconv" @@ -16,10 +16,10 @@ import ( // of TransferStatus. func Test_TransferStatus_String(t *testing.T) { testValues := map[TransferStatus]string{ - Running: "running", - Stopping: "stopping", - Stopped: "stopped", - 100: invalidTransferStatusStringErr + strconv.Itoa(100), + Running: "running", + Completed: "completed", + Failed: "failed", + 100: invalidTransferStatusStringErr + strconv.Itoa(100), } for status, expected := range testValues { @@ -32,7 +32,7 @@ func Test_TransferStatus_String(t *testing.T) { // Tests that a marshalled and unmarshalled TransferStatus matches the original. func Test_TransferStatus_Marshal_UnmarshalTransferStatus(t *testing.T) { - testValues := []TransferStatus{Running, Stopping, Stopped} + testValues := []TransferStatus{Running, Completed, Failed} for _, status := range testValues { marshalledStatus := status.Marshal() diff --git a/fileTransfer/utils_test.go b/fileTransfer/utils_test.go index 3ef8c13da9f22cc3b71369c377ff5ab21b447522..d54a1257f57b6c4625eaa2c819202bc8e35d49ff 100644 --- a/fileTransfer/utils_test.go +++ b/fileTransfer/utils_test.go @@ -10,35 +10,21 @@ package fileTransfer import ( "bytes" "encoding/binary" - "github.com/cloudflare/circl/dh/sidh" - "github.com/pkg/errors" - network2 "gitlab.com/elixxir/client/cmix" - "gitlab.com/elixxir/client/cmix/gateway" - "gitlab.com/elixxir/client/event" - "gitlab.com/elixxir/client/interfaces" - "gitlab.com/elixxir/client/interfaces/message" - "gitlab.com/elixxir/client/interfaces/params" - "gitlab.com/elixxir/client/stoppable" - "gitlab.com/elixxir/client/storage" - ftStorage "gitlab.com/elixxir/client/storage/fileTransfer" - util "gitlab.com/elixxir/client/storage/utility" - "gitlab.com/elixxir/client/storage/versioned" - "gitlab.com/elixxir/client/switchboard" - "gitlab.com/elixxir/comms/network" - "gitlab.com/elixxir/crypto/diffieHellman" - "gitlab.com/elixxir/crypto/e2e" - "gitlab.com/elixxir/crypto/fastRNG" - ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" - "gitlab.com/elixxir/ekv" + jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/client/catalog" + "gitlab.com/elixxir/client/cmix" + "gitlab.com/elixxir/client/cmix/identity/receptionID" + "gitlab.com/elixxir/client/cmix/message" + "gitlab.com/elixxir/client/cmix/rounds" + "gitlab.com/elixxir/client/e2e" + "gitlab.com/elixxir/client/e2e/receive" + e2eCrypto "gitlab.com/elixxir/crypto/e2e" "gitlab.com/elixxir/primitives/format" - "gitlab.com/xx_network/comms/connect" - "gitlab.com/xx_network/crypto/csprng" "gitlab.com/xx_network/primitives/id" "gitlab.com/xx_network/primitives/id/ephemeral" - "gitlab.com/xx_network/primitives/ndf" + "gitlab.com/xx_network/primitives/netTime" "io" "math/rand" - "strconv" "sync" "testing" "time" @@ -93,525 +79,169 @@ func RandStringBytes(n int, prng *rand.Rand) string { return string(b) } -// checkReceivedProgress compares the output of ReceivedTransfer.GetProgress to -// expected values. -func checkReceivedProgress(completed bool, received, total uint16, - eCompleted bool, eReceived, eTotal uint16) error { - if eCompleted != completed || eReceived != received || eTotal != total { - return errors.Errorf("Returned progress does not match expected."+ - "\n completed received total"+ - "\nexpected: %5t %3d %3d"+ - "\nreceived: %5t %3d %3d", - eCompleted, eReceived, eTotal, - completed, received, total) - } - - return nil -} - -// checkSentProgress compares the output of SentTransfer.GetProgress to expected -// values. -func checkSentProgress(completed bool, sent, arrived, total uint16, - eCompleted bool, eSent, eArrived, eTotal uint16) error { - if eCompleted != completed || eSent != sent || eArrived != arrived || - eTotal != total { - return errors.Errorf("Returned progress does not match expected."+ - "\n completed sent arrived total"+ - "\nexpected: %5t %3d %3d %3d"+ - "\nreceived: %5t %3d %3d %3d", - eCompleted, eSent, eArrived, eTotal, - completed, sent, arrived, total) - } - - return nil -} - //////////////////////////////////////////////////////////////////////////////// -// PRNG // +// Mock cMix Client // //////////////////////////////////////////////////////////////////////////////// -// Prng is a PRNG that satisfies the csprng.Source interface. -type Prng struct{ prng io.Reader } - -func NewPrng(seed int64) csprng.Source { return &Prng{rand.New(rand.NewSource(seed))} } -func (s *Prng) Read(b []byte) (int, error) { return s.prng.Read(b) } -func (s *Prng) SetSeed([]byte) error { return nil } - -// PrngErr is a PRNG that satisfies the csprng.Source interface. However, it -// always returns an error -type PrngErr struct{} - -func NewPrngErr() csprng.Source { return &PrngErr{} } -func (s *PrngErr) Read([]byte) (int, error) { return 0, errors.New("ReadFailure") } -func (s *PrngErr) SetSeed([]byte) error { return errors.New("SetSeedFailure") } - -//////////////////////////////////////////////////////////////////////////////// -// Test Managers // -//////////////////////////////////////////////////////////////////////////////// - -// newTestManager creates a new Manager that has groups stored for testing. One -// of the groups in the list is also returned. -func newTestManager(sendErr bool, sendChan, sendE2eChan chan message.Receive, - receiveCB interfaces.ReceiveCallback, kv *versioned.KV, t *testing.T) *Manager { - - if kv == nil { - kv = versioned.NewKV(make(ekv.Memstore)) - } - sent, err := ftStorage.NewSentFileTransfersStore(kv) - if err != nil { - t.Fatalf("Failed to createw new SentFileTransfersStore: %+v", err) - } - received, err := ftStorage.NewReceivedFileTransfersStore(kv) - if err != nil { - t.Fatalf("Failed to createw new ReceivedFileTransfersStore: %+v", err) - } - - net := newTestNetworkManager(sendErr, sendChan, sendE2eChan, t) - - // Returns an error on function and round failure on callback if sendErr is - // set; otherwise, it reports round successes and returns nil - rr := func(rIDs []id.Round, _ time.Duration, cb network2.RoundEventCallback) error { - rounds := make(map[id.Round]network2.RoundLookupStatus, len(rIDs)) - for _, rid := range rIDs { - if sendErr { - rounds[rid] = network2.Failed - } else { - rounds[rid] = network2.Succeeded - } - } - cb(!sendErr, false, rounds) - if sendErr { - return errors.New("SendError") - } - - return nil - } - - p := DefaultParams() - avgNumMessages := (minPartsSendPerRound + maxPartsSendPerRound) / 2 - avgSendSize := avgNumMessages * (8192 / 8) - p.MaxThroughput = int(time.Second) * avgSendSize - - oldTransfersRecovered := uint32(0) - - m := &Manager{ - receiveCB: receiveCB, - sent: sent, - received: received, - sendQueue: make(chan queuedPart, sendQueueBuffLen), - oldTransfersRecovered: &oldTransfersRecovered, - p: p, - store: storage.InitTestingSession(t), - swb: switchboard.New(), - net: net, - rng: fastRNG.NewStreamGenerator(1000, 10, csprng.NewSystemRNG), - getRoundResults: rr, - } - - return m +type mockCmixHandler struct { + sync.Mutex + processorMap map[format.Fingerprint]message.Processor } -// newTestManagerWithTransfers creates a new test manager with transfers added -// to it. -func newTestManagerWithTransfers(numParts []uint16, sendErr, addPartners bool, - sendE2eChan chan message.Receive, receiveCB interfaces.ReceiveCallback, - kv *versioned.KV, t *testing.T) (*Manager, []sentTransferInfo, - []receivedTransferInfo) { - m := newTestManager(sendErr, sendE2eChan, nil, receiveCB, kv, t) - sti := make([]sentTransferInfo, len(numParts)) - rti := make([]receivedTransferInfo, len(numParts)) - var err error - - partSize, err := m.getPartSize() - if err != nil { - t.Errorf("Failed to get part size: %+v", err) - } - - // Add sent transfers to manager and populate the sentTransferInfo list - for i := range sti { - // Generate PRNG, the file and its parts, and the transfer key - prng := NewPrng(int64(42 + i)) - file, parts := newFile(numParts[i], partSize, prng, t) - key, _ := ftCrypto.NewTransferKey(prng) - recipient := id.NewIdFromString("recipient"+strconv.Itoa(i), id.User, t) - - // Create a sentTransferInfo with all the transfer information - sti[i] = sentTransferInfo{ - recipient: recipient, - key: key, - parts: parts, - file: file, - numParts: numParts[i], - numFps: calcNumberOfFingerprints(numParts[i], 0.5), - retry: 0.5, - period: time.Millisecond, - prng: prng, - } - - // Create sent progress callback and channel - cbChan := make(chan sentProgressResults, 8) - cb := func(completed bool, sent, arrived, total uint16, - tr interfaces.FilePartTracker, err error) { - cbChan <- sentProgressResults{completed, sent, arrived, total, tr, err} - } - - // Add callback and channel to the sentTransferInfo - sti[i].cbChan = cbChan - sti[i].cb = cb - - // Add the transfer to the manager - sti[i].tid, err = m.sent.AddTransfer(recipient, sti[i].key, - sti[i].parts, sti[i].numFps, sti[i].cb, sti[i].period, sti[i].prng) - if err != nil { - t.Errorf("Failed to add sent transfer #%d: %+v", i, err) - } - - // Add recipient as partner - if addPartners { - grp := m.store.E2e().GetGroup() - dhKey := grp.NewInt(int64(i + 42)) - pubKey := diffieHellman.GeneratePublicKey(dhKey, grp) - p := params.GetDefaultE2ESessionParams() - rng := csprng.NewSystemRNG() - _, mySidhPriv := util.GenerateSIDHKeyPair( - sidh.KeyVariantSidhA, rng) - theirSidhPub, _ := util.GenerateSIDHKeyPair( - sidh.KeyVariantSidhB, rng) - err = m.store.E2e().AddPartner(recipient, pubKey, dhKey, - mySidhPriv, theirSidhPub, p, p) - if err != nil { - t.Errorf("Failed to add partner #%d %s: %+v", i, recipient, err) - } - } +func newMockCmixHandler() *mockCmixHandler { + return &mockCmixHandler{ + processorMap: make(map[format.Fingerprint]message.Processor), } +} - // Add received transfers to manager and populate the receivedTransferInfo - // list - for i := range rti { - // Generate PRNG, the file and its parts, and the transfer key - prng := NewPrng(int64(42 + i)) - file, parts := newFile(numParts[i], partSize, prng, t) - key, _ := ftCrypto.NewTransferKey(prng) - - // Create a receivedTransferInfo with all the transfer information - rti[i] = receivedTransferInfo{ - key: key, - mac: ftCrypto.CreateTransferMAC(file, key), - parts: parts, - file: file, - fileSize: uint32(len(file)), - numParts: numParts[i], - numFps: calcNumberOfFingerprints(numParts[i], 0.5), - retry: 0.5, - period: time.Millisecond, - prng: prng, - } - - // Create received progress callback and channel - cbChan := make(chan receivedProgressResults, 8) - cb := func(completed bool, received, total uint16, - tr interfaces.FilePartTracker, err error) { - cbChan <- receivedProgressResults{completed, received, total, tr, err} - } - - // Add callback and channel to the receivedTransferInfo - rti[i].cbChan = cbChan - rti[i].cb = cb +type mockCmix struct { + myID *id.ID + numPrimeBytes int + health bool + handler *mockCmixHandler + healthCBs map[uint64]func(b bool) + healthIndex uint64 + sync.Mutex +} - // Add the transfer to the manager - rti[i].tid, err = m.received.AddTransfer(rti[i].key, rti[i].mac, - rti[i].fileSize, rti[i].numParts, rti[i].numFps, rti[i].prng) - if err != nil { - t.Errorf("Failed to add received transfer #%d: %+v", i, err) - } +func newMockCmix(myID *id.ID, handler *mockCmixHandler) *mockCmix { + return &mockCmix{ + myID: myID, + numPrimeBytes: 4096, + health: true, + handler: handler, + healthCBs: make(map[uint64]func(b bool)), + healthIndex: 0, } - - return m, sti, rti } -// receivedFtResults is used to return received new file transfer results on a -// channel from a callback. -type receivedFtResults struct { - tid ftCrypto.TransferID - fileName string - fileType string - sender *id.ID - size uint32 - preview []byte +func (m *mockCmix) GetMaxMessageLength() int { + msg := format.NewMessage(m.numPrimeBytes) + return msg.ContentsSize() } -// sentProgressResults is used to return sent progress results on a channel from -// a callback. -type sentProgressResults struct { - completed bool - sent, arrived, total uint16 - tracker interfaces.FilePartTracker - err error +func (m *mockCmix) SendMany(messages []cmix.TargetedCmixMessage, + _ cmix.CMIXParams) (id.Round, []ephemeral.Id, error) { + m.handler.Lock() + for _, targetedMsg := range messages { + msg := format.NewMessage(m.numPrimeBytes) + msg.SetContents(targetedMsg.Payload) + msg.SetMac(targetedMsg.Mac) + msg.SetKeyFP(targetedMsg.Fingerprint) + m.handler.processorMap[targetedMsg.Fingerprint].Process(msg, + receptionID.EphemeralIdentity{Source: targetedMsg.Recipient}, + rounds.Round{ID: 42}) + } + m.handler.Unlock() + return 42, []ephemeral.Id{}, nil } -// sentTransferInfo contains information on a sent transfer. -type sentTransferInfo struct { - recipient *id.ID - key ftCrypto.TransferKey - tid ftCrypto.TransferID - parts [][]byte - file []byte - numParts uint16 - numFps uint16 - retry float32 - cb interfaces.SentProgressCallback - cbChan chan sentProgressResults - period time.Duration - prng csprng.Source +func (m *mockCmix) AddFingerprint(_ *id.ID, fp format.Fingerprint, mp message.Processor) error { + m.Lock() + defer m.Unlock() + m.handler.processorMap[fp] = mp + return nil } -// receivedProgressResults is used to return received progress results on a -// channel from a callback. -type receivedProgressResults struct { - completed bool - received, total uint16 - tracker interfaces.FilePartTracker - err error +func (m *mockCmix) DeleteFingerprint(_ *id.ID, fp format.Fingerprint) { + m.handler.Lock() + delete(m.handler.processorMap, fp) + m.handler.Unlock() } -// receivedTransferInfo contains information on a received transfer. -type receivedTransferInfo struct { - key ftCrypto.TransferKey - tid ftCrypto.TransferID - mac []byte - parts [][]byte - file []byte - fileSize uint32 - numParts uint16 - numFps uint16 - retry float32 - cb interfaces.ReceivedProgressCallback - cbChan chan receivedProgressResults - period time.Duration - prng csprng.Source +func (m *mockCmix) IsHealthy() bool { + return m.health } -//////////////////////////////////////////////////////////////////////////////// -// Test Network Manager // -//////////////////////////////////////////////////////////////////////////////// - -func newTestNetworkManager(sendErr bool, sendChan, - sendE2eChan chan message.Receive, t *testing.T) interfaces.NetworkManager { - instanceComms := &connect.ProtoComms{ - Manager: connect.NewManagerTesting(t), - } +func (m *mockCmix) WasHealthy() bool { return true } - thisInstance, err := network.NewInstanceTesting(instanceComms, getNDF(), - getNDF(), nil, nil, t) - if err != nil { - t.Fatalf("Failed to create new test instance: %v", err) - } +func (m *mockCmix) AddHealthCallback(f func(bool)) uint64 { + m.Lock() + defer m.Unlock() + m.healthIndex++ + m.healthCBs[m.healthIndex] = f + go f(true) + return m.healthIndex +} - return &testNetworkManager{ - instance: thisInstance, - rid: 0, - messages: make(map[id.Round][]message.TargetedCmixMessage), - sendErr: sendErr, - health: newTestHealthTracker(), - sendChan: sendChan, - sendE2eChan: sendE2eChan, +func (m *mockCmix) RemoveHealthCallback(healthID uint64) { + m.Lock() + defer m.Unlock() + if _, exists := m.healthCBs[healthID]; !exists { + jww.FATAL.Panicf("No health callback with ID %d exists.", healthID) } + delete(m.healthCBs, healthID) } -// testNetworkManager is a test implementation of NetworkManager interface. -type testNetworkManager struct { - instance *network.Instance - updateRid bool - rid id.Round - messages map[id.Round][]message.TargetedCmixMessage - e2eMessages []message.Send - sendErr bool - health testHealthTracker - sendChan chan message.Receive - sendE2eChan chan message.Receive - sync.RWMutex +func (m *mockCmix) GetRoundResults(_ time.Duration, + roundCallback cmix.RoundEventCallback, _ ...id.Round) error { + go roundCallback(true, false, map[id.Round]cmix.RoundResult{42: {}}) + return nil } -func (tnm *testNetworkManager) GetMsgList(rid id.Round) []message.TargetedCmixMessage { - tnm.RLock() - defer tnm.RUnlock() - return tnm.messages[rid] +//////////////////////////////////////////////////////////////////////////////// +// Mock E2E Handler // +//////////////////////////////////////////////////////////////////////////////// +func newMockListener(hearChan chan receive.Message) *mockListener { + return &mockListener{hearChan: hearChan} } -func (tnm *testNetworkManager) GetE2eMsg(i int) message.Send { - tnm.RLock() - defer tnm.RUnlock() - return tnm.e2eMessages[i] +func (l *mockListener) Hear(item receive.Message) { + l.hearChan <- item } -func (tnm *testNetworkManager) SendE2E(msg message.Send, _ params.E2E, _ *stoppable.Single) ( - []id.Round, e2e.MessageID, time.Time, error) { - tnm.Lock() - defer tnm.Unlock() - - if tnm.sendErr { - return nil, e2e.MessageID{}, time.Time{}, errors.New("SendE2E error") - } - - tnm.e2eMessages = append(tnm.e2eMessages, msg) - - if tnm.sendE2eChan != nil { - tnm.sendE2eChan <- message.Receive{ - Payload: msg.Payload, - MessageType: msg.MessageType, - Sender: &id.ID{}, - RecipientID: msg.Recipient, - } - } - - return []id.Round{0, 1, 2, 3}, e2e.MessageID{}, time.Time{}, nil -} - -func (tnm *testNetworkManager) SendUnsafe(message.Send, params.Unsafe) ([]id.Round, error) { - return []id.Round{}, nil +func (l *mockListener) Name() string { + return "mockListener" } -func (tnm *testNetworkManager) SendCMIX(format.Message, *id.ID, params.CMIX) (id.Round, ephemeral.Id, error) { - return 0, ephemeral.Id{}, nil +type mockE2eHandler struct { + msgMap map[id.ID]map[catalog.MessageType][][]byte + listeners map[id.ID]map[catalog.MessageType]receive.Listener } -func (tnm *testNetworkManager) SendManyCMIX(messages []message.TargetedCmixMessage, _ params.CMIX) ( - id.Round, []ephemeral.Id, error) { - tnm.Lock() - defer func() { - // Increment the round every two calls to SendManyCMIX - if tnm.updateRid { - tnm.rid++ - tnm.updateRid = false - } else { - tnm.updateRid = true - } - tnm.Unlock() - }() - - if tnm.sendErr { - return 0, nil, errors.New("SendManyCMIX error") +func newMockE2eHandler() *mockE2eHandler { + return &mockE2eHandler{ + msgMap: make(map[id.ID]map[catalog.MessageType][][]byte), + listeners: make(map[id.ID]map[catalog.MessageType]receive.Listener), } - - tnm.messages[tnm.rid] = messages - - if tnm.sendChan != nil { - for _, msg := range messages { - tnm.sendChan <- message.Receive{ - Payload: msg.Message.Marshal(), - Sender: &id.ID{0}, - RoundId: tnm.rid, - } - } - } - - return tnm.rid, nil, nil } -type dummyEventMgr struct{} - -func (d *dummyEventMgr) Report(int, string, string, string) {} -func (tnm *testNetworkManager) GetEventManager() event.Manager { - return &dummyEventMgr{} +type mockE2e struct { + myID *id.ID + handler *mockE2eHandler } -func (tnm *testNetworkManager) GetInstance() *network.Instance { return tnm.instance } -func (tnm *testNetworkManager) GetHealthTracker() interfaces.HealthTracker { return tnm.health } -func (tnm *testNetworkManager) Follow(interfaces.ClientErrorReport) (stoppable.Stoppable, error) { - return nil, nil -} -func (tnm *testNetworkManager) CheckGarbledMessages() {} -func (tnm *testNetworkManager) InProgressRegistrations() int { return 0 } -func (tnm *testNetworkManager) GetSender() *gateway.Sender { return nil } -func (tnm *testNetworkManager) GetAddressSize() uint8 { return 0 } -func (tnm *testNetworkManager) RegisterAddressSizeNotification(string) (chan uint8, error) { - return nil, nil -} -func (tnm *testNetworkManager) UnregisterAddressSizeNotification(string) {} -func (tnm *testNetworkManager) SetPoolFilter(gateway.Filter) {} -func (tnm *testNetworkManager) GetVerboseRounds() string { return "" } - -type testHealthTracker struct { - chIndex, fnIndex uint64 - channels map[uint64]chan bool - funcs map[uint64]func(bool) - healthy bool +type mockListener struct { + hearChan chan receive.Message } -//////////////////////////////////////////////////////////////////////////////// -// Test Health Tracker // -//////////////////////////////////////////////////////////////////////////////// - -func newTestHealthTracker() testHealthTracker { - return testHealthTracker{ - chIndex: 0, - fnIndex: 0, - channels: make(map[uint64]chan bool), - funcs: make(map[uint64]func(bool)), - healthy: true, +func newMockE2e(myID *id.ID, handler *mockE2eHandler) *mockE2e { + return &mockE2e{ + myID: myID, + handler: handler, } } -func (tht testHealthTracker) AddChannel(c chan bool) uint64 { - tht.channels[tht.chIndex] = c - tht.chIndex++ - return tht.chIndex - 1 -} +// SendE2E adds the message to the e2e handler map. +func (m *mockE2e) SendE2E(mt catalog.MessageType, recipient *id.ID, payload []byte, + _ e2e.Params) ([]id.Round, e2eCrypto.MessageID, time.Time, error) { -func (tht testHealthTracker) RemoveChannel(chanID uint64) { delete(tht.channels, chanID) } + m.handler.listeners[*recipient][mt].Hear(receive.Message{ + MessageType: mt, + Payload: payload, + Sender: m.myID, + RecipientID: recipient, + }) -func (tht testHealthTracker) AddFunc(f func(bool)) uint64 { - tht.funcs[tht.fnIndex] = f - tht.fnIndex++ - return tht.fnIndex - 1 + return []id.Round{42}, e2eCrypto.MessageID{}, netTime.Now(), nil } -func (tht testHealthTracker) RemoveFunc(funcID uint64) { delete(tht.funcs, funcID) } -func (tht testHealthTracker) IsHealthy() bool { return tht.healthy } -func (tht testHealthTracker) WasHealthy() bool { return tht.healthy } - -//////////////////////////////////////////////////////////////////////////////// -// NDF Primes // -//////////////////////////////////////////////////////////////////////////////// - -func getNDF() *ndf.NetworkDefinition { - return &ndf.NetworkDefinition{ - E2E: ndf.Group{ - Prime: "E2EE983D031DC1DB6F1A7A67DF0E9A8E5561DB8E8D49413394C049B7A" + - "8ACCEDC298708F121951D9CF920EC5D146727AA4AE535B0922C688B55B3D" + - "D2AEDF6C01C94764DAB937935AA83BE36E67760713AB44A6337C20E78615" + - "75E745D31F8B9E9AD8412118C62A3E2E29DF46B0864D0C951C394A5CBBDC" + - "6ADC718DD2A3E041023DBB5AB23EBB4742DE9C1687B5B34FA48C3521632C" + - "4A530E8FFB1BC51DADDF453B0B2717C2BC6669ED76B4BDD5C9FF558E88F2" + - "6E5785302BEDBCA23EAC5ACE92096EE8A60642FB61E8F3D24990B8CB12EE" + - "448EEF78E184C7242DD161C7738F32BF29A841698978825B4111B4BC3E1E" + - "198455095958333D776D8B2BEEED3A1A1A221A6E37E664A64B83981C46FF" + - "DDC1A45E3D5211AAF8BFBC072768C4F50D7D7803D2D4F278DE8014A47323" + - "631D7E064DE81C0C6BFA43EF0E6998860F1390B5D3FEACAF1696015CB79C" + - "3F9C2D93D961120CD0E5F12CBB687EAB045241F96789C38E89D796138E63" + - "19BE62E35D87B1048CA28BE389B575E994DCA755471584A09EC723742DC3" + - "5873847AEF49F66E43873", - Generator: "2", - }, - CMIX: ndf.Group{ - Prime: "9DB6FB5951B66BB6FE1E140F1D2CE5502374161FD6538DF1648218642" + - "F0B5C48C8F7A41AADFA187324B87674FA1822B00F1ECF8136943D7C55757" + - "264E5A1A44FFE012E9936E00C1D3E9310B01C7D179805D3058B2A9F4BB6F" + - "9716BFE6117C6B5B3CC4D9BE341104AD4A80AD6C94E005F4B993E14F091E" + - "B51743BF33050C38DE235567E1B34C3D6A5C0CEAA1A0F368213C3D19843D" + - "0B4B09DCB9FC72D39C8DE41F1BF14D4BB4563CA28371621CAD3324B6A2D3" + - "92145BEBFAC748805236F5CA2FE92B871CD8F9C36D3292B5509CA8CAA77A" + - "2ADFC7BFD77DDA6F71125A7456FEA153E433256A2261C6A06ED3693797E7" + - "995FAD5AABBCFBE3EDA2741E375404AE25B", - Generator: "5C7FF6B06F8F143FE8288433493E4769C4D988ACE5BE25A0E2480" + - "9670716C613D7B0CEE6932F8FAA7C44D2CB24523DA53FBE4F6EC3595892D" + - "1AA58C4328A06C46A15662E7EAA703A1DECF8BBB2D05DBE2EB956C142A33" + - "8661D10461C0D135472085057F3494309FFA73C611F78B32ADBB5740C361" + - "C9F35BE90997DB2014E2EF5AA61782F52ABEB8BD6432C4DD097BC5423B28" + - "5DAFB60DC364E8161F4A2A35ACA3A10B1C4D203CC76A470A33AFDCBDD929" + - "59859ABD8B56E1725252D78EAC66E71BA9AE3F1DD2487199874393CD4D83" + - "2186800654760E1E34C09E4D155179F9EC0DC4473F996BDCE6EED1CABED8" + - "B6F116F7AD9CF505DF0F998E34AB27514B0FFE7", - }, +func (m *mockE2e) RegisterListener(senderID *id.ID, mt catalog.MessageType, + listener receive.Listener) receive.ListenerID { + if _, exists := m.handler.listeners[*senderID]; !exists { + m.handler.listeners[*senderID] = map[catalog.MessageType]receive.Listener{mt: listener} + } else if _, exists = m.handler.listeners[*senderID][mt]; !exists { + m.handler.listeners[*senderID][mt] = listener } + return receive.ListenerID{} } diff --git a/go.mod b/go.mod index 276521a57d4678dbc98cc161ba0218cd55c38370..957e6ea2dd5991d7d82fbbcb4a505889bb4da15c 100644 --- a/go.mod +++ b/go.mod @@ -12,13 +12,14 @@ require ( github.com/spf13/jwalterweatherman v1.1.0 github.com/spf13/viper v1.7.1 gitlab.com/elixxir/bloomfilter v0.0.0-20200930191214-10e9ac31b228 - gitlab.com/elixxir/comms v0.0.4-0.20220308183624-c2183e687a03 + gitlab.com/elixxir/comms v0.0.4-0.20220323190139-9ed75f3a8b2c gitlab.com/elixxir/crypto v0.0.7-0.20220406193349-d25222ea3c6e gitlab.com/elixxir/ekv v0.1.6 gitlab.com/elixxir/primitives v0.0.3-0.20220330212736-cce83b5f948f - gitlab.com/xx_network/comms v0.0.4-0.20220311192415-d95fe8906580 - gitlab.com/xx_network/crypto v0.0.5-0.20220222212031-750f7e8a01f4 - gitlab.com/xx_network/primitives v0.0.4-0.20220222211843-901fa4a2d72b + gitlab.com/xx_network/comms v0.0.4-0.20220315161313-76acb14429ac + gitlab.com/xx_network/crypto v0.0.5-0.20220317171841-084640957d71 + gitlab.com/xx_network/primitives v0.0.4-0.20220324193139-b292d1ae6e7e + go.uber.org/ratelimit v0.2.0 golang.org/x/crypto v0.0.0-20220128200615-198e4374d7ed golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2 google.golang.org/grpc v1.42.0 @@ -26,6 +27,7 @@ require ( ) require ( + github.com/andres-erbsen/clock v0.0.0-20160526145045-9e14626cd129 // indirect github.com/badoux/checkmail v1.2.1 // indirect github.com/elliotchance/orderedmap v1.4.0 // indirect github.com/fsnotify/fsnotify v1.4.9 // indirect @@ -49,6 +51,7 @@ require ( gitlab.com/xx_network/ring v0.0.3-0.20220222211904-da613960ad93 // indirect golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac // indirect golang.org/x/text v0.3.6 // indirect + golang.org/x/time v0.0.0-20220224211638-0e9765cccd65 // indirect google.golang.org/genproto v0.0.0-20210105202744-fe13368bc0e1 // indirect gopkg.in/ini.v1 v1.62.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect diff --git a/go.sum b/go.sum index 428a24b569408319d2adbf1839eb006b8056d1f8..9b08ec215353ae39af5fc18f32821d63e6a88863 100644 --- a/go.sum +++ b/go.sum @@ -16,6 +16,8 @@ github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= +github.com/andres-erbsen/clock v0.0.0-20160526145045-9e14626cd129 h1:MzBOUgng9orim59UnfUTLRjMpd09C5uEVQ6RPGeCaVI= +github.com/andres-erbsen/clock v0.0.0-20160526145045-9e14626cd129/go.mod h1:rFgpPQZYZ8vdbc+48xibu8ALc3yeyd64IhHS+PU6Yyg= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= @@ -274,6 +276,8 @@ gitlab.com/elixxir/bloomfilter v0.0.0-20200930191214-10e9ac31b228 h1:Gi6rj4mAlK0 gitlab.com/elixxir/bloomfilter v0.0.0-20200930191214-10e9ac31b228/go.mod h1:H6jztdm0k+wEV2QGK/KYA+MY9nj9Zzatux/qIvDDv3k= gitlab.com/elixxir/comms v0.0.4-0.20220308183624-c2183e687a03 h1:4eNjO3wCyHgxpGeq2zgDb5SsdTcQaG5IZjBOuEL6KgM= gitlab.com/elixxir/comms v0.0.4-0.20220308183624-c2183e687a03/go.mod h1:4yMdU+Jee5W9lqkZGHJAuipEhW7FloT0eyVEFUJza+E= +gitlab.com/elixxir/comms v0.0.4-0.20220323190139-9ed75f3a8b2c h1:ajjTw08YjRjl3HvtBNGtoCWhOg8k8upqmTweH18wkC4= +gitlab.com/elixxir/comms v0.0.4-0.20220323190139-9ed75f3a8b2c/go.mod h1:tlHSrtSliKWUxsck8z/Ql/VJkMdSONV2BeWaUAAXzgk= gitlab.com/elixxir/crypto v0.0.0-20200804182833-984246dea2c4/go.mod h1:ucm9SFKJo+K0N2GwRRpaNr+tKXMIOVWzmyUD0SbOu2c= gitlab.com/elixxir/crypto v0.0.3/go.mod h1:ZNgBOblhYToR4m8tj4cMvJ9UsJAUKq+p0gCp07WQmhA= gitlab.com/elixxir/crypto v0.0.7-0.20220222221347-95c7ae58da6b/go.mod h1:tD6XjtQh87T2nKZL5I/pYPck5M2wLpkZ1Oz7H/LqO10= @@ -301,6 +305,8 @@ gitlab.com/elixxir/primitives v0.0.0-20200804182913-788f47bded40/go.mod h1:tzdFF gitlab.com/elixxir/primitives v0.0.1/go.mod h1:kNp47yPqja2lHSiS4DddTvFpB/4D9dB2YKnw5c+LJCE= gitlab.com/elixxir/primitives v0.0.3-0.20220222212109-d412a6e46623 h1:NzJ06KdJd3fVJee0QvGhNr3CO+Ki8Ea1PeakZsm+rZM= gitlab.com/elixxir/primitives v0.0.3-0.20220222212109-d412a6e46623/go.mod h1:MtFIyJUQn9P7djzVlBpEYkPNnnWFTjZvw89swoXY+QM= +gitlab.com/elixxir/primitives v0.0.3-0.20220323183834-b98f255361b8 h1:U3Ahbg2N6QL5uwPyccWWN4ZUBFBLgCsuq5sQxOI2VCw= +gitlab.com/elixxir/primitives v0.0.3-0.20220323183834-b98f255361b8/go.mod h1:MtFIyJUQn9P7djzVlBpEYkPNnnWFTjZvw89swoXY+QM= gitlab.com/elixxir/primitives v0.0.3-0.20220325212708-5e2a7a385db7 h1:2yIEkxkPJboftkk/xiCONVP/FSl7Y3zOPpekvQ2dhO0= gitlab.com/elixxir/primitives v0.0.3-0.20220325212708-5e2a7a385db7/go.mod h1:MtFIyJUQn9P7djzVlBpEYkPNnnWFTjZvw89swoXY+QM= gitlab.com/elixxir/primitives v0.0.3-0.20220330212736-cce83b5f948f h1:bOX9nsG+ihvZAUP2OimpM/0UNIcfz5tYlvQfS2IFUoU= @@ -309,15 +315,21 @@ gitlab.com/xx_network/comms v0.0.0-20200805174823-841427dd5023/go.mod h1:owEcxTR gitlab.com/xx_network/comms v0.0.4-0.20220223205228-7c4974139569/go.mod h1:isHnwem0v4rTcwwHP455FhVlFyPcHkHiVz+N3s/uCSI= gitlab.com/xx_network/comms v0.0.4-0.20220311192415-d95fe8906580 h1:IV0gDwdTxtCpc9Vkx7IeSStSqvG+0ZpF57X+OhTQDIM= gitlab.com/xx_network/comms v0.0.4-0.20220311192415-d95fe8906580/go.mod h1:isHnwem0v4rTcwwHP455FhVlFyPcHkHiVz+N3s/uCSI= +gitlab.com/xx_network/comms v0.0.4-0.20220315161313-76acb14429ac h1:+ykw0JqLH/qMprPEKazGHNH8gUoHGA78EIr4ienxnw4= +gitlab.com/xx_network/comms v0.0.4-0.20220315161313-76acb14429ac/go.mod h1:isHnwem0v4rTcwwHP455FhVlFyPcHkHiVz+N3s/uCSI= gitlab.com/xx_network/crypto v0.0.3/go.mod h1:DF2HYvvCw9wkBybXcXAgQMzX+MiGbFPjwt3t17VRqRE= gitlab.com/xx_network/crypto v0.0.4/go.mod h1:+lcQEy+Th4eswFgQDwT0EXKp4AXrlubxalwQFH5O0Mk= gitlab.com/xx_network/crypto v0.0.5-0.20220222212031-750f7e8a01f4 h1:95dZDMn/hpLNwsgZO9eyQgGKaSDyh6F6+WygqZIciww= gitlab.com/xx_network/crypto v0.0.5-0.20220222212031-750f7e8a01f4/go.mod h1:6apvsoHCQJDjO0J4E3uhR3yO9tTz/Mq5be5rjB3tQPU= +gitlab.com/xx_network/crypto v0.0.5-0.20220317171841-084640957d71 h1:N2+Jja4xNg66entu6rGvzRcf3Vc785xgiaHeDPYnBvg= +gitlab.com/xx_network/crypto v0.0.5-0.20220317171841-084640957d71/go.mod h1:/SJf+R75E+QepdTLh0H1/udsovxx2Q5ru34q1v0umKk= gitlab.com/xx_network/primitives v0.0.0-20200803231956-9b192c57ea7c/go.mod h1:wtdCMr7DPePz9qwctNoAUzZtbOSHSedcK++3Df3psjA= gitlab.com/xx_network/primitives v0.0.0-20200804183002-f99f7a7284da/go.mod h1:OK9xevzWCaPO7b1wiluVJGk7R5ZsuC7pHY5hteZFQug= gitlab.com/xx_network/primitives v0.0.2/go.mod h1:cs0QlFpdMDI6lAo61lDRH2JZz+3aVkHy+QogOB6F/qc= gitlab.com/xx_network/primitives v0.0.4-0.20220222211843-901fa4a2d72b h1:shZZ3xZNNKYVpEp4Mqu/G9+ZR+J8QA8mmPk/4Cit8+Y= gitlab.com/xx_network/primitives v0.0.4-0.20220222211843-901fa4a2d72b/go.mod h1:9imZHvYwNFobxueSvVtHneZLk9wTK7HQTzxPm+zhFhE= +gitlab.com/xx_network/primitives v0.0.4-0.20220324193139-b292d1ae6e7e h1:F651DdbU9n5qOLN8rdyAIluXWtm5Wl4jwH5D6ED3bSI= +gitlab.com/xx_network/primitives v0.0.4-0.20220324193139-b292d1ae6e7e/go.mod h1:AXVVFt7dDAeIUpOGPiStCcUIKsBXLWbmV/BgZ4T+tOo= gitlab.com/xx_network/ring v0.0.3-0.20220222211904-da613960ad93 h1:eJZrXqHsMmmejEPWw8gNAt0I8CGAMNO/7C339Zco3TM= gitlab.com/xx_network/ring v0.0.3-0.20220222211904-da613960ad93/go.mod h1:aLzpP2TiZTQut/PVHR40EJAomzugDdHXetbieRClXIM= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= @@ -325,7 +337,10 @@ go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= +go.uber.org/ratelimit v0.2.0 h1:UQE2Bgi7p2B85uP5dC2bbRtig0C+OeNRnNEafLjsLPA= +go.uber.org/ratelimit v0.2.0/go.mod h1:YYBV4e4naJvhpitQrWJu1vCpgB7CboMe0qhltKt6mUg= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= @@ -425,6 +440,8 @@ golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20220224211638-0e9765cccd65 h1:M73Iuj3xbbb9Uk1DYhzydthsj6oOd6l9bpuFcNoUvTs= +golang.org/x/time v0.0.0-20220224211638-0e9765cccd65/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/groupChat/makeGroup.go b/groupChat/makeGroup.go index d8f6dd8c32d3f01c97724d497c6b3beafcc44ee0..cc0d22da3119d16a05b7f5dade00cc928ff1cf58 100644 --- a/groupChat/makeGroup.go +++ b/groupChat/makeGroup.go @@ -29,7 +29,6 @@ const ( makeMembershipErr = "failed to assemble group chat membership: %+v" newIdPreimageErr = "failed to create group ID preimage: %+v" newKeyPreimageErr = "failed to create group key preimage: %+v" - addGroupErr = "failed to save new group: %+v" ) // MaxInitMessageSize is the maximum allowable length of the initial message @@ -81,16 +80,12 @@ func (m Manager) MakeGroup(membership []*id.ID, name, msg []byte) (gs.Group, // Create new group and add to manager g := gs.NewGroup( name, groupID, groupKey, idPreimage, keyPreimage, msg, created, mem, dkl) - if err = m.gs.Add(g); err != nil { - return gs.Group{}, nil, NotSent, errors.Errorf(addGroupErr, err) - } jww.DEBUG.Printf("Created new group %q with ID %s and %d members %s", g.Name, g.ID, len(g.Members), g.Members) // Send all group requests roundIDs, status, err := m.sendRequests(g) - if err == nil { err = m.JoinGroup(g) } diff --git a/groupChat/makeGroup_test.go b/groupChat/makeGroup_test.go index eb0676859a739945e6dd2821a97ac4d31da64586..a9541b84471d9d35db12787731916d761f5294e7 100644 --- a/groupChat/makeGroup_test.go +++ b/groupChat/makeGroup_test.go @@ -121,7 +121,7 @@ func TestManager_MakeGroup_AddGroupError(t *testing.T) { prng := rand.New(rand.NewSource(42)) m, _ := newTestManagerWithStore(prng, gs.MaxGroupChats, 0, nil, nil, t) memberIDs, _, _ := addPartners(m, t) - expectedErr := strings.SplitN(addGroupErr, "%", 2)[0] + expectedErr := strings.SplitN(joinGroupErr, "%", 2)[0] _, _, _, err := m.MakeGroup(memberIDs, []byte{}, []byte{}) if err == nil || !strings.Contains(err.Error(), expectedErr) { diff --git a/groupChat/manager.go b/groupChat/manager.go index c67aae9bcb9c677d05709634ae8aca95e000d95e..19c4de8bbca37c62fb7284e9c52701b865f7abba 100644 --- a/groupChat/manager.go +++ b/groupChat/manager.go @@ -44,6 +44,7 @@ type GroupCmix interface { response message.Processor) DeleteService(clientID *id.ID, toDelete message.Service, processor message.Processor) + GetMaxMessageLength() int } // GroupE2e is a subset of the e2e.Handler interface containing only the methods needed by GroupChat diff --git a/groupChat/manager_test.go b/groupChat/manager_test.go index 9cf51c8bcded7fd73381fff79d05a2cf57027604..5a9c30cdc4db6e31d486bdd54a7a5243b846a30d 100644 --- a/groupChat/manager_test.go +++ b/groupChat/manager_test.go @@ -31,7 +31,7 @@ func Test_newManager(t *testing.T) { requestFunc := func(g gs.Group) { requestChan <- g } receiveChan := make(chan MessageReceive) receiveFunc := func(msg MessageReceive) { receiveChan <- msg } - m, err := NewManager(nil, nil, user.ID, nil, nil, kv, requestFunc, receiveFunc) + m, err := NewManager(nil, newTestE2eManager(user.DhKey), user.ID, nil, nil, kv, requestFunc, receiveFunc) if err != nil { t.Errorf("newManager() returned an error: %+v", err) } @@ -84,7 +84,7 @@ func Test_newManager_LoadStorage(t *testing.T) { } } - m, err := NewManager(nil, nil, user.ID, nil, nil, kv, nil, nil) + m, err := NewManager(newTestNetworkManager(0, t), newTestE2eManager(user.DhKey), user.ID, nil, nil, kv, nil, nil) if err != nil { t.Errorf("newManager() returned an error: %+v", err) } @@ -118,7 +118,7 @@ func Test_newManager_LoadError(t *testing.T) { expectedErr := strings.SplitN(newGroupStoreErr, "%", 2)[0] - _, err = NewManager(nil, nil, user.ID, nil, nil, kv, nil, nil) + _, err = NewManager(nil, newTestE2eManager(user.DhKey), user.ID, nil, nil, kv, nil, nil) if err == nil || !strings.Contains(err.Error(), expectedErr) { t.Errorf("newManager() did not return the expected error."+ "\nexpected: %s\nreceived: %+v", expectedErr, err) diff --git a/groupChat/receive.go b/groupChat/receive.go index b559a0b0d4776cd70425f86c0744b3b6eac2b5f4..503a52fa79434510348680b442596c1a6b3dff5c 100644 --- a/groupChat/receive.go +++ b/groupChat/receive.go @@ -38,7 +38,6 @@ type receptionProcessor struct { // Process incoming group chat messages func (p *receptionProcessor) Process(message format.Message, receptionID receptionID.EphemeralIdentity, round rounds.Round) { jww.TRACE.Print("Group message reception received cMix message.") - // Attempt to read the message roundTimestamp := round.Timestamps[states.QUEUED] @@ -117,7 +116,6 @@ func getCryptKey(key group.Key, salt [group.SaltLen]byte, mac, payload []byte, // Compute the current epoch epoch := group.ComputeEpoch(roundTimestamp) - for _, dhKey := range dhKeys { // Create a key with the correct epoch diff --git a/groupChat/receiveRequest.go b/groupChat/receiveRequest.go index 127cd9a37ad2297d931974936dc74e279fcf9a53..3a8741d6b0a7e494c137ac53c6f1f7de9834a389 100644 --- a/groupChat/receiveRequest.go +++ b/groupChat/receiveRequest.go @@ -89,7 +89,7 @@ func (m *Manager) readRequest(msg receive.Message) (gs.Group, error) { // Generate the DH keys with each group member privKey := partner.GetMyOriginPrivateKey() - dkl := gs.GenerateDhKeyList(m.gs.GetUser().ID, privKey, membership, m.grp) + dkl := gs.GenerateDhKeyList(m.receptionId, privKey, membership, m.grp) // Restore the original public key for the leader so that the membership // digest generated later is correct diff --git a/groupChat/receiveRequest_test.go b/groupChat/receiveRequest_test.go index 265309f6e92fe1057e495143029550b4f9360325..398de2f495a348b2f34323cadee1d060074834e0 100644 --- a/groupChat/receiveRequest_test.go +++ b/groupChat/receiveRequest_test.go @@ -66,7 +66,7 @@ func TestRequestListener_Hear(t *testing.T) { _, _ = m.e2e.AddPartner( g.Members[0].ID, g.Members[0].DhKey, - m.grp.NewInt(2), + m.e2e.GetHistoricalDHPrivkey(), theirSIDHPubKey, mySIDHPrivKey, session.GetDefaultParams(), session.GetDefaultParams(), @@ -165,7 +165,7 @@ func TestManager_readRequest(t *testing.T) { _, _ = m.e2e.AddPartner( g.Members[0].ID, g.Members[0].DhKey, - m.grp.NewInt(2), + m.e2e.GetHistoricalDHPrivkey(), theirSIDHPubKey, mySIDHPrivKey, session.GetDefaultParams(), session.GetDefaultParams(), diff --git a/groupChat/send.go b/groupChat/send.go index dfcc522ef11627cd3ccb32648e8b7549f1a45858..64796dd35111424ff438b5ce2b6d7f721cb2083d 100644 --- a/groupChat/send.go +++ b/groupChat/send.go @@ -27,7 +27,6 @@ import ( const ( newCmixMsgErr = "failed to generate cMix messages for group chat: %+v" sendManyCmixErr = "failed to send group chat message from member %s to group %s: %+v" - newCmixErr = "failed to generate cMix message for member %d with ID %s in group %s: %+v" messageLenErr = "message length %d is greater than maximum message space %d" newNoGroupErr = "failed to create message for group %s that cannot be found" newKeyErr = "failed to generate key for encrypting group payload" @@ -94,7 +93,7 @@ func (m *Manager) newMessages(g gs.Group, msg []byte, timestamp time.Time) ( } // Add cMix message to list - cMixMsg, err := newCmixMsg(g, msg, timestamp, member, rng, m.receptionId, m.grp) + cMixMsg, err := newCmixMsg(g, msg, timestamp, member, rng, m.receptionId, m.services.GetMaxMessageLength()) if err != nil { return nil, err } @@ -106,7 +105,7 @@ func (m *Manager) newMessages(g gs.Group, msg []byte, timestamp time.Time) ( // newCmixMsg generates a new cMix message to be sent to a group member. func newCmixMsg(g gs.Group, msg []byte, timestamp time.Time, - mem group.Member, rng io.Reader, senderId *id.ID, grp *cyclic.Group) (cmix.TargetedCmixMessage, error) { + mem group.Member, rng io.Reader, senderId *id.ID, maxCmixMessageSize int) (cmix.TargetedCmixMessage, error) { // Initialize targeted message cmixMsg := cmix.TargetedCmixMessage{ @@ -119,7 +118,7 @@ func newCmixMsg(g gs.Group, msg []byte, timestamp time.Time, } // Create three message layers - pubMsg, intlMsg, err := newMessageParts(grp.GetP().ByteLen()) + pubMsg, intlMsg, err := newMessageParts(maxCmixMessageSize) if err != nil { return cmixMsg, err } diff --git a/groupChat/sendRequests.go b/groupChat/sendRequests.go index f87b53bb9b3d3c761145897d2c70c2dcd01f02e0..3bc201d41ecd77df369a85ba44a47056c456e22d 100644 --- a/groupChat/sendRequests.go +++ b/groupChat/sendRequests.go @@ -118,7 +118,7 @@ func (m Manager) sendRequest(memberID *id.ID, request []byte) ([]id.Round, error p.LastServiceTag = catalog.GroupRq p.CMIX.DebugTag = "group.Request" - rounds, _, _, err := m.e2e.SendE2E(catalog.GroupCreationRequest, m.receptionId, request, p) + rounds, _, _, err := m.e2e.SendE2E(catalog.GroupCreationRequest, memberID, request, p) if err != nil { return nil, errors.Errorf(sendE2eErr, memberID, err) } diff --git a/groupChat/send_test.go b/groupChat/send_test.go index ec257fd4e11dd2bbadc9f614fcb5b1afc12c4972..72d98232909c7d10f401e7805b0be4ef52393765 100644 --- a/groupChat/send_test.go +++ b/groupChat/send_test.go @@ -10,11 +10,12 @@ package groupChat import ( "bytes" "encoding/base64" - "gitlab.com/elixxir/client/cmix/rounds" "gitlab.com/elixxir/client/cmix/identity/receptionID" + "gitlab.com/elixxir/client/cmix/rounds" gs "gitlab.com/elixxir/client/groupChat/groupStore" "gitlab.com/elixxir/crypto/group" "gitlab.com/elixxir/primitives/format" + "gitlab.com/elixxir/primitives/states" "gitlab.com/xx_network/primitives/id" "gitlab.com/xx_network/primitives/netTime" "math/rand" @@ -50,8 +51,10 @@ func TestManager_Send(t *testing.T) { t.Error("No group cMix messages received.") } + timestamps := make(map[states.Round]time.Time) + timestamps[states.QUEUED] = netTime.Now().Round(0) for _, msg := range messages { - reception.Process(msg, receptionID.EphemeralIdentity{}, rounds.Round{ID: roundId}) + reception.Process(msg, receptionID.EphemeralIdentity{}, rounds.Round{ID: roundId, Timestamps: timestamps}) select { case result := <-receiveChan: if !result.SenderID.Cmp(m.receptionId) { @@ -70,26 +73,13 @@ func TestManager_Send(t *testing.T) { } } -// Error path: an error is returned when Manager.neCmixMsg returns an error. -func TestGroup_newMessages_NewCmixMsgError(t *testing.T) { - expectedErr := strings.SplitN(newCmixErr, "%", 2)[0] - prng := rand.New(rand.NewSource(42)) - m, g := newTestManager(prng, t) - - _, err := m.newMessages(g, make([]byte, 1000), netTime.Now()) - if err == nil || !strings.Contains(err.Error(), expectedErr) { - t.Errorf("newMessages() failed to return the expected error."+ - "\nexpected: %s\nreceived: %+v", expectedErr, err) - } -} - // Error path: reader returns an error. func TestGroup_newCmixMsg_SaltReaderError(t *testing.T) { expectedErr := strings.SplitN(saltReadErr, "%", 2)[0] - m := &Manager{} + m, _ := newTestManager(rand.New(rand.NewSource(42)), t) - _, err := newCmixMsg(gs.Group{}, - []byte{}, time.Time{}, group.Member{}, strings.NewReader(""), m.receptionId, m.grp) + _, err := newCmixMsg(gs.Group{ID: id.NewIdFromString("test", id.User, t)}, + []byte{}, time.Time{}, group.Member{}, strings.NewReader(""), m.receptionId, m.services.GetMaxMessageLength()) if err == nil || !strings.Contains(err.Error(), expectedErr) { t.Errorf("newCmixMsg() failed to return the expected error"+ "\nexpected: %s\nreceived: %+v", expectedErr, err) @@ -102,15 +92,15 @@ func TestGroup_newCmixMsg_InternalMsgSizeError(t *testing.T) { // Create new test Manager and Group prng := rand.New(rand.NewSource(42)) - m, g := newTestManager(prng, t) + m, g := newTestManagerWithStore(prng, 10, 0, nil, nil, t) // Create test parameters - testMsg := make([]byte, 341) + testMsg := make([]byte, 1500) mem := group.Member{ID: id.NewIdFromString("memberID", id.User, t)} // Create cMix message prng = rand.New(rand.NewSource(42)) - _, err := newCmixMsg(g, testMsg, netTime.Now(), mem, prng, m.receptionId, m.grp) + _, err := newCmixMsg(g, testMsg, netTime.Now(), mem, prng, m.receptionId, m.services.GetMaxMessageLength()) if err == nil || !strings.Contains(err.Error(), expectedErr) { t.Errorf("newCmixMsg() failed to return the expected error"+ "\nexpected: %s\nreceived: %+v", expectedErr, err) diff --git a/groupChat/utils_test.go b/groupChat/utils_test.go index 4854e56500b3964a01628475f8e9a17d69a86264..002dbb17f28fb5170d511b66126b926280beeb53 100644 --- a/groupChat/utils_test.go +++ b/groupChat/utils_test.go @@ -9,6 +9,11 @@ package groupChat import ( "encoding/base64" + "math/rand" + "sync" + "testing" + "time" + "github.com/cloudflare/circl/dh/sidh" "github.com/pkg/errors" "gitlab.com/elixxir/client/catalog" @@ -34,10 +39,6 @@ import ( "gitlab.com/xx_network/primitives/id/ephemeral" "gitlab.com/xx_network/primitives/ndf" "gitlab.com/xx_network/primitives/netTime" - "math/rand" - "sync" - "testing" - "time" ) // newTestManager creates a new Manager for testing. @@ -47,19 +48,15 @@ func newTestManager(rng *rand.Rand, t *testing.T) (*Manager, gs.Group) { rng: fastRNG.NewStreamGenerator(1000, 10, csprng.NewSystemRNG), grp: getGroup(), services: newTestNetworkManager(0, t), - e2e: &testE2eManager{ - e2eMessages: []testE2eMessage{}, - errSkip: 0, - sendErr: 0, - }, + e2e: newTestE2eManager(randCycInt(rng)), } user := group.Member{ ID: m.receptionId, - DhKey: randCycInt(rng), + DhKey: m.e2e.GetHistoricalDHPubkey(), } g := newTestGroupWithUser(m.grp, user.ID, user.DhKey, - randCycInt(rng), rng, t) + m.e2e.GetHistoricalDHPrivkey(), rng, t) gStore, err := gs.NewStore(versioned.NewKV(make(ekv.Memstore)), user) if err != nil { t.Fatalf("Failed to create new group store: %+v", err) @@ -84,14 +81,15 @@ func newTestManagerWithStore(rng *rand.Rand, numGroups int, sendErr int, services: newTestNetworkManager(sendErr, t), e2e: &testE2eManager{ e2eMessages: []testE2eMessage{}, - errSkip: 0, sendErr: sendErr, + grp: getGroup(), + dhPubKey: randCycInt(rng), + partners: make(map[id.ID]partner.Manager), }, } - user := group.Member{ ID: m.receptionId, - DhKey: randCycInt(rng), + DhKey: m.e2e.GetHistoricalDHPubkey(), } gStore, err := gs.NewStore(versioned.NewKV(make(ekv.Memstore)), user) @@ -111,6 +109,16 @@ func newTestManagerWithStore(rng *rand.Rand, numGroups int, sendErr int, return m, g } +func newTestE2eManager(dhPubKey *cyclic.Int) *testE2eManager { + return &testE2eManager{ + e2eMessages: []testE2eMessage{}, + errSkip: 0, + grp: getGroup(), + dhPubKey: dhPubKey, + partners: make(map[id.ID]partner.Manager), + } +} + // getMembership returns a Membership with random members for testing. func getMembership(size int, uid *id.ID, pubKey *cyclic.Int, grp *cyclic.Group, prng *rand.Rand, t *testing.T) group.Membership { @@ -219,6 +227,7 @@ func newTestNetworkManager(sendErr int, t *testing.T) GroupCmix { return &testNetworkManager{ receptionMessages: [][]format.Message{}, sendMessages: [][]cmix.TargetedCmixMessage{}, + grp: getGroup(), sendErr: sendErr, } } @@ -226,8 +235,11 @@ func newTestNetworkManager(sendErr int, t *testing.T) GroupCmix { // testE2eManager is a test implementation of NetworkManager interface. type testE2eManager struct { e2eMessages []testE2eMessage + partners map[id.ID]partner.Manager errSkip int sendErr int + dhPubKey *cyclic.Int + grp *cyclic.Group sync.RWMutex } @@ -237,20 +249,27 @@ type testE2eMessage struct { } func (tnm *testE2eManager) AddPartner(partnerID *id.ID, partnerPubKey, myPrivKey *cyclic.Int, - partnerSIDHPubKey *sidh.PublicKey, mySIDHPrivKey *sidh.PrivateKey, sendParams, receiveParams session.Params) (partner.Manager, error) { - return nil, nil + partnerSIDHPubKey *sidh.PublicKey, mySIDHPrivKey *sidh.PrivateKey, + sendParams, receiveParams session.Params) (partner.Manager, error) { + + testPartner := partner.NewTestManager(partnerID, partnerPubKey, myPrivKey, &testing.T{}) + tnm.partners[*partnerID] = testPartner + return testPartner, nil } func (tnm *testE2eManager) GetPartner(partnerID *id.ID) (partner.Manager, error) { - return nil, nil + if partner, ok := tnm.partners[*partnerID]; ok { + return partner, nil + } + return nil, errors.New("Unable to find partner") } func (tnm *testE2eManager) GetHistoricalDHPubkey() *cyclic.Int { - panic("implement me") + return tnm.dhPubKey } func (tnm *testE2eManager) GetHistoricalDHPrivkey() *cyclic.Int { - panic("implement me") + return tnm.dhPubKey } func (tnm *testE2eManager) SendE2E(mt catalog.MessageType, recipient *id.ID, payload []byte, params clientE2E.Params) ([]id.Round, e2e.MessageID, time.Time, error) { @@ -273,11 +292,11 @@ func (tnm *testE2eManager) SendE2E(mt catalog.MessageType, recipient *id.ID, pay } func (*testE2eManager) RegisterListener(user *id.ID, messageType catalog.MessageType, newListener receive.Listener) receive.ListenerID { - panic("implement me") + return receive.ListenerID{} } func (*testE2eManager) AddService(tag string, processor message.Processor) error { - panic("implement me") + return nil } func (*testE2eManager) GetDefaultHistoricalDHPubkey() *cyclic.Int { @@ -300,9 +319,14 @@ type testNetworkManager struct { sendMessages [][]cmix.TargetedCmixMessage errSkip int sendErr int + grp *cyclic.Group sync.RWMutex } +func (tnm *testNetworkManager) GetMaxMessageLength() int { + return format.NewMessage(tnm.grp.GetP().ByteLen()).ContentsSize() +} + func (tnm *testNetworkManager) SendMany(messages []cmix.TargetedCmixMessage, p cmix.CMIXParams) (id.Round, []ephemeral.Id, error) { if tnm.sendErr == 1 { return 0, nil, errors.New("SendManyCMIX error") @@ -315,7 +339,7 @@ func (tnm *testNetworkManager) SendMany(messages []cmix.TargetedCmixMessage, p c receiveMessages := []format.Message{} for _, msg := range messages { - receiveMsg := format.Message{} + receiveMsg := format.NewMessage(tnm.grp.GetP().ByteLen()) receiveMsg.SetMac(msg.Mac) receiveMsg.SetContents(msg.Payload) receiveMsg.SetKeyFP(msg.Fingerprint) @@ -326,17 +350,17 @@ func (tnm *testNetworkManager) SendMany(messages []cmix.TargetedCmixMessage, p c } func (*testNetworkManager) AddService(clientID *id.ID, newService message.Service, response message.Processor) { - panic("implement me") + return } func (*testNetworkManager) DeleteService(clientID *id.ID, toDelete message.Service, processor message.Processor) { - panic("implement me") + return } type dummyEventMgr struct{} func (d *dummyEventMgr) Report(int, string, string, string) {} -func (tnm *testNetworkManager) GetEventManager() event.Manager { +func (tnm *testNetworkManager) GetEventManager() event.Reporter { return &dummyEventMgr{} } diff --git a/interfaces/fileTransfer.go b/interfaces/fileTransfer.go deleted file mode 100644 index dcd5b29b6246bfcef6a43db20b44d2bcfe30f068..0000000000000000000000000000000000000000 --- a/interfaces/fileTransfer.go +++ /dev/null @@ -1,140 +0,0 @@ -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - -package interfaces - -import ( - ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" - "gitlab.com/xx_network/primitives/id" - "strconv" - "time" -) - -// SentProgressCallback is a callback function that tracks the progress of -// sending a file. -type SentProgressCallback func(completed bool, sent, arrived, total uint16, - t FilePartTracker, err error) - -// ReceivedProgressCallback is a callback function that tracks the progress of -// receiving a file. -type ReceivedProgressCallback func(completed bool, received, total uint16, - t FilePartTracker, err error) - -// ReceiveCallback is a callback function that notifies the receiver of an -// incoming file transfer. -type ReceiveCallback func(tid ftCrypto.TransferID, fileName, fileType string, - sender *id.ID, size uint32, preview []byte) - -// FileTransfer facilities the sending and receiving of large file transfers. -// It allows for progress tracking of both inbound and outbound transfers. -type FileTransfer interface { - // Send sends a file to the recipient. The sender must have an E2E - // relationship with the recipient. - // The retry float is the total amount of data to send relative to the data - // size. Data will be resent on error and will resend up to [(1 + retry) * - // fileSize]. - // The preview stores a preview of the data (such as a thumbnail) and is - // capped at 4 kB in size. - // Returns a unique transfer ID used to identify the transfer. - Send(fileName, fileType string, fileData []byte, recipient *id.ID, - retry float32, preview []byte, progressCB SentProgressCallback, - period time.Duration) (ftCrypto.TransferID, error) - - // RegisterSentProgressCallback allows for the registration of a callback to - // track the progress of an individual sent file transfer. The callback will - // be called immediately when added to report the current status of the - // transfer. It will then call every time a file part is sent, a file part - // arrives, the transfer completes, or an error occurs. It is called at most - // once ever period, which means if events occur faster than the period, - // then they will not be reported and instead the progress will be reported - // once at the end of the period. - RegisterSentProgressCallback(tid ftCrypto.TransferID, - progressCB SentProgressCallback, period time.Duration) error - - // Resend resends a file if sending fails. Returns an error if CloseSend - // was already called or if the transfer did not run out of retries. - Resend(tid ftCrypto.TransferID) error - - // CloseSend deletes a file from the internal storage once a transfer has - // completed or reached the retry limit. Returns an error if the transfer - // has not run out of retries. - CloseSend(tid ftCrypto.TransferID) error - - // Receive returns the full file on the completion of the transfer as - // reported by a registered ReceivedProgressCallback. It deletes internal - // references to the data and unregisters any attached progress callback. - // Returns an error if the transfer is not complete, the full file cannot be - // verified, or if the transfer cannot be found. - Receive(tid ftCrypto.TransferID) ([]byte, error) - - // RegisterReceivedProgressCallback allows for the registration of a - // callback to track the progress of an individual received file transfer. - // The callback will be called immediately when added to report the current - // status of the transfer. It will then call every time a file part is - // received, the transfer completes, or an error occurs. It is called at - // most once ever period, which means if events occur faster than the - // period, then they will not be reported and instead the progress will be - // reported once at the end of the period. - // Once the callback reports that the transfer has completed, the recipient - // can get the full file by calling Receive. - RegisterReceivedProgressCallback(tid ftCrypto.TransferID, - progressCB ReceivedProgressCallback, period time.Duration) error -} - -// FilePartTracker tracks the status of each file part in a sent or received -// file transfer. -type FilePartTracker interface { - // GetPartStatus returns the status of the file part with the given part - // number. The possible values for the status are: - // 0 = unsent - // 1 = sent (sender has sent a part, but it has not arrived) - // 2 = arrived (sender has sent a part, and it has arrived) - // 3 = received (receiver has received a part) - GetPartStatus(partNum uint16) FpStatus - - // GetNumParts returns the total number of file parts in the transfer. - GetNumParts() uint16 -} - -// FpStatus is the file part status and indicates the status of individual file -// parts in a file transfer. -type FpStatus int - -// Possible values for FpStatus. -const ( - // FpUnsent indicates that the file part has not been sent - FpUnsent FpStatus = iota - - // FpSent indicates that the file part has been sent (sender has sent a - // part, but it has not arrived) - FpSent - - // FpArrived indicates that the file part has arrived (sender has sent a - // part, and it has arrived) - FpArrived - - // FpReceived indicates that the file part has been received (receiver has - // received a part) - FpReceived -) - -// String returns the string representing of the FpStatus. This functions -// satisfies the fmt.Stringer interface. -func (fps FpStatus) String() string { - switch fps { - case FpUnsent: - return "unsent" - case FpSent: - return "sent" - case FpArrived: - return "arrived" - case FpReceived: - return "received" - default: - return "INVALID FpStatus: " + strconv.Itoa(int(fps)) - } -} diff --git a/interfaces/networkManager.go b/interfaces/networkManager.go index b5104d1018151632d4d1ae41767378333990d1bd..1315c7d710571d694f6ca77d7b0222e2e50aa0f9 100644 --- a/interfaces/networkManager.go +++ b/interfaces/networkManager.go @@ -8,13 +8,14 @@ package interfaces import ( + "time" + "gitlab.com/elixxir/comms/network" "gitlab.com/xx_network/comms/connect" "gitlab.com/xx_network/primitives/ndf" - "time" - "gitlab.com/elixxir/client/interfaces/message" - "gitlab.com/elixxir/client/interfaces/params" + "gitlab.com/elixxir/client/cmix" + "gitlab.com/elixxir/client/cmix/message" "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/comms/mixmessages" "gitlab.com/elixxir/primitives/format" @@ -29,24 +30,26 @@ type NetworkManager interface { // Only one follower may run at a time. Follow(report ClientErrorReport) (stoppable.Stoppable, error) - /*===Sending==============================================================*/ + /*===Sending==========================================================*/ - // SendCMIX sends a "raw" CMIX message payload to the provided recipient. - // Returns the round ID of the round the payload was sent or an error - // if it fails. - SendCMIX(message format.Message, recipient *id.ID, p params.CMIX) ( + // SendCMIX sends a "raw" CMIX message payload to the provided + // recipient. Returns the round ID of the round the payload + // was sent or an error if it fails. + SendCMIX(message format.Message, recipient *id.ID, p cmix.Params) ( id.Round, ephemeral.Id, error) - // SendManyCMIX sends many "raw" cMix message payloads to each of the provided - // recipients. Used to send messages in group chats. Metadata is NOT as well - // protected with this call and can leak data about yourself. Should be - // replaced with multiple uses of SendCmix in most cases. Returns the round - // ID of the round the payload was sent or an error if it fails. + // SendManyCMIX sends many "raw" cMix message payloads to each + // of the provided recipients. Used to send messages in group + // chats. Metadata is NOT as well protected with this call and + // can leak data about yourself. Should be replaced with + // multiple uses of SendCmix in most cases. Returns the round + // ID of the round the payload was sent or an error if it + // fails. // WARNING: Potentially Unsafe - SendManyCMIX(messages []message.TargetedCmixMessage, p params.CMIX) ( + SendManyCMIX(messages []cmix.TargetedCmixMessage, p cmix.Params) ( id.Round, []ephemeral.Id, error) - /*===Message Reception====================================================*/ + /*===Message Reception================================================*/ /* Identities are all network identites which the client is currently trying to pick up message on. An identity must be added to receive messages, fake ones will be used to poll the network @@ -60,23 +63,24 @@ type NetworkManager interface { // RemoveIdentity removes a currently tracked identity. RemoveIdentity(id *id.ID) - /* Fingerprints are the primary mechanism of identifying a picked up - message over cMix. They are a unique one time use 255 bit vector generally - associated with a specific encryption key, but can be used for an - alternative protocol.When registering a fingerprint, a MessageProcessor - is registered to handle the message.*/ + /* Fingerprints are the primary mechanism of identifying a + picked up message over cMix. They are a unique one time use + 255 bit vector generally associated with a specific encryption + key, but can be used for an alternative protocol.When + registering a fingerprint, a MessageProcessor is registered to + handle the message.*/ // AddFingerprint - Adds a fingerprint which will be handled by a // specific processor for messages received by the given identity AddFingerprint(identity *id.ID, fingerprint format.Fingerprint, - mp MessageProcessor) error + mp message.Processor) error - // DeleteFingerprint deletes a single fingerprint associated with the given - // identity if it exists + // DeleteFingerprint deletes a single fingerprint associated + // with the given identity if it exists DeleteFingerprint(identity *id.ID, fingerprint format.Fingerprint) - // DeleteClientFingerprints deletes al fingerprint associated with the given - // identity if it exists + // DeleteClientFingerprints deletes al fingerprint associated + // with the given identity if it exists DeleteClientFingerprints(identity *id.ID) /* trigger - predefined hash based tags appended to all cMix messages @@ -87,60 +91,72 @@ type NetworkManager interface { notifications system, or can be used to implement custom non fingerprint processing of payloads. I.E. key negotiation, broadcast negotiation - A tag is appended to the message of the format tag = H(H(messageContents), - preimage) and trial hashing is used to determine if a message adheres to a - tag. - WARNING: If a preimage is known by an adversary, they can determine which - messages are for the client on reception (which is normally hidden due to - collision between ephemeral IDs. + A tag is appended to the message of the format tag = + H(H(messageContents), preimage) and trial hashing is used to + determine if a message adheres to a tag. - Due to the extra overhead of trial hashing, triggers are processed after fingerprints. - If a fingerprint match occurs on the message, triggers will not be handled. + WARNING: If a preimage is known by an adversary, they can + determine which messages are for the client on reception + (which is normally hidden due to collision between ephemeral + IDs. - Triggers are address to the session. When starting a new client, all triggers must be - re-added before StartNetworkFollower is called. + Due to the extra overhead of trial hashing, triggers are + processed after fingerprints. If a fingerprint match occurs + on the message, triggers will not be handled. + + Triggers are address to the session. When starting a new + client, all triggers must be re-added before + StartNetworkFollower is called. */ - // AddTrigger - Adds a trigger which can call a message handing function or - // be used for notifications. Multiple triggers can be registered for the - // same preimage. + // AddTrigger - Adds a trigger which can call a message + // handing function or be used for notifications. Multiple + // triggers can be registered for the same preimage. // preimage - the preimage which is triggered on - // type - a descriptive string of the trigger. Generally used in notifications - // source - a byte buffer of related data. Generally used in notifications. + // type - a descriptive string of the trigger. Generally + // used in notifications + // source - a byte buffer of related data. Generally used in + // notifications. // Example: Sender ID - AddTrigger(identity *id.ID, newTrigger Trigger, response MessageProcessor) + AddTrigger(identity *id.ID, newTrigger message.Service, + response message.Processor) // DeleteTrigger - If only a single response is associated with the // preimage, the entire preimage is removed. If there is more than one // response, only the given response is removed if nil is passed in for // response, all triggers for the preimage will be removed - DeleteTrigger(identity *id.ID, preimage Preimage, response MessageProcessor) error + DeleteTrigger(identity *id.ID, preimage Preimage, + response message.Processor) error - // DeleteClientTriggers - deletes all triggers assoseated with the given identity + // DeleteClientTriggers - deletes all triggers assoseated with + // the given identity DeleteClientTriggers(identity *id.ID) - // TrackTriggers - Registers a callback which will get called every time triggers change. + // TrackTriggers - Registers a callback which will get called + // every time triggers change. // It will receive the triggers list every time it is modified. // Will only get callbacks while the Network Follower is running. // Multiple trackTriggers can be registered - TrackTriggers(TriggerTracker) + TrackServices(message.ServicesTracker) /* In inProcess */ - // it is possible to receive a message over cMix before the fingerprints or - // triggers are registered. As a result, when handling fails, messages are - // put in the inProcess que for a set number of retries. + // it is possible to receive a message over cMix before the + // fingerprints or triggers are registered. As a result, when + // handling fails, messages are put in the inProcess que for a + // set number of retries. // CheckInProgressMessages - retry processing all messages in check in // progress messages. Call this after adding fingerprints or triggers //while the follower is running. CheckInProgressMessages() - /*===Nodes================================================================*/ - /* Keys must be registed with nodes in order to send messages throug them. - this process is in general automatically handled by the Network Manager*/ + /*===Nodes============================================================*/ + /* Keys must be registed with nodes in order to send messages + throug them. this process is in general automatically handled + by the Network Manager*/ - // HasNode can be used to determine if a keying relationship exists with a - // node. + // HasNode can be used to determine if a keying relationship + // exists with a node. HasNode(nid *id.ID) bool // NumRegisteredNodes Returns the total number of nodes we have a keying @@ -150,7 +166,7 @@ type NetworkManager interface { // Triggers the generation of a keying relationship with a given node TriggerNodeRegistration(nid *id.ID) - /*===Historical Rounds====================================================*/ + /*===Historical Rounds================================================*/ /* A complete set of round info is not kept on the client, and sometimes the network will need to be queried to get round info. Historical rounds is the system internal to the Network Manager to do this. @@ -158,52 +174,59 @@ type NetworkManager interface { // LookupHistoricalRound - looks up the passed historical round on the // network - LookupHistoricalRound(rid id.Round, callback func(info *mixmessages.RoundInfo, - success bool)) error + LookupHistoricalRound(rid id.Round, + callback func(info *mixmessages.RoundInfo, + success bool)) error - /*===Sender===============================================================*/ - /* The sender handles sending comms to the network. It tracks connections to - gateways and handles proxying to gateways for targeted comms. It can be - used externally to contact gateway directly, bypassing the majority of - the network package*/ + /*===Sender===========================================================*/ + /* The sender handles sending comms to the network. It tracks + connections to gateways and handles proxying to gateways for + targeted comms. It can be used externally to contact gateway + directly, bypassing the majority of the network package*/ // SendToAny can be used to send the comm to any gateway in the network. - SendToAny(sendFunc func(host *connect.Host) (interface{}, error), stop *stoppable.Single) (interface{}, error) + SendToAny(sendFunc func(host *connect.Host) (interface{}, error), + stop *stoppable.Single) (interface{}, error) // SendToPreferred sends to a specific gateway, doing so through another // gateway as a proxy if not directly connected. SendToPreferred(targets []*id.ID, sendFunc func(host *connect.Host, target *id.ID, timeout time.Duration) (interface{}, error), - stop *stoppable.Single, timeout time.Duration) (interface{}, error) + stop *stoppable.Single, timeout time.Duration) (interface{}, + error) - // SetGatewayFilter sets a function which will be used to filter gateways - // before connecting. + // SetGatewayFilter sets a function which will be used to + // filter gateways before connecting. SetGatewayFilter(f func(map[id.ID]int, *ndf.NetworkDefinition) map[id.ID]int) - // GetHostParams - returns the host params used when connectign to gateways + // GetHostParams - returns the host params used when + // connectign to gateways GetHostParams() connect.HostParams - /*===Address Space========================================================*/ - // The network compasses identities into a smaller address space to cause - // collisions and hide the actual recipient of messages. These functions - // allow for the tracking of this addresses space. In general, address space - // issues are completely handled by the network package + /*===Address Space====================================================*/ + // The network compasses identities into a smaller address + // space to cause collisions and hide the actual recipient of + // messages. These functions allow for the tracking of this + // addresses space. In general, address space issues are + // completely handled by the network package - // GetAddressSpace GetAddressSize returns the current address size of IDs. Blocks until an - // address size is known. + // GetAddressSpace GetAddressSize returns the current address + // size of IDs. Blocks until an address size is known. GetAddressSpace() uint8 - // RegisterAddressSpaceNotification returns a channel that will trigger for - // every address space size update. The provided tag is the unique ID for - // the channel. Returns an error if the tag is already used. + // RegisterAddressSpaceNotification returns a channel that + // will trigger for every address space size update. The + // provided tag is the unique ID for the channel. Returns an + // error if the tag is already used. RegisterAddressSpaceNotification(tag string) (chan uint8, error) - // UnregisterAddressSpaceNotification stops broadcasting address space size - // updates on the channel with the specified tag. + // UnregisterAddressSpaceNotification stops broadcasting + // address space size updates on the channel with the + // specified tag. UnregisterAddressSpaceNotification(tag string) - /*===Accessors============================================================*/ + /*===Accessors========================================================*/ // GetInstance returns the network instance object, which tracks the // state of the network @@ -218,8 +241,3 @@ type NetworkManager interface { } type Preimage [32]byte - -type ClientErrorReport func(source, message, trace string) - -//for use in key exchange which needs to be callable inside of network -///type SendE2E func(m message.Send, p params.E2E, stop *stoppable.Single) ([]id.Round, e2e.MessageID, time.Time, error) diff --git a/interfaces/switchboard.go b/interfaces/switchboard.go deleted file mode 100644 index 2566bfb51de98fe6f5ce6abf0b1038b752f78a74..0000000000000000000000000000000000000000 --- a/interfaces/switchboard.go +++ /dev/null @@ -1,75 +0,0 @@ -/////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -/////////////////////////////////////////////////////////////////////////////// - -package interfaces - -import ( - "gitlab.com/elixxir/client/interfaces/message" - "gitlab.com/elixxir/client/switchboard" - "gitlab.com/xx_network/primitives/id" -) - -// public switchboard interface which only allows registration and does not -// allow speaking messages -type Switchboard interface { - // Registers a new listener. Returns the ID of the new listener. - // Keep this around if you want to be able to delete the listener later. - // - // name is used for debug printing and not checked for uniqueness - // - // user: 0 for all, or any user ID to listen for messages from a particular - // user. 0 can be id.ZeroUser or id.ZeroID - // messageType: 0 for all, or any message type to listen for messages of - // that type. 0 can be Receive.AnyType - // newListener: something implementing the Listener interface. Do not - // pass nil to this. - // - // If a message matches multiple listeners, all of them will hear the - // message. - RegisterListener(user *id.ID, messageType message.Type, - newListener switchboard.Listener) switchboard.ListenerID - - // Registers a new listener built around the passed function. - // Returns the ID of the new listener. - // Keep this around if you want to be able to delete the listener later. - // - // name is used for debug printing and not checked for uniqueness - // - // user: 0 for all, or any user ID to listen for messages from a particular - // user. 0 can be id.ZeroUser or id.ZeroID - // messageType: 0 for all, or any message type to listen for messages of - // that type. 0 can be Receive.AnyType - // newListener: a function implementing the ListenerFunc function type. - // Do not pass nil to this. - // - // If a message matches multiple listeners, all of them will hear the - // message. - RegisterFunc(name string, user *id.ID, messageType message.Type, - newListener switchboard.ListenerFunc) switchboard.ListenerID - - // Registers a new listener built around the passed channel. - // Returns the ID of the new listener. - // Keep this around if you want to be able to delete the listener later. - // - // name is used for debug printing and not checked for uniqueness - // - // user: 0 for all, or any user ID to listen for messages from a particular - // user. 0 can be id.ZeroUser or id.ZeroID - // messageType: 0 for all, or any message type to listen for messages of - // that type. 0 can be Receive.AnyType - // newListener: an item channel. - // Do not pass nil to this. - // - // If a message matches multiple listeners, all of them will hear the - // message. - RegisterChannel(name string, user *id.ID, messageType message.Type, - newListener chan message.Receive) switchboard.ListenerID - - // Unregister removes the listener with the specified ID so it will no - // longer get called - Unregister(listenerID switchboard.ListenerID) -} diff --git a/interfaces/user/proto.go b/interfaces/user/proto.go deleted file mode 100644 index 657c2e714d70f06b88c6d0866ec2bbca5015f24e..0000000000000000000000000000000000000000 --- a/interfaces/user/proto.go +++ /dev/null @@ -1,29 +0,0 @@ -package user - -import ( - "gitlab.com/elixxir/crypto/cyclic" - "gitlab.com/xx_network/crypto/signature/rsa" - "gitlab.com/xx_network/primitives/id" -) - -type Proto struct { - //General Identity - TransmissionID *id.ID - TransmissionSalt []byte - TransmissionRSA *rsa.PrivateKey - ReceptionID *id.ID - ReceptionSalt []byte - ReceptionRSA *rsa.PrivateKey - Precanned bool - // Timestamp in which user has registered with the network - RegistrationTimestamp int64 - - RegCode string - - TransmissionRegValidationSig []byte - ReceptionRegValidationSig []byte - - //e2e Identity - E2eDhPrivateKey *cyclic.Int - E2eDhPublicKey *cyclic.Int -} diff --git a/storage/fileTransfer/partInfo.go b/storage/fileTransfer/partInfo.go deleted file mode 100644 index 1502921d10395f76a67b1a36397dd1dfb08f8fc0..0000000000000000000000000000000000000000 --- a/storage/fileTransfer/partInfo.go +++ /dev/null @@ -1,75 +0,0 @@ -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - -package fileTransfer - -import ( - "bytes" - "encoding/binary" - ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" - "strconv" -) - -// partInfo contains the transfer ID and fingerprint number for a file part. -type partInfo struct { - id ftCrypto.TransferID - fpNum uint16 -} - -// newPartInfo generates a new partInfo with the specified transfer ID and -// fingerprint number for a file part. -func newPartInfo(tid ftCrypto.TransferID, fpNum uint16) *partInfo { - pi := &partInfo{ - id: tid, - fpNum: fpNum, - } - - return pi -} - -// marshal serializes the partInfo into a byte slice. -func (pi *partInfo) marshal() []byte { - // Construct the buffer - buff := bytes.NewBuffer(nil) - buff.Grow(ftCrypto.TransferIdLength + 2) - - // Write the transfer ID to the buffer - buff.Write(pi.id.Bytes()) - - // Write the fingerprint number to the buffer - b := make([]byte, 2) - binary.LittleEndian.PutUint16(b, pi.fpNum) - buff.Write(b) - - // Return the serialized data - return buff.Bytes() -} - -// unmarshalPartInfo deserializes the byte slice into a partInfo. -func unmarshalPartInfo(b []byte) *partInfo { - buff := bytes.NewBuffer(b) - - // Read transfer ID from the buffer - transferIDBytes := buff.Next(ftCrypto.TransferIdLength) - transferID := ftCrypto.UnmarshalTransferID(transferIDBytes) - - // Read the fingerprint number from the buffer - fpNumBytes := buff.Next(2) - fpNum := binary.LittleEndian.Uint16(fpNumBytes) - - // Return the reconstructed partInfo - return &partInfo{ - id: transferID, - fpNum: fpNum, - } -} - -// String prints a string representation of partInfo. This functions satisfies -// the fmt.Stringer interface. -func (pi *partInfo) String() string { - return "{id:" + pi.id.String() + " fpNum:" + strconv.Itoa(int(pi.fpNum)) + "}" -} diff --git a/storage/fileTransfer/partInfo_test.go b/storage/fileTransfer/partInfo_test.go deleted file mode 100644 index b0ce83a3e10a4c5c9f309313b69b9a4602904cfc..0000000000000000000000000000000000000000 --- a/storage/fileTransfer/partInfo_test.go +++ /dev/null @@ -1,74 +0,0 @@ -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// -package fileTransfer - -import ( - ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" - "reflect" - "testing" -) - -// Tests that newPartInfo creates the expected partInfo. -func Test_newPartInfo(t *testing.T) { - // Created expected partInfo - expected := &partInfo{ - id: ftCrypto.UnmarshalTransferID([]byte("TestTransferID")), - fpNum: 16, - } - - // Create new partInfo - received := newPartInfo(expected.id, expected.fpNum) - - if !reflect.DeepEqual(expected, received) { - t.Fatalf("New partInfo does not match expected."+ - "\nexpected: %+v\nreceived: %+v", expected, received) - } -} - -// Tests that a partInfo that is marshalled with partInfo.marshal and -// unmarshalled with unmarshalPartInfo matches the original. -func Test_partInfo_marshal_unmarshalPartInfo(t *testing.T) { - expectedPI := newPartInfo( - ftCrypto.UnmarshalTransferID([]byte("TestTransferID")), 25) - - piBytes := expectedPI.marshal() - receivedPI := unmarshalPartInfo(piBytes) - - if !reflect.DeepEqual(expectedPI, receivedPI) { - t.Errorf("Marshalled and unmarshalled partInfo does not match original."+ - "\nexpected: %+v\nreceived: %+v", expectedPI, receivedPI) - } -} - -// Consistency test of partInfo.String. -func Test_partInfo_String(t *testing.T) { - prng := NewPrng(42) - expectedStrings := []string{ - "{id:U4x/lrFkvxuXu59LtHLon1sUhPJSCcnZND6SugndnVI= fpNum:0}", - "{id:39ebTXZCm2F6DJ+fDTulWwzA1hRMiIU1hBrL4HCbB1g= fpNum:1}", - "{id:CD9h03W8ArQd9PkZKeGP2p5vguVOdI6B555LvW/jTNw= fpNum:2}", - "{id:uoQ+6NY+jE/+HOvqVG2PrBPdGqwEzi6ih3xVec+ix44= fpNum:3}", - "{id:GwuvrogbgqdREIpC7TyQPKpDRlp4YgYWl4rtDOPGxPM= fpNum:4}", - "{id:rnvD4ElbVxL+/b4MECiH4QDazS2IX2kstgfaAKEcHHA= fpNum:5}", - "{id:ceeWotwtwlpbdLLhKXBeJz8FySMmgo4rBW44F2WOEGE= fpNum:6}", - "{id:SYlH/fNEQQ7UwRYCP6jjV2tv7Sf/iXS6wMr9mtBWkrE= fpNum:7}", - "{id:NhnnOJZN/ceejVNDc2Yc/WbXT+weG4lJGrcjbkt1IWI= fpNum:8}", - "{id:kM8r60LDyicyhWDxqsBnzqbov0bUqytGgEAsX7KCDog= fpNum:9}", - "{id:XTJg8d6XgoPUoJo2+WwglBdG4+1NpkaprotPp7T8OiA= fpNum:10}", - "{id:uvoade0yeoa4sMOa8c/Ss7USGep5Uzq/RI0sR50yYHU= fpNum:11}", - } - - for i, expected := range expectedStrings { - tid, _ := ftCrypto.NewTransferID(prng) - pi := newPartInfo(tid, uint16(i)) - - if expected != pi.String() { - t.Errorf("partInfo #%d string does not match expected."+ - "\nexpected: %s\nreceived: %s", i, expected, pi.String()) - } - } -} diff --git a/storage/fileTransfer/partStore.go b/storage/fileTransfer/partStore.go deleted file mode 100644 index 59018ab63b506c8fb0f7d046aeea900ec3d41c6a..0000000000000000000000000000000000000000 --- a/storage/fileTransfer/partStore.go +++ /dev/null @@ -1,285 +0,0 @@ -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - -package fileTransfer - -import ( - "bytes" - "encoding/binary" - "github.com/pkg/errors" - "gitlab.com/elixxir/client/storage/versioned" - "gitlab.com/xx_network/primitives/netTime" - "strconv" - "sync" -) - -// Storage keys and versions. -const ( - partsStoreVersion = 0 - partsStoreKey = "FileTransferPart" - partsListVersion = 0 - partsListKey = "FileTransferList" -) - -// Error messages. -const ( - loadPartListErr = "failed to get parts list from storage: %+v" - loadPartsErr = "failed to load part #%d from storage: %+v" - savePartsErr = "failed to save part #%d to storage: %+v" -) - -// partStore stores the file parts in memory/storage. -type partStore struct { - parts map[uint16][]byte // File parts, keyed on their number in order - numParts uint16 // Number of parts in full file - kv *versioned.KV - mux sync.RWMutex -} - -// newPartStore generates a new empty partStore and saves it to storage. -func newPartStore(kv *versioned.KV, numParts uint16) (*partStore, error) { - // Construct empty partStore of the specified size - ps := &partStore{ - parts: make(map[uint16][]byte, numParts), - numParts: numParts, - kv: kv, - } - - // Save to storage - return ps, ps.save() -} - -// newPartStore generates a new empty partStore and saves it to storage. -func newPartStoreFromParts(kv *versioned.KV, parts ...[]byte) (*partStore, - error) { - - // Construct empty partStore of the specified size - ps := &partStore{ - parts: partSliceToMap(parts...), - numParts: uint16(len(parts)), - kv: kv, - } - - // Save to storage - return ps, ps.save() -} - -// addPart adds a file part to the list of parts, saves it to storage, and -// regenerates the list of all file parts and saves it to storage. -func (ps *partStore) addPart(part []byte, partNum uint16) error { - ps.mux.Lock() - defer ps.mux.Unlock() - - ps.parts[partNum] = make([]byte, len(part)) - copy(ps.parts[partNum], part) - - err := ps.savePart(partNum) - if err != nil { - return err - } - - return ps.saveList() -} - -// getPart returns the part at the given part number. -func (ps *partStore) getPart(partNum uint16) ([]byte, bool) { - ps.mux.Lock() - defer ps.mux.Unlock() - - part, exists := ps.parts[partNum] - newPart := make([]byte, len(part)) - copy(newPart, part) - - return newPart, exists -} - -// getFile returns all file parts concatenated into a single file. Returns the -// entire file as a byte slice and the number of parts missing. If the int is 0, -// then no parts are missing and the returned file is complete. -func (ps *partStore) getFile() ([]byte, int) { - ps.mux.Lock() - defer ps.mux.Unlock() - - // get the length of one of the parts (all parts should be the same size) - partLength := 0 - for _, part := range ps.parts { - partLength = len(part) - break - } - - // Create new empty buffer of the size of the whole file - buff := bytes.NewBuffer(nil) - buff.Grow(int(ps.numParts) * partLength) - - // Loop through the map in order and add each file part that exists - var missingParts int - for i := uint16(0); i < ps.numParts; i++ { - part, exists := ps.parts[i] - if exists { - buff.Write(part) - } else { - missingParts++ - } - } - - return buff.Bytes(), missingParts -} - -// len returns the number of parts stored. -func (ps *partStore) len() int { - return len(ps.parts) -} - -//////////////////////////////////////////////////////////////////////////////// -// Storage Functions // -//////////////////////////////////////////////////////////////////////////////// - -// loadPartStore loads all the file parts from storage into memory. -func loadPartStore(kv *versioned.KV) (*partStore, error) { - // get list of saved file parts - vo, err := kv.Get(partsListKey, partsListVersion) - if err != nil { - return nil, errors.Errorf(loadPartListErr, err) - } - - // Unmarshal saved data into a list - numParts, list := unmarshalPartList(vo.Data) - - // Initialize part map - ps := &partStore{ - parts: make(map[uint16][]byte, numParts), - numParts: numParts, - kv: kv, - } - - // Load each part from storage and add to the map - for _, partNum := range list { - vo, err = kv.Get(makePartsKey(partNum), partsStoreVersion) - if err != nil { - return nil, errors.Errorf(loadPartsErr, partNum, err) - } - - ps.parts[partNum] = vo.Data - } - - return ps, nil -} - -// save stores a list of all file parts and the individual parts in storage. -func (ps *partStore) save() error { - ps.mux.Lock() - defer ps.mux.Unlock() - - // Save the individual file parts to storage - for partNum := range ps.parts { - err := ps.savePart(partNum) - if err != nil { - return errors.Errorf(savePartsErr, partNum, err) - } - } - - // Save the part list to storage - return ps.saveList() -} - -// saveList stores the list of all file parts in storage. -func (ps *partStore) saveList() error { - obj := &versioned.Object{ - Version: partsStoreVersion, - Timestamp: netTime.Now(), - Data: ps.marshalList(), - } - - return ps.kv.Set(partsListKey, partsListVersion, obj) -} - -// savePart stores an individual file part to storage. -func (ps *partStore) savePart(partNum uint16) error { - obj := &versioned.Object{ - Version: partsStoreVersion, - Timestamp: netTime.Now(), - Data: ps.parts[partNum], - } - - return ps.kv.Set(makePartsKey(partNum), partsStoreVersion, obj) -} - -// delete removes all the file parts and file list from storage. -func (ps *partStore) delete() error { - ps.mux.Lock() - defer ps.mux.Unlock() - - for partNum := range ps.parts { - err := ps.kv.Delete(makePartsKey(partNum), partsStoreVersion) - if err != nil { - return err - } - } - - return ps.kv.Delete(partsListKey, partsListVersion) -} - -// marshalList creates a list of part numbers that are currently stored and -// returns them as a byte list to be saved to storage. -func (ps *partStore) marshalList() []byte { - // Create new buffer of the correct size - // (numParts (2 bytes) + (2*length of parts)) - buff := bytes.NewBuffer(nil) - buff.Grow(2 + (2 * len(ps.parts))) - - // Write numParts to buffer - b := make([]byte, 2) - binary.LittleEndian.PutUint16(b, ps.numParts) - buff.Write(b) - - for partNum := range ps.parts { - b = make([]byte, 2) - binary.LittleEndian.PutUint16(b, partNum) - buff.Write(b) - } - - return buff.Bytes() -} - -// unmarshalPartList unmarshalls a byte slice into a list of part numbers. -func unmarshalPartList(b []byte) (uint16, []uint16) { - buff := bytes.NewBuffer(b) - - // Read numParts from the buffer - numParts := binary.LittleEndian.Uint16(buff.Next(2)) - - // Initialize the list to the number of saved parts - list := make([]uint16, 0, len(b)/2) - - // Read each uint16 from the buffer and save into the list - for next := buff.Next(2); len(next) == 2; next = buff.Next(2) { - part := binary.LittleEndian.Uint16(next) - list = append(list, part) - } - - return numParts, list -} - -// partSliceToMap converts a slice of file parts, in order, to a map of file -// parts keyed on their part number. -func partSliceToMap(parts ...[]byte) map[uint16][]byte { - // Initialise map to the correct size - partMap := make(map[uint16][]byte, len(parts)) - - // Add each file part to the map - for partNum, part := range parts { - partMap[uint16(partNum)] = make([]byte, len(part)) - copy(partMap[uint16(partNum)], part) - } - - return partMap -} - -// makePartsKey generates the key used to save a part to storage. -func makePartsKey(partNum uint16) string { - return partsStoreKey + strconv.Itoa(int(partNum)) -} diff --git a/storage/fileTransfer/partStore_test.go b/storage/fileTransfer/partStore_test.go deleted file mode 100644 index 597fea4b20a6871cf44310e97cbd3c4f06c8aa59..0000000000000000000000000000000000000000 --- a/storage/fileTransfer/partStore_test.go +++ /dev/null @@ -1,566 +0,0 @@ -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - -package fileTransfer - -import ( - "bytes" - "encoding/binary" - "gitlab.com/elixxir/client/storage/versioned" - "gitlab.com/elixxir/ekv" - "gitlab.com/elixxir/primitives/format" - "io" - "math/rand" - "reflect" - "sort" - "strings" - "testing" -) - -// Tests that newPartStore produces the expected new partStore and that an empty -// parts list is saved to storage. -func Test_newPartStore(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - numParts := uint16(16) - - expectedPS := &partStore{ - parts: make(map[uint16][]byte, numParts), - numParts: numParts, - kv: kv, - } - - ps, err := newPartStore(kv, numParts) - if err != nil { - t.Errorf("newPartStore returned an error: %+v", err) - } - - if !reflect.DeepEqual(expectedPS, ps) { - t.Errorf("Returned incorrect partStore.\nexpected: %+v\nreceived: %+v", - expectedPS, ps) - } - - _, err = kv.Get(partsListKey, partsListVersion) - if err != nil { - t.Errorf("Failed to load part list from storage: %+v", err) - } -} - -// Tests that newPartStoreFromParts produces the expected partStore filled with -// the given parts. -func Test_newPartStoreFromParts(t *testing.T) { - prng := rand.New(rand.NewSource(42)) - kv := versioned.NewKV(make(ekv.Memstore)) - numParts := uint16(16) - - // Generate part slice and part map filled with the same data - partSlice := make([][]byte, numParts) - partMap := make(map[uint16][]byte, numParts) - for i := uint16(0); i < numParts; i++ { - b := make([]byte, 32) - prng.Read(b) - - partSlice[i] = b - partMap[i] = b - } - - expectedPS := &partStore{ - parts: partMap, - numParts: uint16(len(partMap)), - kv: kv, - } - - ps, err := newPartStoreFromParts(kv, partSlice...) - if err != nil { - t.Errorf("newPartStoreFromParts returned an error: %+v", err) - } - - if !reflect.DeepEqual(expectedPS, ps) { - t.Errorf("Returned incorrect partStore.\nexpected: %+v\nreceived: %+v", - expectedPS, ps) - } - - loadedPS, err := loadPartStore(kv) - if err != nil { - t.Errorf("Failed to load partStore from storage: %+v", err) - } - - if !reflect.DeepEqual(expectedPS, loadedPS) { - t.Errorf("Loaded partStore does not match expected."+ - "\nexpected: %+v\nreceived: %+v", expectedPS, loadedPS) - } -} - -// Tests that a part added via partStore.addPart can be loaded from memory and -// storage. -func Test_partStore_addPart(t *testing.T) { - prng := rand.New(rand.NewSource(42)) - kv := versioned.NewKV(make(ekv.Memstore)) - numParts := uint16(16) - ps, _ := newRandomPartStore(numParts, kv, prng, t) - - expectedPart := []byte("part data") - partNum := uint16(17) - - err := ps.addPart(expectedPart, partNum) - if err != nil { - t.Errorf("addPart returned an error: %+v", err) - } - - // Check if part is in memory - part, exists := ps.parts[partNum] - if !exists || !bytes.Equal(expectedPart, part) { - t.Errorf("Failed to get part #%d from memory."+ - "\nexpected: %+v\nreceived: %+v", partNum, expectedPart, part) - } - - // Load part from storage - vo, err := kv.Get(makePartsKey(partNum), partsStoreVersion) - if err != nil { - t.Errorf("Failed to load part from storage: %+v", err) - } - - // Check that the part loaded from storage is correct - if !bytes.Equal(expectedPart, vo.Data) { - t.Errorf("Part saved to storage unexpected"+ - "\nexpected: %+v\nreceived: %+v", expectedPart, vo.Data) - } -} - -// Tests that partStore.getPart returns the expected part. -func Test_partStore_getPart(t *testing.T) { - prng := rand.New(rand.NewSource(42)) - kv := versioned.NewKV(make(ekv.Memstore)) - numParts := uint16(16) - ps, _ := newRandomPartStore(numParts, kv, prng, t) - - expectedPart := []byte("part data") - partNum := uint16(17) - - err := ps.addPart(expectedPart, partNum) - if err != nil { - t.Errorf("addPart returned an error: %+v", err) - } - - // Check if part is in memory - part, exists := ps.getPart(partNum) - if !exists || !bytes.Equal(expectedPart, part) { - t.Errorf("Failed to get part #%d from memory."+ - "\nexpected: %+v\nreceived: %+v", partNum, expectedPart, part) - } -} - -// Tests that partStore.getFile returns all parts concatenated in order. -func Test_partStore_getFile(t *testing.T) { - prng := rand.New(rand.NewSource(42)) - kv := versioned.NewKV(make(ekv.Memstore)) - ps, expectedFile := newRandomPartStore(16, kv, prng, t) - - // Pull data from file - receivedFile, missingParts := ps.getFile() - if missingParts != 0 { - t.Errorf("File has missing parts.\nexpected: %d\nreceived: %d", - 0, missingParts) - } - - // Check correctness of reconstructions - if !bytes.Equal(receivedFile, expectedFile) { - t.Fatalf("Full reconstructed file does not match expected."+ - "\nexpected: %v\nreceived: %v", expectedFile, receivedFile) - } -} - -// Tests that partStore.getFile returns all parts concatenated in order. -func Test_partStore_getFile_MissingPartsError(t *testing.T) { - prng := rand.New(rand.NewSource(42)) - kv := versioned.NewKV(make(ekv.Memstore)) - numParts := uint16(16) - ps, _ := newRandomPartStore(numParts, kv, prng, t) - - // Delete half the parts - for partNum := uint16(0); partNum < numParts; partNum++ { - if partNum%2 == 0 { - delete(ps.parts, partNum) - } - } - - // Pull data from file - _, missingParts := ps.getFile() - if missingParts != 8 { - t.Errorf("Missing incorrect number of parts."+ - "\nexpected: %d\nreceived: %d", 0, missingParts) - } -} - -// Tests that partStore.len returns 0 for a new map and the correct value when -// parts are added -func Test_partStore_len(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - numParts := uint16(16) - ps, _ := newPartStore(kv, numParts) - - if ps.len() != 0 { - t.Errorf("Length of new partStore is not 0."+ - "\nexpected: %d\nreceived: %d", 0, ps.len()) - } - - addedParts := 5 - for i := 0; i < addedParts; i++ { - _ = ps.addPart([]byte("test"), uint16(i)) - } - - if ps.len() != addedParts { - t.Errorf("Length of new partStore incorrect."+ - "\nexpected: %d\nreceived: %d", addedParts, ps.len()) - } -} - -// Tests that loadPartStore gets the expected partStore from storage. -func Test_loadPartStore(t *testing.T) { - prng := rand.New(rand.NewSource(42)) - kv := versioned.NewKV(make(ekv.Memstore)) - expectedPS, _ := newRandomPartStore(16, kv, prng, t) - - err := expectedPS.save() - if err != nil { - t.Errorf("Failed to save parts to storage: %+v", err) - } - - ps, err := loadPartStore(kv) - if err != nil { - t.Errorf("loadPartStore returned an error: %+v", err) - } - - if !reflect.DeepEqual(expectedPS, ps) { - t.Errorf("Loaded partStore does not match expected."+ - "\nexpected: %+v\nreceived: %+v", expectedPS, ps) - } -} - -// Error path: tests that loadPartStore returns the expected error when no part -// list is saved to storage. -func Test_loadPartStore_NoSavedListErr(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - expectedErr := strings.Split(loadPartListErr, ":")[0] - _, err := loadPartStore(kv) - if err == nil || !strings.Contains(err.Error(), expectedErr) { - t.Errorf("loadPartStore failed to return the expected error when no "+ - "object is saved in storage.\nexpected: %s\nreceived: %+v", - expectedErr, err) - } -} - -// Error path: tests that loadPartStore returns the expected error when no parts -// are saved to storage. -func Test_loadPartStore_NoPartErr(t *testing.T) { - prng := rand.New(rand.NewSource(42)) - kv := versioned.NewKV(make(ekv.Memstore)) - expectedErr := strings.Split(loadPartsErr, "#")[0] - ps, _ := newRandomPartStore(16, kv, prng, t) - - err := ps.saveList() - if err != nil { - t.Errorf("Failed to save part list to storage: %+v", err) - } - - _, err = loadPartStore(kv) - if err == nil || !strings.Contains(err.Error(), expectedErr) { - t.Errorf("loadPartStore failed to return the expected error when no "+ - "object is saved in storage.\nexpected: %s\nreceived: %+v", - expectedErr, err) - } -} - -// Tests that partStore.save correctly stores the part list to storage by -// reading it from storage and ensuring all part numbers are present. It then -// makes sure that all the parts are stored in storage. -func Test_partStore_save(t *testing.T) { - prng := rand.New(rand.NewSource(42)) - kv := versioned.NewKV(make(ekv.Memstore)) - - numParts := uint16(16) - ps, _ := newRandomPartStore(numParts, kv, prng, t) - - err := ps.save() - if err != nil { - t.Errorf("save returned an error: %+v", err) - } - - vo, err := kv.Get(partsListKey, partsListVersion) - if err != nil { - t.Errorf("Failed to load part list from storage: %+v", err) - } - - numList := make(map[uint16]bool, numParts) - for i := uint16(0); i < numParts; i++ { - numList[i] = true - } - - buff := bytes.NewBuffer(vo.Data[2:]) - for next := buff.Next(2); len(next) == 2; next = buff.Next(2) { - partNum := binary.LittleEndian.Uint16(next) - if !numList[partNum] { - t.Errorf("Part number %d is not a correct part number.", partNum) - } else { - delete(numList, partNum) - } - } - - if len(numList) != 0 { - t.Errorf("File part numbers missing from list: %+v", numList) - } - - loadedNumParts, list := unmarshalPartList(vo.Data) - - if numParts != loadedNumParts { - t.Errorf("Loaded numParts does not match expected."+ - "\nexpected: %d\nrecieved: %d", numParts, loadedNumParts) - } - - for _, partNum := range list { - vo, err := kv.Get(makePartsKey(partNum), partsStoreVersion) - if err != nil { - t.Errorf("Failed to load part #%d from storage: %+v", partNum, err) - } - - if !bytes.Equal(ps.parts[partNum], vo.Data) { - t.Errorf("Part data #%d loaded from storage unexpected."+ - "\nexpected: %+v\nreceived: %+v", - partNum, ps.parts[partNum], vo.Data) - } - } -} - -// Tests that partStore.saveList saves the expected list to storage. -func Test_partStore_saveList(t *testing.T) { - prng := rand.New(rand.NewSource(42)) - kv := versioned.NewKV(make(ekv.Memstore)) - numParts := uint16(16) - expectedPS, _ := newRandomPartStore(numParts, kv, prng, t) - - err := expectedPS.saveList() - if err != nil { - t.Errorf("saveList returned an error: %+v", err) - } - - vo, err := kv.Get(partsListKey, partsListVersion) - if err != nil { - t.Errorf("Failed to load part list from storage: %+v", err) - } - - numList := make(map[uint16]bool, numParts) - for i := uint16(0); i < numParts; i++ { - numList[i] = true - } - - buff := bytes.NewBuffer(vo.Data[2:]) - for next := buff.Next(2); len(next) == 2; next = buff.Next(2) { - partNum := binary.LittleEndian.Uint16(next) - if !numList[partNum] { - t.Errorf("Part number %d is not a correct part number.", partNum) - } else { - delete(numList, partNum) - } - } - - if len(numList) != 0 { - t.Errorf("File part numbers missing from list: %+v", numList) - } -} - -// Tests that a part saved via partStore.savePart can be loaded from storage. -func Test_partStore_savePart(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - ps, err := newPartStore(kv, 16) - if err != nil { - t.Fatalf("Failed to create new partStore: %+v", err) - } - - expectedPart := []byte("part data") - partNum := uint16(5) - ps.parts[partNum] = expectedPart - - err = ps.savePart(partNum) - if err != nil { - t.Errorf("savePart returned an error: %+v", err) - } - - // Load part from storage - vo, err := kv.Get(makePartsKey(partNum), partsStoreVersion) - if err != nil { - t.Errorf("Failed to load part from storage: %+v", err) - } - - // Check that the part loaded from storage is correct - if !bytes.Equal(expectedPart, vo.Data) { - t.Errorf("Part saved to storage unexpected"+ - "\nexpected: %+v\nreceived: %+v", expectedPart, vo.Data) - } -} - -// Tests that partStore.delete deletes all stored items. -func Test_partStore_delete(t *testing.T) { - prng := rand.New(rand.NewSource(42)) - kv := versioned.NewKV(make(ekv.Memstore)) - numParts := uint16(16) - ps, _ := newRandomPartStore(numParts, kv, prng, t) - - err := ps.delete() - if err != nil { - t.Errorf("delete returned an error: %+v", err) - } - - _, err = kv.Get(partsListKey, partsListVersion) - if err == nil { - t.Error("Able to load part list from storage when it should have been " + - "deleted.") - } - - for partNum := range ps.parts { - _, err := kv.Get(makePartsKey(partNum), partsStoreVersion) - if err == nil { - t.Errorf("Loaded part #%d from storage when it should have been "+ - "deleted.", partNum) - } - } -} - -// Tests that a list marshalled via partStore.marshalList can be unmarshalled -// with unmarshalPartList to get the original list. -func Test_partStore_marshalList_unmarshalPartList(t *testing.T) { - prng := rand.New(rand.NewSource(42)) - kv := versioned.NewKV(make(ekv.Memstore)) - numParts := uint16(16) - ps, _ := newRandomPartStore(numParts, kv, prng, t) - - expected := make([]uint16, 0, 16) - for partNum := range ps.parts { - expected = append(expected, partNum) - } - - byteList := ps.marshalList() - - loadedNumParts, list := unmarshalPartList(byteList) - - if numParts != loadedNumParts { - t.Errorf("Loaded numParts does not match expected."+ - "\nexpected: %d\nrecieved: %d", numParts, loadedNumParts) - } - - sort.SliceStable(list, func(i, j int) bool { return list[i] < list[j] }) - sort.SliceStable(expected, - func(i, j int) bool { return expected[i] < expected[j] }) - - if !reflect.DeepEqual(expected, list) { - t.Errorf("Failed to marshal and unmarshal part list."+ - "\nexpected: %+v\nreceived: %+v", expected, list) - } -} - -// Tests that partSliceToMap correctly maps all the parts in a slice of parts -// to a map of parts keyed on their part number. -func Test_partSliceToMap(t *testing.T) { - prng := rand.New(rand.NewSource(42)) - - // Create list of file parts with random data - partSlice := make([][]byte, 8) - for i := range partSlice { - partSlice[i] = make([]byte, 16) - prng.Read(partSlice[i]) - } - - // Convert the slice of parts to a map - partMap := partSliceToMap(partSlice...) - - // Check that each part in the map matches a part in the slice - for partNum, slicePart := range partSlice { - // Check that the part exists in the map - mapPart, exists := partMap[uint16(partNum)] - if !exists { - t.Errorf("Part number %d does not exist in the map.", partNum) - } - - // Check that the part in the map is correct - if !bytes.Equal(slicePart, mapPart) { - t.Errorf("Part found in map does not match expected part in slice."+ - "\nexpected: %+v\nreceived: %+v", slicePart, mapPart) - } - - delete(partMap, uint16(partNum)) - } - - // Make sure there are no extra parts in the map - if len(partMap) != 0 { - t.Errorf("Part map contains %d extra parts not in the slice."+ - "\nparts: %+v", len(partMap), partMap) - } -} - -// Tests the consistency of makePartsKey. -func Test_makePartsKey_consistency(t *testing.T) { - expectedStrings := []string{ - partsStoreKey + "0", - partsStoreKey + "1", - partsStoreKey + "2", - partsStoreKey + "3", - } - - for i, expected := range expectedStrings { - key := makePartsKey(uint16(i)) - if key != expected { - t.Errorf("Key #%d does not match expected."+ - "\nexpected: %q\nreceived: %q", i, expected, key) - } - } -} - -// newRandomPartStore creates a new partStore filled with random data. Returns -// the random partStore and a slice of the original file. -func newRandomPartStore(numParts uint16, kv *versioned.KV, prng io.Reader, - t *testing.T) (*partStore, []byte) { - - cmixMsg := format.NewMessage(format.MinimumPrimeSize) - - partData, _ := NewPartMessage(cmixMsg.ContentsSize()) - partSize := partData.GetPartSize() - - ps, err := newPartStore(kv, numParts) - if err != nil { - t.Fatalf("Failed to create new partStore: %+v", err) - } - - fileBuff := bytes.NewBuffer(nil) - fileBuff.Grow(int(numParts) * partSize) - - for partNum := uint16(0); partNum < numParts; partNum++ { - ps.parts[partNum] = make([]byte, partSize) - _, err := prng.Read(ps.parts[partNum]) - if err != nil { - t.Errorf("Failed to generate random part (%d): %+v", partNum, err) - } - fileBuff.Write(ps.parts[partNum]) - } - - return ps, fileBuff.Bytes() -} - -// newRandomPartSlice returns a list of file parts and the file in one piece. -func newRandomPartSlice(numParts uint16, prng io.Reader, t *testing.T) ( - [][]byte, []byte) { - partSize := 64 - fileBuff := bytes.NewBuffer(make([]byte, 0, int(numParts)*partSize)) - partList := make([][]byte, numParts) - for partNum := range partList { - partList[partNum] = make([]byte, partSize) - _, err := prng.Read(partList[partNum]) - if err != nil { - t.Errorf("Failed to generate random part (%d): %+v", partNum, err) - } - fileBuff.Write(partList[partNum]) - } - - return partList, fileBuff.Bytes() -} diff --git a/storage/fileTransfer/receiveFileTransfers.go b/storage/fileTransfer/receiveFileTransfers.go deleted file mode 100644 index 8dcc8649c56f75e4e82727258de0c993fdfe2b39..0000000000000000000000000000000000000000 --- a/storage/fileTransfer/receiveFileTransfers.go +++ /dev/null @@ -1,360 +0,0 @@ -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - -package fileTransfer - -import ( - "bytes" - "github.com/pkg/errors" - jww "github.com/spf13/jwalterweatherman" - "gitlab.com/elixxir/client/storage/versioned" - ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" - "gitlab.com/elixxir/primitives/format" - "gitlab.com/xx_network/crypto/csprng" - "gitlab.com/xx_network/primitives/netTime" - "sync" -) - -const ( - receivedFileTransfersStorePrefix = "FileTransferReceivedFileTransfersStore" - receivedFileTransfersStoreKey = "ReceivedFileTransfers" - receivedFileTransfersStoreVersion = 0 -) - -// Error messages for ReceivedFileTransfersStore. -const ( - saveReceivedTransfersListErr = "failed to save list of received items in transfer map to storage: %+v" - loadReceivedTransfersListErr = "failed to load list of received items in transfer map from storage: %+v" - loadReceivedFileTransfersErr = "[FT] Failed to load received transfers from storage: %+v" - - newReceivedTransferErr = "failed to create new received transfer: %+v" - getReceivedTransferErr = "received transfer with ID %s not found" - addTransferNewIdErr = "could not generate new transfer ID: %+v" - noFingerprintErr = "no part found with fingerprint %s" - addPartErr = "failed to add part to transfer %s: %+v" - deleteReceivedTransferErr = "failed to delete received transfer with ID %s from store: %+v" - - // ReceivedFileTransfersStore.load - loadReceivedTransferWarn = "[FT] Failed to load received file transfer %d of %d with ID %s: %v" - loadReceivedTransfersAllErr = "failed to load all %d transfers" -) - -// ReceivedFileTransfersStore contains information for tracking a received -// ReceivedTransfer to its transfer ID. It also maps a received part, partInfo, -// to the message fingerprint. -type ReceivedFileTransfersStore struct { - transfers map[ftCrypto.TransferID]*ReceivedTransfer - info map[format.Fingerprint]*partInfo - mux sync.Mutex - kv *versioned.KV -} - -// NewReceivedFileTransfersStore creates a new ReceivedFileTransfersStore with -// empty maps. -func NewReceivedFileTransfersStore(kv *versioned.KV) ( - *ReceivedFileTransfersStore, error) { - rft := &ReceivedFileTransfersStore{ - transfers: make(map[ftCrypto.TransferID]*ReceivedTransfer), - info: make(map[format.Fingerprint]*partInfo), - kv: kv.Prefix(receivedFileTransfersStorePrefix), - } - - return rft, rft.saveTransfersList() -} - -// AddTransfer creates a new empty ReceivedTransfer, adds it to the transfers -// map, and adds an entry for each file part fingerprint to the fingerprint map. -func (rft *ReceivedFileTransfersStore) AddTransfer(key ftCrypto.TransferKey, - transferMAC []byte, fileSize uint32, numParts, numFps uint16, - rng csprng.Source) (ftCrypto.TransferID, error) { - - rft.mux.Lock() - defer rft.mux.Unlock() - - // Generate new transfer ID - tid, err := ftCrypto.NewTransferID(rng) - if err != nil { - return tid, errors.Errorf(addTransferNewIdErr, err) - } - - // Generate a new ReceivedTransfer and add it to the map - rft.transfers[tid], err = NewReceivedTransfer( - tid, key, transferMAC, fileSize, numParts, numFps, rft.kv) - if err != nil { - return tid, errors.Errorf(newReceivedTransferErr, err) - } - - // Add part info for each file part to the map - rft.addFingerprints(key, tid, numFps) - - // Update list of transfers in storage - err = rft.saveTransfersList() - if err != nil { - return tid, errors.Errorf(saveReceivedTransfersListErr, err) - } - - return tid, nil -} - -// addFingerprints generates numFps fingerprints, creates a new partInfo -// for each, and adds each to the info map. -func (rft *ReceivedFileTransfersStore) addFingerprints(key ftCrypto.TransferKey, - tid ftCrypto.TransferID, numFps uint16) { - - // Generate list of fingerprints - fps := ftCrypto.GenerateFingerprints(key, numFps) - - // Add fingerprints to map - for fpNum, fp := range fps { - rft.info[fp] = newPartInfo(tid, uint16(fpNum)) - } -} - -// GetTransfer returns the ReceivedTransfer with the given transfer ID. An error -// is returned if no corresponding transfer is found. -func (rft *ReceivedFileTransfersStore) GetTransfer(tid ftCrypto.TransferID) ( - *ReceivedTransfer, error) { - rft.mux.Lock() - defer rft.mux.Unlock() - - rt, exists := rft.transfers[tid] - if !exists { - return nil, errors.Errorf(getReceivedTransferErr, tid) - } - - return rt, nil -} - -// DeleteTransfer removes the ReceivedTransfer with the associated transfer ID -// from memory and storage. -func (rft *ReceivedFileTransfersStore) DeleteTransfer(tid ftCrypto.TransferID) error { - rft.mux.Lock() - defer rft.mux.Unlock() - - // Return an error if the transfer does not exist - rt, exists := rft.transfers[tid] - if !exists { - return errors.Errorf(getReceivedTransferErr, tid) - } - - // Cancel any scheduled callbacks - err := rt.stopScheduledProgressCB() - if err != nil { - jww.WARN.Print(errors.Errorf(cancelCallbackErr, tid, err)) - } - - // Remove all unused fingerprints from map - for n, err := rt.fpVector.Next(); err == nil; n, err = rt.fpVector.Next() { - // Generate fingerprint - fp := ftCrypto.GenerateFingerprint(rt.key, uint16(n)) - - // Delete fingerprint from map - delete(rft.info, fp) - } - - // Delete all data the transfer saved to storage - err = rft.transfers[tid].delete() - if err != nil { - return errors.Errorf(deleteReceivedTransferErr, tid, err) - } - - // Delete the transfer from memory - delete(rft.transfers, tid) - - // Update the transfers list for the removed transfer - err = rft.saveTransfersList() - if err != nil { - return errors.Errorf(saveReceivedTransfersListErr, err) - } - - return nil -} - -// AddPart adds the file part to its corresponding transfer. The fingerprint -// number and transfer ID are looked up using the fingerprint. Then the part is -// added to the transfer with the corresponding transfer ID. Returns the -// transfer that the part was added to so that a progress callback can be -// called. Returns the transfer ID so that it can be used for logging. Also -// returns of the transfer is complete after adding the part. -func (rft *ReceivedFileTransfersStore) AddPart(cmixMsg format.Message) (*ReceivedTransfer, - ftCrypto.TransferID, bool, error) { - rft.mux.Lock() - defer rft.mux.Unlock() - - keyfp := cmixMsg.GetKeyFP() - - // Lookup the part info for the given fingerprint - info, exists := rft.info[cmixMsg.GetKeyFP()] - if !exists { - return nil, ftCrypto.TransferID{}, false, - errors.Errorf(noFingerprintErr, keyfp) - } - - // Lookup the transfer with the ID in the part info - transfer, exists := rft.transfers[info.id] - if !exists { - return nil, info.id, false, - errors.Errorf(getReceivedTransferErr, info.id) - } - - // Add the part to the transfer - completed, err := transfer.AddPart(cmixMsg, info.fpNum) - if err != nil { - return transfer, info.id, false, errors.Errorf( - addPartErr, info.id, err) - } - - // Remove the part info from the map - delete(rft.info, keyfp) - - return transfer, info.id, completed, nil -} - -//////////////////////////////////////////////////////////////////////////////// -// Storage Functions // -//////////////////////////////////////////////////////////////////////////////// - -// LoadReceivedFileTransfersStore loads all ReceivedFileTransfersStore from -// storage. -func LoadReceivedFileTransfersStore(kv *versioned.KV) ( - *ReceivedFileTransfersStore, error) { - rft := &ReceivedFileTransfersStore{ - transfers: make(map[ftCrypto.TransferID]*ReceivedTransfer), - info: make(map[format.Fingerprint]*partInfo), - kv: kv.Prefix(receivedFileTransfersStorePrefix), - } - - // Get the list of transfer IDs corresponding to each received transfer from - // storage - transfersList, err := rft.loadTransfersList() - if err != nil { - return nil, errors.Errorf(loadReceivedTransfersListErr, err) - } - - // Load transfers and fingerprints into the maps - err = rft.load(transfersList) - if err != nil { - return nil, errors.Errorf(loadReceivedFileTransfersErr, err) - } - - return rft, nil -} - -// NewOrLoadReceivedFileTransfersStore loads all ReceivedFileTransfersStore from -// storage, if they exist. Otherwise, a new ReceivedFileTransfersStore is -// returned. -func NewOrLoadReceivedFileTransfersStore(kv *versioned.KV) ( - *ReceivedFileTransfersStore, error) { - rft := &ReceivedFileTransfersStore{ - transfers: make(map[ftCrypto.TransferID]*ReceivedTransfer), - info: make(map[format.Fingerprint]*partInfo), - kv: kv.Prefix(receivedFileTransfersStorePrefix), - } - - // If the transfer list cannot be loaded from storage, then create a new - // ReceivedFileTransfersStore - vo, err := rft.kv.Get( - receivedFileTransfersStoreKey, receivedFileTransfersStoreVersion) - if err != nil { - return NewReceivedFileTransfersStore(kv) - } - - // Unmarshal data into list of saved transfer IDs - transfersList := unmarshalTransfersList(vo.Data) - - // Load transfers and fingerprints into the maps - err = rft.load(transfersList) - if err != nil { - jww.ERROR.Printf(loadReceivedFileTransfersErr, err) - return NewReceivedFileTransfersStore(kv) - } - - return rft, nil -} - -// saveTransfersList saves a list of items in the transfers map to storage. -func (rft *ReceivedFileTransfersStore) saveTransfersList() error { - // Create new versioned object with a list of items in the transfers map - obj := &versioned.Object{ - Version: receivedFileTransfersStoreVersion, - Timestamp: netTime.Now(), - Data: rft.marshalTransfersList(), - } - - // Save list of items in the transfers map to storage - return rft.kv.Set( - receivedFileTransfersStoreKey, receivedFileTransfersStoreVersion, obj) -} - -// loadTransfersList gets the list of transfer IDs corresponding to each saved -// received transfer from storage. -func (rft *ReceivedFileTransfersStore) loadTransfersList() ( - []ftCrypto.TransferID, error) { - // Get transfers list from storage - vo, err := rft.kv.Get( - receivedFileTransfersStoreKey, receivedFileTransfersStoreVersion) - if err != nil { - return nil, err - } - - // Unmarshal data into list of saved transfer IDs - return unmarshalTransfersList(vo.Data), nil -} - -// load gets each ReceivedTransfer in the list from storage and adds them to the -// map. Also adds all unused fingerprints in each ReceivedTransfer to the info -// map. -func (rft *ReceivedFileTransfersStore) load(list []ftCrypto.TransferID) error { - var errCount int - - // Load each sentTransfer from storage into the map - for i, tid := range list { - // Load the transfer with the given transfer ID from storage - rt, err := loadReceivedTransfer(tid, rft.kv) - if err != nil { - jww.WARN.Printf(loadReceivedTransferWarn, i, len(list), tid, err) - errCount++ - continue - } - - // Add transfer to transfer map - rft.transfers[tid] = rt - - // Load all unused fingerprints into the info map - for n := uint32(0); n < rt.fpVector.GetNumKeys(); n++ { - if !rt.fpVector.Used(n) { - fpNum := uint16(n) - - // Generate fingerprint - fp := ftCrypto.GenerateFingerprint(rt.key, fpNum) - - // Add to map - rft.info[fp] = &partInfo{tid, fpNum} - } - } - } - - // Return an error if all transfers failed to load - if errCount == len(list) { - return errors.Errorf(loadReceivedTransfersAllErr, len(list)) - } - - return nil -} - -// marshalTransfersList creates a list of all transfer IDs in the transfers map -// and serialises it. -func (rft *ReceivedFileTransfersStore) marshalTransfersList() []byte { - buff := bytes.NewBuffer(nil) - buff.Grow(ftCrypto.TransferIdLength * len(rft.transfers)) - - for tid := range rft.transfers { - buff.Write(tid.Bytes()) - } - - return buff.Bytes() -} diff --git a/storage/fileTransfer/receiveFileTransfers_test.go b/storage/fileTransfer/receiveFileTransfers_test.go deleted file mode 100644 index 1eee9480f9c28fa0c82dc6c8a46a4e32edee1e6e..0000000000000000000000000000000000000000 --- a/storage/fileTransfer/receiveFileTransfers_test.go +++ /dev/null @@ -1,861 +0,0 @@ -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - -package fileTransfer - -import ( - "bytes" - "fmt" - "gitlab.com/elixxir/client/storage/versioned" - ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" - "gitlab.com/elixxir/ekv" - "gitlab.com/elixxir/primitives/format" - "gitlab.com/xx_network/primitives/netTime" - "reflect" - "sort" - "strings" - "testing" -) - -// Tests that NewReceivedFileTransfersStore creates a new object with empty maps -// and that it is saved to storage -func TestNewReceivedFileTransfersStore(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - expectedRFT := &ReceivedFileTransfersStore{ - transfers: make(map[ftCrypto.TransferID]*ReceivedTransfer), - info: make(map[format.Fingerprint]*partInfo), - kv: kv.Prefix(receivedFileTransfersStorePrefix), - } - - rft, err := NewReceivedFileTransfersStore(kv) - if err != nil { - t.Errorf("NewReceivedFileTransfersStore returned an error: %+v", err) - } - - // Check that the new ReceivedFileTransfersStore matches the expected - if !reflect.DeepEqual(expectedRFT, rft) { - t.Errorf("New ReceivedFileTransfersStore does not match expected."+ - "\nexpected: %+v\nreceived: %+v", expectedRFT, rft) - } - - // Ensure that the transfer list is saved to storage - _, err = expectedRFT.kv.Get( - receivedFileTransfersStoreKey, receivedFileTransfersStoreVersion) - if err != nil { - t.Errorf("Failed to load transfer list from storage: %+v", err) - } -} - -// Tests that ReceivedFileTransfersStore.AddTransfer adds a new transfer and -// adds all the fingerprints to the info map. -func TestReceivedFileTransfersStore_AddTransfer(t *testing.T) { - prng := NewPrng(42) - kv := versioned.NewKV(make(ekv.Memstore)) - rft, err := NewReceivedFileTransfersStore(kv) - if err != nil { - t.Fatalf("Failed to create new ReceivedFileTransfersStore: %+v", err) - } - - // Generate info for new transfer - key, _ := ftCrypto.NewTransferKey(prng) - mac := []byte("transferMAC") - fileSize := uint32(256) - numParts, numFps := uint16(16), uint16(24) - - // Add the transfer - tid, err := rft.AddTransfer(key, mac, fileSize, numParts, numFps, prng) - if err != nil { - t.Errorf("AddTransfer returned an error: %+v", err) - } - - // Check that the transfer was added to the map - transfer, exists := rft.transfers[tid] - if !exists { - t.Errorf("Transfer with ID %s not found.", tid) - } else { - if transfer.GetFileSize() != fileSize { - t.Errorf("New transfer has incorrect file size."+ - "\nexpected: %d\nreceived: %d", fileSize, transfer.GetFileSize()) - } - if transfer.GetNumParts() != numParts { - t.Errorf("New transfer has incorrect number of parts."+ - "\nexpected: %d\nreceived: %d", numParts, transfer.GetNumParts()) - } - if transfer.GetTransferKey() != key { - t.Errorf("New transfer has incorrect transfer key."+ - "\nexpected: %s\nreceived: %s", key, transfer.GetTransferKey()) - } - if transfer.GetNumFps() != numFps { - t.Errorf("New transfer has incorrect number of fingerprints."+ - "\nexpected: %d\nreceived: %d", numFps, transfer.GetNumFps()) - } - } - - // Check that the transfer was added to storage - _, err = loadReceivedTransfer(tid, rft.kv) - if err != nil { - t.Errorf("Transfer with ID %s not found in storage: %+v", tid, err) - } - - // Check that all the fingerprints are in the info map - for fpNum, fp := range ftCrypto.GenerateFingerprints(key, numFps) { - info, exists := rft.info[fp] - if !exists { - t.Errorf("Part fingerprint %s (#%d) not found.", fp, fpNum) - } - - if int(info.fpNum) != fpNum { - t.Errorf("Fingerprint %s has incorrect fingerprint number."+ - "\nexpected: %d\nreceived: %d", fp, fpNum, info.fpNum) - } - - if info.id != tid { - t.Errorf("Fingerprint %s has incorrect transfer ID."+ - "\nexpected: %s\nreceived: %s", fp, tid, info.id) - } - } -} - -// Error path: tests that ReceivedFileTransfersStore.AddTransfer returns the -// expected error when the PRNG returns an error. -func TestReceivedFileTransfersStore_AddTransfer_NewTransferIdRngError(t *testing.T) { - prng := NewPrngErr() - kv := versioned.NewKV(make(ekv.Memstore)) - rft, err := NewReceivedFileTransfersStore(kv) - if err != nil { - t.Fatalf("Failed to create new ReceivedFileTransfersStore: %+v", err) - } - - // Add the transfer - expectedErr := strings.Split(addTransferNewIdErr, "%")[0] - _, err = rft.AddTransfer(ftCrypto.TransferKey{}, nil, 0, 0, 0, prng) - if err == nil || !strings.Contains(err.Error(), expectedErr) { - t.Errorf("AddTransfer did not return the expected error when the PRNG "+ - "should have errored.\nexpected: %s\nrecieved: %+v", expectedErr, err) - } - -} - -// Tests that ReceivedFileTransfersStore.addFingerprints adds all the -// fingerprints to the map -func TestReceivedFileTransfersStore_addFingerprints(t *testing.T) { - prng := NewPrng(42) - kv := versioned.NewKV(make(ekv.Memstore)) - rft, err := NewReceivedFileTransfersStore(kv) - if err != nil { - t.Fatalf("Failed to create new ReceivedFileTransfersStore: %+v", err) - } - - key, _ := ftCrypto.NewTransferKey(prng) - tid, _ := ftCrypto.NewTransferID(prng) - numFps := uint16(24) - - rft.addFingerprints(key, tid, numFps) - - // Check that all the fingerprints are in the info map - for fpNum, fp := range ftCrypto.GenerateFingerprints(key, numFps) { - info, exists := rft.info[fp] - if !exists { - t.Errorf("Part fingerprint %s (#%d) not found.", fp, fpNum) - } - - if int(info.fpNum) != fpNum { - t.Errorf("Fingerprint %s has incorrect fingerprint number."+ - "\nexpected: %d\nreceived: %d", fp, fpNum, info.fpNum) - } - - if info.id != tid { - t.Errorf("Fingerprint %s has incorrect transfer ID."+ - "\nexpected: %s\nreceived: %s", fp, tid, info.id) - } - } -} - -// Tests that ReceivedFileTransfersStore.GetTransfer returns the newly added -// transfer. -func TestReceivedFileTransfersStore_GetTransfer(t *testing.T) { - prng := NewPrng(42) - kv := versioned.NewKV(make(ekv.Memstore)) - rft, err := NewReceivedFileTransfersStore(kv) - if err != nil { - t.Fatalf("Failed to create new ReceivedFileTransfersStore: %+v", err) - } - - // Generate random info for new transfer - key, _ := ftCrypto.NewTransferKey(prng) - mac := []byte("transferMAC") - fileSize := uint32(256) - numParts, numFps := uint16(16), uint16(24) - - // Add the transfer - tid, err := rft.AddTransfer(key, mac, fileSize, numParts, numFps, prng) - if err != nil { - t.Errorf("Failed to add new transfer: %+v", err) - } - - // Get the transfer - transfer, err := rft.GetTransfer(tid) - if err != nil { - t.Errorf("GetTransfer returned an error: %+v", err) - } - - if transfer.GetFileSize() != fileSize { - t.Errorf("New transfer has incorrect file size."+ - "\nexpected: %d\nreceived: %d", fileSize, transfer.GetFileSize()) - } - - if transfer.GetNumParts() != numParts { - t.Errorf("New transfer has incorrect number of parts."+ - "\nexpected: %d\nreceived: %d", numParts, transfer.GetNumParts()) - } - - if transfer.GetNumFps() != numFps { - t.Errorf("New transfer has incorrect number of fingerprints."+ - "\nexpected: %d\nreceived: %d", numFps, transfer.GetNumFps()) - } - - if transfer.GetTransferKey() != key { - t.Errorf("New transfer has incorrect transfer key."+ - "\nexpected: %s\nreceived: %s", key, transfer.GetTransferKey()) - } -} - -// Error path: tests that ReceivedFileTransfersStore.GetTransfer returns the -// expected error when the provided transfer ID does not correlate to any saved -// transfer. -func TestReceivedFileTransfersStore_GetTransfer_NoTransferError(t *testing.T) { - prng := NewPrng(42) - kv := versioned.NewKV(make(ekv.Memstore)) - rft, err := NewReceivedFileTransfersStore(kv) - if err != nil { - t.Fatalf("Failed to create new ReceivedFileTransfersStore: %+v", err) - } - - // Generate random info for new transfer - key, _ := ftCrypto.NewTransferKey(prng) - mac := []byte("transferMAC") - - // Add the transfer - _, err = rft.AddTransfer(key, mac, 256, 16, 24, prng) - if err != nil { - t.Errorf("Failed to add new transfer: %+v", err) - } - - // Get the transfer - invalidTid, _ := ftCrypto.NewTransferID(prng) - expectedErr := fmt.Sprintf(getReceivedTransferErr, invalidTid) - _, err = rft.GetTransfer(invalidTid) - if err == nil || err.Error() != expectedErr { - t.Errorf("GetTransfer did not return the expected error when no "+ - "transfer for the ID exists.\nexpected: %s\nreceived: %+v", - expectedErr, err) - } -} - -// Tests that DeleteTransfer removed a transfer from memory and storage. -func TestReceivedFileTransfersStore_DeleteTransfer(t *testing.T) { - prng := NewPrng(42) - kv := versioned.NewKV(make(ekv.Memstore)) - rft, err := NewReceivedFileTransfersStore(kv) - if err != nil { - t.Fatalf("Failed to create new ReceivedFileTransfersStore: %+v", err) - } - - // Add the transfer - key, _ := ftCrypto.NewTransferKey(prng) - mac := []byte("transferMAC") - numFps := uint16(24) - tid, err := rft.AddTransfer(key, mac, 256, 16, numFps, prng) - if err != nil { - t.Errorf("Failed to add new transfer: %+v", err) - } - - // Delete the transfer - err = rft.DeleteTransfer(tid) - if err != nil { - t.Errorf("DeleteTransfer returned an error: %+v", err) - } - - // Check that the transfer was deleted from the map - _, exists := rft.transfers[tid] - if exists { - t.Errorf("Transfer with ID %s found in map when it should have been "+ - "deleted.", tid) - } - - // Check that the transfer was deleted from storage - _, err = loadReceivedTransfer(tid, rft.kv) - if err == nil { - t.Errorf("Transfer with ID %s found in storage when it should have "+ - "been deleted.", tid) - } - - // Check that all the fingerprints in the info map were deleted - for fpNum, fp := range ftCrypto.GenerateFingerprints(key, numFps) { - _, exists = rft.info[fp] - if exists { - t.Errorf("Part fingerprint %s (#%d) found in map when it should "+ - "have been deleted.", fp, fpNum) - } - } -} - -// Error path: tests that ReceivedFileTransfersStore.DeleteTransfer returns the -// expected error when the provided transfer ID does not correlate to any saved -// transfer. -func TestReceivedFileTransfersStore_DeleteTransfer_NoTransferError(t *testing.T) { - prng := NewPrng(42) - kv := versioned.NewKV(make(ekv.Memstore)) - rft, err := NewReceivedFileTransfersStore(kv) - if err != nil { - t.Fatalf("Failed to create new ReceivedFileTransfersStore: %+v", err) - } - - // Delete the transfer - invalidTid, _ := ftCrypto.NewTransferID(prng) - expectedErr := fmt.Sprintf(getReceivedTransferErr, invalidTid) - err = rft.DeleteTransfer(invalidTid) - if err == nil || err.Error() != expectedErr { - t.Errorf("DeleteTransfer did not return the expected error when no "+ - "transfer for the ID exists.\nexpected: %s\nreceived: %+v", - expectedErr, err) - } -} - -// Tests that ReceivedFileTransfersStore.AddPart modifies the expected transfer -// in memory. -func TestReceivedFileTransfersStore_AddPart(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - rft, err := NewReceivedFileTransfersStore(kv) - if err != nil { - t.Fatalf("Failed to create new ReceivedFileTransfersStore: %+v", err) - } - - prng := NewPrng(42) - key, _ := ftCrypto.NewTransferKey(prng) - mac := []byte("transferMAC") - - tid, err := rft.AddTransfer(key, mac, 256, 16, 24, prng) - if err != nil { - t.Errorf("Failed to add new transfer: %+v", err) - } - - // Create encrypted part - - cmixMsg := format.NewMessage(format.MinimumPrimeSize) - - expectedData := []byte("test") - - partNum, fpNum := uint16(1), uint16(1) - - partData, _ := NewPartMessage(cmixMsg.ContentsSize()) - partData.SetPartNum(partNum) - _ = partData.SetPart(expectedData) - - fp := ftCrypto.GenerateFingerprint(key, fpNum) - encryptedPart, mac, err := ftCrypto.EncryptPart(key, partData.Marshal(), fpNum, fp) - - cmixMsg.SetKeyFP(fp) - cmixMsg.SetContents(encryptedPart) - cmixMsg.SetMac(mac) - - // Add encrypted part - rt, _, _, err := rft.AddPart(cmixMsg) - if err != nil { - t.Errorf("AddPart returned an error: %+v", err) - } - - // Make sure its fingerprint was removed from the map - _, exists := rft.info[fp] - if exists { - t.Errorf("Fingerprints %s for added part found when it should have "+ - "been deleted.", fp) - } - - // Check that the transfer is correct - expectedRT := rft.transfers[tid] - if !reflect.DeepEqual(expectedRT, rt) { - t.Errorf("Returned transfer does not match expected."+ - "\nexpected: %+v\nreceived: %+v", expectedRT, rt) - } - - // Check that the correct part was stored - receivedPart := expectedRT.receivedParts.parts[partNum] - if !bytes.Equal(receivedPart[:len(expectedData)], expectedData) { - t.Errorf("Part in memory is not expected."+ - "\nexpected: %q\nreceived: %q", expectedData, receivedPart[:len(expectedData)]) - } -} - -// Error path: tests that ReceivedFileTransfersStore.AddPart returns the -// expected error when the provided fingerprint does not correlate to any part. -func TestReceivedFileTransfersStore_AddPart_NoFingerprintError(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - rft, err := NewReceivedFileTransfersStore(kv) - if err != nil { - t.Fatalf("Failed to create new ReceivedFileTransfersStore: %+v", err) - } - - // Create encrypted part - fp := format.NewFingerprint([]byte("invalidTransferKey")) - - msg := format.NewMessage(1000) - msg.SetKeyFP(fp) - - // Add encrypted part - expectedErr := fmt.Sprintf(noFingerprintErr, fp) - _, _, _, err = rft.AddPart(msg) - if err == nil || err.Error() != expectedErr { - t.Errorf("AddPart did not return the expected error when no part for "+ - "the fingerprint exists.\nexpected: %s\nreceived: %+v", - expectedErr, err) - } -} - -// Error path: tests that ReceivedFileTransfersStore.AddPart returns the -// expected error when the provided transfer ID does not correlate to any saved -// transfer. -func TestReceivedFileTransfersStore_AddPart_NoTransferError(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - rft, err := NewReceivedFileTransfersStore(kv) - if err != nil { - t.Fatalf("Failed to create new ReceivedFileTransfersStore: %+v", err) - } - - prng := NewPrng(42) - key, _ := ftCrypto.NewTransferKey(prng) - mac := []byte("transferMAC") - - _, err = rft.AddTransfer(key, mac, 256, 16, 24, prng) - if err != nil { - t.Errorf("Failed to add new transfer: %+v", err) - } - - // Create encrypted part - fp := ftCrypto.GenerateFingerprint(key, 1) - invalidTid, _ := ftCrypto.NewTransferID(prng) - rft.info[fp].id = invalidTid - - msg := format.NewMessage(1000) - msg.SetKeyFP(fp) - - // Add encrypted part - expectedErr := fmt.Sprintf(getReceivedTransferErr, invalidTid) - _, _, _, err = rft.AddPart(msg) - if err == nil || err.Error() != expectedErr { - t.Errorf("AddPart did not return the expected error when no transfer "+ - "for the ID exists.\nexpected: %s\nreceived: %+v", expectedErr, err) - } -} - -// Error path: tests that ReceivedFileTransfersStore.AddPart returns the -// expected error when the encrypted part data, MAC, and padding are invalid. -func TestReceivedFileTransfersStore_AddPart_AddPartError(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - rft, err := NewReceivedFileTransfersStore(kv) - if err != nil { - t.Fatalf("Failed to create new ReceivedFileTransfersStore: %+v", err) - } - - prng := NewPrng(42) - key, _ := ftCrypto.NewTransferKey(prng) - mac := []byte("transferMAC") - numParts := uint16(16) - - tid, err := rft.AddTransfer(key, mac, 256, numParts, 24, prng) - if err != nil { - t.Errorf("Failed to add new transfer: %+v", err) - } - - // Create encrypted part - partNum, fpNum := uint16(1), uint16(1) - part := []byte("invalidPart") - mac = make([]byte, format.MacLen) - fp := ftCrypto.GenerateFingerprint(key, fpNum) - - // Add encrypted part - expectedErr := fmt.Sprintf(addPartErr, tid, "") - - cmixMsg := format.NewMessage(format.MinimumPrimeSize) - - partData, _ := NewPartMessage(cmixMsg.ContentsSize()) - partData.SetPartNum(partNum) - _ = partData.SetPart(part) - - cmixMsg.SetKeyFP(fp) - cmixMsg.SetContents(partData.Marshal()) - cmixMsg.SetMac(mac) - - _, _, _, err = rft.AddPart(cmixMsg) - if err == nil || !strings.Contains(err.Error(), expectedErr) { - t.Errorf("AddPart did not return the expected error when the "+ - "encrypted part, padding, and MAC are invalid."+ - "\nexpected: %s\nreceived: %+v", expectedErr, err) - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Storage Functions // -//////////////////////////////////////////////////////////////////////////////// - -// Tests that the ReceivedFileTransfersStore loaded from storage by -// LoadReceivedFileTransfersStore matches the original in memory. -func TestLoadReceivedFileTransfersStore(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - rft, err := NewReceivedFileTransfersStore(kv) - if err != nil { - t.Fatalf("Failed to make new ReceivedFileTransfersStore: %+v", err) - } - - // Add 10 transfers to map in memory - list := make([]ftCrypto.TransferID, 10) - for i := range list { - tid, rt, _ := newRandomReceivedTransfer(16, 24, rft.kv, t) - - // Add to transfer - rft.transfers[tid] = rt - - // Add unused fingerprints - for n := uint32(0); n < rt.fpVector.GetNumKeys(); n++ { - if !rt.fpVector.Used(n) { - fp := ftCrypto.GenerateFingerprint(rt.key, uint16(n)) - rft.info[fp] = &partInfo{tid, uint16(n)} - } - } - - // Add ID to list - list[i] = tid - } - - // Save list to storage - if err = rft.saveTransfersList(); err != nil { - t.Errorf("Faileds to save transfers list: %+v", err) - } - - // Load ReceivedFileTransfersStore from storage - loadedRFT, err := LoadReceivedFileTransfersStore(kv) - if err != nil { - t.Errorf("LoadReceivedFileTransfersStore returned an error: %+v", err) - } - - if !reflect.DeepEqual(rft, loadedRFT) { - t.Errorf("Loaded ReceivedFileTransfersStore does not match original"+ - "in memory.\nexpected: %+v\nreceived: %+v", rft, loadedRFT) - } -} - -// Error path: tests that ReceivedFileTransfersStore returns the expected error -// when the transfer list cannot be loaded from storage. -func TestLoadReceivedFileTransfersStore_NoListInStorageError(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - expectedErr := strings.Split(loadReceivedTransfersListErr, "%")[0] - - // Load ReceivedFileTransfersStore from storage - _, err := LoadReceivedFileTransfersStore(kv) - if err == nil || !strings.Contains(err.Error(), expectedErr) { - t.Errorf("LoadReceivedFileTransfersStore did not return the expected "+ - "error when there is no transfer list saved in storage."+ - "\nexpected: %s\nreceived: %+v", expectedErr, err) - } -} - -// Error path: tests that ReceivedFileTransfersStore returns the expected error -// when the first transfer loaded from storage does not exist. -func TestLoadReceivedFileTransfersStore_NoTransferInStorageError(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - expectedErr := strings.Split(loadReceivedFileTransfersErr, "%")[0] - - // Save list of one transfer ID to storage - obj := &versioned.Object{ - Version: receivedFileTransfersStoreVersion, - Timestamp: netTime.Now(), - Data: ftCrypto.UnmarshalTransferID([]byte("testID_01")).Bytes(), - } - err := kv.Prefix(receivedFileTransfersStorePrefix).Set( - receivedFileTransfersStoreKey, receivedFileTransfersStoreVersion, obj) - - // Load ReceivedFileTransfersStore from storage - _, err = LoadReceivedFileTransfersStore(kv) - if err == nil || !strings.Contains(err.Error(), expectedErr) { - t.Errorf("LoadReceivedFileTransfersStore did not return the expected "+ - "error when there is no transfer saved in storage."+ - "\nexpected: %s\nreceived: %+v", expectedErr, err) - } -} - -// Tests that the ReceivedFileTransfersStore loaded from storage by -// NewOrLoadReceivedFileTransfersStore matches the original in memory. -func TestNewOrLoadReceivedFileTransfersStore(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - rft, err := NewReceivedFileTransfersStore(kv) - if err != nil { - t.Fatalf("Failed to make new ReceivedFileTransfersStore: %+v", err) - } - - // Add 10 transfers to map in memory - list := make([]ftCrypto.TransferID, 10) - for i := range list { - tid, rt, _ := newRandomReceivedTransfer(16, 24, rft.kv, t) - - // Add to transfer - rft.transfers[tid] = rt - - // Add unused fingerprints - for n := uint32(0); n < rt.fpVector.GetNumKeys(); n++ { - if !rt.fpVector.Used(n) { - fp := ftCrypto.GenerateFingerprint(rt.key, uint16(n)) - rft.info[fp] = &partInfo{tid, uint16(n)} - } - } - - // Add ID to list - list[i] = tid - } - - // Save list to storage - if err = rft.saveTransfersList(); err != nil { - t.Errorf("Faileds to save transfers list: %+v", err) - } - - // Load ReceivedFileTransfersStore from storage - loadedRFT, err := NewOrLoadReceivedFileTransfersStore(kv) - if err != nil { - t.Errorf("NewOrLoadReceivedFileTransfersStore returned an error: %+v", - err) - } - - if !reflect.DeepEqual(rft, loadedRFT) { - t.Errorf("Loaded ReceivedFileTransfersStore does not match original "+ - "in memory.\nexpected: %+v\nreceived: %+v", rft, loadedRFT) - } -} - -// Tests that NewOrLoadReceivedFileTransfersStore returns a new -// ReceivedFileTransfersStore when there is none in storage. -func TestNewOrLoadReceivedFileTransfersStore_NewReceivedFileTransfersStore(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - - // Load ReceivedFileTransfersStore from storage - loadedRFT, err := NewOrLoadReceivedFileTransfersStore(kv) - if err != nil { - t.Errorf("NewOrLoadReceivedFileTransfersStore returned an error: %+v", - err) - } - - newRFT, _ := NewReceivedFileTransfersStore(kv) - - if !reflect.DeepEqual(newRFT, loadedRFT) { - t.Errorf("Returned ReceivedFileTransfersStore does not match new."+ - "\nexpected: %+v\nreceived: %+v", newRFT, loadedRFT) - } -} - -// Tests that the list saved by ReceivedFileTransfersStore.saveTransfersList -// matches the list loaded by ReceivedFileTransfersStore.load. -func TestReceivedFileTransfersStore_saveTransfersList_loadTransfersList(t *testing.T) { - rft, err := NewReceivedFileTransfersStore(versioned.NewKV(make(ekv.Memstore))) - if err != nil { - t.Fatalf("Failed to create new ReceivedFileTransfersStore: %+v", err) - } - - // Fill map with transfers - expectedList := make([]ftCrypto.TransferID, 7) - for i := range expectedList { - prng := NewPrng(int64(i)) - key, _ := ftCrypto.NewTransferKey(prng) - mac := []byte("transferMAC") - numParts := uint16(i) - numFps := uint16(float32(i) * 2.5) - - expectedList[i], err = rft.AddTransfer( - key, mac, 256, numParts, numFps, prng) - if err != nil { - t.Errorf("Failed to add new transfer #%d: %+v", i, err) - } - } - - // Save the list - err = rft.saveTransfersList() - if err != nil { - t.Errorf("saveTransfersList returned an error: %+v", err) - } - - // Load the list - loadedList, err := rft.loadTransfersList() - if err != nil { - t.Errorf("loadTransfersList returned an error: %+v", err) - } - - // Sort slices so they can be compared - sort.SliceStable(expectedList, func(i, j int) bool { - return bytes.Compare(expectedList[i].Bytes(), expectedList[j].Bytes()) == -1 - }) - sort.SliceStable(loadedList, func(i, j int) bool { - return bytes.Compare(loadedList[i].Bytes(), loadedList[j].Bytes()) == -1 - }) - - if !reflect.DeepEqual(expectedList, loadedList) { - t.Errorf("Loaded transfer list does not match expected."+ - "\nexpected: %v\nreceived: %v", expectedList, loadedList) - } -} - -// Tests that the list loaded by ReceivedFileTransfersStore.load matches the -// original saved to storage. -func TestReceivedFileTransfersStore_load(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - rft, err := NewReceivedFileTransfersStore(kv) - if err != nil { - t.Fatalf("Failed to create new ReceivedFileTransfersStore: %+v", err) - } - - // Fill map with transfers - idList := make([]ftCrypto.TransferID, 7) - for i := range idList { - prng := NewPrng(int64(i)) - key, _ := ftCrypto.NewTransferKey(prng) - mac := []byte("transferMAC") - - idList[i], err = rft.AddTransfer(key, mac, 256, 16, 24, prng) - if err != nil { - t.Errorf("Failed to add new transfer #%d: %+v", i, err) - } - } - - // Save the list - err = rft.saveTransfersList() - if err != nil { - t.Errorf("saveTransfersList returned an error: %+v", err) - } - - // Build new ReceivedFileTransfersStore - newRFT := &ReceivedFileTransfersStore{ - transfers: make(map[ftCrypto.TransferID]*ReceivedTransfer), - info: make(map[format.Fingerprint]*partInfo), - kv: kv.Prefix(receivedFileTransfersStorePrefix), - } - - // Load saved transfers from storage - err = newRFT.load(idList) - if err != nil { - t.Errorf("load returned an error: %+v", err) - } - - // Check that all transfer were loaded from storage - for _, id := range idList { - transfer, exists := newRFT.transfers[id] - if !exists { - t.Errorf("Transfer %s not loaded from storage.", id) - } - - // Check that the loaded transfer matches the original in memory - if !reflect.DeepEqual(transfer, rft.transfers[id]) { - t.Errorf("Loaded transfer does not match original."+ - "\noriginal: %+v\nreceived: %+v", rft.transfers[id], transfer) - } - - // Make sure all fingerprints are present - for fpNum, fp := range ftCrypto.GenerateFingerprints(transfer.key, 24) { - info, exists := newRFT.info[fp] - if !exists { - t.Errorf("Fingerprint %d for transfer %s does not exist in "+ - "the map.", fpNum, id) - } - - if info.id != id { - t.Errorf("Fingerprint has wrong transfer ID."+ - "\nexpected: %s\nreceived: %s", id, info.id) - } - if int(info.fpNum) != fpNum { - t.Errorf("Fingerprint has wrong number."+ - "\nexpected: %d\nreceived: %d", fpNum, info.fpNum) - } - } - } -} - -// Error path: tests that ReceivedFileTransfersStore.load returns an error when -// all file transfers fail to load from storage. -func TestReceivedFileTransfersStore_load_AllFail(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - rft, err := NewReceivedFileTransfersStore(kv) - if err != nil { - t.Fatalf("Failed to create new ReceivedFileTransfersStore: %+v", err) - } - - // Fill map with transfers - idList := make([]ftCrypto.TransferID, 7) - for i := range idList { - prng := NewPrng(int64(i)) - key, _ := ftCrypto.NewTransferKey(prng) - mac := []byte("transferMAC") - - idList[i], err = rft.AddTransfer(key, mac, 256, 16, 24, prng) - if err != nil { - t.Errorf("Failed to add new transfer #%d: %+v", i, err) - } - - err = rft.DeleteTransfer(idList[i]) - if err != nil { - t.Errorf("Failed to delete transfer: %+v", err) - } - } - - // Save the list - err = rft.saveTransfersList() - if err != nil { - t.Errorf("saveTransfersList returned an error: %+v", err) - } - - // Build new ReceivedFileTransfersStore - newRFT := &ReceivedFileTransfersStore{ - transfers: make(map[ftCrypto.TransferID]*ReceivedTransfer), - info: make(map[format.Fingerprint]*partInfo), - kv: kv.Prefix(receivedFileTransfersStorePrefix), - } - - expectedErr := fmt.Sprintf(loadReceivedTransfersAllErr, len(idList)) - - // Load saved transfers from storage - err = newRFT.load(idList) - if err == nil || err.Error() != expectedErr { - t.Errorf("load did not return the expected error when none of the "+ - "transfer could be loaded from storage."+ - "\nexpected: %s\nreceived: %+v", expectedErr, err) - } -} - -// Tests that a transfer list marshalled with -// ReceivedFileTransfersStore.marshalTransfersList and unmarshalled with -// unmarshalTransfersList matches the original. -func TestReceivedFileTransfersStore_marshalTransfersList_unmarshalTransfersList(t *testing.T) { - prng := NewPrng(42) - rft := &ReceivedFileTransfersStore{ - transfers: make(map[ftCrypto.TransferID]*ReceivedTransfer), - } - - // Add 10 transfers to map in memory - for i := 0; i < 10; i++ { - tid, _ := ftCrypto.NewTransferID(prng) - rft.transfers[tid] = &ReceivedTransfer{} - } - - // Marshal into byte slice - marshalledBytes := rft.marshalTransfersList() - - // Unmarshal marshalled bytes into transfer ID list - list := unmarshalTransfersList(marshalledBytes) - - // Check that the list has all the transfer IDs in memory - for _, tid := range list { - if _, exists := rft.transfers[tid]; !exists { - t.Errorf("No transfer for ID %s exists.", tid) - } else { - delete(rft.transfers, tid) - } - } -} diff --git a/storage/fileTransfer/receiveTransfer.go b/storage/fileTransfer/receiveTransfer.go deleted file mode 100644 index b3e387a750d4cb6ba398c28195add01ec5b5763e..0000000000000000000000000000000000000000 --- a/storage/fileTransfer/receiveTransfer.go +++ /dev/null @@ -1,521 +0,0 @@ -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - -package fileTransfer - -import ( - "bytes" - "encoding/binary" - "github.com/pkg/errors" - jww "github.com/spf13/jwalterweatherman" - "gitlab.com/elixxir/client/interfaces" - "gitlab.com/elixxir/client/storage/utility" - "gitlab.com/elixxir/client/storage/versioned" - ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" - "gitlab.com/elixxir/primitives/format" - "gitlab.com/xx_network/primitives/netTime" - "sync" - "time" -) - -// Storage constants -const ( - receivedTransfersPrefix = "FileTransferReceivedTransferStore" - receivedTransferKey = "ReceivedTransfer" - receivedTransferVersion = 0 - receivedFpVectorKey = "ReceivedFingerprintVector" - receivedVectorKey = "ReceivedStatusVector" -) - -// Error messages for ReceivedTransfer -const ( - // NewReceivedTransfer - newReceivedTransferFpVectorErr = "failed to create new StateVector for fingerprints: %+v" - newReceivedTransferPartStoreErr = "failed to create new part store: %+v" - newReceivedVectorErr = "failed to create new state vector for received status: %+v" - - // ReceivedTransfer.GetFile - getFileErr = "missing %d/%d parts of the file" - getTransferMacErr = "failed to verify transfer MAC" - - // loadReceivedTransfer - loadReceivedStoreErr = "failed to load received transfer info from storage: %+v" - loadReceivePartStoreErr = "failed to load received part store from storage: %+v" - loadReceivedVectorErr = "failed to load new received status state vector from storage: %+v" - loadReceiveFpVectorErr = "failed to load received fingerprint vector from storage: %+v" - - // ReceivedTransfer.delete - deleteReceivedTransferInfoErr = "failed to delete received transfer info from storage: %+v" - deleteReceivedFpVectorErr = "failed to delete received fingerprint vector from storage: %+v" - deleteReceivedFilePartsErr = "failed to delete received file parts from storage: %+v" - deleteReceivedVectorErr = "failed to delete received status state vector from storage: %+v" - - // ReceivedTransfer.stopScheduledProgressCB - cancelReceivedCallbacksErr = "could not cancel %d out of %d received progress callbacks: %d" -) - -// ReceivedTransfer contains information and progress data for receiving an in- -// progress file transfer. -type ReceivedTransfer struct { - // The transfer key is a randomly generated key created by the sender and - // used to generate MACs and fingerprints - key ftCrypto.TransferKey - - // The MAC for the entire file; used to verify the integrity of all parts - transferMAC []byte - - // Size of the entire file in bytes - fileSize uint32 - - // The number of file parts in the file - numParts uint16 - - // The number `of fingerprints to generate (function of numParts and the - // retry rate) - numFps uint16 - - // Stores the state of a fingerprint (used/unused) in a bitstream format - // (has its own storage backend) - fpVector *utility.StateVector - - // Saves each part in order (has its own storage backend) - receivedParts *partStore - - // Stores the received status for each file part in a bitstream format - receivedStatus *utility.StateVector - - // List of callbacks to call for every send - progressCallbacks []*receivedCallbackTracker - - mux sync.RWMutex - kv *versioned.KV -} - -// NewReceivedTransfer generates a ReceivedTransfer with the specified -// transfer key, transfer ID, and a number of parts. -func NewReceivedTransfer(tid ftCrypto.TransferID, key ftCrypto.TransferKey, - transferMAC []byte, fileSize uint32, numParts, numFps uint16, - kv *versioned.KV) (*ReceivedTransfer, error) { - - // Create the ReceivedTransfer object - rt := &ReceivedTransfer{ - key: key, - transferMAC: transferMAC, - fileSize: fileSize, - numParts: numParts, - numFps: numFps, - progressCallbacks: []*receivedCallbackTracker{}, - kv: kv.Prefix(makeReceivedTransferPrefix(tid)), - } - - var err error - - // Create new StateVector for storing fingerprint usage - rt.fpVector, err = utility.NewStateVector( - rt.kv, receivedFpVectorKey, uint32(numFps)) - if err != nil { - return nil, errors.Errorf(newReceivedTransferFpVectorErr, err) - } - - // Create new part store - rt.receivedParts, err = newPartStore(rt.kv, numParts) - if err != nil { - return nil, errors.Errorf(newReceivedTransferPartStoreErr, err) - } - - // Create new StateVector for storing received status - rt.receivedStatus, err = utility.NewStateVector( - rt.kv, receivedVectorKey, uint32(rt.numParts)) - if err != nil { - return nil, errors.Errorf(newReceivedVectorErr, err) - } - - // Save all fields without their own storage to storage - return rt, rt.saveInfo() -} - -// GetTransferKey returns the transfer Key for this received transfer. -func (rt *ReceivedTransfer) GetTransferKey() ftCrypto.TransferKey { - rt.mux.Lock() - defer rt.mux.Unlock() - - return rt.key -} - -// GetTransferMAC returns the transfer MAC for this received transfer. -func (rt *ReceivedTransfer) GetTransferMAC() []byte { - rt.mux.Lock() - defer rt.mux.Unlock() - - return rt.transferMAC -} - -// GetNumParts returns the number of file parts in this transfer. -func (rt *ReceivedTransfer) GetNumParts() uint16 { - rt.mux.Lock() - defer rt.mux.Unlock() - - return rt.numParts -} - -// GetNumFps returns the number of fingerprints. -func (rt *ReceivedTransfer) GetNumFps() uint16 { - rt.mux.Lock() - defer rt.mux.Unlock() - - return rt.numFps -} - -// GetNumAvailableFps returns the number of unused fingerprints. -func (rt *ReceivedTransfer) GetNumAvailableFps() uint16 { - rt.mux.Lock() - defer rt.mux.Unlock() - - return uint16(rt.fpVector.GetNumAvailable()) -} - -// GetFileSize returns the file size in bytes. -func (rt *ReceivedTransfer) GetFileSize() uint32 { - rt.mux.Lock() - defer rt.mux.Unlock() - - return rt.fileSize -} - -// IsPartReceived returns true if the part has successfully been received. -// Returns false if the part has not been received or if the part number is -// invalid. -func (rt *ReceivedTransfer) IsPartReceived(partNum uint16) bool { - _, exists := rt.receivedParts.getPart(partNum) - return exists -} - -// GetProgress returns the current progress of the transfer. Completed is true -// when all parts have been received, received is the number of parts received, -// total is the total number of parts excepted to be received, and t is a part -// status tracker that can be used to get the status of individual file parts. -func (rt *ReceivedTransfer) GetProgress() (completed bool, received, - total uint16, t interfaces.FilePartTracker) { - rt.mux.RLock() - defer rt.mux.RUnlock() - - completed, received, total, t = rt.getProgress() - return completed, received, total, t -} - -// getProgress is the thread-unsafe helper function for GetProgress. -func (rt *ReceivedTransfer) getProgress() (completed bool, received, - total uint16, t interfaces.FilePartTracker) { - - received = uint16(rt.receivedStatus.GetNumUsed()) - total = rt.numParts - - if received == total { - completed = true - } - - return completed, received, total, newReceivedPartTracker(rt.receivedStatus) -} - -// CallProgressCB calls all the progress callbacks with the most recent progress -// information. -func (rt *ReceivedTransfer) CallProgressCB(err error) { - rt.mux.RLock() - defer rt.mux.RUnlock() - - for _, cb := range rt.progressCallbacks { - cb.call(rt, err) - } -} - -// stopScheduledProgressCB cancels all scheduled received progress callbacks -// calls. -func (rt *ReceivedTransfer) stopScheduledProgressCB() error { - rt.mux.Lock() - defer rt.mux.Unlock() - - // Tracks the index of callbacks that failed to stop - var failedCallbacks []int - - for i, cb := range rt.progressCallbacks { - err := cb.stopThread() - if err != nil { - failedCallbacks = append(failedCallbacks, i) - jww.WARN.Printf("[FT] %s", err) - } - } - - if len(failedCallbacks) > 0 { - return errors.Errorf(cancelReceivedCallbacksErr, len(failedCallbacks), - len(rt.progressCallbacks), failedCallbacks) - } - - return nil -} - -// AddProgressCB appends a new interfaces.ReceivedProgressCallback to the list -// of progress callbacks to be called and calls it. The period is how often the -// callback should be called when there are updates. -func (rt *ReceivedTransfer) AddProgressCB( - cb interfaces.ReceivedProgressCallback, period time.Duration) { - rt.mux.Lock() - - // Add callback - rct := newReceivedCallbackTracker(cb, period) - rt.progressCallbacks = append(rt.progressCallbacks, rct) - - rt.mux.Unlock() - - // Trigger the initial call - rct.callNow(true, rt, nil) -} - -// AddPart decrypts an encrypted file part, adds it to the list of received -// parts and marks its fingerprint as used. Returns true if the part added was -// the last in the transfer. -func (rt *ReceivedTransfer) AddPart(cmixMsg format.Message, - fpNum uint16) (bool, error) { - rt.mux.Lock() - defer rt.mux.Unlock() - - // Decrypt the encrypted file part - decryptedPart, err := ftCrypto.DecryptPart(rt.key, - cmixMsg.GetContents(), cmixMsg.GetMac(), fpNum, cmixMsg.GetKeyFP()) - if err != nil { - return false, err - } - - part, err := UnmarshalPartMessage(decryptedPart) - if err != nil { - return false, err - } - - // Add the part to the list of parts - err = rt.receivedParts.addPart(part.GetPart(), part.GetPartNum()) - if err != nil { - return false, err - } - - // Mark the fingerprint as used - rt.fpVector.Use(uint32(fpNum)) - - // Mark part as received - rt.receivedStatus.Use(uint32(part.GetPartNum())) - - if rt.receivedStatus.GetNumUsed() >= uint32(rt.numParts) { - return true, nil - } - - return false, nil -} - -// GetFile returns all the file parts combined into a single byte slice. An -// error is returned if parts are missing or if the MAC cannot be verified. The -// incomplete or invalid file is returned despite errors. -func (rt *ReceivedTransfer) GetFile() ([]byte, error) { - rt.mux.Lock() - defer rt.mux.Unlock() - - fileData, missingParts := rt.receivedParts.getFile() - if missingParts != 0 { - return fileData, errors.Errorf(getFileErr, missingParts, rt.numParts) - } - - // Remove extra data added when sending as parts - fileData = fileData[:rt.fileSize] - - if !ftCrypto.VerifyTransferMAC(fileData, rt.key, rt.transferMAC) { - return fileData, errors.New(getTransferMacErr) - } - - return fileData, nil -} - -//////////////////////////////////////////////////////////////////////////////// -// Storage Functions // -//////////////////////////////////////////////////////////////////////////////// - -// loadReceivedTransfer loads the ReceivedTransfer with the given transfer ID -// from storage. -func loadReceivedTransfer(tid ftCrypto.TransferID, kv *versioned.KV) ( - *ReceivedTransfer, error) { - // Create the ReceivedTransfer object - rt := &ReceivedTransfer{ - progressCallbacks: []*receivedCallbackTracker{}, - kv: kv.Prefix(makeReceivedTransferPrefix(tid)), - } - - // Load transfer key and number of received parts from storage - err := rt.loadInfo() - if err != nil { - return nil, errors.Errorf(loadReceivedStoreErr, err) - } - - // Load the fingerprint vector from storage - rt.fpVector, err = utility.LoadStateVector(rt.kv, receivedFpVectorKey) - if err != nil { - return nil, errors.Errorf(loadReceiveFpVectorErr, err) - } - - // Load received part store from storage - rt.receivedParts, err = loadPartStore(rt.kv) - if err != nil { - return nil, errors.Errorf(loadReceivePartStoreErr, err) - } - - // Load the received status StateVector from storage - rt.receivedStatus, err = utility.LoadStateVector(rt.kv, receivedVectorKey) - if err != nil { - return nil, errors.Errorf(loadReceivedVectorErr, err) - } - - return rt, nil -} - -// saveInfo saves all fields in ReceivedTransfer that do not have their own -// storage (transfer key, transfer MAC, file size, number of file parts, and -// number of fingerprints) to storage. -func (rt *ReceivedTransfer) saveInfo() error { - rt.mux.Lock() - defer rt.mux.Unlock() - - // Create new versioned object for the ReceivedTransfer - vo := &versioned.Object{ - Version: receivedTransferVersion, - Timestamp: netTime.Now(), - Data: rt.marshal(), - } - - // Save versioned object - return rt.kv.Set(receivedTransferKey, receivedTransferVersion, vo) -} - -// loadInfo gets the transfer key, transfer MAC, file size, number of part, and -// number of fingerprints from storage and saves it to the ReceivedTransfer. -func (rt *ReceivedTransfer) loadInfo() error { - vo, err := rt.kv.Get(receivedTransferKey, receivedTransferVersion) - if err != nil { - return err - } - - // Unmarshal the transfer key and numParts - rt.key, rt.transferMAC, rt.fileSize, rt.numParts, rt.numFps = - unmarshalReceivedTransfer(vo.Data) - - return nil -} - -// delete deletes all data in the ReceivedTransfer from storage. -func (rt *ReceivedTransfer) delete() error { - rt.mux.Lock() - defer rt.mux.Unlock() - - // Delete received transfer info from storage - err := rt.deleteInfo() - if err != nil { - return errors.Errorf(deleteReceivedTransferInfoErr, err) - } - - // Delete fingerprint vector from storage - err = rt.fpVector.Delete() - if err != nil { - return errors.Errorf(deleteReceivedFpVectorErr, err) - } - - // Delete received file parts from storage - err = rt.receivedParts.delete() - if err != nil { - return errors.Errorf(deleteReceivedFilePartsErr, err) - } - - // Delete the received status StateVector from storage - err = rt.receivedStatus.Delete() - if err != nil { - return errors.Errorf(deleteReceivedVectorErr, err) - } - - return nil -} - -// deleteInfo removes received transfer info (transfer key, transfer MAC, file -// size, number of parts, and number of fingerprints) from storage. -func (rt *ReceivedTransfer) deleteInfo() error { - return rt.kv.Delete(receivedTransferKey, receivedTransferVersion) -} - -// marshal serializes all primitive fields in ReceivedTransfer (key, -// transferMAC, fileSize, numParts, and numFps). -func (rt *ReceivedTransfer) marshal() []byte { - // Construct the buffer to the correct size (size of key + transfer MAC - // length (2 bytes) + transfer MAC + fileSize (4 bytes) + numParts (2 bytes) - // + numFps (2 bytes)) - buff := bytes.NewBuffer(nil) - buff.Grow(ftCrypto.TransferKeyLength + 2 + len(rt.transferMAC) + 4 + 2 + 2) - - // Write the key to the buffer - buff.Write(rt.key.Bytes()) - - // Write the length of the transfer MAC to the buffer - b := make([]byte, 2) - binary.LittleEndian.PutUint16(b, uint16(len(rt.transferMAC))) - buff.Write(b) - - // Write the transfer MAC to the buffer - buff.Write(rt.transferMAC) - - // Write the file size to the buffer - b = make([]byte, 4) - binary.LittleEndian.PutUint32(b, rt.fileSize) - buff.Write(b) - - // Write the number of parts to the buffer - b = make([]byte, 2) - binary.LittleEndian.PutUint16(b, rt.numParts) - buff.Write(b) - - // Write the number of fingerprints to the buffer - b = make([]byte, 2) - binary.LittleEndian.PutUint16(b, rt.numFps) - buff.Write(b) - - // Return the serialized data - return buff.Bytes() -} - -// unmarshalReceivedTransfer deserializes a byte slice into the primitive fields -// of ReceivedTransfer (key, transferMAC, fileSize, numParts, and numFps). -func unmarshalReceivedTransfer(data []byte) (key ftCrypto.TransferKey, - transferMAC []byte, size uint32, numParts, numFps uint16) { - - buff := bytes.NewBuffer(data) - - // Read the transfer key from the buffer - key = ftCrypto.UnmarshalTransferKey(buff.Next(ftCrypto.TransferKeyLength)) - - // Read the size of the transfer MAC from the buffer - transferMacSize := binary.LittleEndian.Uint16(buff.Next(2)) - - // Read the transfer MAC from the buffer - transferMAC = buff.Next(int(transferMacSize)) - - // Read the file size from the buffer - size = binary.LittleEndian.Uint32(buff.Next(4)) - - // Read the number of part from the buffer - numParts = binary.LittleEndian.Uint16(buff.Next(2)) - - // Read the number of fingerprints from the buffer - numFps = binary.LittleEndian.Uint16(buff.Next(2)) - - return key, transferMAC, size, numParts, numFps -} - -// makeReceivedTransferPrefix generates the unique prefix used on the key value -// store to store received transfers for the given transfer ID. -func makeReceivedTransferPrefix(tid ftCrypto.TransferID) string { - return receivedTransfersPrefix + tid.String() -} diff --git a/storage/fileTransfer/receiveTransfer_test.go b/storage/fileTransfer/receiveTransfer_test.go deleted file mode 100644 index 0edd3cfe63112c519c45046a009ecce868c3cde3..0000000000000000000000000000000000000000 --- a/storage/fileTransfer/receiveTransfer_test.go +++ /dev/null @@ -1,1159 +0,0 @@ -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - -package fileTransfer - -import ( - "bytes" - "fmt" - "github.com/pkg/errors" - "gitlab.com/elixxir/client/interfaces" - "gitlab.com/elixxir/client/storage/utility" - "gitlab.com/elixxir/client/storage/versioned" - ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" - "gitlab.com/elixxir/ekv" - "gitlab.com/elixxir/primitives/format" - "reflect" - "strings" - "sync" - "sync/atomic" - "testing" - "time" -) - -// Tests that NewReceivedTransfer correctly creates a new ReceivedTransfer and -// that it is saved to storage. -func Test_NewReceivedTransfer(t *testing.T) { - prng := NewPrng(42) - kv := versioned.NewKV(make(ekv.Memstore)) - tid, _ := ftCrypto.NewTransferID(prng) - key, _ := ftCrypto.NewTransferKey(prng) - mac := []byte("transferMAC") - kvPrefixed := kv.Prefix(makeReceivedTransferPrefix(tid)) - fileSize := uint32(256) - numParts, numFps := uint16(16), uint16(24) - fpVector, _ := utility.NewStateVector( - kvPrefixed, receivedFpVectorKey, uint32(numFps)) - receivedVector, _ := utility.NewStateVector( - kvPrefixed, receivedVectorKey, uint32(numParts)) - - expected := &ReceivedTransfer{ - key: key, - transferMAC: mac, - fileSize: fileSize, - numParts: numParts, - numFps: numFps, - fpVector: fpVector, - receivedParts: &partStore{ - parts: make(map[uint16][]byte, numParts), - numParts: numParts, - kv: kvPrefixed, - }, - receivedStatus: receivedVector, - progressCallbacks: []*receivedCallbackTracker{}, - mux: sync.RWMutex{}, - kv: kvPrefixed, - } - - // Create new ReceivedTransfer - rt, err := NewReceivedTransfer(tid, key, mac, fileSize, numParts, numFps, kv) - if err != nil { - t.Fatalf("NewReceivedTransfer returned an error: %v", err) - } - - // Check that the new object matches the expected - if !reflect.DeepEqual(expected, rt) { - t.Errorf("New ReceivedTransfer does not match expected."+ - "\nexpected: %#v\nreceived: %#v", expected, rt) - } - - // Make sure it is saved to storage - _, err = kvPrefixed.Get(receivedTransferKey, receivedTransferVersion) - if err != nil { - t.Fatalf("Failed to load ReceivedTransfer from storage: %+v", err) - } - - // Check that the fingerprint vector has correct values - if rt.fpVector.GetNumAvailable() != uint32(numFps) { - t.Errorf("Incorrect number of available keys in fingerprint list."+ - "\nexpected: %d\nreceived: %d", numFps, rt.fpVector.GetNumAvailable()) - } - if rt.fpVector.GetNumKeys() != uint32(numFps) { - t.Errorf("Incorrect number of keys in fingerprint list."+ - "\nexpected: %d\nreceived: %d", numFps, rt.fpVector.GetNumKeys()) - } - if rt.fpVector.GetNumUsed() != 0 { - t.Errorf("Incorrect number of used keys in fingerprint list."+ - "\nexpected: %d\nreceived: %d", 0, rt.fpVector.GetNumUsed()) - } -} - -// Tests that ReceivedTransfer.GetTransferKey returns the expected key. -func TestReceivedTransfer_GetTransferKey(t *testing.T) { - prng := NewPrng(42) - kv := versioned.NewKV(make(ekv.Memstore)) - - tid, _ := ftCrypto.NewTransferID(prng) - mac := []byte("transferMAC") - expectedKey, _ := ftCrypto.NewTransferKey(prng) - - rt, err := NewReceivedTransfer(tid, expectedKey, mac, 256, 16, 1, kv) - if err != nil { - t.Errorf("Failed to create new ReceivedTransfer: %+v", err) - } - - if expectedKey != rt.GetTransferKey() { - t.Errorf("Failed to get expected transfer key."+ - "\nexpected: %s\nreceived: %s", expectedKey, rt.GetTransferKey()) - } -} - -// Tests that ReceivedTransfer.GetTransferMAC returns the expected bytes. -func TestReceivedTransfer_GetTransferMAC(t *testing.T) { - prng := NewPrng(42) - kv := versioned.NewKV(make(ekv.Memstore)) - - tid, _ := ftCrypto.NewTransferID(prng) - expectedMAC := []byte("transferMAC") - key, _ := ftCrypto.NewTransferKey(prng) - - rt, err := NewReceivedTransfer(tid, key, expectedMAC, 256, 16, 1, kv) - if err != nil { - t.Errorf("Failed to create new ReceivedTransfer: %+v", err) - } - - if !bytes.Equal(expectedMAC, rt.GetTransferMAC()) { - t.Errorf("Failed to get expected transfer MAC."+ - "\nexpected: %v\nreceived: %v", expectedMAC, rt.GetTransferMAC()) - } -} - -// Tests that ReceivedTransfer.GetNumParts returns the expected number of parts. -func TestReceivedTransfer_GetNumParts(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - expectedNumParts := uint16(16) - _, rt, _ := newRandomReceivedTransfer(expectedNumParts, 20, kv, t) - - if expectedNumParts != rt.GetNumParts() { - t.Errorf("Failed to get expected number of parts."+ - "\nexpected: %d\nreceived: %d", expectedNumParts, rt.GetNumParts()) - } -} - -// Tests that ReceivedTransfer.GetNumFps returns the expected number of -// fingerprints. -func TestReceivedTransfer_GetNumFps(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - expectedNumFps := uint16(20) - _, rt, _ := newRandomReceivedTransfer(16, expectedNumFps, kv, t) - - if expectedNumFps != rt.GetNumFps() { - t.Errorf("Failed to get expected number of fingerprints."+ - "\nexpected: %d\nreceived: %d", expectedNumFps, rt.GetNumFps()) - } -} - -// Tests that ReceivedTransfer.GetNumAvailableFps returns the expected number of -// available fingerprints. -func TestReceivedTransfer_GetNumAvailableFps(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - numParts, numFps := uint16(16), uint16(24) - _, rt, _ := newRandomReceivedTransfer(numParts, numFps, kv, t) - - if numFps-numParts != rt.GetNumAvailableFps() { - t.Errorf("Failed to get expected number of available fingerprints."+ - "\nexpected: %d\nreceived: %d", - numFps-numParts, rt.GetNumAvailableFps()) - } -} - -// Tests that ReceivedTransfer.GetFileSize returns the expected file size. -func TestReceivedTransfer_GetFileSize(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - _, rt, file := newRandomReceivedTransfer(16, 20, kv, t) - expectedFileSize := len(file) - - if expectedFileSize != int(rt.GetFileSize()) { - t.Errorf("Failed to get expected file size."+ - "\nexpected: %d\nreceived: %d", expectedFileSize, rt.GetFileSize()) - } -} - -// Tests that ReceivedTransfer.IsPartReceived returns false for unreceived file -// parts and true when the file part has been received. -func TestReceivedTransfer_IsPartReceived(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - _, rt, _ := newEmptyReceivedTransfer(16, 20, kv, t) - - partNum := uint16(5) - - if rt.IsPartReceived(partNum) { - t.Errorf("Part number %d received.", partNum) - } - - _ = rt.receivedParts.addPart([]byte("part"), partNum) - - if !rt.IsPartReceived(partNum) { - t.Errorf("Part number %d not received.", partNum) - } -} - -// checkReceivedProgress compares the output of ReceivedTransfer.GetProgress to -// expected values. -func checkReceivedProgress(completed bool, received, total uint16, - eCompleted bool, eReceived, eTotal uint16) error { - if eCompleted != completed || eReceived != received || eTotal != total { - return errors.Errorf("Returned progress does not match expected."+ - "\n completed received total"+ - "\nexpected: %5t %3d %3d"+ - "\nreceived: %5t %3d %3d", - eCompleted, eReceived, eTotal, - completed, received, total) - } - - return nil -} - -// checkReceivedTracker checks that the receivedPartTracker is reporting the -// correct values for each part. Also checks that -// receivedPartTracker.GetNumParts returns the expected value (make sure -// numParts comes from a correct source). -func checkReceivedTracker(track interfaces.FilePartTracker, numParts uint16, - received []uint16, t *testing.T) { - if track.GetNumParts() != numParts { - t.Errorf("Tracker reported incorrect number of parts."+ - "\nexpected: %d\nreceived: %d", numParts, track.GetNumParts()) - return - } - - for partNum := uint16(0); partNum < numParts; partNum++ { - var done bool - for _, receivedNum := range received { - if receivedNum == partNum { - if track.GetPartStatus(partNum) != interfaces.FpReceived { - t.Errorf("Part number %d has unexpected status."+ - "\nexpected: %d\nreceived: %d", partNum, - interfaces.FpReceived, track.GetPartStatus(partNum)) - } - done = true - break - } - } - if done { - continue - } - - if track.GetPartStatus(partNum) != interfaces.FpUnsent { - t.Errorf("Part number %d has incorrect status."+ - "\nexpected: %d\nreceived: %d", - partNum, interfaces.FpUnsent, track.GetPartStatus(partNum)) - } - } -} - -// Tests that ReceivedTransfer.GetProgress returns the expected progress metrics -// for various transfer states. -func TestReceivedTransfer_GetProgress(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - numParts := uint16(16) - _, rt, _ := newEmptyReceivedTransfer(numParts, 20, kv, t) - - completed, received, total, track := rt.GetProgress() - err := checkReceivedProgress(completed, received, total, false, 0, numParts) - if err != nil { - t.Error(err) - } - checkReceivedTracker(track, rt.numParts, nil, t) - - _, _ = rt.fpVector.Next() - _, _ = rt.receivedStatus.Next() - - completed, received, total, track = rt.GetProgress() - err = checkReceivedProgress(completed, received, total, false, 1, numParts) - if err != nil { - t.Error(err) - } - checkReceivedTracker(track, rt.numParts, []uint16{0}, t) - - for i := 0; i < 4; i++ { - _, _ = rt.fpVector.Next() - _, _ = rt.receivedStatus.Next() - } - - completed, received, total, track = rt.GetProgress() - err = checkReceivedProgress(completed, received, total, false, 5, numParts) - if err != nil { - t.Error(err) - } - checkReceivedTracker(track, rt.numParts, []uint16{0, 1, 2, 3, 4}, t) - - for i := 0; i < 6; i++ { - _, _ = rt.fpVector.Next() - _, _ = rt.receivedStatus.Next() - } - - completed, received, total, track = rt.GetProgress() - err = checkReceivedProgress(completed, received, total, false, 11, numParts) - if err != nil { - t.Error(err) - } - checkReceivedTracker( - track, rt.numParts, []uint16{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, t) - - for i := 0; i < 4; i++ { - _, _ = rt.fpVector.Next() - _, _ = rt.receivedStatus.Next() - } - - completed, received, total, track = rt.GetProgress() - err = checkReceivedProgress(completed, received, total, false, 15, numParts) - if err != nil { - t.Error(err) - } - checkReceivedTracker(track, rt.numParts, []uint16{0, 1, 2, 3, 4, 5, 6, 7, 8, - 9, 10, 11, 12, 13, 14}, t) - - _, _ = rt.fpVector.Next() - _, _ = rt.receivedStatus.Next() - - completed, received, total, track = rt.GetProgress() - err = checkReceivedProgress(completed, received, total, true, 16, numParts) - if err != nil { - t.Error(err) - } - checkReceivedTracker(track, rt.numParts, []uint16{0, 1, 2, 3, 4, 5, 6, 7, 8, - 9, 10, 11, 12, 13, 14, 15}, t) -} - -// Tests that 5 different callbacks all receive the expected data when -// ReceivedTransfer.CallProgressCB is called at different stages of transfer. -func TestReceivedTransfer_CallProgressCB(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - _, rt, _ := newEmptyReceivedTransfer(16, 20, kv, t) - - type progressResults struct { - completed bool - received, total uint16 - tr interfaces.FilePartTracker - err error - } - - period := time.Millisecond - - wg := sync.WaitGroup{} - var step0, step1, step2, step3 uint64 - numCallbacks := 5 - - for i := 0; i < numCallbacks; i++ { - progressChan := make(chan progressResults) - - cbFunc := func(completed bool, received, total uint16, - tr interfaces.FilePartTracker, err error) { - progressChan <- progressResults{completed, received, total, tr, err} - } - wg.Add(1) - - go func(i int) { - defer wg.Done() - n := 0 - for { - select { - case <-time.NewTimer(time.Second).C: - t.Errorf("Timed out after %s waiting for callback (%d).", - period*5, i) - return - case r := <-progressChan: - switch n { - case 0: - if err := checkReceivedProgress(r.completed, r.received, - r.total, false, 0, rt.numParts); err != nil { - t.Errorf("%2d: %v", i, err) - } - atomic.AddUint64(&step0, 1) - case 1: - if err := checkReceivedProgress(r.completed, r.received, - r.total, false, 0, rt.numParts); err != nil { - t.Errorf("%2d: %v", i, err) - } - atomic.AddUint64(&step1, 1) - case 2: - if err := checkReceivedProgress(r.completed, r.received, - r.total, false, 4, rt.numParts); err != nil { - t.Errorf("%2d: %v", i, err) - } - atomic.AddUint64(&step2, 1) - case 3: - if err := checkReceivedProgress(r.completed, r.received, - r.total, true, 16, rt.numParts); err != nil { - t.Errorf("%2d: %v", i, err) - } - atomic.AddUint64(&step3, 1) - return - default: - t.Errorf("n (%d) is great than 3 (%d)", n, i) - return - } - n++ - } - } - }(i) - - rt.AddProgressCB(cbFunc, period) - } - - for !atomic.CompareAndSwapUint64(&step0, uint64(numCallbacks), 0) { - } - - rt.CallProgressCB(nil) - - for !atomic.CompareAndSwapUint64(&step1, uint64(numCallbacks), 0) { - } - - for i := 0; i < 4; i++ { - _, _ = rt.receivedStatus.Next() - } - - rt.CallProgressCB(nil) - - for !atomic.CompareAndSwapUint64(&step2, uint64(numCallbacks), 0) { - } - - for i := 0; i < 12; i++ { - _, _ = rt.receivedStatus.Next() - } - - rt.CallProgressCB(nil) - - wg.Wait() -} - -// Tests that ReceivedTransfer.stopScheduledProgressCB stops a scheduled -// callback from being triggered. -func TestReceivedTransfer_stopScheduledProgressCB(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - _, rt, _ := newEmptyReceivedTransfer(16, 20, kv, t) - - cbChan := make(chan struct{}, 5) - cbFunc := interfaces.ReceivedProgressCallback( - func(completed bool, received, total uint16, - t interfaces.FilePartTracker, err error) { - cbChan <- struct{}{} - }) - rt.AddProgressCB(cbFunc, 150*time.Millisecond) - select { - case <-time.NewTimer(10 * time.Millisecond).C: - t.Error("Timed out waiting for callback.") - case <-cbChan: - } - - rt.CallProgressCB(nil) - rt.CallProgressCB(nil) - select { - case <-time.NewTimer(10 * time.Millisecond).C: - t.Error("Timed out waiting for callback.") - case <-cbChan: - } - - err := rt.stopScheduledProgressCB() - if err != nil { - t.Errorf("stopScheduledProgressCB returned an error: %+v", err) - } - - select { - case <-time.NewTimer(200 * time.Millisecond).C: - case <-cbChan: - t.Error("Callback called when it should have been stopped.") - } -} - -// Tests that ReceivedTransfer.AddProgressCB adds an item to the progress -// callback list. -func TestReceivedTransfer_AddProgressCB(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - _, rt, _ := newEmptyReceivedTransfer(16, 20, kv, t) - - type callbackResults struct { - completed bool - received, total uint16 - err error - } - cbChan := make(chan callbackResults) - cbFunc := interfaces.ReceivedProgressCallback( - func(completed bool, received, total uint16, - t interfaces.FilePartTracker, err error) { - cbChan <- callbackResults{completed, received, total, err} - }) - - done := make(chan bool) - go func() { - select { - case <-time.NewTimer(time.Millisecond).C: - t.Error("Timed out waiting for progress callback to be called.") - case r := <-cbChan: - err := checkReceivedProgress( - r.completed, r.received, r.total, false, 0, 16) - if err != nil { - t.Error(err) - } - if r.err != nil { - t.Errorf("Callback returned an error: %+v", err) - } - } - done <- true - }() - - period := time.Millisecond - rt.AddProgressCB(cbFunc, period) - - if len(rt.progressCallbacks) != 1 { - t.Errorf("Callback list should only have one item."+ - "\nexpected: %d\nreceived: %d", 1, len(rt.progressCallbacks)) - } - - if rt.progressCallbacks[0].period != period { - t.Errorf("Callback has wrong lastCall.\nexpected: %s\nreceived: %s", - period, rt.progressCallbacks[0].period) - } - - if rt.progressCallbacks[0].lastCall != (time.Time{}) { - t.Errorf("Callback has wrong time.\nexpected: %s\nreceived: %s", - time.Time{}, rt.progressCallbacks[0].lastCall) - } - - if rt.progressCallbacks[0].scheduled { - t.Errorf("Callback has wrong scheduled.\nexpected: %t\nreceived: %t", - false, rt.progressCallbacks[0].scheduled) - } - <-done -} - -// Tests that ReceivedTransfer.AddPart adds a part in the correct place in the -// list of parts. -func TestReceivedTransfer_AddPart(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - numFps := uint16(20) - _, rt, _ := newEmptyReceivedTransfer(16, 20, kv, t) - - // Create encrypted part - cmixMsg := format.NewMessage(format.MinimumPrimeSize) - - expectedData := []byte("test") - - partNum, fpNum := uint16(1), uint16(1) - - partData, _ := NewPartMessage(cmixMsg.ContentsSize()) - partData.SetPartNum(partNum) - _ = partData.SetPart(expectedData) - - fp := ftCrypto.GenerateFingerprint(rt.key, fpNum) - encryptedPart, mac, err := ftCrypto.EncryptPart(rt.key, partData.Marshal(), fpNum, fp) - - cmixMsg.SetKeyFP(fp) - cmixMsg.SetContents(encryptedPart) - cmixMsg.SetMac(mac) - - // Add encrypted part - complete, err := rt.AddPart(cmixMsg, fpNum) - if err != nil { - t.Errorf("AddPart returned an error: %+v", err) - } - - if complete { - t.Errorf("Transfer complete when it should not be.") - } - - receivedData, exists := rt.receivedParts.parts[partNum] - if !exists { - t.Errorf("Part #%d not found in part map.", partNum) - } else if !bytes.Equal(expectedData, receivedData[:len(expectedData)]) { - t.Fatalf("Part data in list does not match expected."+ - "\nexpected: %+v\nreceived: %+v", expectedData, receivedData[:len(expectedData)]) - } - - // Check that the fingerprint vector has correct values - if rt.fpVector.GetNumAvailable() != uint32(numFps-1) { - t.Errorf("Incorrect number of available keys in fingerprint list."+ - "\nexpected: %d\nreceived: %d", numFps, rt.fpVector.GetNumAvailable()) - } - if rt.fpVector.GetNumKeys() != uint32(numFps) { - t.Errorf("Incorrect number of keys in fingerprint list."+ - "\nexpected: %d\nreceived: %d", numFps, rt.fpVector.GetNumKeys()) - } - if rt.fpVector.GetNumUsed() != 1 { - t.Errorf("Incorrect number of used keys in fingerprint list."+ - "\nexpected: %d\nreceived: %d", 1, rt.fpVector.GetNumUsed()) - } - - // Check that the part was properly marked as received on the received - // status vector - if !rt.receivedStatus.Used(uint32(partNum)) { - t.Errorf("Part number %d not marked as used in received status vector.", - partNum) - } - if rt.receivedStatus.GetNumUsed() != 1 { - t.Errorf("Incorrect number of received parts in vector."+ - "\nexpected: %d\nreceived: %d", 1, rt.receivedStatus.GetNumUsed()) - } -} - -// Error path: tests that ReceivedTransfer.AddPart returns the expected error -// when the provided MAC is invalid. -func TestReceivedTransfer_AddPart_DecryptPartError(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - _, rt, _ := newEmptyReceivedTransfer(16, 20, kv, t) - - // Create encrypted part - cmixMsg := format.NewMessage(format.MinimumPrimeSize) - - expectedData := []byte("test") - - partNum, fpNum := uint16(1), uint16(1) - - partData, _ := NewPartMessage(cmixMsg.ContentsSize()) - partData.SetPartNum(partNum) - _ = partData.SetPart(expectedData) - - fp := ftCrypto.GenerateFingerprint(rt.key, fpNum) - encryptedPart, _, err := ftCrypto.EncryptPart(rt.key, partData.Marshal(), fpNum, fp) - badMac := make([]byte, format.MacLen) - - cmixMsg.SetKeyFP(fp) - cmixMsg.SetContents(encryptedPart) - cmixMsg.SetMac(badMac) - - // Add encrypted part - expectedErr := "reconstructed MAC from decrypting does not match MAC from sender" - _, err = rt.AddPart(cmixMsg, fpNum) - if err == nil || err.Error() != expectedErr { - t.Errorf("AddPart did not return the expected error when the MAC is "+ - "invalid.\nexpected: %s\nreceived: %+v", expectedErr, err) - } -} - -// Tests that ReceivedTransfer.GetFile returns the expected file data. -func TestReceivedTransfer_GetFile(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - _, rt, expectedFile := newRandomReceivedTransfer(16, 20, kv, t) - - receivedFile, err := rt.GetFile() - if err != nil { - t.Errorf("GetFile returned an error: %+v", err) - } - - // Check that the received file is correct - if !bytes.Equal(expectedFile, receivedFile) { - t.Errorf("File does not match expected.\nexpected: %+v\nreceived: %+v", - expectedFile, receivedFile) - } -} - -// Error path: tests that ReceivedTransfer.GetFile returns the expected error -// when not all file parts are present. -func TestReceivedTransfer_GetFile_MissingPartsError(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - numParts := uint16(16) - _, rt, _ := newEmptyReceivedTransfer(numParts, 20, kv, t) - expectedErr := fmt.Sprintf(getFileErr, numParts, numParts) - - _, err := rt.GetFile() - if err == nil || err.Error() != expectedErr { - t.Errorf("GetFile failed to return the expected error when all file "+ - "parts are missing.\nexpected: %s\nreceived: %+v", - expectedErr, err) - } -} - -// Error path: tests that ReceivedTransfer.GetFile returns the expected error -// when the stored transfer MAC does not match the one generated from the file. -func TestReceivedTransfer_GetFile_InvalidMacError(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - _, rt, _ := newRandomReceivedTransfer(16, 20, kv, t) - - rt.transferMAC = []byte("invalidMAC") - - _, err := rt.GetFile() - if err == nil || err.Error() != getTransferMacErr { - t.Errorf("GetFile failed to return the expected error when the file "+ - "MAC cannot be verified.\nexpected: %s\nreceived: %+v", - getTransferMacErr, err) - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Storage Function Testing // -//////////////////////////////////////////////////////////////////////////////// - -// Tests that loadReceivedTransfer returns a ReceivedTransfer that matches the -// original object in memory. -func Test_loadReceivedTransfer(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - tid, expectedRT, _ := newRandomReceivedTransfer(16, 20, kv, t) - - // Create encrypted part - expectedData := []byte("test") - partNum, fpNum := uint16(1), uint16(1) - - cmixMsg := format.NewMessage(format.MinimumPrimeSize) - - partData, _ := NewPartMessage(cmixMsg.ContentsSize()) - partData.SetPartNum(partNum) - _ = partData.SetPart(expectedData) - - fp := ftCrypto.GenerateFingerprint(expectedRT.key, fpNum) - encryptedPart, mac, err := ftCrypto.EncryptPart(expectedRT.key, partData.Marshal(), fpNum, fp) - - cmixMsg.SetKeyFP(fp) - cmixMsg.SetContents(encryptedPart) - cmixMsg.SetMac(mac) - - // Add encrypted part - _, err = expectedRT.AddPart(cmixMsg, fpNum) - if err != nil { - t.Errorf("Failed to add test part: %+v", err) - } - - loadedRT, err := loadReceivedTransfer(tid, kv) - if err != nil { - t.Errorf("loadReceivedTransfer returned an error: %+v", err) - } - - if !reflect.DeepEqual(expectedRT, loadedRT) { - t.Errorf("Loaded ReceivedTransfer does not match expected"+ - ".\nexpected: %+v\nreceived: %+v", expectedRT, loadedRT) - } -} - -// Error path: tests that loadReceivedTransfer returns the expected error when -// no info is in storage to load. -func Test_loadReceivedTransfer_LoadInfoError(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - tid := ftCrypto.UnmarshalTransferID([]byte("invalidTransferID")) - expectedErr := strings.Split(loadReceivedStoreErr, "%")[0] - - _, err := loadReceivedTransfer(tid, kv) - if err == nil || !strings.Contains(err.Error(), expectedErr) { - t.Errorf("loadReceivedTransfer did not return the expected error when "+ - "trying to load info from storage that does not exist."+ - "\nexpected: %s\nreceived: %+v", expectedErr, err) - } -} - -// Error path: tests that loadReceivedTransfer returns the expected error when -// the fingerprint state vector has been deleted from storage. -func Test_loadReceivedTransfer_LoadFpVectorError(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - tid, rt, _ := newRandomReceivedTransfer(16, 20, kv, t) - - // Create encrypted part - - data := []byte("test") - partNum, fpNum := uint16(1), uint16(1) - - cmixMsg := format.NewMessage(format.MinimumPrimeSize) - - partData, _ := NewPartMessage(cmixMsg.ContentsSize()) - partData.SetPartNum(partNum) - _ = partData.SetPart(data) - - fp := ftCrypto.GenerateFingerprint(rt.key, fpNum) - encryptedPart, mac, err := ftCrypto.EncryptPart(rt.key, partData.Marshal(), fpNum, fp) - - cmixMsg.SetKeyFP(fp) - cmixMsg.SetContents(encryptedPart) - cmixMsg.SetMac(mac) - - // Add encrypted part - _, err = rt.AddPart(cmixMsg, fpNum) - if err != nil { - t.Errorf("Failed to add test part: %+v", err) - } - - // Delete fingerprint state vector from storage - err = rt.fpVector.Delete() - if err != nil { - t.Errorf("Failed to delete fingerprint vector: %+v", err) - } - - expectedErr := strings.Split(loadReceiveFpVectorErr, "%")[0] - _, err = loadReceivedTransfer(tid, kv) - if err == nil || !strings.Contains(err.Error(), expectedErr) { - t.Errorf("loadReceivedTransfer did not return the expected error when "+ - "the state vector was deleted from storage."+ - "\nexpected: %s\nreceived: %+v", expectedErr, err) - } -} - -// Error path: tests that loadReceivedTransfer returns the expected error when -// the part store has been deleted from storage. -func Test_loadReceivedTransfer_LoadPartStoreError(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - tid, rt, _ := newRandomReceivedTransfer(16, 20, kv, t) - - // Create encrypted part - data := []byte("test") - partNum, fpNum := uint16(1), uint16(1) - - cmixMsg := format.NewMessage(format.MinimumPrimeSize) - - partData, _ := NewPartMessage(cmixMsg.ContentsSize()) - partData.SetPartNum(partNum) - _ = partData.SetPart(data) - - fp := ftCrypto.GenerateFingerprint(rt.key, fpNum) - encryptedPart, mac, err := ftCrypto.EncryptPart(rt.key, partData.Marshal(), fpNum, fp) - - cmixMsg.SetKeyFP(fp) - cmixMsg.SetContents(encryptedPart) - cmixMsg.SetMac(mac) - - // Add encrypted part - _, err = rt.AddPart(cmixMsg, fpNum) - if err != nil { - t.Errorf("Failed to add test part: %+v", err) - } - - // Delete fingerprint state vector from storage - err = rt.receivedParts.delete() - if err != nil { - t.Errorf("Failed to delete part store: %+v", err) - } - - expectedErr := strings.Split(loadReceivePartStoreErr, "%")[0] - _, err = loadReceivedTransfer(tid, kv) - if err == nil || !strings.Contains(err.Error(), expectedErr) { - t.Errorf("loadReceivedTransfer did not return the expected error when "+ - "the part store was deleted from storage."+ - "\nexpected: %s\nreceived: %+v", expectedErr, err) - } -} - -// Error path: tests that loadReceivedTransfer returns the expected error when -// the received status state vector has been deleted from storage. -func Test_loadReceivedTransfer_LoadReceivedVectorError(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - tid, rt, _ := newRandomReceivedTransfer(16, 20, kv, t) - - // Create encrypted part - data := []byte("test") - partNum, fpNum := uint16(1), uint16(1) - - cmixMsg := format.NewMessage(format.MinimumPrimeSize) - - partData, _ := NewPartMessage(cmixMsg.ContentsSize()) - partData.SetPartNum(partNum) - _ = partData.SetPart(data) - - fp := ftCrypto.GenerateFingerprint(rt.key, fpNum) - encryptedPart, mac, err := ftCrypto.EncryptPart(rt.key, partData.Marshal(), fpNum, fp) - - cmixMsg.SetKeyFP(fp) - cmixMsg.SetContents(encryptedPart) - cmixMsg.SetMac(mac) - - // Add encrypted part - _, err = rt.AddPart(cmixMsg, fpNum) - if err != nil { - t.Errorf("Failed to add test part: %+v", err) - } - - // Delete fingerprint state vector from storage - err = rt.receivedStatus.Delete() - if err != nil { - t.Errorf("Failed to received status vector: %+v", err) - } - - expectedErr := strings.Split(loadReceivedVectorErr, "%")[0] - _, err = loadReceivedTransfer(tid, kv) - if err == nil || !strings.Contains(err.Error(), expectedErr) { - t.Errorf("loadReceivedTransfer did not return the expected error when "+ - "the received status vector was deleted from storage."+ - "\nexpected: %s\nreceived: %+v", expectedErr, err) - } -} - -// Tests that ReceivedTransfer.saveInfo saves the expected data to storage. -func TestReceivedTransfer_saveInfo(t *testing.T) { - rt := &ReceivedTransfer{ - key: ftCrypto.UnmarshalTransferKey([]byte("key")), - transferMAC: []byte("transferMAC"), - numParts: 16, - kv: versioned.NewKV(make(ekv.Memstore)), - } - - err := rt.saveInfo() - if err != nil { - t.Fatalf("saveInfo returned an error: %v", err) - } - - vo, err := rt.kv.Get(receivedTransferKey, receivedTransferVersion) - if err != nil { - t.Fatalf("Failed to load ReceivedTransfer from storage: %+v", err) - } - - if !bytes.Equal(rt.marshal(), vo.Data) { - t.Errorf("Marshalled data loaded from storage does not match expected."+ - "\nexpected: %+v\nreceived: %+v", rt.marshal(), vo.Data) - } -} - -// Tests that ReceivedTransfer.loadInfo loads a saved ReceivedTransfer from -// storage. -func TestReceivedTransfer_loadInfo(t *testing.T) { - rt := &ReceivedTransfer{ - key: ftCrypto.UnmarshalTransferKey([]byte("key")), - transferMAC: []byte("transferMAC"), - numParts: 16, - kv: versioned.NewKV(make(ekv.Memstore)), - } - - err := rt.saveInfo() - if err != nil { - t.Errorf("failed to save new ReceivedTransfer to storage: %+v", err) - } - - loadedRT := &ReceivedTransfer{kv: rt.kv} - err = loadedRT.loadInfo() - if err != nil { - t.Errorf("load returned an error: %+v", err) - } - - if !reflect.DeepEqual(rt, loadedRT) { - t.Errorf("Loaded ReceivedTransfer does not match expected."+ - "\nexpected: %+v\nreceived: %+v", rt, loadedRT) - } -} - -// Error path: tests that ReceivedTransfer.loadInfo returns an error when there -// is no object in storage to load -func TestReceivedTransfer_loadInfo_Error(t *testing.T) { - loadedRT := &ReceivedTransfer{kv: versioned.NewKV(make(ekv.Memstore))} - err := loadedRT.loadInfo() - if err == nil { - t.Errorf("Loaded object that should not be in storage: %+v", err) - } -} - -// Tests that ReceivedTransfer.delete removes all data from storage. -func TestReceivedTransfer_delete(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - _, rt, _ := newRandomReceivedTransfer(16, 20, kv, t) - - // Create encrypted part - expectedData := []byte("test") - partNum, fpNum := uint16(1), uint16(1) - - cmixMsg := format.NewMessage(format.MinimumPrimeSize) - - partData, _ := NewPartMessage(cmixMsg.ContentsSize()) - partData.SetPartNum(partNum) - _ = partData.SetPart(expectedData) - - fp := ftCrypto.GenerateFingerprint(rt.key, fpNum) - encryptedPart, mac, err := ftCrypto.EncryptPart(rt.key, partData.Marshal(), fpNum, fp) - - cmixMsg.SetKeyFP(fp) - cmixMsg.SetContents(encryptedPart) - cmixMsg.SetMac(mac) - - // Add encrypted part - _, err = rt.AddPart(cmixMsg, fpNum) - if err != nil { - t.Fatalf("Failed to add test part: %+v", err) - } - - // Delete everything from storage - err = rt.delete() - if err != nil { - t.Errorf("delete returned an error: %+v", err) - } - - // Check that the SentTransfer info was deleted - err = rt.loadInfo() - if err == nil { - t.Error("Successfully loaded SentTransfer info from storage when it " + - "should have been deleted.") - } - - // Check that the parts store were deleted - _, err = loadPartStore(rt.kv) - if err == nil { - t.Error("Successfully loaded file parts from storage when it should " + - "have been deleted.") - } - - // Check that the fingerprint vector was deleted - _, err = utility.LoadStateVector(rt.kv, receivedFpVectorKey) - if err == nil { - t.Error("Successfully loaded fingerprint vector from storage when it " + - "should have been deleted.") - } - - // Check that the received status vector was deleted - _, err = utility.LoadStateVector(rt.kv, receivedVectorKey) - if err == nil { - t.Error("Successfully loaded received status vector from storage when " + - "it should have been deleted.") - } -} - -// Tests that ReceivedTransfer.deleteInfo removes the saved ReceivedTransfer -// data from storage. -func TestReceivedTransfer_deleteInfo(t *testing.T) { - rt := &ReceivedTransfer{ - key: ftCrypto.UnmarshalTransferKey([]byte("key")), - transferMAC: []byte("transferMAC"), - numParts: 16, - kv: versioned.NewKV(make(ekv.Memstore)), - } - - // Save from storage - err := rt.saveInfo() - if err != nil { - t.Errorf("failed to save new ReceivedTransfer to storage: %+v", err) - } - - // Delete from storage - err = rt.deleteInfo() - if err != nil { - t.Errorf("deleteInfo returned an error: %+v", err) - } - - // Make sure deleted object cannot be loaded from storage - _, err = rt.kv.Get(receivedTransferKey, receivedTransferVersion) - if err == nil { - t.Error("Loaded object that should be deleted from storage.") - } -} - -// Tests that a ReceivedTransfer marshalled with ReceivedTransfer.marshal and -// then unmarshalled with unmarshalReceivedTransfer matches the original. -func TestReceivedTransfer_marshal_unmarshalReceivedTransfer(t *testing.T) { - rt := &ReceivedTransfer{ - key: ftCrypto.UnmarshalTransferKey([]byte("key")), - transferMAC: []byte("transferMAC"), - fileSize: 256, - numParts: 16, - numFps: 20, - } - - marshaledData := rt.marshal() - key, mac, fileSize, numParts, numFps := unmarshalReceivedTransfer(marshaledData) - - if rt.key != key { - t.Errorf("Failed to get expected key.\nexpected: %s\nreceived: %s", - rt.key, key) - } - - if !bytes.Equal(rt.transferMAC, mac) { - t.Errorf("Failed to get expected transfer MAC."+ - "\nexpected: %v\nreceived: %v", rt.transferMAC, mac) - } - - if rt.fileSize != fileSize { - t.Errorf("Failed to get expected file size."+ - "\nexpected: %d\nreceived: %d", rt.fileSize, fileSize) - } - - if rt.numParts != numParts { - t.Errorf("Failed to get expected number of parts."+ - "\nexpected: %d\nreceived: %d", rt.numParts, numParts) - } - - if rt.numFps != numFps { - t.Errorf("Failed to get expected number of fingerprints."+ - "\nexpected: %d\nreceived: %d", rt.numFps, numFps) - } -} - -// Consistency test: tests that makeReceivedTransferPrefix returns the expected -// prefixes for the provided transfer IDs. -func Test_makeReceivedTransferPrefix_Consistency(t *testing.T) { - prng := NewPrng(42) - expectedPrefixes := []string{ - "FileTransferReceivedTransferStoreU4x/lrFkvxuXu59LtHLon1sUhPJSCcnZND6SugndnVI=", - "FileTransferReceivedTransferStore39ebTXZCm2F6DJ+fDTulWwzA1hRMiIU1hBrL4HCbB1g=", - "FileTransferReceivedTransferStoreCD9h03W8ArQd9PkZKeGP2p5vguVOdI6B555LvW/jTNw=", - "FileTransferReceivedTransferStoreuoQ+6NY+jE/+HOvqVG2PrBPdGqwEzi6ih3xVec+ix44=", - "FileTransferReceivedTransferStoreGwuvrogbgqdREIpC7TyQPKpDRlp4YgYWl4rtDOPGxPM=", - "FileTransferReceivedTransferStorernvD4ElbVxL+/b4MECiH4QDazS2IX2kstgfaAKEcHHA=", - "FileTransferReceivedTransferStoreceeWotwtwlpbdLLhKXBeJz8FySMmgo4rBW44F2WOEGE=", - "FileTransferReceivedTransferStoreSYlH/fNEQQ7UwRYCP6jjV2tv7Sf/iXS6wMr9mtBWkrE=", - "FileTransferReceivedTransferStoreNhnnOJZN/ceejVNDc2Yc/WbXT+weG4lJGrcjbkt1IWI=", - "FileTransferReceivedTransferStorekM8r60LDyicyhWDxqsBnzqbov0bUqytGgEAsX7KCDog=", - } - - for i, expected := range expectedPrefixes { - tid, _ := ftCrypto.NewTransferID(prng) - prefix := makeReceivedTransferPrefix(tid) - - if expected != prefix { - t.Errorf("New ReceivedTransfer prefix does not match expected (%d)."+ - "\nexpected: %s\nreceived: %s", i, expected, prefix) - } - } -} - -// newRandomReceivedTransfer generates a new ReceivedTransfer with random data. -// Returns the generated transfer ID, new ReceivedTransfer, and the full file. -func newRandomReceivedTransfer(numParts, numFps uint16, kv *versioned.KV, - t *testing.T) (ftCrypto.TransferID, *ReceivedTransfer, []byte) { - - prng := NewPrng(42) - tid, _ := ftCrypto.NewTransferID(prng) - key, _ := ftCrypto.NewTransferKey(prng) - parts, fileData := newRandomPartStore(numParts, kv, prng, t) - fileSize := uint32(len(fileData)) - mac := ftCrypto.CreateTransferMAC(fileData, key) - - rt, err := NewReceivedTransfer(tid, key, mac, fileSize, numParts, numFps, kv) - if err != nil { - t.Errorf("Failed to create new ReceivedTransfer: %+v", err) - } - - for partNum, part := range parts.parts { - cmixMsg := format.NewMessage(format.MinimumPrimeSize) - - partData, _ := NewPartMessage(cmixMsg.ContentsSize()) - partData.SetPartNum(partNum) - _ = partData.SetPart(part) - - fp := ftCrypto.GenerateFingerprint(rt.key, partNum) - encryptedPart, mac, err := ftCrypto.EncryptPart(rt.key, partData.Marshal(), partNum, fp) - - cmixMsg.SetKeyFP(fp) - cmixMsg.SetContents(encryptedPart) - cmixMsg.SetMac(mac) - - _, err = rt.AddPart(cmixMsg, partNum) - if err != nil { - t.Errorf("Failed to add part #%d: %+v", partNum, err) - } - } - - return tid, rt, fileData -} - -// newRandomReceivedTransfer generates a new empty ReceivedTransfer. Returns the -// generated transfer ID, new ReceivedTransfer, and the full file. -func newEmptyReceivedTransfer(numParts, numFps uint16, kv *versioned.KV, - t *testing.T) (ftCrypto.TransferID, *ReceivedTransfer, []byte) { - - prng := NewPrng(42) - tid, _ := ftCrypto.NewTransferID(prng) - key, _ := ftCrypto.NewTransferKey(prng) - _, fileData := newRandomPartStore(numParts, kv, prng, t) - fileSize := uint32(len(fileData)) - - mac := ftCrypto.CreateTransferMAC(fileData, key) - - rt, err := NewReceivedTransfer(tid, key, mac, fileSize, numParts, numFps, kv) - if err != nil { - t.Errorf("Failed to create new ReceivedTransfer: %+v", err) - } - - return tid, rt, fileData -} diff --git a/storage/fileTransfer/receivedCallbackTracker.go b/storage/fileTransfer/receivedCallbackTracker.go deleted file mode 100644 index 61e850c5b918ab97e427fe5dc6e75181a6508efb..0000000000000000000000000000000000000000 --- a/storage/fileTransfer/receivedCallbackTracker.go +++ /dev/null @@ -1,137 +0,0 @@ -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - -package fileTransfer - -import ( - "gitlab.com/elixxir/client/interfaces" - "gitlab.com/elixxir/client/stoppable" - "gitlab.com/xx_network/primitives/netTime" - "sync" - "sync/atomic" - "time" -) - -// receivedCallbackTrackerStoppable is the name used for the tracker stoppable. -const receivedCallbackTrackerStoppable = "receivedCallbackTrackerStoppable" - -// receivedCallbackTracker tracks the interfaces.ReceivedProgressCallback and -// information on when to call it. The callback will be called on each file part -// reception, unless the time since the lastCall is smaller than the period. In -// that case, a callback is marked as scheduled and waits to be called at the -// end of the period. A callback is called once every period, regardless of the -// number of receptions that occur. -type receivedCallbackTracker struct { - period time.Duration // How often to call the callback - lastCall time.Time // Timestamp of the last call - scheduled bool // Denotes if callback call is scheduled - completed uint64 // Atomic that tells if transfer is completed - stop *stoppable.Single // Stops the scheduled callback from triggering - cb interfaces.ReceivedProgressCallback - mux sync.RWMutex -} - -// newReceivedCallbackTracker creates a new and unused receivedCallbackTracker. -func newReceivedCallbackTracker(cb interfaces.ReceivedProgressCallback, - period time.Duration) *receivedCallbackTracker { - return &receivedCallbackTracker{ - period: period, - lastCall: time.Time{}, - scheduled: false, - completed: 0, - stop: stoppable.NewSingle(receivedCallbackTrackerStoppable), - cb: cb, - } -} - -// call triggers the progress callback with the most recent progress from the -// receivedProgressTracker. If a callback has been called within the last -// period, then a new call is scheduled to occur at the beginning of the next -// period. If a call is already scheduled, then nothing happens; when the -// callback is finally called, it will do so with the most recent changes. -func (rct *receivedCallbackTracker) call(tracker receivedProgressTracker, err error) { - rct.mux.RLock() - // Exit if a callback is already scheduled - if rct.scheduled || atomic.LoadUint64(&rct.completed) == 1 { - rct.mux.RUnlock() - return - } - - rct.mux.RUnlock() - rct.mux.Lock() - defer rct.mux.Unlock() - - if rct.scheduled { - return - } - - // Check if a callback has occurred within the last period - timeSinceLastCall := netTime.Since(rct.lastCall) - if timeSinceLastCall > rct.period { - // If no callback occurred, then trigger the callback now - rct.callNowUnsafe(false, tracker, err) - rct.lastCall = netTime.Now() - } else { - // If a callback did occur, then schedule a new callback to occur at the - // start of the next period - rct.scheduled = true - go func() { - select { - case <-rct.stop.Quit(): - rct.stop.ToStopped() - return - case <-time.NewTimer(rct.period - timeSinceLastCall).C: - rct.mux.Lock() - rct.callNow(false, tracker, err) - rct.lastCall = netTime.Now() - rct.scheduled = false - rct.mux.Unlock() - } - }() - } -} - -// stopThread stops all scheduled callbacks. -func (rct *receivedCallbackTracker) stopThread() error { - return rct.stop.Close() -} - -// callNow calls the callback immediately regardless of the schedule or period. -func (rct *receivedCallbackTracker) callNow(skipCompletedCheck bool, - tracker receivedProgressTracker, err error) { - completed, received, total, t := tracker.GetProgress() - if skipCompletedCheck || !completed || - atomic.CompareAndSwapUint64(&rct.completed, 0, 1) { - go rct.cb(completed, received, total, t, err) - } -} - -// callNowUnsafe calls the callback immediately regardless of the schedule or -// period without taking a thread lock. This function should be used if a lock -// is already taken on the receivedProgressTracker. -func (rct *receivedCallbackTracker) callNowUnsafe(skipCompletedCheck bool, - tracker receivedProgressTracker, err error) { - completed, received, total, t := tracker.getProgress() - if skipCompletedCheck || !completed || - atomic.CompareAndSwapUint64(&rct.completed, 0, 1) { - go rct.cb(completed, received, total, t, err) - } -} - -// receivedProgressTracker interface tracks the progress of a transfer. -type receivedProgressTracker interface { - // GetProgress returns the received transfer progress in a thread-safe - // manner. - GetProgress() ( - completed bool, received, total uint16, t interfaces.FilePartTracker) - - // getProgress returns the received transfer progress in a thread-unsafe - // manner. This function should be used if a lock is already taken on the - // sent transfer. - getProgress() ( - completed bool, received, total uint16, t interfaces.FilePartTracker) -} diff --git a/storage/fileTransfer/receivedCallbackTracker_test.go b/storage/fileTransfer/receivedCallbackTracker_test.go deleted file mode 100644 index ed5f23390ec4fe4ec01bb16a0125a3070d586111..0000000000000000000000000000000000000000 --- a/storage/fileTransfer/receivedCallbackTracker_test.go +++ /dev/null @@ -1,202 +0,0 @@ -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - -package fileTransfer - -import ( - "gitlab.com/elixxir/client/interfaces" - "reflect" - "testing" - "time" -) - -// Tests that newReceivedCallbackTracker returns the expected -// receivedCallbackTracker and that the callback triggers correctly. -func Test_newReceivedCallbackTracker(t *testing.T) { - type cbFields struct { - completed bool - received, total uint16 - err error - } - - cbChan := make(chan cbFields) - cbFunc := func(completed bool, received, total uint16, - t interfaces.FilePartTracker, err error) { - cbChan <- cbFields{completed, received, total, err} - } - - expectedRCT := &receivedCallbackTracker{ - period: time.Millisecond, - lastCall: time.Time{}, - scheduled: false, - cb: cbFunc, - } - - receivedRCT := newReceivedCallbackTracker(expectedRCT.cb, expectedRCT.period) - - go receivedRCT.cb(false, 0, 0, nil, nil) - - select { - case <-time.NewTimer(time.Millisecond).C: - t.Error("Timed out waiting for callback to be called.") - case r := <-cbChan: - err := checkReceivedProgress(r.completed, r.received, r.total, false, 0, 0) - if err != nil { - t.Error(err) - } - } - - // Nil the callbacks so that DeepEqual works - receivedRCT.cb = nil - expectedRCT.cb = nil - - receivedRCT.stop = expectedRCT.stop - - if !reflect.DeepEqual(expectedRCT, receivedRCT) { - t.Errorf("New receivedCallbackTracker does not match expected."+ - "\nexpected: %+v\nreceived: %+v", expectedRCT, receivedRCT) - } -} - -// Tests that receivedCallbackTracker.call calls the tracker immediately when -// no other calls are scheduled and that it schedules a call to the tracker when -// one has been called recently. -func Test_receivedCallbackTracker_call(t *testing.T) { - type cbFields struct { - completed bool - received, total uint16 - err error - } - - cbChan := make(chan cbFields) - cbFunc := func(completed bool, received, total uint16, - t interfaces.FilePartTracker, err error) { - cbChan <- cbFields{completed, received, total, err} - } - - rct := newReceivedCallbackTracker(cbFunc, 50*time.Millisecond) - - tracker := testReceiveTrack{false, 1, 3, receivedPartTracker{}} - rct.call(tracker, nil) - - select { - case <-time.NewTimer(10 * time.Millisecond).C: - t.Error("Timed out waiting for callback to be called.") - case r := <-cbChan: - err := checkReceivedProgress(r.completed, r.received, r.total, false, 1, 3) - if err != nil { - t.Error(err) - } - } - - tracker = testReceiveTrack{true, 3, 3, receivedPartTracker{}} - rct.call(tracker, nil) - - select { - case <-time.NewTimer(10 * time.Millisecond).C: - if !rct.scheduled { - t.Error("Callback should be scheduled.") - } - case r := <-cbChan: - t.Errorf("Received message when period of %s should not have been "+ - "reached: %+v", rct.period, r) - } - - rct.call(tracker, nil) - - select { - case <-time.NewTimer(60 * time.Millisecond).C: - t.Error("Timed out waiting for callback to be called.") - case r := <-cbChan: - err := checkReceivedProgress(r.completed, r.received, r.total, true, 3, 3) - if err != nil { - t.Error(err) - } - } -} - -// Tests that receivedCallbackTracker.stopThread prevents a scheduled call to -// the tracker from occurring. -func Test_receivedCallbackTracker_stopThread(t *testing.T) { - type cbFields struct { - completed bool - received, total uint16 - err error - } - - cbChan := make(chan cbFields) - cbFunc := func(completed bool, received, total uint16, - t interfaces.FilePartTracker, err error) { - cbChan <- cbFields{completed, received, total, err} - } - - rct := newReceivedCallbackTracker(cbFunc, 50*time.Millisecond) - - tracker := testReceiveTrack{false, 1, 3, receivedPartTracker{}} - rct.call(tracker, nil) - - select { - case <-time.NewTimer(10 * time.Millisecond).C: - t.Error("Timed out waiting for callback to be called.") - case r := <-cbChan: - err := checkReceivedProgress(r.completed, r.received, r.total, false, 1, 3) - if err != nil { - t.Error(err) - } - } - - tracker = testReceiveTrack{true, 3, 3, receivedPartTracker{}} - rct.call(tracker, nil) - - select { - case <-time.NewTimer(10 * time.Millisecond).C: - if !rct.scheduled { - t.Error("Callback should be scheduled.") - } - case r := <-cbChan: - t.Errorf("Received message when period of %s should not have been "+ - "reached: %+v", rct.period, r) - } - - rct.call(tracker, nil) - - err := rct.stopThread() - if err != nil { - t.Errorf("stopThread returned an error: %+v", err) - } - - select { - case <-time.NewTimer(60 * time.Millisecond).C: - case r := <-cbChan: - t.Errorf("Received message when period of %s should not have been "+ - "reached: %+v", rct.period, r) - } -} - -// Tests that ReceivedTransfer satisfies the receivedProgressTracker interface. -func TestReceivedTransfer_ReceivedProgressTrackerInterface(t *testing.T) { - var _ receivedProgressTracker = &ReceivedTransfer{} -} - -// testReceiveTrack is a test structure that satisfies the -// receivedProgressTracker interface. -type testReceiveTrack struct { - completed bool - received, total uint16 - t receivedPartTracker -} - -func (trt testReceiveTrack) getProgress() (completed bool, received, - total uint16, t interfaces.FilePartTracker) { - return trt.completed, trt.received, trt.total, trt.t -} - -// GetProgress returns the values in the testTrack. -func (trt testReceiveTrack) GetProgress() (completed bool, received, - total uint16, t interfaces.FilePartTracker) { - return trt.completed, trt.received, trt.total, trt.t -} diff --git a/storage/fileTransfer/receivedPartTracker.go b/storage/fileTransfer/receivedPartTracker.go deleted file mode 100644 index b509b39479dc2396742334af72346971915a5416..0000000000000000000000000000000000000000 --- a/storage/fileTransfer/receivedPartTracker.go +++ /dev/null @@ -1,48 +0,0 @@ -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - -package fileTransfer - -import ( - "gitlab.com/elixxir/client/interfaces" - "gitlab.com/elixxir/client/storage/utility" -) - -// receivedPartTracker tracks the status of individual received file parts. -type receivedPartTracker struct { - // The number of file parts in the file - numParts uint16 - - // Stores the received status for each file part in a bitstream format - receivedStatus *utility.StateVector -} - -// newReceivedPartTracker creates a new receivedPartTracker with copies of the -// received status state vectors. -func newReceivedPartTracker(received *utility.StateVector) receivedPartTracker { - return receivedPartTracker{ - numParts: uint16(received.GetNumKeys()), - receivedStatus: received.DeepCopy(), - } -} - -// GetPartStatus returns the status of the received file part with the given -// part number. The possible values for the status are: -// 0 = unreceived -// 3 = received (receiver has received a part) -func (rpt receivedPartTracker) GetPartStatus(partNum uint16) interfaces.FpStatus { - if rpt.receivedStatus.Used(uint32(partNum)) { - return interfaces.FpReceived - } else { - return interfaces.FpUnsent - } -} - -// GetNumParts returns the total number of file parts in the transfer. -func (rpt receivedPartTracker) GetNumParts() uint16 { - return rpt.numParts -} diff --git a/storage/fileTransfer/receivedPartTracker_test.go b/storage/fileTransfer/receivedPartTracker_test.go deleted file mode 100644 index 64f390058bda927f67d743d78e58cb15cb52f923..0000000000000000000000000000000000000000 --- a/storage/fileTransfer/receivedPartTracker_test.go +++ /dev/null @@ -1,90 +0,0 @@ -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - -package fileTransfer - -import ( - "gitlab.com/elixxir/client/interfaces" - "gitlab.com/elixxir/client/storage/versioned" - "gitlab.com/elixxir/ekv" - "math/rand" - "reflect" - "testing" -) - -// Tests that receivedPartTracker satisfies the interfaces.FilePartTracker -// interface. -func TestReceivedPartTracker_FilePartTrackerInterface(t *testing.T) { - var _ interfaces.FilePartTracker = receivedPartTracker{} -} - -// Tests that newReceivedPartTracker returns a new receivedPartTracker with the -// expected values. -func Test_newReceivedPartTracker(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - _, rt, _ := newRandomReceivedTransfer(16, 24, kv, t) - - expected := receivedPartTracker{ - numParts: rt.numParts, - receivedStatus: rt.receivedStatus.DeepCopy(), - } - - newRPT := newReceivedPartTracker(rt.receivedStatus) - - if !reflect.DeepEqual(expected, newRPT) { - t.Errorf("New receivedPartTracker does not match expected."+ - "\nexpected: %+v\nreceived: %+v", expected, newRPT) - } -} - -// Tests that receivedPartTracker.GetPartStatus returns the expected status for -// each part loaded from a preconfigured ReceivedTransfer. -func TestReceivedPartTracker_GetPartStatus(t *testing.T) { - // Create new ReceivedTransfer - kv := versioned.NewKV(make(ekv.Memstore)) - _, rt, _ := newEmptyReceivedTransfer(16, 24, kv, t) - - // Set statuses of parts in the ReceivedTransfer and a map randomly - prng := rand.New(rand.NewSource(42)) - partStatuses := make(map[uint16]interfaces.FpStatus, rt.numParts) - for partNum := uint16(0); partNum < rt.numParts; partNum++ { - partStatuses[partNum] = - interfaces.FpStatus(prng.Intn(2)) * interfaces.FpReceived - - if partStatuses[partNum] == interfaces.FpReceived { - rt.receivedStatus.Use(uint32(partNum)) - } - } - - // Create a new receivedPartTracker from the ReceivedTransfer - rpt := newReceivedPartTracker(rt.receivedStatus) - - // Check that the statuses for each part matches the map - for partNum := uint16(0); partNum < rt.numParts; partNum++ { - if rpt.GetPartStatus(partNum) != partStatuses[partNum] { - t.Errorf("Part number %d does not have expected status."+ - "\nexpected: %d\nreceived: %d", - partNum, partStatuses[partNum], rpt.GetPartStatus(partNum)) - } - } -} - -// Tests that receivedPartTracker.GetNumParts returns the same number of parts -// as the receivedPartTracker it was created from. -func TestReceivedPartTracker_GetNumParts(t *testing.T) { - // Create new ReceivedTransfer - kv := versioned.NewKV(make(ekv.Memstore)) - _, rt, _ := newEmptyReceivedTransfer(16, 24, kv, t) - - // Create a new receivedPartTracker from the ReceivedTransfer - rpt := newReceivedPartTracker(rt.receivedStatus) - - if rpt.GetNumParts() != rt.GetNumParts() { - t.Errorf("Number of parts incorrect.\nexpected: %d\nreceived: %d", - rt.GetNumParts(), rpt.GetNumParts()) - } -} diff --git a/storage/fileTransfer/sentCallbackTracker.go b/storage/fileTransfer/sentCallbackTracker.go deleted file mode 100644 index 7381422c0cf98f053e47896ae00831bd8eff1cb3..0000000000000000000000000000000000000000 --- a/storage/fileTransfer/sentCallbackTracker.go +++ /dev/null @@ -1,136 +0,0 @@ -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - -package fileTransfer - -import ( - "gitlab.com/elixxir/client/interfaces" - "gitlab.com/elixxir/client/stoppable" - "gitlab.com/xx_network/primitives/netTime" - "sync" - "sync/atomic" - "time" -) - -// sentCallbackTrackerStoppable is the name used for the tracker stoppable. -const sentCallbackTrackerStoppable = "sentCallbackTrackerStoppable" - -// sentCallbackTracker tracks the interfaces.SentProgressCallback and -// information on when to call it. The callback will be called on each send, -// unless the time since the lastCall is smaller than the period. In that case, -// a callback is marked as scheduled and waits to be called at the end of the -// period. A callback is called once every period, regardless of the number of -// sends that occur. -type sentCallbackTracker struct { - period time.Duration // How often to call the callback - lastCall time.Time // Timestamp of the last call - scheduled bool // Denotes if callback call is scheduled - completed uint64 // Atomic that tells if transfer is completed - stop *stoppable.Single // Stops the scheduled callback from triggering - cb interfaces.SentProgressCallback - mux sync.RWMutex -} - -// newSentCallbackTracker creates a new and unused sentCallbackTracker. -func newSentCallbackTracker(cb interfaces.SentProgressCallback, - period time.Duration) *sentCallbackTracker { - return &sentCallbackTracker{ - period: period, - lastCall: time.Time{}, - scheduled: false, - completed: 0, - stop: stoppable.NewSingle(sentCallbackTrackerStoppable), - cb: cb, - } -} - -// call triggers the progress callback with the most recent progress from the -// sentProgressTracker. If a callback has been called within the last period, -// then a new call is scheduled to occur at the beginning of the next period. If -// a call is already scheduled, then nothing happens; when the callback is -// finally called, it will do so with the most recent changes. -func (sct *sentCallbackTracker) call(tracker sentProgressTracker, err error) { - sct.mux.RLock() - // Exit if a callback is already scheduled - if sct.scheduled || atomic.LoadUint64(&sct.completed) == 1 { - sct.mux.RUnlock() - return - } - - sct.mux.RUnlock() - sct.mux.Lock() - defer sct.mux.Unlock() - - if sct.scheduled { - return - } - - // Check if a callback has occurred within the last period - timeSinceLastCall := netTime.Since(sct.lastCall) - if timeSinceLastCall > sct.period { - // If no callback occurred, then trigger the callback now - sct.callNowUnsafe(false, tracker, err) - sct.lastCall = netTime.Now() - } else { - // If a callback did occur, then schedule a new callback to occur at the - // start of the next period - sct.scheduled = true - go func() { - select { - case <-sct.stop.Quit(): - sct.stop.ToStopped() - return - case <-time.NewTimer(sct.period - timeSinceLastCall).C: - sct.mux.Lock() - sct.callNow(false, tracker, err) - sct.lastCall = netTime.Now() - sct.scheduled = false - sct.mux.Unlock() - } - }() - } -} - -// stopThread stops all scheduled callbacks. -func (sct *sentCallbackTracker) stopThread() error { - return sct.stop.Close() -} - -// callNow calls the callback immediately regardless of the schedule or period. -func (sct *sentCallbackTracker) callNow(skipCompletedCheck bool, - tracker sentProgressTracker, err error) { - completed, sent, arrived, total, t := tracker.GetProgress() - if skipCompletedCheck || !completed || - atomic.CompareAndSwapUint64(&sct.completed, 0, 1) { - go sct.cb(completed, sent, arrived, total, t, err) - } -} - -// callNowUnsafe calls the callback immediately regardless of the schedule or -// period without taking a thread lock. This function should be used if a lock -// is already taken on the sentProgressTracker. -func (sct *sentCallbackTracker) callNowUnsafe(skipCompletedCheck bool, - tracker sentProgressTracker, err error) { - completed, sent, arrived, total, t := tracker.getProgress() - if skipCompletedCheck || !completed || - atomic.CompareAndSwapUint64(&sct.completed, 0, 1) { - go sct.cb(completed, sent, arrived, total, t, err) - } -} - -// sentProgressTracker interface tracks the progress of a transfer. -type sentProgressTracker interface { - // GetProgress returns the sent transfer progress in a thread-safe manner. - GetProgress() ( - completed bool, sent, arrived, total uint16, t interfaces.FilePartTracker) - - // getProgress returns the sent transfer progress in a thread-unsafe manner. - // This function should be used if a lock is already taken on the sent - // transfer. - getProgress() ( - completed bool, sent, arrived, total uint16, t interfaces.FilePartTracker) -} diff --git a/storage/fileTransfer/sentCallbackTracker_test.go b/storage/fileTransfer/sentCallbackTracker_test.go deleted file mode 100644 index fc445d11475cf95d41b843e01f64381829577892..0000000000000000000000000000000000000000 --- a/storage/fileTransfer/sentCallbackTracker_test.go +++ /dev/null @@ -1,208 +0,0 @@ -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - -package fileTransfer - -import ( - "gitlab.com/elixxir/client/interfaces" - "gitlab.com/elixxir/client/stoppable" - "reflect" - "testing" - "time" -) - -// Tests that newSentCallbackTracker returns the expected sentCallbackTracker -// and that the callback triggers correctly. -func Test_newSentCallbackTracker(t *testing.T) { - type cbFields struct { - completed bool - sent, arrived, total uint16 - err error - } - - cbChan := make(chan cbFields) - cbFunc := func(completed bool, sent, arrived, total uint16, - t interfaces.FilePartTracker, err error) { - cbChan <- cbFields{completed, sent, arrived, total, err} - } - - expectedSCT := &sentCallbackTracker{ - period: time.Millisecond, - lastCall: time.Time{}, - scheduled: false, - stop: stoppable.NewSingle(sentCallbackTrackerStoppable), - cb: cbFunc, - } - - receivedSCT := newSentCallbackTracker(expectedSCT.cb, expectedSCT.period) - - go receivedSCT.cb(false, 0, 0, 0, sentPartTracker{}, nil) - - select { - case <-time.NewTimer(time.Millisecond).C: - t.Error("Timed out waiting for callback to be called.") - case r := <-cbChan: - err := checkSentProgress( - r.completed, r.sent, r.arrived, r.total, false, 0, 0, 0) - if err != nil { - t.Error(err) - } - } - - // Nil the callbacks so that DeepEqual works - receivedSCT.cb = nil - expectedSCT.cb = nil - - receivedSCT.stop = expectedSCT.stop - - if !reflect.DeepEqual(expectedSCT, receivedSCT) { - t.Errorf("New sentCallbackTracker does not match expected."+ - "\nexpected: %+v\nreceived: %+v", expectedSCT, receivedSCT) - } -} - -// Tests that sentCallbackTracker.call calls the tracker immediately when no -// other calls are scheduled and that it schedules a call to the tracker when -// one has been called recently. -func Test_sentCallbackTracker_call(t *testing.T) { - type cbFields struct { - completed bool - sent, arrived, total uint16 - err error - } - - cbChan := make(chan cbFields) - cbFunc := func(completed bool, sent, arrived, total uint16, - t interfaces.FilePartTracker, err error) { - cbChan <- cbFields{completed, sent, arrived, total, err} - } - - sct := newSentCallbackTracker(cbFunc, 50*time.Millisecond) - - tracker := testSentTrack{false, 1, 2, 3, sentPartTracker{}} - sct.call(tracker, nil) - - select { - case <-time.NewTimer(10 * time.Millisecond).C: - t.Error("Timed out waiting for callback to be called.") - case r := <-cbChan: - err := checkSentProgress( - r.completed, r.sent, r.arrived, r.total, false, 1, 2, 3) - if err != nil { - t.Error(err) - } - } - - tracker = testSentTrack{false, 1, 2, 3, sentPartTracker{}} - sct.call(tracker, nil) - - select { - case <-time.NewTimer(10 * time.Millisecond).C: - if !sct.scheduled { - t.Error("Callback should be scheduled.") - } - case r := <-cbChan: - t.Errorf("Received message when period of %s should not have been "+ - "reached: %+v", sct.period, r) - } - - sct.call(tracker, nil) - - select { - case <-time.NewTimer(60 * time.Millisecond).C: - t.Error("Timed out waiting for callback to be called.") - case r := <-cbChan: - err := checkSentProgress( - r.completed, r.sent, r.arrived, r.total, false, 1, 2, 3) - if err != nil { - t.Error(err) - } - } -} - -// Tests that sentCallbackTracker.stopThread prevents a scheduled call to the -// tracker from occurring. -func Test_sentCallbackTracker_stopThread(t *testing.T) { - type cbFields struct { - completed bool - sent, arrived, total uint16 - err error - } - - cbChan := make(chan cbFields) - cbFunc := func(completed bool, sent, arrived, total uint16, - t interfaces.FilePartTracker, err error) { - cbChan <- cbFields{completed, sent, arrived, total, err} - } - - sct := newSentCallbackTracker(cbFunc, 50*time.Millisecond) - - tracker := testSentTrack{false, 1, 2, 3, sentPartTracker{}} - sct.call(tracker, nil) - - select { - case <-time.NewTimer(10 * time.Millisecond).C: - t.Error("Timed out waiting for callback to be called.") - case r := <-cbChan: - err := checkSentProgress( - r.completed, r.sent, r.arrived, r.total, false, 1, 2, 3) - if err != nil { - t.Error(err) - } - } - - tracker = testSentTrack{false, 1, 2, 3, sentPartTracker{}} - sct.call(tracker, nil) - - select { - case <-time.NewTimer(10 * time.Millisecond).C: - if !sct.scheduled { - t.Error("Callback should be scheduled.") - } - case r := <-cbChan: - t.Errorf("Received message when period of %s should not have been "+ - "reached: %+v", sct.period, r) - } - - sct.call(tracker, nil) - - err := sct.stopThread() - if err != nil { - t.Errorf("stopThread returned an error: %+v", err) - } - - select { - case <-time.NewTimer(60 * time.Millisecond).C: - case r := <-cbChan: - t.Errorf("Received message when period of %s should not have been "+ - "reached: %+v", sct.period, r) - } -} - -// Tests that SentTransfer satisfies the sentProgressTracker interface. -func TestSentTransfer_SentProgressTrackerInterface(t *testing.T) { - var _ sentProgressTracker = &SentTransfer{} -} - -// testSentTrack is a test structure that satisfies the sentProgressTracker -// interface. -type testSentTrack struct { - completed bool - sent, arrived, total uint16 - t sentPartTracker -} - -func (tst testSentTrack) getProgress() (completed bool, sent, arrived, - total uint16, t interfaces.FilePartTracker) { - return tst.completed, tst.sent, tst.arrived, tst.total, tst.t -} - -// GetProgress returns the values in the testTrack. -func (tst testSentTrack) GetProgress() (completed bool, sent, arrived, - total uint16, t interfaces.FilePartTracker) { - return tst.completed, tst.sent, tst.arrived, tst.total, tst.t -} diff --git a/storage/fileTransfer/sentFileTransfers.go b/storage/fileTransfer/sentFileTransfers.go deleted file mode 100644 index eb14f0e8b0a33e090e9ebed9810ac9db58a1aa1a..0000000000000000000000000000000000000000 --- a/storage/fileTransfer/sentFileTransfers.go +++ /dev/null @@ -1,354 +0,0 @@ -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - -package fileTransfer - -import ( - "bytes" - "github.com/pkg/errors" - jww "github.com/spf13/jwalterweatherman" - "gitlab.com/elixxir/client/interfaces" - "gitlab.com/elixxir/client/storage/versioned" - ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" - "gitlab.com/xx_network/crypto/csprng" - "gitlab.com/xx_network/primitives/id" - "gitlab.com/xx_network/primitives/netTime" - "sync" - "time" -) - -// Storage keys and versions. -const ( - sentFileTransfersStorePrefix = "SentFileTransfersStore" - sentFileTransfersStoreKey = "SentFileTransfers" - sentFileTransfersStoreVersion = 0 -) - -// Error messages. -const ( - saveSentTransfersListErr = "failed to save list of sent items in transfer map to storage: %+v" - loadSentTransfersListErr = "failed to load list of sent items in transfer map from storage: %+v" - loadSentTransfersErr = "[FT] Failed to load sent transfers from storage: %+v" - - newSentTransferErr = "failed to create new sent transfer: %+v" - getSentTransferErr = "sent file transfer not found" - cancelCallbackErr = "[FT] Transfer with ID %s: %+v" - deleteSentTransferErr = "failed to delete sent transfer with ID %s from store: %+v" - - // SentFileTransfersStore.loadTransfers - loadSentTransferWarn = "[FT] Failed to load sent file transfer %d of %d with ID %s: %v" - loadSentTransfersAllErr = "failed to load all %d transfers" -) - -// SentFileTransfersStore contains information for tracking sent file transfers. -type SentFileTransfersStore struct { - transfers map[ftCrypto.TransferID]*SentTransfer - mux sync.Mutex - kv *versioned.KV -} - -// NewSentFileTransfersStore creates a new SentFileTransfersStore with an empty -// map. -func NewSentFileTransfersStore(kv *versioned.KV) (*SentFileTransfersStore, error) { - sft := &SentFileTransfersStore{ - transfers: make(map[ftCrypto.TransferID]*SentTransfer), - kv: kv.Prefix(sentFileTransfersStorePrefix), - } - - return sft, sft.saveTransfersList() -} - -// AddTransfer creates a new empty SentTransfer and adds it to the transfers -// map. -func (sft *SentFileTransfersStore) AddTransfer(recipient *id.ID, - key ftCrypto.TransferKey, parts [][]byte, numFps uint16, - progressCB interfaces.SentProgressCallback, period time.Duration, - rng csprng.Source) (ftCrypto.TransferID, error) { - - sft.mux.Lock() - defer sft.mux.Unlock() - - // Generate new transfer ID - tid, err := ftCrypto.NewTransferID(rng) - if err != nil { - return tid, errors.Errorf(addTransferNewIdErr, err) - } - - // Generate a new SentTransfer and add it to the map - sft.transfers[tid], err = NewSentTransfer( - recipient, tid, key, parts, numFps, progressCB, period, sft.kv) - if err != nil { - return tid, errors.Errorf(newSentTransferErr, err) - } - - // Update list of transfers in storage - err = sft.saveTransfersList() - if err != nil { - return tid, errors.Errorf(saveSentTransfersListErr, err) - } - - return tid, nil -} - -// GetTransfer returns the SentTransfer with the given transfer ID. An error is -// returned if no corresponding transfer is found. -func (sft *SentFileTransfersStore) GetTransfer(tid ftCrypto.TransferID) ( - *SentTransfer, error) { - sft.mux.Lock() - defer sft.mux.Unlock() - - rt, exists := sft.transfers[tid] - if !exists { - return nil, errors.New(getSentTransferErr) - } - - return rt, nil -} - -// DeleteTransfer removes the SentTransfer with the associated transfer ID -// from memory and storage. -func (sft *SentFileTransfersStore) DeleteTransfer(tid ftCrypto.TransferID) error { - sft.mux.Lock() - defer sft.mux.Unlock() - - // Return an error if the transfer does not exist - st, exists := sft.transfers[tid] - if !exists { - return errors.New(getSentTransferErr) - } - - // Cancel any scheduled callbacks - err := st.stopScheduledProgressCB() - if err != nil { - jww.WARN.Print(errors.Errorf(cancelCallbackErr, tid, err)) - } - - // Delete all data the transfer saved to storage - err = st.delete() - if err != nil { - return errors.Errorf(deleteSentTransferErr, tid, err) - } - - // Delete the transfer from memory - delete(sft.transfers, tid) - - // Update the transfers list for the removed transfer - err = sft.saveTransfersList() - if err != nil { - return errors.Errorf(saveSentTransfersListErr, err) - } - - return nil -} - -// GetUnsentParts returns a map of all transfers and a list of their parts that -// have not been sent (parts that were never marked as in-progress). -func (sft *SentFileTransfersStore) GetUnsentParts() ( - map[ftCrypto.TransferID][]uint16, error) { - sft.mux.Lock() - defer sft.mux.Unlock() - unsentParts := map[ftCrypto.TransferID][]uint16{} - - // Get list of unsent part numbers for each transfer - for tid, st := range sft.transfers { - unsentPartNums, err := st.GetUnsentPartNums() - if err != nil { - return nil, err - } - unsentParts[tid] = unsentPartNums - } - - return unsentParts, nil -} - -// GetSentRounds returns a map of all round IDs and which transfers have parts -// sent on those rounds (parts marked in-progress). -func (sft *SentFileTransfersStore) GetSentRounds() map[id.Round][]ftCrypto.TransferID { - sft.mux.Lock() - defer sft.mux.Unlock() - sentRounds := map[id.Round][]ftCrypto.TransferID{} - - // Get list of round IDs that transfers have in-progress rounds on - for tid, st := range sft.transfers { - for _, rid := range st.GetSentRounds() { - sentRounds[rid] = append(sentRounds[rid], tid) - } - } - - return sentRounds -} - -// GetUnsentPartsAndSentRounds returns two maps. The first is a map of all -// transfers and a list of their parts that have not been sent (parts that were -// never marked as in-progress). The second is a map of all round IDs and which -// transfers have parts sent on those rounds (parts marked in-progress). This -// function performs the same operations as GetUnsentParts and GetSentRounds but -// in a single loop. -func (sft *SentFileTransfersStore) GetUnsentPartsAndSentRounds() ( - map[ftCrypto.TransferID][]uint16, map[id.Round][]ftCrypto.TransferID, error) { - sft.mux.Lock() - defer sft.mux.Unlock() - - unsentParts := map[ftCrypto.TransferID][]uint16{} - sentRounds := map[id.Round][]ftCrypto.TransferID{} - - for tid, st := range sft.transfers { - // Get list of unsent part numbers for each transfer - stUnsentParts, err := st.GetUnsentPartNums() - if err != nil { - return nil, nil, err - } - if len(stUnsentParts) > 0 { - unsentParts[tid] = stUnsentParts - } - - // Get list of round IDs that transfers have in-progress rounds on - for _, rid := range st.GetSentRounds() { - sentRounds[rid] = append(sentRounds[rid], tid) - } - } - - return unsentParts, sentRounds, nil -} - -//////////////////////////////////////////////////////////////////////////////// -// Storage Functions // -//////////////////////////////////////////////////////////////////////////////// - -// LoadSentFileTransfersStore loads all SentFileTransfersStore from storage. -// Returns a list of unsent file parts. -func LoadSentFileTransfersStore(kv *versioned.KV) (*SentFileTransfersStore, error) { - sft := &SentFileTransfersStore{ - transfers: make(map[ftCrypto.TransferID]*SentTransfer), - kv: kv.Prefix(sentFileTransfersStorePrefix), - } - - // Get the list of transfer IDs corresponding to each sent transfer from - // storage - transfersList, err := sft.loadTransfersList() - if err != nil { - return nil, errors.Errorf(loadSentTransfersListErr, err) - } - - // Load each transfer in the list from storage into the map - err = sft.loadTransfers(transfersList) - if err != nil { - return nil, errors.Errorf(loadSentTransfersErr, err) - } - - return sft, nil -} - -// NewOrLoadSentFileTransfersStore loads all SentFileTransfersStore from storage -// and returns a list of unsent file parts, if they exist. Otherwise, a new -// SentFileTransfersStore is returned. -func NewOrLoadSentFileTransfersStore(kv *versioned.KV) (*SentFileTransfersStore, - error) { - sft := &SentFileTransfersStore{ - transfers: make(map[ftCrypto.TransferID]*SentTransfer), - kv: kv.Prefix(sentFileTransfersStorePrefix), - } - - // If the transfer list cannot be loaded from storage, then create a new - // SentFileTransfersStore - vo, err := sft.kv.Get( - sentFileTransfersStoreKey, sentFileTransfersStoreVersion) - if err != nil { - return NewSentFileTransfersStore(kv) - } - - // Unmarshal data into list of saved transfer IDs - transfersList := unmarshalTransfersList(vo.Data) - - // Load each transfer in the list from storage into the map - err = sft.loadTransfers(transfersList) - if err != nil { - jww.WARN.Printf(loadSentTransfersErr, err) - return NewSentFileTransfersStore(kv) - } - - return sft, nil -} - -// saveTransfersList saves a list of items in the transfers map to storage. -func (sft *SentFileTransfersStore) saveTransfersList() error { - // Create new versioned object with a list of items in the transfers map - obj := &versioned.Object{ - Version: sentFileTransfersStoreVersion, - Timestamp: netTime.Now(), - Data: sft.marshalTransfersList(), - } - - // Save list of items in the transfers map to storage - return sft.kv.Set( - sentFileTransfersStoreKey, sentFileTransfersStoreVersion, obj) -} - -// loadTransfersList gets the list of transfer IDs corresponding to each saved -// sent transfer from storage. -func (sft *SentFileTransfersStore) loadTransfersList() ([]ftCrypto.TransferID, - error) { - // Get transfers list from storage - vo, err := sft.kv.Get( - sentFileTransfersStoreKey, sentFileTransfersStoreVersion) - if err != nil { - return nil, err - } - - // Unmarshal data into list of saved transfer IDs - return unmarshalTransfersList(vo.Data), nil -} - -// loadTransfers loads each SentTransfer from the list and adds them to the map. -// Returns a map of all transfers and their unsent file part numbers to be used -// to add them back into the queue. -func (sft *SentFileTransfersStore) loadTransfers(list []ftCrypto.TransferID) error { - var err error - var errCount int - - // Load each sentTransfer from storage into the map - for i, tid := range list { - sft.transfers[tid], err = loadSentTransfer(tid, sft.kv) - if err != nil { - jww.WARN.Printf(loadSentTransferWarn, i, len(list), tid, err) - errCount++ - } - } - - // Return an error if all transfers failed to load - if errCount == len(list) { - return errors.Errorf(loadSentTransfersAllErr, len(list)) - } - - return nil -} - -// marshalTransfersList creates a list of all transfer IDs in the transfers map -// and serialises it. -func (sft *SentFileTransfersStore) marshalTransfersList() []byte { - buff := bytes.NewBuffer(nil) - buff.Grow(ftCrypto.TransferIdLength * len(sft.transfers)) - - for tid := range sft.transfers { - buff.Write(tid.Bytes()) - } - - return buff.Bytes() -} - -// unmarshalTransfersList deserializes a byte slice into a list of transfer IDs. -func unmarshalTransfersList(b []byte) []ftCrypto.TransferID { - buff := bytes.NewBuffer(b) - list := make([]ftCrypto.TransferID, 0, buff.Len()/ftCrypto.TransferIdLength) - - const size = ftCrypto.TransferIdLength - for n := buff.Next(size); len(n) == size; n = buff.Next(size) { - list = append(list, ftCrypto.UnmarshalTransferID(n)) - } - - return list -} diff --git a/storage/fileTransfer/sentFileTransfers_test.go b/storage/fileTransfer/sentFileTransfers_test.go deleted file mode 100644 index 336cbb4ca17b577a13e0728f3c499fd2a924fdd3..0000000000000000000000000000000000000000 --- a/storage/fileTransfer/sentFileTransfers_test.go +++ /dev/null @@ -1,726 +0,0 @@ -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - -package fileTransfer - -import ( - "fmt" - "gitlab.com/elixxir/client/storage/versioned" - ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" - "gitlab.com/elixxir/ekv" - "gitlab.com/xx_network/primitives/id" - "gitlab.com/xx_network/primitives/netTime" - "reflect" - "sort" - "strings" - "testing" -) - -// Tests that NewSentFileTransfersStore creates a new object with empty maps and -// that it is saved to storage -func TestNewSentFileTransfersStore(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - expectedSFT := &SentFileTransfersStore{ - transfers: make(map[ftCrypto.TransferID]*SentTransfer), - kv: kv.Prefix(sentFileTransfersStorePrefix), - } - - sft, err := NewSentFileTransfersStore(kv) - - // Check that the new SentFileTransfersStore matches the expected - if !reflect.DeepEqual(expectedSFT, sft) { - t.Errorf("New SentFileTransfersStore does not match expected."+ - "\nexpected: %+v\nreceived: %+v", expectedSFT, sft) - } - - // Ensure that the transfer list is saved to storage - _, err = expectedSFT.kv.Get( - sentFileTransfersStoreKey, sentFileTransfersStoreVersion) - if err != nil { - t.Errorf("Failed to load transfer list from storage: %+v", err) - } -} - -// Tests that SentFileTransfersStore.AddTransfer adds a new transfer to the map -// and that its ID is saved to the list in storage. -func TestSentFileTransfersStore_AddTransfer(t *testing.T) { - prng := NewPrng(42) - kv := versioned.NewKV(make(ekv.Memstore)) - sft, err := NewSentFileTransfersStore(kv) - if err != nil { - t.Fatalf("Failed to create new SentFileTransfersStore: %+v", err) - } - - // Generate info for new transfer - recipient := id.NewIdFromString("recipientID", id.User, t) - key, _ := ftCrypto.NewTransferKey(prng) - parts, _ := newRandomPartSlice(16, prng, t) - numFps := uint16(24) - - // Add the transfer - tid, err := sft.AddTransfer(recipient, key, parts, numFps, nil, 0, prng) - if err != nil { - t.Errorf("AddTransfer returned an error: %+v", err) - } - - _, exists := sft.transfers[tid] - if !exists { - t.Errorf("New transfer %s does not exist in map.", tid) - } - - list, err := sft.loadTransfersList() - if err != nil { - t.Errorf("Failed to load transfer list from storage: %+v", err) - } - - if list[0] != tid { - t.Errorf("Transfer ID saved to storage does not match ID in memory."+ - "\nexpected: %s\nreceived: %s", tid, list[0]) - } -} - -// Error path: tests that SentFileTransfersStore.AddTransfer returns the -// expected error when the PRNG returns an error. -func TestSentFileTransfersStore_AddTransfer_NewTransferIdRngError(t *testing.T) { - prng := NewPrngErr() - kv := versioned.NewKV(make(ekv.Memstore)) - sft, err := NewSentFileTransfersStore(kv) - if err != nil { - t.Fatalf("Failed to create new SentFileTransfersStore: %+v", err) - } - - // Add the transfer - expectedErr := strings.Split(addTransferNewIdErr, "%")[0] - _, err = sft.AddTransfer(nil, ftCrypto.TransferKey{}, nil, 0, nil, 0, prng) - if err == nil || !strings.Contains(err.Error(), expectedErr) { - t.Errorf("AddTransfer did not return the expected error when the PRNG "+ - "should have errored.\nexpected: %s\nrecieved: %+v", expectedErr, err) - } -} - -// Tests that SentFileTransfersStore.GetTransfer returns the expected transfer -// from the map. -func TestSentFileTransfersStore_GetTransfer(t *testing.T) { - prng := NewPrng(42) - kv := versioned.NewKV(make(ekv.Memstore)) - sft, err := NewSentFileTransfersStore(kv) - if err != nil { - t.Fatalf("Failed to create new SentFileTransfersStore: %+v", err) - } - - // Generate info for new transfer - recipient := id.NewIdFromString("recipientID", id.User, t) - key, _ := ftCrypto.NewTransferKey(prng) - parts, _ := newRandomPartSlice(16, prng, t) - numFps := uint16(24) - - tid, err := sft.AddTransfer(recipient, key, parts, numFps, nil, 0, prng) - if err != nil { - t.Errorf("AddTransfer returned an error: %+v", err) - } - - transfer, err := sft.GetTransfer(tid) - if err != nil { - t.Errorf("GetTransfer returned an error: %+v", err) - } - - if !reflect.DeepEqual(sft.transfers[tid], transfer) { - t.Errorf("Received transfer does not match expected."+ - "\nexpected: %+v\nreceived: %+v", sft.transfers[tid], transfer) - } -} - -// Error path: tests that SentFileTransfersStore.GetTransfer returns the -// expected error when the map is empty/there is no transfer with the given -// transfer ID. -func TestSentFileTransfersStore_GetTransfer_NoTransferError(t *testing.T) { - prng := NewPrng(42) - kv := versioned.NewKV(make(ekv.Memstore)) - sft, err := NewSentFileTransfersStore(kv) - if err != nil { - t.Fatalf("Failed to create new SentFileTransfersStore: %+v", err) - } - - tid, _ := ftCrypto.NewTransferID(prng) - - _, err = sft.GetTransfer(tid) - if err == nil || err.Error() != getSentTransferErr { - t.Errorf("GetTransfer did not return the expected error when it is "+ - "empty.\nexpected: %s\nreceived: %+v", getSentTransferErr, err) - } -} - -// Tests that SentFileTransfersStore.DeleteTransfer removes the transfer from -// the map in memory and from the list in storage. -func TestSentFileTransfersStore_DeleteTransfer(t *testing.T) { - prng := NewPrng(42) - kv := versioned.NewKV(make(ekv.Memstore)) - sft, err := NewSentFileTransfersStore(kv) - if err != nil { - t.Fatalf("Failed to create new SentFileTransfersStore: %+v", err) - } - - // Generate info for new transfer - recipient := id.NewIdFromString("recipientID", id.User, t) - key, _ := ftCrypto.NewTransferKey(prng) - parts, _ := newRandomPartSlice(16, prng, t) - numFps := uint16(24) - - tid, err := sft.AddTransfer(recipient, key, parts, numFps, nil, 0, prng) - if err != nil { - t.Errorf("AddTransfer returned an error: %+v", err) - } - - err = sft.DeleteTransfer(tid) - if err != nil { - t.Errorf("DeleteTransfer returned an error: %+v", err) - } - - transfer, err := sft.GetTransfer(tid) - if err == nil { - t.Errorf("No error getting transfer that should be deleted: %+v", - transfer) - } - - list, err := sft.loadTransfersList() - if err != nil { - t.Errorf("Failed to load transfer list from storage: %+v", err) - } - - if len(list) > 0 { - t.Errorf("Transfer ID list in storage not empty: %+v", list) - } -} - -// Error path: tests that SentFileTransfersStore.DeleteTransfer returns the -// expected error when the map is empty/there is no transfer with the given -// transfer ID. -func TestSentFileTransfersStore_DeleteTransfer_NoTransferError(t *testing.T) { - prng := NewPrng(42) - kv := versioned.NewKV(make(ekv.Memstore)) - sft, err := NewSentFileTransfersStore(kv) - if err != nil { - t.Fatalf("Failed to create new SentFileTransfersStore: %+v", err) - } - - tid, _ := ftCrypto.NewTransferID(prng) - - err = sft.DeleteTransfer(tid) - if err == nil || err.Error() != getSentTransferErr { - t.Errorf("DeleteTransfer did not return the expected error when it is "+ - "empty.\nexpected: %s\nreceived: %+v", getSentTransferErr, err) - } -} - -// Tests that SentFileTransfersStore.GetUnsentParts returns the expected unsent -// parts for each transfer. Transfers are created with increasing number of -// parts. Each part of each transfer is set as either in-progress, finished, or -// unsent. -func TestSentFileTransfersStore_GetUnsentParts(t *testing.T) { - prng := NewPrng(42) - kv := versioned.NewKV(make(ekv.Memstore)) - sft, err := NewSentFileTransfersStore(kv) - if err != nil { - t.Fatalf("Failed to create new SentFileTransfersStore: %+v", err) - } - - n := uint16(3) - expectedParts := make(map[ftCrypto.TransferID][]uint16, n) - - // Add new transfers - for i := uint16(0); i < n; i++ { - recipient := id.NewIdFromUInt(uint64(i), id.User, t) - key, _ := ftCrypto.NewTransferKey(prng) - parts, _ := newRandomPartSlice((i+1)*6, prng, t) - numParts := uint16(len(parts)) - numFps := numParts * 3 / 2 - - tid, err := sft.AddTransfer(recipient, key, parts, numFps, nil, 0, prng) - if err != nil { - t.Errorf("Failed to add transfer %d: %+v", i, err) - } - - // Loop through each part and set it individually - for j := uint16(0); j < numParts; j++ { - switch ((j + i) % numParts) % 3 { - case 0: - // Part is sent (in-progress) - _, _ = sft.transfers[tid].SetInProgress(id.Round(j), j) - case 1: - // Part is sent and arrived (finished) - _, _ = sft.transfers[tid].SetInProgress(id.Round(j), j) - _, _ = sft.transfers[tid].FinishTransfer(id.Round(j)) - case 2: - // Part is unsent (neither in-progress nor arrived) - expectedParts[tid] = append(expectedParts[tid], j) - } - } - } - - unsentParts, err := sft.GetUnsentParts() - if err != nil { - t.Errorf("GetUnsentParts returned an error: %+v", err) - } - - if !reflect.DeepEqual(expectedParts, unsentParts) { - t.Errorf("Unexpected unsent parts map.\nexpected: %+v\nreceived: %+v", - expectedParts, unsentParts) - } -} - -// Tests that SentFileTransfersStore.GetSentRounds returns the expected transfer -// ID for each unfinished round. Transfers are created with increasing number of -// parts. Each part of each transfer is set as either in-progress, finished, or -// unsent. -func TestSentFileTransfersStore_GetSentRounds(t *testing.T) { - prng := NewPrng(42) - kv := versioned.NewKV(make(ekv.Memstore)) - sft, err := NewSentFileTransfersStore(kv) - if err != nil { - t.Fatalf("Failed to create new SentFileTransfersStore: %+v", err) - } - - n := uint16(3) - expectedRounds := make(map[id.Round][]ftCrypto.TransferID) - - // Add new transfers - for i := uint16(0); i < n; i++ { - recipient := id.NewIdFromUInt(uint64(i), id.User, t) - key, _ := ftCrypto.NewTransferKey(prng) - parts, _ := newRandomPartSlice((i+1)*6, prng, t) - numParts := uint16(len(parts)) - numFps := numParts * 3 / 2 - - tid, err := sft.AddTransfer(recipient, key, parts, numFps, nil, 0, prng) - if err != nil { - t.Errorf("Failed to add transfer %d: %+v", i, err) - } - - // Loop through each part and set it individually - for j := uint16(0); j < numParts; j++ { - rid := id.Round(j) - switch j % 3 { - case 0: - // Part is sent (in-progress) - _, _ = sft.transfers[tid].SetInProgress(id.Round(j), j) - expectedRounds[rid] = append(expectedRounds[rid], tid) - case 1: - // Part is sent and arrived (finished) - _, _ = sft.transfers[tid].SetInProgress(id.Round(j), j) - _, _ = sft.transfers[tid].FinishTransfer(id.Round(j)) - case 2: - // Part is unsent (neither in-progress nor arrived) - } - } - } - - // Sort expected rounds map transfer IDs - for _, tIDs := range expectedRounds { - sort.Slice(tIDs, - func(i, j int) bool { return tIDs[i].String() < tIDs[j].String() }) - } - - sentRounds := sft.GetSentRounds() - - // Sort sent rounds map transfer IDs - for _, tIDs := range sentRounds { - sort.Slice(tIDs, - func(i, j int) bool { return tIDs[i].String() < tIDs[j].String() }) - } - - if !reflect.DeepEqual(expectedRounds, sentRounds) { - t.Errorf("Unexpected sent rounds map.\nexpected: %+v\nreceived: %+v", - expectedRounds, sentRounds) - } -} - -// Tests that SentFileTransfersStore.GetUnsentPartsAndSentRounds -func TestSentFileTransfersStore_GetUnsentPartsAndSentRounds(t *testing.T) { - prng := NewPrng(42) - kv := versioned.NewKV(make(ekv.Memstore)) - sft, err := NewSentFileTransfersStore(kv) - if err != nil { - t.Fatalf("Failed to create new SentFileTransfersStore: %+v", err) - } - - n := uint16(3) - expectedParts := make(map[ftCrypto.TransferID][]uint16, n) - expectedRounds := make(map[id.Round][]ftCrypto.TransferID) - - // Add new transfers - for i := uint16(0); i < n; i++ { - recipient := id.NewIdFromUInt(uint64(i), id.User, t) - key, _ := ftCrypto.NewTransferKey(prng) - parts, _ := newRandomPartSlice((i+1)*6, prng, t) - numParts := uint16(len(parts)) - numFps := numParts * 3 / 2 - - tid, err := sft.AddTransfer(recipient, key, parts, numFps, nil, 0, prng) - if err != nil { - t.Errorf("Failed to add transfer %d: %+v", i, err) - } - - // Loop through each part and set it individually - for j := uint16(0); j < numParts; j++ { - rid := id.Round(j) - switch j % 3 { - case 0: - // Part is sent (in-progress) - _, _ = sft.transfers[tid].SetInProgress(rid, j) - expectedRounds[rid] = append(expectedRounds[rid], tid) - case 1: - // Part is sent and arrived (finished) - _, _ = sft.transfers[tid].SetInProgress(rid, j) - _, _ = sft.transfers[tid].FinishTransfer(rid) - case 2: - // Part is unsent (neither in-progress nor arrived) - expectedParts[tid] = append(expectedParts[tid], j) - } - } - } - - // Sort expected rounds map transfer IDs - for _, tIDs := range expectedRounds { - sort.Slice(tIDs, - func(i, j int) bool { return tIDs[i].String() < tIDs[j].String() }) - } - - unsentParts, sentRounds, err := sft.GetUnsentPartsAndSentRounds() - if err != nil { - t.Errorf("GetUnsentPartsAndSentRounds returned an error: %+v", err) - } - - // Sort sent rounds map transfer IDs - for _, tIDs := range sentRounds { - sort.Slice(tIDs, - func(i, j int) bool { return tIDs[i].String() < tIDs[j].String() }) - } - - if !reflect.DeepEqual(expectedParts, unsentParts) { - t.Errorf("Unexpected unsent parts map.\nexpected: %+v\nreceived: %+v", - expectedParts, unsentParts) - } - - if !reflect.DeepEqual(expectedRounds, sentRounds) { - t.Errorf("Unexpected sent rounds map.\nexpected: %+v\nreceived: %+v", - expectedRounds, sentRounds) - } - - unsentParts2, err := sft.GetUnsentParts() - if err != nil { - t.Errorf("GetUnsentParts returned an error: %+v", err) - } - - if !reflect.DeepEqual(unsentParts, unsentParts2) { - t.Errorf("Unsent parts from GetUnsentParts and "+ - "GetUnsentPartsAndSentRounds do not match."+ - "\nGetUnsentParts: %+v"+ - "\nGetUnsentPartsAndSentRounds: %+v", - unsentParts, unsentParts2) - } - - sentRounds2 := sft.GetSentRounds() - - // Sort sent rounds map transfer IDs - for _, tIDs := range sentRounds2 { - sort.Slice(tIDs, - func(i, j int) bool { return tIDs[i].String() < tIDs[j].String() }) - } - - if !reflect.DeepEqual(sentRounds, sentRounds2) { - t.Errorf("Sent rounds map from GetSentRounds and "+ - "GetUnsentPartsAndSentRounds do not match."+ - "\nGetSentRounds: %+v"+ - "\nGetUnsentPartsAndSentRounds: %+v", - sentRounds, sentRounds2) - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Storage Functions // -//////////////////////////////////////////////////////////////////////////////// - -// Tests that the SentFileTransfersStore loaded from storage by -// LoadSentFileTransfersStore matches the original in memory. -func TestLoadSentFileTransfersStore(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - sft, err := NewSentFileTransfersStore(kv) - if err != nil { - t.Fatalf("Failed to make new SentFileTransfersStore: %+v", err) - } - - // Add 10 transfers to map in memory - list := make([]ftCrypto.TransferID, 10) - for i := range list { - tid, st := newRandomSentTransfer(16, 24, sft.kv, t) - sft.transfers[tid] = st - list[i] = tid - } - - // Save list to storage - if err = sft.saveTransfersList(); err != nil { - t.Errorf("Faileds to save transfers list: %+v", err) - } - - // Load SentFileTransfersStore from storage - loadedSFT, err := LoadSentFileTransfersStore(kv) - if err != nil { - t.Errorf("LoadSentFileTransfersStore returned an error: %+v", err) - } - - // Equalize all progressCallbacks because reflect.DeepEqual does not seem to - // work on function pointers - for _, tid := range list { - loadedSFT.transfers[tid].progressCallbacks = - sft.transfers[tid].progressCallbacks - } - - if !reflect.DeepEqual(sft, loadedSFT) { - t.Errorf("Loaded SentFileTransfersStore does not match original in "+ - "memory.\nexpected: %+v\nreceived: %+v", sft, loadedSFT) - } -} - -// Error path: tests that LoadSentFileTransfersStore returns the expected error -// when the transfer list cannot be loaded from storage. -func TestLoadSentFileTransfersStore_NoListInStorageError(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - expectedErr := strings.Split(loadSentTransfersListErr, "%")[0] - - // Load SentFileTransfersStore from storage - _, err := LoadSentFileTransfersStore(kv) - if err == nil || !strings.Contains(err.Error(), expectedErr) { - t.Errorf("LoadSentFileTransfersStore did not return the expected "+ - "error when there is no transfer list saved in storage."+ - "\nexpected: %s\nreceived: %+v", expectedErr, err) - } -} - -// Error path: tests that LoadSentFileTransfersStore returns the expected error -// when the first transfer loaded from storage does not exist. -func TestLoadSentFileTransfersStore_NoTransferInStorageError(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - expectedErr := strings.Split(loadSentTransfersErr, "%")[0] - - // Save list of one transfer ID to storage - obj := &versioned.Object{ - Version: sentFileTransfersStoreVersion, - Timestamp: netTime.Now(), - Data: ftCrypto.UnmarshalTransferID([]byte("testID_01")).Bytes(), - } - err := kv.Prefix(sentFileTransfersStorePrefix).Set( - sentFileTransfersStoreKey, sentFileTransfersStoreVersion, obj) - - // Load SentFileTransfersStore from storage - _, err = LoadSentFileTransfersStore(kv) - if err == nil || !strings.Contains(err.Error(), expectedErr) { - t.Errorf("LoadSentFileTransfersStore did not return the expected "+ - "error when there is no transfer saved in storage."+ - "\nexpected: %s\nreceived: %+v", expectedErr, err) - } -} - -// Tests that the SentFileTransfersStore loaded from storage by -// NewOrLoadSentFileTransfersStore matches the original in memory. -func TestNewOrLoadSentFileTransfersStore(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - sft, err := NewSentFileTransfersStore(kv) - if err != nil { - t.Fatalf("Failed to make new SentFileTransfersStore: %+v", err) - } - - // Add 10 transfers to map in memory - list := make([]ftCrypto.TransferID, 10) - for i := range list { - tid, st := newRandomSentTransfer(16, 24, sft.kv, t) - sft.transfers[tid] = st - list[i] = tid - } - - // Save list to storage - if err = sft.saveTransfersList(); err != nil { - t.Errorf("Faileds to save transfers list: %+v", err) - } - - // Load SentFileTransfersStore from storage - loadedSFT, err := NewOrLoadSentFileTransfersStore(kv) - if err != nil { - t.Errorf("NewOrLoadSentFileTransfersStore returned an error: %+v", err) - } - - // Equalize all progressCallbacks because reflect.DeepEqual does not seem - // to work on function pointers - for _, tid := range list { - loadedSFT.transfers[tid].progressCallbacks = - sft.transfers[tid].progressCallbacks - } - - if !reflect.DeepEqual(sft, loadedSFT) { - t.Errorf("Loaded SentFileTransfersStore does not match original in "+ - "memory.\nexpected: %+v\nreceived: %+v", sft, loadedSFT) - } -} - -// Tests that NewOrLoadSentFileTransfersStore returns a new -// SentFileTransfersStore when there is none in storage. -func TestNewOrLoadSentFileTransfersStore_NewSentFileTransfersStore(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - // Load SentFileTransfersStore from storage - loadedSFT, err := NewOrLoadSentFileTransfersStore(kv) - if err != nil { - t.Errorf("NewOrLoadSentFileTransfersStore returned an error: %+v", err) - } - - newSFT, _ := NewSentFileTransfersStore(kv) - - if !reflect.DeepEqual(newSFT, loadedSFT) { - t.Errorf("Returned SentFileTransfersStore does not match new."+ - "\nexpected: %+v\nreceived: %+v", newSFT, loadedSFT) - } -} - -// Tests that SentFileTransfersStore.saveTransfersList saves all the transfer -// IDs to storage by loading them from storage via -// SentFileTransfersStore.loadTransfersList and comparing the list to the list -// in memory. -func TestSentFileTransfersStore_saveTransfersList_loadTransfersList(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - sft := &SentFileTransfersStore{ - transfers: make(map[ftCrypto.TransferID]*SentTransfer), - kv: kv.Prefix(sentFileTransfersStorePrefix), - } - - // Add 10 transfers to map in memory - for i := 0; i < 10; i++ { - tid, st := newRandomSentTransfer(16, 24, sft.kv, t) - sft.transfers[tid] = st - } - - // Save transfer ID list to storage - err := sft.saveTransfersList() - if err != nil { - t.Errorf("saveTransfersList returned an error: %+v", err) - } - - // Get list from storage - list, err := sft.loadTransfersList() - if err != nil { - t.Errorf("loadTransfersList returned an error: %+v", err) - } - - // Check that the list has all the transfer IDs in memory - for _, tid := range list { - if _, exists := sft.transfers[tid]; !exists { - t.Errorf("No transfer for ID %s exists.", tid) - } else { - delete(sft.transfers, tid) - } - } -} - -// Tests that the transfer loaded by SentFileTransfersStore.loadTransfers from -// storage matches the original in memory. -func TestSentFileTransfersStore_loadTransfers(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - sft := &SentFileTransfersStore{ - transfers: make(map[ftCrypto.TransferID]*SentTransfer), - kv: kv.Prefix(sentFileTransfersStorePrefix), - } - - // Add 10 transfers to map in memory - list := make([]ftCrypto.TransferID, 10) - for i := range list { - tid, st := newRandomSentTransfer(16, 24, sft.kv, t) - sft.transfers[tid] = st - list[i] = tid - } - - // Load the transfers into a new SentFileTransfersStore - loadedSft := &SentFileTransfersStore{ - transfers: make(map[ftCrypto.TransferID]*SentTransfer), - kv: kv.Prefix(sentFileTransfersStorePrefix), - } - err := loadedSft.loadTransfers(list) - if err != nil { - t.Errorf("loadTransfers returned an error: %+v", err) - } - - // Equalize all progressCallbacks because reflect.DeepEqual does not seem - // to work on function pointers - for _, tid := range list { - loadedSft.transfers[tid].progressCallbacks = - sft.transfers[tid].progressCallbacks - } - - if !reflect.DeepEqual(sft.transfers, loadedSft.transfers) { - t.Errorf("Transfers loaded from storage does not match transfers in "+ - "memory.\nexpected: %+v\nreceived: %+v", - sft.transfers, loadedSft.transfers) - } -} - -// Error path: tests that SentFileTransfersStore.loadTransfers returns an error -// when all file transfers fail to load from storage. -func TestSentFileTransfersStore_loadTransfers_AllErrors(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)).Prefix(sentFileTransfersStorePrefix) - - // Add 10 transfers to map in memory - list := make([]ftCrypto.TransferID, 10) - for i := range list { - tid, st := newRandomSentTransfer(16, 24, kv, t) - list[i] = tid - err := st.delete() - if err != nil { - t.Errorf("Failed to delete transfer: %+v", err) - } - } - - expectedErr := fmt.Sprintf(loadSentTransfersAllErr, len(list)) - - // Load the transfers into a new SentFileTransfersStore - loadedSft := &SentFileTransfersStore{ - transfers: make(map[ftCrypto.TransferID]*SentTransfer), - kv: kv, - } - err := loadedSft.loadTransfers(list) - if err == nil || err.Error() != expectedErr { - t.Errorf("loadTransfers did not return the expected error when none "+ - "of the transfer could be loaded from storage."+ - "\nexpected: %s\nreceived: %+v", expectedErr, err) - } -} - -// Tests that a transfer list marshalled with -// SentFileTransfersStore.marshalTransfersList and unmarshalled with -// unmarshalTransfersList matches the original. -func TestSentFileTransfersStore_marshalTransfersList_unmarshalTransfersList(t *testing.T) { - prng := NewPrng(42) - sft := &SentFileTransfersStore{ - transfers: make(map[ftCrypto.TransferID]*SentTransfer), - } - - // Add 10 transfers to map in memory - for i := 0; i < 10; i++ { - tid, _ := ftCrypto.NewTransferID(prng) - sft.transfers[tid] = &SentTransfer{} - } - - // Marshal into byte slice - marshalledBytes := sft.marshalTransfersList() - - // Unmarshal marshalled bytes into transfer ID list - list := unmarshalTransfersList(marshalledBytes) - - // Check that the list has all the transfer IDs in memory - for _, tid := range list { - if _, exists := sft.transfers[tid]; !exists { - t.Errorf("No transfer for ID %s exists.", tid) - } else { - delete(sft.transfers, tid) - } - } -} diff --git a/storage/fileTransfer/sentPartTracker.go b/storage/fileTransfer/sentPartTracker.go deleted file mode 100644 index b5d30cf6e8960e404f8171b6ba00ed0fa508d8e6..0000000000000000000000000000000000000000 --- a/storage/fileTransfer/sentPartTracker.go +++ /dev/null @@ -1,56 +0,0 @@ -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - -package fileTransfer - -import ( - jww "github.com/spf13/jwalterweatherman" - "gitlab.com/elixxir/client/interfaces" - "gitlab.com/elixxir/client/storage/utility" -) - -// Error messages. -const ( - // sentPartTracker.GetPartStatus - getInvalidPartErr = "[FT] Failed to get status for part %d: %+v" -) - -// sentPartTracker tracks the status of individual sent file parts. -type sentPartTracker struct { - // The number of file parts in the file - numParts uint16 - - // Stores the status for each file part in a bitstream format - partStats *utility.MultiStateVector -} - -// newSentPartTracker creates a new sentPartTracker with copies of the -// in-progress and finished status state vectors. -func newSentPartTracker(partStats *utility.MultiStateVector) sentPartTracker { - return sentPartTracker{ - numParts: partStats.GetNumKeys(), - partStats: partStats.DeepCopy(), - } -} - -// GetPartStatus returns the status of the sent file part with the given part -// number. The possible values for the status are: -// 0 = unsent -// 1 = sent (sender has sent a part, but it has not arrived) -// 2 = arrived (sender has sent a part, and it has arrived) -func (spt sentPartTracker) GetPartStatus(partNum uint16) interfaces.FpStatus { - status, err := spt.partStats.Get(partNum) - if err != nil { - jww.FATAL.Fatalf(getInvalidPartErr, partNum, err) - } - return interfaces.FpStatus(status) -} - -// GetNumParts returns the total number of file parts in the transfer. -func (spt sentPartTracker) GetNumParts() uint16 { - return spt.numParts -} diff --git a/storage/fileTransfer/sentPartTracker_test.go b/storage/fileTransfer/sentPartTracker_test.go deleted file mode 100644 index af50953b889415c3c11bf4969199ee6d318408fb..0000000000000000000000000000000000000000 --- a/storage/fileTransfer/sentPartTracker_test.go +++ /dev/null @@ -1,105 +0,0 @@ -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - -package fileTransfer - -import ( - "gitlab.com/elixxir/client/interfaces" - "gitlab.com/elixxir/client/storage/versioned" - "gitlab.com/elixxir/ekv" - "math/rand" - "reflect" - "testing" -) - -// Tests that sentPartTracker satisfies the interfaces.FilePartTracker -// interface. -func Test_sentPartTracker_FilePartTrackerInterface(t *testing.T) { - var _ interfaces.FilePartTracker = sentPartTracker{} -} - -// Tests that newSentPartTracker returns a new sentPartTracker with the expected -// values. -func Test_newSentPartTracker(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - _, st := newRandomSentTransfer(16, 24, kv, t) - - expected := sentPartTracker{ - numParts: st.numParts, - partStats: st.partStats.DeepCopy(), - } - - newSPT := newSentPartTracker(st.partStats) - - if !reflect.DeepEqual(expected, newSPT) { - t.Errorf("New sentPartTracker does not match expected."+ - "\nexpected: %+v\nreceived: %+v", expected, newSPT) - } -} - -// Tests that sentPartTracker.GetPartStatus returns the expected status for each -// part loaded from a preconfigured SentTransfer. -func Test_sentPartTracker_GetPartStatus(t *testing.T) { - // Create new SentTransfer - kv := versioned.NewKV(make(ekv.Memstore)) - _, st := newRandomSentTransfer(16, 24, kv, t) - - // Set statuses of parts in the SentTransfer and a map randomly - prng := rand.New(rand.NewSource(42)) - partStatuses := make(map[uint16]interfaces.FpStatus, st.numParts) - for partNum := uint16(0); partNum < st.numParts; partNum++ { - partStatuses[partNum] = interfaces.FpStatus(prng.Intn(3)) - - switch partStatuses[partNum] { - case interfaces.FpSent: - err := st.partStats.Set(partNum, inProgress) - if err != nil { - t.Errorf( - "Failed to set part %d to in-progress: %+v", partNum, err) - } - case interfaces.FpArrived: - err := st.partStats.Set(partNum, inProgress) - if err != nil { - t.Errorf( - "Failed to set part %d to in-progress: %+v", partNum, err) - } - err = st.partStats.Set(partNum, finished) - if err != nil { - t.Errorf("Failed to set part %d to finished: %+v", partNum, err) - } - } - } - - // Create a new sentPartTracker from the SentTransfer - spt := newSentPartTracker(st.partStats) - - // Check that the statuses for each part matches the map - for partNum := uint16(0); partNum < st.numParts; partNum++ { - status := spt.GetPartStatus(partNum) - if status != partStatuses[partNum] { - t.Errorf("Part number %d does not have expected status."+ - "\nexpected: %d\nreceived: %d", - partNum, partStatuses[partNum], status) - } - } -} - -// Tests that sentPartTracker.GetNumParts returns the same number of parts as -// the SentTransfer it was created from. -func Test_sentPartTracker_GetNumParts(t *testing.T) { - // Create new SentTransfer - kv := versioned.NewKV(make(ekv.Memstore)) - _, st := newRandomSentTransfer(16, 24, kv, t) - - // Create a new sentPartTracker from the SentTransfer - spt := newSentPartTracker(st.partStats) - - if spt.GetNumParts() != st.GetNumParts() { - t.Errorf("Number of parts incorrect.\nexpected: %d\nreceived: %d", - st.GetNumParts(), spt.GetNumParts()) - } -} diff --git a/storage/fileTransfer/sentTransfer.go b/storage/fileTransfer/sentTransfer.go deleted file mode 100644 index 768aae51e5c2b672d04c2024d072c328f5c2f887..0000000000000000000000000000000000000000 --- a/storage/fileTransfer/sentTransfer.go +++ /dev/null @@ -1,819 +0,0 @@ -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - -package fileTransfer - -import ( - "bytes" - "encoding/binary" - "github.com/pkg/errors" - jww "github.com/spf13/jwalterweatherman" - "gitlab.com/elixxir/client/interfaces" - "gitlab.com/elixxir/client/storage/utility" - "gitlab.com/elixxir/client/storage/versioned" - ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" - "gitlab.com/elixxir/primitives/format" - "gitlab.com/xx_network/primitives/id" - "gitlab.com/xx_network/primitives/netTime" - "sync" - "time" -) - -// Storage keys and versions. -const ( - sentTransferPrefix = "FileTransferSentTransferStore" - sentTransferKey = "SentTransfer" - sentTransferVersion = 0 - sentFpVectorKey = "SentFingerprintVector" - sentPartStatsVectorKey = "SentPartStatsVector" - sentInProgressVectorKey = "SentInProgressStatusVector" - sentFinishedVectorKey = "SentFinishedStatusVector" -) - -// Error messages. -const ( - // NewSentTransfer - newSentTransferFpVectorErr = "failed to create new state vector for fingerprints: %+v" - newSentTransferPartStoreErr = "failed to create new part store: %+v" - newInProgressTransfersErr = "failed to create new in-progress transfers bundle: %+v" - newFinishedTransfersErr = "failed to create new finished transfers bundle: %+v" - newSentPartStatusVectorErr = "failed to create new multi state vector for part statuses: %+v" - - // SentTransfer.ReInit - reInitSentTransferFpVectorErr = "failed to overwrite fingerprint state vector with new vector: %+v" - reInitInProgressTransfersErr = "failed to overwrite in-progress transfers bundle: %+v" - reInitFinishedTransfersErr = "failed to overwrite finished transfers bundle: %+v" - reInitSentPartStatusVectorErr = "failed to overwrite multi state vector for part statuses: %+v" - - // SentTransfer.IsPartInProgress and SentTransfer.IsPartFinished - getStatusErr = "failed to get status of part %d: %+v" - - // SentTransfer.stopScheduledProgressCB - cancelSentCallbacksErr = "could not cancel %d out of %d sent progress callbacks: %d" - - // SentTransfer.GetUnsentPartNums - getUnsentPartsErr = "cannot get unsent parts: %+v" - - // loadSentTransfer - loadSentStoreErr = "failed to load sent transfer info from storage: %+v" - loadSentFpVectorErr = "failed to load sent fingerprint vector from storage: %+v" - loadSentPartStoreErr = "failed to load sent part store from storage: %+v" - loadInProgressTransfersErr = "failed to load in-progress transfers bundle from storage: %+v" - loadFinishedTransfersErr = "failed to load finished transfers bundle from storage: %+v" - loadSentPartStatusVectorErr = "failed to load multi state vector for part statuses from storage: %+v" - - // SentTransfer.delete - deleteSentTransferInfoErr = "failed to delete sent transfer info from storage: %+v" - deleteSentFpVectorErr = "failed to delete sent fingerprint vector from storage: %+v" - deleteSentFilePartsErr = "failed to delete sent file parts from storage: %+v" - deleteInProgressTransfersErr = "failed to delete in-progress transfers from storage: %+v" - deleteFinishedTransfersErr = "failed to delete finished transfers from storage: %+v" - deleteSentPartStatusVectorErr = "failed to delete multi state vector for part statuses from storage: %+v" - - // SentTransfer.FinishTransfer - noPartsForRoundErr = "no file parts in-progress on round %d" - deleteInProgressPartsErr = "failed to remove file parts on round %d from in-progress: %+v" - - // SentTransfer.GetEncryptedPart - noPartNumErr = "no part with part number %d exists" - maxRetriesErr = "maximum number of retries reached" - fingerprintErr = "could not get fingerprint: %+v" - encryptPartErr = "failed to encrypt file part #%d: %+v" -) - -// MaxRetriesErr is returned as an error when number of file part sending -// retries runs out. This occurs when all the fingerprints in a transfer have -// been used. -var MaxRetriesErr = errors.New(maxRetriesErr) - -// States for parts in the partStats MultiStateVector. -const ( - unsent = iota - inProgress - finished - numStates // The number of part states (for initialisation of the vector) -) - -// sentTransferStateMap prevents illegal state changes for part statuses. -var sentTransferStateMap = [][]bool{ - {false, true, false}, - {true, false, true}, - {false, false, false}, -} - -// SentTransfer contains information and progress data for sending and in- -// progress file transfer. -type SentTransfer struct { - // ID of the recipient of the file transfer - recipient *id.ID - - // The transfer key is a randomly generated key created by the sender and - // used to generate MACs and fingerprints - key ftCrypto.TransferKey - - // The number of file parts in the file - numParts uint16 - - // The number of fingerprints to generate (function of numParts and the - // retry rate) - numFps uint16 - - // Stores the state of a fingerprint (used/unused) in a bitstream format - // (has its own storage backend) - fpVector *utility.StateVector - - // List of all file parts in order to send (has its own storage backend) - sentParts *partStore - - // List of parts per round that are currently transferring - inProgressTransfers *transferredBundle - - // List of parts per round that finished transferring - finishedTransfers *transferredBundle - - // Stores the status of each part in a bitstream format - partStats *utility.MultiStateVector - - // List of callbacks to call for every send - progressCallbacks []*sentCallbackTracker - - // status indicates that the transfer is either done or errored out and - // that no more callbacks should be called - status TransferStatus - - mux sync.RWMutex - kv *versioned.KV -} - -// NewSentTransfer generates a new SentTransfer with the specified transfer key, -// transfer ID, and number of parts. -func NewSentTransfer(recipient *id.ID, tid ftCrypto.TransferID, - key ftCrypto.TransferKey, parts [][]byte, numFps uint16, - progressCB interfaces.SentProgressCallback, period time.Duration, - kv *versioned.KV) (*SentTransfer, error) { - - // Create the SentTransfer object - st := &SentTransfer{ - recipient: recipient, - key: key, - numParts: uint16(len(parts)), - numFps: numFps, - progressCallbacks: []*sentCallbackTracker{}, - status: Running, - kv: kv.Prefix(makeSentTransferPrefix(tid)), - } - - var err error - - // Create new StateVector for storing fingerprint usage - st.fpVector, err = utility.NewStateVector( - st.kv, sentFpVectorKey, uint32(numFps)) - if err != nil { - return nil, errors.Errorf(newSentTransferFpVectorErr, err) - } - - // Create new part store - st.sentParts, err = newPartStoreFromParts(st.kv, parts...) - if err != nil { - return nil, errors.Errorf(newSentTransferPartStoreErr, err) - } - - // Create new in-progress transfer bundle - st.inProgressTransfers, err = newTransferredBundle(inProgressKey, st.kv) - if err != nil { - return nil, errors.Errorf(newInProgressTransfersErr, err) - } - - // Create new finished transfer bundle - st.finishedTransfers, err = newTransferredBundle(finishedKey, st.kv) - if err != nil { - return nil, errors.Errorf(newFinishedTransfersErr, err) - } - - // Create new MultiStateVector for storing part statuses - st.partStats, err = utility.NewMultiStateVector(st.numParts, numStates, - sentTransferStateMap, sentPartStatsVectorKey, st.kv) - if err != nil { - return nil, errors.Errorf(newSentPartStatusVectorErr, err) - } - - // Add first progress callback - if progressCB != nil { - st.AddProgressCB(progressCB, period) - } - - return st, st.saveInfo() -} - -// ReInit resets the SentTransfer to its initial state so that sending can -// restart from the beginning. ReInit is used when the sent transfer runs out of -// retries and a user wants to attempt to resend the entire file again. -func (st *SentTransfer) ReInit(numFps uint16, - progressCB interfaces.SentProgressCallback, period time.Duration) error { - st.mux.Lock() - defer st.mux.Unlock() - - var err error - - // Mark the status as running - st.status = Running - - // Update number of fingerprints and overwrite old fingerprint vector - st.numFps = numFps - st.fpVector, err = utility.NewStateVector( - st.kv, sentFpVectorKey, uint32(numFps)) - if err != nil { - return errors.Errorf(reInitSentTransferFpVectorErr, err) - } - - // Overwrite in-progress transfer bundle - st.inProgressTransfers, err = newTransferredBundle(inProgressKey, st.kv) - if err != nil { - return errors.Errorf(reInitInProgressTransfersErr, err) - } - - // Overwrite finished transfer bundle - st.finishedTransfers, err = newTransferredBundle(finishedKey, st.kv) - if err != nil { - return errors.Errorf(reInitFinishedTransfersErr, err) - } - - // Overwrite new part status MultiStateVector - st.partStats, err = utility.NewMultiStateVector(st.numParts, numStates, - sentTransferStateMap, sentPartStatsVectorKey, st.kv) - if err != nil { - return errors.Errorf(reInitSentPartStatusVectorErr, err) - } - - // Clear callbacks - st.progressCallbacks = []*sentCallbackTracker{} - - // Add first progress callback - if progressCB != nil { - // Add callback - sct := newSentCallbackTracker(progressCB, period) - st.progressCallbacks = append(st.progressCallbacks, sct) - - // Trigger the initial call - sct.callNowUnsafe(true, st, nil) - } - - return nil -} - -// GetRecipient returns the ID of the recipient of the transfer. -func (st *SentTransfer) GetRecipient() *id.ID { - st.mux.RLock() - defer st.mux.RUnlock() - - return st.recipient -} - -// GetTransferKey returns the transfer Key for this sent transfer. -func (st *SentTransfer) GetTransferKey() ftCrypto.TransferKey { - st.mux.RLock() - defer st.mux.RUnlock() - - return st.key -} - -// GetNumParts returns the number of file parts in this transfer. -func (st *SentTransfer) GetNumParts() uint16 { - st.mux.RLock() - defer st.mux.RUnlock() - - return st.numParts -} - -// GetNumFps returns the number of fingerprints. -func (st *SentTransfer) GetNumFps() uint16 { - st.mux.RLock() - defer st.mux.RUnlock() - - return st.numFps -} - -// GetNumAvailableFps returns the number of unused fingerprints. -func (st *SentTransfer) GetNumAvailableFps() uint16 { - st.mux.RLock() - defer st.mux.RUnlock() - - return uint16(st.fpVector.GetNumAvailable()) -} - -// GetStatus returns the status of the sent transfer. -func (st *SentTransfer) GetStatus() TransferStatus { - st.mux.RLock() - defer st.mux.RUnlock() - - return st.status -} - -// IsPartInProgress returns true if the part has successfully been sent. Returns -// false if the part is unsent or finished sending or if the part number is -// invalid. -func (st *SentTransfer) IsPartInProgress(partNum uint16) (bool, error) { - status, err := st.partStats.Get(partNum) - if err != nil { - return false, errors.Errorf(getStatusErr, partNum, err) - } - return status == inProgress, nil -} - -// IsPartFinished returns true if the part has successfully arrived. Returns -// false if the part is unsent or in the process of sending or if the part -// number is invalid. -func (st *SentTransfer) IsPartFinished(partNum uint16) (bool, error) { - status, err := st.partStats.Get(partNum) - if err != nil { - return false, errors.Errorf(getStatusErr, partNum, err) - } - return status == finished, nil -} - -// GetProgress returns the current progress of the transfer. Completed is true -// when all parts have arrived, sent is the number of in-progress parts, arrived -// is the number of finished parts, total is the total number of parts being -// sent, and t is a part status tracker that can be used to get the status of -// individual file parts. -func (st *SentTransfer) GetProgress() (completed bool, sent, arrived, - total uint16, t interfaces.FilePartTracker) { - st.mux.RLock() - defer st.mux.RUnlock() - - completed, sent, arrived, total, t = st.getProgress() - return completed, sent, arrived, total, t -} - -// getProgress is the thread-unsafe helper function for GetProgress. -func (st *SentTransfer) getProgress() (completed bool, sent, arrived, - total uint16, t interfaces.FilePartTracker) { - arrived, _ = st.partStats.GetCount(finished) - sent, _ = st.partStats.GetCount(inProgress) - total = st.numParts - - if sent == 0 && arrived == total { - completed = true - } - - partTracker := newSentPartTracker(st.partStats) - - return completed, sent, arrived, total, partTracker -} - -// CallProgressCB calls all the progress callbacks with the most recent progress -// information. -func (st *SentTransfer) CallProgressCB(err error) { - st.mux.Lock() - - if st.status == Stopping { - st.status = Stopped - } - - st.mux.Unlock() - st.mux.RLock() - defer st.mux.RUnlock() - - for _, cb := range st.progressCallbacks { - cb.call(st, err) - } -} - -// stopScheduledProgressCB cancels all scheduled sent progress callbacks calls. -func (st *SentTransfer) stopScheduledProgressCB() error { - st.mux.Lock() - defer st.mux.Unlock() - - // Tracks the index of callbacks that failed to stop - var failedCallbacks []int - - for i, cb := range st.progressCallbacks { - err := cb.stopThread() - if err != nil { - failedCallbacks = append(failedCallbacks, i) - jww.WARN.Printf("[FT] %s", err) - } - } - - if len(failedCallbacks) > 0 { - return errors.Errorf(cancelSentCallbacksErr, len(failedCallbacks), - len(st.progressCallbacks), failedCallbacks) - } - - return nil -} - -// AddProgressCB appends a new interfaces.SentProgressCallback to the list of -// progress callbacks to be called and calls it. The period is how often the -// callback should be called when there are updates. -func (st *SentTransfer) AddProgressCB(cb interfaces.SentProgressCallback, - period time.Duration) { - st.mux.Lock() - - // Add callback - sct := newSentCallbackTracker(cb, period) - st.progressCallbacks = append(st.progressCallbacks, sct) - - st.mux.Unlock() - - // Trigger the initial call - sct.callNow(true, st, nil) -} - -// GetEncryptedPart gets the specified part, encrypts it, and returns the -// encrypted part along with its MAC, padding, and fingerprint. -func (st *SentTransfer) GetEncryptedPart(partNum uint16, contentsSize int) (encPart, mac []byte, - fp format.Fingerprint, err error) { - st.mux.Lock() - defer st.mux.Unlock() - - // Create new empty file part message of size equal to the available payload - // size in the cMix message - partMsg, err := NewPartMessage(contentsSize) - if err != nil { - return nil, nil, format.Fingerprint{}, err - } - - partMsg.SetPartNum(partNum) - - // Lookup part - part, exists := st.sentParts.getPart(partNum) - if !exists { - return nil, nil, format.Fingerprint{}, - errors.Errorf(noPartNumErr, partNum) - } - - if err = partMsg.SetPart(part); err != nil { - return nil, nil, format.Fingerprint{}, - err - } - - // If all fingerprints have been used but parts still remain, then change - // the status to stopping and return an error specifying that all the - // retries have been used - if st.fpVector.GetNumAvailable() < 1 { - st.status = Stopping - return nil, nil, format.Fingerprint{}, MaxRetriesErr - } - - // get next unused fingerprint number and mark it as used - nextKey, err := st.fpVector.Next() - if err != nil { - return nil, nil, format.Fingerprint{}, - errors.Errorf(fingerprintErr, err) - } - fpNum := uint16(nextKey) - - // Generate fingerprint - fp = ftCrypto.GenerateFingerprint(st.key, fpNum) - - // Encrypt the file part and generate the file part MAC and padding (nonce) - encPart, mac, err = ftCrypto.EncryptPart(st.key, partMsg.Marshal(), fpNum, fp) - if err != nil { - return nil, nil, format.Fingerprint{}, - errors.Errorf(encryptPartErr, partNum, err) - } - - return encPart, mac, fp, err -} - -// SetInProgress adds the specified file part numbers to the in-progress -// transfers for the given round ID. Returns whether the round already exists in -// the list. -func (st *SentTransfer) SetInProgress(rid id.Round, partNums ...uint16) (bool, - error) { - st.mux.Lock() - defer st.mux.Unlock() - - // Check if there is already a round in-progress - _, exists := st.inProgressTransfers.getPartNums(rid) - - // Set parts as in-progress in part status vector - err := st.partStats.SetMany(partNums, inProgress) - if err != nil { - return false, err - } - - return exists, st.inProgressTransfers.addPartNums(rid, partNums...) -} - -// GetInProgress returns a list of all part number in the in-progress transfers -// list. -func (st *SentTransfer) GetInProgress(rid id.Round) ([]uint16, bool) { - st.mux.Lock() - defer st.mux.Unlock() - - return st.inProgressTransfers.getPartNums(rid) -} - -// UnsetInProgress removed the file part numbers from the in-progress transfers -// for the given round ID. Returns the list of part numbers that were removed -// from the list. -func (st *SentTransfer) UnsetInProgress(rid id.Round) ([]uint16, error) { - st.mux.Lock() - defer st.mux.Unlock() - - // get the list of part numbers to be removed from list - partNums, _ := st.inProgressTransfers.getPartNums(rid) - - // The part status is set in partStats before the parts and round ID so that - // in the event of recovery after a crash, the parts will be resent on a new - // round and the parts in the inProgressTransfers will be left until deleted - // with the rest of the storage on transfer completion. The side effect is - // that on recovery, the status of the round will be looked up again and the - // progress callback will be called for an event that has already been - // called on the callback. - - // Set parts as unsent in part status vector - err := st.partStats.SetMany(partNums, unsent) - if err != nil { - return nil, err - } - - return partNums, st.inProgressTransfers.deletePartNums(rid) -} - -// FinishTransfer moves the in-progress file parts for the given round to the -// finished list. Returns true if all file parts have been marked as finished -// and false otherwise. -func (st *SentTransfer) FinishTransfer(rid id.Round) (bool, error) { - st.mux.Lock() - defer st.mux.Unlock() - - // get the parts in-progress for the round ID or return an error if none - // exist - partNums, exists := st.inProgressTransfers.getPartNums(rid) - if !exists { - return false, errors.Errorf(noPartsForRoundErr, rid) - } - - // Delete the parts from the in-progress list - err := st.inProgressTransfers.deletePartNums(rid) - if err != nil { - return false, errors.Errorf(deleteInProgressPartsErr, rid, err) - } - - // Add the parts to the finished list - err = st.finishedTransfers.addPartNums(rid, partNums...) - if err != nil { - return false, err - } - - // Set parts as finished in part status vector - err = st.partStats.SetMany(partNums, finished) - if err != nil { - return false, err - } - - // If all parts have been moved to the finished list, then set the status - // to stopping - if st.finishedTransfers.getNumParts() == st.numParts && - st.inProgressTransfers.getNumParts() == 0 { - st.status = Stopping - return true, nil - } - - return false, nil -} - -// GetUnsentPartNums returns a list of part numbers that have not been sent. -func (st *SentTransfer) GetUnsentPartNums() ([]uint16, error) { - st.mux.RLock() - defer st.mux.RUnlock() - - // get list of parts with a status of unsent - unsentPartNums, err := st.partStats.GetKeys(unsent) - if err != nil { - return nil, errors.Errorf(getUnsentPartsErr, err) - } - - return unsentPartNums, nil -} - -// GetSentRounds returns a list of round IDs that parts were sent on (in- -// progress parts) that were never marked as finished. -func (st *SentTransfer) GetSentRounds() []id.Round { - sentRounds := make([]id.Round, 0, len(st.inProgressTransfers.list)) - - for rid := range st.inProgressTransfers.list { - sentRounds = append(sentRounds, rid) - } - - return sentRounds -} - -//////////////////////////////////////////////////////////////////////////////// -// Storage Functions // -//////////////////////////////////////////////////////////////////////////////// - -// loadSentTransfer loads the SentTransfer with the given transfer ID from -// storage. -func loadSentTransfer(tid ftCrypto.TransferID, kv *versioned.KV) (*SentTransfer, - error) { - st := &SentTransfer{ - kv: kv.Prefix(makeSentTransferPrefix(tid)), - } - - // Load transfer key and number of sent parts from storage - err := st.loadInfo() - if err != nil { - return nil, errors.Errorf(loadSentStoreErr, err) - } - - // Load the fingerprint vector from storage - st.fpVector, err = utility.LoadStateVector(st.kv, sentFpVectorKey) - if err != nil { - return nil, errors.Errorf(loadSentFpVectorErr, err) - } - - // Load sent part store from storage - st.sentParts, err = loadPartStore(st.kv) - if err != nil { - return nil, errors.Errorf(loadSentPartStoreErr, err) - } - - // Load in-progress transfer bundle from storage - st.inProgressTransfers, err = loadTransferredBundle(inProgressKey, st.kv) - if err != nil { - return nil, errors.Errorf(loadInProgressTransfersErr, err) - } - - // Load finished transfer bundle from storage - st.finishedTransfers, err = loadTransferredBundle(finishedKey, st.kv) - if err != nil { - return nil, errors.Errorf(loadFinishedTransfersErr, err) - } - - // Load the part status MultiStateVector from storage - st.partStats, err = utility.LoadMultiStateVector( - sentTransferStateMap, sentPartStatsVectorKey, st.kv) - if err != nil { - return nil, errors.Errorf(loadSentPartStatusVectorErr, err) - } - - return st, nil -} - -// saveInfo saves all fields in SentTransfer that do not have their own storage -// (recipient ID, transfer key, number of file parts, number of fingerprints, -// and transfer status) to storage. -func (st *SentTransfer) saveInfo() error { - st.mux.Lock() - defer st.mux.Unlock() - - // Create new versioned object for the SentTransfer - obj := &versioned.Object{ - Version: sentTransferVersion, - Timestamp: netTime.Now(), - Data: st.marshal(), - } - - // Save versioned object - return st.kv.Set(sentTransferKey, sentTransferVersion, obj) -} - -// loadInfo gets the recipient ID, transfer key, number of part, number of -// fingerprints, and transfer status from storage and saves it to the -// SentTransfer. -func (st *SentTransfer) loadInfo() error { - vo, err := st.kv.Get(sentTransferKey, sentTransferVersion) - if err != nil { - return err - } - - // Unmarshal the transfer key and numParts - st.recipient, st.key, st.numParts, st.numFps, st.status = - unmarshalSentTransfer(vo.Data) - - return nil -} - -// delete deletes all data in the SentTransfer from storage. -func (st *SentTransfer) delete() error { - st.mux.Lock() - defer st.mux.Unlock() - - // Delete sent transfer info from storage - err := st.deleteInfo() - if err != nil { - return errors.Errorf(deleteSentTransferInfoErr, err) - } - - // Delete fingerprint vector from storage - err = st.fpVector.Delete() - if err != nil { - return errors.Errorf(deleteSentFpVectorErr, err) - } - - // Delete sent file parts from storage - err = st.sentParts.delete() - if err != nil { - return errors.Errorf(deleteSentFilePartsErr, err) - } - - // Delete in-progress transfer bundles from storage - err = st.inProgressTransfers.delete() - if err != nil { - return errors.Errorf(deleteInProgressTransfersErr, err) - } - - // Delete finished transfer bundles from storage - err = st.finishedTransfers.delete() - if err != nil { - return errors.Errorf(deleteFinishedTransfersErr, err) - } - - // Delete the part status MultiStateVector from storage - err = st.partStats.Delete() - if err != nil { - return errors.Errorf(deleteSentPartStatusVectorErr, err) - } - - return nil -} - -// deleteInfo removes received transfer info (recipient, transfer key, number -// of parts, and number of fingerprints) from storage. -func (st *SentTransfer) deleteInfo() error { - return st.kv.Delete(sentTransferKey, sentTransferVersion) -} - -// marshal serializes all primitive fields in SentTransfer (recipient, key, -// numParts, numFps, and status). -func (st *SentTransfer) marshal() []byte { - // Construct the buffer to the correct size - // (size of ID + size of key + numParts (2 bytes) + numFps (2 bytes)) - buff := bytes.NewBuffer(nil) - buff.Grow(id.ArrIDLen + ftCrypto.TransferKeyLength + 2 + 2) - - // Write the recipient ID to the buffer - if st.recipient != nil { - buff.Write(st.recipient.Marshal()) - } else { - buff.Write((&id.ID{}).Marshal()) - } - - // Write the key to the buffer - buff.Write(st.key.Bytes()) - - // Write the number of parts to the buffer - b := make([]byte, 2) - binary.LittleEndian.PutUint16(b, st.numParts) - buff.Write(b) - - // Write the number of fingerprints to the buffer - b = make([]byte, 2) - binary.LittleEndian.PutUint16(b, st.numFps) - buff.Write(b) - - // Write the transfer status to the buffer - buff.Write(st.status.Marshal()) - - // Return the serialized data - return buff.Bytes() -} - -// unmarshalSentTransfer deserializes a byte slice into the primitive fields -// of SentTransfer (recipient, key, numParts, numFps, and status). -func unmarshalSentTransfer(b []byte) (recipient *id.ID, - key ftCrypto.TransferKey, numParts, numFps uint16, status TransferStatus) { - - buff := bytes.NewBuffer(b) - - // Read the recipient ID from the buffer - recipient = &id.ID{} - copy(recipient[:], buff.Next(id.ArrIDLen)) - - // Read the transfer key from the buffer - key = ftCrypto.UnmarshalTransferKey(buff.Next(ftCrypto.TransferKeyLength)) - - // Read the number of part from the buffer - numParts = binary.LittleEndian.Uint16(buff.Next(2)) - - // Read the number of fingerprints from the buffer - numFps = binary.LittleEndian.Uint16(buff.Next(2)) - - // Read the transfer status from the buffer - status = UnmarshalTransferStatus(buff.Next(8)) - - return recipient, key, numParts, numFps, status -} - -// makeSentTransferPrefix generates the unique prefix used on the key value -// store to store sent transfers for the given transfer ID. -func makeSentTransferPrefix(tid ftCrypto.TransferID) string { - return sentTransferPrefix + tid.String() -} - -// uint16SliceToUint32Slice converts a slice of uint16 to a slice of uint32. -func uint16SliceToUint32Slice(slice []uint16) []uint32 { - newSlice := make([]uint32, len(slice)) - for i, val := range slice { - newSlice[i] = uint32(val) - } - return newSlice -} diff --git a/storage/fileTransfer/sentTransfer_test.go b/storage/fileTransfer/sentTransfer_test.go deleted file mode 100644 index aab3d2a6ff1272cdceeee0d90b630a3d9cca79d5..0000000000000000000000000000000000000000 --- a/storage/fileTransfer/sentTransfer_test.go +++ /dev/null @@ -1,1780 +0,0 @@ -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - -package fileTransfer - -import ( - "bytes" - "fmt" - "github.com/pkg/errors" - "gitlab.com/elixxir/client/interfaces" - "gitlab.com/elixxir/client/storage/utility" - "gitlab.com/elixxir/client/storage/versioned" - ftCrypto "gitlab.com/elixxir/crypto/fileTransfer" - "gitlab.com/elixxir/ekv" - "gitlab.com/elixxir/primitives/format" - "gitlab.com/xx_network/primitives/id" - "gitlab.com/xx_network/primitives/netTime" - "math/rand" - "reflect" - "sort" - "strconv" - "strings" - "sync" - "sync/atomic" - "testing" - "time" -) - -// Tests that NewSentTransfer creates the expected SentTransfer and that it is -// saved to storage. -func Test_NewSentTransfer(t *testing.T) { - prng := NewPrng(42) - kv := versioned.NewKV(make(ekv.Memstore)) - recipient, _ := id.NewRandomID(prng, id.User) - tid, _ := ftCrypto.NewTransferID(prng) - key, _ := ftCrypto.NewTransferKey(prng) - kvPrefixed := kv.Prefix(makeSentTransferPrefix(tid)) - parts := [][]byte{ - []byte("test0"), []byte("test1"), []byte("test2"), - []byte("test3"), []byte("test4"), []byte("test5"), - } - numParts, numFps := uint16(len(parts)), uint16(float64(len(parts))*1.5) - fpVector, _ := utility.NewStateVector( - kvPrefixed, sentFpVectorKey, uint32(numFps)) - partStats, _ := utility.NewMultiStateVector( - numParts, 3, sentTransferStateMap, sentPartStatsVectorKey, kvPrefixed) - - type cbFields struct { - completed bool - sent, arrived, total uint16 - err error - } - - expectedCB := cbFields{ - completed: false, - sent: 0, - arrived: 0, - total: numParts, - err: nil, - } - - cbChan := make(chan cbFields) - cb := func(completed bool, sent, arrived, total uint16, - t interfaces.FilePartTracker, err error) { - cbChan <- cbFields{ - completed: completed, - sent: sent, - arrived: arrived, - total: total, - err: err, - } - } - - expectedPeriod := time.Second - - expected := &SentTransfer{ - recipient: recipient, - key: key, - numParts: numParts, - numFps: numFps, - fpVector: fpVector, - sentParts: &partStore{ - parts: partSliceToMap(parts...), - numParts: uint16(len(parts)), - kv: kvPrefixed, - }, - inProgressTransfers: &transferredBundle{ - list: make(map[id.Round][]uint16), - key: inProgressKey, - kv: kvPrefixed, - }, - finishedTransfers: &transferredBundle{ - list: make(map[id.Round][]uint16), - key: finishedKey, - kv: kvPrefixed, - }, - partStats: partStats, - progressCallbacks: []*sentCallbackTracker{ - newSentCallbackTracker(cb, expectedPeriod), - }, - status: Running, - kv: kvPrefixed, - } - - // Create new SentTransfer - st, err := NewSentTransfer( - recipient, tid, key, parts, numFps, cb, expectedPeriod, kv) - if err != nil { - t.Errorf("NewSentTransfer returned an error: %+v", err) - } - - // Check that the callback is called when added - select { - case <-time.NewTimer(10 * time.Millisecond).C: - t.Error("Timed out waiting fpr progress callback to be called.") - case cbResults := <-cbChan: - if !reflect.DeepEqual(expectedCB, cbResults) { - t.Errorf("Did not receive correct results from callback."+ - "\nexpected: %+v\nreceived: %+v", expectedCB, cbResults) - } - } - - st.progressCallbacks = expected.progressCallbacks - - // Check that the new object matches the expected - if !reflect.DeepEqual(expected, st) { - t.Errorf("New SentTransfer does not match expected."+ - "\nexpected: %#v\nreceived: %#v", expected, st) - } - - // Make sure it is saved to storage - _, err = kvPrefixed.Get(sentTransferKey, sentTransferVersion) - if err != nil { - t.Errorf("Failed to get new SentTransfer from storage: %+v", err) - } - - // Check that the fingerprint vector has correct values - if st.fpVector.GetNumAvailable() != uint32(numFps) { - t.Errorf("Incorrect number of available keys in fingerprint list."+ - "\nexpected: %d\nreceived: %d", numFps, st.fpVector.GetNumAvailable()) - } - if st.fpVector.GetNumKeys() != uint32(numFps) { - t.Errorf("Incorrect number of keys in fingerprint list."+ - "\nexpected: %d\nreceived: %d", numFps, st.fpVector.GetNumKeys()) - } - if st.fpVector.GetNumUsed() != 0 { - t.Errorf("Incorrect number of used keys in fingerprint list."+ - "\nexpected: %d\nreceived: %d", 0, st.fpVector.GetNumUsed()) - } -} - -// Tests that SentTransfer.ReInit overwrites the fingerprint vector, in-progress -// transfer, finished transfers, and progress callbacks with new and empty -// objects. -func TestSentTransfer_ReInit(t *testing.T) { - prng := NewPrng(42) - kv := versioned.NewKV(make(ekv.Memstore)) - recipient, _ := id.NewRandomID(prng, id.User) - tid, _ := ftCrypto.NewTransferID(prng) - key, _ := ftCrypto.NewTransferKey(prng) - kvPrefixed := kv.Prefix(makeSentTransferPrefix(tid)) - parts := [][]byte{ - []byte("test0"), []byte("test1"), []byte("test2"), - []byte("test3"), []byte("test4"), []byte("test5"), - } - numParts, numFps1 := uint16(len(parts)), uint16(float64(len(parts))*1.5) - numFps2 := 2 * numFps1 - fpVector, _ := utility.NewStateVector( - kvPrefixed, sentFpVectorKey, uint32(numFps2)) - partStats, _ := utility.NewMultiStateVector( - numParts, 3, sentTransferStateMap, sentPartStatsVectorKey, kvPrefixed) - - type cbFields struct { - completed bool - sent, arrived, total uint16 - err error - } - - expectedCB := cbFields{ - completed: false, - sent: 0, - arrived: 0, - total: numParts, - err: nil, - } - - cbChan := make(chan cbFields) - cb := func(completed bool, sent, arrived, total uint16, - t interfaces.FilePartTracker, err error) { - cbChan <- cbFields{ - completed: completed, - sent: sent, - arrived: arrived, - total: total, - err: err, - } - } - - expectedPeriod := time.Millisecond - - expected := &SentTransfer{ - recipient: recipient, - key: key, - numParts: numParts, - numFps: numFps2, - fpVector: fpVector, - sentParts: &partStore{ - parts: partSliceToMap(parts...), - numParts: uint16(len(parts)), - kv: kvPrefixed, - }, - inProgressTransfers: &transferredBundle{ - list: make(map[id.Round][]uint16), - key: inProgressKey, - kv: kvPrefixed, - }, - finishedTransfers: &transferredBundle{ - list: make(map[id.Round][]uint16), - key: finishedKey, - kv: kvPrefixed, - }, - partStats: partStats, - progressCallbacks: []*sentCallbackTracker{ - newSentCallbackTracker(cb, expectedPeriod), - }, - status: Running, - kv: kvPrefixed, - } - - // Create new SentTransfer - st, err := NewSentTransfer( - recipient, tid, key, parts, numFps1, nil, 2*expectedPeriod, kv) - if err != nil { - t.Errorf("NewSentTransfer returned an error: %+v", err) - } - - // Re-initialize SentTransfer with new number of fingerprints and callback - err = st.ReInit(numFps2, cb, expectedPeriod) - if err != nil { - t.Errorf("ReInit returned an error: %+v", err) - } - - // Check that the callback is called when added - select { - case <-time.NewTimer(10 * time.Millisecond).C: - t.Error("Timed out waiting fpr progress callback to be called.") - case cbResults := <-cbChan: - if !reflect.DeepEqual(expectedCB, cbResults) { - t.Errorf("Did not receive correct results from callback."+ - "\nexpected: %+v\nreceived: %+v", expectedCB, cbResults) - } - } - - st.progressCallbacks = expected.progressCallbacks - - // Check that the new object matches the expected - if !reflect.DeepEqual(expected, st) { - t.Errorf("New SentTransfer does not match expected."+ - "\nexpected: %#v\nreceived: %#v", expected, st) - } - - // Make sure it is saved to storage - _, err = kvPrefixed.Get(sentTransferKey, sentTransferVersion) - if err != nil { - t.Errorf("Failed to get new SentTransfer from storage: %+v", err) - } - - // Check that the fingerprint vector has correct values - if st.fpVector.GetNumAvailable() != uint32(numFps2) { - t.Errorf("Incorrect number of available keys in fingerprint list."+ - "\nexpected: %d\nreceived: %d", numFps2, st.fpVector.GetNumAvailable()) - } - if st.fpVector.GetNumKeys() != uint32(numFps2) { - t.Errorf("Incorrect number of keys in fingerprint list."+ - "\nexpected: %d\nreceived: %d", numFps2, st.fpVector.GetNumKeys()) - } - if st.fpVector.GetNumUsed() != 0 { - t.Errorf("Incorrect number of used keys in fingerprint list."+ - "\nexpected: %d\nreceived: %d", 0, st.fpVector.GetNumUsed()) - } -} - -// Tests that SentTransfer.GetRecipient returns the expected ID. -func TestSentTransfer_GetRecipient(t *testing.T) { - prng := NewPrng(42) - kv := versioned.NewKV(make(ekv.Memstore)) - expectedRecipient, _ := id.NewRandomID(prng, id.User) - tid, _ := ftCrypto.NewTransferID(prng) - key, _ := ftCrypto.NewTransferKey(prng) - - // Create new SentTransfer - st, err := NewSentTransfer( - expectedRecipient, tid, key, [][]byte{}, 5, nil, 0, kv) - if err != nil { - t.Errorf("Failed to create new SentTransfer: %+v", err) - } - - if expectedRecipient != st.GetRecipient() { - t.Errorf("Failed to get expected transfer key."+ - "\nexpected: %s\nreceived: %s", expectedRecipient, st.GetRecipient()) - } -} - -// Tests that SentTransfer.GetTransferKey returns the expected transfer key. -func TestSentTransfer_GetTransferKey(t *testing.T) { - prng := NewPrng(42) - kv := versioned.NewKV(make(ekv.Memstore)) - recipient, _ := id.NewRandomID(prng, id.User) - tid, _ := ftCrypto.NewTransferID(prng) - expectedKey, _ := ftCrypto.NewTransferKey(prng) - - // Create new SentTransfer - st, err := NewSentTransfer( - recipient, tid, expectedKey, [][]byte{}, 5, nil, 0, kv) - if err != nil { - t.Errorf("Failed to create new SentTransfer: %+v", err) - } - - if expectedKey != st.GetTransferKey() { - t.Errorf("Failed to get expected transfer key."+ - "\nexpected: %s\nreceived: %s", expectedKey, st.GetTransferKey()) - } -} - -// Tests that SentTransfer.GetNumParts returns the expected number of parts. -func TestSentTransfer_GetNumParts(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - expectedNumParts := uint16(16) - _, st := newRandomSentTransfer(expectedNumParts, 24, kv, t) - - if expectedNumParts != st.GetNumParts() { - t.Errorf("Failed to get expected number of parts."+ - "\nexpected: %d\nreceived: %d", expectedNumParts, st.GetNumParts()) - } -} - -// Tests that SentTransfer.GetNumFps returns the expected number of -// fingerprints. -func TestSentTransfer_GetNumFps(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - expectedNumFps := uint16(24) - _, st := newRandomSentTransfer(16, expectedNumFps, kv, t) - - if expectedNumFps != st.GetNumFps() { - t.Errorf("Failed to get expected number of fingerprints."+ - "\nexpected: %d\nreceived: %d", expectedNumFps, st.GetNumFps()) - } -} - -// Tests that SentTransfer.GetNumAvailableFps returns the expected number of -// available fingerprints. -func TestSentTransfer_GetNumAvailableFps(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - numParts, numFps := uint16(16), uint16(24) - _, st := newRandomSentTransfer(numParts, numFps, kv, t) - - if numFps != st.GetNumAvailableFps() { - t.Errorf("Failed to get expected number of available fingerprints."+ - "\nexpected: %d\nreceived: %d", - numFps, st.GetNumAvailableFps()) - } - - for i := uint16(0); i < numParts; i++ { - _, _ = st.fpVector.Next() - } - - if numFps-numParts != st.GetNumAvailableFps() { - t.Errorf("Failed to get expected number of available fingerprints."+ - "\nexpected: %d\nreceived: %d", - numFps-numParts, st.GetNumAvailableFps()) - } -} - -// Tests that SentTransfer.GetStatus returns the expected status at each stage -// of the transfer. -func TestSentTransfer_GetStatus(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - numParts, numFps := uint16(2), uint16(4) - _, st := newRandomSentTransfer(numParts, numFps, kv, t) - - status := st.GetStatus() - if status != Running { - t.Errorf("Unexpected transfer status.\nexpected: %s\nreceived: %s", - Running, status) - } - - _, _ = st.SetInProgress(0, 0, 1) - _, _ = st.FinishTransfer(0) - - status = st.GetStatus() - if status != Stopping { - t.Errorf("Unexpected transfer status.\nexpected: %s\nreceived: %s", - Stopping, status) - } - - st.CallProgressCB(nil) - - status = st.GetStatus() - if status != Stopped { - t.Errorf("Unexpected transfer status.\nexpected: %s\nreceived: %s", - Stopped, status) - } -} - -// Tests that SentTransfer.IsPartInProgress returns false before a part is set -// as in-progress and true after it is set via SentTransfer.SetInProgress. Also -// tests that it returns false after the part has been unset via -// SentTransfer.UnsetInProgress. -func TestSentTransfer_IsPartInProgress(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - _, st := newRandomSentTransfer(16, 24, kv, t) - - rid := id.Round(0) - partNum := uint16(7) - - // Test that the part has not been set to in-progress - inProgress, err := st.IsPartInProgress(partNum) - if err != nil { - t.Errorf("IsPartInProgress returned an error: %+v", err) - } - if inProgress { - t.Errorf("Part number %d set as in-progress.", partNum) - } - - // Set the part number to in-progress - _, _ = st.SetInProgress(rid, partNum) - - // Test that the part has been set to in-progress - inProgress, err = st.IsPartInProgress(partNum) - if err != nil { - t.Errorf("IsPartInProgress returned an error: %+v", err) - } - if !inProgress { - t.Errorf("Part number %d not set as in-progress.", partNum) - } - - // Unset the part as in-progress - _, _ = st.UnsetInProgress(rid) - - // Test that the part has been unset - inProgress, err = st.IsPartInProgress(partNum) - if err != nil { - t.Errorf("IsPartInProgress returned an error: %+v", err) - } - if inProgress { - t.Errorf("Part number %d set as in-progress.", partNum) - } -} - -// Error path: tests that SentTransfer.IsPartInProgress returns the expected -// error when the part number is out of range. -func TestSentTransfer_IsPartInProgress_InvalidPartNumError(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - _, st := newRandomSentTransfer(16, 24, kv, t) - - expectedErr := fmt.Sprintf(getStatusErr, st.numParts, "") - _, err := st.IsPartInProgress(st.numParts) - if err == nil || !strings.Contains(err.Error(), expectedErr) { - t.Errorf("IsPartInProgress did not return the expected error when the "+ - "part number is out of range.\nexpected: %s\nreceived: %v", - expectedErr, err) - } -} - -// Tests that SentTransfer.IsPartFinished returns false before a part is set as -// finished and true after it is set via SentTransfer.FinishTransfer. -func TestSentTransfer_IsPartFinished(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - _, st := newRandomSentTransfer(16, 24, kv, t) - - rid := id.Round(0) - partNum := uint16(7) - - // Set the part number to in-progress - _, _ = st.SetInProgress(rid, partNum) - - // Test that the part has not been set to finished - isFinished, err := st.IsPartFinished(partNum) - if err != nil { - t.Errorf("IsPartFinished returned an error: %+v", err) - } - if isFinished { - t.Errorf("Part number %d set as finished.", partNum) - } - - // Set the part number to finished - _, _ = st.FinishTransfer(rid) - - // Test that the part has been set to finished - isFinished, err = st.IsPartFinished(partNum) - if err != nil { - t.Errorf("IsPartFinished returned an error: %+v", err) - } - if !isFinished { - t.Errorf("Part number %d not set as finished.", partNum) - } -} - -// Error path: tests that SentTransfer.IsPartFinished returns the expected -// error when the part number is out of range. -func TestSentTransfer_IsPartFinished_InvalidPartNumError(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - _, st := newRandomSentTransfer(16, 24, kv, t) - - expectedErr := fmt.Sprintf(getStatusErr, st.numParts, "") - _, err := st.IsPartFinished(st.numParts) - if err == nil || !strings.Contains(err.Error(), expectedErr) { - t.Errorf("IsPartFinished did not return the expected error when the "+ - "part number is out of range.\nexpected: %s\nreceived: %v", - expectedErr, err) - } -} - -// Tests that SentTransfer.GetProgress returns the expected progress metrics for -// various transfer states. -func TestSentTransfer_GetProgress(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - numParts := uint16(16) - _, st := newRandomSentTransfer(16, 24, kv, t) - - completed, sent, arrived, total, track := st.GetProgress() - err := checkSentProgress( - completed, sent, arrived, total, false, 0, 0, numParts) - if err != nil { - t.Error(err) - } - checkSentTracker(track, st.numParts, nil, nil, t) - - _, _ = st.SetInProgress(1, 0, 1, 2) - - completed, sent, arrived, total, track = st.GetProgress() - err = checkSentProgress(completed, sent, arrived, total, false, 3, 0, numParts) - if err != nil { - t.Error(err) - } - checkSentTracker(track, st.numParts, []uint16{0, 1, 2}, nil, t) - - _, _ = st.SetInProgress(2, 3, 4, 5) - - completed, sent, arrived, total, track = st.GetProgress() - err = checkSentProgress(completed, sent, arrived, total, false, 6, 0, numParts) - if err != nil { - t.Error(err) - } - checkSentTracker(track, st.numParts, []uint16{0, 1, 2, 3, 4, 5}, nil, t) - - _, _ = st.FinishTransfer(1) - _, _ = st.UnsetInProgress(2) - - completed, sent, arrived, total, track = st.GetProgress() - err = checkSentProgress(completed, sent, arrived, total, false, 0, 3, numParts) - if err != nil { - t.Error(err) - } - checkSentTracker(track, st.numParts, nil, []uint16{0, 1, 2}, t) - - _, _ = st.SetInProgress(3, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15) - - completed, sent, arrived, total, track = st.GetProgress() - err = checkSentProgress( - completed, sent, arrived, total, false, 10, 3, numParts) - if err != nil { - t.Error(err) - } - checkSentTracker(track, st.numParts, - []uint16{6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, []uint16{0, 1, 2}, t) - - _, _ = st.FinishTransfer(3) - _, _ = st.SetInProgress(4, 3, 4, 5) - - completed, sent, arrived, total, track = st.GetProgress() - err = checkSentProgress( - completed, sent, arrived, total, false, 3, 13, numParts) - if err != nil { - t.Error(err) - } - checkSentTracker(track, st.numParts, []uint16{3, 4, 5}, - []uint16{0, 1, 2, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, t) - - _, _ = st.FinishTransfer(4) - - completed, sent, arrived, total, track = st.GetProgress() - err = checkSentProgress(completed, sent, arrived, total, true, 0, 16, numParts) - if err != nil { - t.Error(err) - } - checkSentTracker(track, st.numParts, nil, - []uint16{0, 1, 2, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 3, 4, 5}, t) -} - -// Tests that 5 different callbacks all receive the expected data when -// SentTransfer.CallProgressCB is called at different stages of transfer. -func TestSentTransfer_CallProgressCB(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - _, st := newRandomSentTransfer(16, 24, kv, t) - - type progressResults struct { - completed bool - sent, arrived, total uint16 - err error - } - - period := time.Millisecond - - wg := sync.WaitGroup{} - var step0, step1, step2, step3 uint64 - numCallbacks := 5 - - for i := 0; i < numCallbacks; i++ { - progressChan := make(chan progressResults) - - cbFunc := func(completed bool, sent, arrived, total uint16, - t interfaces.FilePartTracker, err error) { - progressChan <- progressResults{completed, sent, arrived, total, err} - } - wg.Add(1) - - go func(i int) { - defer wg.Done() - n := 0 - for { - select { - case <-time.NewTimer(time.Second).C: - t.Errorf("Timed out after %s waiting for callback (%d).", - period*5, i) - return - case r := <-progressChan: - switch n { - case 0: - if err := checkSentProgress(r.completed, r.sent, r.arrived, - r.total, false, 0, 0, st.numParts); err != nil { - t.Errorf("%2d: %+v", i, err) - } - atomic.AddUint64(&step0, 1) - case 1: - if err := checkSentProgress(r.completed, r.sent, r.arrived, - r.total, false, 0, 0, st.numParts); err != nil { - t.Errorf("%2d: %+v", i, err) - } - atomic.AddUint64(&step1, 1) - case 2: - if err := checkSentProgress(r.completed, r.sent, r.arrived, - r.total, false, 0, 6, st.numParts); err != nil { - t.Errorf("%2d: %+v", i, err) - } - atomic.AddUint64(&step2, 1) - case 3: - if err := checkSentProgress(r.completed, r.sent, r.arrived, - r.total, true, 0, 16, st.numParts); err != nil { - t.Errorf("%2d: %+v", i, err) - } - atomic.AddUint64(&step3, 1) - return - default: - t.Errorf("n (%d) is great than 3 (%d)", n, i) - return - } - n++ - } - } - }(i) - - st.AddProgressCB(cbFunc, period) - } - - for !atomic.CompareAndSwapUint64(&step0, uint64(numCallbacks), 0) { - } - - st.CallProgressCB(nil) - - for !atomic.CompareAndSwapUint64(&step1, uint64(numCallbacks), 0) { - } - - _, _ = st.SetInProgress(0, 0, 1, 2) - _, _ = st.SetInProgress(1, 3, 4, 5) - _, _ = st.SetInProgress(2, 6, 7, 8) - _, _ = st.UnsetInProgress(1) - _, _ = st.FinishTransfer(0) - _, _ = st.FinishTransfer(2) - - st.CallProgressCB(nil) - - for !atomic.CompareAndSwapUint64(&step2, uint64(numCallbacks), 0) { - } - - _, _ = st.SetInProgress(4, 3, 4, 5, 9, 10, 11, 12, 13, 14, 15) - _, _ = st.FinishTransfer(4) - - st.CallProgressCB(nil) - - wg.Wait() -} - -// Tests that SentTransfer.stopScheduledProgressCB stops a scheduled callback -// from being triggered. -func TestSentTransfer_stopScheduledProgressCB(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - _, st := newRandomSentTransfer(16, 24, kv, t) - - cbChan := make(chan struct{}, 5) - cbFunc := interfaces.SentProgressCallback( - func(completed bool, sent, arrived, total uint16, - t interfaces.FilePartTracker, err error) { - cbChan <- struct{}{} - }) - st.AddProgressCB(cbFunc, 150*time.Millisecond) - select { - case <-time.NewTimer(10 * time.Millisecond).C: - t.Error("Timed out waiting for callback.") - case <-cbChan: - } - - st.CallProgressCB(nil) - st.CallProgressCB(nil) - select { - case <-time.NewTimer(10 * time.Millisecond).C: - t.Error("Timed out waiting for callback.") - case <-cbChan: - } - - err := st.stopScheduledProgressCB() - if err != nil { - t.Errorf("stopScheduledProgressCB returned an error: %+v", err) - } - - select { - case <-time.NewTimer(200 * time.Millisecond).C: - case <-cbChan: - t.Error("Callback called when it should have been stopped.") - } -} - -// Tests that SentTransfer.AddProgressCB adds an item to the progress callback -// list. -func TestSentTransfer_AddProgressCB(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - _, st := newRandomSentTransfer(16, 24, kv, t) - - type callbackResults struct { - completed bool - sent, arrived, total uint16 - err error - } - cbChan := make(chan callbackResults) - cbFunc := interfaces.SentProgressCallback( - func(completed bool, sent, arrived, total uint16, - t interfaces.FilePartTracker, err error) { - cbChan <- callbackResults{completed, sent, arrived, total, err} - }) - - done := make(chan bool) - go func() { - select { - case <-time.NewTimer(time.Millisecond).C: - t.Error("Timed out waiting for progress callback to be called.") - case r := <-cbChan: - err := checkSentProgress( - r.completed, r.sent, r.arrived, r.total, false, 0, 0, 16) - if err != nil { - t.Error(err) - } - if r.err != nil { - t.Errorf("Callback returned an error: %+v", err) - } - } - done <- true - }() - - period := time.Millisecond - st.AddProgressCB(cbFunc, period) - - if len(st.progressCallbacks) != 1 { - t.Errorf("Callback list should only have one item."+ - "\nexpected: %d\nreceived: %d", 1, len(st.progressCallbacks)) - } - - if st.progressCallbacks[0].period != period { - t.Errorf("Callback has wrong lastCall.\nexpected: %s\nreceived: %s", - period, st.progressCallbacks[0].period) - } - - if st.progressCallbacks[0].lastCall != (time.Time{}) { - t.Errorf("Callback has wrong time.\nexpected: %s\nreceived: %s", - time.Time{}, st.progressCallbacks[0].lastCall) - } - - if st.progressCallbacks[0].scheduled { - t.Errorf("Callback has wrong scheduled.\nexpected: %t\nreceived: %t", - false, st.progressCallbacks[0].scheduled) - } - <-done -} - -// Loops through each file part encrypting it with SentTransfer.GetEncryptedPart -// and tests that it returns an encrypted part, MAC, and padding (nonce) that -// can be used to successfully decrypt and get the original part. Also tests -// that fingerprints are valid and not used more than once. -// It also tests that -func TestSentTransfer_GetEncryptedPart(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - _, st := newRandomSentTransfer(16, 24, kv, t) - - // Create and fill fingerprint map used to check fingerprint validity - // The first item in the uint16 slice is the fingerprint number and the - // second item is the number of times it has been used - fpMap := make(map[format.Fingerprint][]uint16, st.numFps) - for num, fp := range ftCrypto.GenerateFingerprints(st.key, st.numFps) { - fpMap[fp] = []uint16{uint16(num), 0} - } - - for i := uint16(0); i < st.numFps; i++ { - partNum := i % st.numParts - - encPart, mac, fp, err := st.GetEncryptedPart(partNum, 18) - if err != nil { - t.Fatalf("GetEncryptedPart returned an error for part number "+ - "%d (%d): %+v", partNum, i, err) - } - - // Check that the fingerprint is valid - fpNum, exists := fpMap[fp] - if !exists { - t.Errorf("Fingerprint %s invalid for part number %d (%d).", - fp, partNum, i) - } - - // Check that the fingerprint has not been used - if fpNum[1] > 0 { - t.Errorf("Fingerprint %s for part number %d already used by %d "+ - "other parts (%d).", fp, partNum, fpNum[1], i) - } - - // Attempt to decrypt the part - partMarshaled, err := ftCrypto.DecryptPart(st.key, encPart, mac, fpNum[0], fp) - if err != nil { - t.Errorf("Failed to decrypt file part number %d (%d): %+v", - partNum, i, err) - } - - partMsg, _ := UnmarshalPartMessage(partMarshaled) - - // Make sure the decrypted part matches the original - expectedPart, _ := st.sentParts.getPart(i % st.numParts) - if !bytes.Equal(expectedPart, partMsg.GetPart()) { - t.Errorf("Decyrpted part number %d does not match expected (%d)."+ - "\nexpected: %+v\nreceived: %+v", partNum, i, expectedPart, partMsg.GetPart()) - } - - if partMsg.GetPartNum() != i%st.numParts { - t.Errorf("Number of part did not match, expected: %d, "+ - "received: %d", i%st.numParts, partMsg.GetPartNum()) - } - } -} - -// Error path: tests that SentTransfer.GetEncryptedPart returns the expected -// error when no part for the given part number exists. -func TestSentTransfer_GetEncryptedPart_NoPartError(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - _, st := newRandomSentTransfer(16, 24, kv, t) - - partNum := st.numParts + 1 - expectedErr := fmt.Sprintf(noPartNumErr, partNum) - - _, _, _, err := st.GetEncryptedPart(partNum, 16) - if err == nil || err.Error() != expectedErr { - t.Errorf("GetEncryptedPart did not return the expected error for a "+ - "nonexistent part number %d.\nexpected: %s\nreceived: %+v", - partNum, expectedErr, err) - } -} - -// Error path: tests that SentTransfer.GetEncryptedPart returns the expected -// error when no fingerprints are available. -func TestSentTransfer_GetEncryptedPart_NoFingerprintsError(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - _, st := newRandomSentTransfer(16, 24, kv, t) - - // Use up all the fingerprints - for i := uint16(0); i < st.numFps; i++ { - partNum := i % st.numParts - _, _, _, err := st.GetEncryptedPart(partNum, 18) - if err != nil { - t.Errorf("Error when encyrpting part number %d (%d): %+v", - partNum, i, err) - } - } - - // Try to encrypt without any fingerprints - _, _, _, err := st.GetEncryptedPart(5, 18) - if err != MaxRetriesErr { - t.Errorf("GetEncryptedPart did not return MaxRetriesErr when all "+ - "fingerprints have been used.\nexpected: %s\nreceived: %+v", - MaxRetriesErr, err) - } -} - -// Tests that SentTransfer.SetInProgress correctly adds the part numbers for the -// given round ID to the in-progress map and sets the correct parts as -// in-progress in the state vector. -func TestSentTransfer_SetInProgress(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - _, st := newRandomSentTransfer(16, 24, kv, t) - - rid := id.Round(5) - expectedPartNums := []uint16{1, 2, 3} - - // Add parts to the in-progress list - exists, err := st.SetInProgress(rid, expectedPartNums...) - if err != nil { - t.Errorf("SetInProgress returned an error: %+v", err) - } - - // Check that the round does not already exist - if exists { - t.Errorf("Round %d already exists.", rid) - } - - // Check that the round ID is in the map - partNums, exists := st.inProgressTransfers.list[rid] - if !exists { - t.Errorf("Part numbers for round %d not found.", rid) - } - - // Check that the returned part numbers are correct - if !reflect.DeepEqual(expectedPartNums, partNums) { - t.Errorf("Received part numbers do not match expected."+ - "\nexpected: %v\nreceived: %v", expectedPartNums, partNums) - } - - // Check that only one item was added to the list - if len(st.inProgressTransfers.list) > 1 { - t.Errorf("Extra items in in-progress list."+ - "\nexpected: %d\nreceived: %d", 1, len(st.inProgressTransfers.list)) - } - - // Check that the part numbers were set on the in-progress status vector - for i, partNum := range expectedPartNums { - if status, _ := st.partStats.Get(partNum); status != inProgress { - t.Errorf("Part number %d not marked as in-progress in status "+ - "vector (%d).", partNum, i) - } - } - - // Check that the correct number of parts were marked as in-progress in the - // status vector - count, _ := st.partStats.GetCount(inProgress) - if int(count) != len(expectedPartNums) { - t.Errorf("Incorrect number of parts marked as in-progress."+ - "\nexpected: %d\nreceived: %d", len(expectedPartNums), count) - } - - // Add more parts to the in-progress list - expectedPartNums2 := []uint16{4, 5, 6} - exists, err = st.SetInProgress(rid, expectedPartNums2...) - if err != nil { - t.Errorf("SetInProgress returned an error: %+v", err) - } - - // Check that the round already exists - if !exists { - t.Errorf("Round %d should already exist.", rid) - } - - // Check that the number of parts were marked as in-progress is unchanged - count, _ = st.partStats.GetCount(inProgress) - if int(count) != len(expectedPartNums2)+len(expectedPartNums) { - t.Errorf("Incorrect number of parts marked as in-progress."+ - "\nexpected: %d\nreceived: %d", - len(expectedPartNums2)+len(expectedPartNums), count) - } -} - -// Tests that SentTransfer.GetInProgress returns the correct part numbers for -// the given round ID in the in-progress map. -func TestSentTransfer_GetInProgress(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - _, st := newRandomSentTransfer(16, 24, kv, t) - - rid := id.Round(5) - expectedPartNums := []uint16{1, 2, 3, 4, 5, 6} - - // Add parts to the in-progress list - _, err := st.SetInProgress(rid, expectedPartNums[:3]...) - if err != nil { - t.Errorf("Failed to set parts %v to in-progress: %+v", - expectedPartNums[3:], err) - } - - // Add parts to the in-progress list - _, err = st.SetInProgress(rid, expectedPartNums[3:]...) - if err != nil { - t.Errorf("Failed to set parts %v to in-progress: %+v", - expectedPartNums[:3], err) - } - - // get the in-progress parts - receivedPartNums, exists := st.GetInProgress(rid) - if !exists { - t.Errorf("Failed to find parts for round %d that should exist.", rid) - } - - // Check that the returned part numbers are correct - if !reflect.DeepEqual(expectedPartNums, receivedPartNums) { - t.Errorf("Received part numbers do not match expected."+ - "\nexpected: %v\nreceived: %v", expectedPartNums, receivedPartNums) - } -} - -// Tests that SentTransfer.UnsetInProgress correctly removes the part numbers -// for the given round ID from the in-progress map and unsets the correct parts -// as in-progress in the state vector. -func TestSentTransfer_UnsetInProgress(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - _, st := newRandomSentTransfer(16, 24, kv, t) - - rid := id.Round(5) - expectedPartNums := []uint16{1, 2, 3, 4, 5, 6} - - // Add parts to the in-progress list - if _, err := st.SetInProgress(rid, expectedPartNums[:3]...); err != nil { - t.Errorf("Failed to set parts in-progress: %+v", err) - } - if _, err := st.SetInProgress(rid, expectedPartNums[3:]...); err != nil { - t.Errorf("Failed to set parts in-progress: %+v", err) - } - - // Remove parts from in-progress list - receivedPartNums, err := st.UnsetInProgress(rid) - if err != nil { - t.Errorf("UnsetInProgress returned an error: %+v", err) - } - - if !reflect.DeepEqual(expectedPartNums, receivedPartNums) { - t.Errorf("Received part numbers do not match expected."+ - "\nexpected: %v\nreceived: %v", expectedPartNums, receivedPartNums) - } - - // Check that the round ID is not the map - partNums, exists := st.inProgressTransfers.list[rid] - if exists { - t.Errorf("Part numbers for round %d found: %v", rid, partNums) - } - - // Check that the list is empty - if len(st.inProgressTransfers.list) != 0 { - t.Errorf("Extra items in in-progress list."+ - "\nexpected: %d\nreceived: %d", 0, len(st.inProgressTransfers.list)) - } - - // Check that there are no set parts in the in-progress status vector - status, _ := st.partStats.Get(inProgress) - if status != unsent { - t.Errorf("Failed to unset all parts in the in-progress vector."+ - "\nexpected: %d\nreceived: %d", unsent, status) - } -} - -// Tests that SentTransfer.FinishTransfer removes the parts from the in-progress -// list and moved them to the finished list and that it unsets the correct parts -// in the in-progress vector in the state vector. -func TestSentTransfer_FinishTransfer(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - _, st := newRandomSentTransfer(16, 24, kv, t) - - rid := id.Round(5) - expectedPartNums := []uint16{1, 2, 3} - - // Add parts to the in-progress list - _, err := st.SetInProgress(rid, expectedPartNums...) - if err != nil { - t.Errorf("Failed to add parts to in-progress list: %+v", err) - } - - // Move transfers to the finished list - complete, err := st.FinishTransfer(rid) - if err != nil { - t.Errorf("FinishTransfer returned an error: %+v", err) - } - - // Ensure the transfer is not reported as complete - if complete { - t.Error("FinishTransfer reported transfer as complete.") - } - - // Check that the round ID is not in the in-progress map - _, exists := st.inProgressTransfers.list[rid] - if exists { - t.Errorf("Found parts for round %d that should not be in map.", rid) - } - - // Check that the round ID is in the finished map - partNums, exists := st.finishedTransfers.list[rid] - if !exists { - t.Errorf("Part numbers for round %d not found.", rid) - } - - // Check that the returned part numbers are correct - if !reflect.DeepEqual(expectedPartNums, partNums) { - t.Errorf("Received part numbers do not match expected."+ - "\nexpected: %+v\nreceived: %+v", expectedPartNums, partNums) - } - - // Check that only one item was added to the list - if len(st.finishedTransfers.list) > 1 { - t.Errorf("Extra items in finished list."+ - "\nexpected: %d\nreceived: %d", 1, len(st.finishedTransfers.list)) - } - - // Check that there are no set parts in the in-progress status vector - count, _ := st.partStats.GetCount(inProgress) - if count != 0 { - t.Errorf("Failed to unset all parts in the in-progress vector."+ - "\nexpected: %d\nreceived: %d", 0, count) - } - - // Check that the part numbers were set on the finished status vector - for i, partNum := range expectedPartNums { - status, _ := st.partStats.Get(inProgress) - if status != finished { - t.Errorf("Part number %d not marked as finished in status vector "+ - "(%d).", partNum, i) - } - } - - // Check that the correct number of parts were marked as finished in the - // status vector - count, _ = st.partStats.GetCount(finished) - if int(count) != len(expectedPartNums) { - t.Errorf("Incorrect number of parts marked as finished."+ - "\nexpected: %d\nreceived: %d", len(expectedPartNums), count) - } -} - -// Tests that SentTransfer.FinishTransfer returns true and sets the status to -// stopping when all file parts are marked as complete. -func TestSentTransfer_FinishTransfer_Complete(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - _, st := newRandomSentTransfer(16, 24, kv, t) - - rid := id.Round(5) - expectedPartNums := make([]uint16, st.numParts) - for i := range expectedPartNums { - expectedPartNums[i] = uint16(i) - } - - // Add parts to the in-progress list - _, err := st.SetInProgress(rid, expectedPartNums...) - if err != nil { - t.Errorf("Failed to add parts to in-progress list: %+v", err) - } - - // Move transfers to the finished list - complete, err := st.FinishTransfer(rid) - if err != nil { - t.Errorf("FinishTransfer returned an error: %+v", err) - } - - // Ensure the transfer is not reported as complete - if !complete { - t.Error("FinishTransfer reported transfer as not complete.") - } - - // Test that the status is correctly set - if st.status != Stopping { - t.Errorf("Status not set to expected value when transfer is complete."+ - "\nexpected: %s\nreceived: %s", Stopping, st.status) - } -} - -// Error path: tests that SentTransfer.FinishTransfer returns the expected error -// when the round ID is found in the in-progress map. -func TestSentTransfer_FinishTransfer_NoRoundErr(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - _, st := newRandomSentTransfer(16, 24, kv, t) - - rid := id.Round(5) - expectedErr := fmt.Sprintf(noPartsForRoundErr, rid) - - // Move transfers to the finished list - complete, err := st.FinishTransfer(rid) - if err == nil || err.Error() != expectedErr { - t.Errorf("Did not get expected error when round ID not in in-progress "+ - "map.\nexpected: %s\nreceived: %+v", expectedErr, err) - } - - // Ensure the transfer is not reported as complete - if complete { - t.Error("FinishTransfer reported transfer as complete.") - } -} - -// Tests that SentTransfer.GetUnsentPartNums returns only part numbers that are -// not marked as in-progress or finished. -func TestSentTransfer_GetUnsentPartNums(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - _, st := newRandomSentTransfer(18, 27, kv, t) - - expectedPartNums := make([]uint16, 0, st.numParts/3) - - // Loop through each part and set it individually - for i := uint16(0); i < st.numParts; i++ { - switch i % 3 { - case 0: - // Part is sent (in-progress) - _, _ = st.SetInProgress(id.Round(i), i) - case 1: - // Part is sent and arrived (finished) - _, _ = st.SetInProgress(id.Round(i), i) - _, _ = st.FinishTransfer(id.Round(i)) - case 2: - // Part is unsent (neither in-progress nor arrived) - expectedPartNums = append(expectedPartNums, i) - } - } - - unsentPartNums, err := st.GetUnsentPartNums() - if err != nil { - t.Errorf("GetUnsentPartNums returned an error: %+v", err) - } - - if !reflect.DeepEqual(expectedPartNums, unsentPartNums) { - t.Errorf("Unexpected unsent part numbers.\nexpected: %d\nreceived: %d", - expectedPartNums, unsentPartNums) - } -} - -// Tests that SentTransfer.GetSentRounds returns the expected round IDs when -// every round is either in-progress, finished, or unsent. -func TestSentTransfer_GetSentRounds(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - _, st := newRandomSentTransfer(18, 27, kv, t) - - expectedRounds := make([]id.Round, 0, st.numParts/3) - - // Loop through each part and set it individually - for i := uint16(0); i < st.numParts; i++ { - rid := id.Round(i) - switch i % 3 { - case 0: - // Part is sent (in-progress) - _, _ = st.SetInProgress(rid, i) - expectedRounds = append(expectedRounds, rid) - case 1: - // Part is sent and arrived (finished) - _, _ = st.SetInProgress(rid, i) - _, _ = st.FinishTransfer(rid) - case 2: - // Part is unsent (neither in-progress nor arrived) - } - } - - // get the sent - sentRounds := st.GetSentRounds() - sort.SliceStable(sentRounds, - func(i, j int) bool { return sentRounds[i] < sentRounds[j] }) - - if !reflect.DeepEqual(expectedRounds, sentRounds) { - t.Errorf("Unexpected sent rounds.\nexpected: %d\nreceived: %d", - expectedRounds, sentRounds) - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Storage Function Testing // -//////////////////////////////////////////////////////////////////////////////// - -// Tests that loadSentTransfer returns a SentTransfer that matches the original -// object in memory. -func Test_loadSentTransfer(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - tid, expectedST := newRandomSentTransfer(16, 24, kv, t) - _, err := expectedST.SetInProgress(5, 3, 4, 5) - if err != nil { - t.Errorf("Failed to add parts to in-progress transfer: %+v", err) - } - _, err = expectedST.SetInProgress(10, 10, 11, 12) - if err != nil { - t.Errorf("Failed to add parts to in-progress transfer: %+v", err) - } - - _, err = expectedST.FinishTransfer(10) - if err != nil { - t.Errorf("Failed to move parts to finished transfer: %+v", err) - } - - loadedST, err := loadSentTransfer(tid, kv) - if err != nil { - t.Errorf("loadSentTransfer returned an error: %+v", err) - } - - // Progress callbacks cannot be compared - loadedST.progressCallbacks = expectedST.progressCallbacks - - if !reflect.DeepEqual(expectedST, loadedST) { - t.Errorf("Loaded SentTransfer does not match expected."+ - "\nexpected: %+v\nreceived: %+v", expectedST, loadedST) - } -} - -// Error path: tests that loadSentTransfer returns the expected error when no -// transfer with the given ID exists in storage. -func Test_loadSentTransfer_LoadInfoError(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - tid := ftCrypto.UnmarshalTransferID([]byte("invalidTransferID")) - - expectedErr := strings.Split(loadSentStoreErr, "%")[0] - _, err := loadSentTransfer(tid, kv) - if err == nil || !strings.Contains(err.Error(), expectedErr) { - t.Errorf("loadSentTransfer did not return the expected error when no "+ - "transfer with the ID %s exists in storage."+ - "\nexpected: %s\nreceived: %+v", tid, expectedErr, err) - } -} - -// Error path: tests that loadSentTransfer returns the expected error when the -// fingerprint state vector was deleted from storage. -func Test_loadSentTransfer_LoadFingerprintStateVectorError(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - tid, st := newRandomSentTransfer(16, 24, kv, t) - - // Delete the fingerprint state vector from storage - err := st.fpVector.Delete() - if err != nil { - t.Errorf("Failed to delete the fingerprint vector: %+v", err) - } - - expectedErr := strings.Split(loadSentFpVectorErr, "%")[0] - _, err = loadSentTransfer(tid, kv) - if err == nil || !strings.Contains(err.Error(), expectedErr) { - t.Errorf("loadSentTransfer did not return the expected error when "+ - "the fingerprint vector was deleted from storage."+ - "\nexpected: %s\nreceived: %+v", expectedErr, err) - } -} - -// Error path: tests that loadSentTransfer returns the expected error when the -// part store was deleted from storage. -func Test_loadSentTransfer_LoadPartStoreError(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - tid, st := newRandomSentTransfer(16, 24, kv, t) - - // Delete the part store from storage - err := st.sentParts.delete() - if err != nil { - t.Errorf("Failed to delete the part store: %+v", err) - } - - expectedErr := strings.Split(loadSentPartStoreErr, "%")[0] - _, err = loadSentTransfer(tid, kv) - if err == nil || !strings.Contains(err.Error(), expectedErr) { - t.Errorf("loadSentTransfer did not return the expected error when "+ - "the part store was deleted from storage."+ - "\nexpected: %s\nreceived: %+v", expectedErr, err) - } -} - -// Error path: tests that loadSentTransfer returns the expected error when the -// in-progress transfers bundle was deleted from storage. -func Test_loadSentTransfer_LoadInProgressTransfersError(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - tid, st := newRandomSentTransfer(16, 24, kv, t) - - // Delete the in-progress transfers bundle from storage - err := st.inProgressTransfers.delete() - if err != nil { - t.Errorf("Failed to delete the in-progress transfers bundle: %+v", err) - } - - expectedErr := strings.Split(loadInProgressTransfersErr, "%")[0] - _, err = loadSentTransfer(tid, kv) - if err == nil || !strings.Contains(err.Error(), expectedErr) { - t.Errorf("loadSentTransfer did not return the expected error when "+ - "the in-progress transfers bundle was deleted from storage."+ - "\nexpected: %s\nreceived: %+v", expectedErr, err) - } -} - -// Error path: tests that loadSentTransfer returns the expected error when the -// finished transfer bundle was deleted from storage. -func Test_loadSentTransfer_LoadFinishedTransfersError(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - tid, st := newRandomSentTransfer(16, 24, kv, t) - - // Delete the finished transfers bundle from storage - err := st.finishedTransfers.delete() - if err != nil { - t.Errorf("Failed to delete the finished transfers bundle: %+v", err) - } - - expectedErr := strings.Split(loadFinishedTransfersErr, "%")[0] - _, err = loadSentTransfer(tid, kv) - if err == nil || !strings.Contains(err.Error(), expectedErr) { - t.Errorf("loadSentTransfer did not return the expected error when "+ - "the finished transfers bundle was deleted from storage."+ - "\nexpected: %s\nreceived: %+v", expectedErr, err) - } -} - -// Error path: tests that loadSentTransfer returns the expected error when the -// part statuses multi state vector was deleted from storage. -func Test_loadSentTransfer_LoadPartStatsMultiStateVectorError(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - tid, st := newRandomSentTransfer(16, 24, kv, t) - - // Delete the in-progress state vector from storage - err := st.partStats.Delete() - if err != nil { - t.Errorf("Failed to delete the partStats vector: %+v", err) - } - - expectedErr := strings.Split(loadSentPartStatusVectorErr, "%")[0] - _, err = loadSentTransfer(tid, kv) - if err == nil || !strings.Contains(err.Error(), expectedErr) { - t.Errorf("loadSentTransfer did not return the expected error when "+ - "the partStats vector was deleted from storage."+ - "\nexpected: %s\nreceived: %+v", expectedErr, err) - } -} - -// Tests that SentTransfer.saveInfo saves the expected data to storage. -func TestSentTransfer_saveInfo(t *testing.T) { - st := &SentTransfer{ - key: ftCrypto.UnmarshalTransferKey([]byte("key")), - numParts: 16, - kv: versioned.NewKV(make(ekv.Memstore)), - } - - err := st.saveInfo() - if err != nil { - t.Errorf("saveInfo returned an error: %+v", err) - } - - vo, err := st.kv.Get(sentTransferKey, sentTransferVersion) - if err != nil { - t.Errorf("Failed to load SentTransfer from storage: %+v", err) - } - - if !bytes.Equal(st.marshal(), vo.Data) { - t.Errorf("Marshalled data loaded from storage does not match expected."+ - "\nexpected: %+v\nreceived: %+v", st.marshal(), vo.Data) - } -} - -// Tests that SentTransfer.loadInfo loads a saved SentTransfer from storage. -func TestSentTransfer_loadInfo(t *testing.T) { - st := &SentTransfer{ - recipient: id.NewIdFromString("recipient", id.User, t), - key: ftCrypto.UnmarshalTransferKey([]byte("key")), - numParts: 16, - kv: versioned.NewKV(make(ekv.Memstore)), - } - - err := st.saveInfo() - if err != nil { - t.Errorf("failed to save new SentTransfer to storage: %+v", err) - } - - loadedST := &SentTransfer{kv: st.kv} - err = loadedST.loadInfo() - if err != nil { - t.Errorf("load returned an error: %+v", err) - } - - if !reflect.DeepEqual(st, loadedST) { - t.Errorf("Loaded SentTransfer does not match expected."+ - "\nexpected: %+v\nreceived: %+v", st, loadedST) - } -} - -// Error path: tests that SentTransfer.loadInfo returns an error when there is -// no object in storage to load -func TestSentTransfer_loadInfo_Error(t *testing.T) { - loadedST := &SentTransfer{kv: versioned.NewKV(make(ekv.Memstore))} - err := loadedST.loadInfo() - if err == nil { - t.Errorf("Loaded object that should not be in storage: %+v", err) - } -} - -// Tests that SentTransfer.delete removes all data from storage. -func TestSentTransfer_delete(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - _, st := newRandomSentTransfer(16, 24, kv, t) - - // Add in-progress transfers - _, err := st.SetInProgress(5, 3, 4, 5) - if err != nil { - t.Errorf("Failed to add parts to in-progress transfer: %+v", err) - } - _, err = st.SetInProgress(10, 10, 11, 12) - if err != nil { - t.Errorf("Failed to add parts to in-progress transfer: %+v", err) - } - - // Mark last in-progress transfer and finished - _, err = st.FinishTransfer(10) - if err != nil { - t.Errorf("Failed to move parts to finished transfer: %+v", err) - } - - // Delete everything from storage - err = st.delete() - if err != nil { - t.Errorf("delete returned an error: %+v", err) - } - - // Check that the SentTransfer info was deleted - err = st.loadInfo() - if err == nil { - t.Error("Successfully loaded SentTransfer info from storage when it " + - "should have been deleted.") - } - - // Check that the parts store were deleted - _, err = loadPartStore(st.kv) - if err == nil { - t.Error("Successfully loaded file parts from storage when it should " + - "have been deleted.") - } - - // Check that the in-progress transfers were deleted - _, err = loadTransferredBundle(inProgressKey, st.kv) - if err == nil { - t.Error("Successfully loaded in-progress transfers from storage when " + - "it should have been deleted.") - } - - // Check that the finished transfers were deleted - _, err = loadTransferredBundle(finishedKey, st.kv) - if err == nil { - t.Error("Successfully loaded finished transfers from storage when " + - "it should have been deleted.") - } - - // Check that the fingerprint vector was deleted - _, err = utility.LoadStateVector(st.kv, sentFpVectorKey) - if err == nil { - t.Error("Successfully loaded fingerprint vector from storage when it " + - "should have been deleted.") - } - - // Check that the in-progress status vector was deleted - _, err = utility.LoadStateVector(st.kv, sentInProgressVectorKey) - if err == nil { - t.Error("Successfully loaded in-progress vector from storage when it " + - "should have been deleted.") - } - - // Check that the finished status vector was deleted - _, err = utility.LoadStateVector(st.kv, sentFinishedVectorKey) - if err == nil { - t.Error("Successfully loaded finished vector from storage when it " + - "should have been deleted.") - } - -} - -// Tests that SentTransfer.deleteInfo removes the saved SentTransfer data from -// storage. -func TestSentTransfer_deleteInfo(t *testing.T) { - st := &SentTransfer{ - key: ftCrypto.UnmarshalTransferKey([]byte("key")), - numParts: 16, - kv: versioned.NewKV(make(ekv.Memstore)), - } - - // Save from storage - err := st.saveInfo() - if err != nil { - t.Errorf("failed to save new SentTransfer to storage: %+v", err) - } - - // Delete from storage - err = st.deleteInfo() - if err != nil { - t.Errorf("deleteInfo returned an error: %+v", err) - } - - // Make sure deleted object cannot be loaded from storage - _, err = st.kv.Get(sentTransferKey, sentTransferVersion) - if err == nil { - t.Error("Loaded object that should be deleted from storage.") - } -} - -// Tests that a SentTransfer marshalled with SentTransfer.marshal and then -// unmarshalled with unmarshalSentTransfer matches the original. -func TestSentTransfer_marshal_unmarshalSentTransfer(t *testing.T) { - st := &SentTransfer{ - recipient: id.NewIdFromString("testRecipient", id.User, t), - key: ftCrypto.UnmarshalTransferKey([]byte("key")), - numParts: 16, - numFps: 20, - status: Stopped, - } - - marshaledData := st.marshal() - - recipient, key, numParts, numFps, status := - unmarshalSentTransfer(marshaledData) - - if !st.recipient.Cmp(recipient) { - t.Errorf("Failed to get recipient ID.\nexpected: %s\nreceived: %s", - st.recipient, recipient) - } - - if st.key != key { - t.Errorf("Failed to get expected key.\nexpected: %s\nreceived: %s", - st.key, key) - } - - if st.numParts != numParts { - t.Errorf("Failed to get expected number of parts."+ - "\nexpected: %d\nreceived: %d", st.numParts, numParts) - } - - if st.numFps != numFps { - t.Errorf("Failed to get expected number of fingerprints."+ - "\nexpected: %d\nreceived: %d", st.numFps, numFps) - } - - if st.status != status { - t.Errorf("Failed to get expected transfer status."+ - "\nexpected: %s\nreceived: %s", st.status, status) - } -} - -// Consistency test: tests that makeSentTransferPrefix returns the expected -// prefixes for the provided transfer IDs. -func Test_makeSentTransferPrefix_Consistency(t *testing.T) { - prng := NewPrng(42) - expectedPrefixes := []string{ - "FileTransferSentTransferStoreU4x/lrFkvxuXu59LtHLon1sUhPJSCcnZND6SugndnVI=", - "FileTransferSentTransferStore39ebTXZCm2F6DJ+fDTulWwzA1hRMiIU1hBrL4HCbB1g=", - "FileTransferSentTransferStoreCD9h03W8ArQd9PkZKeGP2p5vguVOdI6B555LvW/jTNw=", - "FileTransferSentTransferStoreuoQ+6NY+jE/+HOvqVG2PrBPdGqwEzi6ih3xVec+ix44=", - "FileTransferSentTransferStoreGwuvrogbgqdREIpC7TyQPKpDRlp4YgYWl4rtDOPGxPM=", - "FileTransferSentTransferStorernvD4ElbVxL+/b4MECiH4QDazS2IX2kstgfaAKEcHHA=", - "FileTransferSentTransferStoreceeWotwtwlpbdLLhKXBeJz8FySMmgo4rBW44F2WOEGE=", - "FileTransferSentTransferStoreSYlH/fNEQQ7UwRYCP6jjV2tv7Sf/iXS6wMr9mtBWkrE=", - "FileTransferSentTransferStoreNhnnOJZN/ceejVNDc2Yc/WbXT+weG4lJGrcjbkt1IWI=", - "FileTransferSentTransferStorekM8r60LDyicyhWDxqsBnzqbov0bUqytGgEAsX7KCDog=", - } - - for i, expected := range expectedPrefixes { - tid, _ := ftCrypto.NewTransferID(prng) - prefix := makeSentTransferPrefix(tid) - - if expected != prefix { - t.Errorf("New SentTransfer prefix does not match expected (%d)."+ - "\nexpected: %s\nreceived: %s", i, expected, prefix) - } - } -} - -// Tests that each of the elements in the uint32 slice returned by -// uint16SliceToUint32Slice matches the elements in the original uint16 slice. -func Test_uint16SliceToUint32Slice(t *testing.T) { - prng := rand.New(rand.NewSource(42)) - uint16Slice := make([]uint16, 100) - - for i := range uint16Slice { - uint16Slice[i] = uint16(prng.Uint32()) - } - - uint32Slice := uint16SliceToUint32Slice(uint16Slice) - - // Check that each element is correct - for i, expected := range uint16Slice { - if uint32(expected) != uint32Slice[i] { - t.Errorf("Element #%d is incorrect.\nexpected: %d\nreceived: %d", - i, uint32(expected), uint32Slice[i]) - } - } -} - -// newRandomSentTransfer generates a new SentTransfer with random data. -func newRandomSentTransfer(numParts, numFps uint16, kv *versioned.KV, - t *testing.T) (ftCrypto.TransferID, *SentTransfer) { - // Generate new PRNG with the seed generated by multiplying the pointer for - // numParts with the current UNIX time in nanoseconds - seed, _ := strconv.ParseInt(fmt.Sprintf("%d", &numParts), 10, 64) - seed *= netTime.Now().UnixNano() - prng := NewPrng(seed) - - recipient, _ := id.NewRandomID(prng, id.User) - tid, _ := ftCrypto.NewTransferID(prng) - key, _ := ftCrypto.NewTransferKey(prng) - parts := make([][]byte, numParts) - for i := uint16(0); i < numParts; i++ { - parts[i] = make([]byte, 16) - _, err := prng.Read(parts[i]) - if err != nil { - t.Errorf("Failed to generate random part: %+v", err) - } - } - - st, err := NewSentTransfer(recipient, tid, key, parts, numFps, nil, 0, kv) - if err != nil { - t.Errorf("Failed to create new SentTansfer: %+v", err) - } - - return tid, st -} - -// checkSentProgress compares the output of SentTransfer.GetProgress to expected -// values. -func checkSentProgress(completed bool, sent, arrived, total uint16, - eCompleted bool, eSent, eArrived, eTotal uint16) error { - if eCompleted != completed || eSent != sent || eArrived != arrived || - eTotal != total { - return errors.Errorf("Returned progress does not match expected."+ - "\n completed sent arrived total"+ - "\nexpected: %5t %3d %3d %3d"+ - "\nreceived: %5t %3d %3d %3d", - eCompleted, eSent, eArrived, eTotal, - completed, sent, arrived, total) - } - - return nil -} - -// checkSentTracker checks that the sentPartTracker is reporting the correct -// values for each part. Also checks that sentPartTracker.GetNumParts returns -// the expected value (make sure numParts comes from a correct source). -func checkSentTracker(track interfaces.FilePartTracker, numParts uint16, - inProgress, finished []uint16, t *testing.T) { - if track.GetNumParts() != numParts { - t.Errorf("Tracker reported incorrect number of parts."+ - "\nexpected: %d\nreceived: %d", numParts, track.GetNumParts()) - return - } - - for partNum := uint16(0); partNum < numParts; partNum++ { - var done bool - for _, inProgressNum := range inProgress { - if inProgressNum == partNum { - status := track.GetPartStatus(partNum) - if status != interfaces.FpSent { - t.Errorf("Part number %d has unexpected status."+ - "\nexpected: %d\nreceived: %d", - partNum, interfaces.FpSent, status) - } - done = true - break - } - } - if done { - continue - } - - for _, finishedNum := range finished { - if finishedNum == partNum { - status := track.GetPartStatus(partNum) - if status != interfaces.FpArrived { - t.Errorf("Part number %d has unexpected status."+ - "\nexpected: %d\nreceived: %d", - partNum, interfaces.FpArrived, status) - } - done = true - break - } - } - if done { - continue - } - - status := track.GetPartStatus(partNum) - if status != interfaces.FpUnsent { - t.Errorf("Part number %d has incorrect status."+ - "\nexpected: %d\nreceived: %d", - partNum, interfaces.FpUnsent, status) - } - } -} diff --git a/storage/fileTransfer/transferredBundle.go b/storage/fileTransfer/transferredBundle.go deleted file mode 100644 index dff0c2a240daf93e488c5e7a1667e112b5660a91..0000000000000000000000000000000000000000 --- a/storage/fileTransfer/transferredBundle.go +++ /dev/null @@ -1,192 +0,0 @@ -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - -package fileTransfer - -import ( - "bytes" - "encoding/binary" - "github.com/pkg/errors" - "gitlab.com/elixxir/client/storage/versioned" - "gitlab.com/xx_network/primitives/id" - "gitlab.com/xx_network/primitives/netTime" -) - -// Storage keys and versions. -const ( - transferredBundleVersion = 0 - transferredBundleKey = "FileTransferBundle" - inProgressKey = "inProgressTransfers" - finishedKey = "finishedTransfers" -) - -// Error messages. -const ( - loadTransferredBundleErr = "failed to get transferredBundle from storage: %+v" -) - -// transferredBundle lists the file parts sent per round ID. -type transferredBundle struct { - list map[id.Round][]uint16 - numParts uint16 - key string - kv *versioned.KV -} - -// newTransferredBundle generates a new transferredBundle and saves it to -// storage. -func newTransferredBundle(key string, kv *versioned.KV) ( - *transferredBundle, error) { - tb := &transferredBundle{ - list: make(map[id.Round][]uint16), - key: key, - kv: kv, - } - - return tb, tb.save() -} - -// addPartNums adds a round to the map with the specified part numbers. -func (tb *transferredBundle) addPartNums(rid id.Round, partNums ...uint16) error { - tb.list[rid] = append(tb.list[rid], partNums...) - - // Increment number of parts - tb.numParts += uint16(len(partNums)) - - return tb.save() -} - -// getPartNums returns the list of part numbers for the given round ID. If there -// are no part numbers for the round ID, then it returns false. -func (tb *transferredBundle) getPartNums(rid id.Round) ([]uint16, bool) { - partNums, exists := tb.list[rid] - return partNums, exists -} - -// getNumParts returns the number of file parts stored in list. -func (tb *transferredBundle) getNumParts() uint16 { - return tb.numParts -} - -// deletePartNums deletes the round and its part numbers from the map. -func (tb *transferredBundle) deletePartNums(rid id.Round) error { - // Decrement number of parts - tb.numParts -= uint16(len(tb.list[rid])) - - // Remove from list - delete(tb.list, rid) - - // Remove from storage - return tb.save() -} - -//////////////////////////////////////////////////////////////////////////////// -// Storage Functions // -//////////////////////////////////////////////////////////////////////////////// - -// loadTransferredBundle loads a transferredBundle from storage. -func loadTransferredBundle(key string, kv *versioned.KV) (*transferredBundle, - error) { - vo, err := kv.Get(makeTransferredBundleKey(key), transferredBundleVersion) - if err != nil { - return nil, errors.Errorf(loadTransferredBundleErr, err) - } - - tb := &transferredBundle{ - list: make(map[id.Round][]uint16), - key: key, - kv: kv, - } - - tb.unmarshal(vo.Data) - - return tb, nil -} - -// save stores the transferredBundle to storage. -func (tb *transferredBundle) save() error { - obj := &versioned.Object{ - Version: transferredBundleVersion, - Timestamp: netTime.Now(), - Data: tb.marshal(), - } - - return tb.kv.Set( - makeTransferredBundleKey(tb.key), transferredBundleVersion, obj) -} - -// delete remove the transferredBundle from storage. -func (tb *transferredBundle) delete() error { - return tb.kv.Delete( - makeTransferredBundleKey(tb.key), transferredBundleVersion) -} - -// marshal serialises the map into a byte slice. -func (tb *transferredBundle) marshal() []byte { - // Create buffer - buff := bytes.NewBuffer(nil) - - for rid, partNums := range tb.list { - // Write round ID to buffer - b := make([]byte, 8) - binary.LittleEndian.PutUint64(b, uint64(rid)) - buff.Write(b) - - // Write number of part numbers to buffer - b = make([]byte, 2) - binary.LittleEndian.PutUint16(b, uint16(len(partNums))) - buff.Write(b) - - // Write list of part numbers to buffer - for _, partNum := range partNums { - b = make([]byte, 2) - binary.LittleEndian.PutUint16(b, partNum) - buff.Write(b) - } - } - - return buff.Bytes() -} - -// unmarshal deserializes the byte slice into the transferredBundle. -func (tb *transferredBundle) unmarshal(b []byte) { - buff := bytes.NewBuffer(b) - - // Iterate over all map entries - for n := buff.Next(8); len(n) == 8; n = buff.Next(8) { - // get the round ID from the first 8 bytes - rid := id.Round(binary.LittleEndian.Uint64(n)) - - // get number of part numbers listed - partNumsLen := binary.LittleEndian.Uint16(buff.Next(2)) - - // Increment number of parts - tb.numParts += partNumsLen - - // Initialize part number list to the correct size - tb.list[rid] = make([]uint16, 0, partNumsLen) - - // Add all part numbers to list - for i := uint16(0); i < partNumsLen; i++ { - partNum := binary.LittleEndian.Uint16(buff.Next(2)) - tb.list[rid] = append(tb.list[rid], partNum) - } - } -} - -// makeTransferredBundleKey concatenates a unique key with a constant to create -// a key for saving a transferredBundle to storage. -func makeTransferredBundleKey(key string) string { - return transferredBundleKey + key -} diff --git a/storage/fileTransfer/transferredBundle_test.go b/storage/fileTransfer/transferredBundle_test.go deleted file mode 100644 index 36e618778072601a134a5daf0c3b9dc3446bbcd4..0000000000000000000000000000000000000000 --- a/storage/fileTransfer/transferredBundle_test.go +++ /dev/null @@ -1,317 +0,0 @@ -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - -package fileTransfer - -import ( - "bytes" - "gitlab.com/elixxir/client/storage/versioned" - "gitlab.com/elixxir/ekv" - "gitlab.com/xx_network/primitives/id" - "math/rand" - "reflect" - "sort" - "strings" - "testing" -) - -// Tests that newTransferredBundle returns the expected transferredBundle and -// that it can be loaded from storage. -func Test_newTransferredBundle(t *testing.T) { - expectedTB := &transferredBundle{ - list: make(map[id.Round][]uint16), - key: "testKey1", - kv: versioned.NewKV(make(ekv.Memstore)), - } - - tb, err := newTransferredBundle(expectedTB.key, expectedTB.kv) - if err != nil { - t.Errorf("newTransferredBundle produced an error: %+v", err) - } - - if !reflect.DeepEqual(expectedTB, tb) { - t.Errorf("New transferredBundle does not match expected."+ - "\nexpected: %+v\nreceived: %+v", expectedTB, tb) - } - - _, err = expectedTB.kv.Get( - makeTransferredBundleKey(expectedTB.key), transferredBundleVersion) - if err != nil { - t.Errorf("Failed to load transferredBundle from storage: %+v", err) - } -} - -// Tests that transferredBundle.addPartNums adds the part numbers for the round -// ID to the map correctly. -func Test_transferredBundle_addPartNums(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - key := "testKey" - tb, err := newTransferredBundle(key, kv) - if err != nil { - t.Errorf("Failed to create new transferredBundle: %+v", err) - } - - rid := id.Round(10) - expectedPartNums := []uint16{5, 128, 23, 1} - - err = tb.addPartNums(rid, expectedPartNums...) - if err != nil { - t.Errorf("addPartNums returned an error: %+v", err) - } - - partNums, exists := tb.list[rid] - if !exists || !reflect.DeepEqual(expectedPartNums, partNums) { - t.Errorf("Part numbers in memory does not match expected."+ - "\nexpected: %+v\nreceived: %+v", expectedPartNums, partNums) - } -} - -// Tests that transferredBundle.getPartNums returns the expected part numbers -func Test_transferredBundle_getPartNums(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - key := "testKey" - tb, err := newTransferredBundle(key, kv) - if err != nil { - t.Errorf("Failed to create new transferredBundle: %+v", err) - } - - rid := id.Round(10) - expectedPartNums := []uint16{5, 128, 23, 1} - - err = tb.addPartNums(rid, expectedPartNums...) - if err != nil { - t.Errorf("failed to add part numbers: %+v", err) - } - - partNums, exists := tb.getPartNums(rid) - if !exists || !reflect.DeepEqual(expectedPartNums, partNums) { - t.Errorf("Part numbers in memory does not match expected."+ - "\nexpected: %+v\nreceived: %+v", expectedPartNums, partNums) - } -} - -// Tests that transferredBundle.getNumParts returns the correct number of parts -// after parts are added and removed from the list. -func Test_transferredBundle_getNumParts(t *testing.T) { - prng := rand.New(rand.NewSource(42)) - tb, err := newTransferredBundle("testKey", versioned.NewKV(make(ekv.Memstore))) - if err != nil { - t.Errorf("Failed to create new transferredBundle: %+v", err) - } - - // Add 10 random lists of part numbers to the map - var expectedNumParts uint16 - for i := 0; i < 10; i++ { - partNums := make([]uint16, prng.Intn(16)) - for j := range partNums { - partNums[j] = uint16(prng.Uint32()) - } - - // Add number of parts for odd numbered rounds - if i%2 == 1 { - expectedNumParts += uint16(len(partNums)) - } - - err = tb.addPartNums(id.Round(i), partNums...) - if err != nil { - t.Errorf("Failed to add part #%d: %+v", i, err) - } - } - - // Delete num parts for even numbered rounds - for i := 0; i < 10; i += 2 { - err = tb.deletePartNums(id.Round(i)) - if err != nil { - t.Errorf("Failed to delete part #%d: %+v", i, err) - } - } - - // get number of parts - receivedNumParts := tb.getNumParts() - - if expectedNumParts != receivedNumParts { - t.Errorf("Failed to get expected number of parts."+ - "\nexpected: %d\nreceived: %d", expectedNumParts, receivedNumParts) - } -} - -// Tests that transferredBundle.deletePartNums deletes the part number from -// memory. -func Test_transferredBundle_deletePartNums(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - key := "testKey" - tb, err := newTransferredBundle(key, kv) - if err != nil { - t.Errorf("Failed to create new transferredBundle: %+v", err) - } - - rid := id.Round(10) - expectedPartNums := []uint16{5, 128, 23, 1} - - err = tb.addPartNums(rid, expectedPartNums...) - if err != nil { - t.Errorf("failed to add part numbers: %+v", err) - } - - err = tb.deletePartNums(rid) - if err != nil { - t.Errorf("deletePartNums returned an error: %+v", err) - } - - _, exists := tb.list[rid] - if exists { - t.Error("Found part numbers that should have been deleted.") - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Storage Functions // -//////////////////////////////////////////////////////////////////////////////// - -// Tests that loadTransferredBundle returns a transferredBundle from storage -// that matches the original in memory. -func Test_loadTransferredBundle(t *testing.T) { - expectedTB := &transferredBundle{ - list: map[id.Round][]uint16{ - 1: {1, 2, 3}, - 2: {4, 5, 6}, - 3: {7, 8, 9}, - }, - numParts: 9, - key: "testKey2", - kv: versioned.NewKV(make(ekv.Memstore)), - } - - err := expectedTB.save() - if err != nil { - t.Errorf("Failed to save transferredBundle to storage: %+v", err) - } - - tb, err := loadTransferredBundle(expectedTB.key, expectedTB.kv) - if err != nil { - t.Errorf("loadTransferredBundle returned an error: %+v", err) - } - - if !reflect.DeepEqual(expectedTB, tb) { - t.Errorf("transferredBundle loaded from storage does not match expected."+ - "\nexpected: %+v\nreceived: %+v", expectedTB, tb) - } -} - -// Error path: tests that loadTransferredBundle returns the expected error when -// there is no transferredBundle in storage. -func Test_loadTransferredBundle_LoadError(t *testing.T) { - kv := versioned.NewKV(make(ekv.Memstore)) - expectedErr := strings.Split(loadTransferredBundleErr, "%")[0] - - _, err := loadTransferredBundle("", kv) - if err == nil || !strings.Contains(err.Error(), expectedErr) { - t.Errorf("loadTransferredBundle did not returned the expected error "+ - "when no transferredBundle exists in storage."+ - "\nexpected: %s\nreceived: %+v", expectedErr, err) - } -} - -// Tests that transferredBundle.save saves the correct data to storage. -func Test_transferredBundle_save(t *testing.T) { - tb := &transferredBundle{ - list: map[id.Round][]uint16{ - 1: {1, 2, 3}, - 2: {4, 5, 6}, - 3: {7, 8, 9}, - }, - key: "testKey3", - kv: versioned.NewKV(make(ekv.Memstore)), - } - expectedData := tb.marshal() - - err := tb.save() - if err != nil { - t.Errorf("save returned an error: %+v", err) - } - - vo, err := tb.kv.Get( - makeTransferredBundleKey(tb.key), transferredBundleVersion) - if err != nil { - t.Errorf("Failed to load transferredBundle from storage: %+v", err) - } - - sort.SliceStable(expectedData, func(i, j int) bool { - return expectedData[i] > expectedData[j] - }) - - sort.SliceStable(vo.Data, func(i, j int) bool { - return vo.Data[i] > vo.Data[j] - }) - - if !bytes.Equal(expectedData, vo.Data) { - t.Errorf("Loaded transferredBundle does not match expected."+ - "\nexpected: %+v\nreceived: %+v", expectedData, vo.Data) - } -} - -// Tests that transferredBundle.delete removes a saved transferredBundle from -// storage. -func Test_transferredBundle_delete(t *testing.T) { - tb := &transferredBundle{ - list: map[id.Round][]uint16{ - 1: {1, 2, 3}, - 2: {4, 5, 6}, - 3: {7, 8, 9}, - }, - key: "testKey4", - kv: versioned.NewKV(make(ekv.Memstore)), - } - - err := tb.save() - if err != nil { - t.Errorf("Failed to save transferredBundle to storage: %+v", err) - } - - err = tb.delete() - if err != nil { - t.Errorf("delete returned an error: %+v", err) - } - - _, err = tb.kv.Get(makeTransferredBundleKey(tb.key), transferredBundleVersion) - if err == nil { - t.Error("Read transferredBundleVersion from storage when it should" + - "have been deleted.") - } -} - -// Tests that a transferredBundle that is marshalled via -// transferredBundle.marshal and unmarshalled via transferredBundle.unmarshal -// matches the original. -func Test_transferredBundle_marshal_unmarshal(t *testing.T) { - expectedTB := &transferredBundle{ - list: map[id.Round][]uint16{ - 1: {1, 2, 3}, - 2: {4, 5, 6}, - 3: {7, 8, 9}, - }, - numParts: 9, - } - - b := expectedTB.marshal() - - tb := &transferredBundle{list: make(map[id.Round][]uint16)} - - tb.unmarshal(b) - - if !reflect.DeepEqual(expectedTB, tb) { - t.Errorf("Failed to marshal and unmarshal transferredBundle into "+ - "original.\nexpected: %+v\nreceived: %+v", expectedTB, tb) - } -} diff --git a/storage/fileTransfer/utils_test.go b/storage/fileTransfer/utils_test.go deleted file mode 100644 index cd59306b455bc2533b1f264ab73687a26db69b76..0000000000000000000000000000000000000000 --- a/storage/fileTransfer/utils_test.go +++ /dev/null @@ -1,34 +0,0 @@ -//////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -//////////////////////////////////////////////////////////////////////////////// - -package fileTransfer - -import ( - "github.com/pkg/errors" - "gitlab.com/xx_network/crypto/csprng" - "io" - "math/rand" -) - -//////////////////////////////////////////////////////////////////////////////// -// PRNG // -//////////////////////////////////////////////////////////////////////////////// - -// Prng is a PRNG that satisfies the csprng.Source interface. -type Prng struct{ prng io.Reader } - -func NewPrng(seed int64) csprng.Source { return &Prng{rand.New(rand.NewSource(seed))} } -func (s *Prng) Read(b []byte) (int, error) { return s.prng.Read(b) } -func (s *Prng) SetSeed([]byte) error { return nil } - -// PrngErr is a PRNG that satisfies the csprng.Source interface. However, it -// always returns an error -type PrngErr struct{} - -func NewPrngErr() csprng.Source { return &PrngErr{} } -func (s *PrngErr) Read([]byte) (int, error) { return 0, errors.New("ReadFailure") } -func (s *PrngErr) SetSeed([]byte) error { return errors.New("SetSeedFailure") } diff --git a/storage/ndf_test.go b/storage/ndf_test.go index ece36461e5368cd110b0004d95fc09bf3f28837f..4d48ab2d28d7330511f8474a9a335ef023505eb6 100644 --- a/storage/ndf_test.go +++ b/storage/ndf_test.go @@ -17,10 +17,10 @@ func TestSession_SetGetNDF(t *testing.T) { testNdf := getNDF() sess.SetNDF(testNdf) - if !reflect.DeepEqual(testNdf, sess.ndf) { + if !reflect.DeepEqual(testNdf, sess.GetNDF()) { t.Errorf("SetNDF error: "+ "Unexpected value after setting ndf:"+ - "Expected: %v\n\tReceived: %v", testNdf, sess.ndf) + "Expected: %v\n\tReceived: %v", testNdf, sess.GetNDF()) } receivedNdf := sess.GetNDF() diff --git a/storage/session.go b/storage/session.go index 2f2c543e8f3096c3875b4f553fc8df964213a003..e9bee8ae363597a4e43e6b4b4fb2a30cfc938ecf 100644 --- a/storage/session.go +++ b/storage/session.go @@ -10,15 +10,15 @@ package storage import ( - "gitlab.com/elixxir/client/storage/utility" - "gitlab.com/xx_network/crypto/large" "sync" "testing" "time" + "gitlab.com/elixxir/client/storage/utility" + "gitlab.com/xx_network/crypto/large" + "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" - userInterface "gitlab.com/elixxir/client/interfaces/user" "gitlab.com/elixxir/client/storage/clientVersion" "gitlab.com/elixxir/client/storage/user" "gitlab.com/elixxir/client/storage/versioned" @@ -60,7 +60,7 @@ type Session interface { IsPrecanned() bool SetUsername(username string) error GetUsername() (string, error) - PortableUserInfo() userInterface.Info + PortableUserInfo() user.Info GetTransmissionRegistrationValidationSignature() []byte GetReceptionRegistrationValidationSignature() []byte GetRegistrationTimestamp() time.Time @@ -103,7 +103,7 @@ func initStore(baseDir, password string) (*session, error) { } // Creates new UserData in the session -func New(baseDir, password string, u userInterface.Info, +func New(baseDir, password string, u user.Info, currentVersion version.Version, cmixGrp, e2eGrp *cyclic.Group) (Session, error) { s, err := initStore(baseDir, password) diff --git a/storage/user/Info.go b/storage/user/Info.go deleted file mode 100644 index 8091bffe0c38791c93a944621639a1679b3d8895..0000000000000000000000000000000000000000 --- a/storage/user/Info.go +++ /dev/null @@ -1,28 +0,0 @@ -package user - -import "gitlab.com/elixxir/client/interfaces/user" - -func (u *User) PortableUserInfo() user.Info { - ci := u.CryptographicIdentity - return user.Info{ - TransmissionID: ci.GetTransmissionID().DeepCopy(), - TransmissionSalt: copySlice(ci.GetTransmissionSalt()), - TransmissionRSA: ci.GetTransmissionRSA(), - ReceptionID: ci.GetReceptionID().DeepCopy(), - RegistrationTimestamp: u.GetRegistrationTimestamp().UnixNano(), - ReceptionSalt: copySlice(ci.GetReceptionSalt()), - ReceptionRSA: ci.GetReceptionRSA(), - Precanned: ci.IsPrecanned(), - //fixme: set these in the e2e layer, the command line layer - //needs more logical seperation so this can be removed - E2eDhPrivateKey: nil, - E2eDhPublicKey: nil, - } - -} - -func copySlice(s []byte) []byte { - n := make([]byte, len(s)) - copy(n, s) - return n -} diff --git a/interfaces/user/user.go b/storage/user/info.go similarity index 66% rename from interfaces/user/user.go rename to storage/user/info.go index 0e3c6ca0331756882b38ce9706c1ffdbbcf65e96..11d9381ef07cbad22a91dc49db8cff48d551fed8 100644 --- a/interfaces/user/user.go +++ b/storage/user/info.go @@ -16,6 +16,28 @@ import ( "gitlab.com/xx_network/primitives/id" ) +type Proto struct { + //General Identity + TransmissionID *id.ID + TransmissionSalt []byte + TransmissionRSA *rsa.PrivateKey + ReceptionID *id.ID + ReceptionSalt []byte + ReceptionRSA *rsa.PrivateKey + Precanned bool + // Timestamp in which user has registered with the network + RegistrationTimestamp int64 + + RegCode string + + TransmissionRegValidationSig []byte + ReceptionRegValidationSig []byte + + //e2e Identity + E2eDhPrivateKey *cyclic.Int + E2eDhPublicKey *cyclic.Int +} + type Info struct { //General Identity TransmissionID *id.ID @@ -70,3 +92,28 @@ func NewUserFromBackup(backup *backup.Backup) Info { E2eDhPublicKey: backup.ReceptionIdentity.DHPublicKey, } } + +func (u *User) PortableUserInfo() Info { + ci := u.CryptographicIdentity + return Info{ + TransmissionID: ci.GetTransmissionID().DeepCopy(), + TransmissionSalt: copySlice(ci.GetTransmissionSalt()), + TransmissionRSA: ci.GetTransmissionRSA(), + ReceptionID: ci.GetReceptionID().DeepCopy(), + RegistrationTimestamp: u.GetRegistrationTimestamp().UnixNano(), + ReceptionSalt: copySlice(ci.GetReceptionSalt()), + ReceptionRSA: ci.GetReceptionRSA(), + Precanned: ci.IsPrecanned(), + //fixme: set these in the e2e layer, the command line layer + //needs more logical seperation so this can be removed + E2eDhPrivateKey: nil, + E2eDhPublicKey: nil, + } + +} + +func copySlice(s []byte) []byte { + n := make([]byte, len(s)) + copy(n, s) + return n +} diff --git a/storage/utility/stateVector.go b/storage/utility/stateVector.go index 4e4f9d3cfec0c7cc4688b461ded54f2c4d21e5da..2ed89e2faef5fe3ebcc591e0ca0c933bcd71faf3 100644 --- a/storage/utility/stateVector.go +++ b/storage/utility/stateVector.go @@ -214,7 +214,7 @@ func (sv *StateVector) Next() (uint32, error) { // Save to storage if err := sv.save(); err != nil { - return 0, errors.Errorf(saveNextErr, sv, err) + jww.FATAL.Panicf(saveNextErr, sv, err) } return nextKey, nil @@ -296,6 +296,9 @@ func (sv *StateVector) GetUsedKeyNums() []uint32 { // DeepCopy creates a deep copy of the StateVector without a storage backend. // The deep copy can only be used for functions that do not access storage. func (sv *StateVector) DeepCopy() *StateVector { + sv.mux.RLock() + defer sv.mux.RUnlock() + newSV := &StateVector{ vect: make([]uint64, len(sv.vect)), firstAvailable: sv.firstAvailable, diff --git a/ud/utils_test.go b/ud/utils_test.go index 425d22d9ae31775bc13b2efe6391037ba3a63f93..7d3b3626c3b479e0a14d87746ac707f970a7327e 100644 --- a/ud/utils_test.go +++ b/ud/utils_test.go @@ -15,6 +15,9 @@ package ud import ( + "testing" + "time" + "gitlab.com/elixxir/client/cmix/gateway" "gitlab.com/elixxir/client/event" "gitlab.com/elixxir/client/interfaces" @@ -28,8 +31,6 @@ import ( "gitlab.com/xx_network/primitives/id" "gitlab.com/xx_network/primitives/id/ephemeral" "gitlab.com/xx_network/primitives/ndf" - "testing" - "time" ) func newTestNetworkManager(t *testing.T) interfaces.NetworkManager { @@ -76,7 +77,7 @@ func (tnm *testNetworkManager) SendManyCMIX([]message.TargetedCmixMessage, param type dummyEventMgr struct{} func (d *dummyEventMgr) Report(int, string, string, string) {} -func (tnm *testNetworkManager) GetEventManager() event.Manager { +func (tnm *testNetworkManager) GetEventManager() event.Reporter { return &dummyEventMgr{} }